mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 08:58:09 +08:00
Merge branch 'feat/model-plugins-implementing' into deploy/dev
This commit is contained in:
@ -45,7 +45,6 @@ allow_indirect_imports = True
|
||||
ignore_imports =
|
||||
dify_graph.nodes.agent.agent_node -> extensions.ext_database
|
||||
dify_graph.nodes.llm.node -> extensions.ext_database
|
||||
dify_graph.nodes.tool.tool_node -> extensions.ext_database
|
||||
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
|
||||
dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
|
||||
|
||||
@ -111,7 +110,6 @@ ignore_imports =
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
|
||||
dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer
|
||||
dify_graph.nodes.tool.tool_node -> models
|
||||
dify_graph.nodes.agent.agent_node -> models.model
|
||||
dify_graph.nodes.llm.node -> core.helper.code_executor
|
||||
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors
|
||||
@ -134,7 +132,6 @@ ignore_imports =
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.errors
|
||||
dify_graph.nodes.agent.agent_node -> extensions.ext_database
|
||||
dify_graph.nodes.llm.node -> extensions.ext_database
|
||||
dify_graph.nodes.tool.tool_node -> extensions.ext_database
|
||||
dify_graph.nodes.agent.agent_node -> models
|
||||
dify_graph.nodes.llm.node -> models.model
|
||||
dify_graph.nodes.agent.agent_node -> services
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
import json
|
||||
from contextlib import ExitStack
|
||||
from typing import Self
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask import request, send_file
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from sqlalchemy import desc, select
|
||||
@ -100,6 +101,15 @@ class DocumentListQuery(BaseModel):
|
||||
status: str | None = Field(default=None, description="Document status filter")
|
||||
|
||||
|
||||
DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100
|
||||
|
||||
|
||||
class DocumentBatchDownloadZipPayload(BaseModel):
|
||||
"""Request payload for bulk downloading uploaded documents as a ZIP archive."""
|
||||
|
||||
document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS)
|
||||
|
||||
|
||||
register_enum_models(service_api_ns, RetrievalMethod)
|
||||
|
||||
register_schema_models(
|
||||
@ -109,6 +119,7 @@ register_schema_models(
|
||||
DocumentTextCreatePayload,
|
||||
DocumentTextUpdate,
|
||||
DocumentListQuery,
|
||||
DocumentBatchDownloadZipPayload,
|
||||
Rule,
|
||||
PreProcessingRule,
|
||||
Segmentation,
|
||||
@ -540,6 +551,46 @@ class DocumentListApi(DatasetApiResource):
|
||||
return response
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/download-zip")
|
||||
class DocumentBatchDownloadZipApi(DatasetApiResource):
|
||||
"""Download multiple uploaded-file documents as a single ZIP archive."""
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentBatchDownloadZipPayload.__name__])
|
||||
@service_api_ns.doc("download_documents_as_zip")
|
||||
@service_api_ns.doc(description="Download selected uploaded documents as a single ZIP archive")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "ZIP archive generated successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
403: "Forbidden - insufficient permissions",
|
||||
404: "Document or dataset not found",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
payload = DocumentBatchDownloadZipPayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
upload_files, download_name = DocumentService.prepare_document_batch_download_zip(
|
||||
dataset_id=str(dataset_id),
|
||||
document_ids=[str(document_id) for document_id in payload.document_ids],
|
||||
tenant_id=str(tenant_id),
|
||||
current_user=current_user,
|
||||
)
|
||||
|
||||
with ExitStack() as stack:
|
||||
zip_path = stack.enter_context(FileService.build_upload_files_zip_tempfile(upload_files=upload_files))
|
||||
response = send_file(
|
||||
zip_path,
|
||||
mimetype="application/zip",
|
||||
as_attachment=True,
|
||||
download_name=download_name,
|
||||
)
|
||||
cleanup = stack.pop_all()
|
||||
response.call_on_close(cleanup.close)
|
||||
return response
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status")
|
||||
class DocumentIndexingStatusApi(DatasetApiResource):
|
||||
@service_api_ns.doc("get_document_indexing_status")
|
||||
@ -600,6 +651,35 @@ class DocumentIndexingStatusApi(DatasetApiResource):
|
||||
return data
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/download")
|
||||
class DocumentDownloadApi(DatasetApiResource):
|
||||
"""Return a signed download URL for a document's original uploaded file."""
|
||||
|
||||
@service_api_ns.doc("get_document_download_url")
|
||||
@service_api_ns.doc(description="Get a signed download URL for a document's original uploaded file")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Download URL generated successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
403: "Forbidden - insufficient permissions",
|
||||
404: "Document or upload file not found",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def get(self, tenant_id, dataset_id, document_id):
|
||||
dataset = self.get_dataset(str(dataset_id), str(tenant_id))
|
||||
document = DocumentService.get_document(dataset.id, str(document_id))
|
||||
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
if document.tenant_id != str(tenant_id):
|
||||
raise Forbidden("No permission.")
|
||||
|
||||
return {"url": DocumentService.get_document_download_url(document)}
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
class DocumentApi(DatasetApiResource):
|
||||
METADATA_CHOICES = {"all", "only", "without"}
|
||||
|
||||
@ -14,6 +14,7 @@ import httpx
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper import ssrf_proxy
|
||||
from dify_graph.file.models import ToolFile as ToolFilePydanticModel
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import MessageFile
|
||||
from models.tools import ToolFile
|
||||
@ -207,7 +208,9 @@ class ToolFileManager:
|
||||
|
||||
return blob, tool_file.mimetype
|
||||
|
||||
def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, ToolFile | None]:
|
||||
def get_file_generator_by_tool_file_id(
|
||||
self, tool_file_id: str
|
||||
) -> tuple[Generator | None, ToolFilePydanticModel | None]:
|
||||
"""
|
||||
get file binary
|
||||
|
||||
@ -229,7 +232,7 @@ class ToolFileManager:
|
||||
|
||||
stream = storage.load_stream(tool_file.file_key)
|
||||
|
||||
return stream, tool_file
|
||||
return stream, ToolFilePydanticModel.model_validate(tool_file)
|
||||
|
||||
|
||||
# init tool_file_parser
|
||||
|
||||
@ -50,6 +50,7 @@ from dify_graph.nodes.template_transform.template_renderer import (
|
||||
CodeExecutorJinja2TemplateRenderer,
|
||||
)
|
||||
from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from dify_graph.nodes.tool.tool_node import ToolNode
|
||||
from dify_graph.variables.segments import StringSegment
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
@ -310,6 +311,15 @@ class DifyNodeFactory(NodeFactory):
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
if node_type == NodeType.TOOL:
|
||||
return ToolNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
tool_file_manager_factory=self._http_request_tool_file_manager_factory(),
|
||||
)
|
||||
|
||||
return node_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
|
||||
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
@ -43,6 +44,24 @@ class FileUploadConfig(BaseModel):
|
||||
number_limits: int = 0
|
||||
|
||||
|
||||
class ToolFile(BaseModel):
|
||||
id: UUID = Field(default_factory=uuid4, description="Unique identifier for the file")
|
||||
user_id: UUID = Field(..., description="ID of the user who owns this file")
|
||||
tenant_id: UUID = Field(..., description="ID of the tenant/organization")
|
||||
conversation_id: UUID | None = Field(None, description="ID of the associated conversation")
|
||||
file_key: str = Field(..., max_length=255, description="Storage key for the file")
|
||||
mimetype: str = Field(..., max_length=255, description="MIME type of the file")
|
||||
original_url: str | None = Field(
|
||||
None, max_length=2048, description="Original URL if file was fetched from external source"
|
||||
)
|
||||
name: str = Field(default="", max_length=255, description="Display name of the file")
|
||||
size: int = Field(default=-1, ge=-1, description="File size in bytes (-1 if unknown)")
|
||||
|
||||
class Config:
|
||||
from_attributes = True # Enable ORM mode for SQLAlchemy compatibility
|
||||
populate_by_name = True
|
||||
|
||||
|
||||
class File(BaseModel):
|
||||
# NOTE: dify_model_identity is a special identifier used to distinguish between
|
||||
# new and old data formats during serialization and deserialization.
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Protocol
|
||||
|
||||
import httpx
|
||||
|
||||
from dify_graph.file import File
|
||||
from dify_graph.file.models import ToolFile
|
||||
|
||||
|
||||
class HttpClientProtocol(Protocol):
|
||||
@ -40,3 +42,5 @@ class ToolFileManagerProtocol(Protocol):
|
||||
mimetype: str,
|
||||
filename: str | None = None,
|
||||
) -> Any: ...
|
||||
|
||||
def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, ToolFile | None]: ...
|
||||
|
||||
@ -1,9 +1,6 @@
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
@ -21,11 +18,10 @@ from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from dify_graph.nodes.protocols import ToolFileManagerProtocol
|
||||
from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment
|
||||
from dify_graph.variables.variables import ArrayAnyVariable
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import ToolFile
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
from .entities import ToolNodeData
|
||||
@ -36,7 +32,8 @@ from .exc import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_graph.runtime import VariablePool
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
|
||||
class ToolNode(Node[ToolNodeData]):
|
||||
@ -46,6 +43,23 @@ class ToolNode(Node[ToolNodeData]):
|
||||
|
||||
node_type = NodeType.TOOL
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
tool_file_manager_factory: ToolFileManagerProtocol,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._tool_file_manager_factory = tool_file_manager_factory
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
@ -271,11 +285,9 @@ class ToolNode(Node[ToolNodeData]):
|
||||
|
||||
tool_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
|
||||
_, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id)
|
||||
if not tool_file:
|
||||
raise ToolFileError(f"tool file {tool_file_id} not found")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
@ -294,11 +306,9 @@ class ToolNode(Node[ToolNodeData]):
|
||||
assert message.meta
|
||||
|
||||
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileError(f"tool file {tool_file_id} not exists")
|
||||
_, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id)
|
||||
if not tool_file:
|
||||
raise ToolFileError(f"tool file {tool_file_id} not exists")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
|
||||
@ -8,6 +8,7 @@ from core.workflow.node_factory import DifyNodeFactory
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.node_events import StreamCompletedEvent
|
||||
from dify_graph.nodes.protocols import ToolFileManagerProtocol
|
||||
from dify_graph.nodes.tool.tool_node import ToolNode
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
@ -55,11 +56,14 @@ def init_tool_node(config: dict):
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol)
|
||||
|
||||
node = ToolNode(
|
||||
id=str(uuid.uuid4()),
|
||||
config=config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tool_file_manager_factory=tool_file_manager_factory,
|
||||
)
|
||||
return node
|
||||
|
||||
|
||||
@ -0,0 +1,92 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from controllers.console.app import annotation as annotation_module
|
||||
|
||||
|
||||
def test_annotation_reply_payload_valid():
|
||||
"""Test AnnotationReplyPayload with valid data."""
|
||||
payload = annotation_module.AnnotationReplyPayload(
|
||||
score_threshold=0.5,
|
||||
embedding_provider_name="openai",
|
||||
embedding_model_name="text-embedding-3-small",
|
||||
)
|
||||
assert payload.score_threshold == 0.5
|
||||
assert payload.embedding_provider_name == "openai"
|
||||
assert payload.embedding_model_name == "text-embedding-3-small"
|
||||
|
||||
|
||||
def test_annotation_setting_update_payload_valid():
|
||||
"""Test AnnotationSettingUpdatePayload with valid data."""
|
||||
payload = annotation_module.AnnotationSettingUpdatePayload(
|
||||
score_threshold=0.75,
|
||||
)
|
||||
assert payload.score_threshold == 0.75
|
||||
|
||||
|
||||
def test_annotation_list_query_defaults():
|
||||
"""Test AnnotationListQuery with default parameters."""
|
||||
query = annotation_module.AnnotationListQuery()
|
||||
assert query.page == 1
|
||||
assert query.limit == 20
|
||||
assert query.keyword == ""
|
||||
|
||||
|
||||
def test_annotation_list_query_custom_page():
|
||||
"""Test AnnotationListQuery with custom page."""
|
||||
query = annotation_module.AnnotationListQuery(page=3, limit=50)
|
||||
assert query.page == 3
|
||||
assert query.limit == 50
|
||||
|
||||
|
||||
def test_annotation_list_query_with_keyword():
|
||||
"""Test AnnotationListQuery with keyword."""
|
||||
query = annotation_module.AnnotationListQuery(keyword="test")
|
||||
assert query.keyword == "test"
|
||||
|
||||
|
||||
def test_create_annotation_payload_with_message_id():
|
||||
"""Test CreateAnnotationPayload with message ID."""
|
||||
payload = annotation_module.CreateAnnotationPayload(
|
||||
message_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
question="What is AI?",
|
||||
)
|
||||
assert payload.message_id == "550e8400-e29b-41d4-a716-446655440000"
|
||||
assert payload.question == "What is AI?"
|
||||
|
||||
|
||||
def test_create_annotation_payload_with_text():
|
||||
"""Test CreateAnnotationPayload with text content."""
|
||||
payload = annotation_module.CreateAnnotationPayload(
|
||||
question="What is ML?",
|
||||
answer="Machine learning is...",
|
||||
)
|
||||
assert payload.question == "What is ML?"
|
||||
assert payload.answer == "Machine learning is..."
|
||||
|
||||
|
||||
def test_update_annotation_payload():
|
||||
"""Test UpdateAnnotationPayload."""
|
||||
payload = annotation_module.UpdateAnnotationPayload(
|
||||
question="Updated question",
|
||||
answer="Updated answer",
|
||||
)
|
||||
assert payload.question == "Updated question"
|
||||
assert payload.answer == "Updated answer"
|
||||
|
||||
|
||||
def test_annotation_reply_status_query_enable():
|
||||
"""Test AnnotationReplyStatusQuery with enable action."""
|
||||
query = annotation_module.AnnotationReplyStatusQuery(action="enable")
|
||||
assert query.action == "enable"
|
||||
|
||||
|
||||
def test_annotation_reply_status_query_disable():
|
||||
"""Test AnnotationReplyStatusQuery with disable action."""
|
||||
query = annotation_module.AnnotationReplyStatusQuery(action="disable")
|
||||
assert query.action == "disable"
|
||||
|
||||
|
||||
def test_annotation_file_payload_valid():
|
||||
"""Test AnnotationFilePayload with valid message ID."""
|
||||
payload = annotation_module.AnnotationFilePayload(message_id="550e8400-e29b-41d4-a716-446655440000")
|
||||
assert payload.message_id == "550e8400-e29b-41d4-a716-446655440000"
|
||||
@ -13,6 +13,9 @@ from pandas.errors import ParserError
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console.wraps import annotation_import_concurrency_limit, annotation_import_rate_limit
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task
|
||||
|
||||
|
||||
class TestAnnotationImportRateLimiting:
|
||||
@ -33,8 +36,6 @@ class TestAnnotationImportRateLimiting:
|
||||
|
||||
def test_rate_limit_per_minute_enforced(self, mock_redis, mock_current_account):
|
||||
"""Test that per-minute rate limit is enforced."""
|
||||
from controllers.console.wraps import annotation_import_rate_limit
|
||||
|
||||
# Simulate exceeding per-minute limit
|
||||
mock_redis.zcard.side_effect = [
|
||||
dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE + 1, # Minute check
|
||||
@ -54,7 +55,6 @@ class TestAnnotationImportRateLimiting:
|
||||
|
||||
def test_rate_limit_per_hour_enforced(self, mock_redis, mock_current_account):
|
||||
"""Test that per-hour rate limit is enforced."""
|
||||
from controllers.console.wraps import annotation_import_rate_limit
|
||||
|
||||
# Simulate exceeding per-hour limit
|
||||
mock_redis.zcard.side_effect = [
|
||||
@ -74,7 +74,6 @@ class TestAnnotationImportRateLimiting:
|
||||
|
||||
def test_rate_limit_within_limits_passes(self, mock_redis, mock_current_account):
|
||||
"""Test that requests within limits are allowed."""
|
||||
from controllers.console.wraps import annotation_import_rate_limit
|
||||
|
||||
# Simulate being under both limits
|
||||
mock_redis.zcard.return_value = 2
|
||||
@ -110,7 +109,6 @@ class TestAnnotationImportConcurrencyControl:
|
||||
|
||||
def test_concurrency_limit_enforced(self, mock_redis, mock_current_account):
|
||||
"""Test that concurrent task limit is enforced."""
|
||||
from controllers.console.wraps import annotation_import_concurrency_limit
|
||||
|
||||
# Simulate max concurrent tasks already running
|
||||
mock_redis.zcard.return_value = dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT
|
||||
@ -127,7 +125,6 @@ class TestAnnotationImportConcurrencyControl:
|
||||
|
||||
def test_concurrency_within_limit_passes(self, mock_redis, mock_current_account):
|
||||
"""Test that requests within concurrency limits are allowed."""
|
||||
from controllers.console.wraps import annotation_import_concurrency_limit
|
||||
|
||||
# Simulate being under concurrent task limit
|
||||
mock_redis.zcard.return_value = 1
|
||||
@ -142,7 +139,6 @@ class TestAnnotationImportConcurrencyControl:
|
||||
|
||||
def test_stale_jobs_are_cleaned_up(self, mock_redis, mock_current_account):
|
||||
"""Test that old/stale job entries are removed."""
|
||||
from controllers.console.wraps import annotation_import_concurrency_limit
|
||||
|
||||
mock_redis.zcard.return_value = 0
|
||||
|
||||
@ -203,7 +199,6 @@ class TestAnnotationImportServiceValidation:
|
||||
|
||||
def test_max_records_limit_enforced(self, mock_app, mock_db_session):
|
||||
"""Test that files with too many records are rejected."""
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
# Create CSV with too many records
|
||||
max_records = dify_config.ANNOTATION_IMPORT_MAX_RECORDS
|
||||
@ -229,7 +224,6 @@ class TestAnnotationImportServiceValidation:
|
||||
|
||||
def test_min_records_limit_enforced(self, mock_app, mock_db_session):
|
||||
"""Test that files with too few valid records are rejected."""
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
# Create CSV with only header (no data rows)
|
||||
csv_content = "question,answer\n"
|
||||
@ -249,7 +243,6 @@ class TestAnnotationImportServiceValidation:
|
||||
|
||||
def test_invalid_csv_format_handled(self, mock_app, mock_db_session):
|
||||
"""Test that invalid CSV format is handled gracefully."""
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
# Any content is fine once we force ParserError
|
||||
csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff'
|
||||
@ -270,7 +263,6 @@ class TestAnnotationImportServiceValidation:
|
||||
|
||||
def test_valid_import_succeeds(self, mock_app, mock_db_session):
|
||||
"""Test that valid import request succeeds."""
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
# Create valid CSV
|
||||
csv_content = "question,answer\nWhat is AI?,Artificial Intelligence\nWhat is ML?,Machine Learning\n"
|
||||
@ -300,18 +292,10 @@ class TestAnnotationImportServiceValidation:
|
||||
class TestAnnotationImportTaskOptimization:
|
||||
"""Test optimizations in batch import task."""
|
||||
|
||||
def test_task_has_timeout_configured(self):
|
||||
"""Test that task has proper timeout configuration."""
|
||||
from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task
|
||||
|
||||
# Verify task configuration
|
||||
assert hasattr(batch_import_annotations_task, "time_limit")
|
||||
assert hasattr(batch_import_annotations_task, "soft_time_limit")
|
||||
|
||||
# Check timeout values are reasonable
|
||||
# Hard limit should be 6 minutes (360s)
|
||||
# Soft limit should be 5 minutes (300s)
|
||||
# Note: actual values depend on Celery configuration
|
||||
def test_task_is_registered_with_queue(self):
|
||||
"""Test that task is registered with the correct queue."""
|
||||
assert hasattr(batch_import_annotations_task, "apply_async")
|
||||
assert hasattr(batch_import_annotations_task, "delay")
|
||||
|
||||
|
||||
class TestConfigurationValues:
|
||||
|
||||
585
api/tests/unit_tests/controllers/console/app/test_app_apis.py
Normal file
585
api/tests/unit_tests/controllers/console/app/test_app_apis.py
Normal file
@ -0,0 +1,585 @@
|
||||
"""
|
||||
Additional tests to improve coverage for low-coverage modules in controllers/console/app.
|
||||
Target: increase coverage for files with <75% coverage.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.console.app import (
|
||||
annotation as annotation_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
completion as completion_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
message as message_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
ops_trace as ops_trace_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
site as site_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
statistic as statistic_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
workflow_app_log as workflow_app_log_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
workflow_draft_variable as workflow_draft_variable_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
workflow_statistic as workflow_statistic_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
workflow_trigger as workflow_trigger_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
wraps as wraps_module,
|
||||
)
|
||||
from controllers.console.app.completion import ChatMessagePayload, CompletionMessagePayload
|
||||
from controllers.console.app.mcp_server import MCPServerCreatePayload, MCPServerUpdatePayload
|
||||
from controllers.console.app.ops_trace import TraceConfigPayload, TraceProviderQuery
|
||||
from controllers.console.app.site import AppSiteUpdatePayload
|
||||
from controllers.console.app.workflow import AdvancedChatWorkflowRunPayload, SyncDraftWorkflowPayload
|
||||
from controllers.console.app.workflow_app_log import WorkflowAppLogQuery
|
||||
from controllers.console.app.workflow_draft_variable import WorkflowDraftVariableUpdatePayload
|
||||
from controllers.console.app.workflow_statistic import WorkflowStatisticQuery
|
||||
from controllers.console.app.workflow_trigger import Parser, ParserEnable
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
class _ConnContext:
|
||||
def __init__(self, rows):
|
||||
self._rows = rows
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def execute(self, _query, _args):
|
||||
return self._rows
|
||||
|
||||
|
||||
# ========== Completion Tests ==========
|
||||
class TestCompletionEndpoints:
|
||||
"""Tests for completion API endpoints."""
|
||||
|
||||
def test_completion_create_payload(self):
|
||||
"""Test completion creation payload."""
|
||||
payload = CompletionMessagePayload(inputs={"prompt": "test"}, model_config={})
|
||||
assert payload.inputs == {"prompt": "test"}
|
||||
|
||||
def test_chat_message_payload_uuid_validation(self):
|
||||
payload = ChatMessagePayload(
|
||||
inputs={},
|
||||
model_config={},
|
||||
query="hi",
|
||||
conversation_id=str(uuid.uuid4()),
|
||||
parent_message_id=str(uuid.uuid4()),
|
||||
)
|
||||
assert payload.query == "hi"
|
||||
|
||||
def test_completion_api_success(self, app, monkeypatch):
|
||||
api = completion_module.CompletionMessageApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
class DummyAccount:
|
||||
pass
|
||||
|
||||
dummy_account = DummyAccount()
|
||||
|
||||
monkeypatch.setattr(completion_module, "current_user", dummy_account)
|
||||
monkeypatch.setattr(completion_module, "Account", DummyAccount)
|
||||
monkeypatch.setattr(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
lambda **_kwargs: {"text": "ok"},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
completion_module.helper,
|
||||
"compact_generate_response",
|
||||
lambda response: {"result": response},
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/",
|
||||
json={"inputs": {}, "model_config": {}, "query": "hi"},
|
||||
):
|
||||
resp = method(app_model=MagicMock(id="app-1"))
|
||||
|
||||
assert resp == {"result": {"text": "ok"}}
|
||||
|
||||
def test_completion_api_conversation_not_exists(self, app, monkeypatch):
|
||||
api = completion_module.CompletionMessageApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
class DummyAccount:
|
||||
pass
|
||||
|
||||
dummy_account = DummyAccount()
|
||||
|
||||
monkeypatch.setattr(completion_module, "current_user", dummy_account)
|
||||
monkeypatch.setattr(completion_module, "Account", DummyAccount)
|
||||
monkeypatch.setattr(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
lambda **_kwargs: (_ for _ in ()).throw(
|
||||
completion_module.services.errors.conversation.ConversationNotExistsError()
|
||||
),
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/",
|
||||
json={"inputs": {}, "model_config": {}, "query": "hi"},
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(app_model=MagicMock(id="app-1"))
|
||||
|
||||
def test_completion_api_provider_not_initialized(self, app, monkeypatch):
|
||||
api = completion_module.CompletionMessageApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
class DummyAccount:
|
||||
pass
|
||||
|
||||
dummy_account = DummyAccount()
|
||||
|
||||
monkeypatch.setattr(completion_module, "current_user", dummy_account)
|
||||
monkeypatch.setattr(completion_module, "Account", DummyAccount)
|
||||
monkeypatch.setattr(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
lambda **_kwargs: (_ for _ in ()).throw(completion_module.ProviderTokenNotInitError("x")),
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/",
|
||||
json={"inputs": {}, "model_config": {}, "query": "hi"},
|
||||
):
|
||||
with pytest.raises(completion_module.ProviderNotInitializeError):
|
||||
method(app_model=MagicMock(id="app-1"))
|
||||
|
||||
def test_completion_api_quota_exceeded(self, app, monkeypatch):
|
||||
api = completion_module.CompletionMessageApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
class DummyAccount:
|
||||
pass
|
||||
|
||||
dummy_account = DummyAccount()
|
||||
|
||||
monkeypatch.setattr(completion_module, "current_user", dummy_account)
|
||||
monkeypatch.setattr(completion_module, "Account", DummyAccount)
|
||||
monkeypatch.setattr(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
lambda **_kwargs: (_ for _ in ()).throw(completion_module.QuotaExceededError()),
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/",
|
||||
json={"inputs": {}, "model_config": {}, "query": "hi"},
|
||||
):
|
||||
with pytest.raises(completion_module.ProviderQuotaExceededError):
|
||||
method(app_model=MagicMock(id="app-1"))
|
||||
|
||||
|
||||
# ========== OpsTrace Tests ==========
|
||||
class TestOpsTraceEndpoints:
|
||||
"""Tests for ops_trace endpoint."""
|
||||
|
||||
def test_ops_trace_query_basic(self):
|
||||
"""Test ops_trace query."""
|
||||
query = TraceProviderQuery(tracing_provider="langfuse")
|
||||
assert query.tracing_provider == "langfuse"
|
||||
|
||||
def test_ops_trace_config_payload(self):
|
||||
payload = TraceConfigPayload(tracing_provider="langfuse", tracing_config={"api_key": "k"})
|
||||
assert payload.tracing_config["api_key"] == "k"
|
||||
|
||||
def test_trace_app_config_get_empty(self, app, monkeypatch):
|
||||
api = ops_trace_module.TraceAppConfigApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(
|
||||
ops_trace_module.OpsService,
|
||||
"get_tracing_app_config",
|
||||
lambda **_kwargs: None,
|
||||
)
|
||||
|
||||
with app.test_request_context("/?tracing_provider=langfuse"):
|
||||
result = method(app_id="app-1")
|
||||
|
||||
assert result == {"has_not_configured": True}
|
||||
|
||||
def test_trace_app_config_post_invalid(self, app, monkeypatch):
|
||||
api = ops_trace_module.TraceAppConfigApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(
|
||||
ops_trace_module.OpsService,
|
||||
"create_tracing_app_config",
|
||||
lambda **_kwargs: {"error": True},
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/",
|
||||
json={"tracing_provider": "langfuse", "tracing_config": {"api_key": "k"}},
|
||||
):
|
||||
with pytest.raises(BadRequest):
|
||||
method(app_id="app-1")
|
||||
|
||||
def test_trace_app_config_delete_not_found(self, app, monkeypatch):
|
||||
api = ops_trace_module.TraceAppConfigApi()
|
||||
method = _unwrap(api.delete)
|
||||
|
||||
monkeypatch.setattr(
|
||||
ops_trace_module.OpsService,
|
||||
"delete_tracing_app_config",
|
||||
lambda **_kwargs: False,
|
||||
)
|
||||
|
||||
with app.test_request_context("/?tracing_provider=langfuse"):
|
||||
with pytest.raises(BadRequest):
|
||||
method(app_id="app-1")
|
||||
|
||||
|
||||
# ========== Site Tests ==========
|
||||
class TestSiteEndpoints:
|
||||
"""Tests for site endpoint."""
|
||||
|
||||
def test_site_response_structure(self):
|
||||
"""Test site response structure."""
|
||||
payload = AppSiteUpdatePayload(title="My Site", description="Test site")
|
||||
assert payload.title == "My Site"
|
||||
|
||||
def test_site_default_language_validation(self):
|
||||
payload = AppSiteUpdatePayload(default_language="en-US")
|
||||
assert payload.default_language == "en-US"
|
||||
|
||||
def test_app_site_update_post(self, app, monkeypatch):
|
||||
api = site_module.AppSite()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
site = MagicMock()
|
||||
query = MagicMock()
|
||||
query.where.return_value.first.return_value = site
|
||||
monkeypatch.setattr(
|
||||
site_module.db,
|
||||
"session",
|
||||
MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
site_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="u1"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(site_module, "naive_utc_now", lambda: "now")
|
||||
|
||||
with app.test_request_context("/", json={"title": "My Site"}):
|
||||
result = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert result is site
|
||||
|
||||
def test_app_site_access_token_reset(self, app, monkeypatch):
|
||||
api = site_module.AppSiteAccessTokenReset()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
site = MagicMock()
|
||||
query = MagicMock()
|
||||
query.where.return_value.first.return_value = site
|
||||
monkeypatch.setattr(
|
||||
site_module.db,
|
||||
"session",
|
||||
MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None),
|
||||
)
|
||||
monkeypatch.setattr(site_module.Site, "generate_code", lambda *_args, **_kwargs: "code")
|
||||
monkeypatch.setattr(
|
||||
site_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="u1"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(site_module, "naive_utc_now", lambda: "now")
|
||||
|
||||
with app.test_request_context("/"):
|
||||
result = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert result is site
|
||||
|
||||
|
||||
# ========== Workflow Tests ==========
|
||||
class TestWorkflowEndpoints:
|
||||
"""Tests for workflow endpoints."""
|
||||
|
||||
def test_workflow_copy_payload(self):
|
||||
"""Test workflow copy payload."""
|
||||
payload = SyncDraftWorkflowPayload(graph={}, features={})
|
||||
assert payload.graph == {}
|
||||
|
||||
def test_workflow_mode_query(self):
|
||||
"""Test workflow mode query."""
|
||||
payload = AdvancedChatWorkflowRunPayload(inputs={}, query="hi")
|
||||
assert payload.query == "hi"
|
||||
|
||||
|
||||
# ========== Workflow App Log Tests ==========
|
||||
class TestWorkflowAppLogEndpoints:
|
||||
"""Tests for workflow app log endpoints."""
|
||||
|
||||
def test_workflow_app_log_query(self):
|
||||
"""Test workflow app log query."""
|
||||
query = WorkflowAppLogQuery(keyword="test", page=1, limit=20)
|
||||
assert query.keyword == "test"
|
||||
|
||||
def test_workflow_app_log_query_detail_bool(self):
|
||||
query = WorkflowAppLogQuery(detail="true")
|
||||
assert query.detail is True
|
||||
|
||||
def test_workflow_app_log_api_get(self, app, monkeypatch):
|
||||
api = workflow_app_log_module.WorkflowAppLogApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(workflow_app_log_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
|
||||
class DummySession:
|
||||
def __enter__(self):
|
||||
return "session"
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(workflow_app_log_module, "Session", lambda *args, **kwargs: DummySession())
|
||||
|
||||
def fake_get_paginate(self, **_kwargs):
|
||||
return {"items": [], "total": 0}
|
||||
|
||||
monkeypatch.setattr(
|
||||
workflow_app_log_module.WorkflowAppService,
|
||||
"get_paginate_workflow_app_logs",
|
||||
fake_get_paginate,
|
||||
)
|
||||
|
||||
with app.test_request_context("/?page=1&limit=20"):
|
||||
result = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert result == {"items": [], "total": 0}
|
||||
|
||||
|
||||
# ========== Workflow Draft Variable Tests ==========
|
||||
class TestWorkflowDraftVariableEndpoints:
|
||||
"""Tests for workflow draft variable endpoints."""
|
||||
|
||||
def test_workflow_variable_creation(self):
|
||||
"""Test workflow variable creation."""
|
||||
payload = WorkflowDraftVariableUpdatePayload(name="var1", value="test")
|
||||
assert payload.name == "var1"
|
||||
|
||||
def test_workflow_variable_collection_get(self, app, monkeypatch):
|
||||
api = workflow_draft_variable_module.WorkflowVariableCollectionApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(workflow_draft_variable_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
|
||||
class DummySession:
|
||||
def __enter__(self):
|
||||
return "session"
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
class DummyDraftService:
|
||||
def __init__(self, session):
|
||||
self.session = session
|
||||
|
||||
def list_variables_without_values(self, **_kwargs):
|
||||
return {"items": [], "total": 0}
|
||||
|
||||
monkeypatch.setattr(workflow_draft_variable_module, "Session", lambda *args, **kwargs: DummySession())
|
||||
|
||||
class DummyWorkflowService:
|
||||
def is_workflow_exist(self, *args, **kwargs):
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(workflow_draft_variable_module, "WorkflowDraftVariableService", DummyDraftService)
|
||||
monkeypatch.setattr(workflow_draft_variable_module, "WorkflowService", DummyWorkflowService)
|
||||
|
||||
with app.test_request_context("/?page=1&limit=20"):
|
||||
result = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert result == {"items": [], "total": 0}
|
||||
|
||||
|
||||
# ========== Workflow Statistic Tests ==========
|
||||
class TestWorkflowStatisticEndpoints:
|
||||
"""Tests for workflow statistic endpoints."""
|
||||
|
||||
def test_workflow_statistic_time_range(self):
|
||||
"""Test workflow statistic time range query."""
|
||||
query = WorkflowStatisticQuery(start="2024-01-01", end="2024-12-31")
|
||||
assert query.start == "2024-01-01"
|
||||
|
||||
def test_workflow_statistic_blank_to_none(self):
|
||||
query = WorkflowStatisticQuery(start="", end="")
|
||||
assert query.start is None
|
||||
assert query.end is None
|
||||
|
||||
def test_workflow_daily_runs_statistic(self, app, monkeypatch):
|
||||
monkeypatch.setattr(workflow_statistic_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
monkeypatch.setattr(
|
||||
workflow_statistic_module.DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_args, **_kwargs: SimpleNamespace(get_daily_runs_statistics=lambda **_kw: [{"date": "2024-01-01"}]),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_statistic_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_statistic_module,
|
||||
"parse_time_range",
|
||||
lambda *_args, **_kwargs: (None, None),
|
||||
)
|
||||
|
||||
api = workflow_statistic_module.WorkflowDailyRunsStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
response = method(app_model=SimpleNamespace(tenant_id="t1", id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-01"}]}
|
||||
|
||||
def test_workflow_daily_terminals_statistic(self, app, monkeypatch):
|
||||
monkeypatch.setattr(workflow_statistic_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
monkeypatch.setattr(
|
||||
workflow_statistic_module.DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_args, **_kwargs: SimpleNamespace(
|
||||
get_daily_terminals_statistics=lambda **_kw: [{"date": "2024-01-02"}]
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_statistic_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_statistic_module,
|
||||
"parse_time_range",
|
||||
lambda *_args, **_kwargs: (None, None),
|
||||
)
|
||||
|
||||
api = workflow_statistic_module.WorkflowDailyTerminalsStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
response = method(app_model=SimpleNamespace(tenant_id="t1", id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-02"}]}
|
||||
|
||||
|
||||
# ========== Workflow Trigger Tests ==========
|
||||
class TestWorkflowTriggerEndpoints:
|
||||
"""Tests for workflow trigger endpoints."""
|
||||
|
||||
def test_webhook_trigger_payload(self):
|
||||
"""Test webhook trigger payload."""
|
||||
payload = Parser(node_id="node-1")
|
||||
assert payload.node_id == "node-1"
|
||||
|
||||
enable_payload = ParserEnable(trigger_id="trigger-1", enable_trigger=True)
|
||||
assert enable_payload.enable_trigger is True
|
||||
|
||||
def test_webhook_trigger_api_get(self, app, monkeypatch):
|
||||
api = workflow_trigger_module.WebhookTriggerApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(workflow_trigger_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
|
||||
trigger = MagicMock()
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = trigger
|
||||
|
||||
class DummySession:
|
||||
def __enter__(self):
|
||||
return session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(workflow_trigger_module, "Session", lambda *_args, **_kwargs: DummySession())
|
||||
|
||||
with app.test_request_context("/?node_id=node-1"):
|
||||
result = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert result is trigger
|
||||
|
||||
|
||||
# ========== Wraps Tests ==========
|
||||
class TestWrapsEndpoints:
|
||||
"""Tests for wraps utility functions."""
|
||||
|
||||
def test_get_app_model_context(self):
|
||||
"""Test get_app_model wrapper context."""
|
||||
# These are decorator functions, so we test their availability
|
||||
assert hasattr(wraps_module, "get_app_model")
|
||||
|
||||
|
||||
# ========== MCP Server Tests ==========
|
||||
class TestMCPServerEndpoints:
|
||||
"""Tests for MCP server endpoints."""
|
||||
|
||||
def test_mcp_server_connection(self):
|
||||
"""Test MCP server connection."""
|
||||
payload = MCPServerCreatePayload(parameters={"url": "http://localhost:3000"})
|
||||
assert payload.parameters["url"] == "http://localhost:3000"
|
||||
|
||||
def test_mcp_server_update_payload(self):
|
||||
payload = MCPServerUpdatePayload(id="server-1", parameters={"timeout": 30}, status="active")
|
||||
assert payload.status == "active"
|
||||
|
||||
|
||||
# ========== Error Handling Tests ==========
|
||||
class TestErrorHandling:
|
||||
"""Tests for error handling in various endpoints."""
|
||||
|
||||
def test_annotation_list_query_validation(self):
|
||||
"""Test annotation list query validation."""
|
||||
with pytest.raises(ValueError):
|
||||
annotation_module.AnnotationListQuery(page=0)
|
||||
|
||||
|
||||
# ========== Integration-like Tests ==========
|
||||
class TestPayloadIntegration:
|
||||
"""Integration tests for payload handling."""
|
||||
|
||||
def test_multiple_payload_types(self):
|
||||
"""Test handling of multiple payload types."""
|
||||
payloads = [
|
||||
annotation_module.AnnotationReplyPayload(
|
||||
score_threshold=0.5, embedding_provider_name="openai", embedding_model_name="text-embedding-3-small"
|
||||
),
|
||||
message_module.MessageFeedbackPayload(message_id=str(uuid.uuid4()), rating="like"),
|
||||
statistic_module.StatisticTimeRangeQuery(start="2024-01-01"),
|
||||
]
|
||||
assert len(payloads) == 3
|
||||
assert all(p is not None for p in payloads)
|
||||
@ -0,0 +1,157 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.app import app_import as app_import_module
|
||||
from services.app_dsl_service import ImportStatus
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
class _Result:
|
||||
def __init__(self, status: ImportStatus, app_id: str | None = "app-1"):
|
||||
self.status = status
|
||||
self.app_id = app_id
|
||||
|
||||
def model_dump(self, mode: str = "json"):
|
||||
return {"status": self.status, "app_id": self.app_id}
|
||||
|
||||
|
||||
class _SessionContext:
|
||||
def __init__(self, session):
|
||||
self._session = session
|
||||
|
||||
def __enter__(self):
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
|
||||
def _install_session(monkeypatch: pytest.MonkeyPatch, session: MagicMock) -> None:
|
||||
monkeypatch.setattr(app_import_module, "Session", lambda *_: _SessionContext(session))
|
||||
monkeypatch.setattr(app_import_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
|
||||
def _install_features(monkeypatch: pytest.MonkeyPatch, enabled: bool) -> None:
|
||||
features = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=enabled))
|
||||
monkeypatch.setattr(app_import_module.FeatureService, "get_system_features", lambda: features)
|
||||
|
||||
|
||||
def test_import_post_returns_failed_status(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
session = MagicMock()
|
||||
_install_session(monkeypatch, session)
|
||||
_install_features(monkeypatch, enabled=False)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
session.commit.assert_called_once()
|
||||
assert status == 400
|
||||
assert response["status"] == ImportStatus.FAILED
|
||||
|
||||
|
||||
def test_import_post_returns_pending_status(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
session = MagicMock()
|
||||
_install_session(monkeypatch, session)
|
||||
_install_features(monkeypatch, enabled=False)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.PENDING),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
session.commit.assert_called_once()
|
||||
assert status == 202
|
||||
assert response["status"] == ImportStatus.PENDING
|
||||
|
||||
|
||||
def test_import_post_updates_webapp_auth_when_enabled(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
session = MagicMock()
|
||||
_install_session(monkeypatch, session)
|
||||
_install_features(monkeypatch, enabled=True)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"),
|
||||
)
|
||||
update_access = MagicMock()
|
||||
monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
session.commit.assert_called_once()
|
||||
update_access.assert_called_once_with("app-123", "private")
|
||||
assert status == 200
|
||||
assert response["status"] == ImportStatus.COMPLETED
|
||||
|
||||
|
||||
def test_import_confirm_returns_failed_status(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportConfirmApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
session = MagicMock()
|
||||
_install_session(monkeypatch, session)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"confirm_import",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"):
|
||||
response, status = method(import_id="import-1")
|
||||
|
||||
session.commit.assert_called_once()
|
||||
assert status == 400
|
||||
assert response["status"] == ImportStatus.FAILED
|
||||
|
||||
|
||||
def test_import_check_dependencies_returns_result(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportCheckDependenciesApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
session = MagicMock()
|
||||
_install_session(monkeypatch, session)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"check_dependencies",
|
||||
lambda *_args, **_kwargs: SimpleNamespace(model_dump=lambda mode="json": {"leaked_dependencies": []}),
|
||||
)
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports/app-1/check-dependencies", method="GET"):
|
||||
response, status = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert status == 200
|
||||
assert response["leaked_dependencies"] == []
|
||||
292
api/tests/unit_tests/controllers/console/app/test_audio.py
Normal file
292
api/tests/unit_tests/controllers/console/app/test_audio.py
Normal file
@ -0,0 +1,292 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.console.app.audio import ChatMessageAudioApi, ChatMessageTextApi, TextModesApi
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
CompletionRequestError,
|
||||
NoAudioUploadedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeError
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
NoAudioUploadedServiceError,
|
||||
ProviderNotSupportSpeechToTextServiceError,
|
||||
ProviderNotSupportTextToSpeechLanageServiceError,
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def _file_data():
|
||||
return FileStorage(stream=io.BytesIO(b"audio"), filename="audio.wav", content_type="audio/wav")
|
||||
|
||||
|
||||
def test_console_audio_api_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "ok"})
|
||||
api = ChatMessageAudioApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
|
||||
with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}):
|
||||
response = handler(app_model=app_model)
|
||||
|
||||
assert response == {"text": "ok"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exc", "expected"),
|
||||
[
|
||||
(AppModelConfigBrokenError(), AppUnavailableError),
|
||||
(NoAudioUploadedServiceError(), NoAudioUploadedError),
|
||||
(AudioTooLargeServiceError("too big"), AudioTooLargeError),
|
||||
(UnsupportedAudioTypeServiceError(), UnsupportedAudioTypeError),
|
||||
(ProviderNotSupportSpeechToTextServiceError(), ProviderNotSupportSpeechToTextError),
|
||||
(ProviderTokenNotInitError("token"), ProviderNotInitializeError),
|
||||
(QuotaExceededError(), ProviderQuotaExceededError),
|
||||
(ModelCurrentlyNotSupportError(), ProviderModelCurrentlyNotSupportError),
|
||||
(InvokeError("invoke"), CompletionRequestError),
|
||||
],
|
||||
)
|
||||
def test_console_audio_api_error_mapping(app, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(exc))
|
||||
api = ChatMessageAudioApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
|
||||
with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}):
|
||||
with pytest.raises(expected):
|
||||
handler(app_model=app_model)
|
||||
|
||||
|
||||
def test_console_audio_api_unhandled_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("boom")))
|
||||
api = ChatMessageAudioApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
|
||||
with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}):
|
||||
with pytest.raises(InternalServerError):
|
||||
handler(app_model=app_model)
|
||||
|
||||
|
||||
def test_console_text_api_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
|
||||
|
||||
api = ChatMessageTextApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app/text-to-audio",
|
||||
method="POST",
|
||||
json={"text": "hello", "voice": "v"},
|
||||
):
|
||||
response = handler(app_model=app_model)
|
||||
|
||||
assert response == {"audio": "ok"}
|
||||
|
||||
|
||||
def test_console_text_api_error_mapping(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: (_ for _ in ()).throw(QuotaExceededError()))
|
||||
|
||||
api = ChatMessageTextApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app/text-to-audio",
|
||||
method="POST",
|
||||
json={"text": "hello"},
|
||||
):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
handler(app_model=app_model)
|
||||
|
||||
|
||||
def test_console_text_modes_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"])
|
||||
|
||||
api = TextModesApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(tenant_id="t1")
|
||||
|
||||
with app.test_request_context("/console/api/apps/app/text-to-audio/voices?language=en", method="GET"):
|
||||
response = handler(app_model=app_model)
|
||||
|
||||
assert response == ["voice-1"]
|
||||
|
||||
|
||||
def test_console_text_modes_language_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
AudioService,
|
||||
"transcript_tts_voices",
|
||||
lambda **_kwargs: (_ for _ in ()).throw(ProviderNotSupportTextToSpeechLanageServiceError()),
|
||||
)
|
||||
|
||||
api = TextModesApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(tenant_id="t1")
|
||||
|
||||
with app.test_request_context("/console/api/apps/app/text-to-audio/voices?language=en", method="GET"):
|
||||
with pytest.raises(AppUnavailableError):
|
||||
handler(app_model=app_model)
|
||||
|
||||
|
||||
def test_audio_to_text_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = ChatMessageAudioApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
response_payload = {"text": "hello"}
|
||||
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: response_payload)
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
data = {"file": (io.BytesIO(b"x"), "sample.wav")}
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/audio-to-text",
|
||||
method="POST",
|
||||
data=data,
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
|
||||
assert response == response_payload
|
||||
|
||||
|
||||
def test_audio_to_text_maps_audio_too_large(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = ChatMessageAudioApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(
|
||||
AudioService,
|
||||
"transcript_asr",
|
||||
lambda **_kwargs: (_ for _ in ()).throw(AudioTooLargeServiceError("too large")),
|
||||
)
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
data = {"file": (io.BytesIO(b"x"), "sample.wav")}
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/audio-to-text",
|
||||
method="POST",
|
||||
data=data,
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
with pytest.raises(AudioTooLargeError):
|
||||
method(app_model=app_model)
|
||||
|
||||
|
||||
def test_text_to_audio_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = ChatMessageTextApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/text-to-audio",
|
||||
method="POST",
|
||||
json={"text": "hello"},
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
|
||||
assert response == {"audio": "ok"}
|
||||
|
||||
|
||||
def test_text_to_audio_voices_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = TextModesApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"])
|
||||
|
||||
app_model = SimpleNamespace(tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/text-to-audio/voices",
|
||||
method="GET",
|
||||
query_string={"language": "en-US"},
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
|
||||
assert response == ["voice-1"]
|
||||
|
||||
|
||||
def test_audio_to_text_with_invalid_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = ChatMessageAudioApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "test"})
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
data = {"file": (io.BytesIO(b"invalid"), "sample.xyz")}
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/audio-to-text",
|
||||
method="POST",
|
||||
data=data,
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
# Should not raise, AudioService is mocked
|
||||
response = method(app_model=app_model)
|
||||
assert response == {"text": "test"}
|
||||
|
||||
|
||||
def test_text_to_audio_with_language_param(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = ChatMessageTextApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "test"})
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/text-to-audio",
|
||||
method="POST",
|
||||
json={"text": "hello", "language": "en-US"},
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
assert response == {"audio": "test"}
|
||||
|
||||
|
||||
def test_text_to_audio_voices_with_language_filter(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = TextModesApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(
|
||||
AudioService,
|
||||
"transcript_tts_voices",
|
||||
lambda **_kwargs: [{"id": "voice-1", "name": "Voice 1"}],
|
||||
)
|
||||
|
||||
app_model = SimpleNamespace(tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/text-to-audio/voices?language=en-US",
|
||||
method="GET",
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
assert isinstance(response, list)
|
||||
156
api/tests/unit_tests/controllers/console/app/test_audio_api.py
Normal file
156
api/tests/unit_tests/controllers/console/app/test_audio_api.py
Normal file
@ -0,0 +1,156 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.app import audio as audio_module
|
||||
from controllers.console.app.error import AudioTooLargeError
|
||||
from services.errors.audio import AudioTooLargeServiceError
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def test_audio_to_text_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.ChatMessageAudioApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
response_payload = {"text": "hello"}
|
||||
monkeypatch.setattr(audio_module.AudioService, "transcript_asr", lambda **_kwargs: response_payload)
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
data = {"file": (io.BytesIO(b"x"), "sample.wav")}
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/audio-to-text",
|
||||
method="POST",
|
||||
data=data,
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
|
||||
assert response == response_payload
|
||||
|
||||
|
||||
def test_audio_to_text_maps_audio_too_large(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.ChatMessageAudioApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(
|
||||
audio_module.AudioService,
|
||||
"transcript_asr",
|
||||
lambda **_kwargs: (_ for _ in ()).throw(AudioTooLargeServiceError("too large")),
|
||||
)
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
data = {"file": (io.BytesIO(b"x"), "sample.wav")}
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/audio-to-text",
|
||||
method="POST",
|
||||
data=data,
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
with pytest.raises(AudioTooLargeError):
|
||||
method(app_model=app_model)
|
||||
|
||||
|
||||
def test_text_to_audio_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.ChatMessageTextApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(audio_module.AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/text-to-audio",
|
||||
method="POST",
|
||||
json={"text": "hello"},
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
|
||||
assert response == {"audio": "ok"}
|
||||
|
||||
|
||||
def test_text_to_audio_voices_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.TextModesApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(audio_module.AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"])
|
||||
|
||||
app_model = SimpleNamespace(tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/text-to-audio/voices",
|
||||
method="GET",
|
||||
query_string={"language": "en-US"},
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
|
||||
assert response == ["voice-1"]
|
||||
|
||||
|
||||
def test_audio_to_text_with_invalid_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.ChatMessageAudioApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(audio_module.AudioService, "transcript_asr", lambda **_kwargs: {"text": "test"})
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
data = {"file": (io.BytesIO(b"invalid"), "sample.xyz")}
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/audio-to-text",
|
||||
method="POST",
|
||||
data=data,
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
# Should not raise, AudioService is mocked
|
||||
response = method(app_model=app_model)
|
||||
assert response == {"text": "test"}
|
||||
|
||||
|
||||
def test_text_to_audio_with_language_param(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.ChatMessageTextApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(audio_module.AudioService, "transcript_tts", lambda **_kwargs: {"audio": "test"})
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/text-to-audio",
|
||||
method="POST",
|
||||
json={"text": "hello", "language": "en-US"},
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
assert response == {"audio": "test"}
|
||||
|
||||
|
||||
def test_text_to_audio_voices_with_language_filter(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.TextModesApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(
|
||||
audio_module.AudioService,
|
||||
"transcript_tts_voices",
|
||||
lambda **_kwargs: [{"id": "voice-1", "name": "Voice 1"}],
|
||||
)
|
||||
|
||||
app_model = SimpleNamespace(tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/text-to-audio/voices?language=en-US",
|
||||
method="GET",
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
assert isinstance(response, list)
|
||||
@ -0,0 +1,130 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.console.app import conversation as conversation_module
|
||||
from models.model import AppMode
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def _make_account():
|
||||
return SimpleNamespace(timezone="UTC", id="u1")
|
||||
|
||||
|
||||
def test_completion_conversation_list_returns_paginated_result(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = conversation_module.CompletionConversationApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
account = _make_account()
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1"))
|
||||
monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None))
|
||||
|
||||
paginate_result = MagicMock()
|
||||
monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/completion-conversations", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response is paginate_result
|
||||
|
||||
|
||||
def test_completion_conversation_list_invalid_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = conversation_module.CompletionConversationApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
account = _make_account()
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1"))
|
||||
monkeypatch.setattr(
|
||||
conversation_module,
|
||||
"parse_time_range",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(ValueError("bad range")),
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/completion-conversations",
|
||||
method="GET",
|
||||
query_string={"start": "bad"},
|
||||
):
|
||||
with pytest.raises(BadRequest):
|
||||
method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
|
||||
def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = conversation_module.ChatConversationApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
account = _make_account()
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1"))
|
||||
monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None))
|
||||
|
||||
paginate_result = MagicMock()
|
||||
monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/chat-conversations", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1", mode=AppMode.ADVANCED_CHAT))
|
||||
|
||||
assert response is paginate_result
|
||||
|
||||
|
||||
def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
conversation = SimpleNamespace(id="c1", app_id="app-1")
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = conversation
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
|
||||
monkeypatch.setattr(conversation_module.db, "session", session)
|
||||
|
||||
result = conversation_module._get_conversation(SimpleNamespace(id="app-1"), "c1")
|
||||
|
||||
assert result is conversation
|
||||
session.execute.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
session.refresh.assert_called_once_with(conversation)
|
||||
|
||||
|
||||
def test_get_conversation_missing_raises_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = None
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
|
||||
monkeypatch.setattr(conversation_module.db, "session", session)
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
conversation_module._get_conversation(SimpleNamespace(id="app-1"), "missing")
|
||||
|
||||
|
||||
def test_completion_conversation_delete_maps_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = conversation_module.CompletionConversationDetailApi()
|
||||
method = _unwrap(api.delete)
|
||||
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
|
||||
monkeypatch.setattr(
|
||||
conversation_module.ConversationService,
|
||||
"delete",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationNotExistsError()),
|
||||
)
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
method(app_model=SimpleNamespace(id="app-1"), conversation_id="c1")
|
||||
@ -0,0 +1,260 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.app import generator as generator_module
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def _model_config_payload():
|
||||
return {"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}}
|
||||
|
||||
|
||||
def _install_workflow_service(monkeypatch: pytest.MonkeyPatch, workflow):
|
||||
class _Service:
|
||||
def get_draft_workflow(self, app_model):
|
||||
return workflow
|
||||
|
||||
monkeypatch.setattr(generator_module, "WorkflowService", lambda: _Service())
|
||||
|
||||
|
||||
def test_rule_generate_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.RuleGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
monkeypatch.setattr(generator_module.LLMGenerator, "generate_rule_config", lambda **_kwargs: {"rules": []})
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/rule-generate",
|
||||
method="POST",
|
||||
json={"instruction": "do it", "model_config": _model_config_payload()},
|
||||
):
|
||||
response = method()
|
||||
|
||||
assert response == {"rules": []}
|
||||
|
||||
|
||||
def test_rule_code_generate_maps_token_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.RuleCodeGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
def _raise(*_args, **_kwargs):
|
||||
raise ProviderTokenNotInitError("missing token")
|
||||
|
||||
monkeypatch.setattr(generator_module.LLMGenerator, "generate_code", _raise)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/rule-code-generate",
|
||||
method="POST",
|
||||
json={"instruction": "do it", "model_config": _model_config_payload()},
|
||||
):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
method()
|
||||
|
||||
|
||||
def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: None)
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate",
|
||||
method="POST",
|
||||
json={
|
||||
"flow_id": "app-1",
|
||||
"node_id": "node-1",
|
||||
"instruction": "do",
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response, status = method()
|
||||
|
||||
assert status == 400
|
||||
assert response["error"] == "app app-1 not found"
|
||||
|
||||
|
||||
def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
_install_workflow_service(monkeypatch, workflow=None)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate",
|
||||
method="POST",
|
||||
json={
|
||||
"flow_id": "app-1",
|
||||
"node_id": "node-1",
|
||||
"instruction": "do",
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response, status = method()
|
||||
|
||||
assert status == 400
|
||||
assert response["error"] == "workflow app-1 not found"
|
||||
|
||||
|
||||
def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
|
||||
workflow = SimpleNamespace(graph_dict={"nodes": []})
|
||||
_install_workflow_service(monkeypatch, workflow=workflow)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate",
|
||||
method="POST",
|
||||
json={
|
||||
"flow_id": "app-1",
|
||||
"node_id": "node-1",
|
||||
"instruction": "do",
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response, status = method()
|
||||
|
||||
assert status == 400
|
||||
assert response["error"] == "node node-1 not found"
|
||||
|
||||
|
||||
def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
|
||||
workflow = SimpleNamespace(
|
||||
graph_dict={
|
||||
"nodes": [
|
||||
{"id": "node-1", "data": {"type": "code"}},
|
||||
]
|
||||
}
|
||||
)
|
||||
_install_workflow_service(monkeypatch, workflow=workflow)
|
||||
monkeypatch.setattr(generator_module.LLMGenerator, "generate_code", lambda **_kwargs: {"code": "x"})
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate",
|
||||
method="POST",
|
||||
json={
|
||||
"flow_id": "app-1",
|
||||
"node_id": "node-1",
|
||||
"instruction": "do",
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response = method()
|
||||
|
||||
assert response == {"code": "x"}
|
||||
|
||||
|
||||
def test_instruction_generate_legacy_modify(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
monkeypatch.setattr(
|
||||
generator_module.LLMGenerator,
|
||||
"instruction_modify_legacy",
|
||||
lambda **_kwargs: {"instruction": "ok"},
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate",
|
||||
method="POST",
|
||||
json={
|
||||
"flow_id": "app-1",
|
||||
"node_id": "",
|
||||
"current": "old",
|
||||
"instruction": "do",
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response = method()
|
||||
|
||||
assert response == {"instruction": "ok"}
|
||||
|
||||
|
||||
def test_instruction_generate_incompatible_params(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate",
|
||||
method="POST",
|
||||
json={
|
||||
"flow_id": "app-1",
|
||||
"node_id": "",
|
||||
"current": "",
|
||||
"instruction": "do",
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response, status = method()
|
||||
|
||||
assert status == 400
|
||||
assert response["error"] == "incompatible parameters"
|
||||
|
||||
|
||||
def test_instruction_template_prompt(app) -> None:
|
||||
api = generator_module.InstructionGenerationTemplateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate/template",
|
||||
method="POST",
|
||||
json={"type": "prompt"},
|
||||
):
|
||||
response = method()
|
||||
|
||||
assert "data" in response
|
||||
|
||||
|
||||
def test_instruction_template_invalid_type(app) -> None:
|
||||
api = generator_module.InstructionGenerationTemplateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate/template",
|
||||
method="POST",
|
||||
json={"type": "unknown"},
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method()
|
||||
122
api/tests/unit_tests/controllers/console/app/test_message_api.py
Normal file
122
api/tests/unit_tests/controllers/console/app/test_message_api.py
Normal file
@ -0,0 +1,122 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.app import message as message_module
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def test_chat_messages_query_valid(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test valid ChatMessagesQuery with all fields."""
|
||||
query = message_module.ChatMessagesQuery(
|
||||
conversation_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
first_id="550e8400-e29b-41d4-a716-446655440001",
|
||||
limit=50,
|
||||
)
|
||||
assert query.limit == 50
|
||||
|
||||
|
||||
def test_chat_messages_query_defaults(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test ChatMessagesQuery with defaults."""
|
||||
query = message_module.ChatMessagesQuery(conversation_id="550e8400-e29b-41d4-a716-446655440000")
|
||||
assert query.first_id is None
|
||||
assert query.limit == 20
|
||||
|
||||
|
||||
def test_chat_messages_query_empty_first_id(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test ChatMessagesQuery converts empty first_id to None."""
|
||||
query = message_module.ChatMessagesQuery(
|
||||
conversation_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
first_id="",
|
||||
)
|
||||
assert query.first_id is None
|
||||
|
||||
|
||||
def test_message_feedback_payload_valid_like(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test MessageFeedbackPayload with like rating."""
|
||||
payload = message_module.MessageFeedbackPayload(
|
||||
message_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
rating="like",
|
||||
content="Good answer",
|
||||
)
|
||||
assert payload.rating == "like"
|
||||
assert payload.content == "Good answer"
|
||||
|
||||
|
||||
def test_message_feedback_payload_valid_dislike(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test MessageFeedbackPayload with dislike rating."""
|
||||
payload = message_module.MessageFeedbackPayload(
|
||||
message_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
rating="dislike",
|
||||
)
|
||||
assert payload.rating == "dislike"
|
||||
|
||||
|
||||
def test_message_feedback_payload_no_rating(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test MessageFeedbackPayload without rating."""
|
||||
payload = message_module.MessageFeedbackPayload(message_id="550e8400-e29b-41d4-a716-446655440000")
|
||||
assert payload.rating is None
|
||||
|
||||
|
||||
def test_feedback_export_query_defaults(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with default format."""
|
||||
query = message_module.FeedbackExportQuery()
|
||||
assert query.format == "csv"
|
||||
assert query.from_source is None
|
||||
|
||||
|
||||
def test_feedback_export_query_json_format(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with JSON format."""
|
||||
query = message_module.FeedbackExportQuery(format="json")
|
||||
assert query.format == "json"
|
||||
|
||||
|
||||
def test_feedback_export_query_has_comment_true(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with has_comment as true string."""
|
||||
query = message_module.FeedbackExportQuery(has_comment="true")
|
||||
assert query.has_comment is True
|
||||
|
||||
|
||||
def test_feedback_export_query_has_comment_false(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with has_comment as false string."""
|
||||
query = message_module.FeedbackExportQuery(has_comment="false")
|
||||
assert query.has_comment is False
|
||||
|
||||
|
||||
def test_feedback_export_query_has_comment_1(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with has_comment as 1."""
|
||||
query = message_module.FeedbackExportQuery(has_comment="1")
|
||||
assert query.has_comment is True
|
||||
|
||||
|
||||
def test_feedback_export_query_has_comment_0(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with has_comment as 0."""
|
||||
query = message_module.FeedbackExportQuery(has_comment="0")
|
||||
assert query.has_comment is False
|
||||
|
||||
|
||||
def test_feedback_export_query_rating_filter(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with rating filter."""
|
||||
query = message_module.FeedbackExportQuery(rating="like")
|
||||
assert query.rating == "like"
|
||||
|
||||
|
||||
def test_annotation_count_response(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test AnnotationCountResponse creation."""
|
||||
response = message_module.AnnotationCountResponse(count=10)
|
||||
assert response.count == 10
|
||||
|
||||
|
||||
def test_suggested_questions_response(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test SuggestedQuestionsResponse creation."""
|
||||
response = message_module.SuggestedQuestionsResponse(data=["What is AI?", "How does ML work?"])
|
||||
assert len(response.data) == 2
|
||||
assert response.data[0] == "What is AI?"
|
||||
@ -0,0 +1,151 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.app import model_config as model_config_module
|
||||
from models.model import AppMode, AppModelConfig
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def test_post_updates_app_model_config_for_chat(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = model_config_module.ModelConfigResource()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
app_model = SimpleNamespace(
|
||||
id="app-1",
|
||||
mode=AppMode.CHAT.value,
|
||||
is_agent=False,
|
||||
app_model_config_id=None,
|
||||
updated_by=None,
|
||||
updated_at=None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
model_config_module.AppModelConfigService,
|
||||
"validate_configuration",
|
||||
lambda **_kwargs: {"pre_prompt": "hi"},
|
||||
)
|
||||
monkeypatch.setattr(model_config_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
session = MagicMock()
|
||||
monkeypatch.setattr(model_config_module.db, "session", session)
|
||||
|
||||
def _from_model_config_dict(self, model_config):
|
||||
self.pre_prompt = model_config["pre_prompt"]
|
||||
self.id = "config-1"
|
||||
return self
|
||||
|
||||
monkeypatch.setattr(AppModelConfig, "from_model_config_dict", _from_model_config_dict)
|
||||
send_mock = MagicMock()
|
||||
monkeypatch.setattr(model_config_module.app_model_config_was_updated, "send", send_mock)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/model-config", method="POST", json={"pre_prompt": "hi"}):
|
||||
response = method(app_model=app_model)
|
||||
|
||||
session.add.assert_called_once()
|
||||
session.flush.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
send_mock.assert_called_once()
|
||||
assert app_model.app_model_config_id == "config-1"
|
||||
assert response["result"] == "success"
|
||||
|
||||
|
||||
def test_post_encrypts_agent_tool_parameters(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = model_config_module.ModelConfigResource()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
app_model = SimpleNamespace(
|
||||
id="app-1",
|
||||
mode=AppMode.AGENT_CHAT.value,
|
||||
is_agent=True,
|
||||
app_model_config_id="config-0",
|
||||
updated_by=None,
|
||||
updated_at=None,
|
||||
)
|
||||
|
||||
original_config = AppModelConfig(app_id="app-1", created_by="u1", updated_by="u1")
|
||||
original_config.agent_mode = json.dumps(
|
||||
{
|
||||
"enabled": True,
|
||||
"strategy": "function-calling",
|
||||
"tools": [
|
||||
{
|
||||
"provider_id": "provider",
|
||||
"provider_type": "builtin",
|
||||
"tool_name": "tool",
|
||||
"tool_parameters": {"secret": "masked"},
|
||||
}
|
||||
],
|
||||
"prompt": None,
|
||||
}
|
||||
)
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = original_config
|
||||
session.query.return_value = query
|
||||
monkeypatch.setattr(model_config_module.db, "session", session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
model_config_module.AppModelConfigService,
|
||||
"validate_configuration",
|
||||
lambda **_kwargs: {
|
||||
"pre_prompt": "hi",
|
||||
"agent_mode": {
|
||||
"enabled": True,
|
||||
"strategy": "function-calling",
|
||||
"tools": [
|
||||
{
|
||||
"provider_id": "provider",
|
||||
"provider_type": "builtin",
|
||||
"tool_name": "tool",
|
||||
"tool_parameters": {"secret": "masked"},
|
||||
}
|
||||
],
|
||||
"prompt": None,
|
||||
},
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(model_config_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
monkeypatch.setattr(model_config_module.ToolManager, "get_agent_tool_runtime", lambda **_kwargs: object())
|
||||
|
||||
class _ParamManager:
|
||||
def __init__(self, **_kwargs):
|
||||
self.delete_called = False
|
||||
|
||||
def decrypt_tool_parameters(self, _value):
|
||||
return {"secret": "decrypted"}
|
||||
|
||||
def mask_tool_parameters(self, _value):
|
||||
return {"secret": "masked"}
|
||||
|
||||
def encrypt_tool_parameters(self, _value):
|
||||
return {"secret": "encrypted"}
|
||||
|
||||
def delete_tool_parameters_cache(self):
|
||||
self.delete_called = True
|
||||
|
||||
monkeypatch.setattr(model_config_module, "ToolParameterConfigurationManager", _ParamManager)
|
||||
send_mock = MagicMock()
|
||||
monkeypatch.setattr(model_config_module.app_model_config_was_updated, "send", send_mock)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/model-config", method="POST", json={"pre_prompt": "hi"}):
|
||||
response = method(app_model=app_model)
|
||||
|
||||
stored_config = session.add.call_args[0][0]
|
||||
stored_agent_mode = json.loads(stored_config.agent_mode)
|
||||
assert stored_agent_mode["tools"][0]["tool_parameters"]["secret"] == "encrypted"
|
||||
assert response["result"] == "success"
|
||||
@ -0,0 +1,215 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console.app import statistic as statistic_module
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
class _ConnContext:
|
||||
def __init__(self, rows):
|
||||
self._rows = rows
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def execute(self, _query, _args):
|
||||
return self._rows
|
||||
|
||||
|
||||
def _install_db(monkeypatch: pytest.MonkeyPatch, rows) -> None:
|
||||
engine = SimpleNamespace(begin=lambda: _ConnContext(rows))
|
||||
monkeypatch.setattr(statistic_module, "db", SimpleNamespace(engine=engine))
|
||||
|
||||
|
||||
def _install_common(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
statistic_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
statistic_module,
|
||||
"parse_time_range",
|
||||
lambda *_args, **_kwargs: (None, None),
|
||||
)
|
||||
monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field)
|
||||
|
||||
|
||||
def test_daily_message_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyMessageStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
rows = [SimpleNamespace(date="2024-01-01", message_count=3)]
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-01", "message_count": 3}]}
|
||||
|
||||
|
||||
def test_daily_conversation_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyConversationStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
rows = [SimpleNamespace(date="2024-01-02", conversation_count=5)]
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-conversations", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]}
|
||||
|
||||
|
||||
def test_daily_token_cost_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyTokenCostStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
rows = [SimpleNamespace(date="2024-01-03", token_count=10, total_price=0.25, currency="USD")]
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/token-costs", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
data = response.get_json()
|
||||
assert len(data["data"]) == 1
|
||||
assert data["data"][0]["date"] == "2024-01-03"
|
||||
assert data["data"][0]["token_count"] == 10
|
||||
assert data["data"][0]["total_price"] == 0.25
|
||||
|
||||
|
||||
def test_daily_terminals_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyTerminalsStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
rows = [SimpleNamespace(date="2024-01-04", terminal_count=7)]
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-end-users", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-04", "terminal_count": 7}]}
|
||||
|
||||
|
||||
def test_average_session_interaction_statistic_requires_chat_mode(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test that AverageSessionInteractionStatistic is limited to chat/agent modes."""
|
||||
# This just verifies the decorator is applied correctly
|
||||
# Actual endpoint testing would require complex JOIN mocking
|
||||
api = statistic_module.AverageSessionInteractionStatistic()
|
||||
method = _unwrap(api.get)
|
||||
assert callable(method)
|
||||
|
||||
|
||||
def test_daily_message_statistic_with_invalid_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyMessageStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
def mock_parse(*args, **kwargs):
|
||||
raise ValueError("Invalid time range")
|
||||
|
||||
_install_db(monkeypatch, [])
|
||||
monkeypatch.setattr(
|
||||
statistic_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(statistic_module, "parse_time_range", mock_parse)
|
||||
monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
|
||||
with pytest.raises(BadRequest):
|
||||
method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
|
||||
def test_daily_message_statistic_multiple_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyMessageStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
rows = [
|
||||
SimpleNamespace(date="2024-01-01", message_count=10),
|
||||
SimpleNamespace(date="2024-01-02", message_count=15),
|
||||
SimpleNamespace(date="2024-01-03", message_count=12),
|
||||
]
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
data = response.get_json()
|
||||
assert len(data["data"]) == 3
|
||||
|
||||
|
||||
def test_daily_message_statistic_empty_result(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyMessageStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, [])
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": []}
|
||||
|
||||
|
||||
def test_daily_conversation_statistic_with_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyConversationStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
rows = [SimpleNamespace(date="2024-01-02", conversation_count=5)]
|
||||
_install_db(monkeypatch, rows)
|
||||
monkeypatch.setattr(
|
||||
statistic_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
statistic_module,
|
||||
"parse_time_range",
|
||||
lambda *_args, **_kwargs: ("s", "e"),
|
||||
)
|
||||
monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-conversations", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]}
|
||||
|
||||
|
||||
def test_daily_token_cost_with_multiple_currencies(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyTokenCostStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
rows = [
|
||||
SimpleNamespace(date="2024-01-01", token_count=100, total_price=Decimal("0.50"), currency="USD"),
|
||||
SimpleNamespace(date="2024-01-02", token_count=200, total_price=Decimal("1.00"), currency="USD"),
|
||||
]
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/token-costs", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
data = response.get_json()
|
||||
assert len(data["data"]) == 2
|
||||
163
api/tests/unit_tests/controllers/console/app/test_workflow.py
Normal file
163
api/tests/unit_tests/controllers/console/app/test_workflow.py
Normal file
@ -0,0 +1,163 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import HTTPException, NotFound
|
||||
|
||||
from controllers.console.app import workflow as workflow_module
|
||||
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from dify_graph.file.enums import FileTransferMethod, FileType
|
||||
from dify_graph.file.models import File
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def test_parse_file_no_config(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(workflow_module.FileUploadConfigManager, "convert", lambda *_args, **_kwargs: None)
|
||||
workflow = SimpleNamespace(features_dict={}, tenant_id="t1")
|
||||
|
||||
assert workflow_module._parse_file(workflow, files=[{"id": "f"}]) == []
|
||||
|
||||
|
||||
def test_parse_file_with_config(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
config = object()
|
||||
file_list = [
|
||||
File(
|
||||
tenant_id="t1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="http://u",
|
||||
)
|
||||
]
|
||||
build_mock = Mock(return_value=file_list)
|
||||
monkeypatch.setattr(workflow_module.FileUploadConfigManager, "convert", lambda *_args, **_kwargs: config)
|
||||
monkeypatch.setattr(workflow_module.file_factory, "build_from_mappings", build_mock)
|
||||
|
||||
workflow = SimpleNamespace(features_dict={}, tenant_id="t1")
|
||||
result = workflow_module._parse_file(workflow, files=[{"id": "f"}])
|
||||
|
||||
assert result == file_list
|
||||
build_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_sync_draft_workflow_invalid_content_type(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = workflow_module.DraftWorkflowApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
|
||||
|
||||
with app.test_request_context("/apps/app/workflows/draft", method="POST", data="x", content_type="text/html"):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
handler(api, app_model=SimpleNamespace(id="app"))
|
||||
|
||||
assert exc.value.code == 415
|
||||
|
||||
|
||||
def test_sync_draft_workflow_invalid_json(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = workflow_module.DraftWorkflowApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/app/workflows/draft",
|
||||
method="POST",
|
||||
data="[]",
|
||||
content_type="application/json",
|
||||
):
|
||||
response, status = handler(api, app_model=SimpleNamespace(id="app"))
|
||||
|
||||
assert status == 400
|
||||
assert response["message"] == "Invalid JSON data"
|
||||
|
||||
|
||||
def test_sync_draft_workflow_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow = SimpleNamespace(
|
||||
unique_hash="h",
|
||||
updated_at=None,
|
||||
created_at=datetime(2024, 1, 1),
|
||||
)
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
|
||||
monkeypatch.setattr(
|
||||
workflow_module.variable_factory, "build_environment_variable_from_mapping", lambda *_args: "env"
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_module.variable_factory, "build_conversation_variable_from_mapping", lambda *_args: "conv"
|
||||
)
|
||||
|
||||
service = SimpleNamespace(sync_draft_workflow=lambda **_kwargs: workflow)
|
||||
monkeypatch.setattr(workflow_module, "WorkflowService", lambda: service)
|
||||
|
||||
api = workflow_module.DraftWorkflowApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/app/workflows/draft",
|
||||
method="POST",
|
||||
json={"graph": {}, "features": {}, "hash": "h"},
|
||||
):
|
||||
response = handler(api, app_model=SimpleNamespace(id="app"))
|
||||
|
||||
assert response["result"] == "success"
|
||||
|
||||
|
||||
def test_sync_draft_workflow_hash_mismatch(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
|
||||
|
||||
def _raise(*_args, **_kwargs):
|
||||
raise workflow_module.WorkflowHashNotEqualError()
|
||||
|
||||
service = SimpleNamespace(sync_draft_workflow=_raise)
|
||||
monkeypatch.setattr(workflow_module, "WorkflowService", lambda: service)
|
||||
|
||||
api = workflow_module.DraftWorkflowApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/app/workflows/draft",
|
||||
method="POST",
|
||||
json={"graph": {}, "features": {}, "hash": "h"},
|
||||
):
|
||||
with pytest.raises(DraftWorkflowNotSync):
|
||||
handler(api, app_model=SimpleNamespace(id="app"))
|
||||
|
||||
|
||||
def test_draft_workflow_get_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
workflow_module, "WorkflowService", lambda: SimpleNamespace(get_draft_workflow=lambda **_k: None)
|
||||
)
|
||||
|
||||
api = workflow_module.DraftWorkflowApi()
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
with pytest.raises(DraftWorkflowNotExist):
|
||||
handler(api, app_model=SimpleNamespace(id="app"))
|
||||
|
||||
|
||||
def test_advanced_chat_run_conversation_not_exists(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
workflow_module.AppGenerateService,
|
||||
"generate",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(
|
||||
workflow_module.services.errors.conversation.ConversationNotExistsError()
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
|
||||
|
||||
api = workflow_module.AdvancedChatDraftWorkflowRunApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/app/advanced-chat/workflows/draft/run",
|
||||
method="POST",
|
||||
json={"inputs": {}},
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=SimpleNamespace(id="app"))
|
||||
47
api/tests/unit_tests/controllers/console/app/test_wraps.py
Normal file
47
api/tests/unit_tests/controllers/console/app/test_wraps.py
Normal file
@ -0,0 +1,47 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.app import wraps as wraps_module
|
||||
from controllers.console.app.error import AppNotFoundError
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
|
||||
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
|
||||
@wraps_module.get_app_model
|
||||
def handler(app_model):
|
||||
return app_model.id
|
||||
|
||||
assert handler(app_id="app-1") == "app-1"
|
||||
|
||||
|
||||
def test_get_app_model_rejects_wrong_mode(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
|
||||
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
|
||||
@wraps_module.get_app_model(mode=[AppMode.COMPLETION])
|
||||
def handler(app_model):
|
||||
return app_model.id
|
||||
|
||||
with pytest.raises(AppNotFoundError):
|
||||
handler(app_id="app-1")
|
||||
|
||||
|
||||
def test_get_app_model_requires_app_id() -> None:
|
||||
@wraps_module.get_app_model
|
||||
def handler(app_model):
|
||||
return app_model.id
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
handler()
|
||||
402
api/tests/unit_tests/controllers/console/explore/test_audio.py
Normal file
402
api/tests/unit_tests/controllers/console/explore/test_audio.py
Normal file
@ -0,0 +1,402 @@
|
||||
from io import BytesIO
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import controllers.console.explore.audio as audio_module
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
CompletionRequestError,
|
||||
NoAudioUploadedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeError
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
NoAudioUploadedServiceError,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def installed_app():
|
||||
app = MagicMock()
|
||||
app.app = MagicMock()
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def audio_file():
|
||||
return (BytesIO(b"audio"), "audio.wav")
|
||||
|
||||
|
||||
class TestChatAudioApi:
|
||||
def setup_method(self):
|
||||
self.api = audio_module.ChatAudioApi()
|
||||
self.method = unwrap(self.api.post)
|
||||
|
||||
def test_post_success(self, app, installed_app, audio_file):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
data={"file": audio_file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_asr",
|
||||
return_value={"text": "ok"},
|
||||
),
|
||||
):
|
||||
resp = self.method(installed_app)
|
||||
|
||||
assert resp == {"text": "ok"}
|
||||
|
||||
def test_app_unavailable(self, app, installed_app, audio_file):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
data={"file": audio_file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_asr",
|
||||
side_effect=audio_module.services.errors.app_model_config.AppModelConfigBrokenError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(AppUnavailableError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_no_audio_uploaded(self, app, installed_app, audio_file):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
data={"file": audio_file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_asr",
|
||||
side_effect=NoAudioUploadedServiceError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NoAudioUploadedError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_audio_too_large(self, app, installed_app, audio_file):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
data={"file": audio_file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_asr",
|
||||
side_effect=AudioTooLargeServiceError("too big"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(AudioTooLargeError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_provider_quota_exceeded(self, app, installed_app, audio_file):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
data={"file": audio_file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_asr",
|
||||
side_effect=QuotaExceededError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_unknown_exception(self, app, installed_app, audio_file):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
data={"file": audio_file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_asr",
|
||||
side_effect=Exception("boom"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(InternalServerError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_unsupported_audio_type(self, app, installed_app, audio_file):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
data={"file": audio_file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_asr",
|
||||
side_effect=audio_module.UnsupportedAudioTypeServiceError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(audio_module.UnsupportedAudioTypeError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_provider_not_support_speech_to_text(self, app, installed_app, audio_file):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
data={"file": audio_file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_asr",
|
||||
side_effect=audio_module.ProviderNotSupportSpeechToTextServiceError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(audio_module.ProviderNotSupportSpeechToTextError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_provider_not_initialized(self, app, installed_app, audio_file):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
data={"file": audio_file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_asr",
|
||||
side_effect=ProviderTokenNotInitError("not init"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_model_currently_not_supported(self, app, installed_app, audio_file):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
data={"file": audio_file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_asr",
|
||||
side_effect=ModelCurrentlyNotSupportError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ProviderModelCurrentlyNotSupportError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_invoke_error_asr(self, app, installed_app, audio_file):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
data={"file": audio_file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_asr",
|
||||
side_effect=InvokeError("invoke failed"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(CompletionRequestError):
|
||||
self.method(installed_app)
|
||||
|
||||
|
||||
class TestChatTextApi:
|
||||
def setup_method(self):
|
||||
self.api = audio_module.ChatTextApi()
|
||||
self.method = unwrap(self.api.post)
|
||||
|
||||
def test_post_success(self, app, installed_app):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
json={"message_id": "m1", "text": "hello", "voice": "v1"},
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_tts",
|
||||
return_value={"audio": "ok"},
|
||||
),
|
||||
):
|
||||
resp = self.method(installed_app)
|
||||
|
||||
assert resp == {"audio": "ok"}
|
||||
|
||||
def test_provider_not_initialized(self, app, installed_app):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
json={"text": "hi"},
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_tts",
|
||||
side_effect=ProviderTokenNotInitError("not init"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_model_not_supported(self, app, installed_app):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
json={"text": "hi"},
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_tts",
|
||||
side_effect=ModelCurrentlyNotSupportError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ProviderModelCurrentlyNotSupportError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_invoke_error(self, app, installed_app):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
json={"text": "hi"},
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_tts",
|
||||
side_effect=InvokeError("invoke failed"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(CompletionRequestError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_unknown_exception(self, app, installed_app):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
json={"text": "hi"},
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_tts",
|
||||
side_effect=Exception("boom"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(InternalServerError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_app_unavailable_tts(self, app, installed_app):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
json={"text": "hi"},
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_tts",
|
||||
side_effect=audio_module.services.errors.app_model_config.AppModelConfigBrokenError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(AppUnavailableError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_no_audio_uploaded_tts(self, app, installed_app):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
json={"text": "hi"},
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_tts",
|
||||
side_effect=NoAudioUploadedServiceError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NoAudioUploadedError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_audio_too_large_tts(self, app, installed_app):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
json={"text": "hi"},
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_tts",
|
||||
side_effect=AudioTooLargeServiceError("too big"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(AudioTooLargeError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_unsupported_audio_type_tts(self, app, installed_app):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
json={"text": "hi"},
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_tts",
|
||||
side_effect=audio_module.UnsupportedAudioTypeServiceError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(audio_module.UnsupportedAudioTypeError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_provider_not_support_speech_to_text_tts(self, app, installed_app):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
json={"text": "hi"},
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_tts",
|
||||
side_effect=audio_module.ProviderNotSupportSpeechToTextServiceError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(audio_module.ProviderNotSupportSpeechToTextError):
|
||||
self.method(installed_app)
|
||||
|
||||
def test_quota_exceeded_tts(self, app, installed_app):
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
json={"text": "hi"},
|
||||
),
|
||||
patch.object(
|
||||
audio_module.AudioService,
|
||||
"transcript_tts",
|
||||
side_effect=QuotaExceededError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
self.method(installed_app)
|
||||
100
api/tests/unit_tests/controllers/console/explore/test_banner.py
Normal file
100
api/tests/unit_tests/controllers/console/explore/test_banner.py
Normal file
@ -0,0 +1,100 @@
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import controllers.console.explore.banner as banner_module
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestBannerApi:
|
||||
def test_get_banners_with_requested_language(self, app):
|
||||
api = banner_module.BannerApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
banner = MagicMock()
|
||||
banner.id = "b1"
|
||||
banner.content = {"text": "hello"}
|
||||
banner.link = "https://example.com"
|
||||
banner.sort = 1
|
||||
banner.status = "enabled"
|
||||
banner.created_at = datetime(2024, 1, 1)
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.order_by.return_value = query
|
||||
query.all.return_value = [banner]
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
|
||||
with app.test_request_context("/?language=fr-FR"), patch.object(banner_module.db, "session", session):
|
||||
result = method(api)
|
||||
|
||||
assert result == [
|
||||
{
|
||||
"id": "b1",
|
||||
"content": {"text": "hello"},
|
||||
"link": "https://example.com",
|
||||
"sort": 1,
|
||||
"status": "enabled",
|
||||
"created_at": "2024-01-01T00:00:00",
|
||||
}
|
||||
]
|
||||
|
||||
def test_get_banners_fallback_to_en_us(self, app):
|
||||
api = banner_module.BannerApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
banner = MagicMock()
|
||||
banner.id = "b2"
|
||||
banner.content = {"text": "fallback"}
|
||||
banner.link = None
|
||||
banner.sort = 1
|
||||
banner.status = "enabled"
|
||||
banner.created_at = None
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.order_by.return_value = query
|
||||
query.all.side_effect = [
|
||||
[],
|
||||
[banner],
|
||||
]
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
|
||||
with app.test_request_context("/?language=es-ES"), patch.object(banner_module.db, "session", session):
|
||||
result = method(api)
|
||||
|
||||
assert result == [
|
||||
{
|
||||
"id": "b2",
|
||||
"content": {"text": "fallback"},
|
||||
"link": None,
|
||||
"sort": 1,
|
||||
"status": "enabled",
|
||||
"created_at": None,
|
||||
}
|
||||
]
|
||||
|
||||
def test_get_banners_default_language_en_us(self, app):
|
||||
api = banner_module.BannerApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.order_by.return_value = query
|
||||
query.all.return_value = []
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
|
||||
with app.test_request_context("/"), patch.object(banner_module.db, "session", session):
|
||||
result = method(api)
|
||||
|
||||
assert result == []
|
||||
@ -0,0 +1,459 @@
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import controllers.console.explore.completion as completion_module
|
||||
from controllers.console.app.error import (
|
||||
ConversationCompletedError,
|
||||
)
|
||||
from controllers.console.explore.error import NotChatAppError, NotCompletionAppError
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from models import Account
|
||||
from models.model import AppMode
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user():
|
||||
return MagicMock(spec=Account)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def completion_app():
|
||||
return MagicMock(app=MagicMock(mode=AppMode.COMPLETION))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chat_app():
|
||||
return MagicMock(app=MagicMock(mode=AppMode.CHAT))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def payload_data():
|
||||
return {"inputs": {}, "query": "hi"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def payload_patch(payload_data):
|
||||
return patch.object(
|
||||
type(completion_module.console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload_data,
|
||||
)
|
||||
|
||||
|
||||
class TestCompletionApi:
|
||||
def test_post_success(self, app, completion_app, user, payload_patch):
|
||||
api = completion_module.CompletionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
patch.object(
|
||||
completion_module.helper,
|
||||
"compact_generate_response",
|
||||
return_value=("ok", 200),
|
||||
),
|
||||
):
|
||||
result = method(completion_app)
|
||||
|
||||
assert result == ("ok", 200)
|
||||
|
||||
def test_post_wrong_app_mode(self):
|
||||
api = completion_module.CompletionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
installed_app = MagicMock(app=MagicMock(mode=AppMode.CHAT))
|
||||
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
method(installed_app)
|
||||
|
||||
def test_conversation_completed(self, app, completion_app, user, payload_patch):
|
||||
api = completion_module.CompletionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=completion_module.services.errors.conversation.ConversationCompletedError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ConversationCompletedError):
|
||||
method(completion_app)
|
||||
|
||||
def test_internal_error(self, app, completion_app, user, payload_patch):
|
||||
api = completion_module.CompletionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=Exception("boom"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(InternalServerError):
|
||||
method(completion_app)
|
||||
|
||||
def test_conversation_not_exists(self, app, completion_app, user, payload_patch):
|
||||
api = completion_module.CompletionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=completion_module.services.errors.conversation.ConversationNotExistsError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.NotFound):
|
||||
method(completion_app)
|
||||
|
||||
def test_app_unavailable(self, app, completion_app, user, payload_patch):
|
||||
api = completion_module.CompletionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=completion_module.services.errors.app_model_config.AppModelConfigBrokenError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.AppUnavailableError):
|
||||
method(completion_app)
|
||||
|
||||
def test_provider_not_initialized(self, app, completion_app, user, payload_patch):
|
||||
api = completion_module.CompletionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=completion_module.ProviderTokenNotInitError("not init"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.ProviderNotInitializeError):
|
||||
method(completion_app)
|
||||
|
||||
def test_quota_exceeded(self, app, completion_app, user, payload_patch):
|
||||
api = completion_module.CompletionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=completion_module.QuotaExceededError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.ProviderQuotaExceededError):
|
||||
method(completion_app)
|
||||
|
||||
def test_model_not_supported(self, app, completion_app, user, payload_patch):
|
||||
api = completion_module.CompletionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=completion_module.ModelCurrentlyNotSupportError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.ProviderModelCurrentlyNotSupportError):
|
||||
method(completion_app)
|
||||
|
||||
def test_invoke_error(self, app, completion_app, user, payload_patch):
|
||||
api = completion_module.CompletionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=completion_module.InvokeError("invoke failed"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.CompletionRequestError):
|
||||
method(completion_app)
|
||||
|
||||
|
||||
class TestCompletionStopApi:
|
||||
def test_stop_success(self, completion_app, user):
|
||||
api = completion_module.CompletionStopApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user.id = "u1"
|
||||
|
||||
with (
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(completion_module.AppTaskService, "stop_task"),
|
||||
):
|
||||
resp, status = method(completion_app, "task-1")
|
||||
|
||||
assert status == 200
|
||||
assert resp == {"result": "success"}
|
||||
|
||||
def test_stop_wrong_app_mode(self):
|
||||
api = completion_module.CompletionStopApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
installed_app = MagicMock(app=MagicMock(mode=AppMode.CHAT))
|
||||
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
method(installed_app, "task")
|
||||
|
||||
|
||||
class TestChatApi:
|
||||
def test_post_success(self, app, chat_app, user, payload_patch):
|
||||
api = completion_module.ChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
patch.object(
|
||||
completion_module.helper,
|
||||
"compact_generate_response",
|
||||
return_value=("ok", 200),
|
||||
),
|
||||
):
|
||||
result = method(chat_app)
|
||||
|
||||
assert result == ("ok", 200)
|
||||
|
||||
def test_post_not_chat_app(self):
|
||||
api = completion_module.ChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
installed_app = MagicMock(app=MagicMock(mode=AppMode.COMPLETION))
|
||||
|
||||
with pytest.raises(NotChatAppError):
|
||||
method(installed_app)
|
||||
|
||||
def test_rate_limit_error(self, app, chat_app, user, payload_patch):
|
||||
api = completion_module.ChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=InvokeRateLimitError("limit"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvokeRateLimitHttpError):
|
||||
method(chat_app)
|
||||
|
||||
def test_conversation_completed_chat(self, app, chat_app, user, payload_patch):
|
||||
api = completion_module.ChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=completion_module.services.errors.conversation.ConversationCompletedError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ConversationCompletedError):
|
||||
method(chat_app)
|
||||
|
||||
def test_conversation_not_exists_chat(self, app, chat_app, user, payload_patch):
|
||||
api = completion_module.ChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=completion_module.services.errors.conversation.ConversationNotExistsError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.NotFound):
|
||||
method(chat_app)
|
||||
|
||||
def test_app_unavailable_chat(self, app, chat_app, user, payload_patch):
|
||||
api = completion_module.ChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=completion_module.services.errors.app_model_config.AppModelConfigBrokenError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.AppUnavailableError):
|
||||
method(chat_app)
|
||||
|
||||
def test_provider_not_initialized_chat(self, app, chat_app, user, payload_patch):
|
||||
api = completion_module.ChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=completion_module.ProviderTokenNotInitError("not init"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.ProviderNotInitializeError):
|
||||
method(chat_app)
|
||||
|
||||
def test_quota_exceeded_chat(self, app, chat_app, user, payload_patch):
|
||||
api = completion_module.ChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=completion_module.QuotaExceededError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.ProviderQuotaExceededError):
|
||||
method(chat_app)
|
||||
|
||||
def test_model_not_supported_chat(self, app, chat_app, user, payload_patch):
|
||||
api = completion_module.ChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=completion_module.ModelCurrentlyNotSupportError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.ProviderModelCurrentlyNotSupportError):
|
||||
method(chat_app)
|
||||
|
||||
def test_invoke_error_chat(self, app, chat_app, user, payload_patch):
|
||||
api = completion_module.ChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=completion_module.InvokeError("invoke failed"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(completion_module.CompletionRequestError):
|
||||
method(chat_app)
|
||||
|
||||
def test_internal_error_chat(self, app, chat_app, user, payload_patch):
|
||||
api = completion_module.ChatApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
payload_patch,
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
side_effect=Exception("boom"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(InternalServerError):
|
||||
method(chat_app)
|
||||
|
||||
|
||||
class TestChatStopApi:
|
||||
def test_stop_success(self, chat_app, user):
|
||||
api = completion_module.ChatStopApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user.id = "u1"
|
||||
|
||||
with (
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(completion_module.AppTaskService, "stop_task"),
|
||||
):
|
||||
resp, status = method(chat_app, "task-1")
|
||||
|
||||
assert status == 200
|
||||
assert resp == {"result": "success"}
|
||||
|
||||
def test_stop_not_chat_app(self):
|
||||
api = completion_module.ChatStopApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
installed_app = MagicMock(app=MagicMock(mode=AppMode.COMPLETION))
|
||||
|
||||
with pytest.raises(NotChatAppError):
|
||||
method(installed_app, "task")
|
||||
@ -0,0 +1,232 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import controllers.console.explore.conversation as conversation_module
|
||||
from controllers.console.explore.error import NotChatAppError
|
||||
from models import Account
|
||||
from models.model import AppMode
|
||||
from services.errors.conversation import (
|
||||
ConversationNotExistsError,
|
||||
LastConversationNotExistsError,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class FakeConversation:
|
||||
def __init__(self, cid):
|
||||
self.id = cid
|
||||
self.name = "test"
|
||||
self.inputs = {}
|
||||
self.status = "normal"
|
||||
self.introduction = ""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chat_app():
|
||||
app_model = MagicMock(mode=AppMode.CHAT, id="app-id")
|
||||
return MagicMock(app=app_model)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def non_chat_app():
|
||||
app_model = MagicMock(mode=AppMode.COMPLETION)
|
||||
return MagicMock(app=app_model)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user():
|
||||
user = MagicMock(spec=Account)
|
||||
user.id = "uid"
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_db_and_session():
|
||||
with (
|
||||
patch.object(
|
||||
conversation_module,
|
||||
"db",
|
||||
MagicMock(session=MagicMock(), engine=MagicMock()),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.explore.conversation.Session",
|
||||
MagicMock(),
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
class TestConversationListApi:
|
||||
def test_get_success(self, app: Flask, chat_app, user):
|
||||
api = conversation_module.ConversationListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pagination = MagicMock(
|
||||
limit=20,
|
||||
has_more=False,
|
||||
data=[FakeConversation("c1"), FakeConversation("c2")],
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?limit=20"),
|
||||
patch.object(conversation_module, "current_user", user),
|
||||
patch.object(
|
||||
conversation_module.WebConversationService,
|
||||
"pagination_by_last_id",
|
||||
return_value=pagination,
|
||||
),
|
||||
):
|
||||
result = method(chat_app)
|
||||
|
||||
assert result["limit"] == 20
|
||||
assert result["has_more"] is False
|
||||
assert len(result["data"]) == 2
|
||||
|
||||
def test_last_conversation_not_exists(self, app: Flask, chat_app, user):
|
||||
api = conversation_module.ConversationListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(conversation_module, "current_user", user),
|
||||
patch.object(
|
||||
conversation_module.WebConversationService,
|
||||
"pagination_by_last_id",
|
||||
side_effect=LastConversationNotExistsError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(chat_app)
|
||||
|
||||
def test_wrong_app_mode(self, app: Flask, non_chat_app):
|
||||
api = conversation_module.ConversationListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
method(non_chat_app)
|
||||
|
||||
|
||||
class TestConversationApi:
|
||||
def test_delete_success(self, app: Flask, chat_app, user):
|
||||
api = conversation_module.ConversationApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(conversation_module, "current_user", user),
|
||||
patch.object(
|
||||
conversation_module.ConversationService,
|
||||
"delete",
|
||||
),
|
||||
):
|
||||
result = method(chat_app, "cid")
|
||||
|
||||
body, status = result
|
||||
assert status == 204
|
||||
assert body["result"] == "success"
|
||||
|
||||
def test_delete_not_found(self, app: Flask, chat_app, user):
|
||||
api = conversation_module.ConversationApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(conversation_module, "current_user", user),
|
||||
patch.object(
|
||||
conversation_module.ConversationService,
|
||||
"delete",
|
||||
side_effect=ConversationNotExistsError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(chat_app, "cid")
|
||||
|
||||
def test_delete_wrong_app_mode(self, app: Flask, non_chat_app):
|
||||
api = conversation_module.ConversationApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
method(non_chat_app, "cid")
|
||||
|
||||
|
||||
class TestConversationRenameApi:
|
||||
def test_rename_success(self, app: Flask, chat_app, user):
|
||||
api = conversation_module.ConversationRenameApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
conversation = FakeConversation("cid")
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"name": "new"}),
|
||||
patch.object(conversation_module, "current_user", user),
|
||||
patch.object(
|
||||
conversation_module.ConversationService,
|
||||
"rename",
|
||||
return_value=conversation,
|
||||
),
|
||||
):
|
||||
result = method(chat_app, "cid")
|
||||
|
||||
assert result["id"] == "cid"
|
||||
|
||||
def test_rename_not_found(self, app: Flask, chat_app, user):
|
||||
api = conversation_module.ConversationRenameApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"name": "new"}),
|
||||
patch.object(conversation_module, "current_user", user),
|
||||
patch.object(
|
||||
conversation_module.ConversationService,
|
||||
"rename",
|
||||
side_effect=ConversationNotExistsError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(chat_app, "cid")
|
||||
|
||||
|
||||
class TestConversationPinApi:
|
||||
def test_pin_success(self, app: Flask, chat_app, user):
|
||||
api = conversation_module.ConversationPinApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(conversation_module, "current_user", user),
|
||||
patch.object(
|
||||
conversation_module.WebConversationService,
|
||||
"pin",
|
||||
),
|
||||
):
|
||||
result = method(chat_app, "cid")
|
||||
|
||||
assert result == {"result": "success"}
|
||||
|
||||
|
||||
class TestConversationUnPinApi:
|
||||
def test_unpin_success(self, app: Flask, chat_app, user):
|
||||
api = conversation_module.ConversationUnPinApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(conversation_module, "current_user", user),
|
||||
patch.object(
|
||||
conversation_module.WebConversationService,
|
||||
"unpin",
|
||||
),
|
||||
):
|
||||
result = method(chat_app, "cid")
|
||||
|
||||
assert result == {"result": "success"}
|
||||
@ -0,0 +1,363 @@
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
import controllers.console.explore.installed_app as module
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_id():
|
||||
return "t1"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def current_user(tenant_id):
|
||||
user = MagicMock()
|
||||
user.id = "u1"
|
||||
user.current_tenant = MagicMock(id=tenant_id)
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def installed_app():
|
||||
app = MagicMock()
|
||||
app.id = "ia1"
|
||||
app.app = MagicMock(id="a1")
|
||||
app.app_owner_tenant_id = "t2"
|
||||
app.is_pinned = False
|
||||
app.last_used_at = datetime(2024, 1, 1)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def payload_patch():
|
||||
def _patch(payload):
|
||||
return patch.object(
|
||||
type(module.console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
)
|
||||
|
||||
return _patch
|
||||
|
||||
|
||||
class TestInstalledAppsListApi:
|
||||
def test_get_installed_apps(self, app, current_user, tenant_id, installed_app):
|
||||
api = module.InstalledAppsListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = [installed_app]
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
|
||||
patch.object(module.db, "session", session),
|
||||
patch.object(module.TenantService, "get_user_role", return_value="owner"),
|
||||
patch.object(
|
||||
module.FeatureService,
|
||||
"get_system_features",
|
||||
return_value=MagicMock(webapp_auth=MagicMock(enabled=False)),
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert "installed_apps" in result
|
||||
assert result["installed_apps"][0]["editable"] is True
|
||||
assert result["installed_apps"][0]["uninstallable"] is False
|
||||
|
||||
def test_get_installed_apps_with_app_id_filter(self, app, current_user, tenant_id):
|
||||
api = module.InstalledAppsListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = []
|
||||
|
||||
with (
|
||||
app.test_request_context("/?app_id=a1"),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
|
||||
patch.object(module.db, "session", session),
|
||||
patch.object(module.TenantService, "get_user_role", return_value="member"),
|
||||
patch.object(
|
||||
module.FeatureService,
|
||||
"get_system_features",
|
||||
return_value=MagicMock(webapp_auth=MagicMock(enabled=False)),
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result == {"installed_apps": []}
|
||||
|
||||
def test_get_installed_apps_with_webapp_auth_enabled(self, app, current_user, tenant_id, installed_app):
|
||||
"""Test filtering when webapp_auth is enabled."""
|
||||
api = module.InstalledAppsListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = [installed_app]
|
||||
|
||||
mock_webapp_setting = MagicMock()
|
||||
mock_webapp_setting.access_mode = "restricted"
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
|
||||
patch.object(module.db, "session", session),
|
||||
patch.object(module.TenantService, "get_user_role", return_value="owner"),
|
||||
patch.object(
|
||||
module.FeatureService,
|
||||
"get_system_features",
|
||||
return_value=MagicMock(webapp_auth=MagicMock(enabled=True)),
|
||||
),
|
||||
patch.object(
|
||||
module.EnterpriseService.WebAppAuth,
|
||||
"batch_get_app_access_mode_by_id",
|
||||
return_value={"a1": mock_webapp_setting},
|
||||
),
|
||||
patch.object(
|
||||
module.EnterpriseService.WebAppAuth,
|
||||
"batch_is_user_allowed_to_access_webapps",
|
||||
return_value={"a1": True},
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert len(result["installed_apps"]) == 1
|
||||
|
||||
def test_get_installed_apps_with_webapp_auth_user_denied(self, app, current_user, tenant_id, installed_app):
|
||||
"""Test filtering when user doesn't have access."""
|
||||
api = module.InstalledAppsListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = [installed_app]
|
||||
|
||||
mock_webapp_setting = MagicMock()
|
||||
mock_webapp_setting.access_mode = "restricted"
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
|
||||
patch.object(module.db, "session", session),
|
||||
patch.object(module.TenantService, "get_user_role", return_value="member"),
|
||||
patch.object(
|
||||
module.FeatureService,
|
||||
"get_system_features",
|
||||
return_value=MagicMock(webapp_auth=MagicMock(enabled=True)),
|
||||
),
|
||||
patch.object(
|
||||
module.EnterpriseService.WebAppAuth,
|
||||
"batch_get_app_access_mode_by_id",
|
||||
return_value={"a1": mock_webapp_setting},
|
||||
),
|
||||
patch.object(
|
||||
module.EnterpriseService.WebAppAuth,
|
||||
"batch_is_user_allowed_to_access_webapps",
|
||||
return_value={"a1": False},
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["installed_apps"] == []
|
||||
|
||||
def test_get_installed_apps_with_sso_verified_access(self, app, current_user, tenant_id, installed_app):
|
||||
"""Test that sso_verified access mode apps are skipped in filtering."""
|
||||
api = module.InstalledAppsListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = [installed_app]
|
||||
|
||||
mock_webapp_setting = MagicMock()
|
||||
mock_webapp_setting.access_mode = "sso_verified"
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
|
||||
patch.object(module.db, "session", session),
|
||||
patch.object(module.TenantService, "get_user_role", return_value="owner"),
|
||||
patch.object(
|
||||
module.FeatureService,
|
||||
"get_system_features",
|
||||
return_value=MagicMock(webapp_auth=MagicMock(enabled=True)),
|
||||
),
|
||||
patch.object(
|
||||
module.EnterpriseService.WebAppAuth,
|
||||
"batch_get_app_access_mode_by_id",
|
||||
return_value={"a1": mock_webapp_setting},
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert len(result["installed_apps"]) == 0
|
||||
|
||||
def test_get_installed_apps_filters_null_apps(self, app, current_user, tenant_id):
|
||||
"""Test that installed apps with null app are filtered out."""
|
||||
api = module.InstalledAppsListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app_with_null = MagicMock()
|
||||
installed_app_with_null.app = None
|
||||
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = [installed_app_with_null]
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
|
||||
patch.object(module.db, "session", session),
|
||||
patch.object(module.TenantService, "get_user_role", return_value="owner"),
|
||||
patch.object(
|
||||
module.FeatureService,
|
||||
"get_system_features",
|
||||
return_value=MagicMock(webapp_auth=MagicMock(enabled=False)),
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["installed_apps"] == []
|
||||
|
||||
def test_get_installed_apps_current_tenant_none(self, app, tenant_id, installed_app):
|
||||
"""Test error when current_user.current_tenant is None."""
|
||||
api = module.InstalledAppsListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
current_user = MagicMock()
|
||||
current_user.current_tenant = None
|
||||
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = [installed_app]
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
|
||||
patch.object(module.db, "session", session),
|
||||
):
|
||||
with pytest.raises(ValueError, match="current_user.current_tenant must not be None"):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestInstalledAppsCreateApi:
|
||||
def test_post_success(self, app, tenant_id, payload_patch):
|
||||
api = module.InstalledAppsListApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
recommended = MagicMock()
|
||||
recommended.install_count = 0
|
||||
|
||||
app_entity = MagicMock()
|
||||
app_entity.id = "a1"
|
||||
app_entity.is_public = True
|
||||
app_entity.tenant_id = "t2"
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.side_effect = [
|
||||
recommended,
|
||||
app_entity,
|
||||
None,
|
||||
]
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"app_id": "a1"}),
|
||||
payload_patch({"app_id": "a1"}),
|
||||
patch.object(module.db, "session", session),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result == {"message": "App installed successfully"}
|
||||
assert recommended.install_count == 1
|
||||
|
||||
def test_post_recommended_not_found(self, app, payload_patch):
|
||||
api = module.InstalledAppsListApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"app_id": "a1"}),
|
||||
payload_patch({"app_id": "a1"}),
|
||||
patch.object(module.db, "session", session),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api)
|
||||
|
||||
def test_post_app_not_public(self, app, tenant_id, payload_patch):
|
||||
api = module.InstalledAppsListApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
recommended = MagicMock()
|
||||
app_entity = MagicMock(is_public=False)
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.side_effect = [
|
||||
recommended,
|
||||
app_entity,
|
||||
]
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"app_id": "a1"}),
|
||||
payload_patch({"app_id": "a1"}),
|
||||
patch.object(module.db, "session", session),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestInstalledAppApi:
|
||||
def test_delete_success(self, tenant_id, installed_app):
|
||||
api = module.InstalledAppApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)),
|
||||
patch.object(module.db, "session"),
|
||||
):
|
||||
resp, status = method(installed_app)
|
||||
|
||||
assert status == 204
|
||||
assert resp["result"] == "success"
|
||||
|
||||
def test_delete_owned_by_current_tenant(self, tenant_id):
|
||||
api = module.InstalledAppApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
installed_app = MagicMock(app_owner_tenant_id=tenant_id)
|
||||
|
||||
with patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)):
|
||||
with pytest.raises(BadRequest):
|
||||
method(installed_app)
|
||||
|
||||
def test_patch_update_pin(self, app, payload_patch, installed_app):
|
||||
api = module.InstalledAppApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"is_pinned": True}),
|
||||
payload_patch({"is_pinned": True}),
|
||||
patch.object(module.db, "session"),
|
||||
):
|
||||
result = method(installed_app)
|
||||
|
||||
assert installed_app.is_pinned is True
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_patch_no_change(self, app, payload_patch, installed_app):
|
||||
api = module.InstalledAppApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
with app.test_request_context("/", json={}), payload_patch({}), patch.object(module.db, "session"):
|
||||
result = method(installed_app)
|
||||
|
||||
assert result["result"] == "success"
|
||||
552
api/tests/unit_tests/controllers/console/explore/test_message.py
Normal file
552
api/tests/unit_tests/controllers/console/explore/test_message.py
Normal file
@ -0,0 +1,552 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import controllers.console.explore.message as module
|
||||
from controllers.console.app.error import (
|
||||
AppMoreLikeThisDisabledError,
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.explore.error import (
|
||||
AppSuggestedQuestionsAfterAnswerDisabledError,
|
||||
NotChatAppError,
|
||||
NotCompletionAppError,
|
||||
)
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import (
|
||||
FirstMessageNotExistsError,
|
||||
MessageNotExistsError,
|
||||
SuggestedQuestionsAfterAnswerDisabledError,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def make_message():
|
||||
msg = MagicMock()
|
||||
msg.id = "m1"
|
||||
msg.conversation_id = "11111111-1111-1111-1111-111111111111"
|
||||
msg.parent_message_id = None
|
||||
msg.inputs = {}
|
||||
msg.query = "hello"
|
||||
msg.re_sign_file_url_answer = ""
|
||||
msg.user_feedback = MagicMock(rating=None)
|
||||
msg.status = "success"
|
||||
msg.error = None
|
||||
return msg
|
||||
|
||||
|
||||
class TestMessageListApi:
|
||||
def test_get_success(self, app):
|
||||
api = module.MessageListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="chat")
|
||||
|
||||
pagination = MagicMock(
|
||||
limit=20,
|
||||
has_more=False,
|
||||
data=[make_message(), make_message()],
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
query_string={"conversation_id": "11111111-1111-1111-1111-111111111111"},
|
||||
),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.MessageService,
|
||||
"pagination_by_first_id",
|
||||
return_value=pagination,
|
||||
),
|
||||
):
|
||||
result = method(installed_app)
|
||||
|
||||
assert result["limit"] == 20
|
||||
assert result["has_more"] is False
|
||||
assert len(result["data"]) == 2
|
||||
|
||||
def test_get_not_chat_app(self):
|
||||
api = module.MessageListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="completion")
|
||||
|
||||
with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)):
|
||||
with pytest.raises(NotChatAppError):
|
||||
method(installed_app)
|
||||
|
||||
def test_conversation_not_exists(self, app):
|
||||
api = module.MessageListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="chat")
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
query_string={"conversation_id": "11111111-1111-1111-1111-111111111111"},
|
||||
),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.MessageService,
|
||||
"pagination_by_first_id",
|
||||
side_effect=ConversationNotExistsError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(installed_app)
|
||||
|
||||
def test_first_message_not_exists(self, app):
|
||||
api = module.MessageListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="chat")
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
query_string={"conversation_id": "11111111-1111-1111-1111-111111111111"},
|
||||
),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.MessageService,
|
||||
"pagination_by_first_id",
|
||||
side_effect=FirstMessageNotExistsError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(installed_app)
|
||||
|
||||
|
||||
class TestMessageFeedbackApi:
|
||||
def test_post_success(self, app):
|
||||
api = module.MessageFeedbackApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"rating": "like"}),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.MessageService,
|
||||
"create_feedback",
|
||||
),
|
||||
):
|
||||
result = method(installed_app, "mid")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_message_not_exists(self, app):
|
||||
api = module.MessageFeedbackApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.MessageService,
|
||||
"create_feedback",
|
||||
side_effect=MessageNotExistsError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(installed_app, "mid")
|
||||
|
||||
|
||||
class TestMessageMoreLikeThisApi:
|
||||
def test_get_success(self, app):
|
||||
api = module.MessageMoreLikeThisApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="completion")
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
query_string={"response_mode": "blocking"},
|
||||
),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.AppGenerateService,
|
||||
"generate_more_like_this",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
patch.object(
|
||||
module.helper,
|
||||
"compact_generate_response",
|
||||
return_value=("ok", 200),
|
||||
),
|
||||
):
|
||||
resp = method(installed_app, "mid")
|
||||
|
||||
assert resp == ("ok", 200)
|
||||
|
||||
def test_not_completion_app(self):
|
||||
api = module.MessageMoreLikeThisApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="chat")
|
||||
|
||||
with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)):
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
method(installed_app, "mid")
|
||||
|
||||
def test_more_like_this_disabled(self, app):
|
||||
api = module.MessageMoreLikeThisApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="completion")
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
query_string={"response_mode": "blocking"},
|
||||
),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.AppGenerateService,
|
||||
"generate_more_like_this",
|
||||
side_effect=module.MoreLikeThisDisabledError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(AppMoreLikeThisDisabledError):
|
||||
method(installed_app, "mid")
|
||||
|
||||
def test_message_not_exists_more_like_this(self, app):
|
||||
api = module.MessageMoreLikeThisApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="completion")
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
query_string={"response_mode": "blocking"},
|
||||
),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.AppGenerateService,
|
||||
"generate_more_like_this",
|
||||
side_effect=MessageNotExistsError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(installed_app, "mid")
|
||||
|
||||
def test_provider_not_init_more_like_this(self, app):
|
||||
api = module.MessageMoreLikeThisApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="completion")
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
query_string={"response_mode": "blocking"},
|
||||
),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.AppGenerateService,
|
||||
"generate_more_like_this",
|
||||
side_effect=ProviderTokenNotInitError("test"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
method(installed_app, "mid")
|
||||
|
||||
def test_quota_exceeded_more_like_this(self, app):
|
||||
api = module.MessageMoreLikeThisApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="completion")
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
query_string={"response_mode": "blocking"},
|
||||
),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.AppGenerateService,
|
||||
"generate_more_like_this",
|
||||
side_effect=QuotaExceededError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
method(installed_app, "mid")
|
||||
|
||||
def test_model_not_support_more_like_this(self, app):
|
||||
api = module.MessageMoreLikeThisApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="completion")
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
query_string={"response_mode": "blocking"},
|
||||
),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.AppGenerateService,
|
||||
"generate_more_like_this",
|
||||
side_effect=ModelCurrentlyNotSupportError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ProviderModelCurrentlyNotSupportError):
|
||||
method(installed_app, "mid")
|
||||
|
||||
def test_invoke_error_more_like_this(self, app):
|
||||
api = module.MessageMoreLikeThisApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="completion")
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
query_string={"response_mode": "blocking"},
|
||||
),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.AppGenerateService,
|
||||
"generate_more_like_this",
|
||||
side_effect=InvokeError("test error"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(CompletionRequestError):
|
||||
method(installed_app, "mid")
|
||||
|
||||
def test_unexpected_error_more_like_this(self, app):
|
||||
api = module.MessageMoreLikeThisApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="completion")
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
query_string={"response_mode": "blocking"},
|
||||
),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.AppGenerateService,
|
||||
"generate_more_like_this",
|
||||
side_effect=Exception("unexpected"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(InternalServerError):
|
||||
method(installed_app, "mid")
|
||||
|
||||
|
||||
class TestMessageSuggestedQuestionApi:
|
||||
def test_get_success(self):
|
||||
api = module.MessageSuggestedQuestionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="chat")
|
||||
|
||||
with (
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.MessageService,
|
||||
"get_suggested_questions_after_answer",
|
||||
return_value=["q1", "q2"],
|
||||
),
|
||||
):
|
||||
result = method(installed_app, "mid")
|
||||
|
||||
assert result["data"] == ["q1", "q2"]
|
||||
|
||||
def test_not_chat_app(self):
|
||||
api = module.MessageSuggestedQuestionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="completion")
|
||||
|
||||
with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)):
|
||||
with pytest.raises(NotChatAppError):
|
||||
method(installed_app, "mid")
|
||||
|
||||
def test_disabled(self):
|
||||
api = module.MessageSuggestedQuestionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="chat")
|
||||
|
||||
with (
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.MessageService,
|
||||
"get_suggested_questions_after_answer",
|
||||
side_effect=SuggestedQuestionsAfterAnswerDisabledError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(AppSuggestedQuestionsAfterAnswerDisabledError):
|
||||
method(installed_app, "mid")
|
||||
|
||||
def test_message_not_exists_suggested_question(self):
|
||||
api = module.MessageSuggestedQuestionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="chat")
|
||||
|
||||
with (
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.MessageService,
|
||||
"get_suggested_questions_after_answer",
|
||||
side_effect=MessageNotExistsError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(installed_app, "mid")
|
||||
|
||||
def test_conversation_not_exists_suggested_question(self):
|
||||
api = module.MessageSuggestedQuestionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="chat")
|
||||
|
||||
with (
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.MessageService,
|
||||
"get_suggested_questions_after_answer",
|
||||
side_effect=ConversationNotExistsError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(installed_app, "mid")
|
||||
|
||||
def test_provider_not_init_suggested_question(self):
|
||||
api = module.MessageSuggestedQuestionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="chat")
|
||||
|
||||
with (
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.MessageService,
|
||||
"get_suggested_questions_after_answer",
|
||||
side_effect=ProviderTokenNotInitError("test"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
method(installed_app, "mid")
|
||||
|
||||
def test_quota_exceeded_suggested_question(self):
|
||||
api = module.MessageSuggestedQuestionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="chat")
|
||||
|
||||
with (
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.MessageService,
|
||||
"get_suggested_questions_after_answer",
|
||||
side_effect=QuotaExceededError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
method(installed_app, "mid")
|
||||
|
||||
def test_model_not_support_suggested_question(self):
|
||||
api = module.MessageSuggestedQuestionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="chat")
|
||||
|
||||
with (
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.MessageService,
|
||||
"get_suggested_questions_after_answer",
|
||||
side_effect=ModelCurrentlyNotSupportError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ProviderModelCurrentlyNotSupportError):
|
||||
method(installed_app, "mid")
|
||||
|
||||
def test_invoke_error_suggested_question(self):
|
||||
api = module.MessageSuggestedQuestionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="chat")
|
||||
|
||||
with (
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.MessageService,
|
||||
"get_suggested_questions_after_answer",
|
||||
side_effect=InvokeError("test error"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(CompletionRequestError):
|
||||
method(installed_app, "mid")
|
||||
|
||||
def test_unexpected_error_suggested_question(self):
|
||||
api = module.MessageSuggestedQuestionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="chat")
|
||||
|
||||
with (
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.MessageService,
|
||||
"get_suggested_questions_after_answer",
|
||||
side_effect=Exception("unexpected"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(InternalServerError):
|
||||
method(installed_app, "mid")
|
||||
@ -0,0 +1,140 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import controllers.console.explore.parameter as module
|
||||
from controllers.console.app.error import AppUnavailableError
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestAppParameterApi:
|
||||
def test_get_app_none(self):
|
||||
api = module.AppParameterApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock(app=None)
|
||||
|
||||
with pytest.raises(AppUnavailableError):
|
||||
method(installed_app)
|
||||
|
||||
def test_get_advanced_chat_workflow(self):
|
||||
api = module.AppParameterApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
workflow = MagicMock()
|
||||
workflow.features_dict = {"f": "v"}
|
||||
workflow.user_input_form.return_value = [{"name": "x"}]
|
||||
|
||||
app = MagicMock(
|
||||
mode=AppMode.ADVANCED_CHAT,
|
||||
workflow=workflow,
|
||||
)
|
||||
|
||||
installed_app = MagicMock(app=app)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
module,
|
||||
"get_parameters_from_feature_dict",
|
||||
return_value={"any": "thing"},
|
||||
),
|
||||
patch.object(
|
||||
module.fields.Parameters,
|
||||
"model_validate",
|
||||
return_value=MagicMock(model_dump=lambda **_: {"ok": True}),
|
||||
),
|
||||
):
|
||||
result = method(installed_app)
|
||||
|
||||
assert result == {"ok": True}
|
||||
|
||||
def test_get_advanced_chat_workflow_missing(self):
|
||||
api = module.AppParameterApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
app = MagicMock(
|
||||
mode=AppMode.ADVANCED_CHAT,
|
||||
workflow=None,
|
||||
)
|
||||
|
||||
installed_app = MagicMock(app=app)
|
||||
|
||||
with pytest.raises(AppUnavailableError):
|
||||
method(installed_app)
|
||||
|
||||
def test_get_non_workflow_app(self):
|
||||
api = module.AppParameterApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
app_model_config = MagicMock()
|
||||
app_model_config.to_dict.return_value = {"user_input_form": [{"name": "y"}]}
|
||||
|
||||
app = MagicMock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=app_model_config,
|
||||
)
|
||||
|
||||
installed_app = MagicMock(app=app)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
module,
|
||||
"get_parameters_from_feature_dict",
|
||||
return_value={"whatever": 123},
|
||||
),
|
||||
patch.object(
|
||||
module.fields.Parameters,
|
||||
"model_validate",
|
||||
return_value=MagicMock(model_dump=lambda **_: {"ok": True}),
|
||||
),
|
||||
):
|
||||
result = method(installed_app)
|
||||
|
||||
assert result == {"ok": True}
|
||||
|
||||
def test_get_non_workflow_missing_config(self):
|
||||
api = module.AppParameterApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
app = MagicMock(
|
||||
mode=AppMode.CHAT,
|
||||
app_model_config=None,
|
||||
)
|
||||
|
||||
installed_app = MagicMock(app=app)
|
||||
|
||||
with pytest.raises(AppUnavailableError):
|
||||
method(installed_app)
|
||||
|
||||
|
||||
class TestExploreAppMetaApi:
|
||||
def test_get_meta_success(self):
|
||||
api = module.ExploreAppMetaApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
app = MagicMock()
|
||||
installed_app = MagicMock(app=app)
|
||||
|
||||
with patch.object(
|
||||
module.AppService,
|
||||
"get_app_meta",
|
||||
return_value={"meta": "ok"},
|
||||
):
|
||||
result = method(installed_app)
|
||||
|
||||
assert result == {"meta": "ok"}
|
||||
|
||||
def test_get_meta_app_missing(self):
|
||||
api = module.ExploreAppMetaApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock(app=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
method(installed_app)
|
||||
@ -0,0 +1,92 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import controllers.console.explore.recommended_app as module
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestRecommendedAppListApi:
|
||||
def test_get_with_language_param(self, app):
|
||||
api = module.RecommendedAppListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
result_data = {"recommended_apps": [], "categories": []}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", query_string={"language": "en-US"}),
|
||||
patch.object(module, "current_user", MagicMock(interface_language="fr-FR")),
|
||||
patch.object(
|
||||
module.RecommendedAppService,
|
||||
"get_recommended_apps_and_categories",
|
||||
return_value=result_data,
|
||||
) as service_mock,
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
service_mock.assert_called_once_with("en-US")
|
||||
assert result == result_data
|
||||
|
||||
def test_get_fallback_to_user_language(self, app):
|
||||
api = module.RecommendedAppListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
result_data = {"recommended_apps": [], "categories": []}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", query_string={"language": "invalid"}),
|
||||
patch.object(module, "current_user", MagicMock(interface_language="fr-FR")),
|
||||
patch.object(
|
||||
module.RecommendedAppService,
|
||||
"get_recommended_apps_and_categories",
|
||||
return_value=result_data,
|
||||
) as service_mock,
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
service_mock.assert_called_once_with("fr-FR")
|
||||
assert result == result_data
|
||||
|
||||
def test_get_fallback_to_default_language(self, app):
|
||||
api = module.RecommendedAppListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
result_data = {"recommended_apps": [], "categories": []}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(module, "current_user", MagicMock(interface_language=None)),
|
||||
patch.object(
|
||||
module.RecommendedAppService,
|
||||
"get_recommended_apps_and_categories",
|
||||
return_value=result_data,
|
||||
) as service_mock,
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
service_mock.assert_called_once_with(module.languages[0])
|
||||
assert result == result_data
|
||||
|
||||
|
||||
class TestRecommendedAppApi:
|
||||
def test_get_success(self, app):
|
||||
api = module.RecommendedAppApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
result_data = {"id": "app1"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
module.RecommendedAppService,
|
||||
"get_recommend_app_detail",
|
||||
return_value=result_data,
|
||||
) as service_mock,
|
||||
):
|
||||
result = method(api, "11111111-1111-1111-1111-111111111111")
|
||||
|
||||
service_mock.assert_called_once_with("11111111-1111-1111-1111-111111111111")
|
||||
assert result == result_data
|
||||
@ -0,0 +1,154 @@
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import controllers.console.explore.saved_message as module
|
||||
from controllers.console.explore.error import NotCompletionAppError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def make_saved_message():
|
||||
msg = MagicMock()
|
||||
msg.id = str(uuid4())
|
||||
msg.message_id = str(uuid4())
|
||||
msg.app_id = str(uuid4())
|
||||
msg.inputs = {}
|
||||
msg.query = "hello"
|
||||
msg.answer = "world"
|
||||
msg.user_feedback = MagicMock(rating="like")
|
||||
msg.created_at = None
|
||||
return msg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def payload_patch():
|
||||
def _patch(payload):
|
||||
return patch.object(
|
||||
type(module.console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
)
|
||||
|
||||
return _patch
|
||||
|
||||
|
||||
class TestSavedMessageListApi:
|
||||
def test_get_success(self, app):
|
||||
api = module.SavedMessageListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="completion")
|
||||
|
||||
pagination = MagicMock(
|
||||
limit=20,
|
||||
has_more=False,
|
||||
data=[make_saved_message(), make_saved_message()],
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", query_string={}),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.SavedMessageService,
|
||||
"pagination_by_last_id",
|
||||
return_value=pagination,
|
||||
),
|
||||
):
|
||||
result = method(installed_app)
|
||||
|
||||
assert result["limit"] == 20
|
||||
assert result["has_more"] is False
|
||||
assert len(result["data"]) == 2
|
||||
|
||||
def test_get_not_completion_app(self):
|
||||
api = module.SavedMessageListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="chat")
|
||||
|
||||
with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)):
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
method(installed_app)
|
||||
|
||||
def test_post_success(self, app, payload_patch):
|
||||
api = module.SavedMessageListApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="completion")
|
||||
|
||||
payload = {"message_id": str(uuid4())}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
payload_patch(payload),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(module.SavedMessageService, "save") as save_mock,
|
||||
):
|
||||
result = method(installed_app)
|
||||
|
||||
save_mock.assert_called_once()
|
||||
assert result == {"result": "success"}
|
||||
|
||||
def test_post_message_not_exists(self, app, payload_patch):
|
||||
api = module.SavedMessageListApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="completion")
|
||||
|
||||
payload = {"message_id": str(uuid4())}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
payload_patch(payload),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.SavedMessageService,
|
||||
"save",
|
||||
side_effect=MessageNotExistsError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(installed_app)
|
||||
|
||||
|
||||
class TestSavedMessageApi:
|
||||
def test_delete_success(self):
|
||||
api = module.SavedMessageApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="completion")
|
||||
|
||||
with (
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(module.SavedMessageService, "delete") as delete_mock,
|
||||
):
|
||||
result, status = method(installed_app, str(uuid4()))
|
||||
|
||||
delete_mock.assert_called_once()
|
||||
assert status == 204
|
||||
assert result == {"result": "success"}
|
||||
|
||||
def test_delete_not_completion_app(self):
|
||||
api = module.SavedMessageApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="chat")
|
||||
|
||||
with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)):
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
method(installed_app, str(uuid4()))
|
||||
1101
api/tests/unit_tests/controllers/console/explore/test_trial.py
Normal file
1101
api/tests/unit_tests/controllers/console/explore/test_trial.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,151 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.console.explore.error import NotWorkflowAppError
|
||||
from controllers.console.explore.workflow import (
|
||||
InstalledAppWorkflowRunApi,
|
||||
InstalledAppWorkflowTaskStopApi,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from models.model import AppMode
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user():
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def workflow_app():
|
||||
app = MagicMock()
|
||||
app.mode = AppMode.WORKFLOW
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def installed_workflow_app(workflow_app):
|
||||
return MagicMock(app=workflow_app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def non_workflow_installed_app():
|
||||
app = MagicMock()
|
||||
app.mode = AppMode.CHAT
|
||||
return MagicMock(app=app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def payload():
|
||||
return {"inputs": {"a": 1}}
|
||||
|
||||
|
||||
class TestInstalledAppWorkflowRunApi:
|
||||
def test_not_workflow_app(self, app, non_workflow_installed_app):
|
||||
api = InstalledAppWorkflowRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.explore.workflow.current_account_with_tenant",
|
||||
return_value=(MagicMock(), None),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotWorkflowAppError):
|
||||
method(non_workflow_installed_app)
|
||||
|
||||
def test_success(self, app, installed_workflow_app, user, payload):
|
||||
api = InstalledAppWorkflowRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.explore.workflow.current_account_with_tenant",
|
||||
return_value=(user, None),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.explore.workflow.AppGenerateService.generate",
|
||||
return_value=MagicMock(),
|
||||
) as generate_mock,
|
||||
):
|
||||
result = method(installed_workflow_app)
|
||||
|
||||
generate_mock.assert_called_once()
|
||||
assert result is not None
|
||||
|
||||
def test_rate_limit_error(self, app, installed_workflow_app, user, payload):
|
||||
api = InstalledAppWorkflowRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.explore.workflow.current_account_with_tenant",
|
||||
return_value=(user, None),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.explore.workflow.AppGenerateService.generate",
|
||||
side_effect=InvokeRateLimitError("rate limit"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvokeRateLimitHttpError):
|
||||
method(installed_workflow_app)
|
||||
|
||||
def test_unexpected_exception(self, app, installed_workflow_app, user, payload):
|
||||
api = InstalledAppWorkflowRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.explore.workflow.current_account_with_tenant",
|
||||
return_value=(user, None),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.explore.workflow.AppGenerateService.generate",
|
||||
side_effect=Exception("boom"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(InternalServerError):
|
||||
method(installed_workflow_app)
|
||||
|
||||
|
||||
class TestInstalledAppWorkflowTaskStopApi:
|
||||
def test_not_workflow_app(self, non_workflow_installed_app):
|
||||
api = InstalledAppWorkflowTaskStopApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with pytest.raises(NotWorkflowAppError):
|
||||
method(non_workflow_installed_app, "task-1")
|
||||
|
||||
def test_success(self, installed_workflow_app):
|
||||
api = InstalledAppWorkflowTaskStopApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
patch("controllers.console.explore.workflow.AppQueueManager.set_stop_flag_no_user_check") as stop_flag,
|
||||
patch("controllers.console.explore.workflow.GraphEngineManager.send_stop_command") as send_stop,
|
||||
):
|
||||
result = method(installed_workflow_app, "task-1")
|
||||
|
||||
stop_flag.assert_called_once_with("task-1")
|
||||
send_stop.assert_called_once_with("task-1")
|
||||
assert result == {"result": "success"}
|
||||
244
api/tests/unit_tests/controllers/console/explore/test_wraps.py
Normal file
244
api/tests/unit_tests/controllers/console/explore/test_wraps.py
Normal file
@ -0,0 +1,244 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from controllers.console.explore.error import (
|
||||
AppAccessDeniedError,
|
||||
TrialAppLimitExceeded,
|
||||
TrialAppNotAllowed,
|
||||
)
|
||||
from controllers.console.explore.wraps import (
|
||||
InstalledAppResource,
|
||||
TrialAppResource,
|
||||
installed_app_required,
|
||||
trial_app_required,
|
||||
trial_feature_enable,
|
||||
user_allowed_to_access_app,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def test_installed_app_required_not_found():
|
||||
@installed_app_required
|
||||
def view(installed_app):
|
||||
return "ok"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch("controllers.console.explore.wraps.db.session.query") as q,
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
view("app-id")
|
||||
|
||||
|
||||
def test_installed_app_required_app_deleted():
|
||||
installed_app = MagicMock(app=None)
|
||||
|
||||
@installed_app_required
|
||||
def view(installed_app):
|
||||
return "ok"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch("controllers.console.explore.wraps.db.session.query") as q,
|
||||
patch("controllers.console.explore.wraps.db.session.delete"),
|
||||
patch("controllers.console.explore.wraps.db.session.commit"),
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = installed_app
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
view("app-id")
|
||||
|
||||
|
||||
def test_installed_app_required_success():
|
||||
installed_app = MagicMock(app=MagicMock())
|
||||
|
||||
@installed_app_required
|
||||
def view(installed_app):
|
||||
return installed_app
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch("controllers.console.explore.wraps.db.session.query") as q,
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = installed_app
|
||||
|
||||
result = view("app-id")
|
||||
assert result == installed_app
|
||||
|
||||
|
||||
def test_user_allowed_to_access_app_denied():
|
||||
installed_app = MagicMock(app_id="app-1")
|
||||
|
||||
@user_allowed_to_access_app
|
||||
def view(installed_app):
|
||||
return "ok"
|
||||
|
||||
feature = MagicMock()
|
||||
feature.webapp_auth.enabled = True
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="user-1"), None),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.explore.wraps.FeatureService.get_system_features",
|
||||
return_value=feature,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.explore.wraps.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp",
|
||||
return_value=False,
|
||||
),
|
||||
):
|
||||
with pytest.raises(AppAccessDeniedError):
|
||||
view(installed_app)
|
||||
|
||||
|
||||
def test_user_allowed_to_access_app_success():
|
||||
installed_app = MagicMock(app_id="app-1")
|
||||
|
||||
@user_allowed_to_access_app
|
||||
def view(installed_app):
|
||||
return "ok"
|
||||
|
||||
feature = MagicMock()
|
||||
feature.webapp_auth.enabled = True
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="user-1"), None),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.explore.wraps.FeatureService.get_system_features",
|
||||
return_value=feature,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.explore.wraps.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp",
|
||||
return_value=True,
|
||||
),
|
||||
):
|
||||
assert view(installed_app) == "ok"
|
||||
|
||||
|
||||
def test_trial_app_required_not_allowed():
|
||||
@trial_app_required
|
||||
def view(app):
|
||||
return "ok"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="user-1"), None),
|
||||
),
|
||||
patch("controllers.console.explore.wraps.db.session.query") as q,
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(TrialAppNotAllowed):
|
||||
view("app-id")
|
||||
|
||||
|
||||
def test_trial_app_required_limit_exceeded():
|
||||
trial_app = MagicMock(trial_limit=1, app=MagicMock())
|
||||
record = MagicMock(count=1)
|
||||
|
||||
@trial_app_required
|
||||
def view(app):
|
||||
return "ok"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="user-1"), None),
|
||||
),
|
||||
patch("controllers.console.explore.wraps.db.session.query") as q,
|
||||
):
|
||||
q.return_value.where.return_value.first.side_effect = [
|
||||
trial_app,
|
||||
record,
|
||||
]
|
||||
|
||||
with pytest.raises(TrialAppLimitExceeded):
|
||||
view("app-id")
|
||||
|
||||
|
||||
def test_trial_app_required_success():
|
||||
trial_app = MagicMock(trial_limit=2, app=MagicMock())
|
||||
record = MagicMock(count=1)
|
||||
|
||||
@trial_app_required
|
||||
def view(app):
|
||||
return app
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.explore.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="user-1"), None),
|
||||
),
|
||||
patch("controllers.console.explore.wraps.db.session.query") as q,
|
||||
):
|
||||
q.return_value.where.return_value.first.side_effect = [
|
||||
trial_app,
|
||||
record,
|
||||
]
|
||||
|
||||
result = view("app-id")
|
||||
assert result == trial_app.app
|
||||
|
||||
|
||||
def test_trial_feature_enable_disabled():
|
||||
@trial_feature_enable
|
||||
def view():
|
||||
return "ok"
|
||||
|
||||
features = MagicMock(enable_trial_app=False)
|
||||
|
||||
with patch(
|
||||
"controllers.console.explore.wraps.FeatureService.get_system_features",
|
||||
return_value=features,
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
view()
|
||||
|
||||
|
||||
def test_trial_feature_enable_enabled():
|
||||
@trial_feature_enable
|
||||
def view():
|
||||
return "ok"
|
||||
|
||||
features = MagicMock(enable_trial_app=True)
|
||||
|
||||
with patch(
|
||||
"controllers.console.explore.wraps.FeatureService.get_system_features",
|
||||
return_value=features,
|
||||
):
|
||||
assert view() == "ok"
|
||||
|
||||
|
||||
def test_installed_app_resource_decorators():
|
||||
decorators = InstalledAppResource.method_decorators
|
||||
assert len(decorators) == 4
|
||||
|
||||
|
||||
def test_trial_app_resource_decorators():
|
||||
decorators = TrialAppResource.method_decorators
|
||||
assert len(decorators) == 3
|
||||
278
api/tests/unit_tests/controllers/console/tag/test_tags.py
Normal file
278
api/tests/unit_tests/controllers/console/tag/test_tags.py
Normal file
@ -0,0 +1,278 @@
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.tag.tags import (
|
||||
TagBindingCreateApi,
|
||||
TagBindingDeleteApi,
|
||||
TagListApi,
|
||||
TagUpdateDeleteApi,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
"""
|
||||
Recursively unwrap decorated functions.
|
||||
"""
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask("test_tag")
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user():
|
||||
return MagicMock(
|
||||
id="user-1",
|
||||
has_edit_permission=True,
|
||||
is_dataset_editor=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def readonly_user():
|
||||
return MagicMock(
|
||||
id="user-2",
|
||||
has_edit_permission=False,
|
||||
is_dataset_editor=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tag():
|
||||
tag = MagicMock()
|
||||
tag.id = "tag-1"
|
||||
tag.name = "test-tag"
|
||||
tag.type = "knowledge"
|
||||
return tag
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def payload_patch():
|
||||
def _patch(payload):
|
||||
return patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
)
|
||||
|
||||
return _patch
|
||||
|
||||
|
||||
class TestTagListApi:
|
||||
def test_get_success(self, app):
|
||||
api = TagListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/?type=knowledge"):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.tag.tags.TagService.get_tags",
|
||||
return_value=[{"id": "1", "name": "tag"}],
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_post_success(self, app, admin_user, tag, payload_patch):
|
||||
api = TagListApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"name": "test-tag", "type": "knowledge"}
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(admin_user, None),
|
||||
),
|
||||
payload_patch(payload),
|
||||
patch(
|
||||
"controllers.console.tag.tags.TagService.save_tags",
|
||||
return_value=tag,
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["name"] == "test-tag"
|
||||
|
||||
def test_post_forbidden(self, app, readonly_user, payload_patch):
|
||||
api = TagListApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"name": "x"}
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(readonly_user, None),
|
||||
),
|
||||
payload_patch(payload),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestTagUpdateDeleteApi:
|
||||
def test_patch_success(self, app, admin_user, tag, payload_patch):
|
||||
api = TagUpdateDeleteApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
payload = {"name": "updated", "type": "knowledge"}
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(admin_user, None),
|
||||
),
|
||||
payload_patch(payload),
|
||||
patch(
|
||||
"controllers.console.tag.tags.TagService.update_tags",
|
||||
return_value=tag,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.tag.tags.TagService.get_tag_binding_count",
|
||||
return_value=3,
|
||||
),
|
||||
):
|
||||
result, status = method(api, "tag-1")
|
||||
|
||||
assert status == 200
|
||||
assert result["binding_count"] == 3
|
||||
|
||||
def test_patch_forbidden(self, app, readonly_user, payload_patch):
|
||||
api = TagUpdateDeleteApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
payload = {"name": "x"}
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(readonly_user, None),
|
||||
),
|
||||
payload_patch(payload),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "tag-1")
|
||||
|
||||
def test_delete_success(self, app, admin_user):
|
||||
api = TagUpdateDeleteApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(admin_user, "tenant-1"),
|
||||
),
|
||||
patch("controllers.console.tag.tags.TagService.delete_tag") as delete_mock,
|
||||
):
|
||||
result, status = method(api, "tag-1")
|
||||
|
||||
delete_mock.assert_called_once_with("tag-1")
|
||||
assert status == 204
|
||||
|
||||
|
||||
class TestTagBindingCreateApi:
|
||||
def test_create_success(self, app, admin_user, payload_patch):
|
||||
api = TagBindingCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"tag_ids": ["tag-1"],
|
||||
"target_id": "target-1",
|
||||
"type": "knowledge",
|
||||
}
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(admin_user, None),
|
||||
),
|
||||
payload_patch(payload),
|
||||
patch("controllers.console.tag.tags.TagService.save_tag_binding") as save_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
save_mock.assert_called_once()
|
||||
assert status == 200
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_create_forbidden(self, app, readonly_user, payload_patch):
|
||||
api = TagBindingCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with app.test_request_context("/", json={}):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(readonly_user, None),
|
||||
),
|
||||
payload_patch({}),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestTagBindingDeleteApi:
|
||||
def test_remove_success(self, app, admin_user, payload_patch):
|
||||
api = TagBindingDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"tag_id": "tag-1",
|
||||
"target_id": "target-1",
|
||||
"type": "knowledge",
|
||||
}
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(admin_user, None),
|
||||
),
|
||||
payload_patch(payload),
|
||||
patch("controllers.console.tag.tags.TagService.delete_tag_binding") as delete_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
delete_mock.assert_called_once()
|
||||
assert status == 200
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_remove_forbidden(self, app, readonly_user, payload_patch):
|
||||
api = TagBindingDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with app.test_request_context("/", json={}):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(readonly_user, None),
|
||||
),
|
||||
payload_patch({}),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
@ -1,13 +1,483 @@
|
||||
"""Final working unit tests for admin endpoints - tests business logic directly."""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import Mock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from controllers.console.admin import InsertExploreAppPayload
|
||||
from models.model import App, RecommendedApp
|
||||
from controllers.console.admin import (
|
||||
DeleteExploreBannerApi,
|
||||
InsertExploreAppApi,
|
||||
InsertExploreAppListApi,
|
||||
InsertExploreAppPayload,
|
||||
InsertExploreBannerApi,
|
||||
InsertExploreBannerPayload,
|
||||
)
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def bypass_only_edition_cloud(mocker):
|
||||
"""
|
||||
Bypass only_edition_cloud decorator by setting EDITION to "CLOUD".
|
||||
"""
|
||||
mocker.patch(
|
||||
"controllers.console.wraps.dify_config.EDITION",
|
||||
new="CLOUD",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_admin_auth(mocker):
|
||||
"""
|
||||
Provide valid admin authentication for controller tests.
|
||||
"""
|
||||
mocker.patch(
|
||||
"controllers.console.admin.dify_config.ADMIN_API_KEY",
|
||||
"test-admin-key",
|
||||
)
|
||||
mocker.patch(
|
||||
"controllers.console.admin.extract_access_token",
|
||||
return_value="test-admin-key",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_console_payload(mocker):
|
||||
payload = {
|
||||
"app_id": str(uuid.uuid4()),
|
||||
"language": "en-US",
|
||||
"category": "Productivity",
|
||||
"position": 1,
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"flask_restx.namespace.Namespace.payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
)
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_banner_payload(mocker):
|
||||
mocker.patch(
|
||||
"flask_restx.namespace.Namespace.payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value={
|
||||
"title": "Test Banner",
|
||||
"description": "Banner description",
|
||||
"img-src": "https://example.com/banner.png",
|
||||
"link": "https://example.com",
|
||||
"sort": 1,
|
||||
"category": "homepage",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_factory(mocker):
|
||||
mock_session = Mock()
|
||||
mock_session.execute = Mock()
|
||||
mock_session.add = Mock()
|
||||
mock_session.commit = Mock()
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.admin.session_factory.create_session",
|
||||
return_value=Mock(
|
||||
__enter__=lambda s: mock_session,
|
||||
__exit__=Mock(return_value=False),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TestDeleteExploreBannerApi:
|
||||
def setup_method(self):
|
||||
self.api = DeleteExploreBannerApi()
|
||||
|
||||
def test_delete_banner_not_found(self, mocker, mock_admin_auth):
|
||||
mocker.patch(
|
||||
"controllers.console.admin.db.session.execute",
|
||||
return_value=Mock(scalar_one_or_none=lambda: None),
|
||||
)
|
||||
|
||||
with pytest.raises(NotFound, match="is not found"):
|
||||
self.api.delete(uuid.uuid4())
|
||||
|
||||
def test_delete_banner_success(self, mocker, mock_admin_auth):
|
||||
mock_banner = Mock()
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.admin.db.session.execute",
|
||||
return_value=Mock(scalar_one_or_none=lambda: mock_banner),
|
||||
)
|
||||
mocker.patch("controllers.console.admin.db.session.delete")
|
||||
mocker.patch("controllers.console.admin.db.session.commit")
|
||||
|
||||
response, status = self.api.delete(uuid.uuid4())
|
||||
|
||||
assert status == 204
|
||||
assert response["result"] == "success"
|
||||
|
||||
|
||||
class TestInsertExploreBannerApi:
|
||||
def setup_method(self):
|
||||
self.api = InsertExploreBannerApi()
|
||||
|
||||
def test_insert_banner_success(self, mocker, mock_admin_auth, mock_banner_payload):
|
||||
mocker.patch("controllers.console.admin.db.session.add")
|
||||
mocker.patch("controllers.console.admin.db.session.commit")
|
||||
|
||||
response, status = self.api.post()
|
||||
|
||||
assert status == 201
|
||||
assert response["result"] == "success"
|
||||
|
||||
def test_banner_payload_valid_language(self):
|
||||
payload = {
|
||||
"title": "Test Banner",
|
||||
"description": "Banner description",
|
||||
"img-src": "https://example.com/banner.png",
|
||||
"link": "https://example.com",
|
||||
"sort": 1,
|
||||
"category": "homepage",
|
||||
"language": "en-US",
|
||||
}
|
||||
|
||||
model = InsertExploreBannerPayload.model_validate(payload)
|
||||
assert model.language == "en-US"
|
||||
|
||||
def test_banner_payload_invalid_language(self):
|
||||
payload = {
|
||||
"title": "Test Banner",
|
||||
"description": "Banner description",
|
||||
"img-src": "https://example.com/banner.png",
|
||||
"link": "https://example.com",
|
||||
"sort": 1,
|
||||
"category": "homepage",
|
||||
"language": "invalid-lang",
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="invalid-lang is not a valid language"):
|
||||
InsertExploreBannerPayload.model_validate(payload)
|
||||
|
||||
|
||||
class TestInsertExploreAppApiDelete:
|
||||
def setup_method(self):
|
||||
self.api = InsertExploreAppApi()
|
||||
|
||||
def test_delete_when_not_in_explore(self, mocker, mock_admin_auth):
|
||||
mocker.patch(
|
||||
"controllers.console.admin.session_factory.create_session",
|
||||
return_value=Mock(
|
||||
__enter__=lambda s: s,
|
||||
__exit__=Mock(return_value=False),
|
||||
execute=lambda *_: Mock(scalar_one_or_none=lambda: None),
|
||||
),
|
||||
)
|
||||
|
||||
response, status = self.api.delete(uuid.uuid4())
|
||||
|
||||
assert status == 204
|
||||
assert response["result"] == "success"
|
||||
|
||||
def test_delete_when_in_explore_with_trial_app(self, mocker, mock_admin_auth):
|
||||
"""Test deleting an app from explore that has a trial app."""
|
||||
app_id = uuid.uuid4()
|
||||
|
||||
mock_recommended = Mock(spec=RecommendedApp)
|
||||
mock_recommended.app_id = "app-123"
|
||||
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.is_public = True
|
||||
|
||||
mock_trial = Mock()
|
||||
|
||||
# Mock session context manager and its execute
|
||||
mock_session = Mock()
|
||||
mock_session.execute = Mock()
|
||||
mock_session.delete = Mock()
|
||||
|
||||
# Set up side effects for execute calls
|
||||
mock_session.execute.side_effect = [
|
||||
Mock(scalar_one_or_none=lambda: mock_recommended),
|
||||
Mock(scalar_one_or_none=lambda: mock_app),
|
||||
Mock(scalars=Mock(return_value=Mock(all=lambda: []))),
|
||||
Mock(scalar_one_or_none=lambda: mock_trial),
|
||||
]
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.admin.session_factory.create_session",
|
||||
return_value=Mock(
|
||||
__enter__=lambda s: mock_session,
|
||||
__exit__=Mock(return_value=False),
|
||||
),
|
||||
)
|
||||
|
||||
mocker.patch("controllers.console.admin.db.session.delete")
|
||||
mocker.patch("controllers.console.admin.db.session.commit")
|
||||
|
||||
response, status = self.api.delete(app_id)
|
||||
|
||||
assert status == 204
|
||||
assert response["result"] == "success"
|
||||
assert mock_app.is_public is False
|
||||
|
||||
def test_delete_with_installed_apps(self, mocker, mock_admin_auth):
|
||||
"""Test deleting an app that has installed apps in other tenants."""
|
||||
app_id = uuid.uuid4()
|
||||
|
||||
mock_recommended = Mock(spec=RecommendedApp)
|
||||
mock_recommended.app_id = "app-123"
|
||||
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.is_public = True
|
||||
|
||||
mock_installed_app = Mock(spec=InstalledApp)
|
||||
|
||||
# Mock session
|
||||
mock_session = Mock()
|
||||
mock_session.execute = Mock()
|
||||
mock_session.delete = Mock()
|
||||
|
||||
mock_session.execute.side_effect = [
|
||||
Mock(scalar_one_or_none=lambda: mock_recommended),
|
||||
Mock(scalar_one_or_none=lambda: mock_app),
|
||||
Mock(scalars=Mock(return_value=Mock(all=lambda: [mock_installed_app]))),
|
||||
Mock(scalar_one_or_none=lambda: None),
|
||||
]
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.admin.session_factory.create_session",
|
||||
return_value=Mock(
|
||||
__enter__=lambda s: mock_session,
|
||||
__exit__=Mock(return_value=False),
|
||||
),
|
||||
)
|
||||
|
||||
mocker.patch("controllers.console.admin.db.session.delete")
|
||||
mocker.patch("controllers.console.admin.db.session.commit")
|
||||
|
||||
response, status = self.api.delete(app_id)
|
||||
|
||||
assert status == 204
|
||||
assert mock_session.delete.called
|
||||
|
||||
|
||||
class TestInsertExploreAppListApi:
|
||||
def setup_method(self):
|
||||
self.api = InsertExploreAppListApi()
|
||||
|
||||
def test_app_not_found(self, mocker, mock_admin_auth, mock_console_payload):
|
||||
mocker.patch(
|
||||
"controllers.console.admin.db.session.execute",
|
||||
return_value=Mock(scalar_one_or_none=lambda: None),
|
||||
)
|
||||
|
||||
with pytest.raises(NotFound, match="is not found"):
|
||||
self.api.post()
|
||||
|
||||
def test_create_recommended_app(
|
||||
self,
|
||||
mocker,
|
||||
mock_admin_auth,
|
||||
mock_console_payload,
|
||||
):
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.id = "app-id"
|
||||
mock_app.site = None
|
||||
mock_app.tenant_id = "tenant"
|
||||
mock_app.is_public = False
|
||||
|
||||
# db.session.execute → fetch App
|
||||
mocker.patch(
|
||||
"controllers.console.admin.db.session.execute",
|
||||
return_value=Mock(scalar_one_or_none=lambda: mock_app),
|
||||
)
|
||||
|
||||
# session_factory.create_session → recommended_app lookup
|
||||
mock_session = Mock()
|
||||
mock_session.execute = Mock(return_value=Mock(scalar_one_or_none=lambda: None))
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.admin.session_factory.create_session",
|
||||
return_value=Mock(
|
||||
__enter__=lambda s: mock_session,
|
||||
__exit__=Mock(return_value=False),
|
||||
),
|
||||
)
|
||||
|
||||
mocker.patch("controllers.console.admin.db.session.add")
|
||||
mocker.patch("controllers.console.admin.db.session.commit")
|
||||
|
||||
response, status = self.api.post()
|
||||
|
||||
assert status == 201
|
||||
assert response["result"] == "success"
|
||||
assert mock_app.is_public is True
|
||||
|
||||
def test_update_recommended_app(self, mocker, mock_admin_auth, mock_console_payload, mock_session_factory):
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.id = "app-id"
|
||||
mock_app.site = None
|
||||
mock_app.is_public = False
|
||||
|
||||
mock_recommended = Mock(spec=RecommendedApp)
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.admin.db.session.execute",
|
||||
side_effect=[
|
||||
Mock(scalar_one_or_none=lambda: mock_app),
|
||||
Mock(scalar_one_or_none=lambda: mock_recommended),
|
||||
],
|
||||
)
|
||||
|
||||
mocker.patch("controllers.console.admin.db.session.commit")
|
||||
|
||||
response, status = self.api.post()
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == "success"
|
||||
assert mock_app.is_public is True
|
||||
|
||||
def test_site_data_overrides_payload(
|
||||
self,
|
||||
mocker,
|
||||
mock_admin_auth,
|
||||
mock_console_payload,
|
||||
mock_session_factory,
|
||||
):
|
||||
site = Mock()
|
||||
site.description = "Site Desc"
|
||||
site.copyright = "Site Copyright"
|
||||
site.privacy_policy = "Site Privacy"
|
||||
site.custom_disclaimer = "Site Disclaimer"
|
||||
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.id = "app-id"
|
||||
mock_app.site = site
|
||||
mock_app.tenant_id = "tenant"
|
||||
mock_app.is_public = False
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.admin.db.session.execute",
|
||||
side_effect=[
|
||||
Mock(scalar_one_or_none=lambda: mock_app),
|
||||
Mock(scalar_one_or_none=lambda: None),
|
||||
Mock(scalar_one_or_none=lambda: None),
|
||||
],
|
||||
)
|
||||
|
||||
commit_spy = mocker.patch("controllers.console.admin.db.session.commit")
|
||||
|
||||
response, status = self.api.post()
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == "success"
|
||||
assert mock_app.is_public is True
|
||||
commit_spy.assert_called_once()
|
||||
|
||||
def test_create_trial_app_when_can_trial_enabled(
|
||||
self,
|
||||
mocker,
|
||||
mock_admin_auth,
|
||||
mock_console_payload,
|
||||
mock_session_factory,
|
||||
):
|
||||
mock_console_payload["can_trial"] = True
|
||||
mock_console_payload["trial_limit"] = 5
|
||||
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.id = "app-id"
|
||||
mock_app.site = None
|
||||
mock_app.tenant_id = "tenant"
|
||||
mock_app.is_public = False
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.admin.db.session.execute",
|
||||
side_effect=[
|
||||
Mock(scalar_one_or_none=lambda: mock_app),
|
||||
Mock(scalar_one_or_none=lambda: None),
|
||||
Mock(scalar_one_or_none=lambda: None),
|
||||
],
|
||||
)
|
||||
|
||||
add_spy = mocker.patch("controllers.console.admin.db.session.add")
|
||||
mocker.patch("controllers.console.admin.db.session.commit")
|
||||
|
||||
self.api.post()
|
||||
|
||||
assert any(call.args[0].__class__.__name__ == "TrialApp" for call in add_spy.call_args_list)
|
||||
|
||||
def test_update_recommended_app_with_trial(
|
||||
self,
|
||||
mocker,
|
||||
mock_admin_auth,
|
||||
mock_console_payload,
|
||||
mock_session_factory,
|
||||
):
|
||||
"""Test updating a recommended app when trial is enabled."""
|
||||
mock_console_payload["can_trial"] = True
|
||||
mock_console_payload["trial_limit"] = 10
|
||||
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.id = "app-id"
|
||||
mock_app.site = None
|
||||
mock_app.is_public = False
|
||||
mock_app.tenant_id = "tenant-123"
|
||||
|
||||
mock_recommended = Mock(spec=RecommendedApp)
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.admin.db.session.execute",
|
||||
side_effect=[
|
||||
Mock(scalar_one_or_none=lambda: mock_app),
|
||||
Mock(scalar_one_or_none=lambda: mock_recommended),
|
||||
Mock(scalar_one_or_none=lambda: None),
|
||||
],
|
||||
)
|
||||
|
||||
add_spy = mocker.patch("controllers.console.admin.db.session.add")
|
||||
mocker.patch("controllers.console.admin.db.session.commit")
|
||||
|
||||
response, status = self.api.post()
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == "success"
|
||||
assert mock_app.is_public is True
|
||||
|
||||
def test_update_recommended_app_without_trial(
|
||||
self,
|
||||
mocker,
|
||||
mock_admin_auth,
|
||||
mock_console_payload,
|
||||
mock_session_factory,
|
||||
):
|
||||
"""Test updating a recommended app without trial enabled."""
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.id = "app-id"
|
||||
mock_app.site = None
|
||||
mock_app.is_public = False
|
||||
|
||||
mock_recommended = Mock(spec=RecommendedApp)
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.admin.db.session.execute",
|
||||
side_effect=[
|
||||
Mock(scalar_one_or_none=lambda: mock_app),
|
||||
Mock(scalar_one_or_none=lambda: mock_recommended),
|
||||
],
|
||||
)
|
||||
|
||||
mocker.patch("controllers.console.admin.db.session.commit")
|
||||
|
||||
response, status = self.api.post()
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == "success"
|
||||
assert mock_app.is_public is True
|
||||
|
||||
|
||||
class TestInsertExploreAppPayload:
|
||||
|
||||
138
api/tests/unit_tests/controllers/console/test_apikey.py
Normal file
138
api/tests/unit_tests/controllers/console/test_apikey.py
Normal file
@ -0,0 +1,138 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console.apikey import (
|
||||
BaseApiKeyListResource,
|
||||
BaseApiKeyResource,
|
||||
_get_resource,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_context_admin():
|
||||
with patch("controllers.console.apikey.current_account_with_tenant") as mock:
|
||||
user = MagicMock()
|
||||
user.is_admin_or_owner = True
|
||||
mock.return_value = (user, "tenant-123")
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_context_non_admin():
|
||||
with patch("controllers.console.apikey.current_account_with_tenant") as mock:
|
||||
user = MagicMock()
|
||||
user.is_admin_or_owner = False
|
||||
mock.return_value = (user, "tenant-123")
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_mock():
|
||||
with patch("controllers.console.apikey.db") as mock_db:
|
||||
mock_db.session = MagicMock()
|
||||
yield mock_db
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def bypass_permissions():
|
||||
with patch(
|
||||
"controllers.console.apikey.edit_permission_required",
|
||||
lambda f: f,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
class DummyApiKeyListResource(BaseApiKeyListResource):
|
||||
resource_type = "app"
|
||||
resource_model = MagicMock()
|
||||
resource_id_field = "app_id"
|
||||
token_prefix = "app-"
|
||||
|
||||
|
||||
class DummyApiKeyResource(BaseApiKeyResource):
|
||||
resource_type = "app"
|
||||
resource_model = MagicMock()
|
||||
resource_id_field = "app_id"
|
||||
|
||||
|
||||
class TestGetResource:
|
||||
def test_get_resource_success(self):
|
||||
fake_resource = MagicMock()
|
||||
|
||||
with (
|
||||
patch("controllers.console.apikey.select") as mock_select,
|
||||
patch("controllers.console.apikey.Session") as mock_session,
|
||||
patch("controllers.console.apikey.db") as mock_db,
|
||||
):
|
||||
mock_db.engine = MagicMock()
|
||||
mock_select.return_value.filter_by.return_value = MagicMock()
|
||||
|
||||
session = mock_session.return_value.__enter__.return_value
|
||||
session.execute.return_value.scalar_one_or_none.return_value = fake_resource
|
||||
|
||||
result = _get_resource("rid", "tid", MagicMock)
|
||||
assert result == fake_resource
|
||||
|
||||
def test_get_resource_not_found(self):
|
||||
with (
|
||||
patch("controllers.console.apikey.select") as mock_select,
|
||||
patch("controllers.console.apikey.Session") as mock_session,
|
||||
patch("controllers.console.apikey.db") as mock_db,
|
||||
patch("controllers.console.apikey.flask_restx.abort") as abort,
|
||||
):
|
||||
mock_db.engine = MagicMock()
|
||||
mock_select.return_value.filter_by.return_value = MagicMock()
|
||||
|
||||
session = mock_session.return_value.__enter__.return_value
|
||||
session.execute.return_value.scalar_one_or_none.return_value = None
|
||||
|
||||
_get_resource("rid", "tid", MagicMock)
|
||||
|
||||
abort.assert_called_once()
|
||||
|
||||
|
||||
class TestBaseApiKeyListResource:
|
||||
def test_get_apikeys_success(self, tenant_context_admin, db_mock):
|
||||
resource = DummyApiKeyListResource()
|
||||
|
||||
with patch("controllers.console.apikey._get_resource"):
|
||||
db_mock.session.scalars.return_value.all.return_value = [MagicMock(), MagicMock()]
|
||||
|
||||
result = DummyApiKeyListResource.get.__wrapped__(resource, "resource-id")
|
||||
assert "items" in result
|
||||
|
||||
|
||||
class TestBaseApiKeyResource:
|
||||
def test_delete_forbidden(self, tenant_context_non_admin, db_mock):
|
||||
resource = DummyApiKeyResource()
|
||||
|
||||
with patch("controllers.console.apikey._get_resource"):
|
||||
with pytest.raises(Forbidden):
|
||||
DummyApiKeyResource.delete(resource, "rid", "kid")
|
||||
|
||||
def test_delete_key_not_found(self, tenant_context_admin, db_mock):
|
||||
resource = DummyApiKeyResource()
|
||||
db_mock.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with patch("controllers.console.apikey._get_resource"):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
DummyApiKeyResource.delete(resource, "rid", "kid")
|
||||
|
||||
# flask_restx.abort raises HTTPException with message in data attribute
|
||||
assert exc_info.value.data["message"] == "API key not found"
|
||||
|
||||
def test_delete_success(self, tenant_context_admin, db_mock):
|
||||
resource = DummyApiKeyResource()
|
||||
db_mock.session.query.return_value.where.return_value.first.return_value = MagicMock()
|
||||
|
||||
with (
|
||||
patch("controllers.console.apikey._get_resource"),
|
||||
patch("controllers.console.apikey.ApiTokenCache.delete"),
|
||||
):
|
||||
result, status = DummyApiKeyResource.delete(resource, "rid", "kid")
|
||||
|
||||
assert status == 204
|
||||
assert result == {"result": "success"}
|
||||
db_mock.session.commit.assert_called_once()
|
||||
@ -1,46 +0,0 @@
|
||||
import builtins
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from extensions import ext_fastopenapi
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.secret_key = "test-secret-key"
|
||||
return app
|
||||
|
||||
|
||||
def test_console_init_get_returns_finished_when_no_init_password(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
ext_fastopenapi.init_app(app)
|
||||
monkeypatch.delenv("INIT_PASSWORD", raising=False)
|
||||
|
||||
with patch("controllers.console.init_validate.dify_config.EDITION", "SELF_HOSTED"):
|
||||
client = app.test_client()
|
||||
response = client.get("/console/api/init")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"status": "finished"}
|
||||
|
||||
|
||||
def test_console_init_post_returns_success(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
ext_fastopenapi.init_app(app)
|
||||
monkeypatch.setenv("INIT_PASSWORD", "test-init-password")
|
||||
|
||||
with (
|
||||
patch("controllers.console.init_validate.dify_config.EDITION", "SELF_HOSTED"),
|
||||
patch("controllers.console.init_validate.TenantService.get_tenant_count", return_value=0),
|
||||
):
|
||||
client = app.test_client()
|
||||
response = client.post("/console/api/init", json={"password": "test-init-password"})
|
||||
|
||||
assert response.status_code == 201
|
||||
assert response.get_json() == {"result": "success"}
|
||||
@ -1,286 +0,0 @@
|
||||
"""Tests for remote file upload API endpoints using Flask-RESTX."""
|
||||
|
||||
import contextlib
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
"""Create Flask app for testing."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["SECRET_KEY"] = "test-secret-key"
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Create test client with console blueprint registered."""
|
||||
from controllers.console import bp
|
||||
|
||||
app.register_blueprint(bp)
|
||||
return app.test_client()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account():
|
||||
"""Create a mock account for testing."""
|
||||
from models import Account
|
||||
|
||||
account = Mock(spec=Account)
|
||||
account.id = "test-account-id"
|
||||
account.current_tenant_id = "test-tenant-id"
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_ctx(app, mock_account):
|
||||
"""Context manager to set auth/tenant context in flask.g for a request."""
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _ctx():
|
||||
with app.test_request_context():
|
||||
g._login_user = mock_account
|
||||
g._current_tenant = mock_account.current_tenant_id
|
||||
yield
|
||||
|
||||
return _ctx
|
||||
|
||||
|
||||
class TestGetRemoteFileInfo:
|
||||
"""Test GET /console/api/remote-files/<path:url> endpoint."""
|
||||
|
||||
def test_get_remote_file_info_success(self, app, client, mock_account):
|
||||
"""Test successful retrieval of remote file info."""
|
||||
response = httpx.Response(
|
||||
200,
|
||||
request=httpx.Request("HEAD", "http://example.com/file.txt"),
|
||||
headers={"Content-Type": "text/plain", "Content-Length": "1024"},
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.remote_files.current_account_with_tenant",
|
||||
return_value=(mock_account, "test-tenant-id"),
|
||||
),
|
||||
patch("controllers.console.remote_files.ssrf_proxy.head", return_value=response),
|
||||
patch("libs.login.check_csrf_token", return_value=None),
|
||||
):
|
||||
with app.test_request_context():
|
||||
g._login_user = mock_account
|
||||
g._current_tenant = mock_account.current_tenant_id
|
||||
encoded_url = "http%3A%2F%2Fexample.com%2Ffile.txt"
|
||||
resp = client.get(f"/console/api/remote-files/{encoded_url}")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.get_json()
|
||||
assert data["file_type"] == "text/plain"
|
||||
assert data["file_length"] == 1024
|
||||
|
||||
def test_get_remote_file_info_fallback_to_get_on_head_failure(self, app, client, mock_account):
|
||||
"""Test fallback to GET when HEAD returns non-200 status."""
|
||||
head_response = httpx.Response(
|
||||
404,
|
||||
request=httpx.Request("HEAD", "http://example.com/file.pdf"),
|
||||
)
|
||||
get_response = httpx.Response(
|
||||
200,
|
||||
request=httpx.Request("GET", "http://example.com/file.pdf"),
|
||||
headers={"Content-Type": "application/pdf", "Content-Length": "2048"},
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.remote_files.current_account_with_tenant",
|
||||
return_value=(mock_account, "test-tenant-id"),
|
||||
),
|
||||
patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_response),
|
||||
patch("controllers.console.remote_files.ssrf_proxy.get", return_value=get_response),
|
||||
patch("libs.login.check_csrf_token", return_value=None),
|
||||
):
|
||||
with app.test_request_context():
|
||||
g._login_user = mock_account
|
||||
g._current_tenant = mock_account.current_tenant_id
|
||||
encoded_url = "http%3A%2F%2Fexample.com%2Ffile.pdf"
|
||||
resp = client.get(f"/console/api/remote-files/{encoded_url}")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.get_json()
|
||||
assert data["file_type"] == "application/pdf"
|
||||
assert data["file_length"] == 2048
|
||||
|
||||
|
||||
class TestRemoteFileUpload:
|
||||
"""Test POST /console/api/remote-files/upload endpoint."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("head_status", "use_get"),
|
||||
[
|
||||
(200, False), # HEAD succeeds
|
||||
(405, True), # HEAD fails -> fallback GET
|
||||
],
|
||||
)
|
||||
def test_upload_remote_file_success_paths(self, client, mock_account, auth_ctx, head_status, use_get):
|
||||
url = "http://example.com/file.pdf"
|
||||
head_resp = httpx.Response(
|
||||
head_status,
|
||||
request=httpx.Request("HEAD", url),
|
||||
headers={"Content-Type": "application/pdf", "Content-Length": "1024"},
|
||||
)
|
||||
get_resp = httpx.Response(
|
||||
200,
|
||||
request=httpx.Request("GET", url),
|
||||
headers={"Content-Type": "application/pdf", "Content-Length": "1024"},
|
||||
content=b"file content",
|
||||
)
|
||||
|
||||
file_info = SimpleNamespace(
|
||||
extension="pdf",
|
||||
size=1024,
|
||||
filename="file.pdf",
|
||||
mimetype="application/pdf",
|
||||
)
|
||||
uploaded_file = SimpleNamespace(
|
||||
id="uploaded-file-id",
|
||||
name="file.pdf",
|
||||
size=1024,
|
||||
extension="pdf",
|
||||
mime_type="application/pdf",
|
||||
created_by="test-account-id",
|
||||
created_at=datetime(2024, 1, 1, 12, 0, 0),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.remote_files.current_account_with_tenant",
|
||||
return_value=(mock_account, "test-tenant-id"),
|
||||
),
|
||||
patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_resp) as p_head,
|
||||
patch("controllers.console.remote_files.ssrf_proxy.get", return_value=get_resp) as p_get,
|
||||
patch(
|
||||
"controllers.console.remote_files.helpers.guess_file_info_from_response",
|
||||
return_value=file_info,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.remote_files.FileService.is_file_size_within_limit",
|
||||
return_value=True,
|
||||
),
|
||||
patch("controllers.console.remote_files.db", spec=["engine"]),
|
||||
patch("controllers.console.remote_files.FileService") as mock_file_service,
|
||||
patch(
|
||||
"controllers.console.remote_files.file_helpers.get_signed_file_url",
|
||||
return_value="http://example.com/signed-url",
|
||||
),
|
||||
patch("libs.login.check_csrf_token", return_value=None),
|
||||
):
|
||||
mock_file_service.return_value.upload_file.return_value = uploaded_file
|
||||
|
||||
with auth_ctx():
|
||||
resp = client.post(
|
||||
"/console/api/remote-files/upload",
|
||||
json={"url": url},
|
||||
)
|
||||
|
||||
assert resp.status_code == 201
|
||||
p_head.assert_called_once()
|
||||
# GET is used either for fallback (HEAD fails) or to fetch content after HEAD succeeds
|
||||
p_get.assert_called_once()
|
||||
mock_file_service.return_value.upload_file.assert_called_once()
|
||||
|
||||
data = resp.get_json()
|
||||
assert data["id"] == "uploaded-file-id"
|
||||
assert data["name"] == "file.pdf"
|
||||
assert data["size"] == 1024
|
||||
assert data["extension"] == "pdf"
|
||||
assert data["url"] == "http://example.com/signed-url"
|
||||
assert data["mime_type"] == "application/pdf"
|
||||
assert data["created_by"] == "test-account-id"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("size_ok", "raises", "expected_status", "expected_msg"),
|
||||
[
|
||||
# When size check fails in controller, API returns 413 with message "File size exceeded..."
|
||||
(False, None, 413, "file size exceeded"),
|
||||
# When service raises unsupported type, controller maps to 415 with message "File type not allowed."
|
||||
(True, "unsupported", 415, "file type not allowed"),
|
||||
],
|
||||
)
|
||||
def test_upload_remote_file_errors(
|
||||
self, client, mock_account, auth_ctx, size_ok, raises, expected_status, expected_msg
|
||||
):
|
||||
url = "http://example.com/x.pdf"
|
||||
head_resp = httpx.Response(
|
||||
200,
|
||||
request=httpx.Request("HEAD", url),
|
||||
headers={"Content-Type": "application/pdf", "Content-Length": "9"},
|
||||
)
|
||||
file_info = SimpleNamespace(extension="pdf", size=9, filename="x.pdf", mimetype="application/pdf")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.remote_files.current_account_with_tenant",
|
||||
return_value=(mock_account, "test-tenant-id"),
|
||||
),
|
||||
patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_resp),
|
||||
patch(
|
||||
"controllers.console.remote_files.helpers.guess_file_info_from_response",
|
||||
return_value=file_info,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.remote_files.FileService.is_file_size_within_limit",
|
||||
return_value=size_ok,
|
||||
),
|
||||
patch("controllers.console.remote_files.db", spec=["engine"]),
|
||||
patch("libs.login.check_csrf_token", return_value=None),
|
||||
):
|
||||
if raises == "unsupported":
|
||||
from services.errors.file import UnsupportedFileTypeError
|
||||
|
||||
with patch("controllers.console.remote_files.FileService") as mock_file_service:
|
||||
mock_file_service.return_value.upload_file.side_effect = UnsupportedFileTypeError("bad")
|
||||
with auth_ctx():
|
||||
resp = client.post(
|
||||
"/console/api/remote-files/upload",
|
||||
json={"url": url},
|
||||
)
|
||||
else:
|
||||
with auth_ctx():
|
||||
resp = client.post(
|
||||
"/console/api/remote-files/upload",
|
||||
json={"url": url},
|
||||
)
|
||||
|
||||
assert resp.status_code == expected_status
|
||||
data = resp.get_json()
|
||||
msg = (data.get("error") or {}).get("message") or data.get("message", "")
|
||||
assert expected_msg in msg.lower()
|
||||
|
||||
def test_upload_remote_file_fetch_failure(self, client, mock_account, auth_ctx):
|
||||
"""Test upload when fetching of remote file fails."""
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.remote_files.current_account_with_tenant",
|
||||
return_value=(mock_account, "test-tenant-id"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.remote_files.ssrf_proxy.head",
|
||||
side_effect=httpx.RequestError("Connection failed"),
|
||||
),
|
||||
patch("libs.login.check_csrf_token", return_value=None),
|
||||
):
|
||||
with auth_ctx():
|
||||
resp = client.post(
|
||||
"/console/api/remote-files/upload",
|
||||
json={"url": "http://unreachable.com/file.pdf"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 400
|
||||
data = resp.get_json()
|
||||
msg = (data.get("error") or {}).get("message") or data.get("message", "")
|
||||
assert "failed to fetch" in msg.lower()
|
||||
81
api/tests/unit_tests/controllers/console/test_feature.py
Normal file
81
api/tests/unit_tests/controllers/console/test_feature.py
Normal file
@ -0,0 +1,81 @@
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
"""
|
||||
Recursively unwrap decorated functions.
|
||||
"""
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestFeatureApi:
|
||||
def test_get_tenant_features_success(self, mocker):
|
||||
from controllers.console.feature import FeatureApi
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.feature.current_account_with_tenant",
|
||||
return_value=("account_id", "tenant_123"),
|
||||
)
|
||||
|
||||
mocker.patch("controllers.console.feature.FeatureService.get_features").return_value.model_dump.return_value = {
|
||||
"features": {"feature_a": True}
|
||||
}
|
||||
|
||||
api = FeatureApi()
|
||||
|
||||
raw_get = unwrap(FeatureApi.get)
|
||||
result = raw_get(api)
|
||||
|
||||
assert result == {"features": {"feature_a": True}}
|
||||
|
||||
|
||||
class TestSystemFeatureApi:
|
||||
def test_get_system_features_authenticated(self, mocker):
|
||||
"""
|
||||
current_user.is_authenticated == True
|
||||
"""
|
||||
|
||||
from controllers.console.feature import SystemFeatureApi
|
||||
|
||||
fake_user = mocker.Mock()
|
||||
fake_user.is_authenticated = True
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.feature.current_user",
|
||||
fake_user,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.feature.FeatureService.get_system_features"
|
||||
).return_value.model_dump.return_value = {"features": {"sys_feature": True}}
|
||||
|
||||
api = SystemFeatureApi()
|
||||
result = api.get()
|
||||
|
||||
assert result == {"features": {"sys_feature": True}}
|
||||
|
||||
def test_get_system_features_unauthenticated(self, mocker):
|
||||
"""
|
||||
current_user.is_authenticated raises Unauthorized
|
||||
"""
|
||||
|
||||
from controllers.console.feature import SystemFeatureApi
|
||||
|
||||
fake_user = mocker.Mock()
|
||||
type(fake_user).is_authenticated = mocker.PropertyMock(side_effect=Unauthorized())
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.feature.current_user",
|
||||
fake_user,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.feature.FeatureService.get_system_features"
|
||||
).return_value.model_dump.return_value = {"features": {"sys_feature": False}}
|
||||
|
||||
api = SystemFeatureApi()
|
||||
result = api.get()
|
||||
|
||||
assert result == {"features": {"sys_feature": False}}
|
||||
300
api/tests/unit_tests/controllers/console/test_files.py
Normal file
300
api/tests/unit_tests/controllers/console/test_files.py
Normal file
@ -0,0 +1,300 @@
|
||||
import io
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from constants import DOCUMENT_EXTENSIONS
|
||||
from controllers.common.errors import (
|
||||
BlockedFileExtensionError,
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.console.files import (
|
||||
FileApi,
|
||||
FilePreviewApi,
|
||||
FileSupportTypeApi,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
"""
|
||||
Recursively unwrap decorated functions.
|
||||
"""
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask(__name__)
|
||||
app.testing = True
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_decorators():
|
||||
"""
|
||||
Make decorators no-ops so logic is directly testable
|
||||
"""
|
||||
with (
|
||||
patch("controllers.console.files.setup_required", new=lambda f: f),
|
||||
patch("controllers.console.files.login_required", new=lambda f: f),
|
||||
patch("controllers.console.files.account_initialization_required", new=lambda f: f),
|
||||
patch("controllers.console.files.cloud_edition_billing_resource_check", return_value=lambda f: f),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_current_user():
|
||||
user = MagicMock()
|
||||
user.is_dataset_editor = True
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account_context(mock_current_user):
|
||||
with patch(
|
||||
"controllers.console.files.current_account_with_tenant",
|
||||
return_value=(mock_current_user, None),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db():
|
||||
with patch("controllers.console.files.db") as db_mock:
|
||||
db_mock.engine = MagicMock()
|
||||
yield db_mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_file_service(mock_db):
|
||||
with patch("controllers.console.files.FileService") as fs:
|
||||
instance = fs.return_value
|
||||
yield instance
|
||||
|
||||
|
||||
class TestFileApiGet:
|
||||
def test_get_upload_config(self, app):
|
||||
api = FileApi()
|
||||
get_method = unwrap(api.get)
|
||||
|
||||
with app.test_request_context():
|
||||
data, status = get_method(api)
|
||||
|
||||
assert status == 200
|
||||
assert "file_size_limit" in data
|
||||
assert "batch_count_limit" in data
|
||||
|
||||
|
||||
class TestFileApiPost:
|
||||
def test_no_file_uploaded(self, app, mock_account_context):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
with app.test_request_context(method="POST", data={}):
|
||||
with pytest.raises(NoFileUploadedError):
|
||||
post_method(api)
|
||||
|
||||
def test_too_many_files(self, app, mock_account_context):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
with app.test_request_context(method="POST"):
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
with patch("controllers.console.files.request") as mock_request:
|
||||
mock_request.files = MagicMock()
|
||||
mock_request.files.__len__.return_value = 2
|
||||
mock_request.files.__contains__.return_value = True
|
||||
mock_request.form = MagicMock()
|
||||
mock_request.form.get.return_value = None
|
||||
|
||||
with pytest.raises(TooManyFilesError):
|
||||
post_method(api)
|
||||
|
||||
def test_filename_missing(self, app, mock_account_context):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
data = {
|
||||
"file": (io.BytesIO(b"abc"), ""),
|
||||
}
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(FilenameNotExistsError):
|
||||
post_method(api)
|
||||
|
||||
def test_dataset_upload_without_permission(self, app, mock_current_user):
|
||||
mock_current_user.is_dataset_editor = False
|
||||
|
||||
with patch(
|
||||
"controllers.console.files.current_account_with_tenant",
|
||||
return_value=(mock_current_user, None),
|
||||
):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
data = {
|
||||
"file": (io.BytesIO(b"abc"), "test.txt"),
|
||||
"source": "datasets",
|
||||
}
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(Forbidden):
|
||||
post_method(api)
|
||||
|
||||
def test_successful_upload(self, app, mock_account_context, mock_file_service):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
mock_file = MagicMock()
|
||||
mock_file.id = "file-id-123"
|
||||
mock_file.filename = "test.txt"
|
||||
mock_file.name = "test.txt"
|
||||
mock_file.size = 1024
|
||||
mock_file.extension = "txt"
|
||||
mock_file.mime_type = "text/plain"
|
||||
mock_file.created_by = "user-123"
|
||||
mock_file.created_at = 1234567890
|
||||
mock_file.preview_url = "http://example.com/preview/file-id-123"
|
||||
mock_file.source_url = "http://example.com/source/file-id-123"
|
||||
mock_file.original_url = None
|
||||
mock_file.user_id = "user-123"
|
||||
mock_file.tenant_id = "tenant-123"
|
||||
mock_file.conversation_id = None
|
||||
mock_file.file_key = "file-key-123"
|
||||
|
||||
mock_file_service.upload_file.return_value = mock_file
|
||||
|
||||
data = {
|
||||
"file": (io.BytesIO(b"hello"), "test.txt"),
|
||||
}
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
response, status = post_method(api)
|
||||
|
||||
assert status == 201
|
||||
assert response["id"] == "file-id-123"
|
||||
assert response["name"] == "test.txt"
|
||||
|
||||
def test_upload_with_invalid_source(self, app, mock_account_context, mock_file_service):
|
||||
"""Test that invalid source parameter gets normalized to None"""
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
# Create a properly structured mock file object
|
||||
mock_file = MagicMock()
|
||||
mock_file.id = "file-id-456"
|
||||
mock_file.filename = "test.txt"
|
||||
mock_file.name = "test.txt"
|
||||
mock_file.size = 512
|
||||
mock_file.extension = "txt"
|
||||
mock_file.mime_type = "text/plain"
|
||||
mock_file.created_by = "user-456"
|
||||
mock_file.created_at = 1234567890
|
||||
mock_file.preview_url = None
|
||||
mock_file.source_url = None
|
||||
mock_file.original_url = None
|
||||
mock_file.user_id = "user-456"
|
||||
mock_file.tenant_id = "tenant-456"
|
||||
mock_file.conversation_id = None
|
||||
mock_file.file_key = "file-key-456"
|
||||
|
||||
mock_file_service.upload_file.return_value = mock_file
|
||||
|
||||
data = {
|
||||
"file": (io.BytesIO(b"content"), "test.txt"),
|
||||
"source": "invalid_source", # Should be normalized to None
|
||||
}
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
response, status = post_method(api)
|
||||
|
||||
assert status == 201
|
||||
assert response["id"] == "file-id-456"
|
||||
# Verify that FileService was called with source=None
|
||||
mock_file_service.upload_file.assert_called_once()
|
||||
call_kwargs = mock_file_service.upload_file.call_args[1]
|
||||
assert call_kwargs["source"] is None
|
||||
|
||||
def test_file_too_large_error(self, app, mock_account_context, mock_file_service):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
from services.errors.file import FileTooLargeError as ServiceFileTooLargeError
|
||||
|
||||
error = ServiceFileTooLargeError("File is too large")
|
||||
mock_file_service.upload_file.side_effect = error
|
||||
|
||||
data = {
|
||||
"file": (io.BytesIO(b"x" * 1000000), "big.txt"),
|
||||
}
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(FileTooLargeError):
|
||||
post_method(api)
|
||||
|
||||
def test_unsupported_file_type(self, app, mock_account_context, mock_file_service):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError
|
||||
|
||||
error = ServiceUnsupportedFileTypeError()
|
||||
mock_file_service.upload_file.side_effect = error
|
||||
|
||||
data = {
|
||||
"file": (io.BytesIO(b"x"), "bad.exe"),
|
||||
}
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
post_method(api)
|
||||
|
||||
def test_blocked_extension(self, app, mock_account_context, mock_file_service):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
from services.errors.file import BlockedFileExtensionError as ServiceBlockedFileExtensionError
|
||||
|
||||
error = ServiceBlockedFileExtensionError("File extension is blocked")
|
||||
mock_file_service.upload_file.side_effect = error
|
||||
|
||||
data = {
|
||||
"file": (io.BytesIO(b"x"), "blocked.txt"),
|
||||
}
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(BlockedFileExtensionError):
|
||||
post_method(api)
|
||||
|
||||
|
||||
class TestFilePreviewApi:
|
||||
def test_get_preview(self, app, mock_file_service):
|
||||
api = FilePreviewApi()
|
||||
get_method = unwrap(api.get)
|
||||
mock_file_service.get_file_preview.return_value = "preview text"
|
||||
|
||||
with app.test_request_context():
|
||||
result = get_method(api, "1234")
|
||||
|
||||
assert result == {"content": "preview text"}
|
||||
|
||||
|
||||
class TestFileSupportTypeApi:
|
||||
def test_get_supported_types(self, app):
|
||||
api = FileSupportTypeApi()
|
||||
get_method = unwrap(api.get)
|
||||
|
||||
with app.test_request_context():
|
||||
result = get_method(api)
|
||||
|
||||
assert result == {"allowed_extensions": list(DOCUMENT_EXTENSIONS)}
|
||||
@ -0,0 +1,293 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from flask import Response
|
||||
|
||||
from controllers.console.human_input_form import (
|
||||
ConsoleHumanInputFormApi,
|
||||
ConsoleWorkflowEventsApi,
|
||||
DifyAPIRepositoryFactory,
|
||||
WorkflowResponseConverter,
|
||||
_jsonify_form_definition,
|
||||
)
|
||||
from controllers.web.error import NotFoundError
|
||||
from models.enums import CreatorUserRole
|
||||
from models.human_input import RecipientType
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def test_jsonify_form_definition() -> None:
|
||||
expiration = datetime(2024, 1, 1, tzinfo=UTC)
|
||||
definition = SimpleNamespace(model_dump=lambda: {"fields": []})
|
||||
form = SimpleNamespace(get_definition=lambda: definition, expiration_time=expiration)
|
||||
|
||||
response = _jsonify_form_definition(form)
|
||||
|
||||
assert isinstance(response, Response)
|
||||
payload = json.loads(response.get_data(as_text=True))
|
||||
assert payload["expiration_time"] == int(expiration.timestamp())
|
||||
|
||||
|
||||
def test_ensure_console_access_rejects(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = SimpleNamespace(tenant_id="tenant-1")
|
||||
monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-2"))
|
||||
|
||||
with pytest.raises(NotFoundError):
|
||||
ConsoleHumanInputFormApi._ensure_console_access(form)
|
||||
|
||||
|
||||
def test_get_form_definition_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
expiration = datetime(2024, 1, 1, tzinfo=UTC)
|
||||
definition = SimpleNamespace(model_dump=lambda: {"fields": ["a"]})
|
||||
form = SimpleNamespace(tenant_id="tenant-1", get_definition=lambda: definition, expiration_time=expiration)
|
||||
|
||||
class _ServiceStub:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def get_form_definition_by_token_for_console(self, _token):
|
||||
return form
|
||||
|
||||
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-1"))
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleHumanInputFormApi()
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/console/api/form/human_input/token", method="GET"):
|
||||
response = handler(api, form_token="token")
|
||||
|
||||
payload = json.loads(response.get_data(as_text=True))
|
||||
assert payload["fields"] == ["a"]
|
||||
|
||||
|
||||
def test_get_form_definition_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class _ServiceStub:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def get_form_definition_by_token_for_console(self, _token):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-1"))
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleHumanInputFormApi()
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/console/api/form/human_input/token", method="GET"):
|
||||
with pytest.raises(NotFoundError):
|
||||
handler(api, form_token="token")
|
||||
|
||||
|
||||
def test_post_form_invalid_recipient_type(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.EMAIL_MEMBER)
|
||||
|
||||
class _ServiceStub:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def get_form_by_token(self, _token):
|
||||
return form
|
||||
|
||||
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form.current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="user-1"), "tenant-1"),
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleHumanInputFormApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/form/human_input/token",
|
||||
method="POST",
|
||||
json={"inputs": {"content": "ok"}, "action": "approve"},
|
||||
):
|
||||
with pytest.raises(NotFoundError):
|
||||
handler(api, form_token="token")
|
||||
|
||||
|
||||
def test_post_form_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
submit_mock = Mock()
|
||||
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.CONSOLE)
|
||||
|
||||
class _ServiceStub:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def get_form_by_token(self, _token):
|
||||
return form
|
||||
|
||||
def submit_form_by_token(self, **kwargs):
|
||||
submit_mock(**kwargs)
|
||||
|
||||
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form.current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="user-1"), "tenant-1"),
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleHumanInputFormApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/form/human_input/token",
|
||||
method="POST",
|
||||
json={"inputs": {"content": "ok"}, "action": "approve"},
|
||||
):
|
||||
response = handler(api, form_token="token")
|
||||
|
||||
assert response.get_json() == {}
|
||||
submit_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_workflow_events_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class _RepoStub:
|
||||
def get_workflow_run_by_id_and_tenant_id(self, **_kwargs):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(
|
||||
DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_args, **_kwargs: _RepoStub(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form.current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="u1"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleWorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
|
||||
with pytest.raises(NotFoundError):
|
||||
handler(api, workflow_run_id="run-1")
|
||||
|
||||
|
||||
def test_workflow_events_requires_account(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow_run = SimpleNamespace(
|
||||
id="run-1",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="user-1",
|
||||
tenant_id="t1",
|
||||
)
|
||||
|
||||
class _RepoStub:
|
||||
def get_workflow_run_by_id_and_tenant_id(self, **_kwargs):
|
||||
return workflow_run
|
||||
|
||||
monkeypatch.setattr(
|
||||
DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_args, **_kwargs: _RepoStub(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form.current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="u1"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleWorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
|
||||
with pytest.raises(NotFoundError):
|
||||
handler(api, workflow_run_id="run-1")
|
||||
|
||||
|
||||
def test_workflow_events_requires_creator(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow_run = SimpleNamespace(
|
||||
id="run-1",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by="user-2",
|
||||
tenant_id="t1",
|
||||
)
|
||||
|
||||
class _RepoStub:
|
||||
def get_workflow_run_by_id_and_tenant_id(self, **_kwargs):
|
||||
return workflow_run
|
||||
|
||||
monkeypatch.setattr(
|
||||
DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_args, **_kwargs: _RepoStub(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form.current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="u1"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleWorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
|
||||
with pytest.raises(NotFoundError):
|
||||
handler(api, workflow_run_id="run-1")
|
||||
|
||||
|
||||
def test_workflow_events_finished(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow_run = SimpleNamespace(
|
||||
id="run-1",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by="user-1",
|
||||
tenant_id="t1",
|
||||
app_id="app-1",
|
||||
finished_at=datetime(2024, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
app_model = SimpleNamespace(mode=AppMode.WORKFLOW)
|
||||
|
||||
class _RepoStub:
|
||||
def get_workflow_run_by_id_and_tenant_id(self, **_kwargs):
|
||||
return workflow_run
|
||||
|
||||
response_obj = SimpleNamespace(
|
||||
event=SimpleNamespace(value="finished"),
|
||||
model_dump=lambda mode="json": {"status": "done"},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_args, **_kwargs: _RepoStub(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form._retrieve_app_for_workflow_run",
|
||||
lambda *_args, **_kwargs: app_model,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
WorkflowResponseConverter,
|
||||
"workflow_run_result_to_finish_response",
|
||||
lambda **_kwargs: response_obj,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form.current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="user-1"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleWorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
|
||||
response = handler(api, workflow_run_id="run-1")
|
||||
|
||||
assert response.mimetype == "text/event-stream"
|
||||
assert "data" in response.get_data(as_text=True)
|
||||
108
api/tests/unit_tests/controllers/console/test_init_validate.py
Normal file
108
api/tests/unit_tests/controllers/console/test_init_validate.py
Normal file
@ -0,0 +1,108 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console import init_validate
|
||||
from controllers.console.error import AlreadySetupError, InitValidateFailedError
|
||||
|
||||
|
||||
class _SessionStub:
|
||||
def __init__(self, has_setup: bool):
|
||||
self._has_setup = has_setup
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def execute(self, *_args, **_kwargs):
|
||||
return SimpleNamespace(scalar_one_or_none=lambda: Mock() if self._has_setup else None)
|
||||
|
||||
|
||||
def test_get_init_status_finished(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate, "get_init_validate_status", lambda: True)
|
||||
result = init_validate.get_init_status()
|
||||
assert result.status == "finished"
|
||||
|
||||
|
||||
def test_get_init_status_not_started(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate, "get_init_validate_status", lambda: False)
|
||||
result = init_validate.get_init_status()
|
||||
assert result.status == "not_started"
|
||||
|
||||
|
||||
def test_validate_init_password_already_setup(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
|
||||
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 1)
|
||||
app.secret_key = "test-secret"
|
||||
|
||||
with app.test_request_context("/console/api/init", method="POST"):
|
||||
with pytest.raises(AlreadySetupError):
|
||||
init_validate.validate_init_password(init_validate.InitValidatePayload(password="pw"))
|
||||
|
||||
|
||||
def test_validate_init_password_wrong_password(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
|
||||
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 0)
|
||||
monkeypatch.setenv("INIT_PASSWORD", "expected")
|
||||
app.secret_key = "test-secret"
|
||||
|
||||
with app.test_request_context("/console/api/init", method="POST"):
|
||||
with pytest.raises(InitValidateFailedError):
|
||||
init_validate.validate_init_password(init_validate.InitValidatePayload(password="wrong"))
|
||||
assert init_validate.session.get("is_init_validated") is False
|
||||
|
||||
|
||||
def test_validate_init_password_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
|
||||
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 0)
|
||||
monkeypatch.setenv("INIT_PASSWORD", "expected")
|
||||
app.secret_key = "test-secret"
|
||||
|
||||
with app.test_request_context("/console/api/init", method="POST"):
|
||||
result = init_validate.validate_init_password(init_validate.InitValidatePayload(password="expected"))
|
||||
assert result.result == "success"
|
||||
assert init_validate.session.get("is_init_validated") is True
|
||||
|
||||
|
||||
def test_get_init_validate_status_not_self_hosted(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate.dify_config, "EDITION", "CLOUD")
|
||||
assert init_validate.get_init_validate_status() is True
|
||||
|
||||
|
||||
def test_get_init_validate_status_validated_session(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
|
||||
monkeypatch.setenv("INIT_PASSWORD", "expected")
|
||||
app.secret_key = "test-secret"
|
||||
|
||||
with app.test_request_context("/console/api/init", method="GET"):
|
||||
init_validate.session["is_init_validated"] = True
|
||||
assert init_validate.get_init_validate_status() is True
|
||||
|
||||
|
||||
def test_get_init_validate_status_setup_exists(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
|
||||
monkeypatch.setenv("INIT_PASSWORD", "expected")
|
||||
monkeypatch.setattr(init_validate, "Session", lambda *_args, **_kwargs: _SessionStub(True))
|
||||
monkeypatch.setattr(init_validate, "db", SimpleNamespace(engine=object()))
|
||||
app.secret_key = "test-secret"
|
||||
|
||||
with app.test_request_context("/console/api/init", method="GET"):
|
||||
init_validate.session.pop("is_init_validated", None)
|
||||
assert init_validate.get_init_validate_status() is True
|
||||
|
||||
|
||||
def test_get_init_validate_status_not_validated(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
|
||||
monkeypatch.setenv("INIT_PASSWORD", "expected")
|
||||
monkeypatch.setattr(init_validate, "Session", lambda *_args, **_kwargs: _SessionStub(False))
|
||||
monkeypatch.setattr(init_validate, "db", SimpleNamespace(engine=object()))
|
||||
app.secret_key = "test-secret"
|
||||
|
||||
with app.test_request_context("/console/api/init", method="GET"):
|
||||
init_validate.session.pop("is_init_validated", None)
|
||||
assert init_validate.get_init_validate_status() is False
|
||||
281
api/tests/unit_tests/controllers/console/test_remote_files.py
Normal file
281
api/tests/unit_tests/controllers/console/test_remote_files.py
Normal file
@ -0,0 +1,281 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import urllib.parse
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from controllers.common.errors import FileTooLargeError, RemoteFileUploadError, UnsupportedFileTypeError
|
||||
from controllers.console import remote_files as remote_files_module
|
||||
from services.errors.file import FileTooLargeError as ServiceFileTooLargeError
|
||||
from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
status_code: int = 200,
|
||||
headers: dict[str, str] | None = None,
|
||||
method: str = "GET",
|
||||
content: bytes = b"",
|
||||
text: str = "",
|
||||
error: Exception | None = None,
|
||||
) -> None:
|
||||
self.status_code = status_code
|
||||
self.headers = headers or {}
|
||||
self.request = SimpleNamespace(method=method)
|
||||
self.content = content
|
||||
self.text = text
|
||||
self._error = error
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
if self._error:
|
||||
raise self._error
|
||||
|
||||
|
||||
def _mock_upload_dependencies(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
*,
|
||||
file_size_within_limit: bool = True,
|
||||
):
|
||||
file_info = SimpleNamespace(
|
||||
filename="report.txt",
|
||||
extension=".txt",
|
||||
mimetype="text/plain",
|
||||
size=3,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.helpers,
|
||||
"guess_file_info_from_response",
|
||||
MagicMock(return_value=file_info),
|
||||
)
|
||||
|
||||
file_service_cls = MagicMock()
|
||||
file_service_cls.is_file_size_within_limit.return_value = file_size_within_limit
|
||||
monkeypatch.setattr(remote_files_module, "FileService", file_service_cls)
|
||||
monkeypatch.setattr(remote_files_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), None))
|
||||
monkeypatch.setattr(remote_files_module, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.file_helpers,
|
||||
"get_signed_file_url",
|
||||
lambda upload_file_id: f"https://signed.example/{upload_file_id}",
|
||||
)
|
||||
|
||||
return file_service_cls
|
||||
|
||||
|
||||
def test_get_remote_file_info_uses_head_when_successful(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.GetRemoteFileInfo()
|
||||
handler = _unwrap(api.get)
|
||||
decoded_url = "https://example.com/test.txt"
|
||||
encoded_url = urllib.parse.quote(decoded_url, safe="")
|
||||
|
||||
head_resp = _FakeResponse(
|
||||
status_code=200,
|
||||
headers={"Content-Type": "text/plain", "Content-Length": "128"},
|
||||
method="HEAD",
|
||||
)
|
||||
head_mock = MagicMock(return_value=head_resp)
|
||||
get_mock = MagicMock()
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", head_mock)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
|
||||
|
||||
with app.test_request_context(method="GET"):
|
||||
payload = handler(api, url=encoded_url)
|
||||
|
||||
assert payload == {"file_type": "text/plain", "file_length": 128}
|
||||
head_mock.assert_called_once_with(decoded_url)
|
||||
get_mock.assert_not_called()
|
||||
|
||||
|
||||
def test_get_remote_file_info_falls_back_to_get_and_uses_default_headers(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.GetRemoteFileInfo()
|
||||
handler = _unwrap(api.get)
|
||||
decoded_url = "https://example.com/test.txt"
|
||||
encoded_url = urllib.parse.quote(decoded_url, safe="")
|
||||
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=503)))
|
||||
get_mock = MagicMock(return_value=_FakeResponse(status_code=200, headers={}, method="GET"))
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
|
||||
|
||||
with app.test_request_context(method="GET"):
|
||||
payload = handler(api, url=encoded_url)
|
||||
|
||||
assert payload == {"file_type": "application/octet-stream", "file_length": 0}
|
||||
get_mock.assert_called_once_with(decoded_url, timeout=3)
|
||||
|
||||
|
||||
def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = _unwrap(api.post)
|
||||
url = "https://example.com/report.txt"
|
||||
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=404)))
|
||||
get_resp = _FakeResponse(status_code=200, method="GET", content=b"fallback-content")
|
||||
get_mock = MagicMock(return_value=get_resp)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
|
||||
|
||||
file_service_cls = _mock_upload_dependencies(monkeypatch)
|
||||
upload_file = SimpleNamespace(
|
||||
id="file-1",
|
||||
name="report.txt",
|
||||
size=16,
|
||||
extension=".txt",
|
||||
mime_type="text/plain",
|
||||
created_by="u1",
|
||||
created_at=datetime(2024, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
file_service_cls.return_value.upload_file.return_value = upload_file
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
payload, status = handler(api)
|
||||
|
||||
assert status == 201
|
||||
assert payload["id"] == "file-1"
|
||||
assert payload["url"] == "https://signed.example/file-1"
|
||||
get_mock.assert_called_once_with(url=url, timeout=3, follow_redirects=True)
|
||||
file_service_cls.return_value.upload_file.assert_called_once_with(
|
||||
filename="report.txt",
|
||||
content=b"fallback-content",
|
||||
mimetype="text/plain",
|
||||
user=SimpleNamespace(id="u1"),
|
||||
source_url=url,
|
||||
)
|
||||
|
||||
|
||||
def test_remote_file_upload_fetches_content_with_second_get_when_head_succeeds(
|
||||
app, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = _unwrap(api.post)
|
||||
url = "https://example.com/photo.jpg"
|
||||
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.ssrf_proxy,
|
||||
"head",
|
||||
MagicMock(return_value=_FakeResponse(status_code=200, method="HEAD", content=b"head-content")),
|
||||
)
|
||||
extra_get_resp = _FakeResponse(status_code=200, method="GET", content=b"downloaded-content")
|
||||
get_mock = MagicMock(return_value=extra_get_resp)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
|
||||
|
||||
file_service_cls = _mock_upload_dependencies(monkeypatch)
|
||||
upload_file = SimpleNamespace(
|
||||
id="file-2",
|
||||
name="photo.jpg",
|
||||
size=18,
|
||||
extension=".jpg",
|
||||
mime_type="image/jpeg",
|
||||
created_by="u1",
|
||||
created_at=datetime(2024, 1, 2, tzinfo=UTC),
|
||||
)
|
||||
file_service_cls.return_value.upload_file.return_value = upload_file
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
payload, status = handler(api)
|
||||
|
||||
assert status == 201
|
||||
assert payload["id"] == "file-2"
|
||||
get_mock.assert_called_once_with(url)
|
||||
assert file_service_cls.return_value.upload_file.call_args.kwargs["content"] == b"downloaded-content"
|
||||
|
||||
|
||||
def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = _unwrap(api.post)
|
||||
url = "https://example.com/fail.txt"
|
||||
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=500)))
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.ssrf_proxy,
|
||||
"get",
|
||||
MagicMock(return_value=_FakeResponse(status_code=502, text="bad gateway")),
|
||||
)
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: bad gateway"):
|
||||
handler(api)
|
||||
|
||||
|
||||
def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = _unwrap(api.post)
|
||||
url = "https://example.com/fail.txt"
|
||||
|
||||
request = httpx.Request("HEAD", url)
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.ssrf_proxy,
|
||||
"head",
|
||||
MagicMock(side_effect=httpx.RequestError("network down", request=request)),
|
||||
)
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: network down"):
|
||||
handler(api)
|
||||
|
||||
|
||||
def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = _unwrap(api.post)
|
||||
url = "https://example.com/large.bin"
|
||||
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.ssrf_proxy,
|
||||
"head",
|
||||
MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")),
|
||||
)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
|
||||
|
||||
_mock_upload_dependencies(monkeypatch, file_size_within_limit=False)
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
with pytest.raises(FileTooLargeError):
|
||||
handler(api)
|
||||
|
||||
|
||||
def test_remote_file_upload_translates_service_file_too_large_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = _unwrap(api.post)
|
||||
url = "https://example.com/large.bin"
|
||||
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.ssrf_proxy,
|
||||
"head",
|
||||
MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")),
|
||||
)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
|
||||
file_service_cls = _mock_upload_dependencies(monkeypatch)
|
||||
file_service_cls.return_value.upload_file.side_effect = ServiceFileTooLargeError("size exceeded")
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
with pytest.raises(FileTooLargeError, match="size exceeded"):
|
||||
handler(api)
|
||||
|
||||
|
||||
def test_remote_file_upload_translates_service_unsupported_type_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = _unwrap(api.post)
|
||||
url = "https://example.com/file.exe"
|
||||
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.ssrf_proxy,
|
||||
"head",
|
||||
MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")),
|
||||
)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
|
||||
file_service_cls = _mock_upload_dependencies(monkeypatch)
|
||||
file_service_cls.return_value.upload_file.side_effect = ServiceUnsupportedFileTypeError()
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
handler(api)
|
||||
49
api/tests/unit_tests/controllers/console/test_spec.py
Normal file
49
api/tests/unit_tests/controllers/console/test_spec.py
Normal file
@ -0,0 +1,49 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import controllers.console.spec as spec_module
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestSpecSchemaDefinitionsApi:
|
||||
def test_get_success(self):
|
||||
api = spec_module.SpecSchemaDefinitionsApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
schema_definitions = [{"type": "string"}]
|
||||
|
||||
with patch.object(
|
||||
spec_module,
|
||||
"SchemaManager",
|
||||
) as schema_manager_cls:
|
||||
schema_manager_cls.return_value.get_all_schema_definitions.return_value = schema_definitions
|
||||
|
||||
resp, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert resp == schema_definitions
|
||||
|
||||
def test_get_exception_returns_empty_list(self):
|
||||
api = spec_module.SpecSchemaDefinitionsApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
spec_module,
|
||||
"SchemaManager",
|
||||
side_effect=Exception("boom"),
|
||||
),
|
||||
patch.object(
|
||||
spec_module.logger,
|
||||
"exception",
|
||||
) as log_exception,
|
||||
):
|
||||
resp, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert resp == []
|
||||
log_exception.assert_called_once()
|
||||
162
api/tests/unit_tests/controllers/console/test_version.py
Normal file
162
api/tests/unit_tests/controllers/console/test_version.py
Normal file
@ -0,0 +1,162 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import controllers.console.version as version_module
|
||||
|
||||
|
||||
class TestHasNewVersion:
|
||||
def test_has_new_version_true(self):
|
||||
result = version_module._has_new_version(
|
||||
latest_version="1.2.0",
|
||||
current_version="1.1.0",
|
||||
)
|
||||
assert result is True
|
||||
|
||||
def test_has_new_version_false(self):
|
||||
result = version_module._has_new_version(
|
||||
latest_version="1.0.0",
|
||||
current_version="1.1.0",
|
||||
)
|
||||
assert result is False
|
||||
|
||||
def test_has_new_version_invalid_version(self):
|
||||
with patch.object(version_module.logger, "warning") as log_warning:
|
||||
result = version_module._has_new_version(
|
||||
latest_version="invalid",
|
||||
current_version="1.0.0",
|
||||
)
|
||||
|
||||
assert result is False
|
||||
log_warning.assert_called_once()
|
||||
|
||||
|
||||
class TestCheckVersionUpdate:
|
||||
def test_no_check_update_url(self):
|
||||
query = version_module.VersionQuery(current_version="1.0.0")
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
version_module.dify_config,
|
||||
"CHECK_UPDATE_URL",
|
||||
"",
|
||||
),
|
||||
patch.object(
|
||||
version_module.dify_config.project,
|
||||
"version",
|
||||
"1.0.0",
|
||||
),
|
||||
patch.object(
|
||||
version_module.dify_config,
|
||||
"CAN_REPLACE_LOGO",
|
||||
True,
|
||||
),
|
||||
patch.object(
|
||||
version_module.dify_config,
|
||||
"MODEL_LB_ENABLED",
|
||||
False,
|
||||
),
|
||||
):
|
||||
result = version_module.check_version_update(query)
|
||||
|
||||
assert result.version == "1.0.0"
|
||||
assert result.can_auto_update is False
|
||||
assert result.features.can_replace_logo is True
|
||||
assert result.features.model_load_balancing_enabled is False
|
||||
|
||||
def test_http_error_fallback(self):
|
||||
query = version_module.VersionQuery(current_version="1.0.0")
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
version_module.dify_config,
|
||||
"CHECK_UPDATE_URL",
|
||||
"http://example.com",
|
||||
),
|
||||
patch.object(
|
||||
version_module.httpx,
|
||||
"get",
|
||||
side_effect=Exception("boom"),
|
||||
),
|
||||
patch.object(
|
||||
version_module.logger,
|
||||
"warning",
|
||||
) as log_warning,
|
||||
):
|
||||
result = version_module.check_version_update(query)
|
||||
|
||||
assert result.version == "1.0.0"
|
||||
log_warning.assert_called_once()
|
||||
|
||||
def test_new_version_available(self):
|
||||
query = version_module.VersionQuery(current_version="1.0.0")
|
||||
|
||||
response = MagicMock()
|
||||
response.json.return_value = {
|
||||
"version": "1.2.0",
|
||||
"releaseDate": "2024-01-01",
|
||||
"releaseNotes": "New features",
|
||||
"canAutoUpdate": True,
|
||||
}
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
version_module.dify_config,
|
||||
"CHECK_UPDATE_URL",
|
||||
"http://example.com",
|
||||
),
|
||||
patch.object(
|
||||
version_module.httpx,
|
||||
"get",
|
||||
return_value=response,
|
||||
),
|
||||
patch.object(
|
||||
version_module.dify_config.project,
|
||||
"version",
|
||||
"1.0.0",
|
||||
),
|
||||
patch.object(
|
||||
version_module.dify_config,
|
||||
"CAN_REPLACE_LOGO",
|
||||
False,
|
||||
),
|
||||
patch.object(
|
||||
version_module.dify_config,
|
||||
"MODEL_LB_ENABLED",
|
||||
True,
|
||||
),
|
||||
):
|
||||
result = version_module.check_version_update(query)
|
||||
|
||||
assert result.version == "1.2.0"
|
||||
assert result.release_date == "2024-01-01"
|
||||
assert result.release_notes == "New features"
|
||||
assert result.can_auto_update is True
|
||||
|
||||
def test_no_new_version(self):
|
||||
query = version_module.VersionQuery(current_version="1.2.0")
|
||||
|
||||
response = MagicMock()
|
||||
response.json.return_value = {
|
||||
"version": "1.1.0",
|
||||
}
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
version_module.dify_config,
|
||||
"CHECK_UPDATE_URL",
|
||||
"http://example.com",
|
||||
),
|
||||
patch.object(
|
||||
version_module.httpx,
|
||||
"get",
|
||||
return_value=response,
|
||||
),
|
||||
patch.object(
|
||||
version_module.dify_config.project,
|
||||
"version",
|
||||
"1.2.0",
|
||||
),
|
||||
):
|
||||
result = version_module.check_version_update(query)
|
||||
|
||||
assert result.version == "1.2.0"
|
||||
assert result.can_auto_update is False
|
||||
0
api/tests/unit_tests/core/trigger/__init__.py
Normal file
0
api/tests/unit_tests/core/trigger/__init__.py
Normal file
93
api/tests/unit_tests/core/trigger/conftest.py
Normal file
93
api/tests/unit_tests/core/trigger/conftest.py
Normal file
@ -0,0 +1,93 @@
|
||||
"""Shared factory helpers for core.trigger test suite."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.trigger.entities.entities import (
|
||||
EventEntity,
|
||||
EventIdentity,
|
||||
EventParameter,
|
||||
OAuthSchema,
|
||||
Subscription,
|
||||
SubscriptionConstructor,
|
||||
TriggerProviderEntity,
|
||||
TriggerProviderIdentity,
|
||||
)
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
from models.provider_ids import TriggerProviderID
|
||||
|
||||
# Valid format for TriggerProviderID: org/plugin/provider
|
||||
VALID_PROVIDER_ID = "testorg/testplugin/testprovider"
|
||||
|
||||
|
||||
def i18n(text: str = "test") -> I18nObject:
|
||||
return I18nObject(en_US=text, zh_Hans=text)
|
||||
|
||||
|
||||
def make_event(name: str = "test_event", parameters: list[EventParameter] | None = None) -> EventEntity:
|
||||
return EventEntity(
|
||||
identity=EventIdentity(author="a", name=name, label=i18n(name)),
|
||||
description=i18n(name),
|
||||
parameters=parameters or [],
|
||||
)
|
||||
|
||||
|
||||
def make_provider_entity(
|
||||
name: str = "test_provider",
|
||||
events: list[EventEntity] | None = None,
|
||||
constructor: SubscriptionConstructor | None = None,
|
||||
subscription_schema: list[ProviderConfig] | None = None,
|
||||
icon: str | None = "icon.png",
|
||||
icon_dark: str | None = None,
|
||||
) -> TriggerProviderEntity:
|
||||
return TriggerProviderEntity(
|
||||
identity=TriggerProviderIdentity(
|
||||
author="a",
|
||||
name=name,
|
||||
label=i18n(name),
|
||||
description=i18n(name),
|
||||
icon=icon,
|
||||
icon_dark=icon_dark,
|
||||
),
|
||||
events=events if events is not None else [make_event()],
|
||||
subscription_constructor=constructor,
|
||||
subscription_schema=subscription_schema or [],
|
||||
)
|
||||
|
||||
|
||||
def make_controller(
|
||||
entity: TriggerProviderEntity | None = None,
|
||||
tenant_id: str = "tenant-1",
|
||||
provider_id: str = VALID_PROVIDER_ID,
|
||||
) -> PluginTriggerProviderController:
|
||||
return PluginTriggerProviderController(
|
||||
entity=entity or make_provider_entity(),
|
||||
plugin_id="plugin-1",
|
||||
plugin_unique_identifier="uid-1",
|
||||
provider_id=TriggerProviderID(provider_id),
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
|
||||
def make_subscription(**overrides: Any) -> Subscription:
|
||||
defaults = {"expires_at": 9999999999, "endpoint": "https://hook.test", "properties": {"k": "v"}, "parameters": {}}
|
||||
defaults.update(overrides)
|
||||
return Subscription(**defaults)
|
||||
|
||||
|
||||
def make_provider_config(
|
||||
name: str = "api_key", required: bool = True, config_type: str = "secret-input"
|
||||
) -> ProviderConfig:
|
||||
return ProviderConfig(name=name, label=i18n(name), type=config_type, required=required)
|
||||
|
||||
|
||||
def make_constructor(
|
||||
credentials_schema: list[ProviderConfig] | None = None,
|
||||
oauth_schema: OAuthSchema | None = None,
|
||||
) -> SubscriptionConstructor:
|
||||
return SubscriptionConstructor(
|
||||
parameters=[], credentials_schema=credentials_schema or [], oauth_schema=oauth_schema
|
||||
)
|
||||
0
api/tests/unit_tests/core/trigger/debug/__init__.py
Normal file
0
api/tests/unit_tests/core/trigger/debug/__init__.py
Normal file
@ -0,0 +1,93 @@
|
||||
"""
|
||||
Tests for core.trigger.debug.event_bus.TriggerDebugEventBus.
|
||||
|
||||
Covers: Lua-script dispatch/poll with Redis error resilience.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from redis import RedisError
|
||||
|
||||
from core.trigger.debug.event_bus import TriggerDebugEventBus
|
||||
from core.trigger.debug.events import PluginTriggerDebugEvent
|
||||
|
||||
|
||||
class TestDispatch:
|
||||
@patch("core.trigger.debug.event_bus.redis_client")
|
||||
def test_returns_dispatch_count(self, mock_redis):
|
||||
mock_redis.eval.return_value = 3
|
||||
event = MagicMock()
|
||||
event.model_dump_json.return_value = '{"test": true}'
|
||||
|
||||
result = TriggerDebugEventBus.dispatch("tenant-1", event, "pool:key")
|
||||
|
||||
assert result == 3
|
||||
mock_redis.eval.assert_called_once()
|
||||
|
||||
@patch("core.trigger.debug.event_bus.redis_client")
|
||||
def test_redis_error_returns_zero(self, mock_redis):
|
||||
mock_redis.eval.side_effect = RedisError("connection lost")
|
||||
event = MagicMock()
|
||||
event.model_dump_json.return_value = "{}"
|
||||
|
||||
result = TriggerDebugEventBus.dispatch("tenant-1", event, "pool:key")
|
||||
|
||||
assert result == 0
|
||||
|
||||
|
||||
class TestPoll:
|
||||
@patch("core.trigger.debug.event_bus.redis_client")
|
||||
def test_returns_deserialized_event(self, mock_redis):
|
||||
event_json = PluginTriggerDebugEvent(
|
||||
timestamp=100,
|
||||
name="push",
|
||||
user_id="u1",
|
||||
request_id="r1",
|
||||
subscription_id="s1",
|
||||
provider_id="p1",
|
||||
).model_dump_json()
|
||||
mock_redis.eval.return_value = event_json
|
||||
|
||||
result = TriggerDebugEventBus.poll(
|
||||
event_type=PluginTriggerDebugEvent,
|
||||
pool_key="pool:key",
|
||||
tenant_id="t1",
|
||||
user_id="u1",
|
||||
app_id="a1",
|
||||
node_id="n1",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.name == "push"
|
||||
|
||||
@patch("core.trigger.debug.event_bus.redis_client")
|
||||
def test_returns_none_when_no_event(self, mock_redis):
|
||||
mock_redis.eval.return_value = None
|
||||
|
||||
result = TriggerDebugEventBus.poll(
|
||||
event_type=PluginTriggerDebugEvent,
|
||||
pool_key="pool:key",
|
||||
tenant_id="t1",
|
||||
user_id="u1",
|
||||
app_id="a1",
|
||||
node_id="n1",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@patch("core.trigger.debug.event_bus.redis_client")
|
||||
def test_redis_error_returns_none(self, mock_redis):
|
||||
mock_redis.eval.side_effect = RedisError("timeout")
|
||||
|
||||
result = TriggerDebugEventBus.poll(
|
||||
event_type=PluginTriggerDebugEvent,
|
||||
pool_key="pool:key",
|
||||
tenant_id="t1",
|
||||
user_id="u1",
|
||||
app_id="a1",
|
||||
node_id="n1",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
@ -0,0 +1,276 @@
|
||||
"""
|
||||
Tests for core.trigger.debug.event_selectors.
|
||||
|
||||
Covers: Plugin/Webhook/Schedule pollers, create_event_poller factory,
|
||||
and select_trigger_debug_events orchestrator.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.plugin.entities.request import TriggerInvokeEventResponse
|
||||
from core.trigger.debug.event_selectors import (
|
||||
PluginTriggerDebugEventPoller,
|
||||
ScheduleTriggerDebugEventPoller,
|
||||
WebhookTriggerDebugEventPoller,
|
||||
create_event_poller,
|
||||
select_trigger_debug_events,
|
||||
)
|
||||
from core.trigger.debug.events import PluginTriggerDebugEvent, WebhookDebugEvent
|
||||
from core.workflow.enums import NodeType
|
||||
from tests.unit_tests.core.trigger.conftest import VALID_PROVIDER_ID
|
||||
|
||||
|
||||
def _make_poller_args(node_config: dict | None = None) -> dict:
|
||||
return {
|
||||
"tenant_id": "t1",
|
||||
"user_id": "u1",
|
||||
"app_id": "a1",
|
||||
"node_config": node_config or {"data": {}},
|
||||
"node_id": "n1",
|
||||
}
|
||||
|
||||
|
||||
def _plugin_node_config(provider_id: str = VALID_PROVIDER_ID) -> dict:
|
||||
"""Valid node config for TriggerEventNodeData.model_validate."""
|
||||
return {
|
||||
"data": {
|
||||
"title": "test",
|
||||
"plugin_id": "org/testplugin",
|
||||
"provider_id": provider_id,
|
||||
"event_name": "push",
|
||||
"subscription_id": "s1",
|
||||
"plugin_unique_identifier": "uid-1",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class TestPluginTriggerDebugEventPoller:
|
||||
@patch("core.trigger.debug.event_selectors.TriggerDebugEventBus")
|
||||
def test_returns_workflow_args_on_success(self, mock_bus):
|
||||
event = PluginTriggerDebugEvent(
|
||||
timestamp=100,
|
||||
name="push",
|
||||
user_id="u1",
|
||||
request_id="r1",
|
||||
subscription_id="s1",
|
||||
provider_id="p1",
|
||||
)
|
||||
mock_bus.poll.return_value = event
|
||||
|
||||
with patch("services.trigger.trigger_service.TriggerService") as mock_trigger_svc:
|
||||
mock_trigger_svc.invoke_trigger_event.return_value = TriggerInvokeEventResponse(
|
||||
variables={"repo": "dify"},
|
||||
cancelled=False,
|
||||
)
|
||||
|
||||
poller = PluginTriggerDebugEventPoller(**_make_poller_args(_plugin_node_config()))
|
||||
result = poller.poll()
|
||||
|
||||
assert result is not None
|
||||
assert result.workflow_args["inputs"] == {"repo": "dify"}
|
||||
|
||||
@patch("core.trigger.debug.event_selectors.TriggerDebugEventBus")
|
||||
def test_returns_none_when_no_event(self, mock_bus):
|
||||
mock_bus.poll.return_value = None
|
||||
|
||||
poller = PluginTriggerDebugEventPoller(**_make_poller_args(_plugin_node_config()))
|
||||
|
||||
assert poller.poll() is None
|
||||
|
||||
@patch("core.trigger.debug.event_selectors.TriggerDebugEventBus")
|
||||
def test_returns_none_when_invoke_cancelled(self, mock_bus):
|
||||
event = PluginTriggerDebugEvent(
|
||||
timestamp=100,
|
||||
name="push",
|
||||
user_id="u1",
|
||||
request_id="r1",
|
||||
subscription_id="s1",
|
||||
provider_id="p1",
|
||||
)
|
||||
mock_bus.poll.return_value = event
|
||||
|
||||
with patch("services.trigger.trigger_service.TriggerService") as mock_trigger_svc:
|
||||
mock_trigger_svc.invoke_trigger_event.return_value = TriggerInvokeEventResponse(
|
||||
variables={},
|
||||
cancelled=True,
|
||||
)
|
||||
|
||||
poller = PluginTriggerDebugEventPoller(**_make_poller_args(_plugin_node_config()))
|
||||
|
||||
assert poller.poll() is None
|
||||
|
||||
|
||||
class TestWebhookTriggerDebugEventPoller:
|
||||
@patch("core.trigger.debug.event_selectors.TriggerDebugEventBus")
|
||||
def test_uses_inputs_directly_when_present(self, mock_bus):
|
||||
event = WebhookDebugEvent(
|
||||
timestamp=100,
|
||||
request_id="r1",
|
||||
node_id="n1",
|
||||
payload={"inputs": {"key": "val"}, "webhook_data": {}},
|
||||
)
|
||||
mock_bus.poll.return_value = event
|
||||
|
||||
poller = WebhookTriggerDebugEventPoller(**_make_poller_args())
|
||||
result = poller.poll()
|
||||
|
||||
assert result is not None
|
||||
assert result.workflow_args["inputs"] == {"key": "val"}
|
||||
|
||||
@patch("core.trigger.debug.event_selectors.TriggerDebugEventBus")
|
||||
def test_falls_back_to_webhook_data(self, mock_bus):
|
||||
event = WebhookDebugEvent(
|
||||
timestamp=100,
|
||||
request_id="r1",
|
||||
node_id="n1",
|
||||
payload={"webhook_data": {"body": "raw"}},
|
||||
)
|
||||
mock_bus.poll.return_value = event
|
||||
|
||||
with patch("services.trigger.webhook_service.WebhookService") as mock_webhook_svc:
|
||||
mock_webhook_svc.build_workflow_inputs.return_value = {"parsed": "data"}
|
||||
|
||||
poller = WebhookTriggerDebugEventPoller(**_make_poller_args())
|
||||
result = poller.poll()
|
||||
|
||||
assert result is not None
|
||||
assert result.workflow_args["inputs"] == {"parsed": "data"}
|
||||
mock_webhook_svc.build_workflow_inputs.assert_called_once_with({"body": "raw"})
|
||||
|
||||
@patch("core.trigger.debug.event_selectors.TriggerDebugEventBus")
|
||||
def test_returns_none_when_no_event(self, mock_bus):
|
||||
mock_bus.poll.return_value = None
|
||||
poller = WebhookTriggerDebugEventPoller(**_make_poller_args())
|
||||
|
||||
assert poller.poll() is None
|
||||
|
||||
|
||||
class TestScheduleTriggerDebugEventPoller:
|
||||
def _make_schedule_poller(self, mock_redis, mock_schedule_svc, next_run_at: datetime):
|
||||
"""Set up mocks and create a schedule poller."""
|
||||
mock_redis.get.return_value = None
|
||||
mock_schedule_config = MagicMock()
|
||||
mock_schedule_config.cron_expression = "0 * * * *"
|
||||
mock_schedule_config.timezone = "UTC"
|
||||
mock_schedule_svc.to_schedule_config.return_value = mock_schedule_config
|
||||
return ScheduleTriggerDebugEventPoller(**_make_poller_args())
|
||||
|
||||
@patch("core.trigger.debug.event_selectors.redis_client")
|
||||
@patch("core.trigger.debug.event_selectors.naive_utc_now")
|
||||
@patch("core.trigger.debug.event_selectors.calculate_next_run_at")
|
||||
@patch("core.trigger.debug.event_selectors.ensure_naive_utc")
|
||||
def test_returns_none_when_not_yet_due(self, mock_ensure, mock_calc, mock_now, mock_redis):
|
||||
now = datetime(2025, 1, 1, 12, 0, 0)
|
||||
next_run = datetime(2025, 1, 1, 13, 0, 0) # future
|
||||
mock_now.return_value = now
|
||||
mock_calc.return_value = next_run
|
||||
mock_ensure.return_value = next_run
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
with patch("services.trigger.schedule_service.ScheduleService") as mock_schedule_svc:
|
||||
mock_schedule_config = MagicMock()
|
||||
mock_schedule_config.cron_expression = "0 * * * *"
|
||||
mock_schedule_config.timezone = "UTC"
|
||||
mock_schedule_svc.to_schedule_config.return_value = mock_schedule_config
|
||||
|
||||
poller = ScheduleTriggerDebugEventPoller(**_make_poller_args())
|
||||
|
||||
assert poller.poll() is None
|
||||
|
||||
@patch("core.trigger.debug.event_selectors.redis_client")
|
||||
@patch("core.trigger.debug.event_selectors.naive_utc_now")
|
||||
@patch("core.trigger.debug.event_selectors.calculate_next_run_at")
|
||||
@patch("core.trigger.debug.event_selectors.ensure_naive_utc")
|
||||
def test_fires_event_when_due(self, mock_ensure, mock_calc, mock_now, mock_redis):
|
||||
now = datetime(2025, 1, 1, 14, 0, 0)
|
||||
next_run = datetime(2025, 1, 1, 12, 0, 0) # past
|
||||
mock_now.return_value = now
|
||||
mock_calc.return_value = next_run
|
||||
mock_ensure.return_value = next_run
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
with patch("services.trigger.schedule_service.ScheduleService") as mock_schedule_svc:
|
||||
mock_schedule_config = MagicMock()
|
||||
mock_schedule_config.cron_expression = "0 * * * *"
|
||||
mock_schedule_config.timezone = "UTC"
|
||||
mock_schedule_svc.to_schedule_config.return_value = mock_schedule_config
|
||||
|
||||
poller = ScheduleTriggerDebugEventPoller(**_make_poller_args())
|
||||
result = poller.poll()
|
||||
|
||||
assert result is not None
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
|
||||
class TestCreateEventPoller:
|
||||
def _workflow_with_node(self, node_type: NodeType):
|
||||
wf = MagicMock()
|
||||
wf.get_node_config_by_id.return_value = {"data": {}}
|
||||
wf.get_node_type_from_node_config.return_value = node_type
|
||||
return wf
|
||||
|
||||
def test_creates_plugin_poller(self):
|
||||
wf = self._workflow_with_node(NodeType.TRIGGER_PLUGIN)
|
||||
poller = create_event_poller(wf, "t1", "u1", "a1", "n1")
|
||||
assert isinstance(poller, PluginTriggerDebugEventPoller)
|
||||
|
||||
def test_creates_webhook_poller(self):
|
||||
wf = self._workflow_with_node(NodeType.TRIGGER_WEBHOOK)
|
||||
poller = create_event_poller(wf, "t1", "u1", "a1", "n1")
|
||||
assert isinstance(poller, WebhookTriggerDebugEventPoller)
|
||||
|
||||
def test_creates_schedule_poller(self):
|
||||
wf = self._workflow_with_node(NodeType.TRIGGER_SCHEDULE)
|
||||
poller = create_event_poller(wf, "t1", "u1", "a1", "n1")
|
||||
assert isinstance(poller, ScheduleTriggerDebugEventPoller)
|
||||
|
||||
def test_raises_for_unknown_type(self):
|
||||
wf = MagicMock()
|
||||
wf.get_node_config_by_id.return_value = {"data": {}}
|
||||
wf.get_node_type_from_node_config.return_value = NodeType.START
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
create_event_poller(wf, "t1", "u1", "a1", "n1")
|
||||
|
||||
def test_raises_when_node_config_missing(self):
|
||||
wf = MagicMock()
|
||||
wf.get_node_config_by_id.return_value = None
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
create_event_poller(wf, "t1", "u1", "a1", "n1")
|
||||
|
||||
|
||||
class TestSelectTriggerDebugEvents:
|
||||
def test_returns_first_non_none_event(self):
|
||||
wf = MagicMock()
|
||||
wf.get_node_config_by_id.return_value = {"data": {}}
|
||||
wf.get_node_type_from_node_config.return_value = NodeType.TRIGGER_WEBHOOK
|
||||
app_model = MagicMock()
|
||||
app_model.tenant_id = "t1"
|
||||
app_model.id = "a1"
|
||||
|
||||
with patch.object(WebhookTriggerDebugEventPoller, "poll") as mock_poll:
|
||||
expected = MagicMock()
|
||||
mock_poll.return_value = expected
|
||||
|
||||
result = select_trigger_debug_events(wf, app_model, "u1", ["n1", "n2"])
|
||||
|
||||
assert result is expected
|
||||
|
||||
def test_returns_none_when_no_events(self):
|
||||
wf = MagicMock()
|
||||
wf.get_node_config_by_id.return_value = {"data": {}}
|
||||
wf.get_node_type_from_node_config.return_value = NodeType.TRIGGER_WEBHOOK
|
||||
app_model = MagicMock()
|
||||
app_model.tenant_id = "t1"
|
||||
app_model.id = "a1"
|
||||
|
||||
with patch.object(WebhookTriggerDebugEventPoller, "poll", return_value=None):
|
||||
result = select_trigger_debug_events(wf, app_model, "u1", ["n1"])
|
||||
|
||||
assert result is None
|
||||
332
api/tests/unit_tests/core/trigger/test_provider.py
Normal file
332
api/tests/unit_tests/core/trigger/test_provider.py
Normal file
@ -0,0 +1,332 @@
|
||||
"""
|
||||
Tests for core.trigger.provider.PluginTriggerProviderController.
|
||||
|
||||
Covers: to_api_entity creation-method logic, credential validation pipeline,
|
||||
schema resolution by type, event lookup, dispatch/invoke/subscribe delegation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.trigger.entities.entities import (
|
||||
EventParameter,
|
||||
EventParameterType,
|
||||
OAuthSchema,
|
||||
TriggerCreationMethod,
|
||||
)
|
||||
from core.trigger.errors import TriggerProviderCredentialValidationError
|
||||
from tests.unit_tests.core.trigger.conftest import (
|
||||
i18n,
|
||||
make_constructor,
|
||||
make_controller,
|
||||
make_event,
|
||||
make_provider_config,
|
||||
make_provider_entity,
|
||||
make_subscription,
|
||||
)
|
||||
|
||||
ICON_URL = "https://cdn/icon.png"
|
||||
|
||||
|
||||
class TestToApiEntity:
|
||||
@patch("core.trigger.provider.PluginService")
|
||||
def test_includes_icons_when_present(self, mock_plugin_svc):
|
||||
mock_plugin_svc.get_plugin_icon_url.return_value = ICON_URL
|
||||
ctrl = make_controller(entity=make_provider_entity(icon="icon.png", icon_dark="dark.png"))
|
||||
|
||||
api = ctrl.to_api_entity()
|
||||
|
||||
assert api.icon == ICON_URL
|
||||
assert api.icon_dark == ICON_URL
|
||||
|
||||
@patch("core.trigger.provider.PluginService")
|
||||
def test_icons_none_when_absent(self, mock_plugin_svc):
|
||||
ctrl = make_controller(entity=make_provider_entity(icon=None, icon_dark=None))
|
||||
|
||||
api = ctrl.to_api_entity()
|
||||
|
||||
assert api.icon is None
|
||||
assert api.icon_dark is None
|
||||
mock_plugin_svc.get_plugin_icon_url.assert_not_called()
|
||||
|
||||
@patch("core.trigger.provider.PluginService")
|
||||
def test_manual_only_without_schemas(self, mock_plugin_svc):
|
||||
mock_plugin_svc.get_plugin_icon_url.return_value = ICON_URL
|
||||
ctrl = make_controller(entity=make_provider_entity(constructor=None))
|
||||
|
||||
api = ctrl.to_api_entity()
|
||||
|
||||
assert api.supported_creation_methods == [TriggerCreationMethod.MANUAL]
|
||||
|
||||
@patch("core.trigger.provider.PluginService")
|
||||
def test_adds_oauth_when_oauth_schema_present(self, mock_plugin_svc):
|
||||
mock_plugin_svc.get_plugin_icon_url.return_value = ICON_URL
|
||||
oauth = OAuthSchema(client_schema=[], credentials_schema=[])
|
||||
ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor(oauth_schema=oauth)))
|
||||
|
||||
api = ctrl.to_api_entity()
|
||||
|
||||
assert TriggerCreationMethod.OAUTH in api.supported_creation_methods
|
||||
assert TriggerCreationMethod.MANUAL in api.supported_creation_methods
|
||||
|
||||
@patch("core.trigger.provider.PluginService")
|
||||
def test_adds_apikey_when_credentials_schema_present(self, mock_plugin_svc):
|
||||
mock_plugin_svc.get_plugin_icon_url.return_value = ICON_URL
|
||||
ctrl = make_controller(
|
||||
entity=make_provider_entity(constructor=make_constructor(credentials_schema=[make_provider_config()]))
|
||||
)
|
||||
|
||||
api = ctrl.to_api_entity()
|
||||
|
||||
assert TriggerCreationMethod.APIKEY in api.supported_creation_methods
|
||||
|
||||
|
||||
class TestGetEvent:
|
||||
def test_returns_matching_event(self):
|
||||
evt = make_event("push")
|
||||
ctrl = make_controller(entity=make_provider_entity(events=[evt, make_event("pr")]))
|
||||
|
||||
assert ctrl.get_event("push") is evt
|
||||
|
||||
def test_returns_none_for_unknown(self):
|
||||
ctrl = make_controller(entity=make_provider_entity(events=[make_event("push")]))
|
||||
|
||||
assert ctrl.get_event("nonexistent") is None
|
||||
|
||||
|
||||
class TestGetSubscriptionDefaultProperties:
|
||||
def test_returns_defaults_skipping_none(self):
|
||||
config1 = make_provider_config("key1")
|
||||
config1.default = "val1"
|
||||
config2 = make_provider_config("key2")
|
||||
config2.default = None
|
||||
ctrl = make_controller(entity=make_provider_entity(subscription_schema=[config1, config2]))
|
||||
|
||||
props = ctrl.get_subscription_default_properties()
|
||||
|
||||
assert props == {"key1": "val1"}
|
||||
|
||||
|
||||
class TestValidateCredentials:
|
||||
def test_raises_when_no_constructor(self):
|
||||
ctrl = make_controller(entity=make_provider_entity(constructor=None))
|
||||
|
||||
with pytest.raises(ValueError, match="Subscription constructor not found"):
|
||||
ctrl.validate_credentials("u1", {"key": "val"})
|
||||
|
||||
def test_raises_for_missing_required_field(self):
|
||||
required_cfg = make_provider_config("api_key", required=True)
|
||||
ctrl = make_controller(
|
||||
entity=make_provider_entity(constructor=make_constructor(credentials_schema=[required_cfg]))
|
||||
)
|
||||
|
||||
with pytest.raises(TriggerProviderCredentialValidationError, match="Missing required"):
|
||||
ctrl.validate_credentials("u1", {})
|
||||
|
||||
@patch("core.trigger.provider.PluginTriggerClient")
|
||||
def test_passes_with_valid_credentials(self, mock_client):
|
||||
required_cfg = make_provider_config("api_key", required=True)
|
||||
ctrl = make_controller(
|
||||
entity=make_provider_entity(constructor=make_constructor(credentials_schema=[required_cfg]))
|
||||
)
|
||||
mock_client.return_value.validate_provider_credentials.return_value = True
|
||||
|
||||
ctrl.validate_credentials("u1", {"api_key": "secret123"}) # should not raise
|
||||
|
||||
@patch("core.trigger.provider.PluginTriggerClient")
|
||||
def test_raises_when_plugin_rejects(self, mock_client):
|
||||
required_cfg = make_provider_config("api_key", required=True)
|
||||
ctrl = make_controller(
|
||||
entity=make_provider_entity(constructor=make_constructor(credentials_schema=[required_cfg]))
|
||||
)
|
||||
mock_client.return_value.validate_provider_credentials.return_value = None
|
||||
|
||||
with pytest.raises(TriggerProviderCredentialValidationError, match="Invalid credentials"):
|
||||
ctrl.validate_credentials("u1", {"api_key": "bad"})
|
||||
|
||||
|
||||
class TestGetSupportedCredentialTypes:
|
||||
def test_empty_when_no_constructor(self):
|
||||
ctrl = make_controller(entity=make_provider_entity(constructor=None))
|
||||
assert ctrl.get_supported_credential_types() == []
|
||||
|
||||
def test_oauth_only(self):
|
||||
oauth = OAuthSchema(client_schema=[], credentials_schema=[])
|
||||
ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor(oauth_schema=oauth)))
|
||||
|
||||
types = ctrl.get_supported_credential_types()
|
||||
|
||||
assert CredentialType.OAUTH2 in types
|
||||
assert CredentialType.API_KEY not in types
|
||||
|
||||
def test_apikey_only(self):
|
||||
ctrl = make_controller(
|
||||
entity=make_provider_entity(constructor=make_constructor(credentials_schema=[make_provider_config()]))
|
||||
)
|
||||
|
||||
types = ctrl.get_supported_credential_types()
|
||||
|
||||
assert CredentialType.API_KEY in types
|
||||
assert CredentialType.OAUTH2 not in types
|
||||
|
||||
def test_both(self):
|
||||
oauth = OAuthSchema(client_schema=[], credentials_schema=[make_provider_config("oauth_secret")])
|
||||
ctrl = make_controller(
|
||||
entity=make_provider_entity(
|
||||
constructor=make_constructor(credentials_schema=[make_provider_config()], oauth_schema=oauth)
|
||||
)
|
||||
)
|
||||
|
||||
types = ctrl.get_supported_credential_types()
|
||||
|
||||
assert CredentialType.OAUTH2 in types
|
||||
assert CredentialType.API_KEY in types
|
||||
|
||||
|
||||
class TestGetCredentialsSchema:
|
||||
def test_returns_empty_when_no_constructor(self):
|
||||
ctrl = make_controller(entity=make_provider_entity(constructor=None))
|
||||
assert ctrl.get_credentials_schema(CredentialType.API_KEY) == []
|
||||
|
||||
def test_returns_apikey_credentials(self):
|
||||
cfg = make_provider_config("token")
|
||||
ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor(credentials_schema=[cfg])))
|
||||
|
||||
result = ctrl.get_credentials_schema(CredentialType.API_KEY)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "token"
|
||||
|
||||
def test_returns_oauth_credentials(self):
|
||||
oauth_cred = make_provider_config("oauth_token")
|
||||
oauth = OAuthSchema(client_schema=[], credentials_schema=[oauth_cred])
|
||||
ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor(oauth_schema=oauth)))
|
||||
|
||||
result = ctrl.get_credentials_schema(CredentialType.OAUTH2)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "oauth_token"
|
||||
|
||||
def test_unauthorized_returns_empty(self):
|
||||
ctrl = make_controller(
|
||||
entity=make_provider_entity(constructor=make_constructor(credentials_schema=[make_provider_config()]))
|
||||
)
|
||||
assert ctrl.get_credentials_schema(CredentialType.UNAUTHORIZED) == []
|
||||
|
||||
def test_invalid_type_raises(self):
|
||||
ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor()))
|
||||
with pytest.raises(ValueError, match="Invalid credential type"):
|
||||
ctrl.get_credentials_schema("bogus_type")
|
||||
|
||||
|
||||
class TestGetEventParameters:
|
||||
def test_returns_params_for_known_event(self):
|
||||
param = EventParameter(name="branch", label=i18n("branch"), type=EventParameterType.STRING)
|
||||
evt = make_event("push", parameters=[param])
|
||||
ctrl = make_controller(entity=make_provider_entity(events=[evt]))
|
||||
|
||||
result = ctrl.get_event_parameters("push")
|
||||
|
||||
assert "branch" in result
|
||||
assert result["branch"].name == "branch"
|
||||
|
||||
def test_returns_empty_for_unknown_event(self):
|
||||
ctrl = make_controller(entity=make_provider_entity(events=[make_event("push")]))
|
||||
|
||||
assert ctrl.get_event_parameters("nonexistent") == {}
|
||||
|
||||
|
||||
class TestDispatch:
|
||||
@patch("core.trigger.provider.PluginTriggerClient")
|
||||
def test_delegates_to_client(self, mock_client):
|
||||
ctrl = make_controller()
|
||||
expected = MagicMock()
|
||||
mock_client.return_value.dispatch_event.return_value = expected
|
||||
|
||||
result = ctrl.dispatch(
|
||||
request=MagicMock(),
|
||||
subscription=make_subscription(),
|
||||
credentials={"k": "v"},
|
||||
credential_type=CredentialType.API_KEY,
|
||||
)
|
||||
|
||||
assert result is expected
|
||||
mock_client.return_value.dispatch_event.assert_called_once()
|
||||
|
||||
|
||||
class TestInvokeTriggerEvent:
|
||||
@patch("core.trigger.provider.PluginTriggerClient")
|
||||
def test_delegates_to_client(self, mock_client):
|
||||
ctrl = make_controller()
|
||||
expected = MagicMock()
|
||||
mock_client.return_value.invoke_trigger_event.return_value = expected
|
||||
|
||||
result = ctrl.invoke_trigger_event(
|
||||
user_id="u1",
|
||||
event_name="push",
|
||||
parameters={},
|
||||
credentials={},
|
||||
credential_type=CredentialType.API_KEY,
|
||||
subscription=make_subscription(),
|
||||
request=MagicMock(),
|
||||
payload={},
|
||||
)
|
||||
|
||||
assert result is expected
|
||||
|
||||
|
||||
class TestSubscribeTrigger:
|
||||
@patch("core.trigger.provider.PluginTriggerClient")
|
||||
def test_returns_validated_subscription(self, mock_client):
|
||||
ctrl = make_controller()
|
||||
mock_client.return_value.subscribe.return_value.subscription = {
|
||||
"expires_at": 123,
|
||||
"endpoint": "https://e",
|
||||
"properties": {},
|
||||
}
|
||||
|
||||
result = ctrl.subscribe_trigger(
|
||||
user_id="u1",
|
||||
endpoint="https://e",
|
||||
parameters={},
|
||||
credentials={},
|
||||
credential_type=CredentialType.API_KEY,
|
||||
)
|
||||
|
||||
assert result.endpoint == "https://e"
|
||||
|
||||
|
||||
class TestUnsubscribeTrigger:
|
||||
@patch("core.trigger.provider.PluginTriggerClient")
|
||||
def test_returns_validated_result(self, mock_client):
|
||||
ctrl = make_controller()
|
||||
mock_client.return_value.unsubscribe.return_value.subscription = {"success": True, "message": "ok"}
|
||||
|
||||
result = ctrl.unsubscribe_trigger(
|
||||
user_id="u1",
|
||||
subscription=make_subscription(),
|
||||
credentials={},
|
||||
credential_type=CredentialType.API_KEY,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
|
||||
class TestRefreshTrigger:
|
||||
@patch("core.trigger.provider.PluginTriggerClient")
|
||||
def test_uses_system_user_id(self, mock_client):
|
||||
ctrl = make_controller()
|
||||
mock_client.return_value.refresh.return_value.subscription = {
|
||||
"expires_at": 456,
|
||||
"endpoint": "https://e",
|
||||
"properties": {},
|
||||
}
|
||||
|
||||
ctrl.refresh_trigger(subscription=make_subscription(), credentials={}, credential_type=CredentialType.API_KEY)
|
||||
|
||||
call_kwargs = mock_client.return_value.refresh.call_args[1]
|
||||
assert call_kwargs["user_id"] == "system"
|
||||
307
api/tests/unit_tests/core/trigger/test_trigger_manager.py
Normal file
307
api/tests/unit_tests/core/trigger/test_trigger_manager.py
Normal file
@ -0,0 +1,307 @@
|
||||
"""
|
||||
Tests for core.trigger.trigger_manager.TriggerManager.
|
||||
|
||||
Covers: icon URL construction, provider listing with error resilience,
|
||||
double-check lock caching, error translation, EventIgnoreError -> cancelled,
|
||||
and delegation to provider controller.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from threading import Lock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.entities.request import TriggerInvokeEventResponse
|
||||
from core.plugin.impl.exc import PluginDaemonError, PluginNotFoundError
|
||||
from core.trigger.errors import EventIgnoreError
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from tests.unit_tests.core.trigger.conftest import (
|
||||
VALID_PROVIDER_ID,
|
||||
make_controller,
|
||||
make_provider_entity,
|
||||
make_subscription,
|
||||
)
|
||||
|
||||
PID = TriggerProviderID(VALID_PROVIDER_ID)
|
||||
PID_STR = str(PID)
|
||||
|
||||
|
||||
class TestGetTriggerPluginIcon:
|
||||
@patch("core.trigger.trigger_manager.dify_config")
|
||||
@patch("core.trigger.trigger_manager.PluginTriggerClient")
|
||||
def test_builds_correct_url(self, mock_client, mock_config):
|
||||
mock_config.CONSOLE_API_URL = "https://console.example.com"
|
||||
provider = MagicMock()
|
||||
provider.declaration.identity.icon = "my-icon.svg"
|
||||
mock_client.return_value.fetch_trigger_provider.return_value = provider
|
||||
|
||||
url = TriggerManager.get_trigger_plugin_icon("tenant-1", VALID_PROVIDER_ID)
|
||||
|
||||
assert "tenant_id=tenant-1" in url
|
||||
assert "filename=my-icon.svg" in url
|
||||
assert url.startswith("https://console.example.com/console/api/workspaces/current/plugin/icon")
|
||||
|
||||
|
||||
class TestListPluginTriggerProviders:
|
||||
@patch("core.trigger.trigger_manager.PluginTriggerClient")
|
||||
def test_wraps_entities_into_controllers(self, mock_client):
|
||||
entity = MagicMock()
|
||||
entity.declaration = make_provider_entity("p1")
|
||||
entity.plugin_id = "plugin-1"
|
||||
entity.plugin_unique_identifier = "uid-1"
|
||||
entity.provider = VALID_PROVIDER_ID
|
||||
mock_client.return_value.fetch_trigger_providers.return_value = [entity]
|
||||
|
||||
controllers = TriggerManager.list_plugin_trigger_providers("tenant-1")
|
||||
|
||||
assert len(controllers) == 1
|
||||
assert controllers[0].plugin_id == "plugin-1"
|
||||
|
||||
@patch("core.trigger.trigger_manager.PluginTriggerClient")
|
||||
def test_skips_failing_providers(self, mock_client):
|
||||
good = MagicMock()
|
||||
good.declaration = make_provider_entity("good")
|
||||
good.plugin_id = "good-plugin"
|
||||
good.plugin_unique_identifier = "uid-good"
|
||||
good.provider = VALID_PROVIDER_ID
|
||||
|
||||
bad = MagicMock()
|
||||
bad.declaration = make_provider_entity("bad")
|
||||
bad.plugin_id = "bad-plugin"
|
||||
bad.plugin_unique_identifier = "uid-bad"
|
||||
bad.provider = "bad/format" # 2-part: fails TriggerProviderID validation
|
||||
|
||||
mock_client.return_value.fetch_trigger_providers.return_value = [bad, good]
|
||||
|
||||
controllers = TriggerManager.list_plugin_trigger_providers("tenant-1")
|
||||
|
||||
assert len(controllers) == 1
|
||||
assert controllers[0].plugin_id == "good-plugin"
|
||||
|
||||
|
||||
class TestGetTriggerProvider:
|
||||
@patch("core.trigger.trigger_manager.PluginTriggerClient")
|
||||
@patch("core.trigger.trigger_manager.contexts")
|
||||
def test_initializes_context_on_first_call(self, mock_ctx, mock_client):
|
||||
# get() called 3 times: (1) try block, (2) after set, (3) under lock
|
||||
mock_ctx.plugin_trigger_providers.get.side_effect = [LookupError, {}, {}]
|
||||
mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock()
|
||||
provider = MagicMock()
|
||||
provider.declaration = make_provider_entity()
|
||||
provider.plugin_id = "p1"
|
||||
provider.plugin_unique_identifier = "uid-1"
|
||||
mock_client.return_value.fetch_trigger_provider.return_value = provider
|
||||
|
||||
result = TriggerManager.get_trigger_provider("t1", PID)
|
||||
|
||||
mock_ctx.plugin_trigger_providers.set.assert_called_once_with({})
|
||||
mock_ctx.plugin_trigger_providers_lock.set.assert_called_once()
|
||||
assert result is not None
|
||||
|
||||
@patch("core.trigger.trigger_manager.PluginTriggerClient")
|
||||
@patch("core.trigger.trigger_manager.contexts")
|
||||
def test_returns_cached_without_fetch(self, mock_ctx, mock_client):
|
||||
cached = make_controller()
|
||||
mock_ctx.plugin_trigger_providers.get.return_value = {PID_STR: cached}
|
||||
|
||||
result = TriggerManager.get_trigger_provider("t1", PID)
|
||||
|
||||
assert result is cached
|
||||
mock_client.return_value.fetch_trigger_provider.assert_not_called()
|
||||
|
||||
@patch("core.trigger.trigger_manager.PluginTriggerClient")
|
||||
@patch("core.trigger.trigger_manager.contexts")
|
||||
def test_double_check_lock_uses_cached_from_other_thread(self, mock_ctx, mock_client):
|
||||
cached = make_controller()
|
||||
mock_ctx.plugin_trigger_providers.get.side_effect = [
|
||||
{}, # first check misses
|
||||
{PID_STR: cached}, # under-lock check hits
|
||||
]
|
||||
mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock()
|
||||
|
||||
result = TriggerManager.get_trigger_provider("t1", PID)
|
||||
|
||||
assert result is cached
|
||||
mock_client.return_value.fetch_trigger_provider.assert_not_called()
|
||||
|
||||
@patch("core.trigger.trigger_manager.PluginTriggerClient")
|
||||
@patch("core.trigger.trigger_manager.contexts")
|
||||
def test_fetches_and_caches_on_miss(self, mock_ctx, mock_client):
|
||||
cache: dict = {}
|
||||
mock_ctx.plugin_trigger_providers.get.return_value = cache
|
||||
mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock()
|
||||
provider = MagicMock()
|
||||
provider.declaration = make_provider_entity()
|
||||
provider.plugin_id = "p1"
|
||||
provider.plugin_unique_identifier = "uid-1"
|
||||
mock_client.return_value.fetch_trigger_provider.return_value = provider
|
||||
|
||||
result = TriggerManager.get_trigger_provider("t1", PID)
|
||||
|
||||
assert result is not None
|
||||
assert PID_STR in cache
|
||||
|
||||
@patch("core.trigger.trigger_manager.PluginTriggerClient")
|
||||
@patch("core.trigger.trigger_manager.contexts")
|
||||
def test_none_fetch_raises_value_error(self, mock_ctx, mock_client):
|
||||
mock_ctx.plugin_trigger_providers.get.return_value = {}
|
||||
mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock()
|
||||
mock_client.return_value.fetch_trigger_provider.return_value = None
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
TriggerManager.get_trigger_provider("t1", TriggerProviderID("org/plug/missing"))
|
||||
|
||||
@patch("core.trigger.trigger_manager.PluginTriggerClient")
|
||||
@patch("core.trigger.trigger_manager.contexts")
|
||||
def test_plugin_not_found_becomes_value_error(self, mock_ctx, mock_client):
|
||||
mock_ctx.plugin_trigger_providers.get.return_value = {}
|
||||
mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock()
|
||||
mock_client.return_value.fetch_trigger_provider.side_effect = PluginNotFoundError("gone")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
TriggerManager.get_trigger_provider("t1", TriggerProviderID("org/plug/miss"))
|
||||
|
||||
@patch("core.trigger.trigger_manager.PluginTriggerClient")
|
||||
@patch("core.trigger.trigger_manager.contexts")
|
||||
def test_plugin_daemon_error_propagates(self, mock_ctx, mock_client):
|
||||
mock_ctx.plugin_trigger_providers.get.return_value = {}
|
||||
mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock()
|
||||
mock_client.return_value.fetch_trigger_provider.side_effect = PluginDaemonError("test error")
|
||||
|
||||
with pytest.raises(PluginDaemonError):
|
||||
TriggerManager.get_trigger_provider("t1", TriggerProviderID("org/plug/miss"))
|
||||
|
||||
|
||||
class TestListAllTriggerProviders:
|
||||
@patch.object(TriggerManager, "list_plugin_trigger_providers")
|
||||
def test_delegates_to_list_plugin(self, mock_list):
|
||||
expected = [make_controller()]
|
||||
mock_list.return_value = expected
|
||||
|
||||
assert TriggerManager.list_all_trigger_providers("t1") is expected
|
||||
mock_list.assert_called_once_with("t1")
|
||||
|
||||
|
||||
class TestListTriggersByProvider:
|
||||
@patch.object(TriggerManager, "get_trigger_provider")
|
||||
def test_returns_provider_events(self, mock_get):
|
||||
ctrl = make_controller()
|
||||
mock_get.return_value = ctrl
|
||||
|
||||
result = TriggerManager.list_triggers_by_provider("t1", PID)
|
||||
|
||||
assert result == ctrl.get_events()
|
||||
|
||||
|
||||
class TestInvokeTriggerEvent:
|
||||
def _args(self):
|
||||
return {
|
||||
"tenant_id": "t1",
|
||||
"user_id": "u1",
|
||||
"provider_id": PID,
|
||||
"event_name": "on_push",
|
||||
"parameters": {"branch": "main"},
|
||||
"credentials": {"token": "abc"},
|
||||
"credential_type": CredentialType.API_KEY,
|
||||
"subscription": make_subscription(),
|
||||
"request": MagicMock(),
|
||||
"payload": {"action": "push"},
|
||||
}
|
||||
|
||||
@patch.object(TriggerManager, "get_trigger_provider")
|
||||
def test_returns_invoke_response(self, mock_get):
|
||||
ctrl = MagicMock()
|
||||
expected = TriggerInvokeEventResponse(variables={"v": "1"}, cancelled=False)
|
||||
ctrl.invoke_trigger_event.return_value = expected
|
||||
mock_get.return_value = ctrl
|
||||
|
||||
result = TriggerManager.invoke_trigger_event(**self._args())
|
||||
|
||||
assert result is expected
|
||||
assert result.cancelled is False
|
||||
|
||||
@patch.object(TriggerManager, "get_trigger_provider")
|
||||
def test_event_ignore_returns_cancelled(self, mock_get):
|
||||
ctrl = MagicMock()
|
||||
ctrl.invoke_trigger_event.side_effect = EventIgnoreError("skip")
|
||||
mock_get.return_value = ctrl
|
||||
|
||||
result = TriggerManager.invoke_trigger_event(**self._args())
|
||||
|
||||
assert result.cancelled is True
|
||||
assert result.variables == {}
|
||||
|
||||
@patch.object(TriggerManager, "get_trigger_provider")
|
||||
def test_other_errors_propagate(self, mock_get):
|
||||
ctrl = MagicMock()
|
||||
ctrl.invoke_trigger_event.side_effect = RuntimeError("boom")
|
||||
mock_get.return_value = ctrl
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
TriggerManager.invoke_trigger_event(**self._args())
|
||||
|
||||
|
||||
class TestSubscribeTrigger:
|
||||
@patch.object(TriggerManager, "get_trigger_provider")
|
||||
def test_delegates_with_correct_args(self, mock_get):
|
||||
ctrl = MagicMock()
|
||||
expected = make_subscription()
|
||||
ctrl.subscribe_trigger.return_value = expected
|
||||
mock_get.return_value = ctrl
|
||||
|
||||
result = TriggerManager.subscribe_trigger(
|
||||
tenant_id="t1",
|
||||
user_id="u1",
|
||||
provider_id=PID,
|
||||
endpoint="https://hook.test",
|
||||
parameters={"f": "all"},
|
||||
credentials={"token": "x"},
|
||||
credential_type=CredentialType.API_KEY,
|
||||
)
|
||||
|
||||
assert result is expected
|
||||
ctrl.subscribe_trigger.assert_called_once()
|
||||
|
||||
|
||||
class TestUnsubscribeTrigger:
|
||||
@patch.object(TriggerManager, "get_trigger_provider")
|
||||
def test_delegates_with_correct_args(self, mock_get):
|
||||
ctrl = MagicMock()
|
||||
expected = MagicMock()
|
||||
ctrl.unsubscribe_trigger.return_value = expected
|
||||
mock_get.return_value = ctrl
|
||||
sub = make_subscription()
|
||||
|
||||
result = TriggerManager.unsubscribe_trigger(
|
||||
tenant_id="t1",
|
||||
user_id="u1",
|
||||
provider_id=PID,
|
||||
subscription=sub,
|
||||
credentials={"token": "x"},
|
||||
credential_type=CredentialType.API_KEY,
|
||||
)
|
||||
|
||||
assert result is expected
|
||||
|
||||
|
||||
class TestRefreshTrigger:
|
||||
@patch.object(TriggerManager, "get_trigger_provider")
|
||||
def test_delegates_with_correct_args(self, mock_get):
|
||||
ctrl = MagicMock()
|
||||
expected = make_subscription()
|
||||
ctrl.refresh_trigger.return_value = expected
|
||||
mock_get.return_value = ctrl
|
||||
|
||||
result = TriggerManager.refresh_trigger(
|
||||
tenant_id="t1",
|
||||
provider_id=PID,
|
||||
subscription=make_subscription(),
|
||||
credentials={"token": "x"},
|
||||
credential_type=CredentialType.API_KEY,
|
||||
)
|
||||
|
||||
assert result is expected
|
||||
0
api/tests/unit_tests/core/trigger/utils/__init__.py
Normal file
0
api/tests/unit_tests/core/trigger/utils/__init__.py
Normal file
@ -0,0 +1,62 @@
|
||||
"""Tests for core.trigger.utils.encryption — masking logic and cache key generation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.trigger.utils.encryption import (
|
||||
TriggerProviderCredentialsCache,
|
||||
TriggerProviderOAuthClientParamsCache,
|
||||
TriggerProviderPropertiesCache,
|
||||
masked_credentials,
|
||||
)
|
||||
|
||||
|
||||
def _make_schema(name: str, field_type: str = "secret-input") -> ProviderConfig:
|
||||
return ProviderConfig(
|
||||
name=name,
|
||||
label=I18nObject(en_US=name, zh_Hans=name),
|
||||
type=field_type,
|
||||
)
|
||||
|
||||
|
||||
class TestMaskedCredentials:
|
||||
def test_short_secret_fully_masked(self):
|
||||
schema = [_make_schema("key", "secret-input")]
|
||||
result = masked_credentials(schema, {"key": "ab"})
|
||||
assert result["key"] == "**"
|
||||
|
||||
def test_long_secret_partially_masked(self):
|
||||
schema = [_make_schema("key", "secret-input")]
|
||||
result = masked_credentials(schema, {"key": "abcdef"})
|
||||
assert result["key"].startswith("ab")
|
||||
assert result["key"].endswith("ef")
|
||||
assert "**" in result["key"]
|
||||
|
||||
def test_non_secret_field_unchanged(self):
|
||||
schema = [_make_schema("host", "text-input")]
|
||||
result = masked_credentials(schema, {"host": "example.com"})
|
||||
assert result["host"] == "example.com"
|
||||
|
||||
def test_unknown_key_passes_through(self):
|
||||
result = masked_credentials([], {"unknown": "value"})
|
||||
assert result["unknown"] == "value"
|
||||
|
||||
|
||||
class TestCacheKeyGeneration:
|
||||
def test_credentials_cache_key_contains_ids(self):
|
||||
cache = TriggerProviderCredentialsCache(tenant_id="t1", provider_id="p1", credential_id="c1")
|
||||
assert "t1" in cache.cache_key
|
||||
assert "p1" in cache.cache_key
|
||||
assert "c1" in cache.cache_key
|
||||
|
||||
def test_oauth_client_cache_key_contains_ids(self):
|
||||
cache = TriggerProviderOAuthClientParamsCache(tenant_id="t1", provider_id="p1")
|
||||
assert "t1" in cache.cache_key
|
||||
assert "p1" in cache.cache_key
|
||||
|
||||
def test_properties_cache_key_contains_ids(self):
|
||||
cache = TriggerProviderPropertiesCache(tenant_id="t1", provider_id="p1", subscription_id="s1")
|
||||
assert "t1" in cache.cache_key
|
||||
assert "p1" in cache.cache_key
|
||||
assert "s1" in cache.cache_key
|
||||
@ -0,0 +1,31 @@
|
||||
"""Tests for core.trigger.utils.endpoint — URL generation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from yarl import URL
|
||||
|
||||
from core.trigger.utils import endpoint
|
||||
|
||||
|
||||
class TestGeneratePluginTriggerEndpointUrl:
|
||||
def test_builds_correct_url(self):
|
||||
with patch.object(endpoint, "base_url", URL("https://api.example.com")):
|
||||
url = endpoint.generate_plugin_trigger_endpoint_url("endpoint-123")
|
||||
|
||||
assert url == "https://api.example.com/triggers/plugin/endpoint-123"
|
||||
|
||||
|
||||
class TestGenerateWebhookTriggerEndpoint:
|
||||
def test_non_debug_url(self):
|
||||
with patch.object(endpoint, "base_url", URL("https://api.example.com")):
|
||||
url = endpoint.generate_webhook_trigger_endpoint("sub-456", debug=False)
|
||||
|
||||
assert url == "https://api.example.com/triggers/webhook/sub-456"
|
||||
|
||||
def test_debug_url(self):
|
||||
with patch.object(endpoint, "base_url", URL("https://api.example.com")):
|
||||
url = endpoint.generate_webhook_trigger_endpoint("sub-456", debug=True)
|
||||
|
||||
assert url == "https://api.example.com/triggers/webhook-debug/sub-456"
|
||||
23
api/tests/unit_tests/core/trigger/utils/test_utils_locks.py
Normal file
23
api/tests/unit_tests/core/trigger/utils/test_utils_locks.py
Normal file
@ -0,0 +1,23 @@
|
||||
"""Tests for core.trigger.utils.locks — Redis lock key builders."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from core.trigger.utils.locks import build_trigger_refresh_lock_key, build_trigger_refresh_lock_keys
|
||||
|
||||
|
||||
class TestBuildTriggerRefreshLockKey:
|
||||
def test_correct_format(self):
|
||||
key = build_trigger_refresh_lock_key("tenant-1", "sub-1")
|
||||
|
||||
assert key == "trigger_provider_refresh_lock:tenant-1_sub-1"
|
||||
|
||||
|
||||
class TestBuildTriggerRefreshLockKeys:
|
||||
def test_maps_over_pairs(self):
|
||||
pairs = [("t1", "s1"), ("t2", "s2")]
|
||||
|
||||
keys = build_trigger_refresh_lock_keys(pairs)
|
||||
|
||||
assert len(keys) == 2
|
||||
assert keys[0] == "trigger_provider_refresh_lock:t1_s1"
|
||||
assert keys[1] == "trigger_provider_refresh_lock:t2_s2"
|
||||
@ -22,7 +22,7 @@ from dify_graph.nodes.knowledge_retrieval import KnowledgeRetrievalNode
|
||||
from dify_graph.nodes.llm import LLMNode
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.parameter_extractor import ParameterExtractorNode
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol, ToolFileManagerProtocol
|
||||
from dify_graph.nodes.question_classifier import QuestionClassifierNode
|
||||
from dify_graph.nodes.template_transform import TemplateTransformNode
|
||||
from dify_graph.nodes.template_transform.template_renderer import (
|
||||
@ -73,6 +73,12 @@ class MockNodeMixin:
|
||||
if isinstance(self, TemplateTransformNode):
|
||||
kwargs.setdefault("template_renderer", _TestJinja2Renderer())
|
||||
|
||||
# Provide default tool_file_manager_factory for ToolNode subclasses
|
||||
from dify_graph.nodes.tool import ToolNode as _ToolNode # local import to avoid cycles
|
||||
|
||||
if isinstance(self, _ToolNode):
|
||||
kwargs.setdefault("tool_file_manager_factory", MagicMock(spec=ToolFileManagerProtocol))
|
||||
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
|
||||
@ -31,6 +31,7 @@ def tool_node(monkeypatch) -> ToolNode:
|
||||
ops_stub.TraceTask = object # pragma: no cover - stub attribute
|
||||
monkeypatch.setitem(sys.modules, module_name, ops_stub)
|
||||
|
||||
from dify_graph.nodes.protocols import ToolFileManagerProtocol
|
||||
from dify_graph.nodes.tool.tool_node import ToolNode
|
||||
|
||||
graph_config: dict[str, Any] = {
|
||||
@ -69,11 +70,16 @@ def tool_node(monkeypatch) -> ToolNode:
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
|
||||
|
||||
config = graph_config["nodes"][0]
|
||||
|
||||
# Provide a stub ToolFileManager to satisfy the updated ToolNode constructor
|
||||
tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol)
|
||||
|
||||
node = ToolNode(
|
||||
id="node-instance",
|
||||
config=config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tool_file_manager_factory=tool_file_manager_factory,
|
||||
)
|
||||
return node
|
||||
|
||||
|
||||
Reference in New Issue
Block a user