-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpart4.py
127 lines (106 loc) · 3.85 KB
/
part4.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import numpy as np
from transformers import BertConfig, BertForQuestionAnswering, BertTokenizer
from transformers.data.processors.squad import SquadV1Processor
from transformers import squad_convert_examples_to_features
from torch.utils.data import DataLoader
import torch
import time
import argparse
parser = argparse.ArgumentParser()
cache_dir = os.path.join(".", "cache_models")
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
# The predict file contains the dataset. Questions and Context and answers for the model.
predict_file_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json"
predict_file = os.path.join(cache_dir, "dev-v1.1.json")
if not os.path.exists(predict_file):
import wget
print("Start downloading predict file.")
wget.download(predict_file_url, predict_file)
print("Predict file downloaded.")
####################################################################
# define the BERT base model and its max seq length
model_name_or_path = "bert-base-cased"
max_seq_length = 128
doc_stride = 128 # doc_stride used when the context is too large and is split across several features.
max_query_length = 64
# Total samples to inference. It shall be large enough to get stable latency measurement.
total_samples = 100
device = "cpu"
####################################################################
# Load pretrained model and tokenizer
config_class, model_class, tokenizer_class = (
BertConfig,
BertForQuestionAnswering,
BertTokenizer,
)
config = config_class.from_pretrained(model_name_or_path, cache_dir=cache_dir)
# print(config)
tokenizer = tokenizer_class.from_pretrained(
model_name_or_path, do_lower_case=True, cache_dir=cache_dir
)
model = model_class.from_pretrained(
model_name_or_path, from_tf=False, config=config, cache_dir=cache_dir
)
# Load some examples
processor = SquadV1Processor()
examples = processor.get_dev_examples(None, filename=predict_file)
features, dataset = squad_convert_examples_to_features(
examples=examples[:total_samples], # convert just enough examples for this notebook
tokenizer=tokenizer,
max_seq_length=max_seq_length,
doc_stride=doc_stride,
max_query_length=max_query_length,
is_training=False,
return_dataset="pt",
threads=1,
)
# Parameters
parser.add_argument(
"--batch-size",
type=int,
default=8,
help="Batch size for training and testing (default: 8)",
)
args = parser.parse_args()
batch_size = args.batch_size
# Create a DataLoader
data_loader = DataLoader(
dataset, batch_size=batch_size, shuffle=False # Disable shuffle if order matters
)
latency = []
output_dir = os.path.join(".", "onnx_models")
if not os.path.exists(output_dir):
os.makedirs(output_dir)
export_model_path = os.path.join(output_dir, "bert-base-cased-squad.onnx")
device = torch.device("cpu")
# Get the first example data to run the model
data = dataset[0]
inputs = {
"input_ids": data[0].to(device).reshape(1, max_seq_length),
"attention_mask": data[1].to(device).reshape(1, max_seq_length),
"token_type_ids": data[2].to(device).reshape(1, max_seq_length),
}
# Set model to inference mode,
model.eval()
model.to(device)
test_iterator = iter(data_loader)
batch = next(test_iterator)
input_ids = batch[0].to(device) # Shape: (batch_size, max_seq_length)
attention_mask = batch[1].to(device) # Shape: (batch_size, max_seq_length)
token_type_ids = batch[2].to(device) # Shape: (batch_size, max_seq_length)
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
# Measure inference time
start = time.time()
# Run for 50 iterations to remove measurement overhead
for i in range(50):
outputs = model(**inputs)
end = time.time()
latency = (end - start) / 50
# print latency or inference time
print("Inference time for batchsize {} is {} seconds".format(batch_size, latency))