mirror of
https://github.com/langgenius/dify.git
synced 2026-04-30 23:48:04 +08:00
feat(workflow): workflow as tool output schema (#26241)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: Novice <novice12185727@gmail.com>
This commit is contained in:
@ -1,4 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
from collections.abc import Mapping
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.tools.entities.tool_entities import ToolParameter
|
||||
|
||||
@ -25,3 +27,5 @@ class ApiToolBundle(BaseModel):
|
||||
icon: str | None = None
|
||||
# openapi operation
|
||||
openapi: dict
|
||||
# output schema
|
||||
output_schema: Mapping[str, object] = Field(default_factory=dict)
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import Any
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
from core.workflow.nodes.base.entities import OutputVariableEntity
|
||||
|
||||
|
||||
class WorkflowToolConfigurationUtils:
|
||||
@ -24,6 +25,31 @@ class WorkflowToolConfigurationUtils:
|
||||
|
||||
return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])]
|
||||
|
||||
@classmethod
|
||||
def get_workflow_graph_output(cls, graph: Mapping[str, Any]) -> Sequence[OutputVariableEntity]:
|
||||
"""
|
||||
get workflow graph output
|
||||
"""
|
||||
nodes = graph.get("nodes", [])
|
||||
outputs_by_variable: dict[str, OutputVariableEntity] = {}
|
||||
variable_order: list[str] = []
|
||||
|
||||
for node in nodes:
|
||||
if node.get("data", {}).get("type") != "end":
|
||||
continue
|
||||
|
||||
for output in node.get("data", {}).get("outputs", []):
|
||||
entity = OutputVariableEntity.model_validate(output)
|
||||
variable = entity.variable
|
||||
|
||||
if variable not in variable_order:
|
||||
variable_order.append(variable)
|
||||
|
||||
# Later end nodes override duplicated variable definitions.
|
||||
outputs_by_variable[variable] = entity
|
||||
|
||||
return [outputs_by_variable[variable] for variable in variable_order]
|
||||
|
||||
@classmethod
|
||||
def check_is_synced(
|
||||
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
|
||||
|
||||
@ -162,6 +162,20 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
else:
|
||||
raise ValueError("variable not found")
|
||||
|
||||
# get output schema from workflow
|
||||
outputs = WorkflowToolConfigurationUtils.get_workflow_graph_output(graph)
|
||||
|
||||
reserved_keys = {"json", "text", "files"}
|
||||
|
||||
properties = {}
|
||||
for output in outputs:
|
||||
if output.variable not in reserved_keys:
|
||||
properties[output.variable] = {
|
||||
"type": output.value_type,
|
||||
"description": "",
|
||||
}
|
||||
output_schema = {"type": "object", "properties": properties}
|
||||
|
||||
return WorkflowTool(
|
||||
workflow_as_tool_id=db_provider.id,
|
||||
entity=ToolEntity(
|
||||
@ -177,6 +191,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
llm=db_provider.description,
|
||||
),
|
||||
parameters=workflow_tool_parameters,
|
||||
output_schema=output_schema,
|
||||
),
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
|
||||
@ -114,6 +114,11 @@ class WorkflowTool(Tool):
|
||||
for file in files:
|
||||
yield self.create_file_message(file) # type: ignore
|
||||
|
||||
# traverse `outputs` field and create variable messages
|
||||
for key, value in outputs.items():
|
||||
if key not in {"text", "json", "files"}:
|
||||
yield self.create_variable_message(variable_name=key, variable_value=value)
|
||||
|
||||
self._latest_usage = self._derive_usage_from_result(data)
|
||||
|
||||
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
|
||||
|
||||
@ -5,7 +5,7 @@ from collections.abc import Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pydantic import BaseModel, field_validator, model_validator
|
||||
|
||||
from core.workflow.enums import ErrorStrategy
|
||||
|
||||
@ -35,6 +35,45 @@ class VariableSelector(BaseModel):
|
||||
value_selector: Sequence[str]
|
||||
|
||||
|
||||
class OutputVariableType(StrEnum):
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
INTEGER = "integer"
|
||||
SECRET = "secret"
|
||||
BOOLEAN = "boolean"
|
||||
OBJECT = "object"
|
||||
FILE = "file"
|
||||
ARRAY = "array"
|
||||
ARRAY_STRING = "array[string]"
|
||||
ARRAY_NUMBER = "array[number]"
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
ARRAY_BOOLEAN = "array[boolean]"
|
||||
ARRAY_FILE = "array[file]"
|
||||
ANY = "any"
|
||||
ARRAY_ANY = "array[any]"
|
||||
|
||||
|
||||
class OutputVariableEntity(BaseModel):
|
||||
"""
|
||||
Output Variable Entity.
|
||||
"""
|
||||
|
||||
variable: str
|
||||
value_type: OutputVariableType
|
||||
value_selector: Sequence[str]
|
||||
|
||||
@field_validator("value_type", mode="before")
|
||||
@classmethod
|
||||
def normalize_value_type(cls, v: Any) -> Any:
|
||||
"""
|
||||
Normalize value_type to handle case-insensitive array types.
|
||||
Converts 'Array[...]' to 'array[...]' for backward compatibility.
|
||||
"""
|
||||
if isinstance(v, str) and v.startswith("Array["):
|
||||
return v.lower()
|
||||
return v
|
||||
|
||||
|
||||
class DefaultValueType(StrEnum):
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, OutputVariableEntity
|
||||
|
||||
|
||||
class EndNodeData(BaseNodeData):
|
||||
@ -9,7 +8,7 @@ class EndNodeData(BaseNodeData):
|
||||
END Node Data.
|
||||
"""
|
||||
|
||||
outputs: list[VariableSelector]
|
||||
outputs: list[OutputVariableEntity]
|
||||
|
||||
|
||||
class EndStreamParam(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user