refactor: migrate workflow run repository unit tests from mocks to te… (#33843)

This commit is contained in:
Desel72
2026-03-21 05:54:56 -05:00
committed by GitHub
parent 097773c9f5
commit 2ce2fbc2d4
2 changed files with 223 additions and 136 deletions

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import secrets
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from unittest.mock import Mock
@ -12,15 +13,26 @@ from sqlalchemy import Engine, delete, select
from sqlalchemy.orm import Session, sessionmaker
from dify_graph.entities import WorkflowExecution
from dify_graph.entities.pause_reason import PauseReasonType
from dify_graph.entities.pause_reason import HumanInputRequired, PauseReasonType
from dify_graph.enums import WorkflowExecutionStatus
from dify_graph.nodes.human_input.entities import FormDefinition, FormInput, UserAction
from dify_graph.nodes.human_input.enums import DeliveryMethodType, FormInputType, HumanInputFormStatus
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
from models.human_input import (
BackstageRecipientPayload,
HumanInputDelivery,
HumanInputForm,
HumanInputFormRecipient,
RecipientType,
)
from models.workflow import WorkflowAppLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
from repositories.entities.workflow_pause import WorkflowPauseEntity
from repositories.sqlalchemy_api_workflow_run_repository import (
DifyAPISQLAlchemyWorkflowRunRepository,
_build_human_input_required_reason,
_PrivateWorkflowPauseEntity,
_WorkflowRunError,
)
@ -90,6 +102,19 @@ def _cleanup_scope_data(session: Session, scope: _TestScope) -> None:
WorkflowRun.app_id == scope.app_id,
)
)
form_ids_subquery = select(HumanInputForm.id).where(
HumanInputForm.tenant_id == scope.tenant_id,
HumanInputForm.app_id == scope.app_id,
)
session.execute(delete(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids_subquery)))
session.execute(delete(HumanInputDelivery).where(HumanInputDelivery.form_id.in_(form_ids_subquery)))
session.execute(
delete(HumanInputForm).where(
HumanInputForm.tenant_id == scope.tenant_id,
HumanInputForm.app_id == scope.app_id,
)
)
session.commit()
for state_key in scope.state_keys:
@ -504,3 +529,200 @@ class TestDeleteWorkflowPause:
with pytest.raises(_WorkflowRunError, match="WorkflowPause not found"):
repository.delete_workflow_pause(pause_entity=pause_entity)
class TestPrivateWorkflowPauseEntity:
"""Integration tests for _PrivateWorkflowPauseEntity using real DB models."""
def test_properties(
self,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Entity properties delegate to the persisted WorkflowPause model."""
workflow_run = _create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.RUNNING,
)
pause = WorkflowPause(
id=str(uuid4()),
workflow_id=test_scope.workflow_id,
workflow_run_id=workflow_run.id,
state_object_key=f"workflow-state-{uuid4()}.json",
)
db_session_with_containers.add(pause)
db_session_with_containers.commit()
db_session_with_containers.refresh(pause)
test_scope.state_keys.add(pause.state_object_key)
entity = _PrivateWorkflowPauseEntity(pause_model=pause, reason_models=[], human_input_form=[])
assert entity.id == pause.id
assert entity.workflow_execution_id == workflow_run.id
assert entity.resumed_at is None
def test_get_state(
self,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""get_state loads state data from storage using the persisted state_object_key."""
workflow_run = _create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.RUNNING,
)
state_key = f"workflow-state-{uuid4()}.json"
pause = WorkflowPause(
id=str(uuid4()),
workflow_id=test_scope.workflow_id,
workflow_run_id=workflow_run.id,
state_object_key=state_key,
)
db_session_with_containers.add(pause)
db_session_with_containers.commit()
db_session_with_containers.refresh(pause)
test_scope.state_keys.add(state_key)
expected_state = b'{"test": "state"}'
storage.save(state_key, expected_state)
entity = _PrivateWorkflowPauseEntity(pause_model=pause, reason_models=[], human_input_form=[])
result = entity.get_state()
assert result == expected_state
def test_get_state_caching(
self,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""get_state caches the result so storage is only accessed once."""
workflow_run = _create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.RUNNING,
)
state_key = f"workflow-state-{uuid4()}.json"
pause = WorkflowPause(
id=str(uuid4()),
workflow_id=test_scope.workflow_id,
workflow_run_id=workflow_run.id,
state_object_key=state_key,
)
db_session_with_containers.add(pause)
db_session_with_containers.commit()
db_session_with_containers.refresh(pause)
test_scope.state_keys.add(state_key)
expected_state = b'{"test": "state"}'
storage.save(state_key, expected_state)
entity = _PrivateWorkflowPauseEntity(pause_model=pause, reason_models=[], human_input_form=[])
result1 = entity.get_state()
# Delete from storage to prove second call uses cache
storage.delete(state_key)
test_scope.state_keys.discard(state_key)
result2 = entity.get_state()
assert result1 == expected_state
assert result2 == expected_state
class TestBuildHumanInputRequiredReason:
"""Integration tests for _build_human_input_required_reason using real DB models."""
def test_prefers_backstage_token_when_available(
self,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Use backstage token when multiple recipient types may exist."""
expiration_time = naive_utc_now()
form_definition = FormDefinition(
form_content="content",
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
user_actions=[UserAction(id="approve", title="Approve")],
rendered_content="rendered",
expiration_time=expiration_time,
default_values={"name": "Alice"},
node_title="Ask Name",
display_in_ui=True,
)
form_model = HumanInputForm(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
workflow_run_id=str(uuid4()),
node_id="node-1",
form_definition=form_definition.model_dump_json(),
rendered_content="rendered",
status=HumanInputFormStatus.WAITING,
expiration_time=expiration_time,
)
db_session_with_containers.add(form_model)
db_session_with_containers.flush()
delivery = HumanInputDelivery(
form_id=form_model.id,
delivery_method_type=DeliveryMethodType.WEBAPP,
channel_payload="{}",
)
db_session_with_containers.add(delivery)
db_session_with_containers.flush()
access_token = secrets.token_urlsafe(8)
recipient = HumanInputFormRecipient(
form_id=form_model.id,
delivery_id=delivery.id,
recipient_type=RecipientType.BACKSTAGE,
recipient_payload=BackstageRecipientPayload().model_dump_json(),
access_token=access_token,
)
db_session_with_containers.add(recipient)
db_session_with_containers.flush()
# Create a pause so the reason has a valid pause_id
workflow_run = _create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.RUNNING,
)
pause = WorkflowPause(
id=str(uuid4()),
workflow_id=test_scope.workflow_id,
workflow_run_id=workflow_run.id,
state_object_key=f"workflow-state-{uuid4()}.json",
)
db_session_with_containers.add(pause)
db_session_with_containers.flush()
test_scope.state_keys.add(pause.state_object_key)
reason_model = WorkflowPauseReason(
pause_id=pause.id,
type_=PauseReasonType.HUMAN_INPUT_REQUIRED,
form_id=form_model.id,
node_id="node-1",
message="",
)
db_session_with_containers.add(reason_model)
db_session_with_containers.commit()
# Refresh to ensure we have DB-round-tripped objects
db_session_with_containers.refresh(form_model)
db_session_with_containers.refresh(reason_model)
db_session_with_containers.refresh(recipient)
reason = _build_human_input_required_reason(reason_model, form_model, [recipient])
assert isinstance(reason, HumanInputRequired)
assert reason.form_token == access_token
assert reason.node_title == "Ask Name"
assert reason.form_content == "content"
assert reason.inputs[0].output_variable_name == "name"
assert reason.actions[0].id == "approve"

View File

@ -1,135 +0,0 @@
"""Unit tests for non-SQL helper logic in workflow run repository."""
import secrets
from datetime import UTC, datetime
from unittest.mock import Mock, patch
import pytest
from dify_graph.entities.pause_reason import HumanInputRequired, PauseReasonType
from dify_graph.nodes.human_input.entities import FormDefinition, FormInput, UserAction
from dify_graph.nodes.human_input.enums import FormInputType, HumanInputFormStatus
from models.human_input import BackstageRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType
from models.workflow import WorkflowPause as WorkflowPauseModel
from models.workflow import WorkflowPauseReason
from repositories.sqlalchemy_api_workflow_run_repository import (
_build_human_input_required_reason,
_PrivateWorkflowPauseEntity,
)
@pytest.fixture
def sample_workflow_pause() -> Mock:
"""Create a sample WorkflowPause model."""
pause = Mock(spec=WorkflowPauseModel)
pause.id = "pause-123"
pause.workflow_id = "workflow-123"
pause.workflow_run_id = "workflow-run-123"
pause.state_object_key = "workflow-state-123.json"
pause.resumed_at = None
pause.created_at = datetime.now(UTC)
return pause
class TestPrivateWorkflowPauseEntity:
"""Test _PrivateWorkflowPauseEntity class."""
def test_properties(self, sample_workflow_pause: Mock) -> None:
"""Test entity properties."""
# Arrange
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
# Assert
assert entity.id == sample_workflow_pause.id
assert entity.workflow_execution_id == sample_workflow_pause.workflow_run_id
assert entity.resumed_at == sample_workflow_pause.resumed_at
def test_get_state(self, sample_workflow_pause: Mock) -> None:
"""Test getting state from storage."""
# Arrange
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
expected_state = b'{"test": "state"}'
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
mock_storage.load.return_value = expected_state
# Act
result = entity.get_state()
# Assert
assert result == expected_state
mock_storage.load.assert_called_once_with(sample_workflow_pause.state_object_key)
def test_get_state_caching(self, sample_workflow_pause: Mock) -> None:
"""Test state caching in get_state method."""
# Arrange
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
expected_state = b'{"test": "state"}'
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
mock_storage.load.return_value = expected_state
# Act
result1 = entity.get_state()
result2 = entity.get_state()
# Assert
assert result1 == expected_state
assert result2 == expected_state
mock_storage.load.assert_called_once()
class TestBuildHumanInputRequiredReason:
"""Test helper that builds HumanInputRequired pause reasons."""
def test_prefers_backstage_token_when_available(self) -> None:
"""Use backstage token when multiple recipient types may exist."""
# Arrange
expiration_time = datetime.now(UTC)
form_definition = FormDefinition(
form_content="content",
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
user_actions=[UserAction(id="approve", title="Approve")],
rendered_content="rendered",
expiration_time=expiration_time,
default_values={"name": "Alice"},
node_title="Ask Name",
display_in_ui=True,
)
form_model = HumanInputForm(
id="form-1",
tenant_id="tenant-1",
app_id="app-1",
workflow_run_id="run-1",
node_id="node-1",
form_definition=form_definition.model_dump_json(),
rendered_content="rendered",
status=HumanInputFormStatus.WAITING,
expiration_time=expiration_time,
)
reason_model = WorkflowPauseReason(
pause_id="pause-1",
type_=PauseReasonType.HUMAN_INPUT_REQUIRED,
form_id="form-1",
node_id="node-1",
message="",
)
access_token = secrets.token_urlsafe(8)
backstage_recipient = HumanInputFormRecipient(
form_id="form-1",
delivery_id="delivery-1",
recipient_type=RecipientType.BACKSTAGE,
recipient_payload=BackstageRecipientPayload().model_dump_json(),
access_token=access_token,
)
# Act
reason = _build_human_input_required_reason(reason_model, form_model, [backstage_recipient])
# Assert
assert isinstance(reason, HumanInputRequired)
assert reason.form_token == access_token
assert reason.node_title == "Ask Name"
assert reason.form_content == "content"
assert reason.inputs[0].output_variable_name == "name"
assert reason.actions[0].id == "approve"