Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make AgentChat Team Config Serializable #5071

Merged
merged 13 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Any, AsyncGenerator, List, Mapping, Sequence

from autogen_core import CancellationToken
from autogen_core import CancellationToken, Component, ComponentModel
from autogen_core.models import ChatCompletionClient, LLMMessage, SystemMessage, UserMessage
from pydantic import BaseModel
from typing_extensions import Self

from autogen_agentchat.base import Response
from autogen_agentchat.state import SocietyOfMindAgentState
Expand All @@ -16,7 +18,18 @@
from ._base_chat_agent import BaseChatAgent


class SocietyOfMindAgent(BaseChatAgent):
class SocietyOfMindAgentConfig(BaseModel):
"""The declarative configuration for a SocietyOfMindAgent."""

name: str
team: ComponentModel
model_client: ComponentModel
description: str
instruction: str
response_prompt: str


class SocietyOfMindAgent(BaseChatAgent, Component[SocietyOfMindAgentConfig]):
"""An agent that uses an inner team of agents to generate responses.
Each time the agent's :meth:`on_messages` or :meth:`on_messages_stream`
Expand Down Expand Up @@ -74,6 +87,9 @@ async def main() -> None:
asyncio.run(main())
"""

component_config_schema = SocietyOfMindAgentConfig
component_provider_override = "autogen_agentchat.agents.SocietyOfMindAgent"

DEFAULT_INSTRUCTION = "Earlier you were asked to fulfill a request. You and your team worked diligently to address that request. Here is a transcript of that conversation:"
"""str: The default instruction to use when generating a response using the
inner team's messages. The instruction will be prepended to the inner team's
Expand Down Expand Up @@ -173,3 +189,26 @@ async def save_state(self) -> Mapping[str, Any]:
async def load_state(self, state: Mapping[str, Any]) -> None:
society_of_mind_state = SocietyOfMindAgentState.model_validate(state)
await self._team.load_state(society_of_mind_state.inner_team_state)

def _to_config(self) -> SocietyOfMindAgentConfig:
return SocietyOfMindAgentConfig(
name=self.name,
team=self._team.dump_component(),
model_client=self._model_client.dump_component(),
description=self.description,
instruction=self._instruction,
response_prompt=self._response_prompt,
)

@classmethod
def _from_config(cls, config: SocietyOfMindAgentConfig) -> Self:
model_client = ChatCompletionClient.load_component(config.model_client)
team = Team.load_component(config.team)
return cls(
name=config.name,
team=team,
model_client=model_client,
description=config.description,
instruction=config.instruction,
response_prompt=config.response_prompt,
)
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Mapping, Sequence

from autogen_core import CancellationToken
from autogen_core import CancellationToken, ComponentBase
from pydantic import BaseModel

from ..messages import AgentEvent, ChatMessage
from ._task import TaskRunner
Expand All @@ -20,9 +21,11 @@ class Response:
or :class:`ChatMessage`."""


class ChatAgent(ABC, TaskRunner):
class ChatAgent(ABC, TaskRunner, ComponentBase[BaseModel]):
"""Protocol for a chat agent."""

component_type = "agent"

@property
@abstractmethod
def name(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from typing import Any, Mapping
from abc import ABC, abstractmethod
from typing import Any, Mapping

from autogen_core import ComponentBase
from pydantic import BaseModel

from ._task import TaskRunner


class Team(ABC, TaskRunner):
class Team(ABC, TaskRunner, ComponentBase[BaseModel]):
component_type = "team"

@abstractmethod
async def reset(self) -> None:
"""Reset the team and all its participants to its initial state."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
AgentType,
CancellationToken,
ClosureAgent,
ComponentBase,
MessageContext,
SingleThreadedAgentRuntime,
TypeSubscription,
)
from autogen_core._closure_agent import ClosureContext
from pydantic import BaseModel

from ... import EVENT_LOGGER_NAME
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
Expand All @@ -28,13 +30,15 @@
event_logger = logging.getLogger(EVENT_LOGGER_NAME)


class BaseGroupChat(Team, ABC):
class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
"""The base class for group chat teams.

To implement a group chat team, first create a subclass of :class:`BaseGroupChatManager` and then
create a subclass of :class:`BaseGroupChat` that uses the group chat manager.
"""

component_type = "team"

def __init__(
self,
participants: List[ChatAgent],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import logging
from typing import Callable, List

from autogen_core import Component, ComponentModel
from autogen_core.models import ChatCompletionClient
from pydantic import BaseModel
from typing_extensions import Self

from .... import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME
from ....base import ChatAgent, TerminationCondition
Expand All @@ -13,7 +16,18 @@
event_logger = logging.getLogger(EVENT_LOGGER_NAME)


class MagenticOneGroupChat(BaseGroupChat):
class MagenticOneGroupChatConfig(BaseModel):
"""The declarative configuration for a MagenticOneGroupChat."""

participants: List[ComponentModel]
model_client: ComponentModel
termination_condition: ComponentModel | None = None
max_turns: int | None = None
max_stalls: int
final_answer_prompt: str


class MagenticOneGroupChat(BaseGroupChat, Component[MagenticOneGroupChatConfig]):
"""A team that runs a group chat with participants managed by the MagenticOneOrchestrator.
The orchestrator handles the conversation flow, ensuring that the task is completed
Expand Down Expand Up @@ -73,6 +87,9 @@ async def main() -> None:
}
"""

component_config_schema = MagenticOneGroupChatConfig
component_provider_override = "autogen_agentchat.teams.MagenticOneGroupChat"

def __init__(
self,
participants: List[ChatAgent],
Expand Down Expand Up @@ -117,3 +134,31 @@ def _create_group_chat_manager_factory(
self._final_answer_prompt,
termination_condition,
)

def _to_config(self) -> MagenticOneGroupChatConfig:
participants = [participant.dump_component() for participant in self._participants]
termination_condition = self._termination_condition.dump_component() if self._termination_condition else None
return MagenticOneGroupChatConfig(
participants=participants,
model_client=self._model_client.dump_component(),
termination_condition=termination_condition,
max_turns=self._max_turns,
max_stalls=self._max_stalls,
final_answer_prompt=self._final_answer_prompt,
)

@classmethod
def _from_config(cls, config: MagenticOneGroupChatConfig) -> Self:
participants = [ChatAgent.load_component(participant) for participant in config.participants]
model_client = ChatCompletionClient.load_component(config.model_client)
termination_condition = (
TerminationCondition.load_component(config.termination_condition) if config.termination_condition else None
)
return cls(
participants,
model_client,
termination_condition=termination_condition,
max_turns=config.max_turns,
max_stalls=config.max_stalls,
final_answer_prompt=config.final_answer_prompt,
)
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from typing import Any, Callable, List, Mapping

from autogen_core import Component, ComponentModel
from pydantic import BaseModel
from typing_extensions import Self

from ...base import ChatAgent, TerminationCondition
from ...messages import AgentEvent, ChatMessage
from ...state import RoundRobinManagerState
Expand Down Expand Up @@ -61,7 +65,15 @@ async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str:
return current_speaker


class RoundRobinGroupChat(BaseGroupChat):
class RoundRobinGroupChatConfig(BaseModel):
"""The declarative configuration RoundRobinGroupChat."""

participants: List[ComponentModel]
termination_condition: ComponentModel | None = None
max_turns: int | None = None


class RoundRobinGroupChat(BaseGroupChat, Component[RoundRobinGroupChatConfig]):
"""A team that runs a group chat with participants taking turns in a round-robin fashion
to publish a message to all.
Expand Down Expand Up @@ -133,6 +145,9 @@ async def main() -> None:
asyncio.run(main())
"""

component_config_schema = RoundRobinGroupChatConfig
component_provider_override = "autogen_agentchat.teams.RoundRobinGroupChat"

def __init__(
self,
participants: List[ChatAgent],
Expand Down Expand Up @@ -166,3 +181,20 @@ def _factory() -> RoundRobinGroupChatManager:
)

return _factory

def _to_config(self) -> RoundRobinGroupChatConfig:
participants = [participant.dump_component() for participant in self._participants]
termination_condition = self._termination_condition.dump_component() if self._termination_condition else None
return RoundRobinGroupChatConfig(
participants=participants,
termination_condition=termination_condition,
max_turns=self._max_turns,
)

@classmethod
def _from_config(cls, config: RoundRobinGroupChatConfig) -> Self:
participants = [ChatAgent.load_component(participant) for participant in config.participants]
termination_condition = (
TerminationCondition.load_component(config.termination_condition) if config.termination_condition else None
)
return cls(participants, termination_condition=termination_condition, max_turns=config.max_turns)
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
import re
from typing import Any, Callable, Dict, List, Mapping, Sequence

from autogen_core import Component, ComponentModel
from autogen_core.models import ChatCompletionClient, SystemMessage
from pydantic import BaseModel
from typing_extensions import Self

from ... import TRACE_LOGGER_NAME
from ...agents import BaseChatAgent
from ...base import ChatAgent, TerminationCondition
from ...messages import (
AgentEvent,
Expand Down Expand Up @@ -184,7 +188,19 @@ def _mentioned_agents(self, message_content: str, agent_names: List[str]) -> Dic
return mentions


class SelectorGroupChat(BaseGroupChat):
class SelectorGroupChatConfig(BaseModel):
"""The declarative configuration for SelectorGroupChat."""

participants: List[ComponentModel]
model_client: ComponentModel
termination_condition: ComponentModel | None = None
max_turns: int | None = None
selector_prompt: str
allow_repeated_speaker: bool
# selector_func: ComponentModel | None


class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
"""A group chat team that have participants takes turn to publish a message
to all, using a ChatCompletion model to select the next speaker after each message.
Expand Down Expand Up @@ -321,6 +337,9 @@ def selector_func(messages: Sequence[AgentEvent | ChatMessage]) -> str | None:
asyncio.run(main())
"""

component_config_schema = SelectorGroupChatConfig
component_provider_override = "autogen_agentchat.teams.SelectorGroupChat"

def __init__(
self,
participants: List[ChatAgent],
Expand Down Expand Up @@ -381,3 +400,30 @@ def _create_group_chat_manager_factory(
self._allow_repeated_speaker,
self._selector_func,
)

def _to_config(self) -> SelectorGroupChatConfig:
return SelectorGroupChatConfig(
participants=[participant.dump_component() for participant in self._participants],
model_client=self._model_client.dump_component(),
termination_condition=self._termination_condition.dump_component() if self._termination_condition else None,
max_turns=self._max_turns,
selector_prompt=self._selector_prompt,
allow_repeated_speaker=self._allow_repeated_speaker,
# selector_func=self._selector_func.dump_component() if self._selector_func else None,
)

@classmethod
def _from_config(cls, config: SelectorGroupChatConfig) -> Self:
return cls(
participants=[BaseChatAgent.load_component(participant) for participant in config.participants],
model_client=ChatCompletionClient.load_component(config.model_client),
termination_condition=TerminationCondition.load_component(config.termination_condition)
if config.termination_condition
else None,
max_turns=config.max_turns,
selector_prompt=config.selector_prompt,
allow_repeated_speaker=config.allow_repeated_speaker,
# selector_func=ComponentLoader.load_component(config.selector_func, Callable[[Sequence[AgentEvent | ChatMessage]], str | None])
# if config.selector_func
# else None,
)
Loading
Loading