@@ -28,18 +28,28 @@ public class LSTMCell : SimpleCell
28
28
public double yCellState ;
29
29
30
30
//internal weights and deltas
31
+ public double wPeepholeIn ;
32
+ public double wPeepholeForget ;
33
+ public double wPeepholeOut ;
34
+
35
+ //partial derivatives
36
+ public double dSWPeepholeIn ;
37
+ public double dSWPeepholeForget ;
38
+
31
39
public double wCellIn ;
32
40
public double wCellForget ;
41
+ public double wCellState ;
33
42
public double wCellOut ;
34
43
35
- //partial derivatives
36
44
public double dSWCellIn ;
37
45
public double dSWCellForget ;
38
- // double dSWCellState;
46
+ public double dSWCellState ;
39
47
40
48
//output gate
41
49
public double netOut ;
42
50
public double yOut ;
51
+
52
+ public double previousCellOutput ;
43
53
}
44
54
45
55
public class LSTMRNN : RNN
@@ -56,7 +66,8 @@ public class LSTMRNN : RNN
56
66
57
67
protected Vector4 [ ] [ ] Input2HiddenLearningRate ;
58
68
protected Vector4 [ ] [ ] Feature2HiddenLearningRate ;
59
- protected Vector3 [ ] CellLearningRate ;
69
+ protected Vector3 [ ] PeepholeLearningRate ;
70
+ protected Vector4 [ ] CellLearningRate ;
60
71
61
72
protected Vector3 [ ] [ ] input2hiddenDeri ;
62
73
protected Vector3 [ ] [ ] feature2hiddenDeri ;
@@ -322,8 +333,14 @@ public void SaveHiddenLayerWeights(BinaryWriter fo)
322
333
{
323
334
for ( int i = 0 ; i < L1 ; i ++ )
324
335
{
336
+ fo . Write ( neuHidden [ i ] . wPeepholeIn ) ;
337
+ fo . Write ( neuHidden [ i ] . wPeepholeForget ) ;
338
+ // fo.Write(neuHidden[i].wCellState);
339
+ fo . Write ( neuHidden [ i ] . wPeepholeOut ) ;
340
+
325
341
fo . Write ( neuHidden [ i ] . wCellIn ) ;
326
342
fo . Write ( neuHidden [ i ] . wCellForget ) ;
343
+ fo . Write ( neuHidden [ i ] . wCellState ) ;
327
344
fo . Write ( neuHidden [ i ] . wCellOut ) ;
328
345
}
329
346
}
@@ -460,8 +477,13 @@ public void LSTMCellInit(LSTMCell c, bool bBias = false)
460
477
c . cellState = 0 ;
461
478
462
479
//partial derivatives
480
+ c . dSWPeepholeIn = 0 ;
481
+ c . dSWPeepholeForget = 0 ;
482
+ // c.dSWCellState = 0;
483
+
463
484
c . dSWCellIn = 0 ;
464
485
c . dSWCellForget = 0 ;
486
+ c . dSWCellState = 0 ;
465
487
466
488
if ( bBias == false )
467
489
{
@@ -482,7 +504,8 @@ public override void CleanStatus()
482
504
Feature2HiddenLearningRate = new Vector4 [ L1 ] [ ] ;
483
505
}
484
506
485
- CellLearningRate = new Vector3 [ L1 ] ;
507
+ PeepholeLearningRate = new Vector3 [ L1 ] ;
508
+ CellLearningRate = new Vector4 [ L1 ] ;
486
509
Parallel . For ( 0 , L1 , parallelOption , i =>
487
510
{
488
511
Input2HiddenLearningRate [ i ] = new Vector4 [ L0 ] ;
@@ -502,7 +525,6 @@ public override void CleanStatus()
502
525
503
526
vecMaxGrad3 = new Vector3 ( ( float ) GradientCutoff , ( float ) GradientCutoff , ( float ) GradientCutoff ) ;
504
527
vecMinGrad3 = new Vector3 ( ( float ) ( - GradientCutoff ) , ( float ) ( - GradientCutoff ) , ( float ) ( - GradientCutoff ) ) ;
505
-
506
528
}
507
529
508
530
public override void InitMem ( )
@@ -546,8 +568,13 @@ private void CreateCell(BinaryReader br)
546
568
//Load weight from input file
547
569
for ( int i = 0 ; i < L1 ; i ++ )
548
570
{
571
+ neuHidden [ i ] . wPeepholeIn = br . ReadDouble ( ) ;
572
+ neuHidden [ i ] . wPeepholeForget = br . ReadDouble ( ) ;
573
+ neuHidden [ i ] . wPeepholeOut = br . ReadDouble ( ) ;
574
+
549
575
neuHidden [ i ] . wCellIn = br . ReadDouble ( ) ;
550
576
neuHidden [ i ] . wCellForget = br . ReadDouble ( ) ;
577
+ neuHidden [ i ] . wCellState = br . ReadDouble ( ) ;
551
578
neuHidden [ i ] . wCellOut = br . ReadDouble ( ) ;
552
579
}
553
580
}
@@ -557,8 +584,13 @@ private void CreateCell(BinaryReader br)
557
584
for ( int i = 0 ; i < L1 ; i ++ )
558
585
{
559
586
//internal weights, also important
587
+ neuHidden [ i ] . wPeepholeIn = RandInitWeight ( ) ;
588
+ neuHidden [ i ] . wPeepholeForget = RandInitWeight ( ) ;
589
+ neuHidden [ i ] . wPeepholeOut = RandInitWeight ( ) ;
590
+
560
591
neuHidden [ i ] . wCellIn = RandInitWeight ( ) ;
561
592
neuHidden [ i ] . wCellForget = RandInitWeight ( ) ;
593
+ neuHidden [ i ] . wCellState = RandInitWeight ( ) ;
562
594
neuHidden [ i ] . wCellOut = RandInitWeight ( ) ;
563
595
}
564
596
}
@@ -628,18 +660,18 @@ public override void LearnNet(State state, int numStates, int curState)
628
660
var gradientOutputGate = ( float ) ( SigmoidDerivative ( c . netOut ) * TanH ( c . cellState ) * c . er ) ;
629
661
630
662
//internal cell state error
631
- var cellStateError = ( float ) ( c . er ) ;
663
+ var cellStateError = ( float ) ( c . yOut * c . er * TanHDerivative ( c . cellState ) ) ;
632
664
633
665
Vector4 vecErr = new Vector4 ( cellStateError , cellStateError , cellStateError , gradientOutputGate ) ;
634
666
635
- var Sigmoid2Derivative_ci_netCellState_mul_ci_yIn = TanHDerivative ( c . netCellState ) * c . yIn ;
636
667
var Sigmoid2_ci_netCellState_mul_SigmoidDerivative_ci_netIn = TanH ( c . netCellState ) * SigmoidDerivative ( c . netIn ) ;
637
668
var ci_previousCellState_mul_SigmoidDerivative_ci_netForget = c . previousCellState * SigmoidDerivative ( c . netForget ) ;
669
+ var Sigmoid2Derivative_ci_netCellState_mul_ci_yIn = TanHDerivative ( c . netCellState ) * c . yIn ;
638
670
639
671
Vector3 vecDerivate = new Vector3 (
640
- ( float ) ( Sigmoid2_ci_netCellState_mul_SigmoidDerivative_ci_netIn * c . yOut ) ,
641
- ( float ) ( ci_previousCellState_mul_SigmoidDerivative_ci_netForget * c . yOut ) ,
642
- ( float ) ( Sigmoid2Derivative_ci_netCellState_mul_ci_yIn * c . yOut ) ) ;
672
+ ( float ) ( Sigmoid2_ci_netCellState_mul_SigmoidDerivative_ci_netIn ) ,
673
+ ( float ) ( ci_previousCellState_mul_SigmoidDerivative_ci_netForget ) ,
674
+ ( float ) ( Sigmoid2Derivative_ci_netCellState_mul_ci_yIn ) ) ;
643
675
float c_yForget = ( float ) c . yForget ;
644
676
645
677
@@ -695,28 +727,59 @@ public override void LearnNet(State state, int numStates, int curState)
695
727
}
696
728
}
697
729
730
+ //Update peephols weights
731
+
698
732
//partial derivatives for internal connections
699
- c . dSWCellIn = c . dSWCellIn * c . yForget + Sigmoid2_ci_netCellState_mul_SigmoidDerivative_ci_netIn * c . cellState ;
733
+ c . dSWPeepholeIn = c . dSWPeepholeIn * c . yForget + Sigmoid2_ci_netCellState_mul_SigmoidDerivative_ci_netIn * c . previousCellState ;
700
734
701
735
//partial derivatives for internal connections, initially zero as dS is zero and previous cell state is zero
702
- c . dSWCellForget = c . dSWCellForget * c . yForget + ci_previousCellState_mul_SigmoidDerivative_ci_netForget * c . previousCellState ;
736
+ c . dSWPeepholeForget = c . dSWPeepholeForget * c . yForget + ci_previousCellState_mul_SigmoidDerivative_ci_netForget * c . previousCellState ;
703
737
704
738
//update internal weights
705
- Vector3 vecCellDelta = new Vector3 ( ( float ) c . dSWCellIn , ( float ) c . dSWCellForget , ( float ) c . cellState ) ;
706
- Vector3 vecCellErr = new Vector3 ( cellStateError , cellStateError , gradientOutputGate ) ;
707
- vecCellDelta = vecCellErr * vecCellDelta ;
739
+ Vector3 vecCellDelta = new Vector3 ( ( float ) c . dSWPeepholeIn , ( float ) c . dSWPeepholeForget , ( float ) c . cellState ) ;
740
+ Vector3 vecErr3 = new Vector3 ( cellStateError , cellStateError , gradientOutputGate ) ;
741
+
742
+ vecCellDelta = vecErr3 * vecCellDelta ;
708
743
709
744
//Normalize err by gradient cut-off
710
745
vecCellDelta = Vector3 . Clamp ( vecCellDelta , vecMinGrad3 , vecMaxGrad3 ) ;
711
746
712
747
//Computing actual learning rate
713
- Vector3 vecCellLearningRate = ComputeLearningRate ( vecCellDelta , ref CellLearningRate [ i ] ) ;
748
+ Vector3 vecCellLearningRate = ComputeLearningRate ( vecCellDelta , ref PeepholeLearningRate [ i ] ) ;
714
749
715
750
vecCellDelta = vecCellLearningRate * vecCellDelta ;
716
751
717
- c . wCellIn += vecCellDelta . X ;
718
- c . wCellForget += vecCellDelta . Y ;
719
- c . wCellOut += vecCellDelta . Z ;
752
+ c . wPeepholeIn += vecCellDelta . X ;
753
+ c . wPeepholeForget += vecCellDelta . Y ;
754
+ c . wPeepholeOut += vecCellDelta . Z ;
755
+
756
+
757
+
758
+ //Update cells weights
759
+ //partial derivatives for internal connections
760
+ c . dSWCellIn = c . dSWCellIn * c . yForget + Sigmoid2_ci_netCellState_mul_SigmoidDerivative_ci_netIn * c . previousCellOutput ;
761
+
762
+ //partial derivatives for internal connections, initially zero as dS is zero and previous cell state is zero
763
+ c . dSWCellForget = c . dSWCellForget * c . yForget + ci_previousCellState_mul_SigmoidDerivative_ci_netForget * c . previousCellOutput ;
764
+
765
+ c . dSWCellState = c . dSWCellState * c . yForget + Sigmoid2Derivative_ci_netCellState_mul_ci_yIn * c . previousCellOutput ;
766
+
767
+ Vector4 vecCellDelta4 = new Vector4 ( ( float ) c . dSWCellIn , ( float ) c . dSWCellForget , ( float ) c . dSWCellState , ( float ) c . previousCellOutput ) ;
768
+ vecCellDelta4 = vecErr * vecCellDelta4 ;
769
+
770
+ //Normalize err by gradient cut-off
771
+ vecCellDelta4 = Vector4 . Clamp ( vecCellDelta4 , vecMinGrad , vecMaxGrad ) ;
772
+
773
+ //Computing actual learning rate
774
+ Vector4 vecCellLearningRate4 = ComputeLearningRate ( vecCellDelta4 , ref CellLearningRate [ i ] ) ;
775
+
776
+ vecCellDelta4 = vecCellLearningRate4 * vecCellDelta4 ;
777
+
778
+ c . wCellIn += vecCellDelta4 . X ;
779
+ c . wCellForget += vecCellDelta4 . Y ;
780
+ c . wCellState += vecCellDelta4 . Z ;
781
+ c . wCellOut += vecCellDelta4 . W ;
782
+
720
783
721
784
neuHidden [ i ] = c ;
722
785
} ) ;
@@ -737,6 +800,7 @@ public override void computeHiddenLayer(State state, bool isTrain = true)
737
800
738
801
//hidden(t-1) -> hidden(t)
739
802
cell_j . previousCellState = cell_j . cellState ;
803
+ cell_j . previousCellOutput = cell_j . cellOutput ;
740
804
741
805
Vector4 vecCell_j = Vector4 . Zero ;
742
806
//Apply sparse weights
@@ -766,15 +830,17 @@ public override void computeHiddenLayer(State state, bool isTrain = true)
766
830
cell_j . netOut = vecCell_j . W ;
767
831
768
832
//include internal connection multiplied by the previous cell state
769
- cell_j . netIn += cell_j . previousCellState * cell_j . wCellIn ;
833
+ cell_j . netIn += cell_j . previousCellState * cell_j . wPeepholeIn + cell_j . previousCellOutput * cell_j . wCellIn ;
770
834
//squash input
771
835
cell_j . yIn = Sigmoid ( cell_j . netIn ) ;
772
836
773
837
//include internal connection multiplied by the previous cell state
774
- cell_j . netForget += cell_j . previousCellState * cell_j . wCellForget ;
838
+ cell_j . netForget += cell_j . previousCellState * cell_j . wPeepholeForget + cell_j . previousCellOutput * cell_j . wCellForget ;
775
839
cell_j . yForget = Sigmoid ( cell_j . netForget ) ;
776
840
841
+ cell_j . netCellState += cell_j . previousCellOutput * cell_j . wCellState ;
777
842
cell_j . yCellState = TanH ( cell_j . netCellState ) ;
843
+
778
844
if ( cell_j . mask == true )
779
845
{
780
846
cell_j . cellState = 0 ;
@@ -791,7 +857,7 @@ public override void computeHiddenLayer(State state, bool isTrain = true)
791
857
}
792
858
793
859
////include the internal connection multiplied by the CURRENT cell state
794
- cell_j . netOut += cell_j . cellState * cell_j . wCellOut ;
860
+ cell_j . netOut += cell_j . cellState * cell_j . wPeepholeOut + cell_j . previousCellOutput * cell_j . wCellOut ;
795
861
796
862
//squash output gate
797
863
cell_j . yOut = Sigmoid ( cell_j . netOut ) ;
0 commit comments