Skip to content

Commit 37a29ef

Browse files
committed
zhongkaifu#1. Adding bias cell for LSTM
1 parent 2d58c7a commit 37a29ef

File tree

3 files changed

+38
-59
lines changed

3 files changed

+38
-59
lines changed

RNNSharp/BiRNN.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ public SimpleLayer[] InnerDecode(Sequence pSequence, out SimpleLayer[] outputHid
243243
{
244244
State state = pSequence.States[curState];
245245
forwardRNN.SetInputLayer(state, curState, numStates, null);
246-
forwardRNN.computeHiddenLayer(state); //compute probability distribution
246+
forwardRNN.computeHiddenLayer(state);
247247

248248
mForward[curState] = forwardRNN.GetHiddenLayer();
249249
}

RNNSharp/LSTMRNN.cs

Lines changed: 31 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ public class LSTMCell : SimpleCell
2525
public double netCellState;
2626
public double previousCellState;
2727
public double cellState;
28+
public double yCellState;
2829

2930
//internal weights and deltas
3031
public double wCellIn;
@@ -66,8 +67,8 @@ public class LSTMRNN : RNN
6667
private new Vector4 vecMaxGrad;
6768
private new Vector4 vecMinGrad;
6869

69-
private new Vector3 vecMaxGrad3;
70-
private new Vector3 vecMinGrad3;
70+
private Vector3 vecMaxGrad3;
71+
private Vector3 vecMinGrad3;
7172

7273
public LSTMRNN()
7374
{
@@ -453,31 +454,24 @@ public override void initWeights()
453454
}
454455
}
455456

456-
public void LSTMCellInit(LSTMCell c)
457+
public void LSTMCellInit(LSTMCell c, bool bBias = false)
457458
{
458-
//input gate
459-
c.netIn = 0;
460-
c.yIn = 0;
461-
462-
//forget gate
463-
c.netForget = 0;
464-
c.yForget = 0;
465-
466-
//cell state
467-
c.netCellState = 0;
468-
c.previousCellState = 0; //this is important
459+
c.previousCellState = 0;
469460
c.cellState = 0;
470461

471462
//partial derivatives
472463
c.dSWCellIn = 0;
473464
c.dSWCellForget = 0;
474465

475-
//output gate
476-
c.netOut = 0;
477-
c.yOut = 0;
478-
479-
//cell output
480-
c.cellOutput = 0;
466+
if (bBias == false)
467+
{
468+
//cell output
469+
c.cellOutput = 0;
470+
}
471+
else
472+
{
473+
c.cellOutput = 1.0;
474+
}
481475
}
482476

483477
public override void CleanStatus()
@@ -544,7 +538,7 @@ private void CreateCell(BinaryReader br)
544538
for (int i = 0; i < L1; i++)
545539
{
546540
neuHidden[i] = new LSTMCell();
547-
LSTMCellInit(neuHidden[i]);
541+
LSTMCellInit(neuHidden[i], i == L1 - 1);
548542
}
549543

550544
if (br != null)
@@ -626,27 +620,26 @@ public override void LearnNet(State state, int numStates, int curState)
626620
int sparseFeatureSize = sparse.Count;
627621

628622
//put variables for derivaties in weight class and cell class
629-
Parallel.For(0, L1, parallelOption, i =>
623+
Parallel.For(0, L1 - 1, parallelOption, i =>
630624
{
631625
LSTMCell c = neuHidden[i];
632626

633627
//using the error find the gradient of the output gate
634-
var gradientOutputGate = (float)(SigmoidDerivative(c.netOut) * c.cellState * c.er);
628+
var gradientOutputGate = (float)(SigmoidDerivative(c.netOut) * TanH(c.cellState) * c.er);
635629

636630
//internal cell state error
637-
var cellStateError = (float)(c.yOut * c.er);
631+
var cellStateError = (float)(c.er);
638632

639633
Vector4 vecErr = new Vector4(cellStateError, cellStateError, cellStateError, gradientOutputGate);
640-
vecErr = Vector4.Clamp(vecErr, vecMinGrad, vecMaxGrad);
641634

642635
var Sigmoid2Derivative_ci_netCellState_mul_ci_yIn = TanHDerivative(c.netCellState) * c.yIn;
643636
var Sigmoid2_ci_netCellState_mul_SigmoidDerivative_ci_netIn = TanH(c.netCellState) * SigmoidDerivative(c.netIn);
644637
var ci_previousCellState_mul_SigmoidDerivative_ci_netForget = c.previousCellState * SigmoidDerivative(c.netForget);
645638

646639
Vector3 vecDerivate = new Vector3(
647-
(float)Sigmoid2_ci_netCellState_mul_SigmoidDerivative_ci_netIn,
648-
(float)ci_previousCellState_mul_SigmoidDerivative_ci_netForget,
649-
(float)Sigmoid2Derivative_ci_netCellState_mul_ci_yIn);
640+
(float)(Sigmoid2_ci_netCellState_mul_SigmoidDerivative_ci_netIn * c.yOut),
641+
(float)(ci_previousCellState_mul_SigmoidDerivative_ci_netForget * c.yOut),
642+
(float)(Sigmoid2Derivative_ci_netCellState_mul_ci_yIn * c.yOut));
650643
float c_yForget = (float)c.yForget;
651644

652645

@@ -668,8 +661,9 @@ public override void LearnNet(State state, int numStates, int curState)
668661
//Computing final err delta
669662
Vector4 vecDelta = new Vector4(wd, entry.Value);
670663
vecDelta = vecErr * vecDelta;
664+
vecDelta = Vector4.Clamp(vecDelta, vecMinGrad, vecMaxGrad);
671665

672-
//Computing actual learning rate
666+
//Computing actual learning rate
673667
Vector4 vecLearningRate = ComputeLearningRate(vecDelta, ref wlr_i[entry.Key]);
674668
w_i[entry.Key] += vecLearningRate * vecDelta;
675669
}
@@ -693,6 +687,7 @@ public override void LearnNet(State state, int numStates, int curState)
693687

694688
Vector4 vecDelta = new Vector4(wd, feature);
695689
vecDelta = vecErr * vecDelta;
690+
vecDelta = Vector4.Clamp(vecDelta, vecMinGrad, vecMaxGrad);
696691

697692
//Computing actual learning rate
698693
Vector4 vecLearningRate = ComputeLearningRate(vecDelta, ref wlr_i[j]);
@@ -709,11 +704,10 @@ public override void LearnNet(State state, int numStates, int curState)
709704
//update internal weights
710705
Vector3 vecCellDelta = new Vector3((float)c.dSWCellIn, (float)c.dSWCellForget, (float)c.cellState);
711706
Vector3 vecCellErr = new Vector3(cellStateError, cellStateError, gradientOutputGate);
707+
vecCellDelta = vecCellErr * vecCellDelta;
712708

713709
//Normalize err by gradient cut-off
714-
vecCellErr = Vector3.Clamp(vecCellErr, vecMinGrad3, vecMaxGrad3);
715-
716-
vecCellDelta = vecCellErr * vecCellDelta;
710+
vecCellDelta = Vector3.Clamp(vecCellDelta, vecMinGrad3, vecMaxGrad3);
717711

718712
//Computing actual learning rate
719713
Vector3 vecCellLearningRate = ComputeLearningRate(vecCellDelta, ref CellLearningRate[i]);
@@ -737,7 +731,7 @@ public override void computeHiddenLayer(State state, bool isTrain = true)
737731
var sparse = state.SparseData;
738732
int sparseFeatureSize = sparse.Count;
739733

740-
Parallel.For(0, L1, parallelOption, j =>
734+
Parallel.For(0, L1 - 1, parallelOption, j =>
741735
{
742736
LSTMCell cell_j = neuHidden[j];
743737

@@ -780,14 +774,15 @@ public override void computeHiddenLayer(State state, bool isTrain = true)
780774
cell_j.netForget += cell_j.previousCellState * cell_j.wCellForget;
781775
cell_j.yForget = Sigmoid(cell_j.netForget);
782776

777+
cell_j.yCellState = TanH(cell_j.netCellState);
783778
if (cell_j.mask == true)
784779
{
785780
cell_j.cellState = 0;
786781
}
787782
else
788783
{
789784
//cell state is equal to the previous cell state multipled by the forget gate and the cell inputs multiplied by the input gate
790-
cell_j.cellState = cell_j.yForget * cell_j.previousCellState + cell_j.yIn * TanH(cell_j.netCellState);
785+
cell_j.cellState = cell_j.yForget * cell_j.previousCellState + cell_j.yIn * cell_j.yCellState;
791786
}
792787

793788
if (isTrain == false)
@@ -801,7 +796,7 @@ public override void computeHiddenLayer(State state, bool isTrain = true)
801796
//squash output gate
802797
cell_j.yOut = Sigmoid(cell_j.netOut);
803798

804-
cell_j.cellOutput = cell_j.cellState * cell_j.yOut;
799+
cell_j.cellOutput = TanH(cell_j.cellState) * cell_j.yOut;
805800

806801
neuHidden[j] = cell_j;
807802
});
@@ -825,7 +820,7 @@ public override void netReset(bool updateNet = false) //cleans hidden layer ac
825820
for (int i = 0; i < L1; i++)
826821
{
827822
neuHidden[i].mask = false;
828-
LSTMCellInit(neuHidden[i]);
823+
LSTMCellInit(neuHidden[i], i == L1 - 1);
829824
}
830825

831826
if (Dropout > 0 && updateNet == true)

RNNSharp/RNNEncoder.cs

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,15 @@ public void Train()
109109
lastAlpha = rnn.LearningRate;
110110

111111
//Validate the model by validated corpus
112-
bool betterValidateNet = false;
113112
if (ValidationSet != null)
114113
{
115114
Logger.WriteLine("Verify model on validated corpus.");
116-
betterValidateNet = rnn.ValidateNet(ValidationSet, iter);
115+
if (rnn.ValidateNet(ValidationSet, iter) == true)
116+
{
117+
//We got better result on validated corpus, save this model
118+
Logger.WriteLine("Saving better model into file {0}...", m_modelSetting.ModelFile);
119+
rnn.SaveModel(m_modelSetting.ModelFile);
120+
}
117121
}
118122

119123
if (ppl >= lastPPL)
@@ -122,26 +126,6 @@ public void Train()
122126
rnn.LearningRate = rnn.LearningRate / 2.0f;
123127
}
124128

125-
if (betterValidateNet == true)
126-
{
127-
//We got better result on validated corpus, save this model
128-
Logger.WriteLine("Saving better model into file {0}...", m_modelSetting.ModelFile);
129-
rnn.SaveModel(m_modelSetting.ModelFile);
130-
}
131-
132-
133-
//if ((ValidationSet != null && betterValidateNet == false) ||
134-
// (ValidationSet == null && ppl >= lastPPL))
135-
//{
136-
// rnn.LearningRate = rnn.LearningRate / 2.0f;
137-
//}
138-
//else
139-
//{
140-
// //If current model is better than before, save it into file
141-
// Logger.WriteLine("Saving better model into file {0}...", m_modelSetting.ModelFile);
142-
// rnn.SaveModel(m_modelSetting.ModelFile);
143-
//}
144-
145129
lastPPL = ppl;
146130

147131
iter++;

0 commit comments

Comments
 (0)