Merge remote-tracking branch 'upstream/main'

This commit is contained in:
FFXN
2025-12-16 15:50:49 +08:00
921 changed files with 51608 additions and 20120 deletions

View File

@ -0,0 +1,26 @@
"""Helpers for registering Pydantic models with Flask-RESTX namespaces."""
from flask_restx import Namespace
from pydantic import BaseModel
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
def register_schema_model(namespace: Namespace, model: type[BaseModel]) -> None:
"""Register a single BaseModel with a namespace for Swagger documentation."""
namespace.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> None:
"""Register multiple BaseModels with a namespace."""
for model in models:
register_schema_model(namespace, model)
__all__ = [
"DEFAULT_REF_TEMPLATE_SWAGGER_2_0",
"register_schema_model",
"register_schema_models",
]

View File

@ -3,21 +3,47 @@ from functools import wraps
from typing import ParamSpec, TypeVar
from flask import request
from flask_restx import Resource, fields, reqparse
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound, Unauthorized
P = ParamSpec("P")
R = TypeVar("R")
from configs import dify_config
from constants.languages import supported_language
from controllers.console import console_ns
from controllers.console.wraps import only_edition_cloud
from core.db.session_factory import session_factory
from extensions.ext_database import db
from libs.token import extract_access_token
from models.model import App, InstalledApp, RecommendedApp
P = ParamSpec("P")
R = TypeVar("R")
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class InsertExploreAppPayload(BaseModel):
app_id: str = Field(...)
desc: str | None = None
copyright: str | None = None
privacy_policy: str | None = None
custom_disclaimer: str | None = None
language: str = Field(...)
category: str = Field(...)
position: int = Field(...)
@field_validator("language")
@classmethod
def validate_language(cls, value: str) -> str:
return supported_language(value)
console_ns.schema_model(
InsertExploreAppPayload.__name__,
InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def admin_required(view: Callable[P, R]):
@wraps(view)
@ -40,59 +66,34 @@ def admin_required(view: Callable[P, R]):
class InsertExploreAppListApi(Resource):
@console_ns.doc("insert_explore_app")
@console_ns.doc(description="Insert or update an app in the explore list")
@console_ns.expect(
console_ns.model(
"InsertExploreAppRequest",
{
"app_id": fields.String(required=True, description="Application ID"),
"desc": fields.String(description="App description"),
"copyright": fields.String(description="Copyright information"),
"privacy_policy": fields.String(description="Privacy policy"),
"custom_disclaimer": fields.String(description="Custom disclaimer"),
"language": fields.String(required=True, description="Language code"),
"category": fields.String(required=True, description="App category"),
"position": fields.Integer(required=True, description="Display position"),
},
)
)
@console_ns.expect(console_ns.models[InsertExploreAppPayload.__name__])
@console_ns.response(200, "App updated successfully")
@console_ns.response(201, "App inserted successfully")
@console_ns.response(404, "App not found")
@only_edition_cloud
@admin_required
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("app_id", type=str, required=True, nullable=False, location="json")
.add_argument("desc", type=str, location="json")
.add_argument("copyright", type=str, location="json")
.add_argument("privacy_policy", type=str, location="json")
.add_argument("custom_disclaimer", type=str, location="json")
.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
.add_argument("category", type=str, required=True, nullable=False, location="json")
.add_argument("position", type=int, required=True, nullable=False, location="json")
)
args = parser.parse_args()
payload = InsertExploreAppPayload.model_validate(console_ns.payload)
app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none()
app = db.session.execute(select(App).where(App.id == payload.app_id)).scalar_one_or_none()
if not app:
raise NotFound(f"App '{args['app_id']}' is not found")
raise NotFound(f"App '{payload.app_id}' is not found")
site = app.site
if not site:
desc = args["desc"] or ""
copy_right = args["copyright"] or ""
privacy_policy = args["privacy_policy"] or ""
custom_disclaimer = args["custom_disclaimer"] or ""
desc = payload.desc or ""
copy_right = payload.copyright or ""
privacy_policy = payload.privacy_policy or ""
custom_disclaimer = payload.custom_disclaimer or ""
else:
desc = site.description or args["desc"] or ""
copy_right = site.copyright or args["copyright"] or ""
privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""
desc = site.description or payload.desc or ""
copy_right = site.copyright or payload.copyright or ""
privacy_policy = site.privacy_policy or payload.privacy_policy or ""
custom_disclaimer = site.custom_disclaimer or payload.custom_disclaimer or ""
with Session(db.engine) as session:
with session_factory.create_session() as session:
recommended_app = session.execute(
select(RecommendedApp).where(RecommendedApp.app_id == args["app_id"])
select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id)
).scalar_one_or_none()
if not recommended_app:
@ -102,9 +103,9 @@ class InsertExploreAppListApi(Resource):
copyright=copy_right,
privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer,
language=args["language"],
category=args["category"],
position=args["position"],
language=payload.language,
category=payload.category,
position=payload.position,
)
db.session.add(recommended_app)
@ -118,9 +119,9 @@ class InsertExploreAppListApi(Resource):
recommended_app.copyright = copy_right
recommended_app.privacy_policy = privacy_policy
recommended_app.custom_disclaimer = custom_disclaimer
recommended_app.language = args["language"]
recommended_app.category = args["category"]
recommended_app.position = args["position"]
recommended_app.language = payload.language
recommended_app.category = payload.category
recommended_app.position = payload.position
app.is_public = True
@ -138,7 +139,7 @@ class InsertExploreAppApi(Resource):
@only_edition_cloud
@admin_required
def delete(self, app_id):
with Session(db.engine) as session:
with session_factory.create_session() as session:
recommended_app = session.execute(
select(RecommendedApp).where(RecommendedApp.app_id == str(app_id))
).scalar_one_or_none()
@ -146,13 +147,13 @@ class InsertExploreAppApi(Resource):
if not recommended_app:
return {"result": "success"}, 204
with Session(db.engine) as session:
with session_factory.create_session() as session:
app = session.execute(select(App).where(App.id == recommended_app.app_id)).scalar_one_or_none()
if app:
app.is_public = False
with Session(db.engine) as session:
with session_factory.create_session() as session:
installed_apps = (
session.execute(
select(InstalledApp).where(

View File

@ -1,4 +1,6 @@
from flask_restx import Resource, fields, reqparse
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
@ -8,10 +10,21 @@ from libs.login import login_required
from models.model import AppMode
from services.agent_service import AgentService
parser = (
reqparse.RequestParser()
.add_argument("message_id", type=uuid_value, required=True, location="args", help="Message UUID")
.add_argument("conversation_id", type=uuid_value, required=True, location="args", help="Conversation UUID")
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AgentLogQuery(BaseModel):
message_id: str = Field(..., description="Message UUID")
conversation_id: str = Field(..., description="Conversation UUID")
@field_validator("message_id", "conversation_id")
@classmethod
def validate_uuid(cls, value: str) -> str:
return uuid_value(value)
console_ns.schema_model(
AgentLogQuery.__name__, AgentLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@ -20,7 +33,7 @@ class AgentLogApi(Resource):
@console_ns.doc("get_agent_logs")
@console_ns.doc(description="Get agent execution logs for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(parser)
@console_ns.expect(console_ns.models[AgentLogQuery.__name__])
@console_ns.response(
200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries"))
)
@ -31,6 +44,6 @@ class AgentLogApi(Resource):
@get_app_model(mode=[AppMode.AGENT_CHAT])
def get(self, app_model):
"""Get agent logs"""
args = parser.parse_args()
args = AgentLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"])
return AgentService.get_agent_logs(app_model, args.conversation_id, args.message_id)

View File

@ -1,12 +1,15 @@
from typing import Literal
from typing import Any, Literal
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from flask import abort, make_response, request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
annotation_import_concurrency_limit,
annotation_import_rate_limit,
cloud_edition_billing_resource_check,
edit_permission_required,
setup_required,
@ -21,22 +24,79 @@ from libs.helper import uuid_value
from libs.login import login_required
from services.annotation_service import AppAnnotationService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AnnotationReplyPayload(BaseModel):
score_threshold: float = Field(..., description="Score threshold for annotation matching")
embedding_provider_name: str = Field(..., description="Embedding provider name")
embedding_model_name: str = Field(..., description="Embedding model name")
class AnnotationSettingUpdatePayload(BaseModel):
score_threshold: float = Field(..., description="Score threshold")
class AnnotationListQuery(BaseModel):
page: int = Field(default=1, ge=1, description="Page number")
limit: int = Field(default=20, ge=1, description="Page size")
keyword: str = Field(default="", description="Search keyword")
class CreateAnnotationPayload(BaseModel):
message_id: str | None = Field(default=None, description="Message ID")
question: str | None = Field(default=None, description="Question text")
answer: str | None = Field(default=None, description="Answer text")
content: str | None = Field(default=None, description="Content text")
annotation_reply: dict[str, Any] | None = Field(default=None, description="Annotation reply data")
@field_validator("message_id")
@classmethod
def validate_message_id(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class UpdateAnnotationPayload(BaseModel):
question: str | None = None
answer: str | None = None
content: str | None = None
annotation_reply: dict[str, Any] | None = None
class AnnotationReplyStatusQuery(BaseModel):
action: Literal["enable", "disable"]
class AnnotationFilePayload(BaseModel):
message_id: str = Field(..., description="Message ID")
@field_validator("message_id")
@classmethod
def validate_message_id(cls, value: str) -> str:
return uuid_value(value)
def reg(model: type[BaseModel]) -> None:
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(AnnotationReplyPayload)
reg(AnnotationSettingUpdatePayload)
reg(AnnotationListQuery)
reg(CreateAnnotationPayload)
reg(UpdateAnnotationPayload)
reg(AnnotationReplyStatusQuery)
reg(AnnotationFilePayload)
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>")
class AnnotationReplyActionApi(Resource):
@console_ns.doc("annotation_reply_action")
@console_ns.doc(description="Enable or disable annotation reply for an app")
@console_ns.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"})
@console_ns.expect(
console_ns.model(
"AnnotationReplyActionRequest",
{
"score_threshold": fields.Float(required=True, description="Score threshold for annotation matching"),
"embedding_provider_name": fields.String(required=True, description="Embedding provider name"),
"embedding_model_name": fields.String(required=True, description="Embedding model name"),
},
)
)
@console_ns.expect(console_ns.models[AnnotationReplyPayload.__name__])
@console_ns.response(200, "Action completed successfully")
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -46,15 +106,9 @@ class AnnotationReplyActionApi(Resource):
@edit_permission_required
def post(self, app_id, action: Literal["enable", "disable"]):
app_id = str(app_id)
parser = (
reqparse.RequestParser()
.add_argument("score_threshold", required=True, type=float, location="json")
.add_argument("embedding_provider_name", required=True, type=str, location="json")
.add_argument("embedding_model_name", required=True, type=str, location="json")
)
args = parser.parse_args()
args = AnnotationReplyPayload.model_validate(console_ns.payload)
if action == "enable":
result = AppAnnotationService.enable_app_annotation(args, app_id)
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
elif action == "disable":
result = AppAnnotationService.disable_app_annotation(app_id)
return result, 200
@ -82,16 +136,7 @@ class AppAnnotationSettingUpdateApi(Resource):
@console_ns.doc("update_annotation_setting")
@console_ns.doc(description="Update annotation settings for an app")
@console_ns.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"})
@console_ns.expect(
console_ns.model(
"AnnotationSettingUpdateRequest",
{
"score_threshold": fields.Float(required=True, description="Score threshold"),
"embedding_provider_name": fields.String(required=True, description="Embedding provider"),
"embedding_model_name": fields.String(required=True, description="Embedding model"),
},
)
)
@console_ns.expect(console_ns.models[AnnotationSettingUpdatePayload.__name__])
@console_ns.response(200, "Settings updated successfully")
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -102,10 +147,9 @@ class AppAnnotationSettingUpdateApi(Resource):
app_id = str(app_id)
annotation_setting_id = str(annotation_setting_id)
parser = reqparse.RequestParser().add_argument("score_threshold", required=True, type=float, location="json")
args = parser.parse_args()
args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload)
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args.model_dump())
return result, 200
@ -142,12 +186,7 @@ class AnnotationApi(Resource):
@console_ns.doc("list_annotations")
@console_ns.doc(description="Get annotations for an app with pagination")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.parser()
.add_argument("page", type=int, location="args", default=1, help="Page number")
.add_argument("limit", type=int, location="args", default=20, help="Page size")
.add_argument("keyword", type=str, location="args", default="", help="Search keyword")
)
@console_ns.expect(console_ns.models[AnnotationListQuery.__name__])
@console_ns.response(200, "Annotations retrieved successfully")
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -155,9 +194,10 @@ class AnnotationApi(Resource):
@account_initialization_required
@edit_permission_required
def get(self, app_id):
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
keyword = request.args.get("keyword", default="", type=str)
args = AnnotationListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
page = args.page
limit = args.limit
keyword = args.keyword
app_id = str(app_id)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
@ -173,18 +213,7 @@ class AnnotationApi(Resource):
@console_ns.doc("create_annotation")
@console_ns.doc(description="Create a new annotation for an app")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.model(
"CreateAnnotationRequest",
{
"message_id": fields.String(description="Message ID (optional)"),
"question": fields.String(description="Question text (required when message_id not provided)"),
"answer": fields.String(description="Answer text (use 'answer' or 'content')"),
"content": fields.String(description="Content text (use 'answer' or 'content')"),
"annotation_reply": fields.Raw(description="Annotation reply data"),
},
)
)
@console_ns.expect(console_ns.models[CreateAnnotationPayload.__name__])
@console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns))
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -195,16 +224,9 @@ class AnnotationApi(Resource):
@edit_permission_required
def post(self, app_id):
app_id = str(app_id)
parser = (
reqparse.RequestParser()
.add_argument("message_id", required=False, type=uuid_value, location="json")
.add_argument("question", required=False, type=str, location="json")
.add_argument("answer", required=False, type=str, location="json")
.add_argument("content", required=False, type=str, location="json")
.add_argument("annotation_reply", required=False, type=dict, location="json")
)
args = parser.parse_args()
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
args = CreateAnnotationPayload.model_validate(console_ns.payload)
data = args.model_dump(exclude_none=True)
annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
return annotation
@setup_required
@ -237,7 +259,7 @@ class AnnotationApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotations/export")
class AnnotationExportApi(Resource):
@console_ns.doc("export_annotations")
@console_ns.doc(description="Export all annotations for an app")
@console_ns.doc(description="Export all annotations for an app with CSV injection protection")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(
200,
@ -252,15 +274,14 @@ class AnnotationExportApi(Resource):
def get(self, app_id):
app_id = str(app_id)
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
response = {"data": marshal(annotation_list, annotation_fields)}
return response, 200
response_data = {"data": marshal(annotation_list, annotation_fields)}
# Create response with secure headers for CSV export
response = make_response(response_data, 200)
response.headers["Content-Type"] = "application/json; charset=utf-8"
response.headers["X-Content-Type-Options"] = "nosniff"
parser = (
reqparse.RequestParser()
.add_argument("question", required=True, type=str, location="json")
.add_argument("answer", required=True, type=str, location="json")
)
return response
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
@ -271,7 +292,7 @@ class AnnotationUpdateDeleteApi(Resource):
@console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns))
@console_ns.response(204, "Annotation deleted successfully")
@console_ns.response(403, "Insufficient permissions")
@console_ns.expect(parser)
@console_ns.expect(console_ns.models[UpdateAnnotationPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -281,8 +302,10 @@ class AnnotationUpdateDeleteApi(Resource):
def post(self, app_id, annotation_id):
app_id = str(app_id)
annotation_id = str(annotation_id)
args = parser.parse_args()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
args = UpdateAnnotationPayload.model_validate(console_ns.payload)
annotation = AppAnnotationService.update_app_annotation_directly(
args.model_dump(exclude_none=True), app_id, annotation_id
)
return annotation
@setup_required
@ -299,18 +322,25 @@ class AnnotationUpdateDeleteApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import")
class AnnotationBatchImportApi(Resource):
@console_ns.doc("batch_import_annotations")
@console_ns.doc(description="Batch import annotations from CSV file")
@console_ns.doc(description="Batch import annotations from CSV file with rate limiting and security checks")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Batch import started successfully")
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(400, "No file uploaded or too many files")
@console_ns.response(413, "File too large")
@console_ns.response(429, "Too many requests or concurrent imports")
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@annotation_import_rate_limit
@annotation_import_concurrency_limit
@edit_permission_required
def post(self, app_id):
from configs import dify_config
app_id = str(app_id)
# check file
if "file" not in request.files:
raise NoFileUploadedError()
@ -320,9 +350,27 @@ class AnnotationBatchImportApi(Resource):
# get file from request
file = request.files["file"]
# check file type
if not file.filename or not file.filename.lower().endswith(".csv"):
raise ValueError("Invalid file type. Only CSV files are allowed")
# Check file size before processing
file.seek(0, 2) # Seek to end of file
file_size = file.tell()
file.seek(0) # Reset to beginning
max_size_bytes = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024
if file_size > max_size_bytes:
abort(
413,
f"File size exceeds maximum limit of {dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT}MB. "
f"Please reduce the file size and try again.",
)
if file_size == 0:
raise ValueError("The uploaded file is empty")
return AppAnnotationService.batch_import_app_annotations(app_id, file)

View File

@ -31,7 +31,6 @@ from fields.app_fields import (
from fields.workflow_fields import workflow_partial_fields as _workflow_partial_fields_dict
from libs.helper import AppIconUrlField, TimestampField
from libs.login import current_account_with_tenant, login_required
from libs.validators import validate_description_length
from models import App, Workflow
from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService
@ -76,51 +75,30 @@ class AppListQuery(BaseModel):
class CreateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@field_validator("description")
@classmethod
def validate_description(cls, value: str | None) -> str | None:
if value is None:
return value
return validate_description_length(value)
class UpdateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
max_active_requests: int | None = Field(default=None, description="Maximum active requests")
@field_validator("description")
@classmethod
def validate_description(cls, value: str | None) -> str | None:
if value is None:
return value
return validate_description_length(value)
class CopyAppPayload(BaseModel):
name: str | None = Field(default=None, description="Name for the copied app")
description: str | None = Field(default=None, description="Description for the copied app")
description: str | None = Field(default=None, description="Description for the copied app", max_length=400)
icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@field_validator("description")
@classmethod
def validate_description(cls, value: str | None) -> str | None:
if value is None:
return value
return validate_description_length(value)
class AppExportQuery(BaseModel):
include_secret: bool = Field(default=False, description="Include secrets in export")
@ -146,7 +124,14 @@ class AppApiStatusPayload(BaseModel):
class AppTracePayload(BaseModel):
enabled: bool = Field(..., description="Enable or disable tracing")
tracing_provider: str = Field(..., description="Tracing provider")
tracing_provider: str | None = Field(default=None, description="Tracing provider")
@field_validator("tracing_provider")
@classmethod
def validate_tracing_provider(cls, value: str | None, info) -> str | None:
if info.data.get("enabled") and not value:
raise ValueError("tracing_provider is required when enabled is True")
return value
def reg(cls: type[BaseModel]):
@ -324,10 +309,13 @@ class AppListApi(Resource):
NodeType.TRIGGER_PLUGIN,
}
for workflow in draft_workflows:
for _, node_data in workflow.walk_nodes():
if node_data.get("type") in trigger_node_types:
draft_trigger_app_ids.add(str(workflow.app_id))
break
try:
for _, node_data in workflow.walk_nodes():
if node_data.get("type") in trigger_node_types:
draft_trigger_app_ids.add(str(workflow.app_id))
break
except Exception:
continue
for app in app_pagination.items:
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids

View File

@ -1,4 +1,5 @@
from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from controllers.console.app.wraps import get_app_model
@ -35,23 +36,29 @@ app_import_check_dependencies_model = console_ns.model(
"AppImportCheckDependencies", app_import_check_dependencies_fields_copy
)
parser = (
reqparse.RequestParser()
.add_argument("mode", type=str, required=True, location="json")
.add_argument("yaml_content", type=str, location="json")
.add_argument("yaml_url", type=str, location="json")
.add_argument("name", type=str, location="json")
.add_argument("description", type=str, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
.add_argument("app_id", type=str, location="json")
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AppImportPayload(BaseModel):
mode: str = Field(..., description="Import mode")
yaml_content: str | None = None
yaml_url: str | None = None
name: str | None = None
description: str | None = None
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
app_id: str | None = None
console_ns.schema_model(
AppImportPayload.__name__, AppImportPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/apps/imports")
class AppImportApi(Resource):
@console_ns.expect(parser)
@console_ns.expect(console_ns.models[AppImportPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -61,7 +68,7 @@ class AppImportApi(Resource):
def post(self):
# Check user role first
current_user, _ = current_account_with_tenant()
args = parser.parse_args()
args = AppImportPayload.model_validate(console_ns.payload)
# Create service with session
with Session(db.engine) as session:
@ -70,15 +77,15 @@ class AppImportApi(Resource):
account = current_user
result = import_service.import_app(
account=account,
import_mode=args["mode"],
yaml_content=args.get("yaml_content"),
yaml_url=args.get("yaml_url"),
name=args.get("name"),
description=args.get("description"),
icon_type=args.get("icon_type"),
icon=args.get("icon"),
icon_background=args.get("icon_background"),
app_id=args.get("app_id"),
import_mode=args.mode,
yaml_content=args.yaml_content,
yaml_url=args.yaml_url,
name=args.name,
description=args.description,
icon_type=args.icon_type,
icon=args.icon,
icon_background=args.icon_background,
app_id=args.app_id,
)
session.commit()
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:

View File

@ -1,7 +1,8 @@
import logging
from flask import request
from flask_restx import Resource, fields, reqparse
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError
import services
@ -32,6 +33,27 @@ from services.errors.audio import (
)
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class TextToSpeechPayload(BaseModel):
message_id: str | None = Field(default=None, description="Message ID")
text: str = Field(..., description="Text to convert")
voice: str | None = Field(default=None, description="Voice name")
streaming: bool | None = Field(default=None, description="Whether to stream audio")
class TextToSpeechVoiceQuery(BaseModel):
language: str = Field(..., description="Language code")
console_ns.schema_model(
TextToSpeechPayload.__name__, TextToSpeechPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
TextToSpeechVoiceQuery.__name__,
TextToSpeechVoiceQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/apps/<uuid:app_id>/audio-to-text")
@ -92,17 +114,7 @@ class ChatMessageTextApi(Resource):
@console_ns.doc("chat_message_text_to_speech")
@console_ns.doc(description="Convert text to speech for chat messages")
@console_ns.doc(params={"app_id": "App ID"})
@console_ns.expect(
console_ns.model(
"TextToSpeechRequest",
{
"message_id": fields.String(description="Message ID"),
"text": fields.String(required=True, description="Text to convert to speech"),
"voice": fields.String(description="Voice to use for TTS"),
"streaming": fields.Boolean(description="Whether to stream the audio"),
},
)
)
@console_ns.expect(console_ns.models[TextToSpeechPayload.__name__])
@console_ns.response(200, "Text to speech conversion successful")
@console_ns.response(400, "Bad request - Invalid parameters")
@get_app_model
@ -111,21 +123,14 @@ class ChatMessageTextApi(Resource):
@account_initialization_required
def post(self, app_model: App):
try:
parser = (
reqparse.RequestParser()
.add_argument("message_id", type=str, location="json")
.add_argument("text", type=str, location="json")
.add_argument("voice", type=str, location="json")
.add_argument("streaming", type=bool, location="json")
)
args = parser.parse_args()
message_id = args.get("message_id", None)
text = args.get("text", None)
voice = args.get("voice", None)
payload = TextToSpeechPayload.model_validate(console_ns.payload)
response = AudioService.transcript_tts(
app_model=app_model, text=text, voice=voice, message_id=message_id, is_draft=True
app_model=app_model,
text=payload.text,
voice=payload.voice,
message_id=payload.message_id,
is_draft=True,
)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
@ -159,9 +164,7 @@ class TextModesApi(Resource):
@console_ns.doc("get_text_to_speech_voices")
@console_ns.doc(description="Get available TTS voices for a specific language")
@console_ns.doc(params={"app_id": "App ID"})
@console_ns.expect(
console_ns.parser().add_argument("language", type=str, required=True, location="args", help="Language code")
)
@console_ns.expect(console_ns.models[TextToSpeechVoiceQuery.__name__])
@console_ns.response(
200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices"))
)
@ -172,12 +175,11 @@ class TextModesApi(Resource):
@account_initialization_required
def get(self, app_model):
try:
parser = reqparse.RequestParser().add_argument("language", type=str, required=True, location="args")
args = parser.parse_args()
args = TextToSpeechVoiceQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
response = AudioService.transcript_tts_voices(
tenant_id=app_model.tenant_id,
language=args["language"],
language=args.language,
)
return response

View File

@ -49,7 +49,6 @@ class CompletionConversationQuery(BaseConversationQuery):
class ChatConversationQuery(BaseConversationQuery):
message_count_gte: int | None = Field(default=None, ge=1, description="Minimum message count")
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
default="-updated_at", description="Sort field and direction"
)
@ -509,14 +508,6 @@ class ChatConversationApi(Resource):
.having(func.count(MessageAnnotation.id) == 0)
)
if args.message_count_gte and args.message_count_gte >= 1:
query = (
query.options(joinedload(Conversation.messages)) # type: ignore
.join(Message, Message.conversation_id == Conversation.id)
.group_by(Conversation.id)
.having(func.count(Message.id) >= args.message_count_gte)
)
if app_model.mode == AppMode.ADVANCED_CHAT:
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)

View File

@ -1,7 +1,8 @@
import json
from enum import StrEnum
from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
from werkzeug.exceptions import NotFound
from controllers.console import console_ns
@ -12,6 +13,8 @@ from fields.app_fields import app_server_fields
from libs.login import current_account_with_tenant, login_required
from models.model import AppMCPServer
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
# Register model for flask_restx to avoid dict type issues in Swagger
app_server_model = console_ns.model("AppServer", app_server_fields)
@ -21,6 +24,22 @@ class AppMCPServerStatus(StrEnum):
INACTIVE = "inactive"
class MCPServerCreatePayload(BaseModel):
description: str | None = Field(default=None, description="Server description")
parameters: dict = Field(..., description="Server parameters configuration")
class MCPServerUpdatePayload(BaseModel):
id: str = Field(..., description="Server ID")
description: str | None = Field(default=None, description="Server description")
parameters: dict = Field(..., description="Server parameters configuration")
status: str | None = Field(default=None, description="Server status")
for model in (MCPServerCreatePayload, MCPServerUpdatePayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/apps/<uuid:app_id>/server")
class AppMCPServerController(Resource):
@console_ns.doc("get_app_mcp_server")
@ -39,15 +58,7 @@ class AppMCPServerController(Resource):
@console_ns.doc("create_app_mcp_server")
@console_ns.doc(description="Create MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.model(
"MCPServerCreateRequest",
{
"description": fields.String(description="Server description"),
"parameters": fields.Raw(required=True, description="Server parameters configuration"),
},
)
)
@console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__])
@console_ns.response(201, "MCP server configuration created successfully", app_server_model)
@console_ns.response(403, "Insufficient permissions")
@account_initialization_required
@ -58,21 +69,16 @@ class AppMCPServerController(Resource):
@edit_permission_required
def post(self, app_model):
_, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("description", type=str, required=False, location="json")
.add_argument("parameters", type=dict, required=True, location="json")
)
args = parser.parse_args()
payload = MCPServerCreatePayload.model_validate(console_ns.payload or {})
description = args.get("description")
description = payload.description
if not description:
description = app_model.description or ""
server = AppMCPServer(
name=app_model.name,
description=description,
parameters=json.dumps(args["parameters"], ensure_ascii=False),
parameters=json.dumps(payload.parameters, ensure_ascii=False),
status=AppMCPServerStatus.ACTIVE,
app_id=app_model.id,
tenant_id=current_tenant_id,
@ -85,17 +91,7 @@ class AppMCPServerController(Resource):
@console_ns.doc("update_app_mcp_server")
@console_ns.doc(description="Update MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.model(
"MCPServerUpdateRequest",
{
"id": fields.String(required=True, description="Server ID"),
"description": fields.String(description="Server description"),
"parameters": fields.Raw(required=True, description="Server parameters configuration"),
"status": fields.String(description="Server status"),
},
)
)
@console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__])
@console_ns.response(200, "MCP server configuration updated successfully", app_server_model)
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "Server not found")
@ -106,19 +102,12 @@ class AppMCPServerController(Resource):
@marshal_with(app_server_model)
@edit_permission_required
def put(self, app_model):
parser = (
reqparse.RequestParser()
.add_argument("id", type=str, required=True, location="json")
.add_argument("description", type=str, required=False, location="json")
.add_argument("parameters", type=dict, required=True, location="json")
.add_argument("status", type=str, required=False, location="json")
)
args = parser.parse_args()
server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first()
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first()
if not server:
raise NotFound()
description = args.get("description")
description = payload.description
if description is None:
pass
elif not description:
@ -126,11 +115,11 @@ class AppMCPServerController(Resource):
else:
server.description = description
server.parameters = json.dumps(args["parameters"], ensure_ascii=False)
if args["status"]:
if args["status"] not in [status.value for status in AppMCPServerStatus]:
server.parameters = json.dumps(payload.parameters, ensure_ascii=False)
if payload.status:
if payload.status not in [status.value for status in AppMCPServerStatus]:
raise ValueError("Invalid status")
server.status = args["status"]
server.status = payload.status
db.session.commit()
return server

View File

@ -61,6 +61,7 @@ class ChatMessagesQuery(BaseModel):
class MessageFeedbackPayload(BaseModel):
message_id: str = Field(..., description="Message ID")
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
content: str | None = Field(default=None, description="Feedback content")
@field_validator("message_id")
@classmethod
@ -324,6 +325,7 @@ class MessageFeedbackApi(Resource):
db.session.delete(feedback)
elif args.rating and feedback:
feedback.rating = args.rating
feedback.content = args.content
elif not args.rating and not feedback:
raise ValueError("rating cannot be None when feedback not exists")
else:
@ -335,6 +337,7 @@ class MessageFeedbackApi(Resource):
conversation_id=message.conversation_id,
message_id=message.id,
rating=rating_value,
content=args.content,
from_source="admin",
from_account_id=current_user.id,
)

View File

@ -1,4 +1,8 @@
from flask_restx import Resource, fields, reqparse
from typing import Any
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from werkzeug.exceptions import BadRequest
from controllers.console import console_ns
@ -7,6 +11,26 @@ from controllers.console.wraps import account_initialization_required, setup_req
from libs.login import login_required
from services.ops_service import OpsService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class TraceProviderQuery(BaseModel):
tracing_provider: str = Field(..., description="Tracing provider name")
class TraceConfigPayload(BaseModel):
tracing_provider: str = Field(..., description="Tracing provider name")
tracing_config: dict[str, Any] = Field(..., description="Tracing configuration data")
console_ns.schema_model(
TraceProviderQuery.__name__,
TraceProviderQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
TraceConfigPayload.__name__, TraceConfigPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/apps/<uuid:app_id>/trace-config")
class TraceAppConfigApi(Resource):
@ -17,11 +41,7 @@ class TraceAppConfigApi(Resource):
@console_ns.doc("get_trace_app_config")
@console_ns.doc(description="Get tracing configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.parser().add_argument(
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
)
)
@console_ns.expect(console_ns.models[TraceProviderQuery.__name__])
@console_ns.response(
200, "Tracing configuration retrieved successfully", fields.Raw(description="Tracing configuration data")
)
@ -30,11 +50,10 @@ class TraceAppConfigApi(Resource):
@login_required
@account_initialization_required
def get(self, app_id):
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
args = parser.parse_args()
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
if not trace_config:
return {"has_not_configured": True}
return trace_config
@ -44,15 +63,7 @@ class TraceAppConfigApi(Resource):
@console_ns.doc("create_trace_app_config")
@console_ns.doc(description="Create a new tracing configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.model(
"TraceConfigCreateRequest",
{
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
"tracing_config": fields.Raw(required=True, description="Tracing configuration data"),
},
)
)
@console_ns.expect(console_ns.models[TraceConfigPayload.__name__])
@console_ns.response(
201, "Tracing configuration created successfully", fields.Raw(description="Created configuration data")
)
@ -62,16 +73,11 @@ class TraceAppConfigApi(Resource):
@account_initialization_required
def post(self, app_id):
"""Create a new trace app configuration"""
parser = (
reqparse.RequestParser()
.add_argument("tracing_provider", type=str, required=True, location="json")
.add_argument("tracing_config", type=dict, required=True, location="json")
)
args = parser.parse_args()
args = TraceConfigPayload.model_validate(console_ns.payload)
try:
result = OpsService.create_tracing_app_config(
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
)
if not result:
raise TracingConfigIsExist()
@ -84,15 +90,7 @@ class TraceAppConfigApi(Resource):
@console_ns.doc("update_trace_app_config")
@console_ns.doc(description="Update an existing tracing configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.model(
"TraceConfigUpdateRequest",
{
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
"tracing_config": fields.Raw(required=True, description="Updated tracing configuration data"),
},
)
)
@console_ns.expect(console_ns.models[TraceConfigPayload.__name__])
@console_ns.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response"))
@console_ns.response(400, "Invalid request parameters or configuration not found")
@setup_required
@ -100,16 +98,11 @@ class TraceAppConfigApi(Resource):
@account_initialization_required
def patch(self, app_id):
"""Update an existing trace app configuration"""
parser = (
reqparse.RequestParser()
.add_argument("tracing_provider", type=str, required=True, location="json")
.add_argument("tracing_config", type=dict, required=True, location="json")
)
args = parser.parse_args()
args = TraceConfigPayload.model_validate(console_ns.payload)
try:
result = OpsService.update_tracing_app_config(
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
)
if not result:
raise TracingConfigNotExist()
@ -120,11 +113,7 @@ class TraceAppConfigApi(Resource):
@console_ns.doc("delete_trace_app_config")
@console_ns.doc(description="Delete an existing tracing configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.parser().add_argument(
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
)
)
@console_ns.expect(console_ns.models[TraceProviderQuery.__name__])
@console_ns.response(204, "Tracing configuration deleted successfully")
@console_ns.response(400, "Invalid request parameters or configuration not found")
@setup_required
@ -132,11 +121,10 @@ class TraceAppConfigApi(Resource):
@account_initialization_required
def delete(self, app_id):
"""Delete an existing trace app configuration"""
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
args = parser.parse_args()
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
if not result:
raise TracingConfigNotExist()
return {"result": "success"}, 204

View File

@ -1,4 +1,7 @@
from flask_restx import Resource, fields, marshal_with, reqparse
from typing import Literal
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import NotFound
from constants.languages import supported_language
@ -16,69 +19,50 @@ from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import Site
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AppSiteUpdatePayload(BaseModel):
title: str | None = Field(default=None)
icon_type: str | None = Field(default=None)
icon: str | None = Field(default=None)
icon_background: str | None = Field(default=None)
description: str | None = Field(default=None)
default_language: str | None = Field(default=None)
chat_color_theme: str | None = Field(default=None)
chat_color_theme_inverted: bool | None = Field(default=None)
customize_domain: str | None = Field(default=None)
copyright: str | None = Field(default=None)
privacy_policy: str | None = Field(default=None)
custom_disclaimer: str | None = Field(default=None)
customize_token_strategy: Literal["must", "allow", "not_allow"] | None = Field(default=None)
prompt_public: bool | None = Field(default=None)
show_workflow_steps: bool | None = Field(default=None)
use_icon_as_answer_icon: bool | None = Field(default=None)
@field_validator("default_language")
@classmethod
def validate_language(cls, value: str | None) -> str | None:
if value is None:
return value
return supported_language(value)
console_ns.schema_model(
AppSiteUpdatePayload.__name__,
AppSiteUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
# Register model for flask_restx to avoid dict type issues in Swagger
app_site_model = console_ns.model("AppSite", app_site_fields)
def parse_app_site_args():
parser = (
reqparse.RequestParser()
.add_argument("title", type=str, required=False, location="json")
.add_argument("icon_type", type=str, required=False, location="json")
.add_argument("icon", type=str, required=False, location="json")
.add_argument("icon_background", type=str, required=False, location="json")
.add_argument("description", type=str, required=False, location="json")
.add_argument("default_language", type=supported_language, required=False, location="json")
.add_argument("chat_color_theme", type=str, required=False, location="json")
.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
.add_argument("customize_domain", type=str, required=False, location="json")
.add_argument("copyright", type=str, required=False, location="json")
.add_argument("privacy_policy", type=str, required=False, location="json")
.add_argument("custom_disclaimer", type=str, required=False, location="json")
.add_argument(
"customize_token_strategy",
type=str,
choices=["must", "allow", "not_allow"],
required=False,
location="json",
)
.add_argument("prompt_public", type=bool, required=False, location="json")
.add_argument("show_workflow_steps", type=bool, required=False, location="json")
.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
)
return parser.parse_args()
@console_ns.route("/apps/<uuid:app_id>/site")
class AppSite(Resource):
@console_ns.doc("update_app_site")
@console_ns.doc(description="Update application site configuration")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.model(
"AppSiteRequest",
{
"title": fields.String(description="Site title"),
"icon_type": fields.String(description="Icon type"),
"icon": fields.String(description="Icon"),
"icon_background": fields.String(description="Icon background color"),
"description": fields.String(description="Site description"),
"default_language": fields.String(description="Default language"),
"chat_color_theme": fields.String(description="Chat color theme"),
"chat_color_theme_inverted": fields.Boolean(description="Inverted chat color theme"),
"customize_domain": fields.String(description="Custom domain"),
"copyright": fields.String(description="Copyright text"),
"privacy_policy": fields.String(description="Privacy policy"),
"custom_disclaimer": fields.String(description="Custom disclaimer"),
"customize_token_strategy": fields.String(
enum=["must", "allow", "not_allow"], description="Token strategy"
),
"prompt_public": fields.Boolean(description="Make prompt public"),
"show_workflow_steps": fields.Boolean(description="Show workflow steps"),
"use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"),
},
)
)
@console_ns.expect(console_ns.models[AppSiteUpdatePayload.__name__])
@console_ns.response(200, "Site configuration updated successfully", app_site_model)
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "App not found")
@ -89,7 +73,7 @@ class AppSite(Resource):
@get_app_model
@marshal_with(app_site_model)
def post(self, app_model):
args = parse_app_site_args()
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
current_user, _ = current_account_with_tenant()
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site:
@ -113,7 +97,7 @@ class AppSite(Resource):
"show_workflow_steps",
"use_icon_as_answer_icon",
]:
value = args.get(attr_name)
value = getattr(args, attr_name)
if value is not None:
setattr(site, attr_name, value)

View File

@ -1,10 +1,11 @@
import logging
from collections.abc import Callable
from functools import wraps
from typing import NoReturn, ParamSpec, TypeVar
from typing import Any, NoReturn, ParamSpec, TypeVar
from flask import Response
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from flask import Response, request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from controllers.console import console_ns
@ -29,6 +30,27 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList,
from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowDraftVariableListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=100_000, description="Page number")
limit: int = Field(default=20, ge=1, le=100, description="Items per page")
class WorkflowDraftVariableUpdatePayload(BaseModel):
name: str | None = Field(default=None, description="Variable name")
value: Any | None = Field(default=None, description="Variable value")
console_ns.schema_model(
WorkflowDraftVariableListQuery.__name__,
WorkflowDraftVariableListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
WorkflowDraftVariableUpdatePayload.__name__,
WorkflowDraftVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def _convert_values_to_json_serializable_object(value: Segment):
@ -57,22 +79,6 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
return _convert_values_to_json_serializable_object(value)
def _create_pagination_parser():
parser = (
reqparse.RequestParser()
.add_argument(
"page",
type=inputs.int_range(1, 100_000),
required=False,
default=1,
location="args",
help="the page of data requested",
)
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
)
return parser
def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
value_type = workflow_draft_var.value_type
return value_type.exposed_type().value
@ -201,7 +207,7 @@ def _api_prerequisite(f: Callable[P, R]):
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables")
class WorkflowVariableCollectionApi(Resource):
@console_ns.expect(_create_pagination_parser())
@console_ns.expect(console_ns.models[WorkflowDraftVariableListQuery.__name__])
@console_ns.doc("get_workflow_variables")
@console_ns.doc(description="Get draft workflow variables")
@console_ns.doc(params={"app_id": "Application ID"})
@ -215,8 +221,7 @@ class WorkflowVariableCollectionApi(Resource):
"""
Get draft workflow
"""
parser = _create_pagination_parser()
args = parser.parse_args()
args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
# fetch draft workflow by app_model
workflow_service = WorkflowService()
@ -323,15 +328,7 @@ class VariableApi(Resource):
@console_ns.doc("update_variable")
@console_ns.doc(description="Update a workflow variable")
@console_ns.expect(
console_ns.model(
"UpdateVariableRequest",
{
"name": fields.String(description="Variable name"),
"value": fields.Raw(description="Variable value"),
},
)
)
@console_ns.expect(console_ns.models[WorkflowDraftVariableUpdatePayload.__name__])
@console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model)
@console_ns.response(404, "Variable not found")
@_api_prerequisite
@ -358,16 +355,10 @@ class VariableApi(Resource):
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
# }
parser = (
reqparse.RequestParser()
.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
)
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
args = parser.parse_args(strict=True)
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
@ -375,8 +366,8 @@ class VariableApi(Resource):
if variable.app_id != app_model.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
new_name = args.get(self._PATCH_NAME_FIELD, None)
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
new_name = args_model.name
raw_value = args_model.value
if new_name is None and raw_value is None:
return variable

View File

@ -114,7 +114,7 @@ class AppTriggersApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/trigger-enable")
class AppTriggerEnableApi(Resource):
@console_ns.expect(console_ns.models[ParserEnable.__name__], validate=True)
@console_ns.expect(console_ns.models[ParserEnable.__name__])
@setup_required
@login_required
@account_initialization_required

View File

@ -1,28 +1,53 @@
from flask import request
from flask_restx import Resource, fields, reqparse
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from constants.languages import supported_language
from controllers.console import console_ns
from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.helper import StrLen, email, extract_remote_ip, timezone
from libs.helper import EmailStr, extract_remote_ip, timezone
from models import AccountStatus
from services.account_service import AccountService, RegisterService
active_check_parser = (
reqparse.RequestParser()
.add_argument("workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID")
.add_argument("email", type=email, required=False, nullable=True, location="args", help="Email address")
.add_argument("token", type=str, required=True, nullable=False, location="args", help="Activation token")
)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ActivateCheckQuery(BaseModel):
workspace_id: str | None = Field(default=None)
email: EmailStr | None = Field(default=None)
token: str
class ActivatePayload(BaseModel):
workspace_id: str | None = Field(default=None)
email: EmailStr | None = Field(default=None)
token: str
name: str = Field(..., max_length=30)
interface_language: str = Field(...)
timezone: str = Field(...)
@field_validator("interface_language")
@classmethod
def validate_lang(cls, value: str) -> str:
return supported_language(value)
@field_validator("timezone")
@classmethod
def validate_tz(cls, value: str) -> str:
return timezone(value)
for model in (ActivateCheckQuery, ActivatePayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/activate/check")
class ActivateCheckApi(Resource):
@console_ns.doc("check_activation_token")
@console_ns.doc(description="Check if activation token is valid")
@console_ns.expect(active_check_parser)
@console_ns.expect(console_ns.models[ActivateCheckQuery.__name__])
@console_ns.response(
200,
"Success",
@ -35,11 +60,11 @@ class ActivateCheckApi(Resource):
),
)
def get(self):
args = active_check_parser.parse_args()
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
workspaceId = args["workspace_id"]
reg_email = args["email"]
token = args["token"]
workspaceId = args.workspace_id
reg_email = args.email
token = args.token
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
if invitation:
@ -56,22 +81,11 @@ class ActivateCheckApi(Resource):
return {"is_valid": False}
active_parser = (
reqparse.RequestParser()
.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
.add_argument("email", type=email, required=False, nullable=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
.add_argument("interface_language", type=supported_language, required=True, nullable=False, location="json")
.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
)
@console_ns.route("/activate")
class ActivateApi(Resource):
@console_ns.doc("activate_account")
@console_ns.doc(description="Activate account with invitation token")
@console_ns.expect(active_parser)
@console_ns.expect(console_ns.models[ActivatePayload.__name__])
@console_ns.response(
200,
"Account activated successfully",
@ -85,19 +99,19 @@ class ActivateApi(Resource):
)
@console_ns.response(400, "Already activated or invalid token")
def post(self):
args = active_parser.parse_args()
args = ActivatePayload.model_validate(console_ns.payload)
invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"])
invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token)
if invitation is None:
raise AlreadyActivateError()
RegisterService.revoke_token(args["workspace_id"], args["email"], args["token"])
RegisterService.revoke_token(args.workspace_id, args.email, args.token)
account = invitation["account"]
account.name = args["name"]
account.name = args.name
account.interface_language = args["interface_language"]
account.timezone = args["timezone"]
account.interface_language = args.interface_language
account.timezone = args.timezone
account.interface_theme = "light"
account.status = AccountStatus.ACTIVE
account.initialized_at = naive_utc_now()

View File

@ -1,12 +1,26 @@
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.console import console_ns
from controllers.console.auth.error import ApiKeyAuthFailedError
from controllers.console.wraps import is_admin_or_owner_required
from libs.login import current_account_with_tenant, login_required
from services.auth.api_key_auth_service import ApiKeyAuthService
from ..wraps import account_initialization_required, setup_required
from .. import console_ns
from ..auth.error import ApiKeyAuthFailedError
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ApiKeyAuthBindingPayload(BaseModel):
category: str = Field(...)
provider: str = Field(...)
credentials: dict = Field(...)
console_ns.schema_model(
ApiKeyAuthBindingPayload.__name__,
ApiKeyAuthBindingPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/api-key-auth/data-source")
@ -40,19 +54,15 @@ class ApiKeyAuthDataSourceBinding(Resource):
@login_required
@account_initialization_required
@is_admin_or_owner_required
@console_ns.expect(console_ns.models[ApiKeyAuthBindingPayload.__name__])
def post(self):
# The role of the current user in the table must be admin or owner
_, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("category", type=str, required=True, nullable=False, location="json")
.add_argument("provider", type=str, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
ApiKeyAuthService.validate_api_key_auth_args(args)
payload = ApiKeyAuthBindingPayload.model_validate(console_ns.payload)
data = payload.model_dump()
ApiKeyAuthService.validate_api_key_auth_args(data)
try:
ApiKeyAuthService.create_provider_auth(current_tenant_id, args)
ApiKeyAuthService.create_provider_auth(current_tenant_id, data)
except Exception as e:
raise ApiKeyAuthFailedError(str(e))
return {"result": "success"}, 200

View File

@ -5,12 +5,11 @@ from flask import current_app, redirect, request
from flask_restx import Resource, fields
from configs import dify_config
from controllers.console import console_ns
from controllers.console.wraps import is_admin_or_owner_required
from libs.login import login_required
from libs.oauth_data_source import NotionOAuth
from ..wraps import account_initialization_required, setup_required
from .. import console_ns
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
logger = logging.getLogger(__name__)

View File

@ -1,5 +1,6 @@
from flask import request
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -14,16 +15,45 @@ from controllers.console.auth.error import (
InvalidTokenError,
PasswordMismatchError,
)
from controllers.console.error import AccountInFreezeError, EmailSendIpLimitError
from controllers.console.wraps import email_password_login_enabled, email_register_enabled, setup_required
from extensions.ext_database import db
from libs.helper import email, extract_remote_ip
from libs.helper import EmailStr, extract_remote_ip
from libs.password import valid_password
from models import Account
from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import AccountNotFoundError, AccountRegisterError
from ..error import AccountInFreezeError, EmailSendIpLimitError
from ..wraps import email_password_login_enabled, email_register_enabled, setup_required
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class EmailRegisterSendPayload(BaseModel):
email: EmailStr = Field(..., description="Email address")
language: str | None = Field(default=None, description="Language code")
class EmailRegisterValidityPayload(BaseModel):
email: EmailStr = Field(...)
code: str = Field(...)
token: str = Field(...)
class EmailRegisterResetPayload(BaseModel):
token: str = Field(...)
new_password: str = Field(...)
password_confirm: str = Field(...)
@field_validator("new_password", "password_confirm")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
for model in (EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/email-register/send-email")
class EmailRegisterSendEmailApi(Resource):
@ -31,27 +61,22 @@ class EmailRegisterSendEmailApi(Resource):
@email_password_login_enabled
@email_register_enabled
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
args = EmailRegisterSendPayload.model_validate(console_ns.payload)
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
language = "en-US"
if args["language"] in languages:
language = args["language"]
if args.language in languages:
language = args.language
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
raise AccountInFreezeError()
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
token = None
token = AccountService.send_email_register_email(email=args["email"], account=account, language=language)
token = AccountService.send_email_register_email(email=args.email, account=account, language=language)
return {"result": "success", "data": token}
@ -61,40 +86,34 @@ class EmailRegisterCheckApi(Resource):
@email_password_login_enabled
@email_register_enabled
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = EmailRegisterValidityPayload.model_validate(console_ns.payload)
user_email = args["email"]
user_email = args.email
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args["email"])
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email)
if is_email_register_error_rate_limit:
raise EmailRegisterLimitError()
token_data = AccountService.get_email_register_data(args["token"])
token_data = AccountService.get_email_register_data(args.token)
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
raise InvalidEmailError()
if args["code"] != token_data.get("code"):
AccountService.add_email_register_error_rate_limit(args["email"])
if args.code != token_data.get("code"):
AccountService.add_email_register_error_rate_limit(args.email)
raise EmailCodeError()
# Verified, revoke the first token
AccountService.revoke_email_register_token(args["token"])
AccountService.revoke_email_register_token(args.token)
# Refresh token data by generating a new token
_, new_token = AccountService.generate_email_register_token(
user_email, code=args["code"], additional_data={"phase": "register"}
user_email, code=args.code, additional_data={"phase": "register"}
)
AccountService.reset_email_register_error_rate_limit(args["email"])
AccountService.reset_email_register_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@ -104,20 +123,14 @@ class EmailRegisterResetApi(Resource):
@email_password_login_enabled
@email_register_enabled
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("token", type=str, required=True, nullable=False, location="json")
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = EmailRegisterResetPayload.model_validate(console_ns.payload)
# Validate passwords match
if args["new_password"] != args["password_confirm"]:
if args.new_password != args.password_confirm:
raise PasswordMismatchError()
# Validate token and get register data
register_data = AccountService.get_email_register_data(args["token"])
register_data = AccountService.get_email_register_data(args.token)
if not register_data:
raise InvalidTokenError()
# Must use token in reset phase
@ -125,7 +138,7 @@ class EmailRegisterResetApi(Resource):
raise InvalidTokenError()
# Revoke token to prevent reuse
AccountService.revoke_email_register_token(args["token"])
AccountService.revoke_email_register_token(args.token)
email = register_data.get("email", "")
@ -135,7 +148,7 @@ class EmailRegisterResetApi(Resource):
if account:
raise EmailAlreadyInUseError()
else:
account = self._create_new_account(email, args["password_confirm"])
account = self._create_new_account(email, args.password_confirm)
if not account:
raise AccountNotFoundError()
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))

View File

@ -2,7 +2,8 @@ import base64
import secrets
from flask import request
from flask_restx import Resource, fields, reqparse
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -18,26 +19,46 @@ from controllers.console.error import AccountNotFound, EmailSendIpLimitError
from controllers.console.wraps import email_password_login_enabled, setup_required
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import email, extract_remote_ip
from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password
from models import Account
from services.account_service import AccountService, TenantService
from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ForgotPasswordSendPayload(BaseModel):
email: EmailStr = Field(...)
language: str | None = Field(default=None)
class ForgotPasswordCheckPayload(BaseModel):
email: EmailStr = Field(...)
code: str = Field(...)
token: str = Field(...)
class ForgotPasswordResetPayload(BaseModel):
token: str = Field(...)
new_password: str = Field(...)
password_confirm: str = Field(...)
@field_validator("new_password", "password_confirm")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
for model in (ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/forgot-password")
class ForgotPasswordSendEmailApi(Resource):
@console_ns.doc("send_forgot_password_email")
@console_ns.doc(description="Send password reset email")
@console_ns.expect(
console_ns.model(
"ForgotPasswordEmailRequest",
{
"email": fields.String(required=True, description="Email address"),
"language": fields.String(description="Language for email (zh-Hans/en-US)"),
},
)
)
@console_ns.expect(console_ns.models[ForgotPasswordSendPayload.__name__])
@console_ns.response(
200,
"Email sent successfully",
@ -54,28 +75,23 @@ class ForgotPasswordSendEmailApi(Resource):
@setup_required
@email_password_login_enabled
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
args = ForgotPasswordSendPayload.model_validate(console_ns.payload)
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
if args["language"] is not None and args["language"] == "zh-Hans":
if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
token = AccountService.send_reset_password_email(
account=account,
email=args["email"],
email=args.email,
language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register,
)
@ -87,16 +103,7 @@ class ForgotPasswordSendEmailApi(Resource):
class ForgotPasswordCheckApi(Resource):
@console_ns.doc("check_forgot_password_code")
@console_ns.doc(description="Verify password reset code")
@console_ns.expect(
console_ns.model(
"ForgotPasswordCheckRequest",
{
"email": fields.String(required=True, description="Email address"),
"code": fields.String(required=True, description="Verification code"),
"token": fields.String(required=True, description="Reset token"),
},
)
)
@console_ns.expect(console_ns.models[ForgotPasswordCheckPayload.__name__])
@console_ns.response(
200,
"Code verified successfully",
@ -113,40 +120,34 @@ class ForgotPasswordCheckApi(Resource):
@setup_required
@email_password_login_enabled
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = ForgotPasswordCheckPayload.model_validate(console_ns.payload)
user_email = args["email"]
user_email = args.email
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"])
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email)
if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError()
token_data = AccountService.get_reset_password_data(args["token"])
token_data = AccountService.get_reset_password_data(args.token)
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
raise InvalidEmailError()
if args["code"] != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(args["email"])
if args.code != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(args.email)
raise EmailCodeError()
# Verified, revoke the first token
AccountService.revoke_reset_password_token(args["token"])
AccountService.revoke_reset_password_token(args.token)
# Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token(
user_email, code=args["code"], additional_data={"phase": "reset"}
user_email, code=args.code, additional_data={"phase": "reset"}
)
AccountService.reset_forgot_password_error_rate_limit(args["email"])
AccountService.reset_forgot_password_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@ -154,16 +155,7 @@ class ForgotPasswordCheckApi(Resource):
class ForgotPasswordResetApi(Resource):
@console_ns.doc("reset_password")
@console_ns.doc(description="Reset password with verification token")
@console_ns.expect(
console_ns.model(
"ForgotPasswordResetRequest",
{
"token": fields.String(required=True, description="Verification token"),
"new_password": fields.String(required=True, description="New password"),
"password_confirm": fields.String(required=True, description="Password confirmation"),
},
)
)
@console_ns.expect(console_ns.models[ForgotPasswordResetPayload.__name__])
@console_ns.response(
200,
"Password reset successfully",
@ -173,20 +165,14 @@ class ForgotPasswordResetApi(Resource):
@setup_required
@email_password_login_enabled
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("token", type=str, required=True, nullable=False, location="json")
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
)
args = parser.parse_args()
args = ForgotPasswordResetPayload.model_validate(console_ns.payload)
# Validate passwords match
if args["new_password"] != args["password_confirm"]:
if args.new_password != args.password_confirm:
raise PasswordMismatchError()
# Validate token and get reset data
reset_data = AccountService.get_reset_password_data(args["token"])
reset_data = AccountService.get_reset_password_data(args.token)
if not reset_data:
raise InvalidTokenError()
# Must use token in reset phase
@ -194,11 +180,11 @@ class ForgotPasswordResetApi(Resource):
raise InvalidTokenError()
# Revoke token to prevent reuse
AccountService.revoke_reset_password_token(args["token"])
AccountService.revoke_reset_password_token(args.token)
# Generate secure salt and hash password
salt = secrets.token_bytes(16)
password_hashed = hash_password(args["new_password"], salt)
password_hashed = hash_password(args.new_password, salt)
email = reset_data.get("email", "")

View File

@ -1,6 +1,7 @@
import flask_login
from flask import make_response, request
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel, Field
import services
from configs import dify_config
@ -23,7 +24,7 @@ from controllers.console.error import (
)
from controllers.console.wraps import email_password_login_enabled, setup_required
from events.tenant_event import tenant_was_created
from libs.helper import email, extract_remote_ip
from libs.helper import EmailStr, extract_remote_ip
from libs.login import current_account_with_tenant
from libs.token import (
clear_access_token_from_cookie,
@ -40,6 +41,36 @@ from services.errors.account import AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class LoginPayload(BaseModel):
email: EmailStr = Field(..., description="Email address")
password: str = Field(..., description="Password")
remember_me: bool = Field(default=False, description="Remember me flag")
invite_token: str | None = Field(default=None, description="Invitation token")
class EmailPayload(BaseModel):
email: EmailStr = Field(...)
language: str | None = Field(default=None)
class EmailCodeLoginPayload(BaseModel):
email: EmailStr = Field(...)
code: str = Field(...)
token: str = Field(...)
language: str | None = Field(default=None)
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(LoginPayload)
reg(EmailPayload)
reg(EmailCodeLoginPayload)
@console_ns.route("/login")
class LoginApi(Resource):
@ -47,41 +78,36 @@ class LoginApi(Resource):
@setup_required
@email_password_login_enabled
@console_ns.expect(console_ns.models[LoginPayload.__name__])
def post(self):
"""Authenticate user and login."""
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("password", type=str, required=True, location="json")
.add_argument("remember_me", type=bool, required=False, default=False, location="json")
.add_argument("invite_token", type=str, required=False, default=None, location="json")
)
args = parser.parse_args()
args = LoginPayload.model_validate(console_ns.payload)
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
raise AccountInFreezeError()
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args["email"])
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email)
if is_login_error_rate_limit:
raise EmailPasswordLoginLimitError()
invitation = args["invite_token"]
# TODO: why invitation is re-assigned with different type?
invitation = args.invite_token # type: ignore
if invitation:
invitation = RegisterService.get_invitation_if_token_valid(None, args["email"], invitation)
invitation = RegisterService.get_invitation_if_token_valid(None, args.email, invitation) # type: ignore
try:
if invitation:
data = invitation.get("data", {})
data = invitation.get("data", {}) # type: ignore
invitee_email = data.get("email") if data else None
if invitee_email != args["email"]:
if invitee_email != args.email:
raise InvalidEmailError()
account = AccountService.authenticate(args["email"], args["password"], args["invite_token"])
account = AccountService.authenticate(args.email, args.password, args.invite_token)
else:
account = AccountService.authenticate(args["email"], args["password"])
account = AccountService.authenticate(args.email, args.password)
except services.errors.account.AccountLoginError:
raise AccountBannedError()
except services.errors.account.AccountPasswordError:
AccountService.add_login_error_rate_limit(args["email"])
AccountService.add_login_error_rate_limit(args.email)
raise AuthenticationFailedError()
# SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account)
@ -97,7 +123,7 @@ class LoginApi(Resource):
}
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"])
AccountService.reset_login_error_rate_limit(args.email)
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
@ -134,25 +160,21 @@ class LogoutApi(Resource):
class ResetPasswordSendEmailApi(Resource):
@setup_required
@email_password_login_enabled
@console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
args = EmailPayload.model_validate(console_ns.payload)
if args["language"] is not None and args["language"] == "zh-Hans":
if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
try:
account = AccountService.get_user_through_email(args["email"])
account = AccountService.get_user_through_email(args.email)
except AccountRegisterError:
raise AccountInFreezeError()
token = AccountService.send_reset_password_email(
email=args["email"],
email=args.email,
account=account,
language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register,
@ -164,30 +186,26 @@ class ResetPasswordSendEmailApi(Resource):
@console_ns.route("/email-code-login")
class EmailCodeLoginSendEmailApi(Resource):
@setup_required
@console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
args = EmailPayload.model_validate(console_ns.payload)
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
if args["language"] is not None and args["language"] == "zh-Hans":
if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
try:
account = AccountService.get_user_through_email(args["email"])
account = AccountService.get_user_through_email(args.email)
except AccountRegisterError:
raise AccountInFreezeError()
if account is None:
if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_email_code_login_email(email=args["email"], language=language)
token = AccountService.send_email_code_login_email(email=args.email, language=language)
else:
raise AccountNotFound()
else:
@ -199,30 +217,24 @@ class EmailCodeLoginSendEmailApi(Resource):
@console_ns.route("/email-code-login/validity")
class EmailCodeLoginApi(Resource):
@setup_required
@console_ns.expect(console_ns.models[EmailCodeLoginPayload.__name__])
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
args = EmailCodeLoginPayload.model_validate(console_ns.payload)
user_email = args["email"]
language = args["language"]
user_email = args.email
language = args.language
token_data = AccountService.get_email_code_login_data(args["token"])
token_data = AccountService.get_email_code_login_data(args.token)
if token_data is None:
raise InvalidTokenError()
if token_data["email"] != args["email"]:
if token_data["email"] != args.email:
raise InvalidEmailError()
if token_data["code"] != args["code"]:
if token_data["code"] != args.code:
raise EmailCodeError()
AccountService.revoke_email_code_login_token(args["token"])
AccountService.revoke_email_code_login_token(args.token)
try:
account = AccountService.get_user_through_email(user_email)
except AccountRegisterError:
@ -255,7 +267,7 @@ class EmailCodeLoginApi(Resource):
except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"])
AccountService.reset_login_error_rate_limit(args.email)
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})

View File

@ -3,7 +3,8 @@ from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar
from flask import jsonify, request
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel
from werkzeug.exceptions import BadRequest, NotFound
from controllers.console.wraps import account_initialization_required, setup_required
@ -20,15 +21,34 @@ R = TypeVar("R")
T = TypeVar("T")
class OAuthClientPayload(BaseModel):
client_id: str
class OAuthProviderRequest(BaseModel):
client_id: str
redirect_uri: str
class OAuthTokenRequest(BaseModel):
client_id: str
grant_type: str
code: str | None = None
client_secret: str | None = None
redirect_uri: str | None = None
refresh_token: str | None = None
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
@wraps(view)
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
parser = reqparse.RequestParser().add_argument("client_id", type=str, required=True, location="json")
parsed_args = parser.parse_args()
client_id = parsed_args.get("client_id")
if not client_id:
json_data = request.get_json()
if json_data is None:
raise BadRequest("client_id is required")
payload = OAuthClientPayload.model_validate(json_data)
client_id = payload.client_id
oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id)
if not oauth_provider_app:
raise NotFound("client_id is invalid")
@ -89,9 +109,8 @@ class OAuthServerAppApi(Resource):
@setup_required
@oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp):
parser = reqparse.RequestParser().add_argument("redirect_uri", type=str, required=True, location="json")
parsed_args = parser.parse_args()
redirect_uri = parsed_args.get("redirect_uri")
payload = OAuthProviderRequest.model_validate(request.get_json())
redirect_uri = payload.redirect_uri
# check if redirect_uri is valid
if redirect_uri not in oauth_provider_app.redirect_uris:
@ -130,33 +149,25 @@ class OAuthServerUserTokenApi(Resource):
@setup_required
@oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp):
parser = (
reqparse.RequestParser()
.add_argument("grant_type", type=str, required=True, location="json")
.add_argument("code", type=str, required=False, location="json")
.add_argument("client_secret", type=str, required=False, location="json")
.add_argument("redirect_uri", type=str, required=False, location="json")
.add_argument("refresh_token", type=str, required=False, location="json")
)
parsed_args = parser.parse_args()
payload = OAuthTokenRequest.model_validate(request.get_json())
try:
grant_type = OAuthGrantType(parsed_args["grant_type"])
grant_type = OAuthGrantType(payload.grant_type)
except ValueError:
raise BadRequest("invalid grant_type")
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
if not parsed_args["code"]:
if not payload.code:
raise BadRequest("code is required")
if parsed_args["client_secret"] != oauth_provider_app.client_secret:
if payload.client_secret != oauth_provider_app.client_secret:
raise BadRequest("client_secret is invalid")
if parsed_args["redirect_uri"] not in oauth_provider_app.redirect_uris:
if payload.redirect_uri not in oauth_provider_app.redirect_uris:
raise BadRequest("redirect_uri is invalid")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, code=parsed_args["code"], client_id=oauth_provider_app.client_id
grant_type, code=payload.code, client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
@ -167,11 +178,11 @@ class OAuthServerUserTokenApi(Resource):
}
)
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
if not parsed_args["refresh_token"]:
if not payload.refresh_token:
raise BadRequest("refresh_token is required")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, refresh_token=parsed_args["refresh_token"], client_id=oauth_provider_app.client_id
grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{

View File

@ -1,6 +1,8 @@
import base64
from flask_restx import Resource, fields, reqparse
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import BadRequest
from controllers.console import console_ns
@ -9,6 +11,35 @@ from enums.cloud_plan import CloudPlan
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class SubscriptionQuery(BaseModel):
plan: str = Field(..., description="Subscription plan")
interval: str = Field(..., description="Billing interval")
@field_validator("plan")
@classmethod
def validate_plan(cls, value: str) -> str:
if value not in [CloudPlan.PROFESSIONAL, CloudPlan.TEAM]:
raise ValueError("Invalid plan")
return value
@field_validator("interval")
@classmethod
def validate_interval(cls, value: str) -> str:
if value not in {"month", "year"}:
raise ValueError("Invalid interval")
return value
class PartnerTenantsPayload(BaseModel):
click_id: str = Field(..., description="Click Id from partner referral link")
for model in (SubscriptionQuery, PartnerTenantsPayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/billing/subscription")
class Subscription(Resource):
@ -18,20 +49,9 @@ class Subscription(Resource):
@only_edition_cloud
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument(
"plan",
type=str,
required=True,
location="args",
choices=[CloudPlan.PROFESSIONAL, CloudPlan.TEAM],
)
.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
)
args = parser.parse_args()
args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
BillingService.is_tenant_owner_or_admin(current_user)
return BillingService.get_subscription(args["plan"], args["interval"], current_user.email, current_tenant_id)
return BillingService.get_subscription(args.plan, args.interval, current_user.email, current_tenant_id)
@console_ns.route("/billing/invoices")
@ -65,11 +85,10 @@ class PartnerTenants(Resource):
@only_edition_cloud
def put(self, partner_key: str):
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("click_id", required=True, type=str, location="json")
args = parser.parse_args()
try:
click_id = args["click_id"]
args = PartnerTenantsPayload.model_validate(console_ns.payload or {})
click_id = args.click_id
decoded_partner_key = base64.b64decode(partner_key).decode("utf-8")
except Exception:
raise BadRequest("Invalid partner_key")

View File

@ -1,5 +1,6 @@
from flask import request
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel, Field
from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant, login_required
@ -9,16 +10,28 @@ from .. import console_ns
from ..wraps import account_initialization_required, only_edition_cloud, setup_required
class ComplianceDownloadQuery(BaseModel):
doc_name: str = Field(..., description="Compliance document name")
console_ns.schema_model(
ComplianceDownloadQuery.__name__,
ComplianceDownloadQuery.model_json_schema(ref_template="#/definitions/{model}"),
)
@console_ns.route("/compliance/download")
class ComplianceApi(Resource):
@console_ns.expect(console_ns.models[ComplianceDownloadQuery.__name__])
@console_ns.doc("download_compliance_document")
@console_ns.doc(description="Get compliance document download link")
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("doc_name", type=str, required=True, location="args")
args = parser.parse_args()
args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
ip_address = extract_remote_ip(request)
device_info = request.headers.get("User-Agent", "Unknown device")

View File

@ -1,15 +1,15 @@
import json
from collections.abc import Generator
from typing import cast
from typing import Any, cast
from flask import request
from flask_restx import Resource, marshal_with, reqparse
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.common.schema import register_schema_model
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.indexing_runner import IndexingRunner
@ -25,6 +25,19 @@ from services.dataset_service import DatasetService, DocumentService
from services.datasource_provider_service import DatasourceProviderService
from tasks.document_indexing_sync_task import document_indexing_sync_task
from .. import console_ns
from ..wraps import account_initialization_required, setup_required
class NotionEstimatePayload(BaseModel):
notion_info_list: list[dict[str, Any]]
process_rule: dict[str, Any]
doc_form: str = Field(default="text_model")
doc_language: str = Field(default="English")
register_schema_model(console_ns, NotionEstimatePayload)
@console_ns.route(
"/data-source/integrates",
@ -243,20 +256,15 @@ class DataSourceNotionApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.expect(console_ns.models[NotionEstimatePayload.__name__])
def post(self):
_, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
)
args = parser.parse_args()
payload = NotionEstimatePayload.model_validate(console_ns.payload or {})
args = payload.model_dump()
# validate args
DocumentService.estimate_args_validate(args)
notion_info_list = args["notion_info_list"]
notion_info_list = payload.notion_info_list
extract_settings = []
for notion_info in notion_info_list:
workspace_id = notion_info["workspace_id"]

View File

@ -1,12 +1,14 @@
from typing import Any, cast
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound
import services
from configs import dify_config
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.apikey import (
api_key_item_model,
@ -48,7 +50,6 @@ from fields.dataset_fields import (
)
from fields.document_fields import document_status_fields
from libs.login import current_account_with_tenant, login_required
from libs.validators import validate_description_length
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID
@ -107,10 +108,75 @@ related_app_list_copy["data"] = fields.List(fields.Nested(app_detail_kernel_mode
related_app_list_model = _get_or_create_model("RelatedAppList", related_app_list_copy)
def _validate_name(name: str) -> str:
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_indexing_technique(value: str | None) -> str | None:
if value is None:
return value
if value not in Dataset.INDEXING_TECHNIQUE_LIST:
raise ValueError("Invalid indexing technique.")
return value
class DatasetCreatePayload(BaseModel):
name: str = Field(..., min_length=1, max_length=40)
description: str = Field("", max_length=400)
indexing_technique: str | None = None
permission: DatasetPermissionEnum | None = DatasetPermissionEnum.ONLY_ME
provider: str = "vendor"
external_knowledge_api_id: str | None = None
external_knowledge_id: str | None = None
@field_validator("indexing_technique")
@classmethod
def validate_indexing(cls, value: str | None) -> str | None:
return _validate_indexing_technique(value)
@field_validator("provider")
@classmethod
def validate_provider(cls, value: str) -> str:
if value not in Dataset.PROVIDER_LIST:
raise ValueError("Invalid provider.")
return value
class DatasetUpdatePayload(BaseModel):
name: str | None = Field(None, min_length=1, max_length=40)
description: str | None = Field(None, max_length=400)
permission: DatasetPermissionEnum | None = None
indexing_technique: str | None = None
embedding_model: str | None = None
embedding_model_provider: str | None = None
retrieval_model: dict[str, Any] | None = None
partial_member_list: list[str] | None = None
external_retrieval_model: dict[str, Any] | None = None
external_knowledge_id: str | None = None
external_knowledge_api_id: str | None = None
icon_info: dict[str, Any] | None = None
is_multimodal: bool | None = False
@field_validator("indexing_technique")
@classmethod
def validate_indexing(cls, value: str | None) -> str | None:
return _validate_indexing_technique(value)
class IndexingEstimatePayload(BaseModel):
info_list: dict[str, Any]
process_rule: dict[str, Any]
indexing_technique: str
doc_form: str = "text_model"
dataset_id: str | None = None
doc_language: str = "English"
@field_validator("indexing_technique")
@classmethod
def validate_indexing(cls, value: str) -> str:
result = _validate_indexing_technique(value)
if result is None:
raise ValueError("indexing_technique is required.")
return result
register_schema_models(console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload)
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
@ -164,6 +230,7 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
VectorType.CLICKZETTA,
VectorType.BAIDU,
VectorType.ALIBABACLOUD_MYSQL,
VectorType.IRIS,
}
semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
@ -255,20 +322,7 @@ class DatasetListApi(Resource):
@console_ns.doc("create_dataset")
@console_ns.doc(description="Create a new dataset")
@console_ns.expect(
console_ns.model(
"CreateDatasetRequest",
{
"name": fields.String(required=True, description="Dataset name (1-40 characters)"),
"description": fields.String(description="Dataset description (max 400 characters)"),
"indexing_technique": fields.String(description="Indexing technique"),
"permission": fields.String(description="Dataset permission"),
"provider": fields.String(description="Provider"),
"external_knowledge_api_id": fields.String(description="External knowledge API ID"),
"external_knowledge_id": fields.String(description="External knowledge ID"),
},
)
)
@console_ns.expect(console_ns.models[DatasetCreatePayload.__name__])
@console_ns.response(201, "Dataset created successfully")
@console_ns.response(400, "Invalid request parameters")
@setup_required
@ -276,52 +330,7 @@ class DatasetListApi(Resource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument(
"description",
type=validate_description_length,
nullable=True,
required=False,
default="",
)
.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
.add_argument(
"external_knowledge_api_id",
type=str,
nullable=True,
required=False,
)
.add_argument(
"provider",
type=str,
nullable=True,
choices=Dataset.PROVIDER_LIST,
required=False,
default="vendor",
)
.add_argument(
"external_knowledge_id",
type=str,
nullable=True,
required=False,
)
)
args = parser.parse_args()
payload = DatasetCreatePayload.model_validate(console_ns.payload or {})
current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
@ -331,14 +340,14 @@ class DatasetListApi(Resource):
try:
dataset = DatasetService.create_empty_dataset(
tenant_id=current_tenant_id,
name=args["name"],
description=args["description"],
indexing_technique=args["indexing_technique"],
name=payload.name,
description=payload.description,
indexing_technique=payload.indexing_technique,
account=current_user,
permission=DatasetPermissionEnum.ONLY_ME,
provider=args["provider"],
external_knowledge_api_id=args["external_knowledge_api_id"],
external_knowledge_id=args["external_knowledge_id"],
permission=payload.permission or DatasetPermissionEnum.ONLY_ME,
provider=payload.provider,
external_knowledge_api_id=payload.external_knowledge_api_id,
external_knowledge_id=payload.external_knowledge_id,
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
@ -399,18 +408,7 @@ class DatasetApi(Resource):
@console_ns.doc("update_dataset")
@console_ns.doc(description="Update dataset details")
@console_ns.expect(
console_ns.model(
"UpdateDatasetRequest",
{
"name": fields.String(description="Dataset name"),
"description": fields.String(description="Dataset description"),
"permission": fields.String(description="Dataset permission"),
"indexing_technique": fields.String(description="Indexing technique"),
"external_retrieval_model": fields.Raw(description="External retrieval model settings"),
},
)
)
@console_ns.expect(console_ns.models[DatasetUpdatePayload.__name__])
@console_ns.response(200, "Dataset updated successfully", dataset_detail_model)
@console_ns.response(404, "Dataset not found")
@console_ns.response(403, "Permission denied")
@ -424,93 +422,25 @@ class DatasetApi(Resource):
if dataset is None:
raise NotFound("Dataset not found.")
parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument("description", location="json", store_missing=False, type=validate_description_length)
.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
.add_argument(
"permission",
type=str,
location="json",
choices=(
DatasetPermissionEnum.ONLY_ME,
DatasetPermissionEnum.ALL_TEAM,
DatasetPermissionEnum.PARTIAL_TEAM,
),
help="Invalid permission.",
)
.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
.add_argument(
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
)
.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
.add_argument(
"external_retrieval_model",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid external retrieval model.",
)
.add_argument(
"external_knowledge_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge id.",
)
.add_argument(
"external_knowledge_api_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge api id.",
)
.add_argument(
"icon_info",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid icon info.",
)
)
args = parser.parse_args()
data = request.get_json()
payload = DatasetUpdatePayload.model_validate(console_ns.payload or {})
current_user, current_tenant_id = current_account_with_tenant()
# check embedding model setting
if (
data.get("indexing_technique") == "high_quality"
and data.get("embedding_model_provider") is not None
and data.get("embedding_model") is not None
payload.indexing_technique == "high_quality"
and payload.embedding_model_provider is not None
and payload.embedding_model is not None
):
DatasetService.check_embedding_model_setting(
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
is_multimodal = DatasetService.check_is_multimodal_model(
dataset.tenant_id, payload.embedding_model_provider, payload.embedding_model
)
payload.is_multimodal = is_multimodal
payload_data = payload.model_dump(exclude_unset=True)
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission(
current_user, dataset, data.get("permission"), data.get("partial_member_list")
current_user, dataset, payload.permission, payload.partial_member_list
)
dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
dataset = DatasetService.update_dataset(dataset_id_str, payload_data, current_user)
if dataset is None:
raise NotFound("Dataset not found.")
@ -518,15 +448,10 @@ class DatasetApi(Resource):
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
tenant_id = current_tenant_id
if data.get("partial_member_list") and data.get("permission") == "partial_members":
DatasetPermissionService.update_partial_member_list(
tenant_id, dataset_id_str, data.get("partial_member_list")
)
if payload.partial_member_list is not None and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM:
DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id_str, payload.partial_member_list)
# clear partial member list when permission is only_me or all_team_members
elif (
data.get("permission") == DatasetPermissionEnum.ONLY_ME
or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
):
elif payload.permission in {DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM}:
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
@ -615,24 +540,10 @@ class DatasetIndexingEstimateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.expect(console_ns.models[IndexingEstimatePayload.__name__])
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("info_list", type=dict, required=True, nullable=True, location="json")
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
.add_argument(
"indexing_technique",
type=str,
required=True,
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
location="json",
)
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
)
args = parser.parse_args()
payload = IndexingEstimatePayload.model_validate(console_ns.payload or {})
args = payload.model_dump()
_, current_tenant_id = current_account_with_tenant()
# validate args
DocumentService.estimate_args_validate(args)

View File

@ -6,31 +6,14 @@ from typing import Literal, cast
import sqlalchemy as sa
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel
from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import (
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.datasets.error import (
ArchivedDocumentImmutableError,
DocumentAlreadyFinishedError,
DocumentIndexingError,
IndexingEstimateError,
InvalidActionError,
InvalidMetadataError,
)
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
setup_required,
)
from core.errors.error import (
LLMBadRequestError,
ModelCurrentlyNotSupportError,
@ -55,10 +38,30 @@ from fields.document_fields import (
)
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
from ..app.error import (
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from ..datasets.error import (
ArchivedDocumentImmutableError,
DocumentAlreadyFinishedError,
DocumentIndexingError,
IndexingEstimateError,
InvalidActionError,
InvalidMetadataError,
)
from ..wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
setup_required,
)
logger = logging.getLogger(__name__)
@ -93,6 +96,24 @@ dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(docume
dataset_and_document_model = _get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
class DocumentRetryPayload(BaseModel):
document_ids: list[str]
class DocumentRenamePayload(BaseModel):
name: str
register_schema_models(
console_ns,
KnowledgeConfig,
ProcessRule,
RetrievalModel,
DocumentRetryPayload,
DocumentRenamePayload,
)
class DocumentResource(Resource):
def get_document(self, dataset_id: str, document_id: str) -> Document:
current_user, current_tenant_id = current_account_with_tenant()
@ -201,8 +222,9 @@ class DatasetDocumentListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: str):
def get(self, dataset_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str)
@ -310,6 +332,7 @@ class DatasetDocumentListApi(Resource):
@marshal_with(dataset_and_document_model)
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
@ -328,23 +351,7 @@ class DatasetDocumentListApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
parser = (
reqparse.RequestParser()
.add_argument(
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
)
.add_argument("data_source", type=dict, required=False, location="json")
.add_argument("process_rule", type=dict, required=False, location="json")
.add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
.add_argument("original_document_id", type=str, required=False, location="json")
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
)
args = parser.parse_args()
knowledge_config = KnowledgeConfig.model_validate(args)
knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {})
if not dataset.indexing_technique and not knowledge_config.indexing_technique:
raise ValueError("indexing_technique is required.")
@ -390,17 +397,7 @@ class DatasetDocumentListApi(Resource):
class DatasetInitApi(Resource):
@console_ns.doc("init_dataset")
@console_ns.doc(description="Initialize dataset with documents")
@console_ns.expect(
console_ns.model(
"DatasetInitRequest",
{
"upload_file_id": fields.String(required=True, description="Upload file ID"),
"indexing_technique": fields.String(description="Indexing technique"),
"process_rule": fields.Raw(description="Processing rules"),
"data_source": fields.Raw(description="Data source configuration"),
},
)
)
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
@console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model)
@console_ns.response(400, "Invalid request parameters")
@setup_required
@ -415,27 +412,7 @@ class DatasetInitApi(Resource):
if not current_user.is_dataset_editor:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument(
"indexing_technique",
type=str,
choices=Dataset.INDEXING_TECHNIQUE_LIST,
required=True,
nullable=False,
location="json",
)
.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
)
args = parser.parse_args()
knowledge_config = KnowledgeConfig.model_validate(args)
knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {})
if knowledge_config.indexing_technique == "high_quality":
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
@ -443,10 +420,14 @@ class DatasetInitApi(Resource):
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_tenant_id,
provider=args["embedding_model_provider"],
provider=knowledge_config.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=args["embedding_model"],
model=knowledge_config.embedding_model,
)
is_multimodal = DatasetService.check_is_multimodal_model(
current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model
)
knowledge_config.is_multimodal = is_multimodal
except InvokeAuthorizationError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
@ -1076,19 +1057,16 @@ class DocumentRetryApi(DocumentResource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[DocumentRetryPayload.__name__])
def post(self, dataset_id):
"""retry document."""
parser = reqparse.RequestParser().add_argument(
"document_ids", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args()
payload = DocumentRetryPayload.model_validate(console_ns.payload or {})
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
retry_documents = []
if not dataset:
raise NotFound("Dataset not found.")
for document_id in args["document_ids"]:
for document_id in payload.document_ids:
try:
document_id = str(document_id)
@ -1121,6 +1099,7 @@ class DocumentRenameApi(DocumentResource):
@login_required
@account_initialization_required
@marshal_with(document_fields)
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
def post(self, dataset_id, document_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
current_user, _ = current_account_with_tenant()
@ -1130,11 +1109,10 @@ class DocumentRenameApi(DocumentResource):
if not dataset:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_operator_permission(current_user, dataset)
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
payload = DocumentRenamePayload.model_validate(console_ns.payload or {})
try:
document = DocumentService.rename_document(dataset_id, document_id, args["name"])
document = DocumentService.rename_document(dataset_id, document_id, payload.name)
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")

View File

@ -1,11 +1,13 @@
import uuid
from flask import request
from flask_restx import Resource, marshal, reqparse
from flask_restx import Resource, marshal
from pydantic import BaseModel, Field
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import (
@ -36,6 +38,58 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
class SegmentListQuery(BaseModel):
limit: int = Field(default=20, ge=1, le=100)
status: list[str] = Field(default_factory=list)
hit_count_gte: int | None = None
enabled: str = Field(default="all")
keyword: str | None = None
page: int = Field(default=1, ge=1)
class SegmentCreatePayload(BaseModel):
content: str
answer: str | None = None
keywords: list[str] | None = None
attachment_ids: list[str] | None = None
class SegmentUpdatePayload(BaseModel):
content: str
answer: str | None = None
keywords: list[str] | None = None
regenerate_child_chunks: bool = False
attachment_ids: list[str] | None = None
class BatchImportPayload(BaseModel):
upload_file_id: str
class ChildChunkCreatePayload(BaseModel):
content: str
class ChildChunkUpdatePayload(BaseModel):
content: str
class ChildChunkBatchUpdatePayload(BaseModel):
chunks: list[ChildChunkUpdateArgs]
register_schema_models(
console_ns,
SegmentListQuery,
SegmentCreatePayload,
SegmentUpdatePayload,
BatchImportPayload,
ChildChunkCreatePayload,
ChildChunkUpdatePayload,
ChildChunkBatchUpdatePayload,
)
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
class DatasetDocumentSegmentListApi(Resource):
@setup_required
@ -60,23 +114,18 @@ class DatasetDocumentSegmentListApi(Resource):
if not document:
raise NotFound("Document not found.")
parser = (
reqparse.RequestParser()
.add_argument("limit", type=int, default=20, location="args")
.add_argument("status", type=str, action="append", default=[], location="args")
.add_argument("hit_count_gte", type=int, default=None, location="args")
.add_argument("enabled", type=str, default="all", location="args")
.add_argument("keyword", type=str, default=None, location="args")
.add_argument("page", type=int, default=1, location="args")
args = SegmentListQuery.model_validate(
{
**request.args.to_dict(),
"status": request.args.getlist("status"),
}
)
args = parser.parse_args()
page = args["page"]
limit = min(args["limit"], 100)
status_list = args["status"]
hit_count_gte = args["hit_count_gte"]
keyword = args["keyword"]
page = args.page
limit = min(args.limit, 100)
status_list = args.status
hit_count_gte = args.hit_count_gte
keyword = args.keyword
query = (
select(DocumentSegment)
@ -96,10 +145,10 @@ class DatasetDocumentSegmentListApi(Resource):
if keyword:
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
if args["enabled"].lower() != "all":
if args["enabled"].lower() == "true":
if args.enabled.lower() != "all":
if args.enabled.lower() == "true":
query = query.where(DocumentSegment.enabled == True)
elif args["enabled"].lower() == "false":
elif args.enabled.lower() == "false":
query = query.where(DocumentSegment.enabled == False)
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
@ -210,6 +259,7 @@ class DatasetDocumentSegmentAddApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[SegmentCreatePayload.__name__])
def post(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
@ -246,15 +296,10 @@ class DatasetDocumentSegmentAddApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = (
reqparse.RequestParser()
.add_argument("content", type=str, required=True, nullable=False, location="json")
.add_argument("answer", type=str, required=False, nullable=True, location="json")
.add_argument("keywords", type=list, required=False, nullable=True, location="json")
)
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.create_segment(args, document, dataset)
payload = SegmentCreatePayload.model_validate(console_ns.payload or {})
payload_dict = payload.model_dump(exclude_none=True)
SegmentService.segment_create_args_validate(payload_dict, document)
segment = SegmentService.create_segment(payload_dict, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@ -265,6 +310,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[SegmentUpdatePayload.__name__])
def patch(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
@ -313,18 +359,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = (
reqparse.RequestParser()
.add_argument("content", type=str, required=True, nullable=False, location="json")
.add_argument("answer", type=str, required=False, nullable=True, location="json")
.add_argument("keywords", type=list, required=False, nullable=True, location="json")
.add_argument(
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
)
payload = SegmentUpdatePayload.model_validate(console_ns.payload or {})
payload_dict = payload.model_dump(exclude_none=True)
SegmentService.segment_create_args_validate(payload_dict, document)
segment = SegmentService.update_segment(
SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
)
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(SegmentUpdateArgs.model_validate(args), segment, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@setup_required
@ -377,6 +417,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[BatchImportPayload.__name__])
def post(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
@ -391,11 +432,8 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
if not document:
raise NotFound("Document not found.")
parser = reqparse.RequestParser().add_argument(
"upload_file_id", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
upload_file_id = args["upload_file_id"]
payload = BatchImportPayload.model_validate(console_ns.payload or {})
upload_file_id = payload.upload_file_id
upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
if not upload_file:
@ -446,6 +484,7 @@ class ChildChunkAddApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[ChildChunkCreatePayload.__name__])
def post(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
@ -491,13 +530,9 @@ class ChildChunkAddApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser().add_argument(
"content", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
try:
content = args["content"]
child_chunk = SegmentService.create_child_chunk(content, segment, document, dataset)
payload = ChildChunkCreatePayload.model_validate(console_ns.payload or {})
child_chunk = SegmentService.create_child_chunk(payload.content, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
@ -529,18 +564,17 @@ class ChildChunkAddApi(Resource):
)
if not segment:
raise NotFound("Segment not found.")
parser = (
reqparse.RequestParser()
.add_argument("limit", type=int, default=20, location="args")
.add_argument("keyword", type=str, default=None, location="args")
.add_argument("page", type=int, default=1, location="args")
args = SegmentListQuery.model_validate(
{
"limit": request.args.get("limit", default=20, type=int),
"keyword": request.args.get("keyword"),
"page": request.args.get("page", default=1, type=int),
}
)
args = parser.parse_args()
page = args["page"]
limit = min(args["limit"], 100)
keyword = args["keyword"]
page = args.page
limit = min(args.limit, 100)
keyword = args.keyword
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
return {
@ -588,14 +622,9 @@ class ChildChunkAddApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser().add_argument(
"chunks", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args()
payload = ChildChunkBatchUpdatePayload.model_validate(console_ns.payload or {})
try:
chunks_data = args["chunks"]
chunks = [ChildChunkUpdateArgs.model_validate(chunk) for chunk in chunks_data]
child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
child_chunks = SegmentService.update_child_chunks(payload.chunks, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunks, child_chunk_fields)}, 200
@ -665,6 +694,7 @@ class ChildChunkUpdateApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[ChildChunkUpdatePayload.__name__])
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
current_user, current_tenant_id = current_account_with_tenant()
@ -711,13 +741,9 @@ class ChildChunkUpdateApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser().add_argument(
"content", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
try:
content = args["content"]
child_chunk = SegmentService.update_child_chunk(content, child_chunk, segment, document, dataset)
payload = ChildChunkUpdatePayload.model_validate(console_ns.payload or {})
child_chunk = SegmentService.update_child_chunk(payload.content, child_chunk, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200

View File

@ -1,8 +1,10 @@
from flask import request
from flask_restx import Resource, fields, marshal, reqparse
from flask_restx import Resource, fields, marshal
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
@ -71,10 +73,38 @@ except KeyError:
dataset_detail_model = _build_dataset_detail_model()
def _validate_name(name: str) -> str:
if not name or len(name) < 1 or len(name) > 100:
raise ValueError("Name must be between 1 to 100 characters.")
return name
class ExternalKnowledgeApiPayload(BaseModel):
name: str = Field(..., min_length=1, max_length=40)
settings: dict[str, object]
class ExternalDatasetCreatePayload(BaseModel):
external_knowledge_api_id: str
external_knowledge_id: str
name: str = Field(..., min_length=1, max_length=40)
description: str | None = Field(None, max_length=400)
external_retrieval_model: dict[str, object] | None = None
class ExternalHitTestingPayload(BaseModel):
query: str
external_retrieval_model: dict[str, object] | None = None
metadata_filtering_conditions: dict[str, object] | None = None
class BedrockRetrievalPayload(BaseModel):
retrieval_setting: dict[str, object]
query: str
knowledge_id: str
register_schema_models(
console_ns,
ExternalKnowledgeApiPayload,
ExternalDatasetCreatePayload,
ExternalHitTestingPayload,
BedrockRetrievalPayload,
)
@console_ns.route("/datasets/external-knowledge-api")
@ -113,28 +143,12 @@ class ExternalApiTemplateListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
def post(self):
current_user, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="Name is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
.add_argument(
"settings",
type=dict,
location="json",
nullable=False,
required=True,
)
)
args = parser.parse_args()
payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
ExternalDatasetService.validate_api_list(args["settings"])
ExternalDatasetService.validate_api_list(payload.settings)
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
@ -142,7 +156,7 @@ class ExternalApiTemplateListApi(Resource):
try:
external_knowledge_api = ExternalDatasetService.create_external_knowledge_api(
tenant_id=current_tenant_id, user_id=current_user.id, args=args
tenant_id=current_tenant_id, user_id=current_user.id, args=payload.model_dump()
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
@ -171,35 +185,19 @@ class ExternalApiTemplateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
def patch(self, external_knowledge_api_id):
current_user, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id = str(external_knowledge_api_id)
parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="type is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
.add_argument(
"settings",
type=dict,
location="json",
nullable=False,
required=True,
)
)
args = parser.parse_args()
ExternalDatasetService.validate_api_list(args["settings"])
payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
ExternalDatasetService.validate_api_list(payload.settings)
external_knowledge_api = ExternalDatasetService.update_external_knowledge_api(
tenant_id=current_tenant_id,
user_id=current_user.id,
external_knowledge_api_id=external_knowledge_api_id,
args=args,
args=payload.model_dump(),
)
return external_knowledge_api.to_dict(), 200
@ -240,17 +238,7 @@ class ExternalApiUseCheckApi(Resource):
class ExternalDatasetCreateApi(Resource):
@console_ns.doc("create_external_dataset")
@console_ns.doc(description="Create external knowledge dataset")
@console_ns.expect(
console_ns.model(
"CreateExternalDatasetRequest",
{
"external_knowledge_api_id": fields.String(required=True, description="External knowledge API ID"),
"external_knowledge_id": fields.String(required=True, description="External knowledge ID"),
"name": fields.String(required=True, description="Dataset name"),
"description": fields.String(description="Dataset description"),
},
)
)
@console_ns.expect(console_ns.models[ExternalDatasetCreatePayload.__name__])
@console_ns.response(201, "External dataset created successfully", dataset_detail_model)
@console_ns.response(400, "Invalid parameters")
@console_ns.response(403, "Permission denied")
@ -261,22 +249,8 @@ class ExternalDatasetCreateApi(Resource):
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")
.add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json")
.add_argument(
"name",
nullable=False,
required=True,
help="name is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
.add_argument("description", type=str, required=False, nullable=True, location="json")
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
)
args = parser.parse_args()
payload = ExternalDatasetCreatePayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
@ -299,16 +273,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
@console_ns.doc("test_external_knowledge_retrieval")
@console_ns.doc(description="Test external knowledge retrieval for dataset")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect(
console_ns.model(
"ExternalHitTestingRequest",
{
"query": fields.String(required=True, description="Query text for testing"),
"retrieval_model": fields.Raw(description="Retrieval model configuration"),
"external_retrieval_model": fields.Raw(description="External retrieval model configuration"),
},
)
)
@console_ns.expect(console_ns.models[ExternalHitTestingPayload.__name__])
@console_ns.response(200, "External hit testing completed successfully")
@console_ns.response(404, "Dataset not found")
@console_ns.response(400, "Invalid parameters")
@ -327,23 +292,16 @@ class ExternalKnowledgeHitTestingApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
parser = (
reqparse.RequestParser()
.add_argument("query", type=str, location="json")
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json")
)
args = parser.parse_args()
HitTestingService.hit_testing_args_check(args)
payload = ExternalHitTestingPayload.model_validate(console_ns.payload or {})
HitTestingService.hit_testing_args_check(payload.model_dump())
try:
response = HitTestingService.external_retrieve(
dataset=dataset,
query=args["query"],
query=payload.query,
account=current_user,
external_retrieval_model=args["external_retrieval_model"],
metadata_filtering_conditions=args["metadata_filtering_conditions"],
external_retrieval_model=payload.external_retrieval_model,
metadata_filtering_conditions=payload.metadata_filtering_conditions,
)
return response
@ -356,33 +314,13 @@ class BedrockRetrievalApi(Resource):
# this api is only for internal testing
@console_ns.doc("bedrock_retrieval_test")
@console_ns.doc(description="Bedrock retrieval test (internal use only)")
@console_ns.expect(
console_ns.model(
"BedrockRetrievalTestRequest",
{
"retrieval_setting": fields.Raw(required=True, description="Retrieval settings"),
"query": fields.String(required=True, description="Query text"),
"knowledge_id": fields.String(required=True, description="Knowledge ID"),
},
)
)
@console_ns.expect(console_ns.models[BedrockRetrievalPayload.__name__])
@console_ns.response(200, "Bedrock retrieval test completed")
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
.add_argument(
"query",
nullable=False,
required=True,
type=str,
)
.add_argument("knowledge_id", nullable=False, required=True, type=str)
)
args = parser.parse_args()
payload = BedrockRetrievalPayload.model_validate(console_ns.payload or {})
# Call the knowledge retrieval service
result = ExternalDatasetTestService.knowledge_retrieval(
args["retrieval_setting"], args["query"], args["knowledge_id"]
payload.retrieval_setting, payload.query, payload.knowledge_id
)
return result, 200

View File

@ -1,13 +1,17 @@
from flask_restx import Resource, fields
from flask_restx import Resource
from controllers.console import console_ns
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
from controllers.console.wraps import (
from controllers.common.schema import register_schema_model
from libs.login import login_required
from .. import console_ns
from ..datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload
from ..wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
setup_required,
)
from libs.login import login_required
register_schema_model(console_ns, HitTestingPayload)
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
@ -15,17 +19,7 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
@console_ns.doc("test_dataset_retrieval")
@console_ns.doc(description="Test dataset knowledge retrieval")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect(
console_ns.model(
"HitTestingRequest",
{
"query": fields.String(required=True, description="Query text for testing"),
"retrieval_model": fields.Raw(description="Retrieval model configuration"),
"top_k": fields.Integer(description="Number of top results to return"),
"score_threshold": fields.Float(description="Score threshold for filtering results"),
},
)
)
@console_ns.expect(console_ns.models[HitTestingPayload.__name__])
@console_ns.response(200, "Hit testing completed successfully")
@console_ns.response(404, "Dataset not found")
@console_ns.response(400, "Invalid parameters")
@ -37,7 +31,8 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
dataset_id_str = str(dataset_id)
dataset = self.get_and_validate_dataset(dataset_id_str)
args = self.parse_args()
payload = HitTestingPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
self.hit_testing_args_check(args)
return self.perform_hit_testing(dataset, args)

View File

@ -1,6 +1,8 @@
import logging
from typing import Any
from flask_restx import marshal, reqparse
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
@ -27,6 +29,13 @@ from services.hit_testing_service import HitTestingService
logger = logging.getLogger(__name__)
class HitTestingPayload(BaseModel):
query: str = Field(max_length=250)
retrieval_model: dict[str, Any] | None = None
external_retrieval_model: dict[str, Any] | None = None
attachment_ids: list[str] | None = None
class DatasetsHitTestingBase:
@staticmethod
def get_and_validate_dataset(dataset_id: str):
@ -43,14 +52,15 @@ class DatasetsHitTestingBase:
return dataset
@staticmethod
def hit_testing_args_check(args):
def hit_testing_args_check(args: dict[str, Any]):
HitTestingService.hit_testing_args_check(args)
@staticmethod
def parse_args():
parser = (
reqparse.RequestParser()
.add_argument("query", type=str, location="json")
.add_argument("query", type=str, required=False, location="json")
.add_argument("attachment_ids", type=list, required=False, location="json")
.add_argument("retrieval_model", type=dict, required=False, location="json")
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
)
@ -62,10 +72,11 @@ class DatasetsHitTestingBase:
try:
response = HitTestingService.retrieve(
dataset=dataset,
query=args["query"],
query=args.get("query"),
account=current_user,
retrieval_model=args["retrieval_model"],
external_retrieval_model=args["external_retrieval_model"],
retrieval_model=args.get("retrieval_model"),
external_retrieval_model=args.get("external_retrieval_model"),
attachment_ids=args.get("attachment_ids"),
limit=10,
)
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}

View File

@ -1,8 +1,10 @@
from typing import Literal
from flask_restx import Resource, marshal_with, reqparse
from flask_restx import Resource, marshal_with
from pydantic import BaseModel
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_model, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from fields.dataset_fields import dataset_metadata_fields
@ -15,6 +17,14 @@ from services.entities.knowledge_entities.knowledge_entities import (
from services.metadata_service import MetadataService
class MetadataUpdatePayload(BaseModel):
name: str
register_schema_models(console_ns, MetadataArgs, MetadataOperationData)
register_schema_model(console_ns, MetadataUpdatePayload)
@console_ns.route("/datasets/<uuid:dataset_id>/metadata")
class DatasetMetadataCreateApi(Resource):
@setup_required
@ -22,15 +32,10 @@ class DatasetMetadataCreateApi(Resource):
@account_initialization_required
@enterprise_license_required
@marshal_with(dataset_metadata_fields)
@console_ns.expect(console_ns.models[MetadataArgs.__name__])
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("type", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
metadata_args = MetadataArgs.model_validate(args)
metadata_args = MetadataArgs.model_validate(console_ns.payload or {})
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -60,11 +65,11 @@ class DatasetMetadataApi(Resource):
@account_initialization_required
@enterprise_license_required
@marshal_with(dataset_metadata_fields)
@console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__])
def patch(self, dataset_id, metadata_id):
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
name = args["name"]
payload = MetadataUpdatePayload.model_validate(console_ns.payload or {})
name = payload.name
dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id)
@ -131,6 +136,7 @@ class DocumentMetadataEditApi(Resource):
@login_required
@account_initialization_required
@enterprise_license_required
@console_ns.expect(console_ns.models[MetadataOperationData.__name__])
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
@ -139,11 +145,7 @@ class DocumentMetadataEditApi(Resource):
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
parser = reqparse.RequestParser().add_argument(
"operation_data", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args()
metadata_args = MetadataOperationData.model_validate(args)
metadata_args = MetadataOperationData.model_validate(console_ns.payload or {})
MetadataService.update_documents_metadata(dataset, metadata_args)

View File

@ -1,20 +1,63 @@
from typing import Any
from flask import make_response, redirect, request
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.oauth import OAuthHandler
from libs.helper import StrLen
from libs.login import current_account_with_tenant, login_required
from models.provider_ids import DatasourceProviderID
from services.datasource_provider_service import DatasourceProviderService
from services.plugin.oauth_service import OAuthProxyService
class DatasourceCredentialPayload(BaseModel):
name: str | None = Field(default=None, max_length=100)
credentials: dict[str, Any]
class DatasourceCredentialDeletePayload(BaseModel):
credential_id: str
class DatasourceCredentialUpdatePayload(BaseModel):
credential_id: str
name: str | None = Field(default=None, max_length=100)
credentials: dict[str, Any] | None = None
class DatasourceCustomClientPayload(BaseModel):
client_params: dict[str, Any] | None = None
enable_oauth_custom_client: bool | None = None
class DatasourceDefaultPayload(BaseModel):
id: str
class DatasourceUpdateNamePayload(BaseModel):
credential_id: str
name: str = Field(max_length=100)
register_schema_models(
console_ns,
DatasourceCredentialPayload,
DatasourceCredentialDeletePayload,
DatasourceCredentialUpdatePayload,
DatasourceCustomClientPayload,
DatasourceDefaultPayload,
DatasourceUpdateNamePayload,
)
@console_ns.route("/oauth/plugin/<path:provider_id>/datasource/get-authorization-url")
class DatasourcePluginOAuthAuthorizationUrl(Resource):
@setup_required
@ -121,16 +164,9 @@ class DatasourceOAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
parser_datasource = (
reqparse.RequestParser()
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>")
class DatasourceAuth(Resource):
@console_ns.expect(parser_datasource)
@console_ns.expect(console_ns.models[DatasourceCredentialPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -138,7 +174,7 @@ class DatasourceAuth(Resource):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
args = parser_datasource.parse_args()
payload = DatasourceCredentialPayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
@ -146,8 +182,8 @@ class DatasourceAuth(Resource):
datasource_provider_service.add_datasource_api_key_provider(
tenant_id=current_tenant_id,
provider_id=datasource_provider_id,
credentials=args["credentials"],
name=args["name"],
credentials=payload.credentials,
name=payload.name,
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
@ -169,14 +205,9 @@ class DatasourceAuth(Resource):
return {"result": datasources}, 200
parser_datasource_delete = reqparse.RequestParser().add_argument(
"credential_id", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete")
class DatasourceAuthDeleteApi(Resource):
@console_ns.expect(parser_datasource_delete)
@console_ns.expect(console_ns.models[DatasourceCredentialDeletePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -188,28 +219,20 @@ class DatasourceAuthDeleteApi(Resource):
plugin_id = datasource_provider_id.plugin_id
provider_name = datasource_provider_id.provider_name
args = parser_datasource_delete.parse_args()
payload = DatasourceCredentialDeletePayload.model_validate(console_ns.payload or {})
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials(
tenant_id=current_tenant_id,
auth_id=args["credential_id"],
auth_id=payload.credential_id,
provider=provider_name,
plugin_id=plugin_id,
)
return {"result": "success"}, 200
parser_datasource_update = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update")
class DatasourceAuthUpdateApi(Resource):
@console_ns.expect(parser_datasource_update)
@console_ns.expect(console_ns.models[DatasourceCredentialUpdatePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -218,16 +241,16 @@ class DatasourceAuthUpdateApi(Resource):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id)
args = parser_datasource_update.parse_args()
payload = DatasourceCredentialUpdatePayload.model_validate(console_ns.payload or {})
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_credentials(
tenant_id=current_tenant_id,
auth_id=args["credential_id"],
auth_id=payload.credential_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
credentials=args.get("credentials", {}),
name=args.get("name", None),
credentials=payload.credentials or {},
name=payload.name,
)
return {"result": "success"}, 201
@ -258,16 +281,9 @@ class DatasourceHardCodeAuthListApi(Resource):
return {"result": jsonable_encoder(datasources)}, 200
parser_datasource_custom = (
reqparse.RequestParser()
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client")
class DatasourceAuthOauthCustomClient(Resource):
@console_ns.expect(parser_datasource_custom)
@console_ns.expect(console_ns.models[DatasourceCustomClientPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -275,14 +291,14 @@ class DatasourceAuthOauthCustomClient(Resource):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
args = parser_datasource_custom.parse_args()
payload = DatasourceCustomClientPayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.setup_oauth_custom_client_params(
tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id,
client_params=args.get("client_params", {}),
enabled=args.get("enable_oauth_custom_client", False),
client_params=payload.client_params or {},
enabled=payload.enable_oauth_custom_client or False,
)
return {"result": "success"}, 200
@ -301,12 +317,9 @@ class DatasourceAuthOauthCustomClient(Resource):
return {"result": "success"}, 200
parser_default = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/default")
class DatasourceAuthDefaultApi(Resource):
@console_ns.expect(parser_default)
@console_ns.expect(console_ns.models[DatasourceDefaultPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -314,27 +327,20 @@ class DatasourceAuthDefaultApi(Resource):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
args = parser_default.parse_args()
payload = DatasourceDefaultPayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.set_default_datasource_provider(
tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id,
credential_id=args["id"],
credential_id=payload.id,
)
return {"result": "success"}, 200
parser_update_name = (
reqparse.RequestParser()
.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name")
class DatasourceUpdateProviderNameApi(Resource):
@console_ns.expect(parser_update_name)
@console_ns.expect(console_ns.models[DatasourceUpdateNamePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -342,13 +348,13 @@ class DatasourceUpdateProviderNameApi(Resource):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
args = parser_update_name.parse_args()
payload = DatasourceUpdateNamePayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_provider_name(
tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id,
name=args["name"],
credential_id=args["credential_id"],
name=payload.name,
credential_id=payload.credential_id,
)
return {"result": "success"}, 200

View File

@ -26,7 +26,7 @@ console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=D
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
class DataSourceContentPreviewApi(Resource):
@console_ns.expect(console_ns.models[Parser.__name__], validate=True)
@console_ns.expect(console_ns.models[Parser.__name__])
@setup_required
@login_required
@account_initialization_required

View File

@ -1,9 +1,11 @@
import logging
from flask import request
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
@ -20,18 +22,6 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService
logger = logging.getLogger(__name__)
def _validate_name(name: str) -> str:
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description: str) -> str:
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
@console_ns.route("/rag/pipeline/templates")
class PipelineTemplateListApi(Resource):
@setup_required
@ -59,6 +49,15 @@ class PipelineTemplateDetailApi(Resource):
return pipeline_template, 200
class Payload(BaseModel):
name: str = Field(..., min_length=1, max_length=40)
description: str = Field(default="", max_length=400)
icon_info: dict[str, object] | None = None
register_schema_models(console_ns, Payload)
@console_ns.route("/rag/pipeline/customized/templates/<string:template_id>")
class CustomizedPipelineTemplateApi(Resource):
@setup_required
@ -66,31 +65,8 @@ class CustomizedPipelineTemplateApi(Resource):
@account_initialization_required
@enterprise_license_required
def patch(self, template_id: str):
parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument(
"description",
type=_validate_description_length,
nullable=True,
required=False,
default="",
)
.add_argument(
"icon_info",
type=dict,
location="json",
nullable=True,
)
)
args = parser.parse_args()
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args)
payload = Payload.model_validate(console_ns.payload or {})
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(payload.model_dump())
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
return 200
@ -119,36 +95,14 @@ class CustomizedPipelineTemplateApi(Resource):
@console_ns.route("/rag/pipelines/<string:pipeline_id>/customized/publish")
class PublishCustomizedPipelineTemplateApi(Resource):
@console_ns.expect(console_ns.models[Payload.__name__])
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
@knowledge_pipeline_publish_enabled
def post(self, pipeline_id: str):
parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument(
"description",
type=_validate_description_length,
nullable=True,
required=False,
default="",
)
.add_argument(
"icon_info",
type=dict,
location="json",
nullable=True,
)
)
args = parser.parse_args()
payload = Payload.model_validate(console_ns.payload or {})
rag_pipeline_service = RagPipelineService()
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args)
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, payload.model_dump())
return {"result": "success"}

View File

@ -1,8 +1,10 @@
from flask_restx import Resource, marshal, reqparse
from flask_restx import Resource, marshal
from pydantic import BaseModel
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
import services
from controllers.common.schema import register_schema_model
from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import (
@ -19,22 +21,22 @@ from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo,
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
class RagPipelineDatasetImportPayload(BaseModel):
yaml_content: str
register_schema_model(console_ns, RagPipelineDatasetImportPayload)
@console_ns.route("/rag/pipeline/dataset")
class CreateRagPipelineDatasetApi(Resource):
@console_ns.expect(console_ns.models[RagPipelineDatasetImportPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
parser = reqparse.RequestParser().add_argument(
"yaml_content",
type=str,
nullable=False,
required=True,
help="yaml_content is required.",
)
args = parser.parse_args()
payload = RagPipelineDatasetImportPayload.model_validate(console_ns.payload or {})
current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
@ -49,7 +51,7 @@ class CreateRagPipelineDatasetApi(Resource):
),
permission=DatasetPermissionEnum.ONLY_ME,
partial_member_list=None,
yaml_content=args["yaml_content"],
yaml_content=payload.yaml_content,
)
try:
with Session(db.engine) as session:

View File

@ -1,11 +1,13 @@
import logging
from typing import NoReturn
from typing import Any, NoReturn
from flask import Response
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from flask import Response, request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import (
DraftWorkflowNotExist,
@ -33,19 +35,21 @@ logger = logging.getLogger(__name__)
def _create_pagination_parser():
parser = (
reqparse.RequestParser()
.add_argument(
"page",
type=inputs.int_range(1, 100_000),
required=False,
default=1,
location="args",
help="the page of data requested",
)
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
)
return parser
class PaginationQuery(BaseModel):
page: int = Field(default=1, ge=1, le=100_000)
limit: int = Field(default=20, ge=1, le=100)
register_schema_models(console_ns, PaginationQuery)
return PaginationQuery
class WorkflowDraftVariablePatchPayload(BaseModel):
name: str | None = None
value: Any | None = None
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
@ -93,8 +97,8 @@ class RagPipelineVariableCollectionApi(Resource):
"""
Get draft workflow
"""
parser = _create_pagination_parser()
args = parser.parse_args()
pagination = _create_pagination_parser()
query = pagination.model_validate(request.args.to_dict())
# fetch draft workflow by app_model
rag_pipeline_service = RagPipelineService()
@ -109,8 +113,8 @@ class RagPipelineVariableCollectionApi(Resource):
)
workflow_vars = draft_var_srv.list_variables_without_values(
app_id=pipeline.id,
page=args.page,
limit=args.limit,
page=query.page,
limit=query.limit,
)
return workflow_vars
@ -186,6 +190,7 @@ class RagPipelineVariableApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
@console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__])
def patch(self, pipeline: Pipeline, variable_id: str):
# Request payload for file types:
#
@ -208,16 +213,11 @@ class RagPipelineVariableApi(Resource):
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
# }
parser = (
reqparse.RequestParser()
.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
)
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
args = parser.parse_args(strict=True)
payload = WorkflowDraftVariablePatchPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:

View File

@ -1,6 +1,9 @@
from flask_restx import Resource, marshal_with, reqparse # type: ignore
from flask import request
from flask_restx import Resource, marshal_with # type: ignore
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import (
@ -16,6 +19,25 @@ from services.app_dsl_service import ImportStatus
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
class RagPipelineImportPayload(BaseModel):
mode: str
yaml_content: str | None = None
yaml_url: str | None = None
name: str | None = None
description: str | None = None
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
pipeline_id: str | None = None
class IncludeSecretQuery(BaseModel):
include_secret: str = Field(default="false")
register_schema_models(console_ns, RagPipelineImportPayload, IncludeSecretQuery)
@console_ns.route("/rag/pipelines/imports")
class RagPipelineImportApi(Resource):
@setup_required
@ -23,23 +45,11 @@ class RagPipelineImportApi(Resource):
@account_initialization_required
@edit_permission_required
@marshal_with(pipeline_import_fields)
@console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__])
def post(self):
# Check user role first
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("mode", type=str, required=True, location="json")
.add_argument("yaml_content", type=str, location="json")
.add_argument("yaml_url", type=str, location="json")
.add_argument("name", type=str, location="json")
.add_argument("description", type=str, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
.add_argument("pipeline_id", type=str, location="json")
)
args = parser.parse_args()
payload = RagPipelineImportPayload.model_validate(console_ns.payload or {})
# Create service with session
with Session(db.engine) as session:
@ -48,11 +58,11 @@ class RagPipelineImportApi(Resource):
account = current_user
result = import_service.import_rag_pipeline(
account=account,
import_mode=args["mode"],
yaml_content=args.get("yaml_content"),
yaml_url=args.get("yaml_url"),
pipeline_id=args.get("pipeline_id"),
dataset_name=args.get("name"),
import_mode=payload.mode,
yaml_content=payload.yaml_content,
yaml_url=payload.yaml_url,
pipeline_id=payload.pipeline_id,
dataset_name=payload.name,
)
session.commit()
@ -114,13 +124,12 @@ class RagPipelineExportApi(Resource):
@edit_permission_required
def get(self, pipeline: Pipeline):
# Add include_secret params
parser = reqparse.RequestParser().add_argument("include_secret", type=str, default="false", location="args")
args = parser.parse_args()
query = IncludeSecretQuery.model_validate(request.args.to_dict())
with Session(db.engine) as session:
export_service = RagPipelineDslService(session)
result = export_service.export_rag_pipeline_dsl(
pipeline=pipeline, include_secret=args["include_secret"] == "true"
pipeline=pipeline, include_secret=query.include_secret == "true"
)
return {"data": result}, 200

View File

@ -1,14 +1,16 @@
import json
import logging
from typing import cast
from typing import Any, Literal, cast
from uuid import UUID
from flask import abort, request
from flask_restx import Resource, inputs, marshal_with, reqparse # type: ignore # type: ignore
from flask_restx.inputs import int_range # type: ignore
from flask_restx import Resource, marshal_with # type: ignore
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import (
ConversationCompletedError,
@ -36,7 +38,7 @@ from fields.workflow_run_fields import (
workflow_run_pagination_fields,
)
from libs import helper
from libs.helper import TimestampField, uuid_value
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, current_user, login_required
from models import Account
from models.dataset import Pipeline
@ -51,6 +53,91 @@ from services.rag_pipeline.rag_pipeline_transform_service import RagPipelineTran
logger = logging.getLogger(__name__)
class DraftWorkflowSyncPayload(BaseModel):
graph: dict[str, Any]
hash: str | None = None
environment_variables: list[dict[str, Any]] | None = None
conversation_variables: list[dict[str, Any]] | None = None
rag_pipeline_variables: list[dict[str, Any]] | None = None
features: dict[str, Any] | None = None
class NodeRunPayload(BaseModel):
inputs: dict[str, Any] | None = None
class NodeRunRequiredPayload(BaseModel):
inputs: dict[str, Any]
class DatasourceNodeRunPayload(BaseModel):
inputs: dict[str, Any]
datasource_type: str
credential_id: str | None = None
class DraftWorkflowRunPayload(BaseModel):
inputs: dict[str, Any]
datasource_type: str
datasource_info_list: list[dict[str, Any]]
start_node_id: str
class PublishedWorkflowRunPayload(DraftWorkflowRunPayload):
is_preview: bool = False
response_mode: Literal["streaming", "blocking"] = "streaming"
original_document_id: str | None = None
class DefaultBlockConfigQuery(BaseModel):
q: str | None = None
class WorkflowListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=10, ge=1, le=100)
user_id: str | None = None
named_only: bool = False
class WorkflowUpdatePayload(BaseModel):
marked_name: str | None = Field(default=None, max_length=20)
marked_comment: str | None = Field(default=None, max_length=100)
class NodeIdQuery(BaseModel):
node_id: str
class WorkflowRunQuery(BaseModel):
last_id: UUID | None = None
limit: int = Field(default=20, ge=1, le=100)
class DatasourceVariablesPayload(BaseModel):
datasource_type: str
datasource_info: dict[str, Any]
start_node_id: str
start_node_title: str
register_schema_models(
console_ns,
DraftWorkflowSyncPayload,
NodeRunPayload,
NodeRunRequiredPayload,
DatasourceNodeRunPayload,
DraftWorkflowRunPayload,
PublishedWorkflowRunPayload,
DefaultBlockConfigQuery,
WorkflowListQuery,
WorkflowUpdatePayload,
NodeIdQuery,
WorkflowRunQuery,
DatasourceVariablesPayload,
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft")
class DraftRagPipelineApi(Resource):
@setup_required
@ -88,15 +175,7 @@ class DraftRagPipelineApi(Resource):
content_type = request.headers.get("Content-Type", "")
if "application/json" in content_type:
parser = (
reqparse.RequestParser()
.add_argument("graph", type=dict, required=True, nullable=False, location="json")
.add_argument("hash", type=str, required=False, location="json")
.add_argument("environment_variables", type=list, required=False, location="json")
.add_argument("conversation_variables", type=list, required=False, location="json")
.add_argument("rag_pipeline_variables", type=list, required=False, location="json")
)
args = parser.parse_args()
payload_dict = console_ns.payload or {}
elif "text/plain" in content_type:
try:
data = json.loads(request.data.decode("utf-8"))
@ -106,7 +185,7 @@ class DraftRagPipelineApi(Resource):
if not isinstance(data.get("graph"), dict):
raise ValueError("graph is not a dict")
args = {
payload_dict = {
"graph": data.get("graph"),
"features": data.get("features"),
"hash": data.get("hash"),
@ -119,24 +198,26 @@ class DraftRagPipelineApi(Resource):
else:
abort(415)
payload = DraftWorkflowSyncPayload.model_validate(payload_dict)
try:
environment_variables_list = args.get("environment_variables") or []
environment_variables_list = payload.environment_variables or []
environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
conversation_variables_list = args.get("conversation_variables") or []
conversation_variables_list = payload.conversation_variables or []
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.sync_draft_workflow(
pipeline=pipeline,
graph=args["graph"],
unique_hash=args.get("hash"),
graph=payload.graph,
unique_hash=payload.hash,
account=current_user,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
rag_pipeline_variables=args.get("rag_pipeline_variables") or [],
rag_pipeline_variables=payload.rag_pipeline_variables or [],
)
except WorkflowHashNotEqualError:
raise DraftWorkflowNotSync()
@ -148,12 +229,9 @@ class DraftRagPipelineApi(Resource):
}
parser_run = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
class RagPipelineDraftRunIterationNodeApi(Resource):
@console_ns.expect(parser_run)
@console_ns.expect(console_ns.models[NodeRunPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -166,7 +244,8 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
args = parser_run.parse_args()
payload = NodeRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
try:
response = PipelineGenerateService.generate_single_iteration(
@ -187,7 +266,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run")
class RagPipelineDraftRunLoopNodeApi(Resource):
@console_ns.expect(parser_run)
@console_ns.expect(console_ns.models[NodeRunPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -200,7 +279,8 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
args = parser_run.parse_args()
payload = NodeRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
try:
response = PipelineGenerateService.generate_single_loop(
@ -219,18 +299,9 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
raise InternalServerError()
parser_draft_run = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run")
class DraftRagPipelineRunApi(Resource):
@console_ns.expect(parser_draft_run)
@console_ns.expect(console_ns.models[DraftWorkflowRunPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -243,7 +314,8 @@ class DraftRagPipelineRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
args = parser_draft_run.parse_args()
payload = DraftWorkflowRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump()
try:
response = PipelineGenerateService.generate(
@ -259,21 +331,9 @@ class DraftRagPipelineRunApi(Resource):
raise InvokeRateLimitHttpError(ex.description)
parser_published_run = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
.add_argument("is_preview", type=bool, required=True, location="json", default=False)
.add_argument("response_mode", type=str, required=True, location="json", default="streaming")
.add_argument("original_document_id", type=str, required=False, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/run")
class PublishedRagPipelineRunApi(Resource):
@console_ns.expect(parser_published_run)
@console_ns.expect(console_ns.models[PublishedWorkflowRunPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -286,16 +346,16 @@ class PublishedRagPipelineRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
args = parser_published_run.parse_args()
streaming = args["response_mode"] == "streaming"
payload = PublishedWorkflowRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
streaming = payload.response_mode == "streaming"
try:
response = PipelineGenerateService.generate(
pipeline=pipeline,
user=current_user,
args=args,
invoke_from=InvokeFrom.DEBUGGER if args.get("is_preview") else InvokeFrom.PUBLISHED,
invoke_from=InvokeFrom.DEBUGGER if payload.is_preview else InvokeFrom.PUBLISHED,
streaming=streaming,
)
@ -387,17 +447,9 @@ class PublishedRagPipelineRunApi(Resource):
#
# return result
#
parser_rag_run = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@console_ns.expect(parser_rag_run)
@console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -410,14 +462,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
args = parser_rag_run.parse_args()
inputs = args.get("inputs")
if inputs is None:
raise ValueError("missing inputs")
datasource_type = args.get("datasource_type")
if datasource_type is None:
raise ValueError("missing datasource_type")
payload = DatasourceNodeRunPayload.model_validate(console_ns.payload or {})
rag_pipeline_service = RagPipelineService()
return helper.compact_generate_response(
@ -425,11 +470,11 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
rag_pipeline_service.run_datasource_workflow_node(
pipeline=pipeline,
node_id=node_id,
user_inputs=inputs,
user_inputs=payload.inputs,
account=current_user,
datasource_type=datasource_type,
datasource_type=payload.datasource_type,
is_published=False,
credential_id=args.get("credential_id"),
credential_id=payload.credential_id,
)
)
)
@ -437,7 +482,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run")
class RagPipelineDraftDatasourceNodeRunApi(Resource):
@console_ns.expect(parser_rag_run)
@console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__])
@setup_required
@login_required
@edit_permission_required
@ -450,14 +495,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
args = parser_rag_run.parse_args()
inputs = args.get("inputs")
if inputs is None:
raise ValueError("missing inputs")
datasource_type = args.get("datasource_type")
if datasource_type is None:
raise ValueError("missing datasource_type")
payload = DatasourceNodeRunPayload.model_validate(console_ns.payload or {})
rag_pipeline_service = RagPipelineService()
return helper.compact_generate_response(
@ -465,24 +503,19 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
rag_pipeline_service.run_datasource_workflow_node(
pipeline=pipeline,
node_id=node_id,
user_inputs=inputs,
user_inputs=payload.inputs,
account=current_user,
datasource_type=datasource_type,
datasource_type=payload.datasource_type,
is_published=False,
credential_id=args.get("credential_id"),
credential_id=payload.credential_id,
)
)
)
parser_run_api = reqparse.RequestParser().add_argument(
"inputs", type=dict, required=True, nullable=False, location="json"
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run")
class RagPipelineDraftNodeRunApi(Resource):
@console_ns.expect(parser_run_api)
@console_ns.expect(console_ns.models[NodeRunRequiredPayload.__name__])
@setup_required
@login_required
@edit_permission_required
@ -496,11 +529,8 @@ class RagPipelineDraftNodeRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
args = parser_run_api.parse_args()
inputs = args.get("inputs")
if inputs == None:
raise ValueError("missing inputs")
payload = NodeRunRequiredPayload.model_validate(console_ns.payload or {})
inputs = payload.inputs
rag_pipeline_service = RagPipelineService()
workflow_node_execution = rag_pipeline_service.run_draft_workflow_node(
@ -602,12 +632,8 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
return rag_pipeline_service.get_default_block_configs()
parser_default = reqparse.RequestParser().add_argument("q", type=str, location="args")
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>")
class DefaultRagPipelineBlockConfigApi(Resource):
@console_ns.expect(parser_default)
@setup_required
@login_required
@account_initialization_required
@ -617,14 +643,12 @@ class DefaultRagPipelineBlockConfigApi(Resource):
"""
Get default block config
"""
args = parser_default.parse_args()
q = args.get("q")
query = DefaultBlockConfigQuery.model_validate(request.args.to_dict())
filters = None
if q:
if query.q:
try:
filters = json.loads(args.get("q", ""))
filters = json.loads(query.q)
except json.JSONDecodeError:
raise ValueError("Invalid filters")
@ -633,18 +657,8 @@ class DefaultRagPipelineBlockConfigApi(Resource):
return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters)
parser_wf = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
.add_argument("user_id", type=str, required=False, location="args")
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows")
class PublishedAllRagPipelineApi(Resource):
@console_ns.expect(parser_wf)
@setup_required
@login_required
@account_initialization_required
@ -657,16 +671,16 @@ class PublishedAllRagPipelineApi(Resource):
"""
current_user, _ = current_account_with_tenant()
args = parser_wf.parse_args()
page = args["page"]
limit = args["limit"]
user_id = args.get("user_id")
named_only = args.get("named_only", False)
query = WorkflowListQuery.model_validate(request.args.to_dict())
page = query.page
limit = query.limit
user_id = query.user_id
named_only = query.named_only
if user_id:
if user_id != current_user.id:
raise Forbidden()
user_id = cast(str, user_id)
rag_pipeline_service = RagPipelineService()
with Session(db.engine) as session:
@ -687,16 +701,8 @@ class PublishedAllRagPipelineApi(Resource):
}
parser_wf_id = (
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, location="json")
.add_argument("marked_comment", type=str, required=False, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>")
class RagPipelineByIdApi(Resource):
@console_ns.expect(parser_wf_id)
@setup_required
@login_required
@account_initialization_required
@ -710,20 +716,8 @@ class RagPipelineByIdApi(Resource):
# Check permission
current_user, _ = current_account_with_tenant()
args = parser_wf_id.parse_args()
# Validate name and comment length
if args.marked_name and len(args.marked_name) > 20:
raise ValueError("Marked name cannot exceed 20 characters")
if args.marked_comment and len(args.marked_comment) > 100:
raise ValueError("Marked comment cannot exceed 100 characters")
# Prepare update data
update_data = {}
if args.get("marked_name") is not None:
update_data["marked_name"] = args["marked_name"]
if args.get("marked_comment") is not None:
update_data["marked_comment"] = args["marked_comment"]
payload = WorkflowUpdatePayload.model_validate(console_ns.payload or {})
update_data = payload.model_dump(exclude_unset=True)
if not update_data:
return {"message": "No valid fields to update"}, 400
@ -749,12 +743,8 @@ class RagPipelineByIdApi(Resource):
return workflow
parser_parameters = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters")
class PublishedRagPipelineSecondStepApi(Resource):
@console_ns.expect(parser_parameters)
@setup_required
@login_required
@account_initialization_required
@ -764,10 +754,8 @@ class PublishedRagPipelineSecondStepApi(Resource):
"""
Get second step parameters of rag pipeline
"""
args = parser_parameters.parse_args()
node_id = args.get("node_id")
if not node_id:
raise ValueError("Node ID is required")
query = NodeIdQuery.model_validate(request.args.to_dict())
node_id = query.node_id
rag_pipeline_service = RagPipelineService()
variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False)
return {
@ -777,7 +765,6 @@ class PublishedRagPipelineSecondStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters")
class PublishedRagPipelineFirstStepApi(Resource):
@console_ns.expect(parser_parameters)
@setup_required
@login_required
@account_initialization_required
@ -787,10 +774,8 @@ class PublishedRagPipelineFirstStepApi(Resource):
"""
Get first step parameters of rag pipeline
"""
args = parser_parameters.parse_args()
node_id = args.get("node_id")
if not node_id:
raise ValueError("Node ID is required")
query = NodeIdQuery.model_validate(request.args.to_dict())
node_id = query.node_id
rag_pipeline_service = RagPipelineService()
variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False)
return {
@ -800,7 +785,6 @@ class PublishedRagPipelineFirstStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters")
class DraftRagPipelineFirstStepApi(Resource):
@console_ns.expect(parser_parameters)
@setup_required
@login_required
@account_initialization_required
@ -810,10 +794,8 @@ class DraftRagPipelineFirstStepApi(Resource):
"""
Get first step parameters of rag pipeline
"""
args = parser_parameters.parse_args()
node_id = args.get("node_id")
if not node_id:
raise ValueError("Node ID is required")
query = NodeIdQuery.model_validate(request.args.to_dict())
node_id = query.node_id
rag_pipeline_service = RagPipelineService()
variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True)
return {
@ -823,7 +805,6 @@ class DraftRagPipelineFirstStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters")
class DraftRagPipelineSecondStepApi(Resource):
@console_ns.expect(parser_parameters)
@setup_required
@login_required
@account_initialization_required
@ -833,10 +814,8 @@ class DraftRagPipelineSecondStepApi(Resource):
"""
Get second step parameters of rag pipeline
"""
args = parser_parameters.parse_args()
node_id = args.get("node_id")
if not node_id:
raise ValueError("Node ID is required")
query = NodeIdQuery.model_validate(request.args.to_dict())
node_id = query.node_id
rag_pipeline_service = RagPipelineService()
variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True)
@ -845,16 +824,8 @@ class DraftRagPipelineSecondStepApi(Resource):
}
parser_wf_run = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs")
class RagPipelineWorkflowRunListApi(Resource):
@console_ns.expect(parser_wf_run)
@setup_required
@login_required
@account_initialization_required
@ -864,7 +835,16 @@ class RagPipelineWorkflowRunListApi(Resource):
"""
Get workflow run list
"""
args = parser_wf_run.parse_args()
query = WorkflowRunQuery.model_validate(
{
"last_id": request.args.get("last_id"),
"limit": request.args.get("limit", type=int, default=20),
}
)
args = {
"last_id": str(query.last_id) if query.last_id else None,
"limit": query.limit,
}
rag_pipeline_service = RagPipelineService()
result = rag_pipeline_service.get_rag_pipeline_paginate_workflow_runs(pipeline=pipeline, args=args)
@ -964,18 +944,9 @@ class RagPipelineTransformApi(Resource):
return result
parser_var = (
reqparse.RequestParser()
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info", type=dict, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
.add_argument("start_node_title", type=str, required=True, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect")
class RagPipelineDatasourceVariableApi(Resource):
@console_ns.expect(parser_var)
@console_ns.expect(console_ns.models[DatasourceVariablesPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -987,7 +958,7 @@ class RagPipelineDatasourceVariableApi(Resource):
Set datasource variables
"""
current_user, _ = current_account_with_tenant()
args = parser_var.parse_args()
args = DatasourceVariablesPayload.model_validate(console_ns.payload or {}).model_dump()
rag_pipeline_service = RagPipelineService()
workflow_node_execution = rag_pipeline_service.set_datasource_variables(

View File

@ -1,5 +1,10 @@
from flask_restx import Resource, fields, reqparse
from typing import Literal
from flask import request
from flask_restx import Resource
from pydantic import BaseModel
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.datasets.error import WebsiteCrawlError
from controllers.console.wraps import account_initialization_required, setup_required
@ -7,48 +12,35 @@ from libs.login import login_required
from services.website_service import WebsiteCrawlApiRequest, WebsiteCrawlStatusApiRequest, WebsiteService
class WebsiteCrawlPayload(BaseModel):
provider: Literal["firecrawl", "watercrawl", "jinareader"]
url: str
options: dict[str, object]
class WebsiteCrawlStatusQuery(BaseModel):
provider: Literal["firecrawl", "watercrawl", "jinareader"]
register_schema_models(console_ns, WebsiteCrawlPayload, WebsiteCrawlStatusQuery)
@console_ns.route("/website/crawl")
class WebsiteCrawlApi(Resource):
@console_ns.doc("crawl_website")
@console_ns.doc(description="Crawl website content")
@console_ns.expect(
console_ns.model(
"WebsiteCrawlRequest",
{
"provider": fields.String(
required=True,
description="Crawl provider (firecrawl/watercrawl/jinareader)",
enum=["firecrawl", "watercrawl", "jinareader"],
),
"url": fields.String(required=True, description="URL to crawl"),
"options": fields.Raw(required=True, description="Crawl options"),
},
)
)
@console_ns.expect(console_ns.models[WebsiteCrawlPayload.__name__])
@console_ns.response(200, "Website crawl initiated successfully")
@console_ns.response(400, "Invalid crawl parameters")
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = (
reqparse.RequestParser()
.add_argument(
"provider",
type=str,
choices=["firecrawl", "watercrawl", "jinareader"],
required=True,
nullable=True,
location="json",
)
.add_argument("url", type=str, required=True, nullable=True, location="json")
.add_argument("options", type=dict, required=True, nullable=True, location="json")
)
args = parser.parse_args()
payload = WebsiteCrawlPayload.model_validate(console_ns.payload or {})
# Create typed request and validate
try:
api_request = WebsiteCrawlApiRequest.from_args(args)
api_request = WebsiteCrawlApiRequest.from_args(payload.model_dump())
except ValueError as e:
raise WebsiteCrawlError(str(e))
@ -65,6 +57,7 @@ class WebsiteCrawlStatusApi(Resource):
@console_ns.doc("get_crawl_status")
@console_ns.doc(description="Get website crawl status")
@console_ns.doc(params={"job_id": "Crawl job ID", "provider": "Crawl provider (firecrawl/watercrawl/jinareader)"})
@console_ns.expect(console_ns.models[WebsiteCrawlStatusQuery.__name__])
@console_ns.response(200, "Crawl status retrieved successfully")
@console_ns.response(404, "Crawl job not found")
@console_ns.response(400, "Invalid provider")
@ -72,14 +65,11 @@ class WebsiteCrawlStatusApi(Resource):
@login_required
@account_initialization_required
def get(self, job_id: str):
parser = reqparse.RequestParser().add_argument(
"provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args"
)
args = parser.parse_args()
args = WebsiteCrawlStatusQuery.model_validate(request.args.to_dict())
# Create typed request and validate
try:
api_request = WebsiteCrawlStatusApiRequest.from_args(args, job_id)
api_request = WebsiteCrawlStatusApiRequest.from_args(args.model_dump(), job_id)
except ValueError as e:
raise WebsiteCrawlError(str(e))

View File

@ -1,9 +1,11 @@
import logging
from flask import request
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError
import services
from controllers.common.schema import register_schema_model
from controllers.console.app.error import (
AppUnavailableError,
AudioTooLargeError,
@ -31,6 +33,16 @@ from .. import console_ns
logger = logging.getLogger(__name__)
class TextToAudioPayload(BaseModel):
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = Field(default=None, description="Enable streaming response")
register_schema_model(console_ns, TextToAudioPayload)
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/audio-to-text",
endpoint="installed_app_audio",
@ -76,23 +88,15 @@ class ChatAudioApi(InstalledAppResource):
endpoint="installed_app_text",
)
class ChatTextApi(InstalledAppResource):
@console_ns.expect(console_ns.models[TextToAudioPayload.__name__])
def post(self, installed_app):
from flask_restx import reqparse
app_model = installed_app.app
try:
parser = (
reqparse.RequestParser()
.add_argument("message_id", type=str, required=False, location="json")
.add_argument("voice", type=str, location="json")
.add_argument("text", type=str, location="json")
.add_argument("streaming", type=bool, location="json")
)
args = parser.parse_args()
payload = TextToAudioPayload.model_validate(console_ns.payload or {})
message_id = args.get("message_id", None)
text = args.get("text", None)
voice = args.get("voice", None)
message_id = payload.message_id
text = payload.text
voice = payload.voice
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
return response

View File

@ -1,9 +1,12 @@
import logging
from typing import Any, Literal
from uuid import UUID
from flask_restx import reqparse
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import InternalServerError, NotFound
import services
from controllers.common.schema import register_schema_models
from controllers.console.app.error import (
AppUnavailableError,
CompletionRequestError,
@ -25,7 +28,6 @@ from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.helper import uuid_value
from libs.login import current_user
from models import Account
from models.model import AppMode
@ -38,28 +40,56 @@ from .. import console_ns
logger = logging.getLogger(__name__)
class CompletionMessagePayload(BaseModel):
inputs: dict[str, Any]
query: str = ""
files: list[dict[str, Any]] | None = None
response_mode: Literal["blocking", "streaming"] | None = None
retriever_from: str = Field(default="explore_app")
class ChatMessagePayload(BaseModel):
inputs: dict[str, Any]
query: str
files: list[dict[str, Any]] | None = None
conversation_id: str | None = None
parent_message_id: str | None = None
retriever_from: str = Field(default="explore_app")
@field_validator("conversation_id", "parent_message_id", mode="before")
@classmethod
def normalize_uuid(cls, value: str | UUID | None) -> str | None:
"""
Accept blank IDs and validate UUID format when provided.
"""
if not value:
return None
try:
return helper.uuid_value(value)
except ValueError as exc:
raise ValueError("must be a valid UUID") from exc
register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload)
# define completion api for user
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/completion-messages",
endpoint="installed_app_completion",
)
class CompletionApi(InstalledAppResource):
@console_ns.expect(console_ns.models[CompletionMessagePayload.__name__])
def post(self, installed_app):
app_model = installed_app.app
if app_model.mode != AppMode.COMPLETION:
raise NotCompletionAppError()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, location="json", default="")
.add_argument("files", type=list, required=False, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
)
args = parser.parse_args()
payload = CompletionMessagePayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
streaming = args["response_mode"] == "streaming"
streaming = payload.response_mode == "streaming"
args["auto_generate_name"] = False
installed_app.last_used_at = naive_utc_now()
@ -123,22 +153,15 @@ class CompletionStopApi(InstalledAppResource):
endpoint="installed_app_chat_completion",
)
class ChatApi(InstalledAppResource):
@console_ns.expect(console_ns.models[ChatMessagePayload.__name__])
def post(self, installed_app):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, required=True, location="json")
.add_argument("files", type=list, required=False, location="json")
.add_argument("conversation_id", type=uuid_value, location="json")
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
)
args = parser.parse_args()
payload = ChatMessagePayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
args["auto_generate_name"] = False

View File

@ -1,14 +1,18 @@
from flask_restx import marshal_with, reqparse
from flask_restx.inputs import int_range
from typing import Any
from uuid import UUID
from flask import request
from flask_restx import marshal_with
from pydantic import BaseModel, Field, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console.explore.error import NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value
from libs.login import current_user
from models import Account
from models.model import AppMode
@ -19,29 +23,51 @@ from services.web_conversation_service import WebConversationService
from .. import console_ns
class ConversationListQuery(BaseModel):
last_id: UUID | None = None
limit: int = Field(default=20, ge=1, le=100)
pinned: bool | None = None
class ConversationRenamePayload(BaseModel):
name: str | None = None
auto_generate: bool = False
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload)
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/conversations",
endpoint="installed_app_conversations",
)
class ConversationListApi(InstalledAppResource):
@marshal_with(conversation_infinite_scroll_pagination_fields)
@console_ns.expect(console_ns.models[ConversationListQuery.__name__])
def get(self, installed_app):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
.add_argument("pinned", type=str, choices=["true", "false", None], location="args")
)
args = parser.parse_args()
pinned = None
if "pinned" in args and args["pinned"] is not None:
pinned = args["pinned"] == "true"
raw_args: dict[str, Any] = {
"last_id": request.args.get("last_id"),
"limit": request.args.get("limit", default=20, type=int),
"pinned": request.args.get("pinned"),
}
if raw_args["last_id"] is None:
raw_args["last_id"] = None
pinned_value = raw_args["pinned"]
if isinstance(pinned_value, str):
raw_args["pinned"] = pinned_value == "true"
args = ConversationListQuery.model_validate(raw_args)
try:
if not isinstance(current_user, Account):
@ -51,10 +77,10 @@ class ConversationListApi(InstalledAppResource):
session=session,
app_model=app_model,
user=current_user,
last_id=args["last_id"],
limit=args["limit"],
last_id=str(args.last_id) if args.last_id else None,
limit=args.limit,
invoke_from=InvokeFrom.EXPLORE,
pinned=pinned,
pinned=args.pinned,
)
except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
@ -88,6 +114,7 @@ class ConversationApi(InstalledAppResource):
)
class ConversationRenameApi(InstalledAppResource):
@marshal_with(simple_conversation_fields)
@console_ns.expect(console_ns.models[ConversationRenamePayload.__name__])
def post(self, installed_app, c_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
@ -96,18 +123,13 @@ class ConversationRenameApi(InstalledAppResource):
conversation_id = str(c_id)
parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=False, location="json")
.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
)
args = parser.parse_args()
payload = ConversationRenamePayload.model_validate(console_ns.payload or {})
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
return ConversationService.rename(
app_model, conversation_id, current_user, args["name"], args["auto_generate"]
app_model, conversation_id, current_user, payload.name, payload.auto_generate
)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")

View File

@ -1,9 +1,13 @@
import logging
from typing import Literal
from uuid import UUID
from flask_restx import marshal_with, reqparse
from flask_restx.inputs import int_range
from flask import request
from flask_restx import marshal_with
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models
from controllers.console.app.error import (
AppMoreLikeThisDisabledError,
CompletionRequestError,
@ -22,7 +26,6 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from core.model_runtime.errors.invoke import InvokeError
from fields.message_fields import message_infinite_scroll_pagination_fields
from libs import helper
from libs.helper import uuid_value
from libs.login import current_account_with_tenant
from models.model import AppMode
from services.app_generate_service import AppGenerateService
@ -40,12 +43,31 @@ from .. import console_ns
logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel):
conversation_id: UUID
first_id: UUID | None = None
limit: int = Field(default=20, ge=1, le=100)
class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = None
content: str | None = None
class MoreLikeThisQuery(BaseModel):
response_mode: Literal["blocking", "streaming"]
register_schema_models(console_ns, MessageListQuery, MessageFeedbackPayload, MoreLikeThisQuery)
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/messages",
endpoint="installed_app_messages",
)
class MessageListApi(InstalledAppResource):
@marshal_with(message_infinite_scroll_pagination_fields)
@console_ns.expect(console_ns.models[MessageListQuery.__name__])
def get(self, installed_app):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
@ -53,18 +75,15 @@ class MessageListApi(InstalledAppResource):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = (
reqparse.RequestParser()
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
.add_argument("first_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args()
args = MessageListQuery.model_validate(request.args.to_dict())
try:
return MessageService.pagination_by_first_id(
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
app_model,
current_user,
str(args.conversation_id),
str(args.first_id) if args.first_id else None,
args.limit,
)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@ -77,26 +96,22 @@ class MessageListApi(InstalledAppResource):
endpoint="installed_app_message_feedback",
)
class MessageFeedbackApi(InstalledAppResource):
@console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
def post(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
message_id = str(message_id)
parser = (
reqparse.RequestParser()
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
.add_argument("content", type=str, location="json")
)
args = parser.parse_args()
payload = MessageFeedbackPayload.model_validate(console_ns.payload or {})
try:
MessageService.create_feedback(
app_model=app_model,
message_id=message_id,
user=current_user,
rating=args.get("rating"),
content=args.get("content"),
rating=payload.rating,
content=payload.content,
)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
@ -109,6 +124,7 @@ class MessageFeedbackApi(InstalledAppResource):
endpoint="installed_app_more_like_this",
)
class MessageMoreLikeThisApi(InstalledAppResource):
@console_ns.expect(console_ns.models[MoreLikeThisQuery.__name__])
def get(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
@ -117,12 +133,9 @@ class MessageMoreLikeThisApi(InstalledAppResource):
message_id = str(message_id)
parser = reqparse.RequestParser().add_argument(
"response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args"
)
args = parser.parse_args()
args = MoreLikeThisQuery.model_validate(request.args.to_dict())
streaming = args["response_mode"] == "streaming"
streaming = args.response_mode == "streaming"
try:
response = AppGenerateService.generate_more_like_this(

View File

@ -1,4 +1,6 @@
from flask_restx import Resource, fields, marshal_with, reqparse
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from constants.languages import languages
from controllers.console import console_ns
@ -35,20 +37,26 @@ recommended_app_list_fields = {
}
parser_apps = reqparse.RequestParser().add_argument("language", type=str, location="args")
class RecommendedAppsQuery(BaseModel):
language: str | None = Field(default=None)
console_ns.schema_model(
RecommendedAppsQuery.__name__,
RecommendedAppsQuery.model_json_schema(ref_template="#/definitions/{model}"),
)
@console_ns.route("/explore/apps")
class RecommendedAppListApi(Resource):
@console_ns.expect(parser_apps)
@console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__])
@login_required
@account_initialization_required
@marshal_with(recommended_app_list_fields)
def get(self):
# language args
args = parser_apps.parse_args()
language = args.get("language")
args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
language = args.language
if language and language in languages:
language_prefix = language
elif current_user and current_user.interface_language:

View File

@ -1,16 +1,33 @@
from flask_restx import fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from uuid import UUID
from flask import request
from flask_restx import fields, marshal_with
from pydantic import BaseModel, Field
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, uuid_value
from libs.helper import TimestampField
from libs.login import current_account_with_tenant
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
class SavedMessageListQuery(BaseModel):
last_id: UUID | None = None
limit: int = Field(default=20, ge=1, le=100)
class SavedMessageCreatePayload(BaseModel):
message_id: UUID
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)
feedback_fields = {"rating": fields.String}
message_fields = {
@ -33,32 +50,33 @@ class SavedMessageListApi(InstalledAppResource):
}
@marshal_with(saved_message_infinite_scroll_pagination_fields)
@console_ns.expect(console_ns.models[SavedMessageListQuery.__name__])
def get(self, installed_app):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if app_model.mode != "completion":
raise NotCompletionAppError()
parser = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
args = SavedMessageListQuery.model_validate(request.args.to_dict())
return SavedMessageService.pagination_by_last_id(
app_model,
current_user,
str(args.last_id) if args.last_id else None,
args.limit,
)
args = parser.parse_args()
return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"])
@console_ns.expect(console_ns.models[SavedMessageCreatePayload.__name__])
def post(self, installed_app):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if app_model.mode != "completion":
raise NotCompletionAppError()
parser = reqparse.RequestParser().add_argument("message_id", type=uuid_value, required=True, location="json")
args = parser.parse_args()
payload = SavedMessageCreatePayload.model_validate(console_ns.payload or {})
try:
SavedMessageService.save(app_model, current_user, args["message_id"])
SavedMessageService.save(app_model, current_user, str(payload.message_id))
except MessageNotExistsError:
raise NotFound("Message Not Exists.")

View File

@ -1,8 +1,10 @@
import logging
from typing import Any
from flask_restx import reqparse
from pydantic import BaseModel
from werkzeug.exceptions import InternalServerError
from controllers.common.schema import register_schema_model
from controllers.console.app.error import (
CompletionRequestError,
ProviderModelCurrentlyNotSupportError,
@ -32,8 +34,17 @@ from .. import console_ns
logger = logging.getLogger(__name__)
class WorkflowRunPayload(BaseModel):
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
register_schema_model(console_ns, WorkflowRunPayload)
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
class InstalledAppWorkflowRunApi(InstalledAppResource):
@console_ns.expect(console_ns.models[WorkflowRunPayload.__name__])
def post(self, installed_app: InstalledApp):
"""
Run workflow
@ -46,12 +57,8 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("files", type=list, required=False, location="json")
)
args = parser.parse_args()
payload = WorkflowRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
try:
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True

View File

@ -45,6 +45,9 @@ class FileApi(Resource):
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
"image_file_batch_limit": dify_config.IMAGE_FILE_BATCH_LIMIT,
"single_chunk_attachment_limit": dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT,
"attachment_image_file_size_limit": dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT,
}, 200
@setup_required

View File

@ -1,13 +1,13 @@
import os
from flask import session
from flask_restx import Resource, fields, reqparse
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from extensions.ext_database import db
from libs.helper import StrLen
from models.model import DifySetup
from services.account_service import TenantService
@ -15,6 +15,18 @@ from . import console_ns
from .error import AlreadySetupError, InitValidateFailedError
from .wraps import only_edition_self_hosted
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class InitValidatePayload(BaseModel):
password: str = Field(..., max_length=30)
console_ns.schema_model(
InitValidatePayload.__name__,
InitValidatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/init")
class InitValidateAPI(Resource):
@ -37,12 +49,7 @@ class InitValidateAPI(Resource):
@console_ns.doc("validate_init_password")
@console_ns.doc(description="Validate initialization password for self-hosted edition")
@console_ns.expect(
console_ns.model(
"InitValidateRequest",
{"password": fields.String(required=True, description="Initialization password", max_length=30)},
)
)
@console_ns.expect(console_ns.models[InitValidatePayload.__name__])
@console_ns.response(
201,
"Success",
@ -57,8 +64,8 @@ class InitValidateAPI(Resource):
if tenant_count > 0:
raise AlreadySetupError()
parser = reqparse.RequestParser().add_argument("password", type=StrLen(30), required=True, location="json")
input_password = parser.parse_args()["password"]
payload = InitValidatePayload.model_validate(console_ns.payload)
input_password = payload.password
if input_password != os.environ.get("INIT_PASSWORD"):
session["is_init_validated"] = False

View File

@ -1,7 +1,8 @@
import urllib.parse
import httpx
from flask_restx import Resource, marshal_with, reqparse
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
import services
from controllers.common import helpers
@ -36,17 +37,23 @@ class RemoteFileInfoApi(Resource):
}
parser_upload = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required")
class RemoteFileUploadPayload(BaseModel):
url: str = Field(..., description="URL to fetch")
console_ns.schema_model(
RemoteFileUploadPayload.__name__,
RemoteFileUploadPayload.model_json_schema(ref_template="#/definitions/{model}"),
)
@console_ns.route("/remote-files/upload")
class RemoteFileUploadApi(Resource):
@console_ns.expect(parser_upload)
@console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__])
@marshal_with(file_fields_with_signed_url)
def post(self):
args = parser_upload.parse_args()
url = args["url"]
args = RemoteFileUploadPayload.model_validate(console_ns.payload)
url = args.url
try:
resp = ssrf_proxy.head(url=url)

View File

@ -1,8 +1,9 @@
from flask import request
from flask_restx import Resource, fields, reqparse
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from configs import dify_config
from libs.helper import StrLen, email, extract_remote_ip
from libs.helper import EmailStr, extract_remote_ip
from libs.password import valid_password
from models.model import DifySetup, db
from services.account_service import RegisterService, TenantService
@ -12,6 +13,26 @@ from .error import AlreadySetupError, NotInitValidateError
from .init_validate import get_init_validate_status
from .wraps import only_edition_self_hosted
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class SetupRequestPayload(BaseModel):
email: EmailStr = Field(..., description="Admin email address")
name: str = Field(..., max_length=30, description="Admin name (max 30 characters)")
password: str = Field(..., description="Admin password")
language: str | None = Field(default=None, description="Admin language")
@field_validator("password")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
console_ns.schema_model(
SetupRequestPayload.__name__,
SetupRequestPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/setup")
class SetupApi(Resource):
@ -42,17 +63,7 @@ class SetupApi(Resource):
@console_ns.doc("setup_system")
@console_ns.doc(description="Initialize system setup with admin account")
@console_ns.expect(
console_ns.model(
"SetupRequest",
{
"email": fields.String(required=True, description="Admin email address"),
"name": fields.String(required=True, description="Admin name (max 30 characters)"),
"password": fields.String(required=True, description="Admin password"),
"language": fields.String(required=False, description="Admin language"),
},
)
)
@console_ns.expect(console_ns.models[SetupRequestPayload.__name__])
@console_ns.response(
201, "Success", console_ns.model("SetupResponse", {"result": fields.String(description="Setup result")})
)
@ -72,22 +83,15 @@ class SetupApi(Resource):
if not get_init_validate_status():
raise NotInitValidateError()
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("name", type=StrLen(30), required=True, location="json")
.add_argument("password", type=valid_password, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
args = SetupRequestPayload.model_validate(console_ns.payload)
# setup
RegisterService.setup(
email=args["email"],
name=args["name"],
password=args["password"],
email=args.email,
name=args.name,
password=args.password,
ip_address=extract_remote_ip(request),
language=args["language"],
language=args.language,
)
return {"result": "success"}, 201

View File

@ -2,8 +2,10 @@ import json
import logging
import httpx
from flask_restx import Resource, fields, reqparse
from flask import request
from flask_restx import Resource, fields
from packaging import version
from pydantic import BaseModel, Field
from configs import dify_config
@ -11,8 +13,14 @@ from . import console_ns
logger = logging.getLogger(__name__)
parser = reqparse.RequestParser().add_argument(
"current_version", type=str, required=True, location="args", help="Current application version"
class VersionQuery(BaseModel):
current_version: str = Field(..., description="Current application version")
console_ns.schema_model(
VersionQuery.__name__,
VersionQuery.model_json_schema(ref_template="#/definitions/{model}"),
)
@ -20,7 +28,7 @@ parser = reqparse.RequestParser().add_argument(
class VersionApi(Resource):
@console_ns.doc("check_version_update")
@console_ns.doc(description="Check for application version updates")
@console_ns.expect(parser)
@console_ns.expect(console_ns.models[VersionQuery.__name__])
@console_ns.response(
200,
"Success",
@ -37,7 +45,7 @@ class VersionApi(Resource):
)
def get(self):
"""Check for application version updates"""
args = parser.parse_args()
args = VersionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
check_update_url = dify_config.CHECK_UPDATE_URL
result = {
@ -57,16 +65,16 @@ class VersionApi(Resource):
try:
response = httpx.get(
check_update_url,
params={"current_version": args["current_version"]},
params={"current_version": args.current_version},
timeout=httpx.Timeout(timeout=10.0, connect=3.0),
)
except Exception as error:
logger.warning("Check update version error: %s.", str(error))
result["version"] = args["current_version"]
result["version"] = args.current_version
return result
content = json.loads(response.content)
if _has_new_version(latest_version=content["version"], current_version=f"{args['current_version']}"):
if _has_new_version(latest_version=content["version"], current_version=f"{args.current_version}"):
result["version"] = content["version"]
result["release_date"] = content["releaseDate"]
result["release_notes"] = content["releaseNotes"]

View File

@ -37,7 +37,7 @@ from controllers.console.wraps import (
from extensions.ext_database import db
from fields.member_fields import account_fields
from libs.datetime_utils import naive_utc_now
from libs.helper import TimestampField, email, extract_remote_ip, timezone
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
from libs.login import current_account_with_tenant, login_required
from models import Account, AccountIntegrate, InvitationCode
from services.account_service import AccountService
@ -111,14 +111,9 @@ class AccountDeletePayload(BaseModel):
class AccountDeletionFeedbackPayload(BaseModel):
email: str
email: EmailStr
feedback: str
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
class EducationActivatePayload(BaseModel):
token: str
@ -133,45 +128,25 @@ class EducationAutocompleteQuery(BaseModel):
class ChangeEmailSendPayload(BaseModel):
email: str
email: EmailStr
language: str | None = None
phase: str | None = None
token: str | None = None
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
class ChangeEmailValidityPayload(BaseModel):
email: str
email: EmailStr
code: str
token: str
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
class ChangeEmailResetPayload(BaseModel):
new_email: str
new_email: EmailStr
token: str
@field_validator("new_email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
class CheckEmailUniquePayload(BaseModel):
email: str
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
email: EmailStr
def reg(cls: type[BaseModel]):

View File

@ -230,7 +230,7 @@ class ModelProviderModelApi(Resource):
return {"result": "success"}, 200
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__], validate=True)
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@ -282,9 +282,10 @@ class ModelProviderModelCredentialApi(Resource):
tenant_id=tenant_id, provider_name=provider
)
else:
model_type = args.model_type
# Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM)
normalized_model_type = args.model_type.to_origin_model_type()
available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials(
tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args.model
tenant_id=tenant_id, provider_name=provider, model_type=normalized_model_type, model_name=args.model
)
return jsonable_encoder(

View File

@ -46,8 +46,8 @@ class PluginDebuggingKeyApi(Resource):
class ParserList(BaseModel):
page: int = Field(default=1)
page_size: int = Field(default=256)
page: int = Field(default=1, ge=1, description="Page number")
page_size: int = Field(default=256, ge=1, le=256, description="Page size (1-256)")
reg(ParserList)
@ -106,8 +106,8 @@ class ParserPluginIdentifierQuery(BaseModel):
class ParserTasks(BaseModel):
page: int
page_size: int
page: int = Field(default=1, ge=1, description="Page number")
page_size: int = Field(default=256, ge=1, le=256, description="Page size (1-256)")
class ParserMarketplaceUpgrade(BaseModel):

View File

@ -22,7 +22,12 @@ from services.trigger.trigger_subscription_builder_service import TriggerSubscri
from services.trigger.trigger_subscription_operator_service import TriggerSubscriptionOperatorService
from .. import console_ns
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from ..wraps import (
account_initialization_required,
edit_permission_required,
is_admin_or_owner_required,
setup_required,
)
logger = logging.getLogger(__name__)
@ -72,7 +77,7 @@ class TriggerProviderInfoApi(Resource):
class TriggerSubscriptionListApi(Resource):
@setup_required
@login_required
@is_admin_or_owner_required
@edit_permission_required
@account_initialization_required
def get(self, provider):
"""List all trigger subscriptions for the current tenant's provider"""
@ -104,7 +109,7 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
@console_ns.expect(parser)
@setup_required
@login_required
@is_admin_or_owner_required
@edit_permission_required
@account_initialization_required
def post(self, provider):
"""Add a new subscription instance for a trigger provider"""
@ -133,6 +138,7 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
class TriggerSubscriptionBuilderGetApi(Resource):
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
def get(self, provider, subscription_builder_id):
"""Get a subscription instance for a trigger provider"""
@ -155,7 +161,7 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
@console_ns.expect(parser_api)
@setup_required
@login_required
@is_admin_or_owner_required
@edit_permission_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
"""Verify a subscription instance for a trigger provider"""
@ -200,6 +206,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
@console_ns.expect(parser_update_api)
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
"""Update a subscription instance for a trigger provider"""
@ -233,6 +240,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
class TriggerSubscriptionBuilderLogsApi(Resource):
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
def get(self, provider, subscription_builder_id):
"""Get the request logs for a subscription instance for a trigger provider"""
@ -255,7 +263,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
@console_ns.expect(parser_update_api)
@setup_required
@login_required
@is_admin_or_owner_required
@edit_permission_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
"""Build a subscription instance for a trigger provider"""

View File

@ -331,3 +331,91 @@ def is_admin_or_owner_required(f: Callable[P, R]):
return f(*args, **kwargs)
return decorated_function
def annotation_import_rate_limit(view: Callable[P, R]):
"""
Rate limiting decorator for annotation import operations.
Implements sliding window rate limiting with two tiers:
- Short-term: Configurable requests per minute (default: 5)
- Long-term: Configurable requests per hour (default: 20)
Uses Redis ZSET for distributed rate limiting across multiple instances.
"""
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
current_time = int(time.time() * 1000)
# Check per-minute rate limit
minute_key = f"annotation_import_rate_limit:{current_tenant_id}:1min"
redis_client.zadd(minute_key, {current_time: current_time})
redis_client.zremrangebyscore(minute_key, 0, current_time - 60000)
minute_count = redis_client.zcard(minute_key)
redis_client.expire(minute_key, 120) # 2 minutes TTL
if minute_count > dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE:
abort(
429,
f"Too many annotation import requests. Maximum {dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE} "
f"requests per minute allowed. Please try again later.",
)
# Check per-hour rate limit
hour_key = f"annotation_import_rate_limit:{current_tenant_id}:1hour"
redis_client.zadd(hour_key, {current_time: current_time})
redis_client.zremrangebyscore(hour_key, 0, current_time - 3600000)
hour_count = redis_client.zcard(hour_key)
redis_client.expire(hour_key, 7200) # 2 hours TTL
if hour_count > dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR:
abort(
429,
f"Too many annotation import requests. Maximum {dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR} "
f"requests per hour allowed. Please try again later.",
)
return view(*args, **kwargs)
return decorated
def annotation_import_concurrency_limit(view: Callable[P, R]):
"""
Concurrency control decorator for annotation import operations.
Limits the number of concurrent import tasks per tenant to prevent
resource exhaustion and ensure fair resource allocation.
Uses Redis ZSET to track active import jobs with automatic cleanup
of stale entries (jobs older than 2 minutes).
"""
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
current_time = int(time.time() * 1000)
active_jobs_key = f"annotation_import_active:{current_tenant_id}"
# Clean up stale entries (jobs that should have completed or timed out)
stale_threshold = current_time - 120000 # 2 minutes ago
redis_client.zremrangebyscore(active_jobs_key, 0, stale_threshold)
# Check current active job count
active_count = redis_client.zcard(active_jobs_key)
if active_count >= dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT:
abort(
429,
f"Too many concurrent import tasks. Maximum {dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT} "
f"concurrent imports allowed per workspace. Please wait for existing imports to complete.",
)
# Allow the request to proceed
# The actual job registration will happen in the service layer
return view(*args, **kwargs)
return decorated

View File

@ -1,7 +1,8 @@
from urllib.parse import quote
from flask import Response, request
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.exceptions import NotFound
import services
@ -11,6 +12,26 @@ from extensions.ext_database import db
from services.account_service import TenantService
from services.file_service import FileService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class FileSignatureQuery(BaseModel):
timestamp: str = Field(..., description="Unix timestamp used in the signature")
nonce: str = Field(..., description="Random string for signature")
sign: str = Field(..., description="HMAC signature")
class FilePreviewQuery(FileSignatureQuery):
as_attachment: bool = Field(default=False, description="Whether to download as attachment")
files_ns.schema_model(
FileSignatureQuery.__name__, FileSignatureQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
files_ns.schema_model(
FilePreviewQuery.__name__, FilePreviewQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@files_ns.route("/<uuid:file_id>/image-preview")
class ImagePreviewApi(Resource):
@ -36,12 +57,10 @@ class ImagePreviewApi(Resource):
def get(self, file_id):
file_id = str(file_id)
timestamp = request.args.get("timestamp")
nonce = request.args.get("nonce")
sign = request.args.get("sign")
if not timestamp or not nonce or not sign:
return {"content": "Invalid request."}, 400
args = FileSignatureQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
timestamp = args.timestamp
nonce = args.nonce
sign = args.sign
try:
generator, mimetype = FileService(db.engine).get_image_preview(
@ -80,25 +99,14 @@ class FilePreviewApi(Resource):
def get(self, file_id):
file_id = str(file_id)
parser = (
reqparse.RequestParser()
.add_argument("timestamp", type=str, required=True, location="args")
.add_argument("nonce", type=str, required=True, location="args")
.add_argument("sign", type=str, required=True, location="args")
.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
)
args = parser.parse_args()
if not args["timestamp"] or not args["nonce"] or not args["sign"]:
return {"content": "Invalid request."}, 400
args = FilePreviewQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
generator, upload_file = FileService(db.engine).get_file_generator_by_file_id(
file_id=file_id,
timestamp=args["timestamp"],
nonce=args["nonce"],
sign=args["sign"],
timestamp=args.timestamp,
nonce=args.nonce,
sign=args.sign,
)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
@ -125,7 +133,7 @@ class FilePreviewApi(Resource):
response.headers["Accept-Ranges"] = "bytes"
if upload_file.size > 0:
response.headers["Content-Length"] = str(upload_file.size)
if args["as_attachment"]:
if args.as_attachment:
encoded_filename = quote(upload_file.name)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
response.headers["Content-Type"] = "application/octet-stream"

View File

@ -1,7 +1,8 @@
from urllib.parse import quote
from flask import Response
from flask_restx import Resource, reqparse
from flask import Response, request
from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, NotFound
from controllers.common.errors import UnsupportedFileTypeError
@ -10,6 +11,20 @@ from core.tools.signature import verify_tool_file_signature
from core.tools.tool_file_manager import ToolFileManager
from extensions.ext_database import db as global_db
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ToolFileQuery(BaseModel):
timestamp: str = Field(..., description="Unix timestamp")
nonce: str = Field(..., description="Random nonce")
sign: str = Field(..., description="HMAC signature")
as_attachment: bool = Field(default=False, description="Download as attachment")
files_ns.schema_model(
ToolFileQuery.__name__, ToolFileQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@files_ns.route("/tools/<uuid:file_id>.<string:extension>")
class ToolFileApi(Resource):
@ -36,18 +51,8 @@ class ToolFileApi(Resource):
def get(self, file_id, extension):
file_id = str(file_id)
parser = (
reqparse.RequestParser()
.add_argument("timestamp", type=str, required=True, location="args")
.add_argument("nonce", type=str, required=True, location="args")
.add_argument("sign", type=str, required=True, location="args")
.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
)
args = parser.parse_args()
if not verify_tool_file_signature(
file_id=file_id, timestamp=args["timestamp"], nonce=args["nonce"], sign=args["sign"]
):
args = ToolFileQuery.model_validate(request.args.to_dict())
if not verify_tool_file_signature(file_id=file_id, timestamp=args.timestamp, nonce=args.nonce, sign=args.sign):
raise Forbidden("Invalid request.")
try:
@ -69,7 +74,7 @@ class ToolFileApi(Resource):
)
if tool_file.size > 0:
response.headers["Content-Length"] = str(tool_file.size)
if args["as_attachment"]:
if args.as_attachment:
encoded_filename = quote(tool_file.name)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"

View File

@ -1,40 +1,45 @@
from mimetypes import guess_extension
from flask_restx import Resource, reqparse
from flask import request
from flask_restx import Resource
from flask_restx.api import HTTPStatus
from pydantic import BaseModel, Field
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import Forbidden
import services
from controllers.common.errors import (
FileTooLargeError,
UnsupportedFileTypeError,
)
from controllers.console.wraps import setup_required
from controllers.files import files_ns
from controllers.inner_api.plugin.wraps import get_user
from core.file.helpers import verify_plugin_file_signature
from core.tools.tool_file_manager import ToolFileManager
from fields.file_fields import build_file_model
# Define parser for both documentation and validation
upload_parser = (
reqparse.RequestParser()
.add_argument("file", location="files", type=FileStorage, required=True, help="File to upload")
.add_argument(
"timestamp", type=str, required=True, location="args", help="Unix timestamp for signature verification"
)
.add_argument("nonce", type=str, required=True, location="args", help="Random string for signature verification")
.add_argument("sign", type=str, required=True, location="args", help="HMAC signature for request validation")
.add_argument("tenant_id", type=str, required=True, location="args", help="Tenant identifier")
.add_argument("user_id", type=str, required=False, location="args", help="User identifier")
from ..common.errors import (
FileTooLargeError,
UnsupportedFileTypeError,
)
from ..console.wraps import setup_required
from ..files import files_ns
from ..inner_api.plugin.wraps import get_user
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class PluginUploadQuery(BaseModel):
timestamp: str = Field(..., description="Unix timestamp for signature verification")
nonce: str = Field(..., description="Random nonce for signature verification")
sign: str = Field(..., description="HMAC signature")
tenant_id: str = Field(..., description="Tenant identifier")
user_id: str | None = Field(default=None, description="User identifier")
files_ns.schema_model(
PluginUploadQuery.__name__, PluginUploadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@files_ns.route("/upload/for-plugin")
class PluginUploadFileApi(Resource):
@setup_required
@files_ns.expect(upload_parser)
@files_ns.expect(files_ns.models[PluginUploadQuery.__name__])
@files_ns.doc("upload_plugin_file")
@files_ns.doc(description="Upload a file for plugin usage with signature verification")
@files_ns.doc(
@ -62,15 +67,17 @@ class PluginUploadFileApi(Resource):
FileTooLargeError: File exceeds size limit
UnsupportedFileTypeError: File type not supported
"""
# Parse and validate all arguments
args = upload_parser.parse_args()
args = PluginUploadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
file: FileStorage = args["file"]
timestamp: str = args["timestamp"]
nonce: str = args["nonce"]
sign: str = args["sign"]
tenant_id: str = args["tenant_id"]
user_id: str | None = args.get("user_id")
file: FileStorage | None = request.files.get("file")
if file is None:
raise Forbidden("File is required.")
timestamp = args.timestamp
nonce = args.nonce
sign = args.sign
tenant_id = args.tenant_id
user_id = args.user_id
user = get_user(tenant_id, user_id)
filename: str | None = file.filename

View File

@ -1,29 +1,38 @@
from flask_restx import Resource, reqparse
from typing import Any
from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.common.schema import register_schema_model
from controllers.console.wraps import setup_required
from controllers.inner_api import inner_api_ns
from controllers.inner_api.wraps import billing_inner_api_only, enterprise_inner_api_only
from tasks.mail_inner_task import send_inner_email_task
_mail_parser = (
reqparse.RequestParser()
.add_argument("to", type=str, action="append", required=True)
.add_argument("subject", type=str, required=True)
.add_argument("body", type=str, required=True)
.add_argument("substitutions", type=dict, required=False)
)
class InnerMailPayload(BaseModel):
to: list[str] = Field(description="Recipient email addresses", min_length=1)
subject: str
body: str
substitutions: dict[str, Any] | None = None
register_schema_model(inner_api_ns, InnerMailPayload)
class BaseMail(Resource):
"""Shared logic for sending an inner email."""
@inner_api_ns.doc("send_inner_mail")
@inner_api_ns.doc(description="Send internal email")
@inner_api_ns.expect(inner_api_ns.models[InnerMailPayload.__name__])
def post(self):
args = _mail_parser.parse_args()
send_inner_email_task.delay( # type: ignore
to=args["to"],
subject=args["subject"],
body=args["body"],
substitutions=args["substitutions"],
args = InnerMailPayload.model_validate(inner_api_ns.payload or {})
send_inner_email_task.delay(
to=args.to,
subject=args.subject,
body=args.body,
substitutions=args.substitutions, # type: ignore
)
return {"message": "success"}, 200
@ -34,7 +43,7 @@ class EnterpriseMail(BaseMail):
@inner_api_ns.doc("send_enterprise_mail")
@inner_api_ns.doc(description="Send internal email for enterprise features")
@inner_api_ns.expect(_mail_parser)
@inner_api_ns.expect(inner_api_ns.models[InnerMailPayload.__name__])
@inner_api_ns.doc(
responses={200: "Email sent successfully", 401: "Unauthorized - invalid API key", 404: "Service not available"}
)
@ -56,7 +65,7 @@ class BillingMail(BaseMail):
@inner_api_ns.doc("send_billing_mail")
@inner_api_ns.doc(description="Send internal email for billing notifications")
@inner_api_ns.expect(_mail_parser)
@inner_api_ns.expect(inner_api_ns.models[InnerMailPayload.__name__])
@inner_api_ns.doc(
responses={200: "Email sent successfully", 401: "Unauthorized - invalid API key", 404: "Service not available"}
)

View File

@ -1,10 +1,9 @@
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar, cast
from typing import ParamSpec, TypeVar
from flask import current_app, request
from flask_login import user_logged_in
from flask_restx import reqparse
from pydantic import BaseModel
from sqlalchemy.orm import Session
@ -17,6 +16,11 @@ P = ParamSpec("P")
R = TypeVar("R")
class TenantUserPayload(BaseModel):
tenant_id: str
user_id: str
def get_user(tenant_id: str, user_id: str | None) -> EndUser:
"""
Get current user
@ -67,58 +71,45 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
return user_model
def get_user_tenant(view: Callable[P, R] | None = None):
def decorator(view_func: Callable[P, R]):
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
# fetch json body
parser = (
reqparse.RequestParser()
.add_argument("tenant_id", type=str, required=True, location="json")
.add_argument("user_id", type=str, required=True, location="json")
)
def get_user_tenant(view_func: Callable[P, R]):
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
payload = TenantUserPayload.model_validate(request.get_json(silent=True) or {})
p = parser.parse_args()
user_id = payload.user_id
tenant_id = payload.tenant_id
user_id = cast(str, p.get("user_id"))
tenant_id = cast(str, p.get("tenant_id"))
if not tenant_id:
raise ValueError("tenant_id is required")
if not tenant_id:
raise ValueError("tenant_id is required")
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
try:
tenant_model = (
db.session.query(Tenant)
.where(
Tenant.id == tenant_id,
)
.first()
try:
tenant_model = (
db.session.query(Tenant)
.where(
Tenant.id == tenant_id,
)
except Exception:
raise ValueError("tenant not found")
.first()
)
except Exception:
raise ValueError("tenant not found")
if not tenant_model:
raise ValueError("tenant not found")
if not tenant_model:
raise ValueError("tenant not found")
kwargs["tenant_model"] = tenant_model
kwargs["tenant_model"] = tenant_model
user = get_user(tenant_id, user_id)
kwargs["user_model"] = user
user = get_user(tenant_id, user_id)
kwargs["user_model"] = user
current_app.login_manager._update_request_context_with_user(user) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
current_app.login_manager._update_request_context_with_user(user) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
return view_func(*args, **kwargs)
return view_func(*args, **kwargs)
return decorated_view
if view is None:
return decorator
else:
return decorator(view)
return decorated_view
def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]):

View File

@ -1,7 +1,9 @@
import json
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel
from controllers.common.schema import register_schema_models
from controllers.console.wraps import setup_required
from controllers.inner_api import inner_api_ns
from controllers.inner_api.wraps import enterprise_inner_api_only
@ -11,12 +13,25 @@ from models import Account
from services.account_service import TenantService
class WorkspaceCreatePayload(BaseModel):
name: str
owner_email: str
class WorkspaceOwnerlessPayload(BaseModel):
name: str
register_schema_models(inner_api_ns, WorkspaceCreatePayload, WorkspaceOwnerlessPayload)
@inner_api_ns.route("/enterprise/workspace")
class EnterpriseWorkspace(Resource):
@setup_required
@enterprise_inner_api_only
@inner_api_ns.doc("create_enterprise_workspace")
@inner_api_ns.doc(description="Create a new enterprise workspace with owner assignment")
@inner_api_ns.expect(inner_api_ns.models[WorkspaceCreatePayload.__name__])
@inner_api_ns.doc(
responses={
200: "Workspace created successfully",
@ -25,18 +40,13 @@ class EnterpriseWorkspace(Resource):
}
)
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=True, location="json")
.add_argument("owner_email", type=str, required=True, location="json")
)
args = parser.parse_args()
args = WorkspaceCreatePayload.model_validate(inner_api_ns.payload or {})
account = db.session.query(Account).filter_by(email=args["owner_email"]).first()
account = db.session.query(Account).filter_by(email=args.owner_email).first()
if account is None:
return {"message": "owner account not found."}, 404
tenant = TenantService.create_tenant(args["name"], is_from_dashboard=True)
tenant = TenantService.create_tenant(args.name, is_from_dashboard=True)
TenantService.create_tenant_member(tenant, account, role="owner")
tenant_was_created.send(tenant)
@ -62,6 +72,7 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource):
@enterprise_inner_api_only
@inner_api_ns.doc("create_enterprise_workspace_ownerless")
@inner_api_ns.doc(description="Create a new enterprise workspace without initial owner assignment")
@inner_api_ns.expect(inner_api_ns.models[WorkspaceOwnerlessPayload.__name__])
@inner_api_ns.doc(
responses={
200: "Workspace created successfully",
@ -70,10 +81,9 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource):
}
)
def post(self):
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
args = parser.parse_args()
args = WorkspaceOwnerlessPayload.model_validate(inner_api_ns.payload or {})
tenant = TenantService.create_tenant(args["name"], is_from_dashboard=True)
tenant = TenantService.create_tenant(args.name, is_from_dashboard=True)
tenant_was_created.send(tenant)

View File

@ -1,10 +1,11 @@
from typing import Union
from typing import Any, Union
from flask import Response
from flask_restx import Resource, reqparse
from pydantic import ValidationError
from flask_restx import Resource
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_model
from controllers.console.app.mcp_server import AppMCPServerStatus
from controllers.mcp import mcp_ns
from core.app.app_config.entities import VariableEntity
@ -24,27 +25,19 @@ class MCPRequestError(Exception):
super().__init__(message)
def int_or_str(value):
"""Validate that a value is either an integer or string."""
if isinstance(value, (int, str)):
return value
else:
return None
class MCPRequestPayload(BaseModel):
jsonrpc: str = Field(description="JSON-RPC version (should be '2.0')")
method: str = Field(description="The method to invoke")
params: dict[str, Any] | None = Field(default=None, description="Parameters for the method")
id: int | str | None = Field(default=None, description="Request ID for tracking responses")
# Define parser for both documentation and validation
mcp_request_parser = (
reqparse.RequestParser()
.add_argument("jsonrpc", type=str, required=True, location="json", help="JSON-RPC version (should be '2.0')")
.add_argument("method", type=str, required=True, location="json", help="The method to invoke")
.add_argument("params", type=dict, required=False, location="json", help="Parameters for the method")
.add_argument("id", type=int_or_str, required=False, location="json", help="Request ID for tracking responses")
)
register_schema_model(mcp_ns, MCPRequestPayload)
@mcp_ns.route("/server/<string:server_code>/mcp")
class MCPAppApi(Resource):
@mcp_ns.expect(mcp_request_parser)
@mcp_ns.expect(mcp_ns.models[MCPRequestPayload.__name__])
@mcp_ns.doc("handle_mcp_request")
@mcp_ns.doc(description="Handle Model Context Protocol (MCP) requests for a specific server")
@mcp_ns.doc(params={"server_code": "Unique identifier for the MCP server"})
@ -70,9 +63,9 @@ class MCPAppApi(Resource):
Raises:
ValidationError: Invalid request format or parameters
"""
args = mcp_request_parser.parse_args()
request_id: Union[int, str] | None = args.get("id")
mcp_request = self._parse_mcp_request(args)
args = MCPRequestPayload.model_validate(mcp_ns.payload or {})
request_id: Union[int, str] | None = args.id
mcp_request = self._parse_mcp_request(args.model_dump(exclude_none=True))
with Session(db.engine, expire_on_commit=False) as session:
# Get MCP server and app

View File

@ -1,9 +1,11 @@
from typing import Literal
from flask import request
from flask_restx import Api, Namespace, Resource, fields, reqparse
from flask_restx import Api, Namespace, Resource, fields
from flask_restx.api import HTTPStatus
from pydantic import BaseModel, Field
from controllers.common.schema import register_schema_models
from controllers.console.wraps import edit_permission_required
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import validate_app_token
@ -12,26 +14,24 @@ from fields.annotation_fields import annotation_fields, build_annotation_model
from models.model import App
from services.annotation_service import AppAnnotationService
# Define parsers for annotation API
annotation_create_parser = (
reqparse.RequestParser()
.add_argument("question", required=True, type=str, location="json", help="Annotation question")
.add_argument("answer", required=True, type=str, location="json", help="Annotation answer")
)
annotation_reply_action_parser = (
reqparse.RequestParser()
.add_argument(
"score_threshold", required=True, type=float, location="json", help="Score threshold for annotation matching"
)
.add_argument("embedding_provider_name", required=True, type=str, location="json", help="Embedding provider name")
.add_argument("embedding_model_name", required=True, type=str, location="json", help="Embedding model name")
)
class AnnotationCreatePayload(BaseModel):
question: str = Field(description="Annotation question")
answer: str = Field(description="Annotation answer")
class AnnotationReplyActionPayload(BaseModel):
score_threshold: float = Field(description="Score threshold for annotation matching")
embedding_provider_name: str = Field(description="Embedding provider name")
embedding_model_name: str = Field(description="Embedding model name")
register_schema_models(service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload)
@service_api_ns.route("/apps/annotation-reply/<string:action>")
class AnnotationReplyActionApi(Resource):
@service_api_ns.expect(annotation_reply_action_parser)
@service_api_ns.expect(service_api_ns.models[AnnotationReplyActionPayload.__name__])
@service_api_ns.doc("annotation_reply_action")
@service_api_ns.doc(description="Enable or disable annotation reply feature")
@service_api_ns.doc(params={"action": "Action to perform: 'enable' or 'disable'"})
@ -44,7 +44,7 @@ class AnnotationReplyActionApi(Resource):
@validate_app_token
def post(self, app_model: App, action: Literal["enable", "disable"]):
"""Enable or disable annotation reply feature."""
args = annotation_reply_action_parser.parse_args()
args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump()
if action == "enable":
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
elif action == "disable":
@ -126,7 +126,7 @@ class AnnotationListApi(Resource):
"page": page,
}
@service_api_ns.expect(annotation_create_parser)
@service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__])
@service_api_ns.doc("create_annotation")
@service_api_ns.doc(description="Create a new annotation")
@service_api_ns.doc(
@ -139,14 +139,14 @@ class AnnotationListApi(Resource):
@service_api_ns.marshal_with(build_annotation_model(service_api_ns), code=HTTPStatus.CREATED)
def post(self, app_model: App):
"""Create a new annotation."""
args = annotation_create_parser.parse_args()
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id)
return annotation, 201
@service_api_ns.route("/apps/annotations/<uuid:annotation_id>")
class AnnotationUpdateDeleteApi(Resource):
@service_api_ns.expect(annotation_create_parser)
@service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__])
@service_api_ns.doc("update_annotation")
@service_api_ns.doc(description="Update an existing annotation")
@service_api_ns.doc(params={"annotation_id": "Annotation ID"})
@ -163,7 +163,7 @@ class AnnotationUpdateDeleteApi(Resource):
@service_api_ns.marshal_with(build_annotation_model(service_api_ns))
def put(self, app_model: App, annotation_id: str):
"""Update an existing annotation."""
args = annotation_create_parser.parse_args()
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
return annotation

View File

@ -1,10 +1,12 @@
import logging
from flask import request
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError
import services
from controllers.common.schema import register_schema_model
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import (
AppUnavailableError,
@ -84,19 +86,19 @@ class AudioApi(Resource):
raise InternalServerError()
# Define parser for text-to-audio API
text_to_audio_parser = (
reqparse.RequestParser()
.add_argument("message_id", type=str, required=False, location="json", help="Message ID")
.add_argument("voice", type=str, location="json", help="Voice to use for TTS")
.add_argument("text", type=str, location="json", help="Text to convert to audio")
.add_argument("streaming", type=bool, location="json", help="Enable streaming response")
)
class TextToAudioPayload(BaseModel):
message_id: str | None = Field(default=None, description="Message ID")
voice: str | None = Field(default=None, description="Voice to use for TTS")
text: str | None = Field(default=None, description="Text to convert to audio")
streaming: bool | None = Field(default=None, description="Enable streaming response")
register_schema_model(service_api_ns, TextToAudioPayload)
@service_api_ns.route("/text-to-audio")
class TextApi(Resource):
@service_api_ns.expect(text_to_audio_parser)
@service_api_ns.expect(service_api_ns.models[TextToAudioPayload.__name__])
@service_api_ns.doc("text_to_audio")
@service_api_ns.doc(description="Convert text to audio using text-to-speech")
@service_api_ns.doc(
@ -114,11 +116,11 @@ class TextApi(Resource):
Converts the provided text to audio using the specified voice.
"""
try:
args = text_to_audio_parser.parse_args()
payload = TextToAudioPayload.model_validate(service_api_ns.payload or {})
message_id = args.get("message_id", None)
text = args.get("text", None)
voice = args.get("voice", None)
message_id = payload.message_id
text = payload.text
voice = payload.voice
response = AudioService.transcript_tts(
app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
)

View File

@ -1,10 +1,14 @@
import logging
from typing import Any, Literal
from uuid import UUID
from flask import request
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import (
AppUnavailableError,
@ -26,7 +30,6 @@ from core.errors.error import (
from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import uuid_value
from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService
from services.app_task_service import AppTaskService
@ -36,40 +39,46 @@ from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__)
# Define parser for completion API
completion_parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for completion")
.add_argument("query", type=str, location="json", default="", help="The query string")
.add_argument("files", type=list, required=False, location="json", help="List of file attachments")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode")
.add_argument("retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source")
)
class CompletionRequestPayload(BaseModel):
inputs: dict[str, Any]
query: str = Field(default="")
files: list[dict[str, Any]] | None = None
response_mode: Literal["blocking", "streaming"] | None = None
retriever_from: str = Field(default="dev")
# Define parser for chat API
chat_parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for chat")
.add_argument("query", type=str, required=True, location="json", help="The chat query")
.add_argument("files", type=list, required=False, location="json", help="List of file attachments")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode")
.add_argument("conversation_id", type=uuid_value, location="json", help="Existing conversation ID")
.add_argument("retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source")
.add_argument(
"auto_generate_name",
type=bool,
required=False,
default=True,
location="json",
help="Auto generate conversation name",
)
.add_argument("workflow_id", type=str, required=False, location="json", help="Workflow ID for advanced chat")
)
class ChatRequestPayload(BaseModel):
inputs: dict[str, Any]
query: str
files: list[dict[str, Any]] | None = None
response_mode: Literal["blocking", "streaming"] | None = None
conversation_id: str | None = Field(default=None, description="Conversation UUID")
retriever_from: str = Field(default="dev")
auto_generate_name: bool = Field(default=True, description="Auto generate conversation name")
workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat")
@field_validator("conversation_id", mode="before")
@classmethod
def normalize_conversation_id(cls, value: str | UUID | None) -> str | None:
"""Allow missing or blank conversation IDs; enforce UUID format when provided."""
if isinstance(value, str):
value = value.strip()
if not value:
return None
try:
return helper.uuid_value(value)
except ValueError as exc:
raise ValueError("conversation_id must be a valid UUID") from exc
register_schema_models(service_api_ns, CompletionRequestPayload, ChatRequestPayload)
@service_api_ns.route("/completion-messages")
class CompletionApi(Resource):
@service_api_ns.expect(completion_parser)
@service_api_ns.expect(service_api_ns.models[CompletionRequestPayload.__name__])
@service_api_ns.doc("create_completion")
@service_api_ns.doc(description="Create a completion for the given prompt")
@service_api_ns.doc(
@ -91,12 +100,13 @@ class CompletionApi(Resource):
if app_model.mode != AppMode.COMPLETION:
raise AppUnavailableError()
args = completion_parser.parse_args()
payload = CompletionRequestPayload.model_validate(service_api_ns.payload or {})
external_trace_id = get_external_trace_id(request)
args = payload.model_dump(exclude_none=True)
if external_trace_id:
args["external_trace_id"] = external_trace_id
streaming = args["response_mode"] == "streaming"
streaming = payload.response_mode == "streaming"
args["auto_generate_name"] = False
@ -162,7 +172,7 @@ class CompletionStopApi(Resource):
@service_api_ns.route("/chat-messages")
class ChatApi(Resource):
@service_api_ns.expect(chat_parser)
@service_api_ns.expect(service_api_ns.models[ChatRequestPayload.__name__])
@service_api_ns.doc("create_chat_message")
@service_api_ns.doc(description="Send a message in a chat conversation")
@service_api_ns.doc(
@ -186,13 +196,14 @@ class ChatApi(Resource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
args = chat_parser.parse_args()
payload = ChatRequestPayload.model_validate(service_api_ns.payload or {})
external_trace_id = get_external_trace_id(request)
args = payload.model_dump(exclude_none=True)
if external_trace_id:
args["external_trace_id"] = external_trace_id
streaming = args["response_mode"] == "streaming"
streaming = payload.response_mode == "streaming"
try:
response = AppGenerateService.generate(

View File

@ -1,10 +1,15 @@
from flask_restx import Resource, reqparse
from typing import Any, Literal
from uuid import UUID
from flask import request
from flask_restx import Resource
from flask_restx._http import HTTPStatus
from flask_restx.inputs import int_range
from pydantic import BaseModel, Field, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound
import services
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
@ -19,74 +24,51 @@ from fields.conversation_variable_fields import (
build_conversation_variable_infinite_scroll_pagination_model,
build_conversation_variable_model,
)
from libs.helper import uuid_value
from models.model import App, AppMode, EndUser
from services.conversation_service import ConversationService
# Define parsers for conversation APIs
conversation_list_parser = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args", help="Last conversation ID for pagination")
.add_argument(
"limit",
type=int_range(1, 100),
required=False,
default=20,
location="args",
help="Number of conversations to return",
)
.add_argument(
"sort_by",
type=str,
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
required=False,
default="-updated_at",
location="args",
help="Sort order for conversations",
)
)
conversation_rename_parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=False, location="json", help="New conversation name")
.add_argument(
"auto_generate",
type=bool,
required=False,
default=False,
location="json",
help="Auto-generate conversation name",
class ConversationListQuery(BaseModel):
last_id: UUID | None = Field(default=None, description="Last conversation ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of conversations to return")
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
default="-updated_at", description="Sort order for conversations"
)
)
conversation_variables_parser = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args", help="Last variable ID for pagination")
.add_argument(
"limit",
type=int_range(1, 100),
required=False,
default=20,
location="args",
help="Number of variables to return",
)
)
conversation_variable_update_parser = reqparse.RequestParser().add_argument(
# using lambda is for passing the already-typed value without modification
# if no lambda, it will be converted to string
# the string cannot be converted using json.loads
"value",
required=True,
location="json",
type=lambda x: x,
help="New value for the conversation variable",
class ConversationRenamePayload(BaseModel):
name: str | None = Field(default=None, description="New conversation name (required if auto_generate is false)")
auto_generate: bool = Field(default=False, description="Auto-generate conversation name")
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
class ConversationVariablesQuery(BaseModel):
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
class ConversationVariableUpdatePayload(BaseModel):
value: Any
register_schema_models(
service_api_ns,
ConversationListQuery,
ConversationRenamePayload,
ConversationVariablesQuery,
ConversationVariableUpdatePayload,
)
@service_api_ns.route("/conversations")
class ConversationApi(Resource):
@service_api_ns.expect(conversation_list_parser)
@service_api_ns.expect(service_api_ns.models[ConversationListQuery.__name__])
@service_api_ns.doc("list_conversations")
@service_api_ns.doc(description="List all conversations for the current user")
@service_api_ns.doc(
@ -107,7 +89,8 @@ class ConversationApi(Resource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
args = conversation_list_parser.parse_args()
query_args = ConversationListQuery.model_validate(request.args.to_dict())
last_id = str(query_args.last_id) if query_args.last_id else None
try:
with Session(db.engine) as session:
@ -115,10 +98,10 @@ class ConversationApi(Resource):
session=session,
app_model=app_model,
user=end_user,
last_id=args["last_id"],
limit=args["limit"],
last_id=last_id,
limit=query_args.limit,
invoke_from=InvokeFrom.SERVICE_API,
sort_by=args["sort_by"],
sort_by=query_args.sort_by,
)
except services.errors.conversation.LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
@ -155,7 +138,7 @@ class ConversationDetailApi(Resource):
@service_api_ns.route("/conversations/<uuid:c_id>/name")
class ConversationRenameApi(Resource):
@service_api_ns.expect(conversation_rename_parser)
@service_api_ns.expect(service_api_ns.models[ConversationRenamePayload.__name__])
@service_api_ns.doc("rename_conversation")
@service_api_ns.doc(description="Rename a conversation or auto-generate a name")
@service_api_ns.doc(params={"c_id": "Conversation ID"})
@ -176,17 +159,17 @@ class ConversationRenameApi(Resource):
conversation_id = str(c_id)
args = conversation_rename_parser.parse_args()
payload = ConversationRenamePayload.model_validate(service_api_ns.payload or {})
try:
return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"])
return ConversationService.rename(app_model, conversation_id, end_user, payload.name, payload.auto_generate)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@service_api_ns.route("/conversations/<uuid:c_id>/variables")
class ConversationVariablesApi(Resource):
@service_api_ns.expect(conversation_variables_parser)
@service_api_ns.expect(service_api_ns.models[ConversationVariablesQuery.__name__])
@service_api_ns.doc("list_conversation_variables")
@service_api_ns.doc(description="List all variables for a conversation")
@service_api_ns.doc(params={"c_id": "Conversation ID"})
@ -211,11 +194,12 @@ class ConversationVariablesApi(Resource):
conversation_id = str(c_id)
args = conversation_variables_parser.parse_args()
query_args = ConversationVariablesQuery.model_validate(request.args.to_dict())
last_id = str(query_args.last_id) if query_args.last_id else None
try:
return ConversationService.get_conversational_variable(
app_model, conversation_id, end_user, args["limit"], args["last_id"]
app_model, conversation_id, end_user, query_args.limit, last_id
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@ -223,7 +207,7 @@ class ConversationVariablesApi(Resource):
@service_api_ns.route("/conversations/<uuid:c_id>/variables/<uuid:variable_id>")
class ConversationVariableDetailApi(Resource):
@service_api_ns.expect(conversation_variable_update_parser)
@service_api_ns.expect(service_api_ns.models[ConversationVariableUpdatePayload.__name__])
@service_api_ns.doc("update_conversation_variable")
@service_api_ns.doc(description="Update a conversation variable's value")
@service_api_ns.doc(params={"c_id": "Conversation ID", "variable_id": "Variable ID"})
@ -250,11 +234,11 @@ class ConversationVariableDetailApi(Resource):
conversation_id = str(c_id)
variable_id = str(variable_id)
args = conversation_variable_update_parser.parse_args()
payload = ConversationVariableUpdatePayload.model_validate(service_api_ns.payload or {})
try:
return ConversationService.update_conversation_variable(
app_model, conversation_id, variable_id, end_user, args["value"]
app_model, conversation_id, variable_id, end_user, payload.value
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")

View File

@ -1,9 +1,11 @@
import logging
from urllib.parse import quote
from flask import Response
from flask_restx import Resource, reqparse
from flask import Response, request
from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.common.schema import register_schema_model
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import (
FileAccessDeniedError,
@ -17,10 +19,11 @@ from models.model import App, EndUser, Message, MessageFile, UploadFile
logger = logging.getLogger(__name__)
# Define parser for file preview API
file_preview_parser = reqparse.RequestParser().add_argument(
"as_attachment", type=bool, required=False, default=False, location="args", help="Download as attachment"
)
class FilePreviewQuery(BaseModel):
as_attachment: bool = Field(default=False, description="Download as attachment")
register_schema_model(service_api_ns, FilePreviewQuery)
@service_api_ns.route("/files/<uuid:file_id>/preview")
@ -32,7 +35,7 @@ class FilePreviewApi(Resource):
Files can only be accessed if they belong to messages within the requesting app's context.
"""
@service_api_ns.expect(file_preview_parser)
@service_api_ns.expect(service_api_ns.models[FilePreviewQuery.__name__])
@service_api_ns.doc("preview_file")
@service_api_ns.doc(description="Preview or download a file uploaded via Service API")
@service_api_ns.doc(params={"file_id": "UUID of the file to preview"})
@ -55,7 +58,7 @@ class FilePreviewApi(Resource):
file_id = str(file_id)
# Parse query parameters
args = file_preview_parser.parse_args()
args = FilePreviewQuery.model_validate(request.args.to_dict())
# Validate file ownership and get file objects
_, upload_file = self._validate_file_ownership(file_id, app_model.id)
@ -67,7 +70,7 @@ class FilePreviewApi(Resource):
raise FileNotFoundError(f"Failed to load file content: {str(e)}")
# Build response with appropriate headers
response = self._build_file_response(generator, upload_file, args["as_attachment"])
response = self._build_file_response(generator, upload_file, args.as_attachment)
return response

View File

@ -1,11 +1,15 @@
import json
import logging
from typing import Literal
from uuid import UUID
from flask_restx import Api, Namespace, Resource, fields, reqparse
from flask_restx.inputs import int_range
from flask import request
from flask_restx import Namespace, Resource, fields
from pydantic import BaseModel, Field
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
@ -13,7 +17,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import build_message_file_model
from fields.message_fields import build_agent_thought_model, build_feedback_model
from fields.raws import FilesContainedField
from libs.helper import TimestampField, uuid_value
from libs.helper import TimestampField
from models.model import App, AppMode, EndUser
from services.errors.message import (
FirstMessageNotExistsError,
@ -25,42 +29,26 @@ from services.message_service import MessageService
logger = logging.getLogger(__name__)
# Define parsers for message APIs
message_list_parser = (
reqparse.RequestParser()
.add_argument("conversation_id", required=True, type=uuid_value, location="args", help="Conversation ID")
.add_argument("first_id", type=uuid_value, location="args", help="First message ID for pagination")
.add_argument(
"limit",
type=int_range(1, 100),
required=False,
default=20,
location="args",
help="Number of messages to return",
)
)
message_feedback_parser = (
reqparse.RequestParser()
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json", help="Feedback rating")
.add_argument("content", type=str, location="json", help="Feedback content")
)
feedback_list_parser = (
reqparse.RequestParser()
.add_argument("page", type=int, default=1, location="args", help="Page number")
.add_argument(
"limit",
type=int_range(1, 101),
required=False,
default=20,
location="args",
help="Number of feedbacks per page",
)
)
class MessageListQuery(BaseModel):
conversation_id: UUID
first_id: UUID | None = None
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return")
def build_message_model(api_or_ns: Api | Namespace):
class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
content: str | None = Field(default=None, description="Feedback content")
class FeedbackListQuery(BaseModel):
page: int = Field(default=1, ge=1, description="Page number")
limit: int = Field(default=20, ge=1, le=101, description="Number of feedbacks per page")
register_schema_models(service_api_ns, MessageListQuery, MessageFeedbackPayload, FeedbackListQuery)
def build_message_model(api_or_ns: Namespace):
"""Build the message model for the API or Namespace."""
# First build the nested models
feedback_model = build_feedback_model(api_or_ns)
@ -90,7 +78,7 @@ def build_message_model(api_or_ns: Api | Namespace):
return api_or_ns.model("Message", message_fields)
def build_message_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
def build_message_infinite_scroll_pagination_model(api_or_ns: Namespace):
"""Build the message infinite scroll pagination model for the API or Namespace."""
# Build the nested message model first
message_model = build_message_model(api_or_ns)
@ -105,7 +93,7 @@ def build_message_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
@service_api_ns.route("/messages")
class MessageListApi(Resource):
@service_api_ns.expect(message_list_parser)
@service_api_ns.expect(service_api_ns.models[MessageListQuery.__name__])
@service_api_ns.doc("list_messages")
@service_api_ns.doc(description="List messages in a conversation")
@service_api_ns.doc(
@ -126,11 +114,13 @@ class MessageListApi(Resource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
args = message_list_parser.parse_args()
query_args = MessageListQuery.model_validate(request.args.to_dict())
conversation_id = str(query_args.conversation_id)
first_id = str(query_args.first_id) if query_args.first_id else None
try:
return MessageService.pagination_by_first_id(
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"]
app_model, end_user, conversation_id, first_id, query_args.limit
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@ -140,7 +130,7 @@ class MessageListApi(Resource):
@service_api_ns.route("/messages/<uuid:message_id>/feedbacks")
class MessageFeedbackApi(Resource):
@service_api_ns.expect(message_feedback_parser)
@service_api_ns.expect(service_api_ns.models[MessageFeedbackPayload.__name__])
@service_api_ns.doc("create_message_feedback")
@service_api_ns.doc(description="Submit feedback for a message")
@service_api_ns.doc(params={"message_id": "Message ID"})
@ -159,15 +149,15 @@ class MessageFeedbackApi(Resource):
"""
message_id = str(message_id)
args = message_feedback_parser.parse_args()
payload = MessageFeedbackPayload.model_validate(service_api_ns.payload or {})
try:
MessageService.create_feedback(
app_model=app_model,
message_id=message_id,
user=end_user,
rating=args.get("rating"),
content=args.get("content"),
rating=payload.rating,
content=payload.content,
)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
@ -177,7 +167,7 @@ class MessageFeedbackApi(Resource):
@service_api_ns.route("/app/feedbacks")
class AppGetFeedbacksApi(Resource):
@service_api_ns.expect(feedback_list_parser)
@service_api_ns.expect(service_api_ns.models[FeedbackListQuery.__name__])
@service_api_ns.doc("get_app_feedbacks")
@service_api_ns.doc(description="Get all feedbacks for the application")
@service_api_ns.doc(
@ -192,8 +182,8 @@ class AppGetFeedbacksApi(Resource):
Returns paginated list of all feedback submitted for messages in this app.
"""
args = feedback_list_parser.parse_args()
feedbacks = MessageService.get_all_messages_feedbacks(app_model, page=args["page"], limit=args["limit"])
query_args = FeedbackListQuery.model_validate(request.args.to_dict())
feedbacks = MessageService.get_all_messages_feedbacks(app_model, page=query_args.page, limit=query_args.limit)
return {"data": feedbacks}

View File

@ -1,12 +1,14 @@
import logging
from typing import Any, Literal
from dateutil.parser import isoparse
from flask import request
from flask_restx import Api, Namespace, Resource, fields, reqparse
from flask_restx.inputs import int_range
from flask_restx import Api, Namespace, Resource, fields
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session, sessionmaker
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import (
CompletionRequestError,
@ -41,37 +43,25 @@ from services.workflow_app_service import WorkflowAppService
logger = logging.getLogger(__name__)
# Define parsers for workflow APIs
workflow_run_parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("files", type=list, required=False, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
)
workflow_log_parser = (
reqparse.RequestParser()
.add_argument("keyword", type=str, location="args")
.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
.add_argument("created_at__before", type=str, location="args")
.add_argument("created_at__after", type=str, location="args")
.add_argument(
"created_by_end_user_session_id",
type=str,
location="args",
required=False,
default=None,
)
.add_argument(
"created_by_account",
type=str,
location="args",
required=False,
default=None,
)
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
)
class WorkflowRunPayload(BaseModel):
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
response_mode: Literal["blocking", "streaming"] | None = None
class WorkflowLogQuery(BaseModel):
keyword: str | None = None
status: Literal["succeeded", "failed", "stopped"] | None = None
created_at__before: str | None = None
created_at__after: str | None = None
created_by_end_user_session_id: str | None = None
created_by_account: str | None = None
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=20, ge=1, le=100)
register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery)
workflow_run_fields = {
"id": fields.String,
@ -130,7 +120,7 @@ class WorkflowRunDetailApi(Resource):
@service_api_ns.route("/workflows/run")
class WorkflowRunApi(Resource):
@service_api_ns.expect(workflow_run_parser)
@service_api_ns.expect(service_api_ns.models[WorkflowRunPayload.__name__])
@service_api_ns.doc("run_workflow")
@service_api_ns.doc(description="Execute a workflow")
@service_api_ns.doc(
@ -154,11 +144,12 @@ class WorkflowRunApi(Resource):
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
args = workflow_run_parser.parse_args()
payload = WorkflowRunPayload.model_validate(service_api_ns.payload or {})
args = payload.model_dump(exclude_none=True)
external_trace_id = get_external_trace_id(request)
if external_trace_id:
args["external_trace_id"] = external_trace_id
streaming = args.get("response_mode") == "streaming"
streaming = payload.response_mode == "streaming"
try:
response = AppGenerateService.generate(
@ -185,7 +176,7 @@ class WorkflowRunApi(Resource):
@service_api_ns.route("/workflows/<string:workflow_id>/run")
class WorkflowRunByIdApi(Resource):
@service_api_ns.expect(workflow_run_parser)
@service_api_ns.expect(service_api_ns.models[WorkflowRunPayload.__name__])
@service_api_ns.doc("run_workflow_by_id")
@service_api_ns.doc(description="Execute a specific workflow by ID")
@service_api_ns.doc(params={"workflow_id": "Workflow ID to execute"})
@ -209,7 +200,8 @@ class WorkflowRunByIdApi(Resource):
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
args = workflow_run_parser.parse_args()
payload = WorkflowRunPayload.model_validate(service_api_ns.payload or {})
args = payload.model_dump(exclude_none=True)
# Add workflow_id to args for AppGenerateService
args["workflow_id"] = workflow_id
@ -217,7 +209,7 @@ class WorkflowRunByIdApi(Resource):
external_trace_id = get_external_trace_id(request)
if external_trace_id:
args["external_trace_id"] = external_trace_id
streaming = args.get("response_mode") == "streaming"
streaming = payload.response_mode == "streaming"
try:
response = AppGenerateService.generate(
@ -279,7 +271,7 @@ class WorkflowTaskStopApi(Resource):
@service_api_ns.route("/workflows/logs")
class WorkflowAppLogApi(Resource):
@service_api_ns.expect(workflow_log_parser)
@service_api_ns.expect(service_api_ns.models[WorkflowLogQuery.__name__])
@service_api_ns.doc("get_workflow_logs")
@service_api_ns.doc(description="Get workflow execution logs")
@service_api_ns.doc(
@ -295,14 +287,11 @@ class WorkflowAppLogApi(Resource):
Returns paginated workflow execution logs with filtering options.
"""
args = workflow_log_parser.parse_args()
args = WorkflowLogQuery.model_validate(request.args.to_dict())
args.status = WorkflowExecutionStatus(args.status) if args.status else None
if args.created_at__before:
args.created_at__before = isoparse(args.created_at__before)
if args.created_at__after:
args.created_at__after = isoparse(args.created_at__after)
status = WorkflowExecutionStatus(args.status) if args.status else None
created_at_before = isoparse(args.created_at__before) if args.created_at__before else None
created_at_after = isoparse(args.created_at__after) if args.created_at__after else None
# get paginate workflow app logs
workflow_app_service = WorkflowAppService()
@ -311,9 +300,9 @@ class WorkflowAppLogApi(Resource):
session=session,
app_model=app_model,
keyword=args.keyword,
status=args.status,
created_at_before=args.created_at__before,
created_at_after=args.created_at__after,
status=status,
created_at_before=created_at_before,
created_at_after=created_at_after,
page=args.page,
limit=args.limit,
created_by_end_user_session_id=args.created_by_end_user_session_id,

View File

@ -1,10 +1,12 @@
from typing import Any, Literal, cast
from flask import request
from flask_restx import marshal, reqparse
from flask_restx import marshal
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.common.schema import register_schema_models
from controllers.console.wraps import edit_permission_required
from controllers.service_api import service_api_ns
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
@ -18,173 +20,83 @@ from core.provider_manager import ProviderManager
from fields.dataset_fields import dataset_detail_fields
from fields.tag_fields import build_dataset_tag_fields
from libs.login import current_user
from libs.validators import validate_description_length
from models.account import Account
from models.dataset import Dataset, DatasetPermissionEnum
from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
from services.tag_service import TagService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
class DatasetCreatePayload(BaseModel):
name: str = Field(..., min_length=1, max_length=40)
description: str = Field(default="", description="Dataset description (max 400 chars)", max_length=400)
indexing_technique: Literal["high_quality", "economy"] | None = None
permission: DatasetPermissionEnum | None = DatasetPermissionEnum.ONLY_ME
external_knowledge_api_id: str | None = None
provider: str = "vendor"
external_knowledge_id: str | None = None
retrieval_model: RetrievalModel | None = None
embedding_model: str | None = None
embedding_model_provider: str | None = None
# Define parsers for dataset operations
dataset_create_parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument(
"description",
type=validate_description_length,
nullable=True,
required=False,
default="",
)
.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
help="Invalid indexing technique.",
)
.add_argument(
"permission",
type=str,
location="json",
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
help="Invalid permission.",
required=False,
nullable=False,
)
.add_argument(
"external_knowledge_api_id",
type=str,
nullable=True,
required=False,
default="_validate_name",
)
.add_argument(
"provider",
type=str,
nullable=True,
required=False,
default="vendor",
)
.add_argument(
"external_knowledge_id",
type=str,
nullable=True,
required=False,
)
.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
)
class DatasetUpdatePayload(BaseModel):
name: str | None = Field(default=None, min_length=1, max_length=40)
description: str | None = Field(default=None, description="Dataset description (max 400 chars)", max_length=400)
indexing_technique: Literal["high_quality", "economy"] | None = None
permission: DatasetPermissionEnum | None = None
embedding_model: str | None = None
embedding_model_provider: str | None = None
retrieval_model: RetrievalModel | None = None
partial_member_list: list[str] | None = None
external_retrieval_model: dict[str, Any] | None = None
external_knowledge_id: str | None = None
external_knowledge_api_id: str | None = None
dataset_update_parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument("description", location="json", store_missing=False, type=validate_description_length)
.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
.add_argument(
"permission",
type=str,
location="json",
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
help="Invalid permission.",
)
.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
.add_argument("embedding_model_provider", type=str, location="json", help="Invalid embedding model provider.")
.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
.add_argument(
"external_retrieval_model",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid external retrieval model.",
)
.add_argument(
"external_knowledge_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge id.",
)
.add_argument(
"external_knowledge_api_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge api id.",
)
)
tag_create_parser = reqparse.RequestParser().add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 50 characters.",
type=lambda x: x
if x and 1 <= len(x) <= 50
else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")),
)
class TagNamePayload(BaseModel):
name: str = Field(..., min_length=1, max_length=50)
tag_update_parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 50 characters.",
type=lambda x: x
if x and 1 <= len(x) <= 50
else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")),
)
.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
)
tag_delete_parser = reqparse.RequestParser().add_argument(
"tag_id", nullable=False, required=True, help="Id of a tag.", type=str
)
class TagCreatePayload(TagNamePayload):
pass
tag_binding_parser = (
reqparse.RequestParser()
.add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.")
.add_argument(
"target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required."
)
)
tag_unbinding_parser = (
reqparse.RequestParser()
.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
class TagUpdatePayload(TagNamePayload):
tag_id: str
class TagDeletePayload(BaseModel):
tag_id: str
class TagBindingPayload(BaseModel):
tag_ids: list[str]
target_id: str
@field_validator("tag_ids")
@classmethod
def validate_tag_ids(cls, value: list[str]) -> list[str]:
if not value:
raise ValueError("Tag IDs is required.")
return value
class TagUnbindingPayload(BaseModel):
tag_id: str
target_id: str
register_schema_models(
service_api_ns,
DatasetCreatePayload,
DatasetUpdatePayload,
TagCreatePayload,
TagUpdatePayload,
TagDeletePayload,
TagBindingPayload,
TagUnbindingPayload,
)
@ -239,7 +151,7 @@ class DatasetListApi(DatasetApiResource):
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
return response, 200
@service_api_ns.expect(dataset_create_parser)
@service_api_ns.expect(service_api_ns.models[DatasetCreatePayload.__name__])
@service_api_ns.doc("create_dataset")
@service_api_ns.doc(description="Create a new dataset")
@service_api_ns.doc(
@ -252,42 +164,41 @@ class DatasetListApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id):
"""Resource for creating datasets."""
args = dataset_create_parser.parse_args()
payload = DatasetCreatePayload.model_validate(service_api_ns.payload or {})
embedding_model_provider = args.get("embedding_model_provider")
embedding_model = args.get("embedding_model")
embedding_model_provider = payload.embedding_model_provider
embedding_model = payload.embedding_model
if embedding_model_provider and embedding_model:
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
retrieval_model = args.get("retrieval_model")
retrieval_model = payload.retrieval_model
if (
retrieval_model
and retrieval_model.get("reranking_model")
and retrieval_model.get("reranking_model").get("reranking_provider_name")
and retrieval_model.reranking_model
and retrieval_model.reranking_model.reranking_provider_name
and retrieval_model.reranking_model.reranking_model_name
):
DatasetService.check_reranking_model_setting(
tenant_id,
retrieval_model.get("reranking_model").get("reranking_provider_name"),
retrieval_model.get("reranking_model").get("reranking_model_name"),
retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.reranking_model.reranking_model_name,
)
try:
assert isinstance(current_user, Account)
dataset = DatasetService.create_empty_dataset(
tenant_id=tenant_id,
name=args["name"],
description=args["description"],
indexing_technique=args["indexing_technique"],
name=payload.name,
description=payload.description,
indexing_technique=payload.indexing_technique,
account=current_user,
permission=args["permission"],
provider=args["provider"],
external_knowledge_api_id=args["external_knowledge_api_id"],
external_knowledge_id=args["external_knowledge_id"],
embedding_model_provider=args["embedding_model_provider"],
embedding_model_name=args["embedding_model"],
retrieval_model=RetrievalModel.model_validate(args["retrieval_model"])
if args["retrieval_model"] is not None
else None,
permission=str(payload.permission) if payload.permission else None,
provider=payload.provider,
external_knowledge_api_id=payload.external_knowledge_api_id,
external_knowledge_id=payload.external_knowledge_id,
embedding_model_provider=payload.embedding_model_provider,
embedding_model_name=payload.embedding_model,
retrieval_model=payload.retrieval_model,
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
@ -353,7 +264,7 @@ class DatasetApi(DatasetApiResource):
return data, 200
@service_api_ns.expect(dataset_update_parser)
@service_api_ns.expect(service_api_ns.models[DatasetUpdatePayload.__name__])
@service_api_ns.doc("update_dataset")
@service_api_ns.doc(description="Update an existing dataset")
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
@ -372,36 +283,45 @@ class DatasetApi(DatasetApiResource):
if dataset is None:
raise NotFound("Dataset not found.")
args = dataset_update_parser.parse_args()
data = request.get_json()
payload_dict = service_api_ns.payload or {}
payload = DatasetUpdatePayload.model_validate(payload_dict)
update_data = payload.model_dump(exclude_unset=True)
if payload.permission is not None:
update_data["permission"] = str(payload.permission)
if payload.retrieval_model is not None:
update_data["retrieval_model"] = payload.retrieval_model.model_dump()
# check embedding model setting
embedding_model_provider = data.get("embedding_model_provider")
embedding_model = data.get("embedding_model")
if data.get("indexing_technique") == "high_quality" or embedding_model_provider:
embedding_model_provider = payload.embedding_model_provider
embedding_model = payload.embedding_model
if payload.indexing_technique == "high_quality" or embedding_model_provider:
if embedding_model_provider and embedding_model:
DatasetService.check_embedding_model_setting(
dataset.tenant_id, embedding_model_provider, embedding_model
)
retrieval_model = data.get("retrieval_model")
retrieval_model = payload.retrieval_model
if (
retrieval_model
and retrieval_model.get("reranking_model")
and retrieval_model.get("reranking_model").get("reranking_provider_name")
and retrieval_model.reranking_model
and retrieval_model.reranking_model.reranking_provider_name
and retrieval_model.reranking_model.reranking_model_name
):
DatasetService.check_reranking_model_setting(
dataset.tenant_id,
retrieval_model.get("reranking_model").get("reranking_provider_name"),
retrieval_model.get("reranking_model").get("reranking_model_name"),
retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.reranking_model.reranking_model_name,
)
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission(
current_user, dataset, data.get("permission"), data.get("partial_member_list")
current_user,
dataset,
str(payload.permission) if payload.permission else None,
payload.partial_member_list,
)
dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
dataset = DatasetService.update_dataset(dataset_id_str, update_data, current_user)
if dataset is None:
raise NotFound("Dataset not found.")
@ -410,15 +330,10 @@ class DatasetApi(DatasetApiResource):
assert isinstance(current_user, Account)
tenant_id = current_user.current_tenant_id
if data.get("partial_member_list") and data.get("permission") == "partial_members":
DatasetPermissionService.update_partial_member_list(
tenant_id, dataset_id_str, data.get("partial_member_list")
)
if payload.partial_member_list and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM:
DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id_str, payload.partial_member_list)
# clear partial member list when permission is only_me or all_team_members
elif (
data.get("permission") == DatasetPermissionEnum.ONLY_ME
or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
):
elif payload.permission in {DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM}:
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
@ -556,7 +471,7 @@ class DatasetTagsApi(DatasetApiResource):
return tags, 200
@service_api_ns.expect(tag_create_parser)
@service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__])
@service_api_ns.doc("create_dataset_tag")
@service_api_ns.doc(description="Add a knowledge type tag")
@service_api_ns.doc(
@ -574,14 +489,13 @@ class DatasetTagsApi(DatasetApiResource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
args = tag_create_parser.parse_args()
args["type"] = "knowledge"
tag = TagService.save_tags(args)
payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
tag = TagService.save_tags({"name": payload.name, "type": "knowledge"})
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
return response, 200
@service_api_ns.expect(tag_update_parser)
@service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__])
@service_api_ns.doc("update_dataset_tag")
@service_api_ns.doc(description="Update a knowledge type tag")
@service_api_ns.doc(
@ -598,10 +512,10 @@ class DatasetTagsApi(DatasetApiResource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
args = tag_update_parser.parse_args()
args["type"] = "knowledge"
tag_id = args["tag_id"]
tag = TagService.update_tags(args, tag_id)
payload = TagUpdatePayload.model_validate(service_api_ns.payload or {})
params = {"name": payload.name, "type": "knowledge"}
tag_id = payload.tag_id
tag = TagService.update_tags(params, tag_id)
binding_count = TagService.get_tag_binding_count(tag_id)
@ -609,7 +523,7 @@ class DatasetTagsApi(DatasetApiResource):
return response, 200
@service_api_ns.expect(tag_delete_parser)
@service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__])
@service_api_ns.doc("delete_dataset_tag")
@service_api_ns.doc(description="Delete a knowledge type tag")
@service_api_ns.doc(
@ -623,15 +537,15 @@ class DatasetTagsApi(DatasetApiResource):
@edit_permission_required
def delete(self, _, dataset_id):
"""Delete a knowledge type tag."""
args = tag_delete_parser.parse_args()
TagService.delete_tag(args["tag_id"])
payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag(payload.tag_id)
return 204
@service_api_ns.route("/datasets/tags/binding")
class DatasetTagBindingApi(DatasetApiResource):
@service_api_ns.expect(tag_binding_parser)
@service_api_ns.expect(service_api_ns.models[TagBindingPayload.__name__])
@service_api_ns.doc("bind_dataset_tags")
@service_api_ns.doc(description="Bind tags to a dataset")
@service_api_ns.doc(
@ -648,16 +562,15 @@ class DatasetTagBindingApi(DatasetApiResource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
args = tag_binding_parser.parse_args()
args["type"] = "knowledge"
TagService.save_tag_binding(args)
payload = TagBindingPayload.model_validate(service_api_ns.payload or {})
TagService.save_tag_binding({"tag_ids": payload.tag_ids, "target_id": payload.target_id, "type": "knowledge"})
return 204
@service_api_ns.route("/datasets/tags/unbinding")
class DatasetTagUnbindingApi(DatasetApiResource):
@service_api_ns.expect(tag_unbinding_parser)
@service_api_ns.expect(service_api_ns.models[TagUnbindingPayload.__name__])
@service_api_ns.doc("unbind_dataset_tag")
@service_api_ns.doc(description="Unbind a tag from a dataset")
@service_api_ns.doc(
@ -674,9 +587,8 @@ class DatasetTagUnbindingApi(DatasetApiResource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
args = tag_unbinding_parser.parse_args()
args["type"] = "knowledge"
TagService.delete_tag_binding(args)
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag_binding({"tag_id": payload.tag_id, "target_id": payload.target_id, "type": "knowledge"})
return 204

View File

@ -3,8 +3,8 @@ from typing import Self
from uuid import UUID
from flask import request
from flask_restx import marshal, reqparse
from pydantic import BaseModel, model_validator
from flask_restx import marshal
from pydantic import BaseModel, Field, model_validator
from sqlalchemy import desc, select
from werkzeug.exceptions import Forbidden, NotFound
@ -37,22 +37,19 @@ from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
from services.file_service import FileService
# Define parsers for document operations
document_text_create_parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=True, nullable=False, location="json")
.add_argument("text", type=str, required=True, nullable=False, location="json")
.add_argument("process_rule", type=dict, required=False, nullable=True, location="json")
.add_argument("original_document_id", type=str, required=False, location="json")
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
.add_argument(
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
)
.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
)
class DocumentTextCreatePayload(BaseModel):
name: str
text: str
process_rule: ProcessRule | None = None
original_document_id: str | None = None
doc_form: str = Field(default="text_model")
doc_language: str = Field(default="English")
indexing_technique: str | None = None
retrieval_model: RetrievalModel | None = None
embedding_model: str | None = None
embedding_model_provider: str | None = None
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@ -72,7 +69,7 @@ class DocumentTextUpdate(BaseModel):
return self
for m in [ProcessRule, RetrievalModel, DocumentTextUpdate]:
for m in [ProcessRule, RetrievalModel, DocumentTextCreatePayload, DocumentTextUpdate]:
service_api_ns.schema_model(m.__name__, m.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) # type: ignore
@ -83,7 +80,7 @@ for m in [ProcessRule, RetrievalModel, DocumentTextUpdate]:
class DocumentAddByTextApi(DatasetApiResource):
"""Resource for documents."""
@service_api_ns.expect(document_text_create_parser)
@service_api_ns.expect(service_api_ns.models[DocumentTextCreatePayload.__name__])
@service_api_ns.doc("create_document_by_text")
@service_api_ns.doc(description="Create a new document by providing text content")
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
@ -99,7 +96,8 @@ class DocumentAddByTextApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id):
"""Create document by text."""
args = document_text_create_parser.parse_args()
payload = DocumentTextCreatePayload.model_validate(service_api_ns.payload or {})
args = payload.model_dump(exclude_none=True)
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
@ -111,33 +109,29 @@ class DocumentAddByTextApi(DatasetApiResource):
if not dataset.indexing_technique and not args["indexing_technique"]:
raise ValueError("indexing_technique is required.")
text = args.get("text")
name = args.get("name")
if text is None or name is None:
raise ValueError("Both 'text' and 'name' must be non-null values.")
embedding_model_provider = args.get("embedding_model_provider")
embedding_model = args.get("embedding_model")
embedding_model_provider = payload.embedding_model_provider
embedding_model = payload.embedding_model
if embedding_model_provider and embedding_model:
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
retrieval_model = args.get("retrieval_model")
retrieval_model = payload.retrieval_model
if (
retrieval_model
and retrieval_model.get("reranking_model")
and retrieval_model.get("reranking_model").get("reranking_provider_name")
and retrieval_model.reranking_model
and retrieval_model.reranking_model.reranking_provider_name
and retrieval_model.reranking_model.reranking_model_name
):
DatasetService.check_reranking_model_setting(
tenant_id,
retrieval_model.get("reranking_model").get("reranking_provider_name"),
retrieval_model.get("reranking_model").get("reranking_model_name"),
retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.reranking_model.reranking_model_name,
)
if not current_user:
raise ValueError("current_user is required")
upload_file = FileService(db.engine).upload_text(
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
text=payload.text, text_name=payload.name, user_id=current_user.id, tenant_id=tenant_id
)
data_source = {
"type": "upload_file",
@ -174,7 +168,7 @@ class DocumentAddByTextApi(DatasetApiResource):
class DocumentUpdateByTextApi(DatasetApiResource):
"""Resource for update documents."""
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__], validate=True)
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__])
@service_api_ns.doc("update_document_by_text")
@service_api_ns.doc(description="Update an existing document by providing text content")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@ -189,22 +183,23 @@ class DocumentUpdateByTextApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
"""Update document by text."""
args = DocumentTextUpdate.model_validate(service_api_ns.payload).model_dump(exclude_unset=True)
payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {})
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).first()
args = payload.model_dump(exclude_none=True)
if not dataset:
raise ValueError("Dataset does not exist.")
retrieval_model = args.get("retrieval_model")
retrieval_model = payload.retrieval_model
if (
retrieval_model
and retrieval_model.get("reranking_model")
and retrieval_model.get("reranking_model").get("reranking_provider_name")
and retrieval_model.reranking_model
and retrieval_model.reranking_model.reranking_provider_name
and retrieval_model.reranking_model.reranking_model_name
):
DatasetService.check_reranking_model_setting(
tenant_id,
retrieval_model.get("reranking_model").get("reranking_provider_name"),
retrieval_model.get("reranking_model").get("reranking_model_name"),
retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.reranking_model.reranking_model_name,
)
# indexing_technique is already set in dataset since this is an update

View File

@ -1,9 +1,11 @@
from typing import Literal
from flask_login import current_user
from flask_restx import marshal, reqparse
from flask_restx import marshal
from pydantic import BaseModel
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_model, register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
from fields.dataset_fields import dataset_metadata_fields
@ -14,25 +16,18 @@ from services.entities.knowledge_entities.knowledge_entities import (
)
from services.metadata_service import MetadataService
# Define parsers for metadata APIs
metadata_create_parser = (
reqparse.RequestParser()
.add_argument("type", type=str, required=True, nullable=False, location="json", help="Metadata type")
.add_argument("name", type=str, required=True, nullable=False, location="json", help="Metadata name")
)
metadata_update_parser = reqparse.RequestParser().add_argument(
"name", type=str, required=True, nullable=False, location="json", help="New metadata name"
)
class MetadataUpdatePayload(BaseModel):
name: str
document_metadata_parser = reqparse.RequestParser().add_argument(
"operation_data", type=list, required=True, nullable=False, location="json", help="Metadata operation data"
)
register_schema_model(service_api_ns, MetadataUpdatePayload)
register_schema_models(service_api_ns, MetadataArgs, MetadataOperationData)
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata")
class DatasetMetadataCreateServiceApi(DatasetApiResource):
@service_api_ns.expect(metadata_create_parser)
@service_api_ns.expect(service_api_ns.models[MetadataArgs.__name__])
@service_api_ns.doc("create_dataset_metadata")
@service_api_ns.doc(description="Create metadata for a dataset")
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
@ -46,8 +41,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id):
"""Create metadata for a dataset."""
args = metadata_create_parser.parse_args()
metadata_args = MetadataArgs.model_validate(args)
metadata_args = MetadataArgs.model_validate(service_api_ns.payload or {})
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -79,7 +73,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
class DatasetMetadataServiceApi(DatasetApiResource):
@service_api_ns.expect(metadata_update_parser)
@service_api_ns.expect(service_api_ns.models[MetadataUpdatePayload.__name__])
@service_api_ns.doc("update_dataset_metadata")
@service_api_ns.doc(description="Update metadata name")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "metadata_id": "Metadata ID"})
@ -93,7 +87,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, tenant_id, dataset_id, metadata_id):
"""Update metadata name."""
args = metadata_update_parser.parse_args()
payload = MetadataUpdatePayload.model_validate(service_api_ns.payload or {})
dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id)
@ -102,7 +96,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args["name"])
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, payload.name)
return marshal(metadata, dataset_metadata_fields), 200
@service_api_ns.doc("delete_dataset_metadata")
@ -175,7 +169,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/metadata")
class DocumentMetadataEditServiceApi(DatasetApiResource):
@service_api_ns.expect(document_metadata_parser)
@service_api_ns.expect(service_api_ns.models[MetadataOperationData.__name__])
@service_api_ns.doc("update_documents_metadata")
@service_api_ns.doc(description="Update metadata for multiple documents")
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
@ -195,8 +189,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
args = document_metadata_parser.parse_args()
metadata_args = MetadataOperationData.model_validate(args)
metadata_args = MetadataOperationData.model_validate(service_api_ns.payload or {})
MetadataService.update_documents_metadata(dataset, metadata_args)

View File

@ -4,12 +4,12 @@ from collections.abc import Generator
from typing import Any
from flask import request
from flask_restx import reqparse
from flask_restx.reqparse import ParseResult, RequestParser
from pydantic import BaseModel
from werkzeug.exceptions import Forbidden
import services
from controllers.common.errors import FilenameNotExistsError, NoFileUploadedError, TooManyFilesError
from controllers.common.schema import register_schema_model
from controllers.service_api import service_api_ns
from controllers.service_api.dataset.error import PipelineRunError
from controllers.service_api.wraps import DatasetApiResource
@ -22,11 +22,25 @@ from models.dataset import Pipeline
from models.engine import db
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
from services.file_service import FileService
from services.rag_pipeline.entity.pipeline_service_api_entities import DatasourceNodeRunApiEntity
from services.rag_pipeline.entity.pipeline_service_api_entities import (
DatasourceNodeRunApiEntity,
PipelineRunApiEntity,
)
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
from services.rag_pipeline.rag_pipeline import RagPipelineService
class DatasourceNodeRunPayload(BaseModel):
inputs: dict[str, Any]
datasource_type: str
credential_id: str | None = None
is_published: bool
register_schema_model(service_api_ns, DatasourceNodeRunPayload)
register_schema_model(service_api_ns, PipelineRunApiEntity)
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource-plugins")
class DatasourcePluginsApi(DatasetApiResource):
"""Resource for datasource plugins."""
@ -88,22 +102,20 @@ class DatasourceNodeRunApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
@service_api_ns.expect(service_api_ns.models[DatasourceNodeRunPayload.__name__])
def post(self, tenant_id: str, dataset_id: str, node_id: str):
"""Resource for getting datasource plugins."""
# Get query parameter to determine published or draft
parser: RequestParser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
.add_argument("is_published", type=bool, required=True, location="json")
)
args: ParseResult = parser.parse_args()
datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(args)
payload = DatasourceNodeRunPayload.model_validate(service_api_ns.payload or {})
assert isinstance(current_user, Account)
rag_pipeline_service: RagPipelineService = RagPipelineService()
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(
{
**payload.model_dump(exclude_none=True),
"pipeline_id": str(pipeline.id),
"node_id": node_id,
}
)
return helper.compact_generate_response(
PipelineGenerator.convert_to_event_stream(
rag_pipeline_service.run_datasource_workflow_node(
@ -147,25 +159,10 @@ class PipelineRunApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
@service_api_ns.expect(service_api_ns.models[PipelineRunApiEntity.__name__])
def post(self, tenant_id: str, dataset_id: str):
"""Resource for running a rag pipeline."""
parser: RequestParser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
.add_argument("is_published", type=bool, required=True, default=True, location="json")
.add_argument(
"response_mode",
type=str,
required=True,
choices=["streaming", "blocking"],
default="blocking",
location="json",
)
)
args: ParseResult = parser.parse_args()
payload = PipelineRunApiEntity.model_validate(service_api_ns.payload or {})
if not isinstance(current_user, Account):
raise Forbidden()
@ -176,9 +173,9 @@ class PipelineRunApi(DatasetApiResource):
response: dict[Any, Any] | Generator[str, Any, None] = PipelineGenerateService.generate(
pipeline=pipeline,
user=current_user,
args=args,
invoke_from=InvokeFrom.PUBLISHED if args.get("is_published") else InvokeFrom.DEBUGGER,
streaming=args.get("response_mode") == "streaming",
args=payload.model_dump(),
invoke_from=InvokeFrom.PUBLISHED if payload.is_published else InvokeFrom.DEBUGGER,
streaming=payload.response_mode == "streaming",
)
return helper.compact_generate_response(response)

View File

@ -1,8 +1,12 @@
from typing import Any
from flask import request
from flask_restx import marshal, reqparse
from flask_restx import marshal
from pydantic import BaseModel, Field
from werkzeug.exceptions import NotFound
from configs import dify_config
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.wraps import (
@ -24,34 +28,42 @@ from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexing
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
# Define parsers for segment operations
segment_create_parser = reqparse.RequestParser().add_argument(
"segments", type=list, required=False, nullable=True, location="json"
)
segment_list_parser = (
reqparse.RequestParser()
.add_argument("status", type=str, action="append", default=[], location="args")
.add_argument("keyword", type=str, default=None, location="args")
)
class SegmentCreatePayload(BaseModel):
segments: list[dict[str, Any]] | None = None
segment_update_parser = reqparse.RequestParser().add_argument(
"segment", type=dict, required=False, nullable=True, location="json"
)
child_chunk_create_parser = reqparse.RequestParser().add_argument(
"content", type=str, required=True, nullable=False, location="json"
)
class SegmentListQuery(BaseModel):
status: list[str] = Field(default_factory=list)
keyword: str | None = None
child_chunk_list_parser = (
reqparse.RequestParser()
.add_argument("limit", type=int, default=20, location="args")
.add_argument("keyword", type=str, default=None, location="args")
.add_argument("page", type=int, default=1, location="args")
)
child_chunk_update_parser = reqparse.RequestParser().add_argument(
"content", type=str, required=True, nullable=False, location="json"
class SegmentUpdatePayload(BaseModel):
segment: SegmentUpdateArgs
class ChildChunkCreatePayload(BaseModel):
content: str
class ChildChunkListQuery(BaseModel):
limit: int = Field(default=20, ge=1)
keyword: str | None = None
page: int = Field(default=1, ge=1)
class ChildChunkUpdatePayload(BaseModel):
content: str
register_schema_models(
service_api_ns,
SegmentCreatePayload,
SegmentListQuery,
SegmentUpdatePayload,
ChildChunkCreatePayload,
ChildChunkListQuery,
ChildChunkUpdatePayload,
)
@ -59,7 +71,7 @@ child_chunk_update_parser = reqparse.RequestParser().add_argument(
class SegmentApi(DatasetApiResource):
"""Resource for segments."""
@service_api_ns.expect(segment_create_parser)
@service_api_ns.expect(service_api_ns.models[SegmentCreatePayload.__name__])
@service_api_ns.doc("create_segments")
@service_api_ns.doc(description="Create segments in a document")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@ -106,20 +118,20 @@ class SegmentApi(DatasetApiResource):
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# validate args
args = segment_create_parser.parse_args()
if args["segments"] is not None:
payload = SegmentCreatePayload.model_validate(service_api_ns.payload or {})
if payload.segments is not None:
segments_limit = dify_config.DATASET_MAX_SEGMENTS_PER_REQUEST
if segments_limit > 0 and len(args["segments"]) > segments_limit:
if segments_limit > 0 and len(payload.segments) > segments_limit:
raise ValueError(f"Exceeded maximum segments limit of {segments_limit}.")
for args_item in args["segments"]:
for args_item in payload.segments:
SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(args["segments"], document, dataset)
segments = SegmentService.multi_create_segment(payload.segments, document, dataset)
return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200
else:
return {"error": "Segments is required"}, 400
@service_api_ns.expect(segment_list_parser)
@service_api_ns.expect(service_api_ns.models[SegmentListQuery.__name__])
@service_api_ns.doc("list_segments")
@service_api_ns.doc(description="List segments in a document")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@ -160,13 +172,18 @@ class SegmentApi(DatasetApiResource):
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
args = segment_list_parser.parse_args()
args = SegmentListQuery.model_validate(
{
"status": request.args.getlist("status"),
"keyword": request.args.get("keyword"),
}
)
segments, total = SegmentService.get_segments(
document_id=document_id,
tenant_id=current_tenant_id,
status_list=args["status"],
keyword=args["keyword"],
status_list=args.status,
keyword=args.keyword,
page=page,
limit=limit,
)
@ -217,7 +234,7 @@ class DatasetSegmentApi(DatasetApiResource):
SegmentService.delete_segment(segment, document, dataset)
return 204
@service_api_ns.expect(segment_update_parser)
@service_api_ns.expect(service_api_ns.models[SegmentUpdatePayload.__name__])
@service_api_ns.doc("update_segment")
@service_api_ns.doc(description="Update a specific segment")
@service_api_ns.doc(
@ -265,12 +282,9 @@ class DatasetSegmentApi(DatasetApiResource):
if not segment:
raise NotFound("Segment not found.")
# validate args
args = segment_update_parser.parse_args()
payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {})
updated_segment = SegmentService.update_segment(
SegmentUpdateArgs.model_validate(args["segment"]), segment, document, dataset
)
updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset)
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200
@service_api_ns.doc("get_segment")
@ -308,7 +322,7 @@ class DatasetSegmentApi(DatasetApiResource):
class ChildChunkApi(DatasetApiResource):
"""Resource for child chunks."""
@service_api_ns.expect(child_chunk_create_parser)
@service_api_ns.expect(service_api_ns.models[ChildChunkCreatePayload.__name__])
@service_api_ns.doc("create_child_chunk")
@service_api_ns.doc(description="Create a new child chunk for a segment")
@service_api_ns.doc(
@ -360,16 +374,16 @@ class ChildChunkApi(DatasetApiResource):
raise ProviderNotInitializeError(ex.description)
# validate args
args = child_chunk_create_parser.parse_args()
payload = ChildChunkCreatePayload.model_validate(service_api_ns.payload or {})
try:
child_chunk = SegmentService.create_child_chunk(args["content"], segment, document, dataset)
child_chunk = SegmentService.create_child_chunk(payload.content, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
@service_api_ns.expect(child_chunk_list_parser)
@service_api_ns.expect(service_api_ns.models[ChildChunkListQuery.__name__])
@service_api_ns.doc("list_child_chunks")
@service_api_ns.doc(description="List child chunks for a segment")
@service_api_ns.doc(
@ -400,11 +414,17 @@ class ChildChunkApi(DatasetApiResource):
if not segment:
raise NotFound("Segment not found.")
args = child_chunk_list_parser.parse_args()
args = ChildChunkListQuery.model_validate(
{
"limit": request.args.get("limit", default=20, type=int),
"keyword": request.args.get("keyword"),
"page": request.args.get("page", default=1, type=int),
}
)
page = args["page"]
limit = min(args["limit"], 100)
keyword = args["keyword"]
page = args.page
limit = min(args.limit, 100)
keyword = args.keyword
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
@ -480,7 +500,7 @@ class DatasetChildChunkApi(DatasetApiResource):
return 204
@service_api_ns.expect(child_chunk_update_parser)
@service_api_ns.expect(service_api_ns.models[ChildChunkUpdatePayload.__name__])
@service_api_ns.doc("update_child_chunk")
@service_api_ns.doc(description="Update a specific child chunk")
@service_api_ns.doc(
@ -533,10 +553,10 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Child chunk not found.")
# validate args
args = child_chunk_update_parser.parse_args()
payload = ChildChunkUpdatePayload.model_validate(service_api_ns.payload or {})
try:
child_chunk = SegmentService.update_child_chunk(args["content"], child_chunk, segment, document, dataset)
child_chunk = SegmentService.update_child_chunk(payload.content, child_chunk, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))

View File

@ -1,3 +1,4 @@
import logging
import time
from collections.abc import Callable
from datetime import timedelta
@ -28,6 +29,8 @@ P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
logger = logging.getLogger(__name__)
class WhereisUserArg(StrEnum):
"""
@ -238,8 +241,8 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
# Basic check: UUIDs are 36 chars with hyphens
if len(str_id) == 36 and str_id.count("-") == 4:
dataset_id = str_id
except:
pass
except Exception:
logger.exception("Failed to parse dataset_id from class method args")
elif len(args) > 0:
# Not a class method, check if args[0] looks like a UUID
potential_id = args[0]
@ -247,8 +250,8 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
str_id = str(potential_id)
if len(str_id) == 36 and str_id.count("-") == 4:
dataset_id = str_id
except:
pass
except Exception:
logger.exception("Failed to parse dataset_id from positional args")
# Validate dataset if dataset_id is provided
if dataset_id:
@ -316,18 +319,16 @@ def validate_and_get_api_token(scope: str | None = None):
ApiToken.type == scope,
)
.values(last_used_at=current_time)
.returning(ApiToken)
)
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
result = session.execute(update_stmt)
api_token = result.scalar_one_or_none()
api_token = session.scalar(stmt)
if hasattr(result, "rowcount") and result.rowcount > 0:
session.commit()
if not api_token:
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
api_token = session.scalar(stmt)
if not api_token:
raise Unauthorized("Access token is invalid")
else:
session.commit()
raise Unauthorized("Access token is invalid")
return api_token

View File

@ -33,7 +33,7 @@ def trigger_endpoint(endpoint_id: str):
if response:
break
if not response:
logger.error("Endpoint not found for {endpoint_id}")
logger.info("Endpoint not found for %s", endpoint_id)
return jsonify({"error": "Endpoint not found"}), 404
return response
except ValueError as e: