feat: Human Input Node (#32060)

The frontend and backend implementation for the human input node.

Co-authored-by: twwu <twwu@dify.ai>
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
Co-authored-by: zhsama <torvalds@linux.do>
This commit is contained in:
QuantumGhost
2026-02-09 14:57:23 +08:00
committed by GitHub
parent 56e3a55023
commit a1fc280102
474 changed files with 32667 additions and 2050 deletions

View File

@ -0,0 +1,166 @@
"""TestContainers integration tests for ChatConversationApi status_count behavior."""
import json
import uuid
from flask.testing import FlaskClient
from sqlalchemy.orm import Session
from configs import dify_config
from constants import HEADER_NAME_CSRF_TOKEN
from core.workflow.enums import WorkflowExecutionStatus
from libs.datetime_utils import naive_utc_now
from libs.token import _real_cookie_name, generate_csrf_token
from models import Account, DifySetup, Tenant, TenantAccountJoin
from models.account import AccountStatus, TenantAccountRole
from models.enums import CreatorUserRole
from models.model import App, AppMode, Conversation, Message
from models.workflow import WorkflowRun
from services.account_service import AccountService
def _create_account_and_tenant(db_session: Session) -> tuple[Account, Tenant]:
account = Account(
email=f"test-{uuid.uuid4()}@example.com",
name="Test User",
interface_language="en-US",
status=AccountStatus.ACTIVE,
)
account.initialized_at = naive_utc_now()
db_session.add(account)
db_session.commit()
tenant = Tenant(name="Test Tenant", status="normal")
db_session.add(tenant)
db_session.commit()
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db_session.add(join)
db_session.commit()
account.set_tenant_id(tenant.id)
account.timezone = "UTC"
db_session.commit()
dify_setup = DifySetup(version=dify_config.project.version)
db_session.add(dify_setup)
db_session.commit()
return account, tenant
def _create_app(db_session: Session, tenant_id: str, account_id: str) -> App:
app = App(
tenant_id=tenant_id,
name="Test Chat App",
mode=AppMode.CHAT,
enable_site=True,
enable_api=True,
created_by=account_id,
)
db_session.add(app)
db_session.commit()
return app
def _create_conversation(db_session: Session, app_id: str, account_id: str) -> Conversation:
conversation = Conversation(
app_id=app_id,
name="Test Conversation",
inputs={},
status="normal",
mode=AppMode.CHAT,
from_source=CreatorUserRole.ACCOUNT,
from_account_id=account_id,
)
db_session.add(conversation)
db_session.commit()
return conversation
def _create_workflow_run(db_session: Session, app_id: str, tenant_id: str, account_id: str) -> WorkflowRun:
workflow_run = WorkflowRun(
tenant_id=tenant_id,
app_id=app_id,
workflow_id=str(uuid.uuid4()),
type="chat",
triggered_from="app-run",
version="1.0.0",
graph=json.dumps({"nodes": [], "edges": []}),
inputs=json.dumps({"query": "test"}),
status=WorkflowExecutionStatus.PAUSED,
outputs=json.dumps({}),
elapsed_time=0.0,
total_tokens=0,
total_steps=0,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account_id,
created_at=naive_utc_now(),
)
db_session.add(workflow_run)
db_session.commit()
return workflow_run
def _create_message(
db_session: Session, app_id: str, conversation_id: str, workflow_run_id: str, account_id: str
) -> Message:
message = Message(
app_id=app_id,
conversation_id=conversation_id,
query="Hello",
message={"type": "text", "content": "Hello"},
answer="Hi there",
message_tokens=1,
answer_tokens=1,
message_unit_price=0.001,
answer_unit_price=0.001,
message_price_unit=0.001,
answer_price_unit=0.001,
currency="USD",
status="normal",
from_source=CreatorUserRole.ACCOUNT,
from_account_id=account_id,
workflow_run_id=workflow_run_id,
inputs={"query": "Hello"},
)
db_session.add(message)
db_session.commit()
return message
def test_chat_conversation_status_count_includes_paused(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
):
account, tenant = _create_account_and_tenant(db_session_with_containers)
app = _create_app(db_session_with_containers, tenant.id, account.id)
conversation = _create_conversation(db_session_with_containers, app.id, account.id)
conversation_id = conversation.id
workflow_run = _create_workflow_run(db_session_with_containers, app.id, tenant.id, account.id)
_create_message(db_session_with_containers, app.id, conversation.id, workflow_run.id, account.id)
access_token = AccountService.get_account_jwt_token(account)
csrf_token = generate_csrf_token(account.id)
cookie_name = _real_cookie_name("csrf_token")
test_client_with_containers.set_cookie(cookie_name, csrf_token, domain="localhost")
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/chat-conversations",
headers={
"Authorization": f"Bearer {access_token}",
HEADER_NAME_CSRF_TOKEN: csrf_token,
},
)
assert response.status_code == 200
payload = response.get_json()
assert payload is not None
assert payload["total"] == 1
assert payload["data"][0]["id"] == conversation_id
assert payload["data"][0]["status_count"]["paused"] == 1

View File

@ -0,0 +1,240 @@
"""TestContainers integration tests for HumanInputFormRepositoryImpl."""
from __future__ import annotations
from uuid import uuid4
from sqlalchemy import Engine, select
from sqlalchemy.orm import Session
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
from core.workflow.nodes.human_input.entities import (
DeliveryChannelConfig,
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
ExternalRecipient,
FormDefinition,
HumanInputNodeData,
MemberRecipient,
UserAction,
WebAppDeliveryMethod,
)
from core.workflow.repositories.human_input_form_repository import FormCreateParams
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.human_input import (
EmailExternalRecipientPayload,
EmailMemberRecipientPayload,
HumanInputForm,
HumanInputFormRecipient,
RecipientType,
)
def _create_tenant_with_members(session: Session, member_emails: list[str]) -> tuple[Tenant, list[Account]]:
tenant = Tenant(name="Test Tenant", status="normal")
session.add(tenant)
session.flush()
members: list[Account] = []
for index, email in enumerate(member_emails):
account = Account(
email=email,
name=f"Member {index}",
interface_language="en-US",
status="active",
)
session.add(account)
session.flush()
tenant_join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.NORMAL,
current=True,
)
session.add(tenant_join)
members.append(account)
session.commit()
return tenant, members
def _build_form_params(delivery_methods: list[DeliveryChannelConfig]) -> FormCreateParams:
form_config = HumanInputNodeData(
title="Human Approval",
delivery_methods=delivery_methods,
form_content="<p>Approve?</p>",
user_actions=[UserAction(id="approve", title="Approve")],
)
return FormCreateParams(
app_id=str(uuid4()),
workflow_execution_id=str(uuid4()),
node_id="human-input-node",
form_config=form_config,
rendered_content="<p>Approve?</p>",
delivery_methods=delivery_methods,
display_in_ui=False,
resolved_default_values={},
)
def _build_email_delivery(
whole_workspace: bool, recipients: list[MemberRecipient | ExternalRecipient]
) -> EmailDeliveryMethod:
return EmailDeliveryMethod(
config=EmailDeliveryConfig(
recipients=EmailRecipients(whole_workspace=whole_workspace, items=recipients),
subject="Approval Needed",
body="Please review",
)
)
class TestHumanInputFormRepositoryImplWithContainers:
def test_create_form_with_whole_workspace_recipients(self, db_session_with_containers: Session) -> None:
engine = db_session_with_containers.get_bind()
assert isinstance(engine, Engine)
tenant, members = _create_tenant_with_members(
db_session_with_containers,
member_emails=["member1@example.com", "member2@example.com"],
)
repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id)
params = _build_form_params(
delivery_methods=[_build_email_delivery(whole_workspace=True, recipients=[])],
)
form_entity = repository.create_form(params)
with Session(engine) as verification_session:
recipients = verification_session.scalars(
select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_entity.id)
).all()
assert len(recipients) == len(members)
member_payloads = [
EmailMemberRecipientPayload.model_validate_json(recipient.recipient_payload)
for recipient in recipients
if recipient.recipient_type == RecipientType.EMAIL_MEMBER
]
member_emails = {payload.email for payload in member_payloads}
assert member_emails == {member.email for member in members}
def test_create_form_with_specific_members_and_external(self, db_session_with_containers: Session) -> None:
engine = db_session_with_containers.get_bind()
assert isinstance(engine, Engine)
tenant, members = _create_tenant_with_members(
db_session_with_containers,
member_emails=["primary@example.com", "secondary@example.com"],
)
repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id)
params = _build_form_params(
delivery_methods=[
_build_email_delivery(
whole_workspace=False,
recipients=[
MemberRecipient(user_id=members[0].id),
ExternalRecipient(email="external@example.com"),
],
)
],
)
form_entity = repository.create_form(params)
with Session(engine) as verification_session:
recipients = verification_session.scalars(
select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_entity.id)
).all()
member_recipient_payloads = [
EmailMemberRecipientPayload.model_validate_json(recipient.recipient_payload)
for recipient in recipients
if recipient.recipient_type == RecipientType.EMAIL_MEMBER
]
assert len(member_recipient_payloads) == 1
assert member_recipient_payloads[0].user_id == members[0].id
external_payloads = [
EmailExternalRecipientPayload.model_validate_json(recipient.recipient_payload)
for recipient in recipients
if recipient.recipient_type == RecipientType.EMAIL_EXTERNAL
]
assert len(external_payloads) == 1
assert external_payloads[0].email == "external@example.com"
def test_create_form_persists_default_values(self, db_session_with_containers: Session) -> None:
engine = db_session_with_containers.get_bind()
assert isinstance(engine, Engine)
tenant, _ = _create_tenant_with_members(
db_session_with_containers,
member_emails=["prefill@example.com"],
)
repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id)
resolved_values = {"greeting": "Hello!"}
params = FormCreateParams(
app_id=str(uuid4()),
workflow_execution_id=str(uuid4()),
node_id="human-input-node",
form_config=HumanInputNodeData(
title="Human Approval",
form_content="<p>Approve?</p>",
inputs=[],
user_actions=[UserAction(id="approve", title="Approve")],
),
rendered_content="<p>Approve?</p>",
delivery_methods=[],
display_in_ui=False,
resolved_default_values=resolved_values,
)
form_entity = repository.create_form(params)
with Session(engine) as verification_session:
form_model = verification_session.scalars(
select(HumanInputForm).where(HumanInputForm.id == form_entity.id)
).first()
assert form_model is not None
definition = FormDefinition.model_validate_json(form_model.form_definition)
assert definition.default_values == resolved_values
def test_create_form_persists_display_in_ui(self, db_session_with_containers: Session) -> None:
engine = db_session_with_containers.get_bind()
assert isinstance(engine, Engine)
tenant, _ = _create_tenant_with_members(
db_session_with_containers,
member_emails=["ui@example.com"],
)
repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id)
params = FormCreateParams(
app_id=str(uuid4()),
workflow_execution_id=str(uuid4()),
node_id="human-input-node",
form_config=HumanInputNodeData(
title="Human Approval",
form_content="<p>Approve?</p>",
inputs=[],
user_actions=[UserAction(id="approve", title="Approve")],
delivery_methods=[WebAppDeliveryMethod()],
),
rendered_content="<p>Approve?</p>",
delivery_methods=[WebAppDeliveryMethod()],
display_in_ui=True,
resolved_default_values={},
)
form_entity = repository.create_form(params)
with Session(engine) as verification_session:
form_model = verification_session.scalars(
select(HumanInputForm).where(HumanInputForm.id == form_entity.id)
).first()
assert form_model is not None
definition = FormDefinition.model_validate_json(form_model.form_definition)
assert definition.display_in_ui is True

View File

@ -0,0 +1,336 @@
import time
import uuid
from datetime import timedelta
from unittest.mock import MagicMock
import pytest
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities import GraphInitParams
from core.workflow.enums import WorkflowType
from core.workflow.graph import Graph
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from libs.datetime_utils import naive_utc_now
from models import Account
from models.account import Tenant, TenantAccountJoin, TenantAccountRole
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
from models.model import App, AppMode, IconType
from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowRun
def _mock_form_repository_without_submission() -> HumanInputFormRepository:
repo = MagicMock(spec=HumanInputFormRepository)
form_entity = MagicMock(spec=HumanInputFormEntity)
form_entity.id = "test-form-id"
form_entity.web_app_token = "test-form-token"
form_entity.recipients = []
form_entity.rendered_content = "rendered"
form_entity.submitted = False
repo.create_form.return_value = form_entity
repo.get_form.return_value = None
return repo
def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepository:
repo = MagicMock(spec=HumanInputFormRepository)
form_entity = MagicMock(spec=HumanInputFormEntity)
form_entity.id = "test-form-id"
form_entity.web_app_token = "test-form-token"
form_entity.recipients = []
form_entity.rendered_content = "rendered"
form_entity.submitted = True
form_entity.selected_action_id = action_id
form_entity.submitted_data = {}
form_entity.status = HumanInputFormStatus.WAITING
form_entity.expiration_time = naive_utc_now() + timedelta(hours=1)
repo.get_form.return_value = form_entity
return repo
def _build_runtime_state(workflow_execution_id: str, app_id: str, workflow_id: str, user_id: str) -> GraphRuntimeState:
variable_pool = VariablePool(
system_variables=SystemVariable(
workflow_execution_id=workflow_execution_id,
app_id=app_id,
workflow_id=workflow_id,
user_id=user_id,
),
user_inputs={},
conversation_variables=[],
)
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
def _build_graph(
runtime_state: GraphRuntimeState,
tenant_id: str,
app_id: str,
workflow_id: str,
user_id: str,
form_repository: HumanInputFormRepository,
) -> Graph:
graph_config: dict[str, object] = {"nodes": [], "edges": []}
params = GraphInitParams(
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
graph_config=graph_config,
user_id=user_id,
user_from="account",
invoke_from="debugger",
call_depth=0,
)
start_data = StartNodeData(title="start", variables=[])
start_node = StartNode(
id="start",
config={"id": "start", "data": start_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
)
human_data = HumanInputNodeData(
title="human",
form_content="Awaiting human input",
inputs=[],
user_actions=[
UserAction(id="continue", title="Continue"),
],
)
human_node = HumanInputNode(
id="human",
config={"id": "human", "data": human_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
form_repository=form_repository,
)
end_data = EndNodeData(
title="end",
outputs=[],
desc=None,
)
end_node = EndNode(
id="end",
config={"id": "end", "data": end_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
)
return (
Graph.new()
.add_root(start_node)
.add_node(human_node)
.add_node(end_node, from_node_id="human", source_handle="continue")
.build()
)
def _build_generate_entity(
tenant_id: str,
app_id: str,
workflow_id: str,
workflow_execution_id: str,
user_id: str,
) -> WorkflowAppGenerateEntity:
app_config = WorkflowUIBasedAppConfig(
tenant_id=tenant_id,
app_id=app_id,
app_mode=AppMode.WORKFLOW,
workflow_id=workflow_id,
)
return WorkflowAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
inputs={},
files=[],
user_id=user_id,
stream=False,
invoke_from=InvokeFrom.DEBUGGER,
workflow_execution_id=workflow_execution_id,
)
class TestHumanInputResumeNodeExecutionIntegration:
@pytest.fixture(autouse=True)
def setup_test_data(self, db_session_with_containers: Session):
tenant = Tenant(
name="Test Tenant",
status="normal",
)
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
account = Account(
email="test@example.com",
name="Test User",
interface_language="en-US",
status="active",
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant_join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db_session_with_containers.add(tenant_join)
db_session_with_containers.commit()
account.current_tenant = tenant
app = App(
tenant_id=tenant.id,
name="Test App",
description="",
mode=AppMode.WORKFLOW.value,
icon_type=IconType.EMOJI.value,
icon="rocket",
icon_background="#4ECDC4",
enable_site=False,
enable_api=False,
api_rpm=0,
api_rph=0,
is_demo=False,
is_public=False,
is_universal=False,
max_active_requests=None,
created_by=account.id,
updated_by=account.id,
)
db_session_with_containers.add(app)
db_session_with_containers.commit()
workflow = Workflow(
tenant_id=tenant.id,
app_id=app.id,
type="workflow",
version="draft",
graph='{"nodes": [], "edges": []}',
features='{"file_upload": {"enabled": false}}',
created_by=account.id,
created_at=naive_utc_now(),
)
db_session_with_containers.add(workflow)
db_session_with_containers.commit()
self.session = db_session_with_containers
self.tenant = tenant
self.account = account
self.app = app
self.workflow = workflow
yield
self.session.execute(delete(WorkflowNodeExecutionModel))
self.session.execute(delete(WorkflowRun))
self.session.execute(delete(Workflow).where(Workflow.id == self.workflow.id))
self.session.execute(delete(App).where(App.id == self.app.id))
self.session.execute(delete(TenantAccountJoin).where(TenantAccountJoin.tenant_id == self.tenant.id))
self.session.execute(delete(Account).where(Account.id == self.account.id))
self.session.execute(delete(Tenant).where(Tenant.id == self.tenant.id))
self.session.commit()
def _build_persistence_layer(self, execution_id: str) -> WorkflowPersistenceLayer:
generate_entity = _build_generate_entity(
tenant_id=self.tenant.id,
app_id=self.app.id,
workflow_id=self.workflow.id,
workflow_execution_id=execution_id,
user_id=self.account.id,
)
execution_repo = SQLAlchemyWorkflowExecutionRepository(
session_factory=self.session.get_bind(),
user=self.account,
app_id=self.app.id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
node_execution_repo = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=self.session.get_bind(),
user=self.account,
app_id=self.app.id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
return WorkflowPersistenceLayer(
application_generate_entity=generate_entity,
workflow_info=PersistenceWorkflowInfo(
workflow_id=self.workflow.id,
workflow_type=WorkflowType.WORKFLOW,
version=self.workflow.version,
graph_data=self.workflow.graph_dict,
),
workflow_execution_repository=execution_repo,
workflow_node_execution_repository=node_execution_repo,
)
def _run_graph(self, graph: Graph, runtime_state: GraphRuntimeState, execution_id: str) -> None:
engine = GraphEngine(
workflow_id=self.workflow.id,
graph=graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
)
engine.layer(self._build_persistence_layer(execution_id))
for _ in engine.run():
continue
def test_resume_human_input_does_not_create_duplicate_node_execution(self):
execution_id = str(uuid.uuid4())
runtime_state = _build_runtime_state(
workflow_execution_id=execution_id,
app_id=self.app.id,
workflow_id=self.workflow.id,
user_id=self.account.id,
)
pause_repo = _mock_form_repository_without_submission()
paused_graph = _build_graph(
runtime_state,
self.tenant.id,
self.app.id,
self.workflow.id,
self.account.id,
pause_repo,
)
self._run_graph(paused_graph, runtime_state, execution_id)
snapshot = runtime_state.dumps()
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
resume_repo = _mock_form_repository_with_submission(action_id="continue")
resumed_graph = _build_graph(
resumed_state,
self.tenant.id,
self.app.id,
self.workflow.id,
self.account.id,
resume_repo,
)
self._run_graph(resumed_graph, resumed_state, execution_id)
stmt = select(WorkflowNodeExecutionModel).where(
WorkflowNodeExecutionModel.workflow_run_id == execution_id,
WorkflowNodeExecutionModel.node_id == "human",
)
records = self.session.execute(stmt).scalars().all()
assert len(records) == 1
assert records[0].status != "paused"
assert records[0].triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
assert records[0].created_by_role == CreatorUserRole.ACCOUNT

View File

@ -0,0 +1 @@
"""Helper utilities for integration tests."""

View File

@ -0,0 +1,154 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timedelta
from decimal import Decimal
from uuid import uuid4
from core.workflow.nodes.human_input.entities import FormDefinition, UserAction
from models.account import Account, Tenant, TenantAccountJoin
from models.execution_extra_content import HumanInputContent
from models.human_input import HumanInputForm, HumanInputFormStatus
from models.model import App, Conversation, Message
@dataclass
class HumanInputMessageFixture:
app: App
account: Account
conversation: Conversation
message: Message
form: HumanInputForm
action_id: str
action_text: str
node_title: str
def create_human_input_message_fixture(db_session) -> HumanInputMessageFixture:
tenant = Tenant(name=f"Tenant {uuid4()}")
db_session.add(tenant)
db_session.flush()
account = Account(
name=f"Account {uuid4()}",
email=f"human_input_{uuid4()}@example.com",
password="hashed-password",
password_salt="salt",
interface_language="en-US",
timezone="UTC",
)
db_session.add(account)
db_session.flush()
tenant_join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role="owner",
current=True,
)
db_session.add(tenant_join)
db_session.flush()
app = App(
tenant_id=tenant.id,
name=f"App {uuid4()}",
description="",
mode="chat",
icon_type="emoji",
icon="🤖",
icon_background="#FFFFFF",
enable_site=False,
enable_api=True,
api_rpm=100,
api_rph=100,
is_demo=False,
is_public=False,
is_universal=False,
created_by=account.id,
updated_by=account.id,
)
db_session.add(app)
db_session.flush()
conversation = Conversation(
app_id=app.id,
mode="chat",
name="Test Conversation",
summary="",
introduction="",
system_instruction="",
status="normal",
invoke_from="console",
from_source="console",
from_account_id=account.id,
from_end_user_id=None,
)
conversation.inputs = {}
db_session.add(conversation)
db_session.flush()
workflow_run_id = str(uuid4())
message = Message(
app_id=app.id,
conversation_id=conversation.id,
inputs={},
query="Human input query",
message={"messages": []},
answer="Human input answer",
message_tokens=50,
message_unit_price=Decimal("0.001"),
answer_tokens=80,
answer_unit_price=Decimal("0.001"),
provider_response_latency=0.5,
currency="USD",
from_source="console",
from_account_id=account.id,
workflow_run_id=workflow_run_id,
)
db_session.add(message)
db_session.flush()
action_id = "approve"
action_text = "Approve request"
node_title = "Approval"
form_definition = FormDefinition(
form_content="content",
inputs=[],
user_actions=[UserAction(id=action_id, title=action_text)],
rendered_content="Rendered block",
expiration_time=datetime.utcnow() + timedelta(days=1),
node_title=node_title,
display_in_ui=True,
)
form = HumanInputForm(
tenant_id=tenant.id,
app_id=app.id,
workflow_run_id=workflow_run_id,
node_id="node-id",
form_definition=form_definition.model_dump_json(),
rendered_content="Rendered block",
status=HumanInputFormStatus.SUBMITTED,
expiration_time=datetime.utcnow() + timedelta(days=1),
selected_action_id=action_id,
)
db_session.add(form)
db_session.flush()
content = HumanInputContent(
workflow_run_id=workflow_run_id,
message_id=message.id,
form_id=form.id,
)
db_session.add(content)
db_session.commit()
return HumanInputMessageFixture(
app=app,
account=account,
conversation=conversation,
message=message,
form=form,
action_id=action_id,
action_text=action_text,
node_title=node_title,
)

View File

@ -16,6 +16,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
import pytest
import redis
from redis.cluster import RedisCluster
from testcontainers.redis import RedisContainer
from libs.broadcast_channel.channel import BroadcastChannel, Subscription, Topic
@ -332,3 +333,95 @@ class TestShardedRedisBroadcastChannelIntegration:
# Verify subscriptions are cleaned up
topic_subscribers_after = self._get_sharded_numsub(redis_client, topic_name)
assert topic_subscribers_after == 0
class TestShardedRedisBroadcastChannelClusterIntegration:
"""Integration tests for sharded pub/sub with RedisCluster client."""
@pytest.fixture(scope="class")
def redis_cluster_container(self) -> Iterator[RedisContainer]:
"""Create a Redis 7 container with cluster mode enabled."""
command = (
"redis-server --port 6379 "
"--cluster-enabled yes "
"--cluster-config-file nodes.conf "
"--cluster-node-timeout 5000 "
"--appendonly no "
"--protected-mode no"
)
with RedisContainer(image="redis:7-alpine").with_command(command) as container:
yield container
@classmethod
def _get_test_topic_name(cls) -> str:
return f"test_sharded_cluster_topic_{uuid.uuid4()}"
@staticmethod
def _ensure_single_node_cluster(host: str, port: int) -> None:
client = redis.Redis(host=host, port=port, decode_responses=False)
client.config_set("cluster-announce-ip", host)
client.config_set("cluster-announce-port", port)
slots = client.execute_command("CLUSTER", "SLOTS")
if not slots:
client.execute_command("CLUSTER", "ADDSLOTSRANGE", 0, 16383)
deadline = time.time() + 5.0
while time.time() < deadline:
info = client.execute_command("CLUSTER", "INFO")
info_text = info.decode("utf-8") if isinstance(info, (bytes, bytearray)) else str(info)
if "cluster_state:ok" in info_text:
return
time.sleep(0.05)
raise RuntimeError("Redis cluster did not become ready in time")
@pytest.fixture(scope="class")
def redis_cluster_client(self, redis_cluster_container: RedisContainer) -> RedisCluster:
host = redis_cluster_container.get_container_host_ip()
port = int(redis_cluster_container.get_exposed_port(6379))
self._ensure_single_node_cluster(host, port)
return RedisCluster(host=host, port=port, decode_responses=False)
@pytest.fixture
def broadcast_channel(self, redis_cluster_client: RedisCluster) -> BroadcastChannel:
return ShardedRedisBroadcastChannel(redis_cluster_client)
def test_cluster_sharded_pubsub_delivers_message(self, broadcast_channel: BroadcastChannel):
"""Ensure sharded subscription receives messages when using RedisCluster client."""
topic_name = self._get_test_topic_name()
message = b"cluster sharded message"
topic = broadcast_channel.topic(topic_name)
producer = topic.as_producer()
subscription = topic.subscribe()
ready_event = threading.Event()
def consumer_thread() -> list[bytes]:
received = []
try:
_ = subscription.receive(0.01)
except SubscriptionClosedError:
return received
ready_event.set()
deadline = time.time() + 5.0
while time.time() < deadline:
msg = subscription.receive(timeout=0.1)
if msg is None:
continue
received.append(msg)
break
subscription.close()
return received
def producer_thread():
if not ready_event.wait(timeout=2.0):
pytest.fail("subscriber did not become ready before publish")
producer.publish(message)
with ThreadPoolExecutor(max_workers=2) as executor:
consumer_future = executor.submit(consumer_thread)
producer_future = executor.submit(producer_thread)
producer_future.result(timeout=5.0)
received_messages = consumer_future.result(timeout=5.0)
assert received_messages == [message]

View File

@ -0,0 +1,25 @@
"""
Integration tests for RateLimiter using testcontainers Redis.
"""
import uuid
import pytest
from extensions.ext_redis import redis_client
from libs import helper as helper_module
@pytest.mark.usefixtures("flask_app_with_containers")
def test_rate_limiter_counts_multiple_attempts_in_same_second(monkeypatch):
prefix = f"test_rate_limit:{uuid.uuid4().hex}"
limiter = helper_module.RateLimiter(prefix=prefix, max_attempts=2, time_window=60)
key = limiter._get_key("203.0.113.10")
redis_client.delete(key)
monkeypatch.setattr(helper_module.time, "time", lambda: 1_700_000_000)
limiter.increment_rate_limit("203.0.113.10")
limiter.increment_rate_limit("203.0.113.10")
assert limiter.is_rate_limited("203.0.113.10") is True

View File

@ -0,0 +1,79 @@
# import secrets
# import pytest
# from sqlalchemy import select
# from sqlalchemy.orm import Session
# from sqlalchemy.orm.exc import DetachedInstanceError
# from libs.datetime_utils import naive_utc_now
# from models.account import Account, Tenant, TenantAccountJoin
# @pytest.fixture
# def session(db_session_with_containers):
# with Session(db_session_with_containers.get_bind()) as session:
# yield session
# @pytest.fixture
# def account(session):
# account = Account(
# name="test account",
# email=f"test_{secrets.token_hex(8)}@example.com",
# )
# session.add(account)
# session.commit()
# return account
# @pytest.fixture
# def tenant(session):
# tenant = Tenant(name="test tenant")
# session.add(tenant)
# session.commit()
# return tenant
# @pytest.fixture
# def tenant_account_join(session, account, tenant):
# tenant_join = TenantAccountJoin(account_id=account.id, tenant_id=tenant.id)
# session.add(tenant_join)
# session.commit()
# yield tenant_join
# session.delete(tenant_join)
# session.commit()
# class TestAccountTenant:
# def test_set_current_tenant_should_reload_tenant(
# self,
# db_session_with_containers,
# account,
# tenant,
# tenant_account_join,
# ):
# with Session(db_session_with_containers.get_bind(), expire_on_commit=True) as session:
# scoped_tenant = session.scalars(select(Tenant).where(Tenant.id == tenant.id)).one()
# account.current_tenant = scoped_tenant
# scoped_tenant.created_at = naive_utc_now()
# # session.commit()
# # Ensure the tenant used in assignment is detached.
# with pytest.raises(DetachedInstanceError):
# _ = scoped_tenant.name
# assert account._current_tenant.id == tenant.id
# assert account._current_tenant.id == tenant.id
# def test_set_tenant_id_should_load_tenant_as_not_expire(
# self,
# flask_app_with_containers,
# account,
# tenant,
# tenant_account_join,
# ):
# with flask_app_with_containers.test_request_context():
# account.set_tenant_id(tenant.id)
# assert account._current_tenant.id == tenant.id
# assert account._current_tenant.id == tenant.id

View File

@ -0,0 +1,27 @@
from __future__ import annotations
from sqlalchemy.orm import sessionmaker
from extensions.ext_database import db
from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository
from tests.test_containers_integration_tests.helpers.execution_extra_content import (
create_human_input_message_fixture,
)
def test_get_by_message_ids_returns_human_input_content(db_session_with_containers):
fixture = create_human_input_message_fixture(db_session_with_containers)
repository = SQLAlchemyExecutionExtraContentRepository(
session_maker=sessionmaker(bind=db.engine, expire_on_commit=False)
)
results = repository.get_by_message_ids([fixture.message.id])
assert len(results) == 1
assert len(results[0]) == 1
content = results[0][0]
assert content.submitted is True
assert content.form_submission_data is not None
assert content.form_submission_data.action_id == fixture.action_id
assert content.form_submission_data.action_text == fixture.action_text
assert content.form_submission_data.rendered_content == fixture.form.rendered_content

View File

@ -2309,6 +2309,12 @@ class TestRegisterService:
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
from extensions.ext_database import db
from models.model import DifySetup
db.session.query(DifySetup).delete()
db.session.commit()
# Execute setup
RegisterService.setup(
email=admin_email,
@ -2319,9 +2325,7 @@ class TestRegisterService:
)
# Verify account was created
from extensions.ext_database import db
from models import Account
from models.model import DifySetup
account = db.session.query(Account).filter_by(email=admin_email).first()
assert account is not None

View File

@ -1,5 +1,5 @@
import uuid
from unittest.mock import MagicMock, patch
from unittest.mock import ANY, MagicMock, patch
import pytest
from faker import Faker
@ -26,6 +26,7 @@ class TestAppGenerateService:
patch("services.app_generate_service.AgentChatAppGenerator") as mock_agent_chat_generator,
patch("services.app_generate_service.AdvancedChatAppGenerator") as mock_advanced_chat_generator,
patch("services.app_generate_service.WorkflowAppGenerator") as mock_workflow_generator,
patch("services.app_generate_service.MessageBasedAppGenerator") as mock_message_based_generator,
patch("services.account_service.FeatureService") as mock_account_feature_service,
patch("services.app_generate_service.dify_config") as mock_dify_config,
patch("configs.dify_config") as mock_global_dify_config,
@ -38,9 +39,13 @@ class TestAppGenerateService:
# Setup default mock returns for workflow service
mock_workflow_service_instance = mock_workflow_service.return_value
mock_workflow_service_instance.get_published_workflow.return_value = MagicMock(spec=Workflow)
mock_workflow_service_instance.get_draft_workflow.return_value = MagicMock(spec=Workflow)
mock_workflow_service_instance.get_published_workflow_by_id.return_value = MagicMock(spec=Workflow)
mock_published_workflow = MagicMock(spec=Workflow)
mock_published_workflow.id = str(uuid.uuid4())
mock_workflow_service_instance.get_published_workflow.return_value = mock_published_workflow
mock_draft_workflow = MagicMock(spec=Workflow)
mock_draft_workflow.id = str(uuid.uuid4())
mock_workflow_service_instance.get_draft_workflow.return_value = mock_draft_workflow
mock_workflow_service_instance.get_published_workflow_by_id.return_value = mock_published_workflow
# Setup default mock returns for rate limiting
mock_rate_limit_instance = mock_rate_limit.return_value
@ -66,6 +71,8 @@ class TestAppGenerateService:
mock_advanced_chat_generator_instance.generate.return_value = ["advanced_chat_response"]
mock_advanced_chat_generator_instance.single_iteration_generate.return_value = ["single_iteration_response"]
mock_advanced_chat_generator_instance.single_loop_generate.return_value = ["single_loop_response"]
mock_advanced_chat_generator_instance.retrieve_events.return_value = ["advanced_chat_events"]
mock_advanced_chat_generator_instance.convert_to_event_stream.return_value = ["advanced_chat_stream"]
mock_advanced_chat_generator.convert_to_event_stream.return_value = ["advanced_chat_stream"]
mock_workflow_generator_instance = mock_workflow_generator.return_value
@ -76,6 +83,8 @@ class TestAppGenerateService:
mock_workflow_generator_instance.single_loop_generate.return_value = ["workflow_single_loop_response"]
mock_workflow_generator.convert_to_event_stream.return_value = ["workflow_stream"]
mock_message_based_generator.retrieve_events.return_value = ["workflow_events"]
# Setup default mock returns for account service
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
@ -88,6 +97,7 @@ class TestAppGenerateService:
mock_global_dify_config.BILLING_ENABLED = False
mock_global_dify_config.APP_MAX_ACTIVE_REQUESTS = 100
mock_global_dify_config.APP_DAILY_RATE_LIMIT = 1000
mock_global_dify_config.HOSTED_POOL_CREDITS = 1000
yield {
"billing_service": mock_billing_service,
@ -98,6 +108,7 @@ class TestAppGenerateService:
"agent_chat_generator": mock_agent_chat_generator,
"advanced_chat_generator": mock_advanced_chat_generator,
"workflow_generator": mock_workflow_generator,
"message_based_generator": mock_message_based_generator,
"account_feature_service": mock_account_feature_service,
"dify_config": mock_dify_config,
"global_dify_config": mock_global_dify_config,
@ -280,8 +291,10 @@ class TestAppGenerateService:
assert result == ["test_response"]
# Verify advanced chat generator was called
mock_external_service_dependencies["advanced_chat_generator"].return_value.generate.assert_called_once()
mock_external_service_dependencies["advanced_chat_generator"].convert_to_event_stream.assert_called_once()
mock_external_service_dependencies["advanced_chat_generator"].return_value.retrieve_events.assert_called_once()
mock_external_service_dependencies[
"advanced_chat_generator"
].return_value.convert_to_event_stream.assert_called_once()
def test_generate_workflow_mode_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
@ -304,7 +317,7 @@ class TestAppGenerateService:
assert result == ["test_response"]
# Verify workflow generator was called
mock_external_service_dependencies["workflow_generator"].return_value.generate.assert_called_once()
mock_external_service_dependencies["message_based_generator"].retrieve_events.assert_called_once()
mock_external_service_dependencies["workflow_generator"].convert_to_event_stream.assert_called_once()
def test_generate_with_specific_workflow_id(self, db_session_with_containers, mock_external_service_dependencies):
@ -970,14 +983,27 @@ class TestAppGenerateService:
}
# Execute the method under test
result = AppGenerateService.generate(
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
)
with patch("services.app_generate_service.AppExecutionParams") as mock_exec_params:
mock_payload = MagicMock()
mock_payload.workflow_run_id = fake.uuid4()
mock_payload.model_dump_json.return_value = "{}"
mock_exec_params.new.return_value = mock_payload
result = AppGenerateService.generate(
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
)
# Verify the result
assert result == ["test_response"]
# Verify workflow generator was called with complex args
mock_external_service_dependencies["workflow_generator"].return_value.generate.assert_called_once()
call_args = mock_external_service_dependencies["workflow_generator"].return_value.generate.call_args
assert call_args[1]["args"] == args
# Verify payload was built with complex args
mock_exec_params.new.assert_called_once()
call_kwargs = mock_exec_params.new.call_args.kwargs
assert call_kwargs["args"] == args
# Verify workflow streaming event retrieval was used
mock_external_service_dependencies["message_based_generator"].retrieve_events.assert_called_once_with(
ANY,
mock_payload.workflow_run_id,
on_subscribe=ANY,
)

View File

@ -0,0 +1,112 @@
import json
import uuid
from unittest.mock import MagicMock
import pytest
from core.workflow.enums import NodeType
from core.workflow.nodes.human_input.entities import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
ExternalRecipient,
HumanInputNodeData,
)
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.model import App, AppMode
from models.workflow import Workflow, WorkflowType
from services.workflow_service import WorkflowService
def _create_app_with_draft_workflow(session, *, delivery_method_id: uuid.UUID) -> tuple[App, Account]:
tenant = Tenant(name="Test Tenant")
account = Account(name="Tester", email="tester@example.com")
session.add_all([tenant, account])
session.flush()
session.add(
TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
current=True,
role=TenantAccountRole.OWNER.value,
)
)
app = App(
tenant_id=tenant.id,
name="Test App",
description="",
mode=AppMode.WORKFLOW.value,
icon_type="emoji",
icon="app",
icon_background="#ffffff",
enable_site=True,
enable_api=True,
created_by=account.id,
updated_by=account.id,
)
session.add(app)
session.flush()
email_method = EmailDeliveryMethod(
id=delivery_method_id,
enabled=True,
config=EmailDeliveryConfig(
recipients=EmailRecipients(
whole_workspace=False,
items=[ExternalRecipient(email="recipient@example.com")],
),
subject="Test {{recipient_email}}",
body="Body {{#url#}} {{form_content}}",
),
)
node_data = HumanInputNodeData(
title="Human Input",
delivery_methods=[email_method],
form_content="Hello Human Input",
inputs=[],
user_actions=[],
).model_dump(mode="json")
node_data["type"] = NodeType.HUMAN_INPUT.value
graph = json.dumps({"nodes": [{"id": "human-node", "data": node_data}], "edges": []})
workflow = Workflow.new(
tenant_id=tenant.id,
app_id=app.id,
type=WorkflowType.WORKFLOW.value,
version=Workflow.VERSION_DRAFT,
graph=graph,
features=json.dumps({}),
created_by=account.id,
environment_variables=[],
conversation_variables=[],
rag_pipeline_variables=[],
)
session.add(workflow)
session.commit()
return app, account
def test_human_input_delivery_test_sends_email(
db_session_with_containers,
monkeypatch: pytest.MonkeyPatch,
) -> None:
delivery_method_id = uuid.uuid4()
app, account = _create_app_with_draft_workflow(db_session_with_containers, delivery_method_id=delivery_method_id)
send_mock = MagicMock()
monkeypatch.setattr("services.human_input_delivery_test_service.mail.is_inited", lambda: True)
monkeypatch.setattr("services.human_input_delivery_test_service.mail.send", send_mock)
service = WorkflowService()
service.test_human_input_delivery(
app_model=app,
account=account,
node_id="human-node",
delivery_method_id=str(delivery_method_id),
)
assert send_mock.call_count == 1
assert send_mock.call_args.kwargs["to"] == "recipient@example.com"

View File

@ -0,0 +1,38 @@
from __future__ import annotations
import pytest
from services.message_service import MessageService
from tests.test_containers_integration_tests.helpers.execution_extra_content import (
create_human_input_message_fixture,
)
@pytest.mark.usefixtures("flask_req_ctx_with_containers")
def test_pagination_returns_extra_contents(db_session_with_containers):
fixture = create_human_input_message_fixture(db_session_with_containers)
pagination = MessageService.pagination_by_first_id(
app_model=fixture.app,
user=fixture.account,
conversation_id=fixture.conversation.id,
first_id=None,
limit=10,
)
assert pagination.data
message = pagination.data[0]
assert message.extra_contents == [
{
"type": "human_input",
"workflow_run_id": fixture.message.workflow_run_id,
"submitted": True,
"form_submission_data": {
"node_id": fixture.form.node_id,
"node_title": fixture.node_title,
"rendered_content": fixture.form.rendered_content,
"action_id": fixture.action_id,
"action_text": fixture.action_text,
},
}
]

View File

@ -465,6 +465,27 @@ class TestWorkflowRunService:
db.session.add(node_execution)
node_executions.append(node_execution)
paused_node_execution = WorkflowNodeExecutionModel(
tenant_id=app.tenant_id,
app_id=app.id,
workflow_id=workflow_run.workflow_id,
triggered_from="workflow-run",
workflow_run_id=workflow_run.id,
index=99,
node_id="node_paused",
node_type="human_input",
title="Paused Node",
inputs=json.dumps({"input": "paused"}),
process_data=json.dumps({"process": "paused"}),
status="paused",
elapsed_time=0.5,
execution_metadata=json.dumps({"tokens": 0}),
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=datetime.now(UTC),
)
db.session.add(paused_node_execution)
db.session.commit()
# Act: Execute the method under test
@ -473,16 +494,19 @@ class TestWorkflowRunService:
# Assert: Verify the expected outcomes
assert result is not None
assert len(result) == 3
assert len(result) == 4
# Verify node execution properties
statuses = [node_execution.status for node_execution in result]
assert "paused" in statuses
assert statuses.count("succeeded") == 3
assert statuses.count("paused") == 1
for node_execution in result:
assert node_execution.tenant_id == app.tenant_id
assert node_execution.app_id == app.id
assert node_execution.workflow_run_id == workflow_run.id
assert node_execution.index in [0, 1, 2] # Check that index is one of the expected values
assert node_execution.node_id.startswith("node_") # Check that node_id starts with "node_"
assert node_execution.status == "succeeded"
assert node_execution.node_id.startswith("node_")
def test_get_workflow_run_node_executions_empty(
self, db_session_with_containers, mock_external_service_dependencies

View File

@ -6,6 +6,7 @@ from faker import Faker
from pydantic import ValidationError
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
from models.tools import WorkflowToolProvider
from models.workflow import Workflow as WorkflowModel
from services.account_service import AccountService, TenantService
@ -513,6 +514,62 @@ class TestWorkflowToolManageService:
assert tool_count == 0
def test_create_workflow_tool_human_input_node_error(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test workflow tool creation fails when workflow contains human input nodes.
This test verifies:
- Human input nodes prevent workflow tool publishing
- Correct error message
- No database changes when workflow is invalid
"""
fake = Faker()
# Create test data
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
workflow.graph = json.dumps(
{
"nodes": [
{
"id": "human_input_node",
"data": {"type": "human-input"},
}
]
}
)
tool_parameters = self._create_test_workflow_tool_parameters()
with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info:
WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=fake.word(),
label=fake.word(),
icon={"type": "emoji", "emoji": "🔧"},
description=fake.text(max_nb_chars=200),
parameters=tool_parameters,
)
assert exc_info.value.error_code == "workflow_tool_human_input_not_supported"
from extensions.ext_database import db
tool_count = (
db.session.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
)
.count()
)
assert tool_count == 0
def test_update_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful workflow tool update with valid parameters.
@ -600,6 +657,80 @@ class TestWorkflowToolManageService:
mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called()
mock_external_service_dependencies["tool_transform_service"].workflow_provider_to_controller.assert_called()
def test_update_workflow_tool_human_input_node_error(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test workflow tool update fails when workflow contains human input nodes.
This test verifies:
- Human input nodes prevent workflow tool updates
- Correct error message
- Existing tool data remains unchanged
"""
fake = Faker()
# Create test data
app, account, workflow = self._create_test_app_and_account(
db_session_with_containers, mock_external_service_dependencies
)
# Create initial workflow tool
initial_tool_name = fake.word()
initial_tool_parameters = self._create_test_workflow_tool_parameters()
WorkflowToolManageService.create_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_app_id=app.id,
name=initial_tool_name,
label=fake.word(),
icon={"type": "emoji", "emoji": "🔧"},
description=fake.text(max_nb_chars=200),
parameters=initial_tool_parameters,
)
from extensions.ext_database import db
created_tool = (
db.session.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
WorkflowToolProvider.app_id == app.id,
)
.first()
)
original_name = created_tool.name
workflow.graph = json.dumps(
{
"nodes": [
{
"id": "human_input_node",
"data": {"type": "human-input"},
}
]
}
)
db.session.commit()
with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info:
WorkflowToolManageService.update_workflow_tool(
user_id=account.id,
tenant_id=account.current_tenant.id,
workflow_tool_id=created_tool.id,
name=fake.word(),
label=fake.word(),
icon={"type": "emoji", "emoji": "⚙️"},
description=fake.text(max_nb_chars=200),
parameters=initial_tool_parameters,
)
assert exc_info.value.error_code == "workflow_tool_human_input_not_supported"
db.session.refresh(created_tool)
assert created_tool.name == original_name
def test_update_workflow_tool_not_found_error(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test workflow tool update fails when tool does not exist.

View File

@ -0,0 +1,214 @@
import uuid
from datetime import UTC, datetime
from unittest.mock import patch
import pytest
from configs import dify_config
from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.nodes.human_input.entities import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
ExternalRecipient,
HumanInputNodeData,
MemberRecipient,
)
from core.workflow.runtime import GraphRuntimeState, VariablePool
from extensions.ext_storage import storage
from models.account import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
from models.human_input import HumanInputDelivery, HumanInputForm, HumanInputFormRecipient
from models.model import AppMode
from models.workflow import WorkflowPause, WorkflowRun, WorkflowType
from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
@pytest.fixture(autouse=True)
def cleanup_database(db_session_with_containers):
db_session_with_containers.query(HumanInputFormRecipient).delete()
db_session_with_containers.query(HumanInputDelivery).delete()
db_session_with_containers.query(HumanInputForm).delete()
db_session_with_containers.query(WorkflowPause).delete()
db_session_with_containers.query(WorkflowRun).delete()
db_session_with_containers.query(TenantAccountJoin).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.query(Account).delete()
db_session_with_containers.commit()
def _create_workspace_member(db_session_with_containers):
account = Account(
email="owner@example.com",
name="Owner",
password="password",
interface_language="en-US",
status=AccountStatus.ACTIVE,
)
account.created_at = datetime.now(UTC)
account.updated_at = datetime.now(UTC)
db_session_with_containers.add(account)
db_session_with_containers.commit()
db_session_with_containers.refresh(account)
tenant = Tenant(name="Test Tenant")
tenant.created_at = datetime.now(UTC)
tenant.updated_at = datetime.now(UTC)
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
db_session_with_containers.refresh(tenant)
tenant_join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
)
tenant_join.created_at = datetime.now(UTC)
tenant_join.updated_at = datetime.now(UTC)
db_session_with_containers.add(tenant_join)
db_session_with_containers.commit()
return tenant, account
def _build_form(db_session_with_containers, tenant, account, *, app_id: str, workflow_execution_id: str):
delivery_method = EmailDeliveryMethod(
config=EmailDeliveryConfig(
recipients=EmailRecipients(
whole_workspace=False,
items=[
MemberRecipient(user_id=account.id),
ExternalRecipient(email="external@example.com"),
],
),
subject="Action needed {{ node_title }} {{#node1.value#}}",
body="Token {{ form_token }} link {{#url#}} content {{#node1.value#}}",
)
)
node_data = HumanInputNodeData(
title="Review",
form_content="Form content",
delivery_methods=[delivery_method],
)
engine = db_session_with_containers.get_bind()
repo = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id)
params = FormCreateParams(
app_id=app_id,
workflow_execution_id=workflow_execution_id,
node_id="node-1",
form_config=node_data,
rendered_content="Rendered",
delivery_methods=node_data.delivery_methods,
display_in_ui=False,
resolved_default_values={},
)
return repo.create_form(params)
def _create_workflow_pause_state(
db_session_with_containers,
*,
workflow_run_id: str,
workflow_id: str,
tenant_id: str,
app_id: str,
account_id: str,
variable_pool: VariablePool,
):
workflow_run = WorkflowRun(
id=workflow_run_id,
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
type=WorkflowType.WORKFLOW,
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
version="1",
graph="{}",
inputs="{}",
status=WorkflowExecutionStatus.PAUSED,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account_id,
created_at=datetime.now(UTC),
)
db_session_with_containers.add(workflow_run)
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
resumption_context = WorkflowResumptionContext(
generate_entity={
"type": AppMode.WORKFLOW,
"entity": WorkflowAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=WorkflowUIBasedAppConfig(
tenant_id=tenant_id,
app_id=app_id,
app_mode=AppMode.WORKFLOW,
workflow_id=workflow_id,
),
inputs={},
files=[],
user_id=account_id,
stream=False,
invoke_from=InvokeFrom.WEB_APP,
workflow_execution_id=workflow_run_id,
),
},
serialized_graph_runtime_state=runtime_state.dumps(),
)
state_object_key = f"workflow_pause_states/{workflow_run_id}.json"
storage.save(state_object_key, resumption_context.dumps().encode())
pause_state = WorkflowPause(
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
state_object_key=state_object_key,
)
db_session_with_containers.add(pause_state)
db_session_with_containers.commit()
def test_dispatch_human_input_email_task_integration(monkeypatch: pytest.MonkeyPatch, db_session_with_containers):
tenant, account = _create_workspace_member(db_session_with_containers)
workflow_run_id = str(uuid.uuid4())
workflow_id = str(uuid.uuid4())
app_id = str(uuid.uuid4())
variable_pool = VariablePool()
variable_pool.add(["node1", "value"], "OK")
_create_workflow_pause_state(
db_session_with_containers,
workflow_run_id=workflow_run_id,
workflow_id=workflow_id,
tenant_id=tenant.id,
app_id=app_id,
account_id=account.id,
variable_pool=variable_pool,
)
form_entity = _build_form(
db_session_with_containers,
tenant,
account,
app_id=app_id,
workflow_execution_id=workflow_run_id,
)
monkeypatch.setattr(dify_config, "APP_WEB_URL", "https://app.example.com")
with patch("tasks.mail_human_input_delivery_task.mail") as mock_mail:
mock_mail.is_inited.return_value = True
dispatch_human_input_email_task(form_id=form_entity.id, node_title="Approval")
assert mock_mail.send.call_count == 2
send_args = [call.kwargs for call in mock_mail.send.call_args_list]
recipients = {kwargs["to"] for kwargs in send_args}
assert recipients == {"owner@example.com", "external@example.com"}
assert all(kwargs["subject"] == "Action needed {{ node_title }} {{#node1.value#}}" for kwargs in send_args)
assert all("app.example.com/form/" in kwargs["html"] for kwargs in send_args)
assert all("content OK" in kwargs["html"] for kwargs in send_args)
assert all("{{ form_token }}" in kwargs["html"] for kwargs in send_args)

View File

@ -94,11 +94,6 @@ class PrunePausesTestCase:
def pause_workflow_failure_cases() -> list[PauseWorkflowFailureCase]:
"""Create test cases for pause workflow failure scenarios."""
return [
PauseWorkflowFailureCase(
name="pause_already_paused_workflow",
initial_status=WorkflowExecutionStatus.PAUSED,
description="Should fail to pause an already paused workflow",
),
PauseWorkflowFailureCase(
name="pause_completed_workflow",
initial_status=WorkflowExecutionStatus.SUCCEEDED,