mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 01:18:05 +08:00
Merge branch 'main' into feat/rag-2
This commit is contained in:
@ -232,14 +232,14 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
||||
Publish loop started
|
||||
"""
|
||||
self.print_text("\n[LoopRunStartedEvent]", color="blue")
|
||||
self.print_text(f"Loop Node ID: {event.loop_id}", color="blue")
|
||||
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
|
||||
|
||||
def on_workflow_loop_next(self, event: LoopRunNextEvent) -> None:
|
||||
"""
|
||||
Publish loop next
|
||||
"""
|
||||
self.print_text("\n[LoopRunNextEvent]", color="blue")
|
||||
self.print_text(f"Loop Node ID: {event.loop_id}", color="blue")
|
||||
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
|
||||
self.print_text(f"Loop Index: {event.index}", color="blue")
|
||||
|
||||
def on_workflow_loop_completed(self, event: LoopRunSucceededEvent | LoopRunFailedEvent) -> None:
|
||||
@ -250,7 +250,7 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
||||
"\n[LoopRunSucceededEvent]" if isinstance(event, LoopRunSucceededEvent) else "\n[LoopRunFailedEvent]",
|
||||
color="blue",
|
||||
)
|
||||
self.print_text(f"Node ID: {event.loop_id}", color="blue")
|
||||
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
|
||||
|
||||
def print_text(self, text: str, color: Optional[str] = None, end: str = "\n") -> None:
|
||||
"""Print text with highlighting and no end characters."""
|
||||
|
||||
@ -336,7 +336,7 @@ class Graph(BaseModel):
|
||||
|
||||
parallel = GraphParallel(
|
||||
start_from_node_id=start_node_id,
|
||||
parent_parallel_id=parent_parallel.id if parent_parallel else None,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None,
|
||||
)
|
||||
parallel_mapping[parallel.id] = parallel
|
||||
|
||||
@ -3,11 +3,13 @@ import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from packaging.version import Version
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.agent.strategy.plugin import PluginAgentStrategy
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
@ -73,12 +75,14 @@ class AgentNode(ToolNode):
|
||||
agent_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=node_data,
|
||||
strategy=strategy,
|
||||
)
|
||||
parameters_for_log = self._generate_agent_parameters(
|
||||
agent_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=node_data,
|
||||
for_log=True,
|
||||
strategy=strategy,
|
||||
)
|
||||
|
||||
# get conversation id
|
||||
@ -155,6 +159,7 @@ class AgentNode(ToolNode):
|
||||
variable_pool: VariablePool,
|
||||
node_data: AgentNodeData,
|
||||
for_log: bool = False,
|
||||
strategy: PluginAgentStrategy,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||
@ -207,7 +212,7 @@ class AgentNode(ToolNode):
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
value = [tool for tool in value if tool.get("enabled", False)]
|
||||
|
||||
value = self._filter_mcp_type_tool(strategy, value)
|
||||
for tool in value:
|
||||
if "schemas" in tool:
|
||||
tool.pop("schemas")
|
||||
@ -244,9 +249,9 @@ class AgentNode(ToolNode):
|
||||
)
|
||||
|
||||
extra = tool.get("extra", {})
|
||||
|
||||
runtime_variable_pool = variable_pool if self.node_data.version != "1" else None
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
self.tenant_id, self.app_id, entity, self.invoke_from
|
||||
self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool
|
||||
)
|
||||
if tool_runtime.entity.description:
|
||||
tool_runtime.entity.description.llm = (
|
||||
@ -398,3 +403,16 @@ class AgentNode(ToolNode):
|
||||
except ValueError:
|
||||
model_schema.features.remove(feature)
|
||||
return model_schema
|
||||
|
||||
def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Filter MCP type tool
|
||||
:param strategy: plugin agent strategy
|
||||
:param tool: tool
|
||||
:return: filtered tool dict
|
||||
"""
|
||||
meta_version = strategy.meta_version
|
||||
if meta_version and Version(meta_version) > Version("0.0.1"):
|
||||
return tools
|
||||
else:
|
||||
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP.value]
|
||||
|
||||
@ -75,6 +75,7 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
|
||||
},
|
||||
NodeType.TOOL: {
|
||||
LATEST_VERSION: ToolNode,
|
||||
"2": ToolNode,
|
||||
"1": ToolNode,
|
||||
},
|
||||
NodeType.VARIABLE_AGGREGATOR: {
|
||||
@ -124,6 +125,7 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
|
||||
},
|
||||
NodeType.AGENT: {
|
||||
LATEST_VERSION: AgentNode,
|
||||
"2": AgentNode,
|
||||
"1": AgentNode,
|
||||
},
|
||||
NodeType.DATASOURCE: {
|
||||
|
||||
@ -41,6 +41,10 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
def check_type(cls, value, validation_info: ValidationInfo):
|
||||
typ = value
|
||||
value = validation_info.data.get("value")
|
||||
|
||||
if value is None:
|
||||
return typ
|
||||
|
||||
if typ == "mixed" and not isinstance(value, str):
|
||||
raise ValueError("value must be a string")
|
||||
elif typ == "variable":
|
||||
@ -54,3 +58,22 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
return typ
|
||||
|
||||
tool_parameters: dict[str, ToolInput]
|
||||
|
||||
@field_validator("tool_parameters", mode="before")
|
||||
@classmethod
|
||||
def filter_none_tool_inputs(cls, value):
|
||||
if not isinstance(value, dict):
|
||||
return value
|
||||
|
||||
return {
|
||||
key: tool_input
|
||||
for key, tool_input in value.items()
|
||||
if tool_input is not None and cls._has_valid_value(tool_input)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _has_valid_value(tool_input):
|
||||
"""Check if the value is valid"""
|
||||
if isinstance(tool_input, dict):
|
||||
return tool_input.get("value") is not None
|
||||
return getattr(tool_input, "value", None) is not None
|
||||
|
||||
@ -67,8 +67,9 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
try:
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
variable_pool = self.graph_runtime_state.variable_pool if self.node_data.version != "1" else None
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(
|
||||
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from
|
||||
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from, variable_pool
|
||||
)
|
||||
except ToolNodeError as e:
|
||||
yield RunCompletedEvent(
|
||||
@ -95,7 +96,6 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
node_data=self.node_data,
|
||||
for_log=True,
|
||||
)
|
||||
|
||||
# get conversation id
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
|
||||
@ -285,7 +285,8 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
for key, value in msg_metadata.items()
|
||||
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||
}
|
||||
json.append(message.message.json_object)
|
||||
if message.message.json_object is not None:
|
||||
json.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
@ -328,6 +329,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
pass
|
||||
icon_dark = None
|
||||
try:
|
||||
builtin_tool = next(
|
||||
provider
|
||||
@ -338,10 +340,12 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
if provider.name == dict_metadata["provider"]
|
||||
)
|
||||
icon = builtin_tool.icon
|
||||
icon_dark = builtin_tool.icon_dark
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
dict_metadata["icon"] = icon
|
||||
dict_metadata["icon_dark"] = icon_dark
|
||||
message.message.metadata = dict_metadata
|
||||
agent_log = AgentLogEvent(
|
||||
id=message.message.id,
|
||||
@ -369,31 +373,31 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
agent_logs.append(agent_log)
|
||||
|
||||
yield agent_log
|
||||
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
|
||||
json_output: dict[str, Any] = {}
|
||||
if json:
|
||||
if isinstance(json, list) and len(json) == 1:
|
||||
# If json is a list with only one element, convert it to a dictionary
|
||||
json_output = json[0] if isinstance(json[0], dict) else {"data": json[0]}
|
||||
elif isinstance(json, list):
|
||||
# If json is a list with multiple elements, create a dictionary containing all data
|
||||
json_output = {"data": json}
|
||||
|
||||
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
|
||||
json_output: list[dict[str, Any]] = []
|
||||
|
||||
# Step 1: append each agent log as its own dict.
|
||||
if agent_logs:
|
||||
# Add agent_logs to json output
|
||||
json_output["agent_logs"] = [
|
||||
{
|
||||
"id": log.id,
|
||||
"parent_id": log.parent_id,
|
||||
"error": log.error,
|
||||
"status": log.status,
|
||||
"data": log.data,
|
||||
"label": log.label,
|
||||
"metadata": log.metadata,
|
||||
"node_id": log.node_id,
|
||||
}
|
||||
for log in agent_logs
|
||||
]
|
||||
for log in agent_logs:
|
||||
json_output.append(
|
||||
{
|
||||
"id": log.id,
|
||||
"parent_id": log.parent_id,
|
||||
"error": log.error,
|
||||
"status": log.status,
|
||||
"data": log.data,
|
||||
"label": log.label,
|
||||
"metadata": log.metadata,
|
||||
"node_id": log.node_id,
|
||||
}
|
||||
)
|
||||
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
|
||||
if json:
|
||||
json_output.extend(json)
|
||||
else:
|
||||
json_output.append({"data": []})
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
|
||||
Reference in New Issue
Block a user