|
| 1 | +import argparse |
| 2 | +import json |
| 3 | +import logging |
| 4 | +import multiprocessing |
| 5 | +import pickle |
| 6 | +import random |
| 7 | +import threading |
| 8 | +import time |
| 9 | +import uuid |
| 10 | +from concurrent.futures import ALL_COMPLETED, ThreadPoolExecutor, wait |
| 11 | +from queue import Queue |
| 12 | +from typing import Any, Optional |
| 13 | + |
| 14 | +import requests |
| 15 | +import torch |
| 16 | +from flask import Flask, request |
| 17 | +from safetensors.torch import load, save |
| 18 | +from transformers import AutoTokenizer # type: ignore |
| 19 | + |
| 20 | +from deserve_benchmark.rater import Rater, RaterTimeLimitExceeded, Response |
| 21 | +from deserve_benchmark.workload.oasst1 import Oasst1Dataset |
| 22 | +from deserve_benchmark.workload.sharegpt import ShareGptDataset |
| 23 | +from deserve_benchmark.workload.static import StaticWorkload |
| 24 | +from deserve_benchmark.workload.utils import Workload |
| 25 | + |
| 26 | +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") |
| 27 | + |
| 28 | +stop_tokens = {128001, 128009} |
| 29 | + |
| 30 | + |
| 31 | +def dumps(tensors: dict[str, torch.Tensor], metadata: dict[str, Any]) -> bytes: |
| 32 | + """ |
| 33 | + Dump tensors and metadata into bytes |
| 34 | + """ |
| 35 | + |
| 36 | + metadata_bytes = pickle.dumps(metadata) |
| 37 | + sharp_tensors = {} |
| 38 | + for k, v in tensors.items(): |
| 39 | + if v.numel() == 0: |
| 40 | + sharp_tensors[f"#{k}"] = torch.ones((1,), dtype=v.dtype) |
| 41 | + else: |
| 42 | + sharp_tensors[k] = v |
| 43 | + tensors_bytes = save(sharp_tensors) |
| 44 | + return ( |
| 45 | + len(tensors_bytes).to_bytes(4, byteorder="big") + tensors_bytes + metadata_bytes |
| 46 | + ) |
| 47 | + |
| 48 | + |
| 49 | +class DeServeClient: |
| 50 | + def __init__( |
| 51 | + self, |
| 52 | + workload: Workload, |
| 53 | + time_limit: int, |
| 54 | + first_worker_url: str, |
| 55 | + batch_size: int, |
| 56 | + max_tokens: int, |
| 57 | + trace: bool, |
| 58 | + warmup: int, |
| 59 | + variance: int, |
| 60 | + ): |
| 61 | + self.first_worker_url = first_worker_url |
| 62 | + self.batch_size = batch_size |
| 63 | + self.max_tokens = max_tokens |
| 64 | + self.network_executor = ThreadPoolExecutor(max_workers=128) |
| 65 | + self.deserve_executor = ThreadPoolExecutor(max_workers=128) |
| 66 | + self.time_limit = time_limit |
| 67 | + self.rater = Rater( |
| 68 | + workload=workload, time_limit=time_limit, trace=trace, warmup=warmup |
| 69 | + ) |
| 70 | + self.variance = variance |
| 71 | + |
| 72 | + def flask_service(self, events: Queue[int | None]) -> None: # type: ignore |
| 73 | + app = Flask(__name__) |
| 74 | + app.logger.setLevel(logging.ERROR) |
| 75 | + logging.getLogger("werkzeug").setLevel(logging.ERROR) |
| 76 | + |
| 77 | + @app.route("/update_tasks", methods=["POST"]) |
| 78 | + def update_tasks() -> str: |
| 79 | + request_json = request.json |
| 80 | + if request_json is None: |
| 81 | + return "No" |
| 82 | + data: list[dict[str, Any]] = request_json |
| 83 | + for task in data: |
| 84 | + request_id = int(task["task_id"].split("@")[0]) |
| 85 | + token = task["output_token"] |
| 86 | + char = tokenizer.decode(token) |
| 87 | + try: |
| 88 | + self.rater.post( |
| 89 | + Response( |
| 90 | + id=request_id, payload=char, finished=(token in stop_tokens) |
| 91 | + ) |
| 92 | + ) |
| 93 | + except RaterTimeLimitExceeded: |
| 94 | + events.put(None) |
| 95 | + if token in stop_tokens: |
| 96 | + events.put(request_id) |
| 97 | + return "OK" |
| 98 | + |
| 99 | + app.run(host="0.0.0.0", port=19000, debug=False) |
| 100 | + |
| 101 | + def polling(self, queue: Queue[int | None]) -> None: |
| 102 | + current = 0 |
| 103 | + while True: |
| 104 | + if current >= self.batch_size: |
| 105 | + value = queue.get() |
| 106 | + if value is None: |
| 107 | + break |
| 108 | + else: |
| 109 | + current += 1 |
| 110 | + history = self.rater.get(1) |
| 111 | + if len(history) == 0: |
| 112 | + break |
| 113 | + id = history[0].id |
| 114 | + prompt = history[0].history |
| 115 | + tokens = tokenizer.encode(prompt, return_tensors="pt")[0] |
| 116 | + tensors = {"x": tokens} |
| 117 | + metadata = { |
| 118 | + "task_id": str(id) + "@" + str(uuid.uuid4()), |
| 119 | + "sampling_params": { |
| 120 | + "temperature": 0.0, |
| 121 | + "top_p": 1.0, |
| 122 | + "max_new_tokens": self.max_tokens |
| 123 | + + random.randint(-self.variance, self.variance), |
| 124 | + }, |
| 125 | + } |
| 126 | + response = requests.post( |
| 127 | + f"{self.first_worker_url}/prefill", |
| 128 | + data=dumps(tensors, metadata), |
| 129 | + ) |
| 130 | + |
| 131 | + def speedtest(self) -> dict[str, Any]: |
| 132 | + queue: Queue[int | None] = Queue() |
| 133 | + flask_thread = threading.Thread( |
| 134 | + target=self.flask_service, args=[queue], daemon=True |
| 135 | + ) |
| 136 | + flask_thread.start() |
| 137 | + polling_thread = threading.Thread( |
| 138 | + target=self.polling, args=[queue], daemon=True |
| 139 | + ) |
| 140 | + polling_thread.start() |
| 141 | + |
| 142 | + try: |
| 143 | + if self.time_limit > 0: |
| 144 | + for _ in range(self.time_limit): |
| 145 | + time.sleep(1) |
| 146 | + if not polling_thread.is_alive(): |
| 147 | + break |
| 148 | + else: |
| 149 | + while self.rater.requests_finished_total < self.rater.workload.size(): |
| 150 | + time.sleep(1) |
| 151 | + except KeyboardInterrupt: |
| 152 | + pass |
| 153 | + return self.rater.dump() |
| 154 | + |
| 155 | + |
| 156 | +if __name__ == "__main__": |
| 157 | + parser = argparse.ArgumentParser() |
| 158 | + parser.add_argument("--time-limit", type=int, default=-1) |
| 159 | + parser.add_argument( |
| 160 | + "--batch-size", type=int, default=150 |
| 161 | + ) # number of concurrent requests that the controller will send to the workers |
| 162 | + parser.add_argument( |
| 163 | + "--max-tokens", type=int, default=1024 |
| 164 | + ) # max tokens per request |
| 165 | + parser.add_argument( |
| 166 | + "--workload", type=str, default="oasst1" |
| 167 | + ) # workload name, if it starts with "fixed", then the format is "fixed{size}:{length}:{variance}" |
| 168 | + parser.add_argument("--first-worker-url", type=str, default="http://localhost:8080") |
| 169 | + parser.add_argument("--warmup", type=int, default=0) # warmup time in seconds |
| 170 | + parser.add_argument("--trace", action="store_true", default=False) |
| 171 | + parser.add_argument("--variance", type=int, default=0) # variance of output tokens |
| 172 | + args = parser.parse_args() |
| 173 | + |
| 174 | + if args.workload == "oasst1": |
| 175 | + workload = Oasst1Dataset().into_workload() |
| 176 | + elif args.workload == "sharegpt": |
| 177 | + workload = ShareGptDataset().into_workload() |
| 178 | + elif args.workload.startswith("fixed"): |
| 179 | + raw = args.workload[len("fixed") :] |
| 180 | + size, length, variance = map(int, raw.split(":")) |
| 181 | + workload = StaticWorkload(size, length, variance) |
| 182 | + else: |
| 183 | + raise ValueError(f"Unknown workload: {args.workload}") |
| 184 | + client = DeServeClient( |
| 185 | + workload=workload, |
| 186 | + time_limit=args.time_limit, |
| 187 | + first_worker_url=args.first_worker_url, |
| 188 | + batch_size=args.batch_size, |
| 189 | + max_tokens=args.max_tokens, |
| 190 | + trace=args.trace, |
| 191 | + warmup=args.warmup, |
| 192 | + variance=args.variance, |
| 193 | + ) |
| 194 | + print(json.dumps(client.speedtest())) |
0 commit comments