Skip to content

Commit c37227b

Browse files
yiranwu0joshkyhqingyun-wu
authored
Allow user to pass in a customized speaker selection method (#1791)
* init PR * update * update code check * update * update * update * update * Test the ability to have agents a,u,t,o,g,e,n speak in turn. * update * update * update * Evidence that groupchat not terminating because of the TERMINATE substring. * Raising NoEligibleSpeakerException allows graceful exit before max turns * update * To confirm with author that custom function is meant to override graph constraints * Confirmed the expected test behaviour with author * Update autogen/agentchat/groupchat.py * update * update --------- Co-authored-by: Joshua Kim <[email protected]> Co-authored-by: Qingyun Wu <[email protected]>
1 parent d711bd8 commit c37227b

File tree

6 files changed

+707
-17
lines changed

6 files changed

+707
-17
lines changed

autogen/agentchat/groupchat.py

+44-10
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import re
44
import sys
55
from dataclasses import dataclass, field
6-
from typing import Dict, List, Optional, Union, Tuple
6+
from typing import Dict, List, Optional, Union, Tuple, Callable
77

88

99
from ..code_utils import content_str
@@ -42,7 +42,16 @@ class GroupChat:
4242
- "manual": the next speaker is selected manually by user input.
4343
- "random": the next speaker is selected randomly.
4444
- "round_robin": the next speaker is selected in a round robin fashion, i.e., iterating in the same order as provided in `agents`.
45-
45+
- a customized speaker selection function (Callable): the function will be called to select the next speaker.
46+
The function should take the last speaker and the group chat as input and return one of the following:
47+
1. an `Agent` class, it must be one of the agents in the group chat.
48+
2. a string from ['auto', 'manual', 'random', 'round_robin'] to select a default method to use.
49+
3. None, which would terminate the conversation gracefully.
50+
```python
51+
def custom_speaker_selection_func(
52+
last_speaker: Agent, groupchat: GroupChat
53+
) -> Union[Agent, str, None]:
54+
```
4655
- allow_repeat_speaker: whether to allow the same speaker to speak consecutively.
4756
Default is True, in which case all speakers are allowed to speak consecutively.
4857
If `allow_repeat_speaker` is a list of Agents, then only those listed agents are allowed to repeat.
@@ -67,7 +76,7 @@ class GroupChat:
6776
max_round: Optional[int] = 10
6877
admin_name: Optional[str] = "Admin"
6978
func_call_filter: Optional[bool] = True
70-
speaker_selection_method: Optional[str] = "auto"
79+
speaker_selection_method: Optional[Union[str, Callable]] = "auto"
7180
allow_repeat_speaker: Optional[Union[bool, List[Agent]]] = None
7281
allowed_or_disallowed_speaker_transitions: Optional[Dict] = None
7382
speaker_transitions_type: Optional[str] = None
@@ -277,11 +286,36 @@ def random_select_speaker(self, agents: Optional[List[Agent]] = None) -> Union[A
277286
return random.choice(agents)
278287

279288
def _prepare_and_select_agents(
280-
self, last_speaker: Agent
289+
self,
290+
last_speaker: Agent,
281291
) -> Tuple[Optional[Agent], List[Agent], Optional[List[Dict]]]:
282-
if self.speaker_selection_method.lower() not in self._VALID_SPEAKER_SELECTION_METHODS:
292+
# If self.speaker_selection_method is a callable, call it to get the next speaker.
293+
# If self.speaker_selection_method is a string, return it.
294+
speaker_selection_method = self.speaker_selection_method
295+
if isinstance(self.speaker_selection_method, Callable):
296+
selected_agent = self.speaker_selection_method(last_speaker, self)
297+
if selected_agent is None:
298+
raise NoEligibleSpeakerException(
299+
"Custom speaker selection function returned None. Terminating conversation."
300+
)
301+
elif isinstance(selected_agent, Agent):
302+
if selected_agent in self.agents:
303+
return selected_agent, self.agents, None
304+
else:
305+
raise ValueError(
306+
f"Custom speaker selection function returned an agent {selected_agent.name} not in the group chat."
307+
)
308+
elif isinstance(selected_agent, str):
309+
# If returned a string, assume it is a speaker selection method
310+
speaker_selection_method = selected_agent
311+
else:
312+
raise ValueError(
313+
f"Custom speaker selection function returned an object of type {type(selected_agent)} instead of Agent or str."
314+
)
315+
316+
if speaker_selection_method.lower() not in self._VALID_SPEAKER_SELECTION_METHODS:
283317
raise ValueError(
284-
f"GroupChat speaker_selection_method is set to '{self.speaker_selection_method}'. "
318+
f"GroupChat speaker_selection_method is set to '{speaker_selection_method}'. "
285319
f"It should be one of {self._VALID_SPEAKER_SELECTION_METHODS} (case insensitive). "
286320
)
287321

@@ -300,7 +334,7 @@ def _prepare_and_select_agents(
300334
f"GroupChat is underpopulated with {n_agents} agents. "
301335
"Please add more agents to the GroupChat or use direct communication instead."
302336
)
303-
elif n_agents == 2 and self.speaker_selection_method.lower() != "round_robin" and allow_repeat_speaker:
337+
elif n_agents == 2 and speaker_selection_method.lower() != "round_robin" and allow_repeat_speaker:
304338
logger.warning(
305339
f"GroupChat is underpopulated with {n_agents} agents. "
306340
"Consider setting speaker_selection_method to 'round_robin' or allow_repeat_speaker to False, "
@@ -366,11 +400,11 @@ def _prepare_and_select_agents(
366400

367401
# Use the selected speaker selection method
368402
select_speaker_messages = None
369-
if self.speaker_selection_method.lower() == "manual":
403+
if speaker_selection_method.lower() == "manual":
370404
selected_agent = self.manual_select_speaker(graph_eligible_agents)
371-
elif self.speaker_selection_method.lower() == "round_robin":
405+
elif speaker_selection_method.lower() == "round_robin":
372406
selected_agent = self.next_agent(last_speaker, graph_eligible_agents)
373-
elif self.speaker_selection_method.lower() == "random":
407+
elif speaker_selection_method.lower() == "random":
374408
selected_agent = self.random_select_speaker(graph_eligible_agents)
375409
else:
376410
selected_agent = None

notebook/agentchat_custom_model.ipynb

+1
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,7 @@
383383
"source": [
384384
"# load model here\n",
385385
"\n",
386+
"\n",
386387
"config = config_list_custom[0]\n",
387388
"device = config.get(\"device\", \"cpu\")\n",
388389
"loaded_model = AutoModelForCausalLM.from_pretrained(config[\"model\"]).to(device)\n",

0 commit comments

Comments
 (0)