Skip to content

Commit

Permalink
Add support for checking hash of downloaded files before use.
Browse files Browse the repository at this point in the history
  • Loading branch information
Matt Welsh committed Dec 21, 2023
1 parent 9e79899 commit c997326
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 12 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "tiktoken"
version = "0.5.2"
version = "0.6.0"
description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models"
readme = "README.md"
license = {file = "LICENSE"}
Expand Down
31 changes: 24 additions & 7 deletions tiktoken/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import tempfile
import uuid
from typing import Optional

import requests

Expand All @@ -26,7 +27,12 @@ def read_file(blobpath: str) -> bytes:
return resp.content


def read_file_cached(blobpath: str) -> bytes:
def check_hash(data: bytes, hash: str) -> bool:
data_hash = hashlib.sha256(data).hexdigest()
return data_hash == hash


def read_file_cached(blobpath: str, expected_hash: Optional[str]=None) -> bytes:
user_specified_cache = True
if "TIKTOKEN_CACHE_DIR" in os.environ:
cache_dir = os.environ["TIKTOKEN_CACHE_DIR"]
Expand All @@ -45,9 +51,20 @@ def read_file_cached(blobpath: str) -> bytes:
cache_path = os.path.join(cache_dir, cache_key)
if os.path.exists(cache_path):
with open(cache_path, "rb") as f:
return f.read()
data = f.read()
if expected_hash and not check_hash(data, expected_hash):
raise ValueError(
f"Hash mismatch for cached data from {blobpath} (expected {expected_hash}). "
f"Please delete the cache file at {cache_path} and try again."
)
return data

contents = read_file(blobpath)
if expected_hash and not check_hash(contents, expected_hash):
raise ValueError(
f"Hash mismatch for data downloaded from {blobpath} (expected {expected_hash}). "
f"This may indicate a corrupted download. Please try again."
)

try:
os.makedirs(cache_dir, exist_ok=True)
Expand All @@ -64,7 +81,7 @@ def read_file_cached(blobpath: str) -> bytes:


def data_gym_to_mergeable_bpe_ranks(
vocab_bpe_file: str, encoder_json_file: str
vocab_bpe_file: str, encoder_json_file: str, vocab_bpe_hash: Optional[str]=None, encoder_json_hash: Optional[str]=None
) -> dict[bytes, int]:
# NB: do not add caching to this function
rank_to_intbyte = [b for b in range(2**8) if chr(b).isprintable() and chr(b) != " "]
Expand All @@ -79,7 +96,7 @@ def data_gym_to_mergeable_bpe_ranks(
assert len(rank_to_intbyte) == 2**8

# vocab_bpe contains the merges along with associated ranks
vocab_bpe_contents = read_file_cached(vocab_bpe_file).decode()
vocab_bpe_contents = read_file_cached(vocab_bpe_file, vocab_bpe_hash).decode()
bpe_merges = [tuple(merge_str.split()) for merge_str in vocab_bpe_contents.split("\n")[1:-1]]

def decode_data_gym(value: str) -> bytes:
Expand All @@ -96,7 +113,7 @@ def decode_data_gym(value: str) -> bytes:
# check that the encoder file matches the merges file
# this sanity check is important since tiktoken assumes that ranks are ordered the same
# as merge priority
encoder_json = json.loads(read_file_cached(encoder_json_file))
encoder_json = json.loads(read_file_cached(encoder_json_file, encoder_json_hash))
encoder_json_loaded = {decode_data_gym(k): v for k, v in encoder_json.items()}
# drop these two special tokens if present, since they're not mergeable bpe tokens
encoder_json_loaded.pop(b"<|endoftext|>", None)
Expand All @@ -118,9 +135,9 @@ def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> No
f.write(base64.b64encode(token) + b" " + str(rank).encode() + b"\n")


def load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]:
def load_tiktoken_bpe(tiktoken_bpe_file: str, expected_hash: Optional[str]=None) -> dict[bytes, int]:
# NB: do not add caching to this function
contents = read_file_cached(tiktoken_bpe_file)
contents = read_file_cached(tiktoken_bpe_file, expected_hash)
return {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in contents.splitlines() if line)
Expand Down
14 changes: 10 additions & 4 deletions tiktoken_ext/openai_public.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ def gpt2():
mergeable_ranks = data_gym_to_mergeable_bpe_ranks(
vocab_bpe_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/vocab.bpe",
encoder_json_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/encoder.json",
vocab_bpe_hash="1ce1664773c50f3e0cc8842619a93edc4624525b728b188a9e0be33b7726adc5",
encoder_json_hash="196139668be63f3b5d6574427317ae82f612a97c5d1cdaf36ed2256dbf636783",
)
return {
"name": "gpt2",
Expand All @@ -23,7 +25,8 @@ def gpt2():

def r50k_base():
mergeable_ranks = load_tiktoken_bpe(
"https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken"
"https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken",
expected_hash="306cd27f03c1a714eca7108e03d66b7dc042abe8c258b44c199a7ed9838dd930",
)
return {
"name": "r50k_base",
Expand All @@ -36,7 +39,8 @@ def r50k_base():

def p50k_base():
mergeable_ranks = load_tiktoken_bpe(
"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken"
"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken",
expected_hash="94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069",
)
return {
"name": "p50k_base",
Expand All @@ -49,7 +53,8 @@ def p50k_base():

def p50k_edit():
mergeable_ranks = load_tiktoken_bpe(
"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken"
"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken",
expected_hash="94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069",
)
special_tokens = {ENDOFTEXT: 50256, FIM_PREFIX: 50281, FIM_MIDDLE: 50282, FIM_SUFFIX: 50283}
return {
Expand All @@ -62,7 +67,8 @@ def p50k_edit():

def cl100k_base():
mergeable_ranks = load_tiktoken_bpe(
"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken"
"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken",
expected_hash="223921b76ee99bde995b7ff738513eef100fb51d18c93597a113bcffe865b2a7",
)
special_tokens = {
ENDOFTEXT: 100257,
Expand Down

0 comments on commit c997326

Please sign in to comment.