Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update async method #87

Merged
merged 3 commits into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion chattool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

__author__ = """Rex Wang"""
__email__ = '[email protected]'
__version__ = '3.3.3'
__version__ = '3.3.4'

import os, sys, requests, json
from .chattype import Chat, Resp
Expand Down Expand Up @@ -87,6 +87,13 @@ def save_envs(env_file:str):
elif platform.startswith("darwin"):
platform = "macos"

# is jupyter notebook
try:
get_ipython
is_jupyter = True
except:
is_jupyter = False

def default_prompt(msg:str):
"""Default prompt message for the API call

Expand Down
44 changes: 42 additions & 2 deletions chattool/chattype.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .functioncall import generate_json_schema, delete_dialogue_assist
from pprint import pformat
from loguru import logger
import asyncio

class Chat():
def __init__( self
Expand Down Expand Up @@ -235,7 +236,7 @@ def getresponse( self
max_tries = max(max_tries, max_requests)
if options.get('stream'):
options['stream'] = False
warnings.warn("Use `async_stream_responses()` instead.")
warnings.warn("Use `stream_responses` instead.")
options = self._init_options(**options)
# make requests
api_key, chat_log, chat_url = self.api_key, self.chat_log, self.chat_url
Expand All @@ -258,11 +259,49 @@ async def async_stream_responses( self

Returns:
str: response text

Examples:
>>> chat = Chat("Hello")
>>> # in Jupyter notebook
>>> async for resp in chat.async_stream_responses():
>>> print(resp)
"""
async for resp in _async_stream_responses(
self.api_key, self.chat_url, self.chat_log, self.model, timeout=timeout, **options):
yield resp.delta_content if textonly else resp

def stream_responses(self, timeout:int=0, textonly:bool=True, **options):
"""Post request synchronously and stream the responses

Args:
timeout (int, optional): timeout for the API call. Defaults to 0(no timeout).
textonly (bool, optional): whether to only return the text. Defaults to True.
options (dict, optional): other options like `temperature`, `top_p`, etc.

Returns:
str: response text

Examples:
>>> chat = Chat("Hello")
>>> for resp in chat.stream_responses():
>>> print(resp)
"""
assert not chattool.is_jupyter, "use `await chat.async_stream_responses()` in Jupyter notebook"
async_gen = self.async_stream_responses(timeout=timeout, textonly=textonly, **options)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
while True:
try:
# Run the async generator to get each response
response = loop.run_until_complete(async_gen.__anext__())
yield response
except StopAsyncIteration:
# End the generator when the async generator is exhausted
break
finally:
loop.close()

# Part3: tool call
def iswaiting(self):
"""Whether the response is waiting"""
Expand Down Expand Up @@ -396,7 +435,8 @@ def get_valid_models(self, gpt_only:bool=True)->List[str]:
model_url = os.path.join(self.api_base, 'models')
elif self.base_url:
model_url = os.path.join(self.base_url, 'v1/models')
return valid_models(self.api_key, model_url, gpt_only=gpt_only)
model_list = valid_models(self.api_key, model_url, gpt_only=gpt_only)
return sorted(set(model_list))

def get_curl(self, use_env_key:bool=False, **options):
"""Get the curl command
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
with open('README.md') as readme_file:
readme = readme_file.read()

VERSION = '3.3.3'
VERSION = '3.3.4'

requirements = [
'Click>=7.0', 'requests>=2.20', "responses>=0.23", 'aiohttp>=3.8',
Expand Down
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import pytest
from chattool import *

TEST_PATH = 'tests/testfiles/'

@pytest.fixture(scope="session")
def testpath():
return TEST_PATH
11 changes: 6 additions & 5 deletions tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
chatlogs = [
[{"role": "user", "content": f"Print hello using {lang}"}] for lang in langs
]
testpath = 'tests/testfiles/'

def test_simple():
# set api_key in the environment variable
Expand All @@ -32,6 +31,8 @@ async def show_resp(chat):
async for resp in chat.async_stream_responses():
print(resp.delta_content, end='')
asyncio.run(show_resp(chat))
for resp in chat.stream_responses():
print(resp, end='')

def test_async_typewriter():
def typewriter_effect(text, delay):
Expand All @@ -51,23 +52,23 @@ async def show_resp(chat):
chat = Chat("Print hello using Python")
asyncio.run(show_resp(chat))

def test_async_process():
def test_async_process(testpath):
chkpoint = testpath + "test_async.jsonl"
t = time.time()
async_chat_completion(chatlogs[:1], chkpoint, clearfile=True, nproc=3)
async_chat_completion(chatlogs, chkpoint, nproc=3)
print(f"Time elapsed: {time.time() - t:.2f}s")

# broken test
def test_failed_async():
def test_failed_async(testpath):
api_key = chattool.api_key
chattool.api_key = "sk-invalid"
chkpoint = testpath + "test_async_fail.jsonl"
words = ["hello", "Can you help me?", "Do not translate this word", "I need help with my homework"]
resp = async_chat_completion(words, chkpoint, clearfile=True, nproc=3)
chattool.api_key = api_key

def test_async_process_withfunc():
def test_async_process_withfunc(testpath):
chkpoint = testpath + "test_async_withfunc.jsonl"
words = ["hello", "Can you help me?", "Do not translate this word", "I need help with my homework"]
def msg2log(msg):
Expand All @@ -77,7 +78,7 @@ def msg2log(msg):
return chat.chat_log
async_chat_completion(words, chkpoint, clearfile=True, nproc=3, msg2log=msg2log)

def test_normal_process():
def test_normal_process(testpath):
chkpoint = testpath + "test_nomal.jsonl"
def data2chat(data):
chat = Chat(data)
Expand Down
3 changes: 1 addition & 2 deletions tests/test_chattool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from chattool import cli
from chattool import Chat, Resp, findcost
import pytest
testpath = 'tests/testfiles/'


def test_command_line_interface():
Expand All @@ -21,7 +20,7 @@ def test_command_line_interface():
assert '--help Show this message and exit.' in help_result.output

# test for the chat class
def test_chat():
def test_chat(testpath):
# initialize
chat = Chat()
assert chat.chat_log == []
Expand Down
5 changes: 2 additions & 3 deletions tests/test_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import os, responses
from chattool import Chat, load_chats, process_chats, api_key
testpath = 'tests/testfiles/'

def test_with_checkpoint():
def test_with_checkpoint(testpath):
# save chats without chatid
chat = Chat()
checkpath = testpath + "tmp.jsonl"
Expand Down Expand Up @@ -38,7 +37,7 @@ def test_with_checkpoint():
]
assert chats == [Chat(log) if log is not None else None for log in chat_logs]

def test_process_chats():
def test_process_chats(testpath):
def msg2chat(msg):
chat = Chat()
chat.system("You are a helpful translator for numbers.")
Expand Down
3 changes: 1 addition & 2 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import chattool
from chattool import Chat, save_envs, load_envs
testpath = 'tests/testfiles/'

def test_model_api_key():
api_key, model = chattool.api_key, chattool.model
Expand Down Expand Up @@ -43,7 +42,7 @@ def test_apibase():

chattool.api_base, chattool.base_url = api_base, base_url

def test_env_file():
def test_env_file(testpath):
save_envs(testpath + "chattool.env")
with open(testpath + "test.env", "w") as f:
f.write("OPENAI_API_KEY=sk-132\n")
Expand Down
3 changes: 1 addition & 2 deletions tests/test_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
)
import pytest, chattool, os
api_key, base_url, api_base = chattool.api_key, chattool.base_url, chattool.api_base
testpath = 'tests/testfiles/'

def test_valid_models():
if chattool.api_base:
Expand Down Expand Up @@ -40,7 +39,7 @@ def test_normalize_url():
assert normalize_url("api.openai.com") == "https://api.openai.com"
assert normalize_url("example.com/foo/bar") == "https://example.com/foo/bar"

def test_broken_requests():
def test_broken_requests(testpath):
"""Test the broken requests"""
with open(testpath + "test.txt", "w") as f:
f.write("hello world")
Expand Down
Loading