mirror of
https://github.com/langgenius/dify.git
synced 2026-03-12 10:38:54 +08:00
feat: Human Input Node (#32060)
The frontend and backend implementation for the human input node. Co-authored-by: twwu <twwu@dify.ai> Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com> Co-authored-by: zhsama <torvalds@linux.do>
This commit is contained in:
@ -0,0 +1,166 @@
|
||||
"""TestContainers integration tests for ChatConversationApi status_count behavior."""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from constants import HEADER_NAME_CSRF_TOKEN
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.token import _real_cookie_name, generate_csrf_token
|
||||
from models import Account, DifySetup, Tenant, TenantAccountJoin
|
||||
from models.account import AccountStatus, TenantAccountRole
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, AppMode, Conversation, Message
|
||||
from models.workflow import WorkflowRun
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
def _create_account_and_tenant(db_session: Session) -> tuple[Account, Tenant]:
|
||||
account = Account(
|
||||
email=f"test-{uuid.uuid4()}@example.com",
|
||||
name="Test User",
|
||||
interface_language="en-US",
|
||||
status=AccountStatus.ACTIVE,
|
||||
)
|
||||
account.initialized_at = naive_utc_now()
|
||||
db_session.add(account)
|
||||
db_session.commit()
|
||||
|
||||
tenant = Tenant(name="Test Tenant", status="normal")
|
||||
db_session.add(tenant)
|
||||
db_session.commit()
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db_session.add(join)
|
||||
db_session.commit()
|
||||
|
||||
account.set_tenant_id(tenant.id)
|
||||
account.timezone = "UTC"
|
||||
db_session.commit()
|
||||
|
||||
dify_setup = DifySetup(version=dify_config.project.version)
|
||||
db_session.add(dify_setup)
|
||||
db_session.commit()
|
||||
|
||||
return account, tenant
|
||||
|
||||
|
||||
def _create_app(db_session: Session, tenant_id: str, account_id: str) -> App:
|
||||
app = App(
|
||||
tenant_id=tenant_id,
|
||||
name="Test Chat App",
|
||||
mode=AppMode.CHAT,
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
created_by=account_id,
|
||||
)
|
||||
db_session.add(app)
|
||||
db_session.commit()
|
||||
return app
|
||||
|
||||
|
||||
def _create_conversation(db_session: Session, app_id: str, account_id: str) -> Conversation:
|
||||
conversation = Conversation(
|
||||
app_id=app_id,
|
||||
name="Test Conversation",
|
||||
inputs={},
|
||||
status="normal",
|
||||
mode=AppMode.CHAT,
|
||||
from_source=CreatorUserRole.ACCOUNT,
|
||||
from_account_id=account_id,
|
||||
)
|
||||
db_session.add(conversation)
|
||||
db_session.commit()
|
||||
return conversation
|
||||
|
||||
|
||||
def _create_workflow_run(db_session: Session, app_id: str, tenant_id: str, account_id: str) -> WorkflowRun:
|
||||
workflow_run = WorkflowRun(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=str(uuid.uuid4()),
|
||||
type="chat",
|
||||
triggered_from="app-run",
|
||||
version="1.0.0",
|
||||
graph=json.dumps({"nodes": [], "edges": []}),
|
||||
inputs=json.dumps({"query": "test"}),
|
||||
status=WorkflowExecutionStatus.PAUSED,
|
||||
outputs=json.dumps({}),
|
||||
elapsed_time=0.0,
|
||||
total_tokens=0,
|
||||
total_steps=0,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account_id,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
db_session.add(workflow_run)
|
||||
db_session.commit()
|
||||
return workflow_run
|
||||
|
||||
|
||||
def _create_message(
|
||||
db_session: Session, app_id: str, conversation_id: str, workflow_run_id: str, account_id: str
|
||||
) -> Message:
|
||||
message = Message(
|
||||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
query="Hello",
|
||||
message={"type": "text", "content": "Hello"},
|
||||
answer="Hi there",
|
||||
message_tokens=1,
|
||||
answer_tokens=1,
|
||||
message_unit_price=0.001,
|
||||
answer_unit_price=0.001,
|
||||
message_price_unit=0.001,
|
||||
answer_price_unit=0.001,
|
||||
currency="USD",
|
||||
status="normal",
|
||||
from_source=CreatorUserRole.ACCOUNT,
|
||||
from_account_id=account_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
inputs={"query": "Hello"},
|
||||
)
|
||||
db_session.add(message)
|
||||
db_session.commit()
|
||||
return message
|
||||
|
||||
|
||||
def test_chat_conversation_status_count_includes_paused(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
):
|
||||
account, tenant = _create_account_and_tenant(db_session_with_containers)
|
||||
app = _create_app(db_session_with_containers, tenant.id, account.id)
|
||||
conversation = _create_conversation(db_session_with_containers, app.id, account.id)
|
||||
conversation_id = conversation.id
|
||||
workflow_run = _create_workflow_run(db_session_with_containers, app.id, tenant.id, account.id)
|
||||
_create_message(db_session_with_containers, app.id, conversation.id, workflow_run.id, account.id)
|
||||
|
||||
access_token = AccountService.get_account_jwt_token(account)
|
||||
csrf_token = generate_csrf_token(account.id)
|
||||
cookie_name = _real_cookie_name("csrf_token")
|
||||
|
||||
test_client_with_containers.set_cookie(cookie_name, csrf_token, domain="localhost")
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/chat-conversations",
|
||||
headers={
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
HEADER_NAME_CSRF_TOKEN: csrf_token,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert payload["total"] == 1
|
||||
assert payload["data"][0]["id"] == conversation_id
|
||||
assert payload["data"][0]["status_count"]["paused"] == 1
|
||||
@ -0,0 +1,240 @@
|
||||
"""TestContainers integration tests for HumanInputFormRepositoryImpl."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import Engine, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
DeliveryChannelConfig,
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
ExternalRecipient,
|
||||
FormDefinition,
|
||||
HumanInputNodeData,
|
||||
MemberRecipient,
|
||||
UserAction,
|
||||
WebAppDeliveryMethod,
|
||||
)
|
||||
from core.workflow.repositories.human_input_form_repository import FormCreateParams
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.human_input import (
|
||||
EmailExternalRecipientPayload,
|
||||
EmailMemberRecipientPayload,
|
||||
HumanInputForm,
|
||||
HumanInputFormRecipient,
|
||||
RecipientType,
|
||||
)
|
||||
|
||||
|
||||
def _create_tenant_with_members(session: Session, member_emails: list[str]) -> tuple[Tenant, list[Account]]:
|
||||
tenant = Tenant(name="Test Tenant", status="normal")
|
||||
session.add(tenant)
|
||||
session.flush()
|
||||
|
||||
members: list[Account] = []
|
||||
for index, email in enumerate(member_emails):
|
||||
account = Account(
|
||||
email=email,
|
||||
name=f"Member {index}",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
session.add(account)
|
||||
session.flush()
|
||||
|
||||
tenant_join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.NORMAL,
|
||||
current=True,
|
||||
)
|
||||
session.add(tenant_join)
|
||||
members.append(account)
|
||||
|
||||
session.commit()
|
||||
return tenant, members
|
||||
|
||||
|
||||
def _build_form_params(delivery_methods: list[DeliveryChannelConfig]) -> FormCreateParams:
|
||||
form_config = HumanInputNodeData(
|
||||
title="Human Approval",
|
||||
delivery_methods=delivery_methods,
|
||||
form_content="<p>Approve?</p>",
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
)
|
||||
return FormCreateParams(
|
||||
app_id=str(uuid4()),
|
||||
workflow_execution_id=str(uuid4()),
|
||||
node_id="human-input-node",
|
||||
form_config=form_config,
|
||||
rendered_content="<p>Approve?</p>",
|
||||
delivery_methods=delivery_methods,
|
||||
display_in_ui=False,
|
||||
resolved_default_values={},
|
||||
)
|
||||
|
||||
|
||||
def _build_email_delivery(
|
||||
whole_workspace: bool, recipients: list[MemberRecipient | ExternalRecipient]
|
||||
) -> EmailDeliveryMethod:
|
||||
return EmailDeliveryMethod(
|
||||
config=EmailDeliveryConfig(
|
||||
recipients=EmailRecipients(whole_workspace=whole_workspace, items=recipients),
|
||||
subject="Approval Needed",
|
||||
body="Please review",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TestHumanInputFormRepositoryImplWithContainers:
|
||||
def test_create_form_with_whole_workspace_recipients(self, db_session_with_containers: Session) -> None:
|
||||
engine = db_session_with_containers.get_bind()
|
||||
assert isinstance(engine, Engine)
|
||||
tenant, members = _create_tenant_with_members(
|
||||
db_session_with_containers,
|
||||
member_emails=["member1@example.com", "member2@example.com"],
|
||||
)
|
||||
|
||||
repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id)
|
||||
params = _build_form_params(
|
||||
delivery_methods=[_build_email_delivery(whole_workspace=True, recipients=[])],
|
||||
)
|
||||
|
||||
form_entity = repository.create_form(params)
|
||||
|
||||
with Session(engine) as verification_session:
|
||||
recipients = verification_session.scalars(
|
||||
select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_entity.id)
|
||||
).all()
|
||||
|
||||
assert len(recipients) == len(members)
|
||||
member_payloads = [
|
||||
EmailMemberRecipientPayload.model_validate_json(recipient.recipient_payload)
|
||||
for recipient in recipients
|
||||
if recipient.recipient_type == RecipientType.EMAIL_MEMBER
|
||||
]
|
||||
member_emails = {payload.email for payload in member_payloads}
|
||||
assert member_emails == {member.email for member in members}
|
||||
|
||||
def test_create_form_with_specific_members_and_external(self, db_session_with_containers: Session) -> None:
|
||||
engine = db_session_with_containers.get_bind()
|
||||
assert isinstance(engine, Engine)
|
||||
tenant, members = _create_tenant_with_members(
|
||||
db_session_with_containers,
|
||||
member_emails=["primary@example.com", "secondary@example.com"],
|
||||
)
|
||||
|
||||
repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id)
|
||||
params = _build_form_params(
|
||||
delivery_methods=[
|
||||
_build_email_delivery(
|
||||
whole_workspace=False,
|
||||
recipients=[
|
||||
MemberRecipient(user_id=members[0].id),
|
||||
ExternalRecipient(email="external@example.com"),
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
form_entity = repository.create_form(params)
|
||||
|
||||
with Session(engine) as verification_session:
|
||||
recipients = verification_session.scalars(
|
||||
select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_entity.id)
|
||||
).all()
|
||||
|
||||
member_recipient_payloads = [
|
||||
EmailMemberRecipientPayload.model_validate_json(recipient.recipient_payload)
|
||||
for recipient in recipients
|
||||
if recipient.recipient_type == RecipientType.EMAIL_MEMBER
|
||||
]
|
||||
assert len(member_recipient_payloads) == 1
|
||||
assert member_recipient_payloads[0].user_id == members[0].id
|
||||
|
||||
external_payloads = [
|
||||
EmailExternalRecipientPayload.model_validate_json(recipient.recipient_payload)
|
||||
for recipient in recipients
|
||||
if recipient.recipient_type == RecipientType.EMAIL_EXTERNAL
|
||||
]
|
||||
assert len(external_payloads) == 1
|
||||
assert external_payloads[0].email == "external@example.com"
|
||||
|
||||
def test_create_form_persists_default_values(self, db_session_with_containers: Session) -> None:
|
||||
engine = db_session_with_containers.get_bind()
|
||||
assert isinstance(engine, Engine)
|
||||
tenant, _ = _create_tenant_with_members(
|
||||
db_session_with_containers,
|
||||
member_emails=["prefill@example.com"],
|
||||
)
|
||||
|
||||
repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id)
|
||||
resolved_values = {"greeting": "Hello!"}
|
||||
params = FormCreateParams(
|
||||
app_id=str(uuid4()),
|
||||
workflow_execution_id=str(uuid4()),
|
||||
node_id="human-input-node",
|
||||
form_config=HumanInputNodeData(
|
||||
title="Human Approval",
|
||||
form_content="<p>Approve?</p>",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
),
|
||||
rendered_content="<p>Approve?</p>",
|
||||
delivery_methods=[],
|
||||
display_in_ui=False,
|
||||
resolved_default_values=resolved_values,
|
||||
)
|
||||
|
||||
form_entity = repository.create_form(params)
|
||||
|
||||
with Session(engine) as verification_session:
|
||||
form_model = verification_session.scalars(
|
||||
select(HumanInputForm).where(HumanInputForm.id == form_entity.id)
|
||||
).first()
|
||||
|
||||
assert form_model is not None
|
||||
definition = FormDefinition.model_validate_json(form_model.form_definition)
|
||||
assert definition.default_values == resolved_values
|
||||
|
||||
def test_create_form_persists_display_in_ui(self, db_session_with_containers: Session) -> None:
|
||||
engine = db_session_with_containers.get_bind()
|
||||
assert isinstance(engine, Engine)
|
||||
tenant, _ = _create_tenant_with_members(
|
||||
db_session_with_containers,
|
||||
member_emails=["ui@example.com"],
|
||||
)
|
||||
|
||||
repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id)
|
||||
params = FormCreateParams(
|
||||
app_id=str(uuid4()),
|
||||
workflow_execution_id=str(uuid4()),
|
||||
node_id="human-input-node",
|
||||
form_config=HumanInputNodeData(
|
||||
title="Human Approval",
|
||||
form_content="<p>Approve?</p>",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
delivery_methods=[WebAppDeliveryMethod()],
|
||||
),
|
||||
rendered_content="<p>Approve?</p>",
|
||||
delivery_methods=[WebAppDeliveryMethod()],
|
||||
display_in_ui=True,
|
||||
resolved_default_values={},
|
||||
)
|
||||
|
||||
form_entity = repository.create_form(params)
|
||||
|
||||
with Session(engine) as verification_session:
|
||||
form_model = verification_session.scalars(
|
||||
select(HumanInputForm).where(HumanInputForm.id == form_entity.id)
|
||||
).first()
|
||||
|
||||
assert form_model is not None
|
||||
definition = FormDefinition.model_validate_json(form_model.form_definition)
|
||||
assert definition.display_in_ui is True
|
||||
@ -0,0 +1,336 @@
|
||||
import time
|
||||
import uuid
|
||||
from datetime import timedelta
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
|
||||
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
|
||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
from models.account import Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
|
||||
from models.model import App, AppMode, IconType
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowRun
|
||||
|
||||
|
||||
def _mock_form_repository_without_submission() -> HumanInputFormRepository:
|
||||
repo = MagicMock(spec=HumanInputFormRepository)
|
||||
form_entity = MagicMock(spec=HumanInputFormEntity)
|
||||
form_entity.id = "test-form-id"
|
||||
form_entity.web_app_token = "test-form-token"
|
||||
form_entity.recipients = []
|
||||
form_entity.rendered_content = "rendered"
|
||||
form_entity.submitted = False
|
||||
repo.create_form.return_value = form_entity
|
||||
repo.get_form.return_value = None
|
||||
return repo
|
||||
|
||||
|
||||
def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepository:
|
||||
repo = MagicMock(spec=HumanInputFormRepository)
|
||||
form_entity = MagicMock(spec=HumanInputFormEntity)
|
||||
form_entity.id = "test-form-id"
|
||||
form_entity.web_app_token = "test-form-token"
|
||||
form_entity.recipients = []
|
||||
form_entity.rendered_content = "rendered"
|
||||
form_entity.submitted = True
|
||||
form_entity.selected_action_id = action_id
|
||||
form_entity.submitted_data = {}
|
||||
form_entity.status = HumanInputFormStatus.WAITING
|
||||
form_entity.expiration_time = naive_utc_now() + timedelta(hours=1)
|
||||
repo.get_form.return_value = form_entity
|
||||
return repo
|
||||
|
||||
|
||||
def _build_runtime_state(workflow_execution_id: str, app_id: str, workflow_id: str, user_id: str) -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
workflow_execution_id=workflow_execution_id,
|
||||
app_id=app_id,
|
||||
workflow_id=workflow_id,
|
||||
user_id=user_id,
|
||||
),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
|
||||
def _build_graph(
|
||||
runtime_state: GraphRuntimeState,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_id: str,
|
||||
user_id: str,
|
||||
form_repository: HumanInputFormRepository,
|
||||
) -> Graph:
|
||||
graph_config: dict[str, object] = {"nodes": [], "edges": []}
|
||||
params = GraphInitParams(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=workflow_id,
|
||||
graph_config=graph_config,
|
||||
user_id=user_id,
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
start_data = StartNodeData(title="start", variables=[])
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": start_data.model_dump()},
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
human_data = HumanInputNodeData(
|
||||
title="human",
|
||||
form_content="Awaiting human input",
|
||||
inputs=[],
|
||||
user_actions=[
|
||||
UserAction(id="continue", title="Continue"),
|
||||
],
|
||||
)
|
||||
human_node = HumanInputNode(
|
||||
id="human",
|
||||
config={"id": "human", "data": human_data.model_dump()},
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=form_repository,
|
||||
)
|
||||
|
||||
end_data = EndNodeData(
|
||||
title="end",
|
||||
outputs=[],
|
||||
desc=None,
|
||||
)
|
||||
end_node = EndNode(
|
||||
id="end",
|
||||
config={"id": "end", "data": end_data.model_dump()},
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
return (
|
||||
Graph.new()
|
||||
.add_root(start_node)
|
||||
.add_node(human_node)
|
||||
.add_node(end_node, from_node_id="human", source_handle="continue")
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
def _build_generate_entity(
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_id: str,
|
||||
workflow_execution_id: str,
|
||||
user_id: str,
|
||||
) -> WorkflowAppGenerateEntity:
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_id=workflow_id,
|
||||
)
|
||||
return WorkflowAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id=user_id,
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
workflow_execution_id=workflow_execution_id,
|
||||
)
|
||||
|
||||
|
||||
class TestHumanInputResumeNodeExecutionIntegration:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_test_data(self, db_session_with_containers: Session):
|
||||
tenant = Tenant(
|
||||
name="Test Tenant",
|
||||
status="normal",
|
||||
)
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
account = Account(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
tenant_join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db_session_with_containers.add(tenant_join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
account.current_tenant = tenant
|
||||
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name="Test App",
|
||||
description="",
|
||||
mode=AppMode.WORKFLOW.value,
|
||||
icon_type=IconType.EMOJI.value,
|
||||
icon="rocket",
|
||||
icon_background="#4ECDC4",
|
||||
enable_site=False,
|
||||
enable_api=False,
|
||||
api_rpm=0,
|
||||
api_rph=0,
|
||||
is_demo=False,
|
||||
is_public=False,
|
||||
is_universal=False,
|
||||
max_active_requests=None,
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(app)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow = Workflow(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app.id,
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph='{"nodes": [], "edges": []}',
|
||||
features='{"file_upload": {"enabled": false}}',
|
||||
created_by=account.id,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
db_session_with_containers.add(workflow)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
self.session = db_session_with_containers
|
||||
self.tenant = tenant
|
||||
self.account = account
|
||||
self.app = app
|
||||
self.workflow = workflow
|
||||
|
||||
yield
|
||||
|
||||
self.session.execute(delete(WorkflowNodeExecutionModel))
|
||||
self.session.execute(delete(WorkflowRun))
|
||||
self.session.execute(delete(Workflow).where(Workflow.id == self.workflow.id))
|
||||
self.session.execute(delete(App).where(App.id == self.app.id))
|
||||
self.session.execute(delete(TenantAccountJoin).where(TenantAccountJoin.tenant_id == self.tenant.id))
|
||||
self.session.execute(delete(Account).where(Account.id == self.account.id))
|
||||
self.session.execute(delete(Tenant).where(Tenant.id == self.tenant.id))
|
||||
self.session.commit()
|
||||
|
||||
def _build_persistence_layer(self, execution_id: str) -> WorkflowPersistenceLayer:
|
||||
generate_entity = _build_generate_entity(
|
||||
tenant_id=self.tenant.id,
|
||||
app_id=self.app.id,
|
||||
workflow_id=self.workflow.id,
|
||||
workflow_execution_id=execution_id,
|
||||
user_id=self.account.id,
|
||||
)
|
||||
execution_repo = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=self.session.get_bind(),
|
||||
user=self.account,
|
||||
app_id=self.app.id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
)
|
||||
node_execution_repo = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=self.session.get_bind(),
|
||||
user=self.account,
|
||||
app_id=self.app.id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
return WorkflowPersistenceLayer(
|
||||
application_generate_entity=generate_entity,
|
||||
workflow_info=PersistenceWorkflowInfo(
|
||||
workflow_id=self.workflow.id,
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
version=self.workflow.version,
|
||||
graph_data=self.workflow.graph_dict,
|
||||
),
|
||||
workflow_execution_repository=execution_repo,
|
||||
workflow_node_execution_repository=node_execution_repo,
|
||||
)
|
||||
|
||||
def _run_graph(self, graph: Graph, runtime_state: GraphRuntimeState, execution_id: str) -> None:
|
||||
engine = GraphEngine(
|
||||
workflow_id=self.workflow.id,
|
||||
graph=graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
engine.layer(self._build_persistence_layer(execution_id))
|
||||
for _ in engine.run():
|
||||
continue
|
||||
|
||||
def test_resume_human_input_does_not_create_duplicate_node_execution(self):
|
||||
execution_id = str(uuid.uuid4())
|
||||
runtime_state = _build_runtime_state(
|
||||
workflow_execution_id=execution_id,
|
||||
app_id=self.app.id,
|
||||
workflow_id=self.workflow.id,
|
||||
user_id=self.account.id,
|
||||
)
|
||||
pause_repo = _mock_form_repository_without_submission()
|
||||
paused_graph = _build_graph(
|
||||
runtime_state,
|
||||
self.tenant.id,
|
||||
self.app.id,
|
||||
self.workflow.id,
|
||||
self.account.id,
|
||||
pause_repo,
|
||||
)
|
||||
self._run_graph(paused_graph, runtime_state, execution_id)
|
||||
|
||||
snapshot = runtime_state.dumps()
|
||||
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
|
||||
resume_repo = _mock_form_repository_with_submission(action_id="continue")
|
||||
resumed_graph = _build_graph(
|
||||
resumed_state,
|
||||
self.tenant.id,
|
||||
self.app.id,
|
||||
self.workflow.id,
|
||||
self.account.id,
|
||||
resume_repo,
|
||||
)
|
||||
self._run_graph(resumed_graph, resumed_state, execution_id)
|
||||
|
||||
stmt = select(WorkflowNodeExecutionModel).where(
|
||||
WorkflowNodeExecutionModel.workflow_run_id == execution_id,
|
||||
WorkflowNodeExecutionModel.node_id == "human",
|
||||
)
|
||||
records = self.session.execute(stmt).scalars().all()
|
||||
assert len(records) == 1
|
||||
assert records[0].status != "paused"
|
||||
assert records[0].triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
|
||||
assert records[0].created_by_role == CreatorUserRole.ACCOUNT
|
||||
@ -0,0 +1 @@
|
||||
"""Helper utilities for integration tests."""
|
||||
@ -0,0 +1,154 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from decimal import Decimal
|
||||
from uuid import uuid4
|
||||
|
||||
from core.workflow.nodes.human_input.entities import FormDefinition, UserAction
|
||||
from models.account import Account, Tenant, TenantAccountJoin
|
||||
from models.execution_extra_content import HumanInputContent
|
||||
from models.human_input import HumanInputForm, HumanInputFormStatus
|
||||
from models.model import App, Conversation, Message
|
||||
|
||||
|
||||
@dataclass
|
||||
class HumanInputMessageFixture:
|
||||
app: App
|
||||
account: Account
|
||||
conversation: Conversation
|
||||
message: Message
|
||||
form: HumanInputForm
|
||||
action_id: str
|
||||
action_text: str
|
||||
node_title: str
|
||||
|
||||
|
||||
def create_human_input_message_fixture(db_session) -> HumanInputMessageFixture:
|
||||
tenant = Tenant(name=f"Tenant {uuid4()}")
|
||||
db_session.add(tenant)
|
||||
db_session.flush()
|
||||
|
||||
account = Account(
|
||||
name=f"Account {uuid4()}",
|
||||
email=f"human_input_{uuid4()}@example.com",
|
||||
password="hashed-password",
|
||||
password_salt="salt",
|
||||
interface_language="en-US",
|
||||
timezone="UTC",
|
||||
)
|
||||
db_session.add(account)
|
||||
db_session.flush()
|
||||
|
||||
tenant_join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role="owner",
|
||||
current=True,
|
||||
)
|
||||
db_session.add(tenant_join)
|
||||
db_session.flush()
|
||||
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name=f"App {uuid4()}",
|
||||
description="",
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="🤖",
|
||||
icon_background="#FFFFFF",
|
||||
enable_site=False,
|
||||
enable_api=True,
|
||||
api_rpm=100,
|
||||
api_rph=100,
|
||||
is_demo=False,
|
||||
is_public=False,
|
||||
is_universal=False,
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
db_session.add(app)
|
||||
db_session.flush()
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app.id,
|
||||
mode="chat",
|
||||
name="Test Conversation",
|
||||
summary="",
|
||||
introduction="",
|
||||
system_instruction="",
|
||||
status="normal",
|
||||
invoke_from="console",
|
||||
from_source="console",
|
||||
from_account_id=account.id,
|
||||
from_end_user_id=None,
|
||||
)
|
||||
conversation.inputs = {}
|
||||
db_session.add(conversation)
|
||||
db_session.flush()
|
||||
|
||||
workflow_run_id = str(uuid4())
|
||||
message = Message(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
inputs={},
|
||||
query="Human input query",
|
||||
message={"messages": []},
|
||||
answer="Human input answer",
|
||||
message_tokens=50,
|
||||
message_unit_price=Decimal("0.001"),
|
||||
answer_tokens=80,
|
||||
answer_unit_price=Decimal("0.001"),
|
||||
provider_response_latency=0.5,
|
||||
currency="USD",
|
||||
from_source="console",
|
||||
from_account_id=account.id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
db_session.add(message)
|
||||
db_session.flush()
|
||||
|
||||
action_id = "approve"
|
||||
action_text = "Approve request"
|
||||
node_title = "Approval"
|
||||
form_definition = FormDefinition(
|
||||
form_content="content",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id=action_id, title=action_text)],
|
||||
rendered_content="Rendered block",
|
||||
expiration_time=datetime.utcnow() + timedelta(days=1),
|
||||
node_title=node_title,
|
||||
display_in_ui=True,
|
||||
)
|
||||
form = HumanInputForm(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app.id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
node_id="node-id",
|
||||
form_definition=form_definition.model_dump_json(),
|
||||
rendered_content="Rendered block",
|
||||
status=HumanInputFormStatus.SUBMITTED,
|
||||
expiration_time=datetime.utcnow() + timedelta(days=1),
|
||||
selected_action_id=action_id,
|
||||
)
|
||||
db_session.add(form)
|
||||
db_session.flush()
|
||||
|
||||
content = HumanInputContent(
|
||||
workflow_run_id=workflow_run_id,
|
||||
message_id=message.id,
|
||||
form_id=form.id,
|
||||
)
|
||||
db_session.add(content)
|
||||
db_session.commit()
|
||||
|
||||
return HumanInputMessageFixture(
|
||||
app=app,
|
||||
account=account,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
form=form,
|
||||
action_id=action_id,
|
||||
action_text=action_text,
|
||||
node_title=node_title,
|
||||
)
|
||||
@ -16,6 +16,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
import redis
|
||||
from redis.cluster import RedisCluster
|
||||
from testcontainers.redis import RedisContainer
|
||||
|
||||
from libs.broadcast_channel.channel import BroadcastChannel, Subscription, Topic
|
||||
@ -332,3 +333,95 @@ class TestShardedRedisBroadcastChannelIntegration:
|
||||
# Verify subscriptions are cleaned up
|
||||
topic_subscribers_after = self._get_sharded_numsub(redis_client, topic_name)
|
||||
assert topic_subscribers_after == 0
|
||||
|
||||
|
||||
class TestShardedRedisBroadcastChannelClusterIntegration:
|
||||
"""Integration tests for sharded pub/sub with RedisCluster client."""
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def redis_cluster_container(self) -> Iterator[RedisContainer]:
|
||||
"""Create a Redis 7 container with cluster mode enabled."""
|
||||
command = (
|
||||
"redis-server --port 6379 "
|
||||
"--cluster-enabled yes "
|
||||
"--cluster-config-file nodes.conf "
|
||||
"--cluster-node-timeout 5000 "
|
||||
"--appendonly no "
|
||||
"--protected-mode no"
|
||||
)
|
||||
with RedisContainer(image="redis:7-alpine").with_command(command) as container:
|
||||
yield container
|
||||
|
||||
@classmethod
|
||||
def _get_test_topic_name(cls) -> str:
|
||||
return f"test_sharded_cluster_topic_{uuid.uuid4()}"
|
||||
|
||||
@staticmethod
|
||||
def _ensure_single_node_cluster(host: str, port: int) -> None:
|
||||
client = redis.Redis(host=host, port=port, decode_responses=False)
|
||||
client.config_set("cluster-announce-ip", host)
|
||||
client.config_set("cluster-announce-port", port)
|
||||
slots = client.execute_command("CLUSTER", "SLOTS")
|
||||
if not slots:
|
||||
client.execute_command("CLUSTER", "ADDSLOTSRANGE", 0, 16383)
|
||||
|
||||
deadline = time.time() + 5.0
|
||||
while time.time() < deadline:
|
||||
info = client.execute_command("CLUSTER", "INFO")
|
||||
info_text = info.decode("utf-8") if isinstance(info, (bytes, bytearray)) else str(info)
|
||||
if "cluster_state:ok" in info_text:
|
||||
return
|
||||
time.sleep(0.05)
|
||||
raise RuntimeError("Redis cluster did not become ready in time")
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def redis_cluster_client(self, redis_cluster_container: RedisContainer) -> RedisCluster:
|
||||
host = redis_cluster_container.get_container_host_ip()
|
||||
port = int(redis_cluster_container.get_exposed_port(6379))
|
||||
self._ensure_single_node_cluster(host, port)
|
||||
return RedisCluster(host=host, port=port, decode_responses=False)
|
||||
|
||||
@pytest.fixture
|
||||
def broadcast_channel(self, redis_cluster_client: RedisCluster) -> BroadcastChannel:
|
||||
return ShardedRedisBroadcastChannel(redis_cluster_client)
|
||||
|
||||
def test_cluster_sharded_pubsub_delivers_message(self, broadcast_channel: BroadcastChannel):
|
||||
"""Ensure sharded subscription receives messages when using RedisCluster client."""
|
||||
topic_name = self._get_test_topic_name()
|
||||
message = b"cluster sharded message"
|
||||
|
||||
topic = broadcast_channel.topic(topic_name)
|
||||
producer = topic.as_producer()
|
||||
subscription = topic.subscribe()
|
||||
ready_event = threading.Event()
|
||||
|
||||
def consumer_thread() -> list[bytes]:
|
||||
received = []
|
||||
try:
|
||||
_ = subscription.receive(0.01)
|
||||
except SubscriptionClosedError:
|
||||
return received
|
||||
ready_event.set()
|
||||
deadline = time.time() + 5.0
|
||||
while time.time() < deadline:
|
||||
msg = subscription.receive(timeout=0.1)
|
||||
if msg is None:
|
||||
continue
|
||||
received.append(msg)
|
||||
break
|
||||
subscription.close()
|
||||
return received
|
||||
|
||||
def producer_thread():
|
||||
if not ready_event.wait(timeout=2.0):
|
||||
pytest.fail("subscriber did not become ready before publish")
|
||||
producer.publish(message)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||
consumer_future = executor.submit(consumer_thread)
|
||||
producer_future = executor.submit(producer_thread)
|
||||
|
||||
producer_future.result(timeout=5.0)
|
||||
received_messages = consumer_future.result(timeout=5.0)
|
||||
|
||||
assert received_messages == [message]
|
||||
|
||||
@ -0,0 +1,25 @@
|
||||
"""
|
||||
Integration tests for RateLimiter using testcontainers Redis.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs import helper as helper_module
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("flask_app_with_containers")
|
||||
def test_rate_limiter_counts_multiple_attempts_in_same_second(monkeypatch):
|
||||
prefix = f"test_rate_limit:{uuid.uuid4().hex}"
|
||||
limiter = helper_module.RateLimiter(prefix=prefix, max_attempts=2, time_window=60)
|
||||
key = limiter._get_key("203.0.113.10")
|
||||
|
||||
redis_client.delete(key)
|
||||
monkeypatch.setattr(helper_module.time, "time", lambda: 1_700_000_000)
|
||||
|
||||
limiter.increment_rate_limit("203.0.113.10")
|
||||
limiter.increment_rate_limit("203.0.113.10")
|
||||
|
||||
assert limiter.is_rate_limited("203.0.113.10") is True
|
||||
@ -0,0 +1,79 @@
|
||||
# import secrets
|
||||
|
||||
# import pytest
|
||||
# from sqlalchemy import select
|
||||
# from sqlalchemy.orm import Session
|
||||
# from sqlalchemy.orm.exc import DetachedInstanceError
|
||||
|
||||
# from libs.datetime_utils import naive_utc_now
|
||||
# from models.account import Account, Tenant, TenantAccountJoin
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def session(db_session_with_containers):
|
||||
# with Session(db_session_with_containers.get_bind()) as session:
|
||||
# yield session
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def account(session):
|
||||
# account = Account(
|
||||
# name="test account",
|
||||
# email=f"test_{secrets.token_hex(8)}@example.com",
|
||||
# )
|
||||
# session.add(account)
|
||||
# session.commit()
|
||||
# return account
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def tenant(session):
|
||||
# tenant = Tenant(name="test tenant")
|
||||
# session.add(tenant)
|
||||
# session.commit()
|
||||
# return tenant
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def tenant_account_join(session, account, tenant):
|
||||
# tenant_join = TenantAccountJoin(account_id=account.id, tenant_id=tenant.id)
|
||||
# session.add(tenant_join)
|
||||
# session.commit()
|
||||
# yield tenant_join
|
||||
# session.delete(tenant_join)
|
||||
# session.commit()
|
||||
|
||||
|
||||
# class TestAccountTenant:
|
||||
# def test_set_current_tenant_should_reload_tenant(
|
||||
# self,
|
||||
# db_session_with_containers,
|
||||
# account,
|
||||
# tenant,
|
||||
# tenant_account_join,
|
||||
# ):
|
||||
# with Session(db_session_with_containers.get_bind(), expire_on_commit=True) as session:
|
||||
# scoped_tenant = session.scalars(select(Tenant).where(Tenant.id == tenant.id)).one()
|
||||
# account.current_tenant = scoped_tenant
|
||||
# scoped_tenant.created_at = naive_utc_now()
|
||||
# # session.commit()
|
||||
|
||||
# # Ensure the tenant used in assignment is detached.
|
||||
# with pytest.raises(DetachedInstanceError):
|
||||
# _ = scoped_tenant.name
|
||||
|
||||
# assert account._current_tenant.id == tenant.id
|
||||
# assert account._current_tenant.id == tenant.id
|
||||
|
||||
# def test_set_tenant_id_should_load_tenant_as_not_expire(
|
||||
# self,
|
||||
# flask_app_with_containers,
|
||||
# account,
|
||||
# tenant,
|
||||
# tenant_account_join,
|
||||
# ):
|
||||
# with flask_app_with_containers.test_request_context():
|
||||
# account.set_tenant_id(tenant.id)
|
||||
|
||||
# assert account._current_tenant.id == tenant.id
|
||||
# assert account._current_tenant.id == tenant.id
|
||||
@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository
|
||||
from tests.test_containers_integration_tests.helpers.execution_extra_content import (
|
||||
create_human_input_message_fixture,
|
||||
)
|
||||
|
||||
|
||||
def test_get_by_message_ids_returns_human_input_content(db_session_with_containers):
|
||||
fixture = create_human_input_message_fixture(db_session_with_containers)
|
||||
repository = SQLAlchemyExecutionExtraContentRepository(
|
||||
session_maker=sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
)
|
||||
|
||||
results = repository.get_by_message_ids([fixture.message.id])
|
||||
|
||||
assert len(results) == 1
|
||||
assert len(results[0]) == 1
|
||||
content = results[0][0]
|
||||
assert content.submitted is True
|
||||
assert content.form_submission_data is not None
|
||||
assert content.form_submission_data.action_id == fixture.action_id
|
||||
assert content.form_submission_data.action_text == fixture.action_text
|
||||
assert content.form_submission_data.rendered_content == fixture.form.rendered_content
|
||||
@ -2309,6 +2309,12 @@ class TestRegisterService:
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.model import DifySetup
|
||||
|
||||
db.session.query(DifySetup).delete()
|
||||
db.session.commit()
|
||||
|
||||
# Execute setup
|
||||
RegisterService.setup(
|
||||
email=admin_email,
|
||||
@ -2319,9 +2325,7 @@ class TestRegisterService:
|
||||
)
|
||||
|
||||
# Verify account was created
|
||||
from extensions.ext_database import db
|
||||
from models import Account
|
||||
from models.model import DifySetup
|
||||
|
||||
account = db.session.query(Account).filter_by(email=admin_email).first()
|
||||
assert account is not None
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
@ -26,6 +26,7 @@ class TestAppGenerateService:
|
||||
patch("services.app_generate_service.AgentChatAppGenerator") as mock_agent_chat_generator,
|
||||
patch("services.app_generate_service.AdvancedChatAppGenerator") as mock_advanced_chat_generator,
|
||||
patch("services.app_generate_service.WorkflowAppGenerator") as mock_workflow_generator,
|
||||
patch("services.app_generate_service.MessageBasedAppGenerator") as mock_message_based_generator,
|
||||
patch("services.account_service.FeatureService") as mock_account_feature_service,
|
||||
patch("services.app_generate_service.dify_config") as mock_dify_config,
|
||||
patch("configs.dify_config") as mock_global_dify_config,
|
||||
@ -38,9 +39,13 @@ class TestAppGenerateService:
|
||||
|
||||
# Setup default mock returns for workflow service
|
||||
mock_workflow_service_instance = mock_workflow_service.return_value
|
||||
mock_workflow_service_instance.get_published_workflow.return_value = MagicMock(spec=Workflow)
|
||||
mock_workflow_service_instance.get_draft_workflow.return_value = MagicMock(spec=Workflow)
|
||||
mock_workflow_service_instance.get_published_workflow_by_id.return_value = MagicMock(spec=Workflow)
|
||||
mock_published_workflow = MagicMock(spec=Workflow)
|
||||
mock_published_workflow.id = str(uuid.uuid4())
|
||||
mock_workflow_service_instance.get_published_workflow.return_value = mock_published_workflow
|
||||
mock_draft_workflow = MagicMock(spec=Workflow)
|
||||
mock_draft_workflow.id = str(uuid.uuid4())
|
||||
mock_workflow_service_instance.get_draft_workflow.return_value = mock_draft_workflow
|
||||
mock_workflow_service_instance.get_published_workflow_by_id.return_value = mock_published_workflow
|
||||
|
||||
# Setup default mock returns for rate limiting
|
||||
mock_rate_limit_instance = mock_rate_limit.return_value
|
||||
@ -66,6 +71,8 @@ class TestAppGenerateService:
|
||||
mock_advanced_chat_generator_instance.generate.return_value = ["advanced_chat_response"]
|
||||
mock_advanced_chat_generator_instance.single_iteration_generate.return_value = ["single_iteration_response"]
|
||||
mock_advanced_chat_generator_instance.single_loop_generate.return_value = ["single_loop_response"]
|
||||
mock_advanced_chat_generator_instance.retrieve_events.return_value = ["advanced_chat_events"]
|
||||
mock_advanced_chat_generator_instance.convert_to_event_stream.return_value = ["advanced_chat_stream"]
|
||||
mock_advanced_chat_generator.convert_to_event_stream.return_value = ["advanced_chat_stream"]
|
||||
|
||||
mock_workflow_generator_instance = mock_workflow_generator.return_value
|
||||
@ -76,6 +83,8 @@ class TestAppGenerateService:
|
||||
mock_workflow_generator_instance.single_loop_generate.return_value = ["workflow_single_loop_response"]
|
||||
mock_workflow_generator.convert_to_event_stream.return_value = ["workflow_stream"]
|
||||
|
||||
mock_message_based_generator.retrieve_events.return_value = ["workflow_events"]
|
||||
|
||||
# Setup default mock returns for account service
|
||||
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
|
||||
@ -88,6 +97,7 @@ class TestAppGenerateService:
|
||||
mock_global_dify_config.BILLING_ENABLED = False
|
||||
mock_global_dify_config.APP_MAX_ACTIVE_REQUESTS = 100
|
||||
mock_global_dify_config.APP_DAILY_RATE_LIMIT = 1000
|
||||
mock_global_dify_config.HOSTED_POOL_CREDITS = 1000
|
||||
|
||||
yield {
|
||||
"billing_service": mock_billing_service,
|
||||
@ -98,6 +108,7 @@ class TestAppGenerateService:
|
||||
"agent_chat_generator": mock_agent_chat_generator,
|
||||
"advanced_chat_generator": mock_advanced_chat_generator,
|
||||
"workflow_generator": mock_workflow_generator,
|
||||
"message_based_generator": mock_message_based_generator,
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
"dify_config": mock_dify_config,
|
||||
"global_dify_config": mock_global_dify_config,
|
||||
@ -280,8 +291,10 @@ class TestAppGenerateService:
|
||||
assert result == ["test_response"]
|
||||
|
||||
# Verify advanced chat generator was called
|
||||
mock_external_service_dependencies["advanced_chat_generator"].return_value.generate.assert_called_once()
|
||||
mock_external_service_dependencies["advanced_chat_generator"].convert_to_event_stream.assert_called_once()
|
||||
mock_external_service_dependencies["advanced_chat_generator"].return_value.retrieve_events.assert_called_once()
|
||||
mock_external_service_dependencies[
|
||||
"advanced_chat_generator"
|
||||
].return_value.convert_to_event_stream.assert_called_once()
|
||||
|
||||
def test_generate_workflow_mode_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
@ -304,7 +317,7 @@ class TestAppGenerateService:
|
||||
assert result == ["test_response"]
|
||||
|
||||
# Verify workflow generator was called
|
||||
mock_external_service_dependencies["workflow_generator"].return_value.generate.assert_called_once()
|
||||
mock_external_service_dependencies["message_based_generator"].retrieve_events.assert_called_once()
|
||||
mock_external_service_dependencies["workflow_generator"].convert_to_event_stream.assert_called_once()
|
||||
|
||||
def test_generate_with_specific_workflow_id(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
@ -970,14 +983,27 @@ class TestAppGenerateService:
|
||||
}
|
||||
|
||||
# Execute the method under test
|
||||
result = AppGenerateService.generate(
|
||||
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
|
||||
)
|
||||
with patch("services.app_generate_service.AppExecutionParams") as mock_exec_params:
|
||||
mock_payload = MagicMock()
|
||||
mock_payload.workflow_run_id = fake.uuid4()
|
||||
mock_payload.model_dump_json.return_value = "{}"
|
||||
mock_exec_params.new.return_value = mock_payload
|
||||
|
||||
result = AppGenerateService.generate(
|
||||
app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == ["test_response"]
|
||||
|
||||
# Verify workflow generator was called with complex args
|
||||
mock_external_service_dependencies["workflow_generator"].return_value.generate.assert_called_once()
|
||||
call_args = mock_external_service_dependencies["workflow_generator"].return_value.generate.call_args
|
||||
assert call_args[1]["args"] == args
|
||||
# Verify payload was built with complex args
|
||||
mock_exec_params.new.assert_called_once()
|
||||
call_kwargs = mock_exec_params.new.call_args.kwargs
|
||||
assert call_kwargs["args"] == args
|
||||
|
||||
# Verify workflow streaming event retrieval was used
|
||||
mock_external_service_dependencies["message_based_generator"].retrieve_events.assert_called_once_with(
|
||||
ANY,
|
||||
mock_payload.workflow_run_id,
|
||||
on_subscribe=ANY,
|
||||
)
|
||||
|
||||
@ -0,0 +1,112 @@
|
||||
import json
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
ExternalRecipient,
|
||||
HumanInputNodeData,
|
||||
)
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.model import App, AppMode
|
||||
from models.workflow import Workflow, WorkflowType
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
def _create_app_with_draft_workflow(session, *, delivery_method_id: uuid.UUID) -> tuple[App, Account]:
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
account = Account(name="Tester", email="tester@example.com")
|
||||
session.add_all([tenant, account])
|
||||
session.flush()
|
||||
|
||||
session.add(
|
||||
TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
current=True,
|
||||
role=TenantAccountRole.OWNER.value,
|
||||
)
|
||||
)
|
||||
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name="Test App",
|
||||
description="",
|
||||
mode=AppMode.WORKFLOW.value,
|
||||
icon_type="emoji",
|
||||
icon="app",
|
||||
icon_background="#ffffff",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
session.add(app)
|
||||
session.flush()
|
||||
|
||||
email_method = EmailDeliveryMethod(
|
||||
id=delivery_method_id,
|
||||
enabled=True,
|
||||
config=EmailDeliveryConfig(
|
||||
recipients=EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[ExternalRecipient(email="recipient@example.com")],
|
||||
),
|
||||
subject="Test {{recipient_email}}",
|
||||
body="Body {{#url#}} {{form_content}}",
|
||||
),
|
||||
)
|
||||
node_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
delivery_methods=[email_method],
|
||||
form_content="Hello Human Input",
|
||||
inputs=[],
|
||||
user_actions=[],
|
||||
).model_dump(mode="json")
|
||||
node_data["type"] = NodeType.HUMAN_INPUT.value
|
||||
graph = json.dumps({"nodes": [{"id": "human-node", "data": node_data}], "edges": []})
|
||||
|
||||
workflow = Workflow.new(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app.id,
|
||||
type=WorkflowType.WORKFLOW.value,
|
||||
version=Workflow.VERSION_DRAFT,
|
||||
graph=graph,
|
||||
features=json.dumps({}),
|
||||
created_by=account.id,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
rag_pipeline_variables=[],
|
||||
)
|
||||
session.add(workflow)
|
||||
session.commit()
|
||||
|
||||
return app, account
|
||||
|
||||
|
||||
def test_human_input_delivery_test_sends_email(
|
||||
db_session_with_containers,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
delivery_method_id = uuid.uuid4()
|
||||
app, account = _create_app_with_draft_workflow(db_session_with_containers, delivery_method_id=delivery_method_id)
|
||||
|
||||
send_mock = MagicMock()
|
||||
monkeypatch.setattr("services.human_input_delivery_test_service.mail.is_inited", lambda: True)
|
||||
monkeypatch.setattr("services.human_input_delivery_test_service.mail.send", send_mock)
|
||||
|
||||
service = WorkflowService()
|
||||
service.test_human_input_delivery(
|
||||
app_model=app,
|
||||
account=account,
|
||||
node_id="human-node",
|
||||
delivery_method_id=str(delivery_method_id),
|
||||
)
|
||||
|
||||
assert send_mock.call_count == 1
|
||||
assert send_mock.call_args.kwargs["to"] == "recipient@example.com"
|
||||
@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from services.message_service import MessageService
|
||||
from tests.test_containers_integration_tests.helpers.execution_extra_content import (
|
||||
create_human_input_message_fixture,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("flask_req_ctx_with_containers")
|
||||
def test_pagination_returns_extra_contents(db_session_with_containers):
|
||||
fixture = create_human_input_message_fixture(db_session_with_containers)
|
||||
|
||||
pagination = MessageService.pagination_by_first_id(
|
||||
app_model=fixture.app,
|
||||
user=fixture.account,
|
||||
conversation_id=fixture.conversation.id,
|
||||
first_id=None,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
assert pagination.data
|
||||
message = pagination.data[0]
|
||||
assert message.extra_contents == [
|
||||
{
|
||||
"type": "human_input",
|
||||
"workflow_run_id": fixture.message.workflow_run_id,
|
||||
"submitted": True,
|
||||
"form_submission_data": {
|
||||
"node_id": fixture.form.node_id,
|
||||
"node_title": fixture.node_title,
|
||||
"rendered_content": fixture.form.rendered_content,
|
||||
"action_id": fixture.action_id,
|
||||
"action_text": fixture.action_text,
|
||||
},
|
||||
}
|
||||
]
|
||||
@ -465,6 +465,27 @@ class TestWorkflowRunService:
|
||||
db.session.add(node_execution)
|
||||
node_executions.append(node_execution)
|
||||
|
||||
paused_node_execution = WorkflowNodeExecutionModel(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
triggered_from="workflow-run",
|
||||
workflow_run_id=workflow_run.id,
|
||||
index=99,
|
||||
node_id="node_paused",
|
||||
node_type="human_input",
|
||||
title="Paused Node",
|
||||
inputs=json.dumps({"input": "paused"}),
|
||||
process_data=json.dumps({"process": "paused"}),
|
||||
status="paused",
|
||||
elapsed_time=0.5,
|
||||
execution_metadata=json.dumps({"tokens": 0}),
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
db.session.add(paused_node_execution)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# Act: Execute the method under test
|
||||
@ -473,16 +494,19 @@ class TestWorkflowRunService:
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert len(result) == 3
|
||||
assert len(result) == 4
|
||||
|
||||
# Verify node execution properties
|
||||
statuses = [node_execution.status for node_execution in result]
|
||||
assert "paused" in statuses
|
||||
assert statuses.count("succeeded") == 3
|
||||
assert statuses.count("paused") == 1
|
||||
|
||||
for node_execution in result:
|
||||
assert node_execution.tenant_id == app.tenant_id
|
||||
assert node_execution.app_id == app.id
|
||||
assert node_execution.workflow_run_id == workflow_run.id
|
||||
assert node_execution.index in [0, 1, 2] # Check that index is one of the expected values
|
||||
assert node_execution.node_id.startswith("node_") # Check that node_id starts with "node_"
|
||||
assert node_execution.status == "succeeded"
|
||||
assert node_execution.node_id.startswith("node_")
|
||||
|
||||
def test_get_workflow_run_node_executions_empty(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
|
||||
@ -6,6 +6,7 @@ from faker import Faker
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.workflow import Workflow as WorkflowModel
|
||||
from services.account_service import AccountService, TenantService
|
||||
@ -513,6 +514,62 @@ class TestWorkflowToolManageService:
|
||||
|
||||
assert tool_count == 0
|
||||
|
||||
def test_create_workflow_tool_human_input_node_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow tool creation fails when workflow contains human input nodes.
|
||||
|
||||
This test verifies:
|
||||
- Human input nodes prevent workflow tool publishing
|
||||
- Correct error message
|
||||
- No database changes when workflow is invalid
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create test data
|
||||
app, account, workflow = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
workflow.graph = json.dumps(
|
||||
{
|
||||
"nodes": [
|
||||
{
|
||||
"id": "human_input_node",
|
||||
"data": {"type": "human-input"},
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
tool_parameters = self._create_test_workflow_tool_parameters()
|
||||
with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info:
|
||||
WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=account.id,
|
||||
tenant_id=account.current_tenant.id,
|
||||
workflow_app_id=app.id,
|
||||
name=fake.word(),
|
||||
label=fake.word(),
|
||||
icon={"type": "emoji", "emoji": "🔧"},
|
||||
description=fake.text(max_nb_chars=200),
|
||||
parameters=tool_parameters,
|
||||
)
|
||||
|
||||
assert exc_info.value.error_code == "workflow_tool_human_input_not_supported"
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
tool_count = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
assert tool_count == 0
|
||||
|
||||
def test_update_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful workflow tool update with valid parameters.
|
||||
@ -600,6 +657,80 @@ class TestWorkflowToolManageService:
|
||||
mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_called()
|
||||
mock_external_service_dependencies["tool_transform_service"].workflow_provider_to_controller.assert_called()
|
||||
|
||||
def test_update_workflow_tool_human_input_node_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow tool update fails when workflow contains human input nodes.
|
||||
|
||||
This test verifies:
|
||||
- Human input nodes prevent workflow tool updates
|
||||
- Correct error message
|
||||
- Existing tool data remains unchanged
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create test data
|
||||
app, account, workflow = self._create_test_app_and_account(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
# Create initial workflow tool
|
||||
initial_tool_name = fake.word()
|
||||
initial_tool_parameters = self._create_test_workflow_tool_parameters()
|
||||
WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=account.id,
|
||||
tenant_id=account.current_tenant.id,
|
||||
workflow_app_id=app.id,
|
||||
name=initial_tool_name,
|
||||
label=fake.word(),
|
||||
icon={"type": "emoji", "emoji": "🔧"},
|
||||
description=fake.text(max_nb_chars=200),
|
||||
parameters=initial_tool_parameters,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
created_tool = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||
WorkflowToolProvider.app_id == app.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
original_name = created_tool.name
|
||||
|
||||
workflow.graph = json.dumps(
|
||||
{
|
||||
"nodes": [
|
||||
{
|
||||
"id": "human_input_node",
|
||||
"data": {"type": "human-input"},
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info:
|
||||
WorkflowToolManageService.update_workflow_tool(
|
||||
user_id=account.id,
|
||||
tenant_id=account.current_tenant.id,
|
||||
workflow_tool_id=created_tool.id,
|
||||
name=fake.word(),
|
||||
label=fake.word(),
|
||||
icon={"type": "emoji", "emoji": "⚙️"},
|
||||
description=fake.text(max_nb_chars=200),
|
||||
parameters=initial_tool_parameters,
|
||||
)
|
||||
|
||||
assert exc_info.value.error_code == "workflow_tool_human_input_not_supported"
|
||||
|
||||
db.session.refresh(created_tool)
|
||||
assert created_tool.name == original_name
|
||||
|
||||
def test_update_workflow_tool_not_found_error(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test workflow tool update fails when tool does not exist.
|
||||
|
||||
@ -0,0 +1,214 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
|
||||
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
ExternalRecipient,
|
||||
HumanInputNodeData,
|
||||
MemberRecipient,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from extensions.ext_storage import storage
|
||||
from models.account import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
|
||||
from models.human_input import HumanInputDelivery, HumanInputForm, HumanInputFormRecipient
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowPause, WorkflowRun, WorkflowType
|
||||
from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_database(db_session_with_containers):
|
||||
db_session_with_containers.query(HumanInputFormRecipient).delete()
|
||||
db_session_with_containers.query(HumanInputDelivery).delete()
|
||||
db_session_with_containers.query(HumanInputForm).delete()
|
||||
db_session_with_containers.query(WorkflowPause).delete()
|
||||
db_session_with_containers.query(WorkflowRun).delete()
|
||||
db_session_with_containers.query(TenantAccountJoin).delete()
|
||||
db_session_with_containers.query(Tenant).delete()
|
||||
db_session_with_containers.query(Account).delete()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
|
||||
def _create_workspace_member(db_session_with_containers):
|
||||
account = Account(
|
||||
email="owner@example.com",
|
||||
name="Owner",
|
||||
password="password",
|
||||
interface_language="en-US",
|
||||
status=AccountStatus.ACTIVE,
|
||||
)
|
||||
account.created_at = datetime.now(UTC)
|
||||
account.updated_at = datetime.now(UTC)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
db_session_with_containers.refresh(account)
|
||||
|
||||
tenant = Tenant(name="Test Tenant")
|
||||
tenant.created_at = datetime.now(UTC)
|
||||
tenant.updated_at = datetime.now(UTC)
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
db_session_with_containers.refresh(tenant)
|
||||
|
||||
tenant_join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
)
|
||||
tenant_join.created_at = datetime.now(UTC)
|
||||
tenant_join.updated_at = datetime.now(UTC)
|
||||
db_session_with_containers.add(tenant_join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return tenant, account
|
||||
|
||||
|
||||
def _build_form(db_session_with_containers, tenant, account, *, app_id: str, workflow_execution_id: str):
|
||||
delivery_method = EmailDeliveryMethod(
|
||||
config=EmailDeliveryConfig(
|
||||
recipients=EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[
|
||||
MemberRecipient(user_id=account.id),
|
||||
ExternalRecipient(email="external@example.com"),
|
||||
],
|
||||
),
|
||||
subject="Action needed {{ node_title }} {{#node1.value#}}",
|
||||
body="Token {{ form_token }} link {{#url#}} content {{#node1.value#}}",
|
||||
)
|
||||
)
|
||||
|
||||
node_data = HumanInputNodeData(
|
||||
title="Review",
|
||||
form_content="Form content",
|
||||
delivery_methods=[delivery_method],
|
||||
)
|
||||
|
||||
engine = db_session_with_containers.get_bind()
|
||||
repo = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id)
|
||||
params = FormCreateParams(
|
||||
app_id=app_id,
|
||||
workflow_execution_id=workflow_execution_id,
|
||||
node_id="node-1",
|
||||
form_config=node_data,
|
||||
rendered_content="Rendered",
|
||||
delivery_methods=node_data.delivery_methods,
|
||||
display_in_ui=False,
|
||||
resolved_default_values={},
|
||||
)
|
||||
return repo.create_form(params)
|
||||
|
||||
|
||||
def _create_workflow_pause_state(
|
||||
db_session_with_containers,
|
||||
*,
|
||||
workflow_run_id: str,
|
||||
workflow_id: str,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
account_id: str,
|
||||
variable_pool: VariablePool,
|
||||
):
|
||||
workflow_run = WorkflowRun(
|
||||
id=workflow_run_id,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=workflow_id,
|
||||
type=WorkflowType.WORKFLOW,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
version="1",
|
||||
graph="{}",
|
||||
inputs="{}",
|
||||
status=WorkflowExecutionStatus.PAUSED,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account_id,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
db_session_with_containers.add(workflow_run)
|
||||
|
||||
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
|
||||
resumption_context = WorkflowResumptionContext(
|
||||
generate_entity={
|
||||
"type": AppMode.WORKFLOW,
|
||||
"entity": WorkflowAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=WorkflowUIBasedAppConfig(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_id=workflow_id,
|
||||
),
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id=account_id,
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
workflow_execution_id=workflow_run_id,
|
||||
),
|
||||
},
|
||||
serialized_graph_runtime_state=runtime_state.dumps(),
|
||||
)
|
||||
|
||||
state_object_key = f"workflow_pause_states/{workflow_run_id}.json"
|
||||
storage.save(state_object_key, resumption_context.dumps().encode())
|
||||
|
||||
pause_state = WorkflowPause(
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
state_object_key=state_object_key,
|
||||
)
|
||||
db_session_with_containers.add(pause_state)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
|
||||
def test_dispatch_human_input_email_task_integration(monkeypatch: pytest.MonkeyPatch, db_session_with_containers):
|
||||
tenant, account = _create_workspace_member(db_session_with_containers)
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
workflow_id = str(uuid.uuid4())
|
||||
app_id = str(uuid.uuid4())
|
||||
variable_pool = VariablePool()
|
||||
variable_pool.add(["node1", "value"], "OK")
|
||||
_create_workflow_pause_state(
|
||||
db_session_with_containers,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_id=workflow_id,
|
||||
tenant_id=tenant.id,
|
||||
app_id=app_id,
|
||||
account_id=account.id,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
form_entity = _build_form(
|
||||
db_session_with_containers,
|
||||
tenant,
|
||||
account,
|
||||
app_id=app_id,
|
||||
workflow_execution_id=workflow_run_id,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(dify_config, "APP_WEB_URL", "https://app.example.com")
|
||||
|
||||
with patch("tasks.mail_human_input_delivery_task.mail") as mock_mail:
|
||||
mock_mail.is_inited.return_value = True
|
||||
|
||||
dispatch_human_input_email_task(form_id=form_entity.id, node_title="Approval")
|
||||
|
||||
assert mock_mail.send.call_count == 2
|
||||
send_args = [call.kwargs for call in mock_mail.send.call_args_list]
|
||||
recipients = {kwargs["to"] for kwargs in send_args}
|
||||
assert recipients == {"owner@example.com", "external@example.com"}
|
||||
assert all(kwargs["subject"] == "Action needed {{ node_title }} {{#node1.value#}}" for kwargs in send_args)
|
||||
assert all("app.example.com/form/" in kwargs["html"] for kwargs in send_args)
|
||||
assert all("content OK" in kwargs["html"] for kwargs in send_args)
|
||||
assert all("{{ form_token }}" in kwargs["html"] for kwargs in send_args)
|
||||
@ -94,11 +94,6 @@ class PrunePausesTestCase:
|
||||
def pause_workflow_failure_cases() -> list[PauseWorkflowFailureCase]:
|
||||
"""Create test cases for pause workflow failure scenarios."""
|
||||
return [
|
||||
PauseWorkflowFailureCase(
|
||||
name="pause_already_paused_workflow",
|
||||
initial_status=WorkflowExecutionStatus.PAUSED,
|
||||
description="Should fail to pause an already paused workflow",
|
||||
),
|
||||
PauseWorkflowFailureCase(
|
||||
name="pause_completed_workflow",
|
||||
initial_status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
|
||||
Reference in New Issue
Block a user