Skip to content

Commit

Permalink
add_encoding_declaration and format code
Browse files Browse the repository at this point in the history
  • Loading branch information
glide-the committed Jul 17, 2024
1 parent e938dc0 commit b6b228e
Show file tree
Hide file tree
Showing 55 changed files with 124 additions and 51 deletions.
1 change: 1 addition & 0 deletions langchain_glm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
# ruff: noqa: E402
"""Main entrypoint into package."""
from importlib import metadata
Expand Down
1 change: 1 addition & 0 deletions langchain_glm/agent_toolkits/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
from langchain_glm.agent_toolkits.all_tools import AdapterAllTool, BaseToolOutput

__all__ = ["BaseToolOutput", "AdapterAllTool"]
1 change: 1 addition & 0 deletions langchain_glm/agent_toolkits/all_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
from langchain_glm.agent_toolkits.all_tools.tool import (
AdapterAllTool,
BaseToolOutput,
Expand Down
16 changes: 5 additions & 11 deletions langchain_glm/agent_toolkits/all_tools/code_interpreter_tool.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
import json
import logging
from dataclasses import dataclass
Expand Down Expand Up @@ -34,9 +35,7 @@ def __init__(
**extras: Any,
) -> None:
data = CodeInterpreterToolOutput.paser_data(
tool=tool,
code_input=code_input,
code_output=code_output
tool=tool, code_input=code_input, code_output=code_output
)
super().__init__(data, "", "", **extras)
self.platform_params = platform_params
Expand All @@ -45,13 +44,10 @@ def __init__(
self.code_output = code_output

@staticmethod
def paser_data(
tool: str,
code_input: str,
code_output: Dict[str, Any]
) -> str:
def paser_data(tool: str, code_input: str, code_output: Dict[str, Any]) -> str:
return f"""Access:{tool}, Message: {code_input},{code_output}"""


@dataclass
class CodeInterpreterAllToolExecutor(AllToolExecutor):
"""platform adapter tool for code interpreter tool"""
Expand All @@ -70,9 +66,7 @@ def _python_ast_interpreter(
tool = PythonAstREPLTool()
out = tool.run(tool_input=code_input)
if str(out) == "":
raise ValueError(
f"Tool {tool.name} local sandbox is out empty"
)
raise ValueError(f"Tool {tool.name} local sandbox is out empty")
return CodeInterpreterToolOutput(
tool=tool.name,
code_input=code_input,
Expand Down
1 change: 1 addition & 0 deletions langchain_glm/agent_toolkits/all_tools/drawing_tool.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
Expand Down
1 change: 1 addition & 0 deletions langchain_glm/agent_toolkits/all_tools/registry.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
from typing import Dict, Type

from langchain_glm.agent_toolkits import AdapterAllTool
Expand Down
1 change: 1 addition & 0 deletions langchain_glm/agent_toolkits/all_tools/struct_type.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
"""IndexStructType class."""

from enum import Enum
Expand Down
1 change: 1 addition & 0 deletions langchain_glm/agent_toolkits/all_tools/tool.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
"""platform adapter tool """

from __future__ import annotations
Expand Down
1 change: 1 addition & 0 deletions langchain_glm/agent_toolkits/all_tools/web_browser_tool.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
Expand Down
1 change: 1 addition & 0 deletions langchain_glm/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
from langchain_glm.agents.zhipuai_all_tools import ZhipuAIAllToolsRunnable

__all__ = ["ZhipuAIAllToolsRunnable"]
1 change: 1 addition & 0 deletions langchain_glm/agents/all_tools_agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

import asyncio
Expand Down
1 change: 1 addition & 0 deletions langchain_glm/agents/all_tools_bind/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
from typing import Sequence

from langchain_core.language_models import BaseLanguageModel
Expand Down
21 changes: 12 additions & 9 deletions langchain_glm/agents/format_scratchpad/all_tools.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
import json
from typing import List, Sequence, Tuple, Union

Expand Down Expand Up @@ -25,7 +26,7 @@


def _create_tool_message(
agent_action: ToolAgentAction, observation: Union[str, BaseToolOutput]
agent_action: ToolAgentAction, observation: Union[str, BaseToolOutput]
) -> ToolMessage:
"""Convert agent action and observation into a function message.
Args:
Expand All @@ -49,7 +50,7 @@ def _create_tool_message(


def format_to_zhipuai_all_tool_messages(
intermediate_steps: Sequence[Tuple[AgentAction, BaseToolOutput]],
intermediate_steps: Sequence[Tuple[AgentAction, BaseToolOutput]],
) -> List[BaseMessage]:
"""Convert (AgentAction, tool output) tuples into FunctionMessages.
Expand All @@ -67,17 +68,21 @@ def format_to_zhipuai_all_tool_messages(
if "auto" == observation.platform_params.get("sandbox", "auto"):
new_messages = [
AIMessage(content=str(observation.code_input)),
_create_tool_message(agent_action, observation)
_create_tool_message(agent_action, observation),
]

messages.extend([new for new in new_messages if new not in messages])
messages.extend(
[new for new in new_messages if new not in messages]
)
elif "none" == observation.platform_params.get("sandbox", "auto"):
new_messages = [
new_messages = [
AIMessage(content=str(observation.code_input)),
_create_tool_message(agent_action, observation.code_output)
_create_tool_message(agent_action, observation.code_output),
]

messages.extend([new for new in new_messages if new not in messages])
messages.extend(
[new for new in new_messages if new not in messages]
)
else:
raise ValueError(
f"Unknown sandbox type: {observation.platform_params.get('sandbox', 'auto')}"
Expand All @@ -94,14 +99,12 @@ def format_to_zhipuai_all_tool_messages(

elif isinstance(agent_action, WebBrowserAgentAction):
if isinstance(observation, WebBrowserToolOutput):

new_messages = [AIMessage(content=str(observation))]
messages.extend([new for new in new_messages if new not in messages])
else:
raise ValueError(f"Unknown observation type: {type(observation)}")

elif isinstance(agent_action, ToolAgentAction):

ai_msgs = AIMessage(
content=f"arguments='{agent_action.tool_input}', name='{agent_action.tool}'"
)
Expand Down
1 change: 1 addition & 0 deletions langchain_glm/agents/output_parsers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
"""Parsing utils to go from string to AgentAction or Agent Finish.
AgentAction means that an action should be taken.
Expand Down
1 change: 1 addition & 0 deletions langchain_glm/agents/output_parsers/_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
# Function to find positions of object() instances
def find_object_positions(log_chunk, obj):
return [i for i, x in enumerate(log_chunk) if x == obj]
Expand Down
1 change: 1 addition & 0 deletions langchain_glm/agents/output_parsers/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
from typing import Any, Dict, Optional

from zhipuai.core import BaseModel
Expand Down
9 changes: 6 additions & 3 deletions langchain_glm/agents/output_parsers/code_interpreter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# -*- coding: utf-8 -*-
import json
import logging
from typing import Any, Dict, List, Deque, Union

from collections import deque
from typing import Any, Deque, Dict, List, Union

from langchain.agents.output_parsers.tools import ToolAgentAction
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
from langchain_core.exceptions import OutputParserException
Expand Down Expand Up @@ -106,7 +107,9 @@ def _paser_code_interpreter_chunk_input(
tool_call_id = (
code_interpreter_chunk[0].id if code_interpreter_chunk[0].id else "abc"
)
code_interpreter_action_result_stack: Deque[CodeInterpreterAgentAction] = deque()
code_interpreter_action_result_stack: Deque[
CodeInterpreterAgentAction
] = deque()
for i, action in enumerate(result_actions):
if len(result_actions) > len(outputs):
outputs.insert(i, [])
Expand Down
3 changes: 2 additions & 1 deletion langchain_glm/agents/output_parsers/drawing_tool.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# -*- coding: utf-8 -*-
import json
import logging
from collections import deque
from json import JSONDecodeError
from typing import Any, Dict, List, Deque, Union
from typing import Any, Deque, Dict, List, Union

from langchain.agents.output_parsers.tools import ToolAgentAction
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
Expand Down
31 changes: 17 additions & 14 deletions langchain_glm/agents/output_parsers/function.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# -*- coding: utf-8 -*-
import json
import logging
from collections import deque
from typing import Any, Dict, List, Deque, Union
from typing import Any, Deque, Dict, List, Union

from langchain.agents.output_parsers.tools import ToolAgentAction
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
Expand All @@ -13,14 +14,20 @@
from langchain_core.utils.json import parse_partial_json

from langchain_glm.agent_toolkits.all_tools.struct_type import AdapterAllToolStructType
from langchain_glm.agents.output_parsers._utils import find_object_positions, concatenate_segments
from langchain_glm.agents.output_parsers.base import AllToolsMessageToolCall, AllToolsMessageToolCallChunk
from langchain_glm.agents.output_parsers._utils import (
concatenate_segments,
find_object_positions,
)
from langchain_glm.agents.output_parsers.base import (
AllToolsMessageToolCall,
AllToolsMessageToolCallChunk,
)

logger = logging.getLogger(__name__)


def _best_effort_parse_function_tool_calls(
tool_call_chunks: List[dict],
tool_call_chunks: List[dict],
) -> List[Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk]]:
function_chunk: List[
Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk]
Expand Down Expand Up @@ -57,17 +64,13 @@ def _best_effort_parse_function_tool_calls(


def _paser_function_chunk_input(
message: BaseMessage,
function_chunk: List[
Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk]
],
message: BaseMessage,
function_chunk: List[Union[AllToolsMessageToolCall, AllToolsMessageToolCallChunk]],
) -> Deque[ToolAgentAction]:

try:
function_action_result_stack: Deque[ToolAgentAction] = deque()
for _chunk in function_chunk:
if isinstance(_chunk, AllToolsMessageToolCall):

function_name = _chunk.name
_tool_input = _chunk.args
tool_call_id = _chunk.id if _chunk.id else "abc"
Expand All @@ -76,7 +79,9 @@ def _paser_function_chunk_input(
else:
tool_input = _tool_input

content_msg = f"responded: {message.content}\n" if message.content else "\n"
content_msg = (
f"responded: {message.content}\n" if message.content else "\n"
)
log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n"

function_action_result_stack.append(
Expand All @@ -93,6 +98,4 @@ def _paser_function_chunk_input(

except Exception as e:
logger.error(f"Error parsing function_chunk: {e}", exc_info=True)
raise OutputParserException(
f"Error parsing function_chunk: {e} "
)
raise OutputParserException(f"Error parsing function_chunk: {e} ")
16 changes: 10 additions & 6 deletions langchain_glm/agents/output_parsers/tools.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# -*- coding: utf-8 -*-
import json
import logging
from collections import deque
from json import JSONDecodeError
from typing import Any, Dict, List, Deque, Union
from typing import Any, Deque, Dict, List, Union

from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
from langchain_core.exceptions import OutputParserException
Expand Down Expand Up @@ -32,9 +33,10 @@
_best_effort_parse_drawing_tool_tool_calls,
_paser_drawing_tool_chunk_input,
)

from langchain_glm.agents.output_parsers.function import _best_effort_parse_function_tool_calls, \
_paser_function_chunk_input
from langchain_glm.agents.output_parsers.function import (
_best_effort_parse_function_tool_calls,
_paser_function_chunk_input,
)
from langchain_glm.agents.output_parsers.web_browser import (
_best_effort_parse_web_browser_tool_calls,
_paser_web_browser_chunk_input,
Expand All @@ -44,7 +46,9 @@
logger = logging.getLogger(__name__)


def paser_ai_message_to_tool_calls(message: BaseMessage, ):
def paser_ai_message_to_tool_calls(
message: BaseMessage,
):
tool_calls = []
if message.tool_calls:
tool_calls = message.tool_calls
Expand Down Expand Up @@ -87,7 +91,7 @@ def paser_ai_message_to_tool_calls(message: BaseMessage, ):


def parse_ai_message_to_tool_action(
message: BaseMessage,
message: BaseMessage,
) -> Union[List[AgentAction], AgentFinish]:
"""Parse an AI message potentially containing tool_calls."""
if not isinstance(message, AIMessage):
Expand Down
7 changes: 3 additions & 4 deletions langchain_glm/agents/output_parsers/web_browser.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# -*- coding: utf-8 -*-
import json
import logging
from collections import deque
from json import JSONDecodeError
from typing import Any, Dict, List, Deque, Union
from typing import Any, Deque, Dict, List, Union

from langchain.agents.output_parsers.tools import ToolAgentAction
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
Expand Down Expand Up @@ -130,6 +131,4 @@ def _paser_web_browser_chunk_input(
return web_browser_action_result_stack
except Exception as e:
logger.error(f"Error parsing web_browser_chunk: {e}", exc_info=True)
raise OutputParserException(
f"Could not parse tool input: web_browser {e} "
)
raise OutputParserException(f"Could not parse tool input: web_browser {e} ")
1 change: 1 addition & 0 deletions langchain_glm/agents/output_parsers/zhipuai_all_tools.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
from typing import List, Union

from langchain.agents.agent import MultiActionAgentOutputParser
Expand Down
1 change: 1 addition & 0 deletions langchain_glm/agents/zhipuai_all_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
from langchain_glm.agents.zhipuai_all_tools.base import (
ZhipuAIAllToolsRunnable,
)
Expand Down
1 change: 1 addition & 0 deletions langchain_glm/agents/zhipuai_all_tools/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
import asyncio
import json
import logging
Expand Down
1 change: 1 addition & 0 deletions langchain_glm/agents/zhipuai_all_tools/schema.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
import json
import uuid
from abc import abstractmethod
Expand Down
1 change: 1 addition & 0 deletions langchain_glm/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
"""**Callback handlers** allow listening to events in LangChain.
**Class hierarchy:**
Expand Down
1 change: 1 addition & 0 deletions langchain_glm/callbacks/agent_callback_handler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

import asyncio
Expand Down
1 change: 1 addition & 0 deletions langchain_glm/chat_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
from langchain_glm.chat_models.base import ChatZhipuAI

__all__ = [
Expand Down
Loading

0 comments on commit b6b228e

Please sign in to comment.