refactor: tool

This commit is contained in:
Yeuoly
2024-09-20 23:48:48 +08:00
parent 3c1d32e3ac
commit 91cb80f795
29 changed files with 498 additions and 906 deletions

View File

@ -2,7 +2,7 @@ import json
import logging
from collections.abc import Generator
from copy import deepcopy
from typing import Any, Union
from typing import Any, Optional, Union
from core.agent.base_agent_runner import BaseAgentRunner
from core.app.apps.base_app_queue_manager import PublishFrom
@ -11,6 +11,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
SystemPromptMessage,
TextPromptMessageContent,
@ -38,18 +39,20 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# convert tools into ModelRuntime Tool format
tool_instances, prompt_messages_tools = self._init_prompt_tools()
assert app_config.agent
iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
# continue to run until there is not any tool call
function_call_state = True
llm_usage = {"usage": None}
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
final_answer = ""
# get tracing instance
trace_manager = app_generate_entity.trace_manager
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage
else:
@ -99,7 +102,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
current_llm_usage = None
if self.stream_tool_call:
if isinstance(chunks, Generator):
is_first_chunk = True
for chunk in chunks:
if is_first_chunk:
@ -133,7 +136,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
yield chunk
else:
result: LLMResult = chunks
result = chunks
# check if there is any tool call
if self.check_blocking_tool_calls(result):
function_call_state = True
@ -236,15 +239,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
)
# publish files
for message_file_id, save_as in message_files:
if save_as:
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
# publish message file
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
QueueMessageFileEvent(message_file_id=message_file_id.id), PublishFrom.APPLICATION_MANAGER
)
# add message file ids
message_file_ids.append(message_file_id)
message_file_ids.append(message_file_id.id)
tool_response = {
"tool_call_id": tool_call_id,
@ -290,7 +290,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
iteration_step += 1
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
@ -321,9 +320,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return True
return False
def extract_tool_calls(
self, llm_result_chunk: LLMResultChunk
) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
"""
Extract tool calls from llm result chunk
@ -346,7 +343,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return tool_calls
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
"""
Extract blocking tool calls from llm result
@ -370,7 +367,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return tool_calls
def _init_system_message(
self, prompt_template: str, prompt_messages: list[PromptMessage] = None
self, prompt_template: str, prompt_messages: list[PromptMessage]
) -> list[PromptMessage]:
"""
Initialize system message
@ -385,12 +382,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return prompt_messages
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Organize user query
"""
if self.files:
prompt_message_contents = [TextPromptMessageContent(data=query)]
prompt_message_contents: list[PromptMessageContent] = [TextPromptMessageContent(data=query)]
for file_obj in self.files:
prompt_message_contents.append(file_obj.prompt_message_content)