mirror of
https://github.com/langgenius/dify.git
synced 2026-04-27 14:08:18 +08:00
Merge remote-tracking branch 'upstream/main' into feat/human-input-merge-again
This commit is contained in:
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from abc import ABC
|
||||
from builtins import type as type_
|
||||
@ -111,7 +113,7 @@ class DefaultValue(BaseModel):
|
||||
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_value_type(self) -> "DefaultValue":
|
||||
def validate_value_type(self) -> DefaultValue:
|
||||
# Type validation configuration
|
||||
type_validators = {
|
||||
DefaultValueType.STRING: {
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import operator
|
||||
@ -72,7 +74,7 @@ class Node(Generic[NodeDataT]):
|
||||
in its output.
|
||||
"""
|
||||
|
||||
node_type: ClassVar["NodeType"]
|
||||
node_type: ClassVar[NodeType]
|
||||
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
|
||||
_node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData
|
||||
|
||||
@ -211,14 +213,14 @@ class Node(Generic[NodeDataT]):
|
||||
return None
|
||||
|
||||
# Global registry populated via __init_subclass__
|
||||
_registry: ClassVar[dict["NodeType", dict[str, type["Node"]]]] = {}
|
||||
_registry: ClassVar[dict[NodeType, dict[str, type[Node]]]] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> None:
|
||||
self._graph_init_params = graph_init_params
|
||||
self.id = id
|
||||
@ -254,7 +256,7 @@ class Node(Generic[NodeDataT]):
|
||||
return
|
||||
|
||||
@property
|
||||
def graph_init_params(self) -> "GraphInitParams":
|
||||
def graph_init_params(self) -> GraphInitParams:
|
||||
return self._graph_init_params
|
||||
|
||||
@property
|
||||
@ -300,6 +302,10 @@ class Node(Generic[NodeDataT]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _should_stop(self) -> bool:
|
||||
"""Check if execution should be stopped."""
|
||||
return self.graph_runtime_state.stop_event.is_set()
|
||||
|
||||
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
execution_id = self.ensure_execution_id()
|
||||
self._start_at = naive_utc_now()
|
||||
@ -368,6 +374,21 @@ class Node(Generic[NodeDataT]):
|
||||
yield event
|
||||
else:
|
||||
yield event
|
||||
|
||||
if self._should_stop():
|
||||
error_message = "Execution cancelled"
|
||||
yield NodeRunFailedEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error_message,
|
||||
),
|
||||
error=error_message,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception("Node %s failed to run", self._node_id)
|
||||
result = NodeRunResult(
|
||||
@ -474,7 +495,7 @@ class Node(Generic[NodeDataT]):
|
||||
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
|
||||
|
||||
@classmethod
|
||||
def get_node_type_classes_mapping(cls) -> Mapping["NodeType", Mapping[str, type["Node"]]]:
|
||||
def get_node_type_classes_mapping(cls) -> Mapping[NodeType, Mapping[str, type[Node]]]:
|
||||
"""Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry.
|
||||
|
||||
Import all modules under core.workflow.nodes so subclasses register themselves on import.
|
||||
@ -484,12 +505,8 @@ class Node(Generic[NodeDataT]):
|
||||
import core.workflow.nodes as _nodes_pkg
|
||||
|
||||
for _, _modname, _ in pkgutil.walk_packages(_nodes_pkg.__path__, _nodes_pkg.__name__ + "."):
|
||||
# Avoid importing modules that depend on the registry to prevent circular imports
|
||||
# e.g. node_factory imports node_mapping which builds the mapping here.
|
||||
if _modname in {
|
||||
"core.workflow.nodes.node_factory",
|
||||
"core.workflow.nodes.node_mapping",
|
||||
}:
|
||||
# Avoid importing modules that depend on the registry to prevent circular imports.
|
||||
if _modname == "core.workflow.nodes.node_mapping":
|
||||
continue
|
||||
importlib.import_module(_modname)
|
||||
|
||||
|
||||
@ -4,6 +4,8 @@ This module provides a unified template structure for both Answer and End nodes,
|
||||
similar to SegmentGroup but focused on template representation without values.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
@ -58,7 +60,7 @@ class Template:
|
||||
segments: list[TemplateSegmentUnion]
|
||||
|
||||
@classmethod
|
||||
def from_answer_template(cls, template_str: str) -> "Template":
|
||||
def from_answer_template(cls, template_str: str) -> Template:
|
||||
"""Create a Template from an Answer node template string.
|
||||
|
||||
Example:
|
||||
@ -107,7 +109,7 @@ class Template:
|
||||
return cls(segments=segments)
|
||||
|
||||
@classmethod
|
||||
def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> "Template":
|
||||
def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> Template:
|
||||
"""Create a Template from an End node outputs configuration.
|
||||
|
||||
End nodes are treated as templates of concatenated variables with newlines.
|
||||
|
||||
Reference in New Issue
Block a user