mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 09:28:04 +08:00
refactor: tool
This commit is contained in:
@ -2,7 +2,6 @@ import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity
|
||||
@ -23,6 +22,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
@ -31,18 +31,15 @@ from core.model_runtime.entities.message_entities import (
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
ToolRuntimeVariablePool,
|
||||
)
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation, Message, MessageAgentThought
|
||||
from models.tools import ToolConversationVariables
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -59,11 +56,9 @@ class BaseAgentRunner(AppRunner):
|
||||
queue_manager: AppQueueManager,
|
||||
message: Message,
|
||||
user_id: str,
|
||||
model_instance: ModelInstance,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
prompt_messages: Optional[list[PromptMessage]] = None,
|
||||
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
||||
db_variables: Optional[ToolConversationVariables] = None,
|
||||
model_instance: ModelInstance = None,
|
||||
) -> None:
|
||||
"""
|
||||
Agent runner
|
||||
@ -93,8 +88,6 @@ class BaseAgentRunner(AppRunner):
|
||||
self.user_id = user_id
|
||||
self.memory = memory
|
||||
self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
|
||||
self.variables_pool = variables_pool
|
||||
self.db_variables_pool = db_variables
|
||||
self.model_instance = model_instance
|
||||
|
||||
# init callback
|
||||
@ -162,11 +155,10 @@ class BaseAgentRunner(AppRunner):
|
||||
agent_tool=tool,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
)
|
||||
tool_entity.load_variables(self.variables_pool)
|
||||
|
||||
assert tool_entity.entity.description
|
||||
message_tool = PromptMessageTool(
|
||||
name=tool.tool_name,
|
||||
description=tool_entity.description.llm,
|
||||
description=tool_entity.entity.description.llm,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
@ -201,9 +193,11 @@ class BaseAgentRunner(AppRunner):
|
||||
"""
|
||||
convert dataset retriever tool to prompt message tool
|
||||
"""
|
||||
assert tool.entity.description
|
||||
|
||||
prompt_tool = PromptMessageTool(
|
||||
name=tool.identity.name,
|
||||
description=tool.description.llm,
|
||||
name=tool.entity.identity.name,
|
||||
description=tool.entity.description.llm,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
@ -232,7 +226,7 @@ class BaseAgentRunner(AppRunner):
|
||||
tool_instances = {}
|
||||
prompt_messages_tools = []
|
||||
|
||||
for tool in self.app_config.agent.tools if self.app_config.agent else []:
|
||||
for tool in self.app_config.agent.tools or [] if self.app_config.agent else []:
|
||||
try:
|
||||
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
|
||||
except Exception:
|
||||
@ -249,7 +243,7 @@ class BaseAgentRunner(AppRunner):
|
||||
# save prompt tool
|
||||
prompt_messages_tools.append(prompt_tool)
|
||||
# save tool entity
|
||||
tool_instances[dataset_tool.identity.name] = dataset_tool
|
||||
tool_instances[dataset_tool.entity.identity.name] = dataset_tool
|
||||
|
||||
return tool_instances, prompt_messages_tools
|
||||
|
||||
@ -328,25 +322,29 @@ class BaseAgentRunner(AppRunner):
|
||||
def save_agent_thought(
|
||||
self,
|
||||
agent_thought: MessageAgentThought,
|
||||
tool_name: str,
|
||||
tool_input: Union[str, dict],
|
||||
thought: str,
|
||||
observation: Union[str, dict],
|
||||
tool_invoke_meta: Union[str, dict],
|
||||
answer: str,
|
||||
tool_name: str | None,
|
||||
tool_input: Union[str, dict, None],
|
||||
thought: str | None,
|
||||
observation: Union[str, dict, None],
|
||||
tool_invoke_meta: Union[str, dict, None],
|
||||
answer: str | None,
|
||||
messages_ids: list[str],
|
||||
llm_usage: LLMUsage = None,
|
||||
) -> MessageAgentThought:
|
||||
llm_usage: LLMUsage | None = None,
|
||||
):
|
||||
"""
|
||||
Save agent thought
|
||||
"""
|
||||
agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
|
||||
updated_agent_thought = (
|
||||
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
|
||||
)
|
||||
if not updated_agent_thought:
|
||||
raise ValueError("agent thought not found")
|
||||
|
||||
if thought is not None:
|
||||
agent_thought.thought = thought
|
||||
updated_agent_thought.thought = thought
|
||||
|
||||
if tool_name is not None:
|
||||
agent_thought.tool = tool_name
|
||||
updated_agent_thought.tool = tool_name
|
||||
|
||||
if tool_input is not None:
|
||||
if isinstance(tool_input, dict):
|
||||
@ -355,7 +353,7 @@ class BaseAgentRunner(AppRunner):
|
||||
except Exception as e:
|
||||
tool_input = json.dumps(tool_input)
|
||||
|
||||
agent_thought.tool_input = tool_input
|
||||
updated_agent_thought.tool_input = tool_input
|
||||
|
||||
if observation is not None:
|
||||
if isinstance(observation, dict):
|
||||
@ -364,27 +362,27 @@ class BaseAgentRunner(AppRunner):
|
||||
except Exception as e:
|
||||
observation = json.dumps(observation)
|
||||
|
||||
agent_thought.observation = observation
|
||||
updated_agent_thought.observation = observation
|
||||
|
||||
if answer is not None:
|
||||
agent_thought.answer = answer
|
||||
updated_agent_thought.answer = answer
|
||||
|
||||
if messages_ids is not None and len(messages_ids) > 0:
|
||||
agent_thought.message_files = json.dumps(messages_ids)
|
||||
updated_agent_thought.message_files = json.dumps(messages_ids)
|
||||
|
||||
if llm_usage:
|
||||
agent_thought.message_token = llm_usage.prompt_tokens
|
||||
agent_thought.message_price_unit = llm_usage.prompt_price_unit
|
||||
agent_thought.message_unit_price = llm_usage.prompt_unit_price
|
||||
agent_thought.answer_token = llm_usage.completion_tokens
|
||||
agent_thought.answer_price_unit = llm_usage.completion_price_unit
|
||||
agent_thought.answer_unit_price = llm_usage.completion_unit_price
|
||||
agent_thought.tokens = llm_usage.total_tokens
|
||||
agent_thought.total_price = llm_usage.total_price
|
||||
updated_agent_thought.message_token = llm_usage.prompt_tokens
|
||||
updated_agent_thought.message_price_unit = llm_usage.prompt_price_unit
|
||||
updated_agent_thought.message_unit_price = llm_usage.prompt_unit_price
|
||||
updated_agent_thought.answer_token = llm_usage.completion_tokens
|
||||
updated_agent_thought.answer_price_unit = llm_usage.completion_price_unit
|
||||
updated_agent_thought.answer_unit_price = llm_usage.completion_unit_price
|
||||
updated_agent_thought.tokens = llm_usage.total_tokens
|
||||
updated_agent_thought.total_price = llm_usage.total_price
|
||||
|
||||
# check if tool labels is not empty
|
||||
labels = agent_thought.tool_labels or {}
|
||||
tools = agent_thought.tool.split(";") if agent_thought.tool else []
|
||||
labels = updated_agent_thought.tool_labels or {}
|
||||
tools = updated_agent_thought.tool.split(";") if updated_agent_thought.tool else []
|
||||
for tool in tools:
|
||||
if not tool:
|
||||
continue
|
||||
@ -395,7 +393,7 @@ class BaseAgentRunner(AppRunner):
|
||||
else:
|
||||
labels[tool] = {"en_US": tool, "zh_Hans": tool}
|
||||
|
||||
agent_thought.tool_labels_str = json.dumps(labels)
|
||||
updated_agent_thought.tool_labels_str = json.dumps(labels)
|
||||
|
||||
if tool_invoke_meta is not None:
|
||||
if isinstance(tool_invoke_meta, dict):
|
||||
@ -404,28 +402,11 @@ class BaseAgentRunner(AppRunner):
|
||||
except Exception as e:
|
||||
tool_invoke_meta = json.dumps(tool_invoke_meta)
|
||||
|
||||
agent_thought.tool_meta_str = tool_invoke_meta
|
||||
updated_agent_thought.tool_meta_str = tool_invoke_meta
|
||||
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
|
||||
"""
|
||||
convert tool variables to db variables
|
||||
"""
|
||||
db_variables = (
|
||||
db.session.query(ToolConversationVariables)
|
||||
.filter(
|
||||
ToolConversationVariables.conversation_id == self.message.conversation_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize agent history
|
||||
@ -515,6 +496,7 @@ class BaseAgentRunner(AppRunner):
|
||||
|
||||
files = message.message_files
|
||||
if files:
|
||||
assert message.app_model_config
|
||||
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
|
||||
|
||||
if file_extra_config:
|
||||
@ -525,7 +507,7 @@ class BaseAgentRunner(AppRunner):
|
||||
if not file_objs:
|
||||
return UserPromptMessage(content=message.query)
|
||||
else:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=message.query)]
|
||||
prompt_message_contents: list[PromptMessageContent] = [TextPromptMessageContent(data=message.query)]
|
||||
for file_obj in file_objs:
|
||||
prompt_message_contents.append(file_obj.prompt_message_content)
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
@ -12,6 +12,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
@ -26,11 +27,11 @@ from models.model import Message
|
||||
class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
_is_first_iteration = True
|
||||
_ignore_observation_providers = ["wenxin"]
|
||||
_historic_prompt_messages: list[PromptMessage] = None
|
||||
_agent_scratchpad: list[AgentScratchpadUnit] = None
|
||||
_instruction: str = None
|
||||
_query: str = None
|
||||
_prompt_messages_tools: list[PromptMessage] = None
|
||||
_historic_prompt_messages: list[PromptMessage]
|
||||
_agent_scratchpad: list[AgentScratchpadUnit]
|
||||
_instruction: str
|
||||
_query: str
|
||||
_prompt_messages_tools: Sequence[PromptMessageTool]
|
||||
|
||||
def run(
|
||||
self,
|
||||
@ -41,6 +42,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
"""
|
||||
Run Cot agent application
|
||||
"""
|
||||
|
||||
app_generate_entity = self.application_generate_entity
|
||||
self._repack_app_generate_entity(app_generate_entity)
|
||||
self._init_react_state(query)
|
||||
@ -53,9 +55,11 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
app_generate_entity.model_conf.stop.append("Observation")
|
||||
|
||||
app_config = self.app_config
|
||||
assert app_config.agent
|
||||
|
||||
# init instruction
|
||||
inputs = inputs or {}
|
||||
assert app_config.prompt_template.simple_prompt_template
|
||||
instruction = app_config.prompt_template.simple_prompt_template
|
||||
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
|
||||
|
||||
@ -63,13 +67,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
|
||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||
self._prompt_messages_tools = prompt_messages_tools
|
||||
|
||||
function_call_state = True
|
||||
llm_usage = {"usage": None}
|
||||
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
|
||||
final_answer = ""
|
||||
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||
def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
|
||||
if not final_llm_usage_dict["usage"]:
|
||||
final_llm_usage_dict["usage"] = usage
|
||||
else:
|
||||
@ -115,10 +120,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# check llm result
|
||||
if not chunks:
|
||||
raise ValueError("failed to invoke llm")
|
||||
|
||||
usage_dict = {}
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
@ -139,11 +140,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
if isinstance(chunk, AgentScratchpadUnit.Action):
|
||||
action = chunk
|
||||
# detect action
|
||||
assert scratchpad.agent_response is not None
|
||||
scratchpad.agent_response += json.dumps(chunk.model_dump())
|
||||
scratchpad.action_str = json.dumps(chunk.model_dump())
|
||||
scratchpad.action = action
|
||||
else:
|
||||
assert scratchpad.agent_response is not None
|
||||
scratchpad.agent_response += chunk
|
||||
assert scratchpad.thought is not None
|
||||
scratchpad.thought += chunk
|
||||
yield LLMResultChunk(
|
||||
model=self.model_config.model,
|
||||
@ -152,6 +156,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
|
||||
)
|
||||
|
||||
assert scratchpad.thought is not None
|
||||
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
|
||||
self._agent_scratchpad.append(scratchpad)
|
||||
|
||||
@ -168,7 +173,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
tool_invoke_meta={},
|
||||
thought=scratchpad.thought,
|
||||
observation="",
|
||||
answer=scratchpad.agent_response,
|
||||
answer=scratchpad.agent_response or "",
|
||||
messages_ids=[],
|
||||
llm_usage=usage_dict["usage"],
|
||||
)
|
||||
@ -248,7 +253,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
messages_ids=[],
|
||||
)
|
||||
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
@ -266,7 +270,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
def _handle_invoke_action(
|
||||
self,
|
||||
action: AgentScratchpadUnit.Action,
|
||||
tool_instances: dict[str, Tool],
|
||||
tool_instances: Mapping[str, Tool],
|
||||
message_file_ids: list[str],
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> tuple[str, ToolInvokeMeta]:
|
||||
@ -307,15 +311,12 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
|
||||
# publish files
|
||||
for message_file_id, save_as in message_files:
|
||||
if save_as:
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
|
||||
|
||||
# publish message file
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
||||
QueueMessageFileEvent(message_file_id=message_file_id.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file_id)
|
||||
message_file_ids.append(message_file_id.id)
|
||||
|
||||
return tool_invoke_response, tool_invoke_meta
|
||||
|
||||
@ -369,18 +370,19 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
return message
|
||||
|
||||
def _organize_historic_prompt_messages(
|
||||
self, current_session_messages: list[PromptMessage] = None
|
||||
self, current_session_messages: list[PromptMessage] | None = None
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
organize historic prompt messages
|
||||
"""
|
||||
result: list[PromptMessage] = []
|
||||
scratchpads: list[AgentScratchpadUnit] = []
|
||||
current_scratchpad: AgentScratchpadUnit = None
|
||||
current_scratchpad: AgentScratchpadUnit | None = None
|
||||
|
||||
for message in self.history_prompt_messages:
|
||||
if isinstance(message, AssistantPromptMessage):
|
||||
if not current_scratchpad:
|
||||
assert isinstance(message.content, str)
|
||||
current_scratchpad = AgentScratchpadUnit(
|
||||
agent_response=message.content,
|
||||
thought=message.content or "I am thinking about how to help you",
|
||||
@ -400,6 +402,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
pass
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
if current_scratchpad:
|
||||
assert isinstance(message.content, str)
|
||||
current_scratchpad.observation = message.content
|
||||
elif isinstance(message, UserPromptMessage):
|
||||
if scratchpads:
|
||||
|
||||
@ -4,6 +4,7 @@ from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
@ -16,6 +17,9 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
"""
|
||||
Organize system prompt
|
||||
"""
|
||||
assert self.app_config.agent
|
||||
assert self.app_config.agent.prompt
|
||||
|
||||
prompt_entity = self.app_config.agent.prompt
|
||||
first_prompt = prompt_entity.first_prompt
|
||||
|
||||
@ -27,12 +31,12 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
|
||||
return SystemPromptMessage(content=system_prompt)
|
||||
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
if self.files:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=query)]
|
||||
prompt_message_contents: list[PromptMessageContent] = [TextPromptMessageContent(data=query)]
|
||||
for file_obj in self.files:
|
||||
prompt_message_contents.append(file_obj.prompt_message_content)
|
||||
|
||||
@ -57,8 +61,10 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
assistant_message = AssistantPromptMessage(content="")
|
||||
for unit in agent_scratchpad:
|
||||
if unit.is_final():
|
||||
assert isinstance(assistant_message.content, str)
|
||||
assistant_message.content += f"Final Answer: {unit.agent_response}"
|
||||
else:
|
||||
assert isinstance(assistant_message.content, str)
|
||||
assistant_message.content += f"Thought: {unit.thought}\n\n"
|
||||
if unit.action_str:
|
||||
assistant_message.content += f"Action: {unit.action_str}\n\n"
|
||||
|
||||
@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from typing import Any, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
@ -11,6 +11,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentType,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
@ -38,18 +39,20 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||
|
||||
assert app_config.agent
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
|
||||
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = True
|
||||
llm_usage = {"usage": None}
|
||||
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
|
||||
final_answer = ""
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = app_generate_entity.trace_manager
|
||||
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||
def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
|
||||
if not final_llm_usage_dict["usage"]:
|
||||
final_llm_usage_dict["usage"] = usage
|
||||
else:
|
||||
@ -99,7 +102,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
current_llm_usage = None
|
||||
|
||||
if self.stream_tool_call:
|
||||
if isinstance(chunks, Generator):
|
||||
is_first_chunk = True
|
||||
for chunk in chunks:
|
||||
if is_first_chunk:
|
||||
@ -133,7 +136,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
yield chunk
|
||||
else:
|
||||
result: LLMResult = chunks
|
||||
result = chunks
|
||||
# check if there is any tool call
|
||||
if self.check_blocking_tool_calls(result):
|
||||
function_call_state = True
|
||||
@ -236,15 +239,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
)
|
||||
# publish files
|
||||
for message_file_id, save_as in message_files:
|
||||
if save_as:
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
|
||||
|
||||
# publish message file
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
||||
QueueMessageFileEvent(message_file_id=message_file_id.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file_id)
|
||||
message_file_ids.append(message_file_id.id)
|
||||
|
||||
tool_response = {
|
||||
"tool_call_id": tool_call_id,
|
||||
@ -290,7 +290,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
iteration_step += 1
|
||||
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
@ -321,9 +320,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
return True
|
||||
return False
|
||||
|
||||
def extract_tool_calls(
|
||||
self, llm_result_chunk: LLMResultChunk
|
||||
) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""
|
||||
Extract tool calls from llm result chunk
|
||||
|
||||
@ -346,7 +343,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
return tool_calls
|
||||
|
||||
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""
|
||||
Extract blocking tool calls from llm result
|
||||
|
||||
@ -370,7 +367,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
return tool_calls
|
||||
|
||||
def _init_system_message(
|
||||
self, prompt_template: str, prompt_messages: list[PromptMessage] = None
|
||||
self, prompt_template: str, prompt_messages: list[PromptMessage]
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Initialize system message
|
||||
@ -385,12 +382,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
if self.files:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=query)]
|
||||
prompt_message_contents: list[PromptMessageContent] = [TextPromptMessageContent(data=query)]
|
||||
for file_obj in self.files:
|
||||
prompt_message_contents.append(file_obj.prompt_message_content)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user