Merge branch 'feat/queue-based-graph-engine' into feat/rag-2

# Conflicts:
#	api/core/app/apps/advanced_chat/generate_task_pipeline.py
#	api/pyproject.toml
#	api/uv.lock
#	docker/docker-compose-template.yaml
#	docker/docker-compose.yaml
#	web/package.json
This commit is contained in:
jyong
2025-09-04 20:30:08 +08:00
87 changed files with 1930 additions and 393 deletions

View File

@ -1,6 +1,7 @@
from copy import deepcopy
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, PrivateAttr
from core.model_runtime.entities.llm_entities import LLMUsage
@ -8,21 +9,127 @@ from .variable_pool import VariablePool
class GraphRuntimeState(BaseModel):
variable_pool: VariablePool = Field(..., description="variable pool")
"""variable pool"""
# Private attributes to prevent direct modification
_variable_pool: VariablePool = PrivateAttr()
_start_at: float = PrivateAttr()
_total_tokens: int = PrivateAttr(default=0)
_llm_usage: LLMUsage = PrivateAttr(default_factory=LLMUsage.empty_usage)
_outputs: dict[str, Any] = PrivateAttr(default_factory=dict)
_node_run_steps: int = PrivateAttr(default=0)
start_at: float = Field(..., description="start time")
"""start time"""
total_tokens: int = 0
"""total tokens"""
llm_usage: LLMUsage = LLMUsage.empty_usage()
"""llm usage info"""
def __init__(
self,
variable_pool: VariablePool,
start_at: float,
total_tokens: int = 0,
llm_usage: LLMUsage | None = None,
outputs: dict[str, Any] | None = None,
node_run_steps: int = 0,
**kwargs,
):
"""Initialize the GraphRuntimeState with validation."""
super().__init__(**kwargs)
# The `outputs` field stores the final output values generated by executing workflows or chatflows.
#
# Note: Since the type of this field is `dict[str, Any]`, its values may not remain consistent
# after a serialization and deserialization round trip.
outputs: dict[str, Any] = Field(default_factory=dict)
# Initialize private attributes with validation
self._variable_pool = variable_pool
node_run_steps: int = 0
"""node run steps"""
self._start_at = start_at
if total_tokens < 0:
raise ValueError("total_tokens must be non-negative")
self._total_tokens = total_tokens
if llm_usage is None:
llm_usage = LLMUsage.empty_usage()
self._llm_usage = llm_usage
if outputs is None:
outputs = {}
self._outputs = deepcopy(outputs)
if node_run_steps < 0:
raise ValueError("node_run_steps must be non-negative")
self._node_run_steps = node_run_steps
@property
def variable_pool(self) -> VariablePool:
"""Get the variable pool."""
return self._variable_pool
@property
def start_at(self) -> float:
"""Get the start time."""
return self._start_at
@start_at.setter
def start_at(self, value: float) -> None:
"""Set the start time."""
self._start_at = value
@property
def total_tokens(self) -> int:
"""Get the total tokens count."""
return self._total_tokens
@total_tokens.setter
def total_tokens(self, value: int):
"""Set the total tokens count."""
if value < 0:
raise ValueError("total_tokens must be non-negative")
self._total_tokens = value
@property
def llm_usage(self) -> LLMUsage:
"""Get the LLM usage info."""
# Return a copy to prevent external modification
return self._llm_usage.model_copy()
@llm_usage.setter
def llm_usage(self, value: LLMUsage):
"""Set the LLM usage info."""
self._llm_usage = value.model_copy()
@property
def outputs(self) -> dict[str, Any]:
"""Get a copy of the outputs dictionary."""
return deepcopy(self._outputs)
@outputs.setter
def outputs(self, value: dict[str, Any]) -> None:
"""Set the outputs dictionary."""
self._outputs = deepcopy(value)
def set_output(self, key: str, value: Any) -> None:
"""Set a single output value."""
self._outputs[key] = deepcopy(value)
def get_output(self, key: str, default: Any = None) -> Any:
"""Get a single output value."""
return deepcopy(self._outputs.get(key, default))
def update_outputs(self, updates: dict[str, Any]) -> None:
"""Update multiple output values."""
for key, value in updates.items():
self._outputs[key] = deepcopy(value)
@property
def node_run_steps(self) -> int:
"""Get the node run steps count."""
return self._node_run_steps
@node_run_steps.setter
def node_run_steps(self, value: int) -> None:
"""Set the node run steps count."""
if value < 0:
raise ValueError("node_run_steps must be non-negative")
self._node_run_steps = value
def increment_node_run_steps(self) -> None:
"""Increment the node run steps by 1."""
self._node_run_steps += 1
def add_tokens(self, tokens: int) -> None:
"""Add tokens to the total count."""
if tokens < 0:
raise ValueError("tokens must be non-negative")
self._total_tokens += tokens

View File

@ -1,5 +1,16 @@
from .edge import Edge
from .graph import Graph, NodeFactory
from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool
from .graph_template import GraphTemplate
from .read_only_state_wrapper import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper
__all__ = ["Edge", "Graph", "GraphTemplate", "NodeFactory"]
__all__ = [
"Edge",
"Graph",
"GraphTemplate",
"NodeFactory",
"ReadOnlyGraphRuntimeState",
"ReadOnlyGraphRuntimeStateWrapper",
"ReadOnlyVariablePool",
"ReadOnlyVariablePoolWrapper",
]

View File

@ -0,0 +1,59 @@
from typing import Any, Protocol
from core.model_runtime.entities.llm_entities import LLMUsage
class ReadOnlyVariablePool(Protocol):
"""Read-only interface for VariablePool."""
def get(self, node_id: str, variable_key: str) -> Any:
"""Get a variable value (read-only)."""
...
def get_all_by_node(self, node_id: str) -> dict[str, Any]:
"""Get all variables for a node (read-only)."""
...
class ReadOnlyGraphRuntimeState(Protocol):
"""
Read-only view of GraphRuntimeState for layers.
This protocol defines a read-only interface that prevents layers from
modifying the graph runtime state while still allowing observation.
All methods return defensive copies to ensure immutability.
"""
@property
def variable_pool(self) -> ReadOnlyVariablePool:
"""Get read-only access to the variable pool."""
...
@property
def start_at(self) -> float:
"""Get the start time (read-only)."""
...
@property
def total_tokens(self) -> int:
"""Get the total tokens count (read-only)."""
...
@property
def llm_usage(self) -> LLMUsage:
"""Get a copy of LLM usage info (read-only)."""
...
@property
def outputs(self) -> dict[str, Any]:
"""Get a defensive copy of outputs (read-only)."""
...
@property
def node_run_steps(self) -> int:
"""Get the node run steps count (read-only)."""
...
def get_output(self, key: str, default: Any = None) -> Any:
"""Get a single output value (returns a copy)."""
...

View File

@ -0,0 +1,76 @@
from copy import deepcopy
from typing import Any
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
class ReadOnlyVariablePoolWrapper:
"""Wrapper that provides read-only access to VariablePool."""
def __init__(self, variable_pool: VariablePool):
self._variable_pool = variable_pool
def get(self, node_id: str, variable_key: str) -> Any:
"""Get a variable value (returns a defensive copy)."""
value = self._variable_pool.get(node_id, variable_key)
return deepcopy(value) if value is not None else None
def get_all_by_node(self, node_id: str) -> dict[str, Any]:
"""Get all variables for a node (returns defensive copies)."""
variables = {}
if node_id in self._variable_pool.variable_dictionary:
for key, var in self._variable_pool.variable_dictionary[node_id].items():
# FIXME(-LAN-): Handle the actual Variable object structure
value = var.value if hasattr(var, "value") else var
variables[key] = deepcopy(value)
return variables
class ReadOnlyGraphRuntimeStateWrapper:
"""
Wrapper that provides read-only access to GraphRuntimeState.
This wrapper ensures that layers can observe the state without
modifying it. All returned values are defensive copies.
"""
def __init__(self, state: GraphRuntimeState):
self._state = state
self._variable_pool_wrapper = ReadOnlyVariablePoolWrapper(state.variable_pool)
@property
def variable_pool(self) -> ReadOnlyVariablePoolWrapper:
"""Get read-only access to the variable pool."""
return self._variable_pool_wrapper
@property
def start_at(self) -> float:
"""Get the start time (read-only)."""
return self._state.start_at
@property
def total_tokens(self) -> int:
"""Get the total tokens count (read-only)."""
return self._state.total_tokens
@property
def llm_usage(self) -> LLMUsage:
"""Get a copy of LLM usage info (read-only)."""
# Return a copy to prevent modification
return self._state.llm_usage.model_copy()
@property
def outputs(self) -> dict[str, Any]:
"""Get a defensive copy of outputs (read-only)."""
return deepcopy(self._state.outputs)
@property
def node_run_steps(self) -> int:
"""Get the node run steps count (read-only)."""
return self._state.node_run_steps
def get_output(self, key: str, default: Any = None) -> Any:
"""Get a single output value (returns a copy)."""
return self._state.get_output(key, default)

View File

@ -267,10 +267,10 @@ class EventHandler:
# in runtime state, rather than allowing nodes to directly access runtime state.
for key, value in event.node_run_result.outputs.items():
if key == "answer":
existing = self._graph_runtime_state.outputs.get("answer", "")
existing = self._graph_runtime_state.get_output("answer", "")
if existing:
self._graph_runtime_state.outputs["answer"] = f"{existing}{value}"
self._graph_runtime_state.set_output("answer", f"{existing}{value}")
else:
self._graph_runtime_state.outputs["answer"] = value
self._graph_runtime_state.set_output("answer", value)
else:
self._graph_runtime_state.outputs[key] = value
self._graph_runtime_state.set_output(key, value)

View File

@ -9,7 +9,7 @@ from typing import final
from core.workflow.graph_events import GraphEngineEvent
from ..layers.base import Layer
from ..layers.base import GraphEngineLayer
@final
@ -104,10 +104,10 @@ class EventManager:
"""Initialize the event manager."""
self._events: list[GraphEngineEvent] = []
self._lock = ReadWriteLock()
self._layers: list[Layer] = []
self._layers: list[GraphEngineLayer] = []
self._execution_complete = threading.Event()
def set_layers(self, layers: list[Layer]) -> None:
def set_layers(self, layers: list[GraphEngineLayer]) -> None:
"""
Set the layers to notify on event collection.

View File

@ -17,6 +17,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import NodeExecutionType
from core.workflow.graph import Graph
from core.workflow.graph.read_only_state_wrapper import ReadOnlyGraphRuntimeStateWrapper
from core.workflow.graph_events import (
GraphEngineEvent,
GraphNodeEventBase,
@ -33,12 +34,12 @@ from .entities.commands import AbortCommand
from .error_handling import ErrorHandler
from .event_management import EventHandler, EventManager
from .graph_traversal import EdgeProcessor, SkipPropagator
from .layers.base import Layer
from .layers.base import GraphEngineLayer
from .orchestration import Dispatcher, ExecutionCoordinator
from .protocols.command_channel import CommandChannel
from .response_coordinator import ResponseStreamCoordinator
from .state_management import UnifiedStateManager
from .worker_management import SimpleWorkerPool
from .worker_management import WorkerPool
logger = logging.getLogger(__name__)
@ -186,7 +187,7 @@ class GraphEngine:
context_vars = contextvars.copy_context()
# Create worker pool for parallel node execution
self._worker_pool = SimpleWorkerPool(
self._worker_pool = WorkerPool(
ready_queue=self._ready_queue,
event_queue=self._event_queue,
graph=self._graph,
@ -221,7 +222,7 @@ class GraphEngine:
# === Extensibility ===
# Layers allow plugins to extend engine functionality
self._layers: list[Layer] = []
self._layers: list[GraphEngineLayer] = []
# === Validation ===
# Ensure all nodes share the same GraphRuntimeState instance
@ -234,7 +235,7 @@ class GraphEngine:
if id(node.graph_runtime_state) != expected_state_id:
raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance")
def layer(self, layer: Layer) -> "GraphEngine":
def layer(self, layer: GraphEngineLayer) -> "GraphEngine":
"""Add a layer for extending functionality."""
self._layers.append(layer)
return self
@ -288,9 +289,11 @@ class GraphEngine:
def _initialize_layers(self) -> None:
"""Initialize layers with context."""
self._event_manager.set_layers(self._layers)
# Create a read-only wrapper for the runtime state
read_only_state = ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state)
for layer in self._layers:
try:
layer.initialize(self._graph_runtime_state, self._command_channel)
layer.initialize(read_only_state, self._command_channel)
except Exception as e:
logger.warning("Failed to initialize layer %s: %s", layer.__class__.__name__, e)

View File

@ -5,12 +5,12 @@ This module provides the layer infrastructure for extending GraphEngine function
with middleware-like components that can observe events and interact with execution.
"""
from .base import Layer
from .base import GraphEngineLayer
from .debug_logging import DebugLoggingLayer
from .execution_limits import ExecutionLimitsLayer
__all__ = [
"DebugLoggingLayer",
"ExecutionLimitsLayer",
"Layer",
"GraphEngineLayer",
]

View File

@ -7,12 +7,12 @@ intercept and respond to GraphEngine events.
from abc import ABC, abstractmethod
from core.workflow.entities import GraphRuntimeState
from core.workflow.graph.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
from core.workflow.graph_events import GraphEngineEvent
class Layer(ABC):
class GraphEngineLayer(ABC):
"""
Abstract base class for GraphEngine layers.
@ -27,19 +27,19 @@ class Layer(ABC):
def __init__(self) -> None:
"""Initialize the layer. Subclasses can override with custom parameters."""
self.graph_runtime_state: GraphRuntimeState | None = None
self.graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
self.command_channel: CommandChannel | None = None
def initialize(self, graph_runtime_state: GraphRuntimeState, command_channel: CommandChannel) -> None:
def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None:
"""
Initialize the layer with engine dependencies.
Called by GraphEngine before execution starts to inject the runtime state
and command channel. This allows layers to access engine context and send
commands.
Called by GraphEngine before execution starts to inject the read-only runtime state
and command channel. This allows layers to observe engine context and send
commands, but prevents direct state modification.
Args:
graph_runtime_state: The runtime state of the graph execution
graph_runtime_state: Read-only view of the runtime state
command_channel: Channel for sending commands to the engine
"""
self.graph_runtime_state = graph_runtime_state

View File

@ -33,11 +33,11 @@ from core.workflow.graph_events import (
NodeRunSucceededEvent,
)
from .base import Layer
from .base import GraphEngineLayer
@final
class DebugLoggingLayer(Layer):
class DebugLoggingLayer(GraphEngineLayer):
"""
A layer that provides comprehensive logging of GraphEngine execution.

View File

@ -16,7 +16,7 @@ from typing import final
from typing_extensions import override
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
from core.workflow.graph_engine.layers import Layer
from core.workflow.graph_engine.layers import GraphEngineLayer
from core.workflow.graph_events import (
GraphEngineEvent,
NodeRunStartedEvent,
@ -32,7 +32,7 @@ class LimitType(Enum):
@final
class ExecutionLimitsLayer(Layer):
class ExecutionLimitsLayer(GraphEngineLayer):
"""
Layer that enforces execution limits for workflows.

View File

@ -8,7 +8,7 @@ from ..command_processing import CommandProcessor
from ..domain import GraphExecution
from ..event_management import EventManager
from ..state_management import UnifiedStateManager
from ..worker_management import SimpleWorkerPool
from ..worker_management import WorkerPool
if TYPE_CHECKING:
from ..event_management import EventHandler
@ -30,7 +30,7 @@ class ExecutionCoordinator:
event_handler: "EventHandler",
event_collector: EventManager,
command_processor: CommandProcessor,
worker_pool: SimpleWorkerPool,
worker_pool: WorkerPool,
) -> None:
"""
Initialize the execution coordinator.

View File

@ -9,7 +9,6 @@ import contextvars
import queue
import threading
import time
from collections.abc import Callable
from datetime import datetime
from typing import final
from uuid import uuid4
@ -42,8 +41,6 @@ class Worker(threading.Thread):
worker_id: int = 0,
flask_app: Flask | None = None,
context_vars: contextvars.Context | None = None,
on_idle_callback: Callable[[int], None] | None = None,
on_active_callback: Callable[[int], None] | None = None,
) -> None:
"""
Initialize worker thread.
@ -55,8 +52,6 @@ class Worker(threading.Thread):
worker_id: Unique identifier for this worker
flask_app: Optional Flask application for context preservation
context_vars: Optional context variables to preserve in worker thread
on_idle_callback: Optional callback when worker becomes idle
on_active_callback: Optional callback when worker becomes active
"""
super().__init__(name=f"GraphWorker-{worker_id}", daemon=True)
self._ready_queue = ready_queue
@ -66,14 +61,28 @@ class Worker(threading.Thread):
self._flask_app = flask_app
self._context_vars = context_vars
self._stop_event = threading.Event()
self._on_idle_callback = on_idle_callback
self._on_active_callback = on_active_callback
self._last_task_time = time.time()
def stop(self) -> None:
"""Signal the worker to stop processing."""
self._stop_event.set()
@property
def is_idle(self) -> bool:
"""Check if the worker is currently idle."""
# Worker is idle if it hasn't processed a task recently (within 0.2 seconds)
return (time.time() - self._last_task_time) > 0.2
@property
def idle_duration(self) -> float:
"""Get the duration in seconds since the worker last processed a task."""
return time.time() - self._last_task_time
@property
def worker_id(self) -> int:
"""Get the worker's ID."""
return self._worker_id
@override
def run(self) -> None:
"""
@ -87,15 +96,8 @@ class Worker(threading.Thread):
try:
node_id = self._ready_queue.get(timeout=0.1)
except queue.Empty:
# Notify that worker is idle
if self._on_idle_callback:
self._on_idle_callback(self._worker_id)
continue
# Notify that worker is active
if self._on_active_callback:
self._on_active_callback(self._worker_id)
self._last_task_time = time.time()
node = self._graph.nodes[node_id]
try:

View File

@ -5,8 +5,8 @@ This package manages the worker pool, including creation,
scaling, and activity tracking.
"""
from .simple_worker_pool import SimpleWorkerPool
from .worker_pool import WorkerPool
__all__ = [
"SimpleWorkerPool",
"WorkerPool",
]

View File

@ -5,6 +5,7 @@ This is a simpler implementation that merges WorkerPool, ActivityTracker,
DynamicScaler, and WorkerFactory into a single class.
"""
import logging
import queue
import threading
from typing import TYPE_CHECKING, final
@ -15,6 +16,8 @@ from core.workflow.graph_events import GraphNodeEventBase
from ..worker import Worker
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from contextvars import Context
@ -22,7 +25,7 @@ if TYPE_CHECKING:
@final
class SimpleWorkerPool:
class WorkerPool:
"""
Simple worker pool with integrated management.
@ -74,6 +77,8 @@ class SimpleWorkerPool:
self._lock = threading.RLock()
self._running = False
# No longer tracking worker states with callbacks to avoid lock contention
def start(self, initial_count: int | None = None) -> None:
"""
Start the worker pool.
@ -97,6 +102,14 @@ class SimpleWorkerPool:
else:
initial_count = min(self._min_workers + 2, self._max_workers)
logger.debug(
"Starting worker pool: %d workers (nodes=%d, min=%d, max=%d)",
initial_count,
node_count,
self._min_workers,
self._max_workers,
)
# Create initial workers
for _ in range(initial_count):
self._create_worker()
@ -105,6 +118,10 @@ class SimpleWorkerPool:
"""Stop all workers in the pool."""
with self._lock:
self._running = False
worker_count = len(self._workers)
if worker_count > 0:
logger.debug("Stopping worker pool: %d workers", worker_count)
# Stop all workers
for worker in self._workers:
@ -134,6 +151,105 @@ class SimpleWorkerPool:
worker.start()
self._workers.append(worker)
def _remove_worker(self, worker: Worker, worker_id: int) -> None:
"""Remove a specific worker from the pool."""
# Stop the worker
worker.stop()
# Wait for it to finish
if worker.is_alive():
worker.join(timeout=2.0)
# Remove from list
if worker in self._workers:
self._workers.remove(worker)
def _try_scale_up(self, queue_depth: int, current_count: int) -> bool:
"""
Try to scale up workers if needed.
Args:
queue_depth: Current queue depth
current_count: Current number of workers
Returns:
True if scaled up, False otherwise
"""
if queue_depth > self._scale_up_threshold and current_count < self._max_workers:
old_count = current_count
self._create_worker()
logger.debug(
"Scaled up workers: %d -> %d (queue_depth=%d exceeded threshold=%d)",
old_count,
len(self._workers),
queue_depth,
self._scale_up_threshold,
)
return True
return False
def _try_scale_down(self, queue_depth: int, current_count: int, active_count: int, idle_count: int) -> bool:
"""
Try to scale down workers if we have excess capacity.
Args:
queue_depth: Current queue depth
current_count: Current number of workers
active_count: Number of active workers
idle_count: Number of idle workers
Returns:
True if scaled down, False otherwise
"""
# Skip if we're at minimum or have no idle workers
if current_count <= self._min_workers or idle_count == 0:
return False
# Check if we have excess capacity
has_excess_capacity = (
queue_depth <= active_count # Active workers can handle current queue
or idle_count > active_count # More idle than active workers
or (queue_depth == 0 and idle_count > 0) # No work and have idle workers
)
if not has_excess_capacity:
return False
# Find and remove idle workers that have been idle long enough
workers_to_remove = []
for worker in self._workers:
# Check if worker is idle and has exceeded idle time threshold
if worker.is_idle and worker.idle_duration >= self._scale_down_idle_time:
# Don't remove if it would leave us unable to handle the queue
remaining_workers = current_count - len(workers_to_remove) - 1
if remaining_workers >= self._min_workers and remaining_workers >= max(1, queue_depth // 2):
workers_to_remove.append((worker, worker.worker_id))
# Only remove one worker per check to avoid aggressive scaling
break
# Remove idle workers if any found
if workers_to_remove:
old_count = current_count
for worker, worker_id in workers_to_remove:
self._remove_worker(worker, worker_id)
logger.debug(
"Scaled down workers: %d -> %d (removed %d idle workers after %.1fs, "
"queue_depth=%d, active=%d, idle=%d)",
old_count,
len(self._workers),
len(workers_to_remove),
self._scale_down_idle_time,
queue_depth,
active_count,
idle_count - len(workers_to_remove),
)
return True
return False
def check_and_scale(self) -> None:
"""Check and perform scaling if needed."""
with self._lock:
@ -143,9 +259,15 @@ class SimpleWorkerPool:
current_count = len(self._workers)
queue_depth = self._ready_queue.qsize()
# Simple scaling logic
if queue_depth > self._scale_up_threshold and current_count < self._max_workers:
self._create_worker()
# Count active vs idle workers by querying their state directly
idle_count = sum(1 for worker in self._workers if worker.is_idle)
active_count = current_count - idle_count
# Try to scale up if queue is backing up
self._try_scale_up(queue_depth, current_count)
# Try to scale down if we have excess capacity
self._try_scale_down(queue_depth, current_count, active_count, idle_count)
def get_worker_count(self) -> int:
"""Get current number of workers."""

View File

@ -579,7 +579,7 @@ class AgentNode(Node):
for key, value in msg_metadata.items()
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
}
if message.message.json_object is not None:
if message.message.json_object:
json_list.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)

View File

@ -73,9 +73,6 @@ class DefaultValue(BaseModel):
@model_validator(mode="after")
def validate_value_type(self) -> "DefaultValue":
if self.type is None:
raise DefaultValueTypeError("type field is required")
# Type validation configuration
type_validators = {
DefaultValueType.STRING: {

View File

@ -108,8 +108,6 @@ class CodeNode(Node):
"""
if value is None:
return None
if not isinstance(value, str):
raise OutputValidationError(f"Output variable `{variable}` must be a string")
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
raise OutputValidationError(
@ -122,8 +120,6 @@ class CodeNode(Node):
def _check_boolean(self, value: bool | None, variable: str) -> bool | None:
if value is None:
return None
if not isinstance(value, bool):
raise OutputValidationError(f"Output variable `{variable}` must be a boolean")
return value
@ -136,8 +132,6 @@ class CodeNode(Node):
"""
if value is None:
return None
if not isinstance(value, int | float):
raise OutputValidationError(f"Output variable `{variable}` must be a number")
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
raise OutputValidationError(
@ -261,7 +255,13 @@ class CodeNode(Node):
)
elif output_config.type == SegmentType.NUMBER:
# check if number available
checked = self._check_number(value=result[output_name], variable=f"{prefix}{dot}{output_name}")
value = result.get(output_name)
if not isinstance(value, (int, float, None)):
raise OutputValidationError(
f"Output {prefix}{dot}{output_name} is not a number,"
f" got {type(result.get(output_name))} instead."
)
checked = self._check_number(value=value, variable=f"{prefix}{dot}{output_name}")
# If the output is a boolean and the output schema specifies a NUMBER type,
# convert the boolean value to an integer.
#
@ -271,8 +271,11 @@ class CodeNode(Node):
elif output_config.type == SegmentType.STRING:
# check if string available
value = result.get("output_name")
if value is not None and not isinstance(value, str):
raise OutputValidationError(f"Output value `{value}` is not string")
transformed_result[output_name] = self._check_string(
value=result[output_name],
value=value,
variable=f"{prefix}{dot}{output_name}",
)
elif output_config.type == SegmentType.BOOLEAN:
@ -282,31 +285,36 @@ class CodeNode(Node):
)
elif output_config.type == SegmentType.ARRAY_NUMBER:
# check if array of number available
if not isinstance(result[output_name], list):
if result[output_name] is None:
value = result[output_name]
if not isinstance(value, list):
if value is None:
transformed_result[output_name] = None
else:
raise OutputValidationError(
f"Output {prefix}{dot}{output_name} is not an array,"
f" got {type(result.get(output_name))} instead."
f"Output {prefix}{dot}{output_name} is not an array, got {type(value)} instead."
)
else:
if len(result[output_name]) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH:
if len(value) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH:
raise OutputValidationError(
f"The length of output variable `{prefix}{dot}{output_name}` must be"
f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements."
)
for i, inner_value in enumerate(value):
if not isinstance(inner_value, (int, float)):
raise OutputValidationError(
f"The element at index {i} of output variable `{prefix}{dot}{output_name}` must be"
f" a number."
)
_ = self._check_number(value=inner_value, variable=f"{prefix}{dot}{output_name}[{i}]")
transformed_result[output_name] = [
# If the element is a boolean and the output schema specifies a `array[number]` type,
# convert the boolean value to an integer.
#
# This ensures compatibility with existing workflows that may use
# `True` and `False` as values for NUMBER type outputs.
self._convert_boolean_to_int(
self._check_number(value=value, variable=f"{prefix}{dot}{output_name}[{i}]"),
)
for i, value in enumerate(result[output_name])
self._convert_boolean_to_int(v)
for v in value
]
elif output_config.type == SegmentType.ARRAY_STRING:
# check if array of string available
@ -369,8 +377,9 @@ class CodeNode(Node):
]
elif output_config.type == SegmentType.ARRAY_BOOLEAN:
# check if array of object available
if not isinstance(result[output_name], list):
if result[output_name] is None:
value = result[output_name]
if not isinstance(value, list):
if value is None:
transformed_result[output_name] = None
else:
raise OutputValidationError(
@ -378,10 +387,14 @@ class CodeNode(Node):
f" got {type(result.get(output_name))} instead."
)
else:
transformed_result[output_name] = [
self._check_boolean(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
for i, value in enumerate(result[output_name])
]
for i, inner_value in enumerate(value):
if not isinstance(inner_value, bool | None):
raise OutputValidationError(
f"Output {prefix}{dot}{output_name}[{i}] is not a boolean,"
f" got {type(inner_value)} instead."
)
_ = self._check_boolean(value=inner_value, variable=f"{prefix}{dot}{output_name}[{i}]")
transformed_result[output_name] = value
else:
raise OutputValidationError(f"Output type {output_config.type} is not supported.")

View File

@ -263,9 +263,6 @@ class Executor:
if authorization.config is None:
raise AuthorizationConfigError("authorization config is required")
if self.auth.config.api_key is None:
raise AuthorizationConfigError("api_key is required")
if not authorization.config.header:
authorization.config.header = "Authorization"
@ -409,30 +406,22 @@ class Executor:
if self.files and not all(f[0] == "__multipart_placeholder__" for f in self.files):
for file_entry in self.files:
# file_entry should be (key, (filename, content, mime_type)), but handle edge cases
if len(file_entry) != 2 or not isinstance(file_entry[1], tuple) or len(file_entry[1]) < 2:
if len(file_entry) != 2 or len(file_entry[1]) < 2:
continue # skip malformed entries
key = file_entry[0]
content = file_entry[1][1]
body_string += f"--{boundary}\r\n"
body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
# decode content safely
if isinstance(content, bytes):
try:
body_string += content.decode("utf-8")
except UnicodeDecodeError:
body_string += content.decode("utf-8", errors="replace")
elif isinstance(content, str):
body_string += content
else:
body_string += f"[Unsupported content type: {type(content).__name__}]"
try:
body_string += content.decode("utf-8")
except UnicodeDecodeError:
body_string += content.decode("utf-8", errors="replace")
body_string += "\r\n"
body_string += f"--{boundary}--\r\n"
elif self.node_data.body:
if self.content:
if isinstance(self.content, str):
body_string = self.content
elif isinstance(self.content, bytes):
body_string = self.content.decode("utf-8", errors="replace")
body_string = self.content.decode("utf-8", errors="replace")
elif self.data and self.node_data.body.type == "x-www-form-urlencoded":
body_string = urlencode(self.data)
elif self.data and self.node_data.body.type == "form-data":

View File

@ -170,27 +170,21 @@ class ListOperatorNode(Node):
)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayBooleanSegment):
if not isinstance(condition.value, bool):
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
else:
filter_func = _get_boolean_filter_func(condition=condition.comparison_operator, value=condition.value)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
else:
raise AssertionError("this statment should be unreachable.")
return variable
def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
if isinstance(variable, (ArrayStringSegment, ArrayNumberSegment, ArrayBooleanSegment)):
result = sorted(variable.value, reverse=self._node_data.order_by == Order.DESC)
result = sorted(variable.value, reverse=self._node_data.order_by.value == Order.DESC)
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayFileSegment):
else:
result = _order_file(
order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value
)
variable = variable.model_copy(update={"value": result})
else:
raise AssertionError("this statement should be unreachable")
return variable
@ -304,7 +298,7 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str
if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str):
extract_func = _get_file_extract_string_func(key=key)
return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x))
if key in {"type", "transfer_method"} and isinstance(value, Sequence):
if key in {"type", "transfer_method"}:
extract_func = _get_file_extract_string_func(key=key)
return lambda x: _get_sequence_filter_func(condition=condition, value=value)(extract_func(x))
elif key == "size" and isinstance(value, str):

View File

@ -195,9 +195,8 @@ class LLMNode(Node):
generator = self._fetch_context(node_data=self._node_data)
context = None
for event in generator:
if isinstance(event, RunRetrieverResourceEvent):
context = event.context
yield event
context = event.context
yield event
if context:
node_inputs["#context#"] = context
@ -282,7 +281,7 @@ class LLMNode(Node):
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
if structured_output:
outputs["structured_output"] = structured_output.structured_output
if self._file_outputs is not None:
if self._file_outputs:
outputs["files"] = ArrayFileSegment(value=self._file_outputs)
# Send final chunk event to indicate streaming is complete
@ -827,9 +826,7 @@ class LLMNode(Node):
prompt_template = typed_node_data.prompt_template
variable_selectors = []
if isinstance(prompt_template, list) and all(
isinstance(prompt, LLMNodeChatModelMessage) for prompt in prompt_template
):
if isinstance(prompt_template, list):
for prompt in prompt_template:
if prompt.edition_type != "jinja2":
variable_template_parser = VariableTemplateParser(template=prompt.text)
@ -1063,7 +1060,7 @@ class LLMNode(Node):
return
if isinstance(contents, str):
yield contents
elif isinstance(contents, list):
else:
for item in contents:
if isinstance(item, TextPromptMessageContent):
yield item.data
@ -1077,9 +1074,6 @@ class LLMNode(Node):
else:
logger.warning("unknown item type encountered, type=%s", type(item))
yield str(item)
else:
logger.warning("unknown contents type encountered, type=%s", type(contents))
yield str(contents)
@property
def retry(self) -> bool:

View File

@ -147,14 +147,14 @@ class LoopNode(Node):
for key, value in graph_engine.graph_runtime_state.outputs.items():
if key == "answer":
# Concatenate answer outputs with newline
existing_answer = self.graph_runtime_state.outputs.get("answer", "")
existing_answer = self.graph_runtime_state.get_output("answer", "")
if existing_answer:
self.graph_runtime_state.outputs["answer"] = f"{existing_answer}{value}"
self.graph_runtime_state.set_output("answer", f"{existing_answer}{value}")
else:
self.graph_runtime_state.outputs["answer"] = value
self.graph_runtime_state.set_output("answer", value)
else:
# For other outputs, just update
self.graph_runtime_state.outputs[key] = value
self.graph_runtime_state.set_output(key, value)
# Update the total tokens from this iteration
cost_tokens += graph_engine.graph_runtime_state.total_tokens

View File

@ -31,8 +31,6 @@ _VALID_PARAMETER_TYPES = frozenset(
def _validate_type(parameter_type: str) -> SegmentType:
if not isinstance(parameter_type, str):
raise TypeError(f"type should be str, got {type(parameter_type)}, value={parameter_type}")
if parameter_type not in _VALID_PARAMETER_TYPES:
raise ValueError(f"type {parameter_type} is not allowd to use in Parameter Extractor node.")

View File

@ -10,7 +10,7 @@ from core.file import File
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import ImagePromptMessageContent
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
@ -38,7 +38,6 @@ from factories.variable_factory import build_segment_with_type
from .entities import ParameterExtractorNodeData
from .exc import (
InvalidInvokeResultError,
InvalidModelModeError,
InvalidModelTypeError,
InvalidNumberOfParametersError,
@ -304,8 +303,6 @@ class ParameterExtractorNode(Node):
)
# handle invoke result
if not isinstance(invoke_result, LLMResult):
raise InvalidInvokeResultError(f"Invalid invoke result: {invoke_result}")
text = invoke_result.message.content or ""
if not isinstance(text, str):
@ -317,9 +314,6 @@ class ParameterExtractorNode(Node):
# deduct quota
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
if text is None:
text = ""
return text, usage, tool_call
def _generate_function_call_prompt(
@ -583,18 +577,19 @@ class ParameterExtractorNode(Node):
return int(value)
elif isinstance(value, (int, float)):
return value
elif not isinstance(value, str):
return None
if "." in value:
try:
return float(value)
except ValueError:
return None
elif isinstance(value, str):
if "." in value:
try:
return float(value)
except ValueError:
return None
else:
try:
return int(value)
except ValueError:
return None
else:
try:
return int(value)
except ValueError:
return None
return None
def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
"""
@ -697,7 +692,7 @@ class ParameterExtractorNode(Node):
for parameter in data.parameters:
if parameter.type == "number":
result[parameter.name] = 0
elif parameter.type == "bool":
elif parameter.type == "boolean":
result[parameter.name] = False
elif parameter.type in {"string", "select"}:
result[parameter.name] = ""

View File

@ -323,7 +323,7 @@ class ToolNode(Node):
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
# JSON message handling for tool node
if message.message.json_object is not None:
if message.message.json_object:
json.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)

View File

@ -117,13 +117,8 @@ class VariableAssignerNode(Node):
case WriteMode.CLEAR:
income_value = get_zero_value(original_variable.value_type)
if income_value is None:
raise VariableOperatorNodeError("income value not found")
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
case _:
raise VariableOperatorNodeError(f"unsupported write mode: {self._node_data.write_mode}")
# Over write the variable.
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)

View File

@ -25,8 +25,6 @@ def is_operation_supported(*, variable_type: SegmentType, operation: Operation):
# Only array variable can be appended or extended
# Only array variable can have elements removed
return variable_type.is_array_type()
case _:
return False
def is_variable_input_supported(*, operation: Operation):

View File

@ -274,5 +274,3 @@ class VariableAssignerNode(Node):
if not variable.value:
return variable.value
return variable.value[:-1]
case _:
raise OperationNotSupportedError(operation=operation, variable_type=variable.value_type)