Skip to content

Commit

Permalink
Add support for checking hash of downloaded files before use. (openai…
Browse files Browse the repository at this point in the history
…#230)

We are using tiktoken in various production scenarios and sometimes have
the problem that the download of `.tiktoken` files (e.g.,
`cl100k_base.tiktoken`) will get interrupted or fail, causing the cached
file to be corrupted in some way. In those cases, the results returned
from the encoder will be incorrect and could be damaging to our
production instances.

More often, when this happens, `Encoder.encode()` will throw an
exception such as
```
pyo3_runtime.PanicException: no entry found for key
```
which turns out to be quite hard to track down.

In an effort to make tiktoken more robust for production use, this PR
adds the `sha256` hash of each of the downloaded files to
`openai_public.py` and augments `read_file` to check for the hash, if
provided, when the file is accessed from the cache or downloaded
directly. This causes errors to be flagged at file load time, rather
than when the files are used, and provides a more meaningful error
message indicating what might have gone wrong.

This also protects users of tiktoken from scenarios where a network
issue or MITM attack could have corrupted these files in transit.
  • Loading branch information
mdwelsh authored Jan 30, 2024
1 parent 9e79899 commit 3ee6c35
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 11 deletions.
31 changes: 24 additions & 7 deletions tiktoken/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import tempfile
import uuid
from typing import Optional

import requests

Expand All @@ -26,7 +27,12 @@ def read_file(blobpath: str) -> bytes:
return resp.content


def read_file_cached(blobpath: str) -> bytes:
def check_hash(data: bytes, hash: str) -> bool:
data_hash = hashlib.sha256(data).hexdigest()
return data_hash == hash


def read_file_cached(blobpath: str, expected_hash: Optional[str]=None) -> bytes:
user_specified_cache = True
if "TIKTOKEN_CACHE_DIR" in os.environ:
cache_dir = os.environ["TIKTOKEN_CACHE_DIR"]
Expand All @@ -45,9 +51,20 @@ def read_file_cached(blobpath: str) -> bytes:
cache_path = os.path.join(cache_dir, cache_key)
if os.path.exists(cache_path):
with open(cache_path, "rb") as f:
return f.read()
data = f.read()
if expected_hash and not check_hash(data, expected_hash):
raise ValueError(
f"Hash mismatch for cached data from {blobpath} (expected {expected_hash}). "
f"Please delete the cache file at {cache_path} and try again."
)
return data

contents = read_file(blobpath)
if expected_hash and not check_hash(contents, expected_hash):
raise ValueError(
f"Hash mismatch for data downloaded from {blobpath} (expected {expected_hash}). "
f"This may indicate a corrupted download. Please try again."
)

try:
os.makedirs(cache_dir, exist_ok=True)
Expand All @@ -64,7 +81,7 @@ def read_file_cached(blobpath: str) -> bytes:


def data_gym_to_mergeable_bpe_ranks(
vocab_bpe_file: str, encoder_json_file: str
vocab_bpe_file: str, encoder_json_file: str, vocab_bpe_hash: Optional[str]=None, encoder_json_hash: Optional[str]=None
) -> dict[bytes, int]:
# NB: do not add caching to this function
rank_to_intbyte = [b for b in range(2**8) if chr(b).isprintable() and chr(b) != " "]
Expand All @@ -79,7 +96,7 @@ def data_gym_to_mergeable_bpe_ranks(
assert len(rank_to_intbyte) == 2**8

# vocab_bpe contains the merges along with associated ranks
vocab_bpe_contents = read_file_cached(vocab_bpe_file).decode()
vocab_bpe_contents = read_file_cached(vocab_bpe_file, vocab_bpe_hash).decode()
bpe_merges = [tuple(merge_str.split()) for merge_str in vocab_bpe_contents.split("\n")[1:-1]]

def decode_data_gym(value: str) -> bytes:
Expand All @@ -96,7 +113,7 @@ def decode_data_gym(value: str) -> bytes:
# check that the encoder file matches the merges file
# this sanity check is important since tiktoken assumes that ranks are ordered the same
# as merge priority
encoder_json = json.loads(read_file_cached(encoder_json_file))
encoder_json = json.loads(read_file_cached(encoder_json_file, encoder_json_hash))
encoder_json_loaded = {decode_data_gym(k): v for k, v in encoder_json.items()}
# drop these two special tokens if present, since they're not mergeable bpe tokens
encoder_json_loaded.pop(b"<|endoftext|>", None)
Expand All @@ -118,9 +135,9 @@ def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> No
f.write(base64.b64encode(token) + b" " + str(rank).encode() + b"\n")


def load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]:
def load_tiktoken_bpe(tiktoken_bpe_file: str, expected_hash: Optional[str]=None) -> dict[bytes, int]:
# NB: do not add caching to this function
contents = read_file_cached(tiktoken_bpe_file)
contents = read_file_cached(tiktoken_bpe_file, expected_hash)
return {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in contents.splitlines() if line)
Expand Down
14 changes: 10 additions & 4 deletions tiktoken_ext/openai_public.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ def gpt2():
mergeable_ranks = data_gym_to_mergeable_bpe_ranks(
vocab_bpe_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/vocab.bpe",
encoder_json_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/encoder.json",
vocab_bpe_hash="1ce1664773c50f3e0cc8842619a93edc4624525b728b188a9e0be33b7726adc5",
encoder_json_hash="196139668be63f3b5d6574427317ae82f612a97c5d1cdaf36ed2256dbf636783",
)
return {
"name": "gpt2",
Expand All @@ -23,7 +25,8 @@ def gpt2():

def r50k_base():
mergeable_ranks = load_tiktoken_bpe(
"https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken"
"https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken",
expected_hash="306cd27f03c1a714eca7108e03d66b7dc042abe8c258b44c199a7ed9838dd930",
)
return {
"name": "r50k_base",
Expand All @@ -36,7 +39,8 @@ def r50k_base():

def p50k_base():
mergeable_ranks = load_tiktoken_bpe(
"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken"
"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken",
expected_hash="94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069",
)
return {
"name": "p50k_base",
Expand All @@ -49,7 +53,8 @@ def p50k_base():

def p50k_edit():
mergeable_ranks = load_tiktoken_bpe(
"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken"
"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken",
expected_hash="94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069",
)
special_tokens = {ENDOFTEXT: 50256, FIM_PREFIX: 50281, FIM_MIDDLE: 50282, FIM_SUFFIX: 50283}
return {
Expand All @@ -62,7 +67,8 @@ def p50k_edit():

def cl100k_base():
mergeable_ranks = load_tiktoken_bpe(
"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken"
"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken",
expected_hash="223921b76ee99bde995b7ff738513eef100fb51d18c93597a113bcffe865b2a7",
)
special_tokens = {
ENDOFTEXT: 100257,
Expand Down

0 comments on commit 3ee6c35

Please sign in to comment.