Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Make FunctionTools Declarative #5052

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
TimeoutTermination,
TokenUsageTermination,
)
from autogen_core import ComponentLoader, ComponentModel
from autogen_core import ComponentLoader, ComponentModel, CancellationToken
from autogen_core.tools import FunctionTool
from autogen_core.code_executor import ImportFromModule


@pytest.mark.asyncio
Expand Down Expand Up @@ -92,3 +94,59 @@ async def test_termination_declarative() -> None:
# Test loading complex composition
loaded_composite = ComponentLoader.load_component(composite_config)
assert isinstance(loaded_composite, AndTerminationCondition)


@pytest.mark.asyncio
async def test_function_tool() -> None:
"""Test FunctionTool with different function types and features."""

# Test sync and async functions
def sync_func(x: int, y: str) -> str:
return y * x

async def async_func(x: float, y: float, cancellation_token: CancellationToken) -> float:
if cancellation_token.is_cancelled():
raise Exception("Cancelled")
return x + y

# Create tools with different configurations
sync_tool = FunctionTool(
func=sync_func, description="Multiply string", global_imports=[ImportFromModule("typing", ("Dict",))]
)
async_tool = FunctionTool(
func=async_func,
description="Add numbers",
name="custom_adder",
global_imports=[ImportFromModule("autogen_core", ("CancellationToken",))],
)

# Test serialization and config

sync_config = sync_tool.dump_component()
assert isinstance(sync_config, ComponentModel)
assert sync_config.config["name"] == "sync_func"
assert len(sync_config.config["global_imports"]) == 1
assert not sync_config.config["has_cancellation_support"]

async_config = async_tool.dump_component()
assert async_config.config["name"] == "custom_adder"
assert async_config.config["has_cancellation_support"]

# Test deserialization and execution
loaded_sync = FunctionTool.load_component(sync_config, FunctionTool)
loaded_async = FunctionTool.load_component(async_config, FunctionTool)

# Test execution and validation
token = CancellationToken()
assert await loaded_sync.run_json({"x": 2, "y": "test"}, token) == "testtest"
assert await loaded_async.run_json({"x": 1.5, "y": 2.5}, token) == 4.0

# Test error cases
with pytest.raises(ValueError):
# Type error
await loaded_sync.run_json({"x": "invalid", "y": "test"}, token)

cancelled_token = CancellationToken()
cancelled_token.cancel()
with pytest.raises(Exception, match="Cancelled"):
await loaded_async.run_json({"x": 1.0, "y": 2.0}, cancelled_token)
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
module: str
imports: Tuple[Union[str, Alias], ...]

## backward compatibility
# backward compatibility
def __init__(
self,
module: str,
Expand Down Expand Up @@ -214,3 +214,11 @@

content += " ..."
return content


def to_code(func: Union[FunctionWithRequirements[T, P], Callable[P, T], FunctionWithRequirementsStr]) -> str:
return _to_code(func)

Check warning on line 220 in python/packages/autogen-core/src/autogen_core/code_executor/_func_with_reqs.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/code_executor/_func_with_reqs.py#L220

Added line #L220 was not covered by tests


def import_to_str(im: Import) -> str:
return _import_to_str(im)

Check warning on line 224 in python/packages/autogen-core/src/autogen_core/code_executor/_func_with_reqs.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/code_executor/_func_with_reqs.py#L224

Added line #L224 was not covered by tests
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

from ._base import BaseTool, BaseToolWithState, ParametersSchema, Tool, ToolSchema
from ._function_tool import FunctionTool

Expand Down
9 changes: 7 additions & 2 deletions python/packages/autogen-core/src/autogen_core/tools/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing_extensions import NotRequired

from .. import CancellationToken
from .._component_config import ComponentBase
from .._function_utils import normalize_annotated_type

T = TypeVar("T", bound=BaseModel, contravariant=True)
Expand Down Expand Up @@ -56,7 +57,9 @@ def load_state_json(self, state: Mapping[str, Any]) -> None: ...
StateT = TypeVar("StateT", bound=BaseModel)


class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT]):
class BaseTool(ABC, Tool, Generic[ArgsT, ReturnT], ComponentBase[BaseModel]):
component_type = "tool"

def __init__(
self,
args_type: Type[ArgsT],
Expand Down Expand Up @@ -132,7 +135,7 @@ def load_state_json(self, state: Mapping[str, Any]) -> None:
pass


class BaseToolWithState(BaseTool[ArgsT, ReturnT], ABC, Generic[ArgsT, ReturnT, StateT]):
class BaseToolWithState(BaseTool[ArgsT, ReturnT], ABC, Generic[ArgsT, ReturnT, StateT], ComponentBase[BaseModel]):
def __init__(
self,
args_type: Type[ArgsT],
Expand All @@ -144,6 +147,8 @@ def __init__(
super().__init__(args_type, return_type, name, description)
self._state_type = state_type

component_type = "tool"

@abstractmethod
def save_state(self) -> StateT: ...

Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,32 @@
import asyncio
import functools
from typing import Any, Callable
from textwrap import dedent
from typing import Any, Callable, Sequence

from pydantic import BaseModel
from typing_extensions import Self

from .. import CancellationToken
from .._component_config import Component
from .._function_utils import (
args_base_model_from_signature,
get_typed_signature,
)
from ..code_executor._func_with_reqs import Import, import_to_str, to_code
from ._base import BaseTool


class FunctionTool(BaseTool[BaseModel, BaseModel]):
class FunctionToolConfig(BaseModel):
"""Configuration for a function tool."""

source_code: str
name: str
description: str
global_imports: Sequence[Import]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to understand this part: Import looks like a union of str, ImportModule, and Alias, it's a bit unclear to me what it is. Another thing, since import will be also included in the FunctionToolConfig, are they going to be imported at the load time of the FunctionTool or at call time?

Copy link
Collaborator Author

@victordibia victordibia Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question.
It will only be called at call time, when we load the component (wont be called when we dump component).

global_imports=[
        ImportFromModule("typing", ("Literal",)),
        ImportFromModule("autogen_core", ("CancellationToken",)),
        ImportFromModule("autogen_core.tools", ("FunctionTool",))
    ]

when serialized looks like

'global_imports': [{'module': 'typing', 'imports': ('Literal',)}, {'module': 'autogen_core', 'imports': ('CancellationToken',)}, {'module': 'autogen_core.tools', 'imports': ('FunctionTool',)}]

The goal here is to ensure we can dump and load.

  • dump:

    • convert _func to string, as source_code in FunctionToolConfig , store specified global_imports needed in function (e.g, type variables)
  • load (goal is to return an instance of the serialized FunctionTool)

    • use exec to load global_imports. Import is used here because it lets us specify something like from autogen_core import CancellationToken, it is already used in other parts of the lib. This is when the import is called
    • use exec to convert source_code from string to func
    • return an instance of FunctionTool(func=func)

Happy to review other proposals.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the global imports will be evaluated at FunctionTool.load(...) time.

While this has the same effect as importing any module in python, what it does effectively is, during application running, it can load and execute arbitrary code when deserializing a component. Not great for security.

I think what needs to happen by default instead, is to best-effort check the current global scope for whether the imports have already been loaded and provide an extra keyword argument to optionally automatically evaluate the required imports. Make sure the label this keyword argument as "insecure!" with

```{caution}
Evaluating the global_imports can lead to arbitrary code execution in your application's environment. It can reveal secretes and keys. Make sure you understand the implications. 
```

In the API doc.

In fact, we should do the same for FunctionTool.load(...). While it doesn't execute the function directly, it allows arbitrary code to be loaded and potentially executed when the application is running. Make sure to label it as such as well.

has_cancellation_support: bool


class FunctionTool(BaseTool[BaseModel, BaseModel], Component[FunctionToolConfig]):
"""
Create custom tools by wrapping standard Python functions.

Expand Down Expand Up @@ -64,8 +78,14 @@
asyncio.run(example())
"""

def __init__(self, func: Callable[..., Any], description: str, name: str | None = None) -> None:
component_provider_override = "autogen_core.tools.FunctionTool"
component_config_schema = FunctionToolConfig

def __init__(
self, func: Callable[..., Any], description: str, name: str | None = None, global_imports: Sequence[Import] = []
) -> None:
self._func = func
self._global_imports = global_imports
signature = get_typed_signature(func)
func_name = name or func.__name__
args_model = args_base_model_from_signature(func_name + "args", signature)
Expand Down Expand Up @@ -98,3 +118,44 @@
result = await future

return result

def _to_config(self) -> FunctionToolConfig:
return FunctionToolConfig(

Check warning on line 123 in python/packages/autogen-core/src/autogen_core/tools/_function_tool.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/tools/_function_tool.py#L123

Added line #L123 was not covered by tests
source_code=dedent(to_code(self._func)),
global_imports=self._global_imports,
name=self.name,
description=self.description,
has_cancellation_support=self._has_cancellation_support,
)

@classmethod
def _from_config(cls, config: FunctionToolConfig) -> Self:
exec_globals: dict[str, Any] = {}

Check warning on line 133 in python/packages/autogen-core/src/autogen_core/tools/_function_tool.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/tools/_function_tool.py#L133

Added line #L133 was not covered by tests

# Execute imports first
for import_stmt in config.global_imports:
import_code = import_to_str(import_stmt)
try:
exec(import_code, exec_globals)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(

Check warning on line 141 in python/packages/autogen-core/src/autogen_core/tools/_function_tool.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/tools/_function_tool.py#L136-L141

Added lines #L136 - L141 were not covered by tests
f"Failed to import {import_code}: Module not found. Please ensure the module is installed."
) from e
except ImportError as e:
raise ImportError(f"Failed to import {import_code}: {str(e)}") from e
except Exception as e:
raise RuntimeError(f"Unexpected error while importing {import_code}: {str(e)}") from e

Check warning on line 147 in python/packages/autogen-core/src/autogen_core/tools/_function_tool.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/tools/_function_tool.py#L144-L147

Added lines #L144 - L147 were not covered by tests

# Execute function code
try:
exec(config.source_code, exec_globals)
func_name = config.source_code.split("def ")[1].split("(")[0]
except Exception as e:
raise ValueError(f"Could not compile and load function: {e}") from e

Check warning on line 154 in python/packages/autogen-core/src/autogen_core/tools/_function_tool.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/tools/_function_tool.py#L150-L154

Added lines #L150 - L154 were not covered by tests

# Get function and verify it's callable
func: Callable[..., Any] = exec_globals[func_name]
if not callable(func):
raise TypeError(f"Expected function but got {type(func)}")

Check warning on line 159 in python/packages/autogen-core/src/autogen_core/tools/_function_tool.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/tools/_function_tool.py#L157-L159

Added lines #L157 - L159 were not covered by tests

return cls(func, "", None)

Check warning on line 161 in python/packages/autogen-core/src/autogen_core/tools/_function_tool.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/tools/_function_tool.py#L161

Added line #L161 was not covered by tests
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@
Image,
MessageHandlerContext,
)
from autogen_core.models import FinishReasons
from autogen_core.logging import LLMCallEvent
from autogen_core.models import (
AssistantMessage,
ChatCompletionClient,
ChatCompletionTokenLogprob,
CreateResult,
FinishReasons,
FunctionExecutionResultMessage,
LLMMessage,
ModelCapabilities, # type: ignore
Expand Down
Loading