Merge branch 'main' into deploy/dev

# Conflicts:
#	api/core/app/apps/advanced_chat/app_runner.py
#	api/core/app/apps/pipeline/pipeline_generator.py
#	api/core/entities/mcp_provider.py
#	api/core/helper/marketplace.py
#	api/models/workflow.py
#	api/services/tools/tools_transform_service.py
#	api/tasks/document_indexing_task.py
#	api/tests/test_containers_integration_tests/core/__init__.py
#	api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py
#	dev/start-worker
#	docker/.env.example
#	web/app/components/base/chat/embedded-chatbot/hooks.tsx
#	web/app/components/workflow/hooks/use-workflow.ts
#	web/app/components/workflow/nodes/_base/components/workflow-panel/index.tsx
#	web/global.d.ts
#	web/pnpm-lock.yaml
#	web/service/use-plugins.ts
This commit is contained in:
Stream
2025-11-06 15:58:41 +08:00
426 changed files with 21485 additions and 5531 deletions

View File

@ -0,0 +1,258 @@
app:
description: 'This workflow tests the iteration node with flatten_output=False.
It processes [1, 2, 3], outputs [item, item*2] for each iteration.
With flatten_output=False, it should output nested arrays:
```
{"output": [[1, 2], [2, 4], [3, 6]]}
```'
icon: 🤖
icon_background: '#FFEAD5'
mode: workflow
name: test_iteration_flatten_disabled
use_icon_as_answer_icon: false
dependencies: []
kind: app
version: 0.3.1
workflow:
conversation_variables: []
environment_variables: []
features:
file_upload:
enabled: false
opening_statement: ''
retriever_resource:
enabled: true
sensitive_word_avoidance:
enabled: false
speech_to_text:
enabled: false
suggested_questions: []
suggested_questions_after_answer:
enabled: false
text_to_speech:
enabled: false
graph:
edges:
- data:
isInIteration: false
isInLoop: false
sourceType: start
targetType: code
id: start-source-code-target
source: start_node
sourceHandle: source
target: code_node
targetHandle: target
type: custom
zIndex: 0
- data:
isInIteration: false
isInLoop: false
sourceType: code
targetType: iteration
id: code-source-iteration-target
source: code_node
sourceHandle: source
target: iteration_node
targetHandle: target
type: custom
zIndex: 0
- data:
isInIteration: true
isInLoop: false
iteration_id: iteration_node
sourceType: iteration-start
targetType: code
id: iteration-start-source-code-inner-target
source: iteration_nodestart
sourceHandle: source
target: code_inner_node
targetHandle: target
type: custom
zIndex: 1002
- data:
isInIteration: false
isInLoop: false
sourceType: iteration
targetType: end
id: iteration-source-end-target
source: iteration_node
sourceHandle: source
target: end_node
targetHandle: target
type: custom
zIndex: 0
nodes:
- data:
desc: ''
selected: false
title: Start
type: start
variables: []
height: 54
id: start_node
position:
x: 80
y: 282
positionAbsolute:
x: 80
y: 282
sourcePosition: right
targetPosition: left
type: custom
width: 244
- data:
code: "\ndef main() -> dict:\n return {\n \"result\": [1, 2, 3],\n\
\ }\n"
code_language: python3
desc: ''
outputs:
result:
children: null
type: array[number]
selected: false
title: Generate Array
type: code
variables: []
height: 54
id: code_node
position:
x: 384
y: 282
positionAbsolute:
x: 384
y: 282
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 244
- data:
desc: ''
error_handle_mode: terminated
flatten_output: false
height: 178
is_parallel: false
iterator_input_type: array[number]
iterator_selector:
- code_node
- result
output_selector:
- code_inner_node
- result
output_type: array[array[number]]
parallel_nums: 10
selected: false
start_node_id: iteration_nodestart
title: Iteration with Flatten Disabled
type: iteration
width: 388
height: 178
id: iteration_node
position:
x: 684
y: 282
positionAbsolute:
x: 684
y: 282
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 388
zIndex: 1
- data:
desc: ''
isInIteration: true
selected: false
title: ''
type: iteration-start
draggable: false
height: 48
id: iteration_nodestart
parentId: iteration_node
position:
x: 24
y: 68
positionAbsolute:
x: 708
y: 350
selectable: false
sourcePosition: right
targetPosition: left
type: custom-iteration-start
width: 44
zIndex: 1002
- data:
code: "\ndef main(arg1: int) -> dict:\n return {\n \"result\": [arg1,\
\ arg1 * 2],\n }\n"
code_language: python3
desc: ''
isInIteration: true
isInLoop: false
iteration_id: iteration_node
outputs:
result:
children: null
type: array[number]
selected: false
title: Generate Pair
type: code
variables:
- value_selector:
- iteration_node
- item
value_type: number
variable: arg1
height: 54
id: code_inner_node
parentId: iteration_node
position:
x: 128
y: 68
positionAbsolute:
x: 812
y: 350
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 244
zIndex: 1002
- data:
desc: ''
outputs:
- value_selector:
- iteration_node
- output
value_type: array[array[number]]
variable: output
selected: false
title: End
type: end
height: 90
id: end_node
position:
x: 1132
y: 282
positionAbsolute:
x: 1132
y: 282
selected: true
sourcePosition: right
targetPosition: left
type: custom
width: 244
viewport:
x: -476
y: 3
zoom: 1

View File

@ -0,0 +1,258 @@
app:
description: 'This workflow tests the iteration node with flatten_output=True.
It processes [1, 2, 3], outputs [item, item*2] for each iteration.
With flatten_output=True (default), it should output:
```
{"output": [1, 2, 2, 4, 3, 6]}
```'
icon: 🤖
icon_background: '#FFEAD5'
mode: workflow
name: test_iteration_flatten_enabled
use_icon_as_answer_icon: false
dependencies: []
kind: app
version: 0.3.1
workflow:
conversation_variables: []
environment_variables: []
features:
file_upload:
enabled: false
opening_statement: ''
retriever_resource:
enabled: true
sensitive_word_avoidance:
enabled: false
speech_to_text:
enabled: false
suggested_questions: []
suggested_questions_after_answer:
enabled: false
text_to_speech:
enabled: false
graph:
edges:
- data:
isInIteration: false
isInLoop: false
sourceType: start
targetType: code
id: start-source-code-target
source: start_node
sourceHandle: source
target: code_node
targetHandle: target
type: custom
zIndex: 0
- data:
isInIteration: false
isInLoop: false
sourceType: code
targetType: iteration
id: code-source-iteration-target
source: code_node
sourceHandle: source
target: iteration_node
targetHandle: target
type: custom
zIndex: 0
- data:
isInIteration: true
isInLoop: false
iteration_id: iteration_node
sourceType: iteration-start
targetType: code
id: iteration-start-source-code-inner-target
source: iteration_nodestart
sourceHandle: source
target: code_inner_node
targetHandle: target
type: custom
zIndex: 1002
- data:
isInIteration: false
isInLoop: false
sourceType: iteration
targetType: end
id: iteration-source-end-target
source: iteration_node
sourceHandle: source
target: end_node
targetHandle: target
type: custom
zIndex: 0
nodes:
- data:
desc: ''
selected: false
title: Start
type: start
variables: []
height: 54
id: start_node
position:
x: 80
y: 282
positionAbsolute:
x: 80
y: 282
sourcePosition: right
targetPosition: left
type: custom
width: 244
- data:
code: "\ndef main() -> dict:\n return {\n \"result\": [1, 2, 3],\n\
\ }\n"
code_language: python3
desc: ''
outputs:
result:
children: null
type: array[number]
selected: false
title: Generate Array
type: code
variables: []
height: 54
id: code_node
position:
x: 384
y: 282
positionAbsolute:
x: 384
y: 282
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 244
- data:
desc: ''
error_handle_mode: terminated
flatten_output: true
height: 178
is_parallel: false
iterator_input_type: array[number]
iterator_selector:
- code_node
- result
output_selector:
- code_inner_node
- result
output_type: array[array[number]]
parallel_nums: 10
selected: false
start_node_id: iteration_nodestart
title: Iteration with Flatten Enabled
type: iteration
width: 388
height: 178
id: iteration_node
position:
x: 684
y: 282
positionAbsolute:
x: 684
y: 282
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 388
zIndex: 1
- data:
desc: ''
isInIteration: true
selected: false
title: ''
type: iteration-start
draggable: false
height: 48
id: iteration_nodestart
parentId: iteration_node
position:
x: 24
y: 68
positionAbsolute:
x: 708
y: 350
selectable: false
sourcePosition: right
targetPosition: left
type: custom-iteration-start
width: 44
zIndex: 1002
- data:
code: "\ndef main(arg1: int) -> dict:\n return {\n \"result\": [arg1,\
\ arg1 * 2],\n }\n"
code_language: python3
desc: ''
isInIteration: true
isInLoop: false
iteration_id: iteration_node
outputs:
result:
children: null
type: array[number]
selected: false
title: Generate Pair
type: code
variables:
- value_selector:
- iteration_node
- item
value_type: number
variable: arg1
height: 54
id: code_inner_node
parentId: iteration_node
position:
x: 128
y: 68
positionAbsolute:
x: 812
y: 350
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 244
zIndex: 1002
- data:
desc: ''
outputs:
- value_selector:
- iteration_node
- output
value_type: array[number]
variable: output
selected: false
title: End
type: end
height: 90
id: end_node
position:
x: 1132
y: 282
positionAbsolute:
x: 1132
y: 282
selected: true
sourcePosition: right
targetPosition: left
type: custom
width: 244
viewport:
x: -476
y: 3
zoom: 1

View File

@ -1 +1 @@
# Test containers integration tests for core RAG pipeline components
# Core integration tests package

View File

@ -0,0 +1 @@
# App integration tests package

View File

@ -0,0 +1 @@
# Layers integration tests package

View File

@ -0,0 +1,520 @@
"""Comprehensive TestContainers-based integration tests for PauseStatePersistenceLayer class.
This test suite covers complete integration scenarios including:
- Real database interactions using containerized PostgreSQL
- Real storage operations using test storage backend
- Complete workflow: event -> state serialization -> database save -> storage save
- Testing with actual WorkflowRunService (not mocked)
- Real Workflow and WorkflowRun instances in database
- Database transactions and rollback behavior
- Actual file upload and retrieval through storage
- Workflow status transitions in database
- Error handling with real database constraints
- Multiple pause events in sequence
- Integration with real ReadOnlyGraphRuntimeState implementations
These tests use TestContainers to spin up real services for integration testing,
providing more reliable and realistic test scenarios than mocks.
"""
import json
import uuid
from time import time
import pytest
from sqlalchemy import Engine, delete, select
from sqlalchemy.orm import Session
from core.app.layers.pause_state_persist_layer import PauseStatePersistenceLayer
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.graph_engine.entities.commands import GraphEngineCommand
from core.workflow.graph_events.graph import GraphRunPausedEvent
from core.workflow.runtime.graph_runtime_state import GraphRuntimeState
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState
from core.workflow.runtime.read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper
from core.workflow.runtime.variable_pool import SystemVariable, VariablePool
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from models import Account
from models import WorkflowPause as WorkflowPauseModel
from models.model import UploadFile
from models.workflow import Workflow, WorkflowRun
from services.file_service import FileService
from services.workflow_run_service import WorkflowRunService
class _TestCommandChannelImpl:
"""Real implementation of CommandChannel for testing."""
def __init__(self):
self._commands: list[GraphEngineCommand] = []
def fetch_commands(self) -> list[GraphEngineCommand]:
"""Fetch pending commands for this GraphEngine instance."""
return self._commands.copy()
def send_command(self, command: GraphEngineCommand) -> None:
"""Send a command to be processed by this GraphEngine instance."""
self._commands.append(command)
class TestPauseStatePersistenceLayerTestContainers:
"""Comprehensive TestContainers-based integration tests for PauseStatePersistenceLayer class."""
@pytest.fixture
def engine(self, db_session_with_containers: Session):
"""Get database engine from TestContainers session."""
bind = db_session_with_containers.get_bind()
assert isinstance(bind, Engine)
return bind
@pytest.fixture
def file_service(self, engine: Engine):
"""Create FileService instance with TestContainers engine."""
return FileService(engine)
@pytest.fixture
def workflow_run_service(self, engine: Engine, file_service: FileService):
"""Create WorkflowRunService instance with TestContainers engine and FileService."""
return WorkflowRunService(engine)
@pytest.fixture(autouse=True)
def setup_test_data(self, db_session_with_containers, file_service, workflow_run_service):
"""Set up test data for each test method using TestContainers."""
# Create test tenant and account
from models.account import Tenant, TenantAccountJoin, TenantAccountRole
tenant = Tenant(
name="Test Tenant",
status="normal",
)
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
account = Account(
email="test@example.com",
name="Test User",
interface_language="en-US",
status="active",
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant-account join
tenant_join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db_session_with_containers.add(tenant_join)
db_session_with_containers.commit()
# Set test data
self.test_tenant_id = tenant.id
self.test_user_id = account.id
self.test_app_id = str(uuid.uuid4())
self.test_workflow_id = str(uuid.uuid4())
self.test_workflow_run_id = str(uuid.uuid4())
# Create test workflow
self.test_workflow = Workflow(
id=self.test_workflow_id,
tenant_id=self.test_tenant_id,
app_id=self.test_app_id,
type="workflow",
version="draft",
graph='{"nodes": [], "edges": []}',
features='{"file_upload": {"enabled": false}}',
created_by=self.test_user_id,
created_at=naive_utc_now(),
)
# Create test workflow run
self.test_workflow_run = WorkflowRun(
id=self.test_workflow_run_id,
tenant_id=self.test_tenant_id,
app_id=self.test_app_id,
workflow_id=self.test_workflow_id,
type="workflow",
triggered_from="debugging",
version="draft",
status=WorkflowExecutionStatus.RUNNING,
created_by=self.test_user_id,
created_by_role="account",
created_at=naive_utc_now(),
)
# Store session and service instances
self.session = db_session_with_containers
self.file_service = file_service
self.workflow_run_service = workflow_run_service
# Save test data to database
self.session.add(self.test_workflow)
self.session.add(self.test_workflow_run)
self.session.commit()
yield
# Cleanup
self._cleanup_test_data()
def _cleanup_test_data(self):
"""Clean up test data after each test method."""
try:
# Clean up workflow pauses
self.session.execute(delete(WorkflowPauseModel))
# Clean up upload files
self.session.execute(
delete(UploadFile).where(
UploadFile.tenant_id == self.test_tenant_id,
)
)
# Clean up workflow runs
self.session.execute(
delete(WorkflowRun).where(
WorkflowRun.tenant_id == self.test_tenant_id,
WorkflowRun.app_id == self.test_app_id,
)
)
# Clean up workflows
self.session.execute(
delete(Workflow).where(
Workflow.tenant_id == self.test_tenant_id,
Workflow.app_id == self.test_app_id,
)
)
self.session.commit()
except Exception as e:
self.session.rollback()
raise e
def _create_graph_runtime_state(
self,
outputs: dict[str, object] | None = None,
total_tokens: int = 0,
node_run_steps: int = 0,
variables: dict[tuple[str, str], object] | None = None,
workflow_run_id: str | None = None,
) -> ReadOnlyGraphRuntimeState:
"""Create a real GraphRuntimeState for testing."""
start_at = time()
execution_id = workflow_run_id or getattr(self, "test_workflow_run_id", None) or str(uuid.uuid4())
# Create variable pool
variable_pool = VariablePool(system_variables=SystemVariable(workflow_execution_id=execution_id))
if variables:
for (node_id, var_key), value in variables.items():
variable_pool.add([node_id, var_key], value)
# Create LLM usage
llm_usage = LLMUsage.empty_usage()
# Create graph runtime state
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
start_at=start_at,
total_tokens=total_tokens,
llm_usage=llm_usage,
outputs=outputs or {},
node_run_steps=node_run_steps,
)
return ReadOnlyGraphRuntimeStateWrapper(graph_runtime_state)
def _create_pause_state_persistence_layer(
self,
workflow_run: WorkflowRun | None = None,
workflow: Workflow | None = None,
state_owner_user_id: str | None = None,
) -> PauseStatePersistenceLayer:
"""Create PauseStatePersistenceLayer with real dependencies."""
owner_id = state_owner_user_id
if owner_id is None:
if workflow is not None and workflow.created_by:
owner_id = workflow.created_by
elif workflow_run is not None and workflow_run.created_by:
owner_id = workflow_run.created_by
else:
owner_id = getattr(self, "test_user_id", None)
assert owner_id is not None
owner_id = str(owner_id)
return PauseStatePersistenceLayer(
session_factory=self.session.get_bind(),
state_owner_user_id=owner_id,
)
def test_complete_pause_flow_with_real_dependencies(self, db_session_with_containers):
"""Test complete pause flow: event -> state serialization -> database save -> storage save."""
# Arrange
layer = self._create_pause_state_persistence_layer()
# Create real graph runtime state with test data
test_outputs = {"result": "test_output", "step": "intermediate"}
test_variables = {
("node1", "var1"): "string_value",
("node2", "var2"): {"complex": "object"},
}
graph_runtime_state = self._create_graph_runtime_state(
outputs=test_outputs,
total_tokens=100,
node_run_steps=5,
variables=test_variables,
)
command_channel = _TestCommandChannelImpl()
layer.initialize(graph_runtime_state, command_channel)
# Create pause event
event = GraphRunPausedEvent(
reason=SchedulingPause(message="test pause"),
outputs={"intermediate": "result"},
)
# Act
layer.on_event(event)
# Assert - Verify pause state was saved to database
self.session.refresh(self.test_workflow_run)
workflow_run = self.session.get(WorkflowRun, self.test_workflow_run_id)
assert workflow_run is not None
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
# Verify pause state exists in database
pause_model = self.session.scalars(
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
).first()
assert pause_model is not None
assert pause_model.workflow_id == self.test_workflow_id
assert pause_model.workflow_run_id == self.test_workflow_run_id
assert pause_model.state_object_key != ""
assert pause_model.resumed_at is None
storage_content = storage.load(pause_model.state_object_key).decode()
expected_state = json.loads(graph_runtime_state.dumps())
actual_state = json.loads(storage_content)
assert actual_state == expected_state
def test_state_persistence_and_retrieval(self, db_session_with_containers):
"""Test that pause state can be persisted and retrieved correctly."""
# Arrange
layer = self._create_pause_state_persistence_layer()
# Create complex test data
complex_outputs = {
"nested": {"key": "value", "number": 42},
"list": [1, 2, 3, {"nested": "item"}],
"boolean": True,
"null_value": None,
}
complex_variables = {
("node1", "var1"): "string_value",
("node2", "var2"): {"complex": "object"},
("node3", "var3"): [1, 2, 3],
}
graph_runtime_state = self._create_graph_runtime_state(
outputs=complex_outputs,
total_tokens=250,
node_run_steps=10,
variables=complex_variables,
)
command_channel = _TestCommandChannelImpl()
layer.initialize(graph_runtime_state, command_channel)
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
# Act - Save pause state
layer.on_event(event)
# Assert - Retrieve and verify
pause_entity = self.workflow_run_service._workflow_run_repo.get_workflow_pause(self.test_workflow_run_id)
assert pause_entity is not None
assert pause_entity.workflow_execution_id == self.test_workflow_run_id
state_bytes = pause_entity.get_state()
retrieved_state = json.loads(state_bytes.decode())
expected_state = json.loads(graph_runtime_state.dumps())
assert retrieved_state == expected_state
assert retrieved_state["outputs"] == complex_outputs
assert retrieved_state["total_tokens"] == 250
assert retrieved_state["node_run_steps"] == 10
def test_database_transaction_handling(self, db_session_with_containers):
"""Test that database transactions are handled correctly."""
# Arrange
layer = self._create_pause_state_persistence_layer()
graph_runtime_state = self._create_graph_runtime_state(
outputs={"test": "transaction"},
total_tokens=50,
)
command_channel = _TestCommandChannelImpl()
layer.initialize(graph_runtime_state, command_channel)
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
# Act
layer.on_event(event)
# Assert - Verify data is committed and accessible in new session
with Session(bind=self.session.get_bind(), expire_on_commit=False) as new_session:
workflow_run = new_session.get(WorkflowRun, self.test_workflow_run_id)
assert workflow_run is not None
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
pause_model = new_session.scalars(
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
).first()
assert pause_model is not None
assert pause_model.workflow_run_id == self.test_workflow_run_id
assert pause_model.resumed_at is None
assert pause_model.state_object_key != ""
def test_file_storage_integration(self, db_session_with_containers):
"""Test integration with file storage system."""
# Arrange
layer = self._create_pause_state_persistence_layer()
# Create large state data to test storage
large_outputs = {"data": "x" * 10000} # 10KB of data
graph_runtime_state = self._create_graph_runtime_state(
outputs=large_outputs,
total_tokens=1000,
)
command_channel = _TestCommandChannelImpl()
layer.initialize(graph_runtime_state, command_channel)
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
# Act
layer.on_event(event)
# Assert - Verify file was uploaded to storage
self.session.refresh(self.test_workflow_run)
pause_model = self.session.scalars(
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == self.test_workflow_run.id)
).first()
assert pause_model is not None
assert pause_model.state_object_key != ""
# Verify content in storage
storage_content = storage.load(pause_model.state_object_key).decode()
assert storage_content == graph_runtime_state.dumps()
def test_workflow_with_different_creators(self, db_session_with_containers):
"""Test pause state with workflows created by different users."""
# Arrange - Create workflow with different creator
different_user_id = str(uuid.uuid4())
different_workflow = Workflow(
id=str(uuid.uuid4()),
tenant_id=self.test_tenant_id,
app_id=self.test_app_id,
type="workflow",
version="draft",
graph='{"nodes": [], "edges": []}',
features='{"file_upload": {"enabled": false}}',
created_by=different_user_id,
created_at=naive_utc_now(),
)
different_workflow_run = WorkflowRun(
id=str(uuid.uuid4()),
tenant_id=self.test_tenant_id,
app_id=self.test_app_id,
workflow_id=different_workflow.id,
type="workflow",
triggered_from="debugging",
version="draft",
status=WorkflowExecutionStatus.RUNNING,
created_by=self.test_user_id, # Run created by different user
created_by_role="account",
created_at=naive_utc_now(),
)
self.session.add(different_workflow)
self.session.add(different_workflow_run)
self.session.commit()
layer = self._create_pause_state_persistence_layer(
workflow_run=different_workflow_run,
workflow=different_workflow,
)
graph_runtime_state = self._create_graph_runtime_state(
outputs={"creator_test": "different_creator"},
workflow_run_id=different_workflow_run.id,
)
command_channel = _TestCommandChannelImpl()
layer.initialize(graph_runtime_state, command_channel)
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
# Act
layer.on_event(event)
# Assert - Should use workflow creator (not run creator)
self.session.refresh(different_workflow_run)
pause_model = self.session.scalars(
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == different_workflow_run.id)
).first()
assert pause_model is not None
# Verify the state owner is the workflow creator
pause_entity = self.workflow_run_service._workflow_run_repo.get_workflow_pause(different_workflow_run.id)
assert pause_entity is not None
def test_layer_ignores_non_pause_events(self, db_session_with_containers):
"""Test that layer ignores non-pause events."""
# Arrange
layer = self._create_pause_state_persistence_layer()
graph_runtime_state = self._create_graph_runtime_state()
command_channel = _TestCommandChannelImpl()
layer.initialize(graph_runtime_state, command_channel)
# Import other event types
from core.workflow.graph_events.graph import (
GraphRunFailedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
)
# Act - Send non-pause events
layer.on_event(GraphRunStartedEvent())
layer.on_event(GraphRunSucceededEvent(outputs={"result": "success"}))
layer.on_event(GraphRunFailedEvent(error="test error", exceptions_count=1))
# Assert - No pause state should be created
self.session.refresh(self.test_workflow_run)
assert self.test_workflow_run.status == WorkflowExecutionStatus.RUNNING
pause_states = (
self.session.query(WorkflowPauseModel)
.filter(WorkflowPauseModel.workflow_run_id == self.test_workflow_run_id)
.all()
)
assert len(pause_states) == 0
def test_layer_requires_initialization(self, db_session_with_containers):
"""Test that layer requires proper initialization before handling events."""
# Arrange
layer = self._create_pause_state_persistence_layer()
# Don't initialize - graph_runtime_state should not be set
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
# Act & Assert - Should raise AttributeError
with pytest.raises(AttributeError):
layer.on_event(event)

View File

@ -5,6 +5,7 @@ import pytest
from faker import Faker
from core.app.entities.app_invoke_entities import InvokeFrom
from enums.cloud_plan import CloudPlan
from models.model import EndUser
from models.workflow import Workflow
from services.app_generate_service import AppGenerateService
@ -32,7 +33,7 @@ class TestAppGenerateService:
patch("services.app_generate_service.dify_config") as mock_dify_config,
):
# Setup default mock returns for billing service
mock_billing_service.get_info.return_value = {"subscription": {"plan": "sandbox"}}
mock_billing_service.get_info.return_value = {"subscription": {"plan": CloudPlan.SANDBOX}}
# Setup default mock returns for workflow service
mock_workflow_service_instance = mock_workflow_service.return_value
@ -430,7 +431,7 @@ class TestAppGenerateService:
# Setup billing service mock for sandbox plan
mock_external_service_dependencies["billing_service"].get_info.return_value = {
"subscription": {"plan": "sandbox"}
"subscription": {"plan": CloudPlan.SANDBOX}
}
# Set BILLING_ENABLED to True for this test
@ -461,7 +462,7 @@ class TestAppGenerateService:
# Setup billing service mock for sandbox plan
mock_external_service_dependencies["billing_service"].get_info.return_value = {
"subscription": {"plan": "sandbox"}
"subscription": {"plan": CloudPlan.SANDBOX}
}
# Set BILLING_ENABLED to True for this test

View File

@ -3,6 +3,7 @@ from unittest.mock import patch
import pytest
from faker import Faker
from enums.cloud_plan import CloudPlan
from services.feature_service import FeatureModel, FeatureService, KnowledgeRateLimitModel, SystemFeatureModel
@ -173,7 +174,7 @@ class TestFeatureService:
# Set mock return value inside the patch context
mock_external_service_dependencies["billing_service"].get_info.return_value = {
"enabled": True,
"subscription": {"plan": "sandbox", "interval": "monthly", "education": False},
"subscription": {"plan": CloudPlan.SANDBOX, "interval": "monthly", "education": False},
"members": {"size": 1, "limit": 3},
"apps": {"size": 1, "limit": 5},
"vector_space": {"size": 1, "limit": 2},
@ -189,7 +190,7 @@ class TestFeatureService:
result = FeatureService.get_features(tenant_id)
# Assert: Verify sandbox-specific limitations
assert result.billing.subscription.plan == "sandbox"
assert result.billing.subscription.plan == CloudPlan.SANDBOX
assert result.education.activated is False
# Verify sandbox limitations

View File

@ -11,7 +11,7 @@ from configs import dify_config
from models import Account, Tenant
from models.enums import CreatorUserRole
from models.model import EndUser, UploadFile
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
from services.errors.file import BlockedFileExtensionError, FileTooLargeError, UnsupportedFileTypeError
from services.file_service import FileService
@ -943,3 +943,150 @@ class TestFileService:
# Should have the signed URL when source_url is empty
assert upload_file2.source_url == "https://example.com/signed-url"
# Test file extension blacklist
def test_upload_file_blocked_extension(
self, db_session_with_containers, engine, mock_external_service_dependencies
):
"""
Test file upload with blocked extension.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock blacklist configuration by patching the inner field
with patch.object(dify_config, "inner_UPLOAD_FILE_EXTENSION_BLACKLIST", "exe,bat,sh"):
filename = "malware.exe"
content = b"test content"
mimetype = "application/x-msdownload"
with pytest.raises(BlockedFileExtensionError):
FileService(engine).upload_file(
filename=filename,
content=content,
mimetype=mimetype,
user=account,
)
def test_upload_file_blocked_extension_case_insensitive(
self, db_session_with_containers, engine, mock_external_service_dependencies
):
"""
Test file upload with blocked extension (case insensitive).
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock blacklist configuration by patching the inner field
with patch.object(dify_config, "inner_UPLOAD_FILE_EXTENSION_BLACKLIST", "exe,bat"):
# Test with uppercase extension
filename = "malware.EXE"
content = b"test content"
mimetype = "application/x-msdownload"
with pytest.raises(BlockedFileExtensionError):
FileService(engine).upload_file(
filename=filename,
content=content,
mimetype=mimetype,
user=account,
)
def test_upload_file_not_in_blacklist(self, db_session_with_containers, engine, mock_external_service_dependencies):
"""
Test file upload with extension not in blacklist.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock blacklist configuration by patching the inner field
with patch.object(dify_config, "inner_UPLOAD_FILE_EXTENSION_BLACKLIST", "exe,bat,sh"):
filename = "document.pdf"
content = b"test content"
mimetype = "application/pdf"
upload_file = FileService(engine).upload_file(
filename=filename,
content=content,
mimetype=mimetype,
user=account,
)
assert upload_file is not None
assert upload_file.name == filename
assert upload_file.extension == "pdf"
def test_upload_file_empty_blacklist(self, db_session_with_containers, engine, mock_external_service_dependencies):
"""
Test file upload with empty blacklist (default behavior).
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock empty blacklist configuration by patching the inner field
with patch.object(dify_config, "inner_UPLOAD_FILE_EXTENSION_BLACKLIST", ""):
# Should allow all file types when blacklist is empty
filename = "script.sh"
content = b"#!/bin/bash\necho test"
mimetype = "application/x-sh"
upload_file = FileService(engine).upload_file(
filename=filename,
content=content,
mimetype=mimetype,
user=account,
)
assert upload_file is not None
assert upload_file.extension == "sh"
def test_upload_file_multiple_blocked_extensions(
self, db_session_with_containers, engine, mock_external_service_dependencies
):
"""
Test file upload with multiple blocked extensions.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock blacklist with multiple extensions by patching the inner field
blacklist_str = "exe,bat,cmd,com,scr,vbs,ps1,msi,dll"
with patch.object(dify_config, "inner_UPLOAD_FILE_EXTENSION_BLACKLIST", blacklist_str):
for ext in blacklist_str.split(","):
filename = f"malware.{ext}"
content = b"test content"
mimetype = "application/octet-stream"
with pytest.raises(BlockedFileExtensionError):
FileService(engine).upload_file(
filename=filename,
content=content,
mimetype=mimetype,
user=account,
)
def test_upload_file_no_extension_with_blacklist(
self, db_session_with_containers, engine, mock_external_service_dependencies
):
"""
Test file upload with no extension when blacklist is configured.
"""
fake = Faker()
account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies)
# Mock blacklist configuration by patching the inner field
with patch.object(dify_config, "inner_UPLOAD_FILE_EXTENSION_BLACKLIST", "exe,bat"):
# Files with no extension should not be blocked
filename = "README"
content = b"test content"
mimetype = "text/plain"
upload_file = FileService(engine).upload_file(
filename=filename,
content=content,
mimetype=mimetype,
user=account,
)
assert upload_file is not None
assert upload_file.extension == ""

View File

@ -35,9 +35,7 @@ class TestWebAppAuthService:
mock_enterprise_service.WebAppAuth.get_app_access_mode_by_id.return_value = type(
"MockWebAppAuth", (), {"access_mode": "private"}
)()
mock_enterprise_service.WebAppAuth.get_app_access_mode_by_code.return_value = type(
"MockWebAppAuth", (), {"access_mode": "private"}
)()
# Note: get_app_access_mode_by_code method was removed in refactoring
yield {
"passport_service": mock_passport_service,

View File

@ -5,6 +5,7 @@ import pytest
from faker import Faker
from core.entities.document_task import DocumentTask
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document
@ -19,7 +20,7 @@ from tasks.document_indexing_task import (
class TestDocumentIndexingTasks:
"""Integration tests for document indexing tasks using testcontainers.
This test class covers:
- Core _document_indexing function
- Deprecated document_indexing_task function
@ -213,7 +214,7 @@ class TestDocumentIndexingTasks:
# Configure billing features
mock_external_service_dependencies["features"].billing.enabled = billing_enabled
if billing_enabled:
mock_external_service_dependencies["features"].billing.subscription.plan = "sandbox"
mock_external_service_dependencies["features"].billing.subscription.plan = CloudPlan.SANDBOX
mock_external_service_dependencies["features"].vector_space.limit = 100
mock_external_service_dependencies["features"].vector_space.size = 50
@ -462,7 +463,7 @@ class TestDocumentIndexingTasks:
)
# Configure sandbox plan with batch limit
mock_external_service_dependencies["features"].billing.subscription.plan = "sandbox"
mock_external_service_dependencies["features"].billing.subscription.plan = CloudPlan.SANDBOX
# Create more documents than sandbox plan allows (limit is 1)
fake = Faker()
@ -597,7 +598,7 @@ class TestDocumentIndexingTasks:
# Assert: Verify processing occurred (core logic is tested in _document_indexing tests)
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
def test_normal_document_indexing_task_success(
self, db_session_with_containers, mock_external_service_dependencies
):
@ -718,10 +719,10 @@ class TestDocumentIndexingTasks:
# Use real Redis for TenantSelfTaskQueue
from core.rag.pipeline.queue import TenantSelfTaskQueue
# Create real queue instance
queue = TenantSelfTaskQueue(tenant_id, "document_indexing")
# Add waiting tasks to the real Redis queue
waiting_tasks = [
DocumentTask(tenant_id=tenant_id, dataset_id=dataset.id, document_ids=["waiting-doc-1"]),
@ -740,7 +741,7 @@ class TestDocumentIndexingTasks:
# Verify task function was called for each waiting task
assert mock_task_func.delay.call_count == 1
# Verify correct parameters for each call
calls = mock_task_func.delay.call_args_list
assert calls[0][1] == {"tenant_id": tenant_id, "dataset_id": dateset_id, "document_ids": ["waiting-doc-1"]}
@ -782,7 +783,7 @@ class TestDocumentIndexingTasks:
# Create real queue instance
queue = TenantSelfTaskQueue(tenant_id, "document_indexing")
# Add waiting task to the real Redis queue
waiting_task = DocumentTask(tenant_id=tenant_id, dataset_id=dataset.id, document_ids=["waiting-doc-1"])
queue.push_tasks([asdict(waiting_task)])
@ -804,7 +805,7 @@ class TestDocumentIndexingTasks:
# Verify waiting task was still processed despite core processing error
mock_task_func.delay.assert_called_once()
# Verify correct parameters for the call
call = mock_task_func.delay.call_args
assert call[1] == {"tenant_id": tenant_id, "dataset_id": dateset_id, "document_ids": ["waiting-doc-1"]}
@ -831,7 +832,7 @@ class TestDocumentIndexingTasks:
dataset2, documents2 = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=1
)
tenant1_id = dataset1.tenant_id
tenant2_id = dataset2.tenant_id
dataset1_id = dataset1.id
@ -845,15 +846,15 @@ class TestDocumentIndexingTasks:
# Use real Redis for TenantSelfTaskQueue
from core.rag.pipeline.queue import TenantSelfTaskQueue
# Create queue instances for both tenants
queue1 = TenantSelfTaskQueue(tenant1_id, "document_indexing")
queue2 = TenantSelfTaskQueue(tenant2_id, "document_indexing")
# Add waiting tasks to both queues
waiting_task1 = DocumentTask(tenant_id=tenant1_id, dataset_id=dataset1.id, document_ids=["tenant1-doc-1"])
waiting_task2 = DocumentTask(tenant_id=tenant2_id, dataset_id=dataset2.id, document_ids=["tenant2-doc-1"])
queue1.push_tasks([asdict(waiting_task1)])
queue2.push_tasks([asdict(waiting_task2)])

View File

@ -0,0 +1,948 @@
"""Comprehensive integration tests for workflow pause functionality.
This test suite covers complete workflow pause functionality including:
- Real database interactions using containerized PostgreSQL
- Real storage operations using the test storage backend
- Complete workflow: create -> pause -> resume -> delete
- Testing with actual FileService (not mocked)
- Database transactions and rollback behavior
- Actual file upload and retrieval through storage
- Workflow status transitions in the database
- Error handling with real database constraints
- Concurrent access scenarios
- Multi-tenant isolation
- Prune functionality
- File storage integration
These tests use TestContainers to spin up real services for integration testing,
providing more reliable and realistic test scenarios than mocks.
"""
import json
import uuid
from dataclasses import dataclass
from datetime import timedelta
import pytest
from sqlalchemy import delete, select
from sqlalchemy.orm import Session, selectinload, sessionmaker
from core.workflow.entities import WorkflowExecution
from core.workflow.enums import WorkflowExecutionStatus
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from models import Account
from models import WorkflowPause as WorkflowPauseModel
from models.account import Tenant, TenantAccountJoin, TenantAccountRole
from models.model import UploadFile
from models.workflow import Workflow, WorkflowRun
from repositories.sqlalchemy_api_workflow_run_repository import (
DifyAPISQLAlchemyWorkflowRunRepository,
_WorkflowRunError,
)
@dataclass
class PauseWorkflowSuccessCase:
"""Test case for successful pause workflow operations."""
name: str
initial_status: WorkflowExecutionStatus
description: str = ""
@dataclass
class PauseWorkflowFailureCase:
"""Test case for pause workflow failure scenarios."""
name: str
initial_status: WorkflowExecutionStatus
description: str = ""
@dataclass
class ResumeWorkflowSuccessCase:
"""Test case for successful resume workflow operations."""
name: str
initial_status: WorkflowExecutionStatus
description: str = ""
@dataclass
class ResumeWorkflowFailureCase:
"""Test case for resume workflow failure scenarios."""
name: str
initial_status: WorkflowExecutionStatus
pause_resumed: bool
set_running_status: bool = False
description: str = ""
@dataclass
class PrunePausesTestCase:
"""Test case for prune pauses operations."""
name: str
pause_age: timedelta
resume_age: timedelta | None
expected_pruned_count: int
description: str = ""
def pause_workflow_failure_cases() -> list[PauseWorkflowFailureCase]:
"""Create test cases for pause workflow failure scenarios."""
return [
PauseWorkflowFailureCase(
name="pause_already_paused_workflow",
initial_status=WorkflowExecutionStatus.PAUSED,
description="Should fail to pause an already paused workflow",
),
PauseWorkflowFailureCase(
name="pause_completed_workflow",
initial_status=WorkflowExecutionStatus.SUCCEEDED,
description="Should fail to pause a completed workflow",
),
PauseWorkflowFailureCase(
name="pause_failed_workflow",
initial_status=WorkflowExecutionStatus.FAILED,
description="Should fail to pause a failed workflow",
),
]
def resume_workflow_success_cases() -> list[ResumeWorkflowSuccessCase]:
"""Create test cases for successful resume workflow operations."""
return [
ResumeWorkflowSuccessCase(
name="resume_paused_workflow",
initial_status=WorkflowExecutionStatus.PAUSED,
description="Should successfully resume a paused workflow",
),
]
def resume_workflow_failure_cases() -> list[ResumeWorkflowFailureCase]:
"""Create test cases for resume workflow failure scenarios."""
return [
ResumeWorkflowFailureCase(
name="resume_already_resumed_workflow",
initial_status=WorkflowExecutionStatus.PAUSED,
pause_resumed=True,
description="Should fail to resume an already resumed workflow",
),
ResumeWorkflowFailureCase(
name="resume_running_workflow",
initial_status=WorkflowExecutionStatus.RUNNING,
pause_resumed=False,
set_running_status=True,
description="Should fail to resume a running workflow",
),
]
def prune_pauses_test_cases() -> list[PrunePausesTestCase]:
"""Create test cases for prune pauses operations."""
return [
PrunePausesTestCase(
name="prune_old_active_pauses",
pause_age=timedelta(days=7),
resume_age=None,
expected_pruned_count=1,
description="Should prune old active pauses",
),
PrunePausesTestCase(
name="prune_old_resumed_pauses",
pause_age=timedelta(hours=12), # Created 12 hours ago (recent)
resume_age=timedelta(days=7),
expected_pruned_count=1,
description="Should prune old resumed pauses",
),
PrunePausesTestCase(
name="keep_recent_active_pauses",
pause_age=timedelta(hours=1),
resume_age=None,
expected_pruned_count=0,
description="Should keep recent active pauses",
),
PrunePausesTestCase(
name="keep_recent_resumed_pauses",
pause_age=timedelta(days=1),
resume_age=timedelta(hours=1),
expected_pruned_count=0,
description="Should keep recent resumed pauses",
),
]
class TestWorkflowPauseIntegration:
"""Comprehensive integration tests for workflow pause functionality."""
@pytest.fixture(autouse=True)
def setup_test_data(self, db_session_with_containers):
"""Set up test data for each test method using TestContainers."""
# Create test tenant and account
tenant = Tenant(
name="Test Tenant",
status="normal",
)
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
account = Account(
email="test@example.com",
name="Test User",
interface_language="en-US",
status="active",
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant-account join
tenant_join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db_session_with_containers.add(tenant_join)
db_session_with_containers.commit()
# Set test data
self.test_tenant_id = tenant.id
self.test_user_id = account.id
self.test_app_id = str(uuid.uuid4())
self.test_workflow_id = str(uuid.uuid4())
# Create test workflow
self.test_workflow = Workflow(
id=self.test_workflow_id,
tenant_id=self.test_tenant_id,
app_id=self.test_app_id,
type="workflow",
version="draft",
graph='{"nodes": [], "edges": []}',
features='{"file_upload": {"enabled": false}}',
created_by=self.test_user_id,
created_at=naive_utc_now(),
)
# Store session instance
self.session = db_session_with_containers
# Save test data to database
self.session.add(self.test_workflow)
self.session.commit()
yield
# Cleanup
self._cleanup_test_data()
def _cleanup_test_data(self):
"""Clean up test data after each test method."""
# Clean up workflow pauses
self.session.execute(delete(WorkflowPauseModel))
# Clean up upload files
self.session.execute(
delete(UploadFile).where(
UploadFile.tenant_id == self.test_tenant_id,
)
)
# Clean up workflow runs
self.session.execute(
delete(WorkflowRun).where(
WorkflowRun.tenant_id == self.test_tenant_id,
WorkflowRun.app_id == self.test_app_id,
)
)
# Clean up workflows
self.session.execute(
delete(Workflow).where(
Workflow.tenant_id == self.test_tenant_id,
Workflow.app_id == self.test_app_id,
)
)
self.session.commit()
def _create_test_workflow_run(
self, status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING
) -> WorkflowRun:
"""Create a test workflow run with specified status."""
workflow_run = WorkflowRun(
id=str(uuid.uuid4()),
tenant_id=self.test_tenant_id,
app_id=self.test_app_id,
workflow_id=self.test_workflow_id,
type="workflow",
triggered_from="debugging",
version="draft",
status=status,
created_by=self.test_user_id,
created_by_role="account",
created_at=naive_utc_now(),
)
self.session.add(workflow_run)
self.session.commit()
return workflow_run
def _create_test_state(self) -> str:
"""Create a test state string."""
return json.dumps(
{
"node_id": "test-node",
"node_type": "llm",
"status": "paused",
"data": {"key": "value"},
"timestamp": naive_utc_now().isoformat(),
}
)
def _get_workflow_run_repository(self):
"""Get workflow run repository instance for testing."""
# Create session factory from the test session
engine = self.session.get_bind()
session_factory = sessionmaker(bind=engine, expire_on_commit=False)
# Create a test-specific repository that implements the missing save method
class TestWorkflowRunRepository(DifyAPISQLAlchemyWorkflowRunRepository):
"""Test-specific repository that implements the missing save method."""
def save(self, execution: WorkflowExecution):
"""Implement the missing save method for testing."""
# For testing purposes, we don't need to implement this method
# as it's not used in the pause functionality tests
pass
# Create and return repository instance
repository = TestWorkflowRunRepository(session_maker=session_factory)
return repository
# ==================== Complete Pause Workflow Tests ====================
def test_complete_pause_resume_workflow(self):
"""Test complete workflow: create -> pause -> resume -> delete."""
# Arrange
workflow_run = self._create_test_workflow_run()
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
# Act - Create pause state
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
# Assert - Pause state created
assert pause_entity is not None
assert pause_entity.id is not None
assert pause_entity.workflow_execution_id == workflow_run.id
# Convert both to strings for comparison
retrieved_state = pause_entity.get_state()
if isinstance(retrieved_state, bytes):
retrieved_state = retrieved_state.decode()
assert retrieved_state == test_state
# Verify database state
query = select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
pause_model = self.session.scalars(query).first()
assert pause_model is not None
assert pause_model.resumed_at is None
assert pause_model.id == pause_entity.id
self.session.refresh(workflow_run)
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
# Act - Get pause state
retrieved_entity = repository.get_workflow_pause(workflow_run.id)
# Assert - Pause state retrieved
assert retrieved_entity is not None
assert retrieved_entity.id == pause_entity.id
retrieved_state = retrieved_entity.get_state()
if isinstance(retrieved_state, bytes):
retrieved_state = retrieved_state.decode()
assert retrieved_state == test_state
# Act - Resume workflow
resumed_entity = repository.resume_workflow_pause(
workflow_run_id=workflow_run.id,
pause_entity=pause_entity,
)
# Assert - Workflow resumed
assert resumed_entity is not None
assert resumed_entity.id == pause_entity.id
assert resumed_entity.resumed_at is not None
# Verify database state
self.session.refresh(workflow_run)
assert workflow_run.status == WorkflowExecutionStatus.RUNNING
self.session.refresh(pause_model)
assert pause_model.resumed_at is not None
# Act - Delete pause state
repository.delete_workflow_pause(pause_entity)
# Assert - Pause state deleted
with Session(bind=self.session.get_bind()) as session:
deleted_pause = session.get(WorkflowPauseModel, pause_entity.id)
assert deleted_pause is None
def test_pause_workflow_success(self):
"""Test successful pause workflow scenarios."""
workflow_run = self._create_test_workflow_run(status=WorkflowExecutionStatus.RUNNING)
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
assert pause_entity is not None
assert pause_entity.workflow_execution_id == workflow_run.id
retrieved_state = pause_entity.get_state()
if isinstance(retrieved_state, bytes):
retrieved_state = retrieved_state.decode()
assert retrieved_state == test_state
self.session.refresh(workflow_run)
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
pause_query = select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
pause_model = self.session.scalars(pause_query).first()
assert pause_model is not None
assert pause_model.id == pause_entity.id
assert pause_model.resumed_at is None
@pytest.mark.parametrize("test_case", pause_workflow_failure_cases(), ids=lambda tc: tc.name)
def test_pause_workflow_failure(self, test_case: PauseWorkflowFailureCase):
"""Test pause workflow failure scenarios."""
workflow_run = self._create_test_workflow_run(status=test_case.initial_status)
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
with pytest.raises(_WorkflowRunError):
repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
@pytest.mark.parametrize("test_case", resume_workflow_success_cases(), ids=lambda tc: tc.name)
def test_resume_workflow_success(self, test_case: ResumeWorkflowSuccessCase):
"""Test successful resume workflow scenarios."""
workflow_run = self._create_test_workflow_run(status=test_case.initial_status)
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
if workflow_run.status != WorkflowExecutionStatus.RUNNING:
workflow_run.status = WorkflowExecutionStatus.RUNNING
self.session.commit()
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
self.session.refresh(workflow_run)
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
resumed_entity = repository.resume_workflow_pause(
workflow_run_id=workflow_run.id,
pause_entity=pause_entity,
)
assert resumed_entity is not None
assert resumed_entity.id == pause_entity.id
assert resumed_entity.resumed_at is not None
self.session.refresh(workflow_run)
assert workflow_run.status == WorkflowExecutionStatus.RUNNING
pause_query = select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
pause_model = self.session.scalars(pause_query).first()
assert pause_model is not None
assert pause_model.id == pause_entity.id
assert pause_model.resumed_at is not None
def test_resume_running_workflow(self):
"""Test resume workflow failure scenarios."""
workflow_run = self._create_test_workflow_run(status=WorkflowExecutionStatus.RUNNING)
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
self.session.refresh(workflow_run)
workflow_run.status = WorkflowExecutionStatus.RUNNING
self.session.add(workflow_run)
self.session.commit()
with pytest.raises(_WorkflowRunError):
repository.resume_workflow_pause(
workflow_run_id=workflow_run.id,
pause_entity=pause_entity,
)
def test_resume_resumed_pause(self):
"""Test resume workflow failure scenarios."""
workflow_run = self._create_test_workflow_run(status=WorkflowExecutionStatus.RUNNING)
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
pause_model.resumed_at = naive_utc_now()
self.session.add(pause_model)
self.session.commit()
with pytest.raises(_WorkflowRunError):
repository.resume_workflow_pause(
workflow_run_id=workflow_run.id,
pause_entity=pause_entity,
)
# ==================== Error Scenario Tests ====================
def test_pause_nonexistent_workflow_run(self):
"""Test pausing a non-existent workflow run."""
# Arrange
nonexistent_id = str(uuid.uuid4())
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
# Act & Assert
with pytest.raises(ValueError, match="WorkflowRun not found"):
repository.create_workflow_pause(
workflow_run_id=nonexistent_id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
def test_resume_nonexistent_workflow_run(self):
"""Test resuming a non-existent workflow run."""
# Arrange
workflow_run = self._create_test_workflow_run()
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
nonexistent_id = str(uuid.uuid4())
# Act & Assert
with pytest.raises(ValueError, match="WorkflowRun not found"):
repository.resume_workflow_pause(
workflow_run_id=nonexistent_id,
pause_entity=pause_entity,
)
# ==================== Prune Functionality Tests ====================
@pytest.mark.parametrize("test_case", prune_pauses_test_cases(), ids=lambda tc: tc.name)
def test_prune_pauses_scenarios(self, test_case: PrunePausesTestCase):
"""Test various prune pauses scenarios."""
now = naive_utc_now()
# Create pause state
workflow_run = self._create_test_workflow_run()
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
# Manually adjust timestamps for testing
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
pause_model.created_at = now - test_case.pause_age
if test_case.resume_age is not None:
# Resume pause and adjust resume time
repository.resume_workflow_pause(
workflow_run_id=workflow_run.id,
pause_entity=pause_entity,
)
# Need to refresh to get the updated model
self.session.refresh(pause_model)
# Manually set the resumed_at to an older time for testing
pause_model.resumed_at = now - test_case.resume_age
self.session.commit() # Commit the resumed_at change
# Refresh again to ensure the change is persisted
self.session.refresh(pause_model)
self.session.commit()
# Act - Prune pauses
expiration_time = now - timedelta(days=1, seconds=1) # Expire pauses older than 1 day (plus 1 second)
resumption_time = now - timedelta(
days=7, seconds=1
) # Clean up pauses resumed more than 7 days ago (plus 1 second)
# Debug: Check pause state before pruning
self.session.refresh(pause_model)
print(f"Pause created_at: {pause_model.created_at}")
print(f"Pause resumed_at: {pause_model.resumed_at}")
print(f"Expiration time: {expiration_time}")
print(f"Resumption time: {resumption_time}")
# Force commit to ensure timestamps are saved
self.session.commit()
# Determine if the pause should be pruned based on timestamps
should_be_pruned = False
if test_case.resume_age is not None:
# If resumed, check if resumed_at is older than resumption_time
should_be_pruned = pause_model.resumed_at < resumption_time
else:
# If not resumed, check if created_at is older than expiration_time
should_be_pruned = pause_model.created_at < expiration_time
# Act - Prune pauses
pruned_ids = repository.prune_pauses(
expiration=expiration_time,
resumption_expiration=resumption_time,
)
# Assert - Check pruning results
if should_be_pruned:
assert len(pruned_ids) == test_case.expected_pruned_count
# Verify pause was actually deleted
# The pause should be in the pruned_ids list if it was pruned
assert pause_entity.id in pruned_ids
else:
assert len(pruned_ids) == 0
def test_prune_pauses_with_limit(self):
"""Test prune pauses with limit parameter."""
now = naive_utc_now()
# Create multiple pause states
pause_entities = []
repository = self._get_workflow_run_repository()
for i in range(5):
workflow_run = self._create_test_workflow_run()
test_state = self._create_test_state()
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
pause_entities.append(pause_entity)
# Make all pauses old enough to be pruned
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
pause_model.created_at = now - timedelta(days=7)
self.session.commit()
# Act - Prune with limit
expiration_time = now - timedelta(days=1)
resumption_time = now - timedelta(days=7)
pruned_ids = repository.prune_pauses(
expiration=expiration_time,
resumption_expiration=resumption_time,
limit=3,
)
# Assert
assert len(pruned_ids) == 3
# Verify only 3 were deleted
remaining_count = (
self.session.query(WorkflowPauseModel)
.filter(WorkflowPauseModel.id.in_([pe.id for pe in pause_entities]))
.count()
)
assert remaining_count == 2
# ==================== Multi-tenant Isolation Tests ====================
def test_multi_tenant_pause_isolation(self):
"""Test that pause states are properly isolated by tenant."""
# Arrange - Create second tenant
tenant2 = Tenant(
name="Test Tenant 2",
status="normal",
)
self.session.add(tenant2)
self.session.commit()
account2 = Account(
email="test2@example.com",
name="Test User 2",
interface_language="en-US",
status="active",
)
self.session.add(account2)
self.session.commit()
tenant2_join = TenantAccountJoin(
tenant_id=tenant2.id,
account_id=account2.id,
role=TenantAccountRole.OWNER,
current=True,
)
self.session.add(tenant2_join)
self.session.commit()
# Create workflow for tenant 2
workflow2 = Workflow(
id=str(uuid.uuid4()),
tenant_id=tenant2.id,
app_id=str(uuid.uuid4()),
type="workflow",
version="draft",
graph='{"nodes": [], "edges": []}',
features='{"file_upload": {"enabled": false}}',
created_by=account2.id,
created_at=naive_utc_now(),
)
self.session.add(workflow2)
self.session.commit()
# Create workflow runs for both tenants
workflow_run1 = self._create_test_workflow_run()
workflow_run2 = WorkflowRun(
id=str(uuid.uuid4()),
tenant_id=tenant2.id,
app_id=workflow2.app_id,
workflow_id=workflow2.id,
type="workflow",
triggered_from="debugging",
version="draft",
status=WorkflowExecutionStatus.RUNNING,
created_by=account2.id,
created_by_role="account",
created_at=naive_utc_now(),
)
self.session.add(workflow_run2)
self.session.commit()
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
# Act - Create pause for tenant 1
pause_entity1 = repository.create_workflow_pause(
workflow_run_id=workflow_run1.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
# Try to access pause from tenant 2 using tenant 1's repository
# This should work because we're using the same repository
pause_entity2 = repository.get_workflow_pause(workflow_run2.id)
assert pause_entity2 is None # No pause for tenant 2 yet
# Create pause for tenant 2
pause_entity2 = repository.create_workflow_pause(
workflow_run_id=workflow_run2.id,
state_owner_user_id=account2.id,
state=test_state,
)
# Assert - Both pauses should exist and be separate
assert pause_entity1 is not None
assert pause_entity2 is not None
assert pause_entity1.id != pause_entity2.id
assert pause_entity1.workflow_execution_id != pause_entity2.workflow_execution_id
def test_cross_tenant_access_restriction(self):
"""Test that cross-tenant access is properly restricted."""
# This test would require tenant-specific repositories
# For now, we test that pause entities are properly scoped by tenant_id
workflow_run = self._create_test_workflow_run()
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
# Verify pause is properly scoped
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
assert pause_model.workflow_id == self.test_workflow_id
# ==================== File Storage Integration Tests ====================
def test_file_storage_integration(self):
"""Test that state files are properly stored and retrieved."""
# Arrange
workflow_run = self._create_test_workflow_run()
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
# Act - Create pause state
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
# Assert - Verify file was uploaded to storage
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
assert pause_model.state_object_key != ""
# Verify file content in storage
file_key = pause_model.state_object_key
storage_content = storage.load(file_key).decode()
assert storage_content == test_state
# Verify retrieval through entity
retrieved_state = pause_entity.get_state()
if isinstance(retrieved_state, bytes):
retrieved_state = retrieved_state.decode()
assert retrieved_state == test_state
def test_file_cleanup_on_pause_deletion(self):
"""Test that files are properly handled on pause deletion."""
# Arrange
workflow_run = self._create_test_workflow_run()
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
# Get file info before deletion
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
file_key = pause_model.state_object_key
# Act - Delete pause state
repository.delete_workflow_pause(pause_entity)
# Assert - Pause record should be deleted
self.session.expire_all() # Clear session to ensure fresh query
deleted_pause = self.session.get(WorkflowPauseModel, pause_entity.id)
assert deleted_pause is None
try:
content = storage.load(file_key).decode()
pytest.fail("File should be deleted from storage after pause deletion")
except FileNotFoundError:
# This is expected - file should be deleted from storage
pass
except Exception as e:
pytest.fail(f"Unexpected error when checking file deletion: {e}")
def test_large_state_file_handling(self):
"""Test handling of large state files."""
# Arrange - Create a large state (1MB)
large_state = "x" * (1024 * 1024) # 1MB of data
large_state_json = json.dumps({"large_data": large_state})
workflow_run = self._create_test_workflow_run()
repository = self._get_workflow_run_repository()
# Act
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=large_state_json,
)
# Assert
assert pause_entity is not None
retrieved_state = pause_entity.get_state()
if isinstance(retrieved_state, bytes):
retrieved_state = retrieved_state.decode()
assert retrieved_state == large_state_json
# Verify file size in database
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
assert pause_model.state_object_key != ""
loaded_state = storage.load(pause_model.state_object_key)
assert loaded_state.decode() == large_state_json
def test_multiple_pause_resume_cycles(self):
"""Test multiple pause/resume cycles on the same workflow run."""
# Arrange
workflow_run = self._create_test_workflow_run()
repository = self._get_workflow_run_repository()
# Act & Assert - Multiple cycles
for i in range(3):
state = json.dumps({"cycle": i, "data": f"state_{i}"})
# Reset workflow run status to RUNNING before each pause (after first cycle)
if i > 0:
self.session.refresh(workflow_run) # Refresh to get latest state from session
workflow_run.status = WorkflowExecutionStatus.RUNNING
self.session.commit()
self.session.refresh(workflow_run) # Refresh again after commit
# Pause
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=state,
)
assert pause_entity is not None
# Verify pause
self.session.expire_all() # Clear session to ensure fresh query
self.session.refresh(workflow_run)
# Use the test session directly to verify the pause
stmt = select(WorkflowRun).options(selectinload(WorkflowRun.pause)).where(WorkflowRun.id == workflow_run.id)
workflow_run_with_pause = self.session.scalar(stmt)
pause_model = workflow_run_with_pause.pause
# Verify pause using test session directly
assert pause_model is not None
assert pause_model.id == pause_entity.id
assert pause_model.state_object_key != ""
# Load file content using storage directly
file_content = storage.load(pause_model.state_object_key)
if isinstance(file_content, bytes):
file_content = file_content.decode()
assert file_content == state
# Resume
resumed_entity = repository.resume_workflow_pause(
workflow_run_id=workflow_run.id,
pause_entity=pause_entity,
)
assert resumed_entity is not None
assert resumed_entity.resumed_at is not None
# Verify resume - check that pause is marked as resumed
self.session.expire_all() # Clear session to ensure fresh query
stmt = select(WorkflowPauseModel).where(WorkflowPauseModel.id == pause_entity.id)
resumed_pause_model = self.session.scalar(stmt)
assert resumed_pause_model is not None
assert resumed_pause_model.resumed_at is not None
# Verify workflow run status
self.session.refresh(workflow_run)
assert workflow_run.status == WorkflowExecutionStatus.RUNNING

View File

@ -1,324 +0,0 @@
"""
Unit tests for WorkflowResponseConverter focusing on process_data truncation functionality.
"""
import uuid
from collections.abc import Mapping
from typing import Any
from unittest.mock import Mock
import pytest
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
)
from core.workflow.enums import NodeType
from core.workflow.system_variable import SystemVariable
from libs.datetime_utils import naive_utc_now
from models import Account
class TestWorkflowResponseConverterCenarios:
"""Test process_data truncation in WorkflowResponseConverter."""
def create_mock_generate_entity(self) -> WorkflowAppGenerateEntity:
"""Create a mock WorkflowAppGenerateEntity."""
mock_entity = Mock(spec=WorkflowAppGenerateEntity)
mock_app_config = Mock()
mock_app_config.tenant_id = "test-tenant-id"
mock_entity.app_config = mock_app_config
mock_entity.inputs = {}
return mock_entity
def create_workflow_response_converter(self) -> WorkflowResponseConverter:
"""Create a WorkflowResponseConverter for testing."""
mock_entity = self.create_mock_generate_entity()
mock_user = Mock(spec=Account)
mock_user.id = "test-user-id"
mock_user.name = "Test User"
mock_user.email = "test@example.com"
system_variables = SystemVariable(workflow_id="wf-id", workflow_execution_id="initial-run-id")
return WorkflowResponseConverter(
application_generate_entity=mock_entity,
user=mock_user,
system_variables=system_variables,
)
def create_node_started_event(self, *, node_execution_id: str | None = None) -> QueueNodeStartedEvent:
"""Create a QueueNodeStartedEvent for testing."""
return QueueNodeStartedEvent(
node_execution_id=node_execution_id or str(uuid.uuid4()),
node_id="test-node-id",
node_title="Test Node",
node_type=NodeType.CODE,
start_at=naive_utc_now(),
predecessor_node_id=None,
in_iteration_id=None,
in_loop_id=None,
provider_type="built-in",
provider_id="code",
)
def create_node_succeeded_event(
self,
*,
node_execution_id: str,
process_data: Mapping[str, Any] | None = None,
) -> QueueNodeSucceededEvent:
"""Create a QueueNodeSucceededEvent for testing."""
return QueueNodeSucceededEvent(
node_id="test-node-id",
node_type=NodeType.CODE,
node_execution_id=node_execution_id,
start_at=naive_utc_now(),
in_iteration_id=None,
in_loop_id=None,
inputs={},
process_data=process_data or {},
outputs={},
execution_metadata={},
)
def create_node_retry_event(
self,
*,
node_execution_id: str,
process_data: Mapping[str, Any] | None = None,
) -> QueueNodeRetryEvent:
"""Create a QueueNodeRetryEvent for testing."""
return QueueNodeRetryEvent(
inputs={"data": "inputs"},
outputs={"data": "outputs"},
process_data=process_data or {},
error="oops",
retry_index=1,
node_id="test-node-id",
node_type=NodeType.CODE,
node_title="test code",
provider_type="built-in",
provider_id="code",
node_execution_id=node_execution_id,
start_at=naive_utc_now(),
in_iteration_id=None,
in_loop_id=None,
)
def test_workflow_node_finish_response_uses_truncated_process_data(self):
"""Test that node finish response uses get_response_process_data()."""
converter = self.create_workflow_response_converter()
original_data = {"large_field": "x" * 10000, "metadata": "info"}
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
task_id="test-task-id",
)
event = self.create_node_succeeded_event(
node_execution_id=start_event.node_execution_id,
process_data=original_data,
)
def fake_truncate(mapping):
if mapping == dict(original_data):
return truncated_data, True
return mapping, False
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test-task-id",
)
# Response should use truncated data, not original
assert response is not None
assert response.data.process_data == truncated_data
assert response.data.process_data != original_data
assert response.data.process_data_truncated is True
def test_workflow_node_finish_response_without_truncation(self):
"""Test node finish response when no truncation is applied."""
converter = self.create_workflow_response_converter()
original_data = {"small": "data"}
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
task_id="test-task-id",
)
event = self.create_node_succeeded_event(
node_execution_id=start_event.node_execution_id,
process_data=original_data,
)
def fake_truncate(mapping):
return mapping, False
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test-task-id",
)
# Response should use original data
assert response is not None
assert response.data.process_data == original_data
assert response.data.process_data_truncated is False
def test_workflow_node_finish_response_with_none_process_data(self):
"""Test node finish response when process_data is None."""
converter = self.create_workflow_response_converter()
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
task_id="test-task-id",
)
event = self.create_node_succeeded_event(
node_execution_id=start_event.node_execution_id,
process_data=None,
)
def fake_truncate(mapping):
return mapping, False
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test-task-id",
)
# Response should normalize missing process_data to an empty mapping
assert response is not None
assert response.data.process_data == {}
assert response.data.process_data_truncated is False
def test_workflow_node_retry_response_uses_truncated_process_data(self):
"""Test that node retry response uses get_response_process_data()."""
converter = self.create_workflow_response_converter()
original_data = {"large_field": "x" * 10000, "metadata": "info"}
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
task_id="test-task-id",
)
event = self.create_node_retry_event(
node_execution_id=start_event.node_execution_id,
process_data=original_data,
)
def fake_truncate(mapping):
if mapping == dict(original_data):
return truncated_data, True
return mapping, False
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
response = converter.workflow_node_retry_to_stream_response(
event=event,
task_id="test-task-id",
)
# Response should use truncated data, not original
assert response is not None
assert response.data.process_data == truncated_data
assert response.data.process_data != original_data
assert response.data.process_data_truncated is True
def test_workflow_node_retry_response_without_truncation(self):
"""Test node retry response when no truncation is applied."""
converter = self.create_workflow_response_converter()
original_data = {"small": "data"}
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
task_id="test-task-id",
)
event = self.create_node_retry_event(
node_execution_id=start_event.node_execution_id,
process_data=original_data,
)
def fake_truncate(mapping):
return mapping, False
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
response = converter.workflow_node_retry_to_stream_response(
event=event,
task_id="test-task-id",
)
assert response is not None
assert response.data.process_data == original_data
assert response.data.process_data_truncated is False
def test_iteration_and_loop_nodes_return_none(self):
"""Test that iteration and loop nodes return None (no streaming events)."""
converter = self.create_workflow_response_converter()
iteration_event = QueueNodeSucceededEvent(
node_id="iteration-node",
node_type=NodeType.ITERATION,
node_execution_id=str(uuid.uuid4()),
start_at=naive_utc_now(),
in_iteration_id=None,
in_loop_id=None,
inputs={},
process_data={},
outputs={},
execution_metadata={},
)
response = converter.workflow_node_finish_to_stream_response(
event=iteration_event,
task_id="test-task-id",
)
assert response is None
loop_event = iteration_event.model_copy(update={"node_type": NodeType.LOOP})
response = converter.workflow_node_finish_to_stream_response(
event=loop_event,
task_id="test-task-id",
)
assert response is None
def test_finish_without_start_raises(self):
"""Ensure finish responses require a prior workflow start."""
converter = self.create_workflow_response_converter()
event = self.create_node_succeeded_event(
node_execution_id=str(uuid.uuid4()),
process_data={},
)
with pytest.raises(ValueError):
converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test-task-id",
)

View File

@ -0,0 +1,810 @@
"""
Unit tests for WorkflowResponseConverter focusing on process_data truncation functionality.
"""
import uuid
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any
from unittest.mock import Mock
import pytest
from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
QueueEvent,
QueueIterationStartEvent,
QueueLoopStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
)
from core.workflow.enums import NodeType
from core.workflow.system_variable import SystemVariable
from libs.datetime_utils import naive_utc_now
from models import Account
from models.model import AppMode
class TestWorkflowResponseConverter:
"""Test truncation in WorkflowResponseConverter."""
def create_mock_generate_entity(self) -> WorkflowAppGenerateEntity:
"""Create a mock WorkflowAppGenerateEntity."""
mock_entity = Mock(spec=WorkflowAppGenerateEntity)
mock_app_config = Mock()
mock_app_config.tenant_id = "test-tenant-id"
mock_entity.invoke_from = InvokeFrom.WEB_APP
mock_entity.app_config = mock_app_config
mock_entity.inputs = {}
return mock_entity
def create_workflow_response_converter(self) -> WorkflowResponseConverter:
"""Create a WorkflowResponseConverter for testing."""
mock_entity = self.create_mock_generate_entity()
mock_user = Mock(spec=Account)
mock_user.id = "test-user-id"
mock_user.name = "Test User"
mock_user.email = "test@example.com"
system_variables = SystemVariable(workflow_id="wf-id", workflow_execution_id="initial-run-id")
return WorkflowResponseConverter(
application_generate_entity=mock_entity,
user=mock_user,
system_variables=system_variables,
)
def create_node_started_event(self, *, node_execution_id: str | None = None) -> QueueNodeStartedEvent:
"""Create a QueueNodeStartedEvent for testing."""
return QueueNodeStartedEvent(
node_execution_id=node_execution_id or str(uuid.uuid4()),
node_id="test-node-id",
node_title="Test Node",
node_type=NodeType.CODE,
start_at=naive_utc_now(),
in_iteration_id=None,
in_loop_id=None,
provider_type="built-in",
provider_id="code",
)
def create_node_succeeded_event(
self,
*,
node_execution_id: str,
process_data: Mapping[str, Any] | None = None,
) -> QueueNodeSucceededEvent:
"""Create a QueueNodeSucceededEvent for testing."""
return QueueNodeSucceededEvent(
node_id="test-node-id",
node_type=NodeType.CODE,
node_execution_id=node_execution_id,
start_at=naive_utc_now(),
in_iteration_id=None,
in_loop_id=None,
inputs={},
process_data=process_data or {},
outputs={},
execution_metadata={},
)
def create_node_retry_event(
self,
*,
node_execution_id: str,
process_data: Mapping[str, Any] | None = None,
) -> QueueNodeRetryEvent:
"""Create a QueueNodeRetryEvent for testing."""
return QueueNodeRetryEvent(
inputs={"data": "inputs"},
outputs={"data": "outputs"},
process_data=process_data or {},
error="oops",
retry_index=1,
node_id="test-node-id",
node_type=NodeType.CODE,
node_title="test code",
provider_type="built-in",
provider_id="code",
node_execution_id=node_execution_id,
start_at=naive_utc_now(),
in_iteration_id=None,
in_loop_id=None,
)
def test_workflow_node_finish_response_uses_truncated_process_data(self):
"""Test that node finish response uses get_response_process_data()."""
converter = self.create_workflow_response_converter()
original_data = {"large_field": "x" * 10000, "metadata": "info"}
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
task_id="test-task-id",
)
event = self.create_node_succeeded_event(
node_execution_id=start_event.node_execution_id,
process_data=original_data,
)
def fake_truncate(mapping):
if mapping == dict(original_data):
return truncated_data, True
return mapping, False
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test-task-id",
)
# Response should use truncated data, not original
assert response is not None
assert response.data.process_data == truncated_data
assert response.data.process_data != original_data
assert response.data.process_data_truncated is True
def test_workflow_node_finish_response_without_truncation(self):
"""Test node finish response when no truncation is applied."""
converter = self.create_workflow_response_converter()
original_data = {"small": "data"}
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
task_id="test-task-id",
)
event = self.create_node_succeeded_event(
node_execution_id=start_event.node_execution_id,
process_data=original_data,
)
def fake_truncate(mapping):
return mapping, False
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test-task-id",
)
# Response should use original data
assert response is not None
assert response.data.process_data == original_data
assert response.data.process_data_truncated is False
def test_workflow_node_finish_response_with_none_process_data(self):
"""Test node finish response when process_data is None."""
converter = self.create_workflow_response_converter()
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
task_id="test-task-id",
)
event = self.create_node_succeeded_event(
node_execution_id=start_event.node_execution_id,
process_data=None,
)
def fake_truncate(mapping):
return mapping, False
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test-task-id",
)
# Response should normalize missing process_data to an empty mapping
assert response is not None
assert response.data.process_data == {}
assert response.data.process_data_truncated is False
def test_workflow_node_retry_response_uses_truncated_process_data(self):
"""Test that node retry response uses get_response_process_data()."""
converter = self.create_workflow_response_converter()
original_data = {"large_field": "x" * 10000, "metadata": "info"}
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
task_id="test-task-id",
)
event = self.create_node_retry_event(
node_execution_id=start_event.node_execution_id,
process_data=original_data,
)
def fake_truncate(mapping):
if mapping == dict(original_data):
return truncated_data, True
return mapping, False
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
response = converter.workflow_node_retry_to_stream_response(
event=event,
task_id="test-task-id",
)
# Response should use truncated data, not original
assert response is not None
assert response.data.process_data == truncated_data
assert response.data.process_data != original_data
assert response.data.process_data_truncated is True
def test_workflow_node_retry_response_without_truncation(self):
"""Test node retry response when no truncation is applied."""
converter = self.create_workflow_response_converter()
original_data = {"small": "data"}
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
task_id="test-task-id",
)
event = self.create_node_retry_event(
node_execution_id=start_event.node_execution_id,
process_data=original_data,
)
def fake_truncate(mapping):
return mapping, False
converter._truncator.truncate_variable_mapping = fake_truncate # type: ignore[assignment]
response = converter.workflow_node_retry_to_stream_response(
event=event,
task_id="test-task-id",
)
assert response is not None
assert response.data.process_data == original_data
assert response.data.process_data_truncated is False
def test_iteration_and_loop_nodes_return_none(self):
"""Test that iteration and loop nodes return None (no streaming events)."""
converter = self.create_workflow_response_converter()
iteration_event = QueueNodeSucceededEvent(
node_id="iteration-node",
node_type=NodeType.ITERATION,
node_execution_id=str(uuid.uuid4()),
start_at=naive_utc_now(),
in_iteration_id=None,
in_loop_id=None,
inputs={},
process_data={},
outputs={},
execution_metadata={},
)
response = converter.workflow_node_finish_to_stream_response(
event=iteration_event,
task_id="test-task-id",
)
assert response is None
loop_event = iteration_event.model_copy(update={"node_type": NodeType.LOOP})
response = converter.workflow_node_finish_to_stream_response(
event=loop_event,
task_id="test-task-id",
)
assert response is None
def test_finish_without_start_raises(self):
"""Ensure finish responses require a prior workflow start."""
converter = self.create_workflow_response_converter()
event = self.create_node_succeeded_event(
node_execution_id=str(uuid.uuid4()),
process_data={},
)
with pytest.raises(ValueError):
converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test-task-id",
)
@dataclass
class TestCase:
"""Test case data for table-driven tests."""
name: str
invoke_from: InvokeFrom
expected_truncation_enabled: bool
description: str
class TestWorkflowResponseConverterServiceApiTruncation:
"""Test class for Service API truncation functionality in WorkflowResponseConverter."""
def create_test_app_generate_entity(self, invoke_from: InvokeFrom) -> WorkflowAppGenerateEntity:
"""Create a test WorkflowAppGenerateEntity with specified invoke_from."""
# Create a minimal WorkflowUIBasedAppConfig for testing
app_config = WorkflowUIBasedAppConfig(
tenant_id="test_tenant",
app_id="test_app",
app_mode=AppMode.WORKFLOW,
workflow_id="test_workflow_id",
)
entity = WorkflowAppGenerateEntity(
task_id="test_task_id",
app_id="test_app_id",
app_config=app_config,
tenant_id="test_tenant",
app_mode=AppMode.WORKFLOW,
invoke_from=invoke_from,
inputs={"test_input": "test_value"},
user_id="test_user_id",
stream=True,
files=[],
workflow_execution_id="test_workflow_exec_id",
)
return entity
def create_test_user(self) -> Account:
"""Create a test user account."""
account = Account(
name="Test User",
email="test@example.com",
)
# Manually set the ID for testing purposes
account.id = "test_user_id"
return account
def create_test_system_variables(self) -> SystemVariable:
"""Create test system variables."""
return SystemVariable()
def create_test_converter(self, invoke_from: InvokeFrom) -> WorkflowResponseConverter:
"""Create WorkflowResponseConverter with specified invoke_from."""
entity = self.create_test_app_generate_entity(invoke_from)
user = self.create_test_user()
system_variables = self.create_test_system_variables()
converter = WorkflowResponseConverter(
application_generate_entity=entity,
user=user,
system_variables=system_variables,
)
# ensure `workflow_run_id` is set.
converter.workflow_start_to_stream_response(
task_id="test-task-id",
workflow_run_id="test-workflow-run-id",
workflow_id="test-workflow-id",
)
return converter
@pytest.mark.parametrize(
"test_case",
[
TestCase(
name="service_api_truncation_disabled",
invoke_from=InvokeFrom.SERVICE_API,
expected_truncation_enabled=False,
description="Service API calls should have truncation disabled",
),
TestCase(
name="web_app_truncation_enabled",
invoke_from=InvokeFrom.WEB_APP,
expected_truncation_enabled=True,
description="Web app calls should have truncation enabled",
),
TestCase(
name="debugger_truncation_enabled",
invoke_from=InvokeFrom.DEBUGGER,
expected_truncation_enabled=True,
description="Debugger calls should have truncation enabled",
),
TestCase(
name="explore_truncation_enabled",
invoke_from=InvokeFrom.EXPLORE,
expected_truncation_enabled=True,
description="Explore calls should have truncation enabled",
),
TestCase(
name="published_truncation_enabled",
invoke_from=InvokeFrom.PUBLISHED,
expected_truncation_enabled=True,
description="Published app calls should have truncation enabled",
),
],
ids=lambda x: x.name,
)
def test_truncator_selection_based_on_invoke_from(self, test_case: TestCase):
"""Test that the correct truncator is selected based on invoke_from."""
converter = self.create_test_converter(test_case.invoke_from)
# Test truncation behavior instead of checking private attribute
# Create a test event with large data
large_value = {"key": ["x"] * 2000} # Large data that would be truncated
event = QueueNodeSucceededEvent(
node_execution_id="test_node_exec_id",
node_id="test_node",
node_type=NodeType.LLM,
start_at=naive_utc_now(),
inputs=large_value,
process_data=large_value,
outputs=large_value,
error=None,
execution_metadata=None,
in_iteration_id=None,
in_loop_id=None,
)
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test_task",
)
# Verify response is not None
assert response is not None
# Verify truncation behavior matches expectations
if test_case.expected_truncation_enabled:
# Truncation should be enabled for non-service-api calls
assert response.data.inputs_truncated
assert response.data.process_data_truncated
assert response.data.outputs_truncated
else:
# SERVICE_API should not truncate
assert not response.data.inputs_truncated
assert not response.data.process_data_truncated
assert not response.data.outputs_truncated
def test_service_api_truncator_no_op_mapping(self):
"""Test that Service API truncator doesn't truncate variable mappings."""
converter = self.create_test_converter(InvokeFrom.SERVICE_API)
# Create a test event with large data
large_value: dict[str, Any] = {
"large_string": "x" * 10000, # Large string
"large_list": list(range(2000)), # Large array
"nested_data": {"deep_nested": {"very_deep": {"value": "x" * 5000}}},
}
event = QueueNodeSucceededEvent(
node_execution_id="test_node_exec_id",
node_id="test_node",
node_type=NodeType.LLM,
start_at=naive_utc_now(),
inputs=large_value,
process_data=large_value,
outputs=large_value,
error=None,
execution_metadata=None,
in_iteration_id=None,
in_loop_id=None,
)
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test_task",
)
# Verify response is not None
data = response.data
assert data.inputs == large_value
assert data.process_data == large_value
assert data.outputs == large_value
# Service API should not truncate
assert data.inputs_truncated is False
assert data.process_data_truncated is False
assert data.outputs_truncated is False
def test_web_app_truncator_works_normally(self):
"""Test that web app truncator still works normally."""
converter = self.create_test_converter(InvokeFrom.WEB_APP)
# Create a test event with large data
large_value = {
"large_string": "x" * 10000, # Large string
"large_list": list(range(2000)), # Large array
}
event = QueueNodeSucceededEvent(
node_execution_id="test_node_exec_id",
node_id="test_node",
node_type=NodeType.LLM,
start_at=naive_utc_now(),
inputs=large_value,
process_data=large_value,
outputs=large_value,
error=None,
execution_metadata=None,
in_iteration_id=None,
in_loop_id=None,
)
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test_task",
)
# Verify response is not None
assert response is not None
# Web app should truncate
data = response.data
assert data.inputs != large_value
assert data.process_data != large_value
assert data.outputs != large_value
# The exact behavior depends on VariableTruncator implementation
# Just verify that truncation flags are present
assert data.inputs_truncated is True
assert data.process_data_truncated is True
assert data.outputs_truncated is True
@staticmethod
def _create_event_by_type(
type_: QueueEvent, inputs: Mapping[str, Any], process_data: Mapping[str, Any], outputs: Mapping[str, Any]
) -> QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent:
if type_ == QueueEvent.NODE_SUCCEEDED:
return QueueNodeSucceededEvent(
node_execution_id="test_node_exec_id",
node_id="test_node",
node_type=NodeType.LLM,
start_at=naive_utc_now(),
inputs=inputs,
process_data=process_data,
outputs=outputs,
error=None,
execution_metadata=None,
in_iteration_id=None,
in_loop_id=None,
)
elif type_ == QueueEvent.NODE_FAILED:
return QueueNodeFailedEvent(
node_execution_id="test_node_exec_id",
node_id="test_node",
node_type=NodeType.LLM,
start_at=naive_utc_now(),
inputs=inputs,
process_data=process_data,
outputs=outputs,
error="oops",
execution_metadata=None,
in_iteration_id=None,
in_loop_id=None,
)
elif type_ == QueueEvent.NODE_EXCEPTION:
return QueueNodeExceptionEvent(
node_execution_id="test_node_exec_id",
node_id="test_node",
node_type=NodeType.LLM,
start_at=naive_utc_now(),
inputs=inputs,
process_data=process_data,
outputs=outputs,
error="oops",
execution_metadata=None,
in_iteration_id=None,
in_loop_id=None,
)
else:
raise Exception("unknown type.")
@pytest.mark.parametrize(
"event_type",
[
QueueEvent.NODE_SUCCEEDED,
QueueEvent.NODE_FAILED,
QueueEvent.NODE_EXCEPTION,
],
)
def test_service_api_node_finish_event_no_truncation(self, event_type: QueueEvent):
"""Test that Service API doesn't truncate node finish events."""
converter = self.create_test_converter(InvokeFrom.SERVICE_API)
# Create test event with large data
large_inputs = {"input1": "x" * 5000, "input2": list(range(2000))}
large_process_data = {"process1": "y" * 5000, "process2": {"nested": ["z"] * 2000}}
large_outputs = {"output1": "result" * 1000, "output2": list(range(2000))}
event = TestWorkflowResponseConverterServiceApiTruncation._create_event_by_type(
event_type, large_inputs, large_process_data, large_outputs
)
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test_task",
)
# Verify response is not None
assert response is not None
# Verify response contains full data (not truncated)
assert response.data.inputs == large_inputs
assert response.data.process_data == large_process_data
assert response.data.outputs == large_outputs
assert not response.data.inputs_truncated
assert not response.data.process_data_truncated
assert not response.data.outputs_truncated
def test_service_api_node_retry_event_no_truncation(self):
"""Test that Service API doesn't truncate node retry events."""
converter = self.create_test_converter(InvokeFrom.SERVICE_API)
# Create test event with large data
large_inputs = {"retry_input": "x" * 5000}
large_process_data = {"retry_process": "y" * 5000}
large_outputs = {"retry_output": "z" * 5000}
# First, we need to store a snapshot by simulating a start event
start_event = QueueNodeStartedEvent(
node_execution_id="test_node_exec_id",
node_id="test_node",
node_type=NodeType.LLM,
node_title="Test Node",
node_run_index=1,
start_at=naive_utc_now(),
in_iteration_id=None,
in_loop_id=None,
agent_strategy=None,
provider_type="plugin",
provider_id="test/test_plugin",
)
converter.workflow_node_start_to_stream_response(event=start_event, task_id="test_task")
# Now create retry event
event = QueueNodeRetryEvent(
node_execution_id="test_node_exec_id",
node_id="test_node",
node_type=NodeType.LLM,
node_title="Test Node",
node_run_index=1,
start_at=naive_utc_now(),
inputs=large_inputs,
process_data=large_process_data,
outputs=large_outputs,
error="Retry error",
execution_metadata=None,
in_iteration_id=None,
in_loop_id=None,
retry_index=1,
provider_type="plugin",
provider_id="test/test_plugin",
)
response = converter.workflow_node_retry_to_stream_response(
event=event,
task_id="test_task",
)
# Verify response is not None
assert response is not None
# Verify response contains full data (not truncated)
assert response.data.inputs == large_inputs
assert response.data.process_data == large_process_data
assert response.data.outputs == large_outputs
assert not response.data.inputs_truncated
assert not response.data.process_data_truncated
assert not response.data.outputs_truncated
def test_service_api_iteration_events_no_truncation(self):
"""Test that Service API doesn't truncate iteration events."""
converter = self.create_test_converter(InvokeFrom.SERVICE_API)
# Test iteration start event
large_value = {"iteration_input": ["x"] * 2000}
start_event = QueueIterationStartEvent(
node_execution_id="test_iter_exec_id",
node_id="test_iteration",
node_type=NodeType.ITERATION,
node_title="Test Iteration",
node_run_index=0,
start_at=naive_utc_now(),
inputs=large_value,
metadata={},
)
response = converter.workflow_iteration_start_to_stream_response(
task_id="test_task",
workflow_execution_id="test_workflow_exec_id",
event=start_event,
)
assert response is not None
assert response.data.inputs == large_value
assert not response.data.inputs_truncated
def test_service_api_loop_events_no_truncation(self):
"""Test that Service API doesn't truncate loop events."""
converter = self.create_test_converter(InvokeFrom.SERVICE_API)
# Test loop start event
large_inputs = {"loop_input": ["x"] * 2000}
start_event = QueueLoopStartEvent(
node_execution_id="test_loop_exec_id",
node_id="test_loop",
node_type=NodeType.LOOP,
node_title="Test Loop",
start_at=naive_utc_now(),
inputs=large_inputs,
metadata={},
node_run_index=0,
)
response = converter.workflow_loop_start_to_stream_response(
task_id="test_task",
workflow_execution_id="test_workflow_exec_id",
event=start_event,
)
assert response is not None
assert response.data.inputs == large_inputs
assert not response.data.inputs_truncated
def test_web_app_node_finish_event_truncation_works(self):
"""Test that web app still truncates node finish events."""
converter = self.create_test_converter(InvokeFrom.WEB_APP)
# Create test event with large data that should be truncated
large_inputs = {"input1": ["x"] * 2000}
large_process_data = {"process1": ["y"] * 2000}
large_outputs = {"output1": ["z"] * 2000}
event = QueueNodeSucceededEvent(
node_execution_id="test_node_exec_id",
node_id="test_node",
node_type=NodeType.LLM,
start_at=naive_utc_now(),
inputs=large_inputs,
process_data=large_process_data,
outputs=large_outputs,
error=None,
execution_metadata=None,
in_iteration_id=None,
in_loop_id=None,
)
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test_task",
)
# Verify response is not None
assert response is not None
# Verify response contains truncated data
# The exact behavior depends on VariableTruncator implementation
# Just verify truncation flags are set correctly (may or may not be truncated depending on size)
# At minimum, the truncation mechanism should work
assert isinstance(response.data.inputs, dict)
assert response.data.inputs_truncated
assert isinstance(response.data.process_data, dict)
assert response.data.process_data_truncated
assert isinstance(response.data.outputs, dict)
assert response.data.outputs_truncated

View File

@ -0,0 +1,278 @@
import json
from time import time
from unittest.mock import Mock
import pytest
from core.app.layers.pause_state_persist_layer import PauseStatePersistenceLayer
from core.variables.segments import Segment
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.graph_engine.entities.commands import GraphEngineCommand
from core.workflow.graph_events.graph import (
GraphRunFailedEvent,
GraphRunPausedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
)
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool
from repositories.factory import DifyAPIRepositoryFactory
class TestDataFactory:
"""Factory helpers for constructing graph events used in tests."""
@staticmethod
def create_graph_run_paused_event(outputs: dict[str, object] | None = None) -> GraphRunPausedEvent:
return GraphRunPausedEvent(reason=SchedulingPause(message="test pause"), outputs=outputs or {})
@staticmethod
def create_graph_run_started_event() -> GraphRunStartedEvent:
return GraphRunStartedEvent()
@staticmethod
def create_graph_run_succeeded_event(outputs: dict[str, object] | None = None) -> GraphRunSucceededEvent:
return GraphRunSucceededEvent(outputs=outputs or {})
@staticmethod
def create_graph_run_failed_event(
error: str = "Test error",
exceptions_count: int = 1,
) -> GraphRunFailedEvent:
return GraphRunFailedEvent(error=error, exceptions_count=exceptions_count)
class MockSystemVariableReadOnlyView:
"""Minimal read-only system variable view for testing."""
def __init__(self, workflow_execution_id: str | None = None) -> None:
self._workflow_execution_id = workflow_execution_id
@property
def workflow_execution_id(self) -> str | None:
return self._workflow_execution_id
class MockReadOnlyVariablePool:
"""Mock implementation of ReadOnlyVariablePool for testing."""
def __init__(self, variables: dict[tuple[str, str], object] | None = None):
self._variables = variables or {}
def get(self, node_id: str, variable_key: str) -> Segment | None:
value = self._variables.get((node_id, variable_key))
if value is None:
return None
mock_segment = Mock(spec=Segment)
mock_segment.value = value
return mock_segment
def get_all_by_node(self, node_id: str) -> dict[str, object]:
return {key: value for (nid, key), value in self._variables.items() if nid == node_id}
def get_by_prefix(self, prefix: str) -> dict[str, object]:
return {f"{nid}.{key}": value for (nid, key), value in self._variables.items() if nid.startswith(prefix)}
class MockReadOnlyGraphRuntimeState:
"""Mock implementation of ReadOnlyGraphRuntimeState for testing."""
def __init__(
self,
start_at: float | None = None,
total_tokens: int = 0,
node_run_steps: int = 0,
ready_queue_size: int = 0,
exceptions_count: int = 0,
outputs: dict[str, object] | None = None,
variables: dict[tuple[str, str], object] | None = None,
workflow_execution_id: str | None = None,
):
self._start_at = start_at or time()
self._total_tokens = total_tokens
self._node_run_steps = node_run_steps
self._ready_queue_size = ready_queue_size
self._exceptions_count = exceptions_count
self._outputs = outputs or {}
self._variable_pool = MockReadOnlyVariablePool(variables)
self._system_variable = MockSystemVariableReadOnlyView(workflow_execution_id)
@property
def system_variable(self) -> MockSystemVariableReadOnlyView:
return self._system_variable
@property
def variable_pool(self) -> ReadOnlyVariablePool:
return self._variable_pool
@property
def start_at(self) -> float:
return self._start_at
@property
def total_tokens(self) -> int:
return self._total_tokens
@property
def node_run_steps(self) -> int:
return self._node_run_steps
@property
def ready_queue_size(self) -> int:
return self._ready_queue_size
@property
def exceptions_count(self) -> int:
return self._exceptions_count
@property
def outputs(self) -> dict[str, object]:
return self._outputs.copy()
@property
def llm_usage(self):
mock_usage = Mock()
mock_usage.prompt_tokens = 10
mock_usage.completion_tokens = 20
mock_usage.total_tokens = 30
return mock_usage
def get_output(self, key: str, default: object = None) -> object:
return self._outputs.get(key, default)
def dumps(self) -> str:
return json.dumps(
{
"start_at": self._start_at,
"total_tokens": self._total_tokens,
"node_run_steps": self._node_run_steps,
"ready_queue_size": self._ready_queue_size,
"exceptions_count": self._exceptions_count,
"outputs": self._outputs,
"variables": {f"{k[0]}.{k[1]}": v for k, v in self._variable_pool._variables.items()},
"workflow_execution_id": self._system_variable.workflow_execution_id,
}
)
class MockCommandChannel:
"""Mock implementation of CommandChannel for testing."""
def __init__(self):
self._commands: list[GraphEngineCommand] = []
def fetch_commands(self) -> list[GraphEngineCommand]:
return self._commands.copy()
def send_command(self, command: GraphEngineCommand) -> None:
self._commands.append(command)
class TestPauseStatePersistenceLayer:
"""Unit tests for PauseStatePersistenceLayer."""
def test_init_with_dependency_injection(self):
session_factory = Mock(name="session_factory")
state_owner_user_id = "user-123"
layer = PauseStatePersistenceLayer(
session_factory=session_factory,
state_owner_user_id=state_owner_user_id,
)
assert layer._session_maker is session_factory
assert layer._state_owner_user_id == state_owner_user_id
assert not hasattr(layer, "graph_runtime_state")
assert not hasattr(layer, "command_channel")
def test_initialize_sets_dependencies(self):
session_factory = Mock(name="session_factory")
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner")
graph_runtime_state = MockReadOnlyGraphRuntimeState()
command_channel = MockCommandChannel()
layer.initialize(graph_runtime_state, command_channel)
assert layer.graph_runtime_state is graph_runtime_state
assert layer.command_channel is command_channel
def test_on_event_with_graph_run_paused_event(self, monkeypatch: pytest.MonkeyPatch):
session_factory = Mock(name="session_factory")
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
mock_repo = Mock()
mock_factory = Mock(return_value=mock_repo)
monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory)
graph_runtime_state = MockReadOnlyGraphRuntimeState(
outputs={"result": "test_output"},
total_tokens=100,
workflow_execution_id="run-123",
)
command_channel = MockCommandChannel()
layer.initialize(graph_runtime_state, command_channel)
event = TestDataFactory.create_graph_run_paused_event(outputs={"intermediate": "result"})
expected_state = graph_runtime_state.dumps()
layer.on_event(event)
mock_factory.assert_called_once_with(session_factory)
mock_repo.create_workflow_pause.assert_called_once_with(
workflow_run_id="run-123",
state_owner_user_id="owner-123",
state=expected_state,
)
def test_on_event_ignores_non_paused_events(self, monkeypatch: pytest.MonkeyPatch):
session_factory = Mock(name="session_factory")
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
mock_repo = Mock()
mock_factory = Mock(return_value=mock_repo)
monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory)
graph_runtime_state = MockReadOnlyGraphRuntimeState()
command_channel = MockCommandChannel()
layer.initialize(graph_runtime_state, command_channel)
events = [
TestDataFactory.create_graph_run_started_event(),
TestDataFactory.create_graph_run_succeeded_event(),
TestDataFactory.create_graph_run_failed_event(),
]
for event in events:
layer.on_event(event)
mock_factory.assert_not_called()
mock_repo.create_workflow_pause.assert_not_called()
def test_on_event_raises_attribute_error_when_graph_runtime_state_is_none(self):
session_factory = Mock(name="session_factory")
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
event = TestDataFactory.create_graph_run_paused_event()
with pytest.raises(AttributeError):
layer.on_event(event)
def test_on_event_asserts_when_workflow_execution_id_missing(self, monkeypatch: pytest.MonkeyPatch):
session_factory = Mock(name="session_factory")
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
mock_repo = Mock()
mock_factory = Mock(return_value=mock_repo)
monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory)
graph_runtime_state = MockReadOnlyGraphRuntimeState(workflow_execution_id=None)
command_channel = MockCommandChannel()
layer.initialize(graph_runtime_state, command_channel)
event = TestDataFactory.create_graph_run_paused_event()
with pytest.raises(AssertionError):
layer.on_event(event)
mock_factory.assert_not_called()
mock_repo.create_workflow_pause.assert_not_called()

View File

@ -23,3 +23,32 @@ def test_file():
assert file.extension == ".png"
assert file.mime_type == "image/png"
assert file.size == 67
def test_file_model_validate_with_legacy_fields():
"""Test `File` model can handle data containing compatibility fields."""
data = {
"id": "test-file",
"tenant_id": "test-tenant-id",
"type": "image",
"transfer_method": "tool_file",
"related_id": "test-related-id",
"filename": "image.png",
"extension": ".png",
"mime_type": "image/png",
"size": 67,
"storage_key": "test-storage-key",
"url": "https://example.com/image.png",
# Extra legacy fields
"tool_file_id": "tool-file-123",
"upload_file_id": "upload-file-456",
"datasource_file_id": "datasource-file-789",
}
# Should be able to create `File` object without raising an exception
file = File.model_validate(data)
# The File object does not have tool_file_id, upload_file_id, or datasource_file_id as attributes.
# Instead, check it does not expose unrecognized legacy fields (should raise on getattr).
for legacy_field in ("tool_file_id", "upload_file_id", "datasource_file_id"):
assert not hasattr(file, legacy_field)

View File

@ -0,0 +1,12 @@
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer
def test_get_runner_script():
code = JavascriptCodeProvider.get_default_code()
inputs = {"arg1": "hello, ", "arg2": "world!"}
script = NodeJsTemplateTransformer.assemble_runner_script(code, inputs)
script_lines = script.splitlines()
code_lines = code.splitlines()
# Check that the first lines of script are exactly the same as code
assert script_lines[: len(code_lines)] == code_lines

View File

@ -0,0 +1,12 @@
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer
def test_get_runner_script():
code = Python3CodeProvider.get_default_code()
inputs = {"arg1": "hello, ", "arg2": "world!"}
script = Python3TemplateTransformer.assemble_runner_script(code, inputs)
script_lines = script.splitlines()
code_lines = code.splitlines()
# Check that the first lines of script are exactly the same as code
assert script_lines[: len(code_lines)] == code_lines

View File

@ -395,9 +395,6 @@ def test_client_capabilities_default():
# Assert default capabilities
assert received_capabilities is not None
assert received_capabilities.sampling is not None
assert received_capabilities.roots is not None
assert received_capabilities.roots.listChanged is True
def test_client_capabilities_with_custom_callbacks():

View File

@ -0,0 +1,171 @@
"""Tests for _PrivateWorkflowPauseEntity implementation."""
from datetime import datetime
from unittest.mock import MagicMock, patch
from models.workflow import WorkflowPause as WorkflowPauseModel
from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity
class TestPrivateWorkflowPauseEntity:
"""Test _PrivateWorkflowPauseEntity implementation."""
def test_entity_initialization(self):
"""Test entity initialization with required parameters."""
# Create mock models
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.id = "pause-123"
mock_pause_model.workflow_run_id = "execution-456"
mock_pause_model.resumed_at = None
# Create entity
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
# Verify initialization
assert entity._pause_model is mock_pause_model
assert entity._cached_state is None
def test_from_models_classmethod(self):
"""Test from_models class method."""
# Create mock models
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.id = "pause-123"
mock_pause_model.workflow_run_id = "execution-456"
# Create entity using from_models
entity = _PrivateWorkflowPauseEntity.from_models(
workflow_pause_model=mock_pause_model,
)
# Verify entity creation
assert isinstance(entity, _PrivateWorkflowPauseEntity)
assert entity._pause_model is mock_pause_model
def test_id_property(self):
"""Test id property returns pause model ID."""
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.id = "pause-123"
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
assert entity.id == "pause-123"
def test_workflow_execution_id_property(self):
"""Test workflow_execution_id property returns workflow run ID."""
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.workflow_run_id = "execution-456"
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
assert entity.workflow_execution_id == "execution-456"
def test_resumed_at_property(self):
"""Test resumed_at property returns pause model resumed_at."""
resumed_at = datetime(2023, 12, 25, 15, 30, 45)
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.resumed_at = resumed_at
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
assert entity.resumed_at == resumed_at
def test_resumed_at_property_none(self):
"""Test resumed_at property returns None when not set."""
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.resumed_at = None
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
assert entity.resumed_at is None
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
def test_get_state_first_call(self, mock_storage):
"""Test get_state loads from storage on first call."""
state_data = b'{"test": "data", "step": 5}'
mock_storage.load.return_value = state_data
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.state_object_key = "test-state-key"
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
# First call should load from storage
result = entity.get_state()
assert result == state_data
mock_storage.load.assert_called_once_with("test-state-key")
assert entity._cached_state == state_data
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
def test_get_state_cached_call(self, mock_storage):
"""Test get_state returns cached data on subsequent calls."""
state_data = b'{"test": "data", "step": 5}'
mock_storage.load.return_value = state_data
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.state_object_key = "test-state-key"
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
# First call
result1 = entity.get_state()
# Second call should use cache
result2 = entity.get_state()
assert result1 == state_data
assert result2 == state_data
# Storage should only be called once
mock_storage.load.assert_called_once_with("test-state-key")
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
def test_get_state_with_pre_cached_data(self, mock_storage):
"""Test get_state returns pre-cached data."""
state_data = b'{"test": "data", "step": 5}'
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
# Pre-cache data
entity._cached_state = state_data
# Should return cached data without calling storage
result = entity.get_state()
assert result == state_data
mock_storage.load.assert_not_called()
def test_entity_with_binary_state_data(self):
"""Test entity with binary state data."""
# Test with binary data that's not valid JSON
binary_data = b"\x00\x01\x02\x03\x04\x05\xff\xfe"
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
mock_storage.load.return_value = binary_data
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
result = entity.get_state()
assert result == binary_data

View File

@ -3,6 +3,7 @@
import time
from unittest.mock import MagicMock
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
@ -149,8 +150,8 @@ def test_pause_command():
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
pause_events = [e for e in events if isinstance(e, GraphRunPausedEvent)]
assert len(pause_events) == 1
assert pause_events[0].reason == "User requested pause"
assert pause_events[0].reason == SchedulingPause(message="User requested pause")
graph_execution = engine.graph_runtime_state.graph_execution
assert graph_execution.is_paused
assert graph_execution.pause_reason == "User requested pause"
assert graph_execution.pause_reason == SchedulingPause(message="User requested pause")

View File

@ -0,0 +1,96 @@
"""
Test cases for the Iteration node's flatten_output functionality.
This module tests the iteration node's ability to:
1. Flatten array outputs when flatten_output=True (default)
2. Preserve nested array structure when flatten_output=False
"""
from .test_table_runner import TableTestRunner, WorkflowTestCase
def test_iteration_with_flatten_output_enabled():
"""
Test iteration node with flatten_output=True (default behavior).
The fixture implements an iteration that:
1. Iterates over [1, 2, 3]
2. For each item, outputs [item, item*2]
3. With flatten_output=True, should output [1, 2, 2, 4, 3, 6]
"""
runner = TableTestRunner()
test_case = WorkflowTestCase(
fixture_path="iteration_flatten_output_enabled_workflow",
inputs={},
expected_outputs={"output": [1, 2, 2, 4, 3, 6]},
description="Iteration with flatten_output=True flattens nested arrays",
use_auto_mock=False, # Run code nodes directly
)
result = runner.run_test_case(test_case)
assert result.success, f"Test failed: {result.error}"
assert result.actual_outputs is not None, "Should have outputs"
assert result.actual_outputs == {"output": [1, 2, 2, 4, 3, 6]}, (
f"Expected flattened output [1, 2, 2, 4, 3, 6], got {result.actual_outputs}"
)
def test_iteration_with_flatten_output_disabled():
"""
Test iteration node with flatten_output=False.
The fixture implements an iteration that:
1. Iterates over [1, 2, 3]
2. For each item, outputs [item, item*2]
3. With flatten_output=False, should output [[1, 2], [2, 4], [3, 6]]
"""
runner = TableTestRunner()
test_case = WorkflowTestCase(
fixture_path="iteration_flatten_output_disabled_workflow",
inputs={},
expected_outputs={"output": [[1, 2], [2, 4], [3, 6]]},
description="Iteration with flatten_output=False preserves nested structure",
use_auto_mock=False, # Run code nodes directly
)
result = runner.run_test_case(test_case)
assert result.success, f"Test failed: {result.error}"
assert result.actual_outputs is not None, "Should have outputs"
assert result.actual_outputs == {"output": [[1, 2], [2, 4], [3, 6]]}, (
f"Expected nested output [[1, 2], [2, 4], [3, 6]], got {result.actual_outputs}"
)
def test_iteration_flatten_output_comparison():
"""
Run both flatten_output configurations in parallel to verify the difference.
"""
runner = TableTestRunner()
test_cases = [
WorkflowTestCase(
fixture_path="iteration_flatten_output_enabled_workflow",
inputs={},
expected_outputs={"output": [1, 2, 2, 4, 3, 6]},
description="flatten_output=True: Flattened output",
use_auto_mock=False, # Run code nodes directly
),
WorkflowTestCase(
fixture_path="iteration_flatten_output_disabled_workflow",
inputs={},
expected_outputs={"output": [[1, 2], [2, 4], [3, 6]]},
description="flatten_output=False: Nested output",
use_auto_mock=False, # Run code nodes directly
),
]
suite_result = runner.run_table_tests(test_cases, parallel=True)
# Assert all tests passed
assert suite_result.passed_tests == 2, f"Expected 2 passed tests, got {suite_result.passed_tests}"
assert suite_result.failed_tests == 0, f"Expected 0 failed tests, got {suite_result.failed_tests}"
assert suite_result.success_rate == 100.0, f"Expected 100% success rate, got {suite_result.success_rate}"

View File

@ -0,0 +1,32 @@
"""Tests for workflow pause related enums and constants."""
from core.workflow.enums import (
WorkflowExecutionStatus,
)
class TestWorkflowExecutionStatus:
"""Test WorkflowExecutionStatus enum."""
def test_is_ended_method(self):
"""Test is_ended method for different statuses."""
# Test ended statuses
ended_statuses = [
WorkflowExecutionStatus.SUCCEEDED,
WorkflowExecutionStatus.FAILED,
WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
WorkflowExecutionStatus.STOPPED,
]
for status in ended_statuses:
assert status.is_ended(), f"{status} should be considered ended"
# Test non-ended statuses
non_ended_statuses = [
WorkflowExecutionStatus.SCHEDULED,
WorkflowExecutionStatus.RUNNING,
WorkflowExecutionStatus.PAUSED,
]
for status in non_ended_statuses:
assert not status.is_ended(), f"{status} should not be considered ended"

View File

@ -0,0 +1,202 @@
from typing import cast
import pytest
from core.file.models import File, FileTransferMethod, FileType
from core.workflow.system_variable import SystemVariable, SystemVariableReadOnlyView
class TestSystemVariableReadOnlyView:
"""Test cases for SystemVariableReadOnlyView class."""
def test_read_only_property_access(self):
"""Test that all properties return correct values from wrapped instance."""
# Create test data
test_file = File(
id="file-123",
tenant_id="tenant-123",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related-123",
)
datasource_info = {"key": "value", "nested": {"data": 42}}
# Create SystemVariable with all fields
system_var = SystemVariable(
user_id="user-123",
app_id="app-123",
workflow_id="workflow-123",
files=[test_file],
workflow_execution_id="exec-123",
query="test query",
conversation_id="conv-123",
dialogue_count=5,
document_id="doc-123",
original_document_id="orig-doc-123",
dataset_id="dataset-123",
batch="batch-123",
datasource_type="type-123",
datasource_info=datasource_info,
invoke_from="invoke-123",
)
# Create read-only view
read_only_view = SystemVariableReadOnlyView(system_var)
# Test all properties
assert read_only_view.user_id == "user-123"
assert read_only_view.app_id == "app-123"
assert read_only_view.workflow_id == "workflow-123"
assert read_only_view.workflow_execution_id == "exec-123"
assert read_only_view.query == "test query"
assert read_only_view.conversation_id == "conv-123"
assert read_only_view.dialogue_count == 5
assert read_only_view.document_id == "doc-123"
assert read_only_view.original_document_id == "orig-doc-123"
assert read_only_view.dataset_id == "dataset-123"
assert read_only_view.batch == "batch-123"
assert read_only_view.datasource_type == "type-123"
assert read_only_view.invoke_from == "invoke-123"
def test_defensive_copying_of_mutable_objects(self):
"""Test that mutable objects are defensively copied."""
# Create test data
test_file = File(
id="file-123",
tenant_id="tenant-123",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related-123",
)
datasource_info = {"key": "original_value"}
# Create SystemVariable
system_var = SystemVariable(
files=[test_file], datasource_info=datasource_info, workflow_execution_id="exec-123"
)
# Create read-only view
read_only_view = SystemVariableReadOnlyView(system_var)
# Test files defensive copying
files_copy = read_only_view.files
assert isinstance(files_copy, tuple) # Should be immutable tuple
assert len(files_copy) == 1
assert files_copy[0].id == "file-123"
# Verify it's a copy (can't modify original through view)
assert isinstance(files_copy, tuple)
# tuples don't have append method, so they're immutable
# Test datasource_info defensive copying
datasource_copy = read_only_view.datasource_info
assert datasource_copy is not None
assert datasource_copy["key"] == "original_value"
datasource_copy = cast(dict, datasource_copy)
with pytest.raises(TypeError):
datasource_copy["key"] = "modified value"
# Verify original is unchanged
assert system_var.datasource_info is not None
assert system_var.datasource_info["key"] == "original_value"
assert read_only_view.datasource_info is not None
assert read_only_view.datasource_info["key"] == "original_value"
def test_always_accesses_latest_data(self):
"""Test that properties always return the latest data from wrapped instance."""
# Create SystemVariable
system_var = SystemVariable(user_id="original-user", workflow_execution_id="exec-123")
# Create read-only view
read_only_view = SystemVariableReadOnlyView(system_var)
# Verify initial value
assert read_only_view.user_id == "original-user"
# Modify the wrapped instance
system_var.user_id = "modified-user"
# Verify view returns the new value
assert read_only_view.user_id == "modified-user"
def test_repr_method(self):
"""Test the __repr__ method."""
# Create SystemVariable
system_var = SystemVariable(workflow_execution_id="exec-123")
# Create read-only view
read_only_view = SystemVariableReadOnlyView(system_var)
# Test repr
repr_str = repr(read_only_view)
assert "SystemVariableReadOnlyView" in repr_str
assert "system_variable=" in repr_str
def test_none_value_handling(self):
"""Test that None values are properly handled."""
# Create SystemVariable with all None values except workflow_execution_id
system_var = SystemVariable(
user_id=None,
app_id=None,
workflow_id=None,
workflow_execution_id="exec-123",
query=None,
conversation_id=None,
dialogue_count=None,
document_id=None,
original_document_id=None,
dataset_id=None,
batch=None,
datasource_type=None,
datasource_info=None,
invoke_from=None,
)
# Create read-only view
read_only_view = SystemVariableReadOnlyView(system_var)
# Test all None values
assert read_only_view.user_id is None
assert read_only_view.app_id is None
assert read_only_view.workflow_id is None
assert read_only_view.query is None
assert read_only_view.conversation_id is None
assert read_only_view.dialogue_count is None
assert read_only_view.document_id is None
assert read_only_view.original_document_id is None
assert read_only_view.dataset_id is None
assert read_only_view.batch is None
assert read_only_view.datasource_type is None
assert read_only_view.datasource_info is None
assert read_only_view.invoke_from is None
# files should be empty tuple even when default list is empty
assert read_only_view.files == ()
def test_empty_files_handling(self):
"""Test that empty files list is handled correctly."""
# Create SystemVariable with empty files
system_var = SystemVariable(files=[], workflow_execution_id="exec-123")
# Create read-only view
read_only_view = SystemVariableReadOnlyView(system_var)
# Test files handling
assert read_only_view.files == ()
assert isinstance(read_only_view.files, tuple)
def test_empty_datasource_info_handling(self):
"""Test that empty datasource_info is handled correctly."""
# Create SystemVariable with empty datasource_info
system_var = SystemVariable(datasource_info={}, workflow_execution_id="exec-123")
# Create read-only view
read_only_view = SystemVariableReadOnlyView(system_var)
# Test datasource_info handling
assert read_only_view.datasource_info == {}
# Should be a copy, not the same object
assert read_only_view.datasource_info is not system_var.datasource_info

View File

@ -1,8 +1,10 @@
import datetime
from unittest.mock import patch
import pytest
import pytz
from libs.datetime_utils import naive_utc_now
from libs.datetime_utils import naive_utc_now, parse_time_range
def test_naive_utc_now(monkeypatch: pytest.MonkeyPatch):
@ -20,3 +22,247 @@ def test_naive_utc_now(monkeypatch: pytest.MonkeyPatch):
naive_time = naive_datetime.time()
utc_time = tz_aware_utc_now.time()
assert naive_time == utc_time
class TestParseTimeRange:
"""Test cases for parse_time_range function."""
def test_parse_time_range_basic(self):
"""Test basic time range parsing."""
start, end = parse_time_range("2024-01-01 10:00", "2024-01-01 18:00", "UTC")
assert start is not None
assert end is not None
assert start < end
assert start.tzinfo == pytz.UTC
assert end.tzinfo == pytz.UTC
def test_parse_time_range_start_only(self):
"""Test parsing with only start time."""
start, end = parse_time_range("2024-01-01 10:00", None, "UTC")
assert start is not None
assert end is None
assert start.tzinfo == pytz.UTC
def test_parse_time_range_end_only(self):
"""Test parsing with only end time."""
start, end = parse_time_range(None, "2024-01-01 18:00", "UTC")
assert start is None
assert end is not None
assert end.tzinfo == pytz.UTC
def test_parse_time_range_both_none(self):
"""Test parsing with both times None."""
start, end = parse_time_range(None, None, "UTC")
assert start is None
assert end is None
def test_parse_time_range_different_timezones(self):
"""Test parsing with different timezones."""
# Test with US/Eastern timezone
start, end = parse_time_range("2024-01-01 10:00", "2024-01-01 18:00", "US/Eastern")
assert start is not None
assert end is not None
assert start.tzinfo == pytz.UTC
assert end.tzinfo == pytz.UTC
# Verify the times are correctly converted to UTC
assert start.hour == 15 # 10 AM EST = 3 PM UTC (in January)
assert end.hour == 23 # 6 PM EST = 11 PM UTC (in January)
def test_parse_time_range_invalid_start_format(self):
"""Test parsing with invalid start time format."""
with pytest.raises(ValueError, match="time data.*does not match format"):
parse_time_range("invalid-date", "2024-01-01 18:00", "UTC")
def test_parse_time_range_invalid_end_format(self):
"""Test parsing with invalid end time format."""
with pytest.raises(ValueError, match="time data.*does not match format"):
parse_time_range("2024-01-01 10:00", "invalid-date", "UTC")
def test_parse_time_range_invalid_timezone(self):
"""Test parsing with invalid timezone."""
with pytest.raises(pytz.exceptions.UnknownTimeZoneError):
parse_time_range("2024-01-01 10:00", "2024-01-01 18:00", "Invalid/Timezone")
def test_parse_time_range_start_after_end(self):
"""Test parsing with start time after end time."""
with pytest.raises(ValueError, match="start must be earlier than or equal to end"):
parse_time_range("2024-01-01 18:00", "2024-01-01 10:00", "UTC")
def test_parse_time_range_start_equals_end(self):
"""Test parsing with start time equal to end time."""
start, end = parse_time_range("2024-01-01 10:00", "2024-01-01 10:00", "UTC")
assert start is not None
assert end is not None
assert start == end
def test_parse_time_range_dst_ambiguous_time(self):
"""Test parsing during DST ambiguous time (fall back)."""
# This test simulates DST fall back where 2:30 AM occurs twice
with patch("pytz.timezone") as mock_timezone:
# Mock timezone that raises AmbiguousTimeError
mock_tz = mock_timezone.return_value
# Create a mock datetime object for the return value
mock_dt = datetime.datetime(2024, 1, 1, 10, 0, 0)
mock_utc_dt = mock_dt.replace(tzinfo=pytz.UTC)
# Create a proper mock for the localized datetime
from unittest.mock import MagicMock
mock_localized_dt = MagicMock()
mock_localized_dt.astimezone.return_value = mock_utc_dt
# Set up side effects: first call raises exception, second call succeeds
mock_tz.localize.side_effect = [
pytz.AmbiguousTimeError("Ambiguous time"), # First call for start
mock_localized_dt, # Second call for start (with is_dst=False)
pytz.AmbiguousTimeError("Ambiguous time"), # First call for end
mock_localized_dt, # Second call for end (with is_dst=False)
]
start, end = parse_time_range("2024-01-01 10:00", "2024-01-01 18:00", "US/Eastern")
# Should use is_dst=False for ambiguous times
assert mock_tz.localize.call_count == 4 # 2 calls per time (first fails, second succeeds)
assert start is not None
assert end is not None
def test_parse_time_range_dst_nonexistent_time(self):
"""Test parsing during DST nonexistent time (spring forward)."""
with patch("pytz.timezone") as mock_timezone:
# Mock timezone that raises NonExistentTimeError
mock_tz = mock_timezone.return_value
# Create a mock datetime object for the return value
mock_dt = datetime.datetime(2024, 1, 1, 10, 0, 0)
mock_utc_dt = mock_dt.replace(tzinfo=pytz.UTC)
# Create a proper mock for the localized datetime
from unittest.mock import MagicMock
mock_localized_dt = MagicMock()
mock_localized_dt.astimezone.return_value = mock_utc_dt
# Set up side effects: first call raises exception, second call succeeds
mock_tz.localize.side_effect = [
pytz.NonExistentTimeError("Non-existent time"), # First call for start
mock_localized_dt, # Second call for start (with adjusted time)
pytz.NonExistentTimeError("Non-existent time"), # First call for end
mock_localized_dt, # Second call for end (with adjusted time)
]
start, end = parse_time_range("2024-01-01 10:00", "2024-01-01 18:00", "US/Eastern")
# Should adjust time forward by 1 hour for nonexistent times
assert mock_tz.localize.call_count == 4 # 2 calls per time (first fails, second succeeds)
assert start is not None
assert end is not None
def test_parse_time_range_edge_cases(self):
"""Test edge cases for time parsing."""
# Test with midnight times
start, end = parse_time_range("2024-01-01 00:00", "2024-01-01 23:59", "UTC")
assert start is not None
assert end is not None
assert start.hour == 0
assert start.minute == 0
assert end.hour == 23
assert end.minute == 59
def test_parse_time_range_different_dates(self):
"""Test parsing with different dates."""
start, end = parse_time_range("2024-01-01 10:00", "2024-01-02 10:00", "UTC")
assert start is not None
assert end is not None
assert start.date() != end.date()
assert (end - start).days == 1
def test_parse_time_range_seconds_handling(self):
"""Test that seconds are properly set to 0."""
start, end = parse_time_range("2024-01-01 10:30", "2024-01-01 18:45", "UTC")
assert start is not None
assert end is not None
assert start.second == 0
assert end.second == 0
def test_parse_time_range_timezone_conversion_accuracy(self):
"""Test accurate timezone conversion."""
# Test with a known timezone conversion
start, end = parse_time_range("2024-01-01 12:00", "2024-01-01 12:00", "Asia/Tokyo")
assert start is not None
assert end is not None
assert start.tzinfo == pytz.UTC
assert end.tzinfo == pytz.UTC
# Tokyo is UTC+9, so 12:00 JST = 03:00 UTC
assert start.hour == 3
assert end.hour == 3
def test_parse_time_range_summer_time(self):
"""Test parsing during summer time (DST)."""
# Test with US/Eastern during summer (EDT = UTC-4)
start, end = parse_time_range("2024-07-01 12:00", "2024-07-01 12:00", "US/Eastern")
assert start is not None
assert end is not None
assert start.tzinfo == pytz.UTC
assert end.tzinfo == pytz.UTC
# 12:00 EDT = 16:00 UTC
assert start.hour == 16
assert end.hour == 16
def test_parse_time_range_winter_time(self):
"""Test parsing during winter time (standard time)."""
# Test with US/Eastern during winter (EST = UTC-5)
start, end = parse_time_range("2024-01-01 12:00", "2024-01-01 12:00", "US/Eastern")
assert start is not None
assert end is not None
assert start.tzinfo == pytz.UTC
assert end.tzinfo == pytz.UTC
# 12:00 EST = 17:00 UTC
assert start.hour == 17
assert end.hour == 17
def test_parse_time_range_empty_strings(self):
"""Test parsing with empty strings."""
# Empty strings are treated as None, so they should not raise errors
start, end = parse_time_range("", "2024-01-01 18:00", "UTC")
assert start is None
assert end is not None
start, end = parse_time_range("2024-01-01 10:00", "", "UTC")
assert start is not None
assert end is None
def test_parse_time_range_malformed_datetime(self):
"""Test parsing with malformed datetime strings."""
with pytest.raises(ValueError, match="time data.*does not match format"):
parse_time_range("2024-13-01 10:00", "2024-01-01 18:00", "UTC")
with pytest.raises(ValueError, match="time data.*does not match format"):
parse_time_range("2024-01-01 10:00", "2024-01-32 18:00", "UTC")
def test_parse_time_range_very_long_time_range(self):
"""Test parsing with very long time range."""
start, end = parse_time_range("2020-01-01 00:00", "2030-12-31 23:59", "UTC")
assert start is not None
assert end is not None
assert start < end
assert (end - start).days > 3000 # More than 8 years
def test_parse_time_range_negative_timezone(self):
"""Test parsing with negative timezone offset."""
start, end = parse_time_range("2024-01-01 12:00", "2024-01-01 12:00", "America/New_York")
assert start is not None
assert end is not None
assert start.tzinfo == pytz.UTC
assert end.tzinfo == pytz.UTC

View File

@ -1,5 +1,10 @@
from unittest.mock import MagicMock
from werkzeug.wrappers import Response
from constants import COOKIE_NAME_ACCESS_TOKEN, COOKIE_NAME_WEBAPP_ACCESS_TOKEN
from libs.token import extract_access_token, extract_webapp_access_token
from libs import token
from libs.token import extract_access_token, extract_webapp_access_token, set_csrf_token_to_cookie
class MockRequest:
@ -23,3 +28,35 @@ def test_extract_access_token():
for request, expected_console, expected_webapp in test_cases:
assert extract_access_token(request) == expected_console # pyright: ignore[reportArgumentType]
assert extract_webapp_access_token(request) == expected_webapp # pyright: ignore[reportArgumentType]
def test_real_cookie_name_uses_host_prefix_without_domain(monkeypatch):
monkeypatch.setattr(token.dify_config, "CONSOLE_WEB_URL", "https://console.example.com", raising=False)
monkeypatch.setattr(token.dify_config, "CONSOLE_API_URL", "https://api.example.com", raising=False)
monkeypatch.setattr(token.dify_config, "COOKIE_DOMAIN", "", raising=False)
assert token._real_cookie_name("csrf_token") == "__Host-csrf_token"
def test_real_cookie_name_without_host_prefix_when_domain_present(monkeypatch):
monkeypatch.setattr(token.dify_config, "CONSOLE_WEB_URL", "https://console.example.com", raising=False)
monkeypatch.setattr(token.dify_config, "CONSOLE_API_URL", "https://api.example.com", raising=False)
monkeypatch.setattr(token.dify_config, "COOKIE_DOMAIN", ".example.com", raising=False)
assert token._real_cookie_name("csrf_token") == "csrf_token"
def test_set_csrf_cookie_includes_domain_when_configured(monkeypatch):
monkeypatch.setattr(token.dify_config, "CONSOLE_WEB_URL", "https://console.example.com", raising=False)
monkeypatch.setattr(token.dify_config, "CONSOLE_API_URL", "https://api.example.com", raising=False)
monkeypatch.setattr(token.dify_config, "COOKIE_DOMAIN", ".example.com", raising=False)
response = Response()
request = MagicMock()
set_csrf_token_to_cookie(request, response, "abc123")
cookies = response.headers.getlist("Set-Cookie")
assert any("csrf_token=abc123" in c for c in cookies)
assert any("Domain=example.com" in c for c in cookies)
assert all("__Host-" not in c for c in cookies)

View File

@ -0,0 +1,11 @@
from models.base import DefaultFieldsMixin
class FooModel(DefaultFieldsMixin):
def __init__(self, id: str):
self.id = id
def test_repr():
foo_model = FooModel(id="test-id")
assert repr(foo_model) == "<FooModel(id=test-id)>"

View File

@ -0,0 +1,370 @@
"""Unit tests for DifyAPISQLAlchemyWorkflowRunRepository implementation."""
from datetime import UTC, datetime
from unittest.mock import Mock, patch
import pytest
from sqlalchemy.orm import Session, sessionmaker
from core.workflow.entities.workflow_pause import WorkflowPauseEntity
from core.workflow.enums import WorkflowExecutionStatus
from models.workflow import WorkflowPause as WorkflowPauseModel
from models.workflow import WorkflowRun
from repositories.sqlalchemy_api_workflow_run_repository import (
DifyAPISQLAlchemyWorkflowRunRepository,
_PrivateWorkflowPauseEntity,
_WorkflowRunError,
)
class TestDifyAPISQLAlchemyWorkflowRunRepository:
"""Test DifyAPISQLAlchemyWorkflowRunRepository implementation."""
@pytest.fixture
def mock_session(self):
"""Create a mock session."""
return Mock(spec=Session)
@pytest.fixture
def mock_session_maker(self, mock_session):
"""Create a mock sessionmaker."""
session_maker = Mock(spec=sessionmaker)
# Create a context manager mock
context_manager = Mock()
context_manager.__enter__ = Mock(return_value=mock_session)
context_manager.__exit__ = Mock(return_value=None)
session_maker.return_value = context_manager
# Mock session.begin() context manager
begin_context_manager = Mock()
begin_context_manager.__enter__ = Mock(return_value=None)
begin_context_manager.__exit__ = Mock(return_value=None)
mock_session.begin = Mock(return_value=begin_context_manager)
# Add missing session methods
mock_session.commit = Mock()
mock_session.rollback = Mock()
mock_session.add = Mock()
mock_session.delete = Mock()
mock_session.get = Mock()
mock_session.scalar = Mock()
mock_session.scalars = Mock()
# Also support expire_on_commit parameter
def make_session(expire_on_commit=None):
cm = Mock()
cm.__enter__ = Mock(return_value=mock_session)
cm.__exit__ = Mock(return_value=None)
return cm
session_maker.side_effect = make_session
return session_maker
@pytest.fixture
def repository(self, mock_session_maker):
"""Create repository instance with mocked dependencies."""
# Create a testable subclass that implements the save method
class TestableDifyAPISQLAlchemyWorkflowRunRepository(DifyAPISQLAlchemyWorkflowRunRepository):
def __init__(self, session_maker):
# Initialize without calling parent __init__ to avoid any instantiation issues
self._session_maker = session_maker
def save(self, execution):
"""Mock implementation of save method."""
return None
# Create repository instance
repo = TestableDifyAPISQLAlchemyWorkflowRunRepository(mock_session_maker)
return repo
@pytest.fixture
def sample_workflow_run(self):
"""Create a sample WorkflowRun model."""
workflow_run = Mock(spec=WorkflowRun)
workflow_run.id = "workflow-run-123"
workflow_run.tenant_id = "tenant-123"
workflow_run.app_id = "app-123"
workflow_run.workflow_id = "workflow-123"
workflow_run.status = WorkflowExecutionStatus.RUNNING
return workflow_run
@pytest.fixture
def sample_workflow_pause(self):
"""Create a sample WorkflowPauseModel."""
pause = Mock(spec=WorkflowPauseModel)
pause.id = "pause-123"
pause.workflow_id = "workflow-123"
pause.workflow_run_id = "workflow-run-123"
pause.state_object_key = "workflow-state-123.json"
pause.resumed_at = None
pause.created_at = datetime.now(UTC)
return pause
class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
"""Test create_workflow_pause method."""
def test_create_workflow_pause_success(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
mock_session: Mock,
sample_workflow_run: Mock,
):
"""Test successful workflow pause creation."""
# Arrange
workflow_run_id = "workflow-run-123"
state_owner_user_id = "user-123"
state = '{"test": "state"}'
mock_session.get.return_value = sample_workflow_run
with patch("repositories.sqlalchemy_api_workflow_run_repository.uuidv7") as mock_uuidv7:
mock_uuidv7.side_effect = ["pause-123"]
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
# Act
result = repository.create_workflow_pause(
workflow_run_id=workflow_run_id,
state_owner_user_id=state_owner_user_id,
state=state,
)
# Assert
assert isinstance(result, _PrivateWorkflowPauseEntity)
assert result.id == "pause-123"
assert result.workflow_execution_id == workflow_run_id
# Verify database interactions
mock_session.get.assert_called_once_with(WorkflowRun, workflow_run_id)
mock_storage.save.assert_called_once()
mock_session.add.assert_called()
# When using session.begin() context manager, commit is handled automatically
# No explicit commit call is expected
def test_create_workflow_pause_not_found(
self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock
):
"""Test workflow pause creation when workflow run not found."""
# Arrange
mock_session.get.return_value = None
# Act & Assert
with pytest.raises(ValueError, match="WorkflowRun not found: workflow-run-123"):
repository.create_workflow_pause(
workflow_run_id="workflow-run-123",
state_owner_user_id="user-123",
state='{"test": "state"}',
)
mock_session.get.assert_called_once_with(WorkflowRun, "workflow-run-123")
def test_create_workflow_pause_invalid_status(
self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock, sample_workflow_run: Mock
):
"""Test workflow pause creation when workflow not in RUNNING status."""
# Arrange
sample_workflow_run.status = WorkflowExecutionStatus.PAUSED
mock_session.get.return_value = sample_workflow_run
# Act & Assert
with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING status can be paused"):
repository.create_workflow_pause(
workflow_run_id="workflow-run-123",
state_owner_user_id="user-123",
state='{"test": "state"}',
)
class TestResumeWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
"""Test resume_workflow_pause method."""
def test_resume_workflow_pause_success(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
mock_session: Mock,
sample_workflow_run: Mock,
sample_workflow_pause: Mock,
):
"""Test successful workflow pause resume."""
# Arrange
workflow_run_id = "workflow-run-123"
pause_entity = Mock(spec=WorkflowPauseEntity)
pause_entity.id = "pause-123"
# Setup workflow run and pause
sample_workflow_run.status = WorkflowExecutionStatus.PAUSED
sample_workflow_run.pause = sample_workflow_pause
sample_workflow_pause.resumed_at = None
mock_session.scalar.return_value = sample_workflow_run
with patch("repositories.sqlalchemy_api_workflow_run_repository.naive_utc_now") as mock_now:
mock_now.return_value = datetime.now(UTC)
# Act
result = repository.resume_workflow_pause(
workflow_run_id=workflow_run_id,
pause_entity=pause_entity,
)
# Assert
assert isinstance(result, _PrivateWorkflowPauseEntity)
assert result.id == "pause-123"
# Verify state transitions
assert sample_workflow_pause.resumed_at is not None
assert sample_workflow_run.status == WorkflowExecutionStatus.RUNNING
# Verify database interactions
mock_session.add.assert_called()
# When using session.begin() context manager, commit is handled automatically
# No explicit commit call is expected
def test_resume_workflow_pause_not_paused(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
mock_session: Mock,
sample_workflow_run: Mock,
):
"""Test resume when workflow is not paused."""
# Arrange
workflow_run_id = "workflow-run-123"
pause_entity = Mock(spec=WorkflowPauseEntity)
pause_entity.id = "pause-123"
sample_workflow_run.status = WorkflowExecutionStatus.RUNNING
mock_session.scalar.return_value = sample_workflow_run
# Act & Assert
with pytest.raises(_WorkflowRunError, match="WorkflowRun is not in PAUSED status"):
repository.resume_workflow_pause(
workflow_run_id=workflow_run_id,
pause_entity=pause_entity,
)
def test_resume_workflow_pause_id_mismatch(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
mock_session: Mock,
sample_workflow_run: Mock,
sample_workflow_pause: Mock,
):
"""Test resume when pause ID doesn't match."""
# Arrange
workflow_run_id = "workflow-run-123"
pause_entity = Mock(spec=WorkflowPauseEntity)
pause_entity.id = "pause-456" # Different ID
sample_workflow_run.status = WorkflowExecutionStatus.PAUSED
sample_workflow_pause.id = "pause-123"
sample_workflow_run.pause = sample_workflow_pause
mock_session.scalar.return_value = sample_workflow_run
# Act & Assert
with pytest.raises(_WorkflowRunError, match="different id in WorkflowPause and WorkflowPauseEntity"):
repository.resume_workflow_pause(
workflow_run_id=workflow_run_id,
pause_entity=pause_entity,
)
class TestDeleteWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
"""Test delete_workflow_pause method."""
def test_delete_workflow_pause_success(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
mock_session: Mock,
sample_workflow_pause: Mock,
):
"""Test successful workflow pause deletion."""
# Arrange
pause_entity = Mock(spec=WorkflowPauseEntity)
pause_entity.id = "pause-123"
mock_session.get.return_value = sample_workflow_pause
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
# Act
repository.delete_workflow_pause(pause_entity=pause_entity)
# Assert
mock_storage.delete.assert_called_once_with(sample_workflow_pause.state_object_key)
mock_session.delete.assert_called_once_with(sample_workflow_pause)
# When using session.begin() context manager, commit is handled automatically
# No explicit commit call is expected
def test_delete_workflow_pause_not_found(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
mock_session: Mock,
):
"""Test delete when pause not found."""
# Arrange
pause_entity = Mock(spec=WorkflowPauseEntity)
pause_entity.id = "pause-123"
mock_session.get.return_value = None
# Act & Assert
with pytest.raises(_WorkflowRunError, match="WorkflowPause not found: pause-123"):
repository.delete_workflow_pause(pause_entity=pause_entity)
class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository):
"""Test _PrivateWorkflowPauseEntity class."""
def test_from_models(self, sample_workflow_pause: Mock):
"""Test creating _PrivateWorkflowPauseEntity from models."""
# Act
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
# Assert
assert isinstance(entity, _PrivateWorkflowPauseEntity)
assert entity._pause_model == sample_workflow_pause
def test_properties(self, sample_workflow_pause: Mock):
"""Test entity properties."""
# Arrange
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
# Act & Assert
assert entity.id == sample_workflow_pause.id
assert entity.workflow_execution_id == sample_workflow_pause.workflow_run_id
assert entity.resumed_at == sample_workflow_pause.resumed_at
def test_get_state(self, sample_workflow_pause: Mock):
"""Test getting state from storage."""
# Arrange
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
expected_state = b'{"test": "state"}'
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
mock_storage.load.return_value = expected_state
# Act
result = entity.get_state()
# Assert
assert result == expected_state
mock_storage.load.assert_called_once_with(sample_workflow_pause.state_object_key)
def test_get_state_caching(self, sample_workflow_pause: Mock):
"""Test state caching in get_state method."""
# Arrange
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
expected_state = b'{"test": "state"}'
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
mock_storage.load.return_value = expected_state
# Act
result1 = entity.get_state()
result2 = entity.get_state() # Should use cache
# Assert
assert result1 == expected_state
assert result2 == expected_state
mock_storage.load.assert_called_once() # Only called once due to caching

View File

@ -21,6 +21,7 @@ from core.file.enums import FileTransferMethod, FileType
from core.file.models import File
from core.variables.segments import (
ArrayFileSegment,
ArrayNumberSegment,
ArraySegment,
FileSegment,
FloatSegment,
@ -30,6 +31,7 @@ from core.variables.segments import (
StringSegment,
)
from services.variable_truncator import (
DummyVariableTruncator,
MaxDepthExceededError,
TruncationResult,
UnknownTypeError,
@ -596,3 +598,32 @@ class TestIntegrationScenarios:
truncated_mapping, truncated = truncator.truncate_variable_mapping(mapping)
assert truncated is False
assert truncated_mapping == mapping
def test_dummy_variable_truncator_methods():
"""Test DummyVariableTruncator methods work correctly."""
truncator = DummyVariableTruncator()
# Test truncate_variable_mapping
test_data: dict[str, Any] = {
"key1": "value1",
"key2": ["item1", "item2"],
"large_array": list(range(2000)),
}
result, is_truncated = truncator.truncate_variable_mapping(test_data)
assert result == test_data
assert not is_truncated
# Test truncate method
segment = StringSegment(value="test string")
result = truncator.truncate(segment)
assert isinstance(result, TruncationResult)
assert result.result == segment
assert result.truncated is False
segment = ArrayNumberSegment(value=list(range(2000)))
result = truncator.truncate(segment)
assert isinstance(result, TruncationResult)
assert result.result == segment
assert result.truncated is False

View File

@ -0,0 +1,200 @@
"""Comprehensive unit tests for WorkflowRunService class.
This test suite covers all pause state management operations including:
- Retrieving pause state for workflow runs
- Saving pause state with file uploads
- Marking paused workflows as resumed
- Error handling and edge cases
- Database transaction management
- Repository-based approach testing
"""
from datetime import datetime
from unittest.mock import MagicMock, create_autospec, patch
import pytest
from sqlalchemy import Engine
from sqlalchemy.orm import Session, sessionmaker
from core.workflow.enums import WorkflowExecutionStatus
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity
from services.workflow_run_service import (
WorkflowRunService,
)
class TestDataFactory:
"""Factory class for creating test data objects."""
@staticmethod
def create_workflow_run_mock(
id: str = "workflow-run-123",
tenant_id: str = "tenant-456",
app_id: str = "app-789",
workflow_id: str = "workflow-101",
status: str | WorkflowExecutionStatus = "paused",
pause_id: str | None = None,
**kwargs,
) -> MagicMock:
"""Create a mock WorkflowRun object."""
mock_run = MagicMock()
mock_run.id = id
mock_run.tenant_id = tenant_id
mock_run.app_id = app_id
mock_run.workflow_id = workflow_id
mock_run.status = status
mock_run.pause_id = pause_id
for key, value in kwargs.items():
setattr(mock_run, key, value)
return mock_run
@staticmethod
def create_workflow_pause_mock(
id: str = "pause-123",
tenant_id: str = "tenant-456",
app_id: str = "app-789",
workflow_id: str = "workflow-101",
workflow_execution_id: str = "workflow-execution-123",
state_file_id: str = "file-456",
resumed_at: datetime | None = None,
**kwargs,
) -> MagicMock:
"""Create a mock WorkflowPauseModel object."""
mock_pause = MagicMock()
mock_pause.id = id
mock_pause.tenant_id = tenant_id
mock_pause.app_id = app_id
mock_pause.workflow_id = workflow_id
mock_pause.workflow_execution_id = workflow_execution_id
mock_pause.state_file_id = state_file_id
mock_pause.resumed_at = resumed_at
for key, value in kwargs.items():
setattr(mock_pause, key, value)
return mock_pause
@staticmethod
def create_upload_file_mock(
id: str = "file-456",
key: str = "upload_files/test/state.json",
name: str = "state.json",
tenant_id: str = "tenant-456",
**kwargs,
) -> MagicMock:
"""Create a mock UploadFile object."""
mock_file = MagicMock()
mock_file.id = id
mock_file.key = key
mock_file.name = name
mock_file.tenant_id = tenant_id
for key, value in kwargs.items():
setattr(mock_file, key, value)
return mock_file
@staticmethod
def create_pause_entity_mock(
pause_model: MagicMock | None = None,
upload_file: MagicMock | None = None,
) -> _PrivateWorkflowPauseEntity:
"""Create a mock _PrivateWorkflowPauseEntity object."""
if pause_model is None:
pause_model = TestDataFactory.create_workflow_pause_mock()
if upload_file is None:
upload_file = TestDataFactory.create_upload_file_mock()
return _PrivateWorkflowPauseEntity.from_models(pause_model, upload_file)
class TestWorkflowRunService:
"""Comprehensive unit tests for WorkflowRunService class."""
@pytest.fixture
def mock_session_factory(self):
"""Create a mock session factory with proper session management."""
mock_session = create_autospec(Session)
# Create a mock context manager for the session
mock_session_cm = MagicMock()
mock_session_cm.__enter__ = MagicMock(return_value=mock_session)
mock_session_cm.__exit__ = MagicMock(return_value=None)
# Create a mock context manager for the transaction
mock_transaction_cm = MagicMock()
mock_transaction_cm.__enter__ = MagicMock(return_value=mock_session)
mock_transaction_cm.__exit__ = MagicMock(return_value=None)
mock_session.begin = MagicMock(return_value=mock_transaction_cm)
# Create mock factory that returns the context manager
mock_factory = MagicMock(spec=sessionmaker)
mock_factory.return_value = mock_session_cm
return mock_factory, mock_session
@pytest.fixture
def mock_workflow_run_repository(self):
"""Create a mock APIWorkflowRunRepository."""
mock_repo = create_autospec(APIWorkflowRunRepository)
return mock_repo
@pytest.fixture
def workflow_run_service(self, mock_session_factory, mock_workflow_run_repository):
"""Create WorkflowRunService instance with mocked dependencies."""
session_factory, _ = mock_session_factory
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
service = WorkflowRunService(session_factory)
return service
@pytest.fixture
def workflow_run_service_with_engine(self, mock_session_factory, mock_workflow_run_repository):
"""Create WorkflowRunService instance with Engine input."""
mock_engine = create_autospec(Engine)
session_factory, _ = mock_session_factory
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
service = WorkflowRunService(mock_engine)
return service
# ==================== Initialization Tests ====================
def test_init_with_session_factory(self, mock_session_factory, mock_workflow_run_repository):
"""Test WorkflowRunService initialization with session_factory."""
session_factory, _ = mock_session_factory
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
service = WorkflowRunService(session_factory)
assert service._session_factory == session_factory
mock_factory.create_api_workflow_run_repository.assert_called_once_with(session_factory)
def test_init_with_engine(self, mock_session_factory, mock_workflow_run_repository):
"""Test WorkflowRunService initialization with Engine (should convert to sessionmaker)."""
mock_engine = create_autospec(Engine)
session_factory, _ = mock_session_factory
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
with patch("services.workflow_run_service.sessionmaker", return_value=session_factory) as mock_sessionmaker:
service = WorkflowRunService(mock_engine)
mock_sessionmaker.assert_called_once_with(bind=mock_engine, expire_on_commit=False)
assert service._session_factory == session_factory
mock_factory.create_api_workflow_run_repository.assert_called_once_with(session_factory)
def test_init_with_default_dependencies(self, mock_session_factory):
"""Test WorkflowRunService initialization with default dependencies."""
session_factory, _ = mock_session_factory
service = WorkflowRunService(session_factory)
assert service._session_factory == session_factory