mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 17:38:04 +08:00
Merge branch 'main' into deploy/dev
# Conflicts: # api/core/app/apps/advanced_chat/app_runner.py # api/core/app/apps/pipeline/pipeline_generator.py # api/core/entities/mcp_provider.py # api/core/helper/marketplace.py # api/models/workflow.py # api/services/tools/tools_transform_service.py # api/tasks/document_indexing_task.py # api/tests/test_containers_integration_tests/core/__init__.py # api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py # dev/start-worker # docker/.env.example # web/app/components/base/chat/embedded-chatbot/hooks.tsx # web/app/components/workflow/hooks/use-workflow.ts # web/app/components/workflow/nodes/_base/components/workflow-panel/index.tsx # web/global.d.ts # web/pnpm-lock.yaml # web/service/use-plugins.ts
This commit is contained in:
258
api/tests/fixtures/workflow/iteration_flatten_output_disabled_workflow.yml
vendored
Normal file
258
api/tests/fixtures/workflow/iteration_flatten_output_disabled_workflow.yml
vendored
Normal file
@ -0,0 +1,258 @@
|
||||
app:
|
||||
description: 'This workflow tests the iteration node with flatten_output=False.
|
||||
|
||||
|
||||
It processes [1, 2, 3], outputs [item, item*2] for each iteration.
|
||||
|
||||
|
||||
With flatten_output=False, it should output nested arrays:
|
||||
|
||||
|
||||
```
|
||||
|
||||
{"output": [[1, 2], [2, 4], [3, 6]]}
|
||||
|
||||
```'
|
||||
icon: 🤖
|
||||
icon_background: '#FFEAD5'
|
||||
mode: workflow
|
||||
name: test_iteration_flatten_disabled
|
||||
use_icon_as_answer_icon: false
|
||||
dependencies: []
|
||||
kind: app
|
||||
version: 0.3.1
|
||||
workflow:
|
||||
conversation_variables: []
|
||||
environment_variables: []
|
||||
features:
|
||||
file_upload:
|
||||
enabled: false
|
||||
opening_statement: ''
|
||||
retriever_resource:
|
||||
enabled: true
|
||||
sensitive_word_avoidance:
|
||||
enabled: false
|
||||
speech_to_text:
|
||||
enabled: false
|
||||
suggested_questions: []
|
||||
suggested_questions_after_answer:
|
||||
enabled: false
|
||||
text_to_speech:
|
||||
enabled: false
|
||||
graph:
|
||||
edges:
|
||||
- data:
|
||||
isInIteration: false
|
||||
isInLoop: false
|
||||
sourceType: start
|
||||
targetType: code
|
||||
id: start-source-code-target
|
||||
source: start_node
|
||||
sourceHandle: source
|
||||
target: code_node
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
- data:
|
||||
isInIteration: false
|
||||
isInLoop: false
|
||||
sourceType: code
|
||||
targetType: iteration
|
||||
id: code-source-iteration-target
|
||||
source: code_node
|
||||
sourceHandle: source
|
||||
target: iteration_node
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
- data:
|
||||
isInIteration: true
|
||||
isInLoop: false
|
||||
iteration_id: iteration_node
|
||||
sourceType: iteration-start
|
||||
targetType: code
|
||||
id: iteration-start-source-code-inner-target
|
||||
source: iteration_nodestart
|
||||
sourceHandle: source
|
||||
target: code_inner_node
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 1002
|
||||
- data:
|
||||
isInIteration: false
|
||||
isInLoop: false
|
||||
sourceType: iteration
|
||||
targetType: end
|
||||
id: iteration-source-end-target
|
||||
source: iteration_node
|
||||
sourceHandle: source
|
||||
target: end_node
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
nodes:
|
||||
- data:
|
||||
desc: ''
|
||||
selected: false
|
||||
title: Start
|
||||
type: start
|
||||
variables: []
|
||||
height: 54
|
||||
id: start_node
|
||||
position:
|
||||
x: 80
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 80
|
||||
y: 282
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 244
|
||||
- data:
|
||||
code: "\ndef main() -> dict:\n return {\n \"result\": [1, 2, 3],\n\
|
||||
\ }\n"
|
||||
code_language: python3
|
||||
desc: ''
|
||||
outputs:
|
||||
result:
|
||||
children: null
|
||||
type: array[number]
|
||||
selected: false
|
||||
title: Generate Array
|
||||
type: code
|
||||
variables: []
|
||||
height: 54
|
||||
id: code_node
|
||||
position:
|
||||
x: 384
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 384
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 244
|
||||
- data:
|
||||
desc: ''
|
||||
error_handle_mode: terminated
|
||||
flatten_output: false
|
||||
height: 178
|
||||
is_parallel: false
|
||||
iterator_input_type: array[number]
|
||||
iterator_selector:
|
||||
- code_node
|
||||
- result
|
||||
output_selector:
|
||||
- code_inner_node
|
||||
- result
|
||||
output_type: array[array[number]]
|
||||
parallel_nums: 10
|
||||
selected: false
|
||||
start_node_id: iteration_nodestart
|
||||
title: Iteration with Flatten Disabled
|
||||
type: iteration
|
||||
width: 388
|
||||
height: 178
|
||||
id: iteration_node
|
||||
position:
|
||||
x: 684
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 684
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 388
|
||||
zIndex: 1
|
||||
- data:
|
||||
desc: ''
|
||||
isInIteration: true
|
||||
selected: false
|
||||
title: ''
|
||||
type: iteration-start
|
||||
draggable: false
|
||||
height: 48
|
||||
id: iteration_nodestart
|
||||
parentId: iteration_node
|
||||
position:
|
||||
x: 24
|
||||
y: 68
|
||||
positionAbsolute:
|
||||
x: 708
|
||||
y: 350
|
||||
selectable: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom-iteration-start
|
||||
width: 44
|
||||
zIndex: 1002
|
||||
- data:
|
||||
code: "\ndef main(arg1: int) -> dict:\n return {\n \"result\": [arg1,\
|
||||
\ arg1 * 2],\n }\n"
|
||||
code_language: python3
|
||||
desc: ''
|
||||
isInIteration: true
|
||||
isInLoop: false
|
||||
iteration_id: iteration_node
|
||||
outputs:
|
||||
result:
|
||||
children: null
|
||||
type: array[number]
|
||||
selected: false
|
||||
title: Generate Pair
|
||||
type: code
|
||||
variables:
|
||||
- value_selector:
|
||||
- iteration_node
|
||||
- item
|
||||
value_type: number
|
||||
variable: arg1
|
||||
height: 54
|
||||
id: code_inner_node
|
||||
parentId: iteration_node
|
||||
position:
|
||||
x: 128
|
||||
y: 68
|
||||
positionAbsolute:
|
||||
x: 812
|
||||
y: 350
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 244
|
||||
zIndex: 1002
|
||||
- data:
|
||||
desc: ''
|
||||
outputs:
|
||||
- value_selector:
|
||||
- iteration_node
|
||||
- output
|
||||
value_type: array[array[number]]
|
||||
variable: output
|
||||
selected: false
|
||||
title: End
|
||||
type: end
|
||||
height: 90
|
||||
id: end_node
|
||||
position:
|
||||
x: 1132
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 1132
|
||||
y: 282
|
||||
selected: true
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 244
|
||||
viewport:
|
||||
x: -476
|
||||
y: 3
|
||||
zoom: 1
|
||||
|
||||
258
api/tests/fixtures/workflow/iteration_flatten_output_enabled_workflow.yml
vendored
Normal file
258
api/tests/fixtures/workflow/iteration_flatten_output_enabled_workflow.yml
vendored
Normal file
@ -0,0 +1,258 @@
|
||||
app:
|
||||
description: 'This workflow tests the iteration node with flatten_output=True.
|
||||
|
||||
|
||||
It processes [1, 2, 3], outputs [item, item*2] for each iteration.
|
||||
|
||||
|
||||
With flatten_output=True (default), it should output:
|
||||
|
||||
|
||||
```
|
||||
|
||||
{"output": [1, 2, 2, 4, 3, 6]}
|
||||
|
||||
```'
|
||||
icon: 🤖
|
||||
icon_background: '#FFEAD5'
|
||||
mode: workflow
|
||||
name: test_iteration_flatten_enabled
|
||||
use_icon_as_answer_icon: false
|
||||
dependencies: []
|
||||
kind: app
|
||||
version: 0.3.1
|
||||
workflow:
|
||||
conversation_variables: []
|
||||
environment_variables: []
|
||||
features:
|
||||
file_upload:
|
||||
enabled: false
|
||||
opening_statement: ''
|
||||
retriever_resource:
|
||||
enabled: true
|
||||
sensitive_word_avoidance:
|
||||
enabled: false
|
||||
speech_to_text:
|
||||
enabled: false
|
||||
suggested_questions: []
|
||||
suggested_questions_after_answer:
|
||||
enabled: false
|
||||
text_to_speech:
|
||||
enabled: false
|
||||
graph:
|
||||
edges:
|
||||
- data:
|
||||
isInIteration: false
|
||||
isInLoop: false
|
||||
sourceType: start
|
||||
targetType: code
|
||||
id: start-source-code-target
|
||||
source: start_node
|
||||
sourceHandle: source
|
||||
target: code_node
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
- data:
|
||||
isInIteration: false
|
||||
isInLoop: false
|
||||
sourceType: code
|
||||
targetType: iteration
|
||||
id: code-source-iteration-target
|
||||
source: code_node
|
||||
sourceHandle: source
|
||||
target: iteration_node
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
- data:
|
||||
isInIteration: true
|
||||
isInLoop: false
|
||||
iteration_id: iteration_node
|
||||
sourceType: iteration-start
|
||||
targetType: code
|
||||
id: iteration-start-source-code-inner-target
|
||||
source: iteration_nodestart
|
||||
sourceHandle: source
|
||||
target: code_inner_node
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 1002
|
||||
- data:
|
||||
isInIteration: false
|
||||
isInLoop: false
|
||||
sourceType: iteration
|
||||
targetType: end
|
||||
id: iteration-source-end-target
|
||||
source: iteration_node
|
||||
sourceHandle: source
|
||||
target: end_node
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
nodes:
|
||||
- data:
|
||||
desc: ''
|
||||
selected: false
|
||||
title: Start
|
||||
type: start
|
||||
variables: []
|
||||
height: 54
|
||||
id: start_node
|
||||
position:
|
||||
x: 80
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 80
|
||||
y: 282
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 244
|
||||
- data:
|
||||
code: "\ndef main() -> dict:\n return {\n \"result\": [1, 2, 3],\n\
|
||||
\ }\n"
|
||||
code_language: python3
|
||||
desc: ''
|
||||
outputs:
|
||||
result:
|
||||
children: null
|
||||
type: array[number]
|
||||
selected: false
|
||||
title: Generate Array
|
||||
type: code
|
||||
variables: []
|
||||
height: 54
|
||||
id: code_node
|
||||
position:
|
||||
x: 384
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 384
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 244
|
||||
- data:
|
||||
desc: ''
|
||||
error_handle_mode: terminated
|
||||
flatten_output: true
|
||||
height: 178
|
||||
is_parallel: false
|
||||
iterator_input_type: array[number]
|
||||
iterator_selector:
|
||||
- code_node
|
||||
- result
|
||||
output_selector:
|
||||
- code_inner_node
|
||||
- result
|
||||
output_type: array[array[number]]
|
||||
parallel_nums: 10
|
||||
selected: false
|
||||
start_node_id: iteration_nodestart
|
||||
title: Iteration with Flatten Enabled
|
||||
type: iteration
|
||||
width: 388
|
||||
height: 178
|
||||
id: iteration_node
|
||||
position:
|
||||
x: 684
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 684
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 388
|
||||
zIndex: 1
|
||||
- data:
|
||||
desc: ''
|
||||
isInIteration: true
|
||||
selected: false
|
||||
title: ''
|
||||
type: iteration-start
|
||||
draggable: false
|
||||
height: 48
|
||||
id: iteration_nodestart
|
||||
parentId: iteration_node
|
||||
position:
|
||||
x: 24
|
||||
y: 68
|
||||
positionAbsolute:
|
||||
x: 708
|
||||
y: 350
|
||||
selectable: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom-iteration-start
|
||||
width: 44
|
||||
zIndex: 1002
|
||||
- data:
|
||||
code: "\ndef main(arg1: int) -> dict:\n return {\n \"result\": [arg1,\
|
||||
\ arg1 * 2],\n }\n"
|
||||
code_language: python3
|
||||
desc: ''
|
||||
isInIteration: true
|
||||
isInLoop: false
|
||||
iteration_id: iteration_node
|
||||
outputs:
|
||||
result:
|
||||
children: null
|
||||
type: array[number]
|
||||
selected: false
|
||||
title: Generate Pair
|
||||
type: code
|
||||
variables:
|
||||
- value_selector:
|
||||
- iteration_node
|
||||
- item
|
||||
value_type: number
|
||||
variable: arg1
|
||||
height: 54
|
||||
id: code_inner_node
|
||||
parentId: iteration_node
|
||||
position:
|
||||
x: 128
|
||||
y: 68
|
||||
positionAbsolute:
|
||||
x: 812
|
||||
y: 350
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 244
|
||||
zIndex: 1002
|
||||
- data:
|
||||
desc: ''
|
||||
outputs:
|
||||
- value_selector:
|
||||
- iteration_node
|
||||
- output
|
||||
value_type: array[number]
|
||||
variable: output
|
||||
selected: false
|
||||
title: End
|
||||
type: end
|
||||
height: 90
|
||||
id: end_node
|
||||
position:
|
||||
x: 1132
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 1132
|
||||
y: 282
|
||||
selected: true
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 244
|
||||
viewport:
|
||||
x: -476
|
||||
y: 3
|
||||
zoom: 1
|
||||
|
||||
@ -1 +1 @@
|
||||
# Test containers integration tests for core RAG pipeline components
|
||||
# Core integration tests package
|
||||
|
||||
@ -0,0 +1 @@
|
||||
# App integration tests package
|
||||
@ -0,0 +1 @@
|
||||
# Layers integration tests package
|
||||
@ -0,0 +1,520 @@
|
||||
"""Comprehensive TestContainers-based integration tests for PauseStatePersistenceLayer class.
|
||||
|
||||
This test suite covers complete integration scenarios including:
|
||||
- Real database interactions using containerized PostgreSQL
|
||||
- Real storage operations using test storage backend
|
||||
- Complete workflow: event -> state serialization -> database save -> storage save
|
||||
- Testing with actual WorkflowRunService (not mocked)
|
||||
- Real Workflow and WorkflowRun instances in database
|
||||
- Database transactions and rollback behavior
|
||||
- Actual file upload and retrieval through storage
|
||||
- Workflow status transitions in database
|
||||
- Error handling with real database constraints
|
||||
- Multiple pause events in sequence
|
||||
- Integration with real ReadOnlyGraphRuntimeState implementations
|
||||
|
||||
These tests use TestContainers to spin up real services for integration testing,
|
||||
providing more reliable and realistic test scenarios than mocks.
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from time import time
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine, delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.layers.pause_state_persist_layer import PauseStatePersistenceLayer
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.graph_engine.entities.commands import GraphEngineCommand
|
||||
from core.workflow.graph_events.graph import GraphRunPausedEvent
|
||||
from core.workflow.runtime.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState
|
||||
from core.workflow.runtime.read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper
|
||||
from core.workflow.runtime.variable_pool import SystemVariable, VariablePool
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
from models import WorkflowPause as WorkflowPauseModel
|
||||
from models.model import UploadFile
|
||||
from models.workflow import Workflow, WorkflowRun
|
||||
from services.file_service import FileService
|
||||
from services.workflow_run_service import WorkflowRunService
|
||||
|
||||
|
||||
class _TestCommandChannelImpl:
|
||||
"""Real implementation of CommandChannel for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self._commands: list[GraphEngineCommand] = []
|
||||
|
||||
def fetch_commands(self) -> list[GraphEngineCommand]:
|
||||
"""Fetch pending commands for this GraphEngine instance."""
|
||||
return self._commands.copy()
|
||||
|
||||
def send_command(self, command: GraphEngineCommand) -> None:
|
||||
"""Send a command to be processed by this GraphEngine instance."""
|
||||
self._commands.append(command)
|
||||
|
||||
|
||||
class TestPauseStatePersistenceLayerTestContainers:
|
||||
"""Comprehensive TestContainers-based integration tests for PauseStatePersistenceLayer class."""
|
||||
|
||||
@pytest.fixture
|
||||
def engine(self, db_session_with_containers: Session):
|
||||
"""Get database engine from TestContainers session."""
|
||||
bind = db_session_with_containers.get_bind()
|
||||
assert isinstance(bind, Engine)
|
||||
return bind
|
||||
|
||||
@pytest.fixture
|
||||
def file_service(self, engine: Engine):
|
||||
"""Create FileService instance with TestContainers engine."""
|
||||
return FileService(engine)
|
||||
|
||||
@pytest.fixture
|
||||
def workflow_run_service(self, engine: Engine, file_service: FileService):
|
||||
"""Create WorkflowRunService instance with TestContainers engine and FileService."""
|
||||
return WorkflowRunService(engine)
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_test_data(self, db_session_with_containers, file_service, workflow_run_service):
|
||||
"""Set up test data for each test method using TestContainers."""
|
||||
# Create test tenant and account
|
||||
from models.account import Tenant, TenantAccountJoin, TenantAccountRole
|
||||
|
||||
tenant = Tenant(
|
||||
name="Test Tenant",
|
||||
status="normal",
|
||||
)
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
account = Account(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
tenant_join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db_session_with_containers.add(tenant_join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set test data
|
||||
self.test_tenant_id = tenant.id
|
||||
self.test_user_id = account.id
|
||||
self.test_app_id = str(uuid.uuid4())
|
||||
self.test_workflow_id = str(uuid.uuid4())
|
||||
self.test_workflow_run_id = str(uuid.uuid4())
|
||||
|
||||
# Create test workflow
|
||||
self.test_workflow = Workflow(
|
||||
id=self.test_workflow_id,
|
||||
tenant_id=self.test_tenant_id,
|
||||
app_id=self.test_app_id,
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph='{"nodes": [], "edges": []}',
|
||||
features='{"file_upload": {"enabled": false}}',
|
||||
created_by=self.test_user_id,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Create test workflow run
|
||||
self.test_workflow_run = WorkflowRun(
|
||||
id=self.test_workflow_run_id,
|
||||
tenant_id=self.test_tenant_id,
|
||||
app_id=self.test_app_id,
|
||||
workflow_id=self.test_workflow_id,
|
||||
type="workflow",
|
||||
triggered_from="debugging",
|
||||
version="draft",
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
created_by=self.test_user_id,
|
||||
created_by_role="account",
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Store session and service instances
|
||||
self.session = db_session_with_containers
|
||||
self.file_service = file_service
|
||||
self.workflow_run_service = workflow_run_service
|
||||
|
||||
# Save test data to database
|
||||
self.session.add(self.test_workflow)
|
||||
self.session.add(self.test_workflow_run)
|
||||
self.session.commit()
|
||||
|
||||
yield
|
||||
|
||||
# Cleanup
|
||||
self._cleanup_test_data()
|
||||
|
||||
def _cleanup_test_data(self):
|
||||
"""Clean up test data after each test method."""
|
||||
try:
|
||||
# Clean up workflow pauses
|
||||
self.session.execute(delete(WorkflowPauseModel))
|
||||
# Clean up upload files
|
||||
self.session.execute(
|
||||
delete(UploadFile).where(
|
||||
UploadFile.tenant_id == self.test_tenant_id,
|
||||
)
|
||||
)
|
||||
# Clean up workflow runs
|
||||
self.session.execute(
|
||||
delete(WorkflowRun).where(
|
||||
WorkflowRun.tenant_id == self.test_tenant_id,
|
||||
WorkflowRun.app_id == self.test_app_id,
|
||||
)
|
||||
)
|
||||
# Clean up workflows
|
||||
self.session.execute(
|
||||
delete(Workflow).where(
|
||||
Workflow.tenant_id == self.test_tenant_id,
|
||||
Workflow.app_id == self.test_app_id,
|
||||
)
|
||||
)
|
||||
self.session.commit()
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
raise e
|
||||
|
||||
def _create_graph_runtime_state(
|
||||
self,
|
||||
outputs: dict[str, object] | None = None,
|
||||
total_tokens: int = 0,
|
||||
node_run_steps: int = 0,
|
||||
variables: dict[tuple[str, str], object] | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
) -> ReadOnlyGraphRuntimeState:
|
||||
"""Create a real GraphRuntimeState for testing."""
|
||||
start_at = time()
|
||||
|
||||
execution_id = workflow_run_id or getattr(self, "test_workflow_run_id", None) or str(uuid.uuid4())
|
||||
|
||||
# Create variable pool
|
||||
variable_pool = VariablePool(system_variables=SystemVariable(workflow_execution_id=execution_id))
|
||||
if variables:
|
||||
for (node_id, var_key), value in variables.items():
|
||||
variable_pool.add([node_id, var_key], value)
|
||||
|
||||
# Create LLM usage
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
|
||||
# Create graph runtime state
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=start_at,
|
||||
total_tokens=total_tokens,
|
||||
llm_usage=llm_usage,
|
||||
outputs=outputs or {},
|
||||
node_run_steps=node_run_steps,
|
||||
)
|
||||
|
||||
return ReadOnlyGraphRuntimeStateWrapper(graph_runtime_state)
|
||||
|
||||
def _create_pause_state_persistence_layer(
|
||||
self,
|
||||
workflow_run: WorkflowRun | None = None,
|
||||
workflow: Workflow | None = None,
|
||||
state_owner_user_id: str | None = None,
|
||||
) -> PauseStatePersistenceLayer:
|
||||
"""Create PauseStatePersistenceLayer with real dependencies."""
|
||||
owner_id = state_owner_user_id
|
||||
if owner_id is None:
|
||||
if workflow is not None and workflow.created_by:
|
||||
owner_id = workflow.created_by
|
||||
elif workflow_run is not None and workflow_run.created_by:
|
||||
owner_id = workflow_run.created_by
|
||||
else:
|
||||
owner_id = getattr(self, "test_user_id", None)
|
||||
|
||||
assert owner_id is not None
|
||||
owner_id = str(owner_id)
|
||||
|
||||
return PauseStatePersistenceLayer(
|
||||
session_factory=self.session.get_bind(),
|
||||
state_owner_user_id=owner_id,
|
||||
)
|
||||
|
||||
def test_complete_pause_flow_with_real_dependencies(self, db_session_with_containers):
|
||||
"""Test complete pause flow: event -> state serialization -> database save -> storage save."""
|
||||
# Arrange
|
||||
layer = self._create_pause_state_persistence_layer()
|
||||
|
||||
# Create real graph runtime state with test data
|
||||
test_outputs = {"result": "test_output", "step": "intermediate"}
|
||||
test_variables = {
|
||||
("node1", "var1"): "string_value",
|
||||
("node2", "var2"): {"complex": "object"},
|
||||
}
|
||||
graph_runtime_state = self._create_graph_runtime_state(
|
||||
outputs=test_outputs,
|
||||
total_tokens=100,
|
||||
node_run_steps=5,
|
||||
variables=test_variables,
|
||||
)
|
||||
|
||||
command_channel = _TestCommandChannelImpl()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
# Create pause event
|
||||
event = GraphRunPausedEvent(
|
||||
reason=SchedulingPause(message="test pause"),
|
||||
outputs={"intermediate": "result"},
|
||||
)
|
||||
|
||||
# Act
|
||||
layer.on_event(event)
|
||||
|
||||
# Assert - Verify pause state was saved to database
|
||||
self.session.refresh(self.test_workflow_run)
|
||||
workflow_run = self.session.get(WorkflowRun, self.test_workflow_run_id)
|
||||
assert workflow_run is not None
|
||||
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||
|
||||
# Verify pause state exists in database
|
||||
pause_model = self.session.scalars(
|
||||
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
|
||||
).first()
|
||||
assert pause_model is not None
|
||||
assert pause_model.workflow_id == self.test_workflow_id
|
||||
assert pause_model.workflow_run_id == self.test_workflow_run_id
|
||||
assert pause_model.state_object_key != ""
|
||||
assert pause_model.resumed_at is None
|
||||
|
||||
storage_content = storage.load(pause_model.state_object_key).decode()
|
||||
expected_state = json.loads(graph_runtime_state.dumps())
|
||||
actual_state = json.loads(storage_content)
|
||||
|
||||
assert actual_state == expected_state
|
||||
|
||||
def test_state_persistence_and_retrieval(self, db_session_with_containers):
|
||||
"""Test that pause state can be persisted and retrieved correctly."""
|
||||
# Arrange
|
||||
layer = self._create_pause_state_persistence_layer()
|
||||
|
||||
# Create complex test data
|
||||
complex_outputs = {
|
||||
"nested": {"key": "value", "number": 42},
|
||||
"list": [1, 2, 3, {"nested": "item"}],
|
||||
"boolean": True,
|
||||
"null_value": None,
|
||||
}
|
||||
complex_variables = {
|
||||
("node1", "var1"): "string_value",
|
||||
("node2", "var2"): {"complex": "object"},
|
||||
("node3", "var3"): [1, 2, 3],
|
||||
}
|
||||
|
||||
graph_runtime_state = self._create_graph_runtime_state(
|
||||
outputs=complex_outputs,
|
||||
total_tokens=250,
|
||||
node_run_steps=10,
|
||||
variables=complex_variables,
|
||||
)
|
||||
|
||||
command_channel = _TestCommandChannelImpl()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||
|
||||
# Act - Save pause state
|
||||
layer.on_event(event)
|
||||
|
||||
# Assert - Retrieve and verify
|
||||
pause_entity = self.workflow_run_service._workflow_run_repo.get_workflow_pause(self.test_workflow_run_id)
|
||||
assert pause_entity is not None
|
||||
assert pause_entity.workflow_execution_id == self.test_workflow_run_id
|
||||
|
||||
state_bytes = pause_entity.get_state()
|
||||
retrieved_state = json.loads(state_bytes.decode())
|
||||
expected_state = json.loads(graph_runtime_state.dumps())
|
||||
|
||||
assert retrieved_state == expected_state
|
||||
assert retrieved_state["outputs"] == complex_outputs
|
||||
assert retrieved_state["total_tokens"] == 250
|
||||
assert retrieved_state["node_run_steps"] == 10
|
||||
|
||||
def test_database_transaction_handling(self, db_session_with_containers):
|
||||
"""Test that database transactions are handled correctly."""
|
||||
# Arrange
|
||||
layer = self._create_pause_state_persistence_layer()
|
||||
graph_runtime_state = self._create_graph_runtime_state(
|
||||
outputs={"test": "transaction"},
|
||||
total_tokens=50,
|
||||
)
|
||||
|
||||
command_channel = _TestCommandChannelImpl()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||
|
||||
# Act
|
||||
layer.on_event(event)
|
||||
|
||||
# Assert - Verify data is committed and accessible in new session
|
||||
with Session(bind=self.session.get_bind(), expire_on_commit=False) as new_session:
|
||||
workflow_run = new_session.get(WorkflowRun, self.test_workflow_run_id)
|
||||
assert workflow_run is not None
|
||||
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||
|
||||
pause_model = new_session.scalars(
|
||||
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
|
||||
).first()
|
||||
assert pause_model is not None
|
||||
assert pause_model.workflow_run_id == self.test_workflow_run_id
|
||||
assert pause_model.resumed_at is None
|
||||
assert pause_model.state_object_key != ""
|
||||
|
||||
def test_file_storage_integration(self, db_session_with_containers):
|
||||
"""Test integration with file storage system."""
|
||||
# Arrange
|
||||
layer = self._create_pause_state_persistence_layer()
|
||||
|
||||
# Create large state data to test storage
|
||||
large_outputs = {"data": "x" * 10000} # 10KB of data
|
||||
graph_runtime_state = self._create_graph_runtime_state(
|
||||
outputs=large_outputs,
|
||||
total_tokens=1000,
|
||||
)
|
||||
|
||||
command_channel = _TestCommandChannelImpl()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||
|
||||
# Act
|
||||
layer.on_event(event)
|
||||
|
||||
# Assert - Verify file was uploaded to storage
|
||||
self.session.refresh(self.test_workflow_run)
|
||||
pause_model = self.session.scalars(
|
||||
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == self.test_workflow_run.id)
|
||||
).first()
|
||||
assert pause_model is not None
|
||||
assert pause_model.state_object_key != ""
|
||||
|
||||
# Verify content in storage
|
||||
storage_content = storage.load(pause_model.state_object_key).decode()
|
||||
assert storage_content == graph_runtime_state.dumps()
|
||||
|
||||
def test_workflow_with_different_creators(self, db_session_with_containers):
|
||||
"""Test pause state with workflows created by different users."""
|
||||
# Arrange - Create workflow with different creator
|
||||
different_user_id = str(uuid.uuid4())
|
||||
different_workflow = Workflow(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=self.test_tenant_id,
|
||||
app_id=self.test_app_id,
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph='{"nodes": [], "edges": []}',
|
||||
features='{"file_upload": {"enabled": false}}',
|
||||
created_by=different_user_id,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
different_workflow_run = WorkflowRun(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=self.test_tenant_id,
|
||||
app_id=self.test_app_id,
|
||||
workflow_id=different_workflow.id,
|
||||
type="workflow",
|
||||
triggered_from="debugging",
|
||||
version="draft",
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
created_by=self.test_user_id, # Run created by different user
|
||||
created_by_role="account",
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
self.session.add(different_workflow)
|
||||
self.session.add(different_workflow_run)
|
||||
self.session.commit()
|
||||
|
||||
layer = self._create_pause_state_persistence_layer(
|
||||
workflow_run=different_workflow_run,
|
||||
workflow=different_workflow,
|
||||
)
|
||||
|
||||
graph_runtime_state = self._create_graph_runtime_state(
|
||||
outputs={"creator_test": "different_creator"},
|
||||
workflow_run_id=different_workflow_run.id,
|
||||
)
|
||||
|
||||
command_channel = _TestCommandChannelImpl()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||
|
||||
# Act
|
||||
layer.on_event(event)
|
||||
|
||||
# Assert - Should use workflow creator (not run creator)
|
||||
self.session.refresh(different_workflow_run)
|
||||
pause_model = self.session.scalars(
|
||||
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == different_workflow_run.id)
|
||||
).first()
|
||||
assert pause_model is not None
|
||||
|
||||
# Verify the state owner is the workflow creator
|
||||
pause_entity = self.workflow_run_service._workflow_run_repo.get_workflow_pause(different_workflow_run.id)
|
||||
assert pause_entity is not None
|
||||
|
||||
def test_layer_ignores_non_pause_events(self, db_session_with_containers):
|
||||
"""Test that layer ignores non-pause events."""
|
||||
# Arrange
|
||||
layer = self._create_pause_state_persistence_layer()
|
||||
graph_runtime_state = self._create_graph_runtime_state()
|
||||
|
||||
command_channel = _TestCommandChannelImpl()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
# Import other event types
|
||||
from core.workflow.graph_events.graph import (
|
||||
GraphRunFailedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
|
||||
# Act - Send non-pause events
|
||||
layer.on_event(GraphRunStartedEvent())
|
||||
layer.on_event(GraphRunSucceededEvent(outputs={"result": "success"}))
|
||||
layer.on_event(GraphRunFailedEvent(error="test error", exceptions_count=1))
|
||||
|
||||
# Assert - No pause state should be created
|
||||
self.session.refresh(self.test_workflow_run)
|
||||
assert self.test_workflow_run.status == WorkflowExecutionStatus.RUNNING
|
||||
|
||||
pause_states = (
|
||||
self.session.query(WorkflowPauseModel)
|
||||
.filter(WorkflowPauseModel.workflow_run_id == self.test_workflow_run_id)
|
||||
.all()
|
||||
)
|
||||
assert len(pause_states) == 0
|
||||
|
||||
def test_layer_requires_initialization(self, db_session_with_containers):
|
||||
"""Test that layer requires proper initialization before handling events."""
|
||||
# Arrange
|
||||
layer = self._create_pause_state_persistence_layer()
|
||||
# Don't initialize - graph_runtime_state should not be set
|
||||
|
||||
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||
|
||||
# Act & Assert - Should raise AttributeError
|
||||
with pytest.raises(AttributeError):
|
||||
layer.on_event(event)
|
||||
@ -5,6 +5,7 @@ import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from models.model import EndUser
|
||||
from models.workflow import Workflow
|
||||
from services.app_generate_service import AppGenerateService
|
||||
@ -32,7 +33,7 @@ class TestAppGenerateService:
|
||||
patch("services.app_generate_service.dify_config") as mock_dify_config,
|
||||
):
|
||||
# Setup default mock returns for billing service
|
||||
mock_billing_service.get_info.return_value = {"subscription": {"plan": "sandbox"}}
|
||||
mock_billing_service.get_info.return_value = {"subscription": {"plan": CloudPlan.SANDBOX}}
|
||||
|
||||
# Setup default mock returns for workflow service
|
||||
mock_workflow_service_instance = mock_workflow_service.return_value
|
||||
@ -430,7 +431,7 @@ class TestAppGenerateService:
|
||||
|
||||
# Setup billing service mock for sandbox plan
|
||||
mock_external_service_dependencies["billing_service"].get_info.return_value = {
|
||||
"subscription": {"plan": "sandbox"}
|
||||
"subscription": {"plan": CloudPlan.SANDBOX}
|
||||
}
|
||||
|
||||
# Set BILLING_ENABLED to True for this test
|
||||
@ -461,7 +462,7 @@ class TestAppGenerateService:
|
||||
|
||||
# Setup billing service mock for sandbox plan
|
||||
mock_external_service_dependencies["billing_service"].get_info.return_value = {
|
||||
"subscription": {"plan": "sandbox"}
|
||||
"subscription": {"plan": CloudPlan.SANDBOX}
|
||||
}
|
||||
|
||||
# Set BILLING_ENABLED to True for this test
|
||||
|
||||
@ -3,6 +3,7 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services.feature_service import FeatureModel, FeatureService, KnowledgeRateLimitModel, SystemFeatureModel
|
||||
|
||||
|
||||
@ -173,7 +174,7 @@ class TestFeatureService:
|
||||
# Set mock return value inside the patch context
|
||||
mock_external_service_dependencies["billing_service"].get_info.return_value = {
|
||||
"enabled": True,
|
||||
"subscription": {"plan": "sandbox", "interval": "monthly", "education": False},
|
||||
"subscription": {"plan": CloudPlan.SANDBOX, "interval": "monthly", "education": False},
|
||||
"members": {"size": 1, "limit": 3},
|
||||
"apps": {"size": 1, "limit": 5},
|
||||
"vector_space": {"size": 1, "limit": 2},
|
||||
@ -189,7 +190,7 @@ class TestFeatureService:
|
||||
result = FeatureService.get_features(tenant_id)
|
||||
|
||||
# Assert: Verify sandbox-specific limitations
|
||||
assert result.billing.subscription.plan == "sandbox"
|
||||
assert result.billing.subscription.plan == CloudPlan.SANDBOX
|
||||
assert result.education.activated is False
|
||||
|
||||
# Verify sandbox limitations
|
||||
|
||||
@ -11,7 +11,7 @@ from configs import dify_config
|
||||
from models import Account, Tenant
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import EndUser, UploadFile
|
||||
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
|
||||
from services.errors.file import BlockedFileExtensionError, FileTooLargeError, UnsupportedFileTypeError
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@ -943,3 +943,150 @@ class TestFileService:
|
||||
|
||||
# Should have the signed URL when source_url is empty
|
||||
assert upload_file2.source_url == "https://example.com/signed-url"
|
||||
|
||||
# Test file extension blacklist
|
||||
def test_upload_file_blocked_extension(
|
||||
self, db_session_with_containers, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file upload with blocked extension.
|
||||
"""
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Mock blacklist configuration by patching the inner field
|
||||
with patch.object(dify_config, "inner_UPLOAD_FILE_EXTENSION_BLACKLIST", "exe,bat,sh"):
|
||||
filename = "malware.exe"
|
||||
content = b"test content"
|
||||
mimetype = "application/x-msdownload"
|
||||
|
||||
with pytest.raises(BlockedFileExtensionError):
|
||||
FileService(engine).upload_file(
|
||||
filename=filename,
|
||||
content=content,
|
||||
mimetype=mimetype,
|
||||
user=account,
|
||||
)
|
||||
|
||||
def test_upload_file_blocked_extension_case_insensitive(
|
||||
self, db_session_with_containers, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file upload with blocked extension (case insensitive).
|
||||
"""
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Mock blacklist configuration by patching the inner field
|
||||
with patch.object(dify_config, "inner_UPLOAD_FILE_EXTENSION_BLACKLIST", "exe,bat"):
|
||||
# Test with uppercase extension
|
||||
filename = "malware.EXE"
|
||||
content = b"test content"
|
||||
mimetype = "application/x-msdownload"
|
||||
|
||||
with pytest.raises(BlockedFileExtensionError):
|
||||
FileService(engine).upload_file(
|
||||
filename=filename,
|
||||
content=content,
|
||||
mimetype=mimetype,
|
||||
user=account,
|
||||
)
|
||||
|
||||
def test_upload_file_not_in_blacklist(self, db_session_with_containers, engine, mock_external_service_dependencies):
|
||||
"""
|
||||
Test file upload with extension not in blacklist.
|
||||
"""
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Mock blacklist configuration by patching the inner field
|
||||
with patch.object(dify_config, "inner_UPLOAD_FILE_EXTENSION_BLACKLIST", "exe,bat,sh"):
|
||||
filename = "document.pdf"
|
||||
content = b"test content"
|
||||
mimetype = "application/pdf"
|
||||
|
||||
upload_file = FileService(engine).upload_file(
|
||||
filename=filename,
|
||||
content=content,
|
||||
mimetype=mimetype,
|
||||
user=account,
|
||||
)
|
||||
|
||||
assert upload_file is not None
|
||||
assert upload_file.name == filename
|
||||
assert upload_file.extension == "pdf"
|
||||
|
||||
def test_upload_file_empty_blacklist(self, db_session_with_containers, engine, mock_external_service_dependencies):
|
||||
"""
|
||||
Test file upload with empty blacklist (default behavior).
|
||||
"""
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Mock empty blacklist configuration by patching the inner field
|
||||
with patch.object(dify_config, "inner_UPLOAD_FILE_EXTENSION_BLACKLIST", ""):
|
||||
# Should allow all file types when blacklist is empty
|
||||
filename = "script.sh"
|
||||
content = b"#!/bin/bash\necho test"
|
||||
mimetype = "application/x-sh"
|
||||
|
||||
upload_file = FileService(engine).upload_file(
|
||||
filename=filename,
|
||||
content=content,
|
||||
mimetype=mimetype,
|
||||
user=account,
|
||||
)
|
||||
|
||||
assert upload_file is not None
|
||||
assert upload_file.extension == "sh"
|
||||
|
||||
def test_upload_file_multiple_blocked_extensions(
|
||||
self, db_session_with_containers, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file upload with multiple blocked extensions.
|
||||
"""
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Mock blacklist with multiple extensions by patching the inner field
|
||||
blacklist_str = "exe,bat,cmd,com,scr,vbs,ps1,msi,dll"
|
||||
with patch.object(dify_config, "inner_UPLOAD_FILE_EXTENSION_BLACKLIST", blacklist_str):
|
||||
for ext in blacklist_str.split(","):
|
||||
filename = f"malware.{ext}"
|
||||
content = b"test content"
|
||||
mimetype = "application/octet-stream"
|
||||
|
||||
with pytest.raises(BlockedFileExtensionError):
|
||||
FileService(engine).upload_file(
|
||||
filename=filename,
|
||||
content=content,
|
||||
mimetype=mimetype,
|
||||
user=account,
|
||||
)
|
||||
|
||||
def test_upload_file_no_extension_with_blacklist(
|
||||
self, db_session_with_containers, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file upload with no extension when blacklist is configured.
|
||||
"""
|
||||
fake = Faker()
|
||||
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Mock blacklist configuration by patching the inner field
|
||||
with patch.object(dify_config, "inner_UPLOAD_FILE_EXTENSION_BLACKLIST", "exe,bat"):
|
||||
# Files with no extension should not be blocked
|
||||
filename = "README"
|
||||
content = b"test content"
|
||||
mimetype = "text/plain"
|
||||
|
||||
upload_file = FileService(engine).upload_file(
|
||||
filename=filename,
|
||||
content=content,
|
||||
mimetype=mimetype,
|
||||
user=account,
|
||||
)
|
||||
|
||||
assert upload_file is not None
|
||||
assert upload_file.extension == ""
|
||||
|
||||
@ -35,9 +35,7 @@ class TestWebAppAuthService:
|
||||
mock_enterprise_service.WebAppAuth.get_app_access_mode_by_id.return_value = type(
|
||||
"MockWebAppAuth", (), {"access_mode": "private"}
|
||||
)()
|
||||
mock_enterprise_service.WebAppAuth.get_app_access_mode_by_code.return_value = type(
|
||||
"MockWebAppAuth", (), {"access_mode": "private"}
|
||||
)()
|
||||
# Note: get_app_access_mode_by_code method was removed in refactoring
|
||||
|
||||
yield {
|
||||
"passport_service": mock_passport_service,
|
||||
|
||||
@ -5,6 +5,7 @@ import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.entities.document_task import DocumentTask
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document
|
||||
@ -19,7 +20,7 @@ from tasks.document_indexing_task import (
|
||||
|
||||
class TestDocumentIndexingTasks:
|
||||
"""Integration tests for document indexing tasks using testcontainers.
|
||||
|
||||
|
||||
This test class covers:
|
||||
- Core _document_indexing function
|
||||
- Deprecated document_indexing_task function
|
||||
@ -213,7 +214,7 @@ class TestDocumentIndexingTasks:
|
||||
# Configure billing features
|
||||
mock_external_service_dependencies["features"].billing.enabled = billing_enabled
|
||||
if billing_enabled:
|
||||
mock_external_service_dependencies["features"].billing.subscription.plan = "sandbox"
|
||||
mock_external_service_dependencies["features"].billing.subscription.plan = CloudPlan.SANDBOX
|
||||
mock_external_service_dependencies["features"].vector_space.limit = 100
|
||||
mock_external_service_dependencies["features"].vector_space.size = 50
|
||||
|
||||
@ -462,7 +463,7 @@ class TestDocumentIndexingTasks:
|
||||
)
|
||||
|
||||
# Configure sandbox plan with batch limit
|
||||
mock_external_service_dependencies["features"].billing.subscription.plan = "sandbox"
|
||||
mock_external_service_dependencies["features"].billing.subscription.plan = CloudPlan.SANDBOX
|
||||
|
||||
# Create more documents than sandbox plan allows (limit is 1)
|
||||
fake = Faker()
|
||||
@ -597,7 +598,7 @@ class TestDocumentIndexingTasks:
|
||||
# Assert: Verify processing occurred (core logic is tested in _document_indexing tests)
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
|
||||
def test_normal_document_indexing_task_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
@ -718,10 +719,10 @@ class TestDocumentIndexingTasks:
|
||||
|
||||
# Use real Redis for TenantSelfTaskQueue
|
||||
from core.rag.pipeline.queue import TenantSelfTaskQueue
|
||||
|
||||
|
||||
# Create real queue instance
|
||||
queue = TenantSelfTaskQueue(tenant_id, "document_indexing")
|
||||
|
||||
|
||||
# Add waiting tasks to the real Redis queue
|
||||
waiting_tasks = [
|
||||
DocumentTask(tenant_id=tenant_id, dataset_id=dataset.id, document_ids=["waiting-doc-1"]),
|
||||
@ -740,7 +741,7 @@ class TestDocumentIndexingTasks:
|
||||
|
||||
# Verify task function was called for each waiting task
|
||||
assert mock_task_func.delay.call_count == 1
|
||||
|
||||
|
||||
# Verify correct parameters for each call
|
||||
calls = mock_task_func.delay.call_args_list
|
||||
assert calls[0][1] == {"tenant_id": tenant_id, "dataset_id": dateset_id, "document_ids": ["waiting-doc-1"]}
|
||||
@ -782,7 +783,7 @@ class TestDocumentIndexingTasks:
|
||||
|
||||
# Create real queue instance
|
||||
queue = TenantSelfTaskQueue(tenant_id, "document_indexing")
|
||||
|
||||
|
||||
# Add waiting task to the real Redis queue
|
||||
waiting_task = DocumentTask(tenant_id=tenant_id, dataset_id=dataset.id, document_ids=["waiting-doc-1"])
|
||||
queue.push_tasks([asdict(waiting_task)])
|
||||
@ -804,7 +805,7 @@ class TestDocumentIndexingTasks:
|
||||
|
||||
# Verify waiting task was still processed despite core processing error
|
||||
mock_task_func.delay.assert_called_once()
|
||||
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call = mock_task_func.delay.call_args
|
||||
assert call[1] == {"tenant_id": tenant_id, "dataset_id": dateset_id, "document_ids": ["waiting-doc-1"]}
|
||||
@ -831,7 +832,7 @@ class TestDocumentIndexingTasks:
|
||||
dataset2, documents2 = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=1
|
||||
)
|
||||
|
||||
|
||||
tenant1_id = dataset1.tenant_id
|
||||
tenant2_id = dataset2.tenant_id
|
||||
dataset1_id = dataset1.id
|
||||
@ -845,15 +846,15 @@ class TestDocumentIndexingTasks:
|
||||
|
||||
# Use real Redis for TenantSelfTaskQueue
|
||||
from core.rag.pipeline.queue import TenantSelfTaskQueue
|
||||
|
||||
|
||||
# Create queue instances for both tenants
|
||||
queue1 = TenantSelfTaskQueue(tenant1_id, "document_indexing")
|
||||
queue2 = TenantSelfTaskQueue(tenant2_id, "document_indexing")
|
||||
|
||||
|
||||
# Add waiting tasks to both queues
|
||||
waiting_task1 = DocumentTask(tenant_id=tenant1_id, dataset_id=dataset1.id, document_ids=["tenant1-doc-1"])
|
||||
waiting_task2 = DocumentTask(tenant_id=tenant2_id, dataset_id=dataset2.id, document_ids=["tenant2-doc-1"])
|
||||
|
||||
|
||||
queue1.push_tasks([asdict(waiting_task1)])
|
||||
queue2.push_tasks([asdict(waiting_task2)])
|
||||
|
||||
|
||||
@ -0,0 +1,948 @@
|
||||
"""Comprehensive integration tests for workflow pause functionality.
|
||||
|
||||
This test suite covers complete workflow pause functionality including:
|
||||
- Real database interactions using containerized PostgreSQL
|
||||
- Real storage operations using the test storage backend
|
||||
- Complete workflow: create -> pause -> resume -> delete
|
||||
- Testing with actual FileService (not mocked)
|
||||
- Database transactions and rollback behavior
|
||||
- Actual file upload and retrieval through storage
|
||||
- Workflow status transitions in the database
|
||||
- Error handling with real database constraints
|
||||
- Concurrent access scenarios
|
||||
- Multi-tenant isolation
|
||||
- Prune functionality
|
||||
- File storage integration
|
||||
|
||||
These tests use TestContainers to spin up real services for integration testing,
|
||||
providing more reliable and realistic test scenarios than mocks.
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
||||
|
||||
from core.workflow.entities import WorkflowExecution
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
from models import WorkflowPause as WorkflowPauseModel
|
||||
from models.account import Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.model import UploadFile
|
||||
from models.workflow import Workflow, WorkflowRun
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import (
|
||||
DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
_WorkflowRunError,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PauseWorkflowSuccessCase:
|
||||
"""Test case for successful pause workflow operations."""
|
||||
|
||||
name: str
|
||||
initial_status: WorkflowExecutionStatus
|
||||
description: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class PauseWorkflowFailureCase:
|
||||
"""Test case for pause workflow failure scenarios."""
|
||||
|
||||
name: str
|
||||
initial_status: WorkflowExecutionStatus
|
||||
description: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResumeWorkflowSuccessCase:
|
||||
"""Test case for successful resume workflow operations."""
|
||||
|
||||
name: str
|
||||
initial_status: WorkflowExecutionStatus
|
||||
description: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResumeWorkflowFailureCase:
|
||||
"""Test case for resume workflow failure scenarios."""
|
||||
|
||||
name: str
|
||||
initial_status: WorkflowExecutionStatus
|
||||
pause_resumed: bool
|
||||
set_running_status: bool = False
|
||||
description: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class PrunePausesTestCase:
|
||||
"""Test case for prune pauses operations."""
|
||||
|
||||
name: str
|
||||
pause_age: timedelta
|
||||
resume_age: timedelta | None
|
||||
expected_pruned_count: int
|
||||
description: str = ""
|
||||
|
||||
|
||||
def pause_workflow_failure_cases() -> list[PauseWorkflowFailureCase]:
|
||||
"""Create test cases for pause workflow failure scenarios."""
|
||||
return [
|
||||
PauseWorkflowFailureCase(
|
||||
name="pause_already_paused_workflow",
|
||||
initial_status=WorkflowExecutionStatus.PAUSED,
|
||||
description="Should fail to pause an already paused workflow",
|
||||
),
|
||||
PauseWorkflowFailureCase(
|
||||
name="pause_completed_workflow",
|
||||
initial_status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
description="Should fail to pause a completed workflow",
|
||||
),
|
||||
PauseWorkflowFailureCase(
|
||||
name="pause_failed_workflow",
|
||||
initial_status=WorkflowExecutionStatus.FAILED,
|
||||
description="Should fail to pause a failed workflow",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def resume_workflow_success_cases() -> list[ResumeWorkflowSuccessCase]:
|
||||
"""Create test cases for successful resume workflow operations."""
|
||||
return [
|
||||
ResumeWorkflowSuccessCase(
|
||||
name="resume_paused_workflow",
|
||||
initial_status=WorkflowExecutionStatus.PAUSED,
|
||||
description="Should successfully resume a paused workflow",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def resume_workflow_failure_cases() -> list[ResumeWorkflowFailureCase]:
|
||||
"""Create test cases for resume workflow failure scenarios."""
|
||||
return [
|
||||
ResumeWorkflowFailureCase(
|
||||
name="resume_already_resumed_workflow",
|
||||
initial_status=WorkflowExecutionStatus.PAUSED,
|
||||
pause_resumed=True,
|
||||
description="Should fail to resume an already resumed workflow",
|
||||
),
|
||||
ResumeWorkflowFailureCase(
|
||||
name="resume_running_workflow",
|
||||
initial_status=WorkflowExecutionStatus.RUNNING,
|
||||
pause_resumed=False,
|
||||
set_running_status=True,
|
||||
description="Should fail to resume a running workflow",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def prune_pauses_test_cases() -> list[PrunePausesTestCase]:
|
||||
"""Create test cases for prune pauses operations."""
|
||||
return [
|
||||
PrunePausesTestCase(
|
||||
name="prune_old_active_pauses",
|
||||
pause_age=timedelta(days=7),
|
||||
resume_age=None,
|
||||
expected_pruned_count=1,
|
||||
description="Should prune old active pauses",
|
||||
),
|
||||
PrunePausesTestCase(
|
||||
name="prune_old_resumed_pauses",
|
||||
pause_age=timedelta(hours=12), # Created 12 hours ago (recent)
|
||||
resume_age=timedelta(days=7),
|
||||
expected_pruned_count=1,
|
||||
description="Should prune old resumed pauses",
|
||||
),
|
||||
PrunePausesTestCase(
|
||||
name="keep_recent_active_pauses",
|
||||
pause_age=timedelta(hours=1),
|
||||
resume_age=None,
|
||||
expected_pruned_count=0,
|
||||
description="Should keep recent active pauses",
|
||||
),
|
||||
PrunePausesTestCase(
|
||||
name="keep_recent_resumed_pauses",
|
||||
pause_age=timedelta(days=1),
|
||||
resume_age=timedelta(hours=1),
|
||||
expected_pruned_count=0,
|
||||
description="Should keep recent resumed pauses",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class TestWorkflowPauseIntegration:
|
||||
"""Comprehensive integration tests for workflow pause functionality."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_test_data(self, db_session_with_containers):
|
||||
"""Set up test data for each test method using TestContainers."""
|
||||
# Create test tenant and account
|
||||
|
||||
tenant = Tenant(
|
||||
name="Test Tenant",
|
||||
status="normal",
|
||||
)
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
account = Account(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
tenant_join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db_session_with_containers.add(tenant_join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set test data
|
||||
self.test_tenant_id = tenant.id
|
||||
self.test_user_id = account.id
|
||||
self.test_app_id = str(uuid.uuid4())
|
||||
self.test_workflow_id = str(uuid.uuid4())
|
||||
|
||||
# Create test workflow
|
||||
self.test_workflow = Workflow(
|
||||
id=self.test_workflow_id,
|
||||
tenant_id=self.test_tenant_id,
|
||||
app_id=self.test_app_id,
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph='{"nodes": [], "edges": []}',
|
||||
features='{"file_upload": {"enabled": false}}',
|
||||
created_by=self.test_user_id,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Store session instance
|
||||
self.session = db_session_with_containers
|
||||
|
||||
# Save test data to database
|
||||
self.session.add(self.test_workflow)
|
||||
self.session.commit()
|
||||
|
||||
yield
|
||||
|
||||
# Cleanup
|
||||
self._cleanup_test_data()
|
||||
|
||||
def _cleanup_test_data(self):
|
||||
"""Clean up test data after each test method."""
|
||||
# Clean up workflow pauses
|
||||
self.session.execute(delete(WorkflowPauseModel))
|
||||
# Clean up upload files
|
||||
self.session.execute(
|
||||
delete(UploadFile).where(
|
||||
UploadFile.tenant_id == self.test_tenant_id,
|
||||
)
|
||||
)
|
||||
# Clean up workflow runs
|
||||
self.session.execute(
|
||||
delete(WorkflowRun).where(
|
||||
WorkflowRun.tenant_id == self.test_tenant_id,
|
||||
WorkflowRun.app_id == self.test_app_id,
|
||||
)
|
||||
)
|
||||
# Clean up workflows
|
||||
self.session.execute(
|
||||
delete(Workflow).where(
|
||||
Workflow.tenant_id == self.test_tenant_id,
|
||||
Workflow.app_id == self.test_app_id,
|
||||
)
|
||||
)
|
||||
self.session.commit()
|
||||
|
||||
def _create_test_workflow_run(
|
||||
self, status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING
|
||||
) -> WorkflowRun:
|
||||
"""Create a test workflow run with specified status."""
|
||||
workflow_run = WorkflowRun(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=self.test_tenant_id,
|
||||
app_id=self.test_app_id,
|
||||
workflow_id=self.test_workflow_id,
|
||||
type="workflow",
|
||||
triggered_from="debugging",
|
||||
version="draft",
|
||||
status=status,
|
||||
created_by=self.test_user_id,
|
||||
created_by_role="account",
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
self.session.add(workflow_run)
|
||||
self.session.commit()
|
||||
return workflow_run
|
||||
|
||||
def _create_test_state(self) -> str:
|
||||
"""Create a test state string."""
|
||||
return json.dumps(
|
||||
{
|
||||
"node_id": "test-node",
|
||||
"node_type": "llm",
|
||||
"status": "paused",
|
||||
"data": {"key": "value"},
|
||||
"timestamp": naive_utc_now().isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
def _get_workflow_run_repository(self):
|
||||
"""Get workflow run repository instance for testing."""
|
||||
# Create session factory from the test session
|
||||
engine = self.session.get_bind()
|
||||
session_factory = sessionmaker(bind=engine, expire_on_commit=False)
|
||||
|
||||
# Create a test-specific repository that implements the missing save method
|
||||
class TestWorkflowRunRepository(DifyAPISQLAlchemyWorkflowRunRepository):
|
||||
"""Test-specific repository that implements the missing save method."""
|
||||
|
||||
def save(self, execution: WorkflowExecution):
|
||||
"""Implement the missing save method for testing."""
|
||||
# For testing purposes, we don't need to implement this method
|
||||
# as it's not used in the pause functionality tests
|
||||
pass
|
||||
|
||||
# Create and return repository instance
|
||||
repository = TestWorkflowRunRepository(session_maker=session_factory)
|
||||
return repository
|
||||
|
||||
# ==================== Complete Pause Workflow Tests ====================
|
||||
|
||||
def test_complete_pause_resume_workflow(self):
|
||||
"""Test complete workflow: create -> pause -> resume -> delete."""
|
||||
# Arrange
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
# Act - Create pause state
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
# Assert - Pause state created
|
||||
assert pause_entity is not None
|
||||
assert pause_entity.id is not None
|
||||
assert pause_entity.workflow_execution_id == workflow_run.id
|
||||
# Convert both to strings for comparison
|
||||
retrieved_state = pause_entity.get_state()
|
||||
if isinstance(retrieved_state, bytes):
|
||||
retrieved_state = retrieved_state.decode()
|
||||
assert retrieved_state == test_state
|
||||
|
||||
# Verify database state
|
||||
query = select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
|
||||
pause_model = self.session.scalars(query).first()
|
||||
assert pause_model is not None
|
||||
assert pause_model.resumed_at is None
|
||||
assert pause_model.id == pause_entity.id
|
||||
|
||||
self.session.refresh(workflow_run)
|
||||
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||
|
||||
# Act - Get pause state
|
||||
retrieved_entity = repository.get_workflow_pause(workflow_run.id)
|
||||
|
||||
# Assert - Pause state retrieved
|
||||
assert retrieved_entity is not None
|
||||
assert retrieved_entity.id == pause_entity.id
|
||||
retrieved_state = retrieved_entity.get_state()
|
||||
if isinstance(retrieved_state, bytes):
|
||||
retrieved_state = retrieved_state.decode()
|
||||
assert retrieved_state == test_state
|
||||
|
||||
# Act - Resume workflow
|
||||
resumed_entity = repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
# Assert - Workflow resumed
|
||||
assert resumed_entity is not None
|
||||
assert resumed_entity.id == pause_entity.id
|
||||
assert resumed_entity.resumed_at is not None
|
||||
|
||||
# Verify database state
|
||||
self.session.refresh(workflow_run)
|
||||
assert workflow_run.status == WorkflowExecutionStatus.RUNNING
|
||||
self.session.refresh(pause_model)
|
||||
assert pause_model.resumed_at is not None
|
||||
|
||||
# Act - Delete pause state
|
||||
repository.delete_workflow_pause(pause_entity)
|
||||
|
||||
# Assert - Pause state deleted
|
||||
with Session(bind=self.session.get_bind()) as session:
|
||||
deleted_pause = session.get(WorkflowPauseModel, pause_entity.id)
|
||||
assert deleted_pause is None
|
||||
|
||||
def test_pause_workflow_success(self):
|
||||
"""Test successful pause workflow scenarios."""
|
||||
workflow_run = self._create_test_workflow_run(status=WorkflowExecutionStatus.RUNNING)
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
assert pause_entity is not None
|
||||
assert pause_entity.workflow_execution_id == workflow_run.id
|
||||
|
||||
retrieved_state = pause_entity.get_state()
|
||||
if isinstance(retrieved_state, bytes):
|
||||
retrieved_state = retrieved_state.decode()
|
||||
assert retrieved_state == test_state
|
||||
|
||||
self.session.refresh(workflow_run)
|
||||
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||
pause_query = select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
|
||||
pause_model = self.session.scalars(pause_query).first()
|
||||
assert pause_model is not None
|
||||
assert pause_model.id == pause_entity.id
|
||||
assert pause_model.resumed_at is None
|
||||
|
||||
@pytest.mark.parametrize("test_case", pause_workflow_failure_cases(), ids=lambda tc: tc.name)
|
||||
def test_pause_workflow_failure(self, test_case: PauseWorkflowFailureCase):
|
||||
"""Test pause workflow failure scenarios."""
|
||||
workflow_run = self._create_test_workflow_run(status=test_case.initial_status)
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
with pytest.raises(_WorkflowRunError):
|
||||
repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("test_case", resume_workflow_success_cases(), ids=lambda tc: tc.name)
|
||||
def test_resume_workflow_success(self, test_case: ResumeWorkflowSuccessCase):
|
||||
"""Test successful resume workflow scenarios."""
|
||||
workflow_run = self._create_test_workflow_run(status=test_case.initial_status)
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
if workflow_run.status != WorkflowExecutionStatus.RUNNING:
|
||||
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
self.session.commit()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
self.session.refresh(workflow_run)
|
||||
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||
|
||||
resumed_entity = repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
assert resumed_entity is not None
|
||||
assert resumed_entity.id == pause_entity.id
|
||||
assert resumed_entity.resumed_at is not None
|
||||
|
||||
self.session.refresh(workflow_run)
|
||||
assert workflow_run.status == WorkflowExecutionStatus.RUNNING
|
||||
pause_query = select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
|
||||
pause_model = self.session.scalars(pause_query).first()
|
||||
assert pause_model is not None
|
||||
assert pause_model.id == pause_entity.id
|
||||
assert pause_model.resumed_at is not None
|
||||
|
||||
def test_resume_running_workflow(self):
|
||||
"""Test resume workflow failure scenarios."""
|
||||
workflow_run = self._create_test_workflow_run(status=WorkflowExecutionStatus.RUNNING)
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
self.session.refresh(workflow_run)
|
||||
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
self.session.add(workflow_run)
|
||||
self.session.commit()
|
||||
|
||||
with pytest.raises(_WorkflowRunError):
|
||||
repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
def test_resume_resumed_pause(self):
|
||||
"""Test resume workflow failure scenarios."""
|
||||
workflow_run = self._create_test_workflow_run(status=WorkflowExecutionStatus.RUNNING)
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
pause_model.resumed_at = naive_utc_now()
|
||||
self.session.add(pause_model)
|
||||
self.session.commit()
|
||||
|
||||
with pytest.raises(_WorkflowRunError):
|
||||
repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
# ==================== Error Scenario Tests ====================
|
||||
|
||||
def test_pause_nonexistent_workflow_run(self):
|
||||
"""Test pausing a non-existent workflow run."""
|
||||
# Arrange
|
||||
nonexistent_id = str(uuid.uuid4())
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="WorkflowRun not found"):
|
||||
repository.create_workflow_pause(
|
||||
workflow_run_id=nonexistent_id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
def test_resume_nonexistent_workflow_run(self):
|
||||
"""Test resuming a non-existent workflow run."""
|
||||
# Arrange
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
nonexistent_id = str(uuid.uuid4())
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="WorkflowRun not found"):
|
||||
repository.resume_workflow_pause(
|
||||
workflow_run_id=nonexistent_id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
# ==================== Prune Functionality Tests ====================
|
||||
|
||||
@pytest.mark.parametrize("test_case", prune_pauses_test_cases(), ids=lambda tc: tc.name)
|
||||
def test_prune_pauses_scenarios(self, test_case: PrunePausesTestCase):
|
||||
"""Test various prune pauses scenarios."""
|
||||
now = naive_utc_now()
|
||||
|
||||
# Create pause state
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
# Manually adjust timestamps for testing
|
||||
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
pause_model.created_at = now - test_case.pause_age
|
||||
|
||||
if test_case.resume_age is not None:
|
||||
# Resume pause and adjust resume time
|
||||
repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
# Need to refresh to get the updated model
|
||||
self.session.refresh(pause_model)
|
||||
# Manually set the resumed_at to an older time for testing
|
||||
pause_model.resumed_at = now - test_case.resume_age
|
||||
self.session.commit() # Commit the resumed_at change
|
||||
# Refresh again to ensure the change is persisted
|
||||
self.session.refresh(pause_model)
|
||||
|
||||
self.session.commit()
|
||||
|
||||
# Act - Prune pauses
|
||||
expiration_time = now - timedelta(days=1, seconds=1) # Expire pauses older than 1 day (plus 1 second)
|
||||
resumption_time = now - timedelta(
|
||||
days=7, seconds=1
|
||||
) # Clean up pauses resumed more than 7 days ago (plus 1 second)
|
||||
|
||||
# Debug: Check pause state before pruning
|
||||
self.session.refresh(pause_model)
|
||||
print(f"Pause created_at: {pause_model.created_at}")
|
||||
print(f"Pause resumed_at: {pause_model.resumed_at}")
|
||||
print(f"Expiration time: {expiration_time}")
|
||||
print(f"Resumption time: {resumption_time}")
|
||||
|
||||
# Force commit to ensure timestamps are saved
|
||||
self.session.commit()
|
||||
|
||||
# Determine if the pause should be pruned based on timestamps
|
||||
should_be_pruned = False
|
||||
if test_case.resume_age is not None:
|
||||
# If resumed, check if resumed_at is older than resumption_time
|
||||
should_be_pruned = pause_model.resumed_at < resumption_time
|
||||
else:
|
||||
# If not resumed, check if created_at is older than expiration_time
|
||||
should_be_pruned = pause_model.created_at < expiration_time
|
||||
|
||||
# Act - Prune pauses
|
||||
pruned_ids = repository.prune_pauses(
|
||||
expiration=expiration_time,
|
||||
resumption_expiration=resumption_time,
|
||||
)
|
||||
|
||||
# Assert - Check pruning results
|
||||
if should_be_pruned:
|
||||
assert len(pruned_ids) == test_case.expected_pruned_count
|
||||
# Verify pause was actually deleted
|
||||
# The pause should be in the pruned_ids list if it was pruned
|
||||
assert pause_entity.id in pruned_ids
|
||||
else:
|
||||
assert len(pruned_ids) == 0
|
||||
|
||||
def test_prune_pauses_with_limit(self):
|
||||
"""Test prune pauses with limit parameter."""
|
||||
now = naive_utc_now()
|
||||
|
||||
# Create multiple pause states
|
||||
pause_entities = []
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
for i in range(5):
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
test_state = self._create_test_state()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
pause_entities.append(pause_entity)
|
||||
|
||||
# Make all pauses old enough to be pruned
|
||||
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
pause_model.created_at = now - timedelta(days=7)
|
||||
|
||||
self.session.commit()
|
||||
|
||||
# Act - Prune with limit
|
||||
expiration_time = now - timedelta(days=1)
|
||||
resumption_time = now - timedelta(days=7)
|
||||
|
||||
pruned_ids = repository.prune_pauses(
|
||||
expiration=expiration_time,
|
||||
resumption_expiration=resumption_time,
|
||||
limit=3,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(pruned_ids) == 3
|
||||
|
||||
# Verify only 3 were deleted
|
||||
remaining_count = (
|
||||
self.session.query(WorkflowPauseModel)
|
||||
.filter(WorkflowPauseModel.id.in_([pe.id for pe in pause_entities]))
|
||||
.count()
|
||||
)
|
||||
assert remaining_count == 2
|
||||
|
||||
# ==================== Multi-tenant Isolation Tests ====================
|
||||
|
||||
def test_multi_tenant_pause_isolation(self):
|
||||
"""Test that pause states are properly isolated by tenant."""
|
||||
# Arrange - Create second tenant
|
||||
|
||||
tenant2 = Tenant(
|
||||
name="Test Tenant 2",
|
||||
status="normal",
|
||||
)
|
||||
self.session.add(tenant2)
|
||||
self.session.commit()
|
||||
|
||||
account2 = Account(
|
||||
email="test2@example.com",
|
||||
name="Test User 2",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
self.session.add(account2)
|
||||
self.session.commit()
|
||||
|
||||
tenant2_join = TenantAccountJoin(
|
||||
tenant_id=tenant2.id,
|
||||
account_id=account2.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
self.session.add(tenant2_join)
|
||||
self.session.commit()
|
||||
|
||||
# Create workflow for tenant 2
|
||||
workflow2 = Workflow(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant2.id,
|
||||
app_id=str(uuid.uuid4()),
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph='{"nodes": [], "edges": []}',
|
||||
features='{"file_upload": {"enabled": false}}',
|
||||
created_by=account2.id,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
self.session.add(workflow2)
|
||||
self.session.commit()
|
||||
|
||||
# Create workflow runs for both tenants
|
||||
workflow_run1 = self._create_test_workflow_run()
|
||||
workflow_run2 = WorkflowRun(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant2.id,
|
||||
app_id=workflow2.app_id,
|
||||
workflow_id=workflow2.id,
|
||||
type="workflow",
|
||||
triggered_from="debugging",
|
||||
version="draft",
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
created_by=account2.id,
|
||||
created_by_role="account",
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
self.session.add(workflow_run2)
|
||||
self.session.commit()
|
||||
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
# Act - Create pause for tenant 1
|
||||
pause_entity1 = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run1.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
# Try to access pause from tenant 2 using tenant 1's repository
|
||||
# This should work because we're using the same repository
|
||||
pause_entity2 = repository.get_workflow_pause(workflow_run2.id)
|
||||
assert pause_entity2 is None # No pause for tenant 2 yet
|
||||
|
||||
# Create pause for tenant 2
|
||||
pause_entity2 = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run2.id,
|
||||
state_owner_user_id=account2.id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
# Assert - Both pauses should exist and be separate
|
||||
assert pause_entity1 is not None
|
||||
assert pause_entity2 is not None
|
||||
assert pause_entity1.id != pause_entity2.id
|
||||
assert pause_entity1.workflow_execution_id != pause_entity2.workflow_execution_id
|
||||
|
||||
def test_cross_tenant_access_restriction(self):
|
||||
"""Test that cross-tenant access is properly restricted."""
|
||||
# This test would require tenant-specific repositories
|
||||
# For now, we test that pause entities are properly scoped by tenant_id
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
# Verify pause is properly scoped
|
||||
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
assert pause_model.workflow_id == self.test_workflow_id
|
||||
|
||||
# ==================== File Storage Integration Tests ====================
|
||||
|
||||
def test_file_storage_integration(self):
|
||||
"""Test that state files are properly stored and retrieved."""
|
||||
# Arrange
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
# Act - Create pause state
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
# Assert - Verify file was uploaded to storage
|
||||
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
assert pause_model.state_object_key != ""
|
||||
|
||||
# Verify file content in storage
|
||||
|
||||
file_key = pause_model.state_object_key
|
||||
storage_content = storage.load(file_key).decode()
|
||||
assert storage_content == test_state
|
||||
|
||||
# Verify retrieval through entity
|
||||
retrieved_state = pause_entity.get_state()
|
||||
if isinstance(retrieved_state, bytes):
|
||||
retrieved_state = retrieved_state.decode()
|
||||
assert retrieved_state == test_state
|
||||
|
||||
def test_file_cleanup_on_pause_deletion(self):
|
||||
"""Test that files are properly handled on pause deletion."""
|
||||
# Arrange
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
# Get file info before deletion
|
||||
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
file_key = pause_model.state_object_key
|
||||
|
||||
# Act - Delete pause state
|
||||
repository.delete_workflow_pause(pause_entity)
|
||||
|
||||
# Assert - Pause record should be deleted
|
||||
self.session.expire_all() # Clear session to ensure fresh query
|
||||
deleted_pause = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
assert deleted_pause is None
|
||||
|
||||
try:
|
||||
content = storage.load(file_key).decode()
|
||||
pytest.fail("File should be deleted from storage after pause deletion")
|
||||
except FileNotFoundError:
|
||||
# This is expected - file should be deleted from storage
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Unexpected error when checking file deletion: {e}")
|
||||
|
||||
def test_large_state_file_handling(self):
|
||||
"""Test handling of large state files."""
|
||||
# Arrange - Create a large state (1MB)
|
||||
large_state = "x" * (1024 * 1024) # 1MB of data
|
||||
large_state_json = json.dumps({"large_data": large_state})
|
||||
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
# Act
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=large_state_json,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert pause_entity is not None
|
||||
retrieved_state = pause_entity.get_state()
|
||||
if isinstance(retrieved_state, bytes):
|
||||
retrieved_state = retrieved_state.decode()
|
||||
assert retrieved_state == large_state_json
|
||||
|
||||
# Verify file size in database
|
||||
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
assert pause_model.state_object_key != ""
|
||||
loaded_state = storage.load(pause_model.state_object_key)
|
||||
assert loaded_state.decode() == large_state_json
|
||||
|
||||
def test_multiple_pause_resume_cycles(self):
|
||||
"""Test multiple pause/resume cycles on the same workflow run."""
|
||||
# Arrange
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
# Act & Assert - Multiple cycles
|
||||
for i in range(3):
|
||||
state = json.dumps({"cycle": i, "data": f"state_{i}"})
|
||||
|
||||
# Reset workflow run status to RUNNING before each pause (after first cycle)
|
||||
if i > 0:
|
||||
self.session.refresh(workflow_run) # Refresh to get latest state from session
|
||||
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
self.session.commit()
|
||||
self.session.refresh(workflow_run) # Refresh again after commit
|
||||
|
||||
# Pause
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=state,
|
||||
)
|
||||
assert pause_entity is not None
|
||||
|
||||
# Verify pause
|
||||
self.session.expire_all() # Clear session to ensure fresh query
|
||||
self.session.refresh(workflow_run)
|
||||
|
||||
# Use the test session directly to verify the pause
|
||||
stmt = select(WorkflowRun).options(selectinload(WorkflowRun.pause)).where(WorkflowRun.id == workflow_run.id)
|
||||
workflow_run_with_pause = self.session.scalar(stmt)
|
||||
pause_model = workflow_run_with_pause.pause
|
||||
|
||||
# Verify pause using test session directly
|
||||
assert pause_model is not None
|
||||
assert pause_model.id == pause_entity.id
|
||||
assert pause_model.state_object_key != ""
|
||||
|
||||
# Load file content using storage directly
|
||||
file_content = storage.load(pause_model.state_object_key)
|
||||
if isinstance(file_content, bytes):
|
||||
file_content = file_content.decode()
|
||||
assert file_content == state
|
||||
|
||||
# Resume
|
||||
resumed_entity = repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
assert resumed_entity is not None
|
||||
assert resumed_entity.resumed_at is not None
|
||||
|
||||
# Verify resume - check that pause is marked as resumed
|
||||
self.session.expire_all() # Clear session to ensure fresh query
|
||||
stmt = select(WorkflowPauseModel).where(WorkflowPauseModel.id == pause_entity.id)
|
||||
resumed_pause_model = self.session.scalar(stmt)
|
||||
assert resumed_pause_model is not None
|
||||
assert resumed_pause_model.resumed_at is not None
|
||||
|
||||
# Verify workflow run status
|
||||
self.session.refresh(workflow_run)
|
||||
assert workflow_run.status == WorkflowExecutionStatus.RUNNING
|
||||
@ -1,324 +0,0 @@
|
||||
"""
|
||||
Unit tests for WorkflowResponseConverter focusing on process_data truncation functionality.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
)
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
|
||||
|
||||
class TestWorkflowResponseConverterCenarios:
|
||||
"""Test process_data truncation in WorkflowResponseConverter."""
|
||||
|
||||
def create_mock_generate_entity(self) -> WorkflowAppGenerateEntity:
|
||||
"""Create a mock WorkflowAppGenerateEntity."""
|
||||
mock_entity = Mock(spec=WorkflowAppGenerateEntity)
|
||||
mock_app_config = Mock()
|
||||
mock_app_config.tenant_id = "test-tenant-id"
|
||||
mock_entity.app_config = mock_app_config
|
||||
mock_entity.inputs = {}
|
||||
return mock_entity
|
||||
|
||||
def create_workflow_response_converter(self) -> WorkflowResponseConverter:
|
||||
"""Create a WorkflowResponseConverter for testing."""
|
||||
|
||||
mock_entity = self.create_mock_generate_entity()
|
||||
mock_user = Mock(spec=Account)
|
||||
mock_user.id = "test-user-id"
|
||||
mock_user.name = "Test User"
|
||||
mock_user.email = "test@example.com"
|
||||
|
||||
system_variables = SystemVariable(workflow_id="wf-id", workflow_execution_id="initial-run-id")
|
||||
return WorkflowResponseConverter(
|
||||
application_generate_entity=mock_entity,
|
||||
user=mock_user,
|
||||
system_variables=system_variables,
|
||||
)
|
||||
|
||||
def create_node_started_event(self, *, node_execution_id: str | None = None) -> QueueNodeStartedEvent:
|
||||
"""Create a QueueNodeStartedEvent for testing."""
|
||||
return QueueNodeStartedEvent(
|
||||
node_execution_id=node_execution_id or str(uuid.uuid4()),
|
||||
node_id="test-node-id",
|
||||
node_title="Test Node",
|
||||
node_type=NodeType.CODE,
|
||||
start_at=naive_utc_now(),
|
||||
predecessor_node_id=None,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
provider_type="built-in",
|
||||
provider_id="code",
|
||||
)
|
||||
|
||||
def create_node_succeeded_event(
|
||||
self,
|
||||
*,
|
||||
node_execution_id: str,
|
||||
process_data: Mapping[str, Any] | None = None,
|
||||
) -> QueueNodeSucceededEvent:
|
||||
"""Create a QueueNodeSucceededEvent for testing."""
|
||||
return QueueNodeSucceededEvent(
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.CODE,
|
||||
node_execution_id=node_execution_id,
|
||||
start_at=naive_utc_now(),
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
inputs={},
|
||||
process_data=process_data or {},
|
||||
outputs={},
|
||||
execution_metadata={},
|
||||
)
|
||||
|
||||
def create_node_retry_event(
|
||||
self,
|
||||
*,
|
||||
node_execution_id: str,
|
||||
process_data: Mapping[str, Any] | None = None,
|
||||
) -> QueueNodeRetryEvent:
|
||||
"""Create a QueueNodeRetryEvent for testing."""
|
||||
return QueueNodeRetryEvent(
|
||||
inputs={"data": "inputs"},
|
||||
outputs={"data": "outputs"},
|
||||
process_data=process_data or {},
|
||||
error="oops",
|
||||
retry_index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.CODE,
|
||||
node_title="test code",
|
||||
provider_type="built-in",
|
||||
provider_id="code",
|
||||
node_execution_id=node_execution_id,
|
||||
start_at=naive_utc_now(),
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
)
|
||||
|
||||
def test_workflow_node_finish_response_uses_truncated_process_data(self):
|
||||
"""Test that node finish response uses get_response_process_data()."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
original_data = {"large_field": "x" * 10000, "metadata": "info"}
|
||||
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
event = self.create_node_succeeded_event(
|
||||
node_execution_id=start_event.node_execution_id,
|
||||
process_data=original_data,
|
||||
)
|
||||
|
||||
def fake_truncate(mapping):
|
||||
if mapping == dict(original_data):
|
||||
return truncated_data, True
|
||||
return mapping, False
|
||||
|
||||
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
# Response should use truncated data, not original
|
||||
assert response is not None
|
||||
assert response.data.process_data == truncated_data
|
||||
assert response.data.process_data != original_data
|
||||
assert response.data.process_data_truncated is True
|
||||
|
||||
def test_workflow_node_finish_response_without_truncation(self):
|
||||
"""Test node finish response when no truncation is applied."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
original_data = {"small": "data"}
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
event = self.create_node_succeeded_event(
|
||||
node_execution_id=start_event.node_execution_id,
|
||||
process_data=original_data,
|
||||
)
|
||||
|
||||
def fake_truncate(mapping):
|
||||
return mapping, False
|
||||
|
||||
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
# Response should use original data
|
||||
assert response is not None
|
||||
assert response.data.process_data == original_data
|
||||
assert response.data.process_data_truncated is False
|
||||
|
||||
def test_workflow_node_finish_response_with_none_process_data(self):
|
||||
"""Test node finish response when process_data is None."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
event = self.create_node_succeeded_event(
|
||||
node_execution_id=start_event.node_execution_id,
|
||||
process_data=None,
|
||||
)
|
||||
|
||||
def fake_truncate(mapping):
|
||||
return mapping, False
|
||||
|
||||
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
# Response should normalize missing process_data to an empty mapping
|
||||
assert response is not None
|
||||
assert response.data.process_data == {}
|
||||
assert response.data.process_data_truncated is False
|
||||
|
||||
def test_workflow_node_retry_response_uses_truncated_process_data(self):
|
||||
"""Test that node retry response uses get_response_process_data()."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
original_data = {"large_field": "x" * 10000, "metadata": "info"}
|
||||
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
event = self.create_node_retry_event(
|
||||
node_execution_id=start_event.node_execution_id,
|
||||
process_data=original_data,
|
||||
)
|
||||
|
||||
def fake_truncate(mapping):
|
||||
if mapping == dict(original_data):
|
||||
return truncated_data, True
|
||||
return mapping, False
|
||||
|
||||
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
||||
|
||||
response = converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
# Response should use truncated data, not original
|
||||
assert response is not None
|
||||
assert response.data.process_data == truncated_data
|
||||
assert response.data.process_data != original_data
|
||||
assert response.data.process_data_truncated is True
|
||||
|
||||
def test_workflow_node_retry_response_without_truncation(self):
|
||||
"""Test node retry response when no truncation is applied."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
original_data = {"small": "data"}
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
event = self.create_node_retry_event(
|
||||
node_execution_id=start_event.node_execution_id,
|
||||
process_data=original_data,
|
||||
)
|
||||
|
||||
def fake_truncate(mapping):
|
||||
return mapping, False
|
||||
|
||||
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
||||
|
||||
response = converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.data.process_data == original_data
|
||||
assert response.data.process_data_truncated is False
|
||||
|
||||
def test_iteration_and_loop_nodes_return_none(self):
|
||||
"""Test that iteration and loop nodes return None (no streaming events)."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
iteration_event = QueueNodeSucceededEvent(
|
||||
node_id="iteration-node",
|
||||
node_type=NodeType.ITERATION,
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
start_at=naive_utc_now(),
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
inputs={},
|
||||
process_data={},
|
||||
outputs={},
|
||||
execution_metadata={},
|
||||
)
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=iteration_event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
assert response is None
|
||||
|
||||
loop_event = iteration_event.model_copy(update={"node_type": NodeType.LOOP})
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=loop_event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
assert response is None
|
||||
|
||||
def test_finish_without_start_raises(self):
|
||||
"""Ensure finish responses require a prior workflow start."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
event = self.create_node_succeeded_event(
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
process_data={},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
@ -0,0 +1,810 @@
|
||||
"""
|
||||
Unit tests for WorkflowResponseConverter focusing on process_data truncation functionality.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueLoopStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
)
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TestWorkflowResponseConverter:
|
||||
"""Test truncation in WorkflowResponseConverter."""
|
||||
|
||||
def create_mock_generate_entity(self) -> WorkflowAppGenerateEntity:
|
||||
"""Create a mock WorkflowAppGenerateEntity."""
|
||||
mock_entity = Mock(spec=WorkflowAppGenerateEntity)
|
||||
mock_app_config = Mock()
|
||||
mock_app_config.tenant_id = "test-tenant-id"
|
||||
mock_entity.invoke_from = InvokeFrom.WEB_APP
|
||||
mock_entity.app_config = mock_app_config
|
||||
mock_entity.inputs = {}
|
||||
return mock_entity
|
||||
|
||||
def create_workflow_response_converter(self) -> WorkflowResponseConverter:
|
||||
"""Create a WorkflowResponseConverter for testing."""
|
||||
|
||||
mock_entity = self.create_mock_generate_entity()
|
||||
mock_user = Mock(spec=Account)
|
||||
mock_user.id = "test-user-id"
|
||||
mock_user.name = "Test User"
|
||||
mock_user.email = "test@example.com"
|
||||
|
||||
system_variables = SystemVariable(workflow_id="wf-id", workflow_execution_id="initial-run-id")
|
||||
return WorkflowResponseConverter(
|
||||
application_generate_entity=mock_entity,
|
||||
user=mock_user,
|
||||
system_variables=system_variables,
|
||||
)
|
||||
|
||||
def create_node_started_event(self, *, node_execution_id: str | None = None) -> QueueNodeStartedEvent:
|
||||
"""Create a QueueNodeStartedEvent for testing."""
|
||||
return QueueNodeStartedEvent(
|
||||
node_execution_id=node_execution_id or str(uuid.uuid4()),
|
||||
node_id="test-node-id",
|
||||
node_title="Test Node",
|
||||
node_type=NodeType.CODE,
|
||||
start_at=naive_utc_now(),
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
provider_type="built-in",
|
||||
provider_id="code",
|
||||
)
|
||||
|
||||
def create_node_succeeded_event(
|
||||
self,
|
||||
*,
|
||||
node_execution_id: str,
|
||||
process_data: Mapping[str, Any] | None = None,
|
||||
) -> QueueNodeSucceededEvent:
|
||||
"""Create a QueueNodeSucceededEvent for testing."""
|
||||
return QueueNodeSucceededEvent(
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.CODE,
|
||||
node_execution_id=node_execution_id,
|
||||
start_at=naive_utc_now(),
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
inputs={},
|
||||
process_data=process_data or {},
|
||||
outputs={},
|
||||
execution_metadata={},
|
||||
)
|
||||
|
||||
def create_node_retry_event(
|
||||
self,
|
||||
*,
|
||||
node_execution_id: str,
|
||||
process_data: Mapping[str, Any] | None = None,
|
||||
) -> QueueNodeRetryEvent:
|
||||
"""Create a QueueNodeRetryEvent for testing."""
|
||||
return QueueNodeRetryEvent(
|
||||
inputs={"data": "inputs"},
|
||||
outputs={"data": "outputs"},
|
||||
process_data=process_data or {},
|
||||
error="oops",
|
||||
retry_index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.CODE,
|
||||
node_title="test code",
|
||||
provider_type="built-in",
|
||||
provider_id="code",
|
||||
node_execution_id=node_execution_id,
|
||||
start_at=naive_utc_now(),
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
)
|
||||
|
||||
def test_workflow_node_finish_response_uses_truncated_process_data(self):
|
||||
"""Test that node finish response uses get_response_process_data()."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
original_data = {"large_field": "x" * 10000, "metadata": "info"}
|
||||
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
event = self.create_node_succeeded_event(
|
||||
node_execution_id=start_event.node_execution_id,
|
||||
process_data=original_data,
|
||||
)
|
||||
|
||||
def fake_truncate(mapping):
|
||||
if mapping == dict(original_data):
|
||||
return truncated_data, True
|
||||
return mapping, False
|
||||
|
||||
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
# Response should use truncated data, not original
|
||||
assert response is not None
|
||||
assert response.data.process_data == truncated_data
|
||||
assert response.data.process_data != original_data
|
||||
assert response.data.process_data_truncated is True
|
||||
|
||||
def test_workflow_node_finish_response_without_truncation(self):
|
||||
"""Test node finish response when no truncation is applied."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
original_data = {"small": "data"}
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
event = self.create_node_succeeded_event(
|
||||
node_execution_id=start_event.node_execution_id,
|
||||
process_data=original_data,
|
||||
)
|
||||
|
||||
def fake_truncate(mapping):
|
||||
return mapping, False
|
||||
|
||||
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
# Response should use original data
|
||||
assert response is not None
|
||||
assert response.data.process_data == original_data
|
||||
assert response.data.process_data_truncated is False
|
||||
|
||||
def test_workflow_node_finish_response_with_none_process_data(self):
|
||||
"""Test node finish response when process_data is None."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
event = self.create_node_succeeded_event(
|
||||
node_execution_id=start_event.node_execution_id,
|
||||
process_data=None,
|
||||
)
|
||||
|
||||
def fake_truncate(mapping):
|
||||
return mapping, False
|
||||
|
||||
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
# Response should normalize missing process_data to an empty mapping
|
||||
assert response is not None
|
||||
assert response.data.process_data == {}
|
||||
assert response.data.process_data_truncated is False
|
||||
|
||||
def test_workflow_node_retry_response_uses_truncated_process_data(self):
|
||||
"""Test that node retry response uses get_response_process_data()."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
original_data = {"large_field": "x" * 10000, "metadata": "info"}
|
||||
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
event = self.create_node_retry_event(
|
||||
node_execution_id=start_event.node_execution_id,
|
||||
process_data=original_data,
|
||||
)
|
||||
|
||||
def fake_truncate(mapping):
|
||||
if mapping == dict(original_data):
|
||||
return truncated_data, True
|
||||
return mapping, False
|
||||
|
||||
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
||||
|
||||
response = converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
# Response should use truncated data, not original
|
||||
assert response is not None
|
||||
assert response.data.process_data == truncated_data
|
||||
assert response.data.process_data != original_data
|
||||
assert response.data.process_data_truncated is True
|
||||
|
||||
def test_workflow_node_retry_response_without_truncation(self):
|
||||
"""Test node retry response when no truncation is applied."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
original_data = {"small": "data"}
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
event = self.create_node_retry_event(
|
||||
node_execution_id=start_event.node_execution_id,
|
||||
process_data=original_data,
|
||||
)
|
||||
|
||||
def fake_truncate(mapping):
|
||||
return mapping, False
|
||||
|
||||
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
|
||||
|
||||
response = converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.data.process_data == original_data
|
||||
assert response.data.process_data_truncated is False
|
||||
|
||||
def test_iteration_and_loop_nodes_return_none(self):
|
||||
"""Test that iteration and loop nodes return None (no streaming events)."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
iteration_event = QueueNodeSucceededEvent(
|
||||
node_id="iteration-node",
|
||||
node_type=NodeType.ITERATION,
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
start_at=naive_utc_now(),
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
inputs={},
|
||||
process_data={},
|
||||
outputs={},
|
||||
execution_metadata={},
|
||||
)
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=iteration_event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
assert response is None
|
||||
|
||||
loop_event = iteration_event.model_copy(update={"node_type": NodeType.LOOP})
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=loop_event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
assert response is None
|
||||
|
||||
def test_finish_without_start_raises(self):
|
||||
"""Ensure finish responses require a prior workflow start."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
event = self.create_node_succeeded_event(
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
process_data={},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestCase:
|
||||
"""Test case data for table-driven tests."""
|
||||
|
||||
name: str
|
||||
invoke_from: InvokeFrom
|
||||
expected_truncation_enabled: bool
|
||||
description: str
|
||||
|
||||
|
||||
class TestWorkflowResponseConverterServiceApiTruncation:
|
||||
"""Test class for Service API truncation functionality in WorkflowResponseConverter."""
|
||||
|
||||
def create_test_app_generate_entity(self, invoke_from: InvokeFrom) -> WorkflowAppGenerateEntity:
|
||||
"""Create a test WorkflowAppGenerateEntity with specified invoke_from."""
|
||||
# Create a minimal WorkflowUIBasedAppConfig for testing
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_id="test_workflow_id",
|
||||
)
|
||||
|
||||
entity = WorkflowAppGenerateEntity(
|
||||
task_id="test_task_id",
|
||||
app_id="test_app_id",
|
||||
app_config=app_config,
|
||||
tenant_id="test_tenant",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
invoke_from=invoke_from,
|
||||
inputs={"test_input": "test_value"},
|
||||
user_id="test_user_id",
|
||||
stream=True,
|
||||
files=[],
|
||||
workflow_execution_id="test_workflow_exec_id",
|
||||
)
|
||||
return entity
|
||||
|
||||
def create_test_user(self) -> Account:
|
||||
"""Create a test user account."""
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
)
|
||||
# Manually set the ID for testing purposes
|
||||
account.id = "test_user_id"
|
||||
return account
|
||||
|
||||
def create_test_system_variables(self) -> SystemVariable:
|
||||
"""Create test system variables."""
|
||||
return SystemVariable()
|
||||
|
||||
def create_test_converter(self, invoke_from: InvokeFrom) -> WorkflowResponseConverter:
|
||||
"""Create WorkflowResponseConverter with specified invoke_from."""
|
||||
entity = self.create_test_app_generate_entity(invoke_from)
|
||||
user = self.create_test_user()
|
||||
system_variables = self.create_test_system_variables()
|
||||
|
||||
converter = WorkflowResponseConverter(
|
||||
application_generate_entity=entity,
|
||||
user=user,
|
||||
system_variables=system_variables,
|
||||
)
|
||||
# ensure `workflow_run_id` is set.
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="test-task-id",
|
||||
workflow_run_id="test-workflow-run-id",
|
||||
workflow_id="test-workflow-id",
|
||||
)
|
||||
return converter
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
TestCase(
|
||||
name="service_api_truncation_disabled",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
expected_truncation_enabled=False,
|
||||
description="Service API calls should have truncation disabled",
|
||||
),
|
||||
TestCase(
|
||||
name="web_app_truncation_enabled",
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
expected_truncation_enabled=True,
|
||||
description="Web app calls should have truncation enabled",
|
||||
),
|
||||
TestCase(
|
||||
name="debugger_truncation_enabled",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
expected_truncation_enabled=True,
|
||||
description="Debugger calls should have truncation enabled",
|
||||
),
|
||||
TestCase(
|
||||
name="explore_truncation_enabled",
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
expected_truncation_enabled=True,
|
||||
description="Explore calls should have truncation enabled",
|
||||
),
|
||||
TestCase(
|
||||
name="published_truncation_enabled",
|
||||
invoke_from=InvokeFrom.PUBLISHED,
|
||||
expected_truncation_enabled=True,
|
||||
description="Published app calls should have truncation enabled",
|
||||
),
|
||||
],
|
||||
ids=lambda x: x.name,
|
||||
)
|
||||
def test_truncator_selection_based_on_invoke_from(self, test_case: TestCase):
|
||||
"""Test that the correct truncator is selected based on invoke_from."""
|
||||
converter = self.create_test_converter(test_case.invoke_from)
|
||||
|
||||
# Test truncation behavior instead of checking private attribute
|
||||
|
||||
# Create a test event with large data
|
||||
large_value = {"key": ["x"] * 2000} # Large data that would be truncated
|
||||
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test_node_exec_id",
|
||||
node_id="test_node",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=naive_utc_now(),
|
||||
inputs=large_value,
|
||||
process_data=large_value,
|
||||
outputs=large_value,
|
||||
error=None,
|
||||
execution_metadata=None,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
)
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test_task",
|
||||
)
|
||||
|
||||
# Verify response is not None
|
||||
assert response is not None
|
||||
|
||||
# Verify truncation behavior matches expectations
|
||||
if test_case.expected_truncation_enabled:
|
||||
# Truncation should be enabled for non-service-api calls
|
||||
assert response.data.inputs_truncated
|
||||
assert response.data.process_data_truncated
|
||||
assert response.data.outputs_truncated
|
||||
else:
|
||||
# SERVICE_API should not truncate
|
||||
assert not response.data.inputs_truncated
|
||||
assert not response.data.process_data_truncated
|
||||
assert not response.data.outputs_truncated
|
||||
|
||||
def test_service_api_truncator_no_op_mapping(self):
|
||||
"""Test that Service API truncator doesn't truncate variable mappings."""
|
||||
converter = self.create_test_converter(InvokeFrom.SERVICE_API)
|
||||
|
||||
# Create a test event with large data
|
||||
large_value: dict[str, Any] = {
|
||||
"large_string": "x" * 10000, # Large string
|
||||
"large_list": list(range(2000)), # Large array
|
||||
"nested_data": {"deep_nested": {"very_deep": {"value": "x" * 5000}}},
|
||||
}
|
||||
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test_node_exec_id",
|
||||
node_id="test_node",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=naive_utc_now(),
|
||||
inputs=large_value,
|
||||
process_data=large_value,
|
||||
outputs=large_value,
|
||||
error=None,
|
||||
execution_metadata=None,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
)
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test_task",
|
||||
)
|
||||
|
||||
# Verify response is not None
|
||||
data = response.data
|
||||
assert data.inputs == large_value
|
||||
assert data.process_data == large_value
|
||||
assert data.outputs == large_value
|
||||
# Service API should not truncate
|
||||
assert data.inputs_truncated is False
|
||||
assert data.process_data_truncated is False
|
||||
assert data.outputs_truncated is False
|
||||
|
||||
def test_web_app_truncator_works_normally(self):
|
||||
"""Test that web app truncator still works normally."""
|
||||
converter = self.create_test_converter(InvokeFrom.WEB_APP)
|
||||
|
||||
# Create a test event with large data
|
||||
large_value = {
|
||||
"large_string": "x" * 10000, # Large string
|
||||
"large_list": list(range(2000)), # Large array
|
||||
}
|
||||
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test_node_exec_id",
|
||||
node_id="test_node",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=naive_utc_now(),
|
||||
inputs=large_value,
|
||||
process_data=large_value,
|
||||
outputs=large_value,
|
||||
error=None,
|
||||
execution_metadata=None,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
)
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test_task",
|
||||
)
|
||||
|
||||
# Verify response is not None
|
||||
assert response is not None
|
||||
|
||||
# Web app should truncate
|
||||
data = response.data
|
||||
assert data.inputs != large_value
|
||||
assert data.process_data != large_value
|
||||
assert data.outputs != large_value
|
||||
# The exact behavior depends on VariableTruncator implementation
|
||||
# Just verify that truncation flags are present
|
||||
assert data.inputs_truncated is True
|
||||
assert data.process_data_truncated is True
|
||||
assert data.outputs_truncated is True
|
||||
|
||||
@staticmethod
|
||||
def _create_event_by_type(
|
||||
type_: QueueEvent, inputs: Mapping[str, Any], process_data: Mapping[str, Any], outputs: Mapping[str, Any]
|
||||
) -> QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent:
|
||||
if type_ == QueueEvent.NODE_SUCCEEDED:
|
||||
return QueueNodeSucceededEvent(
|
||||
node_execution_id="test_node_exec_id",
|
||||
node_id="test_node",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=naive_utc_now(),
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
error=None,
|
||||
execution_metadata=None,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
)
|
||||
elif type_ == QueueEvent.NODE_FAILED:
|
||||
return QueueNodeFailedEvent(
|
||||
node_execution_id="test_node_exec_id",
|
||||
node_id="test_node",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=naive_utc_now(),
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
error="oops",
|
||||
execution_metadata=None,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
)
|
||||
elif type_ == QueueEvent.NODE_EXCEPTION:
|
||||
return QueueNodeExceptionEvent(
|
||||
node_execution_id="test_node_exec_id",
|
||||
node_id="test_node",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=naive_utc_now(),
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
error="oops",
|
||||
execution_metadata=None,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
)
|
||||
else:
|
||||
raise Exception("unknown type.")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"event_type",
|
||||
[
|
||||
QueueEvent.NODE_SUCCEEDED,
|
||||
QueueEvent.NODE_FAILED,
|
||||
QueueEvent.NODE_EXCEPTION,
|
||||
],
|
||||
)
|
||||
def test_service_api_node_finish_event_no_truncation(self, event_type: QueueEvent):
|
||||
"""Test that Service API doesn't truncate node finish events."""
|
||||
converter = self.create_test_converter(InvokeFrom.SERVICE_API)
|
||||
# Create test event with large data
|
||||
large_inputs = {"input1": "x" * 5000, "input2": list(range(2000))}
|
||||
large_process_data = {"process1": "y" * 5000, "process2": {"nested": ["z"] * 2000}}
|
||||
large_outputs = {"output1": "result" * 1000, "output2": list(range(2000))}
|
||||
|
||||
event = TestWorkflowResponseConverterServiceApiTruncation._create_event_by_type(
|
||||
event_type, large_inputs, large_process_data, large_outputs
|
||||
)
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test_task",
|
||||
)
|
||||
|
||||
# Verify response is not None
|
||||
assert response is not None
|
||||
|
||||
# Verify response contains full data (not truncated)
|
||||
assert response.data.inputs == large_inputs
|
||||
assert response.data.process_data == large_process_data
|
||||
assert response.data.outputs == large_outputs
|
||||
assert not response.data.inputs_truncated
|
||||
assert not response.data.process_data_truncated
|
||||
assert not response.data.outputs_truncated
|
||||
|
||||
def test_service_api_node_retry_event_no_truncation(self):
|
||||
"""Test that Service API doesn't truncate node retry events."""
|
||||
converter = self.create_test_converter(InvokeFrom.SERVICE_API)
|
||||
|
||||
# Create test event with large data
|
||||
large_inputs = {"retry_input": "x" * 5000}
|
||||
large_process_data = {"retry_process": "y" * 5000}
|
||||
large_outputs = {"retry_output": "z" * 5000}
|
||||
|
||||
# First, we need to store a snapshot by simulating a start event
|
||||
start_event = QueueNodeStartedEvent(
|
||||
node_execution_id="test_node_exec_id",
|
||||
node_id="test_node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="Test Node",
|
||||
node_run_index=1,
|
||||
start_at=naive_utc_now(),
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
agent_strategy=None,
|
||||
provider_type="plugin",
|
||||
provider_id="test/test_plugin",
|
||||
)
|
||||
converter.workflow_node_start_to_stream_response(event=start_event, task_id="test_task")
|
||||
|
||||
# Now create retry event
|
||||
event = QueueNodeRetryEvent(
|
||||
node_execution_id="test_node_exec_id",
|
||||
node_id="test_node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="Test Node",
|
||||
node_run_index=1,
|
||||
start_at=naive_utc_now(),
|
||||
inputs=large_inputs,
|
||||
process_data=large_process_data,
|
||||
outputs=large_outputs,
|
||||
error="Retry error",
|
||||
execution_metadata=None,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
retry_index=1,
|
||||
provider_type="plugin",
|
||||
provider_id="test/test_plugin",
|
||||
)
|
||||
|
||||
response = converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id="test_task",
|
||||
)
|
||||
|
||||
# Verify response is not None
|
||||
assert response is not None
|
||||
|
||||
# Verify response contains full data (not truncated)
|
||||
assert response.data.inputs == large_inputs
|
||||
assert response.data.process_data == large_process_data
|
||||
assert response.data.outputs == large_outputs
|
||||
assert not response.data.inputs_truncated
|
||||
assert not response.data.process_data_truncated
|
||||
assert not response.data.outputs_truncated
|
||||
|
||||
def test_service_api_iteration_events_no_truncation(self):
|
||||
"""Test that Service API doesn't truncate iteration events."""
|
||||
converter = self.create_test_converter(InvokeFrom.SERVICE_API)
|
||||
|
||||
# Test iteration start event
|
||||
large_value = {"iteration_input": ["x"] * 2000}
|
||||
|
||||
start_event = QueueIterationStartEvent(
|
||||
node_execution_id="test_iter_exec_id",
|
||||
node_id="test_iteration",
|
||||
node_type=NodeType.ITERATION,
|
||||
node_title="Test Iteration",
|
||||
node_run_index=0,
|
||||
start_at=naive_utc_now(),
|
||||
inputs=large_value,
|
||||
metadata={},
|
||||
)
|
||||
|
||||
response = converter.workflow_iteration_start_to_stream_response(
|
||||
task_id="test_task",
|
||||
workflow_execution_id="test_workflow_exec_id",
|
||||
event=start_event,
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.data.inputs == large_value
|
||||
assert not response.data.inputs_truncated
|
||||
|
||||
def test_service_api_loop_events_no_truncation(self):
|
||||
"""Test that Service API doesn't truncate loop events."""
|
||||
converter = self.create_test_converter(InvokeFrom.SERVICE_API)
|
||||
|
||||
# Test loop start event
|
||||
large_inputs = {"loop_input": ["x"] * 2000}
|
||||
|
||||
start_event = QueueLoopStartEvent(
|
||||
node_execution_id="test_loop_exec_id",
|
||||
node_id="test_loop",
|
||||
node_type=NodeType.LOOP,
|
||||
node_title="Test Loop",
|
||||
start_at=naive_utc_now(),
|
||||
inputs=large_inputs,
|
||||
metadata={},
|
||||
node_run_index=0,
|
||||
)
|
||||
|
||||
response = converter.workflow_loop_start_to_stream_response(
|
||||
task_id="test_task",
|
||||
workflow_execution_id="test_workflow_exec_id",
|
||||
event=start_event,
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.data.inputs == large_inputs
|
||||
assert not response.data.inputs_truncated
|
||||
|
||||
def test_web_app_node_finish_event_truncation_works(self):
|
||||
"""Test that web app still truncates node finish events."""
|
||||
converter = self.create_test_converter(InvokeFrom.WEB_APP)
|
||||
|
||||
# Create test event with large data that should be truncated
|
||||
large_inputs = {"input1": ["x"] * 2000}
|
||||
large_process_data = {"process1": ["y"] * 2000}
|
||||
large_outputs = {"output1": ["z"] * 2000}
|
||||
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test_node_exec_id",
|
||||
node_id="test_node",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=naive_utc_now(),
|
||||
inputs=large_inputs,
|
||||
process_data=large_process_data,
|
||||
outputs=large_outputs,
|
||||
error=None,
|
||||
execution_metadata=None,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
)
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test_task",
|
||||
)
|
||||
|
||||
# Verify response is not None
|
||||
assert response is not None
|
||||
|
||||
# Verify response contains truncated data
|
||||
# The exact behavior depends on VariableTruncator implementation
|
||||
# Just verify truncation flags are set correctly (may or may not be truncated depending on size)
|
||||
# At minimum, the truncation mechanism should work
|
||||
assert isinstance(response.data.inputs, dict)
|
||||
assert response.data.inputs_truncated
|
||||
assert isinstance(response.data.process_data, dict)
|
||||
assert response.data.process_data_truncated
|
||||
assert isinstance(response.data.outputs, dict)
|
||||
assert response.data.outputs_truncated
|
||||
@ -0,0 +1,278 @@
|
||||
import json
|
||||
from time import time
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.layers.pause_state_persist_layer import PauseStatePersistenceLayer
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
from core.workflow.graph_engine.entities.commands import GraphEngineCommand
|
||||
from core.workflow.graph_events.graph import (
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
||||
class TestDataFactory:
|
||||
"""Factory helpers for constructing graph events used in tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_graph_run_paused_event(outputs: dict[str, object] | None = None) -> GraphRunPausedEvent:
|
||||
return GraphRunPausedEvent(reason=SchedulingPause(message="test pause"), outputs=outputs or {})
|
||||
|
||||
@staticmethod
|
||||
def create_graph_run_started_event() -> GraphRunStartedEvent:
|
||||
return GraphRunStartedEvent()
|
||||
|
||||
@staticmethod
|
||||
def create_graph_run_succeeded_event(outputs: dict[str, object] | None = None) -> GraphRunSucceededEvent:
|
||||
return GraphRunSucceededEvent(outputs=outputs or {})
|
||||
|
||||
@staticmethod
|
||||
def create_graph_run_failed_event(
|
||||
error: str = "Test error",
|
||||
exceptions_count: int = 1,
|
||||
) -> GraphRunFailedEvent:
|
||||
return GraphRunFailedEvent(error=error, exceptions_count=exceptions_count)
|
||||
|
||||
|
||||
class MockSystemVariableReadOnlyView:
|
||||
"""Minimal read-only system variable view for testing."""
|
||||
|
||||
def __init__(self, workflow_execution_id: str | None = None) -> None:
|
||||
self._workflow_execution_id = workflow_execution_id
|
||||
|
||||
@property
|
||||
def workflow_execution_id(self) -> str | None:
|
||||
return self._workflow_execution_id
|
||||
|
||||
|
||||
class MockReadOnlyVariablePool:
|
||||
"""Mock implementation of ReadOnlyVariablePool for testing."""
|
||||
|
||||
def __init__(self, variables: dict[tuple[str, str], object] | None = None):
|
||||
self._variables = variables or {}
|
||||
|
||||
def get(self, node_id: str, variable_key: str) -> Segment | None:
|
||||
value = self._variables.get((node_id, variable_key))
|
||||
if value is None:
|
||||
return None
|
||||
mock_segment = Mock(spec=Segment)
|
||||
mock_segment.value = value
|
||||
return mock_segment
|
||||
|
||||
def get_all_by_node(self, node_id: str) -> dict[str, object]:
|
||||
return {key: value for (nid, key), value in self._variables.items() if nid == node_id}
|
||||
|
||||
def get_by_prefix(self, prefix: str) -> dict[str, object]:
|
||||
return {f"{nid}.{key}": value for (nid, key), value in self._variables.items() if nid.startswith(prefix)}
|
||||
|
||||
|
||||
class MockReadOnlyGraphRuntimeState:
|
||||
"""Mock implementation of ReadOnlyGraphRuntimeState for testing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
start_at: float | None = None,
|
||||
total_tokens: int = 0,
|
||||
node_run_steps: int = 0,
|
||||
ready_queue_size: int = 0,
|
||||
exceptions_count: int = 0,
|
||||
outputs: dict[str, object] | None = None,
|
||||
variables: dict[tuple[str, str], object] | None = None,
|
||||
workflow_execution_id: str | None = None,
|
||||
):
|
||||
self._start_at = start_at or time()
|
||||
self._total_tokens = total_tokens
|
||||
self._node_run_steps = node_run_steps
|
||||
self._ready_queue_size = ready_queue_size
|
||||
self._exceptions_count = exceptions_count
|
||||
self._outputs = outputs or {}
|
||||
self._variable_pool = MockReadOnlyVariablePool(variables)
|
||||
self._system_variable = MockSystemVariableReadOnlyView(workflow_execution_id)
|
||||
|
||||
@property
|
||||
def system_variable(self) -> MockSystemVariableReadOnlyView:
|
||||
return self._system_variable
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> ReadOnlyVariablePool:
|
||||
return self._variable_pool
|
||||
|
||||
@property
|
||||
def start_at(self) -> float:
|
||||
return self._start_at
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
return self._total_tokens
|
||||
|
||||
@property
|
||||
def node_run_steps(self) -> int:
|
||||
return self._node_run_steps
|
||||
|
||||
@property
|
||||
def ready_queue_size(self) -> int:
|
||||
return self._ready_queue_size
|
||||
|
||||
@property
|
||||
def exceptions_count(self) -> int:
|
||||
return self._exceptions_count
|
||||
|
||||
@property
|
||||
def outputs(self) -> dict[str, object]:
|
||||
return self._outputs.copy()
|
||||
|
||||
@property
|
||||
def llm_usage(self):
|
||||
mock_usage = Mock()
|
||||
mock_usage.prompt_tokens = 10
|
||||
mock_usage.completion_tokens = 20
|
||||
mock_usage.total_tokens = 30
|
||||
return mock_usage
|
||||
|
||||
def get_output(self, key: str, default: object = None) -> object:
|
||||
return self._outputs.get(key, default)
|
||||
|
||||
def dumps(self) -> str:
|
||||
return json.dumps(
|
||||
{
|
||||
"start_at": self._start_at,
|
||||
"total_tokens": self._total_tokens,
|
||||
"node_run_steps": self._node_run_steps,
|
||||
"ready_queue_size": self._ready_queue_size,
|
||||
"exceptions_count": self._exceptions_count,
|
||||
"outputs": self._outputs,
|
||||
"variables": {f"{k[0]}.{k[1]}": v for k, v in self._variable_pool._variables.items()},
|
||||
"workflow_execution_id": self._system_variable.workflow_execution_id,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class MockCommandChannel:
|
||||
"""Mock implementation of CommandChannel for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self._commands: list[GraphEngineCommand] = []
|
||||
|
||||
def fetch_commands(self) -> list[GraphEngineCommand]:
|
||||
return self._commands.copy()
|
||||
|
||||
def send_command(self, command: GraphEngineCommand) -> None:
|
||||
self._commands.append(command)
|
||||
|
||||
|
||||
class TestPauseStatePersistenceLayer:
|
||||
"""Unit tests for PauseStatePersistenceLayer."""
|
||||
|
||||
def test_init_with_dependency_injection(self):
|
||||
session_factory = Mock(name="session_factory")
|
||||
state_owner_user_id = "user-123"
|
||||
|
||||
layer = PauseStatePersistenceLayer(
|
||||
session_factory=session_factory,
|
||||
state_owner_user_id=state_owner_user_id,
|
||||
)
|
||||
|
||||
assert layer._session_maker is session_factory
|
||||
assert layer._state_owner_user_id == state_owner_user_id
|
||||
assert not hasattr(layer, "graph_runtime_state")
|
||||
assert not hasattr(layer, "command_channel")
|
||||
|
||||
def test_initialize_sets_dependencies(self):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner")
|
||||
|
||||
graph_runtime_state = MockReadOnlyGraphRuntimeState()
|
||||
command_channel = MockCommandChannel()
|
||||
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
assert layer.graph_runtime_state is graph_runtime_state
|
||||
assert layer.command_channel is command_channel
|
||||
|
||||
def test_on_event_with_graph_run_paused_event(self, monkeypatch: pytest.MonkeyPatch):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
|
||||
|
||||
mock_repo = Mock()
|
||||
mock_factory = Mock(return_value=mock_repo)
|
||||
monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory)
|
||||
|
||||
graph_runtime_state = MockReadOnlyGraphRuntimeState(
|
||||
outputs={"result": "test_output"},
|
||||
total_tokens=100,
|
||||
workflow_execution_id="run-123",
|
||||
)
|
||||
command_channel = MockCommandChannel()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = TestDataFactory.create_graph_run_paused_event(outputs={"intermediate": "result"})
|
||||
expected_state = graph_runtime_state.dumps()
|
||||
|
||||
layer.on_event(event)
|
||||
|
||||
mock_factory.assert_called_once_with(session_factory)
|
||||
mock_repo.create_workflow_pause.assert_called_once_with(
|
||||
workflow_run_id="run-123",
|
||||
state_owner_user_id="owner-123",
|
||||
state=expected_state,
|
||||
)
|
||||
|
||||
def test_on_event_ignores_non_paused_events(self, monkeypatch: pytest.MonkeyPatch):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
|
||||
|
||||
mock_repo = Mock()
|
||||
mock_factory = Mock(return_value=mock_repo)
|
||||
monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory)
|
||||
|
||||
graph_runtime_state = MockReadOnlyGraphRuntimeState()
|
||||
command_channel = MockCommandChannel()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
events = [
|
||||
TestDataFactory.create_graph_run_started_event(),
|
||||
TestDataFactory.create_graph_run_succeeded_event(),
|
||||
TestDataFactory.create_graph_run_failed_event(),
|
||||
]
|
||||
|
||||
for event in events:
|
||||
layer.on_event(event)
|
||||
|
||||
mock_factory.assert_not_called()
|
||||
mock_repo.create_workflow_pause.assert_not_called()
|
||||
|
||||
def test_on_event_raises_attribute_error_when_graph_runtime_state_is_none(self):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
|
||||
|
||||
event = TestDataFactory.create_graph_run_paused_event()
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
layer.on_event(event)
|
||||
|
||||
def test_on_event_asserts_when_workflow_execution_id_missing(self, monkeypatch: pytest.MonkeyPatch):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
|
||||
|
||||
mock_repo = Mock()
|
||||
mock_factory = Mock(return_value=mock_repo)
|
||||
monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory)
|
||||
|
||||
graph_runtime_state = MockReadOnlyGraphRuntimeState(workflow_execution_id=None)
|
||||
command_channel = MockCommandChannel()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = TestDataFactory.create_graph_run_paused_event()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
layer.on_event(event)
|
||||
|
||||
mock_factory.assert_not_called()
|
||||
mock_repo.create_workflow_pause.assert_not_called()
|
||||
@ -23,3 +23,32 @@ def test_file():
|
||||
assert file.extension == ".png"
|
||||
assert file.mime_type == "image/png"
|
||||
assert file.size == 67
|
||||
|
||||
|
||||
def test_file_model_validate_with_legacy_fields():
|
||||
"""Test `File` model can handle data containing compatibility fields."""
|
||||
data = {
|
||||
"id": "test-file",
|
||||
"tenant_id": "test-tenant-id",
|
||||
"type": "image",
|
||||
"transfer_method": "tool_file",
|
||||
"related_id": "test-related-id",
|
||||
"filename": "image.png",
|
||||
"extension": ".png",
|
||||
"mime_type": "image/png",
|
||||
"size": 67,
|
||||
"storage_key": "test-storage-key",
|
||||
"url": "https://example.com/image.png",
|
||||
# Extra legacy fields
|
||||
"tool_file_id": "tool-file-123",
|
||||
"upload_file_id": "upload-file-456",
|
||||
"datasource_file_id": "datasource-file-789",
|
||||
}
|
||||
|
||||
# Should be able to create `File` object without raising an exception
|
||||
file = File.model_validate(data)
|
||||
|
||||
# The File object does not have tool_file_id, upload_file_id, or datasource_file_id as attributes.
|
||||
# Instead, check it does not expose unrecognized legacy fields (should raise on getattr).
|
||||
for legacy_field in ("tool_file_id", "upload_file_id", "datasource_file_id"):
|
||||
assert not hasattr(file, legacy_field)
|
||||
|
||||
@ -0,0 +1,12 @@
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer
|
||||
|
||||
|
||||
def test_get_runner_script():
|
||||
code = JavascriptCodeProvider.get_default_code()
|
||||
inputs = {"arg1": "hello, ", "arg2": "world!"}
|
||||
script = NodeJsTemplateTransformer.assemble_runner_script(code, inputs)
|
||||
script_lines = script.splitlines()
|
||||
code_lines = code.splitlines()
|
||||
# Check that the first lines of script are exactly the same as code
|
||||
assert script_lines[: len(code_lines)] == code_lines
|
||||
@ -0,0 +1,12 @@
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer
|
||||
|
||||
|
||||
def test_get_runner_script():
|
||||
code = Python3CodeProvider.get_default_code()
|
||||
inputs = {"arg1": "hello, ", "arg2": "world!"}
|
||||
script = Python3TemplateTransformer.assemble_runner_script(code, inputs)
|
||||
script_lines = script.splitlines()
|
||||
code_lines = code.splitlines()
|
||||
# Check that the first lines of script are exactly the same as code
|
||||
assert script_lines[: len(code_lines)] == code_lines
|
||||
@ -395,9 +395,6 @@ def test_client_capabilities_default():
|
||||
|
||||
# Assert default capabilities
|
||||
assert received_capabilities is not None
|
||||
assert received_capabilities.sampling is not None
|
||||
assert received_capabilities.roots is not None
|
||||
assert received_capabilities.roots.listChanged is True
|
||||
|
||||
|
||||
def test_client_capabilities_with_custom_callbacks():
|
||||
|
||||
@ -0,0 +1,171 @@
|
||||
"""Tests for _PrivateWorkflowPauseEntity implementation."""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from models.workflow import WorkflowPause as WorkflowPauseModel
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity
|
||||
|
||||
|
||||
class TestPrivateWorkflowPauseEntity:
|
||||
"""Test _PrivateWorkflowPauseEntity implementation."""
|
||||
|
||||
def test_entity_initialization(self):
|
||||
"""Test entity initialization with required parameters."""
|
||||
# Create mock models
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.id = "pause-123"
|
||||
mock_pause_model.workflow_run_id = "execution-456"
|
||||
mock_pause_model.resumed_at = None
|
||||
|
||||
# Create entity
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert entity._pause_model is mock_pause_model
|
||||
assert entity._cached_state is None
|
||||
|
||||
def test_from_models_classmethod(self):
|
||||
"""Test from_models class method."""
|
||||
# Create mock models
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.id = "pause-123"
|
||||
mock_pause_model.workflow_run_id = "execution-456"
|
||||
|
||||
# Create entity using from_models
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(
|
||||
workflow_pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# Verify entity creation
|
||||
assert isinstance(entity, _PrivateWorkflowPauseEntity)
|
||||
assert entity._pause_model is mock_pause_model
|
||||
|
||||
def test_id_property(self):
|
||||
"""Test id property returns pause model ID."""
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.id = "pause-123"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
assert entity.id == "pause-123"
|
||||
|
||||
def test_workflow_execution_id_property(self):
|
||||
"""Test workflow_execution_id property returns workflow run ID."""
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.workflow_run_id = "execution-456"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
assert entity.workflow_execution_id == "execution-456"
|
||||
|
||||
def test_resumed_at_property(self):
|
||||
"""Test resumed_at property returns pause model resumed_at."""
|
||||
resumed_at = datetime(2023, 12, 25, 15, 30, 45)
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.resumed_at = resumed_at
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
assert entity.resumed_at == resumed_at
|
||||
|
||||
def test_resumed_at_property_none(self):
|
||||
"""Test resumed_at property returns None when not set."""
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.resumed_at = None
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
assert entity.resumed_at is None
|
||||
|
||||
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
|
||||
def test_get_state_first_call(self, mock_storage):
|
||||
"""Test get_state loads from storage on first call."""
|
||||
state_data = b'{"test": "data", "step": 5}'
|
||||
mock_storage.load.return_value = state_data
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.state_object_key = "test-state-key"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# First call should load from storage
|
||||
result = entity.get_state()
|
||||
|
||||
assert result == state_data
|
||||
mock_storage.load.assert_called_once_with("test-state-key")
|
||||
assert entity._cached_state == state_data
|
||||
|
||||
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
|
||||
def test_get_state_cached_call(self, mock_storage):
|
||||
"""Test get_state returns cached data on subsequent calls."""
|
||||
state_data = b'{"test": "data", "step": 5}'
|
||||
mock_storage.load.return_value = state_data
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.state_object_key = "test-state-key"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# First call
|
||||
result1 = entity.get_state()
|
||||
# Second call should use cache
|
||||
result2 = entity.get_state()
|
||||
|
||||
assert result1 == state_data
|
||||
assert result2 == state_data
|
||||
# Storage should only be called once
|
||||
mock_storage.load.assert_called_once_with("test-state-key")
|
||||
|
||||
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
|
||||
def test_get_state_with_pre_cached_data(self, mock_storage):
|
||||
"""Test get_state returns pre-cached data."""
|
||||
state_data = b'{"test": "data", "step": 5}'
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# Pre-cache data
|
||||
entity._cached_state = state_data
|
||||
|
||||
# Should return cached data without calling storage
|
||||
result = entity.get_state()
|
||||
|
||||
assert result == state_data
|
||||
mock_storage.load.assert_not_called()
|
||||
|
||||
def test_entity_with_binary_state_data(self):
|
||||
"""Test entity with binary state data."""
|
||||
# Test with binary data that's not valid JSON
|
||||
binary_data = b"\x00\x01\x02\x03\x04\x05\xff\xfe"
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
mock_storage.load.return_value = binary_data
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
result = entity.get_state()
|
||||
|
||||
assert result == binary_data
|
||||
@ -3,6 +3,7 @@
|
||||
import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
@ -149,8 +150,8 @@ def test_pause_command():
|
||||
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
|
||||
pause_events = [e for e in events if isinstance(e, GraphRunPausedEvent)]
|
||||
assert len(pause_events) == 1
|
||||
assert pause_events[0].reason == "User requested pause"
|
||||
assert pause_events[0].reason == SchedulingPause(message="User requested pause")
|
||||
|
||||
graph_execution = engine.graph_runtime_state.graph_execution
|
||||
assert graph_execution.is_paused
|
||||
assert graph_execution.pause_reason == "User requested pause"
|
||||
assert graph_execution.pause_reason == SchedulingPause(message="User requested pause")
|
||||
|
||||
@ -0,0 +1,96 @@
|
||||
"""
|
||||
Test cases for the Iteration node's flatten_output functionality.
|
||||
|
||||
This module tests the iteration node's ability to:
|
||||
1. Flatten array outputs when flatten_output=True (default)
|
||||
2. Preserve nested array structure when flatten_output=False
|
||||
"""
|
||||
|
||||
from .test_table_runner import TableTestRunner, WorkflowTestCase
|
||||
|
||||
|
||||
def test_iteration_with_flatten_output_enabled():
|
||||
"""
|
||||
Test iteration node with flatten_output=True (default behavior).
|
||||
|
||||
The fixture implements an iteration that:
|
||||
1. Iterates over [1, 2, 3]
|
||||
2. For each item, outputs [item, item*2]
|
||||
3. With flatten_output=True, should output [1, 2, 2, 4, 3, 6]
|
||||
"""
|
||||
runner = TableTestRunner()
|
||||
|
||||
test_case = WorkflowTestCase(
|
||||
fixture_path="iteration_flatten_output_enabled_workflow",
|
||||
inputs={},
|
||||
expected_outputs={"output": [1, 2, 2, 4, 3, 6]},
|
||||
description="Iteration with flatten_output=True flattens nested arrays",
|
||||
use_auto_mock=False, # Run code nodes directly
|
||||
)
|
||||
|
||||
result = runner.run_test_case(test_case)
|
||||
|
||||
assert result.success, f"Test failed: {result.error}"
|
||||
assert result.actual_outputs is not None, "Should have outputs"
|
||||
assert result.actual_outputs == {"output": [1, 2, 2, 4, 3, 6]}, (
|
||||
f"Expected flattened output [1, 2, 2, 4, 3, 6], got {result.actual_outputs}"
|
||||
)
|
||||
|
||||
|
||||
def test_iteration_with_flatten_output_disabled():
|
||||
"""
|
||||
Test iteration node with flatten_output=False.
|
||||
|
||||
The fixture implements an iteration that:
|
||||
1. Iterates over [1, 2, 3]
|
||||
2. For each item, outputs [item, item*2]
|
||||
3. With flatten_output=False, should output [[1, 2], [2, 4], [3, 6]]
|
||||
"""
|
||||
runner = TableTestRunner()
|
||||
|
||||
test_case = WorkflowTestCase(
|
||||
fixture_path="iteration_flatten_output_disabled_workflow",
|
||||
inputs={},
|
||||
expected_outputs={"output": [[1, 2], [2, 4], [3, 6]]},
|
||||
description="Iteration with flatten_output=False preserves nested structure",
|
||||
use_auto_mock=False, # Run code nodes directly
|
||||
)
|
||||
|
||||
result = runner.run_test_case(test_case)
|
||||
|
||||
assert result.success, f"Test failed: {result.error}"
|
||||
assert result.actual_outputs is not None, "Should have outputs"
|
||||
assert result.actual_outputs == {"output": [[1, 2], [2, 4], [3, 6]]}, (
|
||||
f"Expected nested output [[1, 2], [2, 4], [3, 6]], got {result.actual_outputs}"
|
||||
)
|
||||
|
||||
|
||||
def test_iteration_flatten_output_comparison():
|
||||
"""
|
||||
Run both flatten_output configurations in parallel to verify the difference.
|
||||
"""
|
||||
runner = TableTestRunner()
|
||||
|
||||
test_cases = [
|
||||
WorkflowTestCase(
|
||||
fixture_path="iteration_flatten_output_enabled_workflow",
|
||||
inputs={},
|
||||
expected_outputs={"output": [1, 2, 2, 4, 3, 6]},
|
||||
description="flatten_output=True: Flattened output",
|
||||
use_auto_mock=False, # Run code nodes directly
|
||||
),
|
||||
WorkflowTestCase(
|
||||
fixture_path="iteration_flatten_output_disabled_workflow",
|
||||
inputs={},
|
||||
expected_outputs={"output": [[1, 2], [2, 4], [3, 6]]},
|
||||
description="flatten_output=False: Nested output",
|
||||
use_auto_mock=False, # Run code nodes directly
|
||||
),
|
||||
]
|
||||
|
||||
suite_result = runner.run_table_tests(test_cases, parallel=True)
|
||||
|
||||
# Assert all tests passed
|
||||
assert suite_result.passed_tests == 2, f"Expected 2 passed tests, got {suite_result.passed_tests}"
|
||||
assert suite_result.failed_tests == 0, f"Expected 0 failed tests, got {suite_result.failed_tests}"
|
||||
assert suite_result.success_rate == 100.0, f"Expected 100% success rate, got {suite_result.success_rate}"
|
||||
32
api/tests/unit_tests/core/workflow/test_enums.py
Normal file
32
api/tests/unit_tests/core/workflow/test_enums.py
Normal file
@ -0,0 +1,32 @@
|
||||
"""Tests for workflow pause related enums and constants."""
|
||||
|
||||
from core.workflow.enums import (
|
||||
WorkflowExecutionStatus,
|
||||
)
|
||||
|
||||
|
||||
class TestWorkflowExecutionStatus:
|
||||
"""Test WorkflowExecutionStatus enum."""
|
||||
|
||||
def test_is_ended_method(self):
|
||||
"""Test is_ended method for different statuses."""
|
||||
# Test ended statuses
|
||||
ended_statuses = [
|
||||
WorkflowExecutionStatus.SUCCEEDED,
|
||||
WorkflowExecutionStatus.FAILED,
|
||||
WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
|
||||
WorkflowExecutionStatus.STOPPED,
|
||||
]
|
||||
|
||||
for status in ended_statuses:
|
||||
assert status.is_ended(), f"{status} should be considered ended"
|
||||
|
||||
# Test non-ended statuses
|
||||
non_ended_statuses = [
|
||||
WorkflowExecutionStatus.SCHEDULED,
|
||||
WorkflowExecutionStatus.RUNNING,
|
||||
WorkflowExecutionStatus.PAUSED,
|
||||
]
|
||||
|
||||
for status in non_ended_statuses:
|
||||
assert not status.is_ended(), f"{status} should not be considered ended"
|
||||
@ -0,0 +1,202 @@
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
|
||||
from core.file.models import File, FileTransferMethod, FileType
|
||||
from core.workflow.system_variable import SystemVariable, SystemVariableReadOnlyView
|
||||
|
||||
|
||||
class TestSystemVariableReadOnlyView:
|
||||
"""Test cases for SystemVariableReadOnlyView class."""
|
||||
|
||||
def test_read_only_property_access(self):
|
||||
"""Test that all properties return correct values from wrapped instance."""
|
||||
# Create test data
|
||||
test_file = File(
|
||||
id="file-123",
|
||||
tenant_id="tenant-123",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="related-123",
|
||||
)
|
||||
|
||||
datasource_info = {"key": "value", "nested": {"data": 42}}
|
||||
|
||||
# Create SystemVariable with all fields
|
||||
system_var = SystemVariable(
|
||||
user_id="user-123",
|
||||
app_id="app-123",
|
||||
workflow_id="workflow-123",
|
||||
files=[test_file],
|
||||
workflow_execution_id="exec-123",
|
||||
query="test query",
|
||||
conversation_id="conv-123",
|
||||
dialogue_count=5,
|
||||
document_id="doc-123",
|
||||
original_document_id="orig-doc-123",
|
||||
dataset_id="dataset-123",
|
||||
batch="batch-123",
|
||||
datasource_type="type-123",
|
||||
datasource_info=datasource_info,
|
||||
invoke_from="invoke-123",
|
||||
)
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Test all properties
|
||||
assert read_only_view.user_id == "user-123"
|
||||
assert read_only_view.app_id == "app-123"
|
||||
assert read_only_view.workflow_id == "workflow-123"
|
||||
assert read_only_view.workflow_execution_id == "exec-123"
|
||||
assert read_only_view.query == "test query"
|
||||
assert read_only_view.conversation_id == "conv-123"
|
||||
assert read_only_view.dialogue_count == 5
|
||||
assert read_only_view.document_id == "doc-123"
|
||||
assert read_only_view.original_document_id == "orig-doc-123"
|
||||
assert read_only_view.dataset_id == "dataset-123"
|
||||
assert read_only_view.batch == "batch-123"
|
||||
assert read_only_view.datasource_type == "type-123"
|
||||
assert read_only_view.invoke_from == "invoke-123"
|
||||
|
||||
def test_defensive_copying_of_mutable_objects(self):
|
||||
"""Test that mutable objects are defensively copied."""
|
||||
# Create test data
|
||||
test_file = File(
|
||||
id="file-123",
|
||||
tenant_id="tenant-123",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="related-123",
|
||||
)
|
||||
|
||||
datasource_info = {"key": "original_value"}
|
||||
|
||||
# Create SystemVariable
|
||||
system_var = SystemVariable(
|
||||
files=[test_file], datasource_info=datasource_info, workflow_execution_id="exec-123"
|
||||
)
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Test files defensive copying
|
||||
files_copy = read_only_view.files
|
||||
assert isinstance(files_copy, tuple) # Should be immutable tuple
|
||||
assert len(files_copy) == 1
|
||||
assert files_copy[0].id == "file-123"
|
||||
|
||||
# Verify it's a copy (can't modify original through view)
|
||||
assert isinstance(files_copy, tuple)
|
||||
# tuples don't have append method, so they're immutable
|
||||
|
||||
# Test datasource_info defensive copying
|
||||
datasource_copy = read_only_view.datasource_info
|
||||
assert datasource_copy is not None
|
||||
assert datasource_copy["key"] == "original_value"
|
||||
|
||||
datasource_copy = cast(dict, datasource_copy)
|
||||
with pytest.raises(TypeError):
|
||||
datasource_copy["key"] = "modified value"
|
||||
|
||||
# Verify original is unchanged
|
||||
assert system_var.datasource_info is not None
|
||||
assert system_var.datasource_info["key"] == "original_value"
|
||||
assert read_only_view.datasource_info is not None
|
||||
assert read_only_view.datasource_info["key"] == "original_value"
|
||||
|
||||
def test_always_accesses_latest_data(self):
|
||||
"""Test that properties always return the latest data from wrapped instance."""
|
||||
# Create SystemVariable
|
||||
system_var = SystemVariable(user_id="original-user", workflow_execution_id="exec-123")
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Verify initial value
|
||||
assert read_only_view.user_id == "original-user"
|
||||
|
||||
# Modify the wrapped instance
|
||||
system_var.user_id = "modified-user"
|
||||
|
||||
# Verify view returns the new value
|
||||
assert read_only_view.user_id == "modified-user"
|
||||
|
||||
def test_repr_method(self):
|
||||
"""Test the __repr__ method."""
|
||||
# Create SystemVariable
|
||||
system_var = SystemVariable(workflow_execution_id="exec-123")
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Test repr
|
||||
repr_str = repr(read_only_view)
|
||||
assert "SystemVariableReadOnlyView" in repr_str
|
||||
assert "system_variable=" in repr_str
|
||||
|
||||
def test_none_value_handling(self):
|
||||
"""Test that None values are properly handled."""
|
||||
# Create SystemVariable with all None values except workflow_execution_id
|
||||
system_var = SystemVariable(
|
||||
user_id=None,
|
||||
app_id=None,
|
||||
workflow_id=None,
|
||||
workflow_execution_id="exec-123",
|
||||
query=None,
|
||||
conversation_id=None,
|
||||
dialogue_count=None,
|
||||
document_id=None,
|
||||
original_document_id=None,
|
||||
dataset_id=None,
|
||||
batch=None,
|
||||
datasource_type=None,
|
||||
datasource_info=None,
|
||||
invoke_from=None,
|
||||
)
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Test all None values
|
||||
assert read_only_view.user_id is None
|
||||
assert read_only_view.app_id is None
|
||||
assert read_only_view.workflow_id is None
|
||||
assert read_only_view.query is None
|
||||
assert read_only_view.conversation_id is None
|
||||
assert read_only_view.dialogue_count is None
|
||||
assert read_only_view.document_id is None
|
||||
assert read_only_view.original_document_id is None
|
||||
assert read_only_view.dataset_id is None
|
||||
assert read_only_view.batch is None
|
||||
assert read_only_view.datasource_type is None
|
||||
assert read_only_view.datasource_info is None
|
||||
assert read_only_view.invoke_from is None
|
||||
|
||||
# files should be empty tuple even when default list is empty
|
||||
assert read_only_view.files == ()
|
||||
|
||||
def test_empty_files_handling(self):
|
||||
"""Test that empty files list is handled correctly."""
|
||||
# Create SystemVariable with empty files
|
||||
system_var = SystemVariable(files=[], workflow_execution_id="exec-123")
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Test files handling
|
||||
assert read_only_view.files == ()
|
||||
assert isinstance(read_only_view.files, tuple)
|
||||
|
||||
def test_empty_datasource_info_handling(self):
|
||||
"""Test that empty datasource_info is handled correctly."""
|
||||
# Create SystemVariable with empty datasource_info
|
||||
system_var = SystemVariable(datasource_info={}, workflow_execution_id="exec-123")
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Test datasource_info handling
|
||||
assert read_only_view.datasource_info == {}
|
||||
# Should be a copy, not the same object
|
||||
assert read_only_view.datasource_info is not system_var.datasource_info
|
||||
@ -1,8 +1,10 @@
|
||||
import datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytz
|
||||
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.datetime_utils import naive_utc_now, parse_time_range
|
||||
|
||||
|
||||
def test_naive_utc_now(monkeypatch: pytest.MonkeyPatch):
|
||||
@ -20,3 +22,247 @@ def test_naive_utc_now(monkeypatch: pytest.MonkeyPatch):
|
||||
naive_time = naive_datetime.time()
|
||||
utc_time = tz_aware_utc_now.time()
|
||||
assert naive_time == utc_time
|
||||
|
||||
|
||||
class TestParseTimeRange:
|
||||
"""Test cases for parse_time_range function."""
|
||||
|
||||
def test_parse_time_range_basic(self):
|
||||
"""Test basic time range parsing."""
|
||||
start, end = parse_time_range("2024-01-01 10:00", "2024-01-01 18:00", "UTC")
|
||||
|
||||
assert start is not None
|
||||
assert end is not None
|
||||
assert start < end
|
||||
assert start.tzinfo == pytz.UTC
|
||||
assert end.tzinfo == pytz.UTC
|
||||
|
||||
def test_parse_time_range_start_only(self):
|
||||
"""Test parsing with only start time."""
|
||||
start, end = parse_time_range("2024-01-01 10:00", None, "UTC")
|
||||
|
||||
assert start is not None
|
||||
assert end is None
|
||||
assert start.tzinfo == pytz.UTC
|
||||
|
||||
def test_parse_time_range_end_only(self):
|
||||
"""Test parsing with only end time."""
|
||||
start, end = parse_time_range(None, "2024-01-01 18:00", "UTC")
|
||||
|
||||
assert start is None
|
||||
assert end is not None
|
||||
assert end.tzinfo == pytz.UTC
|
||||
|
||||
def test_parse_time_range_both_none(self):
|
||||
"""Test parsing with both times None."""
|
||||
start, end = parse_time_range(None, None, "UTC")
|
||||
|
||||
assert start is None
|
||||
assert end is None
|
||||
|
||||
def test_parse_time_range_different_timezones(self):
|
||||
"""Test parsing with different timezones."""
|
||||
# Test with US/Eastern timezone
|
||||
start, end = parse_time_range("2024-01-01 10:00", "2024-01-01 18:00", "US/Eastern")
|
||||
|
||||
assert start is not None
|
||||
assert end is not None
|
||||
assert start.tzinfo == pytz.UTC
|
||||
assert end.tzinfo == pytz.UTC
|
||||
# Verify the times are correctly converted to UTC
|
||||
assert start.hour == 15 # 10 AM EST = 3 PM UTC (in January)
|
||||
assert end.hour == 23 # 6 PM EST = 11 PM UTC (in January)
|
||||
|
||||
def test_parse_time_range_invalid_start_format(self):
|
||||
"""Test parsing with invalid start time format."""
|
||||
with pytest.raises(ValueError, match="time data.*does not match format"):
|
||||
parse_time_range("invalid-date", "2024-01-01 18:00", "UTC")
|
||||
|
||||
def test_parse_time_range_invalid_end_format(self):
|
||||
"""Test parsing with invalid end time format."""
|
||||
with pytest.raises(ValueError, match="time data.*does not match format"):
|
||||
parse_time_range("2024-01-01 10:00", "invalid-date", "UTC")
|
||||
|
||||
def test_parse_time_range_invalid_timezone(self):
|
||||
"""Test parsing with invalid timezone."""
|
||||
with pytest.raises(pytz.exceptions.UnknownTimeZoneError):
|
||||
parse_time_range("2024-01-01 10:00", "2024-01-01 18:00", "Invalid/Timezone")
|
||||
|
||||
def test_parse_time_range_start_after_end(self):
|
||||
"""Test parsing with start time after end time."""
|
||||
with pytest.raises(ValueError, match="start must be earlier than or equal to end"):
|
||||
parse_time_range("2024-01-01 18:00", "2024-01-01 10:00", "UTC")
|
||||
|
||||
def test_parse_time_range_start_equals_end(self):
|
||||
"""Test parsing with start time equal to end time."""
|
||||
start, end = parse_time_range("2024-01-01 10:00", "2024-01-01 10:00", "UTC")
|
||||
|
||||
assert start is not None
|
||||
assert end is not None
|
||||
assert start == end
|
||||
|
||||
def test_parse_time_range_dst_ambiguous_time(self):
|
||||
"""Test parsing during DST ambiguous time (fall back)."""
|
||||
# This test simulates DST fall back where 2:30 AM occurs twice
|
||||
with patch("pytz.timezone") as mock_timezone:
|
||||
# Mock timezone that raises AmbiguousTimeError
|
||||
mock_tz = mock_timezone.return_value
|
||||
|
||||
# Create a mock datetime object for the return value
|
||||
mock_dt = datetime.datetime(2024, 1, 1, 10, 0, 0)
|
||||
mock_utc_dt = mock_dt.replace(tzinfo=pytz.UTC)
|
||||
|
||||
# Create a proper mock for the localized datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_localized_dt = MagicMock()
|
||||
mock_localized_dt.astimezone.return_value = mock_utc_dt
|
||||
|
||||
# Set up side effects: first call raises exception, second call succeeds
|
||||
mock_tz.localize.side_effect = [
|
||||
pytz.AmbiguousTimeError("Ambiguous time"), # First call for start
|
||||
mock_localized_dt, # Second call for start (with is_dst=False)
|
||||
pytz.AmbiguousTimeError("Ambiguous time"), # First call for end
|
||||
mock_localized_dt, # Second call for end (with is_dst=False)
|
||||
]
|
||||
|
||||
start, end = parse_time_range("2024-01-01 10:00", "2024-01-01 18:00", "US/Eastern")
|
||||
|
||||
# Should use is_dst=False for ambiguous times
|
||||
assert mock_tz.localize.call_count == 4 # 2 calls per time (first fails, second succeeds)
|
||||
assert start is not None
|
||||
assert end is not None
|
||||
|
||||
def test_parse_time_range_dst_nonexistent_time(self):
|
||||
"""Test parsing during DST nonexistent time (spring forward)."""
|
||||
with patch("pytz.timezone") as mock_timezone:
|
||||
# Mock timezone that raises NonExistentTimeError
|
||||
mock_tz = mock_timezone.return_value
|
||||
|
||||
# Create a mock datetime object for the return value
|
||||
mock_dt = datetime.datetime(2024, 1, 1, 10, 0, 0)
|
||||
mock_utc_dt = mock_dt.replace(tzinfo=pytz.UTC)
|
||||
|
||||
# Create a proper mock for the localized datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_localized_dt = MagicMock()
|
||||
mock_localized_dt.astimezone.return_value = mock_utc_dt
|
||||
|
||||
# Set up side effects: first call raises exception, second call succeeds
|
||||
mock_tz.localize.side_effect = [
|
||||
pytz.NonExistentTimeError("Non-existent time"), # First call for start
|
||||
mock_localized_dt, # Second call for start (with adjusted time)
|
||||
pytz.NonExistentTimeError("Non-existent time"), # First call for end
|
||||
mock_localized_dt, # Second call for end (with adjusted time)
|
||||
]
|
||||
|
||||
start, end = parse_time_range("2024-01-01 10:00", "2024-01-01 18:00", "US/Eastern")
|
||||
|
||||
# Should adjust time forward by 1 hour for nonexistent times
|
||||
assert mock_tz.localize.call_count == 4 # 2 calls per time (first fails, second succeeds)
|
||||
assert start is not None
|
||||
assert end is not None
|
||||
|
||||
def test_parse_time_range_edge_cases(self):
|
||||
"""Test edge cases for time parsing."""
|
||||
# Test with midnight times
|
||||
start, end = parse_time_range("2024-01-01 00:00", "2024-01-01 23:59", "UTC")
|
||||
assert start is not None
|
||||
assert end is not None
|
||||
assert start.hour == 0
|
||||
assert start.minute == 0
|
||||
assert end.hour == 23
|
||||
assert end.minute == 59
|
||||
|
||||
def test_parse_time_range_different_dates(self):
|
||||
"""Test parsing with different dates."""
|
||||
start, end = parse_time_range("2024-01-01 10:00", "2024-01-02 10:00", "UTC")
|
||||
assert start is not None
|
||||
assert end is not None
|
||||
assert start.date() != end.date()
|
||||
assert (end - start).days == 1
|
||||
|
||||
def test_parse_time_range_seconds_handling(self):
|
||||
"""Test that seconds are properly set to 0."""
|
||||
start, end = parse_time_range("2024-01-01 10:30", "2024-01-01 18:45", "UTC")
|
||||
assert start is not None
|
||||
assert end is not None
|
||||
assert start.second == 0
|
||||
assert end.second == 0
|
||||
|
||||
def test_parse_time_range_timezone_conversion_accuracy(self):
|
||||
"""Test accurate timezone conversion."""
|
||||
# Test with a known timezone conversion
|
||||
start, end = parse_time_range("2024-01-01 12:00", "2024-01-01 12:00", "Asia/Tokyo")
|
||||
|
||||
assert start is not None
|
||||
assert end is not None
|
||||
assert start.tzinfo == pytz.UTC
|
||||
assert end.tzinfo == pytz.UTC
|
||||
# Tokyo is UTC+9, so 12:00 JST = 03:00 UTC
|
||||
assert start.hour == 3
|
||||
assert end.hour == 3
|
||||
|
||||
def test_parse_time_range_summer_time(self):
|
||||
"""Test parsing during summer time (DST)."""
|
||||
# Test with US/Eastern during summer (EDT = UTC-4)
|
||||
start, end = parse_time_range("2024-07-01 12:00", "2024-07-01 12:00", "US/Eastern")
|
||||
|
||||
assert start is not None
|
||||
assert end is not None
|
||||
assert start.tzinfo == pytz.UTC
|
||||
assert end.tzinfo == pytz.UTC
|
||||
# 12:00 EDT = 16:00 UTC
|
||||
assert start.hour == 16
|
||||
assert end.hour == 16
|
||||
|
||||
def test_parse_time_range_winter_time(self):
|
||||
"""Test parsing during winter time (standard time)."""
|
||||
# Test with US/Eastern during winter (EST = UTC-5)
|
||||
start, end = parse_time_range("2024-01-01 12:00", "2024-01-01 12:00", "US/Eastern")
|
||||
|
||||
assert start is not None
|
||||
assert end is not None
|
||||
assert start.tzinfo == pytz.UTC
|
||||
assert end.tzinfo == pytz.UTC
|
||||
# 12:00 EST = 17:00 UTC
|
||||
assert start.hour == 17
|
||||
assert end.hour == 17
|
||||
|
||||
def test_parse_time_range_empty_strings(self):
|
||||
"""Test parsing with empty strings."""
|
||||
# Empty strings are treated as None, so they should not raise errors
|
||||
start, end = parse_time_range("", "2024-01-01 18:00", "UTC")
|
||||
assert start is None
|
||||
assert end is not None
|
||||
|
||||
start, end = parse_time_range("2024-01-01 10:00", "", "UTC")
|
||||
assert start is not None
|
||||
assert end is None
|
||||
|
||||
def test_parse_time_range_malformed_datetime(self):
|
||||
"""Test parsing with malformed datetime strings."""
|
||||
with pytest.raises(ValueError, match="time data.*does not match format"):
|
||||
parse_time_range("2024-13-01 10:00", "2024-01-01 18:00", "UTC")
|
||||
|
||||
with pytest.raises(ValueError, match="time data.*does not match format"):
|
||||
parse_time_range("2024-01-01 10:00", "2024-01-32 18:00", "UTC")
|
||||
|
||||
def test_parse_time_range_very_long_time_range(self):
|
||||
"""Test parsing with very long time range."""
|
||||
start, end = parse_time_range("2020-01-01 00:00", "2030-12-31 23:59", "UTC")
|
||||
|
||||
assert start is not None
|
||||
assert end is not None
|
||||
assert start < end
|
||||
assert (end - start).days > 3000 # More than 8 years
|
||||
|
||||
def test_parse_time_range_negative_timezone(self):
|
||||
"""Test parsing with negative timezone offset."""
|
||||
start, end = parse_time_range("2024-01-01 12:00", "2024-01-01 12:00", "America/New_York")
|
||||
|
||||
assert start is not None
|
||||
assert end is not None
|
||||
assert start.tzinfo == pytz.UTC
|
||||
assert end.tzinfo == pytz.UTC
|
||||
|
||||
@ -1,5 +1,10 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from werkzeug.wrappers import Response
|
||||
|
||||
from constants import COOKIE_NAME_ACCESS_TOKEN, COOKIE_NAME_WEBAPP_ACCESS_TOKEN
|
||||
from libs.token import extract_access_token, extract_webapp_access_token
|
||||
from libs import token
|
||||
from libs.token import extract_access_token, extract_webapp_access_token, set_csrf_token_to_cookie
|
||||
|
||||
|
||||
class MockRequest:
|
||||
@ -23,3 +28,35 @@ def test_extract_access_token():
|
||||
for request, expected_console, expected_webapp in test_cases:
|
||||
assert extract_access_token(request) == expected_console # pyright: ignore[reportArgumentType]
|
||||
assert extract_webapp_access_token(request) == expected_webapp # pyright: ignore[reportArgumentType]
|
||||
|
||||
|
||||
def test_real_cookie_name_uses_host_prefix_without_domain(monkeypatch):
|
||||
monkeypatch.setattr(token.dify_config, "CONSOLE_WEB_URL", "https://console.example.com", raising=False)
|
||||
monkeypatch.setattr(token.dify_config, "CONSOLE_API_URL", "https://api.example.com", raising=False)
|
||||
monkeypatch.setattr(token.dify_config, "COOKIE_DOMAIN", "", raising=False)
|
||||
|
||||
assert token._real_cookie_name("csrf_token") == "__Host-csrf_token"
|
||||
|
||||
|
||||
def test_real_cookie_name_without_host_prefix_when_domain_present(monkeypatch):
|
||||
monkeypatch.setattr(token.dify_config, "CONSOLE_WEB_URL", "https://console.example.com", raising=False)
|
||||
monkeypatch.setattr(token.dify_config, "CONSOLE_API_URL", "https://api.example.com", raising=False)
|
||||
monkeypatch.setattr(token.dify_config, "COOKIE_DOMAIN", ".example.com", raising=False)
|
||||
|
||||
assert token._real_cookie_name("csrf_token") == "csrf_token"
|
||||
|
||||
|
||||
def test_set_csrf_cookie_includes_domain_when_configured(monkeypatch):
|
||||
monkeypatch.setattr(token.dify_config, "CONSOLE_WEB_URL", "https://console.example.com", raising=False)
|
||||
monkeypatch.setattr(token.dify_config, "CONSOLE_API_URL", "https://api.example.com", raising=False)
|
||||
monkeypatch.setattr(token.dify_config, "COOKIE_DOMAIN", ".example.com", raising=False)
|
||||
|
||||
response = Response()
|
||||
request = MagicMock()
|
||||
|
||||
set_csrf_token_to_cookie(request, response, "abc123")
|
||||
|
||||
cookies = response.headers.getlist("Set-Cookie")
|
||||
assert any("csrf_token=abc123" in c for c in cookies)
|
||||
assert any("Domain=example.com" in c for c in cookies)
|
||||
assert all("__Host-" not in c for c in cookies)
|
||||
|
||||
11
api/tests/unit_tests/models/test_base.py
Normal file
11
api/tests/unit_tests/models/test_base.py
Normal file
@ -0,0 +1,11 @@
|
||||
from models.base import DefaultFieldsMixin
|
||||
|
||||
|
||||
class FooModel(DefaultFieldsMixin):
|
||||
def __init__(self, id: str):
|
||||
self.id = id
|
||||
|
||||
|
||||
def test_repr():
|
||||
foo_model = FooModel(id="test-id")
|
||||
assert repr(foo_model) == "<FooModel(id=test-id)>"
|
||||
@ -0,0 +1,370 @@
|
||||
"""Unit tests for DifyAPISQLAlchemyWorkflowRunRepository implementation."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.workflow.entities.workflow_pause import WorkflowPauseEntity
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from models.workflow import WorkflowPause as WorkflowPauseModel
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import (
|
||||
DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
_PrivateWorkflowPauseEntity,
|
||||
_WorkflowRunError,
|
||||
)
|
||||
|
||||
|
||||
class TestDifyAPISQLAlchemyWorkflowRunRepository:
|
||||
"""Test DifyAPISQLAlchemyWorkflowRunRepository implementation."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""Create a mock session."""
|
||||
return Mock(spec=Session)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_maker(self, mock_session):
|
||||
"""Create a mock sessionmaker."""
|
||||
session_maker = Mock(spec=sessionmaker)
|
||||
|
||||
# Create a context manager mock
|
||||
context_manager = Mock()
|
||||
context_manager.__enter__ = Mock(return_value=mock_session)
|
||||
context_manager.__exit__ = Mock(return_value=None)
|
||||
session_maker.return_value = context_manager
|
||||
|
||||
# Mock session.begin() context manager
|
||||
begin_context_manager = Mock()
|
||||
begin_context_manager.__enter__ = Mock(return_value=None)
|
||||
begin_context_manager.__exit__ = Mock(return_value=None)
|
||||
mock_session.begin = Mock(return_value=begin_context_manager)
|
||||
|
||||
# Add missing session methods
|
||||
mock_session.commit = Mock()
|
||||
mock_session.rollback = Mock()
|
||||
mock_session.add = Mock()
|
||||
mock_session.delete = Mock()
|
||||
mock_session.get = Mock()
|
||||
mock_session.scalar = Mock()
|
||||
mock_session.scalars = Mock()
|
||||
|
||||
# Also support expire_on_commit parameter
|
||||
def make_session(expire_on_commit=None):
|
||||
cm = Mock()
|
||||
cm.__enter__ = Mock(return_value=mock_session)
|
||||
cm.__exit__ = Mock(return_value=None)
|
||||
return cm
|
||||
|
||||
session_maker.side_effect = make_session
|
||||
return session_maker
|
||||
|
||||
@pytest.fixture
|
||||
def repository(self, mock_session_maker):
|
||||
"""Create repository instance with mocked dependencies."""
|
||||
|
||||
# Create a testable subclass that implements the save method
|
||||
class TestableDifyAPISQLAlchemyWorkflowRunRepository(DifyAPISQLAlchemyWorkflowRunRepository):
|
||||
def __init__(self, session_maker):
|
||||
# Initialize without calling parent __init__ to avoid any instantiation issues
|
||||
self._session_maker = session_maker
|
||||
|
||||
def save(self, execution):
|
||||
"""Mock implementation of save method."""
|
||||
return None
|
||||
|
||||
# Create repository instance
|
||||
repo = TestableDifyAPISQLAlchemyWorkflowRunRepository(mock_session_maker)
|
||||
|
||||
return repo
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workflow_run(self):
|
||||
"""Create a sample WorkflowRun model."""
|
||||
workflow_run = Mock(spec=WorkflowRun)
|
||||
workflow_run.id = "workflow-run-123"
|
||||
workflow_run.tenant_id = "tenant-123"
|
||||
workflow_run.app_id = "app-123"
|
||||
workflow_run.workflow_id = "workflow-123"
|
||||
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
return workflow_run
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workflow_pause(self):
|
||||
"""Create a sample WorkflowPauseModel."""
|
||||
pause = Mock(spec=WorkflowPauseModel)
|
||||
pause.id = "pause-123"
|
||||
pause.workflow_id = "workflow-123"
|
||||
pause.workflow_run_id = "workflow-run-123"
|
||||
pause.state_object_key = "workflow-state-123.json"
|
||||
pause.resumed_at = None
|
||||
pause.created_at = datetime.now(UTC)
|
||||
return pause
|
||||
|
||||
|
||||
class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
"""Test create_workflow_pause method."""
|
||||
|
||||
def test_create_workflow_pause_success(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
mock_session: Mock,
|
||||
sample_workflow_run: Mock,
|
||||
):
|
||||
"""Test successful workflow pause creation."""
|
||||
# Arrange
|
||||
workflow_run_id = "workflow-run-123"
|
||||
state_owner_user_id = "user-123"
|
||||
state = '{"test": "state"}'
|
||||
|
||||
mock_session.get.return_value = sample_workflow_run
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.uuidv7") as mock_uuidv7:
|
||||
mock_uuidv7.side_effect = ["pause-123"]
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
# Act
|
||||
result = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run_id,
|
||||
state_owner_user_id=state_owner_user_id,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, _PrivateWorkflowPauseEntity)
|
||||
assert result.id == "pause-123"
|
||||
assert result.workflow_execution_id == workflow_run_id
|
||||
|
||||
# Verify database interactions
|
||||
mock_session.get.assert_called_once_with(WorkflowRun, workflow_run_id)
|
||||
mock_storage.save.assert_called_once()
|
||||
mock_session.add.assert_called()
|
||||
# When using session.begin() context manager, commit is handled automatically
|
||||
# No explicit commit call is expected
|
||||
|
||||
def test_create_workflow_pause_not_found(
|
||||
self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock
|
||||
):
|
||||
"""Test workflow pause creation when workflow run not found."""
|
||||
# Arrange
|
||||
mock_session.get.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="WorkflowRun not found: workflow-run-123"):
|
||||
repository.create_workflow_pause(
|
||||
workflow_run_id="workflow-run-123",
|
||||
state_owner_user_id="user-123",
|
||||
state='{"test": "state"}',
|
||||
)
|
||||
|
||||
mock_session.get.assert_called_once_with(WorkflowRun, "workflow-run-123")
|
||||
|
||||
def test_create_workflow_pause_invalid_status(
|
||||
self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock, sample_workflow_run: Mock
|
||||
):
|
||||
"""Test workflow pause creation when workflow not in RUNNING status."""
|
||||
# Arrange
|
||||
sample_workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||
mock_session.get.return_value = sample_workflow_run
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING status can be paused"):
|
||||
repository.create_workflow_pause(
|
||||
workflow_run_id="workflow-run-123",
|
||||
state_owner_user_id="user-123",
|
||||
state='{"test": "state"}',
|
||||
)
|
||||
|
||||
|
||||
class TestResumeWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
"""Test resume_workflow_pause method."""
|
||||
|
||||
def test_resume_workflow_pause_success(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
mock_session: Mock,
|
||||
sample_workflow_run: Mock,
|
||||
sample_workflow_pause: Mock,
|
||||
):
|
||||
"""Test successful workflow pause resume."""
|
||||
# Arrange
|
||||
workflow_run_id = "workflow-run-123"
|
||||
pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||
pause_entity.id = "pause-123"
|
||||
|
||||
# Setup workflow run and pause
|
||||
sample_workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||
sample_workflow_run.pause = sample_workflow_pause
|
||||
sample_workflow_pause.resumed_at = None
|
||||
|
||||
mock_session.scalar.return_value = sample_workflow_run
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.naive_utc_now") as mock_now:
|
||||
mock_now.return_value = datetime.now(UTC)
|
||||
|
||||
# Act
|
||||
result = repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run_id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, _PrivateWorkflowPauseEntity)
|
||||
assert result.id == "pause-123"
|
||||
|
||||
# Verify state transitions
|
||||
assert sample_workflow_pause.resumed_at is not None
|
||||
assert sample_workflow_run.status == WorkflowExecutionStatus.RUNNING
|
||||
|
||||
# Verify database interactions
|
||||
mock_session.add.assert_called()
|
||||
# When using session.begin() context manager, commit is handled automatically
|
||||
# No explicit commit call is expected
|
||||
|
||||
def test_resume_workflow_pause_not_paused(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
mock_session: Mock,
|
||||
sample_workflow_run: Mock,
|
||||
):
|
||||
"""Test resume when workflow is not paused."""
|
||||
# Arrange
|
||||
workflow_run_id = "workflow-run-123"
|
||||
pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||
pause_entity.id = "pause-123"
|
||||
|
||||
sample_workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
mock_session.scalar.return_value = sample_workflow_run
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(_WorkflowRunError, match="WorkflowRun is not in PAUSED status"):
|
||||
repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run_id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
def test_resume_workflow_pause_id_mismatch(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
mock_session: Mock,
|
||||
sample_workflow_run: Mock,
|
||||
sample_workflow_pause: Mock,
|
||||
):
|
||||
"""Test resume when pause ID doesn't match."""
|
||||
# Arrange
|
||||
workflow_run_id = "workflow-run-123"
|
||||
pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||
pause_entity.id = "pause-456" # Different ID
|
||||
|
||||
sample_workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||
sample_workflow_pause.id = "pause-123"
|
||||
sample_workflow_run.pause = sample_workflow_pause
|
||||
mock_session.scalar.return_value = sample_workflow_run
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(_WorkflowRunError, match="different id in WorkflowPause and WorkflowPauseEntity"):
|
||||
repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run_id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
|
||||
class TestDeleteWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
"""Test delete_workflow_pause method."""
|
||||
|
||||
def test_delete_workflow_pause_success(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
mock_session: Mock,
|
||||
sample_workflow_pause: Mock,
|
||||
):
|
||||
"""Test successful workflow pause deletion."""
|
||||
# Arrange
|
||||
pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||
pause_entity.id = "pause-123"
|
||||
|
||||
mock_session.get.return_value = sample_workflow_pause
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
# Act
|
||||
repository.delete_workflow_pause(pause_entity=pause_entity)
|
||||
|
||||
# Assert
|
||||
mock_storage.delete.assert_called_once_with(sample_workflow_pause.state_object_key)
|
||||
mock_session.delete.assert_called_once_with(sample_workflow_pause)
|
||||
# When using session.begin() context manager, commit is handled automatically
|
||||
# No explicit commit call is expected
|
||||
|
||||
def test_delete_workflow_pause_not_found(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
mock_session: Mock,
|
||||
):
|
||||
"""Test delete when pause not found."""
|
||||
# Arrange
|
||||
pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||
pause_entity.id = "pause-123"
|
||||
|
||||
mock_session.get.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(_WorkflowRunError, match="WorkflowPause not found: pause-123"):
|
||||
repository.delete_workflow_pause(pause_entity=pause_entity)
|
||||
|
||||
|
||||
class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
"""Test _PrivateWorkflowPauseEntity class."""
|
||||
|
||||
def test_from_models(self, sample_workflow_pause: Mock):
|
||||
"""Test creating _PrivateWorkflowPauseEntity from models."""
|
||||
# Act
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
|
||||
|
||||
# Assert
|
||||
assert isinstance(entity, _PrivateWorkflowPauseEntity)
|
||||
assert entity._pause_model == sample_workflow_pause
|
||||
|
||||
def test_properties(self, sample_workflow_pause: Mock):
|
||||
"""Test entity properties."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
|
||||
|
||||
# Act & Assert
|
||||
assert entity.id == sample_workflow_pause.id
|
||||
assert entity.workflow_execution_id == sample_workflow_pause.workflow_run_id
|
||||
assert entity.resumed_at == sample_workflow_pause.resumed_at
|
||||
|
||||
def test_get_state(self, sample_workflow_pause: Mock):
|
||||
"""Test getting state from storage."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
|
||||
expected_state = b'{"test": "state"}'
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
mock_storage.load.return_value = expected_state
|
||||
|
||||
# Act
|
||||
result = entity.get_state()
|
||||
|
||||
# Assert
|
||||
assert result == expected_state
|
||||
mock_storage.load.assert_called_once_with(sample_workflow_pause.state_object_key)
|
||||
|
||||
def test_get_state_caching(self, sample_workflow_pause: Mock):
|
||||
"""Test state caching in get_state method."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
|
||||
expected_state = b'{"test": "state"}'
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
mock_storage.load.return_value = expected_state
|
||||
|
||||
# Act
|
||||
result1 = entity.get_state()
|
||||
result2 = entity.get_state() # Should use cache
|
||||
|
||||
# Assert
|
||||
assert result1 == expected_state
|
||||
assert result2 == expected_state
|
||||
mock_storage.load.assert_called_once() # Only called once due to caching
|
||||
@ -21,6 +21,7 @@ from core.file.enums import FileTransferMethod, FileType
|
||||
from core.file.models import File
|
||||
from core.variables.segments import (
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArraySegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
@ -30,6 +31,7 @@ from core.variables.segments import (
|
||||
StringSegment,
|
||||
)
|
||||
from services.variable_truncator import (
|
||||
DummyVariableTruncator,
|
||||
MaxDepthExceededError,
|
||||
TruncationResult,
|
||||
UnknownTypeError,
|
||||
@ -596,3 +598,32 @@ class TestIntegrationScenarios:
|
||||
truncated_mapping, truncated = truncator.truncate_variable_mapping(mapping)
|
||||
assert truncated is False
|
||||
assert truncated_mapping == mapping
|
||||
|
||||
|
||||
def test_dummy_variable_truncator_methods():
|
||||
"""Test DummyVariableTruncator methods work correctly."""
|
||||
truncator = DummyVariableTruncator()
|
||||
|
||||
# Test truncate_variable_mapping
|
||||
test_data: dict[str, Any] = {
|
||||
"key1": "value1",
|
||||
"key2": ["item1", "item2"],
|
||||
"large_array": list(range(2000)),
|
||||
}
|
||||
result, is_truncated = truncator.truncate_variable_mapping(test_data)
|
||||
|
||||
assert result == test_data
|
||||
assert not is_truncated
|
||||
|
||||
# Test truncate method
|
||||
segment = StringSegment(value="test string")
|
||||
result = truncator.truncate(segment)
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.result == segment
|
||||
assert result.truncated is False
|
||||
|
||||
segment = ArrayNumberSegment(value=list(range(2000)))
|
||||
result = truncator.truncate(segment)
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.result == segment
|
||||
assert result.truncated is False
|
||||
|
||||
200
api/tests/unit_tests/services/test_workflow_run_service_pause.py
Normal file
200
api/tests/unit_tests/services/test_workflow_run_service_pause.py
Normal file
@ -0,0 +1,200 @@
|
||||
"""Comprehensive unit tests for WorkflowRunService class.
|
||||
|
||||
This test suite covers all pause state management operations including:
|
||||
- Retrieving pause state for workflow runs
|
||||
- Saving pause state with file uploads
|
||||
- Marking paused workflows as resumed
|
||||
- Error handling and edge cases
|
||||
- Database transaction management
|
||||
- Repository-based approach testing
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity
|
||||
from services.workflow_run_service import (
|
||||
WorkflowRunService,
|
||||
)
|
||||
|
||||
|
||||
class TestDataFactory:
|
||||
"""Factory class for creating test data objects."""
|
||||
|
||||
@staticmethod
|
||||
def create_workflow_run_mock(
|
||||
id: str = "workflow-run-123",
|
||||
tenant_id: str = "tenant-456",
|
||||
app_id: str = "app-789",
|
||||
workflow_id: str = "workflow-101",
|
||||
status: str | WorkflowExecutionStatus = "paused",
|
||||
pause_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> MagicMock:
|
||||
"""Create a mock WorkflowRun object."""
|
||||
mock_run = MagicMock()
|
||||
mock_run.id = id
|
||||
mock_run.tenant_id = tenant_id
|
||||
mock_run.app_id = app_id
|
||||
mock_run.workflow_id = workflow_id
|
||||
mock_run.status = status
|
||||
mock_run.pause_id = pause_id
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(mock_run, key, value)
|
||||
|
||||
return mock_run
|
||||
|
||||
@staticmethod
|
||||
def create_workflow_pause_mock(
|
||||
id: str = "pause-123",
|
||||
tenant_id: str = "tenant-456",
|
||||
app_id: str = "app-789",
|
||||
workflow_id: str = "workflow-101",
|
||||
workflow_execution_id: str = "workflow-execution-123",
|
||||
state_file_id: str = "file-456",
|
||||
resumed_at: datetime | None = None,
|
||||
**kwargs,
|
||||
) -> MagicMock:
|
||||
"""Create a mock WorkflowPauseModel object."""
|
||||
mock_pause = MagicMock()
|
||||
mock_pause.id = id
|
||||
mock_pause.tenant_id = tenant_id
|
||||
mock_pause.app_id = app_id
|
||||
mock_pause.workflow_id = workflow_id
|
||||
mock_pause.workflow_execution_id = workflow_execution_id
|
||||
mock_pause.state_file_id = state_file_id
|
||||
mock_pause.resumed_at = resumed_at
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(mock_pause, key, value)
|
||||
|
||||
return mock_pause
|
||||
|
||||
@staticmethod
|
||||
def create_upload_file_mock(
|
||||
id: str = "file-456",
|
||||
key: str = "upload_files/test/state.json",
|
||||
name: str = "state.json",
|
||||
tenant_id: str = "tenant-456",
|
||||
**kwargs,
|
||||
) -> MagicMock:
|
||||
"""Create a mock UploadFile object."""
|
||||
mock_file = MagicMock()
|
||||
mock_file.id = id
|
||||
mock_file.key = key
|
||||
mock_file.name = name
|
||||
mock_file.tenant_id = tenant_id
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(mock_file, key, value)
|
||||
|
||||
return mock_file
|
||||
|
||||
@staticmethod
|
||||
def create_pause_entity_mock(
|
||||
pause_model: MagicMock | None = None,
|
||||
upload_file: MagicMock | None = None,
|
||||
) -> _PrivateWorkflowPauseEntity:
|
||||
"""Create a mock _PrivateWorkflowPauseEntity object."""
|
||||
if pause_model is None:
|
||||
pause_model = TestDataFactory.create_workflow_pause_mock()
|
||||
if upload_file is None:
|
||||
upload_file = TestDataFactory.create_upload_file_mock()
|
||||
|
||||
return _PrivateWorkflowPauseEntity.from_models(pause_model, upload_file)
|
||||
|
||||
|
||||
class TestWorkflowRunService:
|
||||
"""Comprehensive unit tests for WorkflowRunService class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_factory(self):
|
||||
"""Create a mock session factory with proper session management."""
|
||||
mock_session = create_autospec(Session)
|
||||
|
||||
# Create a mock context manager for the session
|
||||
mock_session_cm = MagicMock()
|
||||
mock_session_cm.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_session_cm.__exit__ = MagicMock(return_value=None)
|
||||
|
||||
# Create a mock context manager for the transaction
|
||||
mock_transaction_cm = MagicMock()
|
||||
mock_transaction_cm.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_transaction_cm.__exit__ = MagicMock(return_value=None)
|
||||
|
||||
mock_session.begin = MagicMock(return_value=mock_transaction_cm)
|
||||
|
||||
# Create mock factory that returns the context manager
|
||||
mock_factory = MagicMock(spec=sessionmaker)
|
||||
mock_factory.return_value = mock_session_cm
|
||||
|
||||
return mock_factory, mock_session
|
||||
|
||||
@pytest.fixture
|
||||
def mock_workflow_run_repository(self):
|
||||
"""Create a mock APIWorkflowRunRepository."""
|
||||
mock_repo = create_autospec(APIWorkflowRunRepository)
|
||||
return mock_repo
|
||||
|
||||
@pytest.fixture
|
||||
def workflow_run_service(self, mock_session_factory, mock_workflow_run_repository):
|
||||
"""Create WorkflowRunService instance with mocked dependencies."""
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
|
||||
service = WorkflowRunService(session_factory)
|
||||
return service
|
||||
|
||||
@pytest.fixture
|
||||
def workflow_run_service_with_engine(self, mock_session_factory, mock_workflow_run_repository):
|
||||
"""Create WorkflowRunService instance with Engine input."""
|
||||
mock_engine = create_autospec(Engine)
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
|
||||
service = WorkflowRunService(mock_engine)
|
||||
return service
|
||||
|
||||
# ==================== Initialization Tests ====================
|
||||
|
||||
def test_init_with_session_factory(self, mock_session_factory, mock_workflow_run_repository):
|
||||
"""Test WorkflowRunService initialization with session_factory."""
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
|
||||
service = WorkflowRunService(session_factory)
|
||||
|
||||
assert service._session_factory == session_factory
|
||||
mock_factory.create_api_workflow_run_repository.assert_called_once_with(session_factory)
|
||||
|
||||
def test_init_with_engine(self, mock_session_factory, mock_workflow_run_repository):
|
||||
"""Test WorkflowRunService initialization with Engine (should convert to sessionmaker)."""
|
||||
mock_engine = create_autospec(Engine)
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
|
||||
with patch("services.workflow_run_service.sessionmaker", return_value=session_factory) as mock_sessionmaker:
|
||||
service = WorkflowRunService(mock_engine)
|
||||
|
||||
mock_sessionmaker.assert_called_once_with(bind=mock_engine, expire_on_commit=False)
|
||||
assert service._session_factory == session_factory
|
||||
mock_factory.create_api_workflow_run_repository.assert_called_once_with(session_factory)
|
||||
|
||||
def test_init_with_default_dependencies(self, mock_session_factory):
|
||||
"""Test WorkflowRunService initialization with default dependencies."""
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
service = WorkflowRunService(session_factory)
|
||||
|
||||
assert service._session_factory == session_factory
|
||||
Reference in New Issue
Block a user