mirror of
https://github.com/langgenius/dify.git
synced 2026-05-02 00:18:03 +08:00
Merge branch 'main' into fix/ee-1511-set-default-provider-user-id-bug
This commit is contained in:
@ -1,15 +1,16 @@
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.fields import Parameters as ParametersResponse
|
||||
from controllers.common.fields import Site as SiteResponse
|
||||
from controllers.common.schema import get_or_create_model
|
||||
from controllers.console import api
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
@ -117,7 +118,56 @@ workflow_fields_copy["rag_pipeline_variables"] = fields.List(fields.Nested(pipel
|
||||
workflow_model = get_or_create_model("TrialWorkflow", workflow_fields_copy)
|
||||
|
||||
|
||||
# Pydantic models for request validation
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowRunRequest(BaseModel):
|
||||
inputs: dict
|
||||
files: list | None = None
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
inputs: dict
|
||||
query: str
|
||||
files: list | None = None
|
||||
conversation_id: str | None = None
|
||||
parent_message_id: str | None = None
|
||||
retriever_from: str = "explore_app"
|
||||
|
||||
|
||||
class TextToSpeechRequest(BaseModel):
|
||||
message_id: str | None = None
|
||||
voice: str | None = None
|
||||
text: str | None = None
|
||||
streaming: bool | None = None
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
inputs: dict
|
||||
query: str = ""
|
||||
files: list | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
retriever_from: str = "explore_app"
|
||||
|
||||
|
||||
# Register schemas for Swagger documentation
|
||||
console_ns.schema_model(
|
||||
WorkflowRunRequest.__name__, WorkflowRunRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChatRequest.__name__, ChatRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
TextToSpeechRequest.__name__, TextToSpeechRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
CompletionRequest.__name__, CompletionRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
class TrialAppWorkflowRunApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[WorkflowRunRequest.__name__])
|
||||
def post(self, trial_app):
|
||||
"""
|
||||
Run workflow
|
||||
@ -129,10 +179,8 @@ class TrialAppWorkflowRunApi(TrialAppResource):
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
request_data = WorkflowRunRequest.model_validate(console_ns.payload)
|
||||
args = request_data.model_dump()
|
||||
assert current_user is not None
|
||||
try:
|
||||
app_id = app_model.id
|
||||
@ -183,6 +231,7 @@ class TrialAppWorkflowTaskStopApi(TrialAppResource):
|
||||
|
||||
|
||||
class TrialChatApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[ChatRequest.__name__])
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
@ -190,14 +239,14 @@ class TrialChatApi(TrialAppResource):
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, required=True, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
request_data = ChatRequest.model_validate(console_ns.payload)
|
||||
args = request_data.model_dump()
|
||||
|
||||
# Validate UUID values if provided
|
||||
if args.get("conversation_id"):
|
||||
args["conversation_id"] = uuid_value(args["conversation_id"])
|
||||
if args.get("parent_message_id"):
|
||||
args["parent_message_id"] = uuid_value(args["parent_message_id"])
|
||||
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
@ -320,20 +369,16 @@ class TrialChatAudioApi(TrialAppResource):
|
||||
|
||||
|
||||
class TrialChatTextApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[TextToSpeechRequest.__name__])
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", type=str, required=False, location="json")
|
||||
parser.add_argument("voice", type=str, location="json")
|
||||
parser.add_argument("text", type=str, location="json")
|
||||
parser.add_argument("streaming", type=bool, location="json")
|
||||
args = parser.parse_args()
|
||||
request_data = TextToSpeechRequest.model_validate(console_ns.payload)
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
voice = args.get("voice", None)
|
||||
message_id = request_data.message_id
|
||||
text = request_data.text
|
||||
voice = request_data.voice
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
@ -371,19 +416,15 @@ class TrialChatTextApi(TrialAppResource):
|
||||
|
||||
|
||||
class TrialCompletionApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[CompletionRequest.__name__])
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, location="json", default="")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
request_data = CompletionRequest.model_validate(console_ns.payload)
|
||||
args = request_data.model_dump()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
@ -1,58 +1,60 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource, fields
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from controllers.fastopenapi import console_router
|
||||
from libs.login import current_account_with_tenant, current_user, login_required
|
||||
from services.feature_service import FeatureModel, FeatureService, SystemFeatureModel
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, cloud_utm_record, setup_required
|
||||
|
||||
|
||||
class FeatureResponse(BaseModel):
|
||||
features: FeatureModel = Field(description="Feature configuration object")
|
||||
@console_ns.route("/features")
|
||||
class FeatureApi(Resource):
|
||||
@console_ns.doc("get_tenant_features")
|
||||
@console_ns.doc(description="Get feature configuration for current tenant")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model("FeatureResponse", {"features": fields.Raw(description="Feature configuration object")}),
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_utm_record
|
||||
def get(self):
|
||||
"""Get feature configuration for current tenant"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
return FeatureService.get_features(current_tenant_id).model_dump()
|
||||
|
||||
|
||||
class SystemFeatureResponse(BaseModel):
|
||||
features: SystemFeatureModel = Field(description="System feature configuration object")
|
||||
@console_ns.route("/system-features")
|
||||
class SystemFeatureApi(Resource):
|
||||
@console_ns.doc("get_system_features")
|
||||
@console_ns.doc(description="Get system-wide feature configuration")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model(
|
||||
"SystemFeatureResponse", {"features": fields.Raw(description="System feature configuration object")}
|
||||
),
|
||||
)
|
||||
def get(self):
|
||||
"""Get system-wide feature configuration
|
||||
|
||||
NOTE: This endpoint is unauthenticated by design, as it provides system features
|
||||
data required for dashboard initialization.
|
||||
|
||||
@console_router.get(
|
||||
"/features",
|
||||
response_model=FeatureResponse,
|
||||
tags=["console"],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_utm_record
|
||||
def get_tenant_features() -> FeatureResponse:
|
||||
"""Get feature configuration for current tenant."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
Authentication would create circular dependency (can't login without dashboard loading).
|
||||
|
||||
return FeatureResponse(features=FeatureService.get_features(current_tenant_id))
|
||||
|
||||
|
||||
@console_router.get(
|
||||
"/system-features",
|
||||
response_model=SystemFeatureResponse,
|
||||
tags=["console"],
|
||||
)
|
||||
def get_system_features() -> SystemFeatureResponse:
|
||||
"""Get system-wide feature configuration
|
||||
|
||||
NOTE: This endpoint is unauthenticated by design, as it provides system features
|
||||
data required for dashboard initialization.
|
||||
|
||||
Authentication would create circular dependency (can't login without dashboard loading).
|
||||
|
||||
Only non-sensitive configuration data should be returned by this endpoint.
|
||||
"""
|
||||
# NOTE(QuantumGhost): ideally we should access `current_user.is_authenticated`
|
||||
# without a try-catch. However, due to the implementation of user loader (the `load_user_from_request`
|
||||
# in api/extensions/ext_login.py), accessing `current_user.is_authenticated` will
|
||||
# raise `Unauthorized` exception if authentication token is not provided.
|
||||
try:
|
||||
is_authenticated = current_user.is_authenticated
|
||||
except Unauthorized:
|
||||
is_authenticated = False
|
||||
return SystemFeatureResponse(features=FeatureService.get_system_features(is_authenticated=is_authenticated))
|
||||
Only non-sensitive configuration data should be returned by this endpoint.
|
||||
"""
|
||||
# NOTE(QuantumGhost): ideally we should access `current_user.is_authenticated`
|
||||
# without a try-catch. However, due to the implementation of user loader (the `load_user_from_request`
|
||||
# in api/extensions/ext_login.py), accessing `current_user.is_authenticated` will
|
||||
# raise `Unauthorized` exception if authentication token is not provided.
|
||||
try:
|
||||
is_authenticated = current_user.is_authenticated
|
||||
except Unauthorized:
|
||||
is_authenticated = False
|
||||
return FeatureService.get_system_features(is_authenticated=is_authenticated).model_dump()
|
||||
|
||||
@ -1,14 +1,27 @@
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Namespace, Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.fastopenapi import console_router
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.tag_service import TagService
|
||||
|
||||
dataset_tag_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"type": fields.String,
|
||||
"binding_count": fields.String,
|
||||
}
|
||||
|
||||
|
||||
def build_dataset_tag_fields(api_or_ns: Namespace):
|
||||
return api_or_ns.model("DataSetTag", dataset_tag_fields)
|
||||
|
||||
|
||||
class TagBasePayload(BaseModel):
|
||||
name: str = Field(description="Tag name", min_length=1, max_length=50)
|
||||
@ -32,129 +45,115 @@ class TagListQueryParam(BaseModel):
|
||||
keyword: str | None = Field(None, description="Search keyword")
|
||||
|
||||
|
||||
class TagResponse(BaseModel):
|
||||
id: str = Field(description="Tag ID")
|
||||
name: str = Field(description="Tag name")
|
||||
type: str = Field(description="Tag type")
|
||||
binding_count: int = Field(description="Number of bindings")
|
||||
|
||||
|
||||
class TagBindingResult(BaseModel):
|
||||
result: Literal["success"] = Field(description="Operation result", examples=["success"])
|
||||
|
||||
|
||||
@console_router.get(
|
||||
"/tags",
|
||||
response_model=list[TagResponse],
|
||||
tags=["console"],
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
TagBasePayload,
|
||||
TagBindingPayload,
|
||||
TagBindingRemovePayload,
|
||||
TagListQueryParam,
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def list_tags(query: TagListQueryParam) -> list[TagResponse]:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
tags = TagService.get_tags(query.type, current_tenant_id, query.keyword)
|
||||
|
||||
return [
|
||||
TagResponse(
|
||||
id=tag.id,
|
||||
name=tag.name,
|
||||
type=tag.type,
|
||||
binding_count=int(tag.binding_count),
|
||||
)
|
||||
for tag in tags
|
||||
]
|
||||
|
||||
|
||||
@console_router.post(
|
||||
"/tags",
|
||||
response_model=TagResponse,
|
||||
tags=["console"],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def create_tag(payload: TagBasePayload) -> TagResponse:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the tag table must be admin, owner, or editor
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
@console_ns.route("/tags")
|
||||
class TagListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.doc(
|
||||
params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."}
|
||||
)
|
||||
@marshal_with(dataset_tag_fields)
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
raw_args = request.args.to_dict()
|
||||
param = TagListQueryParam.model_validate(raw_args)
|
||||
tags = TagService.get_tags(param.type, current_tenant_id, param.keyword)
|
||||
|
||||
tag = TagService.save_tags(payload.model_dump())
|
||||
return tags, 200
|
||||
|
||||
return TagResponse(id=tag.id, name=tag.name, type=tag.type, binding_count=0)
|
||||
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.save_tags(payload.model_dump())
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
|
||||
return response, 200
|
||||
|
||||
|
||||
@console_router.patch(
|
||||
"/tags/<uuid:tag_id>",
|
||||
response_model=TagResponse,
|
||||
tags=["console"],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def update_tag(tag_id: UUID, payload: TagBasePayload) -> TagResponse:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
tag_id_str = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
@console_ns.route("/tags/<uuid:tag_id>")
|
||||
class TagUpdateDeleteApi(Resource):
|
||||
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, tag_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
tag_id = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
tag = TagService.update_tags(payload.model_dump(), tag_id_str)
|
||||
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.update_tags(payload.model_dump(), tag_id)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id_str)
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
return TagResponse(id=tag.id, name=tag.name, type=tag.type, binding_count=binding_count)
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
|
||||
return response, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, tag_id):
|
||||
tag_id = str(tag_id)
|
||||
|
||||
TagService.delete_tag(tag_id)
|
||||
|
||||
return 204
|
||||
|
||||
|
||||
@console_router.delete(
|
||||
"/tags/<uuid:tag_id>",
|
||||
tags=["console"],
|
||||
status_code=204,
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete_tag(tag_id: UUID) -> None:
|
||||
tag_id_str = str(tag_id)
|
||||
@console_ns.route("/tag-bindings/create")
|
||||
class TagBindingCreateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
TagService.delete_tag(tag_id_str)
|
||||
payload = TagBindingPayload.model_validate(console_ns.payload or {})
|
||||
TagService.save_tag_binding(payload.model_dump())
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_router.post(
|
||||
"/tag-bindings/create",
|
||||
response_model=TagBindingResult,
|
||||
tags=["console"],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def create_tag_binding(payload: TagBindingPayload) -> TagBindingResult:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the tag table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
@console_ns.route("/tag-bindings/remove")
|
||||
class TagBindingDeleteApi(Resource):
|
||||
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
TagService.save_tag_binding(payload.model_dump())
|
||||
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
|
||||
TagService.delete_tag_binding(payload.model_dump())
|
||||
|
||||
return TagBindingResult(result="success")
|
||||
|
||||
|
||||
@console_router.post(
|
||||
"/tag-bindings/remove",
|
||||
response_model=TagBindingResult,
|
||||
tags=["console"],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete_tag_binding(payload: TagBindingRemovePayload) -> TagBindingResult:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the tag table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
TagService.delete_tag_binding(payload.model_dump())
|
||||
|
||||
return TagBindingResult(result="success")
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dify-api"
|
||||
version = "1.11.4"
|
||||
version = "1.12.1"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
|
||||
dependencies = [
|
||||
|
||||
@ -24,7 +24,7 @@ class TagService:
|
||||
escaped_keyword = escape_like_pattern(keyword)
|
||||
query = query.where(sa.and_(Tag.name.ilike(f"%{escaped_keyword}%", escape="\\")))
|
||||
query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
|
||||
results = query.order_by(Tag.created_at.desc()).all()
|
||||
results: list = query.order_by(Tag.created_at.desc()).all()
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -259,8 +259,8 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
|
||||
|
||||
|
||||
def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str):
|
||||
def del_workflow_archive_log(workflow_archive_log_id: str):
|
||||
db.session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete(
|
||||
def del_workflow_archive_log(session, workflow_archive_log_id: str):
|
||||
session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
@ -420,7 +420,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
|
||||
total_files_deleted = 0
|
||||
|
||||
while True:
|
||||
with session_factory.create_session() as session:
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
# Get a batch of draft variable IDs along with their file_ids
|
||||
query_sql = """
|
||||
SELECT id, file_id FROM workflow_draft_variables
|
||||
|
||||
@ -10,7 +10,10 @@ from models import Tenant
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, UploadFile
|
||||
from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile
|
||||
from tasks.remove_app_and_related_data_task import _delete_draft_variables, delete_draft_variables_batch
|
||||
from tasks.remove_app_and_related_data_task import (
|
||||
_delete_draft_variables,
|
||||
delete_draft_variables_batch,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -297,12 +300,18 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
|
||||
def test_delete_draft_variables_with_offload_data(self, mock_storage, setup_offload_test_data):
|
||||
data = setup_offload_test_data
|
||||
app_id = data["app"].id
|
||||
upload_file_ids = [uf.id for uf in data["upload_files"]]
|
||||
variable_file_ids = [vf.id for vf in data["variable_files"]]
|
||||
mock_storage.delete.return_value = None
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
var_files_before = session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_before = session.query(UploadFile).count()
|
||||
var_files_before = (
|
||||
session.query(WorkflowDraftVariableFile)
|
||||
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
|
||||
.count()
|
||||
)
|
||||
upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
|
||||
assert draft_vars_before == 3
|
||||
assert var_files_before == 2
|
||||
assert upload_files_before == 2
|
||||
@ -315,8 +324,12 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
|
||||
assert draft_vars_after == 0
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
var_files_after = session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_after = session.query(UploadFile).count()
|
||||
var_files_after = (
|
||||
session.query(WorkflowDraftVariableFile)
|
||||
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
|
||||
.count()
|
||||
)
|
||||
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
|
||||
assert var_files_after == 0
|
||||
assert upload_files_after == 0
|
||||
|
||||
@ -329,6 +342,8 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
|
||||
def test_delete_draft_variables_storage_failure_continues_cleanup(self, mock_storage, setup_offload_test_data):
|
||||
data = setup_offload_test_data
|
||||
app_id = data["app"].id
|
||||
upload_file_ids = [uf.id for uf in data["upload_files"]]
|
||||
variable_file_ids = [vf.id for vf in data["variable_files"]]
|
||||
mock_storage.delete.side_effect = [Exception("Storage error"), None]
|
||||
|
||||
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
|
||||
@ -339,8 +354,12 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
|
||||
assert draft_vars_after == 0
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
var_files_after = session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_after = session.query(UploadFile).count()
|
||||
var_files_after = (
|
||||
session.query(WorkflowDraftVariableFile)
|
||||
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
|
||||
.count()
|
||||
)
|
||||
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
|
||||
assert var_files_after == 0
|
||||
assert upload_files_after == 0
|
||||
|
||||
@ -395,3 +414,275 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
|
||||
if app2_obj:
|
||||
session.delete(app2_obj)
|
||||
session.commit()
|
||||
|
||||
|
||||
class TestDeleteDraftVariablesSessionCommit:
|
||||
"""Test suite to verify session commit behavior in delete_draft_variables_batch."""
|
||||
|
||||
@pytest.fixture
|
||||
def setup_offload_test_data(self, app_and_tenant):
|
||||
"""Create test data with offload files for session commit tests."""
|
||||
from core.variables.types import SegmentType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
tenant, app = app_and_tenant
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
upload_file1 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
key="test/file1.json",
|
||||
name="file1.json",
|
||||
size=1024,
|
||||
extension="json",
|
||||
mime_type="application/json",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
created_at=naive_utc_now(),
|
||||
used=False,
|
||||
)
|
||||
upload_file2 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
key="test/file2.json",
|
||||
name="file2.json",
|
||||
size=2048,
|
||||
extension="json",
|
||||
mime_type="application/json",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
created_at=naive_utc_now(),
|
||||
used=False,
|
||||
)
|
||||
session.add(upload_file1)
|
||||
session.add(upload_file2)
|
||||
session.flush()
|
||||
|
||||
var_file1 = WorkflowDraftVariableFile(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app.id,
|
||||
user_id=str(uuid.uuid4()),
|
||||
upload_file_id=upload_file1.id,
|
||||
size=1024,
|
||||
length=10,
|
||||
value_type=SegmentType.STRING,
|
||||
)
|
||||
var_file2 = WorkflowDraftVariableFile(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app.id,
|
||||
user_id=str(uuid.uuid4()),
|
||||
upload_file_id=upload_file2.id,
|
||||
size=2048,
|
||||
length=20,
|
||||
value_type=SegmentType.OBJECT,
|
||||
)
|
||||
session.add(var_file1)
|
||||
session.add(var_file2)
|
||||
session.flush()
|
||||
|
||||
draft_var1 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_1",
|
||||
name="large_var_1",
|
||||
value=StringSegment(value="truncated..."),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
file_id=var_file1.id,
|
||||
)
|
||||
draft_var2 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_2",
|
||||
name="large_var_2",
|
||||
value=StringSegment(value="truncated..."),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
file_id=var_file2.id,
|
||||
)
|
||||
draft_var3 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_3",
|
||||
name="regular_var",
|
||||
value=StringSegment(value="regular_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
session.add(draft_var1)
|
||||
session.add(draft_var2)
|
||||
session.add(draft_var3)
|
||||
session.commit()
|
||||
|
||||
data = {
|
||||
"app": app,
|
||||
"tenant": tenant,
|
||||
"upload_files": [upload_file1, upload_file2],
|
||||
"variable_files": [var_file1, var_file2],
|
||||
"draft_variables": [draft_var1, draft_var2, draft_var3],
|
||||
}
|
||||
|
||||
yield data
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
for table, ids in [
|
||||
(WorkflowDraftVariable, [v.id for v in data["draft_variables"]]),
|
||||
(WorkflowDraftVariableFile, [vf.id for vf in data["variable_files"]]),
|
||||
(UploadFile, [uf.id for uf in data["upload_files"]]),
|
||||
]:
|
||||
cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False)
|
||||
session.execute(cleanup_query)
|
||||
session.commit()
|
||||
|
||||
@pytest.fixture
|
||||
def setup_commit_test_data(self, app_and_tenant):
|
||||
"""Create test data for session commit tests."""
|
||||
tenant, app = app_and_tenant
|
||||
variable_ids: list[str] = []
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
variables = []
|
||||
for i in range(10):
|
||||
var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id=f"node_{i}",
|
||||
name=f"var_{i}",
|
||||
value=StringSegment(value="test_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
session.add(var)
|
||||
variables.append(var)
|
||||
session.commit()
|
||||
variable_ids = [v.id for v in variables]
|
||||
|
||||
yield {
|
||||
"app": app,
|
||||
"tenant": tenant,
|
||||
"variable_ids": variable_ids,
|
||||
}
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
cleanup_query = (
|
||||
delete(WorkflowDraftVariable)
|
||||
.where(WorkflowDraftVariable.id.in_(variable_ids))
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
session.execute(cleanup_query)
|
||||
session.commit()
|
||||
|
||||
def test_session_commit_is_called_after_each_batch(self, setup_commit_test_data):
|
||||
"""Test that session.begin() is used for automatic transaction management."""
|
||||
data = setup_commit_test_data
|
||||
app_id = data["app"].id
|
||||
|
||||
# Since session.begin() is used, the transaction is automatically committed
|
||||
# when the with block exits successfully. We verify this by checking that
|
||||
# data is actually persisted.
|
||||
deleted_count = delete_draft_variables_batch(app_id, batch_size=3)
|
||||
|
||||
# Verify all data was deleted (proves transaction was committed)
|
||||
with session_factory.create_session() as session:
|
||||
remaining_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
|
||||
assert deleted_count == 10
|
||||
assert remaining_count == 0
|
||||
|
||||
def test_data_persisted_after_batch_deletion(self, setup_commit_test_data):
|
||||
"""Test that data is actually persisted to database after batch deletion with commits."""
|
||||
data = setup_commit_test_data
|
||||
app_id = data["app"].id
|
||||
variable_ids = data["variable_ids"]
|
||||
|
||||
# Verify initial state
|
||||
with session_factory.create_session() as session:
|
||||
initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
assert initial_count == 10
|
||||
|
||||
# Perform deletion with small batch size to force multiple commits
|
||||
deleted_count = delete_draft_variables_batch(app_id, batch_size=3)
|
||||
|
||||
assert deleted_count == 10
|
||||
|
||||
# Verify all data is deleted in a new session (proves commits worked)
|
||||
with session_factory.create_session() as session:
|
||||
final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
assert final_count == 0
|
||||
|
||||
# Verify specific IDs are deleted
|
||||
with session_factory.create_session() as session:
|
||||
remaining_vars = (
|
||||
session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id.in_(variable_ids)).count()
|
||||
)
|
||||
assert remaining_vars == 0
|
||||
|
||||
def test_session_commit_with_empty_dataset(self, setup_commit_test_data):
|
||||
"""Test session behavior when deleting from an empty dataset."""
|
||||
nonexistent_app_id = str(uuid.uuid4())
|
||||
|
||||
# Should not raise any errors and should return 0
|
||||
deleted_count = delete_draft_variables_batch(nonexistent_app_id, batch_size=10)
|
||||
assert deleted_count == 0
|
||||
|
||||
def test_session_commit_with_single_batch(self, setup_commit_test_data):
|
||||
"""Test that commit happens correctly when all data fits in a single batch."""
|
||||
data = setup_commit_test_data
|
||||
app_id = data["app"].id
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
assert initial_count == 10
|
||||
|
||||
# Delete all in a single batch
|
||||
deleted_count = delete_draft_variables_batch(app_id, batch_size=100)
|
||||
assert deleted_count == 10
|
||||
|
||||
# Verify data is persisted
|
||||
with session_factory.create_session() as session:
|
||||
final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
assert final_count == 0
|
||||
|
||||
def test_invalid_batch_size_raises_error(self, setup_commit_test_data):
|
||||
"""Test that invalid batch size raises ValueError."""
|
||||
data = setup_commit_test_data
|
||||
app_id = data["app"].id
|
||||
|
||||
with pytest.raises(ValueError, match="batch_size must be positive"):
|
||||
delete_draft_variables_batch(app_id, batch_size=0)
|
||||
|
||||
with pytest.raises(ValueError, match="batch_size must be positive"):
|
||||
delete_draft_variables_batch(app_id, batch_size=-1)
|
||||
|
||||
@patch("extensions.ext_storage.storage")
|
||||
def test_session_commit_with_offload_data_cleanup(self, mock_storage, setup_offload_test_data):
|
||||
"""Test that session commits correctly when cleaning up offload data."""
|
||||
data = setup_offload_test_data
|
||||
app_id = data["app"].id
|
||||
upload_file_ids = [uf.id for uf in data["upload_files"]]
|
||||
mock_storage.delete.return_value = None
|
||||
|
||||
# Verify initial state
|
||||
with session_factory.create_session() as session:
|
||||
draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
var_files_before = (
|
||||
session.query(WorkflowDraftVariableFile)
|
||||
.where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]]))
|
||||
.count()
|
||||
)
|
||||
upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
|
||||
assert draft_vars_before == 3
|
||||
assert var_files_before == 2
|
||||
assert upload_files_before == 2
|
||||
|
||||
# Delete variables with offload data
|
||||
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
|
||||
assert deleted_count == 3
|
||||
|
||||
# Verify all data is persisted (deleted) in new session
|
||||
with session_factory.create_session() as session:
|
||||
draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
var_files_after = (
|
||||
session.query(WorkflowDraftVariableFile)
|
||||
.where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]]))
|
||||
.count()
|
||||
)
|
||||
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
|
||||
assert draft_vars_after == 0
|
||||
assert var_files_after == 0
|
||||
assert upload_files_after == 0
|
||||
|
||||
# Verify storage cleanup was called
|
||||
assert mock_storage.delete.call_count == 2
|
||||
|
||||
@ -1,291 +0,0 @@
|
||||
import builtins
|
||||
import contextlib
|
||||
import importlib
|
||||
import sys
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from extensions import ext_fastopenapi
|
||||
from extensions.ext_database import db
|
||||
from services.feature_service import FeatureModel, SystemFeatureModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""
|
||||
Creates a Flask application instance configured for testing.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["SECRET_KEY"] = "test-secret"
|
||||
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
|
||||
|
||||
# Initialize the database with the app
|
||||
db.init_app(app)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def fix_method_view_issue(monkeypatch):
|
||||
"""
|
||||
Automatic fixture to patch 'builtins.MethodView'.
|
||||
|
||||
Why this is needed:
|
||||
The official legacy codebase contains a global patch in its initialization logic:
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView
|
||||
|
||||
Some dependencies (like ext_fastopenapi or older Flask extensions) might implicitly
|
||||
rely on 'MethodView' being available in the global builtins namespace.
|
||||
|
||||
Refactoring Note:
|
||||
While patching builtins is generally discouraged due to global side effects,
|
||||
this fixture reproduces the production environment's state to ensure tests are realistic.
|
||||
We use 'monkeypatch' to ensure that this change is undone after the test finishes,
|
||||
keeping other tests isolated.
|
||||
"""
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
# 'raising=False' allows us to set an attribute that doesn't exist yet
|
||||
monkeypatch.setattr(builtins, "MethodView", MethodView, raising=False)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Helper Functions for Fixture Complexity Reduction
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _create_isolated_router():
|
||||
"""
|
||||
Creates a fresh, isolated router instance to prevent route pollution.
|
||||
"""
|
||||
import controllers.fastopenapi
|
||||
|
||||
# Dynamically get the class type (e.g., FlaskRouter) to avoid hardcoding dependencies
|
||||
RouterClass = type(controllers.fastopenapi.console_router)
|
||||
return RouterClass()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _patch_auth_and_router(temp_router):
|
||||
"""
|
||||
Context manager that applies all necessary patches for:
|
||||
1. The console_router (redirecting to our isolated temp_router)
|
||||
2. Authentication decorators (disabling them with no-ops)
|
||||
3. User/Account loaders (mocking authenticated state)
|
||||
"""
|
||||
|
||||
def noop(f):
|
||||
return f
|
||||
|
||||
# We patch the SOURCE of the decorators/functions, not the destination module.
|
||||
# This ensures that when 'controllers.console.feature' imports them, it gets the mocks.
|
||||
with (
|
||||
patch("controllers.fastopenapi.console_router", temp_router),
|
||||
patch("extensions.ext_fastopenapi.console_router", temp_router),
|
||||
patch("controllers.console.wraps.setup_required", side_effect=noop),
|
||||
patch("libs.login.login_required", side_effect=noop),
|
||||
patch("controllers.console.wraps.account_initialization_required", side_effect=noop),
|
||||
patch("controllers.console.wraps.cloud_utm_record", side_effect=noop),
|
||||
patch("libs.login.current_account_with_tenant", return_value=(MagicMock(), "tenant-id")),
|
||||
patch("libs.login.current_user", MagicMock(is_authenticated=True)),
|
||||
):
|
||||
# Explicitly reload ext_fastopenapi to ensure it uses the patched console_router
|
||||
import extensions.ext_fastopenapi
|
||||
|
||||
importlib.reload(extensions.ext_fastopenapi)
|
||||
|
||||
yield
|
||||
|
||||
|
||||
def _force_reload_module(target_module: str, alias_module: str):
|
||||
"""
|
||||
Forces a reload of the specified module and handles sys.modules aliasing.
|
||||
|
||||
Why reload?
|
||||
Python decorators (like @route, @login_required) run at IMPORT time.
|
||||
To apply our patches (mocks/no-ops) to these decorators, we must re-import
|
||||
the module while the patches are active.
|
||||
|
||||
Why alias?
|
||||
If 'ext_fastopenapi' imports the controller as 'api.controllers...', but we import
|
||||
it as 'controllers...', Python treats them as two separate modules. This causes:
|
||||
1. Double execution of decorators (registering routes twice -> AssertionError).
|
||||
2. Type mismatch errors (Class A from module X is not Class A from module Y).
|
||||
|
||||
This function ensures both names point to the SAME loaded module instance.
|
||||
"""
|
||||
# 1. Clean existing entries to force re-import
|
||||
if target_module in sys.modules:
|
||||
del sys.modules[target_module]
|
||||
if alias_module in sys.modules:
|
||||
del sys.modules[alias_module]
|
||||
|
||||
# 2. Import the module (triggering decorators with active patches)
|
||||
module = importlib.import_module(target_module)
|
||||
|
||||
# 3. Alias the module in sys.modules to prevent double loading
|
||||
sys.modules[alias_module] = sys.modules[target_module]
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def _cleanup_modules(target_module: str, alias_module: str):
|
||||
"""
|
||||
Removes the module and its alias from sys.modules to prevent side effects
|
||||
on other tests.
|
||||
"""
|
||||
if target_module in sys.modules:
|
||||
del sys.modules[target_module]
|
||||
if alias_module in sys.modules:
|
||||
del sys.modules[alias_module]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_feature_module_env():
|
||||
"""
|
||||
Sets up a mocked environment for the feature module.
|
||||
|
||||
This fixture orchestrates:
|
||||
1. Creating an isolated router.
|
||||
2. Patching authentication and global dependencies.
|
||||
3. Reloading the controller module to apply patches to decorators.
|
||||
4. cleaning up sys.modules afterwards.
|
||||
"""
|
||||
target_module = "controllers.console.feature"
|
||||
alias_module = "api.controllers.console.feature"
|
||||
|
||||
# 1. Prepare isolated router
|
||||
temp_router = _create_isolated_router()
|
||||
|
||||
# 2. Apply patches
|
||||
try:
|
||||
with _patch_auth_and_router(temp_router):
|
||||
# 3. Reload module to register routes on the temp_router
|
||||
feature_module = _force_reload_module(target_module, alias_module)
|
||||
|
||||
yield feature_module
|
||||
|
||||
finally:
|
||||
# 4. Teardown: Clean up sys.modules
|
||||
_cleanup_modules(target_module, alias_module)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Test Cases
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("url", "service_mock_path", "mock_model_instance", "json_key"),
|
||||
[
|
||||
(
|
||||
"/console/api/features",
|
||||
"controllers.console.feature.FeatureService.get_features",
|
||||
FeatureModel(can_replace_logo=True),
|
||||
"features",
|
||||
),
|
||||
(
|
||||
"/console/api/system-features",
|
||||
"controllers.console.feature.FeatureService.get_system_features",
|
||||
SystemFeatureModel(enable_marketplace=True),
|
||||
"features",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_console_features_success(app, mock_feature_module_env, url, service_mock_path, mock_model_instance, json_key):
|
||||
"""
|
||||
Tests that the feature APIs return a 200 OK status and correct JSON structure.
|
||||
"""
|
||||
# Patch the service layer to return our mock model instance
|
||||
with patch(service_mock_path, return_value=mock_model_instance):
|
||||
# Initialize the API extension
|
||||
ext_fastopenapi.init_app(app)
|
||||
|
||||
client = app.test_client()
|
||||
response = client.get(url)
|
||||
|
||||
# Assertions
|
||||
assert response.status_code == 200, f"Request failed with status {response.status_code}: {response.text}"
|
||||
|
||||
# Verify the JSON response matches the Pydantic model dump
|
||||
expected_data = mock_model_instance.model_dump(mode="json")
|
||||
assert response.get_json() == {json_key: expected_data}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("url", "service_mock_path"),
|
||||
[
|
||||
("/console/api/features", "controllers.console.feature.FeatureService.get_features"),
|
||||
("/console/api/system-features", "controllers.console.feature.FeatureService.get_system_features"),
|
||||
],
|
||||
)
|
||||
def test_console_features_service_error(app, mock_feature_module_env, url, service_mock_path):
|
||||
"""
|
||||
Tests how the application handles Service layer errors.
|
||||
|
||||
Note: When an exception occurs in the view, it is typically caught by the framework
|
||||
(Flask or the OpenAPI wrapper) and converted to a 500 error response.
|
||||
This test verifies that the application returns a 500 status code.
|
||||
"""
|
||||
# Simulate a service failure
|
||||
with patch(service_mock_path, side_effect=ValueError("Service Failure")):
|
||||
ext_fastopenapi.init_app(app)
|
||||
client = app.test_client()
|
||||
|
||||
# When an exception occurs in the view, it is typically caught by the framework
|
||||
# (Flask or the OpenAPI wrapper) and converted to a 500 error response.
|
||||
response = client.get(url)
|
||||
|
||||
assert response.status_code == 500
|
||||
# Check if the error details are exposed in the response (depends on error handler config)
|
||||
# We accept either generic 500 or the specific error message
|
||||
assert "Service Failure" in response.text or "Internal Server Error" in response.text
|
||||
|
||||
|
||||
def test_system_features_unauthenticated(app, mock_feature_module_env):
|
||||
"""
|
||||
Tests that /console/api/system-features endpoint works without authentication.
|
||||
|
||||
This test verifies the try-except block in get_system_features that handles
|
||||
unauthenticated requests by passing is_authenticated=False to the service layer.
|
||||
"""
|
||||
feature_module = mock_feature_module_env
|
||||
|
||||
# Override the behavior of the current_user mock
|
||||
# The fixture patched 'libs.login.current_user', so 'controllers.console.feature.current_user'
|
||||
# refers to that same Mock object.
|
||||
mock_user = feature_module.current_user
|
||||
|
||||
# Simulate property access raising Unauthorized
|
||||
# Note: We must reset side_effect if it was set, or set it here.
|
||||
# The fixture initialized it as MagicMock(is_authenticated=True).
|
||||
# We want type(mock_user).is_authenticated to raise Unauthorized.
|
||||
type(mock_user).is_authenticated = PropertyMock(side_effect=Unauthorized)
|
||||
|
||||
# Patch the service layer for this specific test
|
||||
with patch("controllers.console.feature.FeatureService.get_system_features") as mock_service:
|
||||
# Setup mock service return value
|
||||
mock_model = SystemFeatureModel(enable_marketplace=True)
|
||||
mock_service.return_value = mock_model
|
||||
|
||||
# Initialize app
|
||||
ext_fastopenapi.init_app(app)
|
||||
client = app.test_client()
|
||||
|
||||
# Act
|
||||
response = client.get("/console/api/system-features")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200, f"Request failed: {response.text}"
|
||||
|
||||
# Verify service was called with is_authenticated=False
|
||||
mock_service.assert_called_once_with(is_authenticated=False)
|
||||
|
||||
# Verify response body
|
||||
expected_data = mock_model.model_dump(mode="json")
|
||||
assert response.get_json() == {"features": expected_data}
|
||||
@ -1,222 +0,0 @@
|
||||
import builtins
|
||||
import contextlib
|
||||
import importlib
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from extensions import ext_fastopenapi
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["SECRET_KEY"] = "test-secret"
|
||||
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
|
||||
|
||||
db.init_app(app)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def fix_method_view_issue(monkeypatch):
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
monkeypatch.setattr(builtins, "MethodView", MethodView, raising=False)
|
||||
|
||||
|
||||
def _create_isolated_router():
|
||||
import controllers.fastopenapi
|
||||
|
||||
router_class = type(controllers.fastopenapi.console_router)
|
||||
return router_class()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _patch_auth_and_router(temp_router):
|
||||
def noop(func):
|
||||
return func
|
||||
|
||||
default_user = MagicMock(has_edit_permission=True, is_dataset_editor=False)
|
||||
|
||||
with (
|
||||
patch("controllers.fastopenapi.console_router", temp_router),
|
||||
patch("extensions.ext_fastopenapi.console_router", temp_router),
|
||||
patch("controllers.console.wraps.setup_required", side_effect=noop),
|
||||
patch("libs.login.login_required", side_effect=noop),
|
||||
patch("controllers.console.wraps.account_initialization_required", side_effect=noop),
|
||||
patch("controllers.console.wraps.edit_permission_required", side_effect=noop),
|
||||
patch("libs.login.current_account_with_tenant", return_value=(default_user, "tenant-id")),
|
||||
patch("configs.dify_config.EDITION", "CLOUD"),
|
||||
):
|
||||
import extensions.ext_fastopenapi
|
||||
|
||||
importlib.reload(extensions.ext_fastopenapi)
|
||||
|
||||
yield
|
||||
|
||||
|
||||
def _force_reload_module(target_module: str, alias_module: str):
|
||||
if target_module in sys.modules:
|
||||
del sys.modules[target_module]
|
||||
if alias_module in sys.modules:
|
||||
del sys.modules[alias_module]
|
||||
|
||||
module = importlib.import_module(target_module)
|
||||
sys.modules[alias_module] = sys.modules[target_module]
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def _dedupe_routes(router):
|
||||
seen = set()
|
||||
unique_routes = []
|
||||
for path, method, endpoint in reversed(router.get_routes()):
|
||||
key = (path, method, endpoint.__name__)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
unique_routes.append((path, method, endpoint))
|
||||
router._routes = list(reversed(unique_routes))
|
||||
|
||||
|
||||
def _cleanup_modules(target_module: str, alias_module: str):
|
||||
if target_module in sys.modules:
|
||||
del sys.modules[target_module]
|
||||
if alias_module in sys.modules:
|
||||
del sys.modules[alias_module]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tags_module_env():
|
||||
target_module = "controllers.console.tag.tags"
|
||||
alias_module = "api.controllers.console.tag.tags"
|
||||
temp_router = _create_isolated_router()
|
||||
|
||||
try:
|
||||
with _patch_auth_and_router(temp_router):
|
||||
tags_module = _force_reload_module(target_module, alias_module)
|
||||
_dedupe_routes(temp_router)
|
||||
yield tags_module
|
||||
finally:
|
||||
_cleanup_modules(target_module, alias_module)
|
||||
|
||||
|
||||
def test_list_tags_success(app: Flask, mock_tags_module_env):
|
||||
# Arrange
|
||||
tag = SimpleNamespace(id="tag-1", name="Alpha", type="app", binding_count=2)
|
||||
with patch("controllers.console.tag.tags.TagService.get_tags", return_value=[tag]):
|
||||
ext_fastopenapi.init_app(app)
|
||||
client = app.test_client()
|
||||
|
||||
# Act
|
||||
response = client.get("/console/api/tags?type=app&keyword=Alpha")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == [
|
||||
{"id": "tag-1", "name": "Alpha", "type": "app", "binding_count": 2},
|
||||
]
|
||||
|
||||
|
||||
def test_create_tag_success(app: Flask, mock_tags_module_env):
|
||||
# Arrange
|
||||
tag = SimpleNamespace(id="tag-2", name="Beta", type="app")
|
||||
with patch("controllers.console.tag.tags.TagService.save_tags", return_value=tag) as mock_save:
|
||||
ext_fastopenapi.init_app(app)
|
||||
client = app.test_client()
|
||||
|
||||
# Act
|
||||
response = client.post("/console/api/tags", json={"name": "Beta", "type": "app"})
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {
|
||||
"id": "tag-2",
|
||||
"name": "Beta",
|
||||
"type": "app",
|
||||
"binding_count": 0,
|
||||
}
|
||||
mock_save.assert_called_once_with({"name": "Beta", "type": "app"})
|
||||
|
||||
|
||||
def test_update_tag_success(app: Flask, mock_tags_module_env):
|
||||
# Arrange
|
||||
tag = SimpleNamespace(id="tag-3", name="Gamma", type="app")
|
||||
with (
|
||||
patch("controllers.console.tag.tags.TagService.update_tags", return_value=tag) as mock_update,
|
||||
patch("controllers.console.tag.tags.TagService.get_tag_binding_count", return_value=4),
|
||||
):
|
||||
ext_fastopenapi.init_app(app)
|
||||
client = app.test_client()
|
||||
|
||||
# Act
|
||||
response = client.patch(
|
||||
"/console/api/tags/11111111-1111-1111-1111-111111111111",
|
||||
json={"name": "Gamma", "type": "app"},
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {
|
||||
"id": "tag-3",
|
||||
"name": "Gamma",
|
||||
"type": "app",
|
||||
"binding_count": 4,
|
||||
}
|
||||
mock_update.assert_called_once_with(
|
||||
{"name": "Gamma", "type": "app"},
|
||||
"11111111-1111-1111-1111-111111111111",
|
||||
)
|
||||
|
||||
|
||||
def test_delete_tag_success(app: Flask, mock_tags_module_env):
|
||||
# Arrange
|
||||
with patch("controllers.console.tag.tags.TagService.delete_tag") as mock_delete:
|
||||
ext_fastopenapi.init_app(app)
|
||||
client = app.test_client()
|
||||
|
||||
# Act
|
||||
response = client.delete("/console/api/tags/11111111-1111-1111-1111-111111111111")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 204
|
||||
mock_delete.assert_called_once_with("11111111-1111-1111-1111-111111111111")
|
||||
|
||||
|
||||
def test_create_tag_binding_success(app: Flask, mock_tags_module_env):
|
||||
# Arrange
|
||||
payload = {"tag_ids": ["tag-1", "tag-2"], "target_id": "target-1", "type": "app"}
|
||||
with patch("controllers.console.tag.tags.TagService.save_tag_binding") as mock_bind:
|
||||
ext_fastopenapi.init_app(app)
|
||||
client = app.test_client()
|
||||
|
||||
# Act
|
||||
response = client.post("/console/api/tag-bindings/create", json=payload)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"result": "success"}
|
||||
mock_bind.assert_called_once_with(payload)
|
||||
|
||||
|
||||
def test_delete_tag_binding_success(app: Flask, mock_tags_module_env):
|
||||
# Arrange
|
||||
payload = {"tag_id": "tag-1", "target_id": "target-1", "type": "app"}
|
||||
with patch("controllers.console.tag.tags.TagService.delete_tag_binding") as mock_unbind:
|
||||
ext_fastopenapi.init_app(app)
|
||||
client = app.test_client()
|
||||
|
||||
# Act
|
||||
response = client.post("/console/api/tag-bindings/remove", json=payload)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"result": "success"}
|
||||
mock_unbind.assert_called_once_with(payload)
|
||||
@ -350,7 +350,7 @@ class TestDeleteWorkflowArchiveLogs:
|
||||
mock_query.where.return_value = mock_delete_query
|
||||
mock_db.session.query.return_value = mock_query
|
||||
|
||||
delete_func("log-1")
|
||||
delete_func(mock_db.session, "log-1")
|
||||
|
||||
mock_db.session.query.assert_called_once_with(WorkflowArchiveLog)
|
||||
mock_query.where.assert_called_once()
|
||||
|
||||
2
api/uv.lock
generated
2
api/uv.lock
generated
@ -1368,7 +1368,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "dify-api"
|
||||
version = "1.11.4"
|
||||
version = "1.12.1"
|
||||
source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "aliyun-log-python-sdk" },
|
||||
|
||||
Reference in New Issue
Block a user