@@ -25,6 +25,7 @@ public class LSTMCell : SimpleCell
25
25
public double netCellState ;
26
26
public double previousCellState ;
27
27
public double cellState ;
28
+ public double yCellState ;
28
29
29
30
//internal weights and deltas
30
31
public double wCellIn ;
@@ -66,8 +67,8 @@ public class LSTMRNN : RNN
66
67
private new Vector4 vecMaxGrad ;
67
68
private new Vector4 vecMinGrad ;
68
69
69
- private new Vector3 vecMaxGrad3 ;
70
- private new Vector3 vecMinGrad3 ;
70
+ private Vector3 vecMaxGrad3 ;
71
+ private Vector3 vecMinGrad3 ;
71
72
72
73
public LSTMRNN ( )
73
74
{
@@ -453,31 +454,24 @@ public override void initWeights()
453
454
}
454
455
}
455
456
456
- public void LSTMCellInit ( LSTMCell c )
457
+ public void LSTMCellInit ( LSTMCell c , bool bBias = false )
457
458
{
458
- //input gate
459
- c . netIn = 0 ;
460
- c . yIn = 0 ;
461
-
462
- //forget gate
463
- c . netForget = 0 ;
464
- c . yForget = 0 ;
465
-
466
- //cell state
467
- c . netCellState = 0 ;
468
- c . previousCellState = 0 ; //this is important
459
+ c . previousCellState = 0 ;
469
460
c . cellState = 0 ;
470
461
471
462
//partial derivatives
472
463
c . dSWCellIn = 0 ;
473
464
c . dSWCellForget = 0 ;
474
465
475
- //output gate
476
- c . netOut = 0 ;
477
- c . yOut = 0 ;
478
-
479
- //cell output
480
- c . cellOutput = 0 ;
466
+ if ( bBias == false )
467
+ {
468
+ //cell output
469
+ c . cellOutput = 0 ;
470
+ }
471
+ else
472
+ {
473
+ c . cellOutput = 1.0 ;
474
+ }
481
475
}
482
476
483
477
public override void CleanStatus ( )
@@ -544,7 +538,7 @@ private void CreateCell(BinaryReader br)
544
538
for ( int i = 0 ; i < L1 ; i ++ )
545
539
{
546
540
neuHidden [ i ] = new LSTMCell ( ) ;
547
- LSTMCellInit ( neuHidden [ i ] ) ;
541
+ LSTMCellInit ( neuHidden [ i ] , i == L1 - 1 ) ;
548
542
}
549
543
550
544
if ( br != null )
@@ -626,27 +620,26 @@ public override void LearnNet(State state, int numStates, int curState)
626
620
int sparseFeatureSize = sparse . Count ;
627
621
628
622
//put variables for derivaties in weight class and cell class
629
- Parallel . For ( 0 , L1 , parallelOption , i =>
623
+ Parallel . For ( 0 , L1 - 1 , parallelOption , i =>
630
624
{
631
625
LSTMCell c = neuHidden [ i ] ;
632
626
633
627
//using the error find the gradient of the output gate
634
- var gradientOutputGate = ( float ) ( SigmoidDerivative ( c . netOut ) * c . cellState * c . er ) ;
628
+ var gradientOutputGate = ( float ) ( SigmoidDerivative ( c . netOut ) * TanH ( c . cellState ) * c . er ) ;
635
629
636
630
//internal cell state error
637
- var cellStateError = ( float ) ( c . yOut * c . er ) ;
631
+ var cellStateError = ( float ) ( c . er ) ;
638
632
639
633
Vector4 vecErr = new Vector4 ( cellStateError , cellStateError , cellStateError , gradientOutputGate ) ;
640
- vecErr = Vector4 . Clamp ( vecErr , vecMinGrad , vecMaxGrad ) ;
641
634
642
635
var Sigmoid2Derivative_ci_netCellState_mul_ci_yIn = TanHDerivative ( c . netCellState ) * c . yIn ;
643
636
var Sigmoid2_ci_netCellState_mul_SigmoidDerivative_ci_netIn = TanH ( c . netCellState ) * SigmoidDerivative ( c . netIn ) ;
644
637
var ci_previousCellState_mul_SigmoidDerivative_ci_netForget = c . previousCellState * SigmoidDerivative ( c . netForget ) ;
645
638
646
639
Vector3 vecDerivate = new Vector3 (
647
- ( float ) Sigmoid2_ci_netCellState_mul_SigmoidDerivative_ci_netIn ,
648
- ( float ) ci_previousCellState_mul_SigmoidDerivative_ci_netForget ,
649
- ( float ) Sigmoid2Derivative_ci_netCellState_mul_ci_yIn ) ;
640
+ ( float ) ( Sigmoid2_ci_netCellState_mul_SigmoidDerivative_ci_netIn * c . yOut ) ,
641
+ ( float ) ( ci_previousCellState_mul_SigmoidDerivative_ci_netForget * c . yOut ) ,
642
+ ( float ) ( Sigmoid2Derivative_ci_netCellState_mul_ci_yIn * c . yOut ) ) ;
650
643
float c_yForget = ( float ) c . yForget ;
651
644
652
645
@@ -668,8 +661,9 @@ public override void LearnNet(State state, int numStates, int curState)
668
661
//Computing final err delta
669
662
Vector4 vecDelta = new Vector4 ( wd , entry . Value ) ;
670
663
vecDelta = vecErr * vecDelta ;
664
+ vecDelta = Vector4 . Clamp ( vecDelta , vecMinGrad , vecMaxGrad ) ;
671
665
672
- //Computing actual learning rate
666
+ //Computing actual learning rate
673
667
Vector4 vecLearningRate = ComputeLearningRate ( vecDelta , ref wlr_i [ entry . Key ] ) ;
674
668
w_i [ entry . Key ] += vecLearningRate * vecDelta ;
675
669
}
@@ -693,6 +687,7 @@ public override void LearnNet(State state, int numStates, int curState)
693
687
694
688
Vector4 vecDelta = new Vector4 ( wd , feature ) ;
695
689
vecDelta = vecErr * vecDelta ;
690
+ vecDelta = Vector4 . Clamp ( vecDelta , vecMinGrad , vecMaxGrad ) ;
696
691
697
692
//Computing actual learning rate
698
693
Vector4 vecLearningRate = ComputeLearningRate ( vecDelta , ref wlr_i [ j ] ) ;
@@ -709,11 +704,10 @@ public override void LearnNet(State state, int numStates, int curState)
709
704
//update internal weights
710
705
Vector3 vecCellDelta = new Vector3 ( ( float ) c . dSWCellIn , ( float ) c . dSWCellForget , ( float ) c . cellState ) ;
711
706
Vector3 vecCellErr = new Vector3 ( cellStateError , cellStateError , gradientOutputGate ) ;
707
+ vecCellDelta = vecCellErr * vecCellDelta ;
712
708
713
709
//Normalize err by gradient cut-off
714
- vecCellErr = Vector3 . Clamp ( vecCellErr , vecMinGrad3 , vecMaxGrad3 ) ;
715
-
716
- vecCellDelta = vecCellErr * vecCellDelta ;
710
+ vecCellDelta = Vector3 . Clamp ( vecCellDelta , vecMinGrad3 , vecMaxGrad3 ) ;
717
711
718
712
//Computing actual learning rate
719
713
Vector3 vecCellLearningRate = ComputeLearningRate ( vecCellDelta , ref CellLearningRate [ i ] ) ;
@@ -737,7 +731,7 @@ public override void computeHiddenLayer(State state, bool isTrain = true)
737
731
var sparse = state . SparseData ;
738
732
int sparseFeatureSize = sparse . Count ;
739
733
740
- Parallel . For ( 0 , L1 , parallelOption , j =>
734
+ Parallel . For ( 0 , L1 - 1 , parallelOption , j =>
741
735
{
742
736
LSTMCell cell_j = neuHidden [ j ] ;
743
737
@@ -780,14 +774,15 @@ public override void computeHiddenLayer(State state, bool isTrain = true)
780
774
cell_j . netForget += cell_j . previousCellState * cell_j . wCellForget ;
781
775
cell_j . yForget = Sigmoid ( cell_j . netForget ) ;
782
776
777
+ cell_j . yCellState = TanH ( cell_j . netCellState ) ;
783
778
if ( cell_j . mask == true )
784
779
{
785
780
cell_j . cellState = 0 ;
786
781
}
787
782
else
788
783
{
789
784
//cell state is equal to the previous cell state multipled by the forget gate and the cell inputs multiplied by the input gate
790
- cell_j . cellState = cell_j . yForget * cell_j . previousCellState + cell_j . yIn * TanH ( cell_j . netCellState ) ;
785
+ cell_j . cellState = cell_j . yForget * cell_j . previousCellState + cell_j . yIn * cell_j . yCellState ;
791
786
}
792
787
793
788
if ( isTrain == false )
@@ -801,7 +796,7 @@ public override void computeHiddenLayer(State state, bool isTrain = true)
801
796
//squash output gate
802
797
cell_j . yOut = Sigmoid ( cell_j . netOut ) ;
803
798
804
- cell_j . cellOutput = cell_j . cellState * cell_j . yOut ;
799
+ cell_j . cellOutput = TanH ( cell_j . cellState ) * cell_j . yOut ;
805
800
806
801
neuHidden [ j ] = cell_j ;
807
802
} ) ;
@@ -825,7 +820,7 @@ public override void netReset(bool updateNet = false) //cleans hidden layer ac
825
820
for ( int i = 0 ; i < L1 ; i ++ )
826
821
{
827
822
neuHidden [ i ] . mask = false ;
828
- LSTMCellInit ( neuHidden [ i ] ) ;
823
+ LSTMCellInit ( neuHidden [ i ] , i == L1 - 1 ) ;
829
824
}
830
825
831
826
if ( Dropout > 0 && updateNet == true )
0 commit comments