|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Implements tokenization of FEN strings.""" |
|
|
|
|
|
import jaxtyping as jtp |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
_CHARACTERS = [ |
|
|
'0', |
|
|
'1', |
|
|
'2', |
|
|
'3', |
|
|
'4', |
|
|
'5', |
|
|
'6', |
|
|
'7', |
|
|
'8', |
|
|
'9', |
|
|
'a', |
|
|
'b', |
|
|
'c', |
|
|
'd', |
|
|
'e', |
|
|
'f', |
|
|
'g', |
|
|
'h', |
|
|
'p', |
|
|
'n', |
|
|
'r', |
|
|
'k', |
|
|
'q', |
|
|
'P', |
|
|
'B', |
|
|
'N', |
|
|
'R', |
|
|
'Q', |
|
|
'K', |
|
|
'w', |
|
|
'.', |
|
|
] |
|
|
|
|
|
_CHARACTERS_INDEX = {letter: index for index, letter in enumerate(_CHARACTERS)} |
|
|
_SPACES_CHARACTERS = frozenset({'1', '2', '3', '4', '5', '6', '7', '8'}) |
|
|
SEQUENCE_LENGTH = 77 |
|
|
|
|
|
|
|
|
def tokenize(fen: str) -> jtp.Int32[jtp.Array, 'T']: |
|
|
"""Returns an array of tokens from a fen string. |
|
|
|
|
|
We compute a tokenized representation of the board, from the FEN string. |
|
|
The final array of tokens is a mapping from this string to numbers, which |
|
|
are defined in the dictionary `_CHARACTERS_INDEX`. |
|
|
For the 'en passant' information, we convert the '-' (which means there is |
|
|
no en passant relevant square) to '..', to always have two characters, and |
|
|
a fixed length output. |
|
|
|
|
|
Args: |
|
|
fen: The board position in Forsyth-Edwards Notation. |
|
|
""" |
|
|
|
|
|
board, side, castling, en_passant, halfmoves_last, fullmoves = fen.split(' ') |
|
|
board = board.replace('/', '') |
|
|
board = side + board |
|
|
|
|
|
indices = list() |
|
|
|
|
|
for char in board: |
|
|
if char in _SPACES_CHARACTERS: |
|
|
indices.extend(int(char) * [_CHARACTERS_INDEX['.']]) |
|
|
else: |
|
|
indices.append(_CHARACTERS_INDEX[char]) |
|
|
|
|
|
if castling == '-': |
|
|
indices.extend(4 * [_CHARACTERS_INDEX['.']]) |
|
|
else: |
|
|
for char in castling: |
|
|
indices.append(_CHARACTERS_INDEX[char]) |
|
|
|
|
|
if len(castling) < 4: |
|
|
indices.extend((4 - len(castling)) * [_CHARACTERS_INDEX['.']]) |
|
|
|
|
|
if en_passant == '-': |
|
|
indices.extend(2 * [_CHARACTERS_INDEX['.']]) |
|
|
else: |
|
|
|
|
|
for char in en_passant: |
|
|
indices.append(_CHARACTERS_INDEX[char]) |
|
|
|
|
|
|
|
|
|
|
|
halfmoves_last += '.' * (3 - len(halfmoves_last)) |
|
|
indices.extend([_CHARACTERS_INDEX[x] for x in halfmoves_last]) |
|
|
|
|
|
|
|
|
|
|
|
fullmoves += '.' * (3 - len(fullmoves)) |
|
|
indices.extend([_CHARACTERS_INDEX[x] for x in fullmoves]) |
|
|
|
|
|
assert len(indices) == SEQUENCE_LENGTH |
|
|
|
|
|
return np.asarray(indices, dtype=np.uint8) |
|
|
|