refactor: tool

This commit is contained in:
Yeuoly
2024-09-20 23:48:48 +08:00
parent 3c1d32e3ac
commit 91cb80f795
29 changed files with 498 additions and 906 deletions

View File

@ -2,7 +2,6 @@ import json
import logging
import uuid
from collections.abc import Mapping, Sequence
from datetime import datetime, timezone
from typing import Optional, Union, cast
from core.agent.entities import AgentEntity, AgentToolEntity
@ -23,6 +22,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageContent,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
@ -31,18 +31,15 @@ from core.model_runtime.entities.message_entities import (
)
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import (
ToolParameter,
ToolRuntimeVariablePool,
)
from core.tools.tool_manager import ToolManager
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
from extensions.ext_database import db
from models.model import Conversation, Message, MessageAgentThought
from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__)
@ -59,11 +56,9 @@ class BaseAgentRunner(AppRunner):
queue_manager: AppQueueManager,
message: Message,
user_id: str,
model_instance: ModelInstance,
memory: Optional[TokenBufferMemory] = None,
prompt_messages: Optional[list[PromptMessage]] = None,
variables_pool: Optional[ToolRuntimeVariablePool] = None,
db_variables: Optional[ToolConversationVariables] = None,
model_instance: ModelInstance = None,
) -> None:
"""
Agent runner
@ -93,8 +88,6 @@ class BaseAgentRunner(AppRunner):
self.user_id = user_id
self.memory = memory
self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
self.variables_pool = variables_pool
self.db_variables_pool = db_variables
self.model_instance = model_instance
# init callback
@ -162,11 +155,10 @@ class BaseAgentRunner(AppRunner):
agent_tool=tool,
invoke_from=self.application_generate_entity.invoke_from,
)
tool_entity.load_variables(self.variables_pool)
assert tool_entity.entity.description
message_tool = PromptMessageTool(
name=tool.tool_name,
description=tool_entity.description.llm,
description=tool_entity.entity.description.llm,
parameters={
"type": "object",
"properties": {},
@ -201,9 +193,11 @@ class BaseAgentRunner(AppRunner):
"""
convert dataset retriever tool to prompt message tool
"""
assert tool.entity.description
prompt_tool = PromptMessageTool(
name=tool.identity.name,
description=tool.description.llm,
name=tool.entity.identity.name,
description=tool.entity.description.llm,
parameters={
"type": "object",
"properties": {},
@ -232,7 +226,7 @@ class BaseAgentRunner(AppRunner):
tool_instances = {}
prompt_messages_tools = []
for tool in self.app_config.agent.tools if self.app_config.agent else []:
for tool in self.app_config.agent.tools or [] if self.app_config.agent else []:
try:
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
except Exception:
@ -249,7 +243,7 @@ class BaseAgentRunner(AppRunner):
# save prompt tool
prompt_messages_tools.append(prompt_tool)
# save tool entity
tool_instances[dataset_tool.identity.name] = dataset_tool
tool_instances[dataset_tool.entity.identity.name] = dataset_tool
return tool_instances, prompt_messages_tools
@ -328,25 +322,29 @@ class BaseAgentRunner(AppRunner):
def save_agent_thought(
self,
agent_thought: MessageAgentThought,
tool_name: str,
tool_input: Union[str, dict],
thought: str,
observation: Union[str, dict],
tool_invoke_meta: Union[str, dict],
answer: str,
tool_name: str | None,
tool_input: Union[str, dict, None],
thought: str | None,
observation: Union[str, dict, None],
tool_invoke_meta: Union[str, dict, None],
answer: str | None,
messages_ids: list[str],
llm_usage: LLMUsage = None,
) -> MessageAgentThought:
llm_usage: LLMUsage | None = None,
):
"""
Save agent thought
"""
agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
updated_agent_thought = (
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
)
if not updated_agent_thought:
raise ValueError("agent thought not found")
if thought is not None:
agent_thought.thought = thought
updated_agent_thought.thought = thought
if tool_name is not None:
agent_thought.tool = tool_name
updated_agent_thought.tool = tool_name
if tool_input is not None:
if isinstance(tool_input, dict):
@ -355,7 +353,7 @@ class BaseAgentRunner(AppRunner):
except Exception as e:
tool_input = json.dumps(tool_input)
agent_thought.tool_input = tool_input
updated_agent_thought.tool_input = tool_input
if observation is not None:
if isinstance(observation, dict):
@ -364,27 +362,27 @@ class BaseAgentRunner(AppRunner):
except Exception as e:
observation = json.dumps(observation)
agent_thought.observation = observation
updated_agent_thought.observation = observation
if answer is not None:
agent_thought.answer = answer
updated_agent_thought.answer = answer
if messages_ids is not None and len(messages_ids) > 0:
agent_thought.message_files = json.dumps(messages_ids)
updated_agent_thought.message_files = json.dumps(messages_ids)
if llm_usage:
agent_thought.message_token = llm_usage.prompt_tokens
agent_thought.message_price_unit = llm_usage.prompt_price_unit
agent_thought.message_unit_price = llm_usage.prompt_unit_price
agent_thought.answer_token = llm_usage.completion_tokens
agent_thought.answer_price_unit = llm_usage.completion_price_unit
agent_thought.answer_unit_price = llm_usage.completion_unit_price
agent_thought.tokens = llm_usage.total_tokens
agent_thought.total_price = llm_usage.total_price
updated_agent_thought.message_token = llm_usage.prompt_tokens
updated_agent_thought.message_price_unit = llm_usage.prompt_price_unit
updated_agent_thought.message_unit_price = llm_usage.prompt_unit_price
updated_agent_thought.answer_token = llm_usage.completion_tokens
updated_agent_thought.answer_price_unit = llm_usage.completion_price_unit
updated_agent_thought.answer_unit_price = llm_usage.completion_unit_price
updated_agent_thought.tokens = llm_usage.total_tokens
updated_agent_thought.total_price = llm_usage.total_price
# check if tool labels is not empty
labels = agent_thought.tool_labels or {}
tools = agent_thought.tool.split(";") if agent_thought.tool else []
labels = updated_agent_thought.tool_labels or {}
tools = updated_agent_thought.tool.split(";") if updated_agent_thought.tool else []
for tool in tools:
if not tool:
continue
@ -395,7 +393,7 @@ class BaseAgentRunner(AppRunner):
else:
labels[tool] = {"en_US": tool, "zh_Hans": tool}
agent_thought.tool_labels_str = json.dumps(labels)
updated_agent_thought.tool_labels_str = json.dumps(labels)
if tool_invoke_meta is not None:
if isinstance(tool_invoke_meta, dict):
@ -404,28 +402,11 @@ class BaseAgentRunner(AppRunner):
except Exception as e:
tool_invoke_meta = json.dumps(tool_invoke_meta)
agent_thought.tool_meta_str = tool_invoke_meta
updated_agent_thought.tool_meta_str = tool_invoke_meta
db.session.commit()
db.session.close()
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
"""
convert tool variables to db variables
"""
db_variables = (
db.session.query(ToolConversationVariables)
.filter(
ToolConversationVariables.conversation_id == self.message.conversation_id,
)
.first()
)
db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
db.session.commit()
db.session.close()
def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Organize agent history
@ -515,6 +496,7 @@ class BaseAgentRunner(AppRunner):
files = message.message_files
if files:
assert message.app_model_config
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
if file_extra_config:
@ -525,7 +507,7 @@ class BaseAgentRunner(AppRunner):
if not file_objs:
return UserPromptMessage(content=message.query)
else:
prompt_message_contents = [TextPromptMessageContent(data=message.query)]
prompt_message_contents: list[PromptMessageContent] = [TextPromptMessageContent(data=message.query)]
for file_obj in file_objs:
prompt_message_contents.append(file_obj.prompt_message_content)

View File

@ -1,6 +1,6 @@
import json
from abc import ABC, abstractmethod
from collections.abc import Generator
from collections.abc import Generator, Mapping, Sequence
from typing import Optional, Union
from core.agent.base_agent_runner import BaseAgentRunner
@ -12,6 +12,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageTool,
ToolPromptMessage,
UserPromptMessage,
)
@ -26,11 +27,11 @@ from models.model import Message
class CotAgentRunner(BaseAgentRunner, ABC):
_is_first_iteration = True
_ignore_observation_providers = ["wenxin"]
_historic_prompt_messages: list[PromptMessage] = None
_agent_scratchpad: list[AgentScratchpadUnit] = None
_instruction: str = None
_query: str = None
_prompt_messages_tools: list[PromptMessage] = None
_historic_prompt_messages: list[PromptMessage]
_agent_scratchpad: list[AgentScratchpadUnit]
_instruction: str
_query: str
_prompt_messages_tools: Sequence[PromptMessageTool]
def run(
self,
@ -41,6 +42,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
"""
Run Cot agent application
"""
app_generate_entity = self.application_generate_entity
self._repack_app_generate_entity(app_generate_entity)
self._init_react_state(query)
@ -53,9 +55,11 @@ class CotAgentRunner(BaseAgentRunner, ABC):
app_generate_entity.model_conf.stop.append("Observation")
app_config = self.app_config
assert app_config.agent
# init instruction
inputs = inputs or {}
assert app_config.prompt_template.simple_prompt_template
instruction = app_config.prompt_template.simple_prompt_template
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
@ -63,13 +67,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
# convert tools into ModelRuntime Tool format
tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
tool_instances, prompt_messages_tools = self._init_prompt_tools()
self._prompt_messages_tools = prompt_messages_tools
function_call_state = True
llm_usage = {"usage": None}
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
final_answer = ""
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage
else:
@ -115,10 +120,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
callbacks=[],
)
# check llm result
if not chunks:
raise ValueError("failed to invoke llm")
usage_dict = {}
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
scratchpad = AgentScratchpadUnit(
@ -139,11 +140,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
if isinstance(chunk, AgentScratchpadUnit.Action):
action = chunk
# detect action
assert scratchpad.agent_response is not None
scratchpad.agent_response += json.dumps(chunk.model_dump())
scratchpad.action_str = json.dumps(chunk.model_dump())
scratchpad.action = action
else:
assert scratchpad.agent_response is not None
scratchpad.agent_response += chunk
assert scratchpad.thought is not None
scratchpad.thought += chunk
yield LLMResultChunk(
model=self.model_config.model,
@ -152,6 +156,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
)
assert scratchpad.thought is not None
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
self._agent_scratchpad.append(scratchpad)
@ -168,7 +173,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
tool_invoke_meta={},
thought=scratchpad.thought,
observation="",
answer=scratchpad.agent_response,
answer=scratchpad.agent_response or "",
messages_ids=[],
llm_usage=usage_dict["usage"],
)
@ -248,7 +253,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
messages_ids=[],
)
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
@ -266,7 +270,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
def _handle_invoke_action(
self,
action: AgentScratchpadUnit.Action,
tool_instances: dict[str, Tool],
tool_instances: Mapping[str, Tool],
message_file_ids: list[str],
trace_manager: Optional[TraceQueueManager] = None,
) -> tuple[str, ToolInvokeMeta]:
@ -307,15 +311,12 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# publish files
for message_file_id, save_as in message_files:
if save_as:
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
# publish message file
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
QueueMessageFileEvent(message_file_id=message_file_id.id), PublishFrom.APPLICATION_MANAGER
)
# add message file ids
message_file_ids.append(message_file_id)
message_file_ids.append(message_file_id.id)
return tool_invoke_response, tool_invoke_meta
@ -369,18 +370,19 @@ class CotAgentRunner(BaseAgentRunner, ABC):
return message
def _organize_historic_prompt_messages(
self, current_session_messages: list[PromptMessage] = None
self, current_session_messages: list[PromptMessage] | None = None
) -> list[PromptMessage]:
"""
organize historic prompt messages
"""
result: list[PromptMessage] = []
scratchpads: list[AgentScratchpadUnit] = []
current_scratchpad: AgentScratchpadUnit = None
current_scratchpad: AgentScratchpadUnit | None = None
for message in self.history_prompt_messages:
if isinstance(message, AssistantPromptMessage):
if not current_scratchpad:
assert isinstance(message.content, str)
current_scratchpad = AgentScratchpadUnit(
agent_response=message.content,
thought=message.content or "I am thinking about how to help you",
@ -400,6 +402,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
pass
elif isinstance(message, ToolPromptMessage):
if current_scratchpad:
assert isinstance(message.content, str)
current_scratchpad.observation = message.content
elif isinstance(message, UserPromptMessage):
if scratchpads:

View File

@ -4,6 +4,7 @@ from core.agent.cot_agent_runner import CotAgentRunner
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageContent,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
@ -16,6 +17,9 @@ class CotChatAgentRunner(CotAgentRunner):
"""
Organize system prompt
"""
assert self.app_config.agent
assert self.app_config.agent.prompt
prompt_entity = self.app_config.agent.prompt
first_prompt = prompt_entity.first_prompt
@ -27,12 +31,12 @@ class CotChatAgentRunner(CotAgentRunner):
return SystemPromptMessage(content=system_prompt)
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Organize user query
"""
if self.files:
prompt_message_contents = [TextPromptMessageContent(data=query)]
prompt_message_contents: list[PromptMessageContent] = [TextPromptMessageContent(data=query)]
for file_obj in self.files:
prompt_message_contents.append(file_obj.prompt_message_content)
@ -57,8 +61,10 @@ class CotChatAgentRunner(CotAgentRunner):
assistant_message = AssistantPromptMessage(content="")
for unit in agent_scratchpad:
if unit.is_final():
assert isinstance(assistant_message.content, str)
assistant_message.content += f"Final Answer: {unit.agent_response}"
else:
assert isinstance(assistant_message.content, str)
assistant_message.content += f"Thought: {unit.thought}\n\n"
if unit.action_str:
assistant_message.content += f"Action: {unit.action_str}\n\n"

View File

@ -2,7 +2,7 @@ import json
import logging
from collections.abc import Generator
from copy import deepcopy
from typing import Any, Union
from typing import Any, Optional, Union
from core.agent.base_agent_runner import BaseAgentRunner
from core.app.apps.base_app_queue_manager import PublishFrom
@ -11,6 +11,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
SystemPromptMessage,
TextPromptMessageContent,
@ -38,18 +39,20 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# convert tools into ModelRuntime Tool format
tool_instances, prompt_messages_tools = self._init_prompt_tools()
assert app_config.agent
iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
# continue to run until there is not any tool call
function_call_state = True
llm_usage = {"usage": None}
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
final_answer = ""
# get tracing instance
trace_manager = app_generate_entity.trace_manager
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage
else:
@ -99,7 +102,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
current_llm_usage = None
if self.stream_tool_call:
if isinstance(chunks, Generator):
is_first_chunk = True
for chunk in chunks:
if is_first_chunk:
@ -133,7 +136,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
yield chunk
else:
result: LLMResult = chunks
result = chunks
# check if there is any tool call
if self.check_blocking_tool_calls(result):
function_call_state = True
@ -236,15 +239,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
)
# publish files
for message_file_id, save_as in message_files:
if save_as:
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
# publish message file
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
QueueMessageFileEvent(message_file_id=message_file_id.id), PublishFrom.APPLICATION_MANAGER
)
# add message file ids
message_file_ids.append(message_file_id)
message_file_ids.append(message_file_id.id)
tool_response = {
"tool_call_id": tool_call_id,
@ -290,7 +290,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
iteration_step += 1
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
@ -321,9 +320,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return True
return False
def extract_tool_calls(
self, llm_result_chunk: LLMResultChunk
) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
"""
Extract tool calls from llm result chunk
@ -346,7 +343,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return tool_calls
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
"""
Extract blocking tool calls from llm result
@ -370,7 +367,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return tool_calls
def _init_system_message(
self, prompt_template: str, prompt_messages: list[PromptMessage] = None
self, prompt_template: str, prompt_messages: list[PromptMessage]
) -> list[PromptMessage]:
"""
Initialize system message
@ -385,12 +382,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return prompt_messages
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Organize user query
"""
if self.files:
prompt_message_contents = [TextPromptMessageContent(data=query)]
prompt_message_contents: list[PromptMessageContent] = [TextPromptMessageContent(data=query)]
for file_obj in self.files:
prompt_message_contents.append(file_obj.prompt_message_content)