diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index b1144d9c4667..9df82f0ab917 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -28,7 +28,7 @@ SystemMessage, UserMessage, ) -from autogen_core.tools import FunctionTool, BaseTool +from autogen_core.tools import BaseTool, FunctionTool from pydantic import BaseModel from typing_extensions import Self @@ -268,6 +268,7 @@ def __init__( system_message: ( str | None ) = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.", + token_callback: Callable | None = None, reflect_on_tool_use: bool = False, tool_call_summary_format: str = "{result}", memory: Sequence[Memory] | None = None, @@ -289,6 +290,7 @@ def __init__( else: self._system_messages = [SystemMessage(content=system_message)] self._tools: List[BaseTool[Any, Any]] = [] + self._token_callback = token_callback if tools is not None: if model_client.model_info["function_calling"] is False: raise ValueError("The model does not support function calling.") @@ -383,9 +385,26 @@ async def on_messages_stream( # Generate an inference result based on the current model context. llm_messages = self._system_messages + await self._model_context.get_messages() - model_result = await self._model_client.create( - llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token - ) + + # if token_callback is set, use create_stream to get the tokens as they are + # generated and call the token_callback with the tokens + if self._token_callback is not None: + async for model_result in self._model_client.create_stream( + llm_messages, + tools=self._tools + self._handoff_tools, + cancellation_token=cancellation_token, + ): + # if the result is a string, it is a token to be streamed back + if isinstance(model_result, str): + await self._token_callback(model_result) + else: + break + else: + model_result = await self._model_client.create( + llm_messages, + tools=self._tools + self._handoff_tools, + cancellation_token=cancellation_token, + ) # Add the response to the model context. await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name)) @@ -465,7 +484,24 @@ async def on_messages_stream( if self._reflect_on_tool_use: # Generate another inference result based on the tool call and result. llm_messages = self._system_messages + await self._model_context.get_messages() - model_result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token) + + # if token_callback is set, use create_stream to get the tokens as they are + # generated and call the token_callback with the tokens + if self._token_callback is not None: + async for model_result in self._model_client.create_stream( + llm_messages, + cancellation_token=cancellation_token, + ): + # if the result is a string, it is a token to be streamed back + if isinstance(model_result, str): + await self._token_callback(model_result) + else: + break + else: + model_result = await self._model_client.create( + llm_messages, + cancellation_token=cancellation_token, + ) assert isinstance(model_result.content, str) # Add the response to the model context. await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name)) diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py index 131e7288a658..6444ce01fee7 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py @@ -617,7 +617,7 @@ async def create_stream( json_output: Optional[bool] = None, extra_create_args: Mapping[str, Any] = {}, cancellation_token: Optional[CancellationToken] = None, - max_consecutive_empty_chunk_tolerance: int = 0, + max_consecutive_empty_chunk_tolerance: int = 10, ) -> AsyncGenerator[Union[str, CreateResult], None]: """ Creates an AsyncGenerator that will yield a stream of chat completions based on the provided messages and tools.