Skip to content

Commit e483ec6

Browse files
authored
feat: integrate Agent from haystack-experimental (#9112)
* add Agent * add Agent * update imports * add state tests * reno * remove State, its utils, and tests * add pydoc yml for agents * fix module path in serialization test * fix mypy error and use ChatGenerator protocol * remove unused import * address review feedback * remove unused _load_component
1 parent 637dcb4 commit e483ec6

File tree

5 files changed

+501
-0
lines changed

5 files changed

+501
-0
lines changed

docs/pydoc/config/agents_api.yml

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
loaders:
2+
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
3+
search_path: [../../../haystack/components/agents]
4+
modules: ["agent"]
5+
ignore_when_discovered: ["__init__"]
6+
processors:
7+
- type: filter
8+
expression:
9+
documented_only: true
10+
do_not_filter_modules: false
11+
skip_empty_modules: true
12+
- type: smart
13+
- type: crossref
14+
renderer:
15+
type: haystack_pydoc_tools.renderers.ReadmeCoreRenderer
16+
excerpt: Tool-using agents with provider-agnostic chat model support.
17+
category_slug: haystack-api
18+
title: Agents
19+
slug: agents-api
20+
order: 2
21+
markdown:
22+
descriptive_class_title: false
23+
classdef_code_block: false
24+
descriptive_module_title: true
25+
add_method_class_prefix: true
26+
add_member_class_prefix: false
27+
filename: agents_api.md
+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import sys
6+
from typing import TYPE_CHECKING
7+
8+
from lazy_imports import LazyImporter
9+
10+
_import_structure = {"agent": ["Agent"]}
11+
12+
if TYPE_CHECKING:
13+
from .agent import Agent
14+
15+
else:
16+
sys.modules[__name__] = LazyImporter(name=__name__, module_file=__file__, import_structure=_import_structure)

haystack/components/agents/agent.py

+222
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from typing import Any, Dict, List, Optional
6+
7+
from haystack import component, default_from_dict, default_to_dict, logging
8+
from haystack.components.generators.chat.types import ChatGenerator
9+
from haystack.components.tools import ToolInvoker
10+
from haystack.core.serialization import import_class_by_name
11+
from haystack.dataclasses import ChatMessage
12+
from haystack.dataclasses.state import State, _schema_from_dict, _schema_to_dict, _validate_schema
13+
from haystack.dataclasses.streaming_chunk import SyncStreamingCallbackT
14+
from haystack.tools import Tool, deserialize_tools_inplace
15+
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
@component
21+
class Agent:
22+
"""
23+
A Haystack component that implements a tool-using agent with provider-agnostic chat model support.
24+
25+
The component processes messages and executes tools until a exit_condition condition is met.
26+
The exit_condition can be triggered either by a direct text response or by invoking a specific designated tool.
27+
28+
### Usage example
29+
```python
30+
from haystack.components.agents import Agent
31+
from haystack.components.generators.chat import OpenAIChatGenerator
32+
from haystack.dataclasses import ChatMessage
33+
from haystack.tools.tool import Tool
34+
35+
tools = [Tool(name="calculator", description="..."), Tool(name="search", description="...")]
36+
37+
agent = Agent(
38+
chat_generator=OpenAIChatGenerator(),
39+
tools=tools,
40+
exit_condition="search",
41+
)
42+
43+
# Run the agent
44+
result = agent.run(
45+
messages=[ChatMessage.from_user("Find information about Haystack")]
46+
)
47+
48+
assert "messages" in result # Contains conversation history
49+
```
50+
"""
51+
52+
def __init__(
53+
self,
54+
*,
55+
chat_generator: ChatGenerator,
56+
tools: Optional[List[Tool]] = None,
57+
system_prompt: Optional[str] = None,
58+
exit_condition: str = "text",
59+
state_schema: Optional[Dict[str, Any]] = None,
60+
max_runs_per_component: int = 100,
61+
raise_on_tool_invocation_failure: bool = False,
62+
streaming_callback: Optional[SyncStreamingCallbackT] = None,
63+
):
64+
"""
65+
Initialize the agent component.
66+
67+
:param chat_generator: An instance of the chat generator that your agent should use. It must support tools.
68+
:param tools: List of Tool objects available to the agent
69+
:param system_prompt: System prompt for the agent.
70+
:param exit_condition: Either "text" if the agent should return when it generates a message without tool calls
71+
or the name of a tool that will cause the agent to return once the tool was executed
72+
:param state_schema: The schema for the runtime state used by the tools.
73+
:param max_runs_per_component: Maximum number of runs per component. Agent will raise an exception if a
74+
component exceeds the maximum number of runs per component.
75+
:param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
76+
If set to False, the exception will be turned into a chat message and passed to the LLM.
77+
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
78+
"""
79+
valid_exits = ["text"] + [tool.name for tool in tools or []]
80+
if exit_condition not in valid_exits:
81+
raise ValueError(f"Exit condition must be one of {valid_exits}")
82+
83+
if state_schema is not None:
84+
_validate_schema(state_schema)
85+
self.state_schema = state_schema or {}
86+
87+
self.chat_generator = chat_generator
88+
self.tools = tools or []
89+
self.system_prompt = system_prompt
90+
self.exit_condition = exit_condition
91+
self.max_runs_per_component = max_runs_per_component
92+
self.raise_on_tool_invocation_failure = raise_on_tool_invocation_failure
93+
self.streaming_callback = streaming_callback
94+
95+
output_types = {"messages": List[ChatMessage]}
96+
for param, config in self.state_schema.items():
97+
component.set_input_type(self, name=param, type=config["type"], default=None)
98+
output_types[param] = config["type"]
99+
component.set_output_types(self, **output_types)
100+
101+
self._tool_invoker = ToolInvoker(tools=self.tools, raise_on_failure=self.raise_on_tool_invocation_failure)
102+
103+
self._is_warmed_up = False
104+
105+
def warm_up(self) -> None:
106+
"""
107+
Warm up the Agent.
108+
"""
109+
if not self._is_warmed_up:
110+
if hasattr(self.chat_generator, "warm_up"):
111+
self.chat_generator.warm_up()
112+
self._is_warmed_up = True
113+
114+
def to_dict(self) -> Dict[str, Any]:
115+
"""
116+
Serialize the component to a dictionary.
117+
118+
:return: Dictionary with serialized data
119+
"""
120+
if self.streaming_callback is not None:
121+
streaming_callback = serialize_callable(self.streaming_callback)
122+
else:
123+
streaming_callback = None
124+
125+
return default_to_dict(
126+
self,
127+
chat_generator=self.chat_generator.to_dict(),
128+
tools=[t.to_dict() for t in self.tools],
129+
system_prompt=self.system_prompt,
130+
exit_condition=self.exit_condition,
131+
state_schema=_schema_to_dict(self.state_schema),
132+
max_runs_per_component=self.max_runs_per_component,
133+
raise_on_tool_invocation_failure=self.raise_on_tool_invocation_failure,
134+
streaming_callback=streaming_callback,
135+
)
136+
137+
@classmethod
138+
def from_dict(cls, data: Dict[str, Any]) -> "Agent":
139+
"""
140+
Deserialize the agent from a dictionary.
141+
142+
:param data: Dictionary to deserialize from
143+
:return: Deserialized agent
144+
"""
145+
init_params = data.get("init_parameters", {})
146+
147+
chat_generator_class = import_class_by_name(init_params["chat_generator"]["type"])
148+
assert hasattr(chat_generator_class, "from_dict") # we know but mypy doesn't
149+
chat_generator_instance = chat_generator_class.from_dict(init_params["chat_generator"])
150+
data["init_parameters"]["chat_generator"] = chat_generator_instance
151+
152+
if "state_schema" in init_params:
153+
init_params["state_schema"] = _schema_from_dict(init_params["state_schema"])
154+
155+
if init_params.get("streaming_callback") is not None:
156+
init_params["streaming_callback"] = deserialize_callable(init_params["streaming_callback"])
157+
158+
deserialize_tools_inplace(init_params, key="tools")
159+
160+
return default_from_dict(cls, data)
161+
162+
def run(
163+
self, messages: List[ChatMessage], streaming_callback: Optional[SyncStreamingCallbackT] = None, **kwargs
164+
) -> Dict[str, Any]:
165+
"""
166+
Process messages and execute tools until the exit condition is met.
167+
168+
:param messages: List of chat messages to process
169+
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
170+
:param kwargs: Additional data to pass to the State schema used by the Agent.
171+
The keys must match the schema defined in the Agent's `state_schema`.
172+
:return: Dictionary containing messages and outputs matching the defined output types
173+
"""
174+
if not self._is_warmed_up and hasattr(self.chat_generator, "warm_up"):
175+
raise RuntimeError("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run()'.")
176+
177+
state = State(schema=self.state_schema, data=kwargs)
178+
179+
if self.system_prompt is not None:
180+
messages = [ChatMessage.from_system(self.system_prompt)] + messages
181+
182+
generator_inputs: Dict[str, Any] = {"tools": self.tools}
183+
184+
selected_callback = streaming_callback or self.streaming_callback
185+
if selected_callback is not None:
186+
generator_inputs["streaming_callback"] = selected_callback
187+
188+
# Repeat until the exit condition is met
189+
counter = 0
190+
while counter < self.max_runs_per_component:
191+
# 1. Call the ChatGenerator
192+
llm_messages = self.chat_generator.run(messages=messages, **generator_inputs)["replies"]
193+
194+
# TODO Possible for LLM to return multiple messages (e.g. multiple tool calls)
195+
# Would a better check be to see if any of the messages contain a tool call?
196+
# 2. Check if the LLM response contains a tool call
197+
if llm_messages[0].tool_call is None:
198+
return {"messages": messages + llm_messages, **state.data}
199+
200+
# 3. Call the ToolInvoker
201+
# We only send the messages from the LLM to the tool invoker
202+
tool_invoker_result = self._tool_invoker.run(messages=llm_messages, state=state)
203+
tool_messages = tool_invoker_result["messages"]
204+
state = tool_invoker_result["state"]
205+
206+
# 4. Check the LLM and Tool response for the exit condition, if exit_condition is a tool name
207+
# TODO Possible for LLM to return multiple messages (e.g. multiple tool calls)
208+
# So exit condition could be missed if it's not the first message
209+
if self.exit_condition != "text" and (
210+
llm_messages[0].tool_call.tool_name == self.exit_condition
211+
and not tool_messages[0].tool_call_result.error
212+
):
213+
return {"messages": messages + llm_messages + tool_messages, **state.data}
214+
215+
# 5. Combine messages, llm_messages and tool_messages and send to the ChatGenerator
216+
messages = messages + llm_messages + tool_messages
217+
counter += 1
218+
219+
logger.warning(
220+
"Agent exceeded maximum runs per component ({max_loops}), stopping.", max_loops=self.max_runs_per_component
221+
)
222+
return {"messages": messages, **state.data}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
---
2+
highlights: >
3+
The Agent component enables tool-calling functionality with provider-agnostic chat model support and can be used as a standalone component or within a pipeline.
4+
5+
```python
6+
from haystack.components.agents import Agent
7+
from haystack.components.generators.chat import OpenAIChatGenerator
8+
from haystack.components.websearch import SerperDevWebSearch
9+
from haystack.dataclasses import ChatMessage
10+
from haystack.tools.component_tool import ComponentTool
11+
12+
web_tool = ComponentTool(
13+
component=SerperDevWebSearch(),
14+
)
15+
16+
agent = Agent(
17+
chat_generator=OpenAIChatGenerator(),
18+
tools=[web_tool],
19+
exit_condition="text",
20+
)
21+
22+
result = agent.run(
23+
messages=[ChatMessage.from_user("Find information about Haystack")]
24+
)
25+
```

0 commit comments

Comments
 (0)