mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 00:48:04 +08:00
refactor apps
This commit is contained in:
@ -5,17 +5,15 @@ from datetime import datetime
|
||||
from mimetypes import guess_extension
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity
|
||||
from core.app.app_queue_manager import AppQueueManager
|
||||
from core.app.base_app_runner import AppRunner
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.entities.application_entities import (
|
||||
AgentEntity,
|
||||
AgentToolEntity,
|
||||
ApplicationGenerateEntity,
|
||||
AppOrchestrationConfigEntity,
|
||||
InvokeFrom,
|
||||
ModelConfigEntity,
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
EasyUIBasedAppGenerateEntity,
|
||||
InvokeFrom, EasyUIBasedModelConfigEntity,
|
||||
)
|
||||
from core.file.message_file_parser import FileTransferMethod
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
@ -50,9 +48,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseAgentRunner(AppRunner):
|
||||
def __init__(self, tenant_id: str,
|
||||
application_generate_entity: ApplicationGenerateEntity,
|
||||
app_orchestration_config: AppOrchestrationConfigEntity,
|
||||
model_config: ModelConfigEntity,
|
||||
application_generate_entity: EasyUIBasedAppGenerateEntity,
|
||||
app_config: AgentChatAppConfig,
|
||||
model_config: EasyUIBasedModelConfigEntity,
|
||||
config: AgentEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
message: Message,
|
||||
@ -66,7 +64,7 @@ class BaseAgentRunner(AppRunner):
|
||||
"""
|
||||
Agent runner
|
||||
:param tenant_id: tenant id
|
||||
:param app_orchestration_config: app orchestration config
|
||||
:param app_config: app generate entity
|
||||
:param model_config: model config
|
||||
:param config: dataset config
|
||||
:param queue_manager: queue manager
|
||||
@ -78,7 +76,7 @@ class BaseAgentRunner(AppRunner):
|
||||
"""
|
||||
self.tenant_id = tenant_id
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.app_orchestration_config = app_orchestration_config
|
||||
self.app_config = app_config
|
||||
self.model_config = model_config
|
||||
self.config = config
|
||||
self.queue_manager = queue_manager
|
||||
@ -97,16 +95,16 @@ class BaseAgentRunner(AppRunner):
|
||||
# init dataset tools
|
||||
hit_callback = DatasetIndexToolCallbackHandler(
|
||||
queue_manager=queue_manager,
|
||||
app_id=self.application_generate_entity.app_id,
|
||||
app_id=self.app_config.app_id,
|
||||
message_id=message.id,
|
||||
user_id=user_id,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
)
|
||||
self.dataset_tools = DatasetRetrieverTool.get_dataset_tools(
|
||||
tenant_id=tenant_id,
|
||||
dataset_ids=app_orchestration_config.dataset.dataset_ids if app_orchestration_config.dataset else [],
|
||||
retrieve_config=app_orchestration_config.dataset.retrieve_config if app_orchestration_config.dataset else None,
|
||||
return_resource=app_orchestration_config.show_retrieve_source,
|
||||
dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],
|
||||
retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
|
||||
return_resource=app_config.additional_features.show_retrieve_source,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
hit_callback=hit_callback
|
||||
)
|
||||
@ -124,14 +122,15 @@ class BaseAgentRunner(AppRunner):
|
||||
else:
|
||||
self.stream_tool_call = False
|
||||
|
||||
def _repack_app_orchestration_config(self, app_orchestration_config: AppOrchestrationConfigEntity) -> AppOrchestrationConfigEntity:
|
||||
def _repack_app_generate_entity(self, app_generate_entity: EasyUIBasedAppGenerateEntity) \
|
||||
-> EasyUIBasedAppGenerateEntity:
|
||||
"""
|
||||
Repack app orchestration config
|
||||
Repack app generate entity
|
||||
"""
|
||||
if app_orchestration_config.prompt_template.simple_prompt_template is None:
|
||||
app_orchestration_config.prompt_template.simple_prompt_template = ''
|
||||
if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
|
||||
app_generate_entity.app_config.prompt_template.simple_prompt_template = ''
|
||||
|
||||
return app_orchestration_config
|
||||
return app_generate_entity
|
||||
|
||||
def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str:
|
||||
"""
|
||||
@ -351,7 +350,7 @@ class BaseAgentRunner(AppRunner):
|
||||
))
|
||||
|
||||
db.session.close()
|
||||
|
||||
|
||||
return result
|
||||
|
||||
def create_agent_thought(self, message_id: str, message: str,
|
||||
@ -462,7 +461,7 @@ class BaseAgentRunner(AppRunner):
|
||||
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
|
||||
def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]:
|
||||
"""
|
||||
Transform tool message into agent thought
|
||||
|
||||
@ -5,7 +5,7 @@ from typing import Literal, Union
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.app.app_queue_manager import PublishFrom
|
||||
from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit
|
||||
from core.agent.entities import AgentPromptEntity, AgentScratchpadUnit
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
@ -27,7 +27,7 @@ from core.tools.errors import (
|
||||
from models.model import Conversation, Message
|
||||
|
||||
|
||||
class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
class CotAgentRunner(BaseAgentRunner):
|
||||
_is_first_iteration = True
|
||||
_ignore_observation_providers = ['wenxin']
|
||||
|
||||
@ -39,30 +39,33 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
"""
|
||||
Run Cot agent application
|
||||
"""
|
||||
app_orchestration_config = self.app_orchestration_config
|
||||
self._repack_app_orchestration_config(app_orchestration_config)
|
||||
app_generate_entity = self.application_generate_entity
|
||||
self._repack_app_generate_entity(app_generate_entity)
|
||||
|
||||
agent_scratchpad: list[AgentScratchpadUnit] = []
|
||||
self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages)
|
||||
|
||||
if 'Observation' not in app_orchestration_config.model_config.stop:
|
||||
if app_orchestration_config.model_config.provider not in self._ignore_observation_providers:
|
||||
app_orchestration_config.model_config.stop.append('Observation')
|
||||
# check model mode
|
||||
if 'Observation' not in app_generate_entity.model_config.stop:
|
||||
if app_generate_entity.model_config.provider not in self._ignore_observation_providers:
|
||||
app_generate_entity.model_config.stop.append('Observation')
|
||||
|
||||
app_config = self.app_config
|
||||
|
||||
# override inputs
|
||||
inputs = inputs or {}
|
||||
instruction = self.app_orchestration_config.prompt_template.simple_prompt_template
|
||||
instruction = app_config.prompt_template.simple_prompt_template
|
||||
instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(self.app_orchestration_config.agent.max_iteration, 5) + 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
|
||||
|
||||
prompt_messages = self.history_prompt_messages
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
prompt_messages_tools: list[PromptMessageTool] = []
|
||||
tool_instances = {}
|
||||
for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
|
||||
for tool in app_config.agent.tools if app_config.agent else []:
|
||||
try:
|
||||
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
|
||||
except Exception:
|
||||
@ -122,11 +125,11 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
|
||||
# update prompt messages
|
||||
prompt_messages = self._organize_cot_prompt_messages(
|
||||
mode=app_orchestration_config.model_config.mode,
|
||||
mode=app_generate_entity.model_config.mode,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=prompt_messages_tools,
|
||||
agent_scratchpad=agent_scratchpad,
|
||||
agent_prompt_message=app_orchestration_config.agent.prompt,
|
||||
agent_prompt_message=app_config.agent.prompt,
|
||||
instruction=instruction,
|
||||
input=query
|
||||
)
|
||||
@ -136,9 +139,9 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
# invoke model
|
||||
chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_orchestration_config.model_config.parameters,
|
||||
model_parameters=app_generate_entity.model_config.parameters,
|
||||
tools=[],
|
||||
stop=app_orchestration_config.model_config.stop,
|
||||
stop=app_generate_entity.model_config.stop,
|
||||
stream=True,
|
||||
user=self.user_id,
|
||||
callbacks=[],
|
||||
@ -550,7 +553,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
"""
|
||||
convert agent scratchpad list to str
|
||||
"""
|
||||
next_iteration = self.app_orchestration_config.agent.prompt.next_iteration
|
||||
next_iteration = self.app_config.agent.prompt.next_iteration
|
||||
|
||||
result = ''
|
||||
for scratchpad in agent_scratchpad:
|
||||
|
||||
61
api/core/agent/entities.py
Normal file
61
api/core/agent/entities.py
Normal file
@ -0,0 +1,61 @@
|
||||
from enum import Enum
|
||||
from typing import Literal, Any, Union, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AgentToolEntity(BaseModel):
|
||||
"""
|
||||
Agent Tool Entity.
|
||||
"""
|
||||
provider_type: Literal["builtin", "api"]
|
||||
provider_id: str
|
||||
tool_name: str
|
||||
tool_parameters: dict[str, Any] = {}
|
||||
|
||||
|
||||
class AgentPromptEntity(BaseModel):
|
||||
"""
|
||||
Agent Prompt Entity.
|
||||
"""
|
||||
first_prompt: str
|
||||
next_iteration: str
|
||||
|
||||
|
||||
class AgentScratchpadUnit(BaseModel):
|
||||
"""
|
||||
Agent First Prompt Entity.
|
||||
"""
|
||||
|
||||
class Action(BaseModel):
|
||||
"""
|
||||
Action Entity.
|
||||
"""
|
||||
action_name: str
|
||||
action_input: Union[dict, str]
|
||||
|
||||
agent_response: Optional[str] = None
|
||||
thought: Optional[str] = None
|
||||
action_str: Optional[str] = None
|
||||
observation: Optional[str] = None
|
||||
action: Optional[Action] = None
|
||||
|
||||
|
||||
class AgentEntity(BaseModel):
|
||||
"""
|
||||
Agent Entity.
|
||||
"""
|
||||
|
||||
class Strategy(Enum):
|
||||
"""
|
||||
Agent Strategy.
|
||||
"""
|
||||
CHAIN_OF_THOUGHT = 'chain-of-thought'
|
||||
FUNCTION_CALLING = 'function-calling'
|
||||
|
||||
provider: str
|
||||
model: str
|
||||
strategy: Strategy
|
||||
prompt: Optional[AgentPromptEntity] = None
|
||||
tools: list[AgentToolEntity] = None
|
||||
max_iteration: int = 5
|
||||
@ -34,9 +34,11 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
"""
|
||||
Run FunctionCall agent application
|
||||
"""
|
||||
app_orchestration_config = self.app_orchestration_config
|
||||
app_generate_entity = self.application_generate_entity
|
||||
|
||||
prompt_template = self.app_orchestration_config.prompt_template.simple_prompt_template or ''
|
||||
app_config = self.app_config
|
||||
|
||||
prompt_template = app_config.prompt_template.simple_prompt_template or ''
|
||||
prompt_messages = self.history_prompt_messages
|
||||
prompt_messages = self.organize_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
@ -47,7 +49,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
# convert tools into ModelRuntime Tool format
|
||||
prompt_messages_tools: list[PromptMessageTool] = []
|
||||
tool_instances = {}
|
||||
for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
|
||||
for tool in app_config.agent.tools if app_config.agent else []:
|
||||
try:
|
||||
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
|
||||
except Exception:
|
||||
@ -67,7 +69,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
tool_instances[dataset_tool.identity.name] = dataset_tool
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_orchestration_config.agent.max_iteration, 5) + 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
|
||||
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = True
|
||||
@ -110,9 +112,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
# invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_orchestration_config.model_config.parameters,
|
||||
model_parameters=app_generate_entity.model_config.parameters,
|
||||
tools=prompt_messages_tools,
|
||||
stop=app_orchestration_config.model_config.stop,
|
||||
stop=app_generate_entity.model_config.stop,
|
||||
stream=self.stream_tool_call,
|
||||
user=self.user_id,
|
||||
callbacks=[],
|
||||
|
||||
Reference in New Issue
Block a user