mirror of
https://github.com/langgenius/dify.git
synced 2026-04-30 07:28:05 +08:00
feat: support agent node
This commit is contained in:
3
api/core/workflow/nodes/agent/__init__.py
Normal file
3
api/core/workflow/nodes/agent/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .agent_node import AgentNode
|
||||
|
||||
__all__ = ["AgentNode"]
|
||||
85
api/core/workflow/nodes/agent/agent_node.py
Normal file
85
api/core/workflow/nodes/agent/agent_node.py
Normal file
@ -0,0 +1,85 @@
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
from core.plugin.manager.exc import PluginDaemonClientSideError
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.agent.entities import AgentNodeData
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event.event import RunCompletedEvent
|
||||
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||
from factories.agent_factory import get_plugin_agent_strategy
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class AgentNode(ToolNode):
|
||||
"""
|
||||
Agent Node
|
||||
"""
|
||||
|
||||
_node_data_cls = AgentNodeData
|
||||
_node_type = NodeType.AGENT
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""
|
||||
Run the agent node
|
||||
"""
|
||||
node_data = cast(AgentNodeData, self.node_data)
|
||||
|
||||
try:
|
||||
strategy = get_plugin_agent_strategy(
|
||||
tenant_id=self.tenant_id,
|
||||
plugin_unique_identifier=node_data.plugin_unique_identifier,
|
||||
agent_strategy_provider_name=node_data.agent_strategy_provider_name,
|
||||
agent_strategy_name=node_data.agent_strategy_name,
|
||||
)
|
||||
except Exception as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
error=f"Failed to get agent strategy: {str(e)}",
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
agent_parameters = strategy.get_parameters()
|
||||
|
||||
# get parameters
|
||||
parameters = self._generate_parameters(
|
||||
tool_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self.node_data,
|
||||
)
|
||||
parameters_for_log = self._generate_parameters(
|
||||
tool_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self.node_data,
|
||||
for_log=True,
|
||||
)
|
||||
|
||||
try:
|
||||
message_stream = strategy.invoke(
|
||||
params=parameters,
|
||||
user_id=self.user_id,
|
||||
app_id=self.app_id,
|
||||
# TODO: conversation id and message id
|
||||
)
|
||||
except Exception as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
error=f"Failed to invoke agent: {str(e)}",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# convert tool messages
|
||||
yield from self._transform_message(message_stream, {}, parameters_for_log)
|
||||
except PluginDaemonClientSideError as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
error=f"Failed to transform agent message: {str(e)}",
|
||||
)
|
||||
)
|
||||
51
api/core/workflow/nodes/agent/entities.py
Normal file
51
api/core/workflow/nodes/agent/entities.py
Normal file
@ -0,0 +1,51 @@
|
||||
from typing import Any, Literal, Union
|
||||
from pydantic import BaseModel, ValidationInfo, field_validator
|
||||
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
|
||||
|
||||
class AgentEntity(BaseModel):
|
||||
agent_strategy_provider_name: str # redundancy
|
||||
agent_strategy_name: str
|
||||
agent_strategy_label: str # redundancy
|
||||
agent_configurations: dict[str, Any]
|
||||
plugin_unique_identifier: str
|
||||
|
||||
@field_validator("agent_configurations", mode="before")
|
||||
@classmethod
|
||||
def validate_agent_configurations(cls, value, values: ValidationInfo):
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError("agent_configurations must be a dictionary")
|
||||
|
||||
for key in values.data.get("agent_configurations", {}):
|
||||
value = values.data.get("agent_configurations", {}).get(key)
|
||||
if not isinstance(value, str | int | float | bool):
|
||||
raise ValueError(f"{key} must be a string")
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class AgentNodeData(BaseNodeData, AgentEntity):
|
||||
class AgentInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Union[Any, list[str]]
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
def check_type(cls, value, validation_info: ValidationInfo):
|
||||
typ = value
|
||||
value = validation_info.data.get("value")
|
||||
if typ == "mixed" and not isinstance(value, str):
|
||||
raise ValueError("value must be a string")
|
||||
elif typ == "variable":
|
||||
if not isinstance(value, list):
|
||||
raise ValueError("value must be a list")
|
||||
for val in value:
|
||||
if not isinstance(val, str):
|
||||
raise ValueError("value must be a list of strings")
|
||||
elif typ == "constant" and not isinstance(value, str | int | float | bool):
|
||||
raise ValueError("value must be a string, int, float, or bool")
|
||||
return typ
|
||||
|
||||
agent_parameters: dict[str, AgentInput]
|
||||
@ -22,3 +22,4 @@ class NodeType(StrEnum):
|
||||
VARIABLE_ASSIGNER = "assigner"
|
||||
DOCUMENT_EXTRACTOR = "document-extractor"
|
||||
LIST_OPERATOR = "list-operator"
|
||||
AGENT = "agent"
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
from core.workflow.nodes.agent.agent_node import AgentNode
|
||||
from core.workflow.nodes.answer import AnswerNode
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.code import CodeNode
|
||||
@ -101,4 +102,8 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
|
||||
LATEST_VERSION: ListOperatorNode,
|
||||
"1": ListOperatorNode,
|
||||
},
|
||||
NodeType.AGENT: {
|
||||
LATEST_VERSION: AgentNode,
|
||||
"1": AgentNode,
|
||||
},
|
||||
}
|
||||
|
||||
@ -234,8 +234,8 @@ class ParameterExtractorNode(LLMNode):
|
||||
if not isinstance(invoke_result, LLMResult):
|
||||
raise InvalidInvokeResultError(f"Invalid invoke result: {invoke_result}")
|
||||
|
||||
text = invoke_result.message.content
|
||||
if not isinstance(text, str | None):
|
||||
text = invoke_result.message.content or ""
|
||||
if not isinstance(text, str):
|
||||
raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.")
|
||||
|
||||
usage = invoke_result.usage
|
||||
|
||||
Reference in New Issue
Block a user