mirror of
https://github.com/langgenius/dify.git
synced 2026-05-21 09:17:27 +08:00
Compare commits
5 Commits
main
...
laipz8200/
| Author | SHA1 | Date | |
|---|---|---|---|
| e2bd7dbd0d | |||
| 37636f78f5 | |||
| b7834a42b6 | |||
| 21387b3beb | |||
| e4619b5b73 |
@ -565,12 +565,6 @@ GRAPH_ENGINE_SCALE_UP_THRESHOLD=3
|
||||
# Seconds of idle time before scaling down workers (default: 5.0)
|
||||
GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME=5.0
|
||||
|
||||
# Workflow storage configuration
|
||||
# Options: rdbms, hybrid
|
||||
# rdbms: Use only the relational database (default)
|
||||
# hybrid: Save new data to object storage, read from both object storage and RDBMS
|
||||
WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
|
||||
|
||||
# Repository configuration
|
||||
# Core workflow execution repository implementation
|
||||
CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository
|
||||
|
||||
@ -791,11 +791,6 @@ class WorkflowNodeExecutionConfig(BaseSettings):
|
||||
default=100,
|
||||
)
|
||||
|
||||
WORKFLOW_NODE_EXECUTION_STORAGE: str = Field(
|
||||
default="rdbms",
|
||||
description="Storage backend for WorkflowNodeExecution. Options: 'rdbms', 'hybrid'",
|
||||
)
|
||||
|
||||
|
||||
class RepositoryConfig(BaseSettings):
|
||||
"""
|
||||
@ -803,18 +798,12 @@ class RepositoryConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
CORE_WORKFLOW_EXECUTION_REPOSITORY: str = Field(
|
||||
description="Repository implementation for WorkflowExecution. Options: "
|
||||
"'core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository' (default), "
|
||||
"'core.repositories.celery_workflow_execution_repository.CeleryWorkflowExecutionRepository'",
|
||||
description="Repository implementation for WorkflowExecution. Specify as a module path.",
|
||||
default="core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository",
|
||||
)
|
||||
|
||||
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field(
|
||||
description="Repository implementation for WorkflowNodeExecution. Options: "
|
||||
"'core.repositories.sqlalchemy_workflow_node_execution_repository."
|
||||
"SQLAlchemyWorkflowNodeExecutionRepository' (default), "
|
||||
"'core.repositories.celery_workflow_node_execution_repository."
|
||||
"CeleryWorkflowNodeExecutionRepository'",
|
||||
description="Repository implementation for WorkflowNodeExecution. Specify as a module path.",
|
||||
default="core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository",
|
||||
)
|
||||
|
||||
|
||||
@ -235,6 +235,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
),
|
||||
workflow_execution_repository=self._workflow_execution_repository,
|
||||
workflow_node_execution_repository=self._workflow_node_execution_repository,
|
||||
invoke_from=invoke_from,
|
||||
trace_manager=self.application_generate_entity.trace_manager,
|
||||
)
|
||||
|
||||
|
||||
@ -195,6 +195,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
),
|
||||
workflow_execution_repository=self._workflow_execution_repository,
|
||||
workflow_node_execution_repository=self._workflow_node_execution_repository,
|
||||
invoke_from=invoke_from,
|
||||
trace_manager=self.application_generate_entity.trace_manager,
|
||||
)
|
||||
|
||||
|
||||
@ -162,6 +162,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
),
|
||||
workflow_execution_repository=self._workflow_execution_repository,
|
||||
workflow_node_execution_repository=self._workflow_node_execution_repository,
|
||||
invoke_from=invoke_from,
|
||||
trace_manager=self.application_generate_entity.trace_manager,
|
||||
)
|
||||
|
||||
|
||||
@ -4,17 +4,19 @@ This layer mirrors the former ``WorkflowCycleManager`` responsibilities by
|
||||
listening to ``GraphEngineEvent`` instances directly and persisting workflow
|
||||
and node execution state via the injected repositories.
|
||||
|
||||
The design keeps domain persistence concerns inside the engine thread, while
|
||||
allowing presentation layers to remain read-only observers of repository
|
||||
state.
|
||||
The layer owns domain-to-persistence event handling, while the injected
|
||||
repositories choose the write strategy. Debug executions use synchronous
|
||||
writes so developer tools can read DB state immediately; non-debug app
|
||||
executions use a Celery-backed write path to keep DB writes out of the engine
|
||||
thread.
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Union
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.helper.trace_id_helper import ParentTraceContext
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
@ -59,6 +61,21 @@ class PersistenceWorkflowInfo:
|
||||
graph_data: Mapping[str, Any]
|
||||
|
||||
|
||||
def should_use_async_workflow_persistence(invoke_from: InvokeFrom) -> bool:
|
||||
"""Return whether workflow execution state should be persisted through Celery."""
|
||||
return invoke_from != InvokeFrom.DEBUGGER
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AsyncPersistenceConfigurable(Protocol):
|
||||
def set_async_persistence(self, enabled: bool) -> None: ...
|
||||
|
||||
|
||||
def _configure_async_persistence(repository: object, enabled: bool) -> None:
|
||||
if isinstance(repository, AsyncPersistenceConfigurable):
|
||||
repository.set_async_persistence(enabled)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _NodeRuntimeSnapshot:
|
||||
"""Lightweight cache to keep node metadata across event phases."""
|
||||
@ -77,10 +94,11 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity | WorkflowAppGenerateEntity,
|
||||
workflow_info: PersistenceWorkflowInfo,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
invoke_from: InvokeFrom | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -88,6 +106,11 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
self._workflow_info = workflow_info
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
use_async_persistence = should_use_async_workflow_persistence(
|
||||
invoke_from or application_generate_entity.invoke_from
|
||||
)
|
||||
_configure_async_persistence(self._workflow_execution_repository, use_async_persistence)
|
||||
_configure_async_persistence(self._workflow_node_execution_repository, use_async_persistence)
|
||||
self._trace_manager = trace_manager
|
||||
|
||||
self._workflow_execution: WorkflowExecution | None = None
|
||||
|
||||
@ -2,8 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
|
||||
from .celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
|
||||
from .factory import (
|
||||
DifyCoreRepositoryFactory,
|
||||
OrderConfig,
|
||||
@ -15,8 +13,6 @@ from .sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutio
|
||||
from .sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
__all__ = [
|
||||
"CeleryWorkflowExecutionRepository",
|
||||
"CeleryWorkflowNodeExecutionRepository",
|
||||
"DifyCoreRepositoryFactory",
|
||||
"OrderConfig",
|
||||
"RepositoryImportError",
|
||||
|
||||
@ -1,125 +0,0 @@
|
||||
"""
|
||||
Celery-based implementation of the WorkflowExecutionRepository.
|
||||
|
||||
This implementation uses Celery tasks for asynchronous storage operations,
|
||||
providing improved performance by offloading database operations to background workers.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.repositories.factory import WorkflowExecutionRepository
|
||||
from graphon.entities import WorkflowExecution
|
||||
from libs.helper import extract_tenant_id
|
||||
from models import Account, CreatorUserRole, EndUser
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from tasks.workflow_execution_tasks import (
|
||||
save_workflow_execution_task,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
"""
|
||||
Celery-based implementation of the WorkflowExecutionRepository interface.
|
||||
|
||||
This implementation provides asynchronous storage capabilities by using Celery tasks
|
||||
to handle database operations in background workers. This improves performance by
|
||||
reducing the blocking time for workflow execution storage operations.
|
||||
|
||||
Key features:
|
||||
- Asynchronous save operations using Celery tasks
|
||||
- Support for multi-tenancy through tenant/app filtering
|
||||
- Automatic retry and error handling through Celery
|
||||
"""
|
||||
|
||||
_session_factory: sessionmaker
|
||||
_tenant_id: str
|
||||
_app_id: str | None
|
||||
_triggered_from: WorkflowRunTriggeredFrom | None
|
||||
_creator_user_id: str
|
||||
_creator_user_role: CreatorUserRole
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: sessionmaker | Engine,
|
||||
user: Account | EndUser,
|
||||
app_id: str | None,
|
||||
triggered_from: WorkflowRunTriggeredFrom | None,
|
||||
):
|
||||
"""
|
||||
Initialize the repository with Celery task configuration and context information.
|
||||
|
||||
Args:
|
||||
session_factory: SQLAlchemy sessionmaker or engine for fallback operations
|
||||
user: Account or EndUser object containing tenant_id, user ID, and role information
|
||||
app_id: App ID for filtering by application (can be None)
|
||||
triggered_from: Source of the execution trigger (DEBUGGING or APP_RUN)
|
||||
"""
|
||||
# Store session factory for fallback operations
|
||||
if isinstance(session_factory, Engine):
|
||||
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
|
||||
elif isinstance(session_factory, sessionmaker):
|
||||
self._session_factory = session_factory
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
|
||||
)
|
||||
|
||||
# Extract tenant_id from user
|
||||
tenant_id = extract_tenant_id(user)
|
||||
if not tenant_id:
|
||||
raise ValueError("User must have a tenant_id or current_tenant_id")
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
# Store app context
|
||||
self._app_id = app_id
|
||||
|
||||
# Extract user context
|
||||
self._triggered_from = triggered_from
|
||||
self._creator_user_id = user.id
|
||||
|
||||
# Determine user role based on user type
|
||||
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
|
||||
|
||||
logger.info(
|
||||
"Initialized CeleryWorkflowExecutionRepository for tenant %s, app %s, triggered_from %s",
|
||||
self._tenant_id,
|
||||
self._app_id,
|
||||
self._triggered_from,
|
||||
)
|
||||
|
||||
def save(self, execution: WorkflowExecution):
|
||||
"""
|
||||
Save or update a WorkflowExecution instance asynchronously using Celery.
|
||||
|
||||
This method queues the save operation as a Celery task and returns immediately,
|
||||
providing improved performance for high-throughput scenarios.
|
||||
|
||||
Args:
|
||||
execution: The WorkflowExecution instance to save or update
|
||||
"""
|
||||
try:
|
||||
# Serialize execution for Celery task
|
||||
execution_data = execution.model_dump()
|
||||
|
||||
# Queue the save operation as a Celery task (fire and forget)
|
||||
save_workflow_execution_task.delay( # type: ignore
|
||||
execution_data=execution_data,
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id or "",
|
||||
triggered_from=self._triggered_from.value if self._triggered_from else "",
|
||||
creator_user_id=self._creator_user_id,
|
||||
creator_user_role=self._creator_user_role.value,
|
||||
)
|
||||
|
||||
logger.debug("Queued async save for workflow execution: %s", execution.id_)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to queue save operation for execution %s", execution.id_)
|
||||
# In case of Celery failure, we could implement a fallback to synchronous save
|
||||
# For now, we'll re-raise the exception
|
||||
raise
|
||||
@ -1,196 +0,0 @@
|
||||
"""
|
||||
Celery-based implementation of the WorkflowNodeExecutionRepository.
|
||||
|
||||
This implementation uses Celery tasks for asynchronous storage operations,
|
||||
providing improved performance by offloading database operations to background workers.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.repositories.factory import (
|
||||
OrderConfig,
|
||||
WorkflowNodeExecutionRepository,
|
||||
)
|
||||
from graphon.entities import WorkflowNodeExecution
|
||||
from libs.helper import extract_tenant_id
|
||||
from models import Account, CreatorUserRole, EndUser
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
from tasks.workflow_node_execution_tasks import (
|
||||
save_workflow_node_execution_task,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
|
||||
"""
|
||||
Celery-based implementation of the WorkflowNodeExecutionRepository interface.
|
||||
|
||||
This implementation provides asynchronous storage capabilities by using Celery tasks
|
||||
to handle database operations in background workers. This improves performance by
|
||||
reducing the blocking time for workflow node execution storage operations.
|
||||
|
||||
Key features:
|
||||
- Asynchronous save operations using Celery tasks
|
||||
- In-memory cache for immediate reads
|
||||
- Support for multi-tenancy through tenant/app filtering
|
||||
- Automatic retry and error handling through Celery
|
||||
"""
|
||||
|
||||
_session_factory: sessionmaker
|
||||
_tenant_id: str
|
||||
_app_id: str | None
|
||||
_triggered_from: WorkflowNodeExecutionTriggeredFrom | None
|
||||
_creator_user_id: str
|
||||
_creator_user_role: CreatorUserRole
|
||||
_execution_cache: dict[str, WorkflowNodeExecution]
|
||||
_workflow_execution_mapping: dict[str, list[str]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: sessionmaker | Engine,
|
||||
user: Account | EndUser,
|
||||
app_id: str | None,
|
||||
triggered_from: WorkflowNodeExecutionTriggeredFrom | None,
|
||||
):
|
||||
"""
|
||||
Initialize the repository with Celery task configuration and context information.
|
||||
|
||||
Args:
|
||||
session_factory: SQLAlchemy sessionmaker or engine for fallback operations
|
||||
user: Account or EndUser object containing tenant_id, user ID, and role information
|
||||
app_id: App ID for filtering by application (can be None)
|
||||
triggered_from: Source of the execution trigger (SINGLE_STEP or WORKFLOW_RUN)
|
||||
"""
|
||||
# Store session factory for fallback operations
|
||||
if isinstance(session_factory, Engine):
|
||||
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
|
||||
elif isinstance(session_factory, sessionmaker):
|
||||
self._session_factory = session_factory
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
|
||||
)
|
||||
|
||||
# Extract tenant_id from user
|
||||
tenant_id = extract_tenant_id(user)
|
||||
if not tenant_id:
|
||||
raise ValueError("User must have a tenant_id or current_tenant_id")
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
# Store app context
|
||||
self._app_id = app_id
|
||||
|
||||
# Extract user context
|
||||
self._triggered_from = triggered_from
|
||||
self._creator_user_id = user.id
|
||||
|
||||
# Determine user role based on user type
|
||||
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
|
||||
|
||||
# In-memory cache for workflow node executions
|
||||
self._execution_cache = {}
|
||||
|
||||
# Cache for mapping workflow_execution_ids to execution IDs for efficient retrieval
|
||||
self._workflow_execution_mapping = {}
|
||||
|
||||
logger.info(
|
||||
"Initialized CeleryWorkflowNodeExecutionRepository for tenant %s, app %s, triggered_from %s",
|
||||
self._tenant_id,
|
||||
self._app_id,
|
||||
self._triggered_from,
|
||||
)
|
||||
|
||||
def save(self, execution: WorkflowNodeExecution):
|
||||
"""
|
||||
Save or update a WorkflowNodeExecution instance to cache and asynchronously to database.
|
||||
|
||||
This method stores the execution in cache immediately for fast reads and queues
|
||||
the save operation as a Celery task without tracking the task status.
|
||||
|
||||
Args:
|
||||
execution: The WorkflowNodeExecution instance to save or update
|
||||
"""
|
||||
try:
|
||||
# Store in cache immediately for fast reads
|
||||
self._execution_cache[execution.id] = execution
|
||||
|
||||
# Update workflow execution mapping for efficient retrieval
|
||||
if execution.workflow_execution_id:
|
||||
if execution.workflow_execution_id not in self._workflow_execution_mapping:
|
||||
self._workflow_execution_mapping[execution.workflow_execution_id] = []
|
||||
if execution.id not in self._workflow_execution_mapping[execution.workflow_execution_id]:
|
||||
self._workflow_execution_mapping[execution.workflow_execution_id].append(execution.id)
|
||||
|
||||
# Serialize execution for Celery task
|
||||
execution_data = execution.model_dump()
|
||||
|
||||
# Queue the save operation as a Celery task (fire and forget)
|
||||
save_workflow_node_execution_task.delay(
|
||||
execution_data=execution_data,
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id or "",
|
||||
triggered_from=self._triggered_from.value if self._triggered_from else "",
|
||||
creator_user_id=self._creator_user_id,
|
||||
creator_user_role=self._creator_user_role.value,
|
||||
)
|
||||
|
||||
logger.debug("Cached and queued async save for workflow node execution: %s", execution.id)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to cache or queue save operation for node execution %s", execution.id)
|
||||
# In case of Celery failure, we could implement a fallback to synchronous save
|
||||
# For now, we'll re-raise the exception
|
||||
raise
|
||||
|
||||
def get_by_workflow_execution(
|
||||
self,
|
||||
workflow_execution_id: str,
|
||||
order_config: OrderConfig | None = None,
|
||||
) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all workflow node executions for a workflow execution from cache.
|
||||
|
||||
Args:
|
||||
workflow_execution_id: The workflow execution identifier
|
||||
order_config: Optional configuration for ordering results
|
||||
|
||||
Returns:
|
||||
A sequence of WorkflowNodeExecution instances
|
||||
"""
|
||||
try:
|
||||
# Get execution IDs for this workflow execution from cache
|
||||
execution_ids = self._workflow_execution_mapping.get(workflow_execution_id, [])
|
||||
|
||||
# Retrieve executions from cache
|
||||
result = []
|
||||
for execution_id in execution_ids:
|
||||
if execution_id in self._execution_cache:
|
||||
result.append(self._execution_cache[execution_id])
|
||||
|
||||
# Apply ordering if specified
|
||||
if order_config and result:
|
||||
# Sort based on the order configuration
|
||||
reverse = order_config.order_direction == "desc"
|
||||
|
||||
# Sort by multiple fields if specified
|
||||
for field_name in reversed(order_config.order_by):
|
||||
result.sort(key=lambda x: getattr(x, field_name, 0), reverse=reverse)
|
||||
|
||||
logger.debug(
|
||||
"Retrieved %d workflow node executions for execution %s from cache",
|
||||
len(result),
|
||||
workflow_execution_id,
|
||||
)
|
||||
return result
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to get workflow node executions for execution %s from cache",
|
||||
workflow_execution_id,
|
||||
)
|
||||
return []
|
||||
@ -20,6 +20,7 @@ from models import (
|
||||
WorkflowRun,
|
||||
)
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from tasks.workflow_execution_tasks import save_workflow_execution_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -36,6 +37,8 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
performance by reducing database queries.
|
||||
"""
|
||||
|
||||
_use_async_persistence: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: sessionmaker | Engine,
|
||||
@ -81,6 +84,16 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
# Initialize in-memory cache for workflow executions
|
||||
# Key: execution_id, Value: WorkflowRun (DB model)
|
||||
self._execution_cache: dict[str, WorkflowRun] = {}
|
||||
self._use_async_persistence = False
|
||||
|
||||
def set_async_persistence(self, enabled: bool) -> None:
|
||||
"""
|
||||
Configure whether save operations should be queued through Celery.
|
||||
|
||||
Debug executions keep this disabled so the debugger can immediately read persisted
|
||||
workflow state. Non-debug app executions enable it from the persistence layer.
|
||||
"""
|
||||
self._use_async_persistence = enabled
|
||||
|
||||
def _to_domain_model(self, db_model: WorkflowRun) -> WorkflowExecution:
|
||||
"""
|
||||
@ -190,6 +203,10 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
Args:
|
||||
execution: The WorkflowExecution domain entity to persist
|
||||
"""
|
||||
if self._use_async_persistence:
|
||||
self._queue_async_save(execution)
|
||||
return
|
||||
|
||||
# Convert domain model to database model using tenant context and other attributes
|
||||
db_model = self._to_db_model(execution)
|
||||
|
||||
@ -209,3 +226,20 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
|
||||
# Update the in-memory cache for faster subsequent lookups
|
||||
self._execution_cache[db_model.id] = db_model
|
||||
|
||||
def _queue_async_save(self, execution: WorkflowExecution) -> None:
|
||||
if not self._triggered_from:
|
||||
raise ValueError("triggered_from is required in repository constructor")
|
||||
if not self._creator_user_id:
|
||||
raise ValueError("created_by is required in repository constructor")
|
||||
if not self._creator_user_role:
|
||||
raise ValueError("created_by_role is required in repository constructor")
|
||||
|
||||
save_workflow_execution_task.delay(
|
||||
execution_data=execution.model_dump(),
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id or "",
|
||||
triggered_from=self._triggered_from.value,
|
||||
creator_user_id=self._creator_user_id,
|
||||
creator_user_role=self._creator_user_role.value,
|
||||
)
|
||||
|
||||
@ -37,6 +37,10 @@ from models.model import UploadFile
|
||||
from models.workflow import WorkflowNodeExecutionOffload
|
||||
from services.file_service import FileService
|
||||
from services.variable_truncator import VariableTruncator
|
||||
from tasks.workflow_node_execution_tasks import (
|
||||
save_workflow_node_execution_data_task,
|
||||
save_workflow_node_execution_task,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -60,6 +64,8 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
performance by reducing database queries.
|
||||
"""
|
||||
|
||||
_use_async_persistence: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: sessionmaker | Engine,
|
||||
@ -108,6 +114,16 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
|
||||
# Initialize FileService for handling offloaded data
|
||||
self._file_service = FileService(session_factory)
|
||||
self._use_async_persistence = False
|
||||
|
||||
def set_async_persistence(self, enabled: bool) -> None:
|
||||
"""
|
||||
Configure whether save operations should be queued through Celery.
|
||||
|
||||
Debug executions keep this disabled so node data is readable immediately. Non-debug
|
||||
app executions enable it from the workflow persistence layer.
|
||||
"""
|
||||
self._use_async_persistence = enabled
|
||||
|
||||
def _create_truncator(self) -> VariableTruncator:
|
||||
return VariableTruncator(
|
||||
@ -336,6 +352,10 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
# Only the final call contains the complete inputs and outputs payloads, so earlier invocations
|
||||
# must tolerate missing data without attempting to offload variables.
|
||||
|
||||
if self._use_async_persistence:
|
||||
self._queue_async_save(execution)
|
||||
return
|
||||
|
||||
# Convert domain model to database model using tenant context and other attributes
|
||||
db_model = self._to_db_model(execution)
|
||||
|
||||
@ -400,6 +420,10 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
self._node_execution_cache[db_model.node_execution_id] = db_model
|
||||
|
||||
def save_execution_data(self, execution: WorkflowNodeExecution):
|
||||
if self._use_async_persistence:
|
||||
self._queue_async_save_execution_data(execution)
|
||||
return
|
||||
|
||||
domain_model = execution
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
query = WorkflowNodeExecutionModel.preload_offload_data(select(WorkflowNodeExecutionModel)).where(
|
||||
@ -457,6 +481,40 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
session.merge(db_model)
|
||||
session.flush()
|
||||
|
||||
def _queue_async_save(self, execution: WorkflowNodeExecution) -> None:
|
||||
if not self._triggered_from:
|
||||
raise ValueError("triggered_from is required in repository constructor")
|
||||
if not self._creator_user_id:
|
||||
raise ValueError("created_by is required in repository constructor")
|
||||
if not self._creator_user_role:
|
||||
raise ValueError("created_by_role is required in repository constructor")
|
||||
|
||||
save_workflow_node_execution_task.delay(
|
||||
execution_data=execution.model_dump(),
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id or "",
|
||||
triggered_from=self._triggered_from.value,
|
||||
creator_user_id=self._creator_user_id,
|
||||
creator_user_role=self._creator_user_role.value,
|
||||
)
|
||||
|
||||
def _queue_async_save_execution_data(self, execution: WorkflowNodeExecution) -> None:
|
||||
if not self._triggered_from:
|
||||
raise ValueError("triggered_from is required in repository constructor")
|
||||
if not self._creator_user_id:
|
||||
raise ValueError("created_by is required in repository constructor")
|
||||
if not self._creator_user_role:
|
||||
raise ValueError("created_by_role is required in repository constructor")
|
||||
|
||||
save_workflow_node_execution_data_task.delay(
|
||||
execution_data=execution.model_dump(),
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id or "",
|
||||
triggered_from=self._triggered_from.value,
|
||||
creator_user_id=self._creator_user_id,
|
||||
creator_user_role=self._creator_user_role.value,
|
||||
)
|
||||
|
||||
def get_db_models_by_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
|
||||
@ -108,9 +108,10 @@ def _create_workflow_run_from_execution(
|
||||
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
|
||||
)
|
||||
workflow_run.error = execution.error_message
|
||||
workflow_run.elapsed_time = execution.elapsed_time
|
||||
workflow_run.elapsed_time = _calculate_elapsed_time(execution)
|
||||
workflow_run.total_tokens = execution.total_tokens
|
||||
workflow_run.total_steps = execution.total_steps
|
||||
workflow_run.exceptions_count = execution.exceptions_count
|
||||
workflow_run.created_by_role = creator_user_role
|
||||
workflow_run.created_by = creator_user_id
|
||||
workflow_run.created_at = execution.started_at
|
||||
@ -129,7 +130,14 @@ def _update_workflow_run_from_execution(workflow_run: WorkflowRun, execution: Wo
|
||||
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
|
||||
)
|
||||
workflow_run.error = execution.error_message
|
||||
workflow_run.elapsed_time = execution.elapsed_time
|
||||
workflow_run.elapsed_time = _calculate_elapsed_time(execution)
|
||||
workflow_run.total_tokens = execution.total_tokens
|
||||
workflow_run.total_steps = execution.total_steps
|
||||
workflow_run.exceptions_count = execution.exceptions_count
|
||||
workflow_run.finished_at = execution.finished_at
|
||||
|
||||
|
||||
def _calculate_elapsed_time(execution: WorkflowExecution) -> float:
|
||||
if execution.finished_at is None:
|
||||
return execution.elapsed_time
|
||||
return max((execution.finished_at - execution.started_at).total_seconds(), 0.0)
|
||||
|
||||
@ -7,7 +7,7 @@ improving performance by offloading storage operations to background workers.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
@ -17,9 +17,14 @@ from graphon.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
)
|
||||
from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from models import CreatorUserRole, WorkflowNodeExecutionModel
|
||||
from models import Account, CreatorUserRole, EndUser, WorkflowNodeExecutionModel
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import (
|
||||
SQLAlchemyWorkflowNodeExecutionRepository,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -34,7 +39,7 @@ def save_workflow_node_execution_task(
|
||||
creator_user_role: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Asynchronously save or update a workflow node execution to the database.
|
||||
Asynchronously save or update workflow node execution metadata to the database.
|
||||
|
||||
Args:
|
||||
execution_data: Serialized WorkflowNodeExecution data
|
||||
@ -49,20 +54,16 @@ def save_workflow_node_execution_task(
|
||||
"""
|
||||
try:
|
||||
with session_factory.create_session() as session:
|
||||
# Deserialize execution data
|
||||
execution = WorkflowNodeExecution.model_validate(execution_data)
|
||||
|
||||
# Check if node execution already exists
|
||||
existing_execution = session.scalar(
|
||||
select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution.id)
|
||||
)
|
||||
|
||||
if existing_execution:
|
||||
# Update existing node execution
|
||||
_update_node_execution_from_domain(existing_execution, execution)
|
||||
logger.debug("Updated existing workflow node execution: %s", execution.id)
|
||||
_update_node_execution_metadata(existing_execution, execution)
|
||||
logger.debug("Updated existing workflow node execution metadata: %s", execution.id)
|
||||
else:
|
||||
# Create new node execution
|
||||
node_execution = _create_node_execution_from_domain(
|
||||
execution=execution,
|
||||
tenant_id=tenant_id,
|
||||
@ -79,10 +80,76 @@ def save_workflow_node_execution_task(
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to save workflow node execution %s", execution_data.get("id", "unknown"))
|
||||
# Retry the task with exponential backoff
|
||||
raise self.retry(exc=e, countdown=60 * (2**self.request.retries))
|
||||
|
||||
|
||||
@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60)
|
||||
def save_workflow_node_execution_data_task(
|
||||
self,
|
||||
execution_data: dict[str, Any],
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
triggered_from: str,
|
||||
creator_user_id: str,
|
||||
creator_user_role: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Asynchronously save full workflow node execution data to the database.
|
||||
|
||||
This path preserves the SQLAlchemy repository's truncation and offload behavior while
|
||||
moving the blocking database and storage work out of the workflow engine thread.
|
||||
"""
|
||||
try:
|
||||
execution = WorkflowNodeExecution.model_validate(execution_data)
|
||||
repository = _create_sqlalchemy_repository(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
creator_user_id=creator_user_id,
|
||||
creator_user_role=creator_user_role,
|
||||
)
|
||||
repository.save(execution)
|
||||
repository.save_execution_data(execution)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.exception("Failed to save workflow node execution data %s", execution_data.get("id", "unknown"))
|
||||
raise self.retry(exc=e, countdown=60 * (2**self.request.retries))
|
||||
|
||||
|
||||
def _create_sqlalchemy_repository(
|
||||
*,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
triggered_from: str,
|
||||
creator_user_id: str,
|
||||
creator_user_role: str,
|
||||
) -> "SQLAlchemyWorkflowNodeExecutionRepository":
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import (
|
||||
SQLAlchemyWorkflowNodeExecutionRepository,
|
||||
)
|
||||
|
||||
session_maker = session_factory.get_session_maker()
|
||||
role = CreatorUserRole(creator_user_role)
|
||||
user: Account | EndUser | None
|
||||
with session_maker() as session:
|
||||
if role == CreatorUserRole.ACCOUNT:
|
||||
user = session.get(Account, creator_user_id)
|
||||
if user is not None:
|
||||
user.set_tenant_id(tenant_id)
|
||||
else:
|
||||
user = session.get(EndUser, creator_user_id)
|
||||
|
||||
if user is None:
|
||||
raise ValueError(f"Creator user {creator_user_id} not found for workflow node execution persistence")
|
||||
|
||||
return SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_maker,
|
||||
user=user,
|
||||
app_id=app_id or None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom(triggered_from),
|
||||
)
|
||||
|
||||
|
||||
def _create_node_execution_from_domain(
|
||||
execution: WorkflowNodeExecution,
|
||||
tenant_id: str,
|
||||
@ -108,16 +175,11 @@ def _create_node_execution_from_domain(
|
||||
node_execution.title = execution.title
|
||||
node_execution.node_execution_id = execution.node_execution_id
|
||||
|
||||
# Serialize complex data as JSON
|
||||
node_execution.inputs = "{}"
|
||||
node_execution.process_data = "{}"
|
||||
node_execution.outputs = "{}"
|
||||
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
node_execution.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) if execution.inputs else "{}"
|
||||
node_execution.process_data = (
|
||||
json.dumps(json_converter.to_json_encodable(execution.process_data)) if execution.process_data else "{}"
|
||||
)
|
||||
node_execution.outputs = (
|
||||
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
|
||||
)
|
||||
# Convert metadata enum keys to strings for JSON serialization
|
||||
if execution.metadata:
|
||||
metadata_for_json = {
|
||||
key.value if hasattr(key, "value") else str(key): value for key, value in execution.metadata.items()
|
||||
@ -137,20 +199,11 @@ def _create_node_execution_from_domain(
|
||||
return node_execution
|
||||
|
||||
|
||||
def _update_node_execution_from_domain(node_execution: WorkflowNodeExecutionModel, execution: WorkflowNodeExecution):
|
||||
def _update_node_execution_metadata(node_execution: WorkflowNodeExecutionModel, execution: WorkflowNodeExecution):
|
||||
"""
|
||||
Update a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity.
|
||||
Update WorkflowNodeExecutionModel metadata without changing persisted data payload fields.
|
||||
"""
|
||||
# Update serialized data
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
node_execution.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) if execution.inputs else "{}"
|
||||
node_execution.process_data = (
|
||||
json.dumps(json_converter.to_json_encodable(execution.process_data)) if execution.process_data else "{}"
|
||||
)
|
||||
node_execution.outputs = (
|
||||
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
|
||||
)
|
||||
# Convert metadata enum keys to strings for JSON serialization
|
||||
if execution.metadata:
|
||||
metadata_for_json = {
|
||||
key.value if hasattr(key, "value") else str(key): value for key, value in execution.metadata.items()
|
||||
@ -159,7 +212,6 @@ def _update_node_execution_from_domain(node_execution: WorkflowNodeExecutionMode
|
||||
else:
|
||||
node_execution.execution_metadata = "{}"
|
||||
|
||||
# Update other fields
|
||||
node_execution.status = execution.status
|
||||
node_execution.error = execution.error
|
||||
node_execution.elapsed_time = execution.elapsed_time
|
||||
|
||||
@ -5,8 +5,12 @@ from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity
|
||||
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.workflow.layers.persistence import (
|
||||
PersistenceWorkflowInfo,
|
||||
WorkflowPersistenceLayer,
|
||||
should_use_async_workflow_persistence,
|
||||
)
|
||||
from core.ops.ops_trace_manager import TraceTask, TraceTaskName
|
||||
from core.workflow.system_variables import SystemVariableKey, build_system_variables
|
||||
from graphon.entities import WorkflowNodeExecution
|
||||
@ -39,6 +43,7 @@ class _RepoRecorder:
|
||||
def __init__(self) -> None:
|
||||
self.saved: list[object] = []
|
||||
self.saved_exec_data: list[object] = []
|
||||
self.async_enabled: bool | None = None
|
||||
|
||||
def save(self, entity):
|
||||
self.saved.append(entity)
|
||||
@ -46,6 +51,9 @@ class _RepoRecorder:
|
||||
def save_execution_data(self, entity):
|
||||
self.saved_exec_data.append(entity)
|
||||
|
||||
def set_async_persistence(self, enabled: bool) -> None:
|
||||
self.async_enabled = enabled
|
||||
|
||||
|
||||
def _naive_utc_now() -> datetime:
|
||||
return datetime.now(UTC).replace(tzinfo=None)
|
||||
@ -55,6 +63,7 @@ def _make_layer(
|
||||
system_variables: list | None = None,
|
||||
*,
|
||||
extras: dict | None = None,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
trace_manager: object | None = None,
|
||||
):
|
||||
system_variables = system_variables or build_system_variables(
|
||||
@ -74,7 +83,7 @@ def _make_layer(
|
||||
files=[],
|
||||
user_id="user",
|
||||
stream=False,
|
||||
invoke_from=None,
|
||||
invoke_from=invoke_from,
|
||||
trace_manager=None,
|
||||
workflow_execution_id="run-id",
|
||||
extras=extras or {},
|
||||
@ -96,6 +105,7 @@ def _make_layer(
|
||||
workflow_info=workflow_info,
|
||||
workflow_execution_repository=workflow_execution_repo,
|
||||
workflow_node_execution_repository=workflow_node_execution_repo,
|
||||
invoke_from=invoke_from,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
layer.initialize(read_only_state, command_channel=None)
|
||||
@ -104,6 +114,30 @@ def _make_layer(
|
||||
|
||||
|
||||
class TestWorkflowPersistenceLayer:
|
||||
@pytest.mark.parametrize(
|
||||
("invoke_from", "expected"),
|
||||
[
|
||||
(InvokeFrom.DEBUGGER, False),
|
||||
(InvokeFrom.WEB_APP, True),
|
||||
(InvokeFrom.SERVICE_API, True),
|
||||
(InvokeFrom.TRIGGER, True),
|
||||
],
|
||||
)
|
||||
def test_should_use_async_workflow_persistence(self, invoke_from: InvokeFrom, expected: bool):
|
||||
assert should_use_async_workflow_persistence(invoke_from) is expected
|
||||
|
||||
def test_configures_repositories_for_debug_synchronous_persistence(self):
|
||||
_, exec_repo, node_repo, _ = _make_layer(invoke_from=InvokeFrom.DEBUGGER)
|
||||
|
||||
assert exec_repo.async_enabled is False
|
||||
assert node_repo.async_enabled is False
|
||||
|
||||
def test_configures_repositories_for_non_debug_async_persistence(self):
|
||||
_, exec_repo, node_repo, _ = _make_layer(invoke_from=InvokeFrom.WEB_APP)
|
||||
|
||||
assert exec_repo.async_enabled is True
|
||||
assert node_repo.async_enabled is True
|
||||
|
||||
def test_on_graph_start_resets_state(self):
|
||||
layer, _, _, _ = _make_layer()
|
||||
layer._workflow_execution = object()
|
||||
|
||||
@ -1,248 +0,0 @@
|
||||
"""
|
||||
Unit tests for CeleryWorkflowExecutionRepository.
|
||||
|
||||
These tests verify the Celery-based asynchronous storage functionality
|
||||
for workflow execution data.
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
|
||||
from graphon.entities import WorkflowExecution
|
||||
from graphon.enums import WorkflowType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, EndUser
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_factory():
|
||||
"""Mock SQLAlchemy session factory."""
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
# Create a real sessionmaker with in-memory SQLite for testing
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
return sessionmaker(bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account():
|
||||
"""Mock Account user."""
|
||||
account = Mock(spec=Account)
|
||||
account.id = str(uuid4())
|
||||
account.current_tenant_id = str(uuid4())
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_end_user():
|
||||
"""Mock EndUser."""
|
||||
user = Mock(spec=EndUser)
|
||||
user.id = str(uuid4())
|
||||
user.tenant_id = str(uuid4())
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workflow_execution():
|
||||
"""Sample WorkflowExecution for testing."""
|
||||
return WorkflowExecution.new(
|
||||
id_=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"input1": "value1"},
|
||||
started_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
|
||||
class TestCeleryWorkflowExecutionRepository:
|
||||
"""Test cases for CeleryWorkflowExecutionRepository."""
|
||||
|
||||
def test_init_with_sessionmaker(self, mock_session_factory, mock_account):
|
||||
"""Test repository initialization with sessionmaker."""
|
||||
app_id = "test-app-id"
|
||||
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
|
||||
assert repo._tenant_id == mock_account.current_tenant_id
|
||||
assert repo._app_id == app_id
|
||||
assert repo._triggered_from == triggered_from
|
||||
assert repo._creator_user_id == mock_account.id
|
||||
assert repo._creator_user_role is not None
|
||||
|
||||
def test_init_basic_functionality(self, mock_session_factory, mock_account):
|
||||
"""Test repository initialization basic functionality."""
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
)
|
||||
|
||||
# Verify basic initialization
|
||||
assert repo._tenant_id == mock_account.current_tenant_id
|
||||
assert repo._app_id == "test-app"
|
||||
assert repo._triggered_from == WorkflowRunTriggeredFrom.DEBUGGING
|
||||
|
||||
def test_init_with_end_user(self, mock_session_factory, mock_end_user):
|
||||
"""Test repository initialization with EndUser."""
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_end_user,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
assert repo._tenant_id == mock_end_user.tenant_id
|
||||
|
||||
def test_init_without_tenant_id_raises_error(self, mock_session_factory):
|
||||
"""Test that initialization fails without tenant_id."""
|
||||
# Create a mock Account with no tenant_id
|
||||
user = Mock(spec=Account)
|
||||
user.current_tenant_id = None
|
||||
user.id = str(uuid4())
|
||||
|
||||
with pytest.raises(ValueError, match="User must have a tenant_id"):
|
||||
CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=user,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
|
||||
def test_save_queues_celery_task(self, mock_task, mock_session_factory, mock_account, sample_workflow_execution):
|
||||
"""Test that save operation queues a Celery task without tracking."""
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
repo.save(sample_workflow_execution)
|
||||
|
||||
# Verify Celery task was queued with correct parameters
|
||||
mock_task.delay.assert_called_once()
|
||||
call_args = mock_task.delay.call_args[1]
|
||||
|
||||
assert call_args["execution_data"] == sample_workflow_execution.model_dump()
|
||||
assert call_args["tenant_id"] == mock_account.current_tenant_id
|
||||
assert call_args["app_id"] == "test-app"
|
||||
assert call_args["triggered_from"] == WorkflowRunTriggeredFrom.APP_RUN
|
||||
assert call_args["creator_user_id"] == mock_account.id
|
||||
|
||||
# Verify no task tracking occurs (no _pending_saves attribute)
|
||||
assert not hasattr(repo, "_pending_saves")
|
||||
|
||||
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
|
||||
def test_save_handles_celery_failure(
|
||||
self, mock_task, mock_session_factory, mock_account, sample_workflow_execution
|
||||
):
|
||||
"""Test that save operation handles Celery task failures."""
|
||||
mock_task.delay.side_effect = Exception("Celery is down")
|
||||
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match="Celery is down"):
|
||||
repo.save(sample_workflow_execution)
|
||||
|
||||
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
|
||||
def test_save_operation_fire_and_forget(
|
||||
self, mock_task, mock_session_factory, mock_account, sample_workflow_execution
|
||||
):
|
||||
"""Test that save operation works in fire-and-forget mode."""
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
# Test that save doesn't block or maintain state
|
||||
repo.save(sample_workflow_execution)
|
||||
|
||||
# Verify no pending saves are tracked (no _pending_saves attribute)
|
||||
assert not hasattr(repo, "_pending_saves")
|
||||
|
||||
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
|
||||
def test_multiple_save_operations(self, mock_task, mock_session_factory, mock_account):
|
||||
"""Test multiple save operations work correctly."""
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
# Create multiple executions
|
||||
exec1 = WorkflowExecution.new(
|
||||
id_=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"input1": "value1"},
|
||||
started_at=naive_utc_now(),
|
||||
)
|
||||
exec2 = WorkflowExecution.new(
|
||||
id_=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"input2": "value2"},
|
||||
started_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Save both executions
|
||||
repo.save(exec1)
|
||||
repo.save(exec2)
|
||||
|
||||
# Should work without issues and not maintain state (no _pending_saves attribute)
|
||||
assert not hasattr(repo, "_pending_saves")
|
||||
|
||||
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
|
||||
def test_save_with_different_user_types(self, mock_task, mock_session_factory, mock_end_user):
|
||||
"""Test save operation with different user types."""
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_end_user,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
execution = WorkflowExecution.new(
|
||||
id_=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"input1": "value1"},
|
||||
started_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
repo.save(execution)
|
||||
|
||||
# Verify task was called with EndUser context
|
||||
mock_task.delay.assert_called_once()
|
||||
call_args = mock_task.delay.call_args[1]
|
||||
assert call_args["tenant_id"] == mock_end_user.tenant_id
|
||||
assert call_args["creator_user_id"] == mock_end_user.id
|
||||
@ -1,349 +0,0 @@
|
||||
"""
|
||||
Unit tests for CeleryWorkflowNodeExecutionRepository.
|
||||
|
||||
These tests verify the Celery-based asynchronous storage functionality
|
||||
for workflow node execution data.
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
|
||||
from core.repositories.factory import OrderConfig
|
||||
from graphon.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, EndUser
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_factory():
|
||||
"""Mock SQLAlchemy session factory."""
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
# Create a real sessionmaker with in-memory SQLite for testing
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
return sessionmaker(bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account():
|
||||
"""Mock Account user."""
|
||||
account = Mock(spec=Account)
|
||||
account.id = str(uuid4())
|
||||
account.current_tenant_id = str(uuid4())
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_end_user():
|
||||
"""Mock EndUser."""
|
||||
user = Mock(spec=EndUser)
|
||||
user.id = str(uuid4())
|
||||
user.tenant_id = str(uuid4())
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workflow_node_execution():
|
||||
"""Sample WorkflowNodeExecution for testing."""
|
||||
return WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
node_execution_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_execution_id=str(uuid4()),
|
||||
index=1,
|
||||
node_id="test_node",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
title="Test Node",
|
||||
inputs={"input1": "value1"},
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
|
||||
class TestCeleryWorkflowNodeExecutionRepository:
|
||||
"""Test cases for CeleryWorkflowNodeExecutionRepository."""
|
||||
|
||||
def test_init_with_sessionmaker(self, mock_session_factory, mock_account):
|
||||
"""Test repository initialization with sessionmaker."""
|
||||
app_id = "test-app-id"
|
||||
triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
|
||||
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
|
||||
assert repo._tenant_id == mock_account.current_tenant_id
|
||||
assert repo._app_id == app_id
|
||||
assert repo._triggered_from == triggered_from
|
||||
assert repo._creator_user_id == mock_account.id
|
||||
assert repo._creator_user_role is not None
|
||||
|
||||
def test_init_with_cache_initialized(self, mock_session_factory, mock_account):
|
||||
"""Test repository initialization with cache properly initialized."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
|
||||
assert repo._execution_cache == {}
|
||||
assert repo._workflow_execution_mapping == {}
|
||||
|
||||
def test_init_with_end_user(self, mock_session_factory, mock_end_user):
|
||||
"""Test repository initialization with EndUser."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_end_user,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
assert repo._tenant_id == mock_end_user.tenant_id
|
||||
|
||||
def test_init_without_tenant_id_raises_error(self, mock_session_factory):
|
||||
"""Test that initialization fails without tenant_id."""
|
||||
# Create a mock Account with no tenant_id
|
||||
user = Mock(spec=Account)
|
||||
user.current_tenant_id = None
|
||||
user.id = str(uuid4())
|
||||
|
||||
with pytest.raises(ValueError, match="User must have a tenant_id"):
|
||||
CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=user,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
|
||||
def test_save_caches_and_queues_celery_task(
|
||||
self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution
|
||||
):
|
||||
"""Test that save operation caches execution and queues a Celery task."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
repo.save(sample_workflow_node_execution)
|
||||
|
||||
# Verify Celery task was queued with correct parameters
|
||||
mock_task.delay.assert_called_once()
|
||||
call_args = mock_task.delay.call_args[1]
|
||||
|
||||
assert call_args["execution_data"] == sample_workflow_node_execution.model_dump()
|
||||
assert call_args["tenant_id"] == mock_account.current_tenant_id
|
||||
assert call_args["app_id"] == "test-app"
|
||||
assert call_args["triggered_from"] == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
|
||||
assert call_args["creator_user_id"] == mock_account.id
|
||||
|
||||
# Verify execution is cached
|
||||
assert sample_workflow_node_execution.id in repo._execution_cache
|
||||
assert repo._execution_cache[sample_workflow_node_execution.id] == sample_workflow_node_execution
|
||||
|
||||
# Verify workflow execution mapping is updated
|
||||
assert sample_workflow_node_execution.workflow_execution_id in repo._workflow_execution_mapping
|
||||
assert (
|
||||
sample_workflow_node_execution.id
|
||||
in repo._workflow_execution_mapping[sample_workflow_node_execution.workflow_execution_id]
|
||||
)
|
||||
|
||||
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
|
||||
def test_save_handles_celery_failure(
|
||||
self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution
|
||||
):
|
||||
"""Test that save operation handles Celery task failures."""
|
||||
mock_task.delay.side_effect = Exception("Celery is down")
|
||||
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match="Celery is down"):
|
||||
repo.save(sample_workflow_node_execution)
|
||||
|
||||
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
|
||||
def test_get_by_workflow_execution_from_cache(
|
||||
self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution
|
||||
):
|
||||
"""Test that get_by_workflow_execution retrieves executions from cache."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Save execution to cache first
|
||||
repo.save(sample_workflow_node_execution)
|
||||
|
||||
workflow_execution_id = sample_workflow_node_execution.workflow_execution_id
|
||||
order_config = OrderConfig(order_by=["index"], order_direction="asc")
|
||||
|
||||
result = repo.get_by_workflow_execution(workflow_execution_id, order_config)
|
||||
|
||||
# Verify results were retrieved from cache
|
||||
assert len(result) == 1
|
||||
assert result[0].id == sample_workflow_node_execution.id
|
||||
assert result[0] is sample_workflow_node_execution
|
||||
|
||||
def test_get_by_workflow_execution_without_order_config(self, mock_session_factory, mock_account):
|
||||
"""Test get_by_workflow_execution without order configuration."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
result = repo.get_by_workflow_execution("workflow-run-id")
|
||||
|
||||
# Should return empty list since nothing in cache
|
||||
assert len(result) == 0
|
||||
|
||||
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
|
||||
def test_cache_operations(self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution):
|
||||
"""Test cache operations work correctly."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Test saving to cache
|
||||
repo.save(sample_workflow_node_execution)
|
||||
|
||||
# Verify cache contains the execution
|
||||
assert sample_workflow_node_execution.id in repo._execution_cache
|
||||
|
||||
# Test retrieving from cache
|
||||
result = repo.get_by_workflow_execution(sample_workflow_node_execution.workflow_execution_id)
|
||||
assert len(result) == 1
|
||||
assert result[0].id == sample_workflow_node_execution.id
|
||||
|
||||
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
|
||||
def test_multiple_executions_same_workflow(self, mock_task, mock_session_factory, mock_account):
|
||||
"""Test multiple executions for the same workflow."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Create multiple executions for the same workflow
|
||||
workflow_execution_id = str(uuid4())
|
||||
exec1 = WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
node_execution_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_execution_id=workflow_execution_id,
|
||||
index=1,
|
||||
node_id="node1",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
title="Node 1",
|
||||
inputs={"input1": "value1"},
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
exec2 = WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
node_execution_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_execution_id=workflow_execution_id,
|
||||
index=2,
|
||||
node_id="node2",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
title="Node 2",
|
||||
inputs={"input2": "value2"},
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Save both executions
|
||||
repo.save(exec1)
|
||||
repo.save(exec2)
|
||||
|
||||
# Verify both are cached and mapped
|
||||
assert len(repo._execution_cache) == 2
|
||||
assert len(repo._workflow_execution_mapping[workflow_execution_id]) == 2
|
||||
|
||||
# Test retrieval
|
||||
result = repo.get_by_workflow_execution(workflow_execution_id)
|
||||
assert len(result) == 2
|
||||
|
||||
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
|
||||
def test_ordering_functionality(self, mock_task, mock_session_factory, mock_account):
|
||||
"""Test ordering functionality works correctly."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Create executions with different indices
|
||||
workflow_execution_id = str(uuid4())
|
||||
exec1 = WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
node_execution_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_execution_id=workflow_execution_id,
|
||||
index=2,
|
||||
node_id="node2",
|
||||
node_type=BuiltinNodeTypes.START,
|
||||
title="Node 2",
|
||||
inputs={},
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
exec2 = WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
node_execution_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_execution_id=workflow_execution_id,
|
||||
index=1,
|
||||
node_id="node1",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
title="Node 1",
|
||||
inputs={},
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Save in random order
|
||||
repo.save(exec1)
|
||||
repo.save(exec2)
|
||||
|
||||
# Test ascending order
|
||||
order_config = OrderConfig(order_by=["index"], order_direction="asc")
|
||||
result = repo.get_by_workflow_execution(workflow_execution_id, order_config)
|
||||
assert len(result) == 2
|
||||
assert result[0].index == 1
|
||||
assert result[1].index == 2
|
||||
|
||||
# Test descending order
|
||||
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
||||
result = repo.get_by_workflow_execution(workflow_execution_id, order_config)
|
||||
assert len(result) == 2
|
||||
assert result[0].index == 2
|
||||
assert result[1].index == 1
|
||||
@ -1,5 +1,5 @@
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
@ -240,6 +240,78 @@ class TestSQLAlchemyWorkflowExecutionRepository:
|
||||
cached_model = repo._execution_cache[sample_workflow_execution.id_]
|
||||
assert cached_model.id == sample_workflow_execution.id_
|
||||
|
||||
def test_save_rejects_existing_run_from_other_tenant(
|
||||
self, mock_session_factory, mock_account, sample_workflow_execution
|
||||
):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test_app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
existing_run = WorkflowRun()
|
||||
existing_run.id = sample_workflow_execution.id_
|
||||
existing_run.tenant_id = "other-tenant"
|
||||
session = mock_session_factory.return_value.__enter__.return_value
|
||||
session.get.return_value = existing_run
|
||||
|
||||
with pytest.raises(ValueError, match="Unauthorized access to workflow run"):
|
||||
repo.save(sample_workflow_execution)
|
||||
|
||||
session.merge.assert_not_called()
|
||||
session.commit.assert_not_called()
|
||||
|
||||
@patch("core.repositories.sqlalchemy_workflow_execution_repository.save_workflow_execution_task")
|
||||
def test_save_queues_celery_task_when_async_persistence_enabled(
|
||||
self, mock_task, mock_session_factory, mock_account, sample_workflow_execution
|
||||
):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test_app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
repo.set_async_persistence(True)
|
||||
|
||||
repo.save(sample_workflow_execution)
|
||||
|
||||
mock_task.delay.assert_called_once()
|
||||
call_args = mock_task.delay.call_args.kwargs
|
||||
assert call_args["execution_data"] == sample_workflow_execution.model_dump()
|
||||
assert call_args["tenant_id"] == mock_account.current_tenant_id
|
||||
assert call_args["app_id"] == "test_app"
|
||||
assert call_args["triggered_from"] == WorkflowRunTriggeredFrom.APP_RUN
|
||||
assert call_args["creator_user_id"] == mock_account.id
|
||||
session = mock_session_factory.return_value.__enter__.return_value
|
||||
session.merge.assert_not_called()
|
||||
|
||||
@patch("core.repositories.sqlalchemy_workflow_execution_repository.save_workflow_execution_task")
|
||||
def test_queue_async_save_requires_context(
|
||||
self, mock_task, mock_session_factory, mock_account, sample_workflow_execution
|
||||
):
|
||||
repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test_app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
repo._triggered_from = None
|
||||
with pytest.raises(ValueError, match="triggered_from is required"):
|
||||
repo._queue_async_save(sample_workflow_execution)
|
||||
|
||||
repo._triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
repo._creator_user_id = None
|
||||
with pytest.raises(ValueError, match="created_by is required"):
|
||||
repo._queue_async_save(sample_workflow_execution)
|
||||
|
||||
repo._creator_user_id = "user-id"
|
||||
repo._creator_user_role = None
|
||||
with pytest.raises(ValueError, match="created_by_role is required"):
|
||||
repo._queue_async_save(sample_workflow_execution)
|
||||
|
||||
mock_task.delay.assert_not_called()
|
||||
|
||||
def test_save_uses_execution_started_at_when_record_does_not_exist(
|
||||
self, mock_session_factory, mock_account, sample_workflow_execution
|
||||
):
|
||||
|
||||
@ -6,7 +6,7 @@ from collections.abc import Mapping
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, Mock
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import psycopg2.errors
|
||||
import pytest
|
||||
@ -599,6 +599,124 @@ def test_save_execution_data_handles_missing_db_model(monkeypatch: pytest.Monkey
|
||||
assert merged.inputs == '{"a": 1}'
|
||||
|
||||
|
||||
@patch("core.repositories.sqlalchemy_workflow_node_execution_repository.save_workflow_node_execution_task")
|
||||
def test_save_queues_celery_task_when_async_persistence_enabled(mock_task, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id="app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
repo.set_async_persistence(True)
|
||||
execution = _execution(inputs={"a": 1})
|
||||
|
||||
repo.save(execution)
|
||||
|
||||
mock_task.delay.assert_called_once()
|
||||
call_args = mock_task.delay.call_args.kwargs
|
||||
assert call_args["execution_data"] == execution.model_dump()
|
||||
assert call_args["tenant_id"] == "tenant"
|
||||
assert call_args["app_id"] == "app"
|
||||
assert call_args["triggered_from"] == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
|
||||
assert call_args["creator_user_id"] == "user"
|
||||
assert set(call_args) == {
|
||||
"execution_data",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"triggered_from",
|
||||
"creator_user_id",
|
||||
"creator_user_role",
|
||||
}
|
||||
|
||||
|
||||
@patch("core.repositories.sqlalchemy_workflow_node_execution_repository.save_workflow_node_execution_data_task")
|
||||
def test_save_execution_data_queues_celery_task_when_async_persistence_enabled(
|
||||
mock_task, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id="app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
repo.set_async_persistence(True)
|
||||
execution = _execution(inputs={"a": 1})
|
||||
|
||||
repo.save_execution_data(execution)
|
||||
|
||||
mock_task.delay.assert_called_once()
|
||||
call_args = mock_task.delay.call_args.kwargs
|
||||
assert call_args["execution_data"] == execution.model_dump()
|
||||
assert call_args["tenant_id"] == "tenant"
|
||||
assert call_args["app_id"] == "app"
|
||||
assert call_args["triggered_from"] == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
|
||||
assert call_args["creator_user_id"] == "user"
|
||||
|
||||
|
||||
def test_queue_async_save_requires_context(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id="app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
execution = _execution()
|
||||
|
||||
repo._triggered_from = None
|
||||
with pytest.raises(ValueError, match="triggered_from is required"):
|
||||
repo._queue_async_save(execution)
|
||||
|
||||
repo._triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
|
||||
repo._creator_user_id = None
|
||||
with pytest.raises(ValueError, match="created_by is required"):
|
||||
repo._queue_async_save(execution)
|
||||
|
||||
repo._creator_user_id = "user"
|
||||
repo._creator_user_role = None
|
||||
with pytest.raises(ValueError, match="created_by_role is required"):
|
||||
repo._queue_async_save(execution)
|
||||
|
||||
|
||||
def test_queue_async_save_execution_data_requires_context(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.FileService",
|
||||
lambda *_: SimpleNamespace(upload_file=Mock()),
|
||||
)
|
||||
repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=Mock(spec=sessionmaker),
|
||||
user=_mock_account(),
|
||||
app_id="app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
execution = _execution()
|
||||
|
||||
repo._triggered_from = None
|
||||
with pytest.raises(ValueError, match="triggered_from is required"):
|
||||
repo._queue_async_save_execution_data(execution)
|
||||
|
||||
repo._triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
|
||||
repo._creator_user_id = None
|
||||
with pytest.raises(ValueError, match="created_by is required"):
|
||||
repo._queue_async_save_execution_data(execution)
|
||||
|
||||
repo._creator_user_id = "user"
|
||||
repo._creator_user_role = None
|
||||
with pytest.raises(ValueError, match="created_by_role is required"):
|
||||
repo._queue_async_save_execution_data(execution)
|
||||
|
||||
|
||||
def test_save_retries_duplicate_and_logs_non_duplicate(
|
||||
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
|
||||
78
api/tests/unit_tests/tasks/test_workflow_execution_tasks.py
Normal file
78
api/tests/unit_tests/tasks/test_workflow_execution_tasks.py
Normal file
@ -0,0 +1,78 @@
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from graphon.entities import WorkflowExecution
|
||||
from graphon.enums import WorkflowExecutionStatus, WorkflowType
|
||||
from models import CreatorUserRole, WorkflowRun
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from tasks.workflow_execution_tasks import (
|
||||
_calculate_elapsed_time,
|
||||
_create_workflow_run_from_execution,
|
||||
_update_workflow_run_from_execution,
|
||||
)
|
||||
|
||||
|
||||
def _execution(
|
||||
*,
|
||||
elapsed_time: float = 3.5,
|
||||
exceptions_count: int = 2,
|
||||
finished_at: datetime | None = None,
|
||||
) -> WorkflowExecution:
|
||||
started_at = datetime(2026, 1, 1, 12, 0, 0)
|
||||
return WorkflowExecution(
|
||||
id_="workflow-run-id",
|
||||
workflow_id="workflow-id",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"input": "value"},
|
||||
outputs={"output": "value"},
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
error_message="",
|
||||
elapsed_time=elapsed_time,
|
||||
total_tokens=100,
|
||||
total_steps=5,
|
||||
exceptions_count=exceptions_count,
|
||||
started_at=started_at,
|
||||
finished_at=finished_at,
|
||||
)
|
||||
|
||||
|
||||
def test_create_workflow_run_calculates_elapsed_time_and_exceptions_count() -> None:
|
||||
execution = _execution(finished_at=datetime(2026, 1, 1, 12, 0, 12), exceptions_count=3)
|
||||
|
||||
workflow_run = _create_workflow_run_from_execution(
|
||||
execution=execution,
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
creator_user_id="user-id",
|
||||
creator_user_role=CreatorUserRole.ACCOUNT,
|
||||
)
|
||||
|
||||
assert workflow_run.elapsed_time == 12.0
|
||||
assert workflow_run.exceptions_count == 3
|
||||
|
||||
|
||||
def test_update_workflow_run_calculates_elapsed_time_and_exceptions_count() -> None:
|
||||
workflow_run = WorkflowRun()
|
||||
execution = _execution(finished_at=datetime(2026, 1, 1, 12, 0, 8), exceptions_count=4)
|
||||
|
||||
_update_workflow_run_from_execution(workflow_run, execution)
|
||||
|
||||
assert workflow_run.elapsed_time == 8.0
|
||||
assert workflow_run.exceptions_count == 4
|
||||
|
||||
|
||||
def test_calculate_elapsed_time_uses_runtime_elapsed_time_until_finished() -> None:
|
||||
execution = _execution(finished_at=None)
|
||||
execution.started_at = datetime.now(UTC).replace(tzinfo=None) - timedelta(seconds=4)
|
||||
|
||||
elapsed_time = _calculate_elapsed_time(execution)
|
||||
|
||||
assert 3.9 <= elapsed_time <= 5.0
|
||||
|
||||
|
||||
def test_calculate_elapsed_time_clamps_negative_duration_to_zero() -> None:
|
||||
execution = _execution(finished_at=datetime(2026, 1, 1, 11, 59, 59))
|
||||
|
||||
assert _calculate_elapsed_time(execution) == 0.0
|
||||
@ -1,488 +1,327 @@
|
||||
# """
|
||||
# Unit tests for workflow node execution Celery tasks.
|
||||
|
||||
# These tests verify the asynchronous storage functionality for workflow node execution data,
|
||||
# including truncation and offloading logic.
|
||||
# """
|
||||
|
||||
# import json
|
||||
# from unittest.mock import MagicMock, Mock, patch
|
||||
# from uuid import uuid4
|
||||
|
||||
# import pytest
|
||||
|
||||
# from graphon.entities.workflow_node_execution import (
|
||||
# WorkflowNodeExecution,
|
||||
# WorkflowNodeExecutionStatus,
|
||||
# )
|
||||
# from graphon.enums import BuiltinNodeTypes
|
||||
# from libs.datetime_utils import naive_utc_now
|
||||
# from models import WorkflowNodeExecutionModel
|
||||
# from models.enums import ExecutionOffLoadType
|
||||
# from models.model import UploadFile
|
||||
# from models.workflow import WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom
|
||||
# from tasks.workflow_node_execution_tasks import (
|
||||
# _create_truncator,
|
||||
# _json_encode,
|
||||
# _replace_or_append_offload,
|
||||
# _truncate_and_upload_async,
|
||||
# save_workflow_node_execution_data_task,
|
||||
# save_workflow_node_execution_task,
|
||||
# )
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def sample_execution_data():
|
||||
# """Sample execution data for testing."""
|
||||
# execution = WorkflowNodeExecution(
|
||||
# id=str(uuid4()),
|
||||
# node_execution_id=str(uuid4()),
|
||||
# workflow_id=str(uuid4()),
|
||||
# workflow_execution_id=str(uuid4()),
|
||||
# index=1,
|
||||
# node_id="test_node",
|
||||
# node_type=BuiltinNodeTypes.LLM,
|
||||
# title="Test Node",
|
||||
# inputs={"input_key": "input_value"},
|
||||
# outputs={"output_key": "output_value"},
|
||||
# process_data={"process_key": "process_value"},
|
||||
# status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
# created_at=naive_utc_now(),
|
||||
# )
|
||||
# return execution.model_dump()
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def mock_db_model():
|
||||
# """Mock database model for testing."""
|
||||
# db_model = Mock(spec=WorkflowNodeExecutionModel)
|
||||
# db_model.id = "test-execution-id"
|
||||
# db_model.offload_data = []
|
||||
# return db_model
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def mock_file_service():
|
||||
# """Mock file service for testing."""
|
||||
# file_service = Mock()
|
||||
# mock_upload_file = Mock(spec=UploadFile)
|
||||
# mock_upload_file.id = "mock-file-id"
|
||||
# file_service.upload_file.return_value = mock_upload_file
|
||||
# return file_service
|
||||
|
||||
|
||||
# class TestSaveWorkflowNodeExecutionDataTask:
|
||||
# """Test cases for save_workflow_node_execution_data_task."""
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
|
||||
# @patch("tasks.workflow_node_execution_tasks.select")
|
||||
# def test_save_execution_data_task_success(
|
||||
# self, mock_select, mock_sessionmaker, sample_execution_data, mock_db_model
|
||||
# ):
|
||||
# """Test successful execution of save_workflow_node_execution_data_task."""
|
||||
# # Setup mocks
|
||||
# mock_session = MagicMock()
|
||||
# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session
|
||||
# mock_session.execute.return_value.scalars.return_value.first.return_value = mock_db_model
|
||||
|
||||
# # Execute task
|
||||
# result = save_workflow_node_execution_data_task(
|
||||
# execution_data=sample_execution_data,
|
||||
# tenant_id="test-tenant-id",
|
||||
# app_id="test-app-id",
|
||||
# user_data={"user_id": "test-user-id", "user_type": "account"},
|
||||
# )
|
||||
|
||||
# # Verify success
|
||||
# assert result is True
|
||||
# mock_session.merge.assert_called_once_with(mock_db_model)
|
||||
# mock_session.commit.assert_called_once()
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
|
||||
# @patch("tasks.workflow_node_execution_tasks.select")
|
||||
# def test_save_execution_data_task_execution_not_found(self, mock_select, mock_sessionmaker,
|
||||
# sample_execution_data):
|
||||
# """Test task when execution is not found in database."""
|
||||
# # Setup mocks
|
||||
# mock_session = MagicMock()
|
||||
# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session
|
||||
# mock_session.execute.return_value.scalars.return_value.first.return_value = None
|
||||
|
||||
# # Execute task
|
||||
# result = save_workflow_node_execution_data_task(
|
||||
# execution_data=sample_execution_data,
|
||||
# tenant_id="test-tenant-id",
|
||||
# app_id="test-app-id",
|
||||
# user_data={"user_id": "test-user-id", "user_type": "account"},
|
||||
# )
|
||||
|
||||
# # Verify failure
|
||||
# assert result is False
|
||||
# mock_session.merge.assert_not_called()
|
||||
# mock_session.commit.assert_not_called()
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
|
||||
# @patch("tasks.workflow_node_execution_tasks.select")
|
||||
# def test_save_execution_data_task_with_truncation(self, mock_select, mock_sessionmaker, mock_db_model):
|
||||
# """Test task with data that requires truncation."""
|
||||
# # Create execution with large data
|
||||
# large_data = {"large_field": "x" * 10000}
|
||||
# execution = WorkflowNodeExecution(
|
||||
# id=str(uuid4()),
|
||||
# node_execution_id=str(uuid4()),
|
||||
# workflow_id=str(uuid4()),
|
||||
# workflow_execution_id=str(uuid4()),
|
||||
# index=1,
|
||||
# node_id="test_node",
|
||||
# node_type=BuiltinNodeTypes.LLM,
|
||||
# title="Test Node",
|
||||
# inputs=large_data,
|
||||
# outputs=large_data,
|
||||
# process_data=large_data,
|
||||
# status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
# created_at=naive_utc_now(),
|
||||
# )
|
||||
# execution_data = execution.model_dump()
|
||||
|
||||
# # Setup mocks
|
||||
# mock_session = MagicMock()
|
||||
# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session
|
||||
# mock_session.execute.return_value.scalars.return_value.first.return_value = mock_db_model
|
||||
|
||||
# # Create mock upload file
|
||||
# mock_upload_file = Mock(spec=UploadFile)
|
||||
# mock_upload_file.id = "mock-file-id"
|
||||
|
||||
# # Execute task
|
||||
# with patch("tasks.workflow_node_execution_tasks._truncate_and_upload_async") as mock_truncate:
|
||||
# # Mock truncation results
|
||||
# mock_truncate.return_value = {
|
||||
# "truncated_value": {"large_field": "[TRUNCATED]"},
|
||||
# "file": mock_upload_file,
|
||||
# "offload": WorkflowNodeExecutionOffload(
|
||||
# id=str(uuid4()),
|
||||
# tenant_id="test-tenant-id",
|
||||
# app_id="test-app-id",
|
||||
# node_execution_id=execution.id,
|
||||
# type_=ExecutionOffLoadType.INPUTS,
|
||||
# file_id=mock_upload_file.id,
|
||||
# ),
|
||||
# }
|
||||
|
||||
# result = save_workflow_node_execution_data_task(
|
||||
# execution_data=execution_data,
|
||||
# tenant_id="test-tenant-id",
|
||||
# app_id="test-app-id",
|
||||
# user_data={"user_id": "test-user-id", "user_type": "account"},
|
||||
# )
|
||||
|
||||
# # Verify success and truncation was called
|
||||
# assert result is True
|
||||
# assert mock_truncate.call_count == 3 # inputs, outputs, process_data
|
||||
# mock_session.merge.assert_called_once_with(mock_db_model)
|
||||
# mock_session.commit.assert_called_once()
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
|
||||
# def test_save_execution_data_task_retry_on_exception(self, mock_sessionmaker, sample_execution_data):
|
||||
# """Test task retry mechanism on exception."""
|
||||
# # Setup mock to raise exception
|
||||
# mock_sessionmaker.side_effect = Exception("Database error")
|
||||
|
||||
# # Create a mock task instance with proper retry behavior
|
||||
# with patch.object(save_workflow_node_execution_data_task, "retry") as mock_retry:
|
||||
# mock_retry.side_effect = Exception("Retry called")
|
||||
|
||||
# # Execute task and expect retry
|
||||
# with pytest.raises(Exception, match="Retry called"):
|
||||
# save_workflow_node_execution_data_task(
|
||||
# execution_data=sample_execution_data,
|
||||
# tenant_id="test-tenant-id",
|
||||
# app_id="test-app-id",
|
||||
# user_data={"user_id": "test-user-id", "user_type": "account"},
|
||||
# )
|
||||
|
||||
# # Verify retry was called
|
||||
# mock_retry.assert_called_once()
|
||||
|
||||
|
||||
# class TestTruncateAndUploadAsync:
|
||||
# """Test cases for _truncate_and_upload_async function."""
|
||||
|
||||
# def test_truncate_and_upload_with_none_values(self, mock_file_service):
|
||||
# """Test _truncate_and_upload_async with None values."""
|
||||
# # The function handles None values internally, so we test with empty dict instead
|
||||
# result = _truncate_and_upload_async(
|
||||
# values={},
|
||||
# execution_id="test-id",
|
||||
# type_=ExecutionOffLoadType.INPUTS,
|
||||
# tenant_id="test-tenant",
|
||||
# app_id="test-app",
|
||||
# user_data={"user_id": "test-user", "user_type": "account"},
|
||||
# file_service=mock_file_service,
|
||||
# )
|
||||
|
||||
# # Empty dict should not require truncation
|
||||
# assert result is None
|
||||
# mock_file_service.upload_file.assert_not_called()
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks._create_truncator")
|
||||
# def test_truncate_and_upload_no_truncation_needed(self, mock_create_truncator, mock_file_service):
|
||||
# """Test _truncate_and_upload_async when no truncation is needed."""
|
||||
# # Mock truncator to return no truncation
|
||||
# mock_truncator = Mock()
|
||||
# mock_truncator.truncate_variable_mapping.return_value = ({"small": "data"}, False)
|
||||
# mock_create_truncator.return_value = mock_truncator
|
||||
|
||||
# small_values = {"small": "data"}
|
||||
# result = _truncate_and_upload_async(
|
||||
# values=small_values,
|
||||
# execution_id="test-id",
|
||||
# type_=ExecutionOffLoadType.INPUTS,
|
||||
# tenant_id="test-tenant",
|
||||
# app_id="test-app",
|
||||
# user_data={"user_id": "test-user", "user_type": "account"},
|
||||
# file_service=mock_file_service,
|
||||
# )
|
||||
|
||||
# assert result is None
|
||||
# mock_file_service.upload_file.assert_not_called()
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks._create_truncator")
|
||||
# @patch("models.Account")
|
||||
# @patch("models.Tenant")
|
||||
# def test_truncate_and_upload_with_account_user(
|
||||
# self, mock_tenant_class, mock_account_class, mock_create_truncator, mock_file_service
|
||||
# ):
|
||||
# """Test _truncate_and_upload_async with account user."""
|
||||
# # Mock truncator to return truncation needed
|
||||
# mock_truncator = Mock()
|
||||
# mock_truncator.truncate_variable_mapping.return_value = ({"truncated": "data"}, True)
|
||||
# mock_create_truncator.return_value = mock_truncator
|
||||
|
||||
# # Mock user and tenant creation
|
||||
# mock_account = Mock()
|
||||
# mock_account.id = "test-user"
|
||||
# mock_account_class.return_value = mock_account
|
||||
|
||||
# mock_tenant = Mock()
|
||||
# mock_tenant.id = "test-tenant"
|
||||
# mock_tenant_class.return_value = mock_tenant
|
||||
|
||||
# large_values = {"large": "x" * 10000}
|
||||
# result = _truncate_and_upload_async(
|
||||
# values=large_values,
|
||||
# execution_id="test-id",
|
||||
# type_=ExecutionOffLoadType.INPUTS,
|
||||
# tenant_id="test-tenant",
|
||||
# app_id="test-app",
|
||||
# user_data={"user_id": "test-user", "user_type": "account"},
|
||||
# file_service=mock_file_service,
|
||||
# )
|
||||
|
||||
# # Verify result structure
|
||||
# assert result is not None
|
||||
# assert "truncated_value" in result
|
||||
# assert "file" in result
|
||||
# assert "offload" in result
|
||||
# assert result["truncated_value"] == {"truncated": "data"}
|
||||
|
||||
# # Verify file upload was called
|
||||
# mock_file_service.upload_file.assert_called_once()
|
||||
# upload_call = mock_file_service.upload_file.call_args
|
||||
# assert upload_call[1]["filename"] == "node_execution_test-id_inputs.json"
|
||||
# assert upload_call[1]["mimetype"] == "application/json"
|
||||
# assert upload_call[1]["user"] == mock_account
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks._create_truncator")
|
||||
# @patch("models.EndUser")
|
||||
# def test_truncate_and_upload_with_end_user(self, mock_end_user_class, mock_create_truncator, mock_file_service):
|
||||
# """Test _truncate_and_upload_async with end user."""
|
||||
# # Mock truncator to return truncation needed
|
||||
# mock_truncator = Mock()
|
||||
# mock_truncator.truncate_variable_mapping.return_value = ({"truncated": "data"}, True)
|
||||
# mock_create_truncator.return_value = mock_truncator
|
||||
|
||||
# # Mock end user creation
|
||||
# mock_end_user = Mock()
|
||||
# mock_end_user.id = "test-user"
|
||||
# mock_end_user.tenant_id = "test-tenant"
|
||||
# mock_end_user_class.return_value = mock_end_user
|
||||
|
||||
# large_values = {"large": "x" * 10000}
|
||||
# result = _truncate_and_upload_async(
|
||||
# values=large_values,
|
||||
# execution_id="test-id",
|
||||
# type_=ExecutionOffLoadType.OUTPUTS,
|
||||
# tenant_id="test-tenant",
|
||||
# app_id="test-app",
|
||||
# user_data={"user_id": "test-user", "user_type": "end_user"},
|
||||
# file_service=mock_file_service,
|
||||
# )
|
||||
|
||||
# # Verify result structure
|
||||
# assert result is not None
|
||||
# assert result["truncated_value"] == {"truncated": "data"}
|
||||
|
||||
# # Verify file upload was called with end user
|
||||
# mock_file_service.upload_file.assert_called_once()
|
||||
# upload_call = mock_file_service.upload_file.call_args
|
||||
# assert upload_call[1]["filename"] == "node_execution_test-id_outputs.json"
|
||||
# assert upload_call[1]["user"] == mock_end_user
|
||||
|
||||
|
||||
# class TestHelperFunctions:
|
||||
# """Test cases for helper functions."""
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks.dify_config")
|
||||
# def test_create_truncator(self, mock_config):
|
||||
# """Test _create_truncator function."""
|
||||
# mock_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE = 1000
|
||||
# mock_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH = 100
|
||||
# mock_config.WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH = 500
|
||||
|
||||
# truncator = _create_truncator()
|
||||
|
||||
# # Verify truncator was created with correct config
|
||||
# assert truncator is not None
|
||||
|
||||
# def test_json_encode(self):
|
||||
# """Test _json_encode function."""
|
||||
# test_data = {"key": "value", "number": 42}
|
||||
# result = _json_encode(test_data)
|
||||
|
||||
# assert isinstance(result, str)
|
||||
# decoded = json.loads(result)
|
||||
# assert decoded == test_data
|
||||
|
||||
# def test_replace_or_append_offload_replace_existing(self):
|
||||
# """Test _replace_or_append_offload replaces existing offload of same type."""
|
||||
# existing_offload = WorkflowNodeExecutionOffload(
|
||||
# id=str(uuid4()),
|
||||
# tenant_id="test-tenant",
|
||||
# app_id="test-app",
|
||||
# node_execution_id="test-execution",
|
||||
# type_=ExecutionOffLoadType.INPUTS,
|
||||
# file_id="old-file-id",
|
||||
# )
|
||||
|
||||
# new_offload = WorkflowNodeExecutionOffload(
|
||||
# id=str(uuid4()),
|
||||
# tenant_id="test-tenant",
|
||||
# app_id="test-app",
|
||||
# node_execution_id="test-execution",
|
||||
# type_=ExecutionOffLoadType.INPUTS,
|
||||
# file_id="new-file-id",
|
||||
# )
|
||||
|
||||
# result = _replace_or_append_offload([existing_offload], new_offload)
|
||||
|
||||
# assert len(result) == 1
|
||||
# assert result[0].file_id == "new-file-id"
|
||||
|
||||
# def test_replace_or_append_offload_append_new_type(self):
|
||||
# """Test _replace_or_append_offload appends new offload of different type."""
|
||||
# existing_offload = WorkflowNodeExecutionOffload(
|
||||
# id=str(uuid4()),
|
||||
# tenant_id="test-tenant",
|
||||
# app_id="test-app",
|
||||
# node_execution_id="test-execution",
|
||||
# type_=ExecutionOffLoadType.INPUTS,
|
||||
# file_id="inputs-file-id",
|
||||
# )
|
||||
|
||||
# new_offload = WorkflowNodeExecutionOffload(
|
||||
# id=str(uuid4()),
|
||||
# tenant_id="test-tenant",
|
||||
# app_id="test-app",
|
||||
# node_execution_id="test-execution",
|
||||
# type_=ExecutionOffLoadType.OUTPUTS,
|
||||
# file_id="outputs-file-id",
|
||||
# )
|
||||
|
||||
# result = _replace_or_append_offload([existing_offload], new_offload)
|
||||
|
||||
# assert len(result) == 2
|
||||
# file_ids = [offload.file_id for offload in result]
|
||||
# assert "inputs-file-id" in file_ids
|
||||
# assert "outputs-file-id" in file_ids
|
||||
|
||||
|
||||
# class TestSaveWorkflowNodeExecutionTask:
|
||||
# """Test cases for save_workflow_node_execution_task."""
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
|
||||
# @patch("tasks.workflow_node_execution_tasks.select")
|
||||
# def test_save_workflow_node_execution_task_create_new(self, mock_select, mock_sessionmaker,
|
||||
# sample_execution_data):
|
||||
# """Test creating a new workflow node execution."""
|
||||
# # Setup mocks
|
||||
# mock_session = MagicMock()
|
||||
# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session
|
||||
# mock_session.scalar.return_value = None # No existing execution
|
||||
|
||||
# # Execute task
|
||||
# result = save_workflow_node_execution_task(
|
||||
# execution_data=sample_execution_data,
|
||||
# tenant_id="test-tenant-id",
|
||||
# app_id="test-app-id",
|
||||
# triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
# creator_user_id="test-user-id",
|
||||
# creator_user_role="account",
|
||||
# )
|
||||
|
||||
# # Verify success
|
||||
# assert result is True
|
||||
# mock_session.add.assert_called_once()
|
||||
# mock_session.commit.assert_called_once()
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
|
||||
# @patch("tasks.workflow_node_execution_tasks.select")
|
||||
# def test_save_workflow_node_execution_task_update_existing(
|
||||
# self, mock_select, mock_sessionmaker, sample_execution_data
|
||||
# ):
|
||||
# """Test updating an existing workflow node execution."""
|
||||
# # Setup mocks
|
||||
# mock_session = MagicMock()
|
||||
# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# existing_execution = Mock(spec=WorkflowNodeExecutionModel)
|
||||
# mock_session.scalar.return_value = existing_execution
|
||||
|
||||
# # Execute task
|
||||
# result = save_workflow_node_execution_task(
|
||||
# execution_data=sample_execution_data,
|
||||
# tenant_id="test-tenant-id",
|
||||
# app_id="test-app-id",
|
||||
# triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
# creator_user_id="test-user-id",
|
||||
# creator_user_role="account",
|
||||
# )
|
||||
|
||||
# # Verify success
|
||||
# assert result is True
|
||||
# mock_session.add.assert_not_called() # Should not add new, just update existing
|
||||
# mock_session.commit.assert_called_once()
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
|
||||
# def test_save_workflow_node_execution_task_retry_on_exception(self, mock_sessionmaker, sample_execution_data):
|
||||
# """Test task retry mechanism on exception."""
|
||||
# # Setup mock to raise exception
|
||||
# mock_sessionmaker.side_effect = Exception("Database error")
|
||||
|
||||
# # Create a mock task instance with proper retry behavior
|
||||
# with patch.object(save_workflow_node_execution_task, "retry") as mock_retry:
|
||||
# mock_retry.side_effect = Exception("Retry called")
|
||||
|
||||
# # Execute task and expect retry
|
||||
# with pytest.raises(Exception, match="Retry called"):
|
||||
# save_workflow_node_execution_task(
|
||||
# execution_data=sample_execution_data,
|
||||
# tenant_id="test-tenant-id",
|
||||
# app_id="test-app-id",
|
||||
# triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
# creator_user_id="test-user-id",
|
||||
# creator_user_role="account",
|
||||
# )
|
||||
|
||||
# # Verify retry was called
|
||||
# mock_retry.assert_called_once()
|
||||
from collections.abc import Mapping
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from graphon.entities.workflow_node_execution import WorkflowNodeExecution
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from models import Account, EndUser
|
||||
from models.enums import CreatorUserRole
|
||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom
|
||||
from tasks.workflow_node_execution_tasks import (
|
||||
_create_node_execution_from_domain,
|
||||
_create_sqlalchemy_repository,
|
||||
_update_node_execution_metadata,
|
||||
save_workflow_node_execution_data_task,
|
||||
save_workflow_node_execution_task,
|
||||
)
|
||||
|
||||
|
||||
def _execution(
|
||||
*,
|
||||
metadata: Mapping[WorkflowNodeExecutionMetadataKey, object] | None = None,
|
||||
) -> WorkflowNodeExecution:
|
||||
if metadata is None:
|
||||
metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 10}
|
||||
|
||||
return WorkflowNodeExecution(
|
||||
id="exec-id",
|
||||
node_execution_id="node-exec-id",
|
||||
workflow_id="workflow-id",
|
||||
workflow_execution_id="run-id",
|
||||
index=1,
|
||||
node_id="node-id",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
title="LLM",
|
||||
inputs={"input": "value"},
|
||||
process_data={"process": "value"},
|
||||
outputs={"output": "value"},
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
metadata=metadata,
|
||||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
finished_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
|
||||
def test_create_node_execution_persists_metadata_without_data_payloads() -> None:
|
||||
db_model = _create_node_execution_from_domain(
|
||||
execution=_execution(),
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
creator_user_id="user-id",
|
||||
creator_user_role=CreatorUserRole.ACCOUNT,
|
||||
)
|
||||
|
||||
assert db_model.inputs == "{}"
|
||||
assert db_model.process_data == "{}"
|
||||
assert db_model.outputs == "{}"
|
||||
assert db_model.execution_metadata == '{"total_tokens": 10}'
|
||||
|
||||
|
||||
def test_create_node_execution_defaults_empty_metadata() -> None:
|
||||
db_model = _create_node_execution_from_domain(
|
||||
execution=_execution(metadata={}),
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
creator_user_id="user-id",
|
||||
creator_user_role=CreatorUserRole.ACCOUNT,
|
||||
)
|
||||
|
||||
assert db_model.execution_metadata == "{}"
|
||||
|
||||
|
||||
def test_update_node_execution_metadata_preserves_data_payloads() -> None:
|
||||
db_model = WorkflowNodeExecutionModel()
|
||||
db_model.inputs = '{"old_input": true}'
|
||||
db_model.process_data = '{"old_process": true}'
|
||||
db_model.outputs = '{"old_output": true}'
|
||||
|
||||
_update_node_execution_metadata(db_model, _execution())
|
||||
|
||||
assert db_model.inputs == '{"old_input": true}'
|
||||
assert db_model.process_data == '{"old_process": true}'
|
||||
assert db_model.outputs == '{"old_output": true}'
|
||||
assert db_model.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
|
||||
|
||||
def test_update_node_execution_metadata_defaults_empty_metadata() -> None:
|
||||
db_model = WorkflowNodeExecutionModel()
|
||||
|
||||
_update_node_execution_metadata(db_model, _execution(metadata={}))
|
||||
|
||||
assert db_model.execution_metadata == "{}"
|
||||
|
||||
|
||||
@patch("tasks.workflow_node_execution_tasks._create_sqlalchemy_repository")
|
||||
def test_save_workflow_node_execution_data_task_uses_sqlalchemy_repository(mock_create_repository: Mock) -> None:
|
||||
repository = Mock()
|
||||
mock_create_repository.return_value = repository
|
||||
execution = _execution()
|
||||
|
||||
result = save_workflow_node_execution_data_task.run(
|
||||
execution_data=execution.model_dump(),
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
creator_user_id="user-id",
|
||||
creator_user_role=CreatorUserRole.ACCOUNT.value,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
mock_create_repository.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
creator_user_id="user-id",
|
||||
creator_user_role=CreatorUserRole.ACCOUNT.value,
|
||||
)
|
||||
saved_execution = repository.save.call_args.args[0]
|
||||
saved_data_execution = repository.save_execution_data.call_args.args[0]
|
||||
assert saved_execution.model_dump() == execution.model_dump()
|
||||
assert saved_data_execution.model_dump() == execution.model_dump()
|
||||
|
||||
|
||||
@patch("tasks.workflow_node_execution_tasks._create_sqlalchemy_repository")
|
||||
def test_save_workflow_node_execution_data_task_retries_on_failure(mock_create_repository: Mock) -> None:
|
||||
mock_create_repository.side_effect = RuntimeError("db unavailable")
|
||||
execution = _execution()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
save_workflow_node_execution_data_task,
|
||||
"retry",
|
||||
side_effect=RuntimeError("retry requested"),
|
||||
) as retry,
|
||||
pytest.raises(RuntimeError, match="retry requested"),
|
||||
):
|
||||
save_workflow_node_execution_data_task.run(
|
||||
execution_data=execution.model_dump(),
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
creator_user_id="user-id",
|
||||
creator_user_role=CreatorUserRole.ACCOUNT.value,
|
||||
)
|
||||
|
||||
retry.assert_called_once()
|
||||
assert isinstance(retry.call_args.kwargs["exc"], RuntimeError)
|
||||
assert retry.call_args.kwargs["countdown"] == 60
|
||||
|
||||
|
||||
@patch("tasks.workflow_node_execution_tasks.session_factory.create_session")
|
||||
def test_save_workflow_node_execution_task_creates_metadata_record(mock_create_session: Mock) -> None:
|
||||
session = _TaskSession(existing_execution=None)
|
||||
mock_create_session.return_value = session
|
||||
|
||||
result = save_workflow_node_execution_task.run(
|
||||
execution_data=_execution().model_dump(),
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
creator_user_id="user-id",
|
||||
creator_user_role=CreatorUserRole.ACCOUNT.value,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
assert session.committed is True
|
||||
assert isinstance(session.added_execution, WorkflowNodeExecutionModel)
|
||||
assert session.added_execution.inputs == "{}"
|
||||
assert session.added_execution.process_data == "{}"
|
||||
assert session.added_execution.outputs == "{}"
|
||||
|
||||
|
||||
@patch("tasks.workflow_node_execution_tasks.session_factory.create_session")
|
||||
def test_save_workflow_node_execution_task_updates_metadata_without_payloads(mock_create_session: Mock) -> None:
|
||||
existing_execution = WorkflowNodeExecutionModel()
|
||||
existing_execution.inputs = '{"old_input": true}'
|
||||
existing_execution.process_data = '{"old_process": true}'
|
||||
existing_execution.outputs = '{"old_output": true}'
|
||||
session = _TaskSession(existing_execution=existing_execution)
|
||||
mock_create_session.return_value = session
|
||||
|
||||
result = save_workflow_node_execution_task.run(
|
||||
execution_data=_execution().model_dump(),
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
creator_user_id="user-id",
|
||||
creator_user_role=CreatorUserRole.ACCOUNT.value,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
assert session.committed is True
|
||||
assert session.added_execution is None
|
||||
assert existing_execution.inputs == '{"old_input": true}'
|
||||
assert existing_execution.process_data == '{"old_process": true}'
|
||||
assert existing_execution.outputs == '{"old_output": true}'
|
||||
assert existing_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
|
||||
|
||||
def test_create_sqlalchemy_repository_builds_account_context(monkeypatch) -> None:
|
||||
account = Mock()
|
||||
session = _Session({Account: account})
|
||||
|
||||
def session_maker():
|
||||
return session
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tasks.workflow_node_execution_tasks.session_factory.get_session_maker",
|
||||
lambda: session_maker,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository"
|
||||
) as repository_class:
|
||||
repository = _create_sqlalchemy_repository(
|
||||
tenant_id="tenant-id",
|
||||
app_id="",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
creator_user_id="user-id",
|
||||
creator_user_role=CreatorUserRole.ACCOUNT.value,
|
||||
)
|
||||
|
||||
assert repository == repository_class.return_value
|
||||
account.set_tenant_id.assert_called_once_with("tenant-id")
|
||||
repository_class.assert_called_once_with(
|
||||
session_factory=session_maker,
|
||||
user=account,
|
||||
app_id=None,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
|
||||
def test_create_sqlalchemy_repository_builds_end_user_context(monkeypatch) -> None:
|
||||
end_user = Mock()
|
||||
session = _Session({EndUser: end_user})
|
||||
|
||||
def session_maker():
|
||||
return session
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tasks.workflow_node_execution_tasks.session_factory.get_session_maker",
|
||||
lambda: session_maker,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository"
|
||||
) as repository_class:
|
||||
repository = _create_sqlalchemy_repository(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
creator_user_id="user-id",
|
||||
creator_user_role=CreatorUserRole.END_USER.value,
|
||||
)
|
||||
|
||||
assert repository == repository_class.return_value
|
||||
repository_class.assert_called_once_with(
|
||||
session_factory=session_maker,
|
||||
user=end_user,
|
||||
app_id="app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
|
||||
def test_create_sqlalchemy_repository_raises_for_missing_creator(monkeypatch) -> None:
|
||||
session = _Session({})
|
||||
|
||||
def session_maker():
|
||||
return session
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tasks.workflow_node_execution_tasks.session_factory.get_session_maker",
|
||||
lambda: session_maker,
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository"
|
||||
),
|
||||
pytest.raises(ValueError, match="Creator user missing-user not found"),
|
||||
):
|
||||
_create_sqlalchemy_repository(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
creator_user_id="missing-user",
|
||||
creator_user_role=CreatorUserRole.ACCOUNT.value,
|
||||
)
|
||||
|
||||
|
||||
class _Session:
|
||||
def __init__(self, users: dict[type, object]) -> None:
|
||||
self._users = users
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, traceback) -> None:
|
||||
return None
|
||||
|
||||
def get(self, model, _id: str):
|
||||
return self._users.get(model)
|
||||
|
||||
|
||||
class _TaskSession:
|
||||
def __init__(self, existing_execution: WorkflowNodeExecutionModel | None) -> None:
|
||||
self._existing_execution = existing_execution
|
||||
self.added_execution: WorkflowNodeExecutionModel | None = None
|
||||
self.committed = False
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, traceback) -> None:
|
||||
return None
|
||||
|
||||
def scalar(self, _stmt):
|
||||
return self._existing_execution
|
||||
|
||||
def add(self, execution: WorkflowNodeExecutionModel) -> None:
|
||||
self.added_execution = execution
|
||||
|
||||
def commit(self) -> None:
|
||||
self.committed = True
|
||||
|
||||
@ -330,7 +330,6 @@ BAIDU_VECTOR_DB_AUTO_BUILD_ROW_COUNT_INCREMENT_RATIO=0.05
|
||||
BAIDU_VECTOR_DB_REBUILD_INDEX_TIMEOUT_IN_SECONDS=300
|
||||
HUAWEI_CLOUD_HOSTS=https://127.0.0.1:9200
|
||||
HUAWEI_CLOUD_USER=admin
|
||||
WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
|
||||
CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository
|
||||
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository
|
||||
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
|
||||
|
||||
Reference in New Issue
Block a user