mirror of
https://github.com/langgenius/dify.git
synced 2026-04-27 05:58:14 +08:00
Merge branch 'main' into feat/pull-a-variable
This commit is contained in:
22
api/core/workflow/context/__init__.py
Normal file
22
api/core/workflow/context/__init__.py
Normal file
@ -0,0 +1,22 @@
|
||||
"""
|
||||
Execution Context - Context management for workflow execution.
|
||||
|
||||
This package provides Flask-independent context management for workflow
|
||||
execution in multi-threaded environments.
|
||||
"""
|
||||
|
||||
from core.workflow.context.execution_context import (
|
||||
AppContext,
|
||||
ExecutionContext,
|
||||
IExecutionContext,
|
||||
NullAppContext,
|
||||
capture_current_context,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AppContext",
|
||||
"ExecutionContext",
|
||||
"IExecutionContext",
|
||||
"NullAppContext",
|
||||
"capture_current_context",
|
||||
]
|
||||
216
api/core/workflow/context/execution_context.py
Normal file
216
api/core/workflow/context/execution_context.py
Normal file
@ -0,0 +1,216 @@
|
||||
"""
|
||||
Execution Context - Abstracted context management for workflow execution.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from typing import Any, Protocol, final, runtime_checkable
|
||||
|
||||
|
||||
class AppContext(ABC):
|
||||
"""
|
||||
Abstract application context interface.
|
||||
|
||||
This abstraction allows workflow execution to work with or without Flask
|
||||
by providing a common interface for application context management.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""Get configuration value by key."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_extension(self, name: str) -> Any:
|
||||
"""Get Flask extension by name (e.g., 'db', 'cache')."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def enter(self) -> AbstractContextManager[None]:
|
||||
"""Enter the application context."""
|
||||
pass
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class IExecutionContext(Protocol):
|
||||
"""
|
||||
Protocol for execution context.
|
||||
|
||||
This protocol defines the interface that all execution contexts must implement,
|
||||
allowing both ExecutionContext and FlaskExecutionContext to be used interchangeably.
|
||||
"""
|
||||
|
||||
def __enter__(self) -> "IExecutionContext":
|
||||
"""Enter the execution context."""
|
||||
...
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
"""Exit the execution context."""
|
||||
...
|
||||
|
||||
@property
|
||||
def user(self) -> Any:
|
||||
"""Get user object."""
|
||||
...
|
||||
|
||||
|
||||
@final
|
||||
class ExecutionContext:
|
||||
"""
|
||||
Execution context for workflow execution in worker threads.
|
||||
|
||||
This class encapsulates all context needed for workflow execution:
|
||||
- Application context (Flask app or standalone)
|
||||
- Context variables for Python contextvars
|
||||
- User information (optional)
|
||||
|
||||
It is designed to be serializable and passable to worker threads.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_context: AppContext | None = None,
|
||||
context_vars: contextvars.Context | None = None,
|
||||
user: Any = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize execution context.
|
||||
|
||||
Args:
|
||||
app_context: Application context (Flask or standalone)
|
||||
context_vars: Python contextvars to preserve
|
||||
user: User object (optional)
|
||||
"""
|
||||
self._app_context = app_context
|
||||
self._context_vars = context_vars
|
||||
self._user = user
|
||||
|
||||
@property
|
||||
def app_context(self) -> AppContext | None:
|
||||
"""Get application context."""
|
||||
return self._app_context
|
||||
|
||||
@property
|
||||
def context_vars(self) -> contextvars.Context | None:
|
||||
"""Get context variables."""
|
||||
return self._context_vars
|
||||
|
||||
@property
|
||||
def user(self) -> Any:
|
||||
"""Get user object."""
|
||||
return self._user
|
||||
|
||||
@contextmanager
|
||||
def enter(self) -> Generator[None, None, None]:
|
||||
"""
|
||||
Enter this execution context.
|
||||
|
||||
This is a convenience method that creates a context manager.
|
||||
"""
|
||||
# Restore context variables if provided
|
||||
if self._context_vars:
|
||||
for var, val in self._context_vars.items():
|
||||
var.set(val)
|
||||
|
||||
# Enter app context if available
|
||||
if self._app_context is not None:
|
||||
with self._app_context.enter():
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
def __enter__(self) -> "ExecutionContext":
|
||||
"""Enter the execution context."""
|
||||
self._cm = self.enter()
|
||||
self._cm.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
"""Exit the execution context."""
|
||||
if hasattr(self, "_cm"):
|
||||
self._cm.__exit__(*args)
|
||||
|
||||
|
||||
class NullAppContext(AppContext):
|
||||
"""
|
||||
Null implementation of AppContext for non-Flask environments.
|
||||
|
||||
This is used when running without Flask (e.g., in tests or standalone mode).
|
||||
"""
|
||||
|
||||
def __init__(self, config: dict[str, Any] | None = None) -> None:
|
||||
"""
|
||||
Initialize null app context.
|
||||
|
||||
Args:
|
||||
config: Optional configuration dictionary
|
||||
"""
|
||||
self._config = config or {}
|
||||
self._extensions: dict[str, Any] = {}
|
||||
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""Get configuration value by key."""
|
||||
return self._config.get(key, default)
|
||||
|
||||
def get_extension(self, name: str) -> Any:
|
||||
"""Get extension by name."""
|
||||
return self._extensions.get(name)
|
||||
|
||||
def set_extension(self, name: str, extension: Any) -> None:
|
||||
"""Set extension by name."""
|
||||
self._extensions[name] = extension
|
||||
|
||||
@contextmanager
|
||||
def enter(self) -> Generator[None, None, None]:
|
||||
"""Enter null context (no-op)."""
|
||||
yield
|
||||
|
||||
|
||||
class ExecutionContextBuilder:
|
||||
"""
|
||||
Builder for creating ExecutionContext instances.
|
||||
|
||||
This provides a fluent API for building execution contexts.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._app_context: AppContext | None = None
|
||||
self._context_vars: contextvars.Context | None = None
|
||||
self._user: Any = None
|
||||
|
||||
def with_app_context(self, app_context: AppContext) -> "ExecutionContextBuilder":
|
||||
"""Set application context."""
|
||||
self._app_context = app_context
|
||||
return self
|
||||
|
||||
def with_context_vars(self, context_vars: contextvars.Context) -> "ExecutionContextBuilder":
|
||||
"""Set context variables."""
|
||||
self._context_vars = context_vars
|
||||
return self
|
||||
|
||||
def with_user(self, user: Any) -> "ExecutionContextBuilder":
|
||||
"""Set user."""
|
||||
self._user = user
|
||||
return self
|
||||
|
||||
def build(self) -> ExecutionContext:
|
||||
"""Build the execution context."""
|
||||
return ExecutionContext(
|
||||
app_context=self._app_context,
|
||||
context_vars=self._context_vars,
|
||||
user=self._user,
|
||||
)
|
||||
|
||||
|
||||
def capture_current_context() -> IExecutionContext:
|
||||
"""
|
||||
Capture current execution context from the calling environment.
|
||||
|
||||
Returns:
|
||||
IExecutionContext with captured context
|
||||
"""
|
||||
from context import capture_current_context
|
||||
|
||||
return capture_current_context()
|
||||
@ -7,15 +7,13 @@ Domain-Driven Design principles for improved maintainability and testability.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextvars
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, cast, final
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
from core.workflow.context import capture_current_context
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
@ -159,17 +157,8 @@ class GraphEngine:
|
||||
self._layers: list[GraphEngineLayer] = []
|
||||
|
||||
# === Worker Pool Setup ===
|
||||
# Capture Flask app context for worker threads
|
||||
flask_app: Flask | None = None
|
||||
try:
|
||||
app = current_app._get_current_object() # type: ignore
|
||||
if isinstance(app, Flask):
|
||||
flask_app = app
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
# Capture context variables for worker threads
|
||||
context_vars = contextvars.copy_context()
|
||||
# Capture execution context for worker threads
|
||||
execution_context = capture_current_context()
|
||||
|
||||
# Create worker pool for parallel node execution
|
||||
self._worker_pool = WorkerPool(
|
||||
@ -177,8 +166,7 @@ class GraphEngine:
|
||||
event_queue=self._event_queue,
|
||||
graph=self._graph,
|
||||
layers=self._layers,
|
||||
flask_app=flask_app,
|
||||
context_vars=context_vars,
|
||||
execution_context=execution_context,
|
||||
min_workers=self._min_workers,
|
||||
max_workers=self._max_workers,
|
||||
scale_up_threshold=self._scale_up_threshold,
|
||||
|
||||
@ -5,26 +5,27 @@ Workers pull node IDs from the ready_queue, execute nodes, and push events
|
||||
to the event_queue for the dispatcher to process.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import final
|
||||
from typing import TYPE_CHECKING, final
|
||||
from uuid import uuid4
|
||||
|
||||
from flask import Flask
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.context import IExecutionContext
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
from .ready_queue import ReadyQueue
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
@final
|
||||
class Worker(threading.Thread):
|
||||
@ -44,8 +45,7 @@ class Worker(threading.Thread):
|
||||
layers: Sequence[GraphEngineLayer],
|
||||
stop_event: threading.Event,
|
||||
worker_id: int = 0,
|
||||
flask_app: Flask | None = None,
|
||||
context_vars: contextvars.Context | None = None,
|
||||
execution_context: IExecutionContext | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize worker thread.
|
||||
@ -56,19 +56,17 @@ class Worker(threading.Thread):
|
||||
graph: Graph containing nodes to execute
|
||||
layers: Graph engine layers for node execution hooks
|
||||
worker_id: Unique identifier for this worker
|
||||
flask_app: Optional Flask application for context preservation
|
||||
context_vars: Optional context variables to preserve in worker thread
|
||||
execution_context: Optional execution context for context preservation
|
||||
"""
|
||||
super().__init__(name=f"GraphWorker-{worker_id}", daemon=True)
|
||||
self._ready_queue = ready_queue
|
||||
self._event_queue = event_queue
|
||||
self._graph = graph
|
||||
self._worker_id = worker_id
|
||||
self._flask_app = flask_app
|
||||
self._context_vars = context_vars
|
||||
self._last_task_time = time.time()
|
||||
self._execution_context = execution_context
|
||||
self._stop_event = stop_event
|
||||
self._layers = layers if layers is not None else []
|
||||
self._last_task_time = time.time()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Worker is controlled via shared stop_event from GraphEngine.
|
||||
@ -135,11 +133,9 @@ class Worker(threading.Thread):
|
||||
|
||||
error: Exception | None = None
|
||||
|
||||
if self._flask_app and self._context_vars:
|
||||
with preserve_flask_contexts(
|
||||
flask_app=self._flask_app,
|
||||
context_vars=self._context_vars,
|
||||
):
|
||||
# Execute the node with preserved context if execution context is provided
|
||||
if self._execution_context is not None:
|
||||
with self._execution_context:
|
||||
self._invoke_node_run_start_hooks(node)
|
||||
try:
|
||||
node_events = node.run()
|
||||
|
||||
@ -8,9 +8,10 @@ DynamicScaler, and WorkerFactory into a single class.
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, final
|
||||
from typing import final
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.context import IExecutionContext
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphNodeEventBase
|
||||
|
||||
@ -20,11 +21,6 @@ from ..worker import Worker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from contextvars import Context
|
||||
|
||||
from flask import Flask
|
||||
|
||||
|
||||
@final
|
||||
class WorkerPool:
|
||||
@ -42,8 +38,7 @@ class WorkerPool:
|
||||
graph: Graph,
|
||||
layers: list[GraphEngineLayer],
|
||||
stop_event: threading.Event,
|
||||
flask_app: "Flask | None" = None,
|
||||
context_vars: "Context | None" = None,
|
||||
execution_context: IExecutionContext | None = None,
|
||||
min_workers: int | None = None,
|
||||
max_workers: int | None = None,
|
||||
scale_up_threshold: int | None = None,
|
||||
@ -57,8 +52,7 @@ class WorkerPool:
|
||||
event_queue: Queue for worker events
|
||||
graph: The workflow graph
|
||||
layers: Graph engine layers for node execution hooks
|
||||
flask_app: Optional Flask app for context preservation
|
||||
context_vars: Optional context variables
|
||||
execution_context: Optional execution context for context preservation
|
||||
min_workers: Minimum number of workers
|
||||
max_workers: Maximum number of workers
|
||||
scale_up_threshold: Queue depth to trigger scale up
|
||||
@ -67,8 +61,7 @@ class WorkerPool:
|
||||
self._ready_queue = ready_queue
|
||||
self._event_queue = event_queue
|
||||
self._graph = graph
|
||||
self._flask_app = flask_app
|
||||
self._context_vars = context_vars
|
||||
self._execution_context = execution_context
|
||||
self._layers = layers
|
||||
|
||||
# Scaling parameters with defaults
|
||||
@ -152,8 +145,7 @@ class WorkerPool:
|
||||
graph=self._graph,
|
||||
layers=self._layers,
|
||||
worker_id=worker_id,
|
||||
flask_app=self._flask_app,
|
||||
context_vars=self._context_vars,
|
||||
execution_context=self._execution_context,
|
||||
stop_event=self._stop_event,
|
||||
)
|
||||
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
import contextvars
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, NewType, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
@ -39,7 +37,6 @@ from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
from core.workflow.runtime import VariablePool
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
from .exc import (
|
||||
InvalidIteratorValueError,
|
||||
@ -51,6 +48,7 @@ from .exc import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.context import IExecutionContext
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -252,8 +250,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
self._execute_single_iteration_parallel,
|
||||
index=index,
|
||||
item=item,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
context_vars=contextvars.copy_context(),
|
||||
execution_context=self._capture_execution_context(),
|
||||
)
|
||||
future_to_index[future] = index
|
||||
|
||||
@ -306,11 +303,10 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
self,
|
||||
index: int,
|
||||
item: object,
|
||||
flask_app: Flask,
|
||||
context_vars: contextvars.Context,
|
||||
execution_context: "IExecutionContext",
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
|
||||
"""Execute a single iteration in parallel mode and return results."""
|
||||
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
|
||||
with execution_context:
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
events: list[GraphNodeEventBase] = []
|
||||
outputs_temp: list[object] = []
|
||||
@ -339,6 +335,12 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
graph_engine.graph_runtime_state.llm_usage,
|
||||
)
|
||||
|
||||
def _capture_execution_context(self) -> "IExecutionContext":
|
||||
"""Capture current execution context for parallel iterations."""
|
||||
from core.workflow.context import capture_current_context
|
||||
|
||||
return capture_current_context()
|
||||
|
||||
def _handle_iteration_success(
|
||||
self,
|
||||
started_at: datetime,
|
||||
|
||||
Reference in New Issue
Block a user