mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 01:48:04 +08:00
Merge branch 'feat/queue-based-graph-engine' into feat/rag-2
# Conflicts: # api/core/app/apps/advanced_chat/generate_task_pipeline.py # api/pyproject.toml # api/uv.lock # docker/docker-compose-template.yaml # docker/docker-compose.yaml # web/package.json
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
|
||||
@ -8,21 +9,127 @@ from .variable_pool import VariablePool
|
||||
|
||||
|
||||
class GraphRuntimeState(BaseModel):
|
||||
variable_pool: VariablePool = Field(..., description="variable pool")
|
||||
"""variable pool"""
|
||||
# Private attributes to prevent direct modification
|
||||
_variable_pool: VariablePool = PrivateAttr()
|
||||
_start_at: float = PrivateAttr()
|
||||
_total_tokens: int = PrivateAttr(default=0)
|
||||
_llm_usage: LLMUsage = PrivateAttr(default_factory=LLMUsage.empty_usage)
|
||||
_outputs: dict[str, Any] = PrivateAttr(default_factory=dict)
|
||||
_node_run_steps: int = PrivateAttr(default=0)
|
||||
|
||||
start_at: float = Field(..., description="start time")
|
||||
"""start time"""
|
||||
total_tokens: int = 0
|
||||
"""total tokens"""
|
||||
llm_usage: LLMUsage = LLMUsage.empty_usage()
|
||||
"""llm usage info"""
|
||||
def __init__(
|
||||
self,
|
||||
variable_pool: VariablePool,
|
||||
start_at: float,
|
||||
total_tokens: int = 0,
|
||||
llm_usage: LLMUsage | None = None,
|
||||
outputs: dict[str, Any] | None = None,
|
||||
node_run_steps: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the GraphRuntimeState with validation."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# The `outputs` field stores the final output values generated by executing workflows or chatflows.
|
||||
#
|
||||
# Note: Since the type of this field is `dict[str, Any]`, its values may not remain consistent
|
||||
# after a serialization and deserialization round trip.
|
||||
outputs: dict[str, Any] = Field(default_factory=dict)
|
||||
# Initialize private attributes with validation
|
||||
self._variable_pool = variable_pool
|
||||
|
||||
node_run_steps: int = 0
|
||||
"""node run steps"""
|
||||
self._start_at = start_at
|
||||
|
||||
if total_tokens < 0:
|
||||
raise ValueError("total_tokens must be non-negative")
|
||||
self._total_tokens = total_tokens
|
||||
|
||||
if llm_usage is None:
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
self._llm_usage = llm_usage
|
||||
|
||||
if outputs is None:
|
||||
outputs = {}
|
||||
self._outputs = deepcopy(outputs)
|
||||
|
||||
if node_run_steps < 0:
|
||||
raise ValueError("node_run_steps must be non-negative")
|
||||
self._node_run_steps = node_run_steps
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> VariablePool:
|
||||
"""Get the variable pool."""
|
||||
return self._variable_pool
|
||||
|
||||
@property
|
||||
def start_at(self) -> float:
|
||||
"""Get the start time."""
|
||||
return self._start_at
|
||||
|
||||
@start_at.setter
|
||||
def start_at(self, value: float) -> None:
|
||||
"""Set the start time."""
|
||||
self._start_at = value
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
"""Get the total tokens count."""
|
||||
return self._total_tokens
|
||||
|
||||
@total_tokens.setter
|
||||
def total_tokens(self, value: int):
|
||||
"""Set the total tokens count."""
|
||||
if value < 0:
|
||||
raise ValueError("total_tokens must be non-negative")
|
||||
self._total_tokens = value
|
||||
|
||||
@property
|
||||
def llm_usage(self) -> LLMUsage:
|
||||
"""Get the LLM usage info."""
|
||||
# Return a copy to prevent external modification
|
||||
return self._llm_usage.model_copy()
|
||||
|
||||
@llm_usage.setter
|
||||
def llm_usage(self, value: LLMUsage):
|
||||
"""Set the LLM usage info."""
|
||||
self._llm_usage = value.model_copy()
|
||||
|
||||
@property
|
||||
def outputs(self) -> dict[str, Any]:
|
||||
"""Get a copy of the outputs dictionary."""
|
||||
return deepcopy(self._outputs)
|
||||
|
||||
@outputs.setter
|
||||
def outputs(self, value: dict[str, Any]) -> None:
|
||||
"""Set the outputs dictionary."""
|
||||
self._outputs = deepcopy(value)
|
||||
|
||||
def set_output(self, key: str, value: Any) -> None:
|
||||
"""Set a single output value."""
|
||||
self._outputs[key] = deepcopy(value)
|
||||
|
||||
def get_output(self, key: str, default: Any = None) -> Any:
|
||||
"""Get a single output value."""
|
||||
return deepcopy(self._outputs.get(key, default))
|
||||
|
||||
def update_outputs(self, updates: dict[str, Any]) -> None:
|
||||
"""Update multiple output values."""
|
||||
for key, value in updates.items():
|
||||
self._outputs[key] = deepcopy(value)
|
||||
|
||||
@property
|
||||
def node_run_steps(self) -> int:
|
||||
"""Get the node run steps count."""
|
||||
return self._node_run_steps
|
||||
|
||||
@node_run_steps.setter
|
||||
def node_run_steps(self, value: int) -> None:
|
||||
"""Set the node run steps count."""
|
||||
if value < 0:
|
||||
raise ValueError("node_run_steps must be non-negative")
|
||||
self._node_run_steps = value
|
||||
|
||||
def increment_node_run_steps(self) -> None:
|
||||
"""Increment the node run steps by 1."""
|
||||
self._node_run_steps += 1
|
||||
|
||||
def add_tokens(self, tokens: int) -> None:
|
||||
"""Add tokens to the total count."""
|
||||
if tokens < 0:
|
||||
raise ValueError("tokens must be non-negative")
|
||||
self._total_tokens += tokens
|
||||
|
||||
@ -1,5 +1,16 @@
|
||||
from .edge import Edge
|
||||
from .graph import Graph, NodeFactory
|
||||
from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool
|
||||
from .graph_template import GraphTemplate
|
||||
from .read_only_state_wrapper import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper
|
||||
|
||||
__all__ = ["Edge", "Graph", "GraphTemplate", "NodeFactory"]
|
||||
__all__ = [
|
||||
"Edge",
|
||||
"Graph",
|
||||
"GraphTemplate",
|
||||
"NodeFactory",
|
||||
"ReadOnlyGraphRuntimeState",
|
||||
"ReadOnlyGraphRuntimeStateWrapper",
|
||||
"ReadOnlyVariablePool",
|
||||
"ReadOnlyVariablePoolWrapper",
|
||||
]
|
||||
|
||||
59
api/core/workflow/graph/graph_runtime_state_protocol.py
Normal file
59
api/core/workflow/graph/graph_runtime_state_protocol.py
Normal file
@ -0,0 +1,59 @@
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
|
||||
|
||||
class ReadOnlyVariablePool(Protocol):
|
||||
"""Read-only interface for VariablePool."""
|
||||
|
||||
def get(self, node_id: str, variable_key: str) -> Any:
|
||||
"""Get a variable value (read-only)."""
|
||||
...
|
||||
|
||||
def get_all_by_node(self, node_id: str) -> dict[str, Any]:
|
||||
"""Get all variables for a node (read-only)."""
|
||||
...
|
||||
|
||||
|
||||
class ReadOnlyGraphRuntimeState(Protocol):
|
||||
"""
|
||||
Read-only view of GraphRuntimeState for layers.
|
||||
|
||||
This protocol defines a read-only interface that prevents layers from
|
||||
modifying the graph runtime state while still allowing observation.
|
||||
All methods return defensive copies to ensure immutability.
|
||||
"""
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> ReadOnlyVariablePool:
|
||||
"""Get read-only access to the variable pool."""
|
||||
...
|
||||
|
||||
@property
|
||||
def start_at(self) -> float:
|
||||
"""Get the start time (read-only)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
"""Get the total tokens count (read-only)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def llm_usage(self) -> LLMUsage:
|
||||
"""Get a copy of LLM usage info (read-only)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def outputs(self) -> dict[str, Any]:
|
||||
"""Get a defensive copy of outputs (read-only)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def node_run_steps(self) -> int:
|
||||
"""Get the node run steps count (read-only)."""
|
||||
...
|
||||
|
||||
def get_output(self, key: str, default: Any = None) -> Any:
|
||||
"""Get a single output value (returns a copy)."""
|
||||
...
|
||||
76
api/core/workflow/graph/read_only_state_wrapper.py
Normal file
76
api/core/workflow/graph/read_only_state_wrapper.py
Normal file
@ -0,0 +1,76 @@
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
|
||||
class ReadOnlyVariablePoolWrapper:
|
||||
"""Wrapper that provides read-only access to VariablePool."""
|
||||
|
||||
def __init__(self, variable_pool: VariablePool):
|
||||
self._variable_pool = variable_pool
|
||||
|
||||
def get(self, node_id: str, variable_key: str) -> Any:
|
||||
"""Get a variable value (returns a defensive copy)."""
|
||||
value = self._variable_pool.get(node_id, variable_key)
|
||||
return deepcopy(value) if value is not None else None
|
||||
|
||||
def get_all_by_node(self, node_id: str) -> dict[str, Any]:
|
||||
"""Get all variables for a node (returns defensive copies)."""
|
||||
variables = {}
|
||||
if node_id in self._variable_pool.variable_dictionary:
|
||||
for key, var in self._variable_pool.variable_dictionary[node_id].items():
|
||||
# FIXME(-LAN-): Handle the actual Variable object structure
|
||||
value = var.value if hasattr(var, "value") else var
|
||||
variables[key] = deepcopy(value)
|
||||
return variables
|
||||
|
||||
|
||||
class ReadOnlyGraphRuntimeStateWrapper:
|
||||
"""
|
||||
Wrapper that provides read-only access to GraphRuntimeState.
|
||||
|
||||
This wrapper ensures that layers can observe the state without
|
||||
modifying it. All returned values are defensive copies.
|
||||
"""
|
||||
|
||||
def __init__(self, state: GraphRuntimeState):
|
||||
self._state = state
|
||||
self._variable_pool_wrapper = ReadOnlyVariablePoolWrapper(state.variable_pool)
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> ReadOnlyVariablePoolWrapper:
|
||||
"""Get read-only access to the variable pool."""
|
||||
return self._variable_pool_wrapper
|
||||
|
||||
@property
|
||||
def start_at(self) -> float:
|
||||
"""Get the start time (read-only)."""
|
||||
return self._state.start_at
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
"""Get the total tokens count (read-only)."""
|
||||
return self._state.total_tokens
|
||||
|
||||
@property
|
||||
def llm_usage(self) -> LLMUsage:
|
||||
"""Get a copy of LLM usage info (read-only)."""
|
||||
# Return a copy to prevent modification
|
||||
return self._state.llm_usage.model_copy()
|
||||
|
||||
@property
|
||||
def outputs(self) -> dict[str, Any]:
|
||||
"""Get a defensive copy of outputs (read-only)."""
|
||||
return deepcopy(self._state.outputs)
|
||||
|
||||
@property
|
||||
def node_run_steps(self) -> int:
|
||||
"""Get the node run steps count (read-only)."""
|
||||
return self._state.node_run_steps
|
||||
|
||||
def get_output(self, key: str, default: Any = None) -> Any:
|
||||
"""Get a single output value (returns a copy)."""
|
||||
return self._state.get_output(key, default)
|
||||
@ -267,10 +267,10 @@ class EventHandler:
|
||||
# in runtime state, rather than allowing nodes to directly access runtime state.
|
||||
for key, value in event.node_run_result.outputs.items():
|
||||
if key == "answer":
|
||||
existing = self._graph_runtime_state.outputs.get("answer", "")
|
||||
existing = self._graph_runtime_state.get_output("answer", "")
|
||||
if existing:
|
||||
self._graph_runtime_state.outputs["answer"] = f"{existing}{value}"
|
||||
self._graph_runtime_state.set_output("answer", f"{existing}{value}")
|
||||
else:
|
||||
self._graph_runtime_state.outputs["answer"] = value
|
||||
self._graph_runtime_state.set_output("answer", value)
|
||||
else:
|
||||
self._graph_runtime_state.outputs[key] = value
|
||||
self._graph_runtime_state.set_output(key, value)
|
||||
|
||||
@ -9,7 +9,7 @@ from typing import final
|
||||
|
||||
from core.workflow.graph_events import GraphEngineEvent
|
||||
|
||||
from ..layers.base import Layer
|
||||
from ..layers.base import GraphEngineLayer
|
||||
|
||||
|
||||
@final
|
||||
@ -104,10 +104,10 @@ class EventManager:
|
||||
"""Initialize the event manager."""
|
||||
self._events: list[GraphEngineEvent] = []
|
||||
self._lock = ReadWriteLock()
|
||||
self._layers: list[Layer] = []
|
||||
self._layers: list[GraphEngineLayer] = []
|
||||
self._execution_complete = threading.Event()
|
||||
|
||||
def set_layers(self, layers: list[Layer]) -> None:
|
||||
def set_layers(self, layers: list[GraphEngineLayer]) -> None:
|
||||
"""
|
||||
Set the layers to notify on event collection.
|
||||
|
||||
|
||||
@ -17,6 +17,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph.read_only_state_wrapper import ReadOnlyGraphRuntimeStateWrapper
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphNodeEventBase,
|
||||
@ -33,12 +34,12 @@ from .entities.commands import AbortCommand
|
||||
from .error_handling import ErrorHandler
|
||||
from .event_management import EventHandler, EventManager
|
||||
from .graph_traversal import EdgeProcessor, SkipPropagator
|
||||
from .layers.base import Layer
|
||||
from .layers.base import GraphEngineLayer
|
||||
from .orchestration import Dispatcher, ExecutionCoordinator
|
||||
from .protocols.command_channel import CommandChannel
|
||||
from .response_coordinator import ResponseStreamCoordinator
|
||||
from .state_management import UnifiedStateManager
|
||||
from .worker_management import SimpleWorkerPool
|
||||
from .worker_management import WorkerPool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -186,7 +187,7 @@ class GraphEngine:
|
||||
context_vars = contextvars.copy_context()
|
||||
|
||||
# Create worker pool for parallel node execution
|
||||
self._worker_pool = SimpleWorkerPool(
|
||||
self._worker_pool = WorkerPool(
|
||||
ready_queue=self._ready_queue,
|
||||
event_queue=self._event_queue,
|
||||
graph=self._graph,
|
||||
@ -221,7 +222,7 @@ class GraphEngine:
|
||||
|
||||
# === Extensibility ===
|
||||
# Layers allow plugins to extend engine functionality
|
||||
self._layers: list[Layer] = []
|
||||
self._layers: list[GraphEngineLayer] = []
|
||||
|
||||
# === Validation ===
|
||||
# Ensure all nodes share the same GraphRuntimeState instance
|
||||
@ -234,7 +235,7 @@ class GraphEngine:
|
||||
if id(node.graph_runtime_state) != expected_state_id:
|
||||
raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance")
|
||||
|
||||
def layer(self, layer: Layer) -> "GraphEngine":
|
||||
def layer(self, layer: GraphEngineLayer) -> "GraphEngine":
|
||||
"""Add a layer for extending functionality."""
|
||||
self._layers.append(layer)
|
||||
return self
|
||||
@ -288,9 +289,11 @@ class GraphEngine:
|
||||
def _initialize_layers(self) -> None:
|
||||
"""Initialize layers with context."""
|
||||
self._event_manager.set_layers(self._layers)
|
||||
# Create a read-only wrapper for the runtime state
|
||||
read_only_state = ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state)
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.initialize(self._graph_runtime_state, self._command_channel)
|
||||
layer.initialize(read_only_state, self._command_channel)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to initialize layer %s: %s", layer.__class__.__name__, e)
|
||||
|
||||
|
||||
@ -5,12 +5,12 @@ This module provides the layer infrastructure for extending GraphEngine function
|
||||
with middleware-like components that can observe events and interact with execution.
|
||||
"""
|
||||
|
||||
from .base import Layer
|
||||
from .base import GraphEngineLayer
|
||||
from .debug_logging import DebugLoggingLayer
|
||||
from .execution_limits import ExecutionLimitsLayer
|
||||
|
||||
__all__ = [
|
||||
"DebugLoggingLayer",
|
||||
"ExecutionLimitsLayer",
|
||||
"Layer",
|
||||
"GraphEngineLayer",
|
||||
]
|
||||
|
||||
@ -7,12 +7,12 @@ intercept and respond to GraphEngine events.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.graph.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState
|
||||
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
|
||||
from core.workflow.graph_events import GraphEngineEvent
|
||||
|
||||
|
||||
class Layer(ABC):
|
||||
class GraphEngineLayer(ABC):
|
||||
"""
|
||||
Abstract base class for GraphEngine layers.
|
||||
|
||||
@ -27,19 +27,19 @@ class Layer(ABC):
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the layer. Subclasses can override with custom parameters."""
|
||||
self.graph_runtime_state: GraphRuntimeState | None = None
|
||||
self.graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
|
||||
self.command_channel: CommandChannel | None = None
|
||||
|
||||
def initialize(self, graph_runtime_state: GraphRuntimeState, command_channel: CommandChannel) -> None:
|
||||
def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None:
|
||||
"""
|
||||
Initialize the layer with engine dependencies.
|
||||
|
||||
Called by GraphEngine before execution starts to inject the runtime state
|
||||
and command channel. This allows layers to access engine context and send
|
||||
commands.
|
||||
Called by GraphEngine before execution starts to inject the read-only runtime state
|
||||
and command channel. This allows layers to observe engine context and send
|
||||
commands, but prevents direct state modification.
|
||||
|
||||
Args:
|
||||
graph_runtime_state: The runtime state of the graph execution
|
||||
graph_runtime_state: Read-only view of the runtime state
|
||||
command_channel: Channel for sending commands to the engine
|
||||
"""
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
|
||||
@ -33,11 +33,11 @@ from core.workflow.graph_events import (
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
|
||||
from .base import Layer
|
||||
from .base import GraphEngineLayer
|
||||
|
||||
|
||||
@final
|
||||
class DebugLoggingLayer(Layer):
|
||||
class DebugLoggingLayer(GraphEngineLayer):
|
||||
"""
|
||||
A layer that provides comprehensive logging of GraphEngine execution.
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ from typing import final
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
|
||||
from core.workflow.graph_engine.layers import Layer
|
||||
from core.workflow.graph_engine.layers import GraphEngineLayer
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
NodeRunStartedEvent,
|
||||
@ -32,7 +32,7 @@ class LimitType(Enum):
|
||||
|
||||
|
||||
@final
|
||||
class ExecutionLimitsLayer(Layer):
|
||||
class ExecutionLimitsLayer(GraphEngineLayer):
|
||||
"""
|
||||
Layer that enforces execution limits for workflows.
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ from ..command_processing import CommandProcessor
|
||||
from ..domain import GraphExecution
|
||||
from ..event_management import EventManager
|
||||
from ..state_management import UnifiedStateManager
|
||||
from ..worker_management import SimpleWorkerPool
|
||||
from ..worker_management import WorkerPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..event_management import EventHandler
|
||||
@ -30,7 +30,7 @@ class ExecutionCoordinator:
|
||||
event_handler: "EventHandler",
|
||||
event_collector: EventManager,
|
||||
command_processor: CommandProcessor,
|
||||
worker_pool: SimpleWorkerPool,
|
||||
worker_pool: WorkerPool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the execution coordinator.
|
||||
|
||||
@ -9,7 +9,6 @@ import contextvars
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from typing import final
|
||||
from uuid import uuid4
|
||||
@ -42,8 +41,6 @@ class Worker(threading.Thread):
|
||||
worker_id: int = 0,
|
||||
flask_app: Flask | None = None,
|
||||
context_vars: contextvars.Context | None = None,
|
||||
on_idle_callback: Callable[[int], None] | None = None,
|
||||
on_active_callback: Callable[[int], None] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize worker thread.
|
||||
@ -55,8 +52,6 @@ class Worker(threading.Thread):
|
||||
worker_id: Unique identifier for this worker
|
||||
flask_app: Optional Flask application for context preservation
|
||||
context_vars: Optional context variables to preserve in worker thread
|
||||
on_idle_callback: Optional callback when worker becomes idle
|
||||
on_active_callback: Optional callback when worker becomes active
|
||||
"""
|
||||
super().__init__(name=f"GraphWorker-{worker_id}", daemon=True)
|
||||
self._ready_queue = ready_queue
|
||||
@ -66,14 +61,28 @@ class Worker(threading.Thread):
|
||||
self._flask_app = flask_app
|
||||
self._context_vars = context_vars
|
||||
self._stop_event = threading.Event()
|
||||
self._on_idle_callback = on_idle_callback
|
||||
self._on_active_callback = on_active_callback
|
||||
self._last_task_time = time.time()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal the worker to stop processing."""
|
||||
self._stop_event.set()
|
||||
|
||||
@property
|
||||
def is_idle(self) -> bool:
|
||||
"""Check if the worker is currently idle."""
|
||||
# Worker is idle if it hasn't processed a task recently (within 0.2 seconds)
|
||||
return (time.time() - self._last_task_time) > 0.2
|
||||
|
||||
@property
|
||||
def idle_duration(self) -> float:
|
||||
"""Get the duration in seconds since the worker last processed a task."""
|
||||
return time.time() - self._last_task_time
|
||||
|
||||
@property
|
||||
def worker_id(self) -> int:
|
||||
"""Get the worker's ID."""
|
||||
return self._worker_id
|
||||
|
||||
@override
|
||||
def run(self) -> None:
|
||||
"""
|
||||
@ -87,15 +96,8 @@ class Worker(threading.Thread):
|
||||
try:
|
||||
node_id = self._ready_queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
# Notify that worker is idle
|
||||
if self._on_idle_callback:
|
||||
self._on_idle_callback(self._worker_id)
|
||||
continue
|
||||
|
||||
# Notify that worker is active
|
||||
if self._on_active_callback:
|
||||
self._on_active_callback(self._worker_id)
|
||||
|
||||
self._last_task_time = time.time()
|
||||
node = self._graph.nodes[node_id]
|
||||
try:
|
||||
|
||||
@ -5,8 +5,8 @@ This package manages the worker pool, including creation,
|
||||
scaling, and activity tracking.
|
||||
"""
|
||||
|
||||
from .simple_worker_pool import SimpleWorkerPool
|
||||
from .worker_pool import WorkerPool
|
||||
|
||||
__all__ = [
|
||||
"SimpleWorkerPool",
|
||||
"WorkerPool",
|
||||
]
|
||||
|
||||
@ -5,6 +5,7 @@ This is a simpler implementation that merges WorkerPool, ActivityTracker,
|
||||
DynamicScaler, and WorkerFactory into a single class.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, final
|
||||
@ -15,6 +16,8 @@ from core.workflow.graph_events import GraphNodeEventBase
|
||||
|
||||
from ..worker import Worker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from contextvars import Context
|
||||
|
||||
@ -22,7 +25,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
@final
|
||||
class SimpleWorkerPool:
|
||||
class WorkerPool:
|
||||
"""
|
||||
Simple worker pool with integrated management.
|
||||
|
||||
@ -74,6 +77,8 @@ class SimpleWorkerPool:
|
||||
self._lock = threading.RLock()
|
||||
self._running = False
|
||||
|
||||
# No longer tracking worker states with callbacks to avoid lock contention
|
||||
|
||||
def start(self, initial_count: int | None = None) -> None:
|
||||
"""
|
||||
Start the worker pool.
|
||||
@ -97,6 +102,14 @@ class SimpleWorkerPool:
|
||||
else:
|
||||
initial_count = min(self._min_workers + 2, self._max_workers)
|
||||
|
||||
logger.debug(
|
||||
"Starting worker pool: %d workers (nodes=%d, min=%d, max=%d)",
|
||||
initial_count,
|
||||
node_count,
|
||||
self._min_workers,
|
||||
self._max_workers,
|
||||
)
|
||||
|
||||
# Create initial workers
|
||||
for _ in range(initial_count):
|
||||
self._create_worker()
|
||||
@ -105,6 +118,10 @@ class SimpleWorkerPool:
|
||||
"""Stop all workers in the pool."""
|
||||
with self._lock:
|
||||
self._running = False
|
||||
worker_count = len(self._workers)
|
||||
|
||||
if worker_count > 0:
|
||||
logger.debug("Stopping worker pool: %d workers", worker_count)
|
||||
|
||||
# Stop all workers
|
||||
for worker in self._workers:
|
||||
@ -134,6 +151,105 @@ class SimpleWorkerPool:
|
||||
worker.start()
|
||||
self._workers.append(worker)
|
||||
|
||||
def _remove_worker(self, worker: Worker, worker_id: int) -> None:
|
||||
"""Remove a specific worker from the pool."""
|
||||
# Stop the worker
|
||||
worker.stop()
|
||||
|
||||
# Wait for it to finish
|
||||
if worker.is_alive():
|
||||
worker.join(timeout=2.0)
|
||||
|
||||
# Remove from list
|
||||
if worker in self._workers:
|
||||
self._workers.remove(worker)
|
||||
|
||||
def _try_scale_up(self, queue_depth: int, current_count: int) -> bool:
|
||||
"""
|
||||
Try to scale up workers if needed.
|
||||
|
||||
Args:
|
||||
queue_depth: Current queue depth
|
||||
current_count: Current number of workers
|
||||
|
||||
Returns:
|
||||
True if scaled up, False otherwise
|
||||
"""
|
||||
if queue_depth > self._scale_up_threshold and current_count < self._max_workers:
|
||||
old_count = current_count
|
||||
self._create_worker()
|
||||
|
||||
logger.debug(
|
||||
"Scaled up workers: %d -> %d (queue_depth=%d exceeded threshold=%d)",
|
||||
old_count,
|
||||
len(self._workers),
|
||||
queue_depth,
|
||||
self._scale_up_threshold,
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _try_scale_down(self, queue_depth: int, current_count: int, active_count: int, idle_count: int) -> bool:
|
||||
"""
|
||||
Try to scale down workers if we have excess capacity.
|
||||
|
||||
Args:
|
||||
queue_depth: Current queue depth
|
||||
current_count: Current number of workers
|
||||
active_count: Number of active workers
|
||||
idle_count: Number of idle workers
|
||||
|
||||
Returns:
|
||||
True if scaled down, False otherwise
|
||||
"""
|
||||
# Skip if we're at minimum or have no idle workers
|
||||
if current_count <= self._min_workers or idle_count == 0:
|
||||
return False
|
||||
|
||||
# Check if we have excess capacity
|
||||
has_excess_capacity = (
|
||||
queue_depth <= active_count # Active workers can handle current queue
|
||||
or idle_count > active_count # More idle than active workers
|
||||
or (queue_depth == 0 and idle_count > 0) # No work and have idle workers
|
||||
)
|
||||
|
||||
if not has_excess_capacity:
|
||||
return False
|
||||
|
||||
# Find and remove idle workers that have been idle long enough
|
||||
workers_to_remove = []
|
||||
|
||||
for worker in self._workers:
|
||||
# Check if worker is idle and has exceeded idle time threshold
|
||||
if worker.is_idle and worker.idle_duration >= self._scale_down_idle_time:
|
||||
# Don't remove if it would leave us unable to handle the queue
|
||||
remaining_workers = current_count - len(workers_to_remove) - 1
|
||||
if remaining_workers >= self._min_workers and remaining_workers >= max(1, queue_depth // 2):
|
||||
workers_to_remove.append((worker, worker.worker_id))
|
||||
# Only remove one worker per check to avoid aggressive scaling
|
||||
break
|
||||
|
||||
# Remove idle workers if any found
|
||||
if workers_to_remove:
|
||||
old_count = current_count
|
||||
for worker, worker_id in workers_to_remove:
|
||||
self._remove_worker(worker, worker_id)
|
||||
|
||||
logger.debug(
|
||||
"Scaled down workers: %d -> %d (removed %d idle workers after %.1fs, "
|
||||
"queue_depth=%d, active=%d, idle=%d)",
|
||||
old_count,
|
||||
len(self._workers),
|
||||
len(workers_to_remove),
|
||||
self._scale_down_idle_time,
|
||||
queue_depth,
|
||||
active_count,
|
||||
idle_count - len(workers_to_remove),
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def check_and_scale(self) -> None:
|
||||
"""Check and perform scaling if needed."""
|
||||
with self._lock:
|
||||
@ -143,9 +259,15 @@ class SimpleWorkerPool:
|
||||
current_count = len(self._workers)
|
||||
queue_depth = self._ready_queue.qsize()
|
||||
|
||||
# Simple scaling logic
|
||||
if queue_depth > self._scale_up_threshold and current_count < self._max_workers:
|
||||
self._create_worker()
|
||||
# Count active vs idle workers by querying their state directly
|
||||
idle_count = sum(1 for worker in self._workers if worker.is_idle)
|
||||
active_count = current_count - idle_count
|
||||
|
||||
# Try to scale up if queue is backing up
|
||||
self._try_scale_up(queue_depth, current_count)
|
||||
|
||||
# Try to scale down if we have excess capacity
|
||||
self._try_scale_down(queue_depth, current_count, active_count, idle_count)
|
||||
|
||||
def get_worker_count(self) -> int:
|
||||
"""Get current number of workers."""
|
||||
@ -579,7 +579,7 @@ class AgentNode(Node):
|
||||
for key, value in msg_metadata.items()
|
||||
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||
}
|
||||
if message.message.json_object is not None:
|
||||
if message.message.json_object:
|
||||
json_list.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
|
||||
@ -73,9 +73,6 @@ class DefaultValue(BaseModel):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_value_type(self) -> "DefaultValue":
|
||||
if self.type is None:
|
||||
raise DefaultValueTypeError("type field is required")
|
||||
|
||||
# Type validation configuration
|
||||
type_validators = {
|
||||
DefaultValueType.STRING: {
|
||||
|
||||
@ -108,8 +108,6 @@ class CodeNode(Node):
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, str):
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a string")
|
||||
|
||||
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
|
||||
raise OutputValidationError(
|
||||
@ -122,8 +120,6 @@ class CodeNode(Node):
|
||||
def _check_boolean(self, value: bool | None, variable: str) -> bool | None:
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, bool):
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a boolean")
|
||||
|
||||
return value
|
||||
|
||||
@ -136,8 +132,6 @@ class CodeNode(Node):
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, int | float):
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a number")
|
||||
|
||||
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
|
||||
raise OutputValidationError(
|
||||
@ -261,7 +255,13 @@ class CodeNode(Node):
|
||||
)
|
||||
elif output_config.type == SegmentType.NUMBER:
|
||||
# check if number available
|
||||
checked = self._check_number(value=result[output_name], variable=f"{prefix}{dot}{output_name}")
|
||||
value = result.get(output_name)
|
||||
if not isinstance(value, (int, float, None)):
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name} is not a number,"
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
checked = self._check_number(value=value, variable=f"{prefix}{dot}{output_name}")
|
||||
# If the output is a boolean and the output schema specifies a NUMBER type,
|
||||
# convert the boolean value to an integer.
|
||||
#
|
||||
@ -271,8 +271,11 @@ class CodeNode(Node):
|
||||
|
||||
elif output_config.type == SegmentType.STRING:
|
||||
# check if string available
|
||||
value = result.get("output_name")
|
||||
if value is not None and not isinstance(value, str):
|
||||
raise OutputValidationError(f"Output value `{value}` is not string")
|
||||
transformed_result[output_name] = self._check_string(
|
||||
value=result[output_name],
|
||||
value=value,
|
||||
variable=f"{prefix}{dot}{output_name}",
|
||||
)
|
||||
elif output_config.type == SegmentType.BOOLEAN:
|
||||
@ -282,31 +285,36 @@ class CodeNode(Node):
|
||||
)
|
||||
elif output_config.type == SegmentType.ARRAY_NUMBER:
|
||||
# check if array of number available
|
||||
if not isinstance(result[output_name], list):
|
||||
if result[output_name] is None:
|
||||
value = result[output_name]
|
||||
if not isinstance(value, list):
|
||||
if value is None:
|
||||
transformed_result[output_name] = None
|
||||
else:
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name} is not an array,"
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
f"Output {prefix}{dot}{output_name} is not an array, got {type(value)} instead."
|
||||
)
|
||||
else:
|
||||
if len(result[output_name]) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH:
|
||||
if len(value) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH:
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements."
|
||||
)
|
||||
|
||||
for i, inner_value in enumerate(value):
|
||||
if not isinstance(inner_value, (int, float)):
|
||||
raise OutputValidationError(
|
||||
f"The element at index {i} of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" a number."
|
||||
)
|
||||
_ = self._check_number(value=inner_value, variable=f"{prefix}{dot}{output_name}[{i}]")
|
||||
transformed_result[output_name] = [
|
||||
# If the element is a boolean and the output schema specifies a `array[number]` type,
|
||||
# convert the boolean value to an integer.
|
||||
#
|
||||
# This ensures compatibility with existing workflows that may use
|
||||
# `True` and `False` as values for NUMBER type outputs.
|
||||
self._convert_boolean_to_int(
|
||||
self._check_number(value=value, variable=f"{prefix}{dot}{output_name}[{i}]"),
|
||||
)
|
||||
for i, value in enumerate(result[output_name])
|
||||
self._convert_boolean_to_int(v)
|
||||
for v in value
|
||||
]
|
||||
elif output_config.type == SegmentType.ARRAY_STRING:
|
||||
# check if array of string available
|
||||
@ -369,8 +377,9 @@ class CodeNode(Node):
|
||||
]
|
||||
elif output_config.type == SegmentType.ARRAY_BOOLEAN:
|
||||
# check if array of object available
|
||||
if not isinstance(result[output_name], list):
|
||||
if result[output_name] is None:
|
||||
value = result[output_name]
|
||||
if not isinstance(value, list):
|
||||
if value is None:
|
||||
transformed_result[output_name] = None
|
||||
else:
|
||||
raise OutputValidationError(
|
||||
@ -378,10 +387,14 @@ class CodeNode(Node):
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
transformed_result[output_name] = [
|
||||
self._check_boolean(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
|
||||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
for i, inner_value in enumerate(value):
|
||||
if not isinstance(inner_value, bool | None):
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name}[{i}] is not a boolean,"
|
||||
f" got {type(inner_value)} instead."
|
||||
)
|
||||
_ = self._check_boolean(value=inner_value, variable=f"{prefix}{dot}{output_name}[{i}]")
|
||||
transformed_result[output_name] = value
|
||||
|
||||
else:
|
||||
raise OutputValidationError(f"Output type {output_config.type} is not supported.")
|
||||
|
||||
@ -263,9 +263,6 @@ class Executor:
|
||||
if authorization.config is None:
|
||||
raise AuthorizationConfigError("authorization config is required")
|
||||
|
||||
if self.auth.config.api_key is None:
|
||||
raise AuthorizationConfigError("api_key is required")
|
||||
|
||||
if not authorization.config.header:
|
||||
authorization.config.header = "Authorization"
|
||||
|
||||
@ -409,30 +406,22 @@ class Executor:
|
||||
if self.files and not all(f[0] == "__multipart_placeholder__" for f in self.files):
|
||||
for file_entry in self.files:
|
||||
# file_entry should be (key, (filename, content, mime_type)), but handle edge cases
|
||||
if len(file_entry) != 2 or not isinstance(file_entry[1], tuple) or len(file_entry[1]) < 2:
|
||||
if len(file_entry) != 2 or len(file_entry[1]) < 2:
|
||||
continue # skip malformed entries
|
||||
key = file_entry[0]
|
||||
content = file_entry[1][1]
|
||||
body_string += f"--{boundary}\r\n"
|
||||
body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n'
|
||||
# decode content safely
|
||||
if isinstance(content, bytes):
|
||||
try:
|
||||
body_string += content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
body_string += content.decode("utf-8", errors="replace")
|
||||
elif isinstance(content, str):
|
||||
body_string += content
|
||||
else:
|
||||
body_string += f"[Unsupported content type: {type(content).__name__}]"
|
||||
try:
|
||||
body_string += content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
body_string += content.decode("utf-8", errors="replace")
|
||||
body_string += "\r\n"
|
||||
body_string += f"--{boundary}--\r\n"
|
||||
elif self.node_data.body:
|
||||
if self.content:
|
||||
if isinstance(self.content, str):
|
||||
body_string = self.content
|
||||
elif isinstance(self.content, bytes):
|
||||
body_string = self.content.decode("utf-8", errors="replace")
|
||||
body_string = self.content.decode("utf-8", errors="replace")
|
||||
elif self.data and self.node_data.body.type == "x-www-form-urlencoded":
|
||||
body_string = urlencode(self.data)
|
||||
elif self.data and self.node_data.body.type == "form-data":
|
||||
|
||||
@ -170,27 +170,21 @@ class ListOperatorNode(Node):
|
||||
)
|
||||
result = list(filter(filter_func, variable.value))
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayBooleanSegment):
|
||||
if not isinstance(condition.value, bool):
|
||||
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
|
||||
else:
|
||||
filter_func = _get_boolean_filter_func(condition=condition.comparison_operator, value=condition.value)
|
||||
result = list(filter(filter_func, variable.value))
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
else:
|
||||
raise AssertionError("this statment should be unreachable.")
|
||||
return variable
|
||||
|
||||
def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||
if isinstance(variable, (ArrayStringSegment, ArrayNumberSegment, ArrayBooleanSegment)):
|
||||
result = sorted(variable.value, reverse=self._node_data.order_by == Order.DESC)
|
||||
result = sorted(variable.value, reverse=self._node_data.order_by.value == Order.DESC)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayFileSegment):
|
||||
else:
|
||||
result = _order_file(
|
||||
order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value
|
||||
)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
else:
|
||||
raise AssertionError("this statement should be unreachable")
|
||||
|
||||
return variable
|
||||
|
||||
@ -304,7 +298,7 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str
|
||||
if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str):
|
||||
extract_func = _get_file_extract_string_func(key=key)
|
||||
return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x))
|
||||
if key in {"type", "transfer_method"} and isinstance(value, Sequence):
|
||||
if key in {"type", "transfer_method"}:
|
||||
extract_func = _get_file_extract_string_func(key=key)
|
||||
return lambda x: _get_sequence_filter_func(condition=condition, value=value)(extract_func(x))
|
||||
elif key == "size" and isinstance(value, str):
|
||||
|
||||
@ -195,9 +195,8 @@ class LLMNode(Node):
|
||||
generator = self._fetch_context(node_data=self._node_data)
|
||||
context = None
|
||||
for event in generator:
|
||||
if isinstance(event, RunRetrieverResourceEvent):
|
||||
context = event.context
|
||||
yield event
|
||||
context = event.context
|
||||
yield event
|
||||
if context:
|
||||
node_inputs["#context#"] = context
|
||||
|
||||
@ -282,7 +281,7 @@ class LLMNode(Node):
|
||||
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
||||
if structured_output:
|
||||
outputs["structured_output"] = structured_output.structured_output
|
||||
if self._file_outputs is not None:
|
||||
if self._file_outputs:
|
||||
outputs["files"] = ArrayFileSegment(value=self._file_outputs)
|
||||
|
||||
# Send final chunk event to indicate streaming is complete
|
||||
@ -827,9 +826,7 @@ class LLMNode(Node):
|
||||
|
||||
prompt_template = typed_node_data.prompt_template
|
||||
variable_selectors = []
|
||||
if isinstance(prompt_template, list) and all(
|
||||
isinstance(prompt, LLMNodeChatModelMessage) for prompt in prompt_template
|
||||
):
|
||||
if isinstance(prompt_template, list):
|
||||
for prompt in prompt_template:
|
||||
if prompt.edition_type != "jinja2":
|
||||
variable_template_parser = VariableTemplateParser(template=prompt.text)
|
||||
@ -1063,7 +1060,7 @@ class LLMNode(Node):
|
||||
return
|
||||
if isinstance(contents, str):
|
||||
yield contents
|
||||
elif isinstance(contents, list):
|
||||
else:
|
||||
for item in contents:
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
yield item.data
|
||||
@ -1077,9 +1074,6 @@ class LLMNode(Node):
|
||||
else:
|
||||
logger.warning("unknown item type encountered, type=%s", type(item))
|
||||
yield str(item)
|
||||
else:
|
||||
logger.warning("unknown contents type encountered, type=%s", type(contents))
|
||||
yield str(contents)
|
||||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
|
||||
@ -147,14 +147,14 @@ class LoopNode(Node):
|
||||
for key, value in graph_engine.graph_runtime_state.outputs.items():
|
||||
if key == "answer":
|
||||
# Concatenate answer outputs with newline
|
||||
existing_answer = self.graph_runtime_state.outputs.get("answer", "")
|
||||
existing_answer = self.graph_runtime_state.get_output("answer", "")
|
||||
if existing_answer:
|
||||
self.graph_runtime_state.outputs["answer"] = f"{existing_answer}{value}"
|
||||
self.graph_runtime_state.set_output("answer", f"{existing_answer}{value}")
|
||||
else:
|
||||
self.graph_runtime_state.outputs["answer"] = value
|
||||
self.graph_runtime_state.set_output("answer", value)
|
||||
else:
|
||||
# For other outputs, just update
|
||||
self.graph_runtime_state.outputs[key] = value
|
||||
self.graph_runtime_state.set_output(key, value)
|
||||
|
||||
# Update the total tokens from this iteration
|
||||
cost_tokens += graph_engine.graph_runtime_state.total_tokens
|
||||
|
||||
@ -31,8 +31,6 @@ _VALID_PARAMETER_TYPES = frozenset(
|
||||
|
||||
|
||||
def _validate_type(parameter_type: str) -> SegmentType:
|
||||
if not isinstance(parameter_type, str):
|
||||
raise TypeError(f"type should be str, got {type(parameter_type)}, value={parameter_type}")
|
||||
if parameter_type not in _VALID_PARAMETER_TYPES:
|
||||
raise ValueError(f"type {parameter_type} is not allowd to use in Parameter Extractor node.")
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ from core.file import File
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import ImagePromptMessageContent
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
@ -38,7 +38,6 @@ from factories.variable_factory import build_segment_with_type
|
||||
|
||||
from .entities import ParameterExtractorNodeData
|
||||
from .exc import (
|
||||
InvalidInvokeResultError,
|
||||
InvalidModelModeError,
|
||||
InvalidModelTypeError,
|
||||
InvalidNumberOfParametersError,
|
||||
@ -304,8 +303,6 @@ class ParameterExtractorNode(Node):
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
if not isinstance(invoke_result, LLMResult):
|
||||
raise InvalidInvokeResultError(f"Invalid invoke result: {invoke_result}")
|
||||
|
||||
text = invoke_result.message.content or ""
|
||||
if not isinstance(text, str):
|
||||
@ -317,9 +314,6 @@ class ParameterExtractorNode(Node):
|
||||
# deduct quota
|
||||
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
|
||||
if text is None:
|
||||
text = ""
|
||||
|
||||
return text, usage, tool_call
|
||||
|
||||
def _generate_function_call_prompt(
|
||||
@ -583,18 +577,19 @@ class ParameterExtractorNode(Node):
|
||||
return int(value)
|
||||
elif isinstance(value, (int, float)):
|
||||
return value
|
||||
elif not isinstance(value, str):
|
||||
return None
|
||||
if "." in value:
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return None
|
||||
elif isinstance(value, str):
|
||||
if "." in value:
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return None
|
||||
else:
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return None
|
||||
else:
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
|
||||
"""
|
||||
@ -697,7 +692,7 @@ class ParameterExtractorNode(Node):
|
||||
for parameter in data.parameters:
|
||||
if parameter.type == "number":
|
||||
result[parameter.name] = 0
|
||||
elif parameter.type == "bool":
|
||||
elif parameter.type == "boolean":
|
||||
result[parameter.name] = False
|
||||
elif parameter.type in {"string", "select"}:
|
||||
result[parameter.name] = ""
|
||||
|
||||
@ -323,7 +323,7 @@ class ToolNode(Node):
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||
# JSON message handling for tool node
|
||||
if message.message.json_object is not None:
|
||||
if message.message.json_object:
|
||||
json.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
|
||||
@ -117,13 +117,8 @@ class VariableAssignerNode(Node):
|
||||
|
||||
case WriteMode.CLEAR:
|
||||
income_value = get_zero_value(original_variable.value_type)
|
||||
if income_value is None:
|
||||
raise VariableOperatorNodeError("income value not found")
|
||||
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
|
||||
|
||||
case _:
|
||||
raise VariableOperatorNodeError(f"unsupported write mode: {self._node_data.write_mode}")
|
||||
|
||||
# Over write the variable.
|
||||
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)
|
||||
|
||||
|
||||
@ -25,8 +25,6 @@ def is_operation_supported(*, variable_type: SegmentType, operation: Operation):
|
||||
# Only array variable can be appended or extended
|
||||
# Only array variable can have elements removed
|
||||
return variable_type.is_array_type()
|
||||
case _:
|
||||
return False
|
||||
|
||||
|
||||
def is_variable_input_supported(*, operation: Operation):
|
||||
|
||||
@ -274,5 +274,3 @@ class VariableAssignerNode(Node):
|
||||
if not variable.value:
|
||||
return variable.value
|
||||
return variable.value[:-1]
|
||||
case _:
|
||||
raise OperationNotSupportedError(operation=operation, variable_type=variable.value_type)
|
||||
|
||||
Reference in New Issue
Block a user