Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix AssistantAgent Tool Call Behavior #4602

Merged
merged 24 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
ekzhu marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,42 @@ def model_post_init(self, __context: Any) -> None:
class AssistantAgent(BaseChatAgent):
"""An agent that provides assistance with tool use.

```{note}
The assistant agent is not thread-safe or coroutine-safe.
It should not be shared between multiple tasks or coroutines, and it should
not call its methods concurrently.
```
The :meth:`on_messages` returns a :class:`~autogen_agentchat.base.Response`
in which :attr:`~autogen_agentchat.base.Response.chat_message` is the final
response message.

The :meth:`on_messages_stream` creates an async generator that produces
the inner messages as they are created, and the :class:`~autogen_agentchat.base.Response`
object as the last item before closing the generator.

Tool call behavior:

* If the model returns no tool call, then the response is immediately returned as a :class:`~autogen_agentchat.messages.TextMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`.
* When the model returns tool calls, they will be executed right away:
- When `reflect_on_tool_use` is False (default), the tool call results are returned as a :class:`~autogen_agentchat.messages.TextMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`. `tool_call_summary_format` can be used to customize the tool call summary.
- When `reflect_on_tool_use` is True, the another model inference is made using the tool calls and results, and the text response is returned as a :class:`~autogen_agentchat.messages.TextMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`.

Hand off behavior:

* If a handoff is triggered, a :class:`~autogen_agentchat.messages.HandoffMessage` will be returned in :attr:`~autogen_agentchat.base.Response.chat_message`.
* If there are tool calls, they will also be executed right away before returning the handoff.


.. note::
The assistant agent is not thread-safe or coroutine-safe.
It should not be shared between multiple tasks or coroutines, and it should
not call its methods concurrently.

.. note::
By default, the tool call results are returned as response when tool calls are made.
So it is recommended to pay attention to the formatting of the tools return values,
especially if another agent is expecting them in a specific format.
Use `tool_call_summary_format` to customize the tool call summary, if needed.

.. note::
If multiple handoffs are detected, only the first handoff is executed.



Args:
name (str): The name of the agent.
Expand All @@ -66,11 +97,20 @@ class AssistantAgent(BaseChatAgent):
If a handoff is a string, it should represent the target agent's name.
description (str, optional): The description of the agent.
system_message (str, optional): The system message for the model.
reflect_on_tool_use (bool, optional): If `True`, the agent will make another model inference using the tool call and result
to generate a response. If `False`, the tool call result will be returned as the response. Defaults to `False`.
tool_call_summary_format (str, optional): The format string used to create a tool call summary for every tool call result.
Defaults to "{result}".
When `reflect_on_tool_use` is `False`, a concatenation of all the tool call summaries, separated by a new line character ('\\n')
will be returned as the response.
Available variables: `{tool_name}`, `{arguments}`, `{result}`.
For example, `"{tool_name}: {result}"` will create a summary like `"tool_name: result"`.

Raises:
ValueError: If tool names are not unique.
ValueError: If handoff names are not unique.
ValueError: If handoff names are not unique from tool names.
ValueError: If maximum number of tool iterations is less than 1.

Examples:

Expand Down Expand Up @@ -181,6 +221,8 @@ def __init__(
description: str = "An agent that provides assistance with ability to use tools.",
system_message: str
| None = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.",
reflect_on_tool_use: bool = False,
tool_call_summary_format: str = "{result}",
):
super().__init__(name=name, description=description)
self._model_client = model_client
Expand Down Expand Up @@ -231,6 +273,8 @@ def __init__(
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}"
)
self._model_context: List[LLMMessage] = []
self._reflect_on_tool_use = reflect_on_tool_use
self._tool_call_summary_format = tool_call_summary_format
self._is_running = False

@property
Expand Down Expand Up @@ -267,53 +311,77 @@ async def on_messages_stream(
# Add the response to the model context.
self._model_context.append(AssistantMessage(content=result.content, source=self.name))

# Run tool calls until the model produces a string response.
while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
tool_call_msg = ToolCallMessage(content=result.content, source=self.name, models_usage=result.usage)
event_logger.debug(tool_call_msg)
# Add the tool call message to the output.
inner_messages.append(tool_call_msg)
yield tool_call_msg

# Execute the tool calls.
results = await asyncio.gather(
*[self._execute_tool_call(call, cancellation_token) for call in result.content]
# Check if the response is a string and return it.
if isinstance(result.content, str):
yield Response(
chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage),
inner_messages=inner_messages,
)
tool_call_result_msg = ToolCallResultMessage(content=results, source=self.name)
event_logger.debug(tool_call_result_msg)
self._model_context.append(FunctionExecutionResultMessage(content=results))
inner_messages.append(tool_call_result_msg)
yield tool_call_result_msg

# Detect handoff requests.
handoffs: List[HandoffBase] = []
for call in result.content:
if call.name in self._handoffs:
handoffs.append(self._handoffs[call.name])
if len(handoffs) > 0:
if len(handoffs) > 1:
raise ValueError(f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}")
# Return the output messages to signal the handoff.
yield Response(
chat_message=HandoffMessage(
content=handoffs[0].message, target=handoffs[0].target, source=self.name
),
inner_messages=inner_messages,
return

# Process tool calls.
assert isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content)
tool_call_msg = ToolCallMessage(content=result.content, source=self.name, models_usage=result.usage)
event_logger.debug(tool_call_msg)
# Add the tool call message to the output.
inner_messages.append(tool_call_msg)
yield tool_call_msg

# Execute the tool calls.
results = await asyncio.gather(*[self._execute_tool_call(call, cancellation_token) for call in result.content])
tool_call_result_msg = ToolCallResultMessage(content=results, source=self.name)
event_logger.debug(tool_call_result_msg)
self._model_context.append(FunctionExecutionResultMessage(content=results))
inner_messages.append(tool_call_result_msg)
yield tool_call_result_msg

# Detect handoff requests.
handoffs: List[HandoffBase] = []
for call in result.content:
if call.name in self._handoffs:
handoffs.append(self._handoffs[call.name])
if len(handoffs) > 0:
if len(handoffs) > 1:
# show warning if multiple handoffs detected
warnings.warn(
f"Multiple handoffs detected only the first is executed: {[handoff.name for handoff in handoffs]}",
stacklevel=2,
)
return
# Return the output messages to signal the handoff.
yield Response(
chat_message=HandoffMessage(content=handoffs[0].message, target=handoffs[0].target, source=self.name),
inner_messages=inner_messages,
)
return

# Generate an inference result based on the current model context.
if self._reflect_on_tool_use:
# Generate another inference result based on the tool call and result.
llm_messages = self._system_messages + self._model_context
result = await self._model_client.create(
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
)
result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token)
assert isinstance(result.content, str)
# Add the response to the model context.
self._model_context.append(AssistantMessage(content=result.content, source=self.name))

assert isinstance(result.content, str)
yield Response(
chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage),
inner_messages=inner_messages,
)
# Yield the response.
yield Response(
chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage),
inner_messages=inner_messages,
)
else:
# Return tool call result as the response.
tool_call_summaries: List[str] = []
for i in range(len(tool_call_msg.content)):
tool_call_summaries.append(
self._tool_call_summary_format.format(
tool_name=tool_call_msg.content[i].name,
arguments=tool_call_msg.content[i].arguments,
result=tool_call_result_msg.content[i].content,
),
)
tool_call_summary = "\n".join(tool_call_summaries)
yield Response(
chat_message=TextMessage(content=tool_call_summary, source=self.name),
inner_messages=inner_messages,
)

async def _execute_tool_call(
self, tool_call: FunctionCall, cancellation_token: CancellationToken
Expand Down
101 changes: 101 additions & 0 deletions python/packages/autogen-agentchat/tests/test_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,106 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
)
result = await agent.run(task="task")

assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].models_usage is None
assert isinstance(result.messages[1], ToolCallMessage)
assert result.messages[1].models_usage is not None
assert result.messages[1].models_usage.completion_tokens == 5
assert result.messages[1].models_usage.prompt_tokens == 10
assert isinstance(result.messages[2], ToolCallResultMessage)
assert result.messages[2].models_usage is None
assert isinstance(result.messages[3], TextMessage)
assert result.messages[3].content == "pass"
assert result.messages[3].models_usage is None

# Test streaming.
mock._curr_index = 0 # pyright: ignore
index = 0
async for message in agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1

# Test state saving and loading.
state = await agent.save_state()
agent2 = AssistantAgent(
"tool_use_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
)
await agent2.load_state(state)
state2 = await agent2.save_state()
assert state == state2


@pytest.mark.asyncio
async def test_run_with_tools_and_reflection(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="tool_calls",
index=0,
message=ChatCompletionMessage(
content=None,
tool_calls=[
ChatCompletionMessageToolCall(
id="1",
type="function",
function=Function(
name="_pass_function",
arguments=json.dumps({"input": "task"}),
),
)
],
role="assistant",
),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
ChatCompletion(
id="id2",
choices=[
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant"))
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
ChatCompletion(
id="id2",
choices=[
Choice(
finish_reason="stop", index=0, message=ChatCompletionMessage(content="TERMINATE", role="assistant")
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
agent = AssistantAgent(
"tool_use_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
reflect_on_tool_use=True,
)
result = await agent.run(task="task")

assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].models_usage is None
Expand All @@ -128,6 +228,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
assert isinstance(result.messages[2], ToolCallResultMessage)
assert result.messages[2].models_usage is None
assert isinstance(result.messages[3], TextMessage)
assert result.messages[3].content == "Hello"
assert result.messages[3].models_usage is not None
assert result.messages[3].models_usage.completion_tokens == 5
assert result.messages[3].models_usage.prompt_tokens == 10
Expand Down
23 changes: 7 additions & 16 deletions python/packages/autogen-agentchat/tests/test_group_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@
from autogen_agentchat.teams._group_chat._selector_group_chat import SelectorGroupChatManager
from autogen_agentchat.teams._group_chat._swarm_group_chat import SwarmGroupChatManager
from autogen_agentchat.ui import Console
from autogen_core import AgentId, CancellationToken, FunctionCall
from autogen_core.models import FunctionExecutionResult
from autogen_core import AgentId, CancellationToken
from autogen_core.tools import FunctionTool
from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor
from autogen_ext.models import OpenAIChatCompletionClient, ReplayChatCompletionClient
Expand Down Expand Up @@ -306,6 +305,7 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
),
]
# Test with repeat tool calls once
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
tool = FunctionTool(_pass_function, name="pass", description="pass function")
Expand All @@ -320,27 +320,18 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
result = await team.run(
task="Write a program that prints 'Hello, world!'",
)

assert len(result.messages) == 6
assert len(result.messages) == 8
assert isinstance(result.messages[0], TextMessage) # task
assert isinstance(result.messages[1], ToolCallMessage) # tool call
assert isinstance(result.messages[2], ToolCallResultMessage) # tool call result
assert isinstance(result.messages[3], TextMessage) # tool use agent response
assert isinstance(result.messages[4], TextMessage) # echo agent response
assert isinstance(result.messages[5], TextMessage) # tool use agent response
assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"
assert isinstance(result.messages[6], TextMessage) # echo agent response
assert isinstance(result.messages[7], TextMessage) # tool use agent response, that has TERMINATE
assert result.messages[7].content == "TERMINATE"

context = tool_use_agent._model_context # pyright: ignore
assert context[0].content == "Write a program that prints 'Hello, world!'"
assert isinstance(context[1].content, list)
assert isinstance(context[1].content[0], FunctionCall)
assert context[1].content[0].name == "pass"
assert context[1].content[0].arguments == json.dumps({"input": "pass"})
assert isinstance(context[2].content, list)
assert isinstance(context[2].content[0], FunctionExecutionResult)
assert context[2].content[0].content == "pass"
assert context[2].content[0].call_id == "1"
assert context[3].content == "Hello"
assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"

# Test streaming.
tool_use_agent._model_context.clear() # pyright: ignore
Expand Down
Loading
Loading