feat(api): Enhance multi modal support.

This commit is contained in:
-LAN-
2024-09-02 00:46:59 +08:00
parent 7838f9f3a3
commit ea18dd1571
228 changed files with 5324 additions and 3062 deletions

View File

@ -0,0 +1,7 @@
from .base_workflow_callback import WorkflowCallback
from .workflow_logging_callback import WorkflowLoggingCallback
__all__ = [
"WorkflowLoggingCallback",
"WorkflowCallback",
]

View File

@ -0,0 +1,221 @@
from typing import Optional
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
IterationRunFailedEvent,
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeRunFailedEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ParallelBranchRunFailedEvent,
ParallelBranchRunStartedEvent,
ParallelBranchRunSucceededEvent,
)
from .base_workflow_callback import WorkflowCallback
_TEXT_COLOR_MAPPING = {
"blue": "36;1",
"yellow": "33;1",
"pink": "38;5;200",
"green": "32;1",
"red": "31;1",
}
class WorkflowLoggingCallback(WorkflowCallback):
def __init__(self) -> None:
self.current_node_id = None
def on_event(self, event: GraphEngineEvent) -> None:
if isinstance(event, GraphRunStartedEvent):
self.print_text("\n[GraphRunStartedEvent]", color="pink")
elif isinstance(event, GraphRunSucceededEvent):
self.print_text("\n[GraphRunSucceededEvent]", color="green")
elif isinstance(event, GraphRunFailedEvent):
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red")
elif isinstance(event, NodeRunStartedEvent):
self.on_workflow_node_execute_started(event=event)
elif isinstance(event, NodeRunSucceededEvent):
self.on_workflow_node_execute_succeeded(event=event)
elif isinstance(event, NodeRunFailedEvent):
self.on_workflow_node_execute_failed(event=event)
elif isinstance(event, NodeRunStreamChunkEvent):
self.on_node_text_chunk(event=event)
elif isinstance(event, ParallelBranchRunStartedEvent):
self.on_workflow_parallel_started(event=event)
elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent):
self.on_workflow_parallel_completed(event=event)
elif isinstance(event, IterationRunStartedEvent):
self.on_workflow_iteration_started(event=event)
elif isinstance(event, IterationRunNextEvent):
self.on_workflow_iteration_next(event=event)
elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent):
self.on_workflow_iteration_completed(event=event)
else:
self.print_text(f"\n[{event.__class__.__name__}]", color="blue")
def on_workflow_node_execute_started(self, event: NodeRunStartedEvent) -> None:
"""
Workflow node execute started
"""
self.print_text("\n[NodeRunStartedEvent]", color="yellow")
self.print_text(f"Node ID: {event.node_id}", color="yellow")
self.print_text(f"Node Title: {event.node_data.title}", color="yellow")
self.print_text(f"Type: {event.node_type.value}", color="yellow")
def on_workflow_node_execute_succeeded(self, event: NodeRunSucceededEvent) -> None:
"""
Workflow node execute succeeded
"""
route_node_state = event.route_node_state
self.print_text("\n[NodeRunSucceededEvent]", color="green")
self.print_text(f"Node ID: {event.node_id}", color="green")
self.print_text(f"Node Title: {event.node_data.title}", color="green")
self.print_text(f"Type: {event.node_type.value}", color="green")
if route_node_state.node_run_result:
node_run_result = route_node_state.node_run_result
self.print_text(
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
color="green",
)
self.print_text(
f"Process Data: "
f"{jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
color="green",
)
self.print_text(
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
color="green",
)
self.print_text(
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}",
color="green",
)
def on_workflow_node_execute_failed(self, event: NodeRunFailedEvent) -> None:
"""
Workflow node execute failed
"""
route_node_state = event.route_node_state
self.print_text("\n[NodeRunFailedEvent]", color="red")
self.print_text(f"Node ID: {event.node_id}", color="red")
self.print_text(f"Node Title: {event.node_data.title}", color="red")
self.print_text(f"Type: {event.node_type.value}", color="red")
if route_node_state.node_run_result:
node_run_result = route_node_state.node_run_result
self.print_text(f"Error: {node_run_result.error}", color="red")
self.print_text(
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
color="red",
)
self.print_text(
f"Process Data: "
f"{jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
color="red",
)
self.print_text(
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
color="red",
)
def on_node_text_chunk(self, event: NodeRunStreamChunkEvent) -> None:
"""
Publish text chunk
"""
route_node_state = event.route_node_state
if not self.current_node_id or self.current_node_id != route_node_state.node_id:
self.current_node_id = route_node_state.node_id
self.print_text("\n[NodeRunStreamChunkEvent]")
self.print_text(f"Node ID: {route_node_state.node_id}")
node_run_result = route_node_state.node_run_result
if node_run_result:
self.print_text(
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}"
)
self.print_text(event.chunk_content, color="pink", end="")
def on_workflow_parallel_started(self, event: ParallelBranchRunStartedEvent) -> None:
"""
Publish parallel started
"""
self.print_text("\n[ParallelBranchRunStartedEvent]", color="blue")
self.print_text(f"Parallel ID: {event.parallel_id}", color="blue")
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color="blue")
if event.in_iteration_id:
self.print_text(f"Iteration ID: {event.in_iteration_id}", color="blue")
def on_workflow_parallel_completed(
self, event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
) -> None:
"""
Publish parallel completed
"""
if isinstance(event, ParallelBranchRunSucceededEvent):
color = "blue"
elif isinstance(event, ParallelBranchRunFailedEvent):
color = "red"
self.print_text(
"\n[ParallelBranchRunSucceededEvent]"
if isinstance(event, ParallelBranchRunSucceededEvent)
else "\n[ParallelBranchRunFailedEvent]",
color=color,
)
self.print_text(f"Parallel ID: {event.parallel_id}", color=color)
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color)
if event.in_iteration_id:
self.print_text(f"Iteration ID: {event.in_iteration_id}", color=color)
if isinstance(event, ParallelBranchRunFailedEvent):
self.print_text(f"Error: {event.error}", color=color)
def on_workflow_iteration_started(self, event: IterationRunStartedEvent) -> None:
"""
Publish iteration started
"""
self.print_text("\n[IterationRunStartedEvent]", color="blue")
self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
def on_workflow_iteration_next(self, event: IterationRunNextEvent) -> None:
"""
Publish iteration next
"""
self.print_text("\n[IterationRunNextEvent]", color="blue")
self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
self.print_text(f"Iteration Index: {event.index}", color="blue")
def on_workflow_iteration_completed(self, event: IterationRunSucceededEvent | IterationRunFailedEvent) -> None:
"""
Publish iteration completed
"""
self.print_text(
"\n[IterationRunSucceededEvent]"
if isinstance(event, IterationRunSucceededEvent)
else "\n[IterationRunFailedEvent]",
color="blue",
)
self.print_text(f"Node ID: {event.iteration_id}", color="blue")
def print_text(self, text: str, color: Optional[str] = None, end: str = "\n") -> None:
"""Print text with highlighting and no end characters."""
text_to_print = self._get_colored_text(text, color) if color else text
print(f"{text_to_print}", end=end)
def _get_colored_text(self, text: str, color: str) -> str:
"""Get colored text."""
color_str = _TEXT_COLOR_MAPPING[color]
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"

View File

@ -0,0 +1,3 @@
SYSTEM_VARIABLE_NODE_ID = "sys"
ENVIRONMENT_VARIABLE_NODE_ID = "env"
CONVERSATION_VARIABLE_NODE_ID = "conversation"

View File

@ -1,52 +1,14 @@
from collections.abc import Mapping
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMUsage
from models import WorkflowNodeExecutionStatus
from models.workflow import WorkflowNodeExecutionStatus
class NodeType(Enum):
"""
Node Types.
"""
START = "start"
END = "end"
ANSWER = "answer"
LLM = "llm"
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
IF_ELSE = "if-else"
CODE = "code"
TEMPLATE_TRANSFORM = "template-transform"
QUESTION_CLASSIFIER = "question-classifier"
HTTP_REQUEST = "http-request"
TOOL = "tool"
VARIABLE_AGGREGATOR = "variable-aggregator"
# TODO: merge this into VARIABLE_AGGREGATOR
VARIABLE_ASSIGNER = "variable-assigner"
LOOP = "loop"
ITERATION = "iteration"
ITERATION_START = "iteration-start" # fake start node for iteration
PARAMETER_EXTRACTOR = "parameter-extractor"
CONVERSATION_VARIABLE_ASSIGNER = "assigner"
@classmethod
def value_of(cls, value: str) -> "NodeType":
"""
Get value of given node type.
:param value: node type value
:return: node type
"""
for node_type in cls:
if node_type.value == value:
return node_type
raise ValueError(f"invalid node type value {value}")
class NodeRunMetadataKey(Enum):
class NodeRunMetadataKey(str, Enum):
"""
Node Run Metadata Key.
"""
@ -70,7 +32,7 @@ class NodeRunResult(BaseModel):
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
inputs: Optional[dict[str, Any]] = None # node inputs
inputs: Optional[Mapping[str, Any]] = None # node inputs
process_data: Optional[dict[str, Any]] = None # process data
outputs: Optional[dict[str, Any]] = None # node outputs
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
@ -79,24 +41,3 @@ class NodeRunResult(BaseModel):
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
error: Optional[str] = None # error message if status is failed
class UserFrom(Enum):
"""
User from
"""
ACCOUNT = "account"
END_USER = "end-user"
@classmethod
def value_of(cls, value: str) -> "UserFrom":
"""
Value of
:param value: value
:return:
"""
for item in cls:
if item.value == value:
return item
raise ValueError(f"Invalid value: {value}")

View File

@ -1,3 +1,5 @@
from collections.abc import Sequence
from pydantic import BaseModel
@ -7,4 +9,4 @@ class VariableSelector(BaseModel):
"""
variable: str
value_selector: list[str]
value_selector: Sequence[str]

View File

@ -1,20 +1,23 @@
import re
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Any, Union
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field
from typing_extensions import deprecated
from core.app.segments import Segment, Variable, factory
from core.file.file_obj import FileVar
from core.workflow.enums import SystemVariableKey
from core.file import File, FileAttribute, file_manager
from core.variables import Segment, SegmentGroup, Variable
from core.variables.segments import FileSegment
from factories import variable_factory
VariableValue = Union[str, int, float, dict, list, FileVar]
from ..constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from ..enums import SystemVariableKey
VariableValue = Union[str, int, float, dict, list, File]
SYSTEM_VARIABLE_NODE_ID = "sys"
ENVIRONMENT_VARIABLE_NODE_ID = "env"
CONVERSATION_VARIABLE_NODE_ID = "conversation"
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
class VariablePool(BaseModel):
@ -23,46 +26,61 @@ class VariablePool(BaseModel):
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
variable_dictionary: dict[str, dict[int, Segment]] = Field(
description="Variables mapping", default=defaultdict(dict)
description="Variables mapping",
default=defaultdict(dict),
)
# TODO: This user inputs is not used for pool.
user_inputs: Mapping[str, Any] = Field(
description="User inputs",
)
system_variables: Mapping[SystemVariableKey, Any] = Field(
description="System variables",
)
environment_variables: Sequence[Variable] = Field(
description="Environment variables.",
default_factory=list,
)
conversation_variables: Sequence[Variable] = Field(
description="Conversation variables.",
default_factory=list,
)
environment_variables: Sequence[Variable] = Field(description="Environment variables.", default_factory=list)
def __init__(
self,
*,
system_variables: Mapping[SystemVariableKey, Any],
user_inputs: Mapping[str, Any],
environment_variables: Sequence[Variable] | None = None,
conversation_variables: Sequence[Variable] | None = None,
**kwargs,
):
environment_variables = environment_variables or []
conversation_variables = conversation_variables or []
conversation_variables: Sequence[Variable] | None = None
super().__init__(
system_variables=system_variables,
user_inputs=user_inputs,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
**kwargs,
)
@model_validator(mode="after")
def val_model_after(self):
"""
Append system variables
:return:
"""
# Add system variables to the variable pool
for key, value in self.system_variables.items():
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
# Add environment variables to the variable pool
for var in self.environment_variables or []:
for var in self.environment_variables:
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
# Add conversation variables to the variable pool
for var in self.conversation_variables or []:
for var in self.conversation_variables:
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
return self
def add(self, selector: Sequence[str], value: Any, /) -> None:
"""
Adds a variable to the variable pool.
NOTE: You should not add a non-Segment value to the variable pool
even if it is allowed now.
Args:
selector (Sequence[str]): The selector for the variable.
value (VariableValue): The value of the variable.
@ -82,7 +100,7 @@ class VariablePool(BaseModel):
if isinstance(value, Segment):
v = value
else:
v = factory.build_segment(value)
v = variable_factory.build_segment(value)
hash_key = hash(tuple(selector[1:]))
self.variable_dictionary[selector[0]][hash_key] = v
@ -101,10 +119,19 @@ class VariablePool(BaseModel):
ValueError: If the selector is invalid.
"""
if len(selector) < 2:
raise ValueError("Invalid selector")
return None
hash_key = hash(tuple(selector[1:]))
value = self.variable_dictionary[selector[0]].get(hash_key)
if value is None:
selector, attr = selector[:-1], selector[-1]
value = self.get(selector)
if isinstance(value, FileSegment):
attr = FileAttribute(attr)
attr_value = file_manager.get_attr(file=value.value, attr=attr)
return variable_factory.build_segment(attr_value)
return value
@deprecated("This method is deprecated, use `get` instead.")
@ -145,14 +172,18 @@ class VariablePool(BaseModel):
hash_key = hash(tuple(selector[1:]))
self.variable_dictionary[selector[0]].pop(hash_key, None)
def remove_node(self, node_id: str, /):
"""
Remove all variables associated with a given node id.
def convert_template(self, template: str, /):
parts = VARIABLE_PATTERN.split(template)
segments = []
for part in filter(lambda x: x, parts):
if "." in part and (variable := self.get(part.split("."))):
segments.append(variable)
else:
segments.append(variable_factory.build_segment(part))
return SegmentGroup(value=segments)
Args:
node_id (str): The node id to remove.
Returns:
None
"""
self.variable_dictionary.pop(node_id, None)
def get_file(self, selector: Sequence[str], /) -> FileSegment | None:
segment = self.get(selector)
if isinstance(segment, FileSegment):
return segment
return None

View File

@ -3,12 +3,14 @@ from typing import Optional
from pydantic import BaseModel
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.base_node_data_entities import BaseIterationState
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode, UserFrom
from core.workflow.nodes.base_node import BaseNode
from enums import UserFrom
from models.workflow import Workflow, WorkflowType
from .base_node_data_entities import BaseIterationState
from .node_entities import NodeRunResult
from .variable_pool import VariablePool
class WorkflowNodeAndResult:
node: BaseNode

View File

@ -0,0 +1,3 @@
from .entities import Graph, GraphInitParams, GraphRuntimeState, RuntimeRouteState
__all__ = ["Graph", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]

View File

@ -18,11 +18,10 @@ class ConditionRunConditionHandlerHandler(RunConditionHandler):
# process condition
condition_processor = ConditionProcessor()
input_conditions, group_result = condition_processor.process_conditions(
variable_pool=graph_runtime_state.variable_pool, conditions=self.condition.conditions
_, _, final_result = condition_processor.process_conditions(
variable_pool=graph_runtime_state.variable_pool,
conditions=self.condition.conditions,
operator="and",
)
# Apply the logical operator for the current case
compare_result = all(group_result)
return compare_result
return final_result

View File

@ -0,0 +1,6 @@
from .graph import Graph
from .graph_init_params import GraphInitParams
from .graph_runtime_state import GraphRuntimeState
from .runtime_route_state import RuntimeRouteState
__all__ = ["Graph", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]

View File

@ -4,8 +4,8 @@ from typing import Any, Optional
from pydantic import BaseModel, Field
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from enums import NodeType
class GraphEngineEvent(BaseModel):

View File

@ -4,12 +4,12 @@ from typing import Any, Optional, cast
from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeType
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute
from core.workflow.nodes.end.end_stream_generate_router import EndStreamGeneratorRouter
from core.workflow.nodes.end.entities import EndStreamParam
from enums import NodeType
class GraphEdge(BaseModel):

View File

@ -4,7 +4,7 @@ from typing import Any
from pydantic import BaseModel, Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import UserFrom
from enums import UserFrom
from models.workflow import WorkflowType

View File

@ -10,11 +10,7 @@ from flask import Flask, current_app
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import (
NodeRunMetadataKey,
NodeType,
UserFrom,
)
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
from core.workflow.graph_engine.entities.event import (
@ -41,6 +37,7 @@ from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import node_classes
from enums import NodeType, UserFrom
from extensions.ext_database import db
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType

View File

@ -0,0 +1,3 @@
from .entities import AnswerStreamGenerateRoute
__all__ = ["AnswerStreamGenerateRoute"]

View File

@ -1,7 +1,8 @@
from collections.abc import Mapping, Sequence
from typing import Any, cast
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.variables import ArrayFileSegment, FileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
from core.workflow.nodes.answer.entities import (
AnswerNodeData,
@ -11,10 +12,11 @@ from core.workflow.nodes.answer.entities import (
)
from core.workflow.nodes.base_node import BaseNode
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus
class AnswerNode(BaseNode):
class AnswerNode(BaseNode[AnswerNodeData]):
_node_data_cls = AnswerNodeData
_node_type: NodeType = NodeType.ANSWER
@ -23,30 +25,35 @@ class AnswerNode(BaseNode):
Run node
:return:
"""
node_data = self.node_data
node_data = cast(AnswerNodeData, node_data)
# generate routes
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(node_data)
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self.node_data)
answer = ""
files = []
for part in generate_routes:
if part.type == GenerateRouteChunk.ChunkType.VAR:
part = cast(VarGenerateRouteChunk, part)
value_selector = part.value_selector
value = self.graph_runtime_state.variable_pool.get(value_selector)
if value:
answer += value.markdown
variable = self.graph_runtime_state.variable_pool.get(value_selector)
if variable:
if isinstance(variable, FileSegment):
files.append(variable.value)
elif isinstance(variable, ArrayFileSegment):
files.extend(variable.value)
answer += variable.markdown
else:
part = cast(TextGenerateRouteChunk, part)
answer += part.text
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer})
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer, "files": files})
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: AnswerNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: AnswerNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@ -55,9 +62,6 @@ class AnswerNode(BaseNode):
:param node_data: node data
:return:
"""
node_data = node_data
node_data = cast(AnswerNodeData, node_data)
variable_template_parser = VariableTemplateParser(template=node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors()

View File

@ -1,5 +1,4 @@
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.answer.entities import (
AnswerNodeData,
AnswerStreamGenerateRoute,
@ -8,6 +7,7 @@ from core.workflow.nodes.answer.entities import (
VarGenerateRouteChunk,
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from enums import NodeType
class AnswerStreamGeneratorRouter:

View File

@ -1,8 +1,8 @@
import logging
from collections.abc import Generator
from typing import Optional, cast
from typing import cast
from core.file.file_obj import FileVar
from core.file import FILE_MODEL_IDENTITY, File
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
@ -203,7 +203,7 @@ class AnswerStreamProcessor(StreamProcessor):
return files
@classmethod
def _get_file_var_from_value(cls, value: dict | list) -> Optional[dict]:
def _get_file_var_from_value(cls, value: dict | list):
"""
Get file var from value
:param value: variable value
@ -213,9 +213,9 @@ class AnswerStreamProcessor(StreamProcessor):
return None
if isinstance(value, dict):
if "__variant" in value and value["__variant"] == FileVar.__name__:
if "dify_model_identity" in value and value["dify_model_identity"] == FILE_MODEL_IDENTITY:
return value
elif isinstance(value, FileVar):
elif isinstance(value, File):
return value.to_dict()
return None

View File

@ -1,17 +1,24 @@
from abc import ABC, abstractmethod
import logging
from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional
from typing import Any, Generic, Optional, TypeVar, cast
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
from enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus
logger = logging.getLogger(__name__)
GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData)
class BaseNode(ABC):
class BaseNode(Generic[GenericNodeData]):
_node_data_cls: type[BaseNodeData]
_node_type: NodeType
@ -45,7 +52,7 @@ class BaseNode(ABC):
raise ValueError("Node ID is required.")
self.node_id = node_id
self.node_data = self._node_data_cls(**config.get("data", {}))
self.node_data: GenericNodeData = cast(GenericNodeData, self._node_data_cls(**config.get("data", {})))
@abstractmethod
def _run(self) -> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]:
@ -56,11 +63,14 @@ class BaseNode(ABC):
raise NotImplementedError
def run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
"""
Run node entry
:return:
"""
result = self._run()
try:
result = self._run()
except Exception as e:
logger.error(f"Node {self.node_id} failed to run: {e}")
result = NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
)
if isinstance(result, NodeRunResult):
yield RunCompletedEvent(run_result=result)
@ -69,7 +79,10 @@ class BaseNode(ABC):
@classmethod
def extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], config: dict
cls,
*,
graph_config: Mapping[str, Any],
config: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@ -83,12 +96,16 @@ class BaseNode(ABC):
node_data = cls._node_data_cls(**config.get("data", {}))
return cls._extract_variable_selector_to_variable_mapping(
graph_config=graph_config, node_id=node_id, node_data=node_data
graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data)
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: BaseNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: GenericNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping

View File

@ -1,18 +1,19 @@
from collections.abc import Mapping, Sequence
from typing import Any, Optional, Union, cast
from typing import Any, Optional, Union
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.code.entities import CodeNodeData
from enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus
class CodeNode(BaseNode):
class CodeNode(BaseNode[CodeNodeData]):
_node_data_cls = CodeNodeData
_node_type = NodeType.CODE
@ -33,20 +34,13 @@ class CodeNode(BaseNode):
return code_provider.get_default_config()
def _run(self) -> NodeRunResult:
"""
Run code
:return:
"""
node_data = self.node_data
node_data = cast(CodeNodeData, node_data)
# Get code language
code_language = node_data.code_language
code = node_data.code
code_language = self.node_data.code_language
code = self.node_data.code
# Get variables
variables = {}
for variable_selector in node_data.variables:
for variable_selector in self.node_data.variables:
variable = variable_selector.variable
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
@ -60,7 +54,7 @@ class CodeNode(BaseNode):
)
# Transform result
result = self._transform_result(result, node_data.outputs)
result = self._transform_result(result, self.node_data.outputs)
except (CodeExecutionError, ValueError) as e:
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e))
@ -316,7 +310,11 @@ class CodeNode(BaseNode):
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: CodeNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: CodeNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping

View File

@ -0,0 +1,4 @@
from .document_extractor_node import DocumentExtractorNode
from .models import DocumentExtractorNodeData
__all__ = ["DocumentExtractorNode", "DocumentExtractorNodeData"]

View File

@ -0,0 +1,246 @@
import csv
import io
import docx
import pandas as pd
import pypdfium2
from unstructured.partition.email import partition_email
from unstructured.partition.epub import partition_epub
from unstructured.partition.msg import partition_msg
from unstructured.partition.ppt import partition_ppt
from unstructured.partition.pptx import partition_pptx
from core.file import File, FileTransferMethod, file_manager
from core.helper import ssrf_proxy
from core.variables import ArrayFileSegment
from core.variables.segments import FileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base_node import BaseNode
from enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
from .models import DocumentExtractorNodeData
class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
"""
Extracts text content from various file types.
Supports plain text, PDF, and DOC/DOCX files.
"""
_node_data_cls = DocumentExtractorNodeData
_node_type = NodeType.DOCUMENT_EXTRACTOR
def _run(self):
variable_selector = self.node_data.variable_selector
variable = self.graph_runtime_state.variable_pool.get(variable_selector)
if variable is None:
error_message = f"File variable not found for selector: {variable_selector}"
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=error_message)
if variable.value and not isinstance(variable, ArrayFileSegment | FileSegment):
error_message = f"Variable {variable_selector} is not an ArrayFileSegment"
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=error_message)
value = variable.value
inputs = {"variable_selector": variable_selector}
process_data = {"documents": value if isinstance(value, list) else [value]}
try:
if isinstance(value, list):
extracted_text_list = list(map(_extract_text_from_file, value))
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs={"text": extracted_text_list},
)
elif isinstance(value, File):
extracted_text = _extract_text_from_file(value)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs={"text": extracted_text},
)
else:
raise DocumentExtractorError(f"Unsupported variable type: {type(value)}")
except DocumentExtractorError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=inputs,
process_data=process_data,
)
def _extract_text(*, file_content: bytes, mime_type: str) -> str:
"""Extract text from a file based on its MIME type."""
if mime_type.startswith("text/plain") or mime_type in {"text/html", "text/htm", "text/markdown", "text/xml"}:
return _extract_text_from_plain_text(file_content)
elif mime_type == "application/pdf":
return _extract_text_from_pdf(file_content)
elif mime_type in {
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/msword",
}:
return _extract_text_from_doc(file_content)
elif mime_type == "text/csv":
return _extract_text_from_csv(file_content)
elif mime_type in {
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"application/vnd.ms-excel",
}:
return _extract_text_from_excel(file_content)
elif mime_type == "application/vnd.ms-powerpoint":
return _extract_text_from_ppt(file_content)
elif mime_type == "application/vnd.openxmlformats-officedocument.presentationml.presentation":
return _extract_text_from_pptx(file_content)
elif mime_type == "application/epub+zip":
return _extract_text_from_epub(file_content)
elif mime_type == "message/rfc822":
return _extract_text_from_eml(file_content)
elif mime_type == "application/vnd.ms-outlook":
return _extract_text_from_msg(file_content)
else:
raise UnsupportedFileTypeError(f"Unsupported MIME type: {mime_type}")
def _extract_text_from_plain_text(file_content: bytes) -> str:
try:
return file_content.decode("utf-8")
except UnicodeDecodeError as e:
raise TextExtractionError("Failed to decode plain text file") from e
def _extract_text_from_pdf(file_content: bytes) -> str:
try:
pdf_file = io.BytesIO(file_content)
pdf_document = pypdfium2.PdfDocument(pdf_file, autoclose=True)
text = ""
for page in pdf_document:
text_page = page.get_textpage()
text += text_page.get_text_range()
text_page.close()
page.close()
return text
except Exception as e:
raise TextExtractionError(f"Failed to extract text from PDF: {str(e)}") from e
def _extract_text_from_doc(file_content: bytes) -> str:
try:
doc_file = io.BytesIO(file_content)
doc = docx.Document(doc_file)
return "\n".join([paragraph.text for paragraph in doc.paragraphs])
except Exception as e:
raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e
def _download_file_content(file: File) -> bytes:
"""Download the content of a file based on its transfer method."""
try:
if file.transfer_method == FileTransferMethod.REMOTE_URL:
if file.remote_url is None:
raise FileDownloadError("Missing URL for remote file")
response = ssrf_proxy.get(file.remote_url)
response.raise_for_status()
return response.content
elif file.transfer_method == FileTransferMethod.LOCAL_FILE:
if file.related_id is None:
raise FileDownloadError("Missing file ID for local file")
return file_manager.download(upload_file_id=file.related_id, tenant_id=file.tenant_id)
else:
raise ValueError(f"Unsupported transfer method: {file.transfer_method}")
except Exception as e:
raise FileDownloadError(f"Error downloading file: {str(e)}") from e
def _extract_text_from_file(file: File):
if file.mime_type is None:
raise UnsupportedFileTypeError("Unable to determine file type: MIME type is missing")
file_content = _download_file_content(file)
extracted_text = _extract_text(file_content=file_content, mime_type=file.mime_type)
return extracted_text
def _extract_text_from_csv(file_content: bytes) -> str:
try:
csv_file = io.StringIO(file_content.decode("utf-8"))
csv_reader = csv.reader(csv_file)
rows = list(csv_reader)
if not rows:
return ""
# Create markdown table
markdown_table = "| " + " | ".join(rows[0]) + " |\n"
markdown_table += "| " + " | ".join(["---"] * len(rows[0])) + " |\n"
for row in rows[1:]:
markdown_table += "| " + " | ".join(row) + " |\n"
return markdown_table.strip()
except Exception as e:
raise TextExtractionError(f"Failed to extract text from CSV: {str(e)}") from e
def _extract_text_from_excel(file_content: bytes) -> str:
"""Extract text from an Excel file using pandas."""
try:
df = pd.read_excel(io.BytesIO(file_content))
# Drop rows where all elements are NaN
df.dropna(how="all", inplace=True)
# Convert DataFrame to markdown table
markdown_table = df.to_markdown(index=False)
return markdown_table
except Exception as e:
raise TextExtractionError(f"Failed to extract text from Excel file: {str(e)}") from e
def _extract_text_from_ppt(file_content: bytes) -> str:
try:
with io.BytesIO(file_content) as file:
elements = partition_ppt(file=file)
return "\n".join([getattr(element, "text", "") for element in elements])
except Exception as e:
raise TextExtractionError(f"Failed to extract text from PPT: {str(e)}") from e
def _extract_text_from_pptx(file_content: bytes) -> str:
try:
with io.BytesIO(file_content) as file:
elements = partition_pptx(file=file)
return "\n".join([getattr(element, "text", "") for element in elements])
except Exception as e:
raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e
def _extract_text_from_epub(file_content: bytes) -> str:
try:
with io.BytesIO(file_content) as file:
elements = partition_epub(file=file)
return "\n".join([str(element) for element in elements])
except Exception as e:
raise TextExtractionError(f"Failed to extract text from EPUB: {str(e)}") from e
def _extract_text_from_eml(file_content: bytes) -> str:
try:
with io.BytesIO(file_content) as file:
elements = partition_email(file=file)
return "\n".join([str(element) for element in elements])
except Exception as e:
raise TextExtractionError(f"Failed to extract text from EML: {str(e)}") from e
def _extract_text_from_msg(file_content: bytes) -> str:
try:
with io.BytesIO(file_content) as file:
elements = partition_msg(file=file)
return "\n".join([str(element) for element in elements])
except Exception as e:
raise TextExtractionError(f"Failed to extract text from MSG: {str(e)}") from e

View File

@ -0,0 +1,14 @@
class DocumentExtractorError(Exception):
"""Base exception for errors related to the DocumentExtractorNode."""
class FileDownloadError(DocumentExtractorError):
"""Exception raised when there's an error downloading a file."""
class UnsupportedFileTypeError(DocumentExtractorError):
"""Exception raised when trying to extract text from an unsupported file type."""
class TextExtractionError(DocumentExtractorError):
"""Exception raised when there's an error during text extraction from a file."""

View File

@ -0,0 +1,7 @@
from collections.abc import Sequence
from core.workflow.entities.base_node_data_entities import BaseNodeData
class DocumentExtractorNodeData(BaseNodeData):
variable_selector: Sequence[str]

View File

@ -0,0 +1,3 @@
from .entities import EndStreamParam
__all__ = ["EndStreamParam"]

View File

@ -1,13 +1,14 @@
from collections.abc import Mapping, Sequence
from typing import Any, cast
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.end.entities import EndNodeData
from enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus
class EndNode(BaseNode):
class EndNode(BaseNode[EndNodeData]):
_node_data_cls = EndNodeData
_node_type = NodeType.END
@ -16,20 +17,27 @@ class EndNode(BaseNode):
Run node
:return:
"""
node_data = self.node_data
node_data = cast(EndNodeData, node_data)
output_variables = node_data.outputs
output_variables = self.node_data.outputs
outputs = {}
for variable_selector in output_variables:
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
value = variable.to_object() if variable is not None else None
outputs[variable_selector.variable] = value
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=outputs, outputs=outputs)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=outputs,
outputs=outputs,
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: EndNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: EndNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping

View File

@ -1,5 +1,5 @@
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.end.entities import EndNodeData, EndStreamParam
from enums import NodeType
class EndStreamGeneratorRouter:

View File

@ -0,0 +1,4 @@
from .entities import BodyData, HttpRequestNodeAuthorization, HttpRequestNodeBody, HttpRequestNodeData
from .http_request_node import HttpRequestNode
__all__ = ["HttpRequestNodeData", "HttpRequestNodeAuthorization", "HttpRequestNodeBody", "BodyData", "HttpRequestNode"]

View File

@ -1,15 +1,16 @@
from typing import Literal, Optional, Union
from collections.abc import Sequence
from typing import Literal, Optional
from pydantic import BaseModel, ValidationInfo, field_validator
from pydantic import BaseModel, Field, ValidationInfo, field_validator
from configs import dify_config
from core.workflow.entities.base_node_data_entities import BaseNodeData
class HttpRequestNodeAuthorizationConfig(BaseModel):
type: Literal[None, "basic", "bearer", "custom"]
api_key: Union[None, str] = None
header: Union[None, str] = None
type: Literal["basic", "bearer", "custom"]
api_key: str
header: str = ""
class HttpRequestNodeAuthorization(BaseModel):
@ -31,9 +32,16 @@ class HttpRequestNodeAuthorization(BaseModel):
return v
class BodyData(BaseModel):
key: str = ""
type: Literal["file", "text"]
value: str = ""
file: Sequence[str] = Field(default_factory=list)
class HttpRequestNodeBody(BaseModel):
type: Literal["none", "form-data", "x-www-form-urlencoded", "raw-text", "json"]
data: Union[None, str] = None
type: Literal["none", "form-data", "x-www-form-urlencoded", "raw-text", "json", "binary"]
data: Sequence[BodyData] = Field(default_factory=list)
class HttpRequestNodeTimeout(BaseModel):

View File

@ -1,22 +1,35 @@
import json
from collections.abc import Mapping, Sequence
from copy import deepcopy
from random import randint
from typing import Any, Optional, Union
from typing import Any, Literal
from urllib.parse import urlencode
import httpx
from configs import dify_config
from core.file import file_manager
from core.helper import ssrf_proxy
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.http_request.entities import (
HttpRequestNodeAuthorization,
HttpRequestNodeBody,
HttpRequestNodeData,
HttpRequestNodeTimeout,
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
BODY_TYPE_TO_CONTENT_TYPE = {
"json": "application/json",
"x-www-form-urlencoded": "application/x-www-form-urlencoded",
"form-data": "multipart/form-data",
"raw-text": "text/plain",
}
NON_FILE_CONTENT_TYPES = (
"application/json",
"application/xml",
"text/html",
"text/plain",
"application/x-www-form-urlencoded",
)
class HttpExecutorResponse:
@ -25,54 +38,37 @@ class HttpExecutorResponse:
def __init__(self, response: httpx.Response):
self.response = response
self.headers = dict(response.headers) if isinstance(self.response, httpx.Response) else {}
self.headers = dict(response.headers)
@property
def is_file(self) -> bool:
"""
check if response is file
"""
content_type = self.get_content_type()
file_content_types = ["image", "audio", "video"]
def is_file(self):
content_type = self.content_type
content_disposition = self.response.headers.get("Content-Disposition", "")
return any(v in content_type for v in file_content_types)
def get_content_type(self) -> str:
return self.headers.get("content-type", "")
def extract_file(self) -> tuple[str, bytes]:
"""
extract file from response if content type is file related
"""
if self.is_file:
return self.get_content_type(), self.body
return "", b""
return "attachment" in content_disposition or (
not any(non_file in content_type for non_file in NON_FILE_CONTENT_TYPES)
and any(file_type in content_type for file_type in ("application/", "image/", "audio/", "video/"))
)
@property
def content(self) -> str:
if isinstance(self.response, httpx.Response):
return self.response.text
else:
raise ValueError(f"Invalid response type {type(self.response)}")
def content_type(self) -> str:
return self.headers.get("Content-Type", "")
@property
def body(self) -> bytes:
if isinstance(self.response, httpx.Response):
return self.response.content
else:
raise ValueError(f"Invalid response type {type(self.response)}")
def text(self) -> str:
return self.response.text
@property
def content(self) -> bytes:
return self.response.content
@property
def status_code(self) -> int:
if isinstance(self.response, httpx.Response):
return self.response.status_code
else:
raise ValueError(f"Invalid response type {type(self.response)}")
return self.response.status_code
@property
def size(self) -> int:
return len(self.body)
return len(self.content)
@property
def readable_size(self) -> str:
@ -85,152 +81,154 @@ class HttpExecutorResponse:
class HttpExecutor:
server_url: str
method: str
authorization: HttpRequestNodeAuthorization
params: dict[str, Any]
headers: dict[str, Any]
body: Union[None, str]
files: Union[None, dict[str, Any]]
boundary: str
variable_selectors: list[VariableSelector]
method: Literal["get", "head", "post", "put", "delete", "patch"]
url: str
params: Mapping[str, str] | None
content: str | bytes | None
data: Mapping[str, Any] | None
files: Mapping[str, bytes] | None
json: Any
headers: dict[str, str]
auth: HttpRequestNodeAuthorization
timeout: HttpRequestNodeTimeout
boundary: str
def __init__(
self,
*,
node_data: HttpRequestNodeData,
timeout: HttpRequestNodeTimeout,
variable_pool: Optional[VariablePool] = None,
variable_pool: VariablePool,
):
self.server_url = node_data.url
# If authorization API key is present, convert the API key using the variable pool
if node_data.authorization.type == "api-key":
if node_data.authorization.config is None:
raise ValueError("authorization config is required")
node_data.authorization.config.api_key = variable_pool.convert_template(
node_data.authorization.config.api_key
).text
self.url: str = node_data.url
self.method = node_data.method
self.authorization = node_data.authorization
self.auth = node_data.authorization
self.timeout = timeout
self.params = {}
self.params = None
self.headers = {}
self.body = None
self.content = None
self.files = None
self.data = None
self.json = None
# init template
self.variable_selectors = []
self._init_template(node_data, variable_pool)
self.variable_pool = variable_pool
self.node_data = node_data
self._initialize()
@staticmethod
def _is_json_body(body: HttpRequestNodeBody):
"""
check if body is json
"""
if body and body.type == "json" and body.data:
try:
json.loads(body.data)
return True
except:
return False
def _initialize(self):
self._init_url()
self._init_params()
self._init_headers()
self._init_body()
return False
def _init_url(self):
self.url = self.variable_pool.convert_template(self.node_data.url).text
@staticmethod
def _to_dict(convert_text: str):
"""
Convert the string like `aa:bb\n cc:dd` to dict `{aa:bb, cc:dd}`
"""
kv_paris = convert_text.split("\n")
result = {}
for kv in kv_paris:
if not kv.strip():
continue
def _init_params(self):
params = self.variable_pool.convert_template(self.node_data.params).text
self.params = _plain_text_to_dict(params)
kv = kv.split(":", maxsplit=1)
if len(kv) == 1:
k, v = kv[0], ""
else:
k, v = kv
result[k.strip()] = v
return result
def _init_headers(self):
headers = self.variable_pool.convert_template(self.node_data.headers).text
self.headers = _plain_text_to_dict(headers)
def _init_template(self, node_data: HttpRequestNodeData, variable_pool: Optional[VariablePool] = None):
# extract all template in url
self.server_url, server_url_variable_selectors = self._format_template(node_data.url, variable_pool)
body = self.node_data.body
if body is None:
return
if "content-type" not in (k.lower() for k in self.headers) and body.type in BODY_TYPE_TO_CONTENT_TYPE:
self.headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type]
if body.type == "form-data":
self.boundary = f"----WebKitFormBoundary{_generate_random_string(16)}"
self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}"
# extract all template in params
params, params_variable_selectors = self._format_template(node_data.params, variable_pool)
self.params = self._to_dict(params)
def _init_body(self):
body = self.node_data.body
if body is not None:
data = body.data
match body.type:
case "none":
self.content = ""
case "raw-text":
self.content = self.variable_pool.convert_template(data[0].value).text
case "json":
json_object = json.loads(data[0].value)
self.json = self._parse_object_contains_variables(json_object)
case "binary":
file_selector = data[0].file
file_variable = self.variable_pool.get_file(file_selector)
if file_variable is None:
raise ValueError(f"cannot fetch file with selector {file_selector}")
file = file_variable.value
if file.related_id is None:
raise ValueError(f"file {file.related_id} not found")
self.content = file_manager.download(upload_file_id=file.related_id, tenant_id=file.tenant_id)
case "x-www-form-urlencoded":
form_data = {
self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template(
item.value
).text
for item in data
}
self.data = form_data
case "form-data":
form_data = {
self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template(
item.value
).text
for item in filter(lambda item: item.type == "text", data)
}
file_selectors = {
self.variable_pool.convert_template(item.key).text: item.file
for item in filter(lambda item: item.type == "file", data)
}
files = {k: self.variable_pool.get_file(selector) for k, selector in file_selectors.items()}
files = {k: v for k, v in files.items() if v is not None}
files = {k: variable.value for k, variable in files.items()}
files = {
k: file_manager.download(upload_file_id=v.related_id, tenant_id=v.tenant_id)
for k, v in files.items()
if v.related_id is not None
}
# extract all template in headers
headers, headers_variable_selectors = self._format_template(node_data.headers, variable_pool)
self.headers = self._to_dict(headers)
# extract all template in body
body_data_variable_selectors = []
if node_data.body:
# check if it's a valid JSON
is_valid_json = self._is_json_body(node_data.body)
body_data = node_data.body.data or ""
if body_data:
body_data, body_data_variable_selectors = self._format_template(body_data, variable_pool, is_valid_json)
content_type_is_set = any(key.lower() == "content-type" for key in self.headers)
if node_data.body.type == "json" and not content_type_is_set:
self.headers["Content-Type"] = "application/json"
elif node_data.body.type == "x-www-form-urlencoded" and not content_type_is_set:
self.headers["Content-Type"] = "application/x-www-form-urlencoded"
if node_data.body.type in {"form-data", "x-www-form-urlencoded"}:
body = self._to_dict(body_data)
if node_data.body.type == "form-data":
self.files = {k: ("", v) for k, v in body.items()}
random_str = lambda n: "".join([chr(randint(97, 122)) for _ in range(n)])
self.boundary = f"----WebKitFormBoundary{random_str(16)}"
self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}"
else:
self.body = urlencode(body)
elif node_data.body.type in {"json", "raw-text"}:
self.body = body_data
elif node_data.body.type == "none":
self.body = ""
self.variable_selectors = (
server_url_variable_selectors
+ params_variable_selectors
+ headers_variable_selectors
+ body_data_variable_selectors
)
self.data = form_data
self.files = files
def _assembling_headers(self) -> dict[str, Any]:
authorization = deepcopy(self.authorization)
authorization = deepcopy(self.auth)
headers = deepcopy(self.headers) or {}
if self.authorization.type == "api-key":
if self.authorization.config is None:
if self.auth.type == "api-key":
if self.auth.config is None:
raise ValueError("self.authorization config is required")
if authorization.config is None:
raise ValueError("authorization config is required")
if self.authorization.config.api_key is None:
if self.auth.config.api_key is None:
raise ValueError("api_key is required")
if not authorization.config.header:
authorization.config.header = "Authorization"
if self.authorization.config.type == "bearer":
if self.auth.config.type == "bearer":
headers[authorization.config.header] = f"Bearer {authorization.config.api_key}"
elif self.authorization.config.type == "basic":
elif self.auth.config.type == "basic":
headers[authorization.config.header] = f"Basic {authorization.config.api_key}"
elif self.authorization.config.type == "custom":
headers[authorization.config.header] = authorization.config.api_key
elif self.auth.config.type == "custom":
headers[authorization.config.header] = authorization.config.api_key or ""
return headers
def _validate_and_parse_response(self, response: httpx.Response) -> HttpExecutorResponse:
"""
validate the response
"""
if isinstance(response, httpx.Response):
executor_response = HttpExecutorResponse(response)
else:
raise ValueError(f"Invalid response type {type(response)}")
executor_response = HttpExecutorResponse(response)
threshold_size = (
dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE
@ -250,94 +248,133 @@ class HttpExecutor:
"""
do http request depending on api bundle
"""
kwargs = {
"url": self.server_url,
if self.method not in {"get", "head", "post", "put", "delete", "patch"}:
raise ValueError(f"Invalid http method {self.method}")
request_args = {
"url": self.url,
"data": self.data,
"files": self.files,
"json": self.json,
"content": self.content,
"headers": headers,
"params": self.params,
"timeout": (self.timeout.connect, self.timeout.read, self.timeout.write),
"follow_redirects": True,
}
if self.method in {"get", "head", "post", "put", "delete", "patch"}:
response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs)
else:
raise ValueError(f"Invalid http method {self.method}")
response = getattr(ssrf_proxy, self.method)(**request_args)
return response
def invoke(self) -> HttpExecutorResponse:
"""
invoke http request
"""
# assemble headers
headers = self._assembling_headers()
# do http request
response = self._do_http_request(headers)
# validate response
return self._validate_and_parse_response(response)
def to_raw_request(self) -> str:
"""
convert to raw request
"""
server_url = self.server_url
def to_log(self):
url = self.url
if self.params:
server_url += f"?{urlencode(self.params)}"
raw_request = f"{self.method.upper()} {server_url} HTTP/1.1\n"
url += f"?{urlencode(self.params)}"
raw = f"{self.method.upper()} {url} HTTP/1.1\n"
headers = self._assembling_headers()
for k, v in headers.items():
# get authorization header
if self.authorization.type == "api-key":
if self.auth.type == "api-key":
authorization_header = "Authorization"
if self.authorization.config and self.authorization.config.header:
authorization_header = self.authorization.config.header
if self.auth.config and self.auth.config.header:
authorization_header = self.auth.config.header
if k.lower() == authorization_header.lower():
raw_request += f'{k}: {"*" * len(v)}\n'
raw += f'{k}: {"*" * len(v)}\n'
continue
raw += f"{k}: {v}\n"
raw += "\n"
raw_request += f"{k}: {v}\n"
raw_request += "\n"
# if files, use multipart/form-data with boundary
if self.files:
# if files, use multipart/form-data with boundary
boundary = self.boundary
raw_request += f"--{boundary}"
raw += f"--{boundary}"
for k, v in self.files.items():
raw_request += f'\nContent-Disposition: form-data; name="{k}"\n\n'
raw_request += f"{v[1]}\n"
raw_request += f"--{boundary}"
raw_request += "--"
else:
raw_request += self.body or ""
raw += f'\nContent-Disposition: form-data; name="{k}"\n\n'
raw += f"{v[1]}\n"
raw += f"--{boundary}"
raw += "--"
elif self.node_data.body:
if self.content:
# for binary content
if isinstance(self.content, str):
raw += self.content
elif isinstance(self.content, bytes):
raw += self.content.decode("utf-8", errors="replace")
elif self.data and self.node_data.body.type == "x-www-form-urlencoded":
# for x-www-form-urlencoded
raw += urlencode(self.data)
elif self.data and self.node_data.body.type == "form-data":
# for form-data
boundary = self.boundary
for key, value in self.data.items():
raw += f"--{boundary}\r\n"
raw += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
raw += f"{value}\r\n"
raw += f"--{boundary}--\r\n"
elif self.json:
# for json
raw += json.dumps(self.json)
elif self.node_data.body.type == "raw-text":
# for raw text
raw += self.node_data.body.data[0].value
return raw_request
return raw
def _format_template(
self, template: str, variable_pool: Optional[VariablePool], escape_quotes: bool = False
) -> tuple[str, list[VariableSelector]]:
"""
format template
"""
variable_template_parser = VariableTemplateParser(template=template)
variable_selectors = variable_template_parser.extract_variable_selectors()
def _parse_object_contains_variables(self, obj: str | dict | list, /) -> Mapping[str, Any] | Sequence[Any] | str:
if isinstance(obj, dict):
return {k: self._parse_object_contains_variables(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [self._parse_object_contains_variables(v) for v in obj]
elif isinstance(obj, str):
return self.variable_pool.convert_template(obj).text
if variable_pool:
variable_value_mapping = {}
for variable_selector in variable_selectors:
variable = variable_pool.get_any(variable_selector.value_selector)
if variable is None:
raise ValueError(f"Variable {variable_selector.variable} not found")
if escape_quotes and isinstance(variable, str):
value = variable.replace('"', '\\"').replace("\n", "\\n")
else:
value = variable
variable_value_mapping[variable_selector.variable] = value
return variable_template_parser.format(variable_value_mapping), variable_selectors
else:
return template, variable_selectors
def _plain_text_to_dict(text: str, /) -> dict[str, str]:
"""
Convert a string of key-value pairs to a dictionary.
Each line in the input string represents a key-value pair.
Keys and values are separated by ':'.
Empty values are allowed.
Examples:
'aa:bb\n cc:dd' -> {'aa': 'bb', 'cc': 'dd'}
'aa:\n cc:dd\n' -> {'aa': '', 'cc': 'dd'}
'aa\n cc : dd' -> {'aa': '', 'cc': 'dd'}
Args:
convert_text (str): The input string to convert.
Returns:
dict[str, str]: A dictionary of key-value pairs.
"""
return {
key.strip(): (value[0].strip() if value else "")
for line in text.splitlines()
if line.strip()
for key, *value in [line.split(":", 1)]
}
def _generate_random_string(n: int) -> str:
"""
Generate a random string of lowercase ASCII letters.
Args:
n (int): The length of the random string to generate.
Returns:
str: A random string of lowercase ASCII letters with length n.
Example:
>>> _generate_random_string(5)
'abcde'
"""
return "".join([chr(randint(97, 122)) for _ in range(n)])

View File

@ -2,19 +2,21 @@ import logging
from collections.abc import Mapping, Sequence
from mimetypes import guess_extension
from os import path
from typing import Any, cast
from typing import Any
from configs import dify_config
from core.app.segments import parser
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.file import File, FileTransferMethod, FileType
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.http_request.entities import (
HttpRequestNodeData,
HttpRequestNodeTimeout,
)
from core.workflow.nodes.http_request.http_executor import HttpExecutor, HttpExecutorResponse
from core.workflow.utils import variable_template_parser
from enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus
HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
@ -23,8 +25,10 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
write=dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
)
logger = logging.getLogger(__name__)
class HttpRequestNode(BaseNode):
class HttpRequestNode(BaseNode[HttpRequestNodeData]):
_node_data_cls = HttpRequestNodeData
_node_type = NodeType.HTTP_REQUEST
@ -48,51 +52,37 @@ class HttpRequestNode(BaseNode):
}
def _run(self) -> NodeRunResult:
node_data: HttpRequestNodeData = cast(HttpRequestNodeData, self.node_data)
# TODO: Switch to use segment directly
if node_data.authorization.config and node_data.authorization.config.api_key:
node_data.authorization.config.api_key = parser.convert_template(
template=node_data.authorization.config.api_key, variable_pool=self.graph_runtime_state.variable_pool
).text
# init http executor
http_executor = None
process_data = {}
try:
http_executor = HttpExecutor(
node_data=node_data,
timeout=self._get_request_timeout(node_data),
node_data=self.node_data,
timeout=self._get_request_timeout(self.node_data),
variable_pool=self.graph_runtime_state.variable_pool,
)
process_data["request"] = http_executor.to_log()
# invoke http executor
response = http_executor.invoke()
files = self.extract_files(url=http_executor.url, response=response)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
"status_code": response.status_code,
"body": response.text if not files else "",
"headers": response.headers,
"files": files,
},
process_data={
"request": http_executor.to_log(),
},
)
except Exception as e:
process_data = {}
if http_executor:
process_data = {
"request": http_executor.to_raw_request(),
}
logger.warning(f"http request node {self.node_id} failed to run: {e}")
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
process_data=process_data,
)
files = self.extract_files(http_executor.server_url, response)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
"status_code": response.status_code,
"body": response.content if not files else "",
"headers": response.headers,
"files": files,
},
process_data={
"request": http_executor.to_raw_request(),
},
)
@staticmethod
def _get_request_timeout(node_data: HttpRequestNodeData) -> HttpRequestNodeTimeout:
timeout = node_data.timeout
@ -106,59 +96,76 @@ class HttpRequestNode(BaseNode):
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: HttpRequestNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: HttpRequestNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
try:
http_executor = HttpExecutor(node_data=node_data, timeout=HTTP_REQUEST_DEFAULT_TIMEOUT)
selectors: list[VariableSelector] = []
selectors += variable_template_parser.extract_selectors_from_template(node_data.headers)
selectors += variable_template_parser.extract_selectors_from_template(node_data.params)
if node_data.body:
body_type = node_data.body.type
data = node_data.body.data
match body_type:
case "binary":
selector = data[0].file
selectors.append(VariableSelector(variable="#" + ".".join(selector) + "#", value_selector=selector))
case "json" | "raw-text":
selectors += variable_template_parser.extract_selectors_from_template(data[0].key)
selectors += variable_template_parser.extract_selectors_from_template(data[0].value)
case "x-www-form-urlencoded":
for item in data:
selectors += variable_template_parser.extract_selectors_from_template(item.key)
selectors += variable_template_parser.extract_selectors_from_template(item.value)
case "form-data":
for item in data:
selectors += variable_template_parser.extract_selectors_from_template(item.key)
if item.type == "text":
selectors += variable_template_parser.extract_selectors_from_template(item.value)
elif item.type == "file":
selectors.append(
VariableSelector(variable="#" + ".".join(item.file) + "#", value_selector=item.file)
)
variable_selectors = http_executor.variable_selectors
mapping = {}
for selector in selectors:
mapping[node_id + "." + selector.variable] = selector.value_selector
variable_mapping = {}
for variable_selector in variable_selectors:
variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector
return mapping
return variable_mapping
except Exception as e:
logging.exception(f"Failed to extract variable selector to variable mapping: {e}")
return {}
def extract_files(self, url: str, response: HttpExecutorResponse) -> list[FileVar]:
def extract_files(self, url: str, response: HttpExecutorResponse) -> list[File]:
"""
Extract files from response
"""
files = []
mimetype, file_binary = response.extract_file()
content_type = response.content_type
content = response.content
if mimetype:
if content_type:
# extract filename from url
filename = path.basename(url)
# extract extension if possible
extension = guess_extension(mimetype) or ".bin"
extension = guess_extension(content_type) or ".bin"
tool_file = ToolFileManager.create_file_by_raw(
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,
file_binary=file_binary,
mimetype=mimetype,
file_binary=content,
mimetype=content_type,
)
files.append(
FileVar(
File(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=tool_file.id,
filename=filename,
extension=extension,
mime_type=mimetype,
mime_type=content_type,
)
)

View File

@ -1,6 +1,6 @@
from typing import Literal, Optional
from pydantic import BaseModel
from pydantic import BaseModel, Field
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.utils.condition.entities import Condition
@ -21,6 +21,6 @@ class IfElseNodeData(BaseNodeData):
conditions: list[Condition]
logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[Condition]] = None
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
cases: Optional[list[Case]] = None

View File

@ -1,14 +1,19 @@
from collections.abc import Mapping, Sequence
from typing import Any, cast
from typing import Any, Literal
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from typing_extensions import deprecated
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.if_else.entities import IfElseNodeData
from core.workflow.utils.condition.entities import Condition
from core.workflow.utils.condition.processor import ConditionProcessor
from enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus
class IfElseNode(BaseNode):
class IfElseNode(BaseNode[IfElseNodeData]):
_node_data_cls = IfElseNodeData
_node_type = NodeType.IF_ELSE
@ -17,9 +22,6 @@ class IfElseNode(BaseNode):
Run node
:return:
"""
node_data = self.node_data
node_data = cast(IfElseNodeData, node_data)
node_inputs: dict[str, list] = {"conditions": []}
process_datas: dict[str, list] = {"condition_results": []}
@ -30,15 +32,14 @@ class IfElseNode(BaseNode):
condition_processor = ConditionProcessor()
try:
# Check if the new cases structure is used
if node_data.cases:
for case in node_data.cases:
input_conditions, group_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool, conditions=case.conditions
if self.node_data.cases:
for case in self.node_data.cases:
input_conditions, group_result, final_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=case.conditions,
operator=case.logical_operator,
)
# Apply the logical operator for the current case
final_result = all(group_result) if case.logical_operator == "and" else any(group_result)
process_datas["condition_results"].append(
{
"group": case.model_dump(),
@ -53,13 +54,15 @@ class IfElseNode(BaseNode):
break
else:
# TODO: Update database then remove this
# Fallback to old structure if cases are not defined
input_conditions, group_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool, conditions=node_data.conditions
input_conditions, group_result, final_result = _should_not_use_old_function(
condition_processor=condition_processor,
variable_pool=self.graph_runtime_state.variable_pool,
conditions=self.node_data.conditions or [],
operator=self.node_data.logical_operator or "and",
)
final_result = all(group_result) if node_data.logical_operator == "and" else any(group_result)
selected_case_id = "true" if final_result else "false"
process_datas["condition_results"].append(
@ -87,7 +90,11 @@ class IfElseNode(BaseNode):
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: IfElseNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: IfElseNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@ -97,3 +104,18 @@ class IfElseNode(BaseNode):
:return:
"""
return {}
@deprecated("This function is deprecated. You should use the new cases structure.")
def _should_not_use_old_function(
*,
condition_processor: ConditionProcessor,
variable_pool: VariablePool,
conditions: list[Condition],
operator: Literal["and", "or"],
):
return condition_processor.process_conditions(
variable_pool=variable_pool,
conditions=conditions,
operator=operator,
)

View File

@ -5,7 +5,7 @@ from typing import Any, cast
from configs import dify_config
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.graph_engine.entities.event import (
BaseGraphEvent,
BaseNodeEvent,
@ -23,12 +23,13 @@ from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
from core.workflow.nodes.iteration.entities import IterationNodeData
from enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus
logger = logging.getLogger(__name__)
class IterationNode(BaseNode):
class IterationNode(BaseNode[IterationNodeData]):
"""
Iteration Node.
"""
@ -40,7 +41,6 @@ class IterationNode(BaseNode):
"""
Run the node.
"""
self.node_data = cast(IterationNodeData, self.node_data)
iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
if not iterator_list_segment:
@ -177,7 +177,7 @@ class IterationNode(BaseNode):
# remove all nodes outputs from variable pool
for node_id in iteration_graph.node_ids:
variable_pool.remove_node(node_id)
variable_pool.remove([node_id])
# move to next iteration
current_index = variable_pool.get([self.node_id, "index"])
@ -247,7 +247,11 @@ class IterationNode(BaseNode):
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: IterationNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: IterationNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping

View File

@ -1,9 +1,10 @@
from collections.abc import Mapping, Sequence
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.iteration.entities import IterationNodeData, IterationStartNodeData
from enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus

View File

@ -14,9 +14,10 @@ from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
from enums import NodeType
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
from models.workflow import WorkflowNodeExecutionStatus
@ -32,15 +33,13 @@ default_retrieval_model = {
}
class KnowledgeRetrievalNode(BaseNode):
class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
_node_data_cls = KnowledgeRetrievalNodeData
node_type = NodeType.KNOWLEDGE_RETRIEVAL
_node_type = NodeType.KNOWLEDGE_RETRIEVAL
def _run(self) -> NodeRunResult:
node_data = cast(KnowledgeRetrievalNodeData, self.node_data)
# extract variables
variable = self.graph_runtime_state.variable_pool.get_any(node_data.query_variable_selector)
variable = self.graph_runtime_state.variable_pool.get_any(self.node_data.query_variable_selector)
query = variable
variables = {"query": query}
if not query:
@ -49,7 +48,7 @@ class KnowledgeRetrievalNode(BaseNode):
)
# retrieve knowledge
try:
results = self._fetch_dataset_retriever(node_data=node_data, query=query)
results = self._fetch_dataset_retriever(node_data=self.node_data, query=query)
outputs = {"result": results}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
@ -244,7 +243,11 @@ class KnowledgeRetrievalNode(BaseNode):
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: KnowledgeRetrievalNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: KnowledgeRetrievalNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping

View File

@ -0,0 +1,3 @@
from .node import ListFilterNode
__all__ = ["ListFilterNode"]

View File

@ -0,0 +1,51 @@
from collections.abc import Sequence
from typing import Literal
from pydantic import BaseModel, Field
from core.workflow.entities.base_node_data_entities import BaseNodeData
_Condition = Literal[
# string conditions
"contains",
"startswith",
"endswith",
"is",
"in",
"empty",
"not contains",
"not is",
"not in",
"not empty",
# number conditions
"=",
"!=",
"<",
">",
"",
"",
]
class FilterBy(BaseModel):
key: str = ""
comparison_operator: _Condition = "contains"
value: str | Sequence[str] = ""
class OrderBy(BaseModel):
enabled: bool = False
key: str = ""
value: Literal["asc", "desc"] = "asc"
class Limit(BaseModel):
enabled: bool = False
size: int = -1
class ListFilterNodeData(BaseNodeData):
variable: Sequence[str] = Field(default_factory=list)
filter_by: Sequence[FilterBy]
order_by: OrderBy
limit: Limit

View File

@ -0,0 +1,258 @@
from collections.abc import Callable, Sequence
from typing import Literal
from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base_node import BaseNode
from enums.workflow_nodes import NodeType
from models.workflow import WorkflowNodeExecutionStatus
from .models import ListFilterNodeData
class ListFilterNode(BaseNode[ListFilterNodeData]):
_node_data_cls = ListFilterNodeData
_node_type = NodeType.LIST_FILTER
def _run(self):
inputs = {}
process_data = {}
outputs = {}
variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable)
if variable is None:
error_message = f"Variable not found for selector: {self.node_data.variable}"
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
)
if variable.value and not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
error_message = (
f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
"or ArrayStringSegment"
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
)
if isinstance(variable, ArrayFileSegment):
process_data["variable"] = [item.to_dict() for item in variable.value]
else:
process_data["variable"] = variable.value
# Filter
for filter_by in self.node_data.filter_by:
if isinstance(variable, ArrayStringSegment):
if not isinstance(filter_by.value, str):
raise ValueError(f"Invalid filter value: {filter_by.value}")
value = self.graph_runtime_state.variable_pool.convert_template(filter_by.value).text
filter_func = _get_string_filter_func(condition=filter_by.comparison_operator, value=value)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayNumberSegment):
if not isinstance(filter_by.value, str):
raise ValueError(f"Invalid filter value: {filter_by.value}")
value = self.graph_runtime_state.variable_pool.convert_template(filter_by.value).text
filter_func = _get_number_filter_func(condition=filter_by.comparison_operator, value=float(value))
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayFileSegment):
if isinstance(filter_by.value, str):
value = self.graph_runtime_state.variable_pool.convert_template(filter_by.value).text
else:
value = filter_by.value
filter_func = _get_file_filter_func(
key=filter_by.key,
condition=filter_by.comparison_operator,
value=value,
)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
# Order
if self.node_data.order_by.enabled:
if isinstance(variable, ArrayStringSegment):
result = _order_string(order=self.node_data.order_by.value, array=variable.value)
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayNumberSegment):
result = _order_number(order=self.node_data.order_by.value, array=variable.value)
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayFileSegment):
result = _order_file(
order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value
)
variable = variable.model_copy(update={"value": result})
# Slice
if self.node_data.limit.enabled:
result = variable.value[: self.node_data.limit.size]
variable = variable.model_copy(update={"value": result})
outputs = {
"result": variable.value,
"first_record": variable.value[0] if variable.value else None,
"last_record": variable.value[-1] if variable.value else None,
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs=outputs,
)
def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]:
match key:
case "size":
return lambda x: x.size
case _:
raise ValueError(f"Invalid key: {key}")
def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
match key:
case "name":
return lambda x: x.filename or ""
case "type":
return lambda x: x.type
case "extension":
return lambda x: x.extension or ""
case "mimetype":
return lambda x: x.mime_type or ""
case "transfer_method":
return lambda x: x.transfer_method
case "urL":
return lambda x: x.remote_url or ""
case _:
raise ValueError(f"Invalid key: {key}")
def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bool]:
match condition:
case "contains":
return _contains(value)
case "startswith":
return _startswith(value)
case "endswith":
return _endswith(value)
case "is":
return _is(value)
case "in":
return _in(value)
case "empty":
return lambda x: x == ""
case "not contains":
return lambda x: not _contains(value)(x)
case "not is":
return lambda x: not _is(value)(x)
case "not in":
return lambda x: not _in(value)(x)
case "not empty":
return lambda x: x != ""
case _:
raise ValueError(f"Invalid condition: {condition}")
def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callable[[str], bool]:
match condition:
case "in":
return _in(value)
case "not in":
return lambda x: not _in(value)(x)
case _:
raise ValueError(f"Invalid condition: {condition}")
def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[int | float], bool]:
match condition:
case "=":
return _eq(value)
case "!=":
return _ne(value)
case "<":
return _lt(value)
case "":
return _le(value)
case ">":
return _gt(value)
case "":
return _ge(value)
case _:
raise ValueError(f"Invalid condition: {condition}")
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str):
extract_func = _get_file_extract_string_func(key=key)
return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x))
if key in {"type", "transfer_method"} and isinstance(value, Sequence):
extract_func = _get_file_extract_string_func(key=key)
return lambda x: _get_sequence_filter_func(condition=condition, value=value)(extract_func(x))
elif key == "size" and isinstance(value, str):
extract_func = _get_file_extract_number_func(key=key)
return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x))
else:
raise ValueError(f"Invalid key: {key}")
def _contains(value: str):
return lambda x: value in x
def _startswith(value: str):
return lambda x: x.startswith(value)
def _endswith(value: str):
return lambda x: x.endswith(value)
def _is(value: str):
return lambda x: x is value
def _in(value: str | Sequence[str]):
return lambda x: x in value
def _eq(value: int | float):
return lambda x: x == value
def _ne(value: int | float):
return lambda x: x != value
def _lt(value: int | float):
return lambda x: x < value
def _le(value: int | float):
return lambda x: x <= value
def _gt(value: int | float):
return lambda x: x > value
def _ge(value: int | float):
return lambda x: x >= value
def _order_number(*, order: Literal["asc", "desc"], array: Sequence[int | float]):
return sorted(array, key=lambda x: x, reverse=order == "desc")
def _order_string(*, order: Literal["asc", "desc"], array: Sequence[str]):
return sorted(array, key=lambda x: x, reverse=order == "desc")
def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Sequence[File]):
if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "urL"}:
extract_func = _get_file_extract_string_func(key=order_by)
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
elif order_by == "size":
extract_func = _get_file_extract_number_func(key=order_by)
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
else:
raise ValueError(f"Invalid order key: {order_by}")

View File

@ -0,0 +1,18 @@
from .entities import (
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
LLMNodeData,
ModelConfig,
VisionConfig,
)
from .llm_node import LLMNode, ModelInvokeCompleted
__all__ = [
"LLMNode",
"ModelInvokeCompleted",
"LLMNodeChatModelMessage",
"LLMNodeCompletionModelPromptTemplate",
"LLMNodeData",
"ModelConfig",
"VisionConfig",
]

View File

@ -1,17 +1,15 @@
from typing import Any, Literal, Optional, Union
from collections.abc import Sequence
from typing import Any, Optional
from pydantic import BaseModel
from pydantic import BaseModel, Field
from core.model_runtime.entities import ImagePromptMessageContent
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
class ModelConfig(BaseModel):
"""
Model Config.
"""
provider: str
name: str
mode: str
@ -19,62 +17,36 @@ class ModelConfig(BaseModel):
class ContextConfig(BaseModel):
"""
Context Config.
"""
enabled: bool
variable_selector: Optional[list[str]] = None
class VisionConfigOptions(BaseModel):
variable_selector: Sequence[str] = Field(default_factory=lambda: ["sys", "files"])
detail: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.HIGH
class VisionConfig(BaseModel):
"""
Vision Config.
"""
class Configs(BaseModel):
"""
Configs.
"""
detail: Literal["low", "high"]
enabled: bool
configs: Optional[Configs] = None
enabled: bool = False
configs: VisionConfigOptions = Field(default_factory=VisionConfigOptions)
class PromptConfig(BaseModel):
"""
Prompt Config.
"""
jinja2_variables: Optional[list[VariableSelector]] = None
class LLMNodeChatModelMessage(ChatModelMessage):
"""
LLM Node Chat Model Message.
"""
jinja2_text: Optional[str] = None
class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
"""
LLM Node Chat Model Prompt Template.
"""
jinja2_text: Optional[str] = None
class LLMNodeData(BaseNodeData):
"""
LLM Node Data.
"""
model: ModelConfig
prompt_template: Union[list[LLMNodeChatModelMessage], LLMNodeCompletionModelPromptTemplate]
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
prompt_config: Optional[PromptConfig] = None
memory: Optional[MemoryConfig] = None
context: ContextConfig
vision: VisionConfig
vision: VisionConfig = Field(default_factory=VisionConfig)

View File

@ -1,6 +1,5 @@
import json
from collections.abc import Generator, Mapping, Sequence
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Optional, cast
from pydantic import BaseModel
@ -23,26 +22,28 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.variables import ArrayAnySegment, ArrayFileSegment, FileSegment
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.entities import (
from core.workflow.nodes.llm import (
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
LLMNodeData,
ModelConfig,
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from enums import NodeType
from extensions.ext_database import db
from models.model import Conversation
from models.provider import Provider, ProviderType
from models.workflow import WorkflowNodeExecutionStatus
if TYPE_CHECKING:
from core.file.file_obj import FileVar
from core.file.models import File
class ModelInvokeCompleted(BaseModel):
@ -55,30 +56,23 @@ class ModelInvokeCompleted(BaseModel):
finish_reason: Optional[str] = None
class LLMNode(BaseNode):
class LLMNode(BaseNode[LLMNodeData]):
_node_data_cls = LLMNodeData
_node_type = NodeType.LLM
def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
"""
Run node
:return:
"""
node_data = cast(LLMNodeData, deepcopy(self.node_data))
variable_pool = self.graph_runtime_state.variable_pool
def _run(self) -> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]:
node_inputs = None
process_data = None
try:
# init messages template
node_data.prompt_template = self._transform_chat_messages(node_data.prompt_template)
self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template)
# fetch variables and fetch values from variable pool
inputs = self._fetch_inputs(node_data, variable_pool)
inputs = self._fetch_inputs(node_data=self.node_data)
# fetch jinja2 inputs
jinja_inputs = self._fetch_jinja_inputs(node_data, variable_pool)
jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data)
# merge inputs
inputs.update(jinja_inputs)
@ -86,13 +80,17 @@ class LLMNode(BaseNode):
node_inputs = {}
# fetch files
files = self._fetch_files(node_data, variable_pool)
files = (
self._fetch_files(selector=self.node_data.vision.configs.variable_selector)
if self.node_data.vision.enabled
else []
)
if files:
node_inputs["#files#"] = [file.to_dict() for file in files]
# fetch context value
generator = self._fetch_context(node_data, variable_pool)
generator = self._fetch_context(node_data=self.node_data)
context = None
for event in generator:
if isinstance(event, RunRetrieverResourceEvent):
@ -103,21 +101,30 @@ class LLMNode(BaseNode):
node_inputs["#context#"] = context # type: ignore
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data.model)
model_instance, model_config = self._fetch_model_config(self.node_data.model)
# fetch memory
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
memory = self._fetch_memory(node_data_memory=self.node_data.memory, model_instance=model_instance)
# fetch prompt messages
if self.node_data.memory:
query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
if not query:
raise ValueError("Query not found")
query = query.text
else:
query = None
prompt_messages, stop = self._fetch_prompt_messages(
node_data=node_data,
query=variable_pool.get_any(["sys", SystemVariableKey.QUERY.value]) if node_data.memory else None,
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
system_query=query,
inputs=inputs,
files=files,
context=context,
memory=memory,
model_config=model_config,
vision_detail=self.node_data.vision.configs.detail,
prompt_template=self.node_data.prompt_template,
memory_config=self.node_data.memory,
)
process_data = {
@ -131,7 +138,7 @@ class LLMNode(BaseNode):
# handle invoke result
generator = self._invoke_llm(
node_data_model=node_data.model,
node_data_model=self.node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
@ -183,14 +190,6 @@ class LLMNode(BaseNode):
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None,
) -> Generator[RunEvent | ModelInvokeCompleted, None, None]:
"""
Invoke large language model
:param node_data_model: node data model
:param model_instance: model instance
:param prompt_messages: prompt messages
:param stop: stop
:return:
"""
db.session.close()
invoke_result = model_instance.invoke_llm(
@ -216,11 +215,6 @@ class LLMNode(BaseNode):
def _handle_invoke_result(
self, invoke_result: LLMResult | Generator
) -> Generator[RunEvent | ModelInvokeCompleted, None, None]:
"""
Handle invoke result
:param invoke_result: invoke result
:return:
"""
if isinstance(invoke_result, LLMResult):
return
@ -253,15 +247,8 @@ class LLMNode(BaseNode):
yield ModelInvokeCompleted(text=full_text, usage=usage, finish_reason=finish_reason)
def _transform_chat_messages(
self, messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
"""
Transform chat messages
:param messages: chat messages
:return:
"""
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
if isinstance(messages, LLMNodeCompletionModelPromptTemplate):
if messages.edition_type == "jinja2" and messages.jinja2_text:
messages.text = messages.jinja2_text
@ -274,13 +261,7 @@ class LLMNode(BaseNode):
return messages
def _fetch_jinja_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
"""
Fetch jinja inputs
:param node_data: node data
:param variable_pool: variable pool
:return:
"""
def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
variables = {}
if not node_data.prompt_config:
@ -288,7 +269,7 @@ class LLMNode(BaseNode):
for variable_selector in node_data.prompt_config.jinja2_variables or []:
variable = variable_selector.variable
value = variable_pool.get_any(variable_selector.value_selector)
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
def parse_dict(d: dict) -> str:
"""
@ -330,13 +311,7 @@ class LLMNode(BaseNode):
return variables
def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
"""
Fetch inputs
:param node_data: node data
:param variable_pool: variable pool
:return:
"""
def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
inputs = {}
prompt_template = node_data.prompt_template
@ -350,7 +325,7 @@ class LLMNode(BaseNode):
variable_selectors = variable_template_parser.extract_variable_selectors()
for variable_selector in variable_selectors:
variable_value = variable_pool.get_any(variable_selector.value_selector)
variable_value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
if variable_value is None:
raise ValueError(f"Variable {variable_selector.variable} not found")
@ -362,7 +337,7 @@ class LLMNode(BaseNode):
template=memory.query_prompt_template
).extract_variable_selectors()
for variable_selector in query_variable_selectors:
variable_value = variable_pool.get_any(variable_selector.value_selector)
variable_value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
if variable_value is None:
raise ValueError(f"Variable {variable_selector.variable} not found")
@ -370,36 +345,28 @@ class LLMNode(BaseNode):
return inputs
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list["FileVar"]:
"""
Fetch files
:param node_data: node data
:param variable_pool: variable pool
:return:
"""
if not node_data.vision.enabled:
def _fetch_files(self, *, selector: Sequence[str]) -> Sequence["File"]:
variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is None:
return []
files = variable_pool.get_any(["sys", SystemVariableKey.FILES.value])
if not files:
if isinstance(variable, FileSegment):
return [variable.value]
if isinstance(variable, ArrayFileSegment):
return variable.value
# FIXME: Temporary fix for empty array,
# all variables added to variable pool should be a Segment instance.
if isinstance(variable, ArrayAnySegment) and len(variable.value) == 0:
return []
raise ValueError(f"Invalid variable type: {type(variable)}")
return files
def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Generator[RunEvent, None, None]:
"""
Fetch context
:param node_data: node data
:param variable_pool: variable pool
:return:
"""
def _fetch_context(self, node_data: LLMNodeData) -> Generator[RunEvent, None, None]:
if not node_data.context.enabled:
return
if not node_data.context.variable_selector:
return
context_value = variable_pool.get_any(node_data.context.variable_selector)
context_value = self.graph_runtime_state.variable_pool.get_any(node_data.context.variable_selector)
if context_value:
if isinstance(context_value, str):
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value)
@ -424,11 +391,6 @@ class LLMNode(BaseNode):
)
def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]:
"""
Convert to original retriever resource, temp.
:param context_dict: context dict
:return:
"""
if (
"metadata" in context_dict
and "_source" in context_dict["metadata"]
@ -451,6 +413,7 @@ class LLMNode(BaseNode):
"segment_position": metadata.get("segment_position"),
"index_node_hash": metadata.get("segment_index_node_hash"),
"content": context_dict.get("content"),
"page": metadata.get("page"),
}
return source
@ -460,11 +423,6 @@ class LLMNode(BaseNode):
def _fetch_model_config(
self, node_data_model: ModelConfig
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config
:param node_data_model: node data model
:return:
"""
model_name = node_data_model.name
provider_name = node_data_model.provider
@ -523,19 +481,15 @@ class LLMNode(BaseNode):
)
def _fetch_memory(
self, node_data_memory: Optional[MemoryConfig], variable_pool: VariablePool, model_instance: ModelInstance
self, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance
) -> Optional[TokenBufferMemory]:
"""
Fetch memory
:param node_data_memory: node data memory
:param variable_pool: variable pool
:return:
"""
if not node_data_memory:
return None
# get conversation id
conversation_id = variable_pool.get_any(["sys", SystemVariableKey.CONVERSATION_ID.value])
conversation_id = self.graph_runtime_state.variable_pool.get_any(
["sys", SystemVariableKey.CONVERSATION_ID.value]
)
if conversation_id is None:
return None
@ -555,43 +509,31 @@ class LLMNode(BaseNode):
def _fetch_prompt_messages(
self,
node_data: LLMNodeData,
query: Optional[str],
query_prompt_template: Optional[str],
inputs: dict[str, str],
files: list["FileVar"],
context: Optional[str],
memory: Optional[TokenBufferMemory],
*,
system_query: str | None = None,
inputs: dict[str, str] | None = None,
files: Sequence["File"],
context: str | None = None,
memory: TokenBufferMemory | None = None,
model_config: ModelConfigWithCredentialsEntity,
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
memory_config: MemoryConfig | None = None,
vision_detail: ImagePromptMessageContent.DETAIL,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
"""
Fetch prompt messages
:param node_data: node data
:param query: query
:param query_prompt_template: query prompt template
:param inputs: inputs
:param files: files
:param context: context
:param memory: memory
:param model_config: model config
:return:
"""
inputs = inputs or {}
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
prompt_messages = prompt_transform.get_prompt(
prompt_template=node_data.prompt_template,
prompt_template=prompt_template,
inputs=inputs,
query=query or "",
query=system_query or "",
files=files,
context=context,
memory_config=node_data.memory,
memory_config=memory_config,
memory=memory,
model_config=model_config,
query_prompt_template=query_prompt_template,
)
stop = model_config.stop
vision_enabled = node_data.vision.enabled
vision_detail = node_data.vision.configs.detail if node_data.vision.configs else None
filtered_prompt_messages = []
for prompt_message in prompt_messages:
if prompt_message.is_empty():
@ -599,15 +541,11 @@ class LLMNode(BaseNode):
if not isinstance(prompt_message.content, str):
prompt_message_content = []
for content_item in prompt_message.content:
if (
vision_enabled
and content_item.type == PromptMessageContentType.IMAGE
and isinstance(content_item, ImagePromptMessageContent)
):
# Override vision config if LLM node has vision config
if vision_detail:
content_item.detail = ImagePromptMessageContent.DETAIL(vision_detail)
for content_item in prompt_message.content or []:
if isinstance(content_item, ImagePromptMessageContent):
# Override vision config if LLM node has vision config,
# cuz vision detail is related to the configuration from FileUpload feature.
content_item.detail = vision_detail
prompt_message_content.append(content_item)
elif content_item.type == PromptMessageContentType.TEXT:
prompt_message_content.append(content_item)
@ -631,13 +569,6 @@ class LLMNode(BaseNode):
@classmethod
def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
"""
Deduct LLM quota
:param tenant_id: tenant id
:param model_instance: model instance
:param usage: usage
:return:
"""
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
@ -668,7 +599,7 @@ class LLMNode(BaseNode):
else:
used_quota = 1
if used_quota is not None:
if used_quota is not None and system_configuration.current_quota_type is not None:
db.session.query(Provider).filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == model_instance.provider,
@ -680,27 +611,28 @@ class LLMNode(BaseNode):
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: LLMNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: LLMNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
prompt_template = node_data.prompt_template
variable_selectors = []
if isinstance(prompt_template, list):
if isinstance(prompt_template, list) and all(
isinstance(prompt, LLMNodeChatModelMessage) for prompt in prompt_template
):
for prompt in prompt_template:
if prompt.edition_type != "jinja2":
variable_template_parser = VariableTemplateParser(template=prompt.text)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
else:
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
if prompt_template.edition_type != "jinja2":
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
variable_selectors = variable_template_parser.extract_variable_selectors()
else:
raise ValueError(f"Invalid prompt template type: {type(prompt_template)}")
variable_mapping = {}
for variable_selector in variable_selectors:
@ -745,11 +677,6 @@ class LLMNode(BaseNode):
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""
Get default config of node.
:param filters: filter by node config parameters.
:return:
"""
return {
"type": "llm",
"config": {

View File

@ -1,12 +1,12 @@
from typing import Any
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.loop.entities import LoopNodeData, LoopState
from core.workflow.utils.condition.entities import Condition
from enums import NodeType
class LoopNode(BaseNode):
class LoopNode(BaseNode[LoopNodeData]):
"""
Loop Node.
"""

View File

@ -1,12 +1,13 @@
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.document_extractor import DocumentExtractorNode
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from core.workflow.nodes.iteration.iteration_node import IterationNode
from core.workflow.nodes.iteration.iteration_start_node import IterationStartNode
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.list_filter import ListFilterNode
from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
@ -15,6 +16,7 @@ from core.workflow.nodes.template_transform.template_transform_node import Templ
from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
from core.workflow.nodes.variable_assigner import VariableAssignerNode
from enums import NodeType
node_classes = {
NodeType.START: StartNode,
@ -34,4 +36,6 @@ node_classes = {
NodeType.ITERATION_START: IterationStartNode,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode,
NodeType.DOCUMENT_EXTRACTOR: DocumentExtractorNode,
NodeType.LIST_FILTER: ListFilterNode,
}

View File

@ -1,20 +1,10 @@
from typing import Any, Literal, Optional
from pydantic import BaseModel, field_validator
from pydantic import BaseModel, Field, field_validator
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.workflow.entities.base_node_data_entities import BaseNodeData
class ModelConfig(BaseModel):
"""
Model Config.
"""
provider: str
name: str
mode: str
completion_params: dict[str, Any] = {}
from core.workflow.nodes.llm import ModelConfig, VisionConfig
class ParameterConfig(BaseModel):
@ -49,6 +39,7 @@ class ParameterExtractorNodeData(BaseNodeData):
instruction: Optional[str] = None
memory: Optional[MemoryConfig] = None
reasoning_mode: Literal["function_call", "prompt"]
vision: VisionConfig = Field(default_factory=VisionConfig)
@field_validator("reasoning_mode", mode="before")
@classmethod
@ -64,7 +55,7 @@ class ParameterExtractorNodeData(BaseNodeData):
parameters = {"type": "object", "properties": {}, "required": []}
for parameter in self.parameters:
parameter_schema = {"description": parameter.description}
parameter_schema: dict[str, Any] = {"description": parameter.description}
if parameter.type in {"string", "select"}:
parameter_schema["type"] = "string"

View File

@ -4,6 +4,7 @@ from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import File
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
@ -22,12 +23,17 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.llm.entities import ModelConfig
from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData
from core.workflow.nodes.parameter_extractor.prompts import (
from core.workflow.utils import variable_template_parser
from enums import NodeType
from extensions.ext_database import db
from models.workflow import WorkflowNodeExecutionStatus
from .entities import ParameterExtractorNodeData
from .prompts import (
CHAT_EXAMPLE,
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE,
COMPLETION_GENERATE_JSON_PROMPT,
@ -36,9 +42,6 @@ from core.workflow.nodes.parameter_extractor.prompts import (
FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT,
FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE,
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from models.workflow import WorkflowNodeExecutionStatus
class ParameterExtractorNode(LLMNode):
@ -65,33 +68,39 @@ class ParameterExtractorNode(LLMNode):
}
}
def _run(self) -> NodeRunResult:
def _run(self):
"""
Run the node.
"""
node_data = cast(ParameterExtractorNodeData, self.node_data)
variable = self.graph_runtime_state.variable_pool.get_any(node_data.query)
if not variable:
raise ValueError("Input variable content not found or is empty")
query = variable
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
query = variable.text if variable else ""
inputs = {
"query": query,
"parameters": jsonable_encoder(node_data.parameters),
"instruction": jsonable_encoder(node_data.instruction),
}
files = (
self._fetch_files(
selector=node_data.vision.configs.variable_selector,
)
if node_data.vision.enabled
else []
)
model_instance, model_config = self._fetch_model_config(node_data.model)
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
raise ValueError("Model is not a Large Language Model")
llm_model = model_instance.model_type_instance
model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials)
model_schema = llm_model.get_model_schema(
model=model_config.model,
credentials=model_config.credentials,
)
if not model_schema:
raise ValueError("Model schema not found")
# fetch memory
memory = self._fetch_memory(node_data.memory, self.graph_runtime_state.variable_pool, model_instance)
memory = self._fetch_memory(
node_data_memory=node_data.memory,
model_instance=model_instance,
)
if (
set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}
@ -99,15 +108,33 @@ class ParameterExtractorNode(LLMNode):
):
# use function call
prompt_messages, prompt_message_tools = self._generate_function_call_prompt(
node_data, query, self.graph_runtime_state.variable_pool, model_config, memory
node_data=node_data,
query=query,
variable_pool=self.graph_runtime_state.variable_pool,
model_config=model_config,
memory=memory,
files=files,
)
else:
# use prompt engineering
prompt_messages = self._generate_prompt_engineering_prompt(
node_data, query, self.graph_runtime_state.variable_pool, model_config, memory
data=node_data,
query=query,
variable_pool=self.graph_runtime_state.variable_pool,
model_config=model_config,
memory=memory,
files=files,
)
prompt_message_tools = []
inputs = {
"query": query,
"files": [f.to_dict() for f in files],
"parameters": jsonable_encoder(node_data.parameters),
"instruction": jsonable_encoder(node_data.instruction),
}
process_data = {
"model_mode": model_config.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
@ -119,7 +146,7 @@ class ParameterExtractorNode(LLMNode):
}
try:
text, usage, tool_call = self._invoke_llm(
text, usage, tool_call = self._invoke(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
@ -150,12 +177,12 @@ class ParameterExtractorNode(LLMNode):
error = "Failed to extract result from function call or text response, using empty result."
try:
result = self._validate_result(node_data, result)
result = self._validate_result(data=node_data, result=result or {})
except Exception as e:
error = str(e)
# transform result into standard format
result = self._transform_result(node_data, result)
result = self._transform_result(data=node_data, result=result or {})
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -170,7 +197,7 @@ class ParameterExtractorNode(LLMNode):
llm_usage=usage,
)
def _invoke_llm(
def _invoke(
self,
node_data_model: ModelConfig,
model_instance: ModelInstance,
@ -178,14 +205,6 @@ class ParameterExtractorNode(LLMNode):
tools: list[PromptMessageTool],
stop: list[str],
) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]:
"""
Invoke large language model
:param node_data_model: node data model
:param model_instance: model instance
:param prompt_messages: prompt messages
:param stop: stop
:return:
"""
db.session.close()
invoke_result = model_instance.invoke_llm(
@ -202,6 +221,9 @@ class ParameterExtractorNode(LLMNode):
raise ValueError(f"Invalid invoke result: {invoke_result}")
text = invoke_result.message.content
if not isinstance(text, str):
raise ValueError(f"Invalid text content type: {type(text)}. Expected str.")
usage = invoke_result.usage
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
@ -217,6 +239,7 @@ class ParameterExtractorNode(LLMNode):
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
files: Sequence[File],
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
"""
Generate function call prompt.
@ -234,7 +257,7 @@ class ParameterExtractorNode(LLMNode):
prompt_template=prompt_template,
inputs={},
query="",
files=[],
files=files,
context="",
memory_config=node_data.memory,
memory=None,
@ -296,6 +319,7 @@ class ParameterExtractorNode(LLMNode):
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
files: Sequence[File],
) -> list[PromptMessage]:
"""
Generate prompt engineering prompt.
@ -303,9 +327,23 @@ class ParameterExtractorNode(LLMNode):
model_mode = ModelMode.value_of(data.model.mode)
if model_mode == ModelMode.COMPLETION:
return self._generate_prompt_engineering_completion_prompt(data, query, variable_pool, model_config, memory)
return self._generate_prompt_engineering_completion_prompt(
node_data=data,
query=query,
variable_pool=variable_pool,
model_config=model_config,
memory=memory,
files=files,
)
elif model_mode == ModelMode.CHAT:
return self._generate_prompt_engineering_chat_prompt(data, query, variable_pool, model_config, memory)
return self._generate_prompt_engineering_chat_prompt(
node_data=data,
query=query,
variable_pool=variable_pool,
model_config=model_config,
memory=memory,
files=files,
)
else:
raise ValueError(f"Invalid model mode: {model_mode}")
@ -316,20 +354,23 @@ class ParameterExtractorNode(LLMNode):
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
files: Sequence[File],
) -> list[PromptMessage]:
"""
Generate completion prompt.
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "")
rest_token = self._calculate_rest_token(
node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context=""
)
prompt_template = self._get_prompt_engineering_prompt_template(
node_data, query, variable_pool, memory, rest_token
node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token
)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={"structure": json.dumps(node_data.get_parameter_json_schema())},
query="",
files=[],
files=files,
context="",
memory_config=node_data.memory,
memory=memory,
@ -345,27 +386,30 @@ class ParameterExtractorNode(LLMNode):
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
files: Sequence[File],
) -> list[PromptMessage]:
"""
Generate chat prompt.
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "")
rest_token = self._calculate_rest_token(
node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context=""
)
prompt_template = self._get_prompt_engineering_prompt_template(
node_data,
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format(
node_data=node_data,
query=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format(
structure=json.dumps(node_data.get_parameter_json_schema()), text=query
),
variable_pool,
memory,
rest_token,
variable_pool=variable_pool,
memory=memory,
max_token_limit=rest_token,
)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query="",
files=[],
files=files,
context="",
memory_config=node_data.memory,
memory=None,
@ -425,10 +469,11 @@ class ParameterExtractorNode(LLMNode):
raise ValueError(f"Invalid `string` value for parameter {parameter.name}")
if parameter.type.startswith("array"):
if not isinstance(result.get(parameter.name), list):
parameters = result.get(parameter.name)
if not isinstance(parameters, list):
raise ValueError(f"Invalid `array` value for parameter {parameter.name}")
nested_type = parameter.type[6:-1]
for item in result.get(parameter.name):
for item in parameters:
if nested_type == "number" and not isinstance(item, int | float):
raise ValueError(f"Invalid `array[number]` value for parameter {parameter.name}")
if nested_type == "string" and not isinstance(item, str):
@ -565,18 +610,6 @@ class ParameterExtractorNode(LLMNode):
return result
def _render_instruction(self, instruction: str, variable_pool: VariablePool) -> str:
"""
Render instruction.
"""
variable_template_parser = VariableTemplateParser(instruction)
inputs = {}
for selector in variable_template_parser.extract_variable_selectors():
variable = variable_pool.get_any(selector.value_selector)
inputs[selector.variable] = variable
return variable_template_parser.format(inputs)
def _get_function_calling_prompt_template(
self,
node_data: ParameterExtractorNodeData,
@ -588,9 +621,9 @@ class ParameterExtractorNode(LLMNode):
model_mode = ModelMode.value_of(node_data.model.mode)
input_text = query
memory_str = ""
instruction = self._render_instruction(node_data.instruction or "", variable_pool)
instruction = variable_pool.convert_template(node_data.instruction or "").text
if memory:
if memory and node_data.memory and node_data.memory.window:
memory_str = memory.get_history_prompt_text(
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
)
@ -611,13 +644,13 @@ class ParameterExtractorNode(LLMNode):
variable_pool: VariablePool,
memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000,
) -> list[ChatModelMessage]:
):
model_mode = ModelMode.value_of(node_data.model.mode)
input_text = query
memory_str = ""
instruction = self._render_instruction(node_data.instruction or "", variable_pool)
instruction = variable_pool.convert_template(node_data.instruction or "").text
if memory:
if memory and node_data.memory and node_data.memory.window:
memory_str = memory.get_history_prompt_text(
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
)
@ -691,7 +724,7 @@ class ParameterExtractorNode(LLMNode):
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)
or model_config.parameters.get(parameter_rule.use_template or "")
) or 0
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
@ -712,7 +745,11 @@ class ParameterExtractorNode(LLMNode):
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: ParameterExtractorNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: ParameterExtractorNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@ -721,11 +758,11 @@ class ParameterExtractorNode(LLMNode):
:param node_data: node data
:return:
"""
variable_mapping = {"query": node_data.query}
variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query}
if node_data.instruction:
variable_template_parser = VariableTemplateParser(template=node_data.instruction)
for selector in variable_template_parser.extract_variable_selectors():
selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction)
for selector in selectors:
variable_mapping[selector.variable] = selector.value_selector
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}

View File

@ -0,0 +1,3 @@
from .entities import QuestionClassifierNodeData
__all__ = ["QuestionClassifierNodeData"]

View File

@ -1,39 +1,21 @@
from typing import Any, Optional
from typing import Optional
from pydantic import BaseModel
from pydantic import BaseModel, Field
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.workflow.entities.base_node_data_entities import BaseNodeData
class ModelConfig(BaseModel):
"""
Model Config.
"""
provider: str
name: str
mode: str
completion_params: dict[str, Any] = {}
from core.workflow.nodes.llm import ModelConfig, VisionConfig
class ClassConfig(BaseModel):
"""
Class Config.
"""
id: str
name: str
class QuestionClassifierNodeData(BaseNodeData):
"""
Knowledge retrieval Node Data.
"""
query_variable_selector: list[str]
type: str = "question-classifier"
model: ModelConfig
classes: list[ClassConfig]
instruction: Optional[str] = None
memory: Optional[MemoryConfig] = None
vision: VisionConfig = Field(default_factory=VisionConfig)

View File

@ -1,25 +1,30 @@
import json
import logging
from collections.abc import Mapping, Sequence
from typing import Any, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Optional, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.llm.llm_node import LLMNode, ModelInvokeCompleted
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
from core.workflow.nodes.question_classifier.template_prompts import (
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.nodes.llm import (
LLMNode,
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
ModelInvokeCompleted,
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from enums import NodeType
from libs.json_in_md_parser import parse_and_check_json_markdown
from models.workflow import WorkflowNodeExecutionStatus
from .entities import QuestionClassifierNodeData
from .template_prompts import (
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1,
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2,
QUESTION_CLASSIFIER_COMPLETION_PROMPT,
@ -28,39 +33,70 @@ from core.workflow.nodes.question_classifier.template_prompts import (
QUESTION_CLASSIFIER_USER_PROMPT_2,
QUESTION_CLASSIFIER_USER_PROMPT_3,
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from libs.json_in_md_parser import parse_and_check_json_markdown
from models.workflow import WorkflowNodeExecutionStatus
if TYPE_CHECKING:
from core.file import File
class QuestionClassifierNode(LLMNode):
_node_data_cls = QuestionClassifierNodeData
node_type = NodeType.QUESTION_CLASSIFIER
_node_type = NodeType.QUESTION_CLASSIFIER
def _run(self) -> NodeRunResult:
node_data: QuestionClassifierNodeData = cast(self._node_data_cls, self.node_data)
node_data = cast(QuestionClassifierNodeData, node_data)
def _run(self):
node_data = cast(QuestionClassifierNodeData, self.node_data)
variable_pool = self.graph_runtime_state.variable_pool
# extract variables
variable = variable_pool.get(node_data.query_variable_selector)
variable = variable_pool.get(node_data.query_variable_selector) if node_data.query_variable_selector else None
query = variable.value if variable else None
variables = {"query": query}
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data.model)
# fetch memory
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
memory = self._fetch_memory(
node_data_memory=node_data.memory,
model_instance=model_instance,
)
# fetch instruction
instruction = self._format_instruction(node_data.instruction, variable_pool) if node_data.instruction else ""
node_data.instruction = instruction
node_data.instruction = node_data.instruction or ""
node_data.instruction = variable_pool.convert_template(node_data.instruction).text
files: Sequence[File] = (
self._fetch_files(
selector=node_data.vision.configs.variable_selector,
)
if node_data.vision.enabled
else []
)
# fetch prompt messages
prompt_messages, stop = self._fetch_prompt(
node_data=node_data, context="", query=query, memory=memory, model_config=model_config
rest_token = self._calculate_rest_token(
node_data=node_data,
query=query or "",
model_config=model_config,
context="",
)
prompt_template = self._get_prompt_template(
node_data=node_data,
query=query or "",
memory=memory,
max_token_limit=rest_token,
)
prompt_messages, stop = self._fetch_prompt_messages(
prompt_template=prompt_template,
system_query=query,
memory=memory,
model_config=model_config,
files=files,
vision_detail=node_data.vision.configs.detail,
)
# handle invoke result
generator = self._invoke_llm(
node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
)
result_text = ""
@ -129,7 +165,11 @@ class QuestionClassifierNode(LLMNode):
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: QuestionClassifierNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: QuestionClassifierNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@ -159,40 +199,6 @@ class QuestionClassifierNode(LLMNode):
"""
return {"type": "question-classifier", "config": {"instructions": ""}}
def _fetch_prompt(
self,
node_data: QuestionClassifierNodeData,
query: str,
context: Optional[str],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
"""
Fetch prompt
:param node_data: node data
:param query: inputs
:param context: context
:param memory: memory
:param model_config: model config
:return:
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(node_data, query, model_config, context)
prompt_template = self._get_prompt_template(node_data, query, memory, rest_token)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query="",
files=[],
context=context,
memory_config=node_data.memory,
memory=None,
model_config=model_config,
)
stop = model_config.stop
return prompt_messages, stop
def _calculate_rest_token(
self,
node_data: QuestionClassifierNodeData,
@ -229,7 +235,7 @@ class QuestionClassifierNode(LLMNode):
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)
or model_config.parameters.get(parameter_rule.use_template or "")
) or 0
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
@ -243,7 +249,7 @@ class QuestionClassifierNode(LLMNode):
query: str,
memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000,
) -> Union[list[ChatModelMessage], CompletionModelPromptTemplate]:
):
model_mode = ModelMode.value_of(node_data.model.mode)
classes = node_data.classes
categories = []
@ -255,31 +261,32 @@ class QuestionClassifierNode(LLMNode):
memory_str = ""
if memory:
memory_str = memory.get_history_prompt_text(
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
max_token_limit=max_token_limit,
message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None,
)
prompt_messages = []
prompt_messages: list[LLMNodeChatModelMessage] = []
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
system_prompt_messages = LLMNodeChatModelMessage(
role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str)
)
prompt_messages.append(system_prompt_messages)
user_prompt_message_1 = ChatModelMessage(
user_prompt_message_1 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_1
)
prompt_messages.append(user_prompt_message_1)
assistant_prompt_message_1 = ChatModelMessage(
assistant_prompt_message_1 = LLMNodeChatModelMessage(
role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1
)
prompt_messages.append(assistant_prompt_message_1)
user_prompt_message_2 = ChatModelMessage(
user_prompt_message_2 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2
)
prompt_messages.append(user_prompt_message_2)
assistant_prompt_message_2 = ChatModelMessage(
assistant_prompt_message_2 = LLMNodeChatModelMessage(
role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2
)
prompt_messages.append(assistant_prompt_message_2)
user_prompt_message_3 = ChatModelMessage(
user_prompt_message_3 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER,
text=QUESTION_CLASSIFIER_USER_PROMPT_3.format(
input_text=input_text,
@ -290,7 +297,7 @@ class QuestionClassifierNode(LLMNode):
prompt_messages.append(user_prompt_message_3)
return prompt_messages
elif model_mode == ModelMode.COMPLETION:
return CompletionModelPromptTemplate(
return LLMNodeCompletionModelPromptTemplate(
text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(
histories=memory_str,
input_text=input_text,
@ -302,23 +309,3 @@ class QuestionClassifierNode(LLMNode):
else:
raise ValueError(f"Model mode {model_mode} not support.")
def _format_instruction(self, instruction: str, variable_pool: VariablePool) -> str:
inputs = {}
variable_selectors = []
variable_template_parser = VariableTemplateParser(template=instruction)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
for variable_selector in variable_selectors:
variable = variable_pool.get(variable_selector.value_selector)
variable_value = variable.value if variable else None
if variable_value is None:
raise ValueError(f"Variable {variable_selector.variable} not found")
inputs[variable_selector.variable] = variable_value
prompt_template = PromptTemplateParser(template=instruction, with_variable_tmpl=True)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
instruction = prompt_template.format(prompt_inputs)
return instruction

View File

@ -1,25 +1,24 @@
from collections.abc import Mapping, Sequence
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.start.entities import StartNodeData
from enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus
class StartNode(BaseNode):
class StartNode(BaseNode[StartNodeData]):
_node_data_cls = StartNodeData
_node_type = NodeType.START
def _run(self) -> NodeRunResult:
"""
Run node
:return:
"""
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
system_inputs = self.graph_runtime_state.variable_pool.system_variables
# TODO: System variables should be directly accessible, no need for special handling
# Set system variables as node outputs.
for var in system_inputs:
node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
@ -27,13 +26,10 @@ class StartNode(BaseNode):
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: StartNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: StartNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return {}

View File

@ -1,17 +1,18 @@
import os
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any, Optional
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
from enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000"))
class TemplateTransformNode(BaseNode):
class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
_node_data_cls = TemplateTransformNodeData
_node_type = NodeType.TEMPLATE_TRANSFORM
@ -28,22 +29,16 @@ class TemplateTransformNode(BaseNode):
}
def _run(self) -> NodeRunResult:
"""
Run node
"""
node_data = self.node_data
node_data: TemplateTransformNodeData = cast(self._node_data_cls, node_data)
# Get variables
variables = {}
for variable_selector in node_data.variables:
for variable_selector in self.node_data.variables:
variable_name = variable_selector.variable
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
variables[variable_name] = value
# Run code
try:
result = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code=node_data.template, inputs=variables
language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables
)
except CodeExecutionError as e:
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))

View File

@ -51,7 +51,4 @@ class ToolNodeData(BaseNodeData, ToolEntity):
raise ValueError("value must be a string, int, float, or bool")
return typ
"""
Tool Node Schema
"""
tool_parameters: dict[str, ToolInput]

View File

@ -1,24 +1,23 @@
from collections.abc import Mapping, Sequence
from os import path
from typing import Any, cast
from typing import Any
from core.app.segments import ArrayAnySegment, ArrayAnyVariable, parser
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.file.models import File, FileTransferMethod, FileType
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.tool_engine import ToolEngine
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from models import WorkflowNodeExecutionStatus
from enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus
class ToolNode(BaseNode):
class ToolNode(BaseNode[ToolNodeData]):
"""
Tool Node
"""
@ -27,37 +26,38 @@ class ToolNode(BaseNode):
_node_type = NodeType.TOOL
def _run(self) -> NodeRunResult:
"""
Run the tool node
"""
node_data = cast(ToolNodeData, self.node_data)
# fetch tool icon
tool_info = {"provider_type": node_data.provider_type, "provider_id": node_data.provider_id}
tool_info = {
"provider_type": self.node_data.provider_type,
"provider_id": self.node_data.provider_id,
}
# get tool runtime
try:
tool_runtime = ToolManager.get_workflow_tool_runtime(
self.tenant_id, self.app_id, self.node_id, node_data, self.invoke_from
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from
)
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info,
},
error=f"Failed to get tool runtime: {str(e)}",
)
# get parameters
tool_parameters = tool_runtime.get_runtime_parameters() or []
parameters = self._generate_parameters(
tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
)
parameters_for_log = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=node_data,
node_data=self.node_data,
for_log=True,
)
@ -74,7 +74,9 @@ class ToolNode(BaseNode):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info,
},
error=f"Failed to invoke tool: {str(e)}",
)
@ -83,8 +85,14 @@ class ToolNode(BaseNode):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": plain_text, "files": files, "json": json},
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
outputs={
"text": plain_text,
"files": files,
"json": json,
},
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info,
},
inputs=parameters_for_log,
)
@ -116,29 +124,25 @@ class ToolNode(BaseNode):
if not parameter:
result[parameter_name] = None
continue
if parameter.type == ToolParameter.ToolParameterType.FILE:
result[parameter_name] = [v.to_dict() for v in self._fetch_files(variable_pool)]
tool_input = node_data.tool_parameters[parameter_name]
if tool_input.type == "variable":
variable = variable_pool.get(tool_input.value)
if variable is None:
raise ValueError(f"variable {tool_input.value} not exists")
parameter_value = variable.value
elif tool_input.type in {"mixed", "constant"}:
segment_group = variable_pool.convert_template(str(tool_input.value))
parameter_value = segment_group.log if for_log else segment_group.text
else:
tool_input = node_data.tool_parameters[parameter_name]
if tool_input.type == "variable":
# TODO: check if the variable exists in the variable pool
parameter_value = variable_pool.get(tool_input.value).value
else:
segment_group = parser.convert_template(
template=str(tool_input.value),
variable_pool=variable_pool,
)
parameter_value = segment_group.log if for_log else segment_group.text
result[parameter_name] = parameter_value
raise ValueError(f"unknown tool input type '{tool_input.type}'")
result[parameter_name] = parameter_value
return result
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar], list[dict]]:
def _convert_tool_messages(
self,
messages: list[ToolInvokeMessage],
):
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
@ -156,39 +160,38 @@ class ToolNode(BaseNode):
return plain_text, files, json
def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[FileVar]:
def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[File]:
"""
Extract tool response binary
"""
result = []
for response in tool_response:
if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
url = response.message
ext = path.splitext(url)[1]
mimetype = response.meta.get("mime_type", "image/jpeg")
filename = response.save_as or url.split("/")[-1]
tool_file_id = response.save_as or url.split("/")[-1]
transfer_method = response.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
# get tool file id
tool_file_id = url.split("/")[-1].split(".")[0]
tool_file_id = str(url).split("/")[-1].split(".")[0]
result.append(
FileVar(
File(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=transfer_method,
url=url,
remote_url=url,
related_id=tool_file_id,
filename=filename,
filename=tool_file_id,
extension=ext,
mime_type=mimetype,
)
)
elif response.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id
tool_file_id = response.message.split("/")[-1].split(".")[0]
tool_file_id = str(response.message).split("/")[-1].split(".")[0]
result.append(
FileVar(
File(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
@ -199,7 +202,29 @@ class ToolNode(BaseNode):
)
)
elif response.type == ToolInvokeMessage.MessageType.LINK:
pass # TODO:
url = str(response.message)
transfer_method = FileTransferMethod.TOOL_FILE
mimetype = response.meta.get("mime_type", "application/octet-stream")
tool_file_id = url.split("/")[-1].split(".")[0]
if "." in url:
extension = "." + url.split("/")[-1].split(".")[1]
else:
extension = ".bin"
file = File(
tenant_id=self.tenant_id,
type=FileType(response.save_as),
transfer_method=transfer_method,
remote_url=url,
filename=tool_file_id,
related_id=tool_file_id,
extension=extension,
mime_type=mimetype,
)
result.append(file)
elif response.type == ToolInvokeMessage.MessageType.FILE:
assert response.meta is not None
result.append(response.meta["file"])
return result
@ -218,12 +243,16 @@ class ToolNode(BaseNode):
]
)
def _extract_tool_response_json(self, tool_response: list[ToolInvokeMessage]) -> list[dict]:
def _extract_tool_response_json(self, tool_response: list[ToolInvokeMessage]):
return [message.message for message in tool_response if message.type == ToolInvokeMessage.MessageType.JSON]
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: ToolNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: ToolNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@ -236,7 +265,7 @@ class ToolNode(BaseNode):
for parameter_name in node_data.tool_parameters:
input = node_data.tool_parameters[parameter_name]
if input.type == "mixed":
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == "variable":

View File

@ -1,24 +1,24 @@
from collections.abc import Mapping, Sequence
from typing import Any, cast
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
from enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus
class VariableAggregatorNode(BaseNode):
class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
_node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_AGGREGATOR
def _run(self) -> NodeRunResult:
node_data = cast(VariableAssignerNodeData, self.node_data)
# Get variables
outputs = {}
inputs = {}
if not node_data.advanced_settings or not node_data.advanced_settings.group_enabled:
for selector in node_data.variables:
if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled:
for selector in self.node_data.variables:
variable = self.graph_runtime_state.variable_pool.get_any(selector)
if variable is not None:
outputs = {"output": variable}
@ -26,7 +26,7 @@ class VariableAggregatorNode(BaseNode):
inputs = {".".join(selector[1:]): variable}
break
else:
for group in node_data.advanced_settings.groups:
for group in self.node_data.advanced_settings.groups:
for selector in group.variables:
variable = self.graph_runtime_state.variable_pool.get_any(selector)

View File

@ -1,40 +1,39 @@
from typing import cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.segments import SegmentType, Variable, factory
from core.variables import SegmentType, Variable
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base_node import BaseNode
from enums import NodeType
from extensions.ext_database import db
from models import ConversationVariable, WorkflowNodeExecutionStatus
from factories import variable_factory
from models import ConversationVariable
from models.workflow import WorkflowNodeExecutionStatus
from .exc import VariableAssignerNodeError
from .node_data import VariableAssignerData, WriteMode
class VariableAssignerNode(BaseNode):
class VariableAssignerNode(BaseNode[VariableAssignerData]):
_node_data_cls: type[BaseNodeData] = VariableAssignerData
_node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
def _run(self) -> NodeRunResult:
data = cast(VariableAssignerData, self.node_data)
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(data.assigned_variable_selector)
original_variable = self.graph_runtime_state.variable_pool.get(self.node_data.assigned_variable_selector)
if not isinstance(original_variable, Variable):
raise VariableAssignerNodeError("assigned variable not found")
match data.write_mode:
match self.node_data.write_mode:
case WriteMode.OVER_WRITE:
income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector)
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError("input value not found")
updated_variable = original_variable.model_copy(update={"value": income_value.value})
case WriteMode.APPEND:
income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector)
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError("input value not found")
updated_value = original_variable.value + [income_value.value]
@ -45,10 +44,10 @@ class VariableAssignerNode(BaseNode):
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
case _:
raise VariableAssignerNodeError(f"unsupported write mode: {data.write_mode}")
raise VariableAssignerNodeError(f"unsupported write mode: {self.node_data.write_mode}")
# Over write the variable.
self.graph_runtime_state.variable_pool.add(data.assigned_variable_selector, updated_variable)
self.graph_runtime_state.variable_pool.add(self.node_data.assigned_variable_selector, updated_variable)
# TODO: Move database operation to the pipeline.
# Update conversation variable.
@ -80,12 +79,12 @@ def update_conversation_variable(conversation_id: str, variable: Variable):
def get_zero_value(t: SegmentType):
match t:
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
return factory.build_segment([])
return variable_factory.build_segment([])
case SegmentType.OBJECT:
return factory.build_segment({})
return variable_factory.build_segment({})
case SegmentType.STRING:
return factory.build_segment("")
return variable_factory.build_segment("")
case SegmentType.NUMBER:
return factory.build_segment(0)
return variable_factory.build_segment(0)
case _:
raise VariableAssignerNodeError(f"unsupported variable type: {t}")

View File

@ -1,32 +1,46 @@
from typing import Literal, Optional
from collections.abc import Sequence
from typing import Literal
from pydantic import BaseModel
from pydantic import BaseModel, Field
SupportedComparisonOperator = Literal[
# for string or array
"contains",
"not contains",
"start with",
"end with",
"is",
"is not",
"empty",
"not empty",
"in",
"not in",
"all of",
# for number
"=",
"",
">",
"<",
"",
"",
"null",
"not null",
]
class SubCondition(BaseModel):
key: str
comparison_operator: SupportedComparisonOperator
value: str | Sequence[str] | None = None
class SubVariableCondition(BaseModel):
logical_operator: Literal["and", "or"]
conditions: list[SubCondition] = Field(default=list)
class Condition(BaseModel):
"""
Condition entity
"""
variable_selector: list[str]
comparison_operator: Literal[
# for string or array
"contains",
"not contains",
"start with",
"end with",
"is",
"is not",
"empty",
"not empty",
# for number
"=",
"",
">",
"<",
"",
"",
"null",
"not null",
]
value: Optional[str] = None
comparison_operator: SupportedComparisonOperator
value: str | Sequence[str] | None = None
sub_variable_condition: SubVariableCondition | None = None

View File

@ -1,381 +1,362 @@
from collections.abc import Sequence
from typing import Any, Optional
from typing import Any, Literal
from core.file.file_obj import FileVar
from core.file import FileAttribute, file_manager
from core.variables.segments import ArrayFileSegment
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.utils.condition.entities import Condition
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from .entities import Condition, SubCondition, SupportedComparisonOperator
class ConditionProcessor:
def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[Condition]):
input_conditions = []
group_result = []
index = 0
for condition in conditions:
index += 1
actual_value = variable_pool.get_any(condition.variable_selector)
expected_value = None
if condition.value is not None:
variable_template_parser = VariableTemplateParser(template=condition.value)
variable_selectors = variable_template_parser.extract_variable_selectors()
if variable_selectors:
for variable_selector in variable_selectors:
value = variable_pool.get_any(variable_selector.value_selector)
expected_value = variable_template_parser.format({variable_selector.variable: value})
if expected_value is None:
expected_value = condition.value
else:
expected_value = condition.value
comparison_operator = condition.comparison_operator
input_conditions.append(
{
"actual_value": actual_value,
"expected_value": expected_value,
"comparison_operator": comparison_operator,
}
)
result = self.evaluate_condition(actual_value, comparison_operator, expected_value)
group_result.append(result)
return input_conditions, group_result
def evaluate_condition(
def process_conditions(
self,
actual_value: Optional[str | int | float | dict[Any, Any] | list[Any] | FileVar | None],
comparison_operator: str,
expected_value: Optional[str] = None,
) -> bool:
"""
Evaluate condition
:param actual_value: actual value
:param expected_value: expected value
:param comparison_operator: comparison operator
*,
variable_pool: VariablePool,
conditions: Sequence[Condition],
operator: Literal["and", "or"],
):
input_conditions = []
group_results = []
:return: bool
"""
if comparison_operator == "contains":
return self._assert_contains(actual_value, expected_value)
elif comparison_operator == "not contains":
return self._assert_not_contains(actual_value, expected_value)
elif comparison_operator == "start with":
return self._assert_start_with(actual_value, expected_value)
elif comparison_operator == "end with":
return self._assert_end_with(actual_value, expected_value)
elif comparison_operator == "is":
return self._assert_is(actual_value, expected_value)
elif comparison_operator == "is not":
return self._assert_is_not(actual_value, expected_value)
elif comparison_operator == "empty":
return self._assert_empty(actual_value)
elif comparison_operator == "not empty":
return self._assert_not_empty(actual_value)
elif comparison_operator == "=":
return self._assert_equal(actual_value, expected_value)
elif comparison_operator == "":
return self._assert_not_equal(actual_value, expected_value)
elif comparison_operator == ">":
return self._assert_greater_than(actual_value, expected_value)
elif comparison_operator == "<":
return self._assert_less_than(actual_value, expected_value)
elif comparison_operator == "":
return self._assert_greater_than_or_equal(actual_value, expected_value)
elif comparison_operator == "":
return self._assert_less_than_or_equal(actual_value, expected_value)
elif comparison_operator == "null":
return self._assert_null(actual_value)
elif comparison_operator == "not null":
return self._assert_not_null(actual_value)
else:
raise ValueError(f"Invalid comparison operator: {comparison_operator}")
for condition in conditions:
variable = variable_pool.get(condition.variable_selector)
def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
"""
Assert contains
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return False
if isinstance(variable, ArrayFileSegment) and condition.comparison_operator in {
"contains",
"not contains",
"all of",
}:
# check sub conditions
if not condition.sub_variable_condition:
raise ValueError("Sub variable is required")
result = _process_sub_conditions(
variable=variable,
sub_conditions=condition.sub_variable_condition.conditions,
operator=condition.sub_variable_condition.logical_operator,
)
else:
actual_value = variable.value if variable else None
expected_value = condition.value
if isinstance(expected_value, str):
expected_value = variable_pool.convert_template(expected_value).text
input_conditions.append(
{
"actual_value": actual_value,
"expected_value": expected_value,
"comparison_operator": condition.comparison_operator,
}
)
result = _evaluate_condition(
value=actual_value,
operator=condition.comparison_operator,
expected=expected_value,
)
group_results.append(result)
if not isinstance(actual_value, str | list):
raise ValueError("Invalid actual value type: string or array")
final_result = all(group_results) if operator == "and" else any(group_results)
return input_conditions, group_results, final_result
if expected_value not in actual_value:
return False
return True
def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
"""
Assert not contains
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return True
def _evaluate_condition(
*,
operator: SupportedComparisonOperator,
value: Any,
expected: str | Sequence[str] | None,
) -> bool:
match operator:
case "contains":
return _assert_contains(value=value, expected=expected)
case "not contains":
return _assert_not_contains(value=value, expected=expected)
case "start with":
return _assert_start_with(value=value, expected=expected)
case "end with":
return _assert_end_with(value=value, expected=expected)
case "is":
return _assert_is(value=value, expected=expected)
case "is not":
return _assert_is_not(value=value, expected=expected)
case "empty":
return _assert_empty(value=value)
case "not empty":
return _assert_not_empty(value=value)
case "=":
return _assert_equal(value=value, expected=expected)
case "":
return _assert_not_equal(value=value, expected=expected)
case ">":
return _assert_greater_than(value=value, expected=expected)
case "<":
return _assert_less_than(value=value, expected=expected)
case "":
return _assert_greater_than_or_equal(value=value, expected=expected)
case "":
return _assert_less_than_or_equal(value=value, expected=expected)
case "null":
return _assert_null(value=value)
case "not null":
return _assert_not_null(value=value)
case "in":
return _assert_in(value=value, expected=expected)
case "not in":
return _assert_not_in(value=value, expected=expected)
case "all of" if isinstance(expected, list):
return _assert_all_of(value=value, expected=expected)
case _:
raise ValueError(f"Unsupported operator: {operator}")
if not isinstance(actual_value, str | list):
raise ValueError("Invalid actual value type: string or array")
if expected_value in actual_value:
return False
return True
def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert start with
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return False
if not isinstance(actual_value, str):
raise ValueError("Invalid actual value type: string")
if not actual_value.startswith(expected_value):
return False
return True
def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert end with
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return False
if not isinstance(actual_value, str):
raise ValueError("Invalid actual value type: string")
if not actual_value.endswith(expected_value):
return False
return True
def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert is
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, str):
raise ValueError("Invalid actual value type: string")
if actual_value != expected_value:
return False
return True
def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert is not
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, str):
raise ValueError("Invalid actual value type: string")
if actual_value == expected_value:
return False
return True
def _assert_empty(self, actual_value: Optional[str]) -> bool:
"""
Assert empty
:param actual_value: actual value
:return:
"""
if not actual_value:
return True
def _assert_contains(*, value: Any, expected: Any) -> bool:
if not value:
return False
def _assert_not_empty(self, actual_value: Optional[str]) -> bool:
"""
Assert not empty
:param actual_value: actual value
:return:
"""
if actual_value:
return True
if not isinstance(value, str | list):
raise ValueError("Invalid actual value type: string or array")
if expected not in value:
return False
return True
def _assert_not_contains(*, value: Any, expected: Any) -> bool:
if not value:
return True
if not isinstance(value, str | list):
raise ValueError("Invalid actual value type: string or array")
if expected in value:
return False
return True
def _assert_start_with(*, value: Any, expected: Any) -> bool:
if not value:
return False
def _assert_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
"""
Assert equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(value, str):
raise ValueError("Invalid actual value type: string")
if not isinstance(actual_value, int | float):
raise ValueError("Invalid actual value type: number")
if not value.startswith(expected):
return False
return True
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value != expected_value:
return False
return True
def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
"""
Assert not equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError("Invalid actual value type: number")
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value == expected_value:
return False
return True
def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
"""
Assert greater than
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError("Invalid actual value type: number")
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value <= expected_value:
return False
return True
def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
"""
Assert less than
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError("Invalid actual value type: number")
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value >= expected_value:
return False
return True
def _assert_greater_than_or_equal(
self, actual_value: Optional[int | float], expected_value: str | int | float
) -> bool:
"""
Assert greater than or equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError("Invalid actual value type: number")
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value < expected_value:
return False
return True
def _assert_less_than_or_equal(
self, actual_value: Optional[int | float], expected_value: str | int | float
) -> bool:
"""
Assert less than or equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError("Invalid actual value type: number")
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value > expected_value:
return False
return True
def _assert_null(self, actual_value: Optional[int | float]) -> bool:
"""
Assert null
:param actual_value: actual value
:return:
"""
if actual_value is None:
return True
def _assert_end_with(*, value: Any, expected: Any) -> bool:
if not value:
return False
def _assert_not_null(self, actual_value: Optional[int | float]) -> bool:
"""
Assert not null
:param actual_value: actual value
:return:
"""
if actual_value is not None:
return True
if not isinstance(value, str):
raise ValueError("Invalid actual value type: string")
if not value.endswith(expected):
return False
return True
def _assert_is(*, value: Any, expected: Any) -> bool:
if value is None:
return False
if not isinstance(value, str):
raise ValueError("Invalid actual value type: string")
class ConditionAssertionError(Exception):
def __init__(self, message: str, conditions: list[dict], sub_condition_compare_results: list[dict]) -> None:
self.message = message
self.conditions = conditions
self.sub_condition_compare_results = sub_condition_compare_results
super().__init__(self.message)
if value != expected:
return False
return True
def _assert_is_not(*, value: Any, expected: Any) -> bool:
if value is None:
return False
if not isinstance(value, str):
raise ValueError("Invalid actual value type: string")
if value == expected:
return False
return True
def _assert_empty(*, value: Any) -> bool:
if not value:
return True
return False
def _assert_not_empty(*, value: Any) -> bool:
if value:
return True
return False
def _assert_equal(*, value: Any, expected: Any) -> bool:
if value is None:
return False
if not isinstance(value, int | float):
raise ValueError("Invalid actual value type: number")
if isinstance(value, int):
expected = int(expected)
else:
expected = float(expected)
if value != expected:
return False
return True
def _assert_not_equal(*, value: Any, expected: Any) -> bool:
if value is None:
return False
if not isinstance(value, int | float):
raise ValueError("Invalid actual value type: number")
if isinstance(value, int):
expected = int(expected)
else:
expected = float(expected)
if value == expected:
return False
return True
def _assert_greater_than(*, value: Any, expected: Any) -> bool:
if value is None:
return False
if not isinstance(value, int | float):
raise ValueError("Invalid actual value type: number")
if isinstance(value, int):
expected = int(expected)
else:
expected = float(expected)
if value <= expected:
return False
return True
def _assert_less_than(*, value: Any, expected: Any) -> bool:
if value is None:
return False
if not isinstance(value, int | float):
raise ValueError("Invalid actual value type: number")
if isinstance(value, int):
expected = int(expected)
else:
expected = float(expected)
if value >= expected:
return False
return True
def _assert_greater_than_or_equal(*, value: Any, expected: Any) -> bool:
if value is None:
return False
if not isinstance(value, int | float):
raise ValueError("Invalid actual value type: number")
if isinstance(value, int):
expected = int(expected)
else:
expected = float(expected)
if value < expected:
return False
return True
def _assert_less_than_or_equal(*, value: Any, expected: Any) -> bool:
if value is None:
return False
if not isinstance(value, int | float):
raise ValueError("Invalid actual value type: number")
if isinstance(value, int):
expected = int(expected)
else:
expected = float(expected)
if value > expected:
return False
return True
def _assert_null(*, value: Any) -> bool:
if value is None:
return True
return False
def _assert_not_null(*, value: Any) -> bool:
if value is not None:
return True
return False
def _assert_in(*, value: Any, expected: Any) -> bool:
if not value:
return False
if not isinstance(expected, list):
raise ValueError("Invalid expected value type: array")
if value not in expected:
return False
return True
def _assert_not_in(*, value: Any, expected: Any) -> bool:
if not value:
return True
if not isinstance(expected, list):
raise ValueError("Invalid expected value type: array")
if value in expected:
return False
return True
def _assert_all_of(*, value: Any, expected: Sequence[str]) -> bool:
if not value:
return False
if not all(item in value for item in expected):
return False
return True
def _process_sub_conditions(
variable: ArrayFileSegment,
sub_conditions: Sequence[SubCondition],
operator: Literal["and", "or"],
) -> bool:
files = variable.value
group_results = []
for condition in sub_conditions:
key = FileAttribute(condition.key)
values = [file_manager.get_attr(file=file, attr=key) for file in files]
sub_group_results = [
_evaluate_condition(
value=value,
operator=condition.comparison_operator,
expected=condition.value,
)
for value in values
]
# Determine the result based on the presence of "not" in the comparison operator
result = all(sub_group_results) if "not" in condition.comparison_operator else any(sub_group_results)
group_results.append(result)
return all(group_results) if operator == "and" else any(group_results)

View File

@ -1,42 +1,21 @@
import re
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from typing import Any
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool
REGEX = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}")
SELECTOR_PATTERN = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}")
def parse_mixed_template(*, template: str, variable_pool: VariablePool) -> str:
"""
This is an alternative to the VariableTemplateParser class,
offering the same functionality but with better readability and ease of use.
"""
variable_keys = [match[0] for match in re.findall(REGEX, template)]
variable_keys = list(set(variable_keys))
# This key_selector is a tuple of (key, selector) where selector is a list of keys
# e.g. ('#node_id.query.name#', ['node_id', 'query', 'name'])
key_selectors = filter(
lambda t: len(t[1]) >= 2,
((key, selector.replace("#", "").split(".")) for key, selector in zip(variable_keys, variable_keys)),
)
inputs = {key: variable_pool.get_any(selector) for key, selector in key_selectors}
def replacer(match):
key = match.group(1)
# return original matched string if key not found
value = inputs.get(key, match.group(0))
if value is None:
value = ""
value = str(value)
# remove template variables if required
return re.sub(REGEX, r"{\1}", value)
result = re.sub(REGEX, replacer, template)
result = re.sub(r"<\|.*?\|>", "", result)
return result
def extract_selectors_from_template(template: str, /) -> Sequence[VariableSelector]:
parts = SELECTOR_PATTERN.split(template)
selectors = []
for part in filter(lambda x: x, parts):
if "." in part and part[0] == "#" and part[-1] == "#":
selectors.append(VariableSelector(variable=f"{part}", value_selector=part[1:-1].split(".")))
return selectors
class VariableTemplateParser:

View File

@ -8,10 +8,9 @@ from configs import dify_config
from core.app.app_config.entities import FileExtraConfig
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.file.models import File, FileTransferMethod, FileType, ImageConfig
from core.workflow.callbacks import WorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType, UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent, InNodeEvent
@ -23,6 +22,7 @@ from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunEvent
from core.workflow.nodes.llm.entities import LLMNodeData
from core.workflow.nodes.node_mapping import node_classes
from enums import NodeType, UserFrom
from models.workflow import (
Workflow,
WorkflowType,
@ -205,32 +205,27 @@ class WorkflowEntry:
except Exception as e:
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
@classmethod
def handle_special_values(cls, value: Optional[Mapping[str, Any]]) -> Optional[dict]:
"""
Handle special values
:param value: value
:return:
"""
if not value:
return None
@staticmethod
def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None:
return WorkflowEntry._handle_special_values(value)
new_value = dict(value) if value else {}
if isinstance(new_value, dict):
for key, val in new_value.items():
if isinstance(val, FileVar):
new_value[key] = val.to_dict()
elif isinstance(val, list):
new_val = []
for v in val:
if isinstance(v, FileVar):
new_val.append(v.to_dict())
else:
new_val.append(v)
new_value[key] = new_val
return new_value
@staticmethod
def _handle_special_values(value: Any) -> Any:
if value is None:
return value
if isinstance(value, dict):
res = {}
for k, v in value.items():
res[k] = WorkflowEntry._handle_special_values(v)
return res
if isinstance(value, list):
res = []
for item in value:
res.append(WorkflowEntry._handle_special_values(item))
return res
if isinstance(value, File):
return value.to_dict()
return value
@classmethod
def mapping_user_inputs_to_variable_pool(
@ -276,15 +271,19 @@ class WorkflowEntry:
for item in input_value:
if isinstance(item, dict) and "type" in item and item["type"] == "image":
transfer_method = FileTransferMethod.value_of(item.get("transfer_method"))
file = FileVar(
file = File(
tenant_id=tenant_id,
type=FileType.IMAGE,
transfer_method=transfer_method,
url=item.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None,
remote_url=item.get("url")
if transfer_method == FileTransferMethod.REMOTE_URL
else None,
related_id=item.get("upload_file_id")
if transfer_method == FileTransferMethod.LOCAL_FILE
else None,
extra_config=FileExtraConfig(image_config={"detail": detail} if detail else None),
_extra_config=FileExtraConfig(
image_config=ImageConfig(detail=detail) if detail else None
),
)
new_value.append(file)