From a1a9f16826f3f2d8ba80b6c5fd270c1c340d6d67 Mon Sep 17 00:00:00 2001 From: Shantanu Jain Date: Mon, 12 Dec 2022 11:27:27 -0800 Subject: [PATCH] [tiktoken] hello world --- .gitignore | 42 +++ Cargo.toml | 21 ++ LICENSE | 21 ++ MANIFEST.in | 5 + Makefile | 49 +++ README.md | 28 ++ perf.svg | 373 +++++++++++++++++++++++ pyproject.toml | 8 + scripts/benchmark.py | 39 +++ scripts/redact.py | 65 ++++ setup.py | 23 ++ src/lib.rs | 559 ++++++++++++++++++++++++++++++++++ tiktoken/__init__.py | 3 + tiktoken/core.py | 310 +++++++++++++++++++ tiktoken/load.py | 97 ++++++ tiktoken/registry.py | 71 +++++ tiktoken_ext/openai_public.py | 41 +++ 17 files changed, 1755 insertions(+) create mode 100644 .gitignore create mode 100644 Cargo.toml create mode 100644 LICENSE create mode 100644 MANIFEST.in create mode 100644 Makefile create mode 100644 README.md create mode 100644 perf.svg create mode 100644 pyproject.toml create mode 100644 scripts/benchmark.py create mode 100644 scripts/redact.py create mode 100644 setup.py create mode 100644 src/lib.rs create mode 100644 tiktoken/__init__.py create mode 100644 tiktoken/core.py create mode 100644 tiktoken/load.py create mode 100644 tiktoken/registry.py create mode 100644 tiktoken_ext/openai_public.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..9e090c8e --- /dev/null +++ b/.gitignore @@ -0,0 +1,42 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# Environments +.env +.venv + +# Tools +.mypy_cache +.coverage +htmlcov + +# General +.DS_Store + +Cargo.lock +target/ diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 00000000..24b42fd4 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "tiktoken" +version = "0.1.0" +edition = "2021" +rust-version = "1.57.0" + +[lib] +name = "_tiktoken" +crate-type = ["cdylib"] + +[dependencies] +pyo3 = { version = "0.17.3", features = ["extension-module"] } + +# tiktoken dependencies +fancy-regex = "0.10.0" +regex = "1.7.0" +rustc-hash = "1.1.0" +bstr = "1.0.1" + +[profile.release] +incremental = true diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..83ed1036 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 OpenAI, Shantanu Jain + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..cb017cd3 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,5 @@ +include *.svg +include *.toml +include Makefile +recursive-include scripts *.py +recursive-include src *.rs diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..92aec0f9 --- /dev/null +++ b/Makefile @@ -0,0 +1,49 @@ +PROJECT := tiktoken + +.PHONY: default +default: editable_install + +.PHONY: install_rust +install_rust: + which cargo >/dev/null || curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain 1.62 + +.PHONY: clean +clean: + cargo clean + pip uninstall -y $(PROJECT) + find . | grep -E '__pycache__|\.pyc' | xargs rm -rf + find . | grep -E '\.so' | xargs rm -rf + rm -rf dist/ build/ + rm -rf $(PROJECT).egg-info/ + +.PHONY: format +format: + @ which black >/dev/null || python3 -m pip install black + @ which isort >/dev/null || python3 -m pip install isort + cargo fmt -- --config group_imports=StdExternalCrate + black --line-length 100 --skip-magic-trailing-comma --quiet . + isort --line-length 100 --profile black --quiet . + + +.PHONY: format_check +format_check: + @ which black >/dev/null || python3 -m pip install black + @ which isort >/dev/null || python3 -m pip install isort + cargo fmt --check -- --config group_imports=StdExternalCrate + black --check --line-length 100 --skip-magic-trailing-comma --quiet . + isort --check --line-length 100 --profile black --quiet . + +.PHONY: lint +lint: + cargo clippy --all -- -D warnings + @ which flake8 >/dev/null || python3 -m pip install flake8==5 flake8-bugbear==22.9.11 + flake8 --ignore=E203,E501,W503,E731 --per-file-ignores="$(PROJECT)/__init__.py:F401 setup.py:E402" --exclude=build . + +.PHONY: editable_install +editable_install: + @ if [ -f $(PROJECT).egg-info ]; then \ + pip install --disable-pip-version-check --progress-bar=off setuptools wheel setuptools-rust ; \ + pip install --disable-pip-version-check --no-build-isolation -e . ; \ + else \ + pip install --disable-pip-version-check --no-deps --no-build-isolation --ignore-installed -e . ; \ + fi diff --git a/README.md b/README.md new file mode 100644 index 00000000..f0ea386d --- /dev/null +++ b/README.md @@ -0,0 +1,28 @@ +# ⏳ tiktoken + +tiktoken is a fast tokeniser. + +```python +import tiktoken +enc = tiktoken.get_encoding("gpt2") +print(enc.encode("hello world")) +``` + +The open source version of `tiktoken` can be installed from PyPI: +``` +pip install tiktoken +``` + +The tokeniser API is documented in `tiktoken/core.py`. + + +## Performance + +`tiktoken` is between 3-6x faster than huggingface's tokeniser: + +![image](./perf.svg) + +Performance measured on 1GB of text using the GPT-2 tokeniser, using `GPT2TokenizerFast` from +`tokenizers==0.13.2` and `transformers==4.24.0`. + + diff --git a/perf.svg b/perf.svg new file mode 100644 index 00000000..7157ef93 --- /dev/null +++ b/perf.svg @@ -0,0 +1,373 @@ + + + + + + + + + + + + + + + + + + + + + + +Throughput + + + + + + +0 MB/s + + + + + +10 MB/s + + + + + +20 MB/s + + + + + +30 MB/s + + + + + +40 MB/s + + + + + + + + + + + +Thread count + + + + + +1 + + + + + +2 + + + + + +4 + + + + + +8 + + + + + +16 + + + + + +32 + + + + + +64 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +tiktoken + +huggingface + + + + diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..bb9aeeb1 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,8 @@ +[project] +name = "tiktoken" +dependencies = ["blobfile>=2", "regex>=2022.1.18"] +dynamic = ["version"] + +[build-system] +requires = ["setuptools", "wheel", "setuptools-rust"] + diff --git a/scripts/benchmark.py b/scripts/benchmark.py new file mode 100644 index 00000000..4d679fac --- /dev/null +++ b/scripts/benchmark.py @@ -0,0 +1,39 @@ +import base64 +import functools +import gzip +import json +import os +import random +import time +from typing import Any, cast + +import blobfile + +import tiktoken + + +def benchmark_batch(documents: list[str]) -> None: + num_threads = int(os.environ["RAYON_NUM_THREADS"]) + num_bytes = sum(map(len, map(str.encode, documents))) + print(f"num_threads: {num_threads}, num_bytes: {num_bytes}") + + enc = tiktoken.get_encoding("gpt2") + enc.encode("warmup") + + start = time.perf_counter_ns() + enc.encode_ordinary_batch(documents, num_threads=num_threads) + end = time.perf_counter_ns() + print(f"tiktoken \t{num_bytes / (end - start) * 1e9} bytes / s") + + import transformers + + hf_enc = cast(Any, transformers).GPT2TokenizerFast.from_pretrained("gpt2") + hf_enc.model_max_length = 1e30 # silence! + hf_enc.encode("warmup") + + start = time.perf_counter_ns() + hf_enc(documents) + end = time.perf_counter_ns() + print(f"huggingface \t{num_bytes / (end - start) * 1e9} bytes / s") + + diff --git a/scripts/redact.py b/scripts/redact.py new file mode 100644 index 00000000..bcf8ef12 --- /dev/null +++ b/scripts/redact.py @@ -0,0 +1,65 @@ +import argparse +import re +import subprocess +from pathlib import Path + + +def redact_file(path: Path, dry_run: bool) -> None: + if not path.exists() or path.is_dir(): + return + + text = path.read_text() + + first_line = text.splitlines()[0] + if "redact" in first_line: + if not dry_run: + path.unlink() + print(f"Deleted {path}") + return + + pattern = "|".join( + re.escape(x) + for x in [ + "# ===== redact-beg =====\n", + "# ===== redact-end =====\n", + "\n", + "\n", + ] + ) + + if re.search(pattern, text): + redacted_text = "".join(re.split(pattern, text)[::2]) + if not dry_run: + path.write_text(redacted_text) + print(f"Redacted {path}") + return + + print(f"Skipped {path}") + + +def redact(dry_run: bool) -> None: + tiktoken_root = Path(__file__).parent.parent + assert tiktoken_root.name == "tiktoken" + assert (tiktoken_root / "pyproject.toml").exists() + + try: + output = subprocess.check_output(["git", "ls-files"], cwd=tiktoken_root, text=True) + paths = [Path(p) for p in output.splitlines()] + except subprocess.CalledProcessError: + paths = list(tiktoken_root.glob("**/*")) + + for path in paths: + redact_file(path, dry_run=dry_run) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--dry-run", type=lambda x: not x or x[0].lower() != "f", default=True) + args = parser.parse_args() + redact(args.dry_run) + if args.dry_run: + print("Dry run, use --dry-run=false to actually redact files") + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..df18edad --- /dev/null +++ b/setup.py @@ -0,0 +1,23 @@ +from setuptools import setup +from setuptools_rust import Binding, RustExtension + +public = True + +if public: + version = "0.1" + +setup( + name="tiktoken", + version=version, + rust_extensions=[ + RustExtension( + "tiktoken._tiktoken", + binding=Binding.PyO3, + # Between our use of editable installs and wanting to use Rust for performance sensitive + # code, it makes sense to just always use --release + debug=False, + ) + ], + packages=["tiktoken", "tiktoken_ext"], + zip_safe=False, +) diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 00000000..8235dbb1 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,559 @@ +// This check is new and seems buggy (possibly with PyO3 interaction) +#![allow(clippy::borrow_deref_ref)] + +use std::collections::HashSet; +use std::thread; + +use fancy_regex::Regex; +use pyo3::exceptions; +use pyo3::prelude::*; +use pyo3::types::{PyBytes, PyList, PyTuple}; +use pyo3::PyResult; +use rustc_hash::FxHashMap as HashMap; + +fn _byte_pair_merge(piece: &[u8], ranks: &HashMap, usize>) -> Vec> { + let mut parts: Vec<_> = (0..piece.len()).map(|i| i..i + 1).collect(); + + // If you have n parts and m merges, this does O(mn) work + // We could do something with a heap and do O(m log n) work + + // Note that we hash bytes, not token pairs. As long as we train BPE the way we + // currently do, this is equivalent. An easy way to break this would be to decouple + // merge priority from token index or to prevent specific token merges. + loop { + if parts.len() == 1 { + break; + } + let mut min_rank: Option<(usize, usize)> = None; + for i in 0..parts.len() - 1 { + let rank = if let Some(r) = ranks.get(&piece[parts[i].start..parts[i + 1].end]) { + *r + } else { + continue; + }; + if min_rank.is_none() || rank < min_rank.unwrap().0 { + min_rank = Some((rank, i)); + } + } + if let Some((_, i)) = min_rank { + parts[i] = parts[i].start..parts[i + 1].end; + parts.remove(i + 1); + } else { + break; + } + } + parts +} + +pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap, usize>) -> Vec { + if piece.len() == 1 { + return vec![ranks[piece]]; + } + _byte_pair_merge(piece, ranks) + .iter() + .map(|p| ranks[&piece[p.start..p.end]]) + .collect() +} + +pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap, usize>) -> Vec<&'a [u8]> { + if piece.len() == 1 { + return vec![piece]; + } + _byte_pair_merge(piece, ranks) + .iter() + .map(|p| &piece[p.start..p.end]) + .collect() +} + +// Various performance notes: +// +// Regex +// ===== +// Most of the time is spent in regex. The easiest way to speed this up is by using less fancy +// regex features. For instance, using a regex parse-able by `regex` crate is 3x faster than +// the usual regex we use. +// +// However, given that we're using a regex parse-able by `regex`, there isn't much difference +// between using the `regex` crate and using the `fancy_regex` crate. +// +// There is an important interaction between threading, `regex` and `fancy_regex`. +// When using `fancy_regex`, we hit `regex.find_at`. It turns out that this causes contention on +// some mutable scratch space inside of `regex`. This absolutely kills performance. When using plain +// old `regex`, we don't hit this, because `find_iter` has a different code path. +// Related: https://github.com/rust-lang/regex/blob/master/PERFORMANCE.md +// Anyway, the way we get around this is with having a (mostly) thread local clone of the regex for +// each thread. +// +// Threading +// ========= +// I tried using `rayon`. It wasn't really faster than using Python threads and releasing the GIL. +// So goodbye `rayon`! Let thread count etc be in control of our Python users. +// +// Caching +// ======= +// The reference tokeniser has an lru cache over the equivalent of `byte_pair_encode`. +// Originally, we had one too! Without it, we were only vaguely faster than Python. +// I used an RWLock to protect the cache. This didn't seem to hurt single threaded performance +// noticeably, but it did affect multi-threaded performance. Weirdly, it seemed to affect +// multi-threaded performance even when I only had readers (maybed I messed something up?). +// Anyway, I realised that we could get rid of the cache, if we treat the set of tokens as a cache! +// These are exactly the set or merges that are likely to be hot. And now we don't have to think +// about interior mutability, memory use, or cloning. +// +// Hashing +// ======= +// We use FxHashMap instead of the standard HashMap. This is maybe like a 5-10% win? +// The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made +// to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster. + +use std::num::NonZeroU64; +pub struct FakeThreadId(NonZeroU64); + +fn hash_current_thread() -> usize { + // It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter + // that works great for our use case of avoiding collisions in our array. Unfortunately, + // it's private. However, there are only so many ways you can layout a u64, so just transmute + // https://github.com/rust-lang/rust/issues/67939 + const _: [u8; 8] = [0; std::mem::size_of::()]; + const _: [u8; 8] = [0; std::mem::size_of::()]; + let x = unsafe { + std::mem::transmute::(thread::current().id()).0 + }; + u64::from(x) as usize +} + +const MAX_NUM_THREADS: usize = 128; +#[pyclass] +struct CoreBPE { + encoder: HashMap, usize>, + special_tokens_encoder: HashMap, + decoder: HashMap>, + special_tokens_decoder: HashMap>, + regex_tls: Vec, + special_regex_tls: Vec, + sorted_token_bytes: Vec>, +} + +impl CoreBPE { + fn _get_tl_regex(&self) -> &Regex { + // See performance notes above for what this is about + // It's also a little janky, please make a better version of it! + // However, it's nice that this doesn't leak memory to short-lived threads + &self.regex_tls[hash_current_thread() % MAX_NUM_THREADS] + } + + fn _get_tl_special_regex(&self) -> &Regex { + &self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS] + } + + fn _decode_native(&self, tokens: &[usize]) -> Vec { + let mut ret = Vec::with_capacity(tokens.len() * 2); + for token in tokens { + let token_bytes = self + .decoder + .get(token) + .unwrap_or_else(|| &self.special_tokens_decoder[token]); + ret.extend(token_bytes); + } + ret + } + + fn _encode_ordinary_native(&self, text: &str) -> Vec { + // This is the core of the encoding logic; the other functions in here + // just make things complicated :-) + let regex = self._get_tl_regex(); + let mut ret = vec![]; + for mat in regex.find_iter(text) { + let piece = mat.unwrap().as_str().as_bytes(); + if let Some(token) = self.encoder.get(piece) { + ret.push(*token); + continue; + } + ret.extend(&byte_pair_encode(piece, &self.encoder)); + } + ret + } + + fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec, usize) { + let special_regex = self._get_tl_special_regex(); + let regex = self._get_tl_regex(); + let mut ret = vec![]; + + let mut start = 0; + let mut last_piece_token_len = 0; + loop { + let mut next_special; + let mut start_find = start; + loop { + // Find the next allowed special token, if any + next_special = special_regex.find_from_pos(text, start_find).unwrap(); + match next_special { + Some(m) => { + if allowed_special.contains(&text[m.start()..m.end()]) { + break; + } + start_find = m.start() + 1; + } + None => break, + } + } + let end = next_special.map_or(text.len(), |m| m.start()); + + // Okay, here we go, compare this logic to _encode_ordinary_native + for mat in regex.find_iter(&text[start..end]) { + let piece = mat.unwrap().as_str().as_bytes(); + if let Some(token) = self.encoder.get(piece) { + last_piece_token_len = 1; + ret.push(*token); + continue; + } + let tokens = byte_pair_encode(piece, &self.encoder); + last_piece_token_len = tokens.len(); + ret.extend(&tokens); + } + + match next_special { + // And here we push the special token + Some(m) => { + let piece = m.as_str(); + let token = self.special_tokens_encoder[piece]; + ret.push(token); + start = m.end(); + last_piece_token_len = 0; + } + None => break, + } + } + + // last_piece_token_len is how many tokens came from the last regex split. This is used + // for determining unstable tokens, since you can't merge across (stable) regex splits + (ret, last_piece_token_len) + } + + fn _increase_last_piece_token_len( + &self, + tokens: Vec, + mut last_piece_token_len: usize, + ) -> (Vec, usize) { + // Unfortunately, the locations where our regex splits can be unstable. + // For the purposes of determining unstable tokens, unstable regex splitting + // is only a problem if a split that was present disappears, since this can + // lead to merging of tokens otherwise thought to be stable. + // cl100k_base makes our life hard by including the \s*[\r\n]+ + // pattern. This can e.g. cause "\n" + " " to become "\n \n". + // Here is a quick and dirty fix: + { + let token_is_all_space = |token| { + self.decoder + .get(token) + .map(|token_bytes| { + token_bytes + .iter() + .rev() + .all(|&b| [b' ', b'\n', b'\t'].contains(&b)) + }) + .unwrap_or(false) + }; + if last_piece_token_len > 0 + && token_is_all_space(&tokens[tokens.len() - last_piece_token_len]) + { + while (last_piece_token_len < tokens.len()) + && token_is_all_space(&tokens[tokens.len() - last_piece_token_len - 1]) + { + last_piece_token_len += 1; + } + } + } + debug_assert!(last_piece_token_len <= tokens.len()); + + (tokens, last_piece_token_len) + } + + fn _encode_unstable_native( + &self, + text: &str, + allowed_special: &HashSet<&str>, + ) -> (Vec, HashSet>) { + let (tokens, last_piece_token_len) = self._encode_native(text, allowed_special); + if last_piece_token_len == 0 { + // If last_piece_token_len is zero, the last token was a special token and we have + // no unstable bytes + return (tokens, HashSet::new()); + } + let (mut tokens, last_piece_token_len) = + self._increase_last_piece_token_len(tokens, last_piece_token_len); + + let unstable_bytes = self._decode_native(&tokens[tokens.len() - last_piece_token_len..]); + tokens.truncate(tokens.len() - last_piece_token_len); + + // TODO: we should try harder to find additional stable tokens + // This would reduce the amount of retokenising when determining completions + // Refer to the logic in an older version of this file + + let mut completions = HashSet::new(); + if unstable_bytes.is_empty() { + return (tokens, completions); + } + + // This is the easy bit. Just find all single tokens that start with unstable_bytes + // (including tokens that exactly match unstable_bytes) + // Separating this from the loop below helps with performance in a common case. + let mut point = self + .sorted_token_bytes + .partition_point(|x| x.as_slice() < unstable_bytes.as_slice()); + while point < self.sorted_token_bytes.len() + && self.sorted_token_bytes[point].starts_with(&unstable_bytes) + { + completions.insert(vec![ + self.encoder[self.sorted_token_bytes[point].as_slice()], + ]); + point += 1; + } + + // Now apply even more brute force. At every (other) possible position for the straddling + // token, concatenate additional bytes from that token (if any) to unstable_bytes, + // and retokenise the whole thing and see what we get. + for i in 1..unstable_bytes.len() { + let prefix = &unstable_bytes[..i]; + let suffix = &unstable_bytes[i..]; + let mut point = self + .sorted_token_bytes + .partition_point(|x| x.as_slice() < suffix); + // TODO: Perf optimisation if suffix starts with " "? + while point < self.sorted_token_bytes.len() + && self.sorted_token_bytes[point].starts_with(suffix) + { + let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat(); + let encoded = match std::str::from_utf8(&possibility) { + // Morally, this is byte_pair_encode(&possibility, &self.encoder) + // But we might have introduced a regex split which would prevent merges. + // (particularly possible in the presence of unstable regex splits) + // So convert to UTF-8 and do regex splitting. + // E.g. with cl100k_base " !" gets split to " " + " !", + // but byte_pair_encode(" !") != byte_pair_encode(" ") + Ok(s) => self._encode_ordinary_native(s), + + // Technically, whether or not this arm is correct depends on whether there + // would be a regex split before the UTF-8 truncation point. + // Probably niche enough that no one will ever notice (after all, people didn't + // notice all the big holes in the previous unstable token implementation) + Err(_) => byte_pair_encode(&possibility, &self.encoder), + // Something like the following is intriguing but incorrect: + // Err(e) => self._encode_ordinary_native(unsafe { + // std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()]) + // }), + }; + let mut seq = Vec::new(); + let mut seq_len = 0; + for token in encoded { + seq.push(token); + seq_len += self.decoder[&token].len(); + if seq_len >= unstable_bytes.len() { + break; + } + } + completions.insert(seq); + point += 1; + } + } + + // This is also not straightforward. While we generally assume that regex splits are stable, + // unfortunately, they are not. That is, if adding bytes were to make a split appear in + // unstable_bytes, this could make tokens possible which our logic would otherwise think + // would be merged. + // For example, with gpt2, the use of \s+(?!\S) means that "\n\n" could + // develop a split, e.g. "\n\n0" splits into "\n"+"\n"+"0", making "\n" a possible token. + // Here is a quick and dirty fix: + // This isn't right if we ever remove \s+(?!\S) + if unstable_bytes.len() > 1 { + let last_decoded = bstr::decode_last_utf8(unstable_bytes.as_slice()); + if unstable_bytes.len() - last_decoded.1 > 0 + && last_decoded.0.map_or(false, |c| c.is_whitespace()) + { + let mut reencoded = byte_pair_encode( + &unstable_bytes[..unstable_bytes.len() - last_decoded.1], + &self.encoder, + ); + reencoded.extend(byte_pair_encode( + &unstable_bytes[unstable_bytes.len() - last_decoded.1..], + &self.encoder, + )); + completions.insert(reencoded); + } + } + + (tokens, completions) + } +} + +#[pymethods] +impl CoreBPE { + #[new] + fn new( + encoder: HashMap, usize>, + special_tokens_encoder: HashMap, + pattern: &str, + ) -> PyResult { + let regex = Regex::new(pattern) + .map_err(|e| PyErr::new::(e.to_string()))?; + + let special_regex = { + let _parts = special_tokens_encoder + .keys() + .map(|s| fancy_regex::escape(s)) + .collect::>(); + Regex::new(&_parts.join("|")) + .map_err(|e| PyErr::new::(e.to_string()))? + }; + + let decoder: HashMap> = + encoder.iter().map(|(k, v)| (*v, k.clone())).collect(); + + assert!(encoder.len() == decoder.len()); + + let special_tokens_decoder: HashMap> = special_tokens_encoder + .iter() + .map(|(k, v)| (*v, k.as_bytes().to_vec())) + .collect(); + + // Clone because I don't know how to tell Rust I'm not going to change the map + let mut sorted_token_bytes: Vec> = encoder.keys().cloned().collect(); + sorted_token_bytes.sort(); + + Ok(CoreBPE { + encoder, + special_tokens_encoder, + decoder, + special_tokens_decoder, + regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(), + special_regex_tls: (0..MAX_NUM_THREADS) + .map(|_| special_regex.clone()) + .collect(), + sorted_token_bytes, + }) + } + + // ==================== + // Encoding + // ==================== + + fn encode_ordinary(&self, py: Python, text: &str) -> Vec { + py.allow_threads(|| self._encode_ordinary_native(text)) + } + + fn encode(&self, py: Python, text: &str, allowed_special: HashSet<&str>) -> Vec { + py.allow_threads(|| self._encode_native(text, &allowed_special).0) + } + + fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec { + py.allow_threads(|| { + match std::str::from_utf8(bytes) { + Ok(text) => self._encode_ordinary_native(text), + Err(e) => { + let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) }; + let (tokens, last_piece_token_len) = self._encode_native(text, &HashSet::new()); + let (mut tokens, last_piece_token_len) = + self._increase_last_piece_token_len(tokens, last_piece_token_len); + if !tokens.is_empty() && last_piece_token_len > 0 { + // Lop off the tokens from the last piece and run BPE on the remaining bytes + // Somewhat niche, but this may not be correct if we'd have had a regex + // split between the valid UTF-8 and the invalid bytes, which is why this + // method is private + let mut unstable_bytes = + self._decode_native(&tokens[tokens.len() - last_piece_token_len..]); + unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]); + + tokens.truncate(tokens.len() - last_piece_token_len); + tokens.extend(byte_pair_encode(&unstable_bytes, &self.encoder)); + } + tokens + } + } + }) + } + + fn encode_with_unstable( + &self, + py: Python, + text: &str, + allowed_special: HashSet<&str>, + ) -> Py { + let (tokens, completions) = + py.allow_threads(|| self._encode_unstable_native(text, &allowed_special)); + let py_completions = + PyList::new(py, completions.iter().map(|seq| PyList::new(py, &seq[..]))); + (tokens, py_completions).into_py(py) + } + + fn encode_single_token(&self, piece: &[u8]) -> PyResult { + if let Some(token) = self.encoder.get(piece).copied() { + return Ok(token); + } + if let Ok(piece_str) = std::str::from_utf8(piece) { + if let Some(token) = self.special_tokens_encoder.get(piece_str).copied() { + return Ok(token); + } + } + Err(PyErr::new::(piece.to_owned())) + } + + fn encode_single_piece(&self, piece: &[u8]) -> Vec { + if let Some(token) = self.encoder.get(piece) { + return vec![*token]; + } + byte_pair_encode(piece, &self.encoder) + } + + // ==================== + // Decoding + // ==================== + + fn decode_bytes(&self, py: Python, tokens: Vec) -> Py { + let bytes = py.allow_threads(|| self._decode_native(&tokens)); + PyBytes::new(py, &bytes).into() + } + + fn decode_single_token_bytes(&self, py: Python, token: usize) -> PyResult> { + if let Some(bytes) = self.decoder.get(&token) { + return Ok(PyBytes::new(py, bytes).into()); + } + if let Some(bytes) = self.special_tokens_decoder.get(&token) { + return Ok(PyBytes::new(py, bytes).into()); + } + Err(PyErr::new::(token.to_string())) + } + + // ==================== + // Miscellaneous + // ==================== + + fn token_byte_values(&self, py: Python) -> Vec> { + self.sorted_token_bytes + .iter() + .map(|x| PyBytes::new(py, x).into()) + .collect() + } +} + +#[pymodule] +fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use rustc_hash::FxHashMap as HashMap; + + use crate::byte_pair_split; + + #[test] + fn very_simple_test() { + let mut ranks = HashMap::default(); + ranks.insert(b"ab".to_vec(), 1); + ranks.insert(b"cd".to_vec(), 2); + + let res = byte_pair_split(b"abcd", &ranks); + assert_eq!(res, vec![b"ab", b"cd"]); + } +} diff --git a/tiktoken/__init__.py b/tiktoken/__init__.py new file mode 100644 index 00000000..f4b50657 --- /dev/null +++ b/tiktoken/__init__.py @@ -0,0 +1,3 @@ +from .core import Encoding as Encoding +from .registry import get_encoding as get_encoding +from .registry import list_encoding_names as list_encoding_names diff --git a/tiktoken/core.py b/tiktoken/core.py new file mode 100644 index 00000000..e200c291 --- /dev/null +++ b/tiktoken/core.py @@ -0,0 +1,310 @@ +import functools +from concurrent.futures import ThreadPoolExecutor +from typing import AbstractSet, Collection, Literal, NoReturn, Optional, Union + +import regex + +from tiktoken import _tiktoken + + +class Encoding: + def __init__( + self, + name: str, + *, + pat_str: str, + mergeable_ranks: dict[bytes, int], + special_tokens: dict[str, int], + explicit_n_vocab: Optional[int] = None, + ): + self.name = name + + self._pat_str = pat_str + self._mergeable_ranks = mergeable_ranks + self._special_tokens = special_tokens + + self.max_token_value = max( + max(mergeable_ranks.values()), max(special_tokens.values(), default=0) + ) + if explicit_n_vocab: + assert len(mergeable_ranks) + len(special_tokens) == explicit_n_vocab + assert self.max_token_value == explicit_n_vocab - 1 + + self._core_bpe = _tiktoken.CoreBPE(mergeable_ranks, special_tokens, pat_str) + + def __repr__(self) -> str: + return f"" + + # ==================== + # Encoding + # ==================== + + def encode_ordinary(self, text: str) -> list[int]: + """Encodes a string into tokens, ignoring special tokens. + + This is equivalent to `encode(text, disallowed_special=())` (but slightly faster). + + ``` + >>> enc.encode_ordinary("hello world") + [31373, 995] + """ + return self._core_bpe.encode_ordinary(text) + + def encode( + self, + text: str, + *, + allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006 + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + ) -> list[int]: + """Encodes a string into tokens. + + Special tokens are tokens are artificial tokens used to unlock capabilities from a model, + such as fill-in-the-middle. So we want to be careful about accidentally encoding special + tokens, since they can be used to trick a model into doing something we don't want it to do. + + Hence, by default, encode will raise an error if it encounters text that corresponds + to a special token. This can be controlled on a per-token level using the `allowed_special` + and `disallowed_special` parameters. In particular: + - Setting `disallowed_special` to () will prevent this function from raising errors and + cause all text corresponding to special tokens to be encoded as natural text. + - Setting `allowed_special` to "all" will allow cause this function to treat all text + corresponding to special tokens to be encoded as special tokens. + + ``` + >>> enc.encode("hello world") + [31373, 995] + >>> enc.encode("<|endoftext|>", allowed_special={"<|endoftext|>"}) + [50256] + >>> enc.encode("<|endoftext|>", allowed_special="all") + [50256] + >>> enc.encode("<|endoftext|>") + # Raises ValueError + >>> enc.encode("<|endoftext|>", disallowed_special=()) + [27, 91, 437, 1659, 5239, 91, 29] + ``` + """ + if allowed_special == "all": + allowed_special = self.special_tokens_set + if disallowed_special == "all": + disallowed_special = self.special_tokens_set - allowed_special + if disallowed_special: + if not isinstance(disallowed_special, frozenset): + disallowed_special = frozenset(disallowed_special) + if match := _special_token_regex(disallowed_special).search(text): + raise_disallowed_special_token(match.group()) + + return self._core_bpe.encode(text, allowed_special) + + def encode_ordinary_batch(self, text: list[str], *, num_threads: int = 8) -> list[list[int]]: + """Encodes a list of strings into tokens, in parallel, ignoring special tokens. + + This is equivalent to `encode_batch(text, disallowed_special=())` (but slightly faster). + + ``` + >>> enc.encode_batch(["hello world", "goodbye world"]) + [[31373, 995], [11274, 16390, 995]] + ``` + """ + encoder = functools.partial(self.encode_ordinary) + with ThreadPoolExecutor(num_threads) as e: + return list(e.map(encoder, text)) + + def encode_batch( + self, + text: list[str], + *, + num_threads: int = 8, + allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006 + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + ) -> list[list[int]]: + """Encodes a list of strings into tokens, in parallel. + + See `encode` for more details on `allowed_special` and `disallowed_special`. + + ``` + >>> enc.encode_batch(["hello world", "goodbye world"]) + [[31373, 995], [11274, 16390, 995]] + ``` + """ + if allowed_special == "all": + allowed_special = self.special_tokens_set + if disallowed_special == "all": + disallowed_special = self.special_tokens_set - allowed_special + if not isinstance(disallowed_special, frozenset): + disallowed_special = frozenset(disallowed_special) + + encoder = functools.partial( + self.encode, allowed_special=allowed_special, disallowed_special=disallowed_special + ) + with ThreadPoolExecutor(num_threads) as e: + return list(e.map(encoder, text)) + + def encode_with_unstable( + self, + text: str, + *, + allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006 + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + ) -> tuple[list[int], list[list[int]]]: + """Encodes a string into stable tokens and possible completion sequences. + + Note that the stable tokens will only represent a substring of `text`. + + See `encode` for more details on `allowed_special` and `disallowed_special`. + + ``` + >>> enc.encode_with_unstable("hello fanta") + ([31373], [(277, 4910), (5113, 265), ..., (8842,)]) + + >>> text = "..." + >>> stable_tokens, completions = enc.encode_with_unstable(text) + >>> assert text.encode().startswith(enc.decode_bytes(stable_tokens)) + >>> assert all(enc.decode_bytes(stable_tokens + seq).startswith(text.encode()) for seq in completions) + ``` + """ + if allowed_special == "all": + allowed_special = self.special_tokens_set + if disallowed_special == "all": + disallowed_special = self.special_tokens_set - allowed_special + if disallowed_special: + if not isinstance(disallowed_special, frozenset): + disallowed_special = frozenset(disallowed_special) + if match := _special_token_regex(disallowed_special).search(text): + raise_disallowed_special_token(match.group()) + + return self._core_bpe.encode_with_unstable(text, allowed_special) + + def encode_single_token(self, text_or_bytes: Union[str, bytes]) -> int: + """Encodes text corresponding to a single token to its token value. + + NOTE: this will encode all special tokens. + + Raises `KeyError` if the token is not in the vocabulary. + + ``` + >>> enc.encode_single_token("hello") + 31373 + ``` + """ + if isinstance(text_or_bytes, str): + text_or_bytes = text_or_bytes.encode("utf-8") + return self._core_bpe.encode_single_token(text_or_bytes) + + # ==================== + # Decoding + # ==================== + + def decode_bytes(self, tokens: list[int]) -> bytes: + """Decodes a list of tokens into bytes. + + ``` + >>> enc.decode_bytes([31373, 995]) + b'hello world' + ``` + """ + return self._core_bpe.decode_bytes(tokens) + + def decode(self, tokens: list[int], errors: str = "replace") -> str: + """Decodes a list of tokens into a string. + + WARNING: the default behaviour of this function is lossy, since decoded bytes are not + guaranteed to be valid UTF-8. You can control this behaviour using the `errors` parameter, + for instance, setting `errors=strict`. + + ``` + >>> enc.decode([31373, 995]) + 'hello world' + ``` + """ + return self._core_bpe.decode_bytes(tokens).decode("utf-8", errors=errors) + + def decode_single_token_bytes(self, token: int) -> bytes: + """Decodes a token into bytes. + + NOTE: this will decode all special tokens. + + Raises `KeyError` if the token is not in the vocabulary. + + ``` + >>> enc.decode_single_token_bytes(31373) + b'hello' + ``` + """ + return self._core_bpe.decode_single_token_bytes(token) + + def decode_tokens_bytes(self, tokens: list[int]) -> list[bytes]: + """Decodes a list of tokens into a list of bytes. + + Useful for visualising tokenisation. + >>> enc.decode_tokens_bytes([31373, 995]) + [b'hello', b' world'] + """ + return [self.decode_single_token_bytes(token) for token in tokens] + + # ==================== + # Miscellaneous + # ==================== + + def token_byte_values(self) -> list[bytes]: + """Returns the list of all token byte values.""" + return self._core_bpe.token_byte_values() + + @property + def eot_token(self) -> int: + return self._special_tokens["<|endoftext|>"] + + @functools.cached_property + def special_tokens_set(self) -> set[str]: + return set(self._special_tokens.keys()) + + @property + def n_vocab(self) -> int: + """For backwards compatibility. Prefer to use `enc.max_token_value + 1`.""" + return self.max_token_value + 1 + + # ==================== + # Private + # ==================== + + def _encode_single_piece(self, text_or_bytes: Union[str, bytes]) -> list[int]: + """Encodes text corresponding to bytes without a regex split. + + NOTE: this will not encode any special tokens. + + ``` + >>> enc.encode_single_piece("helloqqqq") + [31373, 38227, 38227] + ``` + """ + if isinstance(text_or_bytes, str): + text_or_bytes = text_or_bytes.encode("utf-8") + return self._core_bpe.encode_single_piece(text_or_bytes) + + def _encode_only_native_bpe(self, text: str) -> list[str]: + """Encodes a string into tokens, but do regex splitting in Python.""" + _unused_pat = regex.compile(self._pat_str) + ret = [] + for piece in regex.findall(_unused_pat, text): + ret.extend(self._core_bpe.encode_single_piece(piece)) + return ret + + def _encode_bytes(self, text: bytes) -> list[int]: + return self._core_bpe._encode_bytes(text) + + +@functools.lru_cache(maxsize=128) +def _special_token_regex(tokens: frozenset[str]) -> "regex.Pattern[str]": + inner = "|".join(regex.escape(token) for token in tokens) + return regex.compile(f"({inner})") + + +def raise_disallowed_special_token(token: str) -> NoReturn: + raise ValueError( + f"Encountered text corresponding to disallowed special token {token!r}.\n" + "If you want this text to be encoded as a special token, " + f"pass it to `allowed_special`, e.g. `allowed_special={{{token!r}, ...}}`.\n" + f"If you want this text to be encoded as normal text, disable the check for this token " + f"by passing `disallowed_special=(enc.special_tokens_set - {{{token!r}}})`.\n" + "To disable this check for all special tokens, pass `disallowed_special=()`.\n" + ) diff --git a/tiktoken/load.py b/tiktoken/load.py new file mode 100644 index 00000000..06e51cc3 --- /dev/null +++ b/tiktoken/load.py @@ -0,0 +1,97 @@ +import base64 +import hashlib +import json +import os +import uuid + +import blobfile + + +def read_file_cached(blobpath: str) -> bytes: + if "TIKTOKEN_CACHE_DIR" in os.environ: + cache_dir = os.environ["TIKTOKEN_CACHE_DIR"] + elif "DATA_GYM_CACHE_DIR" in os.environ: + cache_dir = os.environ["DATA_GYM_CACHE_DIR"] + else: + cache_dir = "/tmp/data-gym-cache" + + if cache_dir == "": + # disable caching + with blobfile.BlobFile(blobpath, "rb") as f: + return f.read() + + cache_key = hashlib.sha1(blobpath.encode()).hexdigest() + + cache_path = os.path.join(cache_dir, cache_key) + if os.path.exists(cache_path): + with open(cache_path, "rb") as f: + return f.read() + + with blobfile.BlobFile(blobpath, "rb") as f: + contents = f.read() + + os.makedirs(cache_dir, exist_ok=True) + tmp_filename = cache_path + "." + str(uuid.uuid4()) + ".tmp" + with open(tmp_filename, "wb") as f: + f.write(contents) + os.rename(tmp_filename, cache_path) + + return contents + + +def data_gym_to_mergeable_bpe_ranks( + vocab_bpe_file: str, encoder_json_file: str +) -> dict[bytes, int]: + # NB: do not add caching to this function + rank_to_intbyte = [b for b in range(2**8) if chr(b).isprintable() and chr(b) != " "] + + data_gym_byte_to_byte = {chr(b): b for b in rank_to_intbyte} + n = 0 + for b in range(2**8): + if b not in rank_to_intbyte: + rank_to_intbyte.append(b) + data_gym_byte_to_byte[chr(2**8 + n)] = b + n += 1 + assert len(rank_to_intbyte) == 2**8 + + # vocab_bpe contains the merges along with associated ranks + vocab_bpe_contents = read_file_cached(vocab_bpe_file).decode() + bpe_merges = [tuple(merge_str.split()) for merge_str in vocab_bpe_contents.split("\n")[1:-1]] + + def decode_data_gym(value: str) -> bytes: + return bytes(data_gym_byte_to_byte[b] for b in value) + + # add the single byte tokens + bpe_ranks = {bytes([b]): i for i, b in enumerate(rank_to_intbyte)} + # add the merged tokens + n = len(bpe_ranks) + for first, second in bpe_merges: + bpe_ranks[decode_data_gym(first) + decode_data_gym(second)] = n + n += 1 + + # check that the encoder file matches the merges file + # this sanity check is important since tiktoken assumes that ranks are ordered the same + # as merge priority + encoder_json = json.loads(read_file_cached(encoder_json_file)) + encoder_json_loaded = {decode_data_gym(k): v for k, v in encoder_json.items()} + # drop these two special tokens if present, since they're not mergeable bpe tokens + encoder_json_loaded.pop(b"<|endoftext|>", None) + encoder_json_loaded.pop(b"<|startoftext|>", None) + assert bpe_ranks == encoder_json_loaded + + return bpe_ranks + + +def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> None: + with blobfile.BlobFile(tiktoken_bpe_file, "wb") as f: + for token, rank in sorted(bpe_ranks.items(), key=lambda x: x[1]): + f.write(base64.b64encode(token) + b" " + str(rank).encode() + b"\n") + + +def load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]: + # NB: do not add caching to this function + contents = read_file_cached(tiktoken_bpe_file) + return { + base64.b64decode(token): int(rank) + for token, rank in (line.split() for line in contents.splitlines() if line) + } diff --git a/tiktoken/registry.py b/tiktoken/registry.py new file mode 100644 index 00000000..24bb1737 --- /dev/null +++ b/tiktoken/registry.py @@ -0,0 +1,71 @@ +import importlib +import pkgutil +import threading +from typing import Any, Callable, Optional + +import tiktoken_ext + +from tiktoken.core import Encoding + +_lock = threading.RLock() +ENCODINGS: dict[str, Encoding] = {} +ENCODING_CONSTRUCTORS: Optional[dict[str, Callable[[], dict[str, Any]]]] = None + + +def _find_constructors() -> None: + global ENCODING_CONSTRUCTORS + with _lock: + if ENCODING_CONSTRUCTORS is not None: + return + ENCODING_CONSTRUCTORS = {} + + # tiktoken_ext is a namespace package + # submodules inside tiktoken_ext will be inspected for ENCODING_CONSTRUCTORS attributes + # - we use namespace package pattern so `pkgutil.iter_modules` is fast + # - it's a separate top-level package because namespace subpackages of non-namespace + # packages don't quite do what you want with editable installs + plugin_mods = pkgutil.iter_modules(tiktoken_ext.__path__, tiktoken_ext.__name__ + ".") + + for _, mod_name, _ in plugin_mods: + mod = importlib.import_module(mod_name) + try: + constructors = mod.ENCODING_CONSTRUCTORS + except AttributeError as e: + raise ValueError( + f"tiktoken plugin {mod_name} does not define ENCODING_CONSTRUCTORS" + ) from e + for enc_name, constructor in constructors.items(): + if enc_name in ENCODING_CONSTRUCTORS: + raise ValueError( + f"Duplicate encoding name {enc_name} in tiktoken plugin {mod_name}" + ) + ENCODING_CONSTRUCTORS[enc_name] = constructor + + +def get_encoding(encoding_name: str) -> Encoding: + if encoding_name in ENCODINGS: + return ENCODINGS[encoding_name] + + with _lock: + if encoding_name in ENCODINGS: + return ENCODINGS[encoding_name] + + if ENCODING_CONSTRUCTORS is None: + _find_constructors() + assert ENCODING_CONSTRUCTORS is not None + + if encoding_name not in ENCODING_CONSTRUCTORS: + raise ValueError(f"Unknown encoding {encoding_name}") + + constructor = ENCODING_CONSTRUCTORS[encoding_name] + enc = Encoding(**constructor()) + ENCODINGS[encoding_name] = enc + return enc + + +def list_encoding_names() -> list[str]: + with _lock: + if ENCODING_CONSTRUCTORS is None: + _find_constructors() + assert ENCODING_CONSTRUCTORS is not None + return list(ENCODING_CONSTRUCTORS) diff --git a/tiktoken_ext/openai_public.py b/tiktoken_ext/openai_public.py new file mode 100644 index 00000000..cc6ad3cd --- /dev/null +++ b/tiktoken_ext/openai_public.py @@ -0,0 +1,41 @@ +from tiktoken.load import data_gym_to_mergeable_bpe_ranks, load_tiktoken_bpe + +ENDOFTEXT = "<|endoftext|>" +FIM_PREFIX = "<|fim_prefix|>" +FIM_MIDDLE = "<|fim_middle|>" +FIM_SUFFIX = "<|fim_suffix|>" +ENDOFPROMPT = "<|endofprompt|>" + + +def gpt2(): + mergeable_ranks = data_gym_to_mergeable_bpe_ranks( + vocab_bpe_file="az://openaipublic/gpt-2/encodings/main/vocab.bpe", + encoder_json_file="az://openaipublic/gpt-2/encodings/main/encoder.json", + ) + return { + "name": "gpt2", + "explicit_n_vocab": 50257, + "pat_str": r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", + "mergeable_ranks": mergeable_ranks, + "special_tokens": {"<|endoftext|>": 50256}, + } + + +def cl100k_base(): + mergeable_ranks = load_tiktoken_bpe("az://openaipublic/encodings/cl100k_base.tiktoken") + special_tokens = { + ENDOFTEXT: 100257, + FIM_PREFIX: 100258, + FIM_MIDDLE: 100259, + FIM_SUFFIX: 100260, + ENDOFPROMPT: 100276, + } + return { + "name": "cl100k_base", + "pat_str": r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""", + "mergeable_ranks": mergeable_ranks, + "special_tokens": special_tokens, + } + + +ENCODING_CONSTRUCTORS = {"gpt2": gpt2, "cl100k_base": cl100k_base}