Merge branch 'feat/model-plugins-implementing' into deploy/dev

This commit is contained in:
yyh
2026-03-10 12:12:46 +08:00
64 changed files with 9937 additions and 408 deletions

View File

@ -45,7 +45,6 @@ allow_indirect_imports = True
ignore_imports =
dify_graph.nodes.agent.agent_node -> extensions.ext_database
dify_graph.nodes.llm.node -> extensions.ext_database
dify_graph.nodes.tool.tool_node -> extensions.ext_database
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
@ -111,7 +110,6 @@ ignore_imports =
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager
dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer
dify_graph.nodes.tool.tool_node -> models
dify_graph.nodes.agent.agent_node -> models.model
dify_graph.nodes.llm.node -> core.helper.code_executor
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors
@ -134,7 +132,6 @@ ignore_imports =
dify_graph.nodes.tool.tool_node -> core.tools.errors
dify_graph.nodes.agent.agent_node -> extensions.ext_database
dify_graph.nodes.llm.node -> extensions.ext_database
dify_graph.nodes.tool.tool_node -> extensions.ext_database
dify_graph.nodes.agent.agent_node -> models
dify_graph.nodes.llm.node -> models.model
dify_graph.nodes.agent.agent_node -> services

View File

@ -1,8 +1,9 @@
import json
from contextlib import ExitStack
from typing import Self
from uuid import UUID
from flask import request
from flask import request, send_file
from flask_restx import marshal
from pydantic import BaseModel, Field, field_validator, model_validator
from sqlalchemy import desc, select
@ -100,6 +101,15 @@ class DocumentListQuery(BaseModel):
status: str | None = Field(default=None, description="Document status filter")
DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100
class DocumentBatchDownloadZipPayload(BaseModel):
"""Request payload for bulk downloading uploaded documents as a ZIP archive."""
document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS)
register_enum_models(service_api_ns, RetrievalMethod)
register_schema_models(
@ -109,6 +119,7 @@ register_schema_models(
DocumentTextCreatePayload,
DocumentTextUpdate,
DocumentListQuery,
DocumentBatchDownloadZipPayload,
Rule,
PreProcessingRule,
Segmentation,
@ -540,6 +551,46 @@ class DocumentListApi(DatasetApiResource):
return response
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/download-zip")
class DocumentBatchDownloadZipApi(DatasetApiResource):
"""Download multiple uploaded-file documents as a single ZIP archive."""
@service_api_ns.expect(service_api_ns.models[DocumentBatchDownloadZipPayload.__name__])
@service_api_ns.doc("download_documents_as_zip")
@service_api_ns.doc(description="Download selected uploaded documents as a single ZIP archive")
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
@service_api_ns.doc(
responses={
200: "ZIP archive generated successfully",
401: "Unauthorized - invalid API token",
403: "Forbidden - insufficient permissions",
404: "Document or dataset not found",
}
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id):
payload = DocumentBatchDownloadZipPayload.model_validate(service_api_ns.payload or {})
upload_files, download_name = DocumentService.prepare_document_batch_download_zip(
dataset_id=str(dataset_id),
document_ids=[str(document_id) for document_id in payload.document_ids],
tenant_id=str(tenant_id),
current_user=current_user,
)
with ExitStack() as stack:
zip_path = stack.enter_context(FileService.build_upload_files_zip_tempfile(upload_files=upload_files))
response = send_file(
zip_path,
mimetype="application/zip",
as_attachment=True,
download_name=download_name,
)
cleanup = stack.pop_all()
response.call_on_close(cleanup.close)
return response
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status")
class DocumentIndexingStatusApi(DatasetApiResource):
@service_api_ns.doc("get_document_indexing_status")
@ -600,6 +651,35 @@ class DocumentIndexingStatusApi(DatasetApiResource):
return data
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/download")
class DocumentDownloadApi(DatasetApiResource):
"""Return a signed download URL for a document's original uploaded file."""
@service_api_ns.doc("get_document_download_url")
@service_api_ns.doc(description="Get a signed download URL for a document's original uploaded file")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@service_api_ns.doc(
responses={
200: "Download URL generated successfully",
401: "Unauthorized - invalid API token",
403: "Forbidden - insufficient permissions",
404: "Document or upload file not found",
}
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def get(self, tenant_id, dataset_id, document_id):
dataset = self.get_dataset(str(dataset_id), str(tenant_id))
document = DocumentService.get_document(dataset.id, str(document_id))
if not document:
raise NotFound("Document not found.")
if document.tenant_id != str(tenant_id):
raise Forbidden("No permission.")
return {"url": DocumentService.get_document_download_url(document)}
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
class DocumentApi(DatasetApiResource):
METADATA_CHOICES = {"all", "only", "without"}

View File

@ -14,6 +14,7 @@ import httpx
from configs import dify_config
from core.db.session_factory import session_factory
from core.helper import ssrf_proxy
from dify_graph.file.models import ToolFile as ToolFilePydanticModel
from extensions.ext_storage import storage
from models.model import MessageFile
from models.tools import ToolFile
@ -207,7 +208,9 @@ class ToolFileManager:
return blob, tool_file.mimetype
def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, ToolFile | None]:
def get_file_generator_by_tool_file_id(
self, tool_file_id: str
) -> tuple[Generator | None, ToolFilePydanticModel | None]:
"""
get file binary
@ -229,7 +232,7 @@ class ToolFileManager:
stream = storage.load_stream(tool_file.file_key)
return stream, tool_file
return stream, ToolFilePydanticModel.model_validate(tool_file)
# init tool_file_parser

View File

@ -50,6 +50,7 @@ from dify_graph.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
)
from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode
from dify_graph.nodes.tool.tool_node import ToolNode
from dify_graph.variables.segments import StringSegment
from extensions.ext_database import db
from models.model import Conversation
@ -310,6 +311,15 @@ class DifyNodeFactory(NodeFactory):
memory=memory,
)
if node_type == NodeType.TOOL:
return ToolNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
tool_file_manager_factory=self._http_request_tool_file_manager_factory(),
)
return node_class(
id=node_id,
config=node_config,

View File

@ -2,6 +2,7 @@ from __future__ import annotations
from collections.abc import Mapping, Sequence
from typing import Any
from uuid import UUID, uuid4
from pydantic import BaseModel, Field, model_validator
@ -43,6 +44,24 @@ class FileUploadConfig(BaseModel):
number_limits: int = 0
class ToolFile(BaseModel):
id: UUID = Field(default_factory=uuid4, description="Unique identifier for the file")
user_id: UUID = Field(..., description="ID of the user who owns this file")
tenant_id: UUID = Field(..., description="ID of the tenant/organization")
conversation_id: UUID | None = Field(None, description="ID of the associated conversation")
file_key: str = Field(..., max_length=255, description="Storage key for the file")
mimetype: str = Field(..., max_length=255, description="MIME type of the file")
original_url: str | None = Field(
None, max_length=2048, description="Original URL if file was fetched from external source"
)
name: str = Field(default="", max_length=255, description="Display name of the file")
size: int = Field(default=-1, ge=-1, description="File size in bytes (-1 if unknown)")
class Config:
from_attributes = True # Enable ORM mode for SQLAlchemy compatibility
populate_by_name = True
class File(BaseModel):
# NOTE: dify_model_identity is a special identifier used to distinguish between
# new and old data formats during serialization and deserialization.

View File

@ -1,8 +1,10 @@
from collections.abc import Generator
from typing import Any, Protocol
import httpx
from dify_graph.file import File
from dify_graph.file.models import ToolFile
class HttpClientProtocol(Protocol):
@ -40,3 +42,5 @@ class ToolFileManagerProtocol(Protocol):
mimetype: str,
filename: str | None = None,
) -> Any: ...
def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, ToolFile | None]: ...

View File

@ -1,9 +1,6 @@
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
@ -21,11 +18,10 @@ from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
from dify_graph.nodes.protocols import ToolFileManagerProtocol
from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment
from dify_graph.variables.variables import ArrayAnyVariable
from extensions.ext_database import db
from factories import file_factory
from models import ToolFile
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .entities import ToolNodeData
@ -36,7 +32,8 @@ from .exc import (
)
if TYPE_CHECKING:
from dify_graph.runtime import VariablePool
from dify_graph.entities import GraphInitParams
from dify_graph.runtime import GraphRuntimeState, VariablePool
class ToolNode(Node[ToolNodeData]):
@ -46,6 +43,23 @@ class ToolNode(Node[ToolNodeData]):
node_type = NodeType.TOOL
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
tool_file_manager_factory: ToolFileManagerProtocol,
):
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._tool_file_manager_factory = tool_file_manager_factory
@classmethod
def version(cls) -> str:
return "1"
@ -271,11 +285,9 @@ class ToolNode(Node[ToolNodeData]):
tool_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
_, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id)
if not tool_file:
raise ToolFileError(f"tool file {tool_file_id} not found")
mapping = {
"tool_file_id": tool_file_id,
@ -294,11 +306,9 @@ class ToolNode(Node[ToolNodeData]):
assert message.meta
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileError(f"tool file {tool_file_id} not exists")
_, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id)
if not tool_file:
raise ToolFileError(f"tool file {tool_file_id} not exists")
mapping = {
"tool_file_id": tool_file_id,

View File

@ -8,6 +8,7 @@ from core.workflow.node_factory import DifyNodeFactory
from dify_graph.enums import WorkflowNodeExecutionStatus
from dify_graph.graph import Graph
from dify_graph.node_events import StreamCompletedEvent
from dify_graph.nodes.protocols import ToolFileManagerProtocol
from dify_graph.nodes.tool.tool_node import ToolNode
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
@ -55,11 +56,14 @@ def init_tool_node(config: dict):
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol)
node = ToolNode(
id=str(uuid.uuid4()),
config=config,
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
tool_file_manager_factory=tool_file_manager_factory,
)
return node

View File

@ -0,0 +1,92 @@
from __future__ import annotations
from controllers.console.app import annotation as annotation_module
def test_annotation_reply_payload_valid():
"""Test AnnotationReplyPayload with valid data."""
payload = annotation_module.AnnotationReplyPayload(
score_threshold=0.5,
embedding_provider_name="openai",
embedding_model_name="text-embedding-3-small",
)
assert payload.score_threshold == 0.5
assert payload.embedding_provider_name == "openai"
assert payload.embedding_model_name == "text-embedding-3-small"
def test_annotation_setting_update_payload_valid():
"""Test AnnotationSettingUpdatePayload with valid data."""
payload = annotation_module.AnnotationSettingUpdatePayload(
score_threshold=0.75,
)
assert payload.score_threshold == 0.75
def test_annotation_list_query_defaults():
"""Test AnnotationListQuery with default parameters."""
query = annotation_module.AnnotationListQuery()
assert query.page == 1
assert query.limit == 20
assert query.keyword == ""
def test_annotation_list_query_custom_page():
"""Test AnnotationListQuery with custom page."""
query = annotation_module.AnnotationListQuery(page=3, limit=50)
assert query.page == 3
assert query.limit == 50
def test_annotation_list_query_with_keyword():
"""Test AnnotationListQuery with keyword."""
query = annotation_module.AnnotationListQuery(keyword="test")
assert query.keyword == "test"
def test_create_annotation_payload_with_message_id():
"""Test CreateAnnotationPayload with message ID."""
payload = annotation_module.CreateAnnotationPayload(
message_id="550e8400-e29b-41d4-a716-446655440000",
question="What is AI?",
)
assert payload.message_id == "550e8400-e29b-41d4-a716-446655440000"
assert payload.question == "What is AI?"
def test_create_annotation_payload_with_text():
"""Test CreateAnnotationPayload with text content."""
payload = annotation_module.CreateAnnotationPayload(
question="What is ML?",
answer="Machine learning is...",
)
assert payload.question == "What is ML?"
assert payload.answer == "Machine learning is..."
def test_update_annotation_payload():
"""Test UpdateAnnotationPayload."""
payload = annotation_module.UpdateAnnotationPayload(
question="Updated question",
answer="Updated answer",
)
assert payload.question == "Updated question"
assert payload.answer == "Updated answer"
def test_annotation_reply_status_query_enable():
"""Test AnnotationReplyStatusQuery with enable action."""
query = annotation_module.AnnotationReplyStatusQuery(action="enable")
assert query.action == "enable"
def test_annotation_reply_status_query_disable():
"""Test AnnotationReplyStatusQuery with disable action."""
query = annotation_module.AnnotationReplyStatusQuery(action="disable")
assert query.action == "disable"
def test_annotation_file_payload_valid():
"""Test AnnotationFilePayload with valid message ID."""
payload = annotation_module.AnnotationFilePayload(message_id="550e8400-e29b-41d4-a716-446655440000")
assert payload.message_id == "550e8400-e29b-41d4-a716-446655440000"

View File

@ -13,6 +13,9 @@ from pandas.errors import ParserError
from werkzeug.datastructures import FileStorage
from configs import dify_config
from controllers.console.wraps import annotation_import_concurrency_limit, annotation_import_rate_limit
from services.annotation_service import AppAnnotationService
from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task
class TestAnnotationImportRateLimiting:
@ -33,8 +36,6 @@ class TestAnnotationImportRateLimiting:
def test_rate_limit_per_minute_enforced(self, mock_redis, mock_current_account):
"""Test that per-minute rate limit is enforced."""
from controllers.console.wraps import annotation_import_rate_limit
# Simulate exceeding per-minute limit
mock_redis.zcard.side_effect = [
dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE + 1, # Minute check
@ -54,7 +55,6 @@ class TestAnnotationImportRateLimiting:
def test_rate_limit_per_hour_enforced(self, mock_redis, mock_current_account):
"""Test that per-hour rate limit is enforced."""
from controllers.console.wraps import annotation_import_rate_limit
# Simulate exceeding per-hour limit
mock_redis.zcard.side_effect = [
@ -74,7 +74,6 @@ class TestAnnotationImportRateLimiting:
def test_rate_limit_within_limits_passes(self, mock_redis, mock_current_account):
"""Test that requests within limits are allowed."""
from controllers.console.wraps import annotation_import_rate_limit
# Simulate being under both limits
mock_redis.zcard.return_value = 2
@ -110,7 +109,6 @@ class TestAnnotationImportConcurrencyControl:
def test_concurrency_limit_enforced(self, mock_redis, mock_current_account):
"""Test that concurrent task limit is enforced."""
from controllers.console.wraps import annotation_import_concurrency_limit
# Simulate max concurrent tasks already running
mock_redis.zcard.return_value = dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT
@ -127,7 +125,6 @@ class TestAnnotationImportConcurrencyControl:
def test_concurrency_within_limit_passes(self, mock_redis, mock_current_account):
"""Test that requests within concurrency limits are allowed."""
from controllers.console.wraps import annotation_import_concurrency_limit
# Simulate being under concurrent task limit
mock_redis.zcard.return_value = 1
@ -142,7 +139,6 @@ class TestAnnotationImportConcurrencyControl:
def test_stale_jobs_are_cleaned_up(self, mock_redis, mock_current_account):
"""Test that old/stale job entries are removed."""
from controllers.console.wraps import annotation_import_concurrency_limit
mock_redis.zcard.return_value = 0
@ -203,7 +199,6 @@ class TestAnnotationImportServiceValidation:
def test_max_records_limit_enforced(self, mock_app, mock_db_session):
"""Test that files with too many records are rejected."""
from services.annotation_service import AppAnnotationService
# Create CSV with too many records
max_records = dify_config.ANNOTATION_IMPORT_MAX_RECORDS
@ -229,7 +224,6 @@ class TestAnnotationImportServiceValidation:
def test_min_records_limit_enforced(self, mock_app, mock_db_session):
"""Test that files with too few valid records are rejected."""
from services.annotation_service import AppAnnotationService
# Create CSV with only header (no data rows)
csv_content = "question,answer\n"
@ -249,7 +243,6 @@ class TestAnnotationImportServiceValidation:
def test_invalid_csv_format_handled(self, mock_app, mock_db_session):
"""Test that invalid CSV format is handled gracefully."""
from services.annotation_service import AppAnnotationService
# Any content is fine once we force ParserError
csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff'
@ -270,7 +263,6 @@ class TestAnnotationImportServiceValidation:
def test_valid_import_succeeds(self, mock_app, mock_db_session):
"""Test that valid import request succeeds."""
from services.annotation_service import AppAnnotationService
# Create valid CSV
csv_content = "question,answer\nWhat is AI?,Artificial Intelligence\nWhat is ML?,Machine Learning\n"
@ -300,18 +292,10 @@ class TestAnnotationImportServiceValidation:
class TestAnnotationImportTaskOptimization:
"""Test optimizations in batch import task."""
def test_task_has_timeout_configured(self):
"""Test that task has proper timeout configuration."""
from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task
# Verify task configuration
assert hasattr(batch_import_annotations_task, "time_limit")
assert hasattr(batch_import_annotations_task, "soft_time_limit")
# Check timeout values are reasonable
# Hard limit should be 6 minutes (360s)
# Soft limit should be 5 minutes (300s)
# Note: actual values depend on Celery configuration
def test_task_is_registered_with_queue(self):
"""Test that task is registered with the correct queue."""
assert hasattr(batch_import_annotations_task, "apply_async")
assert hasattr(batch_import_annotations_task, "delay")
class TestConfigurationValues:

View File

@ -0,0 +1,585 @@
"""
Additional tests to improve coverage for low-coverage modules in controllers/console/app.
Target: increase coverage for files with <75% coverage.
"""
from __future__ import annotations
import uuid
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from werkzeug.exceptions import BadRequest, NotFound
from controllers.console.app import (
annotation as annotation_module,
)
from controllers.console.app import (
completion as completion_module,
)
from controllers.console.app import (
message as message_module,
)
from controllers.console.app import (
ops_trace as ops_trace_module,
)
from controllers.console.app import (
site as site_module,
)
from controllers.console.app import (
statistic as statistic_module,
)
from controllers.console.app import (
workflow_app_log as workflow_app_log_module,
)
from controllers.console.app import (
workflow_draft_variable as workflow_draft_variable_module,
)
from controllers.console.app import (
workflow_statistic as workflow_statistic_module,
)
from controllers.console.app import (
workflow_trigger as workflow_trigger_module,
)
from controllers.console.app import (
wraps as wraps_module,
)
from controllers.console.app.completion import ChatMessagePayload, CompletionMessagePayload
from controllers.console.app.mcp_server import MCPServerCreatePayload, MCPServerUpdatePayload
from controllers.console.app.ops_trace import TraceConfigPayload, TraceProviderQuery
from controllers.console.app.site import AppSiteUpdatePayload
from controllers.console.app.workflow import AdvancedChatWorkflowRunPayload, SyncDraftWorkflowPayload
from controllers.console.app.workflow_app_log import WorkflowAppLogQuery
from controllers.console.app.workflow_draft_variable import WorkflowDraftVariableUpdatePayload
from controllers.console.app.workflow_statistic import WorkflowStatisticQuery
from controllers.console.app.workflow_trigger import Parser, ParserEnable
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
class _ConnContext:
def __init__(self, rows):
self._rows = rows
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def execute(self, _query, _args):
return self._rows
# ========== Completion Tests ==========
class TestCompletionEndpoints:
"""Tests for completion API endpoints."""
def test_completion_create_payload(self):
"""Test completion creation payload."""
payload = CompletionMessagePayload(inputs={"prompt": "test"}, model_config={})
assert payload.inputs == {"prompt": "test"}
def test_chat_message_payload_uuid_validation(self):
payload = ChatMessagePayload(
inputs={},
model_config={},
query="hi",
conversation_id=str(uuid.uuid4()),
parent_message_id=str(uuid.uuid4()),
)
assert payload.query == "hi"
def test_completion_api_success(self, app, monkeypatch):
api = completion_module.CompletionMessageApi()
method = _unwrap(api.post)
class DummyAccount:
pass
dummy_account = DummyAccount()
monkeypatch.setattr(completion_module, "current_user", dummy_account)
monkeypatch.setattr(completion_module, "Account", DummyAccount)
monkeypatch.setattr(
completion_module.AppGenerateService,
"generate",
lambda **_kwargs: {"text": "ok"},
)
monkeypatch.setattr(
completion_module.helper,
"compact_generate_response",
lambda response: {"result": response},
)
with app.test_request_context(
"/",
json={"inputs": {}, "model_config": {}, "query": "hi"},
):
resp = method(app_model=MagicMock(id="app-1"))
assert resp == {"result": {"text": "ok"}}
def test_completion_api_conversation_not_exists(self, app, monkeypatch):
api = completion_module.CompletionMessageApi()
method = _unwrap(api.post)
class DummyAccount:
pass
dummy_account = DummyAccount()
monkeypatch.setattr(completion_module, "current_user", dummy_account)
monkeypatch.setattr(completion_module, "Account", DummyAccount)
monkeypatch.setattr(
completion_module.AppGenerateService,
"generate",
lambda **_kwargs: (_ for _ in ()).throw(
completion_module.services.errors.conversation.ConversationNotExistsError()
),
)
with app.test_request_context(
"/",
json={"inputs": {}, "model_config": {}, "query": "hi"},
):
with pytest.raises(NotFound):
method(app_model=MagicMock(id="app-1"))
def test_completion_api_provider_not_initialized(self, app, monkeypatch):
api = completion_module.CompletionMessageApi()
method = _unwrap(api.post)
class DummyAccount:
pass
dummy_account = DummyAccount()
monkeypatch.setattr(completion_module, "current_user", dummy_account)
monkeypatch.setattr(completion_module, "Account", DummyAccount)
monkeypatch.setattr(
completion_module.AppGenerateService,
"generate",
lambda **_kwargs: (_ for _ in ()).throw(completion_module.ProviderTokenNotInitError("x")),
)
with app.test_request_context(
"/",
json={"inputs": {}, "model_config": {}, "query": "hi"},
):
with pytest.raises(completion_module.ProviderNotInitializeError):
method(app_model=MagicMock(id="app-1"))
def test_completion_api_quota_exceeded(self, app, monkeypatch):
api = completion_module.CompletionMessageApi()
method = _unwrap(api.post)
class DummyAccount:
pass
dummy_account = DummyAccount()
monkeypatch.setattr(completion_module, "current_user", dummy_account)
monkeypatch.setattr(completion_module, "Account", DummyAccount)
monkeypatch.setattr(
completion_module.AppGenerateService,
"generate",
lambda **_kwargs: (_ for _ in ()).throw(completion_module.QuotaExceededError()),
)
with app.test_request_context(
"/",
json={"inputs": {}, "model_config": {}, "query": "hi"},
):
with pytest.raises(completion_module.ProviderQuotaExceededError):
method(app_model=MagicMock(id="app-1"))
# ========== OpsTrace Tests ==========
class TestOpsTraceEndpoints:
"""Tests for ops_trace endpoint."""
def test_ops_trace_query_basic(self):
"""Test ops_trace query."""
query = TraceProviderQuery(tracing_provider="langfuse")
assert query.tracing_provider == "langfuse"
def test_ops_trace_config_payload(self):
payload = TraceConfigPayload(tracing_provider="langfuse", tracing_config={"api_key": "k"})
assert payload.tracing_config["api_key"] == "k"
def test_trace_app_config_get_empty(self, app, monkeypatch):
api = ops_trace_module.TraceAppConfigApi()
method = _unwrap(api.get)
monkeypatch.setattr(
ops_trace_module.OpsService,
"get_tracing_app_config",
lambda **_kwargs: None,
)
with app.test_request_context("/?tracing_provider=langfuse"):
result = method(app_id="app-1")
assert result == {"has_not_configured": True}
def test_trace_app_config_post_invalid(self, app, monkeypatch):
api = ops_trace_module.TraceAppConfigApi()
method = _unwrap(api.post)
monkeypatch.setattr(
ops_trace_module.OpsService,
"create_tracing_app_config",
lambda **_kwargs: {"error": True},
)
with app.test_request_context(
"/",
json={"tracing_provider": "langfuse", "tracing_config": {"api_key": "k"}},
):
with pytest.raises(BadRequest):
method(app_id="app-1")
def test_trace_app_config_delete_not_found(self, app, monkeypatch):
api = ops_trace_module.TraceAppConfigApi()
method = _unwrap(api.delete)
monkeypatch.setattr(
ops_trace_module.OpsService,
"delete_tracing_app_config",
lambda **_kwargs: False,
)
with app.test_request_context("/?tracing_provider=langfuse"):
with pytest.raises(BadRequest):
method(app_id="app-1")
# ========== Site Tests ==========
class TestSiteEndpoints:
"""Tests for site endpoint."""
def test_site_response_structure(self):
"""Test site response structure."""
payload = AppSiteUpdatePayload(title="My Site", description="Test site")
assert payload.title == "My Site"
def test_site_default_language_validation(self):
payload = AppSiteUpdatePayload(default_language="en-US")
assert payload.default_language == "en-US"
def test_app_site_update_post(self, app, monkeypatch):
api = site_module.AppSite()
method = _unwrap(api.post)
site = MagicMock()
query = MagicMock()
query.where.return_value.first.return_value = site
monkeypatch.setattr(
site_module.db,
"session",
MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None),
)
monkeypatch.setattr(
site_module,
"current_account_with_tenant",
lambda: (SimpleNamespace(id="u1"), "t1"),
)
monkeypatch.setattr(site_module, "naive_utc_now", lambda: "now")
with app.test_request_context("/", json={"title": "My Site"}):
result = method(app_model=SimpleNamespace(id="app-1"))
assert result is site
def test_app_site_access_token_reset(self, app, monkeypatch):
api = site_module.AppSiteAccessTokenReset()
method = _unwrap(api.post)
site = MagicMock()
query = MagicMock()
query.where.return_value.first.return_value = site
monkeypatch.setattr(
site_module.db,
"session",
MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None),
)
monkeypatch.setattr(site_module.Site, "generate_code", lambda *_args, **_kwargs: "code")
monkeypatch.setattr(
site_module,
"current_account_with_tenant",
lambda: (SimpleNamespace(id="u1"), "t1"),
)
monkeypatch.setattr(site_module, "naive_utc_now", lambda: "now")
with app.test_request_context("/"):
result = method(app_model=SimpleNamespace(id="app-1"))
assert result is site
# ========== Workflow Tests ==========
class TestWorkflowEndpoints:
"""Tests for workflow endpoints."""
def test_workflow_copy_payload(self):
"""Test workflow copy payload."""
payload = SyncDraftWorkflowPayload(graph={}, features={})
assert payload.graph == {}
def test_workflow_mode_query(self):
"""Test workflow mode query."""
payload = AdvancedChatWorkflowRunPayload(inputs={}, query="hi")
assert payload.query == "hi"
# ========== Workflow App Log Tests ==========
class TestWorkflowAppLogEndpoints:
"""Tests for workflow app log endpoints."""
def test_workflow_app_log_query(self):
"""Test workflow app log query."""
query = WorkflowAppLogQuery(keyword="test", page=1, limit=20)
assert query.keyword == "test"
def test_workflow_app_log_query_detail_bool(self):
query = WorkflowAppLogQuery(detail="true")
assert query.detail is True
def test_workflow_app_log_api_get(self, app, monkeypatch):
api = workflow_app_log_module.WorkflowAppLogApi()
method = _unwrap(api.get)
monkeypatch.setattr(workflow_app_log_module, "db", SimpleNamespace(engine=MagicMock()))
class DummySession:
def __enter__(self):
return "session"
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(workflow_app_log_module, "Session", lambda *args, **kwargs: DummySession())
def fake_get_paginate(self, **_kwargs):
return {"items": [], "total": 0}
monkeypatch.setattr(
workflow_app_log_module.WorkflowAppService,
"get_paginate_workflow_app_logs",
fake_get_paginate,
)
with app.test_request_context("/?page=1&limit=20"):
result = method(app_model=SimpleNamespace(id="app-1"))
assert result == {"items": [], "total": 0}
# ========== Workflow Draft Variable Tests ==========
class TestWorkflowDraftVariableEndpoints:
"""Tests for workflow draft variable endpoints."""
def test_workflow_variable_creation(self):
"""Test workflow variable creation."""
payload = WorkflowDraftVariableUpdatePayload(name="var1", value="test")
assert payload.name == "var1"
def test_workflow_variable_collection_get(self, app, monkeypatch):
api = workflow_draft_variable_module.WorkflowVariableCollectionApi()
method = _unwrap(api.get)
monkeypatch.setattr(workflow_draft_variable_module, "db", SimpleNamespace(engine=MagicMock()))
class DummySession:
def __enter__(self):
return "session"
def __exit__(self, exc_type, exc, tb):
return False
class DummyDraftService:
def __init__(self, session):
self.session = session
def list_variables_without_values(self, **_kwargs):
return {"items": [], "total": 0}
monkeypatch.setattr(workflow_draft_variable_module, "Session", lambda *args, **kwargs: DummySession())
class DummyWorkflowService:
def is_workflow_exist(self, *args, **kwargs):
return True
monkeypatch.setattr(workflow_draft_variable_module, "WorkflowDraftVariableService", DummyDraftService)
monkeypatch.setattr(workflow_draft_variable_module, "WorkflowService", DummyWorkflowService)
with app.test_request_context("/?page=1&limit=20"):
result = method(app_model=SimpleNamespace(id="app-1"))
assert result == {"items": [], "total": 0}
# ========== Workflow Statistic Tests ==========
class TestWorkflowStatisticEndpoints:
"""Tests for workflow statistic endpoints."""
def test_workflow_statistic_time_range(self):
"""Test workflow statistic time range query."""
query = WorkflowStatisticQuery(start="2024-01-01", end="2024-12-31")
assert query.start == "2024-01-01"
def test_workflow_statistic_blank_to_none(self):
query = WorkflowStatisticQuery(start="", end="")
assert query.start is None
assert query.end is None
def test_workflow_daily_runs_statistic(self, app, monkeypatch):
monkeypatch.setattr(workflow_statistic_module, "db", SimpleNamespace(engine=MagicMock()))
monkeypatch.setattr(
workflow_statistic_module.DifyAPIRepositoryFactory,
"create_api_workflow_run_repository",
lambda *_args, **_kwargs: SimpleNamespace(get_daily_runs_statistics=lambda **_kw: [{"date": "2024-01-01"}]),
)
monkeypatch.setattr(
workflow_statistic_module,
"current_account_with_tenant",
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
)
monkeypatch.setattr(
workflow_statistic_module,
"parse_time_range",
lambda *_args, **_kwargs: (None, None),
)
api = workflow_statistic_module.WorkflowDailyRunsStatistic()
method = _unwrap(api.get)
with app.test_request_context("/"):
response = method(app_model=SimpleNamespace(tenant_id="t1", id="app-1"))
assert response.get_json() == {"data": [{"date": "2024-01-01"}]}
def test_workflow_daily_terminals_statistic(self, app, monkeypatch):
monkeypatch.setattr(workflow_statistic_module, "db", SimpleNamespace(engine=MagicMock()))
monkeypatch.setattr(
workflow_statistic_module.DifyAPIRepositoryFactory,
"create_api_workflow_run_repository",
lambda *_args, **_kwargs: SimpleNamespace(
get_daily_terminals_statistics=lambda **_kw: [{"date": "2024-01-02"}]
),
)
monkeypatch.setattr(
workflow_statistic_module,
"current_account_with_tenant",
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
)
monkeypatch.setattr(
workflow_statistic_module,
"parse_time_range",
lambda *_args, **_kwargs: (None, None),
)
api = workflow_statistic_module.WorkflowDailyTerminalsStatistic()
method = _unwrap(api.get)
with app.test_request_context("/"):
response = method(app_model=SimpleNamespace(tenant_id="t1", id="app-1"))
assert response.get_json() == {"data": [{"date": "2024-01-02"}]}
# ========== Workflow Trigger Tests ==========
class TestWorkflowTriggerEndpoints:
"""Tests for workflow trigger endpoints."""
def test_webhook_trigger_payload(self):
"""Test webhook trigger payload."""
payload = Parser(node_id="node-1")
assert payload.node_id == "node-1"
enable_payload = ParserEnable(trigger_id="trigger-1", enable_trigger=True)
assert enable_payload.enable_trigger is True
def test_webhook_trigger_api_get(self, app, monkeypatch):
api = workflow_trigger_module.WebhookTriggerApi()
method = _unwrap(api.get)
monkeypatch.setattr(workflow_trigger_module, "db", SimpleNamespace(engine=MagicMock()))
trigger = MagicMock()
session = MagicMock()
session.query.return_value.where.return_value.first.return_value = trigger
class DummySession:
def __enter__(self):
return session
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(workflow_trigger_module, "Session", lambda *_args, **_kwargs: DummySession())
with app.test_request_context("/?node_id=node-1"):
result = method(app_model=SimpleNamespace(id="app-1"))
assert result is trigger
# ========== Wraps Tests ==========
class TestWrapsEndpoints:
"""Tests for wraps utility functions."""
def test_get_app_model_context(self):
"""Test get_app_model wrapper context."""
# These are decorator functions, so we test their availability
assert hasattr(wraps_module, "get_app_model")
# ========== MCP Server Tests ==========
class TestMCPServerEndpoints:
"""Tests for MCP server endpoints."""
def test_mcp_server_connection(self):
"""Test MCP server connection."""
payload = MCPServerCreatePayload(parameters={"url": "http://localhost:3000"})
assert payload.parameters["url"] == "http://localhost:3000"
def test_mcp_server_update_payload(self):
payload = MCPServerUpdatePayload(id="server-1", parameters={"timeout": 30}, status="active")
assert payload.status == "active"
# ========== Error Handling Tests ==========
class TestErrorHandling:
"""Tests for error handling in various endpoints."""
def test_annotation_list_query_validation(self):
"""Test annotation list query validation."""
with pytest.raises(ValueError):
annotation_module.AnnotationListQuery(page=0)
# ========== Integration-like Tests ==========
class TestPayloadIntegration:
"""Integration tests for payload handling."""
def test_multiple_payload_types(self):
"""Test handling of multiple payload types."""
payloads = [
annotation_module.AnnotationReplyPayload(
score_threshold=0.5, embedding_provider_name="openai", embedding_model_name="text-embedding-3-small"
),
message_module.MessageFeedbackPayload(message_id=str(uuid.uuid4()), rating="like"),
statistic_module.StatisticTimeRangeQuery(start="2024-01-01"),
]
assert len(payloads) == 3
assert all(p is not None for p in payloads)

View File

@ -0,0 +1,157 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from controllers.console.app import app_import as app_import_module
from services.app_dsl_service import ImportStatus
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
class _Result:
def __init__(self, status: ImportStatus, app_id: str | None = "app-1"):
self.status = status
self.app_id = app_id
def model_dump(self, mode: str = "json"):
return {"status": self.status, "app_id": self.app_id}
class _SessionContext:
def __init__(self, session):
self._session = session
def __enter__(self):
return self._session
def __exit__(self, exc_type, exc, tb):
return False
def _install_session(monkeypatch: pytest.MonkeyPatch, session: MagicMock) -> None:
monkeypatch.setattr(app_import_module, "Session", lambda *_: _SessionContext(session))
monkeypatch.setattr(app_import_module, "db", SimpleNamespace(engine=object()))
def _install_features(monkeypatch: pytest.MonkeyPatch, enabled: bool) -> None:
features = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=enabled))
monkeypatch.setattr(app_import_module.FeatureService, "get_system_features", lambda: features)
def test_import_post_returns_failed_status(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
session = MagicMock()
_install_session(monkeypatch, session)
_install_features(monkeypatch, enabled=False)
monkeypatch.setattr(
app_import_module.AppDslService,
"import_app",
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
session.commit.assert_called_once()
assert status == 400
assert response["status"] == ImportStatus.FAILED
def test_import_post_returns_pending_status(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
session = MagicMock()
_install_session(monkeypatch, session)
_install_features(monkeypatch, enabled=False)
monkeypatch.setattr(
app_import_module.AppDslService,
"import_app",
lambda *_args, **_kwargs: _Result(ImportStatus.PENDING),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
session.commit.assert_called_once()
assert status == 202
assert response["status"] == ImportStatus.PENDING
def test_import_post_updates_webapp_auth_when_enabled(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
session = MagicMock()
_install_session(monkeypatch, session)
_install_features(monkeypatch, enabled=True)
monkeypatch.setattr(
app_import_module.AppDslService,
"import_app",
lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"),
)
update_access = MagicMock()
monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
session.commit.assert_called_once()
update_access.assert_called_once_with("app-123", "private")
assert status == 200
assert response["status"] == ImportStatus.COMPLETED
def test_import_confirm_returns_failed_status(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportConfirmApi()
method = _unwrap(api.post)
session = MagicMock()
_install_session(monkeypatch, session)
monkeypatch.setattr(
app_import_module.AppDslService,
"confirm_import",
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"):
response, status = method(import_id="import-1")
session.commit.assert_called_once()
assert status == 400
assert response["status"] == ImportStatus.FAILED
def test_import_check_dependencies_returns_result(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportCheckDependenciesApi()
method = _unwrap(api.get)
session = MagicMock()
_install_session(monkeypatch, session)
monkeypatch.setattr(
app_import_module.AppDslService,
"check_dependencies",
lambda *_args, **_kwargs: SimpleNamespace(model_dump=lambda mode="json": {"leaked_dependencies": []}),
)
with app.test_request_context("/console/api/apps/imports/app-1/check-dependencies", method="GET"):
response, status = method(app_model=SimpleNamespace(id="app-1"))
assert status == 200
assert response["leaked_dependencies"] == []

View File

@ -0,0 +1,292 @@
from __future__ import annotations
import io
from types import SimpleNamespace
import pytest
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import InternalServerError
from controllers.console.app.audio import ChatMessageAudioApi, ChatMessageTextApi, TextModesApi
from controllers.console.app.error import (
AppUnavailableError,
AudioTooLargeError,
CompletionRequestError,
NoAudioUploadedError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderNotSupportSpeechToTextError,
ProviderQuotaExceededError,
UnsupportedAudioTypeError,
)
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from dify_graph.model_runtime.errors.invoke import InvokeError
from services.audio_service import AudioService
from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.audio import (
AudioTooLargeServiceError,
NoAudioUploadedServiceError,
ProviderNotSupportSpeechToTextServiceError,
ProviderNotSupportTextToSpeechLanageServiceError,
UnsupportedAudioTypeServiceError,
)
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
def _file_data():
return FileStorage(stream=io.BytesIO(b"audio"), filename="audio.wav", content_type="audio/wav")
def test_console_audio_api_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "ok"})
api = ChatMessageAudioApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="a1")
with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}):
response = handler(app_model=app_model)
assert response == {"text": "ok"}
@pytest.mark.parametrize(
("exc", "expected"),
[
(AppModelConfigBrokenError(), AppUnavailableError),
(NoAudioUploadedServiceError(), NoAudioUploadedError),
(AudioTooLargeServiceError("too big"), AudioTooLargeError),
(UnsupportedAudioTypeServiceError(), UnsupportedAudioTypeError),
(ProviderNotSupportSpeechToTextServiceError(), ProviderNotSupportSpeechToTextError),
(ProviderTokenNotInitError("token"), ProviderNotInitializeError),
(QuotaExceededError(), ProviderQuotaExceededError),
(ModelCurrentlyNotSupportError(), ProviderModelCurrentlyNotSupportError),
(InvokeError("invoke"), CompletionRequestError),
],
)
def test_console_audio_api_error_mapping(app, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None:
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(exc))
api = ChatMessageAudioApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="a1")
with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}):
with pytest.raises(expected):
handler(app_model=app_model)
def test_console_audio_api_unhandled_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("boom")))
api = ChatMessageAudioApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="a1")
with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}):
with pytest.raises(InternalServerError):
handler(app_model=app_model)
def test_console_text_api_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
api = ChatMessageTextApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="a1")
with app.test_request_context(
"/console/api/apps/app/text-to-audio",
method="POST",
json={"text": "hello", "voice": "v"},
):
response = handler(app_model=app_model)
assert response == {"audio": "ok"}
def test_console_text_api_error_mapping(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: (_ for _ in ()).throw(QuotaExceededError()))
api = ChatMessageTextApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="a1")
with app.test_request_context(
"/console/api/apps/app/text-to-audio",
method="POST",
json={"text": "hello"},
):
with pytest.raises(ProviderQuotaExceededError):
handler(app_model=app_model)
def test_console_text_modes_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"])
api = TextModesApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(tenant_id="t1")
with app.test_request_context("/console/api/apps/app/text-to-audio/voices?language=en", method="GET"):
response = handler(app_model=app_model)
assert response == ["voice-1"]
def test_console_text_modes_language_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AudioService,
"transcript_tts_voices",
lambda **_kwargs: (_ for _ in ()).throw(ProviderNotSupportTextToSpeechLanageServiceError()),
)
api = TextModesApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(tenant_id="t1")
with app.test_request_context("/console/api/apps/app/text-to-audio/voices?language=en", method="GET"):
with pytest.raises(AppUnavailableError):
handler(app_model=app_model)
def test_audio_to_text_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = ChatMessageAudioApi()
method = _unwrap(api.post)
response_payload = {"text": "hello"}
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: response_payload)
app_model = SimpleNamespace(id="app-1")
data = {"file": (io.BytesIO(b"x"), "sample.wav")}
with app.test_request_context(
"/console/api/apps/app-1/audio-to-text",
method="POST",
data=data,
content_type="multipart/form-data",
):
response = method(app_model=app_model)
assert response == response_payload
def test_audio_to_text_maps_audio_too_large(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = ChatMessageAudioApi()
method = _unwrap(api.post)
monkeypatch.setattr(
AudioService,
"transcript_asr",
lambda **_kwargs: (_ for _ in ()).throw(AudioTooLargeServiceError("too large")),
)
app_model = SimpleNamespace(id="app-1")
data = {"file": (io.BytesIO(b"x"), "sample.wav")}
with app.test_request_context(
"/console/api/apps/app-1/audio-to-text",
method="POST",
data=data,
content_type="multipart/form-data",
):
with pytest.raises(AudioTooLargeError):
method(app_model=app_model)
def test_text_to_audio_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = ChatMessageTextApi()
method = _unwrap(api.post)
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
app_model = SimpleNamespace(id="app-1")
with app.test_request_context(
"/console/api/apps/app-1/text-to-audio",
method="POST",
json={"text": "hello"},
):
response = method(app_model=app_model)
assert response == {"audio": "ok"}
def test_text_to_audio_voices_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = TextModesApi()
method = _unwrap(api.get)
monkeypatch.setattr(AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"])
app_model = SimpleNamespace(tenant_id="tenant-1")
with app.test_request_context(
"/console/api/apps/app-1/text-to-audio/voices",
method="GET",
query_string={"language": "en-US"},
):
response = method(app_model=app_model)
assert response == ["voice-1"]
def test_audio_to_text_with_invalid_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = ChatMessageAudioApi()
method = _unwrap(api.post)
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "test"})
app_model = SimpleNamespace(id="app-1")
data = {"file": (io.BytesIO(b"invalid"), "sample.xyz")}
with app.test_request_context(
"/console/api/apps/app-1/audio-to-text",
method="POST",
data=data,
content_type="multipart/form-data",
):
# Should not raise, AudioService is mocked
response = method(app_model=app_model)
assert response == {"text": "test"}
def test_text_to_audio_with_language_param(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = ChatMessageTextApi()
method = _unwrap(api.post)
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "test"})
app_model = SimpleNamespace(id="app-1")
with app.test_request_context(
"/console/api/apps/app-1/text-to-audio",
method="POST",
json={"text": "hello", "language": "en-US"},
):
response = method(app_model=app_model)
assert response == {"audio": "test"}
def test_text_to_audio_voices_with_language_filter(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = TextModesApi()
method = _unwrap(api.get)
monkeypatch.setattr(
AudioService,
"transcript_tts_voices",
lambda **_kwargs: [{"id": "voice-1", "name": "Voice 1"}],
)
app_model = SimpleNamespace(tenant_id="tenant-1")
with app.test_request_context(
"/console/api/apps/app-1/text-to-audio/voices?language=en-US",
method="GET",
):
response = method(app_model=app_model)
assert isinstance(response, list)

View File

@ -0,0 +1,156 @@
from __future__ import annotations
import io
from types import SimpleNamespace
import pytest
from controllers.console.app import audio as audio_module
from controllers.console.app.error import AudioTooLargeError
from services.errors.audio import AudioTooLargeServiceError
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
def test_audio_to_text_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = audio_module.ChatMessageAudioApi()
method = _unwrap(api.post)
response_payload = {"text": "hello"}
monkeypatch.setattr(audio_module.AudioService, "transcript_asr", lambda **_kwargs: response_payload)
app_model = SimpleNamespace(id="app-1")
data = {"file": (io.BytesIO(b"x"), "sample.wav")}
with app.test_request_context(
"/console/api/apps/app-1/audio-to-text",
method="POST",
data=data,
content_type="multipart/form-data",
):
response = method(app_model=app_model)
assert response == response_payload
def test_audio_to_text_maps_audio_too_large(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = audio_module.ChatMessageAudioApi()
method = _unwrap(api.post)
monkeypatch.setattr(
audio_module.AudioService,
"transcript_asr",
lambda **_kwargs: (_ for _ in ()).throw(AudioTooLargeServiceError("too large")),
)
app_model = SimpleNamespace(id="app-1")
data = {"file": (io.BytesIO(b"x"), "sample.wav")}
with app.test_request_context(
"/console/api/apps/app-1/audio-to-text",
method="POST",
data=data,
content_type="multipart/form-data",
):
with pytest.raises(AudioTooLargeError):
method(app_model=app_model)
def test_text_to_audio_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = audio_module.ChatMessageTextApi()
method = _unwrap(api.post)
monkeypatch.setattr(audio_module.AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
app_model = SimpleNamespace(id="app-1")
with app.test_request_context(
"/console/api/apps/app-1/text-to-audio",
method="POST",
json={"text": "hello"},
):
response = method(app_model=app_model)
assert response == {"audio": "ok"}
def test_text_to_audio_voices_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = audio_module.TextModesApi()
method = _unwrap(api.get)
monkeypatch.setattr(audio_module.AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"])
app_model = SimpleNamespace(tenant_id="tenant-1")
with app.test_request_context(
"/console/api/apps/app-1/text-to-audio/voices",
method="GET",
query_string={"language": "en-US"},
):
response = method(app_model=app_model)
assert response == ["voice-1"]
def test_audio_to_text_with_invalid_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = audio_module.ChatMessageAudioApi()
method = _unwrap(api.post)
monkeypatch.setattr(audio_module.AudioService, "transcript_asr", lambda **_kwargs: {"text": "test"})
app_model = SimpleNamespace(id="app-1")
data = {"file": (io.BytesIO(b"invalid"), "sample.xyz")}
with app.test_request_context(
"/console/api/apps/app-1/audio-to-text",
method="POST",
data=data,
content_type="multipart/form-data",
):
# Should not raise, AudioService is mocked
response = method(app_model=app_model)
assert response == {"text": "test"}
def test_text_to_audio_with_language_param(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = audio_module.ChatMessageTextApi()
method = _unwrap(api.post)
monkeypatch.setattr(audio_module.AudioService, "transcript_tts", lambda **_kwargs: {"audio": "test"})
app_model = SimpleNamespace(id="app-1")
with app.test_request_context(
"/console/api/apps/app-1/text-to-audio",
method="POST",
json={"text": "hello", "language": "en-US"},
):
response = method(app_model=app_model)
assert response == {"audio": "test"}
def test_text_to_audio_voices_with_language_filter(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = audio_module.TextModesApi()
method = _unwrap(api.get)
monkeypatch.setattr(
audio_module.AudioService,
"transcript_tts_voices",
lambda **_kwargs: [{"id": "voice-1", "name": "Voice 1"}],
)
app_model = SimpleNamespace(tenant_id="tenant-1")
with app.test_request_context(
"/console/api/apps/app-1/text-to-audio/voices?language=en-US",
method="GET",
):
response = method(app_model=app_model)
assert isinstance(response, list)

View File

@ -0,0 +1,130 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from werkzeug.exceptions import BadRequest, NotFound
from controllers.console.app import conversation as conversation_module
from models.model import AppMode
from services.errors.conversation import ConversationNotExistsError
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
def _make_account():
return SimpleNamespace(timezone="UTC", id="u1")
def test_completion_conversation_list_returns_paginated_result(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = conversation_module.CompletionConversationApi()
method = _unwrap(api.get)
account = _make_account()
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1"))
monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None))
paginate_result = MagicMock()
monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result)
with app.test_request_context("/console/api/apps/app-1/completion-conversations", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
assert response is paginate_result
def test_completion_conversation_list_invalid_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = conversation_module.CompletionConversationApi()
method = _unwrap(api.get)
account = _make_account()
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1"))
monkeypatch.setattr(
conversation_module,
"parse_time_range",
lambda *_args, **_kwargs: (_ for _ in ()).throw(ValueError("bad range")),
)
with app.test_request_context(
"/console/api/apps/app-1/completion-conversations",
method="GET",
query_string={"start": "bad"},
):
with pytest.raises(BadRequest):
method(app_model=SimpleNamespace(id="app-1"))
def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = conversation_module.ChatConversationApi()
method = _unwrap(api.get)
account = _make_account()
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1"))
monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None))
paginate_result = MagicMock()
monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result)
with app.test_request_context("/console/api/apps/app-1/chat-conversations", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1", mode=AppMode.ADVANCED_CHAT))
assert response is paginate_result
def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> None:
conversation = SimpleNamespace(id="c1", app_id="app-1")
query = MagicMock()
query.where.return_value = query
query.first.return_value = conversation
session = MagicMock()
session.query.return_value = query
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
monkeypatch.setattr(conversation_module.db, "session", session)
result = conversation_module._get_conversation(SimpleNamespace(id="app-1"), "c1")
assert result is conversation
session.execute.assert_called_once()
session.commit.assert_called_once()
session.refresh.assert_called_once_with(conversation)
def test_get_conversation_missing_raises_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
query = MagicMock()
query.where.return_value = query
query.first.return_value = None
session = MagicMock()
session.query.return_value = query
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
monkeypatch.setattr(conversation_module.db, "session", session)
with pytest.raises(NotFound):
conversation_module._get_conversation(SimpleNamespace(id="app-1"), "missing")
def test_completion_conversation_delete_maps_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
api = conversation_module.CompletionConversationDetailApi()
method = _unwrap(api.delete)
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
monkeypatch.setattr(
conversation_module.ConversationService,
"delete",
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationNotExistsError()),
)
with pytest.raises(NotFound):
method(app_model=SimpleNamespace(id="app-1"), conversation_id="c1")

View File

@ -0,0 +1,260 @@
from __future__ import annotations
from types import SimpleNamespace
import pytest
from controllers.console.app import generator as generator_module
from controllers.console.app.error import ProviderNotInitializeError
from core.errors.error import ProviderTokenNotInitError
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
def _model_config_payload():
return {"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}}
def _install_workflow_service(monkeypatch: pytest.MonkeyPatch, workflow):
class _Service:
def get_draft_workflow(self, app_model):
return workflow
monkeypatch.setattr(generator_module, "WorkflowService", lambda: _Service())
def test_rule_generate_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.RuleGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
monkeypatch.setattr(generator_module.LLMGenerator, "generate_rule_config", lambda **_kwargs: {"rules": []})
with app.test_request_context(
"/console/api/rule-generate",
method="POST",
json={"instruction": "do it", "model_config": _model_config_payload()},
):
response = method()
assert response == {"rules": []}
def test_rule_code_generate_maps_token_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.RuleCodeGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
def _raise(*_args, **_kwargs):
raise ProviderTokenNotInitError("missing token")
monkeypatch.setattr(generator_module.LLMGenerator, "generate_code", _raise)
with app.test_request_context(
"/console/api/rule-code-generate",
method="POST",
json={"instruction": "do it", "model_config": _model_config_payload()},
):
with pytest.raises(ProviderNotInitializeError):
method()
def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: None)
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
with app.test_request_context(
"/console/api/instruction-generate",
method="POST",
json={
"flow_id": "app-1",
"node_id": "node-1",
"instruction": "do",
"model_config": _model_config_payload(),
},
):
response, status = method()
assert status == 400
assert response["error"] == "app app-1 not found"
def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
app_model = SimpleNamespace(id="app-1")
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
_install_workflow_service(monkeypatch, workflow=None)
with app.test_request_context(
"/console/api/instruction-generate",
method="POST",
json={
"flow_id": "app-1",
"node_id": "node-1",
"instruction": "do",
"model_config": _model_config_payload(),
},
):
response, status = method()
assert status == 400
assert response["error"] == "workflow app-1 not found"
def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
app_model = SimpleNamespace(id="app-1")
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
workflow = SimpleNamespace(graph_dict={"nodes": []})
_install_workflow_service(monkeypatch, workflow=workflow)
with app.test_request_context(
"/console/api/instruction-generate",
method="POST",
json={
"flow_id": "app-1",
"node_id": "node-1",
"instruction": "do",
"model_config": _model_config_payload(),
},
):
response, status = method()
assert status == 400
assert response["error"] == "node node-1 not found"
def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
app_model = SimpleNamespace(id="app-1")
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
workflow = SimpleNamespace(
graph_dict={
"nodes": [
{"id": "node-1", "data": {"type": "code"}},
]
}
)
_install_workflow_service(monkeypatch, workflow=workflow)
monkeypatch.setattr(generator_module.LLMGenerator, "generate_code", lambda **_kwargs: {"code": "x"})
with app.test_request_context(
"/console/api/instruction-generate",
method="POST",
json={
"flow_id": "app-1",
"node_id": "node-1",
"instruction": "do",
"model_config": _model_config_payload(),
},
):
response = method()
assert response == {"code": "x"}
def test_instruction_generate_legacy_modify(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
monkeypatch.setattr(
generator_module.LLMGenerator,
"instruction_modify_legacy",
lambda **_kwargs: {"instruction": "ok"},
)
with app.test_request_context(
"/console/api/instruction-generate",
method="POST",
json={
"flow_id": "app-1",
"node_id": "",
"current": "old",
"instruction": "do",
"model_config": _model_config_payload(),
},
):
response = method()
assert response == {"instruction": "ok"}
def test_instruction_generate_incompatible_params(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
with app.test_request_context(
"/console/api/instruction-generate",
method="POST",
json={
"flow_id": "app-1",
"node_id": "",
"current": "",
"instruction": "do",
"model_config": _model_config_payload(),
},
):
response, status = method()
assert status == 400
assert response["error"] == "incompatible parameters"
def test_instruction_template_prompt(app) -> None:
api = generator_module.InstructionGenerationTemplateApi()
method = _unwrap(api.post)
with app.test_request_context(
"/console/api/instruction-generate/template",
method="POST",
json={"type": "prompt"},
):
response = method()
assert "data" in response
def test_instruction_template_invalid_type(app) -> None:
api = generator_module.InstructionGenerationTemplateApi()
method = _unwrap(api.post)
with app.test_request_context(
"/console/api/instruction-generate/template",
method="POST",
json={"type": "unknown"},
):
with pytest.raises(ValueError):
method()

View File

@ -0,0 +1,122 @@
from __future__ import annotations
import pytest
from controllers.console.app import message as message_module
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
def test_chat_messages_query_valid(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test valid ChatMessagesQuery with all fields."""
query = message_module.ChatMessagesQuery(
conversation_id="550e8400-e29b-41d4-a716-446655440000",
first_id="550e8400-e29b-41d4-a716-446655440001",
limit=50,
)
assert query.limit == 50
def test_chat_messages_query_defaults(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test ChatMessagesQuery with defaults."""
query = message_module.ChatMessagesQuery(conversation_id="550e8400-e29b-41d4-a716-446655440000")
assert query.first_id is None
assert query.limit == 20
def test_chat_messages_query_empty_first_id(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test ChatMessagesQuery converts empty first_id to None."""
query = message_module.ChatMessagesQuery(
conversation_id="550e8400-e29b-41d4-a716-446655440000",
first_id="",
)
assert query.first_id is None
def test_message_feedback_payload_valid_like(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test MessageFeedbackPayload with like rating."""
payload = message_module.MessageFeedbackPayload(
message_id="550e8400-e29b-41d4-a716-446655440000",
rating="like",
content="Good answer",
)
assert payload.rating == "like"
assert payload.content == "Good answer"
def test_message_feedback_payload_valid_dislike(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test MessageFeedbackPayload with dislike rating."""
payload = message_module.MessageFeedbackPayload(
message_id="550e8400-e29b-41d4-a716-446655440000",
rating="dislike",
)
assert payload.rating == "dislike"
def test_message_feedback_payload_no_rating(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test MessageFeedbackPayload without rating."""
payload = message_module.MessageFeedbackPayload(message_id="550e8400-e29b-41d4-a716-446655440000")
assert payload.rating is None
def test_feedback_export_query_defaults(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test FeedbackExportQuery with default format."""
query = message_module.FeedbackExportQuery()
assert query.format == "csv"
assert query.from_source is None
def test_feedback_export_query_json_format(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test FeedbackExportQuery with JSON format."""
query = message_module.FeedbackExportQuery(format="json")
assert query.format == "json"
def test_feedback_export_query_has_comment_true(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test FeedbackExportQuery with has_comment as true string."""
query = message_module.FeedbackExportQuery(has_comment="true")
assert query.has_comment is True
def test_feedback_export_query_has_comment_false(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test FeedbackExportQuery with has_comment as false string."""
query = message_module.FeedbackExportQuery(has_comment="false")
assert query.has_comment is False
def test_feedback_export_query_has_comment_1(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test FeedbackExportQuery with has_comment as 1."""
query = message_module.FeedbackExportQuery(has_comment="1")
assert query.has_comment is True
def test_feedback_export_query_has_comment_0(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test FeedbackExportQuery with has_comment as 0."""
query = message_module.FeedbackExportQuery(has_comment="0")
assert query.has_comment is False
def test_feedback_export_query_rating_filter(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test FeedbackExportQuery with rating filter."""
query = message_module.FeedbackExportQuery(rating="like")
assert query.rating == "like"
def test_annotation_count_response(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test AnnotationCountResponse creation."""
response = message_module.AnnotationCountResponse(count=10)
assert response.count == 10
def test_suggested_questions_response(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test SuggestedQuestionsResponse creation."""
response = message_module.SuggestedQuestionsResponse(data=["What is AI?", "How does ML work?"])
assert len(response.data) == 2
assert response.data[0] == "What is AI?"

View File

@ -0,0 +1,151 @@
from __future__ import annotations
import json
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from controllers.console.app import model_config as model_config_module
from models.model import AppMode, AppModelConfig
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
def test_post_updates_app_model_config_for_chat(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = model_config_module.ModelConfigResource()
method = _unwrap(api.post)
app_model = SimpleNamespace(
id="app-1",
mode=AppMode.CHAT.value,
is_agent=False,
app_model_config_id=None,
updated_by=None,
updated_at=None,
)
monkeypatch.setattr(
model_config_module.AppModelConfigService,
"validate_configuration",
lambda **_kwargs: {"pre_prompt": "hi"},
)
monkeypatch.setattr(model_config_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
session = MagicMock()
monkeypatch.setattr(model_config_module.db, "session", session)
def _from_model_config_dict(self, model_config):
self.pre_prompt = model_config["pre_prompt"]
self.id = "config-1"
return self
monkeypatch.setattr(AppModelConfig, "from_model_config_dict", _from_model_config_dict)
send_mock = MagicMock()
monkeypatch.setattr(model_config_module.app_model_config_was_updated, "send", send_mock)
with app.test_request_context("/console/api/apps/app-1/model-config", method="POST", json={"pre_prompt": "hi"}):
response = method(app_model=app_model)
session.add.assert_called_once()
session.flush.assert_called_once()
session.commit.assert_called_once()
send_mock.assert_called_once()
assert app_model.app_model_config_id == "config-1"
assert response["result"] == "success"
def test_post_encrypts_agent_tool_parameters(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = model_config_module.ModelConfigResource()
method = _unwrap(api.post)
app_model = SimpleNamespace(
id="app-1",
mode=AppMode.AGENT_CHAT.value,
is_agent=True,
app_model_config_id="config-0",
updated_by=None,
updated_at=None,
)
original_config = AppModelConfig(app_id="app-1", created_by="u1", updated_by="u1")
original_config.agent_mode = json.dumps(
{
"enabled": True,
"strategy": "function-calling",
"tools": [
{
"provider_id": "provider",
"provider_type": "builtin",
"tool_name": "tool",
"tool_parameters": {"secret": "masked"},
}
],
"prompt": None,
}
)
session = MagicMock()
query = MagicMock()
query.where.return_value = query
query.first.return_value = original_config
session.query.return_value = query
monkeypatch.setattr(model_config_module.db, "session", session)
monkeypatch.setattr(
model_config_module.AppModelConfigService,
"validate_configuration",
lambda **_kwargs: {
"pre_prompt": "hi",
"agent_mode": {
"enabled": True,
"strategy": "function-calling",
"tools": [
{
"provider_id": "provider",
"provider_type": "builtin",
"tool_name": "tool",
"tool_parameters": {"secret": "masked"},
}
],
"prompt": None,
},
},
)
monkeypatch.setattr(model_config_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
monkeypatch.setattr(model_config_module.ToolManager, "get_agent_tool_runtime", lambda **_kwargs: object())
class _ParamManager:
def __init__(self, **_kwargs):
self.delete_called = False
def decrypt_tool_parameters(self, _value):
return {"secret": "decrypted"}
def mask_tool_parameters(self, _value):
return {"secret": "masked"}
def encrypt_tool_parameters(self, _value):
return {"secret": "encrypted"}
def delete_tool_parameters_cache(self):
self.delete_called = True
monkeypatch.setattr(model_config_module, "ToolParameterConfigurationManager", _ParamManager)
send_mock = MagicMock()
monkeypatch.setattr(model_config_module.app_model_config_was_updated, "send", send_mock)
with app.test_request_context("/console/api/apps/app-1/model-config", method="POST", json={"pre_prompt": "hi"}):
response = method(app_model=app_model)
stored_config = session.add.call_args[0][0]
stored_agent_mode = json.loads(stored_config.agent_mode)
assert stored_agent_mode["tools"][0]["tool_parameters"]["secret"] == "encrypted"
assert response["result"] == "success"

View File

@ -0,0 +1,215 @@
from __future__ import annotations
from decimal import Decimal
from types import SimpleNamespace
import pytest
from werkzeug.exceptions import BadRequest
from controllers.console.app import statistic as statistic_module
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
class _ConnContext:
def __init__(self, rows):
self._rows = rows
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def execute(self, _query, _args):
return self._rows
def _install_db(monkeypatch: pytest.MonkeyPatch, rows) -> None:
engine = SimpleNamespace(begin=lambda: _ConnContext(rows))
monkeypatch.setattr(statistic_module, "db", SimpleNamespace(engine=engine))
def _install_common(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
statistic_module,
"current_account_with_tenant",
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
)
monkeypatch.setattr(
statistic_module,
"parse_time_range",
lambda *_args, **_kwargs: (None, None),
)
monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field)
def test_daily_message_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyMessageStatistic()
method = _unwrap(api.get)
rows = [SimpleNamespace(date="2024-01-01", message_count=3)]
_install_common(monkeypatch)
_install_db(monkeypatch, rows)
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
assert response.get_json() == {"data": [{"date": "2024-01-01", "message_count": 3}]}
def test_daily_conversation_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyConversationStatistic()
method = _unwrap(api.get)
rows = [SimpleNamespace(date="2024-01-02", conversation_count=5)]
_install_common(monkeypatch)
_install_db(monkeypatch, rows)
with app.test_request_context("/console/api/apps/app-1/statistics/daily-conversations", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]}
def test_daily_token_cost_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyTokenCostStatistic()
method = _unwrap(api.get)
rows = [SimpleNamespace(date="2024-01-03", token_count=10, total_price=0.25, currency="USD")]
_install_common(monkeypatch)
_install_db(monkeypatch, rows)
with app.test_request_context("/console/api/apps/app-1/statistics/token-costs", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
data = response.get_json()
assert len(data["data"]) == 1
assert data["data"][0]["date"] == "2024-01-03"
assert data["data"][0]["token_count"] == 10
assert data["data"][0]["total_price"] == 0.25
def test_daily_terminals_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyTerminalsStatistic()
method = _unwrap(api.get)
rows = [SimpleNamespace(date="2024-01-04", terminal_count=7)]
_install_common(monkeypatch)
_install_db(monkeypatch, rows)
with app.test_request_context("/console/api/apps/app-1/statistics/daily-end-users", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
assert response.get_json() == {"data": [{"date": "2024-01-04", "terminal_count": 7}]}
def test_average_session_interaction_statistic_requires_chat_mode(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that AverageSessionInteractionStatistic is limited to chat/agent modes."""
# This just verifies the decorator is applied correctly
# Actual endpoint testing would require complex JOIN mocking
api = statistic_module.AverageSessionInteractionStatistic()
method = _unwrap(api.get)
assert callable(method)
def test_daily_message_statistic_with_invalid_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyMessageStatistic()
method = _unwrap(api.get)
def mock_parse(*args, **kwargs):
raise ValueError("Invalid time range")
_install_db(monkeypatch, [])
monkeypatch.setattr(
statistic_module,
"current_account_with_tenant",
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
)
monkeypatch.setattr(statistic_module, "parse_time_range", mock_parse)
monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field)
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
with pytest.raises(BadRequest):
method(app_model=SimpleNamespace(id="app-1"))
def test_daily_message_statistic_multiple_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyMessageStatistic()
method = _unwrap(api.get)
rows = [
SimpleNamespace(date="2024-01-01", message_count=10),
SimpleNamespace(date="2024-01-02", message_count=15),
SimpleNamespace(date="2024-01-03", message_count=12),
]
_install_common(monkeypatch)
_install_db(monkeypatch, rows)
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
data = response.get_json()
assert len(data["data"]) == 3
def test_daily_message_statistic_empty_result(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyMessageStatistic()
method = _unwrap(api.get)
_install_common(monkeypatch)
_install_db(monkeypatch, [])
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
assert response.get_json() == {"data": []}
def test_daily_conversation_statistic_with_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyConversationStatistic()
method = _unwrap(api.get)
rows = [SimpleNamespace(date="2024-01-02", conversation_count=5)]
_install_db(monkeypatch, rows)
monkeypatch.setattr(
statistic_module,
"current_account_with_tenant",
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
)
monkeypatch.setattr(
statistic_module,
"parse_time_range",
lambda *_args, **_kwargs: ("s", "e"),
)
monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field)
with app.test_request_context("/console/api/apps/app-1/statistics/daily-conversations", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]}
def test_daily_token_cost_with_multiple_currencies(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyTokenCostStatistic()
method = _unwrap(api.get)
rows = [
SimpleNamespace(date="2024-01-01", token_count=100, total_price=Decimal("0.50"), currency="USD"),
SimpleNamespace(date="2024-01-02", token_count=200, total_price=Decimal("1.00"), currency="USD"),
]
_install_common(monkeypatch)
_install_db(monkeypatch, rows)
with app.test_request_context("/console/api/apps/app-1/statistics/token-costs", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
data = response.get_json()
assert len(data["data"]) == 2

View File

@ -0,0 +1,163 @@
from __future__ import annotations
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import Mock
import pytest
from werkzeug.exceptions import HTTPException, NotFound
from controllers.console.app import workflow as workflow_module
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
from dify_graph.file.enums import FileTransferMethod, FileType
from dify_graph.file.models import File
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def test_parse_file_no_config(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(workflow_module.FileUploadConfigManager, "convert", lambda *_args, **_kwargs: None)
workflow = SimpleNamespace(features_dict={}, tenant_id="t1")
assert workflow_module._parse_file(workflow, files=[{"id": "f"}]) == []
def test_parse_file_with_config(monkeypatch: pytest.MonkeyPatch) -> None:
config = object()
file_list = [
File(
tenant_id="t1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="http://u",
)
]
build_mock = Mock(return_value=file_list)
monkeypatch.setattr(workflow_module.FileUploadConfigManager, "convert", lambda *_args, **_kwargs: config)
monkeypatch.setattr(workflow_module.file_factory, "build_from_mappings", build_mock)
workflow = SimpleNamespace(features_dict={}, tenant_id="t1")
result = workflow_module._parse_file(workflow, files=[{"id": "f"}])
assert result == file_list
build_mock.assert_called_once()
def test_sync_draft_workflow_invalid_content_type(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = workflow_module.DraftWorkflowApi()
handler = _unwrap(api.post)
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
with app.test_request_context("/apps/app/workflows/draft", method="POST", data="x", content_type="text/html"):
with pytest.raises(HTTPException) as exc:
handler(api, app_model=SimpleNamespace(id="app"))
assert exc.value.code == 415
def test_sync_draft_workflow_invalid_json(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = workflow_module.DraftWorkflowApi()
handler = _unwrap(api.post)
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
with app.test_request_context(
"/apps/app/workflows/draft",
method="POST",
data="[]",
content_type="application/json",
):
response, status = handler(api, app_model=SimpleNamespace(id="app"))
assert status == 400
assert response["message"] == "Invalid JSON data"
def test_sync_draft_workflow_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
workflow = SimpleNamespace(
unique_hash="h",
updated_at=None,
created_at=datetime(2024, 1, 1),
)
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
monkeypatch.setattr(
workflow_module.variable_factory, "build_environment_variable_from_mapping", lambda *_args: "env"
)
monkeypatch.setattr(
workflow_module.variable_factory, "build_conversation_variable_from_mapping", lambda *_args: "conv"
)
service = SimpleNamespace(sync_draft_workflow=lambda **_kwargs: workflow)
monkeypatch.setattr(workflow_module, "WorkflowService", lambda: service)
api = workflow_module.DraftWorkflowApi()
handler = _unwrap(api.post)
with app.test_request_context(
"/apps/app/workflows/draft",
method="POST",
json={"graph": {}, "features": {}, "hash": "h"},
):
response = handler(api, app_model=SimpleNamespace(id="app"))
assert response["result"] == "success"
def test_sync_draft_workflow_hash_mismatch(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
def _raise(*_args, **_kwargs):
raise workflow_module.WorkflowHashNotEqualError()
service = SimpleNamespace(sync_draft_workflow=_raise)
monkeypatch.setattr(workflow_module, "WorkflowService", lambda: service)
api = workflow_module.DraftWorkflowApi()
handler = _unwrap(api.post)
with app.test_request_context(
"/apps/app/workflows/draft",
method="POST",
json={"graph": {}, "features": {}, "hash": "h"},
):
with pytest.raises(DraftWorkflowNotSync):
handler(api, app_model=SimpleNamespace(id="app"))
def test_draft_workflow_get_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
workflow_module, "WorkflowService", lambda: SimpleNamespace(get_draft_workflow=lambda **_k: None)
)
api = workflow_module.DraftWorkflowApi()
handler = _unwrap(api.get)
with pytest.raises(DraftWorkflowNotExist):
handler(api, app_model=SimpleNamespace(id="app"))
def test_advanced_chat_run_conversation_not_exists(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
workflow_module.AppGenerateService,
"generate",
lambda *_args, **_kwargs: (_ for _ in ()).throw(
workflow_module.services.errors.conversation.ConversationNotExistsError()
),
)
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
api = workflow_module.AdvancedChatDraftWorkflowRunApi()
handler = _unwrap(api.post)
with app.test_request_context(
"/apps/app/advanced-chat/workflows/draft/run",
method="POST",
json={"inputs": {}},
):
with pytest.raises(NotFound):
handler(api, app_model=SimpleNamespace(id="app"))

View File

@ -0,0 +1,47 @@
from __future__ import annotations
from types import SimpleNamespace
import pytest
from controllers.console.app import wraps as wraps_module
from controllers.console.app.error import AppNotFoundError
from models.model import AppMode
def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None:
app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1")
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1"))
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
@wraps_module.get_app_model
def handler(app_model):
return app_model.id
assert handler(app_id="app-1") == "app-1"
def test_get_app_model_rejects_wrong_mode(monkeypatch: pytest.MonkeyPatch) -> None:
app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1")
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1"))
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
@wraps_module.get_app_model(mode=[AppMode.COMPLETION])
def handler(app_model):
return app_model.id
with pytest.raises(AppNotFoundError):
handler(app_id="app-1")
def test_get_app_model_requires_app_id() -> None:
@wraps_module.get_app_model
def handler(app_model):
return app_model.id
with pytest.raises(ValueError):
handler()

View File

@ -0,0 +1,402 @@
from io import BytesIO
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import InternalServerError
import controllers.console.explore.audio as audio_module
from controllers.console.app.error import (
AppUnavailableError,
AudioTooLargeError,
CompletionRequestError,
NoAudioUploadedError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from dify_graph.model_runtime.errors.invoke import InvokeError
from services.errors.audio import (
AudioTooLargeServiceError,
NoAudioUploadedServiceError,
)
def unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
@pytest.fixture
def installed_app():
app = MagicMock()
app.app = MagicMock()
return app
@pytest.fixture
def audio_file():
return (BytesIO(b"audio"), "audio.wav")
class TestChatAudioApi:
def setup_method(self):
self.api = audio_module.ChatAudioApi()
self.method = unwrap(self.api.post)
def test_post_success(self, app, installed_app, audio_file):
with (
app.test_request_context(
"/",
data={"file": audio_file},
content_type="multipart/form-data",
),
patch.object(
audio_module.AudioService,
"transcript_asr",
return_value={"text": "ok"},
),
):
resp = self.method(installed_app)
assert resp == {"text": "ok"}
def test_app_unavailable(self, app, installed_app, audio_file):
with (
app.test_request_context(
"/",
data={"file": audio_file},
content_type="multipart/form-data",
),
patch.object(
audio_module.AudioService,
"transcript_asr",
side_effect=audio_module.services.errors.app_model_config.AppModelConfigBrokenError(),
),
):
with pytest.raises(AppUnavailableError):
self.method(installed_app)
def test_no_audio_uploaded(self, app, installed_app, audio_file):
with (
app.test_request_context(
"/",
data={"file": audio_file},
content_type="multipart/form-data",
),
patch.object(
audio_module.AudioService,
"transcript_asr",
side_effect=NoAudioUploadedServiceError(),
),
):
with pytest.raises(NoAudioUploadedError):
self.method(installed_app)
def test_audio_too_large(self, app, installed_app, audio_file):
with (
app.test_request_context(
"/",
data={"file": audio_file},
content_type="multipart/form-data",
),
patch.object(
audio_module.AudioService,
"transcript_asr",
side_effect=AudioTooLargeServiceError("too big"),
),
):
with pytest.raises(AudioTooLargeError):
self.method(installed_app)
def test_provider_quota_exceeded(self, app, installed_app, audio_file):
with (
app.test_request_context(
"/",
data={"file": audio_file},
content_type="multipart/form-data",
),
patch.object(
audio_module.AudioService,
"transcript_asr",
side_effect=QuotaExceededError(),
),
):
with pytest.raises(ProviderQuotaExceededError):
self.method(installed_app)
def test_unknown_exception(self, app, installed_app, audio_file):
with (
app.test_request_context(
"/",
data={"file": audio_file},
content_type="multipart/form-data",
),
patch.object(
audio_module.AudioService,
"transcript_asr",
side_effect=Exception("boom"),
),
):
with pytest.raises(InternalServerError):
self.method(installed_app)
def test_unsupported_audio_type(self, app, installed_app, audio_file):
with (
app.test_request_context(
"/",
data={"file": audio_file},
content_type="multipart/form-data",
),
patch.object(
audio_module.AudioService,
"transcript_asr",
side_effect=audio_module.UnsupportedAudioTypeServiceError(),
),
):
with pytest.raises(audio_module.UnsupportedAudioTypeError):
self.method(installed_app)
def test_provider_not_support_speech_to_text(self, app, installed_app, audio_file):
with (
app.test_request_context(
"/",
data={"file": audio_file},
content_type="multipart/form-data",
),
patch.object(
audio_module.AudioService,
"transcript_asr",
side_effect=audio_module.ProviderNotSupportSpeechToTextServiceError(),
),
):
with pytest.raises(audio_module.ProviderNotSupportSpeechToTextError):
self.method(installed_app)
def test_provider_not_initialized(self, app, installed_app, audio_file):
with (
app.test_request_context(
"/",
data={"file": audio_file},
content_type="multipart/form-data",
),
patch.object(
audio_module.AudioService,
"transcript_asr",
side_effect=ProviderTokenNotInitError("not init"),
),
):
with pytest.raises(ProviderNotInitializeError):
self.method(installed_app)
def test_model_currently_not_supported(self, app, installed_app, audio_file):
with (
app.test_request_context(
"/",
data={"file": audio_file},
content_type="multipart/form-data",
),
patch.object(
audio_module.AudioService,
"transcript_asr",
side_effect=ModelCurrentlyNotSupportError(),
),
):
with pytest.raises(ProviderModelCurrentlyNotSupportError):
self.method(installed_app)
def test_invoke_error_asr(self, app, installed_app, audio_file):
with (
app.test_request_context(
"/",
data={"file": audio_file},
content_type="multipart/form-data",
),
patch.object(
audio_module.AudioService,
"transcript_asr",
side_effect=InvokeError("invoke failed"),
),
):
with pytest.raises(CompletionRequestError):
self.method(installed_app)
class TestChatTextApi:
def setup_method(self):
self.api = audio_module.ChatTextApi()
self.method = unwrap(self.api.post)
def test_post_success(self, app, installed_app):
with (
app.test_request_context(
"/",
json={"message_id": "m1", "text": "hello", "voice": "v1"},
),
patch.object(
audio_module.AudioService,
"transcript_tts",
return_value={"audio": "ok"},
),
):
resp = self.method(installed_app)
assert resp == {"audio": "ok"}
def test_provider_not_initialized(self, app, installed_app):
with (
app.test_request_context(
"/",
json={"text": "hi"},
),
patch.object(
audio_module.AudioService,
"transcript_tts",
side_effect=ProviderTokenNotInitError("not init"),
),
):
with pytest.raises(ProviderNotInitializeError):
self.method(installed_app)
def test_model_not_supported(self, app, installed_app):
with (
app.test_request_context(
"/",
json={"text": "hi"},
),
patch.object(
audio_module.AudioService,
"transcript_tts",
side_effect=ModelCurrentlyNotSupportError(),
),
):
with pytest.raises(ProviderModelCurrentlyNotSupportError):
self.method(installed_app)
def test_invoke_error(self, app, installed_app):
with (
app.test_request_context(
"/",
json={"text": "hi"},
),
patch.object(
audio_module.AudioService,
"transcript_tts",
side_effect=InvokeError("invoke failed"),
),
):
with pytest.raises(CompletionRequestError):
self.method(installed_app)
def test_unknown_exception(self, app, installed_app):
with (
app.test_request_context(
"/",
json={"text": "hi"},
),
patch.object(
audio_module.AudioService,
"transcript_tts",
side_effect=Exception("boom"),
),
):
with pytest.raises(InternalServerError):
self.method(installed_app)
def test_app_unavailable_tts(self, app, installed_app):
with (
app.test_request_context(
"/",
json={"text": "hi"},
),
patch.object(
audio_module.AudioService,
"transcript_tts",
side_effect=audio_module.services.errors.app_model_config.AppModelConfigBrokenError(),
),
):
with pytest.raises(AppUnavailableError):
self.method(installed_app)
def test_no_audio_uploaded_tts(self, app, installed_app):
with (
app.test_request_context(
"/",
json={"text": "hi"},
),
patch.object(
audio_module.AudioService,
"transcript_tts",
side_effect=NoAudioUploadedServiceError(),
),
):
with pytest.raises(NoAudioUploadedError):
self.method(installed_app)
def test_audio_too_large_tts(self, app, installed_app):
with (
app.test_request_context(
"/",
json={"text": "hi"},
),
patch.object(
audio_module.AudioService,
"transcript_tts",
side_effect=AudioTooLargeServiceError("too big"),
),
):
with pytest.raises(AudioTooLargeError):
self.method(installed_app)
def test_unsupported_audio_type_tts(self, app, installed_app):
with (
app.test_request_context(
"/",
json={"text": "hi"},
),
patch.object(
audio_module.AudioService,
"transcript_tts",
side_effect=audio_module.UnsupportedAudioTypeServiceError(),
),
):
with pytest.raises(audio_module.UnsupportedAudioTypeError):
self.method(installed_app)
def test_provider_not_support_speech_to_text_tts(self, app, installed_app):
with (
app.test_request_context(
"/",
json={"text": "hi"},
),
patch.object(
audio_module.AudioService,
"transcript_tts",
side_effect=audio_module.ProviderNotSupportSpeechToTextServiceError(),
),
):
with pytest.raises(audio_module.ProviderNotSupportSpeechToTextError):
self.method(installed_app)
def test_quota_exceeded_tts(self, app, installed_app):
with (
app.test_request_context(
"/",
json={"text": "hi"},
),
patch.object(
audio_module.AudioService,
"transcript_tts",
side_effect=QuotaExceededError(),
),
):
with pytest.raises(ProviderQuotaExceededError):
self.method(installed_app)

View File

@ -0,0 +1,100 @@
from datetime import datetime
from unittest.mock import MagicMock, patch
import controllers.console.explore.banner as banner_module
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestBannerApi:
def test_get_banners_with_requested_language(self, app):
api = banner_module.BannerApi()
method = unwrap(api.get)
banner = MagicMock()
banner.id = "b1"
banner.content = {"text": "hello"}
banner.link = "https://example.com"
banner.sort = 1
banner.status = "enabled"
banner.created_at = datetime(2024, 1, 1)
query = MagicMock()
query.where.return_value = query
query.order_by.return_value = query
query.all.return_value = [banner]
session = MagicMock()
session.query.return_value = query
with app.test_request_context("/?language=fr-FR"), patch.object(banner_module.db, "session", session):
result = method(api)
assert result == [
{
"id": "b1",
"content": {"text": "hello"},
"link": "https://example.com",
"sort": 1,
"status": "enabled",
"created_at": "2024-01-01T00:00:00",
}
]
def test_get_banners_fallback_to_en_us(self, app):
api = banner_module.BannerApi()
method = unwrap(api.get)
banner = MagicMock()
banner.id = "b2"
banner.content = {"text": "fallback"}
banner.link = None
banner.sort = 1
banner.status = "enabled"
banner.created_at = None
query = MagicMock()
query.where.return_value = query
query.order_by.return_value = query
query.all.side_effect = [
[],
[banner],
]
session = MagicMock()
session.query.return_value = query
with app.test_request_context("/?language=es-ES"), patch.object(banner_module.db, "session", session):
result = method(api)
assert result == [
{
"id": "b2",
"content": {"text": "fallback"},
"link": None,
"sort": 1,
"status": "enabled",
"created_at": None,
}
]
def test_get_banners_default_language_en_us(self, app):
api = banner_module.BannerApi()
method = unwrap(api.get)
query = MagicMock()
query.where.return_value = query
query.order_by.return_value = query
query.all.return_value = []
session = MagicMock()
session.query.return_value = query
with app.test_request_context("/"), patch.object(banner_module.db, "session", session):
result = method(api)
assert result == []

View File

@ -0,0 +1,459 @@
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from werkzeug.exceptions import InternalServerError
import controllers.console.explore.completion as completion_module
from controllers.console.app.error import (
ConversationCompletedError,
)
from controllers.console.explore.error import NotChatAppError, NotCompletionAppError
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from models import Account
from models.model import AppMode
from services.errors.llm import InvokeRateLimitError
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def user():
return MagicMock(spec=Account)
@pytest.fixture
def completion_app():
return MagicMock(app=MagicMock(mode=AppMode.COMPLETION))
@pytest.fixture
def chat_app():
return MagicMock(app=MagicMock(mode=AppMode.CHAT))
@pytest.fixture
def payload_data():
return {"inputs": {}, "query": "hi"}
@pytest.fixture
def payload_patch(payload_data):
return patch.object(
type(completion_module.console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload_data,
)
class TestCompletionApi:
def test_post_success(self, app, completion_app, user, payload_patch):
api = completion_module.CompletionApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
return_value={"ok": True},
),
patch.object(
completion_module.helper,
"compact_generate_response",
return_value=("ok", 200),
),
):
result = method(completion_app)
assert result == ("ok", 200)
def test_post_wrong_app_mode(self):
api = completion_module.CompletionApi()
method = unwrap(api.post)
installed_app = MagicMock(app=MagicMock(mode=AppMode.CHAT))
with pytest.raises(NotCompletionAppError):
method(installed_app)
def test_conversation_completed(self, app, completion_app, user, payload_patch):
api = completion_module.CompletionApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=completion_module.services.errors.conversation.ConversationCompletedError(),
),
):
with pytest.raises(ConversationCompletedError):
method(completion_app)
def test_internal_error(self, app, completion_app, user, payload_patch):
api = completion_module.CompletionApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=Exception("boom"),
),
):
with pytest.raises(InternalServerError):
method(completion_app)
def test_conversation_not_exists(self, app, completion_app, user, payload_patch):
api = completion_module.CompletionApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=completion_module.services.errors.conversation.ConversationNotExistsError(),
),
):
with pytest.raises(completion_module.NotFound):
method(completion_app)
def test_app_unavailable(self, app, completion_app, user, payload_patch):
api = completion_module.CompletionApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=completion_module.services.errors.app_model_config.AppModelConfigBrokenError(),
),
):
with pytest.raises(completion_module.AppUnavailableError):
method(completion_app)
def test_provider_not_initialized(self, app, completion_app, user, payload_patch):
api = completion_module.CompletionApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=completion_module.ProviderTokenNotInitError("not init"),
),
):
with pytest.raises(completion_module.ProviderNotInitializeError):
method(completion_app)
def test_quota_exceeded(self, app, completion_app, user, payload_patch):
api = completion_module.CompletionApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=completion_module.QuotaExceededError(),
),
):
with pytest.raises(completion_module.ProviderQuotaExceededError):
method(completion_app)
def test_model_not_supported(self, app, completion_app, user, payload_patch):
api = completion_module.CompletionApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=completion_module.ModelCurrentlyNotSupportError(),
),
):
with pytest.raises(completion_module.ProviderModelCurrentlyNotSupportError):
method(completion_app)
def test_invoke_error(self, app, completion_app, user, payload_patch):
api = completion_module.CompletionApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=completion_module.InvokeError("invoke failed"),
),
):
with pytest.raises(completion_module.CompletionRequestError):
method(completion_app)
class TestCompletionStopApi:
def test_stop_success(self, completion_app, user):
api = completion_module.CompletionStopApi()
method = unwrap(api.post)
user.id = "u1"
with (
patch.object(completion_module, "current_user", user),
patch.object(completion_module.AppTaskService, "stop_task"),
):
resp, status = method(completion_app, "task-1")
assert status == 200
assert resp == {"result": "success"}
def test_stop_wrong_app_mode(self):
api = completion_module.CompletionStopApi()
method = unwrap(api.post)
installed_app = MagicMock(app=MagicMock(mode=AppMode.CHAT))
with pytest.raises(NotCompletionAppError):
method(installed_app, "task")
class TestChatApi:
def test_post_success(self, app, chat_app, user, payload_patch):
api = completion_module.ChatApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
return_value={"ok": True},
),
patch.object(
completion_module.helper,
"compact_generate_response",
return_value=("ok", 200),
),
):
result = method(chat_app)
assert result == ("ok", 200)
def test_post_not_chat_app(self):
api = completion_module.ChatApi()
method = unwrap(api.post)
installed_app = MagicMock(app=MagicMock(mode=AppMode.COMPLETION))
with pytest.raises(NotChatAppError):
method(installed_app)
def test_rate_limit_error(self, app, chat_app, user, payload_patch):
api = completion_module.ChatApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=InvokeRateLimitError("limit"),
),
):
with pytest.raises(InvokeRateLimitHttpError):
method(chat_app)
def test_conversation_completed_chat(self, app, chat_app, user, payload_patch):
api = completion_module.ChatApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=completion_module.services.errors.conversation.ConversationCompletedError(),
),
):
with pytest.raises(ConversationCompletedError):
method(chat_app)
def test_conversation_not_exists_chat(self, app, chat_app, user, payload_patch):
api = completion_module.ChatApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=completion_module.services.errors.conversation.ConversationNotExistsError(),
),
):
with pytest.raises(completion_module.NotFound):
method(chat_app)
def test_app_unavailable_chat(self, app, chat_app, user, payload_patch):
api = completion_module.ChatApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=completion_module.services.errors.app_model_config.AppModelConfigBrokenError(),
),
):
with pytest.raises(completion_module.AppUnavailableError):
method(chat_app)
def test_provider_not_initialized_chat(self, app, chat_app, user, payload_patch):
api = completion_module.ChatApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=completion_module.ProviderTokenNotInitError("not init"),
),
):
with pytest.raises(completion_module.ProviderNotInitializeError):
method(chat_app)
def test_quota_exceeded_chat(self, app, chat_app, user, payload_patch):
api = completion_module.ChatApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=completion_module.QuotaExceededError(),
),
):
with pytest.raises(completion_module.ProviderQuotaExceededError):
method(chat_app)
def test_model_not_supported_chat(self, app, chat_app, user, payload_patch):
api = completion_module.ChatApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=completion_module.ModelCurrentlyNotSupportError(),
),
):
with pytest.raises(completion_module.ProviderModelCurrentlyNotSupportError):
method(chat_app)
def test_invoke_error_chat(self, app, chat_app, user, payload_patch):
api = completion_module.ChatApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=completion_module.InvokeError("invoke failed"),
),
):
with pytest.raises(completion_module.CompletionRequestError):
method(chat_app)
def test_internal_error_chat(self, app, chat_app, user, payload_patch):
api = completion_module.ChatApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
payload_patch,
patch.object(completion_module, "current_user", user),
patch.object(
completion_module.AppGenerateService,
"generate",
side_effect=Exception("boom"),
),
):
with pytest.raises(InternalServerError):
method(chat_app)
class TestChatStopApi:
def test_stop_success(self, chat_app, user):
api = completion_module.ChatStopApi()
method = unwrap(api.post)
user.id = "u1"
with (
patch.object(completion_module, "current_user", user),
patch.object(completion_module.AppTaskService, "stop_task"),
):
resp, status = method(chat_app, "task-1")
assert status == 200
assert resp == {"result": "success"}
def test_stop_not_chat_app(self):
api = completion_module.ChatStopApi()
method = unwrap(api.post)
installed_app = MagicMock(app=MagicMock(mode=AppMode.COMPLETION))
with pytest.raises(NotChatAppError):
method(installed_app, "task")

View File

@ -0,0 +1,232 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
import controllers.console.explore.conversation as conversation_module
from controllers.console.explore.error import NotChatAppError
from models import Account
from models.model import AppMode
from services.errors.conversation import (
ConversationNotExistsError,
LastConversationNotExistsError,
)
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class FakeConversation:
def __init__(self, cid):
self.id = cid
self.name = "test"
self.inputs = {}
self.status = "normal"
self.introduction = ""
@pytest.fixture
def chat_app():
app_model = MagicMock(mode=AppMode.CHAT, id="app-id")
return MagicMock(app=app_model)
@pytest.fixture
def non_chat_app():
app_model = MagicMock(mode=AppMode.COMPLETION)
return MagicMock(app=app_model)
@pytest.fixture
def user():
user = MagicMock(spec=Account)
user.id = "uid"
return user
@pytest.fixture(autouse=True)
def mock_db_and_session():
with (
patch.object(
conversation_module,
"db",
MagicMock(session=MagicMock(), engine=MagicMock()),
),
patch(
"controllers.console.explore.conversation.Session",
MagicMock(),
),
):
yield
class TestConversationListApi:
def test_get_success(self, app: Flask, chat_app, user):
api = conversation_module.ConversationListApi()
method = unwrap(api.get)
pagination = MagicMock(
limit=20,
has_more=False,
data=[FakeConversation("c1"), FakeConversation("c2")],
)
with (
app.test_request_context("/?limit=20"),
patch.object(conversation_module, "current_user", user),
patch.object(
conversation_module.WebConversationService,
"pagination_by_last_id",
return_value=pagination,
),
):
result = method(chat_app)
assert result["limit"] == 20
assert result["has_more"] is False
assert len(result["data"]) == 2
def test_last_conversation_not_exists(self, app: Flask, chat_app, user):
api = conversation_module.ConversationListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch.object(conversation_module, "current_user", user),
patch.object(
conversation_module.WebConversationService,
"pagination_by_last_id",
side_effect=LastConversationNotExistsError(),
),
):
with pytest.raises(NotFound):
method(chat_app)
def test_wrong_app_mode(self, app: Flask, non_chat_app):
api = conversation_module.ConversationListApi()
method = unwrap(api.get)
with app.test_request_context("/"):
with pytest.raises(NotChatAppError):
method(non_chat_app)
class TestConversationApi:
def test_delete_success(self, app: Flask, chat_app, user):
api = conversation_module.ConversationApi()
method = unwrap(api.delete)
with (
app.test_request_context("/"),
patch.object(conversation_module, "current_user", user),
patch.object(
conversation_module.ConversationService,
"delete",
),
):
result = method(chat_app, "cid")
body, status = result
assert status == 204
assert body["result"] == "success"
def test_delete_not_found(self, app: Flask, chat_app, user):
api = conversation_module.ConversationApi()
method = unwrap(api.delete)
with (
app.test_request_context("/"),
patch.object(conversation_module, "current_user", user),
patch.object(
conversation_module.ConversationService,
"delete",
side_effect=ConversationNotExistsError(),
),
):
with pytest.raises(NotFound):
method(chat_app, "cid")
def test_delete_wrong_app_mode(self, app: Flask, non_chat_app):
api = conversation_module.ConversationApi()
method = unwrap(api.delete)
with app.test_request_context("/"):
with pytest.raises(NotChatAppError):
method(non_chat_app, "cid")
class TestConversationRenameApi:
def test_rename_success(self, app: Flask, chat_app, user):
api = conversation_module.ConversationRenameApi()
method = unwrap(api.post)
conversation = FakeConversation("cid")
with (
app.test_request_context("/", json={"name": "new"}),
patch.object(conversation_module, "current_user", user),
patch.object(
conversation_module.ConversationService,
"rename",
return_value=conversation,
),
):
result = method(chat_app, "cid")
assert result["id"] == "cid"
def test_rename_not_found(self, app: Flask, chat_app, user):
api = conversation_module.ConversationRenameApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"name": "new"}),
patch.object(conversation_module, "current_user", user),
patch.object(
conversation_module.ConversationService,
"rename",
side_effect=ConversationNotExistsError(),
),
):
with pytest.raises(NotFound):
method(chat_app, "cid")
class TestConversationPinApi:
def test_pin_success(self, app: Flask, chat_app, user):
api = conversation_module.ConversationPinApi()
method = unwrap(api.patch)
with (
app.test_request_context("/"),
patch.object(conversation_module, "current_user", user),
patch.object(
conversation_module.WebConversationService,
"pin",
),
):
result = method(chat_app, "cid")
assert result == {"result": "success"}
class TestConversationUnPinApi:
def test_unpin_success(self, app: Flask, chat_app, user):
api = conversation_module.ConversationUnPinApi()
method = unwrap(api.patch)
with (
app.test_request_context("/"),
patch.object(conversation_module, "current_user", user),
patch.object(
conversation_module.WebConversationService,
"unpin",
),
):
result = method(chat_app, "cid")
assert result == {"result": "success"}

View File

@ -0,0 +1,363 @@
from datetime import datetime
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
import controllers.console.explore.installed_app as module
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def tenant_id():
return "t1"
@pytest.fixture
def current_user(tenant_id):
user = MagicMock()
user.id = "u1"
user.current_tenant = MagicMock(id=tenant_id)
return user
@pytest.fixture
def installed_app():
app = MagicMock()
app.id = "ia1"
app.app = MagicMock(id="a1")
app.app_owner_tenant_id = "t2"
app.is_pinned = False
app.last_used_at = datetime(2024, 1, 1)
return app
@pytest.fixture
def payload_patch():
def _patch(payload):
return patch.object(
type(module.console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
)
return _patch
class TestInstalledAppsListApi:
def test_get_installed_apps(self, app, current_user, tenant_id, installed_app):
api = module.InstalledAppsListApi()
method = unwrap(api.get)
session = MagicMock()
session.scalars.return_value.all.return_value = [installed_app]
with (
app.test_request_context("/"),
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
patch.object(module.db, "session", session),
patch.object(module.TenantService, "get_user_role", return_value="owner"),
patch.object(
module.FeatureService,
"get_system_features",
return_value=MagicMock(webapp_auth=MagicMock(enabled=False)),
),
):
result = method(api)
assert "installed_apps" in result
assert result["installed_apps"][0]["editable"] is True
assert result["installed_apps"][0]["uninstallable"] is False
def test_get_installed_apps_with_app_id_filter(self, app, current_user, tenant_id):
api = module.InstalledAppsListApi()
method = unwrap(api.get)
session = MagicMock()
session.scalars.return_value.all.return_value = []
with (
app.test_request_context("/?app_id=a1"),
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
patch.object(module.db, "session", session),
patch.object(module.TenantService, "get_user_role", return_value="member"),
patch.object(
module.FeatureService,
"get_system_features",
return_value=MagicMock(webapp_auth=MagicMock(enabled=False)),
),
):
result = method(api)
assert result == {"installed_apps": []}
def test_get_installed_apps_with_webapp_auth_enabled(self, app, current_user, tenant_id, installed_app):
"""Test filtering when webapp_auth is enabled."""
api = module.InstalledAppsListApi()
method = unwrap(api.get)
session = MagicMock()
session.scalars.return_value.all.return_value = [installed_app]
mock_webapp_setting = MagicMock()
mock_webapp_setting.access_mode = "restricted"
with (
app.test_request_context("/"),
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
patch.object(module.db, "session", session),
patch.object(module.TenantService, "get_user_role", return_value="owner"),
patch.object(
module.FeatureService,
"get_system_features",
return_value=MagicMock(webapp_auth=MagicMock(enabled=True)),
),
patch.object(
module.EnterpriseService.WebAppAuth,
"batch_get_app_access_mode_by_id",
return_value={"a1": mock_webapp_setting},
),
patch.object(
module.EnterpriseService.WebAppAuth,
"batch_is_user_allowed_to_access_webapps",
return_value={"a1": True},
),
):
result = method(api)
assert len(result["installed_apps"]) == 1
def test_get_installed_apps_with_webapp_auth_user_denied(self, app, current_user, tenant_id, installed_app):
"""Test filtering when user doesn't have access."""
api = module.InstalledAppsListApi()
method = unwrap(api.get)
session = MagicMock()
session.scalars.return_value.all.return_value = [installed_app]
mock_webapp_setting = MagicMock()
mock_webapp_setting.access_mode = "restricted"
with (
app.test_request_context("/"),
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
patch.object(module.db, "session", session),
patch.object(module.TenantService, "get_user_role", return_value="member"),
patch.object(
module.FeatureService,
"get_system_features",
return_value=MagicMock(webapp_auth=MagicMock(enabled=True)),
),
patch.object(
module.EnterpriseService.WebAppAuth,
"batch_get_app_access_mode_by_id",
return_value={"a1": mock_webapp_setting},
),
patch.object(
module.EnterpriseService.WebAppAuth,
"batch_is_user_allowed_to_access_webapps",
return_value={"a1": False},
),
):
result = method(api)
assert result["installed_apps"] == []
def test_get_installed_apps_with_sso_verified_access(self, app, current_user, tenant_id, installed_app):
"""Test that sso_verified access mode apps are skipped in filtering."""
api = module.InstalledAppsListApi()
method = unwrap(api.get)
session = MagicMock()
session.scalars.return_value.all.return_value = [installed_app]
mock_webapp_setting = MagicMock()
mock_webapp_setting.access_mode = "sso_verified"
with (
app.test_request_context("/"),
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
patch.object(module.db, "session", session),
patch.object(module.TenantService, "get_user_role", return_value="owner"),
patch.object(
module.FeatureService,
"get_system_features",
return_value=MagicMock(webapp_auth=MagicMock(enabled=True)),
),
patch.object(
module.EnterpriseService.WebAppAuth,
"batch_get_app_access_mode_by_id",
return_value={"a1": mock_webapp_setting},
),
):
result = method(api)
assert len(result["installed_apps"]) == 0
def test_get_installed_apps_filters_null_apps(self, app, current_user, tenant_id):
"""Test that installed apps with null app are filtered out."""
api = module.InstalledAppsListApi()
method = unwrap(api.get)
installed_app_with_null = MagicMock()
installed_app_with_null.app = None
session = MagicMock()
session.scalars.return_value.all.return_value = [installed_app_with_null]
with (
app.test_request_context("/"),
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
patch.object(module.db, "session", session),
patch.object(module.TenantService, "get_user_role", return_value="owner"),
patch.object(
module.FeatureService,
"get_system_features",
return_value=MagicMock(webapp_auth=MagicMock(enabled=False)),
),
):
result = method(api)
assert result["installed_apps"] == []
def test_get_installed_apps_current_tenant_none(self, app, tenant_id, installed_app):
"""Test error when current_user.current_tenant is None."""
api = module.InstalledAppsListApi()
method = unwrap(api.get)
current_user = MagicMock()
current_user.current_tenant = None
session = MagicMock()
session.scalars.return_value.all.return_value = [installed_app]
with (
app.test_request_context("/"),
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
patch.object(module.db, "session", session),
):
with pytest.raises(ValueError, match="current_user.current_tenant must not be None"):
method(api)
class TestInstalledAppsCreateApi:
def test_post_success(self, app, tenant_id, payload_patch):
api = module.InstalledAppsListApi()
method = unwrap(api.post)
recommended = MagicMock()
recommended.install_count = 0
app_entity = MagicMock()
app_entity.id = "a1"
app_entity.is_public = True
app_entity.tenant_id = "t2"
session = MagicMock()
session.query.return_value.where.return_value.first.side_effect = [
recommended,
app_entity,
None,
]
with (
app.test_request_context("/", json={"app_id": "a1"}),
payload_patch({"app_id": "a1"}),
patch.object(module.db, "session", session),
patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)),
):
result = method(api)
assert result == {"message": "App installed successfully"}
assert recommended.install_count == 1
def test_post_recommended_not_found(self, app, payload_patch):
api = module.InstalledAppsListApi()
method = unwrap(api.post)
session = MagicMock()
session.query.return_value.where.return_value.first.return_value = None
with (
app.test_request_context("/", json={"app_id": "a1"}),
payload_patch({"app_id": "a1"}),
patch.object(module.db, "session", session),
):
with pytest.raises(NotFound):
method(api)
def test_post_app_not_public(self, app, tenant_id, payload_patch):
api = module.InstalledAppsListApi()
method = unwrap(api.post)
recommended = MagicMock()
app_entity = MagicMock(is_public=False)
session = MagicMock()
session.query.return_value.where.return_value.first.side_effect = [
recommended,
app_entity,
]
with (
app.test_request_context("/", json={"app_id": "a1"}),
payload_patch({"app_id": "a1"}),
patch.object(module.db, "session", session),
patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)),
):
with pytest.raises(Forbidden):
method(api)
class TestInstalledAppApi:
def test_delete_success(self, tenant_id, installed_app):
api = module.InstalledAppApi()
method = unwrap(api.delete)
with (
patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)),
patch.object(module.db, "session"),
):
resp, status = method(installed_app)
assert status == 204
assert resp["result"] == "success"
def test_delete_owned_by_current_tenant(self, tenant_id):
api = module.InstalledAppApi()
method = unwrap(api.delete)
installed_app = MagicMock(app_owner_tenant_id=tenant_id)
with patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)):
with pytest.raises(BadRequest):
method(installed_app)
def test_patch_update_pin(self, app, payload_patch, installed_app):
api = module.InstalledAppApi()
method = unwrap(api.patch)
with (
app.test_request_context("/", json={"is_pinned": True}),
payload_patch({"is_pinned": True}),
patch.object(module.db, "session"),
):
result = method(installed_app)
assert installed_app.is_pinned is True
assert result["result"] == "success"
def test_patch_no_change(self, app, payload_patch, installed_app):
api = module.InstalledAppApi()
method = unwrap(api.patch)
with app.test_request_context("/", json={}), payload_patch({}), patch.object(module.db, "session"):
result = method(installed_app)
assert result["result"] == "success"

View File

@ -0,0 +1,552 @@
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import InternalServerError, NotFound
import controllers.console.explore.message as module
from controllers.console.app.error import (
AppMoreLikeThisDisabledError,
CompletionRequestError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.explore.error import (
AppSuggestedQuestionsAfterAnswerDisabledError,
NotChatAppError,
NotCompletionAppError,
)
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from dify_graph.model_runtime.errors.invoke import InvokeError
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import (
FirstMessageNotExistsError,
MessageNotExistsError,
SuggestedQuestionsAfterAnswerDisabledError,
)
def unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
def make_message():
msg = MagicMock()
msg.id = "m1"
msg.conversation_id = "11111111-1111-1111-1111-111111111111"
msg.parent_message_id = None
msg.inputs = {}
msg.query = "hello"
msg.re_sign_file_url_answer = ""
msg.user_feedback = MagicMock(rating=None)
msg.status = "success"
msg.error = None
return msg
class TestMessageListApi:
def test_get_success(self, app):
api = module.MessageListApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="chat")
pagination = MagicMock(
limit=20,
has_more=False,
data=[make_message(), make_message()],
)
with (
app.test_request_context(
"/",
query_string={"conversation_id": "11111111-1111-1111-1111-111111111111"},
),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.MessageService,
"pagination_by_first_id",
return_value=pagination,
),
):
result = method(installed_app)
assert result["limit"] == 20
assert result["has_more"] is False
assert len(result["data"]) == 2
def test_get_not_chat_app(self):
api = module.MessageListApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="completion")
with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)):
with pytest.raises(NotChatAppError):
method(installed_app)
def test_conversation_not_exists(self, app):
api = module.MessageListApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="chat")
with (
app.test_request_context(
"/",
query_string={"conversation_id": "11111111-1111-1111-1111-111111111111"},
),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.MessageService,
"pagination_by_first_id",
side_effect=ConversationNotExistsError(),
),
):
with pytest.raises(NotFound):
method(installed_app)
def test_first_message_not_exists(self, app):
api = module.MessageListApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="chat")
with (
app.test_request_context(
"/",
query_string={"conversation_id": "11111111-1111-1111-1111-111111111111"},
),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.MessageService,
"pagination_by_first_id",
side_effect=FirstMessageNotExistsError(),
),
):
with pytest.raises(NotFound):
method(installed_app)
class TestMessageFeedbackApi:
def test_post_success(self, app):
api = module.MessageFeedbackApi()
method = unwrap(api.post)
installed_app = MagicMock()
installed_app.app = MagicMock()
with (
app.test_request_context("/", json={"rating": "like"}),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.MessageService,
"create_feedback",
),
):
result = method(installed_app, "mid")
assert result["result"] == "success"
def test_message_not_exists(self, app):
api = module.MessageFeedbackApi()
method = unwrap(api.post)
installed_app = MagicMock()
installed_app.app = MagicMock()
with (
app.test_request_context("/", json={}),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.MessageService,
"create_feedback",
side_effect=MessageNotExistsError(),
),
):
with pytest.raises(NotFound):
method(installed_app, "mid")
class TestMessageMoreLikeThisApi:
def test_get_success(self, app):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="completion")
with (
app.test_request_context(
"/",
query_string={"response_mode": "blocking"},
),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.AppGenerateService,
"generate_more_like_this",
return_value={"ok": True},
),
patch.object(
module.helper,
"compact_generate_response",
return_value=("ok", 200),
),
):
resp = method(installed_app, "mid")
assert resp == ("ok", 200)
def test_not_completion_app(self):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="chat")
with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)):
with pytest.raises(NotCompletionAppError):
method(installed_app, "mid")
def test_more_like_this_disabled(self, app):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="completion")
with (
app.test_request_context(
"/",
query_string={"response_mode": "blocking"},
),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.AppGenerateService,
"generate_more_like_this",
side_effect=module.MoreLikeThisDisabledError(),
),
):
with pytest.raises(AppMoreLikeThisDisabledError):
method(installed_app, "mid")
def test_message_not_exists_more_like_this(self, app):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="completion")
with (
app.test_request_context(
"/",
query_string={"response_mode": "blocking"},
),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.AppGenerateService,
"generate_more_like_this",
side_effect=MessageNotExistsError(),
),
):
with pytest.raises(NotFound):
method(installed_app, "mid")
def test_provider_not_init_more_like_this(self, app):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="completion")
with (
app.test_request_context(
"/",
query_string={"response_mode": "blocking"},
),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.AppGenerateService,
"generate_more_like_this",
side_effect=ProviderTokenNotInitError("test"),
),
):
with pytest.raises(ProviderNotInitializeError):
method(installed_app, "mid")
def test_quota_exceeded_more_like_this(self, app):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="completion")
with (
app.test_request_context(
"/",
query_string={"response_mode": "blocking"},
),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.AppGenerateService,
"generate_more_like_this",
side_effect=QuotaExceededError(),
),
):
with pytest.raises(ProviderQuotaExceededError):
method(installed_app, "mid")
def test_model_not_support_more_like_this(self, app):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="completion")
with (
app.test_request_context(
"/",
query_string={"response_mode": "blocking"},
),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.AppGenerateService,
"generate_more_like_this",
side_effect=ModelCurrentlyNotSupportError(),
),
):
with pytest.raises(ProviderModelCurrentlyNotSupportError):
method(installed_app, "mid")
def test_invoke_error_more_like_this(self, app):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="completion")
with (
app.test_request_context(
"/",
query_string={"response_mode": "blocking"},
),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.AppGenerateService,
"generate_more_like_this",
side_effect=InvokeError("test error"),
),
):
with pytest.raises(CompletionRequestError):
method(installed_app, "mid")
def test_unexpected_error_more_like_this(self, app):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="completion")
with (
app.test_request_context(
"/",
query_string={"response_mode": "blocking"},
),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.AppGenerateService,
"generate_more_like_this",
side_effect=Exception("unexpected"),
),
):
with pytest.raises(InternalServerError):
method(installed_app, "mid")
class TestMessageSuggestedQuestionApi:
def test_get_success(self):
api = module.MessageSuggestedQuestionApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="chat")
with (
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.MessageService,
"get_suggested_questions_after_answer",
return_value=["q1", "q2"],
),
):
result = method(installed_app, "mid")
assert result["data"] == ["q1", "q2"]
def test_not_chat_app(self):
api = module.MessageSuggestedQuestionApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="completion")
with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)):
with pytest.raises(NotChatAppError):
method(installed_app, "mid")
def test_disabled(self):
api = module.MessageSuggestedQuestionApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="chat")
with (
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.MessageService,
"get_suggested_questions_after_answer",
side_effect=SuggestedQuestionsAfterAnswerDisabledError(),
),
):
with pytest.raises(AppSuggestedQuestionsAfterAnswerDisabledError):
method(installed_app, "mid")
def test_message_not_exists_suggested_question(self):
api = module.MessageSuggestedQuestionApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="chat")
with (
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.MessageService,
"get_suggested_questions_after_answer",
side_effect=MessageNotExistsError(),
),
):
with pytest.raises(NotFound):
method(installed_app, "mid")
def test_conversation_not_exists_suggested_question(self):
api = module.MessageSuggestedQuestionApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="chat")
with (
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.MessageService,
"get_suggested_questions_after_answer",
side_effect=ConversationNotExistsError(),
),
):
with pytest.raises(NotFound):
method(installed_app, "mid")
def test_provider_not_init_suggested_question(self):
api = module.MessageSuggestedQuestionApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="chat")
with (
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.MessageService,
"get_suggested_questions_after_answer",
side_effect=ProviderTokenNotInitError("test"),
),
):
with pytest.raises(ProviderNotInitializeError):
method(installed_app, "mid")
def test_quota_exceeded_suggested_question(self):
api = module.MessageSuggestedQuestionApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="chat")
with (
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.MessageService,
"get_suggested_questions_after_answer",
side_effect=QuotaExceededError(),
),
):
with pytest.raises(ProviderQuotaExceededError):
method(installed_app, "mid")
def test_model_not_support_suggested_question(self):
api = module.MessageSuggestedQuestionApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="chat")
with (
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.MessageService,
"get_suggested_questions_after_answer",
side_effect=ModelCurrentlyNotSupportError(),
),
):
with pytest.raises(ProviderModelCurrentlyNotSupportError):
method(installed_app, "mid")
def test_invoke_error_suggested_question(self):
api = module.MessageSuggestedQuestionApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="chat")
with (
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.MessageService,
"get_suggested_questions_after_answer",
side_effect=InvokeError("test error"),
),
):
with pytest.raises(CompletionRequestError):
method(installed_app, "mid")
def test_unexpected_error_suggested_question(self):
api = module.MessageSuggestedQuestionApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="chat")
with (
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.MessageService,
"get_suggested_questions_after_answer",
side_effect=Exception("unexpected"),
),
):
with pytest.raises(InternalServerError):
method(installed_app, "mid")

View File

@ -0,0 +1,140 @@
from unittest.mock import MagicMock, patch
import pytest
import controllers.console.explore.parameter as module
from controllers.console.app.error import AppUnavailableError
from models.model import AppMode
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestAppParameterApi:
def test_get_app_none(self):
api = module.AppParameterApi()
method = unwrap(api.get)
installed_app = MagicMock(app=None)
with pytest.raises(AppUnavailableError):
method(installed_app)
def test_get_advanced_chat_workflow(self):
api = module.AppParameterApi()
method = unwrap(api.get)
workflow = MagicMock()
workflow.features_dict = {"f": "v"}
workflow.user_input_form.return_value = [{"name": "x"}]
app = MagicMock(
mode=AppMode.ADVANCED_CHAT,
workflow=workflow,
)
installed_app = MagicMock(app=app)
with (
patch.object(
module,
"get_parameters_from_feature_dict",
return_value={"any": "thing"},
),
patch.object(
module.fields.Parameters,
"model_validate",
return_value=MagicMock(model_dump=lambda **_: {"ok": True}),
),
):
result = method(installed_app)
assert result == {"ok": True}
def test_get_advanced_chat_workflow_missing(self):
api = module.AppParameterApi()
method = unwrap(api.get)
app = MagicMock(
mode=AppMode.ADVANCED_CHAT,
workflow=None,
)
installed_app = MagicMock(app=app)
with pytest.raises(AppUnavailableError):
method(installed_app)
def test_get_non_workflow_app(self):
api = module.AppParameterApi()
method = unwrap(api.get)
app_model_config = MagicMock()
app_model_config.to_dict.return_value = {"user_input_form": [{"name": "y"}]}
app = MagicMock(
mode=AppMode.CHAT,
app_model_config=app_model_config,
)
installed_app = MagicMock(app=app)
with (
patch.object(
module,
"get_parameters_from_feature_dict",
return_value={"whatever": 123},
),
patch.object(
module.fields.Parameters,
"model_validate",
return_value=MagicMock(model_dump=lambda **_: {"ok": True}),
),
):
result = method(installed_app)
assert result == {"ok": True}
def test_get_non_workflow_missing_config(self):
api = module.AppParameterApi()
method = unwrap(api.get)
app = MagicMock(
mode=AppMode.CHAT,
app_model_config=None,
)
installed_app = MagicMock(app=app)
with pytest.raises(AppUnavailableError):
method(installed_app)
class TestExploreAppMetaApi:
def test_get_meta_success(self):
api = module.ExploreAppMetaApi()
method = unwrap(api.get)
app = MagicMock()
installed_app = MagicMock(app=app)
with patch.object(
module.AppService,
"get_app_meta",
return_value={"meta": "ok"},
):
result = method(installed_app)
assert result == {"meta": "ok"}
def test_get_meta_app_missing(self):
api = module.ExploreAppMetaApi()
method = unwrap(api.get)
installed_app = MagicMock(app=None)
with pytest.raises(ValueError):
method(installed_app)

View File

@ -0,0 +1,92 @@
from unittest.mock import MagicMock, patch
import controllers.console.explore.recommended_app as module
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestRecommendedAppListApi:
def test_get_with_language_param(self, app):
api = module.RecommendedAppListApi()
method = unwrap(api.get)
result_data = {"recommended_apps": [], "categories": []}
with (
app.test_request_context("/", query_string={"language": "en-US"}),
patch.object(module, "current_user", MagicMock(interface_language="fr-FR")),
patch.object(
module.RecommendedAppService,
"get_recommended_apps_and_categories",
return_value=result_data,
) as service_mock,
):
result = method(api)
service_mock.assert_called_once_with("en-US")
assert result == result_data
def test_get_fallback_to_user_language(self, app):
api = module.RecommendedAppListApi()
method = unwrap(api.get)
result_data = {"recommended_apps": [], "categories": []}
with (
app.test_request_context("/", query_string={"language": "invalid"}),
patch.object(module, "current_user", MagicMock(interface_language="fr-FR")),
patch.object(
module.RecommendedAppService,
"get_recommended_apps_and_categories",
return_value=result_data,
) as service_mock,
):
result = method(api)
service_mock.assert_called_once_with("fr-FR")
assert result == result_data
def test_get_fallback_to_default_language(self, app):
api = module.RecommendedAppListApi()
method = unwrap(api.get)
result_data = {"recommended_apps": [], "categories": []}
with (
app.test_request_context("/"),
patch.object(module, "current_user", MagicMock(interface_language=None)),
patch.object(
module.RecommendedAppService,
"get_recommended_apps_and_categories",
return_value=result_data,
) as service_mock,
):
result = method(api)
service_mock.assert_called_once_with(module.languages[0])
assert result == result_data
class TestRecommendedAppApi:
def test_get_success(self, app):
api = module.RecommendedAppApi()
method = unwrap(api.get)
result_data = {"id": "app1"}
with (
app.test_request_context("/"),
patch.object(
module.RecommendedAppService,
"get_recommend_app_detail",
return_value=result_data,
) as service_mock,
):
result = method(api, "11111111-1111-1111-1111-111111111111")
service_mock.assert_called_once_with("11111111-1111-1111-1111-111111111111")
assert result == result_data

View File

@ -0,0 +1,154 @@
from unittest.mock import MagicMock, PropertyMock, patch
from uuid import uuid4
import pytest
from werkzeug.exceptions import NotFound
import controllers.console.explore.saved_message as module
from controllers.console.explore.error import NotCompletionAppError
from services.errors.message import MessageNotExistsError
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def make_saved_message():
msg = MagicMock()
msg.id = str(uuid4())
msg.message_id = str(uuid4())
msg.app_id = str(uuid4())
msg.inputs = {}
msg.query = "hello"
msg.answer = "world"
msg.user_feedback = MagicMock(rating="like")
msg.created_at = None
return msg
@pytest.fixture
def payload_patch():
def _patch(payload):
return patch.object(
type(module.console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
)
return _patch
class TestSavedMessageListApi:
def test_get_success(self, app):
api = module.SavedMessageListApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="completion")
pagination = MagicMock(
limit=20,
has_more=False,
data=[make_saved_message(), make_saved_message()],
)
with (
app.test_request_context("/", query_string={}),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.SavedMessageService,
"pagination_by_last_id",
return_value=pagination,
),
):
result = method(installed_app)
assert result["limit"] == 20
assert result["has_more"] is False
assert len(result["data"]) == 2
def test_get_not_completion_app(self):
api = module.SavedMessageListApi()
method = unwrap(api.get)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="chat")
with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)):
with pytest.raises(NotCompletionAppError):
method(installed_app)
def test_post_success(self, app, payload_patch):
api = module.SavedMessageListApi()
method = unwrap(api.post)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="completion")
payload = {"message_id": str(uuid4())}
with (
app.test_request_context("/", json=payload),
payload_patch(payload),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(module.SavedMessageService, "save") as save_mock,
):
result = method(installed_app)
save_mock.assert_called_once()
assert result == {"result": "success"}
def test_post_message_not_exists(self, app, payload_patch):
api = module.SavedMessageListApi()
method = unwrap(api.post)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="completion")
payload = {"message_id": str(uuid4())}
with (
app.test_request_context("/", json=payload),
payload_patch(payload),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.SavedMessageService,
"save",
side_effect=MessageNotExistsError(),
),
):
with pytest.raises(NotFound):
method(installed_app)
class TestSavedMessageApi:
def test_delete_success(self):
api = module.SavedMessageApi()
method = unwrap(api.delete)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="completion")
with (
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(module.SavedMessageService, "delete") as delete_mock,
):
result, status = method(installed_app, str(uuid4()))
delete_mock.assert_called_once()
assert status == 204
assert result == {"result": "success"}
def test_delete_not_completion_app(self):
api = module.SavedMessageApi()
method = unwrap(api.delete)
installed_app = MagicMock()
installed_app.app = MagicMock(mode="chat")
with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)):
with pytest.raises(NotCompletionAppError):
method(installed_app, str(uuid4()))

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,151 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import InternalServerError
from controllers.console.explore.error import NotWorkflowAppError
from controllers.console.explore.workflow import (
InstalledAppWorkflowRunApi,
InstalledAppWorkflowTaskStopApi,
)
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from models.model import AppMode
from services.errors.llm import InvokeRateLimitError
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def app():
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def user():
return MagicMock()
@pytest.fixture
def workflow_app():
app = MagicMock()
app.mode = AppMode.WORKFLOW
return app
@pytest.fixture
def installed_workflow_app(workflow_app):
return MagicMock(app=workflow_app)
@pytest.fixture
def non_workflow_installed_app():
app = MagicMock()
app.mode = AppMode.CHAT
return MagicMock(app=app)
@pytest.fixture
def payload():
return {"inputs": {"a": 1}}
class TestInstalledAppWorkflowRunApi:
def test_not_workflow_app(self, app, non_workflow_installed_app):
api = InstalledAppWorkflowRunApi()
method = unwrap(api.post)
with (
app.test_request_context("/"),
patch(
"controllers.console.explore.workflow.current_account_with_tenant",
return_value=(MagicMock(), None),
),
):
with pytest.raises(NotWorkflowAppError):
method(non_workflow_installed_app)
def test_success(self, app, installed_workflow_app, user, payload):
api = InstalledAppWorkflowRunApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.explore.workflow.current_account_with_tenant",
return_value=(user, None),
),
patch(
"controllers.console.explore.workflow.AppGenerateService.generate",
return_value=MagicMock(),
) as generate_mock,
):
result = method(installed_workflow_app)
generate_mock.assert_called_once()
assert result is not None
def test_rate_limit_error(self, app, installed_workflow_app, user, payload):
api = InstalledAppWorkflowRunApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.explore.workflow.current_account_with_tenant",
return_value=(user, None),
),
patch(
"controllers.console.explore.workflow.AppGenerateService.generate",
side_effect=InvokeRateLimitError("rate limit"),
),
):
with pytest.raises(InvokeRateLimitHttpError):
method(installed_workflow_app)
def test_unexpected_exception(self, app, installed_workflow_app, user, payload):
api = InstalledAppWorkflowRunApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.explore.workflow.current_account_with_tenant",
return_value=(user, None),
),
patch(
"controllers.console.explore.workflow.AppGenerateService.generate",
side_effect=Exception("boom"),
),
):
with pytest.raises(InternalServerError):
method(installed_workflow_app)
class TestInstalledAppWorkflowTaskStopApi:
def test_not_workflow_app(self, non_workflow_installed_app):
api = InstalledAppWorkflowTaskStopApi()
method = unwrap(api.post)
with pytest.raises(NotWorkflowAppError):
method(non_workflow_installed_app, "task-1")
def test_success(self, installed_workflow_app):
api = InstalledAppWorkflowTaskStopApi()
method = unwrap(api.post)
with (
patch("controllers.console.explore.workflow.AppQueueManager.set_stop_flag_no_user_check") as stop_flag,
patch("controllers.console.explore.workflow.GraphEngineManager.send_stop_command") as send_stop,
):
result = method(installed_workflow_app, "task-1")
stop_flag.assert_called_once_with("task-1")
send_stop.assert_called_once_with("task-1")
assert result == {"result": "success"}

View File

@ -0,0 +1,244 @@
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden, NotFound
from controllers.console.explore.error import (
AppAccessDeniedError,
TrialAppLimitExceeded,
TrialAppNotAllowed,
)
from controllers.console.explore.wraps import (
InstalledAppResource,
TrialAppResource,
installed_app_required,
trial_app_required,
trial_feature_enable,
user_allowed_to_access_app,
)
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def test_installed_app_required_not_found():
@installed_app_required
def view(installed_app):
return "ok"
with (
patch(
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch("controllers.console.explore.wraps.db.session.query") as q,
):
q.return_value.where.return_value.first.return_value = None
with pytest.raises(NotFound):
view("app-id")
def test_installed_app_required_app_deleted():
installed_app = MagicMock(app=None)
@installed_app_required
def view(installed_app):
return "ok"
with (
patch(
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch("controllers.console.explore.wraps.db.session.query") as q,
patch("controllers.console.explore.wraps.db.session.delete"),
patch("controllers.console.explore.wraps.db.session.commit"),
):
q.return_value.where.return_value.first.return_value = installed_app
with pytest.raises(NotFound):
view("app-id")
def test_installed_app_required_success():
installed_app = MagicMock(app=MagicMock())
@installed_app_required
def view(installed_app):
return installed_app
with (
patch(
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch("controllers.console.explore.wraps.db.session.query") as q,
):
q.return_value.where.return_value.first.return_value = installed_app
result = view("app-id")
assert result == installed_app
def test_user_allowed_to_access_app_denied():
installed_app = MagicMock(app_id="app-1")
@user_allowed_to_access_app
def view(installed_app):
return "ok"
feature = MagicMock()
feature.webapp_auth.enabled = True
with (
patch(
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(id="user-1"), None),
),
patch(
"controllers.console.explore.wraps.FeatureService.get_system_features",
return_value=feature,
),
patch(
"controllers.console.explore.wraps.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp",
return_value=False,
),
):
with pytest.raises(AppAccessDeniedError):
view(installed_app)
def test_user_allowed_to_access_app_success():
installed_app = MagicMock(app_id="app-1")
@user_allowed_to_access_app
def view(installed_app):
return "ok"
feature = MagicMock()
feature.webapp_auth.enabled = True
with (
patch(
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(id="user-1"), None),
),
patch(
"controllers.console.explore.wraps.FeatureService.get_system_features",
return_value=feature,
),
patch(
"controllers.console.explore.wraps.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp",
return_value=True,
),
):
assert view(installed_app) == "ok"
def test_trial_app_required_not_allowed():
@trial_app_required
def view(app):
return "ok"
with (
patch(
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(id="user-1"), None),
),
patch("controllers.console.explore.wraps.db.session.query") as q,
):
q.return_value.where.return_value.first.return_value = None
with pytest.raises(TrialAppNotAllowed):
view("app-id")
def test_trial_app_required_limit_exceeded():
trial_app = MagicMock(trial_limit=1, app=MagicMock())
record = MagicMock(count=1)
@trial_app_required
def view(app):
return "ok"
with (
patch(
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(id="user-1"), None),
),
patch("controllers.console.explore.wraps.db.session.query") as q,
):
q.return_value.where.return_value.first.side_effect = [
trial_app,
record,
]
with pytest.raises(TrialAppLimitExceeded):
view("app-id")
def test_trial_app_required_success():
trial_app = MagicMock(trial_limit=2, app=MagicMock())
record = MagicMock(count=1)
@trial_app_required
def view(app):
return app
with (
patch(
"controllers.console.explore.wraps.current_account_with_tenant",
return_value=(MagicMock(id="user-1"), None),
),
patch("controllers.console.explore.wraps.db.session.query") as q,
):
q.return_value.where.return_value.first.side_effect = [
trial_app,
record,
]
result = view("app-id")
assert result == trial_app.app
def test_trial_feature_enable_disabled():
@trial_feature_enable
def view():
return "ok"
features = MagicMock(enable_trial_app=False)
with patch(
"controllers.console.explore.wraps.FeatureService.get_system_features",
return_value=features,
):
with pytest.raises(Forbidden):
view()
def test_trial_feature_enable_enabled():
@trial_feature_enable
def view():
return "ok"
features = MagicMock(enable_trial_app=True)
with patch(
"controllers.console.explore.wraps.FeatureService.get_system_features",
return_value=features,
):
assert view() == "ok"
def test_installed_app_resource_decorators():
decorators = InstalledAppResource.method_decorators
assert len(decorators) == 4
def test_trial_app_resource_decorators():
decorators = TrialAppResource.method_decorators
assert len(decorators) == 3

View File

@ -0,0 +1,278 @@
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
from controllers.console.tag.tags import (
TagBindingCreateApi,
TagBindingDeleteApi,
TagListApi,
TagUpdateDeleteApi,
)
def unwrap(func):
"""
Recursively unwrap decorated functions.
"""
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def app():
app = Flask("test_tag")
app.config["TESTING"] = True
return app
@pytest.fixture
def admin_user():
return MagicMock(
id="user-1",
has_edit_permission=True,
is_dataset_editor=True,
)
@pytest.fixture
def readonly_user():
return MagicMock(
id="user-2",
has_edit_permission=False,
is_dataset_editor=False,
)
@pytest.fixture
def tag():
tag = MagicMock()
tag.id = "tag-1"
tag.name = "test-tag"
tag.type = "knowledge"
return tag
@pytest.fixture
def payload_patch():
def _patch(payload):
return patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
)
return _patch
class TestTagListApi:
def test_get_success(self, app):
api = TagListApi()
method = unwrap(api.get)
with app.test_request_context("/?type=knowledge"):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.tag.tags.TagService.get_tags",
return_value=[{"id": "1", "name": "tag"}],
),
):
result, status = method(api)
assert status == 200
assert isinstance(result, list)
def test_post_success(self, app, admin_user, tag, payload_patch):
api = TagListApi()
method = unwrap(api.post)
payload = {"name": "test-tag", "type": "knowledge"}
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(admin_user, None),
),
payload_patch(payload),
patch(
"controllers.console.tag.tags.TagService.save_tags",
return_value=tag,
),
):
result, status = method(api)
assert status == 200
assert result["name"] == "test-tag"
def test_post_forbidden(self, app, readonly_user, payload_patch):
api = TagListApi()
method = unwrap(api.post)
payload = {"name": "x"}
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(readonly_user, None),
),
payload_patch(payload),
):
with pytest.raises(Forbidden):
method(api)
class TestTagUpdateDeleteApi:
def test_patch_success(self, app, admin_user, tag, payload_patch):
api = TagUpdateDeleteApi()
method = unwrap(api.patch)
payload = {"name": "updated", "type": "knowledge"}
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(admin_user, None),
),
payload_patch(payload),
patch(
"controllers.console.tag.tags.TagService.update_tags",
return_value=tag,
),
patch(
"controllers.console.tag.tags.TagService.get_tag_binding_count",
return_value=3,
),
):
result, status = method(api, "tag-1")
assert status == 200
assert result["binding_count"] == 3
def test_patch_forbidden(self, app, readonly_user, payload_patch):
api = TagUpdateDeleteApi()
method = unwrap(api.patch)
payload = {"name": "x"}
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(readonly_user, None),
),
payload_patch(payload),
):
with pytest.raises(Forbidden):
method(api, "tag-1")
def test_delete_success(self, app, admin_user):
api = TagUpdateDeleteApi()
method = unwrap(api.delete)
with (
app.test_request_context("/"),
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(admin_user, "tenant-1"),
),
patch("controllers.console.tag.tags.TagService.delete_tag") as delete_mock,
):
result, status = method(api, "tag-1")
delete_mock.assert_called_once_with("tag-1")
assert status == 204
class TestTagBindingCreateApi:
def test_create_success(self, app, admin_user, payload_patch):
api = TagBindingCreateApi()
method = unwrap(api.post)
payload = {
"tag_ids": ["tag-1"],
"target_id": "target-1",
"type": "knowledge",
}
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(admin_user, None),
),
payload_patch(payload),
patch("controllers.console.tag.tags.TagService.save_tag_binding") as save_mock,
):
result, status = method(api)
save_mock.assert_called_once()
assert status == 200
assert result["result"] == "success"
def test_create_forbidden(self, app, readonly_user, payload_patch):
api = TagBindingCreateApi()
method = unwrap(api.post)
with app.test_request_context("/", json={}):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(readonly_user, None),
),
payload_patch({}),
):
with pytest.raises(Forbidden):
method(api)
class TestTagBindingDeleteApi:
def test_remove_success(self, app, admin_user, payload_patch):
api = TagBindingDeleteApi()
method = unwrap(api.post)
payload = {
"tag_id": "tag-1",
"target_id": "target-1",
"type": "knowledge",
}
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(admin_user, None),
),
payload_patch(payload),
patch("controllers.console.tag.tags.TagService.delete_tag_binding") as delete_mock,
):
result, status = method(api)
delete_mock.assert_called_once()
assert status == 200
assert result["result"] == "success"
def test_remove_forbidden(self, app, readonly_user, payload_patch):
api = TagBindingDeleteApi()
method = unwrap(api.post)
with app.test_request_context("/", json={}):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(readonly_user, None),
),
payload_patch({}),
):
with pytest.raises(Forbidden):
method(api)

View File

@ -1,13 +1,483 @@
"""Final working unit tests for admin endpoints - tests business logic directly."""
import uuid
from unittest.mock import Mock, patch
from unittest.mock import Mock, PropertyMock, patch
import pytest
from werkzeug.exceptions import NotFound, Unauthorized
from controllers.console.admin import InsertExploreAppPayload
from models.model import App, RecommendedApp
from controllers.console.admin import (
DeleteExploreBannerApi,
InsertExploreAppApi,
InsertExploreAppListApi,
InsertExploreAppPayload,
InsertExploreBannerApi,
InsertExploreBannerPayload,
)
from models.model import App, InstalledApp, RecommendedApp
@pytest.fixture(autouse=True)
def bypass_only_edition_cloud(mocker):
"""
Bypass only_edition_cloud decorator by setting EDITION to "CLOUD".
"""
mocker.patch(
"controllers.console.wraps.dify_config.EDITION",
new="CLOUD",
)
@pytest.fixture
def mock_admin_auth(mocker):
"""
Provide valid admin authentication for controller tests.
"""
mocker.patch(
"controllers.console.admin.dify_config.ADMIN_API_KEY",
"test-admin-key",
)
mocker.patch(
"controllers.console.admin.extract_access_token",
return_value="test-admin-key",
)
@pytest.fixture
def mock_console_payload(mocker):
payload = {
"app_id": str(uuid.uuid4()),
"language": "en-US",
"category": "Productivity",
"position": 1,
}
mocker.patch(
"flask_restx.namespace.Namespace.payload",
new_callable=PropertyMock,
return_value=payload,
)
return payload
@pytest.fixture
def mock_banner_payload(mocker):
mocker.patch(
"flask_restx.namespace.Namespace.payload",
new_callable=PropertyMock,
return_value={
"title": "Test Banner",
"description": "Banner description",
"img-src": "https://example.com/banner.png",
"link": "https://example.com",
"sort": 1,
"category": "homepage",
},
)
@pytest.fixture
def mock_session_factory(mocker):
mock_session = Mock()
mock_session.execute = Mock()
mock_session.add = Mock()
mock_session.commit = Mock()
mocker.patch(
"controllers.console.admin.session_factory.create_session",
return_value=Mock(
__enter__=lambda s: mock_session,
__exit__=Mock(return_value=False),
),
)
class TestDeleteExploreBannerApi:
def setup_method(self):
self.api = DeleteExploreBannerApi()
def test_delete_banner_not_found(self, mocker, mock_admin_auth):
mocker.patch(
"controllers.console.admin.db.session.execute",
return_value=Mock(scalar_one_or_none=lambda: None),
)
with pytest.raises(NotFound, match="is not found"):
self.api.delete(uuid.uuid4())
def test_delete_banner_success(self, mocker, mock_admin_auth):
mock_banner = Mock()
mocker.patch(
"controllers.console.admin.db.session.execute",
return_value=Mock(scalar_one_or_none=lambda: mock_banner),
)
mocker.patch("controllers.console.admin.db.session.delete")
mocker.patch("controllers.console.admin.db.session.commit")
response, status = self.api.delete(uuid.uuid4())
assert status == 204
assert response["result"] == "success"
class TestInsertExploreBannerApi:
def setup_method(self):
self.api = InsertExploreBannerApi()
def test_insert_banner_success(self, mocker, mock_admin_auth, mock_banner_payload):
mocker.patch("controllers.console.admin.db.session.add")
mocker.patch("controllers.console.admin.db.session.commit")
response, status = self.api.post()
assert status == 201
assert response["result"] == "success"
def test_banner_payload_valid_language(self):
payload = {
"title": "Test Banner",
"description": "Banner description",
"img-src": "https://example.com/banner.png",
"link": "https://example.com",
"sort": 1,
"category": "homepage",
"language": "en-US",
}
model = InsertExploreBannerPayload.model_validate(payload)
assert model.language == "en-US"
def test_banner_payload_invalid_language(self):
payload = {
"title": "Test Banner",
"description": "Banner description",
"img-src": "https://example.com/banner.png",
"link": "https://example.com",
"sort": 1,
"category": "homepage",
"language": "invalid-lang",
}
with pytest.raises(ValueError, match="invalid-lang is not a valid language"):
InsertExploreBannerPayload.model_validate(payload)
class TestInsertExploreAppApiDelete:
def setup_method(self):
self.api = InsertExploreAppApi()
def test_delete_when_not_in_explore(self, mocker, mock_admin_auth):
mocker.patch(
"controllers.console.admin.session_factory.create_session",
return_value=Mock(
__enter__=lambda s: s,
__exit__=Mock(return_value=False),
execute=lambda *_: Mock(scalar_one_or_none=lambda: None),
),
)
response, status = self.api.delete(uuid.uuid4())
assert status == 204
assert response["result"] == "success"
def test_delete_when_in_explore_with_trial_app(self, mocker, mock_admin_auth):
"""Test deleting an app from explore that has a trial app."""
app_id = uuid.uuid4()
mock_recommended = Mock(spec=RecommendedApp)
mock_recommended.app_id = "app-123"
mock_app = Mock(spec=App)
mock_app.is_public = True
mock_trial = Mock()
# Mock session context manager and its execute
mock_session = Mock()
mock_session.execute = Mock()
mock_session.delete = Mock()
# Set up side effects for execute calls
mock_session.execute.side_effect = [
Mock(scalar_one_or_none=lambda: mock_recommended),
Mock(scalar_one_or_none=lambda: mock_app),
Mock(scalars=Mock(return_value=Mock(all=lambda: []))),
Mock(scalar_one_or_none=lambda: mock_trial),
]
mocker.patch(
"controllers.console.admin.session_factory.create_session",
return_value=Mock(
__enter__=lambda s: mock_session,
__exit__=Mock(return_value=False),
),
)
mocker.patch("controllers.console.admin.db.session.delete")
mocker.patch("controllers.console.admin.db.session.commit")
response, status = self.api.delete(app_id)
assert status == 204
assert response["result"] == "success"
assert mock_app.is_public is False
def test_delete_with_installed_apps(self, mocker, mock_admin_auth):
"""Test deleting an app that has installed apps in other tenants."""
app_id = uuid.uuid4()
mock_recommended = Mock(spec=RecommendedApp)
mock_recommended.app_id = "app-123"
mock_app = Mock(spec=App)
mock_app.is_public = True
mock_installed_app = Mock(spec=InstalledApp)
# Mock session
mock_session = Mock()
mock_session.execute = Mock()
mock_session.delete = Mock()
mock_session.execute.side_effect = [
Mock(scalar_one_or_none=lambda: mock_recommended),
Mock(scalar_one_or_none=lambda: mock_app),
Mock(scalars=Mock(return_value=Mock(all=lambda: [mock_installed_app]))),
Mock(scalar_one_or_none=lambda: None),
]
mocker.patch(
"controllers.console.admin.session_factory.create_session",
return_value=Mock(
__enter__=lambda s: mock_session,
__exit__=Mock(return_value=False),
),
)
mocker.patch("controllers.console.admin.db.session.delete")
mocker.patch("controllers.console.admin.db.session.commit")
response, status = self.api.delete(app_id)
assert status == 204
assert mock_session.delete.called
class TestInsertExploreAppListApi:
def setup_method(self):
self.api = InsertExploreAppListApi()
def test_app_not_found(self, mocker, mock_admin_auth, mock_console_payload):
mocker.patch(
"controllers.console.admin.db.session.execute",
return_value=Mock(scalar_one_or_none=lambda: None),
)
with pytest.raises(NotFound, match="is not found"):
self.api.post()
def test_create_recommended_app(
self,
mocker,
mock_admin_auth,
mock_console_payload,
):
mock_app = Mock(spec=App)
mock_app.id = "app-id"
mock_app.site = None
mock_app.tenant_id = "tenant"
mock_app.is_public = False
# db.session.execute → fetch App
mocker.patch(
"controllers.console.admin.db.session.execute",
return_value=Mock(scalar_one_or_none=lambda: mock_app),
)
# session_factory.create_session → recommended_app lookup
mock_session = Mock()
mock_session.execute = Mock(return_value=Mock(scalar_one_or_none=lambda: None))
mocker.patch(
"controllers.console.admin.session_factory.create_session",
return_value=Mock(
__enter__=lambda s: mock_session,
__exit__=Mock(return_value=False),
),
)
mocker.patch("controllers.console.admin.db.session.add")
mocker.patch("controllers.console.admin.db.session.commit")
response, status = self.api.post()
assert status == 201
assert response["result"] == "success"
assert mock_app.is_public is True
def test_update_recommended_app(self, mocker, mock_admin_auth, mock_console_payload, mock_session_factory):
mock_app = Mock(spec=App)
mock_app.id = "app-id"
mock_app.site = None
mock_app.is_public = False
mock_recommended = Mock(spec=RecommendedApp)
mocker.patch(
"controllers.console.admin.db.session.execute",
side_effect=[
Mock(scalar_one_or_none=lambda: mock_app),
Mock(scalar_one_or_none=lambda: mock_recommended),
],
)
mocker.patch("controllers.console.admin.db.session.commit")
response, status = self.api.post()
assert status == 200
assert response["result"] == "success"
assert mock_app.is_public is True
def test_site_data_overrides_payload(
self,
mocker,
mock_admin_auth,
mock_console_payload,
mock_session_factory,
):
site = Mock()
site.description = "Site Desc"
site.copyright = "Site Copyright"
site.privacy_policy = "Site Privacy"
site.custom_disclaimer = "Site Disclaimer"
mock_app = Mock(spec=App)
mock_app.id = "app-id"
mock_app.site = site
mock_app.tenant_id = "tenant"
mock_app.is_public = False
mocker.patch(
"controllers.console.admin.db.session.execute",
side_effect=[
Mock(scalar_one_or_none=lambda: mock_app),
Mock(scalar_one_or_none=lambda: None),
Mock(scalar_one_or_none=lambda: None),
],
)
commit_spy = mocker.patch("controllers.console.admin.db.session.commit")
response, status = self.api.post()
assert status == 200
assert response["result"] == "success"
assert mock_app.is_public is True
commit_spy.assert_called_once()
def test_create_trial_app_when_can_trial_enabled(
self,
mocker,
mock_admin_auth,
mock_console_payload,
mock_session_factory,
):
mock_console_payload["can_trial"] = True
mock_console_payload["trial_limit"] = 5
mock_app = Mock(spec=App)
mock_app.id = "app-id"
mock_app.site = None
mock_app.tenant_id = "tenant"
mock_app.is_public = False
mocker.patch(
"controllers.console.admin.db.session.execute",
side_effect=[
Mock(scalar_one_or_none=lambda: mock_app),
Mock(scalar_one_or_none=lambda: None),
Mock(scalar_one_or_none=lambda: None),
],
)
add_spy = mocker.patch("controllers.console.admin.db.session.add")
mocker.patch("controllers.console.admin.db.session.commit")
self.api.post()
assert any(call.args[0].__class__.__name__ == "TrialApp" for call in add_spy.call_args_list)
def test_update_recommended_app_with_trial(
self,
mocker,
mock_admin_auth,
mock_console_payload,
mock_session_factory,
):
"""Test updating a recommended app when trial is enabled."""
mock_console_payload["can_trial"] = True
mock_console_payload["trial_limit"] = 10
mock_app = Mock(spec=App)
mock_app.id = "app-id"
mock_app.site = None
mock_app.is_public = False
mock_app.tenant_id = "tenant-123"
mock_recommended = Mock(spec=RecommendedApp)
mocker.patch(
"controllers.console.admin.db.session.execute",
side_effect=[
Mock(scalar_one_or_none=lambda: mock_app),
Mock(scalar_one_or_none=lambda: mock_recommended),
Mock(scalar_one_or_none=lambda: None),
],
)
add_spy = mocker.patch("controllers.console.admin.db.session.add")
mocker.patch("controllers.console.admin.db.session.commit")
response, status = self.api.post()
assert status == 200
assert response["result"] == "success"
assert mock_app.is_public is True
def test_update_recommended_app_without_trial(
self,
mocker,
mock_admin_auth,
mock_console_payload,
mock_session_factory,
):
"""Test updating a recommended app without trial enabled."""
mock_app = Mock(spec=App)
mock_app.id = "app-id"
mock_app.site = None
mock_app.is_public = False
mock_recommended = Mock(spec=RecommendedApp)
mocker.patch(
"controllers.console.admin.db.session.execute",
side_effect=[
Mock(scalar_one_or_none=lambda: mock_app),
Mock(scalar_one_or_none=lambda: mock_recommended),
],
)
mocker.patch("controllers.console.admin.db.session.commit")
response, status = self.api.post()
assert status == 200
assert response["result"] == "success"
assert mock_app.is_public is True
class TestInsertExploreAppPayload:

View File

@ -0,0 +1,138 @@
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden
from controllers.console.apikey import (
BaseApiKeyListResource,
BaseApiKeyResource,
_get_resource,
)
@pytest.fixture
def tenant_context_admin():
with patch("controllers.console.apikey.current_account_with_tenant") as mock:
user = MagicMock()
user.is_admin_or_owner = True
mock.return_value = (user, "tenant-123")
yield mock
@pytest.fixture
def tenant_context_non_admin():
with patch("controllers.console.apikey.current_account_with_tenant") as mock:
user = MagicMock()
user.is_admin_or_owner = False
mock.return_value = (user, "tenant-123")
yield mock
@pytest.fixture
def db_mock():
with patch("controllers.console.apikey.db") as mock_db:
mock_db.session = MagicMock()
yield mock_db
@pytest.fixture(autouse=True)
def bypass_permissions():
with patch(
"controllers.console.apikey.edit_permission_required",
lambda f: f,
):
yield
class DummyApiKeyListResource(BaseApiKeyListResource):
resource_type = "app"
resource_model = MagicMock()
resource_id_field = "app_id"
token_prefix = "app-"
class DummyApiKeyResource(BaseApiKeyResource):
resource_type = "app"
resource_model = MagicMock()
resource_id_field = "app_id"
class TestGetResource:
def test_get_resource_success(self):
fake_resource = MagicMock()
with (
patch("controllers.console.apikey.select") as mock_select,
patch("controllers.console.apikey.Session") as mock_session,
patch("controllers.console.apikey.db") as mock_db,
):
mock_db.engine = MagicMock()
mock_select.return_value.filter_by.return_value = MagicMock()
session = mock_session.return_value.__enter__.return_value
session.execute.return_value.scalar_one_or_none.return_value = fake_resource
result = _get_resource("rid", "tid", MagicMock)
assert result == fake_resource
def test_get_resource_not_found(self):
with (
patch("controllers.console.apikey.select") as mock_select,
patch("controllers.console.apikey.Session") as mock_session,
patch("controllers.console.apikey.db") as mock_db,
patch("controllers.console.apikey.flask_restx.abort") as abort,
):
mock_db.engine = MagicMock()
mock_select.return_value.filter_by.return_value = MagicMock()
session = mock_session.return_value.__enter__.return_value
session.execute.return_value.scalar_one_or_none.return_value = None
_get_resource("rid", "tid", MagicMock)
abort.assert_called_once()
class TestBaseApiKeyListResource:
def test_get_apikeys_success(self, tenant_context_admin, db_mock):
resource = DummyApiKeyListResource()
with patch("controllers.console.apikey._get_resource"):
db_mock.session.scalars.return_value.all.return_value = [MagicMock(), MagicMock()]
result = DummyApiKeyListResource.get.__wrapped__(resource, "resource-id")
assert "items" in result
class TestBaseApiKeyResource:
def test_delete_forbidden(self, tenant_context_non_admin, db_mock):
resource = DummyApiKeyResource()
with patch("controllers.console.apikey._get_resource"):
with pytest.raises(Forbidden):
DummyApiKeyResource.delete(resource, "rid", "kid")
def test_delete_key_not_found(self, tenant_context_admin, db_mock):
resource = DummyApiKeyResource()
db_mock.session.query.return_value.where.return_value.first.return_value = None
with patch("controllers.console.apikey._get_resource"):
with pytest.raises(Exception) as exc_info:
DummyApiKeyResource.delete(resource, "rid", "kid")
# flask_restx.abort raises HTTPException with message in data attribute
assert exc_info.value.data["message"] == "API key not found"
def test_delete_success(self, tenant_context_admin, db_mock):
resource = DummyApiKeyResource()
db_mock.session.query.return_value.where.return_value.first.return_value = MagicMock()
with (
patch("controllers.console.apikey._get_resource"),
patch("controllers.console.apikey.ApiTokenCache.delete"),
):
result, status = DummyApiKeyResource.delete(resource, "rid", "kid")
assert status == 204
assert result == {"result": "success"}
db_mock.session.commit.assert_called_once()

View File

@ -1,46 +0,0 @@
import builtins
from unittest.mock import patch
import pytest
from flask import Flask
from flask.views import MethodView
from extensions import ext_fastopenapi
if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
@pytest.fixture
def app() -> Flask:
app = Flask(__name__)
app.config["TESTING"] = True
app.secret_key = "test-secret-key"
return app
def test_console_init_get_returns_finished_when_no_init_password(app: Flask, monkeypatch: pytest.MonkeyPatch):
ext_fastopenapi.init_app(app)
monkeypatch.delenv("INIT_PASSWORD", raising=False)
with patch("controllers.console.init_validate.dify_config.EDITION", "SELF_HOSTED"):
client = app.test_client()
response = client.get("/console/api/init")
assert response.status_code == 200
assert response.get_json() == {"status": "finished"}
def test_console_init_post_returns_success(app: Flask, monkeypatch: pytest.MonkeyPatch):
ext_fastopenapi.init_app(app)
monkeypatch.setenv("INIT_PASSWORD", "test-init-password")
with (
patch("controllers.console.init_validate.dify_config.EDITION", "SELF_HOSTED"),
patch("controllers.console.init_validate.TenantService.get_tenant_count", return_value=0),
):
client = app.test_client()
response = client.post("/console/api/init", json={"password": "test-init-password"})
assert response.status_code == 201
assert response.get_json() == {"result": "success"}

View File

@ -1,286 +0,0 @@
"""Tests for remote file upload API endpoints using Flask-RESTX."""
import contextlib
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import Mock, patch
import httpx
import pytest
from flask import Flask, g
@pytest.fixture
def app() -> Flask:
"""Create Flask app for testing."""
app = Flask(__name__)
app.config["TESTING"] = True
app.config["SECRET_KEY"] = "test-secret-key"
return app
@pytest.fixture
def client(app):
"""Create test client with console blueprint registered."""
from controllers.console import bp
app.register_blueprint(bp)
return app.test_client()
@pytest.fixture
def mock_account():
"""Create a mock account for testing."""
from models import Account
account = Mock(spec=Account)
account.id = "test-account-id"
account.current_tenant_id = "test-tenant-id"
return account
@pytest.fixture
def auth_ctx(app, mock_account):
"""Context manager to set auth/tenant context in flask.g for a request."""
@contextlib.contextmanager
def _ctx():
with app.test_request_context():
g._login_user = mock_account
g._current_tenant = mock_account.current_tenant_id
yield
return _ctx
class TestGetRemoteFileInfo:
"""Test GET /console/api/remote-files/<path:url> endpoint."""
def test_get_remote_file_info_success(self, app, client, mock_account):
"""Test successful retrieval of remote file info."""
response = httpx.Response(
200,
request=httpx.Request("HEAD", "http://example.com/file.txt"),
headers={"Content-Type": "text/plain", "Content-Length": "1024"},
)
with (
patch(
"controllers.console.remote_files.current_account_with_tenant",
return_value=(mock_account, "test-tenant-id"),
),
patch("controllers.console.remote_files.ssrf_proxy.head", return_value=response),
patch("libs.login.check_csrf_token", return_value=None),
):
with app.test_request_context():
g._login_user = mock_account
g._current_tenant = mock_account.current_tenant_id
encoded_url = "http%3A%2F%2Fexample.com%2Ffile.txt"
resp = client.get(f"/console/api/remote-files/{encoded_url}")
assert resp.status_code == 200
data = resp.get_json()
assert data["file_type"] == "text/plain"
assert data["file_length"] == 1024
def test_get_remote_file_info_fallback_to_get_on_head_failure(self, app, client, mock_account):
"""Test fallback to GET when HEAD returns non-200 status."""
head_response = httpx.Response(
404,
request=httpx.Request("HEAD", "http://example.com/file.pdf"),
)
get_response = httpx.Response(
200,
request=httpx.Request("GET", "http://example.com/file.pdf"),
headers={"Content-Type": "application/pdf", "Content-Length": "2048"},
)
with (
patch(
"controllers.console.remote_files.current_account_with_tenant",
return_value=(mock_account, "test-tenant-id"),
),
patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_response),
patch("controllers.console.remote_files.ssrf_proxy.get", return_value=get_response),
patch("libs.login.check_csrf_token", return_value=None),
):
with app.test_request_context():
g._login_user = mock_account
g._current_tenant = mock_account.current_tenant_id
encoded_url = "http%3A%2F%2Fexample.com%2Ffile.pdf"
resp = client.get(f"/console/api/remote-files/{encoded_url}")
assert resp.status_code == 200
data = resp.get_json()
assert data["file_type"] == "application/pdf"
assert data["file_length"] == 2048
class TestRemoteFileUpload:
"""Test POST /console/api/remote-files/upload endpoint."""
@pytest.mark.parametrize(
("head_status", "use_get"),
[
(200, False), # HEAD succeeds
(405, True), # HEAD fails -> fallback GET
],
)
def test_upload_remote_file_success_paths(self, client, mock_account, auth_ctx, head_status, use_get):
url = "http://example.com/file.pdf"
head_resp = httpx.Response(
head_status,
request=httpx.Request("HEAD", url),
headers={"Content-Type": "application/pdf", "Content-Length": "1024"},
)
get_resp = httpx.Response(
200,
request=httpx.Request("GET", url),
headers={"Content-Type": "application/pdf", "Content-Length": "1024"},
content=b"file content",
)
file_info = SimpleNamespace(
extension="pdf",
size=1024,
filename="file.pdf",
mimetype="application/pdf",
)
uploaded_file = SimpleNamespace(
id="uploaded-file-id",
name="file.pdf",
size=1024,
extension="pdf",
mime_type="application/pdf",
created_by="test-account-id",
created_at=datetime(2024, 1, 1, 12, 0, 0),
)
with (
patch(
"controllers.console.remote_files.current_account_with_tenant",
return_value=(mock_account, "test-tenant-id"),
),
patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_resp) as p_head,
patch("controllers.console.remote_files.ssrf_proxy.get", return_value=get_resp) as p_get,
patch(
"controllers.console.remote_files.helpers.guess_file_info_from_response",
return_value=file_info,
),
patch(
"controllers.console.remote_files.FileService.is_file_size_within_limit",
return_value=True,
),
patch("controllers.console.remote_files.db", spec=["engine"]),
patch("controllers.console.remote_files.FileService") as mock_file_service,
patch(
"controllers.console.remote_files.file_helpers.get_signed_file_url",
return_value="http://example.com/signed-url",
),
patch("libs.login.check_csrf_token", return_value=None),
):
mock_file_service.return_value.upload_file.return_value = uploaded_file
with auth_ctx():
resp = client.post(
"/console/api/remote-files/upload",
json={"url": url},
)
assert resp.status_code == 201
p_head.assert_called_once()
# GET is used either for fallback (HEAD fails) or to fetch content after HEAD succeeds
p_get.assert_called_once()
mock_file_service.return_value.upload_file.assert_called_once()
data = resp.get_json()
assert data["id"] == "uploaded-file-id"
assert data["name"] == "file.pdf"
assert data["size"] == 1024
assert data["extension"] == "pdf"
assert data["url"] == "http://example.com/signed-url"
assert data["mime_type"] == "application/pdf"
assert data["created_by"] == "test-account-id"
@pytest.mark.parametrize(
("size_ok", "raises", "expected_status", "expected_msg"),
[
# When size check fails in controller, API returns 413 with message "File size exceeded..."
(False, None, 413, "file size exceeded"),
# When service raises unsupported type, controller maps to 415 with message "File type not allowed."
(True, "unsupported", 415, "file type not allowed"),
],
)
def test_upload_remote_file_errors(
self, client, mock_account, auth_ctx, size_ok, raises, expected_status, expected_msg
):
url = "http://example.com/x.pdf"
head_resp = httpx.Response(
200,
request=httpx.Request("HEAD", url),
headers={"Content-Type": "application/pdf", "Content-Length": "9"},
)
file_info = SimpleNamespace(extension="pdf", size=9, filename="x.pdf", mimetype="application/pdf")
with (
patch(
"controllers.console.remote_files.current_account_with_tenant",
return_value=(mock_account, "test-tenant-id"),
),
patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_resp),
patch(
"controllers.console.remote_files.helpers.guess_file_info_from_response",
return_value=file_info,
),
patch(
"controllers.console.remote_files.FileService.is_file_size_within_limit",
return_value=size_ok,
),
patch("controllers.console.remote_files.db", spec=["engine"]),
patch("libs.login.check_csrf_token", return_value=None),
):
if raises == "unsupported":
from services.errors.file import UnsupportedFileTypeError
with patch("controllers.console.remote_files.FileService") as mock_file_service:
mock_file_service.return_value.upload_file.side_effect = UnsupportedFileTypeError("bad")
with auth_ctx():
resp = client.post(
"/console/api/remote-files/upload",
json={"url": url},
)
else:
with auth_ctx():
resp = client.post(
"/console/api/remote-files/upload",
json={"url": url},
)
assert resp.status_code == expected_status
data = resp.get_json()
msg = (data.get("error") or {}).get("message") or data.get("message", "")
assert expected_msg in msg.lower()
def test_upload_remote_file_fetch_failure(self, client, mock_account, auth_ctx):
"""Test upload when fetching of remote file fails."""
with (
patch(
"controllers.console.remote_files.current_account_with_tenant",
return_value=(mock_account, "test-tenant-id"),
),
patch(
"controllers.console.remote_files.ssrf_proxy.head",
side_effect=httpx.RequestError("Connection failed"),
),
patch("libs.login.check_csrf_token", return_value=None),
):
with auth_ctx():
resp = client.post(
"/console/api/remote-files/upload",
json={"url": "http://unreachable.com/file.pdf"},
)
assert resp.status_code == 400
data = resp.get_json()
msg = (data.get("error") or {}).get("message") or data.get("message", "")
assert "failed to fetch" in msg.lower()

View File

@ -0,0 +1,81 @@
from werkzeug.exceptions import Unauthorized
def unwrap(func):
"""
Recursively unwrap decorated functions.
"""
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestFeatureApi:
def test_get_tenant_features_success(self, mocker):
from controllers.console.feature import FeatureApi
mocker.patch(
"controllers.console.feature.current_account_with_tenant",
return_value=("account_id", "tenant_123"),
)
mocker.patch("controllers.console.feature.FeatureService.get_features").return_value.model_dump.return_value = {
"features": {"feature_a": True}
}
api = FeatureApi()
raw_get = unwrap(FeatureApi.get)
result = raw_get(api)
assert result == {"features": {"feature_a": True}}
class TestSystemFeatureApi:
def test_get_system_features_authenticated(self, mocker):
"""
current_user.is_authenticated == True
"""
from controllers.console.feature import SystemFeatureApi
fake_user = mocker.Mock()
fake_user.is_authenticated = True
mocker.patch(
"controllers.console.feature.current_user",
fake_user,
)
mocker.patch(
"controllers.console.feature.FeatureService.get_system_features"
).return_value.model_dump.return_value = {"features": {"sys_feature": True}}
api = SystemFeatureApi()
result = api.get()
assert result == {"features": {"sys_feature": True}}
def test_get_system_features_unauthenticated(self, mocker):
"""
current_user.is_authenticated raises Unauthorized
"""
from controllers.console.feature import SystemFeatureApi
fake_user = mocker.Mock()
type(fake_user).is_authenticated = mocker.PropertyMock(side_effect=Unauthorized())
mocker.patch(
"controllers.console.feature.current_user",
fake_user,
)
mocker.patch(
"controllers.console.feature.FeatureService.get_system_features"
).return_value.model_dump.return_value = {"features": {"sys_feature": False}}
api = SystemFeatureApi()
result = api.get()
assert result == {"features": {"sys_feature": False}}

View File

@ -0,0 +1,300 @@
import io
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden
from constants import DOCUMENT_EXTENSIONS
from controllers.common.errors import (
BlockedFileExtensionError,
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.console.files import (
FileApi,
FilePreviewApi,
FileSupportTypeApi,
)
def unwrap(func):
"""
Recursively unwrap decorated functions.
"""
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def app():
app = Flask(__name__)
app.testing = True
return app
@pytest.fixture(autouse=True)
def mock_decorators():
"""
Make decorators no-ops so logic is directly testable
"""
with (
patch("controllers.console.files.setup_required", new=lambda f: f),
patch("controllers.console.files.login_required", new=lambda f: f),
patch("controllers.console.files.account_initialization_required", new=lambda f: f),
patch("controllers.console.files.cloud_edition_billing_resource_check", return_value=lambda f: f),
):
yield
@pytest.fixture
def mock_current_user():
user = MagicMock()
user.is_dataset_editor = True
return user
@pytest.fixture
def mock_account_context(mock_current_user):
with patch(
"controllers.console.files.current_account_with_tenant",
return_value=(mock_current_user, None),
):
yield
@pytest.fixture
def mock_db():
with patch("controllers.console.files.db") as db_mock:
db_mock.engine = MagicMock()
yield db_mock
@pytest.fixture
def mock_file_service(mock_db):
with patch("controllers.console.files.FileService") as fs:
instance = fs.return_value
yield instance
class TestFileApiGet:
def test_get_upload_config(self, app):
api = FileApi()
get_method = unwrap(api.get)
with app.test_request_context():
data, status = get_method(api)
assert status == 200
assert "file_size_limit" in data
assert "batch_count_limit" in data
class TestFileApiPost:
def test_no_file_uploaded(self, app, mock_account_context):
api = FileApi()
post_method = unwrap(api.post)
with app.test_request_context(method="POST", data={}):
with pytest.raises(NoFileUploadedError):
post_method(api)
def test_too_many_files(self, app, mock_account_context):
api = FileApi()
post_method = unwrap(api.post)
with app.test_request_context(method="POST"):
from unittest.mock import MagicMock, patch
with patch("controllers.console.files.request") as mock_request:
mock_request.files = MagicMock()
mock_request.files.__len__.return_value = 2
mock_request.files.__contains__.return_value = True
mock_request.form = MagicMock()
mock_request.form.get.return_value = None
with pytest.raises(TooManyFilesError):
post_method(api)
def test_filename_missing(self, app, mock_account_context):
api = FileApi()
post_method = unwrap(api.post)
data = {
"file": (io.BytesIO(b"abc"), ""),
}
with app.test_request_context(method="POST", data=data):
with pytest.raises(FilenameNotExistsError):
post_method(api)
def test_dataset_upload_without_permission(self, app, mock_current_user):
mock_current_user.is_dataset_editor = False
with patch(
"controllers.console.files.current_account_with_tenant",
return_value=(mock_current_user, None),
):
api = FileApi()
post_method = unwrap(api.post)
data = {
"file": (io.BytesIO(b"abc"), "test.txt"),
"source": "datasets",
}
with app.test_request_context(method="POST", data=data):
with pytest.raises(Forbidden):
post_method(api)
def test_successful_upload(self, app, mock_account_context, mock_file_service):
api = FileApi()
post_method = unwrap(api.post)
mock_file = MagicMock()
mock_file.id = "file-id-123"
mock_file.filename = "test.txt"
mock_file.name = "test.txt"
mock_file.size = 1024
mock_file.extension = "txt"
mock_file.mime_type = "text/plain"
mock_file.created_by = "user-123"
mock_file.created_at = 1234567890
mock_file.preview_url = "http://example.com/preview/file-id-123"
mock_file.source_url = "http://example.com/source/file-id-123"
mock_file.original_url = None
mock_file.user_id = "user-123"
mock_file.tenant_id = "tenant-123"
mock_file.conversation_id = None
mock_file.file_key = "file-key-123"
mock_file_service.upload_file.return_value = mock_file
data = {
"file": (io.BytesIO(b"hello"), "test.txt"),
}
with app.test_request_context(method="POST", data=data):
response, status = post_method(api)
assert status == 201
assert response["id"] == "file-id-123"
assert response["name"] == "test.txt"
def test_upload_with_invalid_source(self, app, mock_account_context, mock_file_service):
"""Test that invalid source parameter gets normalized to None"""
api = FileApi()
post_method = unwrap(api.post)
# Create a properly structured mock file object
mock_file = MagicMock()
mock_file.id = "file-id-456"
mock_file.filename = "test.txt"
mock_file.name = "test.txt"
mock_file.size = 512
mock_file.extension = "txt"
mock_file.mime_type = "text/plain"
mock_file.created_by = "user-456"
mock_file.created_at = 1234567890
mock_file.preview_url = None
mock_file.source_url = None
mock_file.original_url = None
mock_file.user_id = "user-456"
mock_file.tenant_id = "tenant-456"
mock_file.conversation_id = None
mock_file.file_key = "file-key-456"
mock_file_service.upload_file.return_value = mock_file
data = {
"file": (io.BytesIO(b"content"), "test.txt"),
"source": "invalid_source", # Should be normalized to None
}
with app.test_request_context(method="POST", data=data):
response, status = post_method(api)
assert status == 201
assert response["id"] == "file-id-456"
# Verify that FileService was called with source=None
mock_file_service.upload_file.assert_called_once()
call_kwargs = mock_file_service.upload_file.call_args[1]
assert call_kwargs["source"] is None
def test_file_too_large_error(self, app, mock_account_context, mock_file_service):
api = FileApi()
post_method = unwrap(api.post)
from services.errors.file import FileTooLargeError as ServiceFileTooLargeError
error = ServiceFileTooLargeError("File is too large")
mock_file_service.upload_file.side_effect = error
data = {
"file": (io.BytesIO(b"x" * 1000000), "big.txt"),
}
with app.test_request_context(method="POST", data=data):
with pytest.raises(FileTooLargeError):
post_method(api)
def test_unsupported_file_type(self, app, mock_account_context, mock_file_service):
api = FileApi()
post_method = unwrap(api.post)
from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError
error = ServiceUnsupportedFileTypeError()
mock_file_service.upload_file.side_effect = error
data = {
"file": (io.BytesIO(b"x"), "bad.exe"),
}
with app.test_request_context(method="POST", data=data):
with pytest.raises(UnsupportedFileTypeError):
post_method(api)
def test_blocked_extension(self, app, mock_account_context, mock_file_service):
api = FileApi()
post_method = unwrap(api.post)
from services.errors.file import BlockedFileExtensionError as ServiceBlockedFileExtensionError
error = ServiceBlockedFileExtensionError("File extension is blocked")
mock_file_service.upload_file.side_effect = error
data = {
"file": (io.BytesIO(b"x"), "blocked.txt"),
}
with app.test_request_context(method="POST", data=data):
with pytest.raises(BlockedFileExtensionError):
post_method(api)
class TestFilePreviewApi:
def test_get_preview(self, app, mock_file_service):
api = FilePreviewApi()
get_method = unwrap(api.get)
mock_file_service.get_file_preview.return_value = "preview text"
with app.test_request_context():
result = get_method(api, "1234")
assert result == {"content": "preview text"}
class TestFileSupportTypeApi:
def test_get_supported_types(self, app):
api = FileSupportTypeApi()
get_method = unwrap(api.get)
with app.test_request_context():
result = get_method(api)
assert result == {"allowed_extensions": list(DOCUMENT_EXTENSIONS)}

View File

@ -0,0 +1,293 @@
from __future__ import annotations
import json
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import Mock
import pytest
from flask import Response
from controllers.console.human_input_form import (
ConsoleHumanInputFormApi,
ConsoleWorkflowEventsApi,
DifyAPIRepositoryFactory,
WorkflowResponseConverter,
_jsonify_form_definition,
)
from controllers.web.error import NotFoundError
from models.enums import CreatorUserRole
from models.human_input import RecipientType
from models.model import AppMode
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def test_jsonify_form_definition() -> None:
expiration = datetime(2024, 1, 1, tzinfo=UTC)
definition = SimpleNamespace(model_dump=lambda: {"fields": []})
form = SimpleNamespace(get_definition=lambda: definition, expiration_time=expiration)
response = _jsonify_form_definition(form)
assert isinstance(response, Response)
payload = json.loads(response.get_data(as_text=True))
assert payload["expiration_time"] == int(expiration.timestamp())
def test_ensure_console_access_rejects(monkeypatch: pytest.MonkeyPatch) -> None:
form = SimpleNamespace(tenant_id="tenant-1")
monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-2"))
with pytest.raises(NotFoundError):
ConsoleHumanInputFormApi._ensure_console_access(form)
def test_get_form_definition_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
expiration = datetime(2024, 1, 1, tzinfo=UTC)
definition = SimpleNamespace(model_dump=lambda: {"fields": ["a"]})
form = SimpleNamespace(tenant_id="tenant-1", get_definition=lambda: definition, expiration_time=expiration)
class _ServiceStub:
def __init__(self, *_args, **_kwargs):
pass
def get_form_definition_by_token_for_console(self, _token):
return form
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-1"))
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleHumanInputFormApi()
handler = _unwrap(api.get)
with app.test_request_context("/console/api/form/human_input/token", method="GET"):
response = handler(api, form_token="token")
payload = json.loads(response.get_data(as_text=True))
assert payload["fields"] == ["a"]
def test_get_form_definition_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
class _ServiceStub:
def __init__(self, *_args, **_kwargs):
pass
def get_form_definition_by_token_for_console(self, _token):
return None
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-1"))
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleHumanInputFormApi()
handler = _unwrap(api.get)
with app.test_request_context("/console/api/form/human_input/token", method="GET"):
with pytest.raises(NotFoundError):
handler(api, form_token="token")
def test_post_form_invalid_recipient_type(app, monkeypatch: pytest.MonkeyPatch) -> None:
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.EMAIL_MEMBER)
class _ServiceStub:
def __init__(self, *_args, **_kwargs):
pass
def get_form_by_token(self, _token):
return form
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
monkeypatch.setattr(
"controllers.console.human_input_form.current_account_with_tenant",
lambda: (SimpleNamespace(id="user-1"), "tenant-1"),
)
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleHumanInputFormApi()
handler = _unwrap(api.post)
with app.test_request_context(
"/console/api/form/human_input/token",
method="POST",
json={"inputs": {"content": "ok"}, "action": "approve"},
):
with pytest.raises(NotFoundError):
handler(api, form_token="token")
def test_post_form_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
submit_mock = Mock()
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.CONSOLE)
class _ServiceStub:
def __init__(self, *_args, **_kwargs):
pass
def get_form_by_token(self, _token):
return form
def submit_form_by_token(self, **kwargs):
submit_mock(**kwargs)
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
monkeypatch.setattr(
"controllers.console.human_input_form.current_account_with_tenant",
lambda: (SimpleNamespace(id="user-1"), "tenant-1"),
)
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleHumanInputFormApi()
handler = _unwrap(api.post)
with app.test_request_context(
"/console/api/form/human_input/token",
method="POST",
json={"inputs": {"content": "ok"}, "action": "approve"},
):
response = handler(api, form_token="token")
assert response.get_json() == {}
submit_mock.assert_called_once()
def test_workflow_events_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
class _RepoStub:
def get_workflow_run_by_id_and_tenant_id(self, **_kwargs):
return None
monkeypatch.setattr(
DifyAPIRepositoryFactory,
"create_api_workflow_run_repository",
lambda *_args, **_kwargs: _RepoStub(),
)
monkeypatch.setattr(
"controllers.console.human_input_form.current_account_with_tenant",
lambda: (SimpleNamespace(id="u1"), "t1"),
)
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleWorkflowEventsApi()
handler = _unwrap(api.get)
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
with pytest.raises(NotFoundError):
handler(api, workflow_run_id="run-1")
def test_workflow_events_requires_account(app, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
created_by_role=CreatorUserRole.END_USER,
created_by="user-1",
tenant_id="t1",
)
class _RepoStub:
def get_workflow_run_by_id_and_tenant_id(self, **_kwargs):
return workflow_run
monkeypatch.setattr(
DifyAPIRepositoryFactory,
"create_api_workflow_run_repository",
lambda *_args, **_kwargs: _RepoStub(),
)
monkeypatch.setattr(
"controllers.console.human_input_form.current_account_with_tenant",
lambda: (SimpleNamespace(id="u1"), "t1"),
)
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleWorkflowEventsApi()
handler = _unwrap(api.get)
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
with pytest.raises(NotFoundError):
handler(api, workflow_run_id="run-1")
def test_workflow_events_requires_creator(app, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
created_by_role=CreatorUserRole.ACCOUNT,
created_by="user-2",
tenant_id="t1",
)
class _RepoStub:
def get_workflow_run_by_id_and_tenant_id(self, **_kwargs):
return workflow_run
monkeypatch.setattr(
DifyAPIRepositoryFactory,
"create_api_workflow_run_repository",
lambda *_args, **_kwargs: _RepoStub(),
)
monkeypatch.setattr(
"controllers.console.human_input_form.current_account_with_tenant",
lambda: (SimpleNamespace(id="u1"), "t1"),
)
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleWorkflowEventsApi()
handler = _unwrap(api.get)
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
with pytest.raises(NotFoundError):
handler(api, workflow_run_id="run-1")
def test_workflow_events_finished(app, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
created_by_role=CreatorUserRole.ACCOUNT,
created_by="user-1",
tenant_id="t1",
app_id="app-1",
finished_at=datetime(2024, 1, 1, tzinfo=UTC),
)
app_model = SimpleNamespace(mode=AppMode.WORKFLOW)
class _RepoStub:
def get_workflow_run_by_id_and_tenant_id(self, **_kwargs):
return workflow_run
response_obj = SimpleNamespace(
event=SimpleNamespace(value="finished"),
model_dump=lambda mode="json": {"status": "done"},
)
monkeypatch.setattr(
DifyAPIRepositoryFactory,
"create_api_workflow_run_repository",
lambda *_args, **_kwargs: _RepoStub(),
)
monkeypatch.setattr(
"controllers.console.human_input_form._retrieve_app_for_workflow_run",
lambda *_args, **_kwargs: app_model,
)
monkeypatch.setattr(
WorkflowResponseConverter,
"workflow_run_result_to_finish_response",
lambda **_kwargs: response_obj,
)
monkeypatch.setattr(
"controllers.console.human_input_form.current_account_with_tenant",
lambda: (SimpleNamespace(id="user-1"), "t1"),
)
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleWorkflowEventsApi()
handler = _unwrap(api.get)
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
response = handler(api, workflow_run_id="run-1")
assert response.mimetype == "text/event-stream"
assert "data" in response.get_data(as_text=True)

View File

@ -0,0 +1,108 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import Mock
import pytest
from controllers.console import init_validate
from controllers.console.error import AlreadySetupError, InitValidateFailedError
class _SessionStub:
def __init__(self, has_setup: bool):
self._has_setup = has_setup
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def execute(self, *_args, **_kwargs):
return SimpleNamespace(scalar_one_or_none=lambda: Mock() if self._has_setup else None)
def test_get_init_status_finished(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate, "get_init_validate_status", lambda: True)
result = init_validate.get_init_status()
assert result.status == "finished"
def test_get_init_status_not_started(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate, "get_init_validate_status", lambda: False)
result = init_validate.get_init_status()
assert result.status == "not_started"
def test_validate_init_password_already_setup(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 1)
app.secret_key = "test-secret"
with app.test_request_context("/console/api/init", method="POST"):
with pytest.raises(AlreadySetupError):
init_validate.validate_init_password(init_validate.InitValidatePayload(password="pw"))
def test_validate_init_password_wrong_password(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 0)
monkeypatch.setenv("INIT_PASSWORD", "expected")
app.secret_key = "test-secret"
with app.test_request_context("/console/api/init", method="POST"):
with pytest.raises(InitValidateFailedError):
init_validate.validate_init_password(init_validate.InitValidatePayload(password="wrong"))
assert init_validate.session.get("is_init_validated") is False
def test_validate_init_password_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 0)
monkeypatch.setenv("INIT_PASSWORD", "expected")
app.secret_key = "test-secret"
with app.test_request_context("/console/api/init", method="POST"):
result = init_validate.validate_init_password(init_validate.InitValidatePayload(password="expected"))
assert result.result == "success"
assert init_validate.session.get("is_init_validated") is True
def test_get_init_validate_status_not_self_hosted(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "CLOUD")
assert init_validate.get_init_validate_status() is True
def test_get_init_validate_status_validated_session(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
monkeypatch.setenv("INIT_PASSWORD", "expected")
app.secret_key = "test-secret"
with app.test_request_context("/console/api/init", method="GET"):
init_validate.session["is_init_validated"] = True
assert init_validate.get_init_validate_status() is True
def test_get_init_validate_status_setup_exists(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
monkeypatch.setenv("INIT_PASSWORD", "expected")
monkeypatch.setattr(init_validate, "Session", lambda *_args, **_kwargs: _SessionStub(True))
monkeypatch.setattr(init_validate, "db", SimpleNamespace(engine=object()))
app.secret_key = "test-secret"
with app.test_request_context("/console/api/init", method="GET"):
init_validate.session.pop("is_init_validated", None)
assert init_validate.get_init_validate_status() is True
def test_get_init_validate_status_not_validated(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
monkeypatch.setenv("INIT_PASSWORD", "expected")
monkeypatch.setattr(init_validate, "Session", lambda *_args, **_kwargs: _SessionStub(False))
monkeypatch.setattr(init_validate, "db", SimpleNamespace(engine=object()))
app.secret_key = "test-secret"
with app.test_request_context("/console/api/init", method="GET"):
init_validate.session.pop("is_init_validated", None)
assert init_validate.get_init_validate_status() is False

View File

@ -0,0 +1,281 @@
from __future__ import annotations
import urllib.parse
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import MagicMock
import httpx
import pytest
from controllers.common.errors import FileTooLargeError, RemoteFileUploadError, UnsupportedFileTypeError
from controllers.console import remote_files as remote_files_module
from services.errors.file import FileTooLargeError as ServiceFileTooLargeError
from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class _FakeResponse:
def __init__(
self,
*,
status_code: int = 200,
headers: dict[str, str] | None = None,
method: str = "GET",
content: bytes = b"",
text: str = "",
error: Exception | None = None,
) -> None:
self.status_code = status_code
self.headers = headers or {}
self.request = SimpleNamespace(method=method)
self.content = content
self.text = text
self._error = error
def raise_for_status(self) -> None:
if self._error:
raise self._error
def _mock_upload_dependencies(
monkeypatch: pytest.MonkeyPatch,
*,
file_size_within_limit: bool = True,
):
file_info = SimpleNamespace(
filename="report.txt",
extension=".txt",
mimetype="text/plain",
size=3,
)
monkeypatch.setattr(
remote_files_module.helpers,
"guess_file_info_from_response",
MagicMock(return_value=file_info),
)
file_service_cls = MagicMock()
file_service_cls.is_file_size_within_limit.return_value = file_size_within_limit
monkeypatch.setattr(remote_files_module, "FileService", file_service_cls)
monkeypatch.setattr(remote_files_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), None))
monkeypatch.setattr(remote_files_module, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(
remote_files_module.file_helpers,
"get_signed_file_url",
lambda upload_file_id: f"https://signed.example/{upload_file_id}",
)
return file_service_cls
def test_get_remote_file_info_uses_head_when_successful(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.GetRemoteFileInfo()
handler = _unwrap(api.get)
decoded_url = "https://example.com/test.txt"
encoded_url = urllib.parse.quote(decoded_url, safe="")
head_resp = _FakeResponse(
status_code=200,
headers={"Content-Type": "text/plain", "Content-Length": "128"},
method="HEAD",
)
head_mock = MagicMock(return_value=head_resp)
get_mock = MagicMock()
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", head_mock)
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
with app.test_request_context(method="GET"):
payload = handler(api, url=encoded_url)
assert payload == {"file_type": "text/plain", "file_length": 128}
head_mock.assert_called_once_with(decoded_url)
get_mock.assert_not_called()
def test_get_remote_file_info_falls_back_to_get_and_uses_default_headers(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.GetRemoteFileInfo()
handler = _unwrap(api.get)
decoded_url = "https://example.com/test.txt"
encoded_url = urllib.parse.quote(decoded_url, safe="")
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=503)))
get_mock = MagicMock(return_value=_FakeResponse(status_code=200, headers={}, method="GET"))
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
with app.test_request_context(method="GET"):
payload = handler(api, url=encoded_url)
assert payload == {"file_type": "application/octet-stream", "file_length": 0}
get_mock.assert_called_once_with(decoded_url, timeout=3)
def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.RemoteFileUpload()
handler = _unwrap(api.post)
url = "https://example.com/report.txt"
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=404)))
get_resp = _FakeResponse(status_code=200, method="GET", content=b"fallback-content")
get_mock = MagicMock(return_value=get_resp)
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
file_service_cls = _mock_upload_dependencies(monkeypatch)
upload_file = SimpleNamespace(
id="file-1",
name="report.txt",
size=16,
extension=".txt",
mime_type="text/plain",
created_by="u1",
created_at=datetime(2024, 1, 1, tzinfo=UTC),
)
file_service_cls.return_value.upload_file.return_value = upload_file
with app.test_request_context(method="POST", json={"url": url}):
payload, status = handler(api)
assert status == 201
assert payload["id"] == "file-1"
assert payload["url"] == "https://signed.example/file-1"
get_mock.assert_called_once_with(url=url, timeout=3, follow_redirects=True)
file_service_cls.return_value.upload_file.assert_called_once_with(
filename="report.txt",
content=b"fallback-content",
mimetype="text/plain",
user=SimpleNamespace(id="u1"),
source_url=url,
)
def test_remote_file_upload_fetches_content_with_second_get_when_head_succeeds(
app, monkeypatch: pytest.MonkeyPatch
) -> None:
api = remote_files_module.RemoteFileUpload()
handler = _unwrap(api.post)
url = "https://example.com/photo.jpg"
monkeypatch.setattr(
remote_files_module.ssrf_proxy,
"head",
MagicMock(return_value=_FakeResponse(status_code=200, method="HEAD", content=b"head-content")),
)
extra_get_resp = _FakeResponse(status_code=200, method="GET", content=b"downloaded-content")
get_mock = MagicMock(return_value=extra_get_resp)
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
file_service_cls = _mock_upload_dependencies(monkeypatch)
upload_file = SimpleNamespace(
id="file-2",
name="photo.jpg",
size=18,
extension=".jpg",
mime_type="image/jpeg",
created_by="u1",
created_at=datetime(2024, 1, 2, tzinfo=UTC),
)
file_service_cls.return_value.upload_file.return_value = upload_file
with app.test_request_context(method="POST", json={"url": url}):
payload, status = handler(api)
assert status == 201
assert payload["id"] == "file-2"
get_mock.assert_called_once_with(url)
assert file_service_cls.return_value.upload_file.call_args.kwargs["content"] == b"downloaded-content"
def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.RemoteFileUpload()
handler = _unwrap(api.post)
url = "https://example.com/fail.txt"
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=500)))
monkeypatch.setattr(
remote_files_module.ssrf_proxy,
"get",
MagicMock(return_value=_FakeResponse(status_code=502, text="bad gateway")),
)
with app.test_request_context(method="POST", json={"url": url}):
with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: bad gateway"):
handler(api)
def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.RemoteFileUpload()
handler = _unwrap(api.post)
url = "https://example.com/fail.txt"
request = httpx.Request("HEAD", url)
monkeypatch.setattr(
remote_files_module.ssrf_proxy,
"head",
MagicMock(side_effect=httpx.RequestError("network down", request=request)),
)
with app.test_request_context(method="POST", json={"url": url}):
with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: network down"):
handler(api)
def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.RemoteFileUpload()
handler = _unwrap(api.post)
url = "https://example.com/large.bin"
monkeypatch.setattr(
remote_files_module.ssrf_proxy,
"head",
MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")),
)
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
_mock_upload_dependencies(monkeypatch, file_size_within_limit=False)
with app.test_request_context(method="POST", json={"url": url}):
with pytest.raises(FileTooLargeError):
handler(api)
def test_remote_file_upload_translates_service_file_too_large_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.RemoteFileUpload()
handler = _unwrap(api.post)
url = "https://example.com/large.bin"
monkeypatch.setattr(
remote_files_module.ssrf_proxy,
"head",
MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")),
)
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
file_service_cls = _mock_upload_dependencies(monkeypatch)
file_service_cls.return_value.upload_file.side_effect = ServiceFileTooLargeError("size exceeded")
with app.test_request_context(method="POST", json={"url": url}):
with pytest.raises(FileTooLargeError, match="size exceeded"):
handler(api)
def test_remote_file_upload_translates_service_unsupported_type_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.RemoteFileUpload()
handler = _unwrap(api.post)
url = "https://example.com/file.exe"
monkeypatch.setattr(
remote_files_module.ssrf_proxy,
"head",
MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")),
)
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
file_service_cls = _mock_upload_dependencies(monkeypatch)
file_service_cls.return_value.upload_file.side_effect = ServiceUnsupportedFileTypeError()
with app.test_request_context(method="POST", json={"url": url}):
with pytest.raises(UnsupportedFileTypeError):
handler(api)

View File

@ -0,0 +1,49 @@
from unittest.mock import patch
import controllers.console.spec as spec_module
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestSpecSchemaDefinitionsApi:
def test_get_success(self):
api = spec_module.SpecSchemaDefinitionsApi()
method = unwrap(api.get)
schema_definitions = [{"type": "string"}]
with patch.object(
spec_module,
"SchemaManager",
) as schema_manager_cls:
schema_manager_cls.return_value.get_all_schema_definitions.return_value = schema_definitions
resp, status = method(api)
assert status == 200
assert resp == schema_definitions
def test_get_exception_returns_empty_list(self):
api = spec_module.SpecSchemaDefinitionsApi()
method = unwrap(api.get)
with (
patch.object(
spec_module,
"SchemaManager",
side_effect=Exception("boom"),
),
patch.object(
spec_module.logger,
"exception",
) as log_exception,
):
resp, status = method(api)
assert status == 200
assert resp == []
log_exception.assert_called_once()

View File

@ -0,0 +1,162 @@
from unittest.mock import MagicMock, patch
import controllers.console.version as version_module
class TestHasNewVersion:
def test_has_new_version_true(self):
result = version_module._has_new_version(
latest_version="1.2.0",
current_version="1.1.0",
)
assert result is True
def test_has_new_version_false(self):
result = version_module._has_new_version(
latest_version="1.0.0",
current_version="1.1.0",
)
assert result is False
def test_has_new_version_invalid_version(self):
with patch.object(version_module.logger, "warning") as log_warning:
result = version_module._has_new_version(
latest_version="invalid",
current_version="1.0.0",
)
assert result is False
log_warning.assert_called_once()
class TestCheckVersionUpdate:
def test_no_check_update_url(self):
query = version_module.VersionQuery(current_version="1.0.0")
with (
patch.object(
version_module.dify_config,
"CHECK_UPDATE_URL",
"",
),
patch.object(
version_module.dify_config.project,
"version",
"1.0.0",
),
patch.object(
version_module.dify_config,
"CAN_REPLACE_LOGO",
True,
),
patch.object(
version_module.dify_config,
"MODEL_LB_ENABLED",
False,
),
):
result = version_module.check_version_update(query)
assert result.version == "1.0.0"
assert result.can_auto_update is False
assert result.features.can_replace_logo is True
assert result.features.model_load_balancing_enabled is False
def test_http_error_fallback(self):
query = version_module.VersionQuery(current_version="1.0.0")
with (
patch.object(
version_module.dify_config,
"CHECK_UPDATE_URL",
"http://example.com",
),
patch.object(
version_module.httpx,
"get",
side_effect=Exception("boom"),
),
patch.object(
version_module.logger,
"warning",
) as log_warning,
):
result = version_module.check_version_update(query)
assert result.version == "1.0.0"
log_warning.assert_called_once()
def test_new_version_available(self):
query = version_module.VersionQuery(current_version="1.0.0")
response = MagicMock()
response.json.return_value = {
"version": "1.2.0",
"releaseDate": "2024-01-01",
"releaseNotes": "New features",
"canAutoUpdate": True,
}
with (
patch.object(
version_module.dify_config,
"CHECK_UPDATE_URL",
"http://example.com",
),
patch.object(
version_module.httpx,
"get",
return_value=response,
),
patch.object(
version_module.dify_config.project,
"version",
"1.0.0",
),
patch.object(
version_module.dify_config,
"CAN_REPLACE_LOGO",
False,
),
patch.object(
version_module.dify_config,
"MODEL_LB_ENABLED",
True,
),
):
result = version_module.check_version_update(query)
assert result.version == "1.2.0"
assert result.release_date == "2024-01-01"
assert result.release_notes == "New features"
assert result.can_auto_update is True
def test_no_new_version(self):
query = version_module.VersionQuery(current_version="1.2.0")
response = MagicMock()
response.json.return_value = {
"version": "1.1.0",
}
with (
patch.object(
version_module.dify_config,
"CHECK_UPDATE_URL",
"http://example.com",
),
patch.object(
version_module.httpx,
"get",
return_value=response,
),
patch.object(
version_module.dify_config.project,
"version",
"1.2.0",
),
):
result = version_module.check_version_update(query)
assert result.version == "1.2.0"
assert result.can_auto_update is False

View File

@ -0,0 +1,93 @@
"""Shared factory helpers for core.trigger test suite."""
from __future__ import annotations
from typing import Any
from core.entities.provider_entities import ProviderConfig
from core.tools.entities.common_entities import I18nObject
from core.trigger.entities.entities import (
EventEntity,
EventIdentity,
EventParameter,
OAuthSchema,
Subscription,
SubscriptionConstructor,
TriggerProviderEntity,
TriggerProviderIdentity,
)
from core.trigger.provider import PluginTriggerProviderController
from models.provider_ids import TriggerProviderID
# Valid format for TriggerProviderID: org/plugin/provider
VALID_PROVIDER_ID = "testorg/testplugin/testprovider"
def i18n(text: str = "test") -> I18nObject:
return I18nObject(en_US=text, zh_Hans=text)
def make_event(name: str = "test_event", parameters: list[EventParameter] | None = None) -> EventEntity:
return EventEntity(
identity=EventIdentity(author="a", name=name, label=i18n(name)),
description=i18n(name),
parameters=parameters or [],
)
def make_provider_entity(
name: str = "test_provider",
events: list[EventEntity] | None = None,
constructor: SubscriptionConstructor | None = None,
subscription_schema: list[ProviderConfig] | None = None,
icon: str | None = "icon.png",
icon_dark: str | None = None,
) -> TriggerProviderEntity:
return TriggerProviderEntity(
identity=TriggerProviderIdentity(
author="a",
name=name,
label=i18n(name),
description=i18n(name),
icon=icon,
icon_dark=icon_dark,
),
events=events if events is not None else [make_event()],
subscription_constructor=constructor,
subscription_schema=subscription_schema or [],
)
def make_controller(
entity: TriggerProviderEntity | None = None,
tenant_id: str = "tenant-1",
provider_id: str = VALID_PROVIDER_ID,
) -> PluginTriggerProviderController:
return PluginTriggerProviderController(
entity=entity or make_provider_entity(),
plugin_id="plugin-1",
plugin_unique_identifier="uid-1",
provider_id=TriggerProviderID(provider_id),
tenant_id=tenant_id,
)
def make_subscription(**overrides: Any) -> Subscription:
defaults = {"expires_at": 9999999999, "endpoint": "https://hook.test", "properties": {"k": "v"}, "parameters": {}}
defaults.update(overrides)
return Subscription(**defaults)
def make_provider_config(
name: str = "api_key", required: bool = True, config_type: str = "secret-input"
) -> ProviderConfig:
return ProviderConfig(name=name, label=i18n(name), type=config_type, required=required)
def make_constructor(
credentials_schema: list[ProviderConfig] | None = None,
oauth_schema: OAuthSchema | None = None,
) -> SubscriptionConstructor:
return SubscriptionConstructor(
parameters=[], credentials_schema=credentials_schema or [], oauth_schema=oauth_schema
)

View File

@ -0,0 +1,93 @@
"""
Tests for core.trigger.debug.event_bus.TriggerDebugEventBus.
Covers: Lua-script dispatch/poll with Redis error resilience.
"""
from __future__ import annotations
from unittest.mock import MagicMock, patch
from redis import RedisError
from core.trigger.debug.event_bus import TriggerDebugEventBus
from core.trigger.debug.events import PluginTriggerDebugEvent
class TestDispatch:
@patch("core.trigger.debug.event_bus.redis_client")
def test_returns_dispatch_count(self, mock_redis):
mock_redis.eval.return_value = 3
event = MagicMock()
event.model_dump_json.return_value = '{"test": true}'
result = TriggerDebugEventBus.dispatch("tenant-1", event, "pool:key")
assert result == 3
mock_redis.eval.assert_called_once()
@patch("core.trigger.debug.event_bus.redis_client")
def test_redis_error_returns_zero(self, mock_redis):
mock_redis.eval.side_effect = RedisError("connection lost")
event = MagicMock()
event.model_dump_json.return_value = "{}"
result = TriggerDebugEventBus.dispatch("tenant-1", event, "pool:key")
assert result == 0
class TestPoll:
@patch("core.trigger.debug.event_bus.redis_client")
def test_returns_deserialized_event(self, mock_redis):
event_json = PluginTriggerDebugEvent(
timestamp=100,
name="push",
user_id="u1",
request_id="r1",
subscription_id="s1",
provider_id="p1",
).model_dump_json()
mock_redis.eval.return_value = event_json
result = TriggerDebugEventBus.poll(
event_type=PluginTriggerDebugEvent,
pool_key="pool:key",
tenant_id="t1",
user_id="u1",
app_id="a1",
node_id="n1",
)
assert result is not None
assert result.name == "push"
@patch("core.trigger.debug.event_bus.redis_client")
def test_returns_none_when_no_event(self, mock_redis):
mock_redis.eval.return_value = None
result = TriggerDebugEventBus.poll(
event_type=PluginTriggerDebugEvent,
pool_key="pool:key",
tenant_id="t1",
user_id="u1",
app_id="a1",
node_id="n1",
)
assert result is None
@patch("core.trigger.debug.event_bus.redis_client")
def test_redis_error_returns_none(self, mock_redis):
mock_redis.eval.side_effect = RedisError("timeout")
result = TriggerDebugEventBus.poll(
event_type=PluginTriggerDebugEvent,
pool_key="pool:key",
tenant_id="t1",
user_id="u1",
app_id="a1",
node_id="n1",
)
assert result is None

View File

@ -0,0 +1,276 @@
"""
Tests for core.trigger.debug.event_selectors.
Covers: Plugin/Webhook/Schedule pollers, create_event_poller factory,
and select_trigger_debug_events orchestrator.
"""
from __future__ import annotations
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
from core.plugin.entities.request import TriggerInvokeEventResponse
from core.trigger.debug.event_selectors import (
PluginTriggerDebugEventPoller,
ScheduleTriggerDebugEventPoller,
WebhookTriggerDebugEventPoller,
create_event_poller,
select_trigger_debug_events,
)
from core.trigger.debug.events import PluginTriggerDebugEvent, WebhookDebugEvent
from core.workflow.enums import NodeType
from tests.unit_tests.core.trigger.conftest import VALID_PROVIDER_ID
def _make_poller_args(node_config: dict | None = None) -> dict:
return {
"tenant_id": "t1",
"user_id": "u1",
"app_id": "a1",
"node_config": node_config or {"data": {}},
"node_id": "n1",
}
def _plugin_node_config(provider_id: str = VALID_PROVIDER_ID) -> dict:
"""Valid node config for TriggerEventNodeData.model_validate."""
return {
"data": {
"title": "test",
"plugin_id": "org/testplugin",
"provider_id": provider_id,
"event_name": "push",
"subscription_id": "s1",
"plugin_unique_identifier": "uid-1",
}
}
class TestPluginTriggerDebugEventPoller:
@patch("core.trigger.debug.event_selectors.TriggerDebugEventBus")
def test_returns_workflow_args_on_success(self, mock_bus):
event = PluginTriggerDebugEvent(
timestamp=100,
name="push",
user_id="u1",
request_id="r1",
subscription_id="s1",
provider_id="p1",
)
mock_bus.poll.return_value = event
with patch("services.trigger.trigger_service.TriggerService") as mock_trigger_svc:
mock_trigger_svc.invoke_trigger_event.return_value = TriggerInvokeEventResponse(
variables={"repo": "dify"},
cancelled=False,
)
poller = PluginTriggerDebugEventPoller(**_make_poller_args(_plugin_node_config()))
result = poller.poll()
assert result is not None
assert result.workflow_args["inputs"] == {"repo": "dify"}
@patch("core.trigger.debug.event_selectors.TriggerDebugEventBus")
def test_returns_none_when_no_event(self, mock_bus):
mock_bus.poll.return_value = None
poller = PluginTriggerDebugEventPoller(**_make_poller_args(_plugin_node_config()))
assert poller.poll() is None
@patch("core.trigger.debug.event_selectors.TriggerDebugEventBus")
def test_returns_none_when_invoke_cancelled(self, mock_bus):
event = PluginTriggerDebugEvent(
timestamp=100,
name="push",
user_id="u1",
request_id="r1",
subscription_id="s1",
provider_id="p1",
)
mock_bus.poll.return_value = event
with patch("services.trigger.trigger_service.TriggerService") as mock_trigger_svc:
mock_trigger_svc.invoke_trigger_event.return_value = TriggerInvokeEventResponse(
variables={},
cancelled=True,
)
poller = PluginTriggerDebugEventPoller(**_make_poller_args(_plugin_node_config()))
assert poller.poll() is None
class TestWebhookTriggerDebugEventPoller:
@patch("core.trigger.debug.event_selectors.TriggerDebugEventBus")
def test_uses_inputs_directly_when_present(self, mock_bus):
event = WebhookDebugEvent(
timestamp=100,
request_id="r1",
node_id="n1",
payload={"inputs": {"key": "val"}, "webhook_data": {}},
)
mock_bus.poll.return_value = event
poller = WebhookTriggerDebugEventPoller(**_make_poller_args())
result = poller.poll()
assert result is not None
assert result.workflow_args["inputs"] == {"key": "val"}
@patch("core.trigger.debug.event_selectors.TriggerDebugEventBus")
def test_falls_back_to_webhook_data(self, mock_bus):
event = WebhookDebugEvent(
timestamp=100,
request_id="r1",
node_id="n1",
payload={"webhook_data": {"body": "raw"}},
)
mock_bus.poll.return_value = event
with patch("services.trigger.webhook_service.WebhookService") as mock_webhook_svc:
mock_webhook_svc.build_workflow_inputs.return_value = {"parsed": "data"}
poller = WebhookTriggerDebugEventPoller(**_make_poller_args())
result = poller.poll()
assert result is not None
assert result.workflow_args["inputs"] == {"parsed": "data"}
mock_webhook_svc.build_workflow_inputs.assert_called_once_with({"body": "raw"})
@patch("core.trigger.debug.event_selectors.TriggerDebugEventBus")
def test_returns_none_when_no_event(self, mock_bus):
mock_bus.poll.return_value = None
poller = WebhookTriggerDebugEventPoller(**_make_poller_args())
assert poller.poll() is None
class TestScheduleTriggerDebugEventPoller:
def _make_schedule_poller(self, mock_redis, mock_schedule_svc, next_run_at: datetime):
"""Set up mocks and create a schedule poller."""
mock_redis.get.return_value = None
mock_schedule_config = MagicMock()
mock_schedule_config.cron_expression = "0 * * * *"
mock_schedule_config.timezone = "UTC"
mock_schedule_svc.to_schedule_config.return_value = mock_schedule_config
return ScheduleTriggerDebugEventPoller(**_make_poller_args())
@patch("core.trigger.debug.event_selectors.redis_client")
@patch("core.trigger.debug.event_selectors.naive_utc_now")
@patch("core.trigger.debug.event_selectors.calculate_next_run_at")
@patch("core.trigger.debug.event_selectors.ensure_naive_utc")
def test_returns_none_when_not_yet_due(self, mock_ensure, mock_calc, mock_now, mock_redis):
now = datetime(2025, 1, 1, 12, 0, 0)
next_run = datetime(2025, 1, 1, 13, 0, 0) # future
mock_now.return_value = now
mock_calc.return_value = next_run
mock_ensure.return_value = next_run
mock_redis.get.return_value = None
with patch("services.trigger.schedule_service.ScheduleService") as mock_schedule_svc:
mock_schedule_config = MagicMock()
mock_schedule_config.cron_expression = "0 * * * *"
mock_schedule_config.timezone = "UTC"
mock_schedule_svc.to_schedule_config.return_value = mock_schedule_config
poller = ScheduleTriggerDebugEventPoller(**_make_poller_args())
assert poller.poll() is None
@patch("core.trigger.debug.event_selectors.redis_client")
@patch("core.trigger.debug.event_selectors.naive_utc_now")
@patch("core.trigger.debug.event_selectors.calculate_next_run_at")
@patch("core.trigger.debug.event_selectors.ensure_naive_utc")
def test_fires_event_when_due(self, mock_ensure, mock_calc, mock_now, mock_redis):
now = datetime(2025, 1, 1, 14, 0, 0)
next_run = datetime(2025, 1, 1, 12, 0, 0) # past
mock_now.return_value = now
mock_calc.return_value = next_run
mock_ensure.return_value = next_run
mock_redis.get.return_value = None
with patch("services.trigger.schedule_service.ScheduleService") as mock_schedule_svc:
mock_schedule_config = MagicMock()
mock_schedule_config.cron_expression = "0 * * * *"
mock_schedule_config.timezone = "UTC"
mock_schedule_svc.to_schedule_config.return_value = mock_schedule_config
poller = ScheduleTriggerDebugEventPoller(**_make_poller_args())
result = poller.poll()
assert result is not None
mock_redis.delete.assert_called_once()
class TestCreateEventPoller:
def _workflow_with_node(self, node_type: NodeType):
wf = MagicMock()
wf.get_node_config_by_id.return_value = {"data": {}}
wf.get_node_type_from_node_config.return_value = node_type
return wf
def test_creates_plugin_poller(self):
wf = self._workflow_with_node(NodeType.TRIGGER_PLUGIN)
poller = create_event_poller(wf, "t1", "u1", "a1", "n1")
assert isinstance(poller, PluginTriggerDebugEventPoller)
def test_creates_webhook_poller(self):
wf = self._workflow_with_node(NodeType.TRIGGER_WEBHOOK)
poller = create_event_poller(wf, "t1", "u1", "a1", "n1")
assert isinstance(poller, WebhookTriggerDebugEventPoller)
def test_creates_schedule_poller(self):
wf = self._workflow_with_node(NodeType.TRIGGER_SCHEDULE)
poller = create_event_poller(wf, "t1", "u1", "a1", "n1")
assert isinstance(poller, ScheduleTriggerDebugEventPoller)
def test_raises_for_unknown_type(self):
wf = MagicMock()
wf.get_node_config_by_id.return_value = {"data": {}}
wf.get_node_type_from_node_config.return_value = NodeType.START
with pytest.raises(ValueError):
create_event_poller(wf, "t1", "u1", "a1", "n1")
def test_raises_when_node_config_missing(self):
wf = MagicMock()
wf.get_node_config_by_id.return_value = None
with pytest.raises(ValueError):
create_event_poller(wf, "t1", "u1", "a1", "n1")
class TestSelectTriggerDebugEvents:
def test_returns_first_non_none_event(self):
wf = MagicMock()
wf.get_node_config_by_id.return_value = {"data": {}}
wf.get_node_type_from_node_config.return_value = NodeType.TRIGGER_WEBHOOK
app_model = MagicMock()
app_model.tenant_id = "t1"
app_model.id = "a1"
with patch.object(WebhookTriggerDebugEventPoller, "poll") as mock_poll:
expected = MagicMock()
mock_poll.return_value = expected
result = select_trigger_debug_events(wf, app_model, "u1", ["n1", "n2"])
assert result is expected
def test_returns_none_when_no_events(self):
wf = MagicMock()
wf.get_node_config_by_id.return_value = {"data": {}}
wf.get_node_type_from_node_config.return_value = NodeType.TRIGGER_WEBHOOK
app_model = MagicMock()
app_model.tenant_id = "t1"
app_model.id = "a1"
with patch.object(WebhookTriggerDebugEventPoller, "poll", return_value=None):
result = select_trigger_debug_events(wf, app_model, "u1", ["n1"])
assert result is None

View File

@ -0,0 +1,332 @@
"""
Tests for core.trigger.provider.PluginTriggerProviderController.
Covers: to_api_entity creation-method logic, credential validation pipeline,
schema resolution by type, event lookup, dispatch/invoke/subscribe delegation.
"""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from core.plugin.entities.plugin_daemon import CredentialType
from core.trigger.entities.entities import (
EventParameter,
EventParameterType,
OAuthSchema,
TriggerCreationMethod,
)
from core.trigger.errors import TriggerProviderCredentialValidationError
from tests.unit_tests.core.trigger.conftest import (
i18n,
make_constructor,
make_controller,
make_event,
make_provider_config,
make_provider_entity,
make_subscription,
)
ICON_URL = "https://cdn/icon.png"
class TestToApiEntity:
@patch("core.trigger.provider.PluginService")
def test_includes_icons_when_present(self, mock_plugin_svc):
mock_plugin_svc.get_plugin_icon_url.return_value = ICON_URL
ctrl = make_controller(entity=make_provider_entity(icon="icon.png", icon_dark="dark.png"))
api = ctrl.to_api_entity()
assert api.icon == ICON_URL
assert api.icon_dark == ICON_URL
@patch("core.trigger.provider.PluginService")
def test_icons_none_when_absent(self, mock_plugin_svc):
ctrl = make_controller(entity=make_provider_entity(icon=None, icon_dark=None))
api = ctrl.to_api_entity()
assert api.icon is None
assert api.icon_dark is None
mock_plugin_svc.get_plugin_icon_url.assert_not_called()
@patch("core.trigger.provider.PluginService")
def test_manual_only_without_schemas(self, mock_plugin_svc):
mock_plugin_svc.get_plugin_icon_url.return_value = ICON_URL
ctrl = make_controller(entity=make_provider_entity(constructor=None))
api = ctrl.to_api_entity()
assert api.supported_creation_methods == [TriggerCreationMethod.MANUAL]
@patch("core.trigger.provider.PluginService")
def test_adds_oauth_when_oauth_schema_present(self, mock_plugin_svc):
mock_plugin_svc.get_plugin_icon_url.return_value = ICON_URL
oauth = OAuthSchema(client_schema=[], credentials_schema=[])
ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor(oauth_schema=oauth)))
api = ctrl.to_api_entity()
assert TriggerCreationMethod.OAUTH in api.supported_creation_methods
assert TriggerCreationMethod.MANUAL in api.supported_creation_methods
@patch("core.trigger.provider.PluginService")
def test_adds_apikey_when_credentials_schema_present(self, mock_plugin_svc):
mock_plugin_svc.get_plugin_icon_url.return_value = ICON_URL
ctrl = make_controller(
entity=make_provider_entity(constructor=make_constructor(credentials_schema=[make_provider_config()]))
)
api = ctrl.to_api_entity()
assert TriggerCreationMethod.APIKEY in api.supported_creation_methods
class TestGetEvent:
def test_returns_matching_event(self):
evt = make_event("push")
ctrl = make_controller(entity=make_provider_entity(events=[evt, make_event("pr")]))
assert ctrl.get_event("push") is evt
def test_returns_none_for_unknown(self):
ctrl = make_controller(entity=make_provider_entity(events=[make_event("push")]))
assert ctrl.get_event("nonexistent") is None
class TestGetSubscriptionDefaultProperties:
def test_returns_defaults_skipping_none(self):
config1 = make_provider_config("key1")
config1.default = "val1"
config2 = make_provider_config("key2")
config2.default = None
ctrl = make_controller(entity=make_provider_entity(subscription_schema=[config1, config2]))
props = ctrl.get_subscription_default_properties()
assert props == {"key1": "val1"}
class TestValidateCredentials:
def test_raises_when_no_constructor(self):
ctrl = make_controller(entity=make_provider_entity(constructor=None))
with pytest.raises(ValueError, match="Subscription constructor not found"):
ctrl.validate_credentials("u1", {"key": "val"})
def test_raises_for_missing_required_field(self):
required_cfg = make_provider_config("api_key", required=True)
ctrl = make_controller(
entity=make_provider_entity(constructor=make_constructor(credentials_schema=[required_cfg]))
)
with pytest.raises(TriggerProviderCredentialValidationError, match="Missing required"):
ctrl.validate_credentials("u1", {})
@patch("core.trigger.provider.PluginTriggerClient")
def test_passes_with_valid_credentials(self, mock_client):
required_cfg = make_provider_config("api_key", required=True)
ctrl = make_controller(
entity=make_provider_entity(constructor=make_constructor(credentials_schema=[required_cfg]))
)
mock_client.return_value.validate_provider_credentials.return_value = True
ctrl.validate_credentials("u1", {"api_key": "secret123"}) # should not raise
@patch("core.trigger.provider.PluginTriggerClient")
def test_raises_when_plugin_rejects(self, mock_client):
required_cfg = make_provider_config("api_key", required=True)
ctrl = make_controller(
entity=make_provider_entity(constructor=make_constructor(credentials_schema=[required_cfg]))
)
mock_client.return_value.validate_provider_credentials.return_value = None
with pytest.raises(TriggerProviderCredentialValidationError, match="Invalid credentials"):
ctrl.validate_credentials("u1", {"api_key": "bad"})
class TestGetSupportedCredentialTypes:
def test_empty_when_no_constructor(self):
ctrl = make_controller(entity=make_provider_entity(constructor=None))
assert ctrl.get_supported_credential_types() == []
def test_oauth_only(self):
oauth = OAuthSchema(client_schema=[], credentials_schema=[])
ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor(oauth_schema=oauth)))
types = ctrl.get_supported_credential_types()
assert CredentialType.OAUTH2 in types
assert CredentialType.API_KEY not in types
def test_apikey_only(self):
ctrl = make_controller(
entity=make_provider_entity(constructor=make_constructor(credentials_schema=[make_provider_config()]))
)
types = ctrl.get_supported_credential_types()
assert CredentialType.API_KEY in types
assert CredentialType.OAUTH2 not in types
def test_both(self):
oauth = OAuthSchema(client_schema=[], credentials_schema=[make_provider_config("oauth_secret")])
ctrl = make_controller(
entity=make_provider_entity(
constructor=make_constructor(credentials_schema=[make_provider_config()], oauth_schema=oauth)
)
)
types = ctrl.get_supported_credential_types()
assert CredentialType.OAUTH2 in types
assert CredentialType.API_KEY in types
class TestGetCredentialsSchema:
def test_returns_empty_when_no_constructor(self):
ctrl = make_controller(entity=make_provider_entity(constructor=None))
assert ctrl.get_credentials_schema(CredentialType.API_KEY) == []
def test_returns_apikey_credentials(self):
cfg = make_provider_config("token")
ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor(credentials_schema=[cfg])))
result = ctrl.get_credentials_schema(CredentialType.API_KEY)
assert len(result) == 1
assert result[0].name == "token"
def test_returns_oauth_credentials(self):
oauth_cred = make_provider_config("oauth_token")
oauth = OAuthSchema(client_schema=[], credentials_schema=[oauth_cred])
ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor(oauth_schema=oauth)))
result = ctrl.get_credentials_schema(CredentialType.OAUTH2)
assert len(result) == 1
assert result[0].name == "oauth_token"
def test_unauthorized_returns_empty(self):
ctrl = make_controller(
entity=make_provider_entity(constructor=make_constructor(credentials_schema=[make_provider_config()]))
)
assert ctrl.get_credentials_schema(CredentialType.UNAUTHORIZED) == []
def test_invalid_type_raises(self):
ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor()))
with pytest.raises(ValueError, match="Invalid credential type"):
ctrl.get_credentials_schema("bogus_type")
class TestGetEventParameters:
def test_returns_params_for_known_event(self):
param = EventParameter(name="branch", label=i18n("branch"), type=EventParameterType.STRING)
evt = make_event("push", parameters=[param])
ctrl = make_controller(entity=make_provider_entity(events=[evt]))
result = ctrl.get_event_parameters("push")
assert "branch" in result
assert result["branch"].name == "branch"
def test_returns_empty_for_unknown_event(self):
ctrl = make_controller(entity=make_provider_entity(events=[make_event("push")]))
assert ctrl.get_event_parameters("nonexistent") == {}
class TestDispatch:
@patch("core.trigger.provider.PluginTriggerClient")
def test_delegates_to_client(self, mock_client):
ctrl = make_controller()
expected = MagicMock()
mock_client.return_value.dispatch_event.return_value = expected
result = ctrl.dispatch(
request=MagicMock(),
subscription=make_subscription(),
credentials={"k": "v"},
credential_type=CredentialType.API_KEY,
)
assert result is expected
mock_client.return_value.dispatch_event.assert_called_once()
class TestInvokeTriggerEvent:
@patch("core.trigger.provider.PluginTriggerClient")
def test_delegates_to_client(self, mock_client):
ctrl = make_controller()
expected = MagicMock()
mock_client.return_value.invoke_trigger_event.return_value = expected
result = ctrl.invoke_trigger_event(
user_id="u1",
event_name="push",
parameters={},
credentials={},
credential_type=CredentialType.API_KEY,
subscription=make_subscription(),
request=MagicMock(),
payload={},
)
assert result is expected
class TestSubscribeTrigger:
@patch("core.trigger.provider.PluginTriggerClient")
def test_returns_validated_subscription(self, mock_client):
ctrl = make_controller()
mock_client.return_value.subscribe.return_value.subscription = {
"expires_at": 123,
"endpoint": "https://e",
"properties": {},
}
result = ctrl.subscribe_trigger(
user_id="u1",
endpoint="https://e",
parameters={},
credentials={},
credential_type=CredentialType.API_KEY,
)
assert result.endpoint == "https://e"
class TestUnsubscribeTrigger:
@patch("core.trigger.provider.PluginTriggerClient")
def test_returns_validated_result(self, mock_client):
ctrl = make_controller()
mock_client.return_value.unsubscribe.return_value.subscription = {"success": True, "message": "ok"}
result = ctrl.unsubscribe_trigger(
user_id="u1",
subscription=make_subscription(),
credentials={},
credential_type=CredentialType.API_KEY,
)
assert result.success is True
class TestRefreshTrigger:
@patch("core.trigger.provider.PluginTriggerClient")
def test_uses_system_user_id(self, mock_client):
ctrl = make_controller()
mock_client.return_value.refresh.return_value.subscription = {
"expires_at": 456,
"endpoint": "https://e",
"properties": {},
}
ctrl.refresh_trigger(subscription=make_subscription(), credentials={}, credential_type=CredentialType.API_KEY)
call_kwargs = mock_client.return_value.refresh.call_args[1]
assert call_kwargs["user_id"] == "system"

View File

@ -0,0 +1,307 @@
"""
Tests for core.trigger.trigger_manager.TriggerManager.
Covers: icon URL construction, provider listing with error resilience,
double-check lock caching, error translation, EventIgnoreError -> cancelled,
and delegation to provider controller.
"""
from __future__ import annotations
from threading import Lock
from unittest.mock import MagicMock, patch
import pytest
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.entities.request import TriggerInvokeEventResponse
from core.plugin.impl.exc import PluginDaemonError, PluginNotFoundError
from core.trigger.errors import EventIgnoreError
from core.trigger.trigger_manager import TriggerManager
from models.provider_ids import TriggerProviderID
from tests.unit_tests.core.trigger.conftest import (
VALID_PROVIDER_ID,
make_controller,
make_provider_entity,
make_subscription,
)
PID = TriggerProviderID(VALID_PROVIDER_ID)
PID_STR = str(PID)
class TestGetTriggerPluginIcon:
@patch("core.trigger.trigger_manager.dify_config")
@patch("core.trigger.trigger_manager.PluginTriggerClient")
def test_builds_correct_url(self, mock_client, mock_config):
mock_config.CONSOLE_API_URL = "https://console.example.com"
provider = MagicMock()
provider.declaration.identity.icon = "my-icon.svg"
mock_client.return_value.fetch_trigger_provider.return_value = provider
url = TriggerManager.get_trigger_plugin_icon("tenant-1", VALID_PROVIDER_ID)
assert "tenant_id=tenant-1" in url
assert "filename=my-icon.svg" in url
assert url.startswith("https://console.example.com/console/api/workspaces/current/plugin/icon")
class TestListPluginTriggerProviders:
@patch("core.trigger.trigger_manager.PluginTriggerClient")
def test_wraps_entities_into_controllers(self, mock_client):
entity = MagicMock()
entity.declaration = make_provider_entity("p1")
entity.plugin_id = "plugin-1"
entity.plugin_unique_identifier = "uid-1"
entity.provider = VALID_PROVIDER_ID
mock_client.return_value.fetch_trigger_providers.return_value = [entity]
controllers = TriggerManager.list_plugin_trigger_providers("tenant-1")
assert len(controllers) == 1
assert controllers[0].plugin_id == "plugin-1"
@patch("core.trigger.trigger_manager.PluginTriggerClient")
def test_skips_failing_providers(self, mock_client):
good = MagicMock()
good.declaration = make_provider_entity("good")
good.plugin_id = "good-plugin"
good.plugin_unique_identifier = "uid-good"
good.provider = VALID_PROVIDER_ID
bad = MagicMock()
bad.declaration = make_provider_entity("bad")
bad.plugin_id = "bad-plugin"
bad.plugin_unique_identifier = "uid-bad"
bad.provider = "bad/format" # 2-part: fails TriggerProviderID validation
mock_client.return_value.fetch_trigger_providers.return_value = [bad, good]
controllers = TriggerManager.list_plugin_trigger_providers("tenant-1")
assert len(controllers) == 1
assert controllers[0].plugin_id == "good-plugin"
class TestGetTriggerProvider:
@patch("core.trigger.trigger_manager.PluginTriggerClient")
@patch("core.trigger.trigger_manager.contexts")
def test_initializes_context_on_first_call(self, mock_ctx, mock_client):
# get() called 3 times: (1) try block, (2) after set, (3) under lock
mock_ctx.plugin_trigger_providers.get.side_effect = [LookupError, {}, {}]
mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock()
provider = MagicMock()
provider.declaration = make_provider_entity()
provider.plugin_id = "p1"
provider.plugin_unique_identifier = "uid-1"
mock_client.return_value.fetch_trigger_provider.return_value = provider
result = TriggerManager.get_trigger_provider("t1", PID)
mock_ctx.plugin_trigger_providers.set.assert_called_once_with({})
mock_ctx.plugin_trigger_providers_lock.set.assert_called_once()
assert result is not None
@patch("core.trigger.trigger_manager.PluginTriggerClient")
@patch("core.trigger.trigger_manager.contexts")
def test_returns_cached_without_fetch(self, mock_ctx, mock_client):
cached = make_controller()
mock_ctx.plugin_trigger_providers.get.return_value = {PID_STR: cached}
result = TriggerManager.get_trigger_provider("t1", PID)
assert result is cached
mock_client.return_value.fetch_trigger_provider.assert_not_called()
@patch("core.trigger.trigger_manager.PluginTriggerClient")
@patch("core.trigger.trigger_manager.contexts")
def test_double_check_lock_uses_cached_from_other_thread(self, mock_ctx, mock_client):
cached = make_controller()
mock_ctx.plugin_trigger_providers.get.side_effect = [
{}, # first check misses
{PID_STR: cached}, # under-lock check hits
]
mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock()
result = TriggerManager.get_trigger_provider("t1", PID)
assert result is cached
mock_client.return_value.fetch_trigger_provider.assert_not_called()
@patch("core.trigger.trigger_manager.PluginTriggerClient")
@patch("core.trigger.trigger_manager.contexts")
def test_fetches_and_caches_on_miss(self, mock_ctx, mock_client):
cache: dict = {}
mock_ctx.plugin_trigger_providers.get.return_value = cache
mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock()
provider = MagicMock()
provider.declaration = make_provider_entity()
provider.plugin_id = "p1"
provider.plugin_unique_identifier = "uid-1"
mock_client.return_value.fetch_trigger_provider.return_value = provider
result = TriggerManager.get_trigger_provider("t1", PID)
assert result is not None
assert PID_STR in cache
@patch("core.trigger.trigger_manager.PluginTriggerClient")
@patch("core.trigger.trigger_manager.contexts")
def test_none_fetch_raises_value_error(self, mock_ctx, mock_client):
mock_ctx.plugin_trigger_providers.get.return_value = {}
mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock()
mock_client.return_value.fetch_trigger_provider.return_value = None
with pytest.raises(ValueError):
TriggerManager.get_trigger_provider("t1", TriggerProviderID("org/plug/missing"))
@patch("core.trigger.trigger_manager.PluginTriggerClient")
@patch("core.trigger.trigger_manager.contexts")
def test_plugin_not_found_becomes_value_error(self, mock_ctx, mock_client):
mock_ctx.plugin_trigger_providers.get.return_value = {}
mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock()
mock_client.return_value.fetch_trigger_provider.side_effect = PluginNotFoundError("gone")
with pytest.raises(ValueError):
TriggerManager.get_trigger_provider("t1", TriggerProviderID("org/plug/miss"))
@patch("core.trigger.trigger_manager.PluginTriggerClient")
@patch("core.trigger.trigger_manager.contexts")
def test_plugin_daemon_error_propagates(self, mock_ctx, mock_client):
mock_ctx.plugin_trigger_providers.get.return_value = {}
mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock()
mock_client.return_value.fetch_trigger_provider.side_effect = PluginDaemonError("test error")
with pytest.raises(PluginDaemonError):
TriggerManager.get_trigger_provider("t1", TriggerProviderID("org/plug/miss"))
class TestListAllTriggerProviders:
@patch.object(TriggerManager, "list_plugin_trigger_providers")
def test_delegates_to_list_plugin(self, mock_list):
expected = [make_controller()]
mock_list.return_value = expected
assert TriggerManager.list_all_trigger_providers("t1") is expected
mock_list.assert_called_once_with("t1")
class TestListTriggersByProvider:
@patch.object(TriggerManager, "get_trigger_provider")
def test_returns_provider_events(self, mock_get):
ctrl = make_controller()
mock_get.return_value = ctrl
result = TriggerManager.list_triggers_by_provider("t1", PID)
assert result == ctrl.get_events()
class TestInvokeTriggerEvent:
def _args(self):
return {
"tenant_id": "t1",
"user_id": "u1",
"provider_id": PID,
"event_name": "on_push",
"parameters": {"branch": "main"},
"credentials": {"token": "abc"},
"credential_type": CredentialType.API_KEY,
"subscription": make_subscription(),
"request": MagicMock(),
"payload": {"action": "push"},
}
@patch.object(TriggerManager, "get_trigger_provider")
def test_returns_invoke_response(self, mock_get):
ctrl = MagicMock()
expected = TriggerInvokeEventResponse(variables={"v": "1"}, cancelled=False)
ctrl.invoke_trigger_event.return_value = expected
mock_get.return_value = ctrl
result = TriggerManager.invoke_trigger_event(**self._args())
assert result is expected
assert result.cancelled is False
@patch.object(TriggerManager, "get_trigger_provider")
def test_event_ignore_returns_cancelled(self, mock_get):
ctrl = MagicMock()
ctrl.invoke_trigger_event.side_effect = EventIgnoreError("skip")
mock_get.return_value = ctrl
result = TriggerManager.invoke_trigger_event(**self._args())
assert result.cancelled is True
assert result.variables == {}
@patch.object(TriggerManager, "get_trigger_provider")
def test_other_errors_propagate(self, mock_get):
ctrl = MagicMock()
ctrl.invoke_trigger_event.side_effect = RuntimeError("boom")
mock_get.return_value = ctrl
with pytest.raises(RuntimeError, match="boom"):
TriggerManager.invoke_trigger_event(**self._args())
class TestSubscribeTrigger:
@patch.object(TriggerManager, "get_trigger_provider")
def test_delegates_with_correct_args(self, mock_get):
ctrl = MagicMock()
expected = make_subscription()
ctrl.subscribe_trigger.return_value = expected
mock_get.return_value = ctrl
result = TriggerManager.subscribe_trigger(
tenant_id="t1",
user_id="u1",
provider_id=PID,
endpoint="https://hook.test",
parameters={"f": "all"},
credentials={"token": "x"},
credential_type=CredentialType.API_KEY,
)
assert result is expected
ctrl.subscribe_trigger.assert_called_once()
class TestUnsubscribeTrigger:
@patch.object(TriggerManager, "get_trigger_provider")
def test_delegates_with_correct_args(self, mock_get):
ctrl = MagicMock()
expected = MagicMock()
ctrl.unsubscribe_trigger.return_value = expected
mock_get.return_value = ctrl
sub = make_subscription()
result = TriggerManager.unsubscribe_trigger(
tenant_id="t1",
user_id="u1",
provider_id=PID,
subscription=sub,
credentials={"token": "x"},
credential_type=CredentialType.API_KEY,
)
assert result is expected
class TestRefreshTrigger:
@patch.object(TriggerManager, "get_trigger_provider")
def test_delegates_with_correct_args(self, mock_get):
ctrl = MagicMock()
expected = make_subscription()
ctrl.refresh_trigger.return_value = expected
mock_get.return_value = ctrl
result = TriggerManager.refresh_trigger(
tenant_id="t1",
provider_id=PID,
subscription=make_subscription(),
credentials={"token": "x"},
credential_type=CredentialType.API_KEY,
)
assert result is expected

View File

@ -0,0 +1,62 @@
"""Tests for core.trigger.utils.encryption — masking logic and cache key generation."""
from __future__ import annotations
from core.entities.provider_entities import ProviderConfig
from core.tools.entities.common_entities import I18nObject
from core.trigger.utils.encryption import (
TriggerProviderCredentialsCache,
TriggerProviderOAuthClientParamsCache,
TriggerProviderPropertiesCache,
masked_credentials,
)
def _make_schema(name: str, field_type: str = "secret-input") -> ProviderConfig:
return ProviderConfig(
name=name,
label=I18nObject(en_US=name, zh_Hans=name),
type=field_type,
)
class TestMaskedCredentials:
def test_short_secret_fully_masked(self):
schema = [_make_schema("key", "secret-input")]
result = masked_credentials(schema, {"key": "ab"})
assert result["key"] == "**"
def test_long_secret_partially_masked(self):
schema = [_make_schema("key", "secret-input")]
result = masked_credentials(schema, {"key": "abcdef"})
assert result["key"].startswith("ab")
assert result["key"].endswith("ef")
assert "**" in result["key"]
def test_non_secret_field_unchanged(self):
schema = [_make_schema("host", "text-input")]
result = masked_credentials(schema, {"host": "example.com"})
assert result["host"] == "example.com"
def test_unknown_key_passes_through(self):
result = masked_credentials([], {"unknown": "value"})
assert result["unknown"] == "value"
class TestCacheKeyGeneration:
def test_credentials_cache_key_contains_ids(self):
cache = TriggerProviderCredentialsCache(tenant_id="t1", provider_id="p1", credential_id="c1")
assert "t1" in cache.cache_key
assert "p1" in cache.cache_key
assert "c1" in cache.cache_key
def test_oauth_client_cache_key_contains_ids(self):
cache = TriggerProviderOAuthClientParamsCache(tenant_id="t1", provider_id="p1")
assert "t1" in cache.cache_key
assert "p1" in cache.cache_key
def test_properties_cache_key_contains_ids(self):
cache = TriggerProviderPropertiesCache(tenant_id="t1", provider_id="p1", subscription_id="s1")
assert "t1" in cache.cache_key
assert "p1" in cache.cache_key
assert "s1" in cache.cache_key

View File

@ -0,0 +1,31 @@
"""Tests for core.trigger.utils.endpoint — URL generation."""
from __future__ import annotations
from unittest.mock import patch
from yarl import URL
from core.trigger.utils import endpoint
class TestGeneratePluginTriggerEndpointUrl:
def test_builds_correct_url(self):
with patch.object(endpoint, "base_url", URL("https://api.example.com")):
url = endpoint.generate_plugin_trigger_endpoint_url("endpoint-123")
assert url == "https://api.example.com/triggers/plugin/endpoint-123"
class TestGenerateWebhookTriggerEndpoint:
def test_non_debug_url(self):
with patch.object(endpoint, "base_url", URL("https://api.example.com")):
url = endpoint.generate_webhook_trigger_endpoint("sub-456", debug=False)
assert url == "https://api.example.com/triggers/webhook/sub-456"
def test_debug_url(self):
with patch.object(endpoint, "base_url", URL("https://api.example.com")):
url = endpoint.generate_webhook_trigger_endpoint("sub-456", debug=True)
assert url == "https://api.example.com/triggers/webhook-debug/sub-456"

View File

@ -0,0 +1,23 @@
"""Tests for core.trigger.utils.locks — Redis lock key builders."""
from __future__ import annotations
from core.trigger.utils.locks import build_trigger_refresh_lock_key, build_trigger_refresh_lock_keys
class TestBuildTriggerRefreshLockKey:
def test_correct_format(self):
key = build_trigger_refresh_lock_key("tenant-1", "sub-1")
assert key == "trigger_provider_refresh_lock:tenant-1_sub-1"
class TestBuildTriggerRefreshLockKeys:
def test_maps_over_pairs(self):
pairs = [("t1", "s1"), ("t2", "s2")]
keys = build_trigger_refresh_lock_keys(pairs)
assert len(keys) == 2
assert keys[0] == "trigger_provider_refresh_lock:t1_s1"
assert keys[1] == "trigger_provider_refresh_lock:t2_s2"

View File

@ -22,7 +22,7 @@ from dify_graph.nodes.knowledge_retrieval import KnowledgeRetrievalNode
from dify_graph.nodes.llm import LLMNode
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.parameter_extractor import ParameterExtractorNode
from dify_graph.nodes.protocols import HttpClientProtocol
from dify_graph.nodes.protocols import HttpClientProtocol, ToolFileManagerProtocol
from dify_graph.nodes.question_classifier import QuestionClassifierNode
from dify_graph.nodes.template_transform import TemplateTransformNode
from dify_graph.nodes.template_transform.template_renderer import (
@ -73,6 +73,12 @@ class MockNodeMixin:
if isinstance(self, TemplateTransformNode):
kwargs.setdefault("template_renderer", _TestJinja2Renderer())
# Provide default tool_file_manager_factory for ToolNode subclasses
from dify_graph.nodes.tool import ToolNode as _ToolNode # local import to avoid cycles
if isinstance(self, _ToolNode):
kwargs.setdefault("tool_file_manager_factory", MagicMock(spec=ToolFileManagerProtocol))
super().__init__(
id=id,
config=config,

View File

@ -31,6 +31,7 @@ def tool_node(monkeypatch) -> ToolNode:
ops_stub.TraceTask = object # pragma: no cover - stub attribute
monkeypatch.setitem(sys.modules, module_name, ops_stub)
from dify_graph.nodes.protocols import ToolFileManagerProtocol
from dify_graph.nodes.tool.tool_node import ToolNode
graph_config: dict[str, Any] = {
@ -69,11 +70,16 @@ def tool_node(monkeypatch) -> ToolNode:
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
config = graph_config["nodes"][0]
# Provide a stub ToolFileManager to satisfy the updated ToolNode constructor
tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol)
node = ToolNode(
id="node-instance",
config=config,
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
tool_file_manager_factory=tool_file_manager_factory,
)
return node