mirror of
https://github.com/langgenius/dify.git
synced 2026-03-21 22:38:26 +08:00
refactor: migrate workflow run repository unit tests from mocks to te… (#33843)
This commit is contained in:
@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock
|
||||
@ -12,15 +13,26 @@ from sqlalchemy import Engine, delete, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from dify_graph.entities import WorkflowExecution
|
||||
from dify_graph.entities.pause_reason import PauseReasonType
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired, PauseReasonType
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from dify_graph.nodes.human_input.entities import FormDefinition, FormInput, UserAction
|
||||
from dify_graph.nodes.human_input.enums import DeliveryMethodType, FormInputType, HumanInputFormStatus
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
|
||||
from models.human_input import (
|
||||
BackstageRecipientPayload,
|
||||
HumanInputDelivery,
|
||||
HumanInputForm,
|
||||
HumanInputFormRecipient,
|
||||
RecipientType,
|
||||
)
|
||||
from models.workflow import WorkflowAppLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import (
|
||||
DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
_build_human_input_required_reason,
|
||||
_PrivateWorkflowPauseEntity,
|
||||
_WorkflowRunError,
|
||||
)
|
||||
|
||||
@ -90,6 +102,19 @@ def _cleanup_scope_data(session: Session, scope: _TestScope) -> None:
|
||||
WorkflowRun.app_id == scope.app_id,
|
||||
)
|
||||
)
|
||||
|
||||
form_ids_subquery = select(HumanInputForm.id).where(
|
||||
HumanInputForm.tenant_id == scope.tenant_id,
|
||||
HumanInputForm.app_id == scope.app_id,
|
||||
)
|
||||
session.execute(delete(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids_subquery)))
|
||||
session.execute(delete(HumanInputDelivery).where(HumanInputDelivery.form_id.in_(form_ids_subquery)))
|
||||
session.execute(
|
||||
delete(HumanInputForm).where(
|
||||
HumanInputForm.tenant_id == scope.tenant_id,
|
||||
HumanInputForm.app_id == scope.app_id,
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
for state_key in scope.state_keys:
|
||||
@ -504,3 +529,200 @@ class TestDeleteWorkflowPause:
|
||||
|
||||
with pytest.raises(_WorkflowRunError, match="WorkflowPause not found"):
|
||||
repository.delete_workflow_pause(pause_entity=pause_entity)
|
||||
|
||||
|
||||
class TestPrivateWorkflowPauseEntity:
|
||||
"""Integration tests for _PrivateWorkflowPauseEntity using real DB models."""
|
||||
|
||||
def test_properties(
|
||||
self,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Entity properties delegate to the persisted WorkflowPause model."""
|
||||
|
||||
workflow_run = _create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
)
|
||||
pause = WorkflowPause(
|
||||
id=str(uuid4()),
|
||||
workflow_id=test_scope.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_object_key=f"workflow-state-{uuid4()}.json",
|
||||
)
|
||||
db_session_with_containers.add(pause)
|
||||
db_session_with_containers.commit()
|
||||
db_session_with_containers.refresh(pause)
|
||||
test_scope.state_keys.add(pause.state_object_key)
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=pause, reason_models=[], human_input_form=[])
|
||||
|
||||
assert entity.id == pause.id
|
||||
assert entity.workflow_execution_id == workflow_run.id
|
||||
assert entity.resumed_at is None
|
||||
|
||||
def test_get_state(
|
||||
self,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""get_state loads state data from storage using the persisted state_object_key."""
|
||||
|
||||
workflow_run = _create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
)
|
||||
state_key = f"workflow-state-{uuid4()}.json"
|
||||
pause = WorkflowPause(
|
||||
id=str(uuid4()),
|
||||
workflow_id=test_scope.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_object_key=state_key,
|
||||
)
|
||||
db_session_with_containers.add(pause)
|
||||
db_session_with_containers.commit()
|
||||
db_session_with_containers.refresh(pause)
|
||||
test_scope.state_keys.add(state_key)
|
||||
|
||||
expected_state = b'{"test": "state"}'
|
||||
storage.save(state_key, expected_state)
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=pause, reason_models=[], human_input_form=[])
|
||||
result = entity.get_state()
|
||||
|
||||
assert result == expected_state
|
||||
|
||||
def test_get_state_caching(
|
||||
self,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""get_state caches the result so storage is only accessed once."""
|
||||
|
||||
workflow_run = _create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
)
|
||||
state_key = f"workflow-state-{uuid4()}.json"
|
||||
pause = WorkflowPause(
|
||||
id=str(uuid4()),
|
||||
workflow_id=test_scope.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_object_key=state_key,
|
||||
)
|
||||
db_session_with_containers.add(pause)
|
||||
db_session_with_containers.commit()
|
||||
db_session_with_containers.refresh(pause)
|
||||
test_scope.state_keys.add(state_key)
|
||||
|
||||
expected_state = b'{"test": "state"}'
|
||||
storage.save(state_key, expected_state)
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=pause, reason_models=[], human_input_form=[])
|
||||
result1 = entity.get_state()
|
||||
# Delete from storage to prove second call uses cache
|
||||
storage.delete(state_key)
|
||||
test_scope.state_keys.discard(state_key)
|
||||
result2 = entity.get_state()
|
||||
|
||||
assert result1 == expected_state
|
||||
assert result2 == expected_state
|
||||
|
||||
|
||||
class TestBuildHumanInputRequiredReason:
|
||||
"""Integration tests for _build_human_input_required_reason using real DB models."""
|
||||
|
||||
def test_prefers_backstage_token_when_available(
|
||||
self,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Use backstage token when multiple recipient types may exist."""
|
||||
|
||||
expiration_time = naive_utc_now()
|
||||
form_definition = FormDefinition(
|
||||
form_content="content",
|
||||
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
rendered_content="rendered",
|
||||
expiration_time=expiration_time,
|
||||
default_values={"name": "Alice"},
|
||||
node_title="Ask Name",
|
||||
display_in_ui=True,
|
||||
)
|
||||
|
||||
form_model = HumanInputForm(
|
||||
tenant_id=test_scope.tenant_id,
|
||||
app_id=test_scope.app_id,
|
||||
workflow_run_id=str(uuid4()),
|
||||
node_id="node-1",
|
||||
form_definition=form_definition.model_dump_json(),
|
||||
rendered_content="rendered",
|
||||
status=HumanInputFormStatus.WAITING,
|
||||
expiration_time=expiration_time,
|
||||
)
|
||||
db_session_with_containers.add(form_model)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
delivery = HumanInputDelivery(
|
||||
form_id=form_model.id,
|
||||
delivery_method_type=DeliveryMethodType.WEBAPP,
|
||||
channel_payload="{}",
|
||||
)
|
||||
db_session_with_containers.add(delivery)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
access_token = secrets.token_urlsafe(8)
|
||||
recipient = HumanInputFormRecipient(
|
||||
form_id=form_model.id,
|
||||
delivery_id=delivery.id,
|
||||
recipient_type=RecipientType.BACKSTAGE,
|
||||
recipient_payload=BackstageRecipientPayload().model_dump_json(),
|
||||
access_token=access_token,
|
||||
)
|
||||
db_session_with_containers.add(recipient)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
# Create a pause so the reason has a valid pause_id
|
||||
workflow_run = _create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
)
|
||||
pause = WorkflowPause(
|
||||
id=str(uuid4()),
|
||||
workflow_id=test_scope.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_object_key=f"workflow-state-{uuid4()}.json",
|
||||
)
|
||||
db_session_with_containers.add(pause)
|
||||
db_session_with_containers.flush()
|
||||
test_scope.state_keys.add(pause.state_object_key)
|
||||
|
||||
reason_model = WorkflowPauseReason(
|
||||
pause_id=pause.id,
|
||||
type_=PauseReasonType.HUMAN_INPUT_REQUIRED,
|
||||
form_id=form_model.id,
|
||||
node_id="node-1",
|
||||
message="",
|
||||
)
|
||||
db_session_with_containers.add(reason_model)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Refresh to ensure we have DB-round-tripped objects
|
||||
db_session_with_containers.refresh(form_model)
|
||||
db_session_with_containers.refresh(reason_model)
|
||||
db_session_with_containers.refresh(recipient)
|
||||
|
||||
reason = _build_human_input_required_reason(reason_model, form_model, [recipient])
|
||||
|
||||
assert isinstance(reason, HumanInputRequired)
|
||||
assert reason.form_token == access_token
|
||||
assert reason.node_title == "Ask Name"
|
||||
assert reason.form_content == "content"
|
||||
assert reason.inputs[0].output_variable_name == "name"
|
||||
assert reason.actions[0].id == "approve"
|
||||
|
||||
@ -1,135 +0,0 @@
|
||||
"""Unit tests for non-SQL helper logic in workflow run repository."""
|
||||
|
||||
import secrets
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired, PauseReasonType
|
||||
from dify_graph.nodes.human_input.entities import FormDefinition, FormInput, UserAction
|
||||
from dify_graph.nodes.human_input.enums import FormInputType, HumanInputFormStatus
|
||||
from models.human_input import BackstageRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType
|
||||
from models.workflow import WorkflowPause as WorkflowPauseModel
|
||||
from models.workflow import WorkflowPauseReason
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import (
|
||||
_build_human_input_required_reason,
|
||||
_PrivateWorkflowPauseEntity,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workflow_pause() -> Mock:
|
||||
"""Create a sample WorkflowPause model."""
|
||||
pause = Mock(spec=WorkflowPauseModel)
|
||||
pause.id = "pause-123"
|
||||
pause.workflow_id = "workflow-123"
|
||||
pause.workflow_run_id = "workflow-run-123"
|
||||
pause.state_object_key = "workflow-state-123.json"
|
||||
pause.resumed_at = None
|
||||
pause.created_at = datetime.now(UTC)
|
||||
return pause
|
||||
|
||||
|
||||
class TestPrivateWorkflowPauseEntity:
|
||||
"""Test _PrivateWorkflowPauseEntity class."""
|
||||
|
||||
def test_properties(self, sample_workflow_pause: Mock) -> None:
|
||||
"""Test entity properties."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
|
||||
|
||||
# Assert
|
||||
assert entity.id == sample_workflow_pause.id
|
||||
assert entity.workflow_execution_id == sample_workflow_pause.workflow_run_id
|
||||
assert entity.resumed_at == sample_workflow_pause.resumed_at
|
||||
|
||||
def test_get_state(self, sample_workflow_pause: Mock) -> None:
|
||||
"""Test getting state from storage."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
|
||||
expected_state = b'{"test": "state"}'
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
mock_storage.load.return_value = expected_state
|
||||
|
||||
# Act
|
||||
result = entity.get_state()
|
||||
|
||||
# Assert
|
||||
assert result == expected_state
|
||||
mock_storage.load.assert_called_once_with(sample_workflow_pause.state_object_key)
|
||||
|
||||
def test_get_state_caching(self, sample_workflow_pause: Mock) -> None:
|
||||
"""Test state caching in get_state method."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
|
||||
expected_state = b'{"test": "state"}'
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
mock_storage.load.return_value = expected_state
|
||||
|
||||
# Act
|
||||
result1 = entity.get_state()
|
||||
result2 = entity.get_state()
|
||||
|
||||
# Assert
|
||||
assert result1 == expected_state
|
||||
assert result2 == expected_state
|
||||
mock_storage.load.assert_called_once()
|
||||
|
||||
|
||||
class TestBuildHumanInputRequiredReason:
|
||||
"""Test helper that builds HumanInputRequired pause reasons."""
|
||||
|
||||
def test_prefers_backstage_token_when_available(self) -> None:
|
||||
"""Use backstage token when multiple recipient types may exist."""
|
||||
# Arrange
|
||||
expiration_time = datetime.now(UTC)
|
||||
form_definition = FormDefinition(
|
||||
form_content="content",
|
||||
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
rendered_content="rendered",
|
||||
expiration_time=expiration_time,
|
||||
default_values={"name": "Alice"},
|
||||
node_title="Ask Name",
|
||||
display_in_ui=True,
|
||||
)
|
||||
form_model = HumanInputForm(
|
||||
id="form-1",
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
workflow_run_id="run-1",
|
||||
node_id="node-1",
|
||||
form_definition=form_definition.model_dump_json(),
|
||||
rendered_content="rendered",
|
||||
status=HumanInputFormStatus.WAITING,
|
||||
expiration_time=expiration_time,
|
||||
)
|
||||
reason_model = WorkflowPauseReason(
|
||||
pause_id="pause-1",
|
||||
type_=PauseReasonType.HUMAN_INPUT_REQUIRED,
|
||||
form_id="form-1",
|
||||
node_id="node-1",
|
||||
message="",
|
||||
)
|
||||
access_token = secrets.token_urlsafe(8)
|
||||
backstage_recipient = HumanInputFormRecipient(
|
||||
form_id="form-1",
|
||||
delivery_id="delivery-1",
|
||||
recipient_type=RecipientType.BACKSTAGE,
|
||||
recipient_payload=BackstageRecipientPayload().model_dump_json(),
|
||||
access_token=access_token,
|
||||
)
|
||||
|
||||
# Act
|
||||
reason = _build_human_input_required_reason(reason_model, form_model, [backstage_recipient])
|
||||
|
||||
# Assert
|
||||
assert isinstance(reason, HumanInputRequired)
|
||||
assert reason.form_token == access_token
|
||||
assert reason.node_title == "Ask Name"
|
||||
assert reason.form_content == "content"
|
||||
assert reason.inputs[0].output_variable_name == "name"
|
||||
assert reason.actions[0].id == "approve"
|
||||
Reference in New Issue
Block a user