mirror of
https://github.com/langgenius/dify.git
synced 2026-04-26 21:55:58 +08:00
feat(workflow): workflow as tool output schema (#26241)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: Novice <novice12185727@gmail.com>
This commit is contained in:
@ -1,4 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
from collections.abc import Mapping
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.tools.entities.tool_entities import ToolParameter
|
||||
|
||||
@ -25,3 +27,5 @@ class ApiToolBundle(BaseModel):
|
||||
icon: str | None = None
|
||||
# openapi operation
|
||||
openapi: dict
|
||||
# output schema
|
||||
output_schema: Mapping[str, object] = Field(default_factory=dict)
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import Any
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
from core.workflow.nodes.base.entities import OutputVariableEntity
|
||||
|
||||
|
||||
class WorkflowToolConfigurationUtils:
|
||||
@ -24,6 +25,31 @@ class WorkflowToolConfigurationUtils:
|
||||
|
||||
return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])]
|
||||
|
||||
@classmethod
|
||||
def get_workflow_graph_output(cls, graph: Mapping[str, Any]) -> Sequence[OutputVariableEntity]:
|
||||
"""
|
||||
get workflow graph output
|
||||
"""
|
||||
nodes = graph.get("nodes", [])
|
||||
outputs_by_variable: dict[str, OutputVariableEntity] = {}
|
||||
variable_order: list[str] = []
|
||||
|
||||
for node in nodes:
|
||||
if node.get("data", {}).get("type") != "end":
|
||||
continue
|
||||
|
||||
for output in node.get("data", {}).get("outputs", []):
|
||||
entity = OutputVariableEntity.model_validate(output)
|
||||
variable = entity.variable
|
||||
|
||||
if variable not in variable_order:
|
||||
variable_order.append(variable)
|
||||
|
||||
# Later end nodes override duplicated variable definitions.
|
||||
outputs_by_variable[variable] = entity
|
||||
|
||||
return [outputs_by_variable[variable] for variable in variable_order]
|
||||
|
||||
@classmethod
|
||||
def check_is_synced(
|
||||
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
|
||||
|
||||
@ -162,6 +162,20 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
else:
|
||||
raise ValueError("variable not found")
|
||||
|
||||
# get output schema from workflow
|
||||
outputs = WorkflowToolConfigurationUtils.get_workflow_graph_output(graph)
|
||||
|
||||
reserved_keys = {"json", "text", "files"}
|
||||
|
||||
properties = {}
|
||||
for output in outputs:
|
||||
if output.variable not in reserved_keys:
|
||||
properties[output.variable] = {
|
||||
"type": output.value_type,
|
||||
"description": "",
|
||||
}
|
||||
output_schema = {"type": "object", "properties": properties}
|
||||
|
||||
return WorkflowTool(
|
||||
workflow_as_tool_id=db_provider.id,
|
||||
entity=ToolEntity(
|
||||
@ -177,6 +191,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
llm=db_provider.description,
|
||||
),
|
||||
parameters=workflow_tool_parameters,
|
||||
output_schema=output_schema,
|
||||
),
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
|
||||
@ -114,6 +114,11 @@ class WorkflowTool(Tool):
|
||||
for file in files:
|
||||
yield self.create_file_message(file) # type: ignore
|
||||
|
||||
# traverse `outputs` field and create variable messages
|
||||
for key, value in outputs.items():
|
||||
if key not in {"text", "json", "files"}:
|
||||
yield self.create_variable_message(variable_name=key, variable_value=value)
|
||||
|
||||
self._latest_usage = self._derive_usage_from_result(data)
|
||||
|
||||
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
|
||||
|
||||
@ -5,7 +5,7 @@ from collections.abc import Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pydantic import BaseModel, field_validator, model_validator
|
||||
|
||||
from core.workflow.enums import ErrorStrategy
|
||||
|
||||
@ -35,6 +35,45 @@ class VariableSelector(BaseModel):
|
||||
value_selector: Sequence[str]
|
||||
|
||||
|
||||
class OutputVariableType(StrEnum):
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
INTEGER = "integer"
|
||||
SECRET = "secret"
|
||||
BOOLEAN = "boolean"
|
||||
OBJECT = "object"
|
||||
FILE = "file"
|
||||
ARRAY = "array"
|
||||
ARRAY_STRING = "array[string]"
|
||||
ARRAY_NUMBER = "array[number]"
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
ARRAY_BOOLEAN = "array[boolean]"
|
||||
ARRAY_FILE = "array[file]"
|
||||
ANY = "any"
|
||||
ARRAY_ANY = "array[any]"
|
||||
|
||||
|
||||
class OutputVariableEntity(BaseModel):
|
||||
"""
|
||||
Output Variable Entity.
|
||||
"""
|
||||
|
||||
variable: str
|
||||
value_type: OutputVariableType
|
||||
value_selector: Sequence[str]
|
||||
|
||||
@field_validator("value_type", mode="before")
|
||||
@classmethod
|
||||
def normalize_value_type(cls, v: Any) -> Any:
|
||||
"""
|
||||
Normalize value_type to handle case-insensitive array types.
|
||||
Converts 'Array[...]' to 'array[...]' for backward compatibility.
|
||||
"""
|
||||
if isinstance(v, str) and v.startswith("Array["):
|
||||
return v.lower()
|
||||
return v
|
||||
|
||||
|
||||
class DefaultValueType(StrEnum):
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, OutputVariableEntity
|
||||
|
||||
|
||||
class EndNodeData(BaseNodeData):
|
||||
@ -9,7 +8,7 @@ class EndNodeData(BaseNodeData):
|
||||
END Node Data.
|
||||
"""
|
||||
|
||||
outputs: list[VariableSelector]
|
||||
outputs: list[OutputVariableEntity]
|
||||
|
||||
|
||||
class EndStreamParam(BaseModel):
|
||||
|
||||
@ -405,6 +405,7 @@ class ToolTransformService:
|
||||
name=tool.operation_id or "",
|
||||
label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id),
|
||||
description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""),
|
||||
output_schema=tool.output_schema,
|
||||
parameters=tool.parameters,
|
||||
labels=labels or [],
|
||||
)
|
||||
|
||||
@ -291,6 +291,10 @@ class WorkflowToolManageService:
|
||||
if len(workflow_tools) == 0:
|
||||
raise ValueError(f"Tool {db_tool.id} not found")
|
||||
|
||||
tool_entity = workflow_tools[0].entity
|
||||
# get output schema from workflow tool entity
|
||||
output_schema = tool_entity.output_schema
|
||||
|
||||
return {
|
||||
"name": db_tool.name,
|
||||
"label": db_tool.label,
|
||||
@ -299,6 +303,7 @@ class WorkflowToolManageService:
|
||||
"icon": json.loads(db_tool.icon),
|
||||
"description": db_tool.description,
|
||||
"parameters": jsonable_encoder(db_tool.parameter_configurations),
|
||||
"output_schema": output_schema,
|
||||
"tool": ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tool=tool.get_tools(db_tool.tenant_id)[0],
|
||||
labels=ToolLabelManager.get_tool_labels(tool),
|
||||
|
||||
@ -257,7 +257,6 @@ class TestWorkflowToolManageService:
|
||||
|
||||
# Attempt to create second workflow tool with same name
|
||||
second_tool_parameters = self._create_test_workflow_tool_parameters()
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=account.id,
|
||||
@ -309,7 +308,6 @@ class TestWorkflowToolManageService:
|
||||
|
||||
# Attempt to create workflow tool with non-existent app
|
||||
tool_parameters = self._create_test_workflow_tool_parameters()
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=account.id,
|
||||
@ -365,7 +363,6 @@ class TestWorkflowToolManageService:
|
||||
"required": True,
|
||||
}
|
||||
]
|
||||
|
||||
# Attempt to create workflow tool with invalid parameters
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
WorkflowToolManageService.create_workflow_tool(
|
||||
@ -416,7 +413,6 @@ class TestWorkflowToolManageService:
|
||||
# Create first workflow tool
|
||||
first_tool_name = fake.word()
|
||||
first_tool_parameters = self._create_test_workflow_tool_parameters()
|
||||
|
||||
WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=account.id,
|
||||
tenant_id=account.current_tenant.id,
|
||||
@ -431,7 +427,6 @@ class TestWorkflowToolManageService:
|
||||
# Attempt to create second workflow tool with same app_id but different name
|
||||
second_tool_name = fake.word()
|
||||
second_tool_parameters = self._create_test_workflow_tool_parameters()
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=account.id,
|
||||
@ -486,7 +481,6 @@ class TestWorkflowToolManageService:
|
||||
|
||||
# Attempt to create workflow tool for app without workflow
|
||||
tool_parameters = self._create_test_workflow_tool_parameters()
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=account.id,
|
||||
@ -534,7 +528,6 @@ class TestWorkflowToolManageService:
|
||||
# Create initial workflow tool
|
||||
initial_tool_name = fake.word()
|
||||
initial_tool_parameters = self._create_test_workflow_tool_parameters()
|
||||
|
||||
WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=account.id,
|
||||
tenant_id=account.current_tenant.id,
|
||||
@ -621,7 +614,6 @@ class TestWorkflowToolManageService:
|
||||
|
||||
# Attempt to update non-existent workflow tool
|
||||
tool_parameters = self._create_test_workflow_tool_parameters()
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
WorkflowToolManageService.update_workflow_tool(
|
||||
user_id=account.id,
|
||||
@ -671,7 +663,6 @@ class TestWorkflowToolManageService:
|
||||
# Create first workflow tool
|
||||
first_tool_name = fake.word()
|
||||
first_tool_parameters = self._create_test_workflow_tool_parameters()
|
||||
|
||||
WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=account.id,
|
||||
tenant_id=account.current_tenant.id,
|
||||
|
||||
@ -3,7 +3,7 @@ import pytest
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
|
||||
@ -51,3 +51,166 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel
|
||||
# actually `run` the tool.
|
||||
list(tool.invoke("test_user", {}))
|
||||
assert exc_info.value.args == ("oops",)
|
||||
|
||||
|
||||
def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test that WorkflowTool should generate variable messages when there are outputs"""
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
|
||||
parameters=[],
|
||||
description=None,
|
||||
has_runtime_parameters=False,
|
||||
)
|
||||
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
|
||||
tool = WorkflowTool(
|
||||
workflow_app_id="",
|
||||
workflow_as_tool_id="",
|
||||
version="1",
|
||||
workflow_entities={},
|
||||
workflow_call_depth=1,
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
# Mock workflow outputs
|
||||
mock_outputs = {"result": "success", "count": 42, "data": {"key": "value"}}
|
||||
|
||||
# needs to patch those methods to avoid database access.
|
||||
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
|
||||
|
||||
# Mock user resolution to avoid database access
|
||||
from unittest.mock import Mock
|
||||
|
||||
mock_user = Mock()
|
||||
monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user)
|
||||
|
||||
# replace `WorkflowAppGenerator.generate` 's return value.
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate",
|
||||
lambda *args, **kwargs: {"data": {"outputs": mock_outputs}},
|
||||
)
|
||||
monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None)
|
||||
|
||||
# Execute tool invocation
|
||||
messages = list(tool.invoke("test_user", {}))
|
||||
|
||||
# Verify generated messages
|
||||
# Should contain: 3 variable messages + 1 text message + 1 JSON message = 5 messages
|
||||
assert len(messages) == 5
|
||||
|
||||
# Verify variable messages
|
||||
variable_messages = [msg for msg in messages if msg.type == ToolInvokeMessage.MessageType.VARIABLE]
|
||||
assert len(variable_messages) == 3
|
||||
|
||||
# Verify content of each variable message
|
||||
variable_dict = {msg.message.variable_name: msg.message.variable_value for msg in variable_messages}
|
||||
assert variable_dict["result"] == "success"
|
||||
assert variable_dict["count"] == 42
|
||||
assert variable_dict["data"] == {"key": "value"}
|
||||
|
||||
# Verify text message
|
||||
text_messages = [msg for msg in messages if msg.type == ToolInvokeMessage.MessageType.TEXT]
|
||||
assert len(text_messages) == 1
|
||||
assert '{"result": "success", "count": 42, "data": {"key": "value"}}' in text_messages[0].message.text
|
||||
|
||||
# Verify JSON message
|
||||
json_messages = [msg for msg in messages if msg.type == ToolInvokeMessage.MessageType.JSON]
|
||||
assert len(json_messages) == 1
|
||||
assert json_messages[0].message.json_object == mock_outputs
|
||||
|
||||
|
||||
def test_workflow_tool_should_handle_empty_outputs(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test that WorkflowTool should handle empty outputs correctly"""
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
|
||||
parameters=[],
|
||||
description=None,
|
||||
has_runtime_parameters=False,
|
||||
)
|
||||
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
|
||||
tool = WorkflowTool(
|
||||
workflow_app_id="",
|
||||
workflow_as_tool_id="",
|
||||
version="1",
|
||||
workflow_entities={},
|
||||
workflow_call_depth=1,
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
# needs to patch those methods to avoid database access.
|
||||
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
|
||||
|
||||
# Mock user resolution to avoid database access
|
||||
from unittest.mock import Mock
|
||||
|
||||
mock_user = Mock()
|
||||
monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user)
|
||||
|
||||
# replace `WorkflowAppGenerator.generate` 's return value.
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate",
|
||||
lambda *args, **kwargs: {"data": {}},
|
||||
)
|
||||
monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None)
|
||||
|
||||
# Execute tool invocation
|
||||
messages = list(tool.invoke("test_user", {}))
|
||||
|
||||
# Verify generated messages
|
||||
# Should contain: 0 variable messages + 1 text message + 1 JSON message = 2 messages
|
||||
assert len(messages) == 2
|
||||
|
||||
# Verify no variable messages
|
||||
variable_messages = [msg for msg in messages if msg.type == ToolInvokeMessage.MessageType.VARIABLE]
|
||||
assert len(variable_messages) == 0
|
||||
|
||||
# Verify text message
|
||||
text_messages = [msg for msg in messages if msg.type == ToolInvokeMessage.MessageType.TEXT]
|
||||
assert len(text_messages) == 1
|
||||
assert text_messages[0].message.text == "{}"
|
||||
|
||||
# Verify JSON message
|
||||
json_messages = [msg for msg in messages if msg.type == ToolInvokeMessage.MessageType.JSON]
|
||||
assert len(json_messages) == 1
|
||||
assert json_messages[0].message.json_object == {}
|
||||
|
||||
|
||||
def test_create_variable_message():
|
||||
"""Test the functionality of creating variable messages"""
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
|
||||
parameters=[],
|
||||
description=None,
|
||||
has_runtime_parameters=False,
|
||||
)
|
||||
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
|
||||
tool = WorkflowTool(
|
||||
workflow_app_id="",
|
||||
workflow_as_tool_id="",
|
||||
version="1",
|
||||
workflow_entities={},
|
||||
workflow_call_depth=1,
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
# Test different types of variable values
|
||||
test_cases = [
|
||||
("string_var", "test string"),
|
||||
("int_var", 42),
|
||||
("float_var", 3.14),
|
||||
("bool_var", True),
|
||||
("list_var", [1, 2, 3]),
|
||||
("dict_var", {"key": "value"}),
|
||||
]
|
||||
|
||||
for var_name, var_value in test_cases:
|
||||
message = tool.create_variable_message(var_name, var_value)
|
||||
|
||||
assert message.type == ToolInvokeMessage.MessageType.VARIABLE
|
||||
assert message.message.variable_name == var_name
|
||||
assert message.message.variable_value == var_value
|
||||
assert message.message.stream is False
|
||||
|
||||
@ -14,7 +14,7 @@ from core.workflow.graph_events import (
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.human_input import HumanInputNode
|
||||
@ -110,8 +110,12 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
|
||||
end_primary_data = EndNodeData(
|
||||
title="End Primary",
|
||||
outputs=[
|
||||
VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]),
|
||||
VariableSelector(variable="primary_text", value_selector=["llm_primary", "text"]),
|
||||
OutputVariableEntity(
|
||||
variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"]
|
||||
),
|
||||
OutputVariableEntity(
|
||||
variable="primary_text", value_type=OutputVariableType.STRING, value_selector=["llm_primary", "text"]
|
||||
),
|
||||
],
|
||||
desc=None,
|
||||
)
|
||||
@ -126,8 +130,14 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
|
||||
end_secondary_data = EndNodeData(
|
||||
title="End Secondary",
|
||||
outputs=[
|
||||
VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]),
|
||||
VariableSelector(variable="secondary_text", value_selector=["llm_secondary", "text"]),
|
||||
OutputVariableEntity(
|
||||
variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"]
|
||||
),
|
||||
OutputVariableEntity(
|
||||
variable="secondary_text",
|
||||
value_type=OutputVariableType.STRING,
|
||||
value_selector=["llm_secondary", "text"],
|
||||
),
|
||||
],
|
||||
desc=None,
|
||||
)
|
||||
|
||||
@ -13,7 +13,7 @@ from core.workflow.graph_events import (
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.human_input import HumanInputNode
|
||||
@ -108,8 +108,12 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
|
||||
end_data = EndNodeData(
|
||||
title="End",
|
||||
outputs=[
|
||||
VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]),
|
||||
VariableSelector(variable="resume_text", value_selector=["llm_resume", "text"]),
|
||||
OutputVariableEntity(
|
||||
variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"]
|
||||
),
|
||||
OutputVariableEntity(
|
||||
variable="resume_text", value_type=OutputVariableType.STRING, value_selector=["llm_resume", "text"]
|
||||
),
|
||||
],
|
||||
desc=None,
|
||||
)
|
||||
|
||||
@ -11,7 +11,7 @@ from core.workflow.graph_events import (
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.if_else.entities import IfElseNodeData
|
||||
@ -123,8 +123,12 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
|
||||
end_primary_data = EndNodeData(
|
||||
title="End Primary",
|
||||
outputs=[
|
||||
VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]),
|
||||
VariableSelector(variable="primary_text", value_selector=["llm_primary", "text"]),
|
||||
OutputVariableEntity(
|
||||
variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"]
|
||||
),
|
||||
OutputVariableEntity(
|
||||
variable="primary_text", value_type=OutputVariableType.STRING, value_selector=["llm_primary", "text"]
|
||||
),
|
||||
],
|
||||
desc=None,
|
||||
)
|
||||
@ -139,8 +143,14 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
|
||||
end_secondary_data = EndNodeData(
|
||||
title="End Secondary",
|
||||
outputs=[
|
||||
VariableSelector(variable="initial_text", value_selector=["llm_initial", "text"]),
|
||||
VariableSelector(variable="secondary_text", value_selector=["llm_secondary", "text"]),
|
||||
OutputVariableEntity(
|
||||
variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"]
|
||||
),
|
||||
OutputVariableEntity(
|
||||
variable="secondary_text",
|
||||
value_type=OutputVariableType.STRING,
|
||||
value_selector=["llm_secondary", "text"],
|
||||
),
|
||||
],
|
||||
desc=None,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user