Skip to content

Commit

Permalink
Refactored parts of NN code (liamt19#70)
Browse files Browse the repository at this point in the history
bench: 5142788
  • Loading branch information
liamt19 authored Aug 31, 2024
1 parent 2a95bad commit 4d426b7
Show file tree
Hide file tree
Showing 3 changed files with 637 additions and 498 deletions.
87 changes: 38 additions & 49 deletions Logic/NN/Bucketed768.cs
Original file line number Diff line number Diff line change
Expand Up @@ -95,25 +95,11 @@ public static void Initialize(string networkToLoad, bool exitIfFail = true)
}
}

for (int i = 0; i < FeatureWeightElements; i++)
{
FeatureWeights[i] = br.ReadInt16();
}
for (int i = 0; i < FeatureWeightElements; i++) FeatureWeights[i] = br.ReadInt16();
for (int i = 0; i < FeatureBiasElements; i++) FeatureBiases[i] = br.ReadInt16();

for (int i = 0; i < FeatureBiasElements; i++)
{
FeatureBiases[i] = br.ReadInt16();
}

for (int i = 0; i < LayerWeightElements; i++)
{
LayerWeights[i] = br.ReadInt16();
}

for (int i = 0; i < LayerBiasElements; i++)
{
LayerBiases[i] = br.ReadInt16();
}
for (int i = 0; i < LayerWeightElements; i++) LayerWeights[i] = br.ReadInt16();
for (int i = 0; i < LayerBiasElements; i++) LayerBiases[i] = br.ReadInt16();

// These weights are stored in column major order, but they are easier to use in row major order.
// The first 8 weights in the binary file are actually the first weight for each of the 8 output buckets,
Expand All @@ -125,7 +111,7 @@ public static void Initialize(string networkToLoad, bool exitIfFail = true)
NetStats("ft bias\t", FeatureBiases, FeatureBiasElements);

NetStats("fc weight", LayerWeights, LayerWeightElements);
NetStats("fc bias", LayerBiases, LayerBiasElements);
NetStats("fc bias\t", LayerBiases, LayerBiasElements);

Log("Init Bucketed768 done");
#endif
Expand All @@ -143,7 +129,7 @@ public static void RefreshAccumulatorPerspectiveFull(Position pos, int perspecti
ref Bitboard bb = ref pos.bb;

var ourAccumulation = (short*)accumulator[perspective];
Unsafe.CopyBlock(ourAccumulation, FeatureBiases, sizeof(short) * HiddenSize);
Unsafe.CopyBlock(ourAccumulation, FeatureBiases, Accumulator.ByteSize);
accumulator.NeedsRefresh[perspective] = false;
accumulator.Computed[perspective] = true;

Expand All @@ -157,7 +143,7 @@ public static void RefreshAccumulatorPerspectiveFull(Position pos, int perspecti
int pc = bb.GetColorAtIndex(pieceIdx);

int idx = FeatureIndexSingle(pc, pt, pieceIdx, ourKing, perspective);
UnrollAdd(ourAccumulation, ourAccumulation, FeatureWeights + idx);
UnrollAdd(ourAccumulation, ourAccumulation, &FeatureWeights[idx]);
}

if (pos.Owner.CachedBuckets == null)
Expand Down Expand Up @@ -203,14 +189,14 @@ public static void RefreshAccumulatorPerspective(Position pos, int perspective)
{
int sq = poplsb(&added);
int idx = FeatureIndexSingle(pc, pt, sq, ourKing, perspective);
UnrollAdd(ourAccumulation, ourAccumulation, FeatureWeights + idx);
UnrollAdd(ourAccumulation, ourAccumulation, &FeatureWeights[idx]);
}

while (removed != 0)
{
int sq = poplsb(&removed);
int idx = FeatureIndexSingle(pc, pt, sq, ourKing, perspective);
UnrollSubtract(ourAccumulation, ourAccumulation, FeatureWeights + idx);
UnrollSubtract(ourAccumulation, ourAccumulation, &FeatureWeights[idx]);
}
}
}
Expand All @@ -236,31 +222,31 @@ public static int GetEvaluation(Position pos)
int occ = (int)popcount(pos.bb.Occupancy);
int outputBucket = Math.Min((63 - occ) * (32 - occ) / 225, 7);

var ourData = (accumulator[pos.ToMove]);
var theirData = (accumulator[Not(pos.ToMove)]);
var ourWeights = (Vector256<short>*)(LayerWeights + (outputBucket * (HiddenSize * 2)));
var theirWeights = (Vector256<short>*)(LayerWeights + (outputBucket * (HiddenSize * 2)) + HiddenSize);
var ourData = accumulator[pos.ToMove];
var theirData = accumulator[Not(pos.ToMove)];
var ourWeights = (Vector256<short>*)(&LayerWeights[outputBucket * (HiddenSize * 2)]);
var theirWeights = (Vector256<short>*)(&LayerWeights[outputBucket * (HiddenSize * 2) + HiddenSize]);

for (int i = 0; i < SimdChunks; i++)
{
Vector256<short> clamp = Vector256.Min(maxVec, Vector256.Max(zeroVec, ourData[i]));
Vector256<short> mult = clamp * ourWeights[i];

(var loMult, var hiMult) = Vector256.Widen(mult);
(var loClamp, var hiClamp) = Vector256.Widen(clamp);
(var mLo, var mHi) = Vector256.Widen(mult);
(var cLo, var cHi) = Vector256.Widen(clamp);

sum = Vector256.Add(sum, Vector256.Add(loMult * loClamp, hiMult * hiClamp));
sum = Vector256.Add(sum, Vector256.Add(mLo * cLo, mHi * cHi));
}

for (int i = 0; i < SimdChunks; i++)
{
Vector256<short> clamp = Vector256.Min(maxVec, Vector256.Max(zeroVec, theirData[i]));
Vector256<short> mult = clamp * theirWeights[i];

(var loMult, var hiMult) = Vector256.Widen(mult);
(var loClamp, var hiClamp) = Vector256.Widen(clamp);
(var mLo, var mHi) = Vector256.Widen(mult);
(var cLo, var cHi) = Vector256.Widen(clamp);

sum = Vector256.Add(sum, Vector256.Add(loMult * loClamp, hiMult * hiClamp));
sum = Vector256.Add(sum, Vector256.Add(mLo * cLo, mHi * cHi));
}

int output = Vector256.Sum(sum);
Expand Down Expand Up @@ -311,8 +297,8 @@ private static (int, int) FeatureIndex(int pc, int pt, int sq, int wk, int bk)
bSq ^= 7;
}

int whiteIndex = (768 * KingBuckets[wk]) + (pc * ColorStride) + (pt * PieceStride) + (wSq);
int blackIndex = (768 * KingBuckets[bk]) + (Not(pc) * ColorStride) + (pt * PieceStride) + (bSq);
int whiteIndex = (768 * KingBuckets[wk]) + ( pc * ColorStride) + (pt * PieceStride) + wSq;
int blackIndex = (768 * KingBuckets[bk]) + (Not(pc) * ColorStride) + (pt * PieceStride) + bSq;

return (whiteIndex * HiddenSize, blackIndex * HiddenSize);
}
Expand Down Expand Up @@ -428,12 +414,15 @@ public static void MakeMove(Position pos, Move m)
[MethodImpl(Inline)]
public static void MakeNullMove(Position pos)
{
pos.State->Accumulator->CopyTo(pos.NextState->Accumulator);
var currAcc = pos.State->Accumulator;
var nextAcc = pos.NextState->Accumulator;

currAcc->CopyTo(nextAcc);

pos.NextState->Accumulator->Computed[White] = pos.State->Accumulator->Computed[White];
pos.NextState->Accumulator->Computed[Black] = pos.State->Accumulator->Computed[Black];
pos.NextState->Accumulator->Update[White].Clear();
pos.NextState->Accumulator->Update[Black].Clear();
nextAcc->Computed[White] = currAcc->Computed[White];
nextAcc->Computed[Black] = currAcc->Computed[Black];
nextAcc->Update[White].Clear();
nextAcc->Update[Black].Clear();
}


Expand Down Expand Up @@ -499,23 +488,23 @@ public static void UpdateSingle(Accumulator* prev, Accumulator* curr, int perspe
if (updates.AddCnt == 1 && updates.SubCnt == 1)
{
SubAdd(src, dst,
(FeatureWeights + updates.Subs[0]),
(FeatureWeights + updates.Adds[0]));
&FeatureWeights[updates.Subs[0]],
&FeatureWeights[updates.Adds[0]]);
}
else if (updates.AddCnt == 1 && updates.SubCnt == 2)
{
SubSubAdd(src, dst,
(FeatureWeights + updates.Subs[0]),
(FeatureWeights + updates.Subs[1]),
(FeatureWeights + updates.Adds[0]));
&FeatureWeights[updates.Subs[0]],
&FeatureWeights[updates.Subs[1]],
&FeatureWeights[updates.Adds[0]]);
}
else if (updates.AddCnt == 2 && updates.SubCnt == 2)
{
SubSubAddAdd(src, dst,
(FeatureWeights + updates.Subs[0]),
(FeatureWeights + updates.Subs[1]),
(FeatureWeights + updates.Adds[0]),
(FeatureWeights + updates.Adds[1]));
&FeatureWeights[updates.Subs[0]],
&FeatureWeights[updates.Subs[1]],
&FeatureWeights[updates.Adds[0]],
&FeatureWeights[updates.Adds[1]]);
}

curr->Computed[perspective] = true;
Expand Down
Loading

0 comments on commit 4d426b7

Please sign in to comment.