Skip to content

Commit

Permalink
New network with selfplay data (liamt19#65)
Browse files Browse the repository at this point in the history
bench: 4743305
  • Loading branch information
liamt19 authored Aug 20, 2024
1 parent 5cc2260 commit 3416a7b
Show file tree
Hide file tree
Showing 17 changed files with 436 additions and 199 deletions.
17 changes: 15 additions & 2 deletions Logic/Core/Position.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public unsafe partial class Position
public bool Checked => popcount(State->Checkers) != 0;

public ulong Hash => State->Hash;

public ulong PawnHash => State->PawnHash;

/// <summary>
/// The number of <see cref="StateInfo"/> items that memory will be allocated for within the StateStack, which is 256 KB.
Expand Down Expand Up @@ -88,6 +88,9 @@ public bool CanCastle(ulong boardOcc, ulong ourOcc, CastlingStatus cr)
[MethodImpl(Inline)]
public bool HasNonPawnMaterial(int pc) => (((bb.Occupancy ^ bb.Pieces[Pawn] ^ bb.Pieces[King]) & bb.Colors[pc]) != 0);

[MethodImpl(Inline)]
public bool IsCapture(Move m) => ((bb.GetPieceAtIndex(m.To) != None && !m.IsCastle) || m.IsEnPassant);


/// <summary>
/// Creates a new Position object and loads the provided FEN.
Expand Down Expand Up @@ -293,6 +296,10 @@ public void MakeMove(Move move)
// If we are capturing a rook, make sure that if that we remove that castling status from them if necessary.
RemoveCastling(GetCastlingForRook(moveTo));
}
else if (theirPiece == Pawn)
{
State->PawnHash.ZobristToggleSquare(theirColor, Pawn, moveTo);
}

// Reset the halfmove clock
State->HalfmoveClock = 0;
Expand All @@ -316,6 +323,7 @@ public void MakeMove(Move move)
int idxPawn = ((bb.Pieces[Pawn] & SquareBB[tempEPSquare - 8]) != 0) ? tempEPSquare - 8 : tempEPSquare + 8;
bb.RemovePiece(idxPawn, theirColor, Pawn);
State->Hash.ZobristToggleSquare(theirColor, Pawn, idxPawn);
State->PawnHash.ZobristToggleSquare(theirColor, Pawn, idxPawn);

// The EnPassant/Capture flags are mutually exclusive, so set CapturedPiece here
State->CapturedPiece = Pawn;
Expand All @@ -332,6 +340,8 @@ public void MakeMove(Move move)
}
}

State->PawnHash.ZobristMove(moveFrom, moveTo, ourColor, ourPiece);

// Reset the halfmove clock
State->HalfmoveClock = 0;
}
Expand All @@ -352,6 +362,8 @@ public void MakeMove(Move move)

State->Hash.ZobristToggleSquare(ourColor, ourPiece, moveTo);
State->Hash.ZobristToggleSquare(ourColor, move.PromotionTo, moveTo);

State->PawnHash.ZobristToggleSquare(ourColor, ourPiece, moveTo);
}

State->Hash.ZobristChangeToMove();
Expand Down Expand Up @@ -511,7 +523,8 @@ public void SetState()

SetCheckInfo();

State->Hash = Zobrist.GetHash(this);
State->PawnHash = 0;
State->Hash = Zobrist.GetHash(this, &State->PawnHash);
}


Expand Down
17 changes: 9 additions & 8 deletions Logic/Core/StateInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@ public unsafe struct StateInfo
[FieldOffset( 64)] public fixed ulong Pinners[2];
[FieldOffset( 80)] public fixed int KingSquares[2];
[FieldOffset( 88)] public ulong Hash = 0;
[FieldOffset( 96)] public ulong Checkers = 0;
[FieldOffset(104)] public int HalfmoveClock = 0;
[FieldOffset(108)] public int EPSquare = EPNone;
[FieldOffset(112)] public int CapturedPiece = None;
[FieldOffset(116)] public int PliesFromNull = 0;
[FieldOffset(120)] public CastlingStatus CastleStatus = CastlingStatus.None;
[FieldOffset(124)] private fixed byte _pad0[4];
[FieldOffset(128)] public Accumulator* Accumulator;
[FieldOffset( 96)] public ulong PawnHash = 0;
[FieldOffset(104)] public ulong Checkers = 0;
[FieldOffset(112)] public int HalfmoveClock = 0;
[FieldOffset(116)] public int EPSquare = EPNone;
[FieldOffset(120)] public int CapturedPiece = None;
[FieldOffset(124)] public int PliesFromNull = 0;
[FieldOffset(128)] public CastlingStatus CastleStatus = CastlingStatus.None;
[FieldOffset(132)] private fixed byte _pad0[4];
[FieldOffset(136)] public Accumulator* Accumulator;

public StateInfo()
{
Expand Down
52 changes: 33 additions & 19 deletions Logic/NN/Bucketed768.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace Lizard.Logic.NN
{
public static unsafe partial class Bucketed768
{
public const int InputBuckets = 5;
public const int InputBuckets = 1;
public const int InputSize = 768;
public const int HiddenSize = 1536;
public const int OutputBuckets = 8;
Expand All @@ -22,9 +22,9 @@ public static unsafe partial class Bucketed768
public const int OutputScale = 400;

/// <summary>
/// (768x5 -> 1536)x2 -> 8
/// (768 -> 1536)x2 -> 8
/// </summary>
public const string NetworkName = "L1536x5x8_cos51_from315_dfrc08b-680.bin";
public const string NetworkName = "net-009-250.bin";

public static readonly short* FeatureWeights;
public static readonly short* FeatureBiases;
Expand All @@ -41,14 +41,14 @@ public static unsafe partial class Bucketed768

private static ReadOnlySpan<int> KingBuckets =>
[
0, 0, 1, 1, 6, 6, 5, 5,
2, 2, 3, 3, 8, 8, 7, 7,
4, 4, 4, 4, 9, 9, 9, 9,
4, 4, 4, 4, 9, 9, 9, 9,
4, 4, 4, 4, 9, 9, 9, 9,
4, 4, 4, 4, 9, 9, 9, 9,
4, 4, 4, 4, 9, 9, 9, 9,
4, 4, 4, 4, 9, 9, 9, 9,
0, 0, 0, 0, 1, 1, 1, 1,
0, 0, 0, 0, 1, 1, 1, 1,
0, 0, 0, 0, 1, 1, 1, 1,
0, 0, 0, 0, 1, 1, 1, 1,
0, 0, 0, 0, 1, 1, 1, 1,
0, 0, 0, 0, 1, 1, 1, 1,
0, 0, 0, 0, 1, 1, 1, 1,
0, 0, 0, 0, 1, 1, 1, 1,
];

public static int BucketForPerspective(int ksq, int perspective) => (KingBuckets[perspective == Black ? (ksq ^ 56) : ksq]);
Expand Down Expand Up @@ -157,7 +157,6 @@ public static void RefreshAccumulatorPerspectiveFull(Position pos, int perspecti
int pc = bb.GetColorAtIndex(pieceIdx);

int idx = FeatureIndexSingle(pc, pt, pieceIdx, ourKing, perspective);
var ourWeights = (Vector512<short>*)(FeatureWeights + idx);
UnrollAdd(ourAccumulation, ourAccumulation, FeatureWeights + idx);
}

Expand Down Expand Up @@ -389,24 +388,39 @@ public static void MakeMove(Position pos, Move m)
(int wFrom, int bFrom) = FeatureIndex(us, ourPiece, moveFrom, wKing, bKing);
(int wTo, int bTo) = FeatureIndex(us, m.IsPromotion ? m.PromotionTo : ourPiece, moveTo, wKing, bKing);

wUpdate.PushSubAdd(wFrom, wTo);
bUpdate.PushSubAdd(bFrom, bTo);
if (m.IsCastle)
{
int rookFromSq = moveTo;
int rookToSq = m.CastlingRookSquare();

(wTo, bTo) = FeatureIndex(us, ourPiece, m.CastlingKingSquare(), wKing, bKing);

(int wRookFrom, int bRookFrom) = FeatureIndex(us, Rook, rookFromSq, wKing, bKing);
(int wRookTo, int bRookTo) = FeatureIndex(us, Rook, rookToSq, wKing, bKing);

if (theirPiece != None)
wUpdate.PushSubSubAddAdd(wFrom, wRookFrom, wTo, wRookTo);
bUpdate.PushSubSubAddAdd(bFrom, bRookFrom, bTo, bRookTo);
}
else if (theirPiece != None)
{
(int wCap, int bCap) = FeatureIndex(them, theirPiece, moveTo, wKing, bKing);

wUpdate.PushSub(wCap);
bUpdate.PushSub(bCap);
wUpdate.PushSubSubAdd(wFrom, wCap, wTo);
bUpdate.PushSubSubAdd(bFrom, bCap, bTo);
}
else if (m.IsEnPassant)
{
int idxPawn = moveTo - ShiftUpDir(us);

(int wCap, int bCap) = FeatureIndex(them, Pawn, idxPawn, wKing, bKing);

wUpdate.PushSub(wCap);
bUpdate.PushSub(bCap);
wUpdate.PushSubSubAdd(wFrom, wCap, wTo);
bUpdate.PushSubSubAdd(bFrom, bCap, bTo);
}
else
{
wUpdate.PushSubAdd(wFrom, wTo);
bUpdate.PushSubAdd(bFrom, bTo);
}
}
}
Expand Down
39 changes: 22 additions & 17 deletions Logic/NN/Bucketed768Unroll.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,6 @@ public static unsafe partial class Bucketed768

private const int StopBefore = HiddenSize / N;

private const int AVX512_1024HL = 1024 / 32;
private const int AVX512_1536HL = 1536 / 32;

private const int AVX256_1024HL = 1024 / 16;
private const int AVX256_1536HL = 1536 / 16;

public static int GetEvaluationUnrolled512(Position pos)
{
Expand Down Expand Up @@ -69,6 +64,8 @@ public static int GetEvaluationUnrolled512(Position pos)
sumVec = VectorT.Add(sumVec, SIMDClass.MultiplyAddAdjacent(c_us_6, SIMDClass.MultiplyLow(c_us_6, VectorT.LoadAligned(ourWeights + 6 * N))));
sumVec = VectorT.Add(sumVec, SIMDClass.MultiplyAddAdjacent(c_us_7, SIMDClass.MultiplyLow(c_us_7, VectorT.LoadAligned(ourWeights + 7 * N))));

if (StopBefore == 8) goto NSTM;

var c_us_8 = VectorT.Min(maxVec, VectorT.Max(zeroVec, VectorT.LoadAligned(ourData + 8 * N)));
var c_us_9 = VectorT.Min(maxVec, VectorT.Max(zeroVec, VectorT.LoadAligned(ourData + 9 * N)));
var c_us_10 = VectorT.Min(maxVec, VectorT.Max(zeroVec, VectorT.LoadAligned(ourData + 10 * N)));
Expand All @@ -86,6 +83,8 @@ public static int GetEvaluationUnrolled512(Position pos)
sumVec = VectorT.Add(sumVec, SIMDClass.MultiplyAddAdjacent(c_us_14, SIMDClass.MultiplyLow(c_us_14, VectorT.LoadAligned(ourWeights + 14 * N))));
sumVec = VectorT.Add(sumVec, SIMDClass.MultiplyAddAdjacent(c_us_15, SIMDClass.MultiplyLow(c_us_15, VectorT.LoadAligned(ourWeights + 15 * N))));

if (StopBefore == 16) goto NSTM;

var c_us_16 = VectorT.Min(maxVec, VectorT.Max(zeroVec, VectorT.LoadAligned(ourData + 16 * N)));
var c_us_17 = VectorT.Min(maxVec, VectorT.Max(zeroVec, VectorT.LoadAligned(ourData + 17 * N)));
var c_us_18 = VectorT.Min(maxVec, VectorT.Max(zeroVec, VectorT.LoadAligned(ourData + 18 * N)));
Expand Down Expand Up @@ -120,8 +119,7 @@ public static int GetEvaluationUnrolled512(Position pos)
sumVec = VectorT.Add(sumVec, SIMDClass.MultiplyAddAdjacent(c_us_30, SIMDClass.MultiplyLow(c_us_30, VectorT.LoadAligned(ourWeights + 30 * N))));
sumVec = VectorT.Add(sumVec, SIMDClass.MultiplyAddAdjacent(c_us_31, SIMDClass.MultiplyLow(c_us_31, VectorT.LoadAligned(ourWeights + 31 * N))));

if (StopBefore == AVX512_1024HL)
goto NSTM;
if (StopBefore == 32) goto NSTM;

var c_us_32 = VectorT.Min(maxVec, VectorT.Max(zeroVec, VectorT.LoadAligned(ourData + 32 * N)));
var c_us_33 = VectorT.Min(maxVec, VectorT.Max(zeroVec, VectorT.LoadAligned(ourData + 33 * N)));
Expand Down Expand Up @@ -157,8 +155,7 @@ public static int GetEvaluationUnrolled512(Position pos)
sumVec = VectorT.Add(sumVec, SIMDClass.MultiplyAddAdjacent(c_us_46, SIMDClass.MultiplyLow(c_us_46, VectorT.LoadAligned(ourWeights + 46 * N))));
sumVec = VectorT.Add(sumVec, SIMDClass.MultiplyAddAdjacent(c_us_47, SIMDClass.MultiplyLow(c_us_47, VectorT.LoadAligned(ourWeights + 47 * N))));

if (StopBefore == AVX512_1536HL)
goto NSTM;
if (StopBefore == 48) goto NSTM;

var c_us_48 = VectorT.Min(maxVec, VectorT.Max(zeroVec, VectorT.LoadAligned(ourData + 48 * N)));
var c_us_49 = VectorT.Min(maxVec, VectorT.Max(zeroVec, VectorT.LoadAligned(ourData + 49 * N)));
Expand Down Expand Up @@ -194,8 +191,7 @@ public static int GetEvaluationUnrolled512(Position pos)
sumVec = VectorT.Add(sumVec, SIMDClass.MultiplyAddAdjacent(c_us_62, SIMDClass.MultiplyLow(c_us_62, VectorT.LoadAligned(ourWeights + 62 * N))));
sumVec = VectorT.Add(sumVec, SIMDClass.MultiplyAddAdjacent(c_us_63, SIMDClass.MultiplyLow(c_us_63, VectorT.LoadAligned(ourWeights + 63 * N))));

if (StopBefore == AVX256_1024HL)
goto NSTM;
if (StopBefore == 64) goto NSTM;

var c_us_64 = VectorT.Min(maxVec, VectorT.Max(zeroVec, VectorT.LoadAligned(ourData + 64 * N)));
var c_us_65 = VectorT.Min(maxVec, VectorT.Max(zeroVec, VectorT.LoadAligned(ourData + 65 * N)));
Expand Down Expand Up @@ -231,6 +227,8 @@ public static int GetEvaluationUnrolled512(Position pos)
sumVec = VectorT.Add(sumVec, SIMDClass.MultiplyAddAdjacent(c_us_78, SIMDClass.MultiplyLow(c_us_78, VectorT.LoadAligned(ourWeights + 78 * N))));
sumVec = VectorT.Add(sumVec, SIMDClass.MultiplyAddAdjacent(c_us_79, SIMDClass.MultiplyLow(c_us_79, VectorT.LoadAligned(ourWeights + 79 * N))));

if (StopBefore == 80) goto NSTM;

var c_us_80 = VectorT.Min(maxVec, VectorT.Max(zeroVec, VectorT.LoadAligned(ourData + 80 * N)));
var c_us_81 = VectorT.Min(maxVec, VectorT.Max(zeroVec, VectorT.LoadAligned(ourData + 81 * N)));
var c_us_82 = VectorT.Min(maxVec, VectorT.Max(zeroVec, VectorT.LoadAligned(ourData + 82 * N)));
Expand Down Expand Up @@ -265,6 +263,8 @@ public static int GetEvaluationUnrolled512(Position pos)
sumVec = VectorT.Add(sumVec, SIMDClass.MultiplyAddAdjacent(c_us_94, SIMDClass.MultiplyLow(c_us_94, VectorT.LoadAligned(ourWeights + 94 * N))));
sumVec = VectorT.Add(sumVec, SIMDClass.MultiplyAddAdjacent(c_us_95, SIMDClass.MultiplyLow(c_us_95, VectorT.LoadAligned(ourWeights + 95 * N))));

if (StopBefore == 96) goto NSTM;

#endregion


Expand All @@ -290,6 +290,8 @@ public static int GetEvaluationUnrolled512(Position pos)
sumVec = VectorT.Add(sumVec, SIMDClass.MultiplyAddAdjacent(c_them_6, SIMDClass.MultiplyLow(c_them_6, VectorT.LoadAligned(theirWeights + 6 * N))));
sumVec = VectorT.Add(sumVec, SIMDClass.MultiplyAddAdjacent(c_them_7, SIMDClass.MultiplyLow(c_them_7, VectorT.LoadAligned(theirWeights + 7 * N))));

if (StopBefore == 8) goto END;

var c_them_8 = VectorT.Min(maxVec, VectorT.Max(zeroVec, VectorT.LoadAligned(theirData + 8 * N)));
var c_them_9 = VectorT.Min(maxVec, VectorT.Max(zeroVec, VectorT.LoadAligned(theirData + 9 * N)));
var c_them_10 = VectorT.Min(maxVec, VectorT.Max(zeroVec, VectorT.LoadAligned(theirData + 10 * N)));
Expand All @@ -307,6 +309,8 @@ public static int GetEvaluationUnrolled512(Position pos)
sumVec = VectorT.Add(sumVec, SIMDClass.MultiplyAddAdjacent(c_them_14, SIMDClass.MultiplyLow(c_them_14, VectorT.LoadAligned(theirWeights + 14 * N))));
sumVec = VectorT.Add(sumVec, SIMDClass.MultiplyAddAdjacent(c_them_15, SIMDClass.MultiplyLow(c_them_15, VectorT.LoadAligned(theirWeights + 15 * N))));

if (StopBefore == 16) goto END;

var c_them_16 = VectorT.Min(maxVec, VectorT.Max(zeroVec, VectorT.LoadAligned(theirData + 16 * N)));
var c_them_17 = VectorT.Min(maxVec, VectorT.Max(zeroVec, VectorT.LoadAligned(theirData + 17 * N)));
var c_them_18 = VectorT.Min(maxVec, VectorT.Max(zeroVec, VectorT.LoadAligned(theirData + 18 * N)));
Expand Down Expand Up @@ -341,8 +345,7 @@ public static int GetEvaluationUnrolled512(Position pos)
sumVec = VectorT.Add(sumVec, SIMDClass.MultiplyAddAdjacent(c_them_30, SIMDClass.MultiplyLow(c_them_30, VectorT.LoadAligned(theirWeights + 30 * N))));
sumVec = VectorT.Add(sumVec, SIMDClass.MultiplyAddAdjacent(c_them_31, SIMDClass.MultiplyLow(c_them_31, VectorT.LoadAligned(theirWeights + 31 * N))));

if (StopBefore == AVX512_1024HL)
goto END;
if (StopBefore == 32) goto END;

var c_them_32 = VectorT.Min(maxVec, VectorT.Max(zeroVec, VectorT.LoadAligned(theirData + 32 * N)));
var c_them_33 = VectorT.Min(maxVec, VectorT.Max(zeroVec, VectorT.LoadAligned(theirData + 33 * N)));
Expand Down Expand Up @@ -378,8 +381,7 @@ public static int GetEvaluationUnrolled512(Position pos)
sumVec = VectorT.Add(sumVec, SIMDClass.MultiplyAddAdjacent(c_them_46, SIMDClass.MultiplyLow(c_them_46, VectorT.LoadAligned(theirWeights + 46 * N))));
sumVec = VectorT.Add(sumVec, SIMDClass.MultiplyAddAdjacent(c_them_47, SIMDClass.MultiplyLow(c_them_47, VectorT.LoadAligned(theirWeights + 47 * N))));

if (StopBefore == AVX512_1536HL)
goto END;
if (StopBefore == 48) goto END;

var c_them_48 = VectorT.Min(maxVec, VectorT.Max(zeroVec, VectorT.LoadAligned(theirData + 48 * N)));
var c_them_49 = VectorT.Min(maxVec, VectorT.Max(zeroVec, VectorT.LoadAligned(theirData + 49 * N)));
Expand Down Expand Up @@ -415,8 +417,7 @@ public static int GetEvaluationUnrolled512(Position pos)
sumVec = VectorT.Add(sumVec, SIMDClass.MultiplyAddAdjacent(c_them_62, SIMDClass.MultiplyLow(c_them_62, VectorT.LoadAligned(theirWeights + 62 * N))));
sumVec = VectorT.Add(sumVec, SIMDClass.MultiplyAddAdjacent(c_them_63, SIMDClass.MultiplyLow(c_them_63, VectorT.LoadAligned(theirWeights + 63 * N))));

if (StopBefore == AVX256_1024HL)
goto END;
if (StopBefore == 64) goto END;

var c_them_64 = VectorT.Min(maxVec, VectorT.Max(zeroVec, VectorT.LoadAligned(theirData + 64 * N)));
var c_them_65 = VectorT.Min(maxVec, VectorT.Max(zeroVec, VectorT.LoadAligned(theirData + 65 * N)));
Expand Down Expand Up @@ -452,6 +453,8 @@ public static int GetEvaluationUnrolled512(Position pos)
sumVec = VectorT.Add(sumVec, SIMDClass.MultiplyAddAdjacent(c_them_78, SIMDClass.MultiplyLow(c_them_78, VectorT.LoadAligned(theirWeights + 78 * N))));
sumVec = VectorT.Add(sumVec, SIMDClass.MultiplyAddAdjacent(c_them_79, SIMDClass.MultiplyLow(c_them_79, VectorT.LoadAligned(theirWeights + 79 * N))));

if (StopBefore == 80) goto END;

var c_them_80 = VectorT.Min(maxVec, VectorT.Max(zeroVec, VectorT.LoadAligned(theirData + 80 * N)));
var c_them_81 = VectorT.Min(maxVec, VectorT.Max(zeroVec, VectorT.LoadAligned(theirData + 81 * N)));
var c_them_82 = VectorT.Min(maxVec, VectorT.Max(zeroVec, VectorT.LoadAligned(theirData + 82 * N)));
Expand Down Expand Up @@ -486,6 +489,8 @@ public static int GetEvaluationUnrolled512(Position pos)
sumVec = VectorT.Add(sumVec, SIMDClass.MultiplyAddAdjacent(c_them_94, SIMDClass.MultiplyLow(c_them_94, VectorT.LoadAligned(theirWeights + 94 * N))));
sumVec = VectorT.Add(sumVec, SIMDClass.MultiplyAddAdjacent(c_them_95, SIMDClass.MultiplyLow(c_them_95, VectorT.LoadAligned(theirWeights + 95 * N))));

if (StopBefore == 96) goto END;

#endregion


Expand Down
Loading

0 comments on commit 3416a7b

Please sign in to comment.