Skip to content

Commit ad9a7cf

Browse files
committed
Simplify matrixVectorAdd interface
1 parent 3336396 commit ad9a7cf

File tree

4 files changed

+17
-20
lines changed

4 files changed

+17
-20
lines changed

RNNSharp/BiRNN.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ public SimpleCell[][] InnerDecode(Sequence pSequence, out SimpleCell[][] outputH
262262
seqOutput[curState] = InitSimpleCell(L2);
263263
SimpleCell[] outputCells = seqOutput[curState];
264264

265-
matrixXvectorADD(outputCells, mergedHiddenLayer[curState], Hidden2OutputWeight, 0, L2, 0, L1, 0);
265+
matrixXvectorADD(outputCells, mergedHiddenLayer[curState], Hidden2OutputWeight, L2, L1, 0);
266266

267267
for (int i = 0; i < L2; i++)
268268
{

RNNSharp/LSTMRNN.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,7 @@ public override void computeHiddenLayer(State state, bool isTrain = true)
740740

741741
public override void computeOutput(double[] doutput)
742742
{
743-
matrixXvectorADD(OutputLayer, neuHidden, Hidden2OutputWeight, 0, L2, 0, L1, 0);
743+
matrixXvectorADD(OutputLayer, neuHidden, Hidden2OutputWeight, L2, L1, 0);
744744
if (doutput != null)
745745
{
746746
for (int i = 0; i < L2; i++)

RNNSharp/RNN.cs

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -586,39 +586,36 @@ protected double NormalizeErr(double err)
586586
return err;
587587
}
588588

589-
public void matrixXvectorADD(SimpleCell[] dest, SimpleCell[] srcvec, Matrix<double> srcmatrix, int from, int to, int from2, int to2, int type)
589+
public void matrixXvectorADD(SimpleCell[] dest, SimpleCell[] srcvec, Matrix<double> srcmatrix, int DestSize, int SrcSize, int type)
590590
{
591591
if (type == 0)
592592
{
593593
//ac mod
594-
Parallel.For(0, (to - from), parallelOption, i =>
594+
Parallel.For(0, DestSize, parallelOption, i =>
595595
{
596-
SimpleCell cell = dest[i + from];
596+
SimpleCell cell = dest[i];
597597
double[] vector_i = srcmatrix[i];
598598
cell.cellOutput = 0;
599-
for (int j = 0; j < to2 - from2; j++)
599+
for (int j = 0; j < SrcSize; j++)
600600
{
601-
cell.cellOutput += srcvec[j + from2].cellOutput * vector_i[j];
601+
cell.cellOutput += srcvec[j].cellOutput * vector_i[j];
602602
}
603603
});
604604

605605
}
606606
else
607607
{
608-
Parallel.For(0, (to - from), parallelOption, i =>
608+
Parallel.For(0, DestSize, parallelOption, i =>
609609
{
610-
SimpleCell cell = dest[i + from];
610+
SimpleCell cell = dest[i];
611611
cell.er = 0;
612-
for (int j = 0; j < to2 - from2; j++)
612+
for (int j = 0; j < SrcSize; j++)
613613
{
614-
cell.er += srcvec[j + from2].er * srcmatrix[j][i];
614+
cell.er += srcvec[j].er * srcmatrix[j][i];
615615
}
616-
});
617616

618-
for (int i = from; i < to; i++)
619-
{
620-
dest[i].er = NormalizeErr(dest[i].er);
621-
}
617+
cell.er = NormalizeErr(cell.er);
618+
});
622619
}
623620
}
624621

RNNSharp/SimpleRNN.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ public override void computeHiddenLayer(State state, bool isTrain = true)
137137

138138
//hidden(t-1) -> hidden(t)
139139
neuHidden = InitSimpleCell(L1);
140-
matrixXvectorADD(neuHidden, neuLastHidden, mat_hiddenBpttWeight, 0, L1, 0, L1, 0);
140+
matrixXvectorADD(neuHidden, neuLastHidden, mat_hiddenBpttWeight, L1, L1, 0);
141141

142142
//Apply feature values on hidden layer
143143
var sparse = state.SparseData;
@@ -171,7 +171,7 @@ public override void computeHiddenLayer(State state, bool isTrain = true)
171171
public override void computeOutput(double[] doutput)
172172
{
173173
//Calculate output layer
174-
matrixXvectorADD(OutputLayer, neuHidden, Hidden2OutputWeight, 0, L2, 0, L1, 0);
174+
matrixXvectorADD(OutputLayer, neuHidden, Hidden2OutputWeight, L2, L1, 0);
175175
if (doutput != null)
176176
{
177177
for (int i = 0; i < L2; i++)
@@ -187,7 +187,7 @@ public override void computeOutput(double[] doutput)
187187
public override void ComputeHiddenLayerErr()
188188
{
189189
//error output->hidden for words from specific class
190-
matrixXvectorADD(neuHidden, OutputLayer, Hidden2OutputWeight, 0, L1, 0, L2, 1);
190+
matrixXvectorADD(neuHidden, OutputLayer, Hidden2OutputWeight, L1, L2, 1);
191191

192192
//Apply drop out on error in hidden layer
193193
for (int i = 0; i < L1; i++)
@@ -254,7 +254,7 @@ void learnBptt(State state)
254254
});
255255

256256
//propagates errors hidden->input to the recurrent part
257-
matrixXvectorADD(neuLastHidden, neuHidden, mat_hiddenBpttWeight, 0, L1, 0, L1, 1);
257+
matrixXvectorADD(neuLastHidden, neuHidden, mat_hiddenBpttWeight, L1, L1, 1);
258258

259259
for (int a = 0; a < L1; a++)
260260
{

0 commit comments

Comments
 (0)