Skip to content

Commit e08bf4a

Browse files
authoredMay 28, 2024
Add eval for accuracy of fine-tune vs. original CLIP
1 parent 34543e3 commit e08bf4a

7 files changed

+1019
-0
lines changed
 

‎ft-D-eval-imagenet-objectnet.py

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import torch
2+
from torch.utils.data import Dataset, DataLoader
3+
# Import original CLIP code with modification to bypass SHA256 checksum verification
4+
# Don't use this to load arbitrary third-party models, google "pickle vulnerability" for details
5+
# However, this allows you to use clip.load on your own (safe) fine-tuned model:
6+
from orgclipnosha import clip
7+
from PIL import Image
8+
from tqdm import tqdm
9+
import pandas as pd
10+
import os
11+
12+
# Download from https://objectnet.dev/mvt/
13+
14+
# With "normal" fine-tuning, your model is expected to overfit on your dataset and alas,
15+
# become worse at generalizing with the above dataset. This is NORMAL.
16+
# Accuracy of 0.5 - 0.7 generally means "good preservation of pre-training",
17+
# likely a great model when using this CLIP as TE for SDXL exclusively.
18+
# Accuracy of 0.3 can be great if you trained some "glitch art" / weird dataset model.
19+
# Accuracy <0.1: No. You ruined the model, that won't be a good guide / TE.
20+
21+
# I am mainly adding this code for replication of my GmP-CLIP results.
22+
# Make sure you use the GmP fine-tuned model after "convert-GmP-back-to-weight" below:
23+
24+
# Load csv labels file from dataset:
25+
csv_file = 'path/to/mvt/dataset/data_release_2023/human_responses.csv'
26+
27+
clipmodel = 'ViT-L/14'
28+
# Your fine-tuned model below:
29+
finetunedclip = "ft-checkpoints/clip_ft_20.pt"
30+
device = "cuda" if torch.cuda.is_available() else "cpu"
31+
32+
# Load models
33+
original_model, preprocess = clip.load(clipmodel, device=device, jit=False)
34+
finetuned_model, preprocess = clip.load(finetunedclip, device=device)
35+
36+
# Dataset class to load images and their corresponding labels from CSV
37+
class CroppedImageCSVFileDataset(Dataset):
38+
def __init__(self, csv_file, image_folder, transform=None):
39+
self.data = pd.read_csv(csv_file)
40+
self.image_folder = image_folder
41+
self.transform = transform
42+
43+
def __len__(self):
44+
return len(self.data)
45+
46+
def __getitem__(self, idx):
47+
image_name = self.data.iloc[idx]['image']
48+
image_path = os.path.join(self.image_folder, image_name)
49+
image = Image.open(image_path).convert('RGB') # Convert to RGB
50+
if self.transform:
51+
image = self.transform(image)
52+
53+
label = self.data.iloc[idx]['label']
54+
55+
return image, label
56+
57+
# Path to the image folder that contains ALL images from the MVT dataset:
58+
image_folder = 'path/to/mvt/dataset/data_release_2023/all/'
59+
60+
# Create dataset and dataloader
61+
dataset = CroppedImageCSVFileDataset(csv_file, image_folder, transform=preprocess)
62+
dataloader = DataLoader(dataset, batch_size=48, shuffle=True)
63+
64+
# Function to evaluate model on custom dataset
65+
def evaluate_model(model, dataloader):
66+
correct = 0
67+
total = 0
68+
69+
for batch_images, batch_labels in tqdm(dataloader):
70+
batch_images = batch_images.to(device)
71+
batch_texts = clip.tokenize(batch_labels).to(device)
72+
73+
with torch.no_grad():
74+
image_embeddings = model.encode_image(batch_images)
75+
text_embeddings = model.encode_text(batch_texts)
76+
logits_per_image = (image_embeddings @ text_embeddings.T).softmax(dim=-1)
77+
78+
# Get the top predictions
79+
_, top_indices = logits_per_image.topk(1, dim=-1)
80+
81+
for i, label in enumerate(batch_labels):
82+
if label == batch_labels[top_indices[i]]:
83+
correct += 1
84+
total += 1
85+
86+
accuracy = correct / total
87+
return accuracy
88+
89+
# Evaluate original and fine-tuned models
90+
original_accuracy = evaluate_model(original_model, dataloader)
91+
finetuned_accuracy = evaluate_model(finetuned_model, dataloader)
92+
93+
print(f"Original Model Accuracy on MVT ImageNet/ObjectNet: {original_accuracy}")
94+
print(f"Fine-tuned Model Accuracy on MVT ImageNet/ObjectNet: {finetuned_accuracy}")
+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import torch
2+
from torch.utils.data import DataLoader
3+
# Import original CLIP code with modification to bypass SHA256 checksum verification
4+
# Don't use this to load arbitrary third-party models, google "pickle vulnerability" for details
5+
# However, this allows you to use clip.load on your own (safe) fine-tuned model:
6+
from orgclipnosha import clip
7+
from PIL import Image
8+
import json
9+
from tqdm import tqdm
10+
import os
11+
import random
12+
from torch.utils.data import Dataset
13+
14+
class ImageTextDataset(Dataset):
15+
def __init__(self, image_folder, annotations_file, transform=None):
16+
self.image_folder = image_folder
17+
self.transform = transform
18+
with open(annotations_file, 'r') as f:
19+
self.annotations = json.load(f)
20+
self.image_paths = list(self.annotations.keys())
21+
22+
def __len__(self):
23+
return len(self.image_paths)
24+
25+
def __getitem__(self, idx):
26+
image_path = os.path.join(self.image_folder, self.image_paths[idx])
27+
image = Image.open(image_path).convert('RGB') # Convert to RGB
28+
if self.transform:
29+
image = self.transform(image)
30+
31+
labels = self.annotations[self.image_paths[idx]]
32+
33+
# Just used normal dataloader from training, set this to...
34+
# ...more labels than there are + alas use what is defined in "elif":
35+
if len(labels) >= 20:
36+
label = random.choice([labels[0], labels[1]])
37+
elif labels:
38+
label = labels[1] # Use second label = short original CoCo dataset label
39+
else:
40+
label = '' # Fallback if no labels are available
41+
42+
return image, label
43+
44+
45+
clipmodel = 'ViT-L/14'
46+
# Your fine-tuned model below:
47+
finetunedclip = "ft-checkpoints/clip_ft_20.pt"
48+
device = "cuda" if torch.cuda.is_available() else "cpu"
49+
50+
# Load models with clip.load()
51+
original_model, preprocess = clip.load(clipmodel, device=device, jit=False)
52+
finetuned_model, preprocess = clip.load(finetunedclip, device=device)
53+
54+
# Load the validation dataset the fine-tune did NOT learn on:
55+
val_dataset = ImageTextDataset(
56+
"path/to/image/folder",
57+
"path/to/validation/dataset/labels.json",
58+
transform=preprocess
59+
)
60+
61+
# Load the train dataset the fine-tune has learned (overfit-eval):
62+
train_dataset = ImageTextDataset(
63+
"path/to/image/folder",
64+
"path/to/training/dataset/labels.json",
65+
transform=preprocess
66+
)
67+
68+
batch_size = 48
69+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
70+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
71+
72+
73+
def evaluate_model(model, data_loader):
74+
correct = 0
75+
total = 0
76+
77+
for batch_images, batch_labels in tqdm(data_loader):
78+
batch_images = batch_images.to(device)
79+
batch_texts = clip.tokenize(batch_labels).to(device)
80+
81+
with torch.no_grad():
82+
image_embeddings = model.encode_image(batch_images)
83+
text_embeddings = model.encode_text(batch_texts)
84+
logits_per_image = (image_embeddings @ text_embeddings.T).softmax(dim=-1)
85+
86+
# Get the top predictions
87+
_, top_indices = logits_per_image.topk(1, dim=-1)
88+
89+
for i, label in enumerate(batch_labels):
90+
if label == batch_labels[top_indices[i]]:
91+
correct += 1
92+
total += 1
93+
94+
accuracy = correct / total
95+
return accuracy
96+
97+
original_accuracy = evaluate_model(original_model, val_loader)
98+
finetuned_accuracy = evaluate_model(finetuned_model, val_loader)
99+
100+
print(f"Original Model Accuracy on val: {original_accuracy}")
101+
print(f"Fine-tuned Model Accuracy on val: {finetuned_accuracy}")
102+
103+
original_accuracy_train = evaluate_model(original_model, train_loader)
104+
finetuned_accuracy_train = evaluate_model(finetuned_model, train_loader)
105+
106+
print(f"Original Model Accuracy on train: {original_accuracy_train}")
107+
print(f"Fine-tuned Model Accuracy on train: {finetuned_accuracy_train}")
108+
109+
print("\nNote: Your fine-tune should be better than the original model. However, if the difference on 'train' far exceeds the difference on 'val', this suggests overfit (bad).")

‎orgclipnosha/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .clip import *
1.29 MB
Binary file not shown.

‎orgclipnosha/clip.py

+247
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
import hashlib
2+
import os
3+
import urllib
4+
import warnings
5+
from typing import Any, Union, List
6+
from pkg_resources import packaging
7+
8+
import torch
9+
from PIL import Image
10+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11+
from tqdm import tqdm
12+
13+
from .model import build_model
14+
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15+
16+
try:
17+
from torchvision.transforms import InterpolationMode
18+
BICUBIC = InterpolationMode.BICUBIC
19+
except ImportError:
20+
BICUBIC = Image.BICUBIC
21+
22+
23+
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
24+
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
25+
26+
27+
__all__ = ["available_models", "load", "tokenize"]
28+
_tokenizer = _Tokenizer()
29+
30+
_MODELS = {
31+
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
32+
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
33+
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
34+
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
35+
"RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
36+
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
37+
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
38+
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
39+
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
40+
}
41+
42+
def _download(url: str, root: str):
43+
os.makedirs(root, exist_ok=True)
44+
filename = os.path.basename(url)
45+
46+
# Commenting out the expected_sha256 as we're bypassing checksum verification
47+
# expected_sha256 = url.split("/")[-2]
48+
download_target = os.path.join(root, filename)
49+
50+
if os.path.exists(download_target) and not os.path.isfile(download_target):
51+
raise RuntimeError(f"{download_target} exists and is not a regular file")
52+
53+
if os.path.isfile(download_target):
54+
# Bypassing the SHA256 checksum verification
55+
return download_target
56+
57+
58+
59+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
60+
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
61+
while True:
62+
buffer = source.read(8192)
63+
if not buffer:
64+
break
65+
66+
output.write(buffer)
67+
loop.update(len(buffer))
68+
69+
70+
71+
72+
# Bypassing the SHA256 checksum verification on download completion
73+
return download_target
74+
75+
76+
def _convert_image_to_rgb(image):
77+
return image.convert("RGB")
78+
79+
80+
def _transform(n_px):
81+
return Compose([
82+
Resize(n_px, interpolation=BICUBIC),
83+
CenterCrop(n_px),
84+
_convert_image_to_rgb,
85+
ToTensor(),
86+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
87+
])
88+
89+
90+
def available_models() -> List[str]:
91+
"""Returns the names of available CLIP models"""
92+
return list(_MODELS.keys())
93+
94+
95+
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
96+
"""Load a CLIP model
97+
98+
Parameters
99+
----------
100+
name : str
101+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
102+
103+
device : Union[str, torch.device]
104+
The device to put the loaded model
105+
106+
jit : bool
107+
Whether to load the optimized JIT model or more hackable non-JIT model (default).
108+
109+
download_root: str
110+
path to download the model files; by default, it uses "~/.cache/clip"
111+
112+
Returns
113+
-------
114+
model : torch.nn.Module
115+
The CLIP model
116+
117+
preprocess : Callable[[PIL.Image], torch.Tensor]
118+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
119+
"""
120+
if name in _MODELS:
121+
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
122+
elif os.path.isfile(name):
123+
model_path = name
124+
else:
125+
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
126+
127+
with open(model_path, 'rb') as opened_file:
128+
try:
129+
# loading JIT archive
130+
#model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
131+
model = torch.load(opened_file, map_location=device if jit else "cpu").eval()
132+
state_dict = None
133+
except RuntimeError:
134+
# loading saved state dict
135+
if jit:
136+
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
137+
jit = False
138+
state_dict = torch.load(opened_file, map_location="cpu")
139+
140+
if not jit:
141+
model = build_model(state_dict or model.state_dict()).to(device)
142+
if str(device) == "cpu":
143+
model.float()
144+
return model, _transform(model.visual.input_resolution)
145+
146+
# patch the device names
147+
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
148+
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
149+
150+
def _node_get(node: torch._C.Node, key: str):
151+
"""Gets attributes of a node which is polymorphic over return type.
152+
153+
From https://github.com/pytorch/pytorch/pull/82628
154+
"""
155+
sel = node.kindOf(key)
156+
return getattr(node, sel)(key)
157+
158+
def patch_device(module):
159+
try:
160+
graphs = [module.graph] if hasattr(module, "graph") else []
161+
except RuntimeError:
162+
graphs = []
163+
164+
if hasattr(module, "forward1"):
165+
graphs.append(module.forward1.graph)
166+
167+
for graph in graphs:
168+
for node in graph.findAllNodes("prim::Constant"):
169+
if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"):
170+
node.copyAttributes(device_node)
171+
172+
model.apply(patch_device)
173+
patch_device(model.encode_image)
174+
patch_device(model.encode_text)
175+
176+
# patch dtype to float32 on CPU
177+
if str(device) == "cpu":
178+
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
179+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
180+
float_node = float_input.node()
181+
182+
def patch_float(module):
183+
try:
184+
graphs = [module.graph] if hasattr(module, "graph") else []
185+
except RuntimeError:
186+
graphs = []
187+
188+
if hasattr(module, "forward1"):
189+
graphs.append(module.forward1.graph)
190+
191+
for graph in graphs:
192+
for node in graph.findAllNodes("aten::to"):
193+
inputs = list(node.inputs())
194+
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
195+
if _node_get(inputs[i].node(), "value") == 5:
196+
inputs[i].node().copyAttributes(float_node)
197+
198+
model.apply(patch_float)
199+
patch_float(model.encode_image)
200+
patch_float(model.encode_text)
201+
202+
model.float()
203+
204+
return model, _transform(model.input_resolution.item())
205+
206+
207+
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
208+
"""
209+
Returns the tokenized representation of given input string(s)
210+
211+
Parameters
212+
----------
213+
texts : Union[str, List[str]]
214+
An input string or a list of input strings to tokenize
215+
216+
context_length : int
217+
The context length to use; all CLIP models use 77 as the context length
218+
219+
truncate: bool
220+
Whether to truncate the text in case its encoding is longer than the context length
221+
222+
Returns
223+
-------
224+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
225+
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
226+
"""
227+
if isinstance(texts, str):
228+
texts = [texts]
229+
230+
sot_token = _tokenizer.encoder["<|startoftext|>"]
231+
eot_token = _tokenizer.encoder["<|endoftext|>"]
232+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
233+
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
234+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
235+
else:
236+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
237+
238+
for i, tokens in enumerate(all_tokens):
239+
if len(tokens) > context_length:
240+
if truncate:
241+
tokens = tokens[:context_length]
242+
tokens[-1] = eot_token
243+
else:
244+
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
245+
result[i, :len(tokens)] = torch.tensor(tokens)
246+
247+
return result

‎orgclipnosha/model.py

+436
Large diffs are not rendered by default.

‎orgclipnosha/simple_tokenizer.py

+132
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import gzip
2+
import html
3+
import os
4+
from functools import lru_cache
5+
6+
import ftfy
7+
import regex as re
8+
9+
10+
@lru_cache()
11+
def default_bpe():
12+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13+
14+
15+
@lru_cache()
16+
def bytes_to_unicode():
17+
"""
18+
Returns list of utf-8 byte and a corresponding list of unicode strings.
19+
The reversible bpe codes work on unicode strings.
20+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22+
This is a signficant percentage of your normal, say, 32K bpe vocab.
23+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24+
And avoids mapping to whitespace/control characters the bpe code barfs on.
25+
"""
26+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27+
cs = bs[:]
28+
n = 0
29+
for b in range(2**8):
30+
if b not in bs:
31+
bs.append(b)
32+
cs.append(2**8+n)
33+
n += 1
34+
cs = [chr(n) for n in cs]
35+
return dict(zip(bs, cs))
36+
37+
38+
def get_pairs(word):
39+
"""Return set of symbol pairs in a word.
40+
Word is represented as tuple of symbols (symbols being variable-length strings).
41+
"""
42+
pairs = set()
43+
prev_char = word[0]
44+
for char in word[1:]:
45+
pairs.add((prev_char, char))
46+
prev_char = char
47+
return pairs
48+
49+
50+
def basic_clean(text):
51+
text = ftfy.fix_text(text)
52+
text = html.unescape(html.unescape(text))
53+
return text.strip()
54+
55+
56+
def whitespace_clean(text):
57+
text = re.sub(r'\s+', ' ', text)
58+
text = text.strip()
59+
return text
60+
61+
62+
class SimpleTokenizer(object):
63+
def __init__(self, bpe_path: str = default_bpe()):
64+
self.byte_encoder = bytes_to_unicode()
65+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67+
merges = merges[1:49152-256-2+1]
68+
merges = [tuple(merge.split()) for merge in merges]
69+
vocab = list(bytes_to_unicode().values())
70+
vocab = vocab + [v+'</w>' for v in vocab]
71+
for merge in merges:
72+
vocab.append(''.join(merge))
73+
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74+
self.encoder = dict(zip(vocab, range(len(vocab))))
75+
self.decoder = {v: k for k, v in self.encoder.items()}
76+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
77+
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78+
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79+
80+
def bpe(self, token):
81+
if token in self.cache:
82+
return self.cache[token]
83+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
84+
pairs = get_pairs(word)
85+
86+
if not pairs:
87+
return token+'</w>'
88+
89+
while True:
90+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91+
if bigram not in self.bpe_ranks:
92+
break
93+
first, second = bigram
94+
new_word = []
95+
i = 0
96+
while i < len(word):
97+
try:
98+
j = word.index(first, i)
99+
new_word.extend(word[i:j])
100+
i = j
101+
except:
102+
new_word.extend(word[i:])
103+
break
104+
105+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
106+
new_word.append(first+second)
107+
i += 2
108+
else:
109+
new_word.append(word[i])
110+
i += 1
111+
new_word = tuple(new_word)
112+
word = new_word
113+
if len(word) == 1:
114+
break
115+
else:
116+
pairs = get_pairs(word)
117+
word = ' '.join(word)
118+
self.cache[token] = word
119+
return word
120+
121+
def encode(self, text):
122+
bpe_tokens = []
123+
text = whitespace_clean(basic_clean(text)).lower()
124+
for token in re.findall(self.pat, text):
125+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127+
return bpe_tokens
128+
129+
def decode(self, tokens):
130+
text = ''.join([self.decoder[token] for token in tokens])
131+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
132+
return text

0 commit comments

Comments
 (0)
Please sign in to comment.