From dab8bad9065369c643573d9834e27bb858c1af36 Mon Sep 17 00:00:00 2001 From: alpayariyak Date: Wed, 31 Jan 2024 23:27:31 +0000 Subject: [PATCH] OpenAI Chat Completions Stream --- .gitignore | 3 ++- src/engine.py | 45 ++++++++++----------------------------- src/test_openai_stream.py | 44 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 35 deletions(-) create mode 100644 src/test_openai_stream.py diff --git a/.gitignore b/.gitignore index b4e200e..e87478d 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ runpod.toml *.pyc .env -test/* \ No newline at end of file +test/* +vllm-base/vllm-* diff --git a/src/engine.py b/src/engine.py index 356fee3..3293f59 100644 --- a/src/engine.py +++ b/src/engine.py @@ -125,44 +125,21 @@ async def generate_openai_chat(self, llm_input, validated_sampling_params, batch response_generator = await self.openai_engine.create_chat_completion(chat_completion_request, DummyRequest()) if not stream: - yield json.loads(response_generator.model_dump_json()) + yield response_generator else: - batch_contents = {} - batch_latest_choices = {} + batch = "" batch_token_counter = 0 - last_chunk = {} async for chunk_str in response_generator: - try: - chunk = json.loads(chunk_str.removeprefix("data: ").rstrip("\n\n")) - except: - continue - - if "choices" in chunk: - for choice in chunk["choices"]: - choice_index = choice["index"] - if "delta" in choice and "content" in choice["delta"]: - batch_contents[choice_index] = batch_contents.get(choice_index, []) + [choice["delta"]["content"]] - batch_latest_choices[choice_index] = choice - batch_token_counter += 1 - last_chunk = chunk - - if batch_token_counter >= batch_size: - for choice_index in batch_latest_choices: - batch_latest_choices[choice_index]["delta"]["content"] = batch_contents[choice_index] - last_chunk["choices"] = list(batch_latest_choices.values()) - yield last_chunk - - batch_contents = {} - batch_latest_choices = {} - batch_token_counter = 0 - - if batch_contents: - for choice_index in batch_latest_choices: - batch_latest_choices[choice_index]["delta"]["content"] = batch_contents[choice_index] - last_chunk["choices"] = list(batch_latest_choices.values()) - yield last_chunk - + if "data" in chunk_str: + batch += chunk_str + batch_token_counter += 1 + if batch_token_counter >= batch_size: + yield batch + batch = "" + batch_token_counter = 0 + + def _initialize_config(self): quantization = self._get_quantization() model, download_dir = self._get_model_name_and_path() diff --git a/src/test_openai_stream.py b/src/test_openai_stream.py new file mode 100644 index 0000000..86860e2 --- /dev/null +++ b/src/test_openai_stream.py @@ -0,0 +1,44 @@ +import os +from utils import JobInput +from engine import vLLMEngine + +os.environ["MODEL_NAME"] = "facebook/opt-125m" +os.environ["CUSTOM_CHAT_TEMPLATE"] = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" + +vllm_engine = vLLMEngine() + +async def handler(job): + job_input = JobInput(job["input"]) + results_generator = vllm_engine.generate(job_input) + async for batch in results_generator: + yield batch + +test_payload = { + "input": { + "messages": [ + {"role": "user", "content": "Write me a 3000 word long and detailed essay about how the french revolution impacted the rest of europe over the 18th century."}, + ], + "batch_size": 2, # How many tokens to yield per batch + "apply_chat_template": True, + "sampling_params": { + "max_tokens": 10, + "temperature": 0, + "ignore_eos": True, + "n":1 + }, + "stream": True, + "use_openai_format": True + } +} + +async def test_handler(): + print("Start of output") + print("=" *50) + async for batch in handler(test_payload): + print(batch, end="") + print("=" *50) + print("End of output") + +import asyncio + +asyncio.run(test_handler()) \ No newline at end of file