Skip to content

Commit cffba26

Browse files
committed
feat: support reproducibility expr
1 parent 730c21e commit cffba26

14 files changed

+1019
-0
lines changed

deserve_benchmark/__init__.py

Whitespace-only changes.

deserve_benchmark/benchmark/__init__.py

Whitespace-only changes.
+194
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
import argparse
2+
import json
3+
import logging
4+
import multiprocessing
5+
import pickle
6+
import random
7+
import threading
8+
import time
9+
import uuid
10+
from concurrent.futures import ALL_COMPLETED, ThreadPoolExecutor, wait
11+
from queue import Queue
12+
from typing import Any, Optional
13+
14+
import requests
15+
import torch
16+
from flask import Flask, request
17+
from safetensors.torch import load, save
18+
from transformers import AutoTokenizer # type: ignore
19+
20+
from deserve_benchmark.rater import Rater, RaterTimeLimitExceeded, Response
21+
from deserve_benchmark.workload.oasst1 import Oasst1Dataset
22+
from deserve_benchmark.workload.sharegpt import ShareGptDataset
23+
from deserve_benchmark.workload.static import StaticWorkload
24+
from deserve_benchmark.workload.utils import Workload
25+
26+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
27+
28+
stop_tokens = {128001, 128009}
29+
30+
31+
def dumps(tensors: dict[str, torch.Tensor], metadata: dict[str, Any]) -> bytes:
32+
"""
33+
Dump tensors and metadata into bytes
34+
"""
35+
36+
metadata_bytes = pickle.dumps(metadata)
37+
sharp_tensors = {}
38+
for k, v in tensors.items():
39+
if v.numel() == 0:
40+
sharp_tensors[f"#{k}"] = torch.ones((1,), dtype=v.dtype)
41+
else:
42+
sharp_tensors[k] = v
43+
tensors_bytes = save(sharp_tensors)
44+
return (
45+
len(tensors_bytes).to_bytes(4, byteorder="big") + tensors_bytes + metadata_bytes
46+
)
47+
48+
49+
class DeServeClient:
50+
def __init__(
51+
self,
52+
workload: Workload,
53+
time_limit: int,
54+
first_worker_url: str,
55+
batch_size: int,
56+
max_tokens: int,
57+
trace: bool,
58+
warmup: int,
59+
variance: int,
60+
):
61+
self.first_worker_url = first_worker_url
62+
self.batch_size = batch_size
63+
self.max_tokens = max_tokens
64+
self.network_executor = ThreadPoolExecutor(max_workers=128)
65+
self.deserve_executor = ThreadPoolExecutor(max_workers=128)
66+
self.time_limit = time_limit
67+
self.rater = Rater(
68+
workload=workload, time_limit=time_limit, trace=trace, warmup=warmup
69+
)
70+
self.variance = variance
71+
72+
def flask_service(self, events: Queue[int | None]) -> None: # type: ignore
73+
app = Flask(__name__)
74+
app.logger.setLevel(logging.ERROR)
75+
logging.getLogger("werkzeug").setLevel(logging.ERROR)
76+
77+
@app.route("/update_tasks", methods=["POST"])
78+
def update_tasks() -> str:
79+
request_json = request.json
80+
if request_json is None:
81+
return "No"
82+
data: list[dict[str, Any]] = request_json
83+
for task in data:
84+
request_id = int(task["task_id"].split("@")[0])
85+
token = task["output_token"]
86+
char = tokenizer.decode(token)
87+
try:
88+
self.rater.post(
89+
Response(
90+
id=request_id, payload=char, finished=(token in stop_tokens)
91+
)
92+
)
93+
except RaterTimeLimitExceeded:
94+
events.put(None)
95+
if token in stop_tokens:
96+
events.put(request_id)
97+
return "OK"
98+
99+
app.run(host="0.0.0.0", port=19000, debug=False)
100+
101+
def polling(self, queue: Queue[int | None]) -> None:
102+
current = 0
103+
while True:
104+
if current >= self.batch_size:
105+
value = queue.get()
106+
if value is None:
107+
break
108+
else:
109+
current += 1
110+
history = self.rater.get(1)
111+
if len(history) == 0:
112+
break
113+
id = history[0].id
114+
prompt = history[0].history
115+
tokens = tokenizer.encode(prompt, return_tensors="pt")[0]
116+
tensors = {"x": tokens}
117+
metadata = {
118+
"task_id": str(id) + "@" + str(uuid.uuid4()),
119+
"sampling_params": {
120+
"temperature": 0.0,
121+
"top_p": 1.0,
122+
"max_new_tokens": self.max_tokens
123+
+ random.randint(-self.variance, self.variance),
124+
},
125+
}
126+
response = requests.post(
127+
f"{self.first_worker_url}/prefill",
128+
data=dumps(tensors, metadata),
129+
)
130+
131+
def speedtest(self) -> dict[str, Any]:
132+
queue: Queue[int | None] = Queue()
133+
flask_thread = threading.Thread(
134+
target=self.flask_service, args=[queue], daemon=True
135+
)
136+
flask_thread.start()
137+
polling_thread = threading.Thread(
138+
target=self.polling, args=[queue], daemon=True
139+
)
140+
polling_thread.start()
141+
142+
try:
143+
if self.time_limit > 0:
144+
for _ in range(self.time_limit):
145+
time.sleep(1)
146+
if not polling_thread.is_alive():
147+
break
148+
else:
149+
while self.rater.requests_finished_total < self.rater.workload.size():
150+
time.sleep(1)
151+
except KeyboardInterrupt:
152+
pass
153+
return self.rater.dump()
154+
155+
156+
if __name__ == "__main__":
157+
parser = argparse.ArgumentParser()
158+
parser.add_argument("--time-limit", type=int, default=-1)
159+
parser.add_argument(
160+
"--batch-size", type=int, default=150
161+
) # number of concurrent requests that the controller will send to the workers
162+
parser.add_argument(
163+
"--max-tokens", type=int, default=1024
164+
) # max tokens per request
165+
parser.add_argument(
166+
"--workload", type=str, default="oasst1"
167+
) # workload name, if it starts with "fixed", then the format is "fixed{size}:{length}:{variance}"
168+
parser.add_argument("--first-worker-url", type=str, default="http://localhost:8080")
169+
parser.add_argument("--warmup", type=int, default=0) # warmup time in seconds
170+
parser.add_argument("--trace", action="store_true", default=False)
171+
parser.add_argument("--variance", type=int, default=0) # variance of output tokens
172+
args = parser.parse_args()
173+
174+
if args.workload == "oasst1":
175+
workload = Oasst1Dataset().into_workload()
176+
elif args.workload == "sharegpt":
177+
workload = ShareGptDataset().into_workload()
178+
elif args.workload.startswith("fixed"):
179+
raw = args.workload[len("fixed") :]
180+
size, length, variance = map(int, raw.split(":"))
181+
workload = StaticWorkload(size, length, variance)
182+
else:
183+
raise ValueError(f"Unknown workload: {args.workload}")
184+
client = DeServeClient(
185+
workload=workload,
186+
time_limit=args.time_limit,
187+
first_worker_url=args.first_worker_url,
188+
batch_size=args.batch_size,
189+
max_tokens=args.max_tokens,
190+
trace=args.trace,
191+
warmup=args.warmup,
192+
variance=args.variance,
193+
)
194+
print(json.dumps(client.speedtest()))
+143
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import argparse
2+
import json
3+
import random
4+
import threading
5+
import time
6+
from concurrent.futures import ALL_COMPLETED, ThreadPoolExecutor, wait
7+
from typing import Any
8+
9+
import requests
10+
from openai import OpenAI, Stream
11+
from openai.types.completion import Completion
12+
from transformers import AutoTokenizer # type: ignore
13+
14+
from src.workload.static import StaticWorkload
15+
from src.workload.utils import Workload
16+
17+
from ..rater import Rater, RaterTimeLimitExceeded, Response
18+
from ..workload.oasst1 import Oasst1Dataset
19+
from ..workload.sharegpt import ShareGptDataset
20+
21+
22+
class OnlineVLLMClient:
23+
def __init__(
24+
self,
25+
model: str,
26+
workload: Workload,
27+
time_limit: int,
28+
url: str,
29+
batch_size: int,
30+
max_tokens: int,
31+
trace: bool,
32+
warmup: int,
33+
variance: int,
34+
):
35+
self.url = url
36+
self.batch_size = batch_size
37+
self.max_tokens = max_tokens
38+
self.time_limit = time_limit
39+
self.network_executor = ThreadPoolExecutor(max_workers=128)
40+
self.vllm_executor = ThreadPoolExecutor(max_workers=128)
41+
self.rater = Rater(
42+
workload=workload, time_limit=time_limit, trace=trace, warmup=warmup
43+
)
44+
self.model = model
45+
self.openai_client = OpenAI(
46+
api_key="EMPTY",
47+
base_url=self.url,
48+
)
49+
self.variance = variance
50+
51+
def polling(self) -> None:
52+
while True:
53+
completions = self.rater.get(1)
54+
if len(completions) == 0:
55+
# no more requests
56+
break
57+
id = completions[0].id
58+
history = completions[0].history
59+
try:
60+
chat_stream: Stream[Completion] = self.openai_client.completions.create(
61+
model=self.model,
62+
prompt=history,
63+
max_tokens=self.max_tokens
64+
+ random.randint(-self.variance, self.variance),
65+
temperature=0,
66+
stream=True,
67+
)
68+
except Exception as e:
69+
print(e)
70+
raise e
71+
for chunk in chat_stream:
72+
content = chunk.choices[0].text
73+
if content is None:
74+
continue
75+
try:
76+
self.rater.post(Response(id=id, payload=content, finished=False))
77+
except RaterTimeLimitExceeded as e:
78+
return
79+
self.rater.post(Response(id=id, payload="", finished=True))
80+
81+
def routine(self) -> None:
82+
try:
83+
futures = []
84+
for _ in range(self.batch_size):
85+
futures.append(self.vllm_executor.submit(self.polling))
86+
wait(futures, return_when=ALL_COMPLETED)
87+
except KeyboardInterrupt:
88+
pass
89+
90+
def speedtest(self) -> dict[str, Any]:
91+
routine_thread = threading.Thread(target=self.routine, daemon=True)
92+
routine_thread.start()
93+
try:
94+
if self.time_limit != -1:
95+
for _ in range(self.time_limit):
96+
time.sleep(1)
97+
if not routine_thread.is_alive():
98+
break
99+
else:
100+
routine_thread.join()
101+
except KeyboardInterrupt:
102+
pass
103+
return self.rater.dump()
104+
105+
106+
if __name__ == "__main__":
107+
parser = argparse.ArgumentParser()
108+
parser.add_argument("--time-limit", type=int, default=-1)
109+
parser.add_argument("--batch-size", type=int, default=128)
110+
parser.add_argument("--max-tokens", type=int, default=1024)
111+
parser.add_argument("--url", type=str, default="http://localhost:8000/v1")
112+
parser.add_argument(
113+
"--model-name", type=str, default="meta-llama/Meta-Llama-3-70B-Instruct"
114+
)
115+
parser.add_argument("--workload", type=str, default="oasst1")
116+
parser.add_argument("--trace", action="store_true", default=False)
117+
parser.add_argument("--warmup", type=int, default=0)
118+
parser.add_argument("--variance", type=int, default=0)
119+
args = parser.parse_args()
120+
121+
if args.workload == "oasst1":
122+
workload = Oasst1Dataset().into_workload()
123+
elif args.workload == "sharegpt":
124+
workload = ShareGptDataset().into_workload()
125+
elif args.workload.startswith("fixed"):
126+
raw = args.workload[len("fixed") :]
127+
size, length, variance = map(int, raw.split(":"))
128+
workload = StaticWorkload(size, length, variance)
129+
else:
130+
raise ValueError(f"Unknown workload: {args.workload}")
131+
client = OnlineVLLMClient(
132+
model=args.model_name,
133+
workload=workload,
134+
time_limit=args.time_limit,
135+
url=args.url,
136+
batch_size=args.batch_size,
137+
max_tokens=args.max_tokens,
138+
trace=args.trace,
139+
warmup=args.warmup,
140+
variance=args.variance,
141+
)
142+
result = client.speedtest()
143+
print(json.dumps(result))

0 commit comments

Comments
 (0)