3
3
import re
4
4
import sys
5
5
from dataclasses import dataclass , field
6
- from typing import Dict , List , Optional , Union , Tuple
6
+ from typing import Dict , List , Optional , Union , Tuple , Callable
7
7
8
8
9
9
from ..code_utils import content_str
@@ -42,7 +42,16 @@ class GroupChat:
42
42
- "manual": the next speaker is selected manually by user input.
43
43
- "random": the next speaker is selected randomly.
44
44
- "round_robin": the next speaker is selected in a round robin fashion, i.e., iterating in the same order as provided in `agents`.
45
-
45
+ - a customized speaker selection function (Callable): the function will be called to select the next speaker.
46
+ The function should take the last speaker and the group chat as input and return one of the following:
47
+ 1. an `Agent` class, it must be one of the agents in the group chat.
48
+ 2. a string from ['auto', 'manual', 'random', 'round_robin'] to select a default method to use.
49
+ 3. None, which would terminate the conversation gracefully.
50
+ ```python
51
+ def custom_speaker_selection_func(
52
+ last_speaker: Agent, groupchat: GroupChat
53
+ ) -> Union[Agent, str, None]:
54
+ ```
46
55
- allow_repeat_speaker: whether to allow the same speaker to speak consecutively.
47
56
Default is True, in which case all speakers are allowed to speak consecutively.
48
57
If `allow_repeat_speaker` is a list of Agents, then only those listed agents are allowed to repeat.
@@ -67,7 +76,7 @@ class GroupChat:
67
76
max_round : Optional [int ] = 10
68
77
admin_name : Optional [str ] = "Admin"
69
78
func_call_filter : Optional [bool ] = True
70
- speaker_selection_method : Optional [str ] = "auto"
79
+ speaker_selection_method : Optional [Union [ str , Callable ] ] = "auto"
71
80
allow_repeat_speaker : Optional [Union [bool , List [Agent ]]] = None
72
81
allowed_or_disallowed_speaker_transitions : Optional [Dict ] = None
73
82
speaker_transitions_type : Optional [str ] = None
@@ -277,11 +286,36 @@ def random_select_speaker(self, agents: Optional[List[Agent]] = None) -> Union[A
277
286
return random .choice (agents )
278
287
279
288
def _prepare_and_select_agents (
280
- self , last_speaker : Agent
289
+ self ,
290
+ last_speaker : Agent ,
281
291
) -> Tuple [Optional [Agent ], List [Agent ], Optional [List [Dict ]]]:
282
- if self .speaker_selection_method .lower () not in self ._VALID_SPEAKER_SELECTION_METHODS :
292
+ # If self.speaker_selection_method is a callable, call it to get the next speaker.
293
+ # If self.speaker_selection_method is a string, return it.
294
+ speaker_selection_method = self .speaker_selection_method
295
+ if isinstance (self .speaker_selection_method , Callable ):
296
+ selected_agent = self .speaker_selection_method (last_speaker , self )
297
+ if selected_agent is None :
298
+ raise NoEligibleSpeakerException (
299
+ "Custom speaker selection function returned None. Terminating conversation."
300
+ )
301
+ elif isinstance (selected_agent , Agent ):
302
+ if selected_agent in self .agents :
303
+ return selected_agent , self .agents , None
304
+ else :
305
+ raise ValueError (
306
+ f"Custom speaker selection function returned an agent { selected_agent .name } not in the group chat."
307
+ )
308
+ elif isinstance (selected_agent , str ):
309
+ # If returned a string, assume it is a speaker selection method
310
+ speaker_selection_method = selected_agent
311
+ else :
312
+ raise ValueError (
313
+ f"Custom speaker selection function returned an object of type { type (selected_agent )} instead of Agent or str."
314
+ )
315
+
316
+ if speaker_selection_method .lower () not in self ._VALID_SPEAKER_SELECTION_METHODS :
283
317
raise ValueError (
284
- f"GroupChat speaker_selection_method is set to '{ self . speaker_selection_method } '. "
318
+ f"GroupChat speaker_selection_method is set to '{ speaker_selection_method } '. "
285
319
f"It should be one of { self ._VALID_SPEAKER_SELECTION_METHODS } (case insensitive). "
286
320
)
287
321
@@ -300,7 +334,7 @@ def _prepare_and_select_agents(
300
334
f"GroupChat is underpopulated with { n_agents } agents. "
301
335
"Please add more agents to the GroupChat or use direct communication instead."
302
336
)
303
- elif n_agents == 2 and self . speaker_selection_method .lower () != "round_robin" and allow_repeat_speaker :
337
+ elif n_agents == 2 and speaker_selection_method .lower () != "round_robin" and allow_repeat_speaker :
304
338
logger .warning (
305
339
f"GroupChat is underpopulated with { n_agents } agents. "
306
340
"Consider setting speaker_selection_method to 'round_robin' or allow_repeat_speaker to False, "
@@ -366,11 +400,11 @@ def _prepare_and_select_agents(
366
400
367
401
# Use the selected speaker selection method
368
402
select_speaker_messages = None
369
- if self . speaker_selection_method .lower () == "manual" :
403
+ if speaker_selection_method .lower () == "manual" :
370
404
selected_agent = self .manual_select_speaker (graph_eligible_agents )
371
- elif self . speaker_selection_method .lower () == "round_robin" :
405
+ elif speaker_selection_method .lower () == "round_robin" :
372
406
selected_agent = self .next_agent (last_speaker , graph_eligible_agents )
373
- elif self . speaker_selection_method .lower () == "random" :
407
+ elif speaker_selection_method .lower () == "random" :
374
408
selected_agent = self .random_select_speaker (graph_eligible_agents )
375
409
else :
376
410
selected_agent = None
0 commit comments