diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 4436773d25..324dd059d1 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -488,6 +488,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, WorkflowNodeExecutionModel.tenant_id == self._tenant_id, WorkflowNodeExecutionModel.triggered_from == triggered_from, + WorkflowNodeExecutionModel.status != WorkflowNodeExecutionStatus.PAUSED, ) if self._app_id: diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 09fa9de686..8169b30988 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -260,10 +260,33 @@ class Node(Generic[NodeDataT]): return self._node_execution_id def ensure_execution_id(self) -> str: - if not self._node_execution_id: - self._node_execution_id = str(uuid4()) + if self._node_execution_id: + return self._node_execution_id + + resumed_execution_id = self._restore_execution_id_from_runtime_state() + if resumed_execution_id: + self._node_execution_id = resumed_execution_id + return self._node_execution_id + + self._node_execution_id = str(uuid4()) return self._node_execution_id + def _restore_execution_id_from_runtime_state(self) -> str | None: + graph_execution = self.graph_runtime_state.graph_execution + try: + node_executions = graph_execution.node_executions + except AttributeError: + return None + if not isinstance(node_executions, dict): + return None + node_execution = node_executions.get(self._node_id) + if node_execution is None: + return None + execution_id = node_execution.execution_id + if not execution_id: + return None + return str(execution_id) + def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT: return cast(NodeDataT, self._node_data_type.model_validate(data)) diff --git a/api/core/workflow/nodes/human_input/human_input_node.py b/api/core/workflow/nodes/human_input/human_input_node.py index dff8b75f57..79be9015b5 100644 --- a/api/core/workflow/nodes/human_input/human_input_node.py +++ b/api/core/workflow/nodes/human_input/human_input_node.py @@ -8,6 +8,7 @@ from core.workflow.entities.pause_reason import HumanInputRequired from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import HumanInputFormFilledEvent, NodeRunResult, PauseRequestedEvent from core.workflow.node_events.base import NodeEventBase +from core.workflow.node_events.node import StreamCompletedEvent from core.workflow.nodes.base.node import Node from core.workflow.repositories.human_input_form_repository import ( FormCreateParams, @@ -166,34 +167,7 @@ class HumanInputNode(Node[HumanInputNodeData]): resolved_placeholder_values=resolved_placeholder_values, ) - def _create_form(self) -> Generator[NodeEventBase, None, None] | NodeRunResult: - try: - params = FormCreateParams( - workflow_execution_id=self._workflow_execution_id, - node_id=self.id, - form_config=self._node_data, - rendered_content=self._render_form_content(), - resolved_placeholder_values=self._resolve_inputs(), - ) - form_entity = self._form_repository.create_form(params) - # Create human input required event - - logger.info( - "Human Input node suspended workflow for form. workflow_run_id=%s, node_id=%s, form_id=%s", - self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id, - self.id, - form_entity.id, - ) - yield self._form_to_pause_event(form_entity) - except Exception as e: - logger.exception("Human Input node failed to execute, node_id=%s", self.id) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - error_type="HumanInputNodeError", - ) - - def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]: + def _run(self) -> Generator[NodeEventBase, None, None]: """ Execute the human input node. @@ -208,56 +182,69 @@ class HumanInputNode(Node[HumanInputNodeData]): repo = self._form_repository form = repo.get_form(self._workflow_execution_id, self.id) if form is None: - return self._create_form() - - if form.submitted: - selected_action_id = form.selected_action_id - if selected_action_id is None: - raise AssertionError(f"selected_action_id should not be None when form submitted, form_id={form.id}") - submitted_data = form.submitted_data or {} - outputs: dict[str, Any] = dict(submitted_data) - outputs["__action_id"] = selected_action_id - rendered_content = self._render_form_content_with_outputs( - form.rendered_content, - outputs, - self._node_data.outputs_field_names(), + params = FormCreateParams( + workflow_execution_id=self._workflow_execution_id, + node_id=self.id, + form_config=self._node_data, + rendered_content=self._render_form_content_before_submission(), + resolved_placeholder_values=self._resolve_inputs(), ) - outputs["__rendered_content"] = rendered_content + form_entity = self._form_repository.create_form(params) + # Create human input required event - action_text = self._node_data.find_action_text(selected_action_id) - - yield HumanInputFormFilledEvent( - rendered_content=rendered_content, - action_id=selected_action_id, - action_text=action_text, + logger.info( + "Human Input node suspended workflow for form. workflow_run_id=%s, node_id=%s, form_id=%s", + self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id, + self.id, + form_entity.id, ) + yield self._form_to_pause_event(form_entity) + return - return NodeRunResult( + if form.status == HumanInputFormStatus.TIMEOUT or form.expiration_time <= naive_utc_now(): + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={}, + edge_source_handle="__timeout", + ) + ) + return + + if not form.submitted: + yield self._form_to_pause_event(form) + return + + selected_action_id = form.selected_action_id + if selected_action_id is None: + raise AssertionError(f"selected_action_id should not be None when form submitted, form_id={form.id}") + submitted_data = form.submitted_data or {} + outputs: dict[str, Any] = dict(submitted_data) + outputs["__action_id"] = selected_action_id + rendered_content = self._render_form_content_with_outputs( + form.rendered_content, + outputs, + self._node_data.outputs_field_names(), + ) + outputs["__rendered_content"] = rendered_content + + action_text = self._node_data.find_action_text(selected_action_id) + + yield HumanInputFormFilledEvent( + rendered_content=rendered_content, + action_id=selected_action_id, + action_text=action_text, + ) + + yield StreamCompletedEvent( + node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs=outputs, edge_source_handle=selected_action_id, ) + ) - if form.status == HumanInputFormStatus.TIMEOUT or form.expiration_time <= naive_utc_now(): - outputs: dict[str, Any] = { - "__rendered_content": self._render_form_content_with_outputs( - form.rendered_content, - {}, - self._node_data.outputs_field_names(), - ) - } - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs=outputs, - edge_source_handle="__timeout", - ) - - return self._pause_with_form(form) - - def _pause_with_form(self, form_entity: HumanInputFormEntity) -> Generator[NodeEventBase, None, None]: - yield self._form_to_pause_event(form_entity) - - def _render_form_content(self) -> str: + def _render_form_content_before_submission(self) -> str: """ Process form content by substituting variables. diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py index 8c804d6bb5..1cae14b726 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py @@ -13,6 +13,7 @@ from typing import Any from sqlalchemy.orm import sessionmaker +from core.workflow.enums import WorkflowNodeExecutionStatus from extensions.logstore.aliyun_logstore import AliyunLogStore from models.workflow import WorkflowNodeExecutionModel from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository @@ -199,8 +200,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep reverse=True, ) - if deduplicated_results: - return _dict_to_workflow_node_execution_model(deduplicated_results[0]) + for row in deduplicated_results: + model = _dict_to_workflow_node_execution_model(row) + if model.status != WorkflowNodeExecutionStatus.PAUSED: + return model return None @@ -293,6 +296,8 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep if model and model.id: # Ensure model is valid models.append(model) + models = [model for model in models if model.status != WorkflowNodeExecutionStatus.PAUSED] + # Sort by index DESC for trace visualization models.sort(key=lambda x: x.index, reverse=True) diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index 7e2173acdd..b65206cc52 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -13,6 +13,7 @@ from sqlalchemy import asc, delete, desc, select from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, sessionmaker +from core.workflow.enums import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionModel from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository @@ -76,6 +77,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut WorkflowNodeExecutionModel.app_id == app_id, WorkflowNodeExecutionModel.workflow_id == workflow_id, WorkflowNodeExecutionModel.node_id == node_id, + WorkflowNodeExecutionModel.status != WorkflowNodeExecutionStatus.PAUSED, ) .order_by(desc(WorkflowNodeExecutionModel.created_at)) .limit(1) @@ -109,6 +111,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut WorkflowNodeExecutionModel.tenant_id == tenant_id, WorkflowNodeExecutionModel.app_id == app_id, WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, + WorkflowNodeExecutionModel.status != WorkflowNodeExecutionStatus.PAUSED, ).order_by(asc(WorkflowNodeExecutionModel.created_at)) with self._session_maker() as session: diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py new file mode 100644 index 0000000000..9a117cba10 --- /dev/null +++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py @@ -0,0 +1,336 @@ +import time +import uuid +from datetime import timedelta +from unittest.mock import MagicMock + +import pytest +from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from sqlalchemy import delete, select +from sqlalchemy.orm import Session + +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.entities import GraphInitParams +from core.workflow.enums import WorkflowType +from core.workflow.graph import Graph +from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from core.workflow.graph_engine.graph_engine import GraphEngine +from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.workflow.nodes.end.end_node import EndNode +from core.workflow.nodes.end.entities import EndNodeData +from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction +from core.workflow.nodes.human_input.enums import HumanInputFormStatus +from core.workflow.nodes.human_input.human_input_node import HumanInputNode +from core.workflow.nodes.start.entities import StartNodeData +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable +from libs.datetime_utils import naive_utc_now +from models import Account +from models.account import Tenant, TenantAccountJoin, TenantAccountRole +from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom +from models.model import App, AppMode, IconType +from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowRun + + +def _mock_form_repository_without_submission() -> HumanInputFormRepository: + repo = MagicMock(spec=HumanInputFormRepository) + form_entity = MagicMock(spec=HumanInputFormEntity) + form_entity.id = "test-form-id" + form_entity.web_app_token = "test-form-token" + form_entity.recipients = [] + form_entity.rendered_content = "rendered" + form_entity.submitted = False + repo.create_form.return_value = form_entity + repo.get_form.return_value = None + return repo + + +def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepository: + repo = MagicMock(spec=HumanInputFormRepository) + form_entity = MagicMock(spec=HumanInputFormEntity) + form_entity.id = "test-form-id" + form_entity.web_app_token = "test-form-token" + form_entity.recipients = [] + form_entity.rendered_content = "rendered" + form_entity.submitted = True + form_entity.selected_action_id = action_id + form_entity.submitted_data = {} + form_entity.status = HumanInputFormStatus.WAITING + form_entity.expiration_time = naive_utc_now() + timedelta(hours=1) + repo.get_form.return_value = form_entity + return repo + + +def _build_runtime_state(workflow_execution_id: str, app_id: str, workflow_id: str, user_id: str) -> GraphRuntimeState: + variable_pool = VariablePool( + system_variables=SystemVariable( + workflow_execution_id=workflow_execution_id, + app_id=app_id, + workflow_id=workflow_id, + user_id=user_id, + ), + user_inputs={}, + conversation_variables=[], + ) + return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + +def _build_graph( + runtime_state: GraphRuntimeState, + tenant_id: str, + app_id: str, + workflow_id: str, + user_id: str, + form_repository: HumanInputFormRepository, +) -> Graph: + graph_config: dict[str, object] = {"nodes": [], "edges": []} + params = GraphInitParams( + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow_id, + graph_config=graph_config, + user_id=user_id, + user_from="account", + invoke_from="debugger", + call_depth=0, + ) + + start_data = StartNodeData(title="start", variables=[]) + start_node = StartNode( + id="start", + config={"id": "start", "data": start_data.model_dump()}, + graph_init_params=params, + graph_runtime_state=runtime_state, + ) + + human_data = HumanInputNodeData( + title="human", + form_content="Awaiting human input", + inputs=[], + user_actions=[ + UserAction(id="continue", title="Continue"), + ], + ) + human_node = HumanInputNode( + id="human", + config={"id": "human", "data": human_data.model_dump()}, + graph_init_params=params, + graph_runtime_state=runtime_state, + form_repository=form_repository, + ) + + end_data = EndNodeData( + title="end", + outputs=[], + desc=None, + ) + end_node = EndNode( + id="end", + config={"id": "end", "data": end_data.model_dump()}, + graph_init_params=params, + graph_runtime_state=runtime_state, + ) + + return ( + Graph.new() + .add_root(start_node) + .add_node(human_node) + .add_node(end_node, from_node_id="human", source_handle="continue") + .build() + ) + + +def _build_generate_entity( + tenant_id: str, + app_id: str, + workflow_id: str, + workflow_execution_id: str, + user_id: str, +) -> WorkflowAppGenerateEntity: + app_config = WorkflowUIBasedAppConfig( + tenant_id=tenant_id, + app_id=app_id, + app_mode=AppMode.WORKFLOW, + workflow_id=workflow_id, + ) + return WorkflowAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + inputs={}, + files=[], + user_id=user_id, + stream=False, + invoke_from=InvokeFrom.DEBUGGER, + workflow_execution_id=workflow_execution_id, + ) + + +class TestHumanInputResumeNodeExecutionIntegration: + @pytest.fixture(autouse=True) + def setup_test_data(self, db_session_with_containers: Session): + tenant = Tenant( + name="Test Tenant", + status="normal", + ) + db_session_with_containers.add(tenant) + db_session_with_containers.commit() + + account = Account( + email="test@example.com", + name="Test User", + interface_language="en-US", + status="active", + ) + db_session_with_containers.add(account) + db_session_with_containers.commit() + + tenant_join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db_session_with_containers.add(tenant_join) + db_session_with_containers.commit() + + account.current_tenant = tenant + + app = App( + tenant_id=tenant.id, + name="Test App", + description="", + mode=AppMode.WORKFLOW.value, + icon_type=IconType.EMOJI.value, + icon="rocket", + icon_background="#4ECDC4", + enable_site=False, + enable_api=False, + api_rpm=0, + api_rph=0, + is_demo=False, + is_public=False, + is_universal=False, + max_active_requests=None, + created_by=account.id, + updated_by=account.id, + ) + db_session_with_containers.add(app) + db_session_with_containers.commit() + + workflow = Workflow( + tenant_id=tenant.id, + app_id=app.id, + type="workflow", + version="draft", + graph='{"nodes": [], "edges": []}', + features='{"file_upload": {"enabled": false}}', + created_by=account.id, + created_at=naive_utc_now(), + ) + db_session_with_containers.add(workflow) + db_session_with_containers.commit() + + self.session = db_session_with_containers + self.tenant = tenant + self.account = account + self.app = app + self.workflow = workflow + + yield + + self.session.execute(delete(WorkflowNodeExecutionModel)) + self.session.execute(delete(WorkflowRun)) + self.session.execute(delete(Workflow).where(Workflow.id == self.workflow.id)) + self.session.execute(delete(App).where(App.id == self.app.id)) + self.session.execute(delete(TenantAccountJoin).where(TenantAccountJoin.tenant_id == self.tenant.id)) + self.session.execute(delete(Account).where(Account.id == self.account.id)) + self.session.execute(delete(Tenant).where(Tenant.id == self.tenant.id)) + self.session.commit() + + def _build_persistence_layer(self, execution_id: str) -> WorkflowPersistenceLayer: + generate_entity = _build_generate_entity( + tenant_id=self.tenant.id, + app_id=self.app.id, + workflow_id=self.workflow.id, + workflow_execution_id=execution_id, + user_id=self.account.id, + ) + execution_repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=self.session.get_bind(), + user=self.account, + app_id=self.app.id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + ) + node_execution_repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=self.session.get_bind(), + user=self.account, + app_id=self.app.id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + return WorkflowPersistenceLayer( + application_generate_entity=generate_entity, + workflow_info=PersistenceWorkflowInfo( + workflow_id=self.workflow.id, + workflow_type=WorkflowType.WORKFLOW, + version=self.workflow.version, + graph_data=self.workflow.graph_dict, + ), + workflow_execution_repository=execution_repo, + workflow_node_execution_repository=node_execution_repo, + ) + + def _run_graph(self, graph: Graph, runtime_state: GraphRuntimeState, execution_id: str) -> None: + engine = GraphEngine( + workflow_id=self.workflow.id, + graph=graph, + graph_runtime_state=runtime_state, + command_channel=InMemoryChannel(), + ) + engine.layer(self._build_persistence_layer(execution_id)) + for _ in engine.run(): + continue + + def test_resume_human_input_does_not_create_duplicate_node_execution(self): + execution_id = str(uuid.uuid4()) + runtime_state = _build_runtime_state( + workflow_execution_id=execution_id, + app_id=self.app.id, + workflow_id=self.workflow.id, + user_id=self.account.id, + ) + pause_repo = _mock_form_repository_without_submission() + paused_graph = _build_graph( + runtime_state, + self.tenant.id, + self.app.id, + self.workflow.id, + self.account.id, + pause_repo, + ) + self._run_graph(paused_graph, runtime_state, execution_id) + + snapshot = runtime_state.dumps() + resumed_state = GraphRuntimeState.from_snapshot(snapshot) + resume_repo = _mock_form_repository_with_submission(action_id="continue") + resumed_graph = _build_graph( + resumed_state, + self.tenant.id, + self.app.id, + self.workflow.id, + self.account.id, + resume_repo, + ) + self._run_graph(resumed_graph, resumed_state, execution_id) + + stmt = select(WorkflowNodeExecutionModel).where( + WorkflowNodeExecutionModel.workflow_run_id == execution_id, + WorkflowNodeExecutionModel.node_id == "human", + ) + records = self.session.execute(stmt).scalars().all() + assert len(records) == 1 + assert records[0].status != "paused" + assert records[0].triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN + assert records[0].created_by_role == CreatorUserRole.ACCOUNT diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py index 23c4eeb82f..f13688b14b 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py @@ -465,6 +465,27 @@ class TestWorkflowRunService: db.session.add(node_execution) node_executions.append(node_execution) + paused_node_execution = WorkflowNodeExecutionModel( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow_run.workflow_id, + triggered_from="workflow-run", + workflow_run_id=workflow_run.id, + index=99, + node_id="node_paused", + node_type="human_input", + title="Paused Node", + inputs=json.dumps({"input": "paused"}), + process_data=json.dumps({"process": "paused"}), + status="paused", + elapsed_time=0.5, + execution_metadata=json.dumps({"tokens": 0}), + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + created_at=datetime.now(UTC), + ) + db.session.add(paused_node_execution) + db.session.commit() # Act: Execute the method under test @@ -477,6 +498,7 @@ class TestWorkflowRunService: # Verify node execution properties for node_execution in result: + assert node_execution.status != "paused" assert node_execution.tenant_id == app.tenant_id assert node_execution.app_id == app.id assert node_execution.workflow_run_id == workflow_run.id diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py index 999275e5a3..889d1389b1 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py @@ -9,7 +9,7 @@ import pytest from pydantic import ValidationError from core.workflow.entities import GraphInitParams -from core.workflow.node_events import NodeRunResult, PauseRequestedEvent +from core.workflow.node_events import PauseRequestedEvent from core.workflow.node_events.node import StreamCompletedEvent from core.workflow.nodes.human_input.entities import ( EmailDeliveryConfig, diff --git a/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py index 32d2f8b7e0..0538ff2581 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py @@ -5,6 +5,7 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session +from core.workflow.enums import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionModel from repositories.sqlalchemy_api_workflow_node_execution_repository import ( DifyAPISQLAlchemyWorkflowNodeExecutionRepository, @@ -52,6 +53,9 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: call_args = mock_session.scalar.call_args[0][0] assert hasattr(call_args, "compile") # It's a SQLAlchemy statement + compiled = call_args.compile() + assert WorkflowNodeExecutionStatus.PAUSED in compiled.params.values() + def test_get_node_last_execution_not_found(self, repository): """Test getting the last execution for a node when it doesn't exist.""" # Arrange @@ -93,6 +97,9 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: call_args = mock_session.execute.call_args[0][0] assert hasattr(call_args, "compile") # It's a SQLAlchemy statement + compiled = call_args.compile() + assert WorkflowNodeExecutionStatus.PAUSED in compiled.params.values() + def test_get_executions_by_workflow_run_empty(self, repository): """Test getting executions for a workflow run when none exist.""" # Arrange