mirror of
https://github.com/langgenius/dify.git
synced 2026-04-24 12:55:49 +08:00
Merge remote-tracking branch 'origin/feat/plugins' into dev/plugin-deploy
This commit is contained in:
@ -2,6 +2,6 @@ from .base_workflow_callback import WorkflowCallback
|
||||
from .workflow_logging_callback import WorkflowLoggingCallback
|
||||
|
||||
__all__ = [
|
||||
"WorkflowLoggingCallback",
|
||||
"WorkflowCallback",
|
||||
"WorkflowLoggingCallback",
|
||||
]
|
||||
|
||||
@ -39,7 +39,7 @@ from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProce
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.node_mapping import node_type_classes_mapping
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
@ -65,7 +65,6 @@ class GraphEngineThreadPool(ThreadPoolExecutor):
|
||||
self.submit_count -= 1
|
||||
|
||||
def check_is_full(self) -> None:
|
||||
print(f"submit_count: {self.submit_count}, max_submit_count: {self.max_submit_count}")
|
||||
if self.submit_count > self.max_submit_count:
|
||||
raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.")
|
||||
|
||||
@ -229,7 +228,8 @@ class GraphEngine:
|
||||
|
||||
# convert to specific node
|
||||
node_type = NodeType(node_config.get("data", {}).get("type"))
|
||||
node_cls = node_type_classes_mapping[node_type]
|
||||
node_version = node_config.get("data", {}).get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from .answer_node import AnswerNode
|
||||
from .entities import AnswerStreamGenerateRoute
|
||||
|
||||
__all__ = ["AnswerStreamGenerateRoute", "AnswerNode"]
|
||||
__all__ = ["AnswerNode", "AnswerStreamGenerateRoute"]
|
||||
|
||||
@ -153,7 +153,7 @@ class AnswerStreamGeneratorRouter:
|
||||
NodeType.IF_ELSE,
|
||||
NodeType.QUESTION_CLASSIFIER,
|
||||
NodeType.ITERATION,
|
||||
NodeType.CONVERSATION_VARIABLE_ASSIGNER,
|
||||
NodeType.VARIABLE_ASSIGNER,
|
||||
}:
|
||||
answer_dependencies[answer_node_id].append(source_node_id)
|
||||
else:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from .entities import BaseIterationNodeData, BaseIterationState, BaseNodeData
|
||||
from .node import BaseNode
|
||||
|
||||
__all__ = ["BaseNode", "BaseNodeData", "BaseIterationNodeData", "BaseIterationState"]
|
||||
__all__ = ["BaseIterationNodeData", "BaseIterationState", "BaseNode", "BaseNodeData"]
|
||||
|
||||
@ -7,6 +7,7 @@ from pydantic import BaseModel
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
title: str
|
||||
desc: Optional[str] = None
|
||||
version: str = "1"
|
||||
|
||||
|
||||
class BaseIterationNodeData(BaseNodeData):
|
||||
|
||||
@ -55,7 +55,9 @@ class BaseNode(Generic[GenericNodeData]):
|
||||
raise ValueError("Node ID is required.")
|
||||
|
||||
self.node_id = node_id
|
||||
self.node_data: GenericNodeData = cast(GenericNodeData, self._node_data_cls(**config.get("data", {})))
|
||||
|
||||
node_data = self._node_data_cls.model_validate(config.get("data", {}))
|
||||
self.node_data = cast(GenericNodeData, node_data)
|
||||
|
||||
@abstractmethod
|
||||
def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
|
||||
|
||||
@ -4,8 +4,8 @@ import json
|
||||
|
||||
import docx
|
||||
import pandas as pd
|
||||
import pypdfium2
|
||||
import yaml
|
||||
import pypdfium2 # type: ignore
|
||||
import yaml # type: ignore
|
||||
from unstructured.partition.api import partition_via_api
|
||||
from unstructured.partition.email import partition_email
|
||||
from unstructured.partition.epub import partition_epub
|
||||
@ -113,7 +113,7 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
|
||||
def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) -> str:
|
||||
"""Extract text from a file based on its file extension."""
|
||||
match file_extension:
|
||||
case ".txt" | ".markdown" | ".md" | ".html" | ".htm" | ".xml":
|
||||
case ".txt" | ".markdown" | ".md" | ".html" | ".htm" | ".xml" | ".vtt":
|
||||
return _extract_text_from_plain_text(file_content)
|
||||
case ".json":
|
||||
return _extract_text_from_json(file_content)
|
||||
@ -237,15 +237,17 @@ def _extract_text_from_csv(file_content: bytes) -> str:
|
||||
|
||||
def _extract_text_from_excel(file_content: bytes) -> str:
|
||||
"""Extract text from an Excel file using pandas."""
|
||||
|
||||
try:
|
||||
df = pd.read_excel(io.BytesIO(file_content))
|
||||
|
||||
# Drop rows where all elements are NaN
|
||||
df.dropna(how="all", inplace=True)
|
||||
|
||||
# Convert DataFrame to Markdown table
|
||||
markdown_table = df.to_markdown(index=False)
|
||||
excel_file = pd.ExcelFile(io.BytesIO(file_content))
|
||||
markdown_table = ""
|
||||
for sheet_name in excel_file.sheet_names:
|
||||
try:
|
||||
df = excel_file.parse(sheet_name=sheet_name)
|
||||
df.dropna(how="all", inplace=True)
|
||||
# Create Markdown table two times to separate tables with a newline
|
||||
markdown_table += df.to_markdown(index=False) + "\n\n"
|
||||
except Exception as e:
|
||||
continue
|
||||
return markdown_table
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from Excel file: {str(e)}") from e
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from .end_node import EndNode
|
||||
from .entities import EndStreamParam
|
||||
|
||||
__all__ = ["EndStreamParam", "EndNode"]
|
||||
__all__ = ["EndNode", "EndStreamParam"]
|
||||
|
||||
@ -14,11 +14,11 @@ class NodeType(StrEnum):
|
||||
HTTP_REQUEST = "http-request"
|
||||
TOOL = "tool"
|
||||
VARIABLE_AGGREGATOR = "variable-aggregator"
|
||||
VARIABLE_ASSIGNER = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
|
||||
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
|
||||
LOOP = "loop"
|
||||
ITERATION = "iteration"
|
||||
ITERATION_START = "iteration-start" # Fake start node for iteration.
|
||||
PARAMETER_EXTRACTOR = "parameter-extractor"
|
||||
CONVERSATION_VARIABLE_ASSIGNER = "assigner"
|
||||
VARIABLE_ASSIGNER = "assigner"
|
||||
DOCUMENT_EXTRACTOR = "document-extractor"
|
||||
LIST_OPERATOR = "list-operator"
|
||||
|
||||
@ -2,9 +2,9 @@ from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverRes
|
||||
from .types import NodeEvent
|
||||
|
||||
__all__ = [
|
||||
"ModelInvokeCompletedEvent",
|
||||
"NodeEvent",
|
||||
"RunCompletedEvent",
|
||||
"RunRetrieverResourceEvent",
|
||||
"RunStreamChunkEvent",
|
||||
"NodeEvent",
|
||||
"ModelInvokeCompletedEvent",
|
||||
]
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from .entities import BodyData, HttpRequestNodeAuthorization, HttpRequestNodeBody, HttpRequestNodeData
|
||||
from .node import HttpRequestNode
|
||||
|
||||
__all__ = ["HttpRequestNodeData", "HttpRequestNodeAuthorization", "HttpRequestNodeBody", "BodyData", "HttpRequestNode"]
|
||||
__all__ = ["BodyData", "HttpRequestNode", "HttpRequestNodeAuthorization", "HttpRequestNodeBody", "HttpRequestNodeData"]
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from mimetypes import guess_extension
|
||||
from os import path
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
@ -107,6 +105,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
node_data: HttpRequestNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
selectors: list[VariableSelector] = []
|
||||
selectors += variable_template_parser.extract_selectors_from_template(node_data.url)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(node_data.headers)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(node_data.params)
|
||||
if node_data.body:
|
||||
@ -149,11 +148,6 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
content = response.content
|
||||
|
||||
if is_file and content_type:
|
||||
# extract filename from url
|
||||
filename = path.basename(url)
|
||||
# extract extension if possible
|
||||
extension = guess_extension(content_type) or ".bin"
|
||||
|
||||
tool_file = ToolFileManager.create_file_by_raw(
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
@ -164,7 +158,6 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file.id,
|
||||
"type": FileType.IMAGE.value,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE.value,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
|
||||
@ -117,7 +117,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
variable_pool.add([self.node_id, "item"], iterator_list_value[0])
|
||||
|
||||
# init graph engine
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id=self.tenant_id,
|
||||
@ -163,7 +163,8 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
if self.node_data.is_parallel:
|
||||
futures: list[Future] = []
|
||||
q = Queue()
|
||||
thread_pool = GraphEngineThreadPool(max_workers=self.node_data.parallel_nums, max_submit_count=100)
|
||||
thread_pool = graph_engine.workflow_thread_pool_mapping[graph_engine.thread_pool_id]
|
||||
thread_pool._max_workers = self.node_data.parallel_nums
|
||||
for index, item in enumerate(iterator_list_value):
|
||||
future: Future = thread_pool.submit(
|
||||
self._run_single_iter_parallel,
|
||||
@ -299,12 +300,13 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
# variable selector to variable mapping
|
||||
try:
|
||||
# Get node class
|
||||
from core.workflow.nodes.node_mapping import node_type_classes_mapping
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
node_type = NodeType(sub_node_config.get("data", {}).get("type"))
|
||||
node_cls = node_type_classes_mapping.get(node_type)
|
||||
if not node_cls:
|
||||
if node_type not in NODE_TYPE_CLASSES_MAPPING:
|
||||
continue
|
||||
node_version = sub_node_config.get("data", {}).get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, config=sub_node_config
|
||||
|
||||
@ -197,7 +197,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception(f"Node {self.node_id} failed to run")
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
from core.workflow.nodes.answer import AnswerNode
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.code import CodeNode
|
||||
@ -16,26 +18,87 @@ from core.workflow.nodes.start import StartNode
|
||||
from core.workflow.nodes.template_transform import TemplateTransformNode
|
||||
from core.workflow.nodes.tool import ToolNode
|
||||
from core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
||||
from core.workflow.nodes.variable_assigner import VariableAssignerNode
|
||||
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode as VariableAssignerNodeV1
|
||||
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as VariableAssignerNodeV2
|
||||
|
||||
node_type_classes_mapping: dict[NodeType, type[BaseNode]] = {
|
||||
NodeType.START: StartNode,
|
||||
NodeType.END: EndNode,
|
||||
NodeType.ANSWER: AnswerNode,
|
||||
NodeType.LLM: LLMNode,
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
|
||||
NodeType.IF_ELSE: IfElseNode,
|
||||
NodeType.CODE: CodeNode,
|
||||
NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode,
|
||||
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
|
||||
NodeType.HTTP_REQUEST: HttpRequestNode,
|
||||
NodeType.TOOL: ToolNode,
|
||||
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
|
||||
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
|
||||
NodeType.ITERATION: IterationNode,
|
||||
NodeType.ITERATION_START: IterationStartNode,
|
||||
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
|
||||
NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode,
|
||||
NodeType.DOCUMENT_EXTRACTOR: DocumentExtractorNode,
|
||||
NodeType.LIST_OPERATOR: ListOperatorNode,
|
||||
LATEST_VERSION = "latest"
|
||||
|
||||
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
|
||||
NodeType.START: {
|
||||
LATEST_VERSION: StartNode,
|
||||
"1": StartNode,
|
||||
},
|
||||
NodeType.END: {
|
||||
LATEST_VERSION: EndNode,
|
||||
"1": EndNode,
|
||||
},
|
||||
NodeType.ANSWER: {
|
||||
LATEST_VERSION: AnswerNode,
|
||||
"1": AnswerNode,
|
||||
},
|
||||
NodeType.LLM: {
|
||||
LATEST_VERSION: LLMNode,
|
||||
"1": LLMNode,
|
||||
},
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: {
|
||||
LATEST_VERSION: KnowledgeRetrievalNode,
|
||||
"1": KnowledgeRetrievalNode,
|
||||
},
|
||||
NodeType.IF_ELSE: {
|
||||
LATEST_VERSION: IfElseNode,
|
||||
"1": IfElseNode,
|
||||
},
|
||||
NodeType.CODE: {
|
||||
LATEST_VERSION: CodeNode,
|
||||
"1": CodeNode,
|
||||
},
|
||||
NodeType.TEMPLATE_TRANSFORM: {
|
||||
LATEST_VERSION: TemplateTransformNode,
|
||||
"1": TemplateTransformNode,
|
||||
},
|
||||
NodeType.QUESTION_CLASSIFIER: {
|
||||
LATEST_VERSION: QuestionClassifierNode,
|
||||
"1": QuestionClassifierNode,
|
||||
},
|
||||
NodeType.HTTP_REQUEST: {
|
||||
LATEST_VERSION: HttpRequestNode,
|
||||
"1": HttpRequestNode,
|
||||
},
|
||||
NodeType.TOOL: {
|
||||
LATEST_VERSION: ToolNode,
|
||||
"1": ToolNode,
|
||||
},
|
||||
NodeType.VARIABLE_AGGREGATOR: {
|
||||
LATEST_VERSION: VariableAggregatorNode,
|
||||
"1": VariableAggregatorNode,
|
||||
},
|
||||
NodeType.LEGACY_VARIABLE_AGGREGATOR: {
|
||||
LATEST_VERSION: VariableAggregatorNode,
|
||||
"1": VariableAggregatorNode,
|
||||
}, # original name of VARIABLE_AGGREGATOR
|
||||
NodeType.ITERATION: {
|
||||
LATEST_VERSION: IterationNode,
|
||||
"1": IterationNode,
|
||||
},
|
||||
NodeType.ITERATION_START: {
|
||||
LATEST_VERSION: IterationStartNode,
|
||||
"1": IterationStartNode,
|
||||
},
|
||||
NodeType.PARAMETER_EXTRACTOR: {
|
||||
LATEST_VERSION: ParameterExtractorNode,
|
||||
"1": ParameterExtractorNode,
|
||||
},
|
||||
NodeType.VARIABLE_ASSIGNER: {
|
||||
LATEST_VERSION: VariableAssignerNodeV2,
|
||||
"1": VariableAssignerNodeV1,
|
||||
"2": VariableAssignerNodeV2,
|
||||
},
|
||||
NodeType.DOCUMENT_EXTRACTOR: {
|
||||
LATEST_VERSION: DocumentExtractorNode,
|
||||
"1": DocumentExtractorNode,
|
||||
},
|
||||
NodeType.LIST_OPERATOR: {
|
||||
LATEST_VERSION: ListOperatorNode,
|
||||
"1": ListOperatorNode,
|
||||
},
|
||||
}
|
||||
|
||||
@ -235,7 +235,7 @@ class ParameterExtractorNode(LLMNode):
|
||||
raise InvalidInvokeResultError(f"Invalid invoke result: {invoke_result}")
|
||||
|
||||
text = invoke_result.message.content
|
||||
if not isinstance(text, str):
|
||||
if not isinstance(text, str | None):
|
||||
raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.")
|
||||
|
||||
usage = invoke_result.usage
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from .entities import QuestionClassifierNodeData
|
||||
from .question_classifier_node import QuestionClassifierNode
|
||||
|
||||
__all__ = ["QuestionClassifierNodeData", "QuestionClassifierNode"]
|
||||
__all__ = ["QuestionClassifierNode", "QuestionClassifierNodeData"]
|
||||
|
||||
@ -1,8 +0,0 @@
|
||||
from .node import VariableAssignerNode
|
||||
from .node_data import VariableAssignerData, WriteMode
|
||||
|
||||
__all__ = [
|
||||
"VariableAssignerNode",
|
||||
"VariableAssignerData",
|
||||
"WriteMode",
|
||||
]
|
||||
|
||||
4
api/core/workflow/nodes/variable_assigner/common/exc.py
Normal file
4
api/core/workflow/nodes/variable_assigner/common/exc.py
Normal file
@ -0,0 +1,4 @@
|
||||
class VariableOperatorNodeError(Exception):
|
||||
"""Base error type, don't use directly."""
|
||||
|
||||
pass
|
||||
19
api/core/workflow/nodes/variable_assigner/common/helpers.py
Normal file
19
api/core/workflow/nodes/variable_assigner/common/helpers.py
Normal file
@ -0,0 +1,19 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.variables import Variable
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from extensions.ext_database import db
|
||||
from models import ConversationVariable
|
||||
|
||||
|
||||
def update_conversation_variable(conversation_id: str, variable: Variable):
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
row = session.scalar(stmt)
|
||||
if not row:
|
||||
raise VariableOperatorNodeError("conversation variable not found in the database")
|
||||
row.data = variable.model_dump_json()
|
||||
session.commit()
|
||||
@ -1,2 +0,0 @@
|
||||
class VariableAssignerNodeError(Exception):
|
||||
pass
|
||||
3
api/core/workflow/nodes/variable_assigner/v1/__init__.py
Normal file
3
api/core/workflow/nodes/variable_assigner/v1/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .node import VariableAssignerNode
|
||||
|
||||
__all__ = ["VariableAssignerNode"]
|
||||
@ -1,40 +1,36 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.base import BaseNode, BaseNodeData
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from factories import variable_factory
|
||||
from models import ConversationVariable
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .exc import VariableAssignerNodeError
|
||||
from .node_data import VariableAssignerData, WriteMode
|
||||
|
||||
|
||||
class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
||||
_node_data_cls: type[BaseNodeData] = VariableAssignerData
|
||||
_node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
|
||||
_node_type = NodeType.VARIABLE_ASSIGNER
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
|
||||
original_variable = self.graph_runtime_state.variable_pool.get(self.node_data.assigned_variable_selector)
|
||||
if not isinstance(original_variable, Variable):
|
||||
raise VariableAssignerNodeError("assigned variable not found")
|
||||
raise VariableOperatorNodeError("assigned variable not found")
|
||||
|
||||
match self.node_data.write_mode:
|
||||
case WriteMode.OVER_WRITE:
|
||||
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
|
||||
if not income_value:
|
||||
raise VariableAssignerNodeError("input value not found")
|
||||
raise VariableOperatorNodeError("input value not found")
|
||||
updated_variable = original_variable.model_copy(update={"value": income_value.value})
|
||||
|
||||
case WriteMode.APPEND:
|
||||
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
|
||||
if not income_value:
|
||||
raise VariableAssignerNodeError("input value not found")
|
||||
raise VariableOperatorNodeError("input value not found")
|
||||
updated_value = original_variable.value + [income_value.value]
|
||||
updated_variable = original_variable.model_copy(update={"value": updated_value})
|
||||
|
||||
@ -43,7 +39,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
||||
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
|
||||
|
||||
case _:
|
||||
raise VariableAssignerNodeError(f"unsupported write mode: {self.node_data.write_mode}")
|
||||
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")
|
||||
|
||||
# Over write the variable.
|
||||
self.graph_runtime_state.variable_pool.add(self.node_data.assigned_variable_selector, updated_variable)
|
||||
@ -52,8 +48,8 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
||||
# Update conversation variable.
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
|
||||
if not conversation_id:
|
||||
raise VariableAssignerNodeError("conversation_id not found")
|
||||
update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
|
||||
raise VariableOperatorNodeError("conversation_id not found")
|
||||
common_helpers.update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
@ -63,18 +59,6 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
||||
)
|
||||
|
||||
|
||||
def update_conversation_variable(conversation_id: str, variable: Variable):
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
row = session.scalar(stmt)
|
||||
if not row:
|
||||
raise VariableAssignerNodeError("conversation variable not found in the database")
|
||||
row.data = variable.model_dump_json()
|
||||
session.commit()
|
||||
|
||||
|
||||
def get_zero_value(t: SegmentType):
|
||||
match t:
|
||||
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
|
||||
@ -86,4 +70,4 @@ def get_zero_value(t: SegmentType):
|
||||
case SegmentType.NUMBER:
|
||||
return variable_factory.build_segment(0)
|
||||
case _:
|
||||
raise VariableAssignerNodeError(f"unsupported variable type: {t}")
|
||||
raise VariableOperatorNodeError(f"unsupported variable type: {t}")
|
||||
@ -1,6 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Optional
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
@ -12,8 +11,6 @@ class WriteMode(StrEnum):
|
||||
|
||||
|
||||
class VariableAssignerData(BaseNodeData):
|
||||
title: str = "Variable Assigner"
|
||||
desc: Optional[str] = "Assign a value to a variable"
|
||||
assigned_variable_selector: Sequence[str]
|
||||
write_mode: WriteMode
|
||||
input_variable_selector: Sequence[str]
|
||||
3
api/core/workflow/nodes/variable_assigner/v2/__init__.py
Normal file
3
api/core/workflow/nodes/variable_assigner/v2/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .node import VariableAssignerNode
|
||||
|
||||
__all__ = ["VariableAssignerNode"]
|
||||
11
api/core/workflow/nodes/variable_assigner/v2/constants.py
Normal file
11
api/core/workflow/nodes/variable_assigner/v2/constants.py
Normal file
@ -0,0 +1,11 @@
|
||||
from core.variables import SegmentType
|
||||
|
||||
EMPTY_VALUE_MAPPING = {
|
||||
SegmentType.STRING: "",
|
||||
SegmentType.NUMBER: 0,
|
||||
SegmentType.OBJECT: {},
|
||||
SegmentType.ARRAY_ANY: [],
|
||||
SegmentType.ARRAY_STRING: [],
|
||||
SegmentType.ARRAY_NUMBER: [],
|
||||
SegmentType.ARRAY_OBJECT: [],
|
||||
}
|
||||
20
api/core/workflow/nodes/variable_assigner/v2/entities.py
Normal file
20
api/core/workflow/nodes/variable_assigner/v2/entities.py
Normal file
@ -0,0 +1,20 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
from .enums import InputType, Operation
|
||||
|
||||
|
||||
class VariableOperationItem(BaseModel):
|
||||
variable_selector: Sequence[str]
|
||||
input_type: InputType
|
||||
operation: Operation
|
||||
value: Any | None = None
|
||||
|
||||
|
||||
class VariableAssignerNodeData(BaseNodeData):
|
||||
version: str = "2"
|
||||
items: Sequence[VariableOperationItem]
|
||||
18
api/core/workflow/nodes/variable_assigner/v2/enums.py
Normal file
18
api/core/workflow/nodes/variable_assigner/v2/enums.py
Normal file
@ -0,0 +1,18 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class Operation(StrEnum):
|
||||
OVER_WRITE = "over-write"
|
||||
CLEAR = "clear"
|
||||
APPEND = "append"
|
||||
EXTEND = "extend"
|
||||
SET = "set"
|
||||
ADD = "+="
|
||||
SUBTRACT = "-="
|
||||
MULTIPLY = "*="
|
||||
DIVIDE = "/="
|
||||
|
||||
|
||||
class InputType(StrEnum):
|
||||
VARIABLE = "variable"
|
||||
CONSTANT = "constant"
|
||||
31
api/core/workflow/nodes/variable_assigner/v2/exc.py
Normal file
31
api/core/workflow/nodes/variable_assigner/v2/exc.py
Normal file
@ -0,0 +1,31 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
|
||||
from .enums import InputType, Operation
|
||||
|
||||
|
||||
class OperationNotSupportedError(VariableOperatorNodeError):
|
||||
def __init__(self, *, operation: Operation, varialbe_type: str):
|
||||
super().__init__(f"Operation {operation} is not supported for type {varialbe_type}")
|
||||
|
||||
|
||||
class InputTypeNotSupportedError(VariableOperatorNodeError):
|
||||
def __init__(self, *, input_type: InputType, operation: Operation):
|
||||
super().__init__(f"Input type {input_type} is not supported for operation {operation}")
|
||||
|
||||
|
||||
class VariableNotFoundError(VariableOperatorNodeError):
|
||||
def __init__(self, *, variable_selector: Sequence[str]):
|
||||
super().__init__(f"Variable {variable_selector} not found")
|
||||
|
||||
|
||||
class InvalidInputValueError(VariableOperatorNodeError):
|
||||
def __init__(self, *, value: Any):
|
||||
super().__init__(f"Invalid input value {value}")
|
||||
|
||||
|
||||
class ConversationIDNotFoundError(VariableOperatorNodeError):
|
||||
def __init__(self):
|
||||
super().__init__("conversation_id not found")
|
||||
91
api/core/workflow/nodes/variable_assigner/v2/helpers.py
Normal file
91
api/core/workflow/nodes/variable_assigner/v2/helpers.py
Normal file
@ -0,0 +1,91 @@
|
||||
from typing import Any
|
||||
|
||||
from core.variables import SegmentType
|
||||
|
||||
from .enums import Operation
|
||||
|
||||
|
||||
def is_operation_supported(*, variable_type: SegmentType, operation: Operation):
|
||||
match operation:
|
||||
case Operation.OVER_WRITE | Operation.CLEAR:
|
||||
return True
|
||||
case Operation.SET:
|
||||
return variable_type in {SegmentType.OBJECT, SegmentType.STRING, SegmentType.NUMBER}
|
||||
case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE:
|
||||
# Only number variable can be added, subtracted, multiplied or divided
|
||||
return variable_type == SegmentType.NUMBER
|
||||
case Operation.APPEND | Operation.EXTEND:
|
||||
# Only array variable can be appended or extended
|
||||
return variable_type in {
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_FILE,
|
||||
}
|
||||
case _:
|
||||
return False
|
||||
|
||||
|
||||
def is_variable_input_supported(*, operation: Operation):
|
||||
if operation in {Operation.SET, Operation.ADD, Operation.SUBTRACT, Operation.MULTIPLY, Operation.DIVIDE}:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_constant_input_supported(*, variable_type: SegmentType, operation: Operation):
|
||||
match variable_type:
|
||||
case SegmentType.STRING | SegmentType.OBJECT:
|
||||
return operation in {Operation.OVER_WRITE, Operation.SET}
|
||||
case SegmentType.NUMBER:
|
||||
return operation in {
|
||||
Operation.OVER_WRITE,
|
||||
Operation.SET,
|
||||
Operation.ADD,
|
||||
Operation.SUBTRACT,
|
||||
Operation.MULTIPLY,
|
||||
Operation.DIVIDE,
|
||||
}
|
||||
case _:
|
||||
return False
|
||||
|
||||
|
||||
def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, value: Any):
|
||||
if operation == Operation.CLEAR:
|
||||
return True
|
||||
match variable_type:
|
||||
case SegmentType.STRING:
|
||||
return isinstance(value, str)
|
||||
|
||||
case SegmentType.NUMBER:
|
||||
if not isinstance(value, int | float):
|
||||
return False
|
||||
if operation == Operation.DIVIDE and value == 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
case SegmentType.OBJECT:
|
||||
return isinstance(value, dict)
|
||||
|
||||
# Array & Append
|
||||
case SegmentType.ARRAY_ANY if operation == Operation.APPEND:
|
||||
return isinstance(value, str | float | int | dict)
|
||||
case SegmentType.ARRAY_STRING if operation == Operation.APPEND:
|
||||
return isinstance(value, str)
|
||||
case SegmentType.ARRAY_NUMBER if operation == Operation.APPEND:
|
||||
return isinstance(value, int | float)
|
||||
case SegmentType.ARRAY_OBJECT if operation == Operation.APPEND:
|
||||
return isinstance(value, dict)
|
||||
|
||||
# Array & Extend / Overwrite
|
||||
case SegmentType.ARRAY_ANY if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
|
||||
return isinstance(value, list) and all(isinstance(item, str | float | int | dict) for item in value)
|
||||
case SegmentType.ARRAY_STRING if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
|
||||
return isinstance(value, list) and all(isinstance(item, str) for item in value)
|
||||
case SegmentType.ARRAY_NUMBER if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
|
||||
return isinstance(value, list) and all(isinstance(item, int | float) for item in value)
|
||||
case SegmentType.ARRAY_OBJECT if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
|
||||
return isinstance(value, list) and all(isinstance(item, dict) for item in value)
|
||||
|
||||
case _:
|
||||
return False
|
||||
159
api/core/workflow/nodes/variable_assigner/v2/node.py
Normal file
159
api/core/workflow/nodes/variable_assigner/v2/node.py
Normal file
@ -0,0 +1,159 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from . import helpers
|
||||
from .constants import EMPTY_VALUE_MAPPING
|
||||
from .entities import VariableAssignerNodeData
|
||||
from .enums import InputType, Operation
|
||||
from .exc import (
|
||||
ConversationIDNotFoundError,
|
||||
InputTypeNotSupportedError,
|
||||
InvalidInputValueError,
|
||||
OperationNotSupportedError,
|
||||
VariableNotFoundError,
|
||||
)
|
||||
|
||||
|
||||
class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
|
||||
_node_data_cls = VariableAssignerNodeData
|
||||
_node_type = NodeType.VARIABLE_ASSIGNER
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
inputs = self.node_data.model_dump()
|
||||
process_data = {}
|
||||
# NOTE: This node has no outputs
|
||||
updated_variables: list[Variable] = []
|
||||
|
||||
try:
|
||||
for item in self.node_data.items:
|
||||
variable = self.graph_runtime_state.variable_pool.get(item.variable_selector)
|
||||
|
||||
# ==================== Validation Part
|
||||
|
||||
# Check if variable exists
|
||||
if not isinstance(variable, Variable):
|
||||
raise VariableNotFoundError(variable_selector=item.variable_selector)
|
||||
|
||||
# Check if operation is supported
|
||||
if not helpers.is_operation_supported(variable_type=variable.value_type, operation=item.operation):
|
||||
raise OperationNotSupportedError(operation=item.operation, varialbe_type=variable.value_type)
|
||||
|
||||
# Check if variable input is supported
|
||||
if item.input_type == InputType.VARIABLE and not helpers.is_variable_input_supported(
|
||||
operation=item.operation
|
||||
):
|
||||
raise InputTypeNotSupportedError(input_type=InputType.VARIABLE, operation=item.operation)
|
||||
|
||||
# Check if constant input is supported
|
||||
if item.input_type == InputType.CONSTANT and not helpers.is_constant_input_supported(
|
||||
variable_type=variable.value_type, operation=item.operation
|
||||
):
|
||||
raise InputTypeNotSupportedError(input_type=InputType.CONSTANT, operation=item.operation)
|
||||
|
||||
# Get value from variable pool
|
||||
if (
|
||||
item.input_type == InputType.VARIABLE
|
||||
and item.operation != Operation.CLEAR
|
||||
and item.value is not None
|
||||
):
|
||||
value = self.graph_runtime_state.variable_pool.get(item.value)
|
||||
if value is None:
|
||||
raise VariableNotFoundError(variable_selector=item.value)
|
||||
# Skip if value is NoneSegment
|
||||
if value.value_type == SegmentType.NONE:
|
||||
continue
|
||||
item.value = value.value
|
||||
|
||||
# If set string / bytes / bytearray to object, try convert string to object.
|
||||
if (
|
||||
item.operation == Operation.SET
|
||||
and variable.value_type == SegmentType.OBJECT
|
||||
and isinstance(item.value, str | bytes | bytearray)
|
||||
):
|
||||
try:
|
||||
item.value = json.loads(item.value)
|
||||
except json.JSONDecodeError:
|
||||
raise InvalidInputValueError(value=item.value)
|
||||
|
||||
# Check if input value is valid
|
||||
if not helpers.is_input_value_valid(
|
||||
variable_type=variable.value_type, operation=item.operation, value=item.value
|
||||
):
|
||||
raise InvalidInputValueError(value=item.value)
|
||||
|
||||
# ==================== Execution Part
|
||||
|
||||
updated_value = self._handle_item(
|
||||
variable=variable,
|
||||
operation=item.operation,
|
||||
value=item.value,
|
||||
)
|
||||
variable = variable.model_copy(update={"value": updated_value})
|
||||
updated_variables.append(variable)
|
||||
except VariableOperatorNodeError as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
# Update variables
|
||||
for variable in updated_variables:
|
||||
self.graph_runtime_state.variable_pool.add(variable.selector, variable)
|
||||
process_data[variable.name] = variable.value
|
||||
|
||||
if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID:
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
|
||||
if not conversation_id:
|
||||
raise ConversationIDNotFoundError
|
||||
else:
|
||||
conversation_id = conversation_id.value
|
||||
common_helpers.update_conversation_variable(
|
||||
conversation_id=conversation_id,
|
||||
variable=variable,
|
||||
)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
)
|
||||
|
||||
def _handle_item(
|
||||
self,
|
||||
*,
|
||||
variable: Variable,
|
||||
operation: Operation,
|
||||
value: Any,
|
||||
):
|
||||
match operation:
|
||||
case Operation.OVER_WRITE:
|
||||
return value
|
||||
case Operation.CLEAR:
|
||||
return EMPTY_VALUE_MAPPING[variable.value_type]
|
||||
case Operation.APPEND:
|
||||
return variable.value + [value]
|
||||
case Operation.EXTEND:
|
||||
return variable.value + value
|
||||
case Operation.SET:
|
||||
return value
|
||||
case Operation.ADD:
|
||||
return variable.value + value
|
||||
case Operation.SUBTRACT:
|
||||
return variable.value - value
|
||||
case Operation.MULTIPLY:
|
||||
return variable.value * value
|
||||
case Operation.DIVIDE:
|
||||
return variable.value / value
|
||||
case _:
|
||||
raise OperationNotSupportedError(operation=operation, varialbe_type=variable.value_type)
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, Optional
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
||||
@ -19,7 +19,7 @@ from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.event import NodeEvent
|
||||
from core.workflow.nodes.node_mapping import node_type_classes_mapping
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from factories import file_factory
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import (
|
||||
@ -145,11 +145,8 @@ class WorkflowEntry:
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType(node_config.get("data", {}).get("type"))
|
||||
node_cls = node_type_classes_mapping.get(node_type)
|
||||
node_cls = cast(type[BaseNode], node_cls)
|
||||
|
||||
if not node_cls:
|
||||
raise ValueError(f"Node class not found for node type {node_type}")
|
||||
node_version = node_config.get("data", {}).get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(environment_variables=workflow.environment_variables)
|
||||
|
||||
Reference in New Issue
Block a user