Skip to content

Commit 54018d4

Browse files
committed
Improve LSTM cell states
1 parent 37a29ef commit 54018d4

File tree

3 files changed

+97
-23
lines changed

3 files changed

+97
-23
lines changed

RNNSharp/LSTMRNN.cs

Lines changed: 88 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,28 @@ public class LSTMCell : SimpleCell
2828
public double yCellState;
2929

3030
//internal weights and deltas
31+
public double wPeepholeIn;
32+
public double wPeepholeForget;
33+
public double wPeepholeOut;
34+
35+
//partial derivatives
36+
public double dSWPeepholeIn;
37+
public double dSWPeepholeForget;
38+
3139
public double wCellIn;
3240
public double wCellForget;
41+
public double wCellState;
3342
public double wCellOut;
3443

35-
//partial derivatives
3644
public double dSWCellIn;
3745
public double dSWCellForget;
38-
//double dSWCellState;
46+
public double dSWCellState;
3947

4048
//output gate
4149
public double netOut;
4250
public double yOut;
51+
52+
public double previousCellOutput;
4353
}
4454

4555
public class LSTMRNN : RNN
@@ -56,7 +66,8 @@ public class LSTMRNN : RNN
5666

5767
protected Vector4[][] Input2HiddenLearningRate;
5868
protected Vector4[][] Feature2HiddenLearningRate;
59-
protected Vector3[] CellLearningRate;
69+
protected Vector3[] PeepholeLearningRate;
70+
protected Vector4[] CellLearningRate;
6071

6172
protected Vector3[][] input2hiddenDeri;
6273
protected Vector3[][] feature2hiddenDeri;
@@ -322,8 +333,14 @@ public void SaveHiddenLayerWeights(BinaryWriter fo)
322333
{
323334
for (int i = 0; i < L1; i++)
324335
{
336+
fo.Write(neuHidden[i].wPeepholeIn);
337+
fo.Write(neuHidden[i].wPeepholeForget);
338+
// fo.Write(neuHidden[i].wCellState);
339+
fo.Write(neuHidden[i].wPeepholeOut);
340+
325341
fo.Write(neuHidden[i].wCellIn);
326342
fo.Write(neuHidden[i].wCellForget);
343+
fo.Write(neuHidden[i].wCellState);
327344
fo.Write(neuHidden[i].wCellOut);
328345
}
329346
}
@@ -460,8 +477,13 @@ public void LSTMCellInit(LSTMCell c, bool bBias = false)
460477
c.cellState = 0;
461478

462479
//partial derivatives
480+
c.dSWPeepholeIn = 0;
481+
c.dSWPeepholeForget = 0;
482+
// c.dSWCellState = 0;
483+
463484
c.dSWCellIn = 0;
464485
c.dSWCellForget = 0;
486+
c.dSWCellState = 0;
465487

466488
if (bBias == false)
467489
{
@@ -482,7 +504,8 @@ public override void CleanStatus()
482504
Feature2HiddenLearningRate = new Vector4[L1][];
483505
}
484506

485-
CellLearningRate = new Vector3[L1];
507+
PeepholeLearningRate = new Vector3[L1];
508+
CellLearningRate = new Vector4[L1];
486509
Parallel.For(0, L1, parallelOption, i =>
487510
{
488511
Input2HiddenLearningRate[i] = new Vector4[L0];
@@ -502,7 +525,6 @@ public override void CleanStatus()
502525

503526
vecMaxGrad3 = new Vector3((float)GradientCutoff, (float)GradientCutoff, (float)GradientCutoff);
504527
vecMinGrad3 = new Vector3((float)(-GradientCutoff), (float)(-GradientCutoff), (float)(-GradientCutoff));
505-
506528
}
507529

508530
public override void InitMem()
@@ -546,8 +568,13 @@ private void CreateCell(BinaryReader br)
546568
//Load weight from input file
547569
for (int i = 0; i < L1; i++)
548570
{
571+
neuHidden[i].wPeepholeIn = br.ReadDouble();
572+
neuHidden[i].wPeepholeForget = br.ReadDouble();
573+
neuHidden[i].wPeepholeOut = br.ReadDouble();
574+
549575
neuHidden[i].wCellIn = br.ReadDouble();
550576
neuHidden[i].wCellForget = br.ReadDouble();
577+
neuHidden[i].wCellState = br.ReadDouble();
551578
neuHidden[i].wCellOut = br.ReadDouble();
552579
}
553580
}
@@ -557,8 +584,13 @@ private void CreateCell(BinaryReader br)
557584
for (int i = 0; i < L1; i++)
558585
{
559586
//internal weights, also important
587+
neuHidden[i].wPeepholeIn = RandInitWeight();
588+
neuHidden[i].wPeepholeForget = RandInitWeight();
589+
neuHidden[i].wPeepholeOut = RandInitWeight();
590+
560591
neuHidden[i].wCellIn = RandInitWeight();
561592
neuHidden[i].wCellForget = RandInitWeight();
593+
neuHidden[i].wCellState = RandInitWeight();
562594
neuHidden[i].wCellOut = RandInitWeight();
563595
}
564596
}
@@ -628,18 +660,18 @@ public override void LearnNet(State state, int numStates, int curState)
628660
var gradientOutputGate = (float)(SigmoidDerivative(c.netOut) * TanH(c.cellState) * c.er);
629661

630662
//internal cell state error
631-
var cellStateError = (float)(c.er);
663+
var cellStateError = (float)(c.yOut * c.er * TanHDerivative(c.cellState));
632664

633665
Vector4 vecErr = new Vector4(cellStateError, cellStateError, cellStateError, gradientOutputGate);
634666

635-
var Sigmoid2Derivative_ci_netCellState_mul_ci_yIn = TanHDerivative(c.netCellState) * c.yIn;
636667
var Sigmoid2_ci_netCellState_mul_SigmoidDerivative_ci_netIn = TanH(c.netCellState) * SigmoidDerivative(c.netIn);
637668
var ci_previousCellState_mul_SigmoidDerivative_ci_netForget = c.previousCellState * SigmoidDerivative(c.netForget);
669+
var Sigmoid2Derivative_ci_netCellState_mul_ci_yIn = TanHDerivative(c.netCellState) * c.yIn;
638670

639671
Vector3 vecDerivate = new Vector3(
640-
(float)(Sigmoid2_ci_netCellState_mul_SigmoidDerivative_ci_netIn * c.yOut),
641-
(float)(ci_previousCellState_mul_SigmoidDerivative_ci_netForget * c.yOut),
642-
(float)(Sigmoid2Derivative_ci_netCellState_mul_ci_yIn * c.yOut));
672+
(float)(Sigmoid2_ci_netCellState_mul_SigmoidDerivative_ci_netIn),
673+
(float)(ci_previousCellState_mul_SigmoidDerivative_ci_netForget),
674+
(float)(Sigmoid2Derivative_ci_netCellState_mul_ci_yIn));
643675
float c_yForget = (float)c.yForget;
644676

645677

@@ -695,28 +727,59 @@ public override void LearnNet(State state, int numStates, int curState)
695727
}
696728
}
697729

730+
//Update peephols weights
731+
698732
//partial derivatives for internal connections
699-
c.dSWCellIn = c.dSWCellIn * c.yForget + Sigmoid2_ci_netCellState_mul_SigmoidDerivative_ci_netIn * c.cellState;
733+
c.dSWPeepholeIn = c.dSWPeepholeIn * c.yForget + Sigmoid2_ci_netCellState_mul_SigmoidDerivative_ci_netIn * c.previousCellState;
700734

701735
//partial derivatives for internal connections, initially zero as dS is zero and previous cell state is zero
702-
c.dSWCellForget = c.dSWCellForget * c.yForget + ci_previousCellState_mul_SigmoidDerivative_ci_netForget * c.previousCellState;
736+
c.dSWPeepholeForget = c.dSWPeepholeForget * c.yForget + ci_previousCellState_mul_SigmoidDerivative_ci_netForget * c.previousCellState;
703737

704738
//update internal weights
705-
Vector3 vecCellDelta = new Vector3((float)c.dSWCellIn, (float)c.dSWCellForget, (float)c.cellState);
706-
Vector3 vecCellErr = new Vector3(cellStateError, cellStateError, gradientOutputGate);
707-
vecCellDelta = vecCellErr * vecCellDelta;
739+
Vector3 vecCellDelta = new Vector3((float)c.dSWPeepholeIn, (float)c.dSWPeepholeForget, (float)c.cellState);
740+
Vector3 vecErr3 = new Vector3(cellStateError, cellStateError, gradientOutputGate);
741+
742+
vecCellDelta = vecErr3 * vecCellDelta;
708743

709744
//Normalize err by gradient cut-off
710745
vecCellDelta = Vector3.Clamp(vecCellDelta, vecMinGrad3, vecMaxGrad3);
711746

712747
//Computing actual learning rate
713-
Vector3 vecCellLearningRate = ComputeLearningRate(vecCellDelta, ref CellLearningRate[i]);
748+
Vector3 vecCellLearningRate = ComputeLearningRate(vecCellDelta, ref PeepholeLearningRate[i]);
714749

715750
vecCellDelta = vecCellLearningRate * vecCellDelta;
716751

717-
c.wCellIn += vecCellDelta.X;
718-
c.wCellForget += vecCellDelta.Y;
719-
c.wCellOut += vecCellDelta.Z;
752+
c.wPeepholeIn += vecCellDelta.X;
753+
c.wPeepholeForget += vecCellDelta.Y;
754+
c.wPeepholeOut += vecCellDelta.Z;
755+
756+
757+
758+
//Update cells weights
759+
//partial derivatives for internal connections
760+
c.dSWCellIn = c.dSWCellIn * c.yForget + Sigmoid2_ci_netCellState_mul_SigmoidDerivative_ci_netIn * c.previousCellOutput;
761+
762+
//partial derivatives for internal connections, initially zero as dS is zero and previous cell state is zero
763+
c.dSWCellForget = c.dSWCellForget * c.yForget + ci_previousCellState_mul_SigmoidDerivative_ci_netForget * c.previousCellOutput;
764+
765+
c.dSWCellState = c.dSWCellState * c.yForget + Sigmoid2Derivative_ci_netCellState_mul_ci_yIn * c.previousCellOutput;
766+
767+
Vector4 vecCellDelta4 = new Vector4((float)c.dSWCellIn, (float)c.dSWCellForget, (float)c.dSWCellState, (float)c.previousCellOutput);
768+
vecCellDelta4 = vecErr * vecCellDelta4;
769+
770+
//Normalize err by gradient cut-off
771+
vecCellDelta4 = Vector4.Clamp(vecCellDelta4, vecMinGrad, vecMaxGrad);
772+
773+
//Computing actual learning rate
774+
Vector4 vecCellLearningRate4 = ComputeLearningRate(vecCellDelta4, ref CellLearningRate[i]);
775+
776+
vecCellDelta4 = vecCellLearningRate4 * vecCellDelta4;
777+
778+
c.wCellIn += vecCellDelta4.X;
779+
c.wCellForget += vecCellDelta4.Y;
780+
c.wCellState += vecCellDelta4.Z;
781+
c.wCellOut += vecCellDelta4.W;
782+
720783

721784
neuHidden[i] = c;
722785
});
@@ -737,6 +800,7 @@ public override void computeHiddenLayer(State state, bool isTrain = true)
737800

738801
//hidden(t-1) -> hidden(t)
739802
cell_j.previousCellState = cell_j.cellState;
803+
cell_j.previousCellOutput = cell_j.cellOutput;
740804

741805
Vector4 vecCell_j = Vector4.Zero;
742806
//Apply sparse weights
@@ -766,15 +830,17 @@ public override void computeHiddenLayer(State state, bool isTrain = true)
766830
cell_j.netOut = vecCell_j.W;
767831

768832
//include internal connection multiplied by the previous cell state
769-
cell_j.netIn += cell_j.previousCellState * cell_j.wCellIn;
833+
cell_j.netIn += cell_j.previousCellState * cell_j.wPeepholeIn + cell_j.previousCellOutput * cell_j.wCellIn;
770834
//squash input
771835
cell_j.yIn = Sigmoid(cell_j.netIn);
772836

773837
//include internal connection multiplied by the previous cell state
774-
cell_j.netForget += cell_j.previousCellState * cell_j.wCellForget;
838+
cell_j.netForget += cell_j.previousCellState * cell_j.wPeepholeForget + cell_j.previousCellOutput * cell_j.wCellForget;
775839
cell_j.yForget = Sigmoid(cell_j.netForget);
776840

841+
cell_j.netCellState += cell_j.previousCellOutput * cell_j.wCellState;
777842
cell_j.yCellState = TanH(cell_j.netCellState);
843+
778844
if (cell_j.mask == true)
779845
{
780846
cell_j.cellState = 0;
@@ -791,7 +857,7 @@ public override void computeHiddenLayer(State state, bool isTrain = true)
791857
}
792858

793859
////include the internal connection multiplied by the CURRENT cell state
794-
cell_j.netOut += cell_j.cellState * cell_j.wCellOut;
860+
cell_j.netOut += cell_j.cellState * cell_j.wPeepholeOut + cell_j.previousCellOutput * cell_j.wCellOut;
795861

796862
//squash output gate
797863
cell_j.yOut = Sigmoid(cell_j.netOut);

RNNSharp/RNNEncoder.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,14 @@ public void Train()
119119
rnn.SaveModel(m_modelSetting.ModelFile);
120120
}
121121
}
122-
122+
else if (ppl < lastPPL)
123+
{
124+
//We don't have validate corpus, but we get a better result on training corpus
125+
//We got better result on validated corpus, save this model
126+
Logger.WriteLine("Saving better model into file {0}...", m_modelSetting.ModelFile);
127+
rnn.SaveModel(m_modelSetting.ModelFile);
128+
}
129+
123130
if (ppl >= lastPPL)
124131
{
125132
//We cannot get a better result on training corpus, so reduce learning rate

RNNSharp/neuron.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ public class SimpleCell
88
{
99
//cell output
1010
public double cellOutput;
11+
1112
public double er;
1213
public bool mask;
1314
}

0 commit comments

Comments
 (0)