feat: support agent node

This commit is contained in:
Yeuoly
2024-12-09 23:02:11 +08:00
parent 16b49ac436
commit ae72514cb4
15 changed files with 459 additions and 34 deletions

View File

@ -0,0 +1,3 @@
from .agent_node import AgentNode
__all__ = ["AgentNode"]

View File

@ -0,0 +1,85 @@
from collections.abc import Generator
from typing import cast
from core.plugin.manager.exc import PluginDaemonClientSideError
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.agent.entities import AgentNodeData
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event.event import RunCompletedEvent
from core.workflow.nodes.tool.tool_node import ToolNode
from factories.agent_factory import get_plugin_agent_strategy
from models.workflow import WorkflowNodeExecutionStatus
class AgentNode(ToolNode):
"""
Agent Node
"""
_node_data_cls = AgentNodeData
_node_type = NodeType.AGENT
def _run(self) -> Generator:
"""
Run the agent node
"""
node_data = cast(AgentNodeData, self.node_data)
try:
strategy = get_plugin_agent_strategy(
tenant_id=self.tenant_id,
plugin_unique_identifier=node_data.plugin_unique_identifier,
agent_strategy_provider_name=node_data.agent_strategy_provider_name,
agent_strategy_name=node_data.agent_strategy_name,
)
except Exception as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
error=f"Failed to get agent strategy: {str(e)}",
)
)
return
agent_parameters = strategy.get_parameters()
# get parameters
parameters = self._generate_parameters(
tool_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
)
parameters_for_log = self._generate_parameters(
tool_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
for_log=True,
)
try:
message_stream = strategy.invoke(
params=parameters,
user_id=self.user_id,
app_id=self.app_id,
# TODO: conversation id and message id
)
except Exception as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
error=f"Failed to invoke agent: {str(e)}",
)
)
try:
# convert tool messages
yield from self._transform_message(message_stream, {}, parameters_for_log)
except PluginDaemonClientSideError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
error=f"Failed to transform agent message: {str(e)}",
)
)

View File

@ -0,0 +1,51 @@
from typing import Any, Literal, Union
from pydantic import BaseModel, ValidationInfo, field_validator
from core.workflow.nodes.base.entities import BaseNodeData
class AgentEntity(BaseModel):
agent_strategy_provider_name: str # redundancy
agent_strategy_name: str
agent_strategy_label: str # redundancy
agent_configurations: dict[str, Any]
plugin_unique_identifier: str
@field_validator("agent_configurations", mode="before")
@classmethod
def validate_agent_configurations(cls, value, values: ValidationInfo):
if not isinstance(value, dict):
raise ValueError("agent_configurations must be a dictionary")
for key in values.data.get("agent_configurations", {}):
value = values.data.get("agent_configurations", {}).get(key)
if not isinstance(value, str | int | float | bool):
raise ValueError(f"{key} must be a string")
return value
class AgentNodeData(BaseNodeData, AgentEntity):
class AgentInput(BaseModel):
# TODO: check this type
value: Union[Any, list[str]]
type: Literal["mixed", "variable", "constant"]
@field_validator("type", mode="before")
@classmethod
def check_type(cls, value, validation_info: ValidationInfo):
typ = value
value = validation_info.data.get("value")
if typ == "mixed" and not isinstance(value, str):
raise ValueError("value must be a string")
elif typ == "variable":
if not isinstance(value, list):
raise ValueError("value must be a list")
for val in value:
if not isinstance(val, str):
raise ValueError("value must be a list of strings")
elif typ == "constant" and not isinstance(value, str | int | float | bool):
raise ValueError("value must be a string, int, float, or bool")
return typ
agent_parameters: dict[str, AgentInput]

View File

@ -22,3 +22,4 @@ class NodeType(StrEnum):
VARIABLE_ASSIGNER = "assigner"
DOCUMENT_EXTRACTOR = "document-extractor"
LIST_OPERATOR = "list-operator"
AGENT = "agent"

View File

@ -1,5 +1,6 @@
from collections.abc import Mapping
from core.workflow.nodes.agent.agent_node import AgentNode
from core.workflow.nodes.answer import AnswerNode
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.code import CodeNode
@ -101,4 +102,8 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
LATEST_VERSION: ListOperatorNode,
"1": ListOperatorNode,
},
NodeType.AGENT: {
LATEST_VERSION: AgentNode,
"1": AgentNode,
},
}

View File

@ -234,8 +234,8 @@ class ParameterExtractorNode(LLMNode):
if not isinstance(invoke_result, LLMResult):
raise InvalidInvokeResultError(f"Invalid invoke result: {invoke_result}")
text = invoke_result.message.content
if not isinstance(text, str | None):
text = invoke_result.message.content or ""
if not isinstance(text, str):
raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.")
usage = invoke_result.usage