From c373d9b9c357ca927dfaf8cffdf3a1a72d0d213f Mon Sep 17 00:00:00 2001 From: Shantanu Jain Date: Wed, 7 Jun 2023 15:34:08 -0700 Subject: [PATCH] Sync codebase --- .gitignore | 1 + Cargo.toml | 8 +- README.md | 27 +++++ pyproject.toml | 2 +- tests/__init__.py | 0 tests/test_encoding.py | 231 +++++++++++++++++++++++++++++++++++++++ tests/test_helpers.py | 22 ++++ tests/test_misc.py | 24 ++++ tests/test_offsets.py | 79 +++++++++++++ tiktoken/_educational.py | 210 +++++++++++++++++++++++++++++++++++ 10 files changed, 599 insertions(+), 5 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/test_encoding.py create mode 100644 tests/test_helpers.py create mode 100644 tests/test_misc.py create mode 100644 tests/test_offsets.py create mode 100644 tiktoken/_educational.py diff --git a/.gitignore b/.gitignore index 9e090c8e..68cdf7ff 100644 --- a/.gitignore +++ b/.gitignore @@ -33,6 +33,7 @@ MANIFEST # Tools .mypy_cache .coverage +.hypothesis htmlcov # General diff --git a/Cargo.toml b/Cargo.toml index ff5ef628..948b9f13 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,13 +9,13 @@ name = "_tiktoken" crate-type = ["cdylib"] [dependencies] -pyo3 = { version = "0.17.3", features = ["extension-module"] } +pyo3 = { version = "0.19.0", features = ["extension-module"] } # tiktoken dependencies -fancy-regex = "0.10.0" -regex = "1.7.0" +fancy-regex = "0.11.0" +regex = "1.8.3" rustc-hash = "1.1.0" -bstr = "1.0.1" +bstr = "1.5.0" [profile.release] incremental = true diff --git a/README.md b/README.md index 10088162..1a76a2c0 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,33 @@ Please post questions in the [issue tracker](https://github.com/openai/tiktoken/ If you work at OpenAI, make sure to check the internal documentation or feel free to contact @shantanu. +## What is BPE anyway? + +Models don't see text like you and I, instead they see a sequence of numbers (known as tokens). +Byte pair encoding (BPE) is a way of converting text into tokens. It has a couple desirable +properties: +1) It's reversible and lossless, so you can convert tokens back into the original text +2) It works on arbitrary text, even text that is not in the tokeniser's training data +3) It compresses the text: the token sequence is shorter than the bytes corresponding to the + original text. On average, in practice, each token corresponds to about 4 bytes. +4) It attempts to let the model see common subwords. For instance, "ing" is a common subword in + English, so BPE encodings will often split "encoding" into tokens like "encod" and "ing" + (instead of e.g. "enc" and "oding"). Because the model will then see the "ing" token again and + again in different contexts, it helps models generalise and better understand grammar. + +`tiktoken` contains an educational submodule that is friendlier if you want to learn more about +the details of BPE, including code that helps visualise the BPE procedure: +```python +from tiktoken._educational import * + +# Train a BPE tokeniser on a small amount of text +enc = train_simple_encoding() + +# Visualise how the GPT-4 encoder encodes text +enc = SimpleBytePairEncoding.from_tiktoken("cl100k_base") +enc.encode("hello world aaaaaaaaaaaa") +``` + ## Extending tiktoken diff --git a/pyproject.toml b/pyproject.toml index 4d92051b..3fc42c73 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,6 @@ macos.archs = ["x86_64", "arm64"] # Warnings will be silenced with following CIBW_TEST_SKIP test-skip = "*-macosx_arm64" -before-test = "pip install pytest" +before-test = "pip install pytest hypothesis" test-command = "pytest {project}/tests" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_encoding.py b/tests/test_encoding.py new file mode 100644 index 00000000..27b21925 --- /dev/null +++ b/tests/test_encoding.py @@ -0,0 +1,231 @@ +# Note that there are more actual tests, they're just not currently public :-) + +from typing import Callable + +import hypothesis +import hypothesis.strategies as st +import pytest + +import tiktoken + +from .test_helpers import ENCODING_FACTORIES, MAX_EXAMPLES + + +def test_simple(): + enc = tiktoken.get_encoding("gpt2") + assert enc.encode("hello world") == [31373, 995] + assert enc.decode([31373, 995]) == "hello world" + assert enc.encode("hello <|endoftext|>", allowed_special="all") == [31373, 220, 50256] + + enc = tiktoken.get_encoding("cl100k_base") + assert enc.encode("hello world") == [15339, 1917] + assert enc.decode([15339, 1917]) == "hello world" + assert enc.encode("hello <|endoftext|>", allowed_special="all") == [15339, 220, 100257] + + for enc_name in tiktoken.list_encoding_names(): + enc = tiktoken.get_encoding(enc_name) + for token in range(10_000): + assert enc.encode_single_token(enc.decode_single_token_bytes(token)) == token + + +def test_simple_repeated(): + enc = tiktoken.get_encoding("gpt2") + assert enc.encode("0") == [15] + assert enc.encode("00") == [405] + assert enc.encode("000") == [830] + assert enc.encode("0000") == [2388] + assert enc.encode("00000") == [20483] + assert enc.encode("000000") == [10535] + assert enc.encode("0000000") == [24598] + assert enc.encode("00000000") == [8269] + assert enc.encode("000000000") == [10535, 830] + assert enc.encode("0000000000") == [8269, 405] + assert enc.encode("00000000000") == [8269, 830] + assert enc.encode("000000000000") == [8269, 2388] + assert enc.encode("0000000000000") == [8269, 20483] + assert enc.encode("00000000000000") == [8269, 10535] + assert enc.encode("000000000000000") == [8269, 24598] + assert enc.encode("0000000000000000") == [25645] + assert enc.encode("00000000000000000") == [8269, 10535, 830] + + +def test_simple_regex(): + enc = tiktoken.get_encoding("cl100k_base") + assert enc.encode("rer") == [38149] + assert enc.encode("'rer") == [2351, 81] + assert enc.encode("today\n ") == [31213, 198, 220] + assert enc.encode("today\n \n") == [31213, 27907] + assert enc.encode("today\n \n") == [31213, 14211] + + +def test_basic_encode(): + enc = tiktoken.get_encoding("r50k_base") + assert enc.encode("hello world") == [31373, 995] + + enc = tiktoken.get_encoding("p50k_base") + assert enc.encode("hello world") == [31373, 995] + + enc = tiktoken.get_encoding("cl100k_base") + assert enc.encode("hello world") == [15339, 1917] + assert enc.encode(" \x850") == [220, 126, 227, 15] + + +def test_encode_empty(): + enc = tiktoken.get_encoding("r50k_base") + assert enc.encode("") == [] + + +def test_encode_bytes(): + enc = tiktoken.get_encoding("cl100k_base") + assert enc._encode_bytes(b" \xec\x8b\xa4\xed") == [62085] + + +def test_encode_surrogate_pairs(): + enc = tiktoken.get_encoding("cl100k_base") + + assert enc.encode("👍") == [9468, 239, 235] + # surrogate pair gets converted to codepoint + assert enc.encode("\ud83d\udc4d") == [9468, 239, 235] + + # lone surrogate just gets replaced + assert enc.encode("\ud83d") == enc.encode("�") + + +# ==================== +# Roundtrip +# ==================== + + +@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) +def test_basic_roundtrip(make_enc): + enc = make_enc() + for value in ( + "hello", + "hello ", + "hello ", + " hello", + " hello ", + " hello ", + "hello world", + "请考试我的软件!12345", + ): + assert value == enc.decode(enc.encode(value)) + assert value == enc.decode(enc.encode_ordinary(value)) + + +@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) +@hypothesis.given(text=st.text()) +@hypothesis.settings(deadline=None) +def test_hyp_roundtrip(make_enc: Callable[[], tiktoken.Encoding], text): + enc = make_enc() + + assert text == enc.decode(enc.encode(text)) + + +@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) +def test_single_token_roundtrip(make_enc: Callable[[], tiktoken.Encoding]): + enc = make_enc() + + for token in range(enc.n_vocab): + try: + token_bytes = enc.decode_single_token_bytes(token) + except KeyError: + continue + assert enc.encode_single_token(token_bytes) == token + + +# ==================== +# Special tokens +# ==================== + + +def test_special_token(): + enc = tiktoken.get_encoding("cl100k_base") + + eot = enc.encode_single_token("<|endoftext|>") + assert eot == enc.eot_token + fip = enc.encode_single_token("<|fim_prefix|>") + fim = enc.encode_single_token("<|fim_middle|>") + + text = "<|endoftext|> hello <|fim_prefix|>" + assert eot not in enc.encode(text, disallowed_special=()) + with pytest.raises(ValueError): + enc.encode(text) + with pytest.raises(ValueError): + enc.encode(text, disallowed_special="all") + with pytest.raises(ValueError): + enc.encode(text, disallowed_special={"<|endoftext|>"}) + with pytest.raises(ValueError): + enc.encode(text, disallowed_special={"<|fim_prefix|>"}) + + text = "<|endoftext|> hello <|fim_prefix|> there <|fim_middle|>" + tokens = enc.encode(text, disallowed_special=()) + assert eot not in tokens + assert fip not in tokens + assert fim not in tokens + + tokens = enc.encode(text, allowed_special="all", disallowed_special=()) + assert eot in tokens + assert fip in tokens + assert fim in tokens + + tokens = enc.encode(text, allowed_special="all", disallowed_special="all") + assert eot in tokens + assert fip in tokens + assert fim in tokens + + tokens = enc.encode(text, allowed_special={"<|fim_prefix|>"}, disallowed_special=()) + assert eot not in tokens + assert fip in tokens + assert fim not in tokens + + tokens = enc.encode(text, allowed_special={"<|endoftext|>"}, disallowed_special=()) + assert eot in tokens + assert fip not in tokens + assert fim not in tokens + + tokens = enc.encode(text, allowed_special={"<|fim_middle|>"}, disallowed_special=()) + assert eot not in tokens + assert fip not in tokens + assert fim in tokens + + +@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) +@hypothesis.given(text=st.text()) +@hypothesis.settings(deadline=None, max_examples=MAX_EXAMPLES) +def test_hyp_special_ordinary(make_enc, text: str): + enc = make_enc() + assert enc.encode_ordinary(text) == enc.encode(text, disallowed_special=()) + + +# ==================== +# Batch encoding +# ==================== + + +@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) +def test_batch_encode(make_enc: Callable[[], tiktoken.Encoding]): + enc = make_enc() + text1 = "hello world" + text2 = "goodbye world" + + assert enc.encode_batch([text1]) == [enc.encode(text1)] + assert enc.encode_batch([text1, text2]) == [enc.encode(text1), enc.encode(text2)] + + assert enc.encode_ordinary_batch([text1]) == [enc.encode_ordinary(text1)] + assert enc.encode_ordinary_batch([text1, text2]) == [ + enc.encode_ordinary(text1), + enc.encode_ordinary(text2), + ] + + +@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) +@hypothesis.given(batch=st.lists(st.text())) +@hypothesis.settings(deadline=None) +def test_hyp_batch_roundtrip(make_enc: Callable[[], tiktoken.Encoding], batch): + enc = make_enc() + + encoded = enc.encode_batch(batch) + assert encoded == [enc.encode(t) for t in batch] + decoded = enc.decode_batch(encoded) + assert decoded == batch diff --git a/tests/test_helpers.py b/tests/test_helpers.py new file mode 100644 index 00000000..2be95d26 --- /dev/null +++ b/tests/test_helpers.py @@ -0,0 +1,22 @@ +import bisect +import functools +import os + +import pytest + +import tiktoken + +MAX_EXAMPLES: int = int(os.environ.get("TIKTOKEN_MAX_EXAMPLES", "100")) + +ENCODINGS = ["r50k_base", "cl100k_base"] +SOME_ENCODINGS = ["cl100k_base"] + + +ENCODING_FACTORIES = [ + pytest.param(functools.partial(tiktoken.get_encoding, name), id=name) for name in ENCODINGS +] +SOME_ENCODING_FACTORIES = [ + pytest.param(functools.partial(tiktoken.get_encoding, name), id=name) for name in SOME_ENCODINGS +] + + diff --git a/tests/test_misc.py b/tests/test_misc.py new file mode 100644 index 00000000..a2b00f67 --- /dev/null +++ b/tests/test_misc.py @@ -0,0 +1,24 @@ +import subprocess +import sys + +import tiktoken + + +def test_encoding_for_model(): + enc = tiktoken.encoding_for_model("gpt2") + assert enc.name == "gpt2" + enc = tiktoken.encoding_for_model("text-davinci-003") + assert enc.name == "p50k_base" + enc = tiktoken.encoding_for_model("text-davinci-edit-001") + assert enc.name == "p50k_edit" + enc = tiktoken.encoding_for_model("gpt-3.5-turbo-0301") + assert enc.name == "cl100k_base" + + +def test_optional_blobfile_dependency(): + prog = """ +import tiktoken +import sys +assert "blobfile" not in sys.modules +""" + subprocess.check_call([sys.executable, "-c", prog]) diff --git a/tests/test_offsets.py b/tests/test_offsets.py new file mode 100644 index 00000000..31b7f8d4 --- /dev/null +++ b/tests/test_offsets.py @@ -0,0 +1,79 @@ +from typing import Callable + +import hypothesis +import pytest +from hypothesis import strategies as st + +import tiktoken + +from .test_helpers import MAX_EXAMPLES, SOME_ENCODING_FACTORIES + + +def _common_prefix_len(a, b): + i = 0 + while i < len(a) and i < len(b) and a[i] == b[i]: + i += 1 + return i + + +def _token_offsets_reference(enc, tokens): + text = enc.decode(tokens, errors="strict") + res = [] + for i in range(len(tokens)): + prefix = enc.decode(tokens[:i], errors="ignore") + res.append(_common_prefix_len(text, prefix)) + return res + + +@pytest.mark.parametrize("make_enc", SOME_ENCODING_FACTORIES) +@hypothesis.given(data=st.data()) +@hypothesis.settings(deadline=None, max_examples=MAX_EXAMPLES) +def test_hyp_offsets(make_enc: Callable[[], tiktoken.Encoding], data): + enc = make_enc() + + tokens_st = st.lists( + st.integers(0, enc.n_vocab - 1).filter( + lambda x: x in enc._special_tokens.values() or x in enc._mergeable_ranks.values() + ), + min_size=1, + max_size=20, + ) + tokens = data.draw(tokens_st) + + # This is a dumb hack to make sure that our tokens are a valid UTF-8 string + # We could potentially drop this, see the TODO in decode_with_offsets + tokens = enc.encode(enc.decode(tokens, errors="ignore"), allowed_special="all") + assert enc.decode_with_offsets(tokens)[1] == _token_offsets_reference(enc, tokens) + + +def test_basic_offsets(): + enc = tiktoken.get_encoding("cl100k_base") + + prompt = "hello world" + p, o = enc.decode_with_offsets(enc.encode(prompt)) + assert p == prompt + assert o == [0, 5] + + prompt = "hello world<|endoftext|> green cow" + p, o = enc.decode_with_offsets(enc.encode(prompt, allowed_special="all")) + assert p == prompt + assert o == [0, 5, 11, 24, 30] + + prompt = "我非常渴望与人工智能一起工作" + p, o = enc.decode_with_offsets(enc.encode(prompt)) + assert p == prompt + assert o == [0, 1, 2, 3, 3, 4, 4, 5, 6, 7, 8, 8, 9, 10, 11, 12, 13] + + # contains the interesting tokens b'\xe0\xae\xbf\xe0\xae' and b'\xe0\xaf\x8d\xe0\xae' + # in which \xe0 is the start of a 3-byte UTF-8 character + prompt = "நடிகர் சூர்யா" + p, o = enc.decode_with_offsets(enc.encode(prompt)) + assert p == prompt + assert o == [0, 0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 8, 8, 9, 9, 10, 11, 12, 12] + + # contains the interesting token b'\xa0\xe9\x99\xa4' + # in which \xe9 is the start of a 3-byte UTF-8 character and \xa0 is a continuation byte + prompt = " Ġ除" + p, o = enc.decode_with_offsets(enc.encode(prompt)) + assert p == prompt + assert o == [0, 1] diff --git a/tiktoken/_educational.py b/tiktoken/_educational.py new file mode 100644 index 00000000..692a8bb8 --- /dev/null +++ b/tiktoken/_educational.py @@ -0,0 +1,210 @@ +"""This is an educational implementation of the byte pair encoding algorithm.""" +import collections +import itertools +from typing import Optional + +import regex + +import tiktoken + + +class SimpleBytePairEncoding: + def __init__(self, *, pat_str: str, mergeable_ranks: dict[bytes, int]) -> None: + """Creates an Encoding object.""" + # A regex pattern string that is used to split the input text + self.pat_str = pat_str + # A dictionary mapping token bytes to their ranks. The ranks correspond to merge priority + self.mergeable_ranks = mergeable_ranks + + self._decoder = {token: token_bytes for token_bytes, token in mergeable_ranks.items()} + self._pat = regex.compile(pat_str) + + def encode(self, text: str, visualise: Optional[str] = "colour") -> list[int]: + """Encodes a string into tokens. + + >>> enc.encode("hello world") + [388, 372] + """ + # Use the regex to split the text into (approximately) words + words = self._pat.findall(text) + tokens = [] + for word in words: + # Turn each word into tokens, using the byte pair encoding algorithm + word_bytes = word.encode("utf-8") + word_tokens = bpe_encode(self.mergeable_ranks, word_bytes, visualise=visualise) + tokens.extend(word_tokens) + return tokens + + def decode_bytes(self, tokens: list[int]) -> bytes: + """Decodes a list of tokens into bytes. + + >>> enc.decode_bytes([388, 372]) + b'hello world' + """ + return b"".join(self._decoder[token] for token in tokens) + + def decode(self, tokens: list[int]) -> str: + """Decodes a list of tokens into a string. + + Decoded bytes are not guaranteed to be valid UTF-8. In that case, we replace + the invalid bytes with the replacement character "�". + + >>> enc.decode([388, 372]) + 'hello world' + """ + return self.decode_bytes(tokens).decode("utf-8", errors="replace") + + def decode_tokens_bytes(self, tokens: list[int]) -> list[bytes]: + """Decodes a list of tokens into a list of bytes. + + Useful for visualising how a string is tokenised. + + >>> enc.decode_tokens_bytes([388, 372]) + [b'hello', b' world'] + """ + return [self._decoder[token] for token in tokens] + + @staticmethod + def train(training_data: str, vocab_size: int, pat_str: str): + """Train a BPE tokeniser on some data!""" + mergeable_ranks = bpe_train(data=training_data, vocab_size=vocab_size, pat_str=pat_str) + return SimpleBytePairEncoding(pat_str=pat_str, mergeable_ranks=mergeable_ranks) + + @staticmethod + def from_tiktoken(encoding): + if isinstance(encoding, str): + encoding = tiktoken.get_encoding(encoding) + return SimpleBytePairEncoding( + pat_str=encoding._pat_str, mergeable_ranks=encoding._mergeable_ranks + ) + + +def bpe_encode( + mergeable_ranks: dict[bytes, int], input: bytes, visualise: Optional[str] = "colour" +) -> list[int]: + parts = [bytes([b]) for b in input] + while True: + # See the intermediate merges play out! + if visualise: + if visualise in ["colour", "color"]: + visualise_tokens(parts) + elif visualise == "simple": + print(parts) + + # Iterate over all pairs and find the pair we want to merge the most + min_idx = None + min_rank = None + for i, pair in enumerate(zip(parts[:-1], parts[1:])): + rank = mergeable_ranks.get(pair[0] + pair[1]) + if rank is not None and (min_rank is None or rank < min_rank): + min_idx = i + min_rank = rank + + # If there were no pairs we could merge, we're done! + if min_rank is None: + break + assert min_idx is not None + + # Otherwise, merge that pair and leave the rest unchanged. Then repeat. + parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2 :] + + if visualise: + print() + + tokens = [mergeable_ranks[part] for part in parts] + return tokens + + +def bpe_train( + data: str, vocab_size: int, pat_str: str, visualise: Optional[str] = "colour" +) -> dict[bytes, int]: + # First, add tokens for each individual byte value + if vocab_size < 2**8: + raise ValueError("vocab_size must be at least 256, so we can encode all bytes") + ranks = {} + for i in range(2**8): + ranks[bytes([i])] = i + + # Splinter up our data into lists of bytes + # data = "Hello world" + # words = [ + # [b'H', b'e', b'l', b'l', b'o'], + # [b' ', b'w', b'o', b'r', b'l', b'd'] + # ] + words: list[list[bytes]] = [ + [bytes([b]) for b in word.encode("utf-8")] for word in regex.findall(pat_str, data) + ] + + # Now, use our data to figure out which merges we should make + while len(ranks) < vocab_size: + # Find the most common pair. This will become our next token + stats = collections.Counter() + for piece in words: + for pair in zip(piece[:-1], piece[1:]): + stats[pair] += 1 + + most_common_pair = max(stats, key=lambda x: stats[x]) + token_bytes = most_common_pair[0] + most_common_pair[1] + token = len(ranks) + # Add the new token! + ranks[token_bytes] = token + + # Now merge that most common pair in all the words. That is, update our training data + # to reflect our decision to make that pair into a new token. + new_words = [] + for word in words: + new_word = [] + i = 0 + while i < len(word) - 1: + if (word[i], word[i + 1]) == most_common_pair: + # We found our pair! Merge it + new_word.append(token_bytes) + i += 2 + else: + new_word.append(word[i]) + i += 1 + if i == len(word) - 1: + new_word.append(word[i]) + new_words.append(new_word) + words = new_words + + # See the intermediate merges play out! + if visualise: + print(f"The current most common pair is {most_common_pair[0]} + {most_common_pair[1]}") + print(f"So we made {token_bytes} our {len(ranks)}th token") + if visualise in ["colour", "color"]: + print("Now the first fifty words in our training data look like:") + visualise_tokens([token for word in words[:50] for token in word]) + elif visualise == "simple": + print("Now the first twenty words in our training data look like:") + for word in words[:20]: + print(word) + print("\n") + + return ranks + + +def visualise_tokens(token_values: list[bytes]) -> None: + backgrounds = itertools.cycle( + [f"\u001b[48;5;{i}m".encode() for i in [167, 179, 185, 77, 80, 68, 134]] + ) + interleaved = itertools.chain.from_iterable(zip(backgrounds, token_values)) + print((b"".join(interleaved) + "\u001b[0m".encode()).decode("utf-8")) + + +def train_simple_encoding(): + gpt2_pattern = ( + r"""'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" + ) + with open(__file__, "r") as f: + data = f.read() + + enc = SimpleBytePairEncoding.train(data, vocab_size=600, pat_str=gpt2_pattern) + + print("This is the sequence of merges performed in order to encode 'hello world':") + tokens = enc.encode("hello world") + assert enc.decode(tokens) == "hello world" + assert enc.decode_bytes(tokens) == b"hello world" + assert enc.decode_tokens_bytes(tokens) == [b"hello", b" world"] + + return enc