Skip to content

Commit 91a8a89

Browse files
committed
Improve SimpleRNN by SIMD instructions
1 parent 4b86462 commit 91a8a89

File tree

7 files changed

+172
-25
lines changed

7 files changed

+172
-25
lines changed

RNNSharp/BiRNN.cs

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System.Threading.Tasks;
44
using AdvUtils;
55
using System.Collections.Generic;
6+
using System.Numerics;
67

78
/// <summary>
89
/// RNNSharp written by Zhongkai Fu ([email protected])
@@ -13,6 +14,7 @@ class BiRNN : RNN
1314
{
1415
private RNN forwardRNN;
1516
private RNN backwardRNN;
17+
private Vector<double> vecConst2 = new Vector<double>(2.0f);
1618

1719
public BiRNN(RNN s_forwardRNN, RNN s_backwardRNN)
1820
{
@@ -56,7 +58,7 @@ public override void CleanStatus()
5658
forwardRNN.CleanStatus();
5759
backwardRNN.CleanStatus();
5860

59-
Hidden2OutputWeightLearningRate = new Matrix<float>(L2, L1);
61+
Hidden2OutputWeightLearningRate = new Matrix<double>(L2, L1);
6062
}
6163

6264
public override void initWeights()
@@ -219,7 +221,7 @@ public override void InitMem()
219221
}
220222
}
221223

222-
Hidden2OutputWeightLearningRate = new Matrix<float>(L2, L1);
224+
Hidden2OutputWeightLearningRate = new Matrix<double>(L2, L1);
223225
}
224226

225227
public SimpleLayer[] InnerDecode(Sequence pSequence, out SimpleLayer[] outputHiddenLayer, out Matrix<double> rawOutputLayer)
@@ -266,9 +268,22 @@ public SimpleLayer[] InnerDecode(Sequence pSequence, out SimpleLayer[] outputHid
266268
SimpleLayer forwardCells = mForward[curState];
267269
SimpleLayer backwardCells = mBackward[curState];
268270

269-
for (int i = 0; i < forwardRNN.L1; i++)
271+
int i = 0;
272+
while (i < forwardRNN.L1 - Vector<double>.Count)
273+
{
274+
Vector<double> v1 = new Vector<double>(forwardCells.cellOutput, i);
275+
Vector<double> v2 = new Vector<double>(backwardCells.cellOutput, i);
276+
Vector<double> v = (v1 + v2) / vecConst2;
277+
278+
v.CopyTo(cells.cellOutput, i);
279+
280+
i += Vector<float>.Count;
281+
}
282+
283+
while (i < forwardRNN.L1)
270284
{
271285
cells.cellOutput[i] = (forwardCells.cellOutput[i] + backwardCells.cellOutput[i]) / 2.0;
286+
i++;
272287
}
273288
});
274289

RNNSharp/LSTMRNN.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ public override void CleanStatus()
496496

497497
});
498498

499-
Hidden2OutputWeightLearningRate = new Matrix<float>(L2, L1);
499+
Hidden2OutputWeightLearningRate = new Matrix<double>(L2, L1);
500500
vecLearningRate = new Vector4(LearningRate, LearningRate, LearningRate, LearningRate);
501501
vecLearningRate3 = new Vector3(LearningRate, LearningRate, LearningRate);
502502
}

RNNSharp/Matrix.cs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-

1+
using System.Numerics;
2+
23
/// <summary>
34
/// RNNSharp written by Zhongkai Fu ([email protected])
45
/// </summary>
56
namespace RNNSharp
67
{
7-
public class Matrix<T>
8+
public class Matrix<T> where T : struct
89
{
910

1011
public int Height { get; set; } // the number of rows
@@ -41,7 +42,22 @@ public Matrix<T> CopyTo()
4142

4243
for (int i = 0; i < Height; i++)
4344
{
44-
m_saData[i].CopyTo(m[i], 0);
45+
T[] m_i = m[i];
46+
T[] m_saData_i = m_saData[i];
47+
int j = 0;
48+
while (j < Width - Vector<T>.Count)
49+
{
50+
Vector<T> v1 = new Vector<T>(m_saData_i, j);
51+
v1.CopyTo(m_i, j);
52+
53+
j += Vector<T>.Count;
54+
}
55+
56+
while (j < Width)
57+
{
58+
m_i[j] = m_saData_i[j];
59+
j++;
60+
}
4561
}
4662

4763
return m;

RNNSharp/RNN.cs

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System.Threading.Tasks;
44
using System.IO;
55
using AdvUtils;
6+
using System.Numerics;
67

78
/// <summary>
89
/// RNNSharp written by Zhongkai Fu ([email protected])
@@ -61,7 +62,7 @@ abstract public class RNN
6162
public Matrix<double> CRFTagTransWeights { get; set; }
6263
public SimpleLayer OutputLayer { get; set; }
6364
public Matrix<double> Hidden2OutputWeight;
64-
public Matrix<float> Hidden2OutputWeightLearningRate;
65+
public Matrix<double> Hidden2OutputWeightLearningRate;
6566

6667
// CRF result output
6768
protected Matrix<double> CRFSeqOutput;
@@ -96,10 +97,10 @@ protected SimpleCell[] InitSimpleCell(int size)
9697
return cells;
9798
}
9899

99-
public double UpdateLearningRate(Matrix<float> m, int i, int j, double delta)
100+
public double UpdateLearningRate(Matrix<double> m, int i, int j, double delta)
100101
{
101102
double dg = m[i][j] + delta * delta;
102-
m[i][j] = (float)dg;
103+
m[i][j] = dg;
103104

104105
return LearningRate / (1.0 + Math.Sqrt(dg));
105106
}
@@ -644,10 +645,23 @@ public void matrixXvectorADD(SimpleLayer dest, SimpleLayer srcvec, Matrix<double
644645
{
645646
double[] vector_i = srcmatrix[i];
646647
double cellOutput = 0;
647-
for (int j = 0; j < SrcSize; j++)
648+
int j = 0;
649+
650+
while (j < SrcSize - Vector<double>.Count)
651+
{
652+
Vector<double> v1 = new Vector<double>(srcvec.cellOutput, j);
653+
Vector<double> v2 = new Vector<double>(vector_i, j);
654+
cellOutput += Vector.Dot<double>(v1, v2);
655+
656+
j += Vector<double>.Count;
657+
}
658+
659+
while (j < SrcSize)
648660
{
649661
cellOutput += srcvec.cellOutput[j] * vector_i[j];
662+
j++;
650663
}
664+
651665
dest.cellOutput[i] = cellOutput;
652666
});
653667

RNNSharp/RNNSharp.csproj

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@
3737
<Reference Include="System" />
3838
<Reference Include="System.Core" />
3939
<Reference Include="System.Numerics" />
40-
<Reference Include="System.Numerics.Vectors" />
40+
<Reference Include="System.Numerics.Vectors, Version=4.1.0.0, Culture=neutral, PublicKeyToken=b03f5f7f11d50a3a, processorArchitecture=MSIL">
41+
<HintPath>..\packages\System.Numerics.Vectors.4.1.0\lib\net46\System.Numerics.Vectors.dll</HintPath>
42+
<Private>True</Private>
43+
</Reference>
4144
<Reference Include="System.Xml.Linq" />
4245
<Reference Include="System.Data.DataSetExtensions" />
4346
<Reference Include="Microsoft.CSharp" />
@@ -70,6 +73,9 @@
7073
<Compile Include="Vector.cs" />
7174
<Compile Include="WordEMWrapFeaturizer.cs" />
7275
</ItemGroup>
76+
<ItemGroup>
77+
<None Include="packages.config" />
78+
</ItemGroup>
7379
<Import Project="$(MSBuildToolsPath)\Microsoft.CSharp.targets" />
7480
<!-- To modify your build process, add your task inside one of the targets below and uncomment it.
7581
Other similar extension points exist, see Microsoft.Common.targets.

RNNSharp/SimpleRNN.cs

Lines changed: 101 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using System.Threading.Tasks;
33
using System.IO;
44
using AdvUtils;
5+
using System.Numerics;
56

67
/// <summary>
78
/// RNNSharp written by Zhongkai Fu ([email protected])
@@ -31,9 +32,13 @@ public class SimpleRNN : RNN
3132
protected Matrix<double> Feature2HiddenWeights { get; set; }
3233

3334
//The learning ratio of each weight
34-
protected Matrix<float> HiddenBpttWeightsLearningRate { get; set; }
35-
protected Matrix<float> Input2HiddenWeightsLearningRate { get; set; }
36-
protected Matrix<float> Feature2HiddenWeightsLearningRate { get; set; }
35+
protected Matrix<double> HiddenBpttWeightsLearningRate { get; set; }
36+
protected Matrix<double> Input2HiddenWeightsLearningRate { get; set; }
37+
protected Matrix<double> Feature2HiddenWeightsLearningRate { get; set; }
38+
39+
protected Vector<double> vecMaxGrad;
40+
protected Vector<double> vecMinGrad;
41+
protected Vector<double> vecNormalLearningRate;
3742

3843
public SimpleRNN()
3944
{
@@ -257,28 +262,55 @@ private void learnBptt(State state)
257262
//dense weight update fea->0
258263
double[] vector_a = null;
259264
double er = neuHidden.er[a];
265+
Vector<double> vecErr = new Vector<double>(er);
266+
267+
int i = 0;
260268
if (DenseFeatureSize > 0)
261269
{
262270
vector_a = mat_bptt_synf[a];
263-
for (int i = 0; i < DenseFeatureSize; i++)
271+
i = 0;
272+
while (i < DenseFeatureSize - Vector<double>.Count)
273+
{
274+
Vector<double> v1 = new Vector<double>(bptt_fea_step, i);
275+
Vector<double> v2 = new Vector<double>(vector_a, i);
276+
v2 += vecErr * v1;
277+
v2.CopyTo(vector_a, i);
278+
279+
i += Vector<double>.Count;
280+
}
281+
282+
while (i < DenseFeatureSize)
264283
{
265284
vector_a[i] += er * bptt_fea_step[i];
285+
i++;
266286
}
267287
}
268288

269289
//sparse weight update hidden->input
270290
vector_a = mat_bptt_syn0_w[a];
271-
for (int i = 0; i < sparse.Count; i++)
291+
for (i = 0; i < sparse.Count; i++)
272292
{
273293
var entry = sparse.GetEntry(i);
274294
vector_a[entry.Key] += er * entry.Value;
275295
}
276296

277297
//bptt weight update
278298
vector_a = mat_bptt_syn0_ph[a];
279-
for (int i = 0; i < L1; i++)
299+
i = 0;
300+
while (i < L1 - Vector<double>.Count)
301+
{
302+
Vector<double> v1 = new Vector<double>(neuLastHidden.cellOutput, i);
303+
Vector<double> v2 = new Vector<double>(vector_a, i);
304+
v2 += vecErr * v1;
305+
v2.CopyTo(vector_a, i);
306+
307+
i += Vector<double>.Count;
308+
}
309+
310+
while(i < L1)
280311
{
281312
vector_a[i] += er * neuLastHidden.cellOutput[i];
313+
i++;
282314
}
283315

284316
});
@@ -308,33 +340,85 @@ private void learnBptt(State state)
308340
{
309341
double[] vector_b = null;
310342
double[] vector_bf = null;
343+
double[] vector_lr = null;
311344

312345
//Update bptt feature weights
313346
vector_b = HiddenBpttWeights[b];
314347
vector_bf = mat_bptt_syn0_ph[b];
315-
for (int i = 0; i < L1; i++)
348+
vector_lr = HiddenBpttWeightsLearningRate[b];
349+
350+
int i = 0;
351+
while (i < L1 - Vector<double>.Count)
352+
{
353+
Vector<double> vecDelta = new Vector<double>(vector_bf, i);
354+
Vector<double> vecLearningRate = new Vector<double>(vector_lr, i);
355+
Vector<double> vecB = new Vector<double>(vector_b, i);
356+
vecDelta = Vector.Min<double>(vecDelta, vecMaxGrad);
357+
vecDelta = Vector.Max<double>(vecDelta, vecMinGrad);
358+
359+
vecLearningRate += (vecDelta * vecDelta);
360+
vecLearningRate.CopyTo(vector_lr, i);
361+
vecLearningRate = vecNormalLearningRate / (Vector<double>.One + Vector.SquareRoot<double>(vecLearningRate));
362+
363+
vecB += (vecLearningRate * vecDelta);
364+
vecB.CopyTo(vector_b, i);
365+
366+
Vector<double>.Zero.CopyTo(vector_bf, i);
367+
368+
i += Vector<double>.Count;
369+
}
370+
371+
while (i < L1)
316372
{
317373
double delta = NormalizeGradient(vector_bf[i]);
318374
double newLearningRate = UpdateLearningRate(HiddenBpttWeightsLearningRate, b, i, delta);
319375

320376
vector_b[i] += newLearningRate * delta;
321377
//Clean bptt weight error
322378
vector_bf[i] = 0;
379+
380+
i++;
323381
}
324382

325383
//Update dense feature weights
326384
if (DenseFeatureSize > 0)
327385
{
328386
vector_b = Feature2HiddenWeights[b];
329387
vector_bf = mat_bptt_synf[b];
330-
for (int i = 0; i < DenseFeatureSize; i++)
388+
vector_lr = Feature2HiddenWeightsLearningRate[b];
389+
390+
i = 0;
391+
while (i < DenseFeatureSize - Vector<double>.Count)
392+
{
393+
Vector<double> vecDelta = new Vector<double>(vector_bf, i);
394+
Vector<double> vecLearningRate = new Vector<double>(vector_lr, i);
395+
Vector<double> vecB = new Vector<double>(vector_b, i);
396+
vecDelta = Vector.Min<double>(vecDelta, vecMaxGrad);
397+
vecDelta = Vector.Max<double>(vecDelta, vecMinGrad);
398+
399+
vecLearningRate += (vecDelta * vecDelta);
400+
vecLearningRate.CopyTo(vector_lr, i);
401+
vecLearningRate = vecNormalLearningRate / (Vector<double>.One + Vector.SquareRoot<double>(vecLearningRate));
402+
403+
vecB += (vecLearningRate * vecDelta);
404+
vecB.CopyTo(vector_b, i);
405+
406+
vecDelta = Vector<double>.Zero;
407+
vecDelta.CopyTo(vector_bf, i);
408+
409+
i += Vector<double>.Count;
410+
}
411+
412+
while (i < DenseFeatureSize)
331413
{
332414
double delta = NormalizeGradient(vector_bf[i]);
333415
double newLearningRate = UpdateLearningRate(Feature2HiddenWeightsLearningRate, b, i, delta);
334416

335417
vector_b[i] += newLearningRate * delta;
336418
//Clean dense feature weights error
337419
vector_bf[i] = 0;
420+
421+
i++;
338422
}
339423
}
340424

@@ -347,7 +431,7 @@ private void learnBptt(State state)
347431
if (sparse == null)
348432
break;
349433

350-
for (int i = 0; i < sparse.Count; i++)
434+
for (i = 0; i < sparse.Count; i++)
351435
{
352436
int pos = sparse.GetEntry(i).Key;
353437

@@ -387,10 +471,14 @@ public void resetBpttMem()
387471

388472
public override void CleanStatus()
389473
{
390-
Hidden2OutputWeightLearningRate = new Matrix<float>(L2, L1);
391-
Input2HiddenWeightsLearningRate = new Matrix<float>(L1, L0);
392-
Feature2HiddenWeightsLearningRate = new Matrix<float>(L1, DenseFeatureSize);
393-
HiddenBpttWeightsLearningRate = new Matrix<float>(L1, L1);
474+
Hidden2OutputWeightLearningRate = new Matrix<double>(L2, L1);
475+
Input2HiddenWeightsLearningRate = new Matrix<double>(L1, L0);
476+
Feature2HiddenWeightsLearningRate = new Matrix<double>(L1, DenseFeatureSize);
477+
HiddenBpttWeightsLearningRate = new Matrix<double>(L1, L1);
478+
479+
vecMaxGrad = new Vector<double>(GradientCutoff);
480+
vecMinGrad = new Vector<double>(-GradientCutoff);
481+
vecNormalLearningRate = new Vector<double>(LearningRate);
394482
}
395483
public override void InitMem()
396484
{

RNNSharp/packages.config

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
<?xml version="1.0" encoding="utf-8"?>
2+
<packages>
3+
<package id="System.Globalization" version="4.0.10" targetFramework="net46" />
4+
<package id="System.Numerics.Vectors" version="4.1.0" targetFramework="net46" />
5+
<package id="System.Resources.ResourceManager" version="4.0.0" targetFramework="net46" />
6+
<package id="System.Runtime" version="4.0.20" targetFramework="net46" />
7+
<package id="System.Runtime.Extensions" version="4.0.10" targetFramework="net46" />
8+
</packages>

0 commit comments

Comments
 (0)