Compare commits

..

5 Commits

Author SHA1 Message Date
e2bd7dbd0d test(workflow): cover async persistence patch paths
Add workflow execution task coverage and async repository validation coverage for the mixed persistence implementation so Codecov patch coverage includes the changed persistence paths.
2026-05-20 21:59:39 +08:00
37636f78f5 test(workflow): cover async node persistence tasks
Expose the SQLAlchemy repository type to static analyzers without changing the runtime lazy import, and add coverage for metadata and data Celery task paths used by async workflow persistence.
2026-05-20 21:36:54 +08:00
b7834a42b6 Merge branch 'main' into laipz8200/workflow-persistence-mixed-mode 2026-05-20 21:08:51 +08:00
21387b3beb [autofix.ci] apply automated fixes 2026-05-20 13:02:56 +00:00
e4619b5b73 refactor(workflow): consolidate persistence async writes
Route workflow persistence mode from InvokeFrom so debugger executions keep synchronous DB writes while non-debug invocations enqueue Celery tasks through the default SQLAlchemy repositories.

Remove the legacy Celery workflow execution repositories, obsolete workflow node execution storage config, and tests tied only to the removed repository classes.
2026-05-20 20:59:18 +08:00
21 changed files with 853 additions and 1474 deletions

View File

@ -565,12 +565,6 @@ GRAPH_ENGINE_SCALE_UP_THRESHOLD=3
# Seconds of idle time before scaling down workers (default: 5.0)
GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME=5.0
# Workflow storage configuration
# Options: rdbms, hybrid
# rdbms: Use only the relational database (default)
# hybrid: Save new data to object storage, read from both object storage and RDBMS
WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
# Repository configuration
# Core workflow execution repository implementation
CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository

View File

@ -791,11 +791,6 @@ class WorkflowNodeExecutionConfig(BaseSettings):
default=100,
)
WORKFLOW_NODE_EXECUTION_STORAGE: str = Field(
default="rdbms",
description="Storage backend for WorkflowNodeExecution. Options: 'rdbms', 'hybrid'",
)
class RepositoryConfig(BaseSettings):
"""
@ -803,18 +798,12 @@ class RepositoryConfig(BaseSettings):
"""
CORE_WORKFLOW_EXECUTION_REPOSITORY: str = Field(
description="Repository implementation for WorkflowExecution. Options: "
"'core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository' (default), "
"'core.repositories.celery_workflow_execution_repository.CeleryWorkflowExecutionRepository'",
description="Repository implementation for WorkflowExecution. Specify as a module path.",
default="core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository",
)
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field(
description="Repository implementation for WorkflowNodeExecution. Options: "
"'core.repositories.sqlalchemy_workflow_node_execution_repository."
"SQLAlchemyWorkflowNodeExecutionRepository' (default), "
"'core.repositories.celery_workflow_node_execution_repository."
"CeleryWorkflowNodeExecutionRepository'",
description="Repository implementation for WorkflowNodeExecution. Specify as a module path.",
default="core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository",
)

View File

@ -235,6 +235,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
),
workflow_execution_repository=self._workflow_execution_repository,
workflow_node_execution_repository=self._workflow_node_execution_repository,
invoke_from=invoke_from,
trace_manager=self.application_generate_entity.trace_manager,
)

View File

@ -195,6 +195,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
),
workflow_execution_repository=self._workflow_execution_repository,
workflow_node_execution_repository=self._workflow_node_execution_repository,
invoke_from=invoke_from,
trace_manager=self.application_generate_entity.trace_manager,
)

View File

@ -162,6 +162,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
),
workflow_execution_repository=self._workflow_execution_repository,
workflow_node_execution_repository=self._workflow_node_execution_repository,
invoke_from=invoke_from,
trace_manager=self.application_generate_entity.trace_manager,
)

View File

@ -4,17 +4,19 @@ This layer mirrors the former ``WorkflowCycleManager`` responsibilities by
listening to ``GraphEngineEvent`` instances directly and persisting workflow
and node execution state via the injected repositories.
The design keeps domain persistence concerns inside the engine thread, while
allowing presentation layers to remain read-only observers of repository
state.
The layer owns domain-to-persistence event handling, while the injected
repositories choose the write strategy. Debug executions use synchronous
writes so developer tools can read DB state immediately; non-debug app
executions use a Celery-backed write path to keep DB writes out of the engine
thread.
"""
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Union
from typing import Any, Protocol, runtime_checkable
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.helper.trace_id_helper import ParentTraceContext
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
@ -59,6 +61,21 @@ class PersistenceWorkflowInfo:
graph_data: Mapping[str, Any]
def should_use_async_workflow_persistence(invoke_from: InvokeFrom) -> bool:
"""Return whether workflow execution state should be persisted through Celery."""
return invoke_from != InvokeFrom.DEBUGGER
@runtime_checkable
class AsyncPersistenceConfigurable(Protocol):
def set_async_persistence(self, enabled: bool) -> None: ...
def _configure_async_persistence(repository: object, enabled: bool) -> None:
if isinstance(repository, AsyncPersistenceConfigurable):
repository.set_async_persistence(enabled)
@dataclass(slots=True)
class _NodeRuntimeSnapshot:
"""Lightweight cache to keep node metadata across event phases."""
@ -77,10 +94,11 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
def __init__(
self,
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
application_generate_entity: AdvancedChatAppGenerateEntity | WorkflowAppGenerateEntity,
workflow_info: PersistenceWorkflowInfo,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
invoke_from: InvokeFrom | None = None,
trace_manager: TraceQueueManager | None = None,
) -> None:
super().__init__()
@ -88,6 +106,11 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
self._workflow_info = workflow_info
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
use_async_persistence = should_use_async_workflow_persistence(
invoke_from or application_generate_entity.invoke_from
)
_configure_async_persistence(self._workflow_execution_repository, use_async_persistence)
_configure_async_persistence(self._workflow_node_execution_repository, use_async_persistence)
self._trace_manager = trace_manager
self._workflow_execution: WorkflowExecution | None = None

View File

@ -2,8 +2,6 @@
from __future__ import annotations
from .celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
from .celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
from .factory import (
DifyCoreRepositoryFactory,
OrderConfig,
@ -15,8 +13,6 @@ from .sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutio
from .sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
__all__ = [
"CeleryWorkflowExecutionRepository",
"CeleryWorkflowNodeExecutionRepository",
"DifyCoreRepositoryFactory",
"OrderConfig",
"RepositoryImportError",

View File

@ -1,125 +0,0 @@
"""
Celery-based implementation of the WorkflowExecutionRepository.
This implementation uses Celery tasks for asynchronous storage operations,
providing improved performance by offloading database operations to background workers.
"""
import logging
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.repositories.factory import WorkflowExecutionRepository
from graphon.entities import WorkflowExecution
from libs.helper import extract_tenant_id
from models import Account, CreatorUserRole, EndUser
from models.enums import WorkflowRunTriggeredFrom
from tasks.workflow_execution_tasks import (
save_workflow_execution_task,
)
logger = logging.getLogger(__name__)
class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository):
"""
Celery-based implementation of the WorkflowExecutionRepository interface.
This implementation provides asynchronous storage capabilities by using Celery tasks
to handle database operations in background workers. This improves performance by
reducing the blocking time for workflow execution storage operations.
Key features:
- Asynchronous save operations using Celery tasks
- Support for multi-tenancy through tenant/app filtering
- Automatic retry and error handling through Celery
"""
_session_factory: sessionmaker
_tenant_id: str
_app_id: str | None
_triggered_from: WorkflowRunTriggeredFrom | None
_creator_user_id: str
_creator_user_role: CreatorUserRole
def __init__(
self,
session_factory: sessionmaker | Engine,
user: Account | EndUser,
app_id: str | None,
triggered_from: WorkflowRunTriggeredFrom | None,
):
"""
Initialize the repository with Celery task configuration and context information.
Args:
session_factory: SQLAlchemy sessionmaker or engine for fallback operations
user: Account or EndUser object containing tenant_id, user ID, and role information
app_id: App ID for filtering by application (can be None)
triggered_from: Source of the execution trigger (DEBUGGING or APP_RUN)
"""
# Store session factory for fallback operations
if isinstance(session_factory, Engine):
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
elif isinstance(session_factory, sessionmaker):
self._session_factory = session_factory
else:
raise ValueError(
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
)
# Extract tenant_id from user
tenant_id = extract_tenant_id(user)
if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id
# Store app context
self._app_id = app_id
# Extract user context
self._triggered_from = triggered_from
self._creator_user_id = user.id
# Determine user role based on user type
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
logger.info(
"Initialized CeleryWorkflowExecutionRepository for tenant %s, app %s, triggered_from %s",
self._tenant_id,
self._app_id,
self._triggered_from,
)
def save(self, execution: WorkflowExecution):
"""
Save or update a WorkflowExecution instance asynchronously using Celery.
This method queues the save operation as a Celery task and returns immediately,
providing improved performance for high-throughput scenarios.
Args:
execution: The WorkflowExecution instance to save or update
"""
try:
# Serialize execution for Celery task
execution_data = execution.model_dump()
# Queue the save operation as a Celery task (fire and forget)
save_workflow_execution_task.delay( # type: ignore
execution_data=execution_data,
tenant_id=self._tenant_id,
app_id=self._app_id or "",
triggered_from=self._triggered_from.value if self._triggered_from else "",
creator_user_id=self._creator_user_id,
creator_user_role=self._creator_user_role.value,
)
logger.debug("Queued async save for workflow execution: %s", execution.id_)
except Exception:
logger.exception("Failed to queue save operation for execution %s", execution.id_)
# In case of Celery failure, we could implement a fallback to synchronous save
# For now, we'll re-raise the exception
raise

View File

@ -1,196 +0,0 @@
"""
Celery-based implementation of the WorkflowNodeExecutionRepository.
This implementation uses Celery tasks for asynchronous storage operations,
providing improved performance by offloading database operations to background workers.
"""
import logging
from collections.abc import Sequence
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.repositories.factory import (
OrderConfig,
WorkflowNodeExecutionRepository,
)
from graphon.entities import WorkflowNodeExecution
from libs.helper import extract_tenant_id
from models import Account, CreatorUserRole, EndUser
from models.workflow import WorkflowNodeExecutionTriggeredFrom
from tasks.workflow_node_execution_tasks import (
save_workflow_node_execution_task,
)
logger = logging.getLogger(__name__)
class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
"""
Celery-based implementation of the WorkflowNodeExecutionRepository interface.
This implementation provides asynchronous storage capabilities by using Celery tasks
to handle database operations in background workers. This improves performance by
reducing the blocking time for workflow node execution storage operations.
Key features:
- Asynchronous save operations using Celery tasks
- In-memory cache for immediate reads
- Support for multi-tenancy through tenant/app filtering
- Automatic retry and error handling through Celery
"""
_session_factory: sessionmaker
_tenant_id: str
_app_id: str | None
_triggered_from: WorkflowNodeExecutionTriggeredFrom | None
_creator_user_id: str
_creator_user_role: CreatorUserRole
_execution_cache: dict[str, WorkflowNodeExecution]
_workflow_execution_mapping: dict[str, list[str]]
def __init__(
self,
session_factory: sessionmaker | Engine,
user: Account | EndUser,
app_id: str | None,
triggered_from: WorkflowNodeExecutionTriggeredFrom | None,
):
"""
Initialize the repository with Celery task configuration and context information.
Args:
session_factory: SQLAlchemy sessionmaker or engine for fallback operations
user: Account or EndUser object containing tenant_id, user ID, and role information
app_id: App ID for filtering by application (can be None)
triggered_from: Source of the execution trigger (SINGLE_STEP or WORKFLOW_RUN)
"""
# Store session factory for fallback operations
if isinstance(session_factory, Engine):
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
elif isinstance(session_factory, sessionmaker):
self._session_factory = session_factory
else:
raise ValueError(
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
)
# Extract tenant_id from user
tenant_id = extract_tenant_id(user)
if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id
# Store app context
self._app_id = app_id
# Extract user context
self._triggered_from = triggered_from
self._creator_user_id = user.id
# Determine user role based on user type
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
# In-memory cache for workflow node executions
self._execution_cache = {}
# Cache for mapping workflow_execution_ids to execution IDs for efficient retrieval
self._workflow_execution_mapping = {}
logger.info(
"Initialized CeleryWorkflowNodeExecutionRepository for tenant %s, app %s, triggered_from %s",
self._tenant_id,
self._app_id,
self._triggered_from,
)
def save(self, execution: WorkflowNodeExecution):
"""
Save or update a WorkflowNodeExecution instance to cache and asynchronously to database.
This method stores the execution in cache immediately for fast reads and queues
the save operation as a Celery task without tracking the task status.
Args:
execution: The WorkflowNodeExecution instance to save or update
"""
try:
# Store in cache immediately for fast reads
self._execution_cache[execution.id] = execution
# Update workflow execution mapping for efficient retrieval
if execution.workflow_execution_id:
if execution.workflow_execution_id not in self._workflow_execution_mapping:
self._workflow_execution_mapping[execution.workflow_execution_id] = []
if execution.id not in self._workflow_execution_mapping[execution.workflow_execution_id]:
self._workflow_execution_mapping[execution.workflow_execution_id].append(execution.id)
# Serialize execution for Celery task
execution_data = execution.model_dump()
# Queue the save operation as a Celery task (fire and forget)
save_workflow_node_execution_task.delay(
execution_data=execution_data,
tenant_id=self._tenant_id,
app_id=self._app_id or "",
triggered_from=self._triggered_from.value if self._triggered_from else "",
creator_user_id=self._creator_user_id,
creator_user_role=self._creator_user_role.value,
)
logger.debug("Cached and queued async save for workflow node execution: %s", execution.id)
except Exception:
logger.exception("Failed to cache or queue save operation for node execution %s", execution.id)
# In case of Celery failure, we could implement a fallback to synchronous save
# For now, we'll re-raise the exception
raise
def get_by_workflow_execution(
self,
workflow_execution_id: str,
order_config: OrderConfig | None = None,
) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all workflow node executions for a workflow execution from cache.
Args:
workflow_execution_id: The workflow execution identifier
order_config: Optional configuration for ordering results
Returns:
A sequence of WorkflowNodeExecution instances
"""
try:
# Get execution IDs for this workflow execution from cache
execution_ids = self._workflow_execution_mapping.get(workflow_execution_id, [])
# Retrieve executions from cache
result = []
for execution_id in execution_ids:
if execution_id in self._execution_cache:
result.append(self._execution_cache[execution_id])
# Apply ordering if specified
if order_config and result:
# Sort based on the order configuration
reverse = order_config.order_direction == "desc"
# Sort by multiple fields if specified
for field_name in reversed(order_config.order_by):
result.sort(key=lambda x: getattr(x, field_name, 0), reverse=reverse)
logger.debug(
"Retrieved %d workflow node executions for execution %s from cache",
len(result),
workflow_execution_id,
)
return result
except Exception:
logger.exception(
"Failed to get workflow node executions for execution %s from cache",
workflow_execution_id,
)
return []

View File

@ -20,6 +20,7 @@ from models import (
WorkflowRun,
)
from models.enums import WorkflowRunTriggeredFrom
from tasks.workflow_execution_tasks import save_workflow_execution_task
logger = logging.getLogger(__name__)
@ -36,6 +37,8 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
performance by reducing database queries.
"""
_use_async_persistence: bool
def __init__(
self,
session_factory: sessionmaker | Engine,
@ -81,6 +84,16 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
# Initialize in-memory cache for workflow executions
# Key: execution_id, Value: WorkflowRun (DB model)
self._execution_cache: dict[str, WorkflowRun] = {}
self._use_async_persistence = False
def set_async_persistence(self, enabled: bool) -> None:
"""
Configure whether save operations should be queued through Celery.
Debug executions keep this disabled so the debugger can immediately read persisted
workflow state. Non-debug app executions enable it from the persistence layer.
"""
self._use_async_persistence = enabled
def _to_domain_model(self, db_model: WorkflowRun) -> WorkflowExecution:
"""
@ -190,6 +203,10 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
Args:
execution: The WorkflowExecution domain entity to persist
"""
if self._use_async_persistence:
self._queue_async_save(execution)
return
# Convert domain model to database model using tenant context and other attributes
db_model = self._to_db_model(execution)
@ -209,3 +226,20 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
# Update the in-memory cache for faster subsequent lookups
self._execution_cache[db_model.id] = db_model
def _queue_async_save(self, execution: WorkflowExecution) -> None:
if not self._triggered_from:
raise ValueError("triggered_from is required in repository constructor")
if not self._creator_user_id:
raise ValueError("created_by is required in repository constructor")
if not self._creator_user_role:
raise ValueError("created_by_role is required in repository constructor")
save_workflow_execution_task.delay(
execution_data=execution.model_dump(),
tenant_id=self._tenant_id,
app_id=self._app_id or "",
triggered_from=self._triggered_from.value,
creator_user_id=self._creator_user_id,
creator_user_role=self._creator_user_role.value,
)

View File

@ -37,6 +37,10 @@ from models.model import UploadFile
from models.workflow import WorkflowNodeExecutionOffload
from services.file_service import FileService
from services.variable_truncator import VariableTruncator
from tasks.workflow_node_execution_tasks import (
save_workflow_node_execution_data_task,
save_workflow_node_execution_task,
)
logger = logging.getLogger(__name__)
@ -60,6 +64,8 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
performance by reducing database queries.
"""
_use_async_persistence: bool
def __init__(
self,
session_factory: sessionmaker | Engine,
@ -108,6 +114,16 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
# Initialize FileService for handling offloaded data
self._file_service = FileService(session_factory)
self._use_async_persistence = False
def set_async_persistence(self, enabled: bool) -> None:
"""
Configure whether save operations should be queued through Celery.
Debug executions keep this disabled so node data is readable immediately. Non-debug
app executions enable it from the workflow persistence layer.
"""
self._use_async_persistence = enabled
def _create_truncator(self) -> VariableTruncator:
return VariableTruncator(
@ -336,6 +352,10 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
# Only the final call contains the complete inputs and outputs payloads, so earlier invocations
# must tolerate missing data without attempting to offload variables.
if self._use_async_persistence:
self._queue_async_save(execution)
return
# Convert domain model to database model using tenant context and other attributes
db_model = self._to_db_model(execution)
@ -400,6 +420,10 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
self._node_execution_cache[db_model.node_execution_id] = db_model
def save_execution_data(self, execution: WorkflowNodeExecution):
if self._use_async_persistence:
self._queue_async_save_execution_data(execution)
return
domain_model = execution
with self._session_factory(expire_on_commit=False) as session:
query = WorkflowNodeExecutionModel.preload_offload_data(select(WorkflowNodeExecutionModel)).where(
@ -457,6 +481,40 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
session.merge(db_model)
session.flush()
def _queue_async_save(self, execution: WorkflowNodeExecution) -> None:
if not self._triggered_from:
raise ValueError("triggered_from is required in repository constructor")
if not self._creator_user_id:
raise ValueError("created_by is required in repository constructor")
if not self._creator_user_role:
raise ValueError("created_by_role is required in repository constructor")
save_workflow_node_execution_task.delay(
execution_data=execution.model_dump(),
tenant_id=self._tenant_id,
app_id=self._app_id or "",
triggered_from=self._triggered_from.value,
creator_user_id=self._creator_user_id,
creator_user_role=self._creator_user_role.value,
)
def _queue_async_save_execution_data(self, execution: WorkflowNodeExecution) -> None:
if not self._triggered_from:
raise ValueError("triggered_from is required in repository constructor")
if not self._creator_user_id:
raise ValueError("created_by is required in repository constructor")
if not self._creator_user_role:
raise ValueError("created_by_role is required in repository constructor")
save_workflow_node_execution_data_task.delay(
execution_data=execution.model_dump(),
tenant_id=self._tenant_id,
app_id=self._app_id or "",
triggered_from=self._triggered_from.value,
creator_user_id=self._creator_user_id,
creator_user_role=self._creator_user_role.value,
)
def get_db_models_by_workflow_run(
self,
workflow_run_id: str,

View File

@ -108,9 +108,10 @@ def _create_workflow_run_from_execution(
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
)
workflow_run.error = execution.error_message
workflow_run.elapsed_time = execution.elapsed_time
workflow_run.elapsed_time = _calculate_elapsed_time(execution)
workflow_run.total_tokens = execution.total_tokens
workflow_run.total_steps = execution.total_steps
workflow_run.exceptions_count = execution.exceptions_count
workflow_run.created_by_role = creator_user_role
workflow_run.created_by = creator_user_id
workflow_run.created_at = execution.started_at
@ -129,7 +130,14 @@ def _update_workflow_run_from_execution(workflow_run: WorkflowRun, execution: Wo
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
)
workflow_run.error = execution.error_message
workflow_run.elapsed_time = execution.elapsed_time
workflow_run.elapsed_time = _calculate_elapsed_time(execution)
workflow_run.total_tokens = execution.total_tokens
workflow_run.total_steps = execution.total_steps
workflow_run.exceptions_count = execution.exceptions_count
workflow_run.finished_at = execution.finished_at
def _calculate_elapsed_time(execution: WorkflowExecution) -> float:
if execution.finished_at is None:
return execution.elapsed_time
return max((execution.finished_at - execution.started_at).total_seconds(), 0.0)

View File

@ -7,7 +7,7 @@ improving performance by offloading storage operations to background workers.
import json
import logging
from typing import Any
from typing import TYPE_CHECKING, Any
from celery import shared_task
from sqlalchemy import select
@ -17,9 +17,14 @@ from graphon.entities.workflow_node_execution import (
WorkflowNodeExecution,
)
from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter
from models import CreatorUserRole, WorkflowNodeExecutionModel
from models import Account, CreatorUserRole, EndUser, WorkflowNodeExecutionModel
from models.workflow import WorkflowNodeExecutionTriggeredFrom
if TYPE_CHECKING:
from core.repositories.sqlalchemy_workflow_node_execution_repository import (
SQLAlchemyWorkflowNodeExecutionRepository,
)
logger = logging.getLogger(__name__)
@ -34,7 +39,7 @@ def save_workflow_node_execution_task(
creator_user_role: str,
) -> bool:
"""
Asynchronously save or update a workflow node execution to the database.
Asynchronously save or update workflow node execution metadata to the database.
Args:
execution_data: Serialized WorkflowNodeExecution data
@ -49,20 +54,16 @@ def save_workflow_node_execution_task(
"""
try:
with session_factory.create_session() as session:
# Deserialize execution data
execution = WorkflowNodeExecution.model_validate(execution_data)
# Check if node execution already exists
existing_execution = session.scalar(
select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution.id)
)
if existing_execution:
# Update existing node execution
_update_node_execution_from_domain(existing_execution, execution)
logger.debug("Updated existing workflow node execution: %s", execution.id)
_update_node_execution_metadata(existing_execution, execution)
logger.debug("Updated existing workflow node execution metadata: %s", execution.id)
else:
# Create new node execution
node_execution = _create_node_execution_from_domain(
execution=execution,
tenant_id=tenant_id,
@ -79,10 +80,76 @@ def save_workflow_node_execution_task(
except Exception as e:
logger.exception("Failed to save workflow node execution %s", execution_data.get("id", "unknown"))
# Retry the task with exponential backoff
raise self.retry(exc=e, countdown=60 * (2**self.request.retries))
@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60)
def save_workflow_node_execution_data_task(
self,
execution_data: dict[str, Any],
tenant_id: str,
app_id: str,
triggered_from: str,
creator_user_id: str,
creator_user_role: str,
) -> bool:
"""
Asynchronously save full workflow node execution data to the database.
This path preserves the SQLAlchemy repository's truncation and offload behavior while
moving the blocking database and storage work out of the workflow engine thread.
"""
try:
execution = WorkflowNodeExecution.model_validate(execution_data)
repository = _create_sqlalchemy_repository(
tenant_id=tenant_id,
app_id=app_id,
triggered_from=triggered_from,
creator_user_id=creator_user_id,
creator_user_role=creator_user_role,
)
repository.save(execution)
repository.save_execution_data(execution)
return True
except Exception as e:
logger.exception("Failed to save workflow node execution data %s", execution_data.get("id", "unknown"))
raise self.retry(exc=e, countdown=60 * (2**self.request.retries))
def _create_sqlalchemy_repository(
*,
tenant_id: str,
app_id: str,
triggered_from: str,
creator_user_id: str,
creator_user_role: str,
) -> "SQLAlchemyWorkflowNodeExecutionRepository":
from core.repositories.sqlalchemy_workflow_node_execution_repository import (
SQLAlchemyWorkflowNodeExecutionRepository,
)
session_maker = session_factory.get_session_maker()
role = CreatorUserRole(creator_user_role)
user: Account | EndUser | None
with session_maker() as session:
if role == CreatorUserRole.ACCOUNT:
user = session.get(Account, creator_user_id)
if user is not None:
user.set_tenant_id(tenant_id)
else:
user = session.get(EndUser, creator_user_id)
if user is None:
raise ValueError(f"Creator user {creator_user_id} not found for workflow node execution persistence")
return SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_maker,
user=user,
app_id=app_id or None,
triggered_from=WorkflowNodeExecutionTriggeredFrom(triggered_from),
)
def _create_node_execution_from_domain(
execution: WorkflowNodeExecution,
tenant_id: str,
@ -108,16 +175,11 @@ def _create_node_execution_from_domain(
node_execution.title = execution.title
node_execution.node_execution_id = execution.node_execution_id
# Serialize complex data as JSON
node_execution.inputs = "{}"
node_execution.process_data = "{}"
node_execution.outputs = "{}"
json_converter = WorkflowRuntimeTypeConverter()
node_execution.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) if execution.inputs else "{}"
node_execution.process_data = (
json.dumps(json_converter.to_json_encodable(execution.process_data)) if execution.process_data else "{}"
)
node_execution.outputs = (
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
)
# Convert metadata enum keys to strings for JSON serialization
if execution.metadata:
metadata_for_json = {
key.value if hasattr(key, "value") else str(key): value for key, value in execution.metadata.items()
@ -137,20 +199,11 @@ def _create_node_execution_from_domain(
return node_execution
def _update_node_execution_from_domain(node_execution: WorkflowNodeExecutionModel, execution: WorkflowNodeExecution):
def _update_node_execution_metadata(node_execution: WorkflowNodeExecutionModel, execution: WorkflowNodeExecution):
"""
Update a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity.
Update WorkflowNodeExecutionModel metadata without changing persisted data payload fields.
"""
# Update serialized data
json_converter = WorkflowRuntimeTypeConverter()
node_execution.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) if execution.inputs else "{}"
node_execution.process_data = (
json.dumps(json_converter.to_json_encodable(execution.process_data)) if execution.process_data else "{}"
)
node_execution.outputs = (
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
)
# Convert metadata enum keys to strings for JSON serialization
if execution.metadata:
metadata_for_json = {
key.value if hasattr(key, "value") else str(key): value for key, value in execution.metadata.items()
@ -159,7 +212,6 @@ def _update_node_execution_from_domain(node_execution: WorkflowNodeExecutionMode
else:
node_execution.execution_metadata = "{}"
# Update other fields
node_execution.status = execution.status
node_execution.error = execution.error
node_execution.elapsed_time = execution.elapsed_time

View File

@ -5,8 +5,12 @@ from types import SimpleNamespace
import pytest
from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.workflow.layers.persistence import (
PersistenceWorkflowInfo,
WorkflowPersistenceLayer,
should_use_async_workflow_persistence,
)
from core.ops.ops_trace_manager import TraceTask, TraceTaskName
from core.workflow.system_variables import SystemVariableKey, build_system_variables
from graphon.entities import WorkflowNodeExecution
@ -39,6 +43,7 @@ class _RepoRecorder:
def __init__(self) -> None:
self.saved: list[object] = []
self.saved_exec_data: list[object] = []
self.async_enabled: bool | None = None
def save(self, entity):
self.saved.append(entity)
@ -46,6 +51,9 @@ class _RepoRecorder:
def save_execution_data(self, entity):
self.saved_exec_data.append(entity)
def set_async_persistence(self, enabled: bool) -> None:
self.async_enabled = enabled
def _naive_utc_now() -> datetime:
return datetime.now(UTC).replace(tzinfo=None)
@ -55,6 +63,7 @@ def _make_layer(
system_variables: list | None = None,
*,
extras: dict | None = None,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
trace_manager: object | None = None,
):
system_variables = system_variables or build_system_variables(
@ -74,7 +83,7 @@ def _make_layer(
files=[],
user_id="user",
stream=False,
invoke_from=None,
invoke_from=invoke_from,
trace_manager=None,
workflow_execution_id="run-id",
extras=extras or {},
@ -96,6 +105,7 @@ def _make_layer(
workflow_info=workflow_info,
workflow_execution_repository=workflow_execution_repo,
workflow_node_execution_repository=workflow_node_execution_repo,
invoke_from=invoke_from,
trace_manager=trace_manager,
)
layer.initialize(read_only_state, command_channel=None)
@ -104,6 +114,30 @@ def _make_layer(
class TestWorkflowPersistenceLayer:
@pytest.mark.parametrize(
("invoke_from", "expected"),
[
(InvokeFrom.DEBUGGER, False),
(InvokeFrom.WEB_APP, True),
(InvokeFrom.SERVICE_API, True),
(InvokeFrom.TRIGGER, True),
],
)
def test_should_use_async_workflow_persistence(self, invoke_from: InvokeFrom, expected: bool):
assert should_use_async_workflow_persistence(invoke_from) is expected
def test_configures_repositories_for_debug_synchronous_persistence(self):
_, exec_repo, node_repo, _ = _make_layer(invoke_from=InvokeFrom.DEBUGGER)
assert exec_repo.async_enabled is False
assert node_repo.async_enabled is False
def test_configures_repositories_for_non_debug_async_persistence(self):
_, exec_repo, node_repo, _ = _make_layer(invoke_from=InvokeFrom.WEB_APP)
assert exec_repo.async_enabled is True
assert node_repo.async_enabled is True
def test_on_graph_start_resets_state(self):
layer, _, _, _ = _make_layer()
layer._workflow_execution = object()

View File

@ -1,248 +0,0 @@
"""
Unit tests for CeleryWorkflowExecutionRepository.
These tests verify the Celery-based asynchronous storage functionality
for workflow execution data.
"""
from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
from graphon.entities import WorkflowExecution
from graphon.enums import WorkflowType
from libs.datetime_utils import naive_utc_now
from models import Account, EndUser
from models.enums import WorkflowRunTriggeredFrom
@pytest.fixture
def mock_session_factory():
"""Mock SQLAlchemy session factory."""
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
# Create a real sessionmaker with in-memory SQLite for testing
engine = create_engine("sqlite:///:memory:")
return sessionmaker(bind=engine)
@pytest.fixture
def mock_account():
"""Mock Account user."""
account = Mock(spec=Account)
account.id = str(uuid4())
account.current_tenant_id = str(uuid4())
return account
@pytest.fixture
def mock_end_user():
"""Mock EndUser."""
user = Mock(spec=EndUser)
user.id = str(uuid4())
user.tenant_id = str(uuid4())
return user
@pytest.fixture
def sample_workflow_execution():
"""Sample WorkflowExecution for testing."""
return WorkflowExecution.new(
id_=str(uuid4()),
workflow_id=str(uuid4()),
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input1": "value1"},
started_at=naive_utc_now(),
)
class TestCeleryWorkflowExecutionRepository:
"""Test cases for CeleryWorkflowExecutionRepository."""
def test_init_with_sessionmaker(self, mock_session_factory, mock_account):
"""Test repository initialization with sessionmaker."""
app_id = "test-app-id"
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id=app_id,
triggered_from=triggered_from,
)
assert repo._tenant_id == mock_account.current_tenant_id
assert repo._app_id == app_id
assert repo._triggered_from == triggered_from
assert repo._creator_user_id == mock_account.id
assert repo._creator_user_role is not None
def test_init_basic_functionality(self, mock_session_factory, mock_account):
"""Test repository initialization basic functionality."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Verify basic initialization
assert repo._tenant_id == mock_account.current_tenant_id
assert repo._app_id == "test-app"
assert repo._triggered_from == WorkflowRunTriggeredFrom.DEBUGGING
def test_init_with_end_user(self, mock_session_factory, mock_end_user):
"""Test repository initialization with EndUser."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_end_user,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert repo._tenant_id == mock_end_user.tenant_id
def test_init_without_tenant_id_raises_error(self, mock_session_factory):
"""Test that initialization fails without tenant_id."""
# Create a mock Account with no tenant_id
user = Mock(spec=Account)
user.current_tenant_id = None
user.id = str(uuid4())
with pytest.raises(ValueError, match="User must have a tenant_id"):
CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=user,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
def test_save_queues_celery_task(self, mock_task, mock_session_factory, mock_account, sample_workflow_execution):
"""Test that save operation queues a Celery task without tracking."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
repo.save(sample_workflow_execution)
# Verify Celery task was queued with correct parameters
mock_task.delay.assert_called_once()
call_args = mock_task.delay.call_args[1]
assert call_args["execution_data"] == sample_workflow_execution.model_dump()
assert call_args["tenant_id"] == mock_account.current_tenant_id
assert call_args["app_id"] == "test-app"
assert call_args["triggered_from"] == WorkflowRunTriggeredFrom.APP_RUN
assert call_args["creator_user_id"] == mock_account.id
# Verify no task tracking occurs (no _pending_saves attribute)
assert not hasattr(repo, "_pending_saves")
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
def test_save_handles_celery_failure(
self, mock_task, mock_session_factory, mock_account, sample_workflow_execution
):
"""Test that save operation handles Celery task failures."""
mock_task.delay.side_effect = Exception("Celery is down")
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
with pytest.raises(Exception, match="Celery is down"):
repo.save(sample_workflow_execution)
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
def test_save_operation_fire_and_forget(
self, mock_task, mock_session_factory, mock_account, sample_workflow_execution
):
"""Test that save operation works in fire-and-forget mode."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
# Test that save doesn't block or maintain state
repo.save(sample_workflow_execution)
# Verify no pending saves are tracked (no _pending_saves attribute)
assert not hasattr(repo, "_pending_saves")
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
def test_multiple_save_operations(self, mock_task, mock_session_factory, mock_account):
"""Test multiple save operations work correctly."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
# Create multiple executions
exec1 = WorkflowExecution.new(
id_=str(uuid4()),
workflow_id=str(uuid4()),
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input1": "value1"},
started_at=naive_utc_now(),
)
exec2 = WorkflowExecution.new(
id_=str(uuid4()),
workflow_id=str(uuid4()),
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input2": "value2"},
started_at=naive_utc_now(),
)
# Save both executions
repo.save(exec1)
repo.save(exec2)
# Should work without issues and not maintain state (no _pending_saves attribute)
assert not hasattr(repo, "_pending_saves")
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
def test_save_with_different_user_types(self, mock_task, mock_session_factory, mock_end_user):
"""Test save operation with different user types."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_end_user,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
execution = WorkflowExecution.new(
id_=str(uuid4()),
workflow_id=str(uuid4()),
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input1": "value1"},
started_at=naive_utc_now(),
)
repo.save(execution)
# Verify task was called with EndUser context
mock_task.delay.assert_called_once()
call_args = mock_task.delay.call_args[1]
assert call_args["tenant_id"] == mock_end_user.tenant_id
assert call_args["creator_user_id"] == mock_end_user.id

View File

@ -1,349 +0,0 @@
"""
Unit tests for CeleryWorkflowNodeExecutionRepository.
These tests verify the Celery-based asynchronous storage functionality
for workflow node execution data.
"""
from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
from core.repositories.factory import OrderConfig
from graphon.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
)
from graphon.enums import BuiltinNodeTypes
from libs.datetime_utils import naive_utc_now
from models import Account, EndUser
from models.workflow import WorkflowNodeExecutionTriggeredFrom
@pytest.fixture
def mock_session_factory():
"""Mock SQLAlchemy session factory."""
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
# Create a real sessionmaker with in-memory SQLite for testing
engine = create_engine("sqlite:///:memory:")
return sessionmaker(bind=engine)
@pytest.fixture
def mock_account():
"""Mock Account user."""
account = Mock(spec=Account)
account.id = str(uuid4())
account.current_tenant_id = str(uuid4())
return account
@pytest.fixture
def mock_end_user():
"""Mock EndUser."""
user = Mock(spec=EndUser)
user.id = str(uuid4())
user.tenant_id = str(uuid4())
return user
@pytest.fixture
def sample_workflow_node_execution():
"""Sample WorkflowNodeExecution for testing."""
return WorkflowNodeExecution(
id=str(uuid4()),
node_execution_id=str(uuid4()),
workflow_id=str(uuid4()),
workflow_execution_id=str(uuid4()),
index=1,
node_id="test_node",
node_type=BuiltinNodeTypes.START,
title="Test Node",
inputs={"input1": "value1"},
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=naive_utc_now(),
)
class TestCeleryWorkflowNodeExecutionRepository:
"""Test cases for CeleryWorkflowNodeExecutionRepository."""
def test_init_with_sessionmaker(self, mock_session_factory, mock_account):
"""Test repository initialization with sessionmaker."""
app_id = "test-app-id"
triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id=app_id,
triggered_from=triggered_from,
)
assert repo._tenant_id == mock_account.current_tenant_id
assert repo._app_id == app_id
assert repo._triggered_from == triggered_from
assert repo._creator_user_id == mock_account.id
assert repo._creator_user_role is not None
def test_init_with_cache_initialized(self, mock_session_factory, mock_account):
"""Test repository initialization with cache properly initialized."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
assert repo._execution_cache == {}
assert repo._workflow_execution_mapping == {}
def test_init_with_end_user(self, mock_session_factory, mock_end_user):
"""Test repository initialization with EndUser."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_end_user,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
assert repo._tenant_id == mock_end_user.tenant_id
def test_init_without_tenant_id_raises_error(self, mock_session_factory):
"""Test that initialization fails without tenant_id."""
# Create a mock Account with no tenant_id
user = Mock(spec=Account)
user.current_tenant_id = None
user.id = str(uuid4())
with pytest.raises(ValueError, match="User must have a tenant_id"):
CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=user,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_save_caches_and_queues_celery_task(
self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution
):
"""Test that save operation caches execution and queues a Celery task."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
repo.save(sample_workflow_node_execution)
# Verify Celery task was queued with correct parameters
mock_task.delay.assert_called_once()
call_args = mock_task.delay.call_args[1]
assert call_args["execution_data"] == sample_workflow_node_execution.model_dump()
assert call_args["tenant_id"] == mock_account.current_tenant_id
assert call_args["app_id"] == "test-app"
assert call_args["triggered_from"] == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
assert call_args["creator_user_id"] == mock_account.id
# Verify execution is cached
assert sample_workflow_node_execution.id in repo._execution_cache
assert repo._execution_cache[sample_workflow_node_execution.id] == sample_workflow_node_execution
# Verify workflow execution mapping is updated
assert sample_workflow_node_execution.workflow_execution_id in repo._workflow_execution_mapping
assert (
sample_workflow_node_execution.id
in repo._workflow_execution_mapping[sample_workflow_node_execution.workflow_execution_id]
)
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_save_handles_celery_failure(
self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution
):
"""Test that save operation handles Celery task failures."""
mock_task.delay.side_effect = Exception("Celery is down")
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
with pytest.raises(Exception, match="Celery is down"):
repo.save(sample_workflow_node_execution)
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_get_by_workflow_execution_from_cache(
self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution
):
"""Test that get_by_workflow_execution retrieves executions from cache."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Save execution to cache first
repo.save(sample_workflow_node_execution)
workflow_execution_id = sample_workflow_node_execution.workflow_execution_id
order_config = OrderConfig(order_by=["index"], order_direction="asc")
result = repo.get_by_workflow_execution(workflow_execution_id, order_config)
# Verify results were retrieved from cache
assert len(result) == 1
assert result[0].id == sample_workflow_node_execution.id
assert result[0] is sample_workflow_node_execution
def test_get_by_workflow_execution_without_order_config(self, mock_session_factory, mock_account):
"""Test get_by_workflow_execution without order configuration."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
result = repo.get_by_workflow_execution("workflow-run-id")
# Should return empty list since nothing in cache
assert len(result) == 0
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_cache_operations(self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution):
"""Test cache operations work correctly."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Test saving to cache
repo.save(sample_workflow_node_execution)
# Verify cache contains the execution
assert sample_workflow_node_execution.id in repo._execution_cache
# Test retrieving from cache
result = repo.get_by_workflow_execution(sample_workflow_node_execution.workflow_execution_id)
assert len(result) == 1
assert result[0].id == sample_workflow_node_execution.id
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_multiple_executions_same_workflow(self, mock_task, mock_session_factory, mock_account):
"""Test multiple executions for the same workflow."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Create multiple executions for the same workflow
workflow_execution_id = str(uuid4())
exec1 = WorkflowNodeExecution(
id=str(uuid4()),
node_execution_id=str(uuid4()),
workflow_id=str(uuid4()),
workflow_execution_id=workflow_execution_id,
index=1,
node_id="node1",
node_type=BuiltinNodeTypes.START,
title="Node 1",
inputs={"input1": "value1"},
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=naive_utc_now(),
)
exec2 = WorkflowNodeExecution(
id=str(uuid4()),
node_execution_id=str(uuid4()),
workflow_id=str(uuid4()),
workflow_execution_id=workflow_execution_id,
index=2,
node_id="node2",
node_type=BuiltinNodeTypes.LLM,
title="Node 2",
inputs={"input2": "value2"},
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=naive_utc_now(),
)
# Save both executions
repo.save(exec1)
repo.save(exec2)
# Verify both are cached and mapped
assert len(repo._execution_cache) == 2
assert len(repo._workflow_execution_mapping[workflow_execution_id]) == 2
# Test retrieval
result = repo.get_by_workflow_execution(workflow_execution_id)
assert len(result) == 2
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_ordering_functionality(self, mock_task, mock_session_factory, mock_account):
"""Test ordering functionality works correctly."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Create executions with different indices
workflow_execution_id = str(uuid4())
exec1 = WorkflowNodeExecution(
id=str(uuid4()),
node_execution_id=str(uuid4()),
workflow_id=str(uuid4()),
workflow_execution_id=workflow_execution_id,
index=2,
node_id="node2",
node_type=BuiltinNodeTypes.START,
title="Node 2",
inputs={},
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=naive_utc_now(),
)
exec2 = WorkflowNodeExecution(
id=str(uuid4()),
node_execution_id=str(uuid4()),
workflow_id=str(uuid4()),
workflow_execution_id=workflow_execution_id,
index=1,
node_id="node1",
node_type=BuiltinNodeTypes.LLM,
title="Node 1",
inputs={},
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=naive_utc_now(),
)
# Save in random order
repo.save(exec1)
repo.save(exec2)
# Test ascending order
order_config = OrderConfig(order_by=["index"], order_direction="asc")
result = repo.get_by_workflow_execution(workflow_execution_id, order_config)
assert len(result) == 2
assert result[0].index == 1
assert result[1].index == 2
# Test descending order
order_config = OrderConfig(order_by=["index"], order_direction="desc")
result = repo.get_by_workflow_execution(workflow_execution_id, order_config)
assert len(result) == 2
assert result[0].index == 2
assert result[1].index == 1

View File

@ -1,5 +1,5 @@
from datetime import UTC, datetime
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
@ -240,6 +240,78 @@ class TestSQLAlchemyWorkflowExecutionRepository:
cached_model = repo._execution_cache[sample_workflow_execution.id_]
assert cached_model.id == sample_workflow_execution.id_
def test_save_rejects_existing_run_from_other_tenant(
self, mock_session_factory, mock_account, sample_workflow_execution
):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test_app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
existing_run = WorkflowRun()
existing_run.id = sample_workflow_execution.id_
existing_run.tenant_id = "other-tenant"
session = mock_session_factory.return_value.__enter__.return_value
session.get.return_value = existing_run
with pytest.raises(ValueError, match="Unauthorized access to workflow run"):
repo.save(sample_workflow_execution)
session.merge.assert_not_called()
session.commit.assert_not_called()
@patch("core.repositories.sqlalchemy_workflow_execution_repository.save_workflow_execution_task")
def test_save_queues_celery_task_when_async_persistence_enabled(
self, mock_task, mock_session_factory, mock_account, sample_workflow_execution
):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test_app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
repo.set_async_persistence(True)
repo.save(sample_workflow_execution)
mock_task.delay.assert_called_once()
call_args = mock_task.delay.call_args.kwargs
assert call_args["execution_data"] == sample_workflow_execution.model_dump()
assert call_args["tenant_id"] == mock_account.current_tenant_id
assert call_args["app_id"] == "test_app"
assert call_args["triggered_from"] == WorkflowRunTriggeredFrom.APP_RUN
assert call_args["creator_user_id"] == mock_account.id
session = mock_session_factory.return_value.__enter__.return_value
session.merge.assert_not_called()
@patch("core.repositories.sqlalchemy_workflow_execution_repository.save_workflow_execution_task")
def test_queue_async_save_requires_context(
self, mock_task, mock_session_factory, mock_account, sample_workflow_execution
):
repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test_app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
repo._triggered_from = None
with pytest.raises(ValueError, match="triggered_from is required"):
repo._queue_async_save(sample_workflow_execution)
repo._triggered_from = WorkflowRunTriggeredFrom.APP_RUN
repo._creator_user_id = None
with pytest.raises(ValueError, match="created_by is required"):
repo._queue_async_save(sample_workflow_execution)
repo._creator_user_id = "user-id"
repo._creator_user_role = None
with pytest.raises(ValueError, match="created_by_role is required"):
repo._queue_async_save(sample_workflow_execution)
mock_task.delay.assert_not_called()
def test_save_uses_execution_started_at_when_record_does_not_exist(
self, mock_session_factory, mock_account, sample_workflow_execution
):

View File

@ -6,7 +6,7 @@ from collections.abc import Mapping
from datetime import UTC, datetime
from types import SimpleNamespace
from typing import Any
from unittest.mock import MagicMock, Mock
from unittest.mock import MagicMock, Mock, patch
import psycopg2.errors
import pytest
@ -599,6 +599,124 @@ def test_save_execution_data_handles_missing_db_model(monkeypatch: pytest.Monkey
assert merged.inputs == '{"a": 1}'
@patch("core.repositories.sqlalchemy_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_save_queues_celery_task_when_async_persistence_enabled(mock_task, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
repo.set_async_persistence(True)
execution = _execution(inputs={"a": 1})
repo.save(execution)
mock_task.delay.assert_called_once()
call_args = mock_task.delay.call_args.kwargs
assert call_args["execution_data"] == execution.model_dump()
assert call_args["tenant_id"] == "tenant"
assert call_args["app_id"] == "app"
assert call_args["triggered_from"] == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
assert call_args["creator_user_id"] == "user"
assert set(call_args) == {
"execution_data",
"tenant_id",
"app_id",
"triggered_from",
"creator_user_id",
"creator_user_role",
}
@patch("core.repositories.sqlalchemy_workflow_node_execution_repository.save_workflow_node_execution_data_task")
def test_save_execution_data_queues_celery_task_when_async_persistence_enabled(
mock_task, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
repo.set_async_persistence(True)
execution = _execution(inputs={"a": 1})
repo.save_execution_data(execution)
mock_task.delay.assert_called_once()
call_args = mock_task.delay.call_args.kwargs
assert call_args["execution_data"] == execution.model_dump()
assert call_args["tenant_id"] == "tenant"
assert call_args["app_id"] == "app"
assert call_args["triggered_from"] == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
assert call_args["creator_user_id"] == "user"
def test_queue_async_save_requires_context(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
execution = _execution()
repo._triggered_from = None
with pytest.raises(ValueError, match="triggered_from is required"):
repo._queue_async_save(execution)
repo._triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
repo._creator_user_id = None
with pytest.raises(ValueError, match="created_by is required"):
repo._queue_async_save(execution)
repo._creator_user_id = "user"
repo._creator_user_role = None
with pytest.raises(ValueError, match="created_by_role is required"):
repo._queue_async_save(execution)
def test_queue_async_save_execution_data_requires_context(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
lambda *_: SimpleNamespace(upload_file=Mock()),
)
repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=Mock(spec=sessionmaker),
user=_mock_account(),
app_id="app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
execution = _execution()
repo._triggered_from = None
with pytest.raises(ValueError, match="triggered_from is required"):
repo._queue_async_save_execution_data(execution)
repo._triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
repo._creator_user_id = None
with pytest.raises(ValueError, match="created_by is required"):
repo._queue_async_save_execution_data(execution)
repo._creator_user_id = "user"
repo._creator_user_role = None
with pytest.raises(ValueError, match="created_by_role is required"):
repo._queue_async_save_execution_data(execution)
def test_save_retries_duplicate_and_logs_non_duplicate(
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
) -> None:

View File

@ -0,0 +1,78 @@
from datetime import UTC, datetime, timedelta
from graphon.entities import WorkflowExecution
from graphon.enums import WorkflowExecutionStatus, WorkflowType
from models import CreatorUserRole, WorkflowRun
from models.enums import WorkflowRunTriggeredFrom
from tasks.workflow_execution_tasks import (
_calculate_elapsed_time,
_create_workflow_run_from_execution,
_update_workflow_run_from_execution,
)
def _execution(
*,
elapsed_time: float = 3.5,
exceptions_count: int = 2,
finished_at: datetime | None = None,
) -> WorkflowExecution:
started_at = datetime(2026, 1, 1, 12, 0, 0)
return WorkflowExecution(
id_="workflow-run-id",
workflow_id="workflow-id",
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input": "value"},
outputs={"output": "value"},
status=WorkflowExecutionStatus.SUCCEEDED,
error_message="",
elapsed_time=elapsed_time,
total_tokens=100,
total_steps=5,
exceptions_count=exceptions_count,
started_at=started_at,
finished_at=finished_at,
)
def test_create_workflow_run_calculates_elapsed_time_and_exceptions_count() -> None:
execution = _execution(finished_at=datetime(2026, 1, 1, 12, 0, 12), exceptions_count=3)
workflow_run = _create_workflow_run_from_execution(
execution=execution,
tenant_id="tenant-id",
app_id="app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
creator_user_id="user-id",
creator_user_role=CreatorUserRole.ACCOUNT,
)
assert workflow_run.elapsed_time == 12.0
assert workflow_run.exceptions_count == 3
def test_update_workflow_run_calculates_elapsed_time_and_exceptions_count() -> None:
workflow_run = WorkflowRun()
execution = _execution(finished_at=datetime(2026, 1, 1, 12, 0, 8), exceptions_count=4)
_update_workflow_run_from_execution(workflow_run, execution)
assert workflow_run.elapsed_time == 8.0
assert workflow_run.exceptions_count == 4
def test_calculate_elapsed_time_uses_runtime_elapsed_time_until_finished() -> None:
execution = _execution(finished_at=None)
execution.started_at = datetime.now(UTC).replace(tzinfo=None) - timedelta(seconds=4)
elapsed_time = _calculate_elapsed_time(execution)
assert 3.9 <= elapsed_time <= 5.0
def test_calculate_elapsed_time_clamps_negative_duration_to_zero() -> None:
execution = _execution(finished_at=datetime(2026, 1, 1, 11, 59, 59))
assert _calculate_elapsed_time(execution) == 0.0

View File

@ -1,488 +1,327 @@
# """
# Unit tests for workflow node execution Celery tasks.
# These tests verify the asynchronous storage functionality for workflow node execution data,
# including truncation and offloading logic.
# """
# import json
# from unittest.mock import MagicMock, Mock, patch
# from uuid import uuid4
# import pytest
# from graphon.entities.workflow_node_execution import (
# WorkflowNodeExecution,
# WorkflowNodeExecutionStatus,
# )
# from graphon.enums import BuiltinNodeTypes
# from libs.datetime_utils import naive_utc_now
# from models import WorkflowNodeExecutionModel
# from models.enums import ExecutionOffLoadType
# from models.model import UploadFile
# from models.workflow import WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom
# from tasks.workflow_node_execution_tasks import (
# _create_truncator,
# _json_encode,
# _replace_or_append_offload,
# _truncate_and_upload_async,
# save_workflow_node_execution_data_task,
# save_workflow_node_execution_task,
# )
# @pytest.fixture
# def sample_execution_data():
# """Sample execution data for testing."""
# execution = WorkflowNodeExecution(
# id=str(uuid4()),
# node_execution_id=str(uuid4()),
# workflow_id=str(uuid4()),
# workflow_execution_id=str(uuid4()),
# index=1,
# node_id="test_node",
# node_type=BuiltinNodeTypes.LLM,
# title="Test Node",
# inputs={"input_key": "input_value"},
# outputs={"output_key": "output_value"},
# process_data={"process_key": "process_value"},
# status=WorkflowNodeExecutionStatus.RUNNING,
# created_at=naive_utc_now(),
# )
# return execution.model_dump()
# @pytest.fixture
# def mock_db_model():
# """Mock database model for testing."""
# db_model = Mock(spec=WorkflowNodeExecutionModel)
# db_model.id = "test-execution-id"
# db_model.offload_data = []
# return db_model
# @pytest.fixture
# def mock_file_service():
# """Mock file service for testing."""
# file_service = Mock()
# mock_upload_file = Mock(spec=UploadFile)
# mock_upload_file.id = "mock-file-id"
# file_service.upload_file.return_value = mock_upload_file
# return file_service
# class TestSaveWorkflowNodeExecutionDataTask:
# """Test cases for save_workflow_node_execution_data_task."""
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
# @patch("tasks.workflow_node_execution_tasks.select")
# def test_save_execution_data_task_success(
# self, mock_select, mock_sessionmaker, sample_execution_data, mock_db_model
# ):
# """Test successful execution of save_workflow_node_execution_data_task."""
# # Setup mocks
# mock_session = MagicMock()
# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session
# mock_session.execute.return_value.scalars.return_value.first.return_value = mock_db_model
# # Execute task
# result = save_workflow_node_execution_data_task(
# execution_data=sample_execution_data,
# tenant_id="test-tenant-id",
# app_id="test-app-id",
# user_data={"user_id": "test-user-id", "user_type": "account"},
# )
# # Verify success
# assert result is True
# mock_session.merge.assert_called_once_with(mock_db_model)
# mock_session.commit.assert_called_once()
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
# @patch("tasks.workflow_node_execution_tasks.select")
# def test_save_execution_data_task_execution_not_found(self, mock_select, mock_sessionmaker,
# sample_execution_data):
# """Test task when execution is not found in database."""
# # Setup mocks
# mock_session = MagicMock()
# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session
# mock_session.execute.return_value.scalars.return_value.first.return_value = None
# # Execute task
# result = save_workflow_node_execution_data_task(
# execution_data=sample_execution_data,
# tenant_id="test-tenant-id",
# app_id="test-app-id",
# user_data={"user_id": "test-user-id", "user_type": "account"},
# )
# # Verify failure
# assert result is False
# mock_session.merge.assert_not_called()
# mock_session.commit.assert_not_called()
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
# @patch("tasks.workflow_node_execution_tasks.select")
# def test_save_execution_data_task_with_truncation(self, mock_select, mock_sessionmaker, mock_db_model):
# """Test task with data that requires truncation."""
# # Create execution with large data
# large_data = {"large_field": "x" * 10000}
# execution = WorkflowNodeExecution(
# id=str(uuid4()),
# node_execution_id=str(uuid4()),
# workflow_id=str(uuid4()),
# workflow_execution_id=str(uuid4()),
# index=1,
# node_id="test_node",
# node_type=BuiltinNodeTypes.LLM,
# title="Test Node",
# inputs=large_data,
# outputs=large_data,
# process_data=large_data,
# status=WorkflowNodeExecutionStatus.RUNNING,
# created_at=naive_utc_now(),
# )
# execution_data = execution.model_dump()
# # Setup mocks
# mock_session = MagicMock()
# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session
# mock_session.execute.return_value.scalars.return_value.first.return_value = mock_db_model
# # Create mock upload file
# mock_upload_file = Mock(spec=UploadFile)
# mock_upload_file.id = "mock-file-id"
# # Execute task
# with patch("tasks.workflow_node_execution_tasks._truncate_and_upload_async") as mock_truncate:
# # Mock truncation results
# mock_truncate.return_value = {
# "truncated_value": {"large_field": "[TRUNCATED]"},
# "file": mock_upload_file,
# "offload": WorkflowNodeExecutionOffload(
# id=str(uuid4()),
# tenant_id="test-tenant-id",
# app_id="test-app-id",
# node_execution_id=execution.id,
# type_=ExecutionOffLoadType.INPUTS,
# file_id=mock_upload_file.id,
# ),
# }
# result = save_workflow_node_execution_data_task(
# execution_data=execution_data,
# tenant_id="test-tenant-id",
# app_id="test-app-id",
# user_data={"user_id": "test-user-id", "user_type": "account"},
# )
# # Verify success and truncation was called
# assert result is True
# assert mock_truncate.call_count == 3 # inputs, outputs, process_data
# mock_session.merge.assert_called_once_with(mock_db_model)
# mock_session.commit.assert_called_once()
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
# def test_save_execution_data_task_retry_on_exception(self, mock_sessionmaker, sample_execution_data):
# """Test task retry mechanism on exception."""
# # Setup mock to raise exception
# mock_sessionmaker.side_effect = Exception("Database error")
# # Create a mock task instance with proper retry behavior
# with patch.object(save_workflow_node_execution_data_task, "retry") as mock_retry:
# mock_retry.side_effect = Exception("Retry called")
# # Execute task and expect retry
# with pytest.raises(Exception, match="Retry called"):
# save_workflow_node_execution_data_task(
# execution_data=sample_execution_data,
# tenant_id="test-tenant-id",
# app_id="test-app-id",
# user_data={"user_id": "test-user-id", "user_type": "account"},
# )
# # Verify retry was called
# mock_retry.assert_called_once()
# class TestTruncateAndUploadAsync:
# """Test cases for _truncate_and_upload_async function."""
# def test_truncate_and_upload_with_none_values(self, mock_file_service):
# """Test _truncate_and_upload_async with None values."""
# # The function handles None values internally, so we test with empty dict instead
# result = _truncate_and_upload_async(
# values={},
# execution_id="test-id",
# type_=ExecutionOffLoadType.INPUTS,
# tenant_id="test-tenant",
# app_id="test-app",
# user_data={"user_id": "test-user", "user_type": "account"},
# file_service=mock_file_service,
# )
# # Empty dict should not require truncation
# assert result is None
# mock_file_service.upload_file.assert_not_called()
# @patch("tasks.workflow_node_execution_tasks._create_truncator")
# def test_truncate_and_upload_no_truncation_needed(self, mock_create_truncator, mock_file_service):
# """Test _truncate_and_upload_async when no truncation is needed."""
# # Mock truncator to return no truncation
# mock_truncator = Mock()
# mock_truncator.truncate_variable_mapping.return_value = ({"small": "data"}, False)
# mock_create_truncator.return_value = mock_truncator
# small_values = {"small": "data"}
# result = _truncate_and_upload_async(
# values=small_values,
# execution_id="test-id",
# type_=ExecutionOffLoadType.INPUTS,
# tenant_id="test-tenant",
# app_id="test-app",
# user_data={"user_id": "test-user", "user_type": "account"},
# file_service=mock_file_service,
# )
# assert result is None
# mock_file_service.upload_file.assert_not_called()
# @patch("tasks.workflow_node_execution_tasks._create_truncator")
# @patch("models.Account")
# @patch("models.Tenant")
# def test_truncate_and_upload_with_account_user(
# self, mock_tenant_class, mock_account_class, mock_create_truncator, mock_file_service
# ):
# """Test _truncate_and_upload_async with account user."""
# # Mock truncator to return truncation needed
# mock_truncator = Mock()
# mock_truncator.truncate_variable_mapping.return_value = ({"truncated": "data"}, True)
# mock_create_truncator.return_value = mock_truncator
# # Mock user and tenant creation
# mock_account = Mock()
# mock_account.id = "test-user"
# mock_account_class.return_value = mock_account
# mock_tenant = Mock()
# mock_tenant.id = "test-tenant"
# mock_tenant_class.return_value = mock_tenant
# large_values = {"large": "x" * 10000}
# result = _truncate_and_upload_async(
# values=large_values,
# execution_id="test-id",
# type_=ExecutionOffLoadType.INPUTS,
# tenant_id="test-tenant",
# app_id="test-app",
# user_data={"user_id": "test-user", "user_type": "account"},
# file_service=mock_file_service,
# )
# # Verify result structure
# assert result is not None
# assert "truncated_value" in result
# assert "file" in result
# assert "offload" in result
# assert result["truncated_value"] == {"truncated": "data"}
# # Verify file upload was called
# mock_file_service.upload_file.assert_called_once()
# upload_call = mock_file_service.upload_file.call_args
# assert upload_call[1]["filename"] == "node_execution_test-id_inputs.json"
# assert upload_call[1]["mimetype"] == "application/json"
# assert upload_call[1]["user"] == mock_account
# @patch("tasks.workflow_node_execution_tasks._create_truncator")
# @patch("models.EndUser")
# def test_truncate_and_upload_with_end_user(self, mock_end_user_class, mock_create_truncator, mock_file_service):
# """Test _truncate_and_upload_async with end user."""
# # Mock truncator to return truncation needed
# mock_truncator = Mock()
# mock_truncator.truncate_variable_mapping.return_value = ({"truncated": "data"}, True)
# mock_create_truncator.return_value = mock_truncator
# # Mock end user creation
# mock_end_user = Mock()
# mock_end_user.id = "test-user"
# mock_end_user.tenant_id = "test-tenant"
# mock_end_user_class.return_value = mock_end_user
# large_values = {"large": "x" * 10000}
# result = _truncate_and_upload_async(
# values=large_values,
# execution_id="test-id",
# type_=ExecutionOffLoadType.OUTPUTS,
# tenant_id="test-tenant",
# app_id="test-app",
# user_data={"user_id": "test-user", "user_type": "end_user"},
# file_service=mock_file_service,
# )
# # Verify result structure
# assert result is not None
# assert result["truncated_value"] == {"truncated": "data"}
# # Verify file upload was called with end user
# mock_file_service.upload_file.assert_called_once()
# upload_call = mock_file_service.upload_file.call_args
# assert upload_call[1]["filename"] == "node_execution_test-id_outputs.json"
# assert upload_call[1]["user"] == mock_end_user
# class TestHelperFunctions:
# """Test cases for helper functions."""
# @patch("tasks.workflow_node_execution_tasks.dify_config")
# def test_create_truncator(self, mock_config):
# """Test _create_truncator function."""
# mock_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE = 1000
# mock_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH = 100
# mock_config.WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH = 500
# truncator = _create_truncator()
# # Verify truncator was created with correct config
# assert truncator is not None
# def test_json_encode(self):
# """Test _json_encode function."""
# test_data = {"key": "value", "number": 42}
# result = _json_encode(test_data)
# assert isinstance(result, str)
# decoded = json.loads(result)
# assert decoded == test_data
# def test_replace_or_append_offload_replace_existing(self):
# """Test _replace_or_append_offload replaces existing offload of same type."""
# existing_offload = WorkflowNodeExecutionOffload(
# id=str(uuid4()),
# tenant_id="test-tenant",
# app_id="test-app",
# node_execution_id="test-execution",
# type_=ExecutionOffLoadType.INPUTS,
# file_id="old-file-id",
# )
# new_offload = WorkflowNodeExecutionOffload(
# id=str(uuid4()),
# tenant_id="test-tenant",
# app_id="test-app",
# node_execution_id="test-execution",
# type_=ExecutionOffLoadType.INPUTS,
# file_id="new-file-id",
# )
# result = _replace_or_append_offload([existing_offload], new_offload)
# assert len(result) == 1
# assert result[0].file_id == "new-file-id"
# def test_replace_or_append_offload_append_new_type(self):
# """Test _replace_or_append_offload appends new offload of different type."""
# existing_offload = WorkflowNodeExecutionOffload(
# id=str(uuid4()),
# tenant_id="test-tenant",
# app_id="test-app",
# node_execution_id="test-execution",
# type_=ExecutionOffLoadType.INPUTS,
# file_id="inputs-file-id",
# )
# new_offload = WorkflowNodeExecutionOffload(
# id=str(uuid4()),
# tenant_id="test-tenant",
# app_id="test-app",
# node_execution_id="test-execution",
# type_=ExecutionOffLoadType.OUTPUTS,
# file_id="outputs-file-id",
# )
# result = _replace_or_append_offload([existing_offload], new_offload)
# assert len(result) == 2
# file_ids = [offload.file_id for offload in result]
# assert "inputs-file-id" in file_ids
# assert "outputs-file-id" in file_ids
# class TestSaveWorkflowNodeExecutionTask:
# """Test cases for save_workflow_node_execution_task."""
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
# @patch("tasks.workflow_node_execution_tasks.select")
# def test_save_workflow_node_execution_task_create_new(self, mock_select, mock_sessionmaker,
# sample_execution_data):
# """Test creating a new workflow node execution."""
# # Setup mocks
# mock_session = MagicMock()
# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session
# mock_session.scalar.return_value = None # No existing execution
# # Execute task
# result = save_workflow_node_execution_task(
# execution_data=sample_execution_data,
# tenant_id="test-tenant-id",
# app_id="test-app-id",
# triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
# creator_user_id="test-user-id",
# creator_user_role="account",
# )
# # Verify success
# assert result is True
# mock_session.add.assert_called_once()
# mock_session.commit.assert_called_once()
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
# @patch("tasks.workflow_node_execution_tasks.select")
# def test_save_workflow_node_execution_task_update_existing(
# self, mock_select, mock_sessionmaker, sample_execution_data
# ):
# """Test updating an existing workflow node execution."""
# # Setup mocks
# mock_session = MagicMock()
# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session
# existing_execution = Mock(spec=WorkflowNodeExecutionModel)
# mock_session.scalar.return_value = existing_execution
# # Execute task
# result = save_workflow_node_execution_task(
# execution_data=sample_execution_data,
# tenant_id="test-tenant-id",
# app_id="test-app-id",
# triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
# creator_user_id="test-user-id",
# creator_user_role="account",
# )
# # Verify success
# assert result is True
# mock_session.add.assert_not_called() # Should not add new, just update existing
# mock_session.commit.assert_called_once()
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
# def test_save_workflow_node_execution_task_retry_on_exception(self, mock_sessionmaker, sample_execution_data):
# """Test task retry mechanism on exception."""
# # Setup mock to raise exception
# mock_sessionmaker.side_effect = Exception("Database error")
# # Create a mock task instance with proper retry behavior
# with patch.object(save_workflow_node_execution_task, "retry") as mock_retry:
# mock_retry.side_effect = Exception("Retry called")
# # Execute task and expect retry
# with pytest.raises(Exception, match="Retry called"):
# save_workflow_node_execution_task(
# execution_data=sample_execution_data,
# tenant_id="test-tenant-id",
# app_id="test-app-id",
# triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
# creator_user_id="test-user-id",
# creator_user_role="account",
# )
# # Verify retry was called
# mock_retry.assert_called_once()
from collections.abc import Mapping
from datetime import UTC, datetime
from unittest.mock import Mock, patch
import pytest
from graphon.entities.workflow_node_execution import WorkflowNodeExecution
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from models import Account, EndUser
from models.enums import CreatorUserRole
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom
from tasks.workflow_node_execution_tasks import (
_create_node_execution_from_domain,
_create_sqlalchemy_repository,
_update_node_execution_metadata,
save_workflow_node_execution_data_task,
save_workflow_node_execution_task,
)
def _execution(
*,
metadata: Mapping[WorkflowNodeExecutionMetadataKey, object] | None = None,
) -> WorkflowNodeExecution:
if metadata is None:
metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 10}
return WorkflowNodeExecution(
id="exec-id",
node_execution_id="node-exec-id",
workflow_id="workflow-id",
workflow_execution_id="run-id",
index=1,
node_id="node-id",
node_type=BuiltinNodeTypes.LLM,
title="LLM",
inputs={"input": "value"},
process_data={"process": "value"},
outputs={"output": "value"},
status=WorkflowNodeExecutionStatus.SUCCEEDED,
metadata=metadata,
created_at=datetime.now(UTC).replace(tzinfo=None),
finished_at=datetime.now(UTC).replace(tzinfo=None),
)
def test_create_node_execution_persists_metadata_without_data_payloads() -> None:
db_model = _create_node_execution_from_domain(
execution=_execution(),
tenant_id="tenant-id",
app_id="app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
creator_user_id="user-id",
creator_user_role=CreatorUserRole.ACCOUNT,
)
assert db_model.inputs == "{}"
assert db_model.process_data == "{}"
assert db_model.outputs == "{}"
assert db_model.execution_metadata == '{"total_tokens": 10}'
def test_create_node_execution_defaults_empty_metadata() -> None:
db_model = _create_node_execution_from_domain(
execution=_execution(metadata={}),
tenant_id="tenant-id",
app_id="app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
creator_user_id="user-id",
creator_user_role=CreatorUserRole.ACCOUNT,
)
assert db_model.execution_metadata == "{}"
def test_update_node_execution_metadata_preserves_data_payloads() -> None:
db_model = WorkflowNodeExecutionModel()
db_model.inputs = '{"old_input": true}'
db_model.process_data = '{"old_process": true}'
db_model.outputs = '{"old_output": true}'
_update_node_execution_metadata(db_model, _execution())
assert db_model.inputs == '{"old_input": true}'
assert db_model.process_data == '{"old_process": true}'
assert db_model.outputs == '{"old_output": true}'
assert db_model.status == WorkflowNodeExecutionStatus.SUCCEEDED
def test_update_node_execution_metadata_defaults_empty_metadata() -> None:
db_model = WorkflowNodeExecutionModel()
_update_node_execution_metadata(db_model, _execution(metadata={}))
assert db_model.execution_metadata == "{}"
@patch("tasks.workflow_node_execution_tasks._create_sqlalchemy_repository")
def test_save_workflow_node_execution_data_task_uses_sqlalchemy_repository(mock_create_repository: Mock) -> None:
repository = Mock()
mock_create_repository.return_value = repository
execution = _execution()
result = save_workflow_node_execution_data_task.run(
execution_data=execution.model_dump(),
tenant_id="tenant-id",
app_id="app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
creator_user_id="user-id",
creator_user_role=CreatorUserRole.ACCOUNT.value,
)
assert result is True
mock_create_repository.assert_called_once_with(
tenant_id="tenant-id",
app_id="app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
creator_user_id="user-id",
creator_user_role=CreatorUserRole.ACCOUNT.value,
)
saved_execution = repository.save.call_args.args[0]
saved_data_execution = repository.save_execution_data.call_args.args[0]
assert saved_execution.model_dump() == execution.model_dump()
assert saved_data_execution.model_dump() == execution.model_dump()
@patch("tasks.workflow_node_execution_tasks._create_sqlalchemy_repository")
def test_save_workflow_node_execution_data_task_retries_on_failure(mock_create_repository: Mock) -> None:
mock_create_repository.side_effect = RuntimeError("db unavailable")
execution = _execution()
with (
patch.object(
save_workflow_node_execution_data_task,
"retry",
side_effect=RuntimeError("retry requested"),
) as retry,
pytest.raises(RuntimeError, match="retry requested"),
):
save_workflow_node_execution_data_task.run(
execution_data=execution.model_dump(),
tenant_id="tenant-id",
app_id="app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
creator_user_id="user-id",
creator_user_role=CreatorUserRole.ACCOUNT.value,
)
retry.assert_called_once()
assert isinstance(retry.call_args.kwargs["exc"], RuntimeError)
assert retry.call_args.kwargs["countdown"] == 60
@patch("tasks.workflow_node_execution_tasks.session_factory.create_session")
def test_save_workflow_node_execution_task_creates_metadata_record(mock_create_session: Mock) -> None:
session = _TaskSession(existing_execution=None)
mock_create_session.return_value = session
result = save_workflow_node_execution_task.run(
execution_data=_execution().model_dump(),
tenant_id="tenant-id",
app_id="app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
creator_user_id="user-id",
creator_user_role=CreatorUserRole.ACCOUNT.value,
)
assert result is True
assert session.committed is True
assert isinstance(session.added_execution, WorkflowNodeExecutionModel)
assert session.added_execution.inputs == "{}"
assert session.added_execution.process_data == "{}"
assert session.added_execution.outputs == "{}"
@patch("tasks.workflow_node_execution_tasks.session_factory.create_session")
def test_save_workflow_node_execution_task_updates_metadata_without_payloads(mock_create_session: Mock) -> None:
existing_execution = WorkflowNodeExecutionModel()
existing_execution.inputs = '{"old_input": true}'
existing_execution.process_data = '{"old_process": true}'
existing_execution.outputs = '{"old_output": true}'
session = _TaskSession(existing_execution=existing_execution)
mock_create_session.return_value = session
result = save_workflow_node_execution_task.run(
execution_data=_execution().model_dump(),
tenant_id="tenant-id",
app_id="app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
creator_user_id="user-id",
creator_user_role=CreatorUserRole.ACCOUNT.value,
)
assert result is True
assert session.committed is True
assert session.added_execution is None
assert existing_execution.inputs == '{"old_input": true}'
assert existing_execution.process_data == '{"old_process": true}'
assert existing_execution.outputs == '{"old_output": true}'
assert existing_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED
def test_create_sqlalchemy_repository_builds_account_context(monkeypatch) -> None:
account = Mock()
session = _Session({Account: account})
def session_maker():
return session
monkeypatch.setattr(
"tasks.workflow_node_execution_tasks.session_factory.get_session_maker",
lambda: session_maker,
)
with patch(
"core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository"
) as repository_class:
repository = _create_sqlalchemy_repository(
tenant_id="tenant-id",
app_id="",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
creator_user_id="user-id",
creator_user_role=CreatorUserRole.ACCOUNT.value,
)
assert repository == repository_class.return_value
account.set_tenant_id.assert_called_once_with("tenant-id")
repository_class.assert_called_once_with(
session_factory=session_maker,
user=account,
app_id=None,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
def test_create_sqlalchemy_repository_builds_end_user_context(monkeypatch) -> None:
end_user = Mock()
session = _Session({EndUser: end_user})
def session_maker():
return session
monkeypatch.setattr(
"tasks.workflow_node_execution_tasks.session_factory.get_session_maker",
lambda: session_maker,
)
with patch(
"core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository"
) as repository_class:
repository = _create_sqlalchemy_repository(
tenant_id="tenant-id",
app_id="app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
creator_user_id="user-id",
creator_user_role=CreatorUserRole.END_USER.value,
)
assert repository == repository_class.return_value
repository_class.assert_called_once_with(
session_factory=session_maker,
user=end_user,
app_id="app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
def test_create_sqlalchemy_repository_raises_for_missing_creator(monkeypatch) -> None:
session = _Session({})
def session_maker():
return session
monkeypatch.setattr(
"tasks.workflow_node_execution_tasks.session_factory.get_session_maker",
lambda: session_maker,
)
with (
patch(
"core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository"
),
pytest.raises(ValueError, match="Creator user missing-user not found"),
):
_create_sqlalchemy_repository(
tenant_id="tenant-id",
app_id="app-id",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
creator_user_id="missing-user",
creator_user_role=CreatorUserRole.ACCOUNT.value,
)
class _Session:
def __init__(self, users: dict[type, object]) -> None:
self._users = users
def __enter__(self):
return self
def __exit__(self, exc_type, exc, traceback) -> None:
return None
def get(self, model, _id: str):
return self._users.get(model)
class _TaskSession:
def __init__(self, existing_execution: WorkflowNodeExecutionModel | None) -> None:
self._existing_execution = existing_execution
self.added_execution: WorkflowNodeExecutionModel | None = None
self.committed = False
def __enter__(self):
return self
def __exit__(self, exc_type, exc, traceback) -> None:
return None
def scalar(self, _stmt):
return self._existing_execution
def add(self, execution: WorkflowNodeExecutionModel) -> None:
self.added_execution = execution
def commit(self) -> None:
self.committed = True

View File

@ -330,7 +330,6 @@ BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO=0.05
BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS=300
HUAWEI_CLOUD_HOSTS=https://127.0.0.1:9200
HUAWEI_CLOUD_USER=admin
WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository