Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Azure AI Chat Completion Client #4723

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/packages/autogen-core/docs/src/reference/index.md
Original file line number Diff line number Diff line change
@@ -51,6 +51,7 @@ python/autogen_ext.teams.magentic_one
python/autogen_ext.models.cache
python/autogen_ext.models.openai
python/autogen_ext.models.replay
python/autogen_ext.models.azure
python/autogen_ext.models.semantic_kernel
python/autogen_ext.tools.langchain
python/autogen_ext.tools.graphrag
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
autogen\_ext.models.azure
==========================


.. automodule:: autogen_ext.models.azure
:members:
:undoc-members:
:show-inheritance:
18 changes: 5 additions & 13 deletions python/packages/autogen-ext/pyproject.toml
Original file line number Diff line number Diff line change
@@ -20,7 +20,11 @@ dependencies = [

[project.optional-dependencies]
langchain = ["langchain_core~= 0.3.3"]
azure = ["azure-core", "azure-identity"]
azure = [
"azure-ai-inference>=1.0.0b7",
"azure-core",
"azure-identity",
]
docker = ["docker~=7.0"]
openai = ["openai>=1.52.2", "tiktoken>=0.8.0", "aiofiles"]
file-surfer = [
@@ -52,55 +56,43 @@ diskcache = [
redis = [
"redis>=5.2.1"
]

grpc = [
"grpcio~=1.62.0", # TODO: update this once we have a stable version.
]
jupyter-executor = [
"ipykernel>=6.29.5",
"nbclient>=0.10.2",
]

semantic-kernel-core = [
"semantic-kernel>=1.17.1",
]

semantic-kernel-google = [
"semantic-kernel[google]>=1.17.1",
]

semantic-kernel-hugging-face = [
"semantic-kernel[hugging_face]>=1.17.1",
]

semantic-kernel-mistralai = [
"semantic-kernel[mistralai]>=1.17.1",
]

semantic-kernel-ollama = [
"semantic-kernel[ollama]>=1.17.1",
]

semantic-kernel-onnx = [
"semantic-kernel[onnx]>=1.17.1",
]

semantic-kernel-anthropic = [
"semantic-kernel[anthropic]>=1.17.1",
]

semantic-kernel-pandas = [
"semantic-kernel[pandas]>=1.17.1",
]

semantic-kernel-aws = [
"semantic-kernel[aws]>=1.17.1",
]

semantic-kernel-dapr = [
"semantic-kernel[dapr]>=1.17.1",
]

semantic-kernel-all = [
"semantic-kernel[google,hugging_face,mistralai,ollama,onnx,anthropic,usearch,pandas,aws,dapr]>=1.17.1",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from ._azure_ai_client import AzureAIChatCompletionClient
from .config import AzureAIChatCompletionClientConfig

__all__ = ["AzureAIChatCompletionClient", "AzureAIChatCompletionClientConfig"]

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import TypedDict, Union, Optional, List, Dict, Any
from azure.ai.inference.models import (
JsonSchemaFormat,
ChatCompletionsToolDefinition,
ChatCompletionsToolChoicePreset,
ChatCompletionsNamedToolChoice,
)

from azure.core.credentials import AzureKeyCredential
from azure.core.credentials_async import AsyncTokenCredential
from autogen_core.models import ModelInfo

GITHUB_MODELS_ENDPOINT = "https://models.inference.ai.azure.com"


class AzureAIClientArguments(TypedDict, total=False):
endpoint: str
credential: Union[AzureKeyCredential, AsyncTokenCredential]
model_info: ModelInfo


class AzureAICreateArguments(TypedDict, total=False):
frequency_penalty: Optional[float]
presence_penalty: Optional[float]
temperature: Optional[float]
top_p: Optional[float]
max_tokens: Optional[int]
response_format: Optional[Union[str, JsonSchemaFormat]]
stop: Optional[List[str]]
tools: Optional[List[ChatCompletionsToolDefinition]]
tool_choice: Optional[Union[str, ChatCompletionsToolChoicePreset, ChatCompletionsNamedToolChoice]]
seed: Optional[int]
model: Optional[str]
model_extras: Optional[Dict[str, Any]]


class AzureAIChatCompletionClientConfig(AzureAIClientArguments, AzureAICreateArguments):
pass
174 changes: 174 additions & 0 deletions python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import asyncio
from datetime import datetime
from typing import AsyncGenerator, Any

import pytest
from azure.ai.inference.aio import (
ChatCompletionsClient,
)


from azure.ai.inference.models import (
ChatChoice,
ChatResponseMessage,
CompletionsUsage,
)

from azure.ai.inference.models import (
ChatCompletions,
StreamingChatCompletionsUpdate,
StreamingChatChoiceUpdate,
StreamingChatResponseMessageUpdate,
)

from azure.core.credentials import AzureKeyCredential

from autogen_core import CancellationToken
from autogen_core.models import UserMessage
from autogen_ext.models.azure import AzureAIChatCompletionClient


async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[StreamingChatCompletionsUpdate, None]:
mock_chunks_content = ["Hello", " Another Hello", " Yet Another Hello"]

mock_chunks = [
StreamingChatChoiceUpdate(
index=0,
finish_reason="stop",
delta=StreamingChatResponseMessageUpdate(role="assistant", content=chunk_content),
)
for chunk_content in mock_chunks_content
]

for mock_chunk in mock_chunks:
await asyncio.sleep(0.1)
yield StreamingChatCompletionsUpdate(
id="id",
choices=[mock_chunk],
created=datetime.now(),
model="model",
usage=CompletionsUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
)


async def _mock_create(
*args: Any, **kwargs: Any
) -> ChatCompletions | AsyncGenerator[StreamingChatCompletionsUpdate, None]:
stream = kwargs.get("stream", False)

if not stream:
await asyncio.sleep(0.1)
return ChatCompletions(
id="id",
created=datetime.now(),
model="model",
choices=[
ChatChoice(
index=0, finish_reason="stop", message=ChatResponseMessage(content="Hello", role="assistant")
)
],
usage=CompletionsUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
)
else:
return _mock_create_stream(*args, **kwargs)


@pytest.mark.asyncio
async def test_azure_ai_chat_completion_client() -> None:
client = AzureAIChatCompletionClient(
endpoint="endpoint",
credential=AzureKeyCredential("api_key"),
model_info={
"family": "unknown",
"json_output": False,
"function_calling": False,
"vision": False,
},
model="model",
)
assert client


@pytest.mark.asyncio
async def test_azure_ai_chat_completion_client_create(monkeypatch: pytest.MonkeyPatch) -> None:
# monkeypatch.setattr(AsyncCompletions, "create", _mock_create)
monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create)
client = AzureAIChatCompletionClient(
endpoint="endpoint",
credential=AzureKeyCredential("api_key"),
model_info={
"family": "unknown",
"json_output": False,
"function_calling": False,
"vision": False,
},
)
result = await client.create(messages=[UserMessage(content="Hello", source="user")])
assert result.content == "Hello"


@pytest.mark.asyncio
async def test_azure_ai_chat_completion_client_create_stream(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create)
chunks = []
client = AzureAIChatCompletionClient(
endpoint="endpoint",
credential=AzureKeyCredential("api_key"),
model_info={
"family": "unknown",
"json_output": False,
"function_calling": False,
"vision": False,
},
)
async for chunk in client.create_stream(messages=[UserMessage(content="Hello", source="user")]):
chunks.append(chunk)

assert chunks[0] == "Hello"
assert chunks[1] == " Another Hello"
assert chunks[2] == " Yet Another Hello"


@pytest.mark.asyncio
async def test_azure_ai_chat_completion_client_create_cancel(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create)
cancellation_token = CancellationToken()
client = AzureAIChatCompletionClient(
endpoint="endpoint",
credential=AzureKeyCredential("api_key"),
model_info={
"family": "unknown",
"json_output": False,
"function_calling": False,
"vision": False,
},
)
task = asyncio.create_task(
client.create(messages=[UserMessage(content="Hello", source="user")], cancellation_token=cancellation_token)
)
cancellation_token.cancel()
with pytest.raises(asyncio.CancelledError):
await task


@pytest.mark.asyncio
async def test_azure_ai_chat_completion_client_create_stream_cancel(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create)
cancellation_token = CancellationToken()
client = AzureAIChatCompletionClient(
endpoint="endpoint",
credential=AzureKeyCredential("api_key"),
model_info={
"family": "unknown",
"json_output": False,
"function_calling": False,
"vision": False,
},
)
stream = client.create_stream(
messages=[UserMessage(content="Hello", source="user")], cancellation_token=cancellation_token
)
cancellation_token.cancel()
with pytest.raises(asyncio.CancelledError):
async for _ in stream:
pass