mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
feat(workflow): integrate workflow entry with workflow app
This commit is contained in:
@ -1,10 +1,9 @@
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
|
||||
|
||||
class WorkflowNodeRunFailedError(Exception):
|
||||
def __init__(self, node_id: str, node_type: NodeType, node_title: str, error: str):
|
||||
self.node_id = node_id
|
||||
self.node_type = node_type
|
||||
self.node_title = node_title
|
||||
def __init__(self, node_instance: BaseNode, error: str):
|
||||
self.node_instance = node_instance
|
||||
self.error = error
|
||||
super().__init__(f"Node {node_title} run failed: {error}")
|
||||
super().__init__(f"Node {node_instance.node_data.title} run failed: {error}")
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from typing import cast
|
||||
from typing import Any, Mapping, Sequence, cast
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
|
||||
from core.workflow.nodes.answer.entities import (
|
||||
AnswerNodeData,
|
||||
@ -52,9 +53,16 @@ class AnswerNode(BaseNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: AnswerNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
@ -66,6 +74,6 @@ class AnswerNode(BaseNode):
|
||||
|
||||
variable_mapping = {}
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
variable_mapping[node_id + '.' + variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
return variable_mapping
|
||||
|
||||
@ -67,19 +67,35 @@ class BaseNode(ABC):
|
||||
yield from result
|
||||
|
||||
@classmethod
|
||||
def extract_variable_selector_to_variable_mapping(cls, config: dict):
|
||||
def extract_variable_selector_to_variable_mapping(cls, graph_config: Mapping[str, Any], config: dict) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_id = config.get("id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
|
||||
|
||||
node_data = cls._node_data_cls(**config.get("data", {}))
|
||||
return cls._extract_variable_selector_to_variable_mapping(node_data)
|
||||
return cls._extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config,
|
||||
node_id=node_id,
|
||||
node_data=node_data
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> Mapping[str, Sequence[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: BaseNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Optional, Union, cast
|
||||
from typing import Any, Mapping, Optional, Sequence, Union, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage
|
||||
@ -314,13 +314,19 @@ class CodeNode(BaseNode):
|
||||
return transformed_result
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: CodeNodeData) -> dict[str, list[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: CodeNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
|
||||
return {
|
||||
variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
|
||||
node_id + '.' + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
|
||||
}
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import cast
|
||||
from typing import Any, Mapping, Sequence, cast
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
@ -32,9 +32,16 @@ class EndNode(BaseNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: EndNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from mimetypes import guess_extension
|
||||
from os import path
|
||||
from typing import cast
|
||||
from typing import Any, Mapping, Sequence, cast
|
||||
|
||||
from core.app.segments import parser
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
@ -107,13 +107,19 @@ class HttpRequestNode(BaseNode):
|
||||
return timeout
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: HttpRequestNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
node_data = cast(HttpRequestNodeData, node_data)
|
||||
try:
|
||||
http_executor = HttpExecutor(node_data=node_data, timeout=HTTP_REQUEST_DEFAULT_TIMEOUT)
|
||||
|
||||
@ -121,7 +127,7 @@ class HttpRequestNode(BaseNode):
|
||||
|
||||
variable_mapping = {}
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
variable_mapping[node_id + '.' + variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
return variable_mapping
|
||||
except Exception as e:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import cast
|
||||
from typing import Any, Mapping, Sequence, cast
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
@ -99,9 +99,16 @@ class IfElseNode(BaseNode):
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: IfElseNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseGraphEvent,
|
||||
@ -287,12 +288,67 @@ class IterationNode(BaseNode):
|
||||
variable_pool.remove([self.node_id, 'item'])
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: IterationNodeData) -> dict[str, list[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: IterationNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {
|
||||
'input_selector': node_data.iterator_selector,
|
||||
variable_mapping = {
|
||||
f'{node_id}.input_selector': node_data.iterator_selector,
|
||||
}
|
||||
|
||||
# init graph
|
||||
iteration_graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
root_node_id=node_data.start_node_id
|
||||
)
|
||||
|
||||
if not iteration_graph:
|
||||
raise ValueError('iteration graph not found')
|
||||
|
||||
for sub_node_id, sub_node_config in iteration_graph.node_id_config_mapping.items():
|
||||
if sub_node_config.get('data', {}).get('iteration_id') != node_id:
|
||||
continue
|
||||
|
||||
# variable selector to variable mapping
|
||||
try:
|
||||
# Get node class
|
||||
from core.workflow.nodes.node_mapping import node_classes
|
||||
node_type = NodeType.value_of(sub_node_config.get('data', {}).get('type'))
|
||||
node_cls = node_classes.get(node_type)
|
||||
if not node_cls:
|
||||
continue
|
||||
|
||||
node_cls = cast(BaseNode, node_cls)
|
||||
|
||||
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config,
|
||||
config=sub_node_config
|
||||
)
|
||||
sub_node_variable_mapping = cast(dict[str, list[str]], sub_node_variable_mapping)
|
||||
except NotImplementedError:
|
||||
sub_node_variable_mapping = {}
|
||||
|
||||
# remove iteration variables
|
||||
sub_node_variable_mapping = {
|
||||
sub_node_id + '.' + key: value for key, value in sub_node_variable_mapping.items()
|
||||
if value[0] != node_id
|
||||
}
|
||||
|
||||
variable_mapping.update(sub_node_variable_mapping)
|
||||
|
||||
# remove variable out from iteration
|
||||
variable_mapping = {
|
||||
key: value for key, value in variable_mapping.items()
|
||||
if value[0] not in iteration_graph.node_ids
|
||||
}
|
||||
|
||||
return variable_mapping
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
from typing import Any, Mapping, Sequence, cast
|
||||
|
||||
from sqlalchemy import func
|
||||
|
||||
@ -232,11 +232,21 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
return context_list
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
node_data = node_data
|
||||
node_data = cast(cls._node_data_cls, node_data)
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: KnowledgeRetrievalNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
variable_mapping = {}
|
||||
variable_mapping['query'] = node_data.query_variable_selector
|
||||
variable_mapping[node_id + '.query'] = node_data.query_variable_selector
|
||||
return variable_mapping
|
||||
|
||||
def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from typing import Optional, cast
|
||||
from typing import Any, Mapping, Optional, Sequence, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -678,13 +678,19 @@ class LLMNode(BaseNode):
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: LLMNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
node_data = cast(LLMNodeData, node_data)
|
||||
prompt_template = node_data.prompt_template
|
||||
|
||||
variable_selectors = []
|
||||
@ -734,6 +740,10 @@ class LLMNode(BaseNode):
|
||||
for variable_selector in node_data.prompt_config.jinja2_variables or []:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
variable_mapping = {
|
||||
node_id + '.' + key: value for key, value in variable_mapping.items()
|
||||
}
|
||||
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import Optional, cast
|
||||
from typing import Any, Mapping, Optional, Sequence, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
@ -701,15 +701,19 @@ class ParameterExtractorNode(LLMNode):
|
||||
return self._model_instance, self._model_config
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: ParameterExtractorNodeData) -> dict[
|
||||
str, list[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: ParameterExtractorNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
node_data = node_data
|
||||
|
||||
variable_mapping = {
|
||||
'query': node_data.query
|
||||
}
|
||||
@ -719,4 +723,8 @@ class ParameterExtractorNode(LLMNode):
|
||||
for selector in variable_template_parser.extract_variable_selectors():
|
||||
variable_mapping[selector.variable] = selector.value_selector
|
||||
|
||||
variable_mapping = {
|
||||
node_id + '.' + key: value for key, value in variable_mapping.items()
|
||||
}
|
||||
|
||||
return variable_mapping
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Union, cast
|
||||
from typing import Any, Mapping, Optional, Sequence, Union, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
@ -137,9 +137,19 @@ class QuestionClassifierNode(LLMNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
node_data = node_data
|
||||
node_data = cast(cls._node_data_cls, node_data)
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: QuestionClassifierNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
variable_mapping = {'query': node_data.query_variable_selector}
|
||||
variable_selectors = []
|
||||
if node_data.instruction:
|
||||
@ -147,6 +157,11 @@ class QuestionClassifierNode(LLMNode):
|
||||
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
variable_mapping = {
|
||||
node_id + '.' + key: value for key, value in variable_mapping.items()
|
||||
}
|
||||
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
|
||||
from typing import Any, Mapping, Sequence
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
@ -28,9 +29,16 @@ class StartNode(BaseNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: StartNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Optional, cast
|
||||
from typing import Any, Mapping, Optional, Sequence, cast
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
@ -77,13 +77,19 @@ class TemplateTransformNode(BaseNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: TemplateTransformNodeData) -> dict[
|
||||
str, list[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: TemplateTransformNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {
|
||||
variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
|
||||
node_id + '.' + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
|
||||
}
|
||||
|
||||
@ -221,9 +221,16 @@ class ToolNode(BaseNode):
|
||||
return [message.message for message in tool_response if message.type == ToolInvokeMessage.MessageType.JSON]
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: ToolNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
@ -239,4 +246,8 @@ class ToolNode(BaseNode):
|
||||
elif input.type == 'constant':
|
||||
pass
|
||||
|
||||
result = {
|
||||
node_id + '.' + key: value for key, value in result.items()
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import cast
|
||||
from typing import Any, Mapping, Sequence, cast
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
@ -48,5 +48,17 @@ class VariableAggregatorNode(BaseNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: VariableAssignerNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, Optional, Type, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.app_config.entities import FileExtraConfig
|
||||
@ -8,13 +10,18 @@ from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType, SystemVariable, UserFrom
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType, UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent
|
||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent, InNodeEvent
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
|
||||
from core.workflow.nodes.iteration.entities import IterationNodeData
|
||||
from core.workflow.nodes.llm.entities import LLMNodeData
|
||||
from core.workflow.nodes.node_mapping import node_classes
|
||||
from models.workflow import (
|
||||
@ -32,18 +39,17 @@ class WorkflowEntry:
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
user_inputs: Mapping[str, Any],
|
||||
system_inputs: Mapping[SystemVariable, Any],
|
||||
call_depth: int = 0
|
||||
call_depth: int,
|
||||
variable_pool: VariablePool
|
||||
) -> None:
|
||||
"""
|
||||
:param workflow: Workflow instance
|
||||
:param user_id: user id
|
||||
:param user_from: user from
|
||||
:param invoke_from: invoke from service-api, web-app, debugger, explore
|
||||
:param user_inputs: user variables inputs
|
||||
:param system_inputs: system inputs, like: query, files
|
||||
:param call_depth: call depth
|
||||
:param variable_pool: variable pool
|
||||
:param single_step_run_iteration_id: single step run iteration id
|
||||
"""
|
||||
# fetch workflow graph
|
||||
graph_config = workflow.graph_dict
|
||||
@ -71,13 +77,6 @@ class WorkflowEntry:
|
||||
if not graph:
|
||||
raise ValueError('graph not found in workflow')
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=user_inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
|
||||
# init workflow run state
|
||||
self.graph_engine = GraphEngine(
|
||||
tenant_id=workflow.tenant_id,
|
||||
@ -134,10 +133,160 @@ class WorkflowEntry:
|
||||
)
|
||||
return
|
||||
|
||||
def single_step_run(self, workflow: Workflow,
|
||||
node_id: str,
|
||||
user_id: str,
|
||||
user_inputs: dict) -> tuple[BaseNode, NodeRunResult]:
|
||||
@classmethod
|
||||
def single_step_run_iteration(
|
||||
cls,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_id: str,
|
||||
user_inputs: dict,
|
||||
callbacks: Sequence[WorkflowCallback],
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Single step run workflow node iteration
|
||||
:param workflow: Workflow instance
|
||||
:param node_id: node id
|
||||
:param user_id: user id
|
||||
:param user_inputs: user inputs
|
||||
:return:
|
||||
"""
|
||||
# fetch workflow graph
|
||||
graph_config = workflow.graph_dict
|
||||
if not graph_config:
|
||||
raise ValueError('workflow graph not found')
|
||||
|
||||
graph_config = cast(dict[str, Any], graph_config)
|
||||
|
||||
if 'nodes' not in graph_config or 'edges' not in graph_config:
|
||||
raise ValueError('nodes or edges not found in workflow graph')
|
||||
|
||||
if not isinstance(graph_config.get('nodes'), list):
|
||||
raise ValueError('nodes in workflow graph must be a list')
|
||||
|
||||
if not isinstance(graph_config.get('edges'), list):
|
||||
raise ValueError('edges in workflow graph must be a list')
|
||||
|
||||
# filter nodes only in iteration
|
||||
node_configs = [
|
||||
node for node in graph_config.get('nodes', [])
|
||||
if node.get('id') == node_id or node.get('data', {}).get('iteration_id', '') == node_id
|
||||
]
|
||||
|
||||
graph_config['nodes'] = node_configs
|
||||
|
||||
node_ids = [node.get('id') for node in node_configs]
|
||||
|
||||
# filter edges only in iteration
|
||||
edge_configs = [
|
||||
edge for edge in graph_config.get('edges', [])
|
||||
if (edge.get('source') is None or edge.get('source') in node_ids)
|
||||
and (edge.get('target') is None or edge.get('target') in node_ids)
|
||||
]
|
||||
|
||||
graph_config['edges'] = edge_configs
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
root_node_id=node_id
|
||||
)
|
||||
|
||||
if not graph:
|
||||
raise ValueError('graph not found in workflow')
|
||||
|
||||
# fetch node config from node id
|
||||
iteration_node_config = None
|
||||
for node in node_configs:
|
||||
if node.get('id') == node_id:
|
||||
iteration_node_config = node
|
||||
break
|
||||
|
||||
if not iteration_node_config:
|
||||
raise ValueError('iteration node id not found in workflow graph')
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType.value_of(iteration_node_config.get('data', {}).get('type'))
|
||||
node_cls = node_classes.get(node_type)
|
||||
node_cls = cast(type[BaseNode], node_cls)
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict,
|
||||
config=iteration_node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
||||
cls._mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
node_type=node_type,
|
||||
node_data=IterationNodeData(**iteration_node_config.get('data', {}))
|
||||
)
|
||||
|
||||
# init workflow run state
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_type=WorkflowType.value_of(workflow.type),
|
||||
workflow_id=workflow.id,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=1,
|
||||
graph=graph,
|
||||
graph_config=graph_config,
|
||||
variable_pool=variable_pool,
|
||||
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
|
||||
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
|
||||
)
|
||||
|
||||
try:
|
||||
# run workflow
|
||||
generator = graph_engine.run()
|
||||
for event in generator:
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
callback.on_event(
|
||||
graph=graph_engine.graph,
|
||||
graph_init_params=graph_engine.init_params,
|
||||
graph_runtime_state=graph_engine.graph_runtime_state,
|
||||
event=event
|
||||
)
|
||||
yield event
|
||||
except GenerateTaskStoppedException:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when workflow entry running")
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
callback.on_event(
|
||||
graph=graph_engine.graph,
|
||||
graph_init_params=graph_engine.init_params,
|
||||
graph_runtime_state=graph_engine.graph_runtime_state,
|
||||
event=GraphRunFailedEvent(
|
||||
error=str(e)
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def single_step_run(
|
||||
cls,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_id: str,
|
||||
user_inputs: dict
|
||||
) -> tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]:
|
||||
"""
|
||||
Single step run workflow node
|
||||
:param workflow: Workflow instance
|
||||
@ -168,61 +317,74 @@ class WorkflowEntry:
|
||||
# Get node class
|
||||
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
|
||||
node_cls = node_classes.get(node_type)
|
||||
node_cls = cast(type[BaseNode], node_cls)
|
||||
|
||||
if not node_cls:
|
||||
raise ValueError(f'Node class not found for node type {node_type}')
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(
|
||||
graph_config=workflow.graph_dict
|
||||
)
|
||||
|
||||
# init workflow run state
|
||||
node_instance = node_cls(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
node_instance: BaseNode = node_cls(
|
||||
id=str(uuid.uuid4()),
|
||||
config=node_config,
|
||||
workflow_call_depth=0
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_type=WorkflowType.value_of(workflow.type),
|
||||
workflow_id=workflow.id,
|
||||
graph_config=workflow.graph_dict,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0
|
||||
),
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=time.perf_counter()
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
|
||||
# variable selector to variable mapping
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(node_config)
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict,
|
||||
config=node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
||||
self._mapping_user_inputs_to_variable_pool(
|
||||
cls._mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
node_instance=node_instance
|
||||
node_type=node_type,
|
||||
node_data=node_instance.node_data
|
||||
)
|
||||
|
||||
# run node
|
||||
node_run_result = node_instance.run(
|
||||
variable_pool=variable_pool
|
||||
)
|
||||
generator = node_instance.run()
|
||||
|
||||
# sign output files
|
||||
node_run_result.outputs = self.handle_special_values(node_run_result.outputs)
|
||||
return node_instance, generator
|
||||
except Exception as e:
|
||||
raise WorkflowNodeRunFailedError(
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_title=node_instance.node_data.title,
|
||||
node_instance=node_instance,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
return node_instance, node_run_result
|
||||
|
||||
@classmethod
|
||||
def handle_special_values(cls, value: Optional[Mapping[str, Any]]) -> Optional[dict]:
|
||||
"""
|
||||
@ -250,33 +412,49 @@ class WorkflowEntry:
|
||||
|
||||
return new_value
|
||||
|
||||
def _mapping_user_inputs_to_variable_pool(self,
|
||||
variable_mapping: dict,
|
||||
user_inputs: dict,
|
||||
variable_pool: VariablePool,
|
||||
tenant_id: str,
|
||||
node_instance: BaseNode):
|
||||
for variable_key, variable_selector in variable_mapping.items():
|
||||
if variable_key not in user_inputs and not variable_pool.get(variable_selector):
|
||||
raise ValueError(f'Variable key {variable_key} not found in user inputs.')
|
||||
@classmethod
|
||||
def _mapping_user_inputs_to_variable_pool(
|
||||
cls,
|
||||
variable_mapping: Mapping[str, Sequence[str]],
|
||||
user_inputs: dict,
|
||||
variable_pool: VariablePool,
|
||||
tenant_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData
|
||||
) -> None:
|
||||
for node_variable, variable_selector in variable_mapping.items():
|
||||
# fetch node id and variable key from node_variable
|
||||
node_variable_list = node_variable.split('.')
|
||||
if len(node_variable_list) < 1:
|
||||
raise ValueError(f'Invalid node variable {node_variable}')
|
||||
|
||||
node_variable_key = node_variable_list[1:]
|
||||
|
||||
if (
|
||||
node_variable_key not in user_inputs
|
||||
or node_variable not in user_inputs
|
||||
) and not variable_pool.get(variable_selector):
|
||||
raise ValueError(f'Variable key {node_variable} not found in user inputs.')
|
||||
|
||||
# fetch variable node id from variable selector
|
||||
variable_node_id = variable_selector[0]
|
||||
variable_key_list = variable_selector[1:]
|
||||
variable_key_list = cast(list[str], variable_key_list)
|
||||
|
||||
# get value
|
||||
value = user_inputs.get(variable_key)
|
||||
# get input value
|
||||
input_value = user_inputs.get(node_variable)
|
||||
if not input_value:
|
||||
input_value = user_inputs.get(node_variable_key)
|
||||
|
||||
# FIXME: temp fix for image type
|
||||
if node_instance.node_type == NodeType.LLM:
|
||||
if node_type == NodeType.LLM:
|
||||
new_value = []
|
||||
if isinstance(value, list):
|
||||
node_data = node_instance.node_data
|
||||
if isinstance(input_value, list):
|
||||
node_data = cast(LLMNodeData, node_data)
|
||||
|
||||
detail = node_data.vision.configs.detail if node_data.vision.configs else None
|
||||
|
||||
for item in value:
|
||||
for item in input_value:
|
||||
if isinstance(item, dict) and 'type' in item and item['type'] == 'image':
|
||||
transfer_method = FileTransferMethod.value_of(item.get('transfer_method'))
|
||||
file = FileVar(
|
||||
@ -294,4 +472,4 @@ class WorkflowEntry:
|
||||
value = new_value
|
||||
|
||||
# append variable and value to variable pool
|
||||
variable_pool.add([variable_node_id] + variable_key_list, value)
|
||||
variable_pool.add([variable_node_id] + variable_key_list, input_value)
|
||||
|
||||
Reference in New Issue
Block a user