Merge branch 'main' into feat/agent-node-v2

This commit is contained in:
Novice
2025-12-15 15:26:48 +08:00
694 changed files with 37577 additions and 16560 deletions

View File

@ -0,0 +1,26 @@
"""Helpers for registering Pydantic models with Flask-RESTX namespaces."""
from flask_restx import Namespace
from pydantic import BaseModel
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
def register_schema_model(namespace: Namespace, model: type[BaseModel]) -> None:
"""Register a single BaseModel with a namespace for Swagger documentation."""
namespace.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> None:
"""Register multiple BaseModels with a namespace."""
for model in models:
register_schema_model(namespace, model)
__all__ = [
"DEFAULT_REF_TEMPLATE_SWAGGER_2_0",
"register_schema_model",
"register_schema_models",
]

View File

@ -6,19 +6,20 @@ from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound, Unauthorized
P = ParamSpec("P")
R = TypeVar("R")
from configs import dify_config
from constants.languages import supported_language
from controllers.console import console_ns
from controllers.console.wraps import only_edition_cloud
from core.db.session_factory import session_factory
from extensions.ext_database import db
from libs.token import extract_access_token
from models.model import App, InstalledApp, RecommendedApp
P = ParamSpec("P")
R = TypeVar("R")
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@ -90,7 +91,7 @@ class InsertExploreAppListApi(Resource):
privacy_policy = site.privacy_policy or payload.privacy_policy or ""
custom_disclaimer = site.custom_disclaimer or payload.custom_disclaimer or ""
with Session(db.engine) as session:
with session_factory.create_session() as session:
recommended_app = session.execute(
select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id)
).scalar_one_or_none()
@ -138,7 +139,7 @@ class InsertExploreAppApi(Resource):
@only_edition_cloud
@admin_required
def delete(self, app_id):
with Session(db.engine) as session:
with session_factory.create_session() as session:
recommended_app = session.execute(
select(RecommendedApp).where(RecommendedApp.app_id == str(app_id))
).scalar_one_or_none()
@ -146,13 +147,13 @@ class InsertExploreAppApi(Resource):
if not recommended_app:
return {"result": "success"}, 204
with Session(db.engine) as session:
with session_factory.create_session() as session:
app = session.execute(select(App).where(App.id == recommended_app.app_id)).scalar_one_or_none()
if app:
app.is_public = False
with Session(db.engine) as session:
with session_factory.create_session() as session:
installed_apps = (
session.execute(
select(InstalledApp).where(

View File

@ -1,6 +1,6 @@
from typing import Any, Literal
from flask import request
from flask import abort, request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
@ -8,6 +8,8 @@ from controllers.common.errors import NoFileUploadedError, TooManyFilesError
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
annotation_import_concurrency_limit,
annotation_import_rate_limit,
cloud_edition_billing_resource_check,
edit_permission_required,
setup_required,
@ -314,18 +316,25 @@ class AnnotationUpdateDeleteApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import")
class AnnotationBatchImportApi(Resource):
@console_ns.doc("batch_import_annotations")
@console_ns.doc(description="Batch import annotations from CSV file")
@console_ns.doc(description="Batch import annotations from CSV file with rate limiting and security checks")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Batch import started successfully")
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(400, "No file uploaded or too many files")
@console_ns.response(413, "File too large")
@console_ns.response(429, "Too many requests or concurrent imports")
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@annotation_import_rate_limit
@annotation_import_concurrency_limit
@edit_permission_required
def post(self, app_id):
from configs import dify_config
app_id = str(app_id)
# check file
if "file" not in request.files:
raise NoFileUploadedError()
@ -335,9 +344,27 @@ class AnnotationBatchImportApi(Resource):
# get file from request
file = request.files["file"]
# check file type
if not file.filename or not file.filename.lower().endswith(".csv"):
raise ValueError("Invalid file type. Only CSV files are allowed")
# Check file size before processing
file.seek(0, 2) # Seek to end of file
file_size = file.tell()
file.seek(0) # Reset to beginning
max_size_bytes = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024
if file_size > max_size_bytes:
abort(
413,
f"File size exceeds maximum limit of {dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT}MB. "
f"Please reduce the file size and try again.",
)
if file_size == 0:
raise ValueError("The uploaded file is empty")
return AppAnnotationService.batch_import_app_annotations(app_id, file)

View File

@ -31,7 +31,6 @@ from fields.app_fields import (
from fields.workflow_fields import workflow_partial_fields as _workflow_partial_fields_dict
from libs.helper import AppIconUrlField, TimestampField
from libs.login import current_account_with_tenant, login_required
from libs.validators import validate_description_length
from models import App, Workflow
from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService
@ -76,51 +75,30 @@ class AppListQuery(BaseModel):
class CreateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@field_validator("description")
@classmethod
def validate_description(cls, value: str | None) -> str | None:
if value is None:
return value
return validate_description_length(value)
class UpdateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
max_active_requests: int | None = Field(default=None, description="Maximum active requests")
@field_validator("description")
@classmethod
def validate_description(cls, value: str | None) -> str | None:
if value is None:
return value
return validate_description_length(value)
class CopyAppPayload(BaseModel):
name: str | None = Field(default=None, description="Name for the copied app")
description: str | None = Field(default=None, description="Description for the copied app")
description: str | None = Field(default=None, description="Description for the copied app", max_length=400)
icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@field_validator("description")
@classmethod
def validate_description(cls, value: str | None) -> str | None:
if value is None:
return value
return validate_description_length(value)
class AppExportQuery(BaseModel):
include_secret: bool = Field(default=False, description="Include secrets in export")

View File

@ -61,6 +61,7 @@ class ChatMessagesQuery(BaseModel):
class MessageFeedbackPayload(BaseModel):
message_id: str = Field(..., description="Message ID")
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
content: str | None = Field(default=None, description="Feedback content")
@field_validator("message_id")
@classmethod
@ -325,6 +326,7 @@ class MessageFeedbackApi(Resource):
db.session.delete(feedback)
elif args.rating and feedback:
feedback.rating = args.rating
feedback.content = args.content
elif not args.rating and not feedback:
raise ValueError("rating cannot be None when feedback not exists")
else:
@ -336,6 +338,7 @@ class MessageFeedbackApi(Resource):
conversation_id=message.conversation_id,
message_id=message.id,
rating=rating_value,
content=args.content,
from_source="admin",
from_account_id=current_user.id,
)

View File

@ -114,7 +114,7 @@ class AppTriggersApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/trigger-enable")
class AppTriggerEnableApi(Resource):
@console_ns.expect(console_ns.models[ParserEnable.__name__], validate=True)
@console_ns.expect(console_ns.models[ParserEnable.__name__])
@setup_required
@login_required
@account_initialization_required

View File

@ -1,15 +1,15 @@
import json
from collections.abc import Generator
from typing import cast
from typing import Any, cast
from flask import request
from flask_restx import Resource, marshal_with, reqparse
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.common.schema import register_schema_model
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.indexing_runner import IndexingRunner
@ -25,6 +25,19 @@ from services.dataset_service import DatasetService, DocumentService
from services.datasource_provider_service import DatasourceProviderService
from tasks.document_indexing_sync_task import document_indexing_sync_task
from .. import console_ns
from ..wraps import account_initialization_required, setup_required
class NotionEstimatePayload(BaseModel):
notion_info_list: list[dict[str, Any]]
process_rule: dict[str, Any]
doc_form: str = Field(default="text_model")
doc_language: str = Field(default="English")
register_schema_model(console_ns, NotionEstimatePayload)
@console_ns.route(
"/data-source/integrates",
@ -243,20 +256,15 @@ class DataSourceNotionApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.expect(console_ns.models[NotionEstimatePayload.__name__])
def post(self):
_, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
)
args = parser.parse_args()
payload = NotionEstimatePayload.model_validate(console_ns.payload or {})
args = payload.model_dump()
# validate args
DocumentService.estimate_args_validate(args)
notion_info_list = args["notion_info_list"]
notion_info_list = payload.notion_info_list
extract_settings = []
for notion_info in notion_info_list:
workspace_id = notion_info["workspace_id"]

View File

@ -1,12 +1,14 @@
from typing import Any, cast
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound
import services
from configs import dify_config
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.apikey import (
api_key_item_model,
@ -48,7 +50,6 @@ from fields.dataset_fields import (
)
from fields.document_fields import document_status_fields
from libs.login import current_account_with_tenant, login_required
from libs.validators import validate_description_length
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID
@ -107,10 +108,75 @@ related_app_list_copy["data"] = fields.List(fields.Nested(app_detail_kernel_mode
related_app_list_model = _get_or_create_model("RelatedAppList", related_app_list_copy)
def _validate_name(name: str) -> str:
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_indexing_technique(value: str | None) -> str | None:
if value is None:
return value
if value not in Dataset.INDEXING_TECHNIQUE_LIST:
raise ValueError("Invalid indexing technique.")
return value
class DatasetCreatePayload(BaseModel):
name: str = Field(..., min_length=1, max_length=40)
description: str = Field("", max_length=400)
indexing_technique: str | None = None
permission: DatasetPermissionEnum | None = DatasetPermissionEnum.ONLY_ME
provider: str = "vendor"
external_knowledge_api_id: str | None = None
external_knowledge_id: str | None = None
@field_validator("indexing_technique")
@classmethod
def validate_indexing(cls, value: str | None) -> str | None:
return _validate_indexing_technique(value)
@field_validator("provider")
@classmethod
def validate_provider(cls, value: str) -> str:
if value not in Dataset.PROVIDER_LIST:
raise ValueError("Invalid provider.")
return value
class DatasetUpdatePayload(BaseModel):
name: str | None = Field(None, min_length=1, max_length=40)
description: str | None = Field(None, max_length=400)
permission: DatasetPermissionEnum | None = None
indexing_technique: str | None = None
embedding_model: str | None = None
embedding_model_provider: str | None = None
retrieval_model: dict[str, Any] | None = None
partial_member_list: list[str] | None = None
external_retrieval_model: dict[str, Any] | None = None
external_knowledge_id: str | None = None
external_knowledge_api_id: str | None = None
icon_info: dict[str, Any] | None = None
is_multimodal: bool | None = False
@field_validator("indexing_technique")
@classmethod
def validate_indexing(cls, value: str | None) -> str | None:
return _validate_indexing_technique(value)
class IndexingEstimatePayload(BaseModel):
info_list: dict[str, Any]
process_rule: dict[str, Any]
indexing_technique: str
doc_form: str = "text_model"
dataset_id: str | None = None
doc_language: str = "English"
@field_validator("indexing_technique")
@classmethod
def validate_indexing(cls, value: str) -> str:
result = _validate_indexing_technique(value)
if result is None:
raise ValueError("indexing_technique is required.")
return result
register_schema_models(console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload)
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
@ -164,6 +230,7 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
VectorType.CLICKZETTA,
VectorType.BAIDU,
VectorType.ALIBABACLOUD_MYSQL,
VectorType.IRIS,
}
semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
@ -255,20 +322,7 @@ class DatasetListApi(Resource):
@console_ns.doc("create_dataset")
@console_ns.doc(description="Create a new dataset")
@console_ns.expect(
console_ns.model(
"CreateDatasetRequest",
{
"name": fields.String(required=True, description="Dataset name (1-40 characters)"),
"description": fields.String(description="Dataset description (max 400 characters)"),
"indexing_technique": fields.String(description="Indexing technique"),
"permission": fields.String(description="Dataset permission"),
"provider": fields.String(description="Provider"),
"external_knowledge_api_id": fields.String(description="External knowledge API ID"),
"external_knowledge_id": fields.String(description="External knowledge ID"),
},
)
)
@console_ns.expect(console_ns.models[DatasetCreatePayload.__name__])
@console_ns.response(201, "Dataset created successfully")
@console_ns.response(400, "Invalid request parameters")
@setup_required
@ -276,52 +330,7 @@ class DatasetListApi(Resource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument(
"description",
type=validate_description_length,
nullable=True,
required=False,
default="",
)
.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
.add_argument(
"external_knowledge_api_id",
type=str,
nullable=True,
required=False,
)
.add_argument(
"provider",
type=str,
nullable=True,
choices=Dataset.PROVIDER_LIST,
required=False,
default="vendor",
)
.add_argument(
"external_knowledge_id",
type=str,
nullable=True,
required=False,
)
)
args = parser.parse_args()
payload = DatasetCreatePayload.model_validate(console_ns.payload or {})
current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
@ -331,14 +340,14 @@ class DatasetListApi(Resource):
try:
dataset = DatasetService.create_empty_dataset(
tenant_id=current_tenant_id,
name=args["name"],
description=args["description"],
indexing_technique=args["indexing_technique"],
name=payload.name,
description=payload.description,
indexing_technique=payload.indexing_technique,
account=current_user,
permission=DatasetPermissionEnum.ONLY_ME,
provider=args["provider"],
external_knowledge_api_id=args["external_knowledge_api_id"],
external_knowledge_id=args["external_knowledge_id"],
permission=payload.permission or DatasetPermissionEnum.ONLY_ME,
provider=payload.provider,
external_knowledge_api_id=payload.external_knowledge_api_id,
external_knowledge_id=payload.external_knowledge_id,
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
@ -399,18 +408,7 @@ class DatasetApi(Resource):
@console_ns.doc("update_dataset")
@console_ns.doc(description="Update dataset details")
@console_ns.expect(
console_ns.model(
"UpdateDatasetRequest",
{
"name": fields.String(description="Dataset name"),
"description": fields.String(description="Dataset description"),
"permission": fields.String(description="Dataset permission"),
"indexing_technique": fields.String(description="Indexing technique"),
"external_retrieval_model": fields.Raw(description="External retrieval model settings"),
},
)
)
@console_ns.expect(console_ns.models[DatasetUpdatePayload.__name__])
@console_ns.response(200, "Dataset updated successfully", dataset_detail_model)
@console_ns.response(404, "Dataset not found")
@console_ns.response(403, "Permission denied")
@ -424,93 +422,25 @@ class DatasetApi(Resource):
if dataset is None:
raise NotFound("Dataset not found.")
parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument("description", location="json", store_missing=False, type=validate_description_length)
.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
.add_argument(
"permission",
type=str,
location="json",
choices=(
DatasetPermissionEnum.ONLY_ME,
DatasetPermissionEnum.ALL_TEAM,
DatasetPermissionEnum.PARTIAL_TEAM,
),
help="Invalid permission.",
)
.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
.add_argument(
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
)
.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
.add_argument(
"external_retrieval_model",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid external retrieval model.",
)
.add_argument(
"external_knowledge_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge id.",
)
.add_argument(
"external_knowledge_api_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge api id.",
)
.add_argument(
"icon_info",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid icon info.",
)
)
args = parser.parse_args()
data = request.get_json()
payload = DatasetUpdatePayload.model_validate(console_ns.payload or {})
current_user, current_tenant_id = current_account_with_tenant()
# check embedding model setting
if (
data.get("indexing_technique") == "high_quality"
and data.get("embedding_model_provider") is not None
and data.get("embedding_model") is not None
payload.indexing_technique == "high_quality"
and payload.embedding_model_provider is not None
and payload.embedding_model is not None
):
DatasetService.check_embedding_model_setting(
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
is_multimodal = DatasetService.check_is_multimodal_model(
dataset.tenant_id, payload.embedding_model_provider, payload.embedding_model
)
payload.is_multimodal = is_multimodal
payload_data = payload.model_dump(exclude_unset=True)
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission(
current_user, dataset, data.get("permission"), data.get("partial_member_list")
current_user, dataset, payload.permission, payload.partial_member_list
)
dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
dataset = DatasetService.update_dataset(dataset_id_str, payload_data, current_user)
if dataset is None:
raise NotFound("Dataset not found.")
@ -518,15 +448,10 @@ class DatasetApi(Resource):
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
tenant_id = current_tenant_id
if data.get("partial_member_list") and data.get("permission") == "partial_members":
DatasetPermissionService.update_partial_member_list(
tenant_id, dataset_id_str, data.get("partial_member_list")
)
if payload.partial_member_list is not None and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM:
DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id_str, payload.partial_member_list)
# clear partial member list when permission is only_me or all_team_members
elif (
data.get("permission") == DatasetPermissionEnum.ONLY_ME
or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
):
elif payload.permission in {DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM}:
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
@ -615,24 +540,10 @@ class DatasetIndexingEstimateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.expect(console_ns.models[IndexingEstimatePayload.__name__])
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("info_list", type=dict, required=True, nullable=True, location="json")
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
.add_argument(
"indexing_technique",
type=str,
required=True,
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
location="json",
)
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
)
args = parser.parse_args()
payload = IndexingEstimatePayload.model_validate(console_ns.payload or {})
args = payload.model_dump()
_, current_tenant_id = current_account_with_tenant()
# validate args
DocumentService.estimate_args_validate(args)

View File

@ -6,31 +6,14 @@ from typing import Literal, cast
import sqlalchemy as sa
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel
from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import (
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.datasets.error import (
ArchivedDocumentImmutableError,
DocumentAlreadyFinishedError,
DocumentIndexingError,
IndexingEstimateError,
InvalidActionError,
InvalidMetadataError,
)
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
setup_required,
)
from core.errors.error import (
LLMBadRequestError,
ModelCurrentlyNotSupportError,
@ -55,10 +38,30 @@ from fields.document_fields import (
)
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
from ..app.error import (
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from ..datasets.error import (
ArchivedDocumentImmutableError,
DocumentAlreadyFinishedError,
DocumentIndexingError,
IndexingEstimateError,
InvalidActionError,
InvalidMetadataError,
)
from ..wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
setup_required,
)
logger = logging.getLogger(__name__)
@ -93,6 +96,24 @@ dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(docume
dataset_and_document_model = _get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
class DocumentRetryPayload(BaseModel):
document_ids: list[str]
class DocumentRenamePayload(BaseModel):
name: str
register_schema_models(
console_ns,
KnowledgeConfig,
ProcessRule,
RetrievalModel,
DocumentRetryPayload,
DocumentRenamePayload,
)
class DocumentResource(Resource):
def get_document(self, dataset_id: str, document_id: str) -> Document:
current_user, current_tenant_id = current_account_with_tenant()
@ -201,8 +222,9 @@ class DatasetDocumentListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: str):
def get(self, dataset_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str)
@ -310,6 +332,7 @@ class DatasetDocumentListApi(Resource):
@marshal_with(dataset_and_document_model)
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
@ -328,23 +351,7 @@ class DatasetDocumentListApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
parser = (
reqparse.RequestParser()
.add_argument(
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
)
.add_argument("data_source", type=dict, required=False, location="json")
.add_argument("process_rule", type=dict, required=False, location="json")
.add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
.add_argument("original_document_id", type=str, required=False, location="json")
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
)
args = parser.parse_args()
knowledge_config = KnowledgeConfig.model_validate(args)
knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {})
if not dataset.indexing_technique and not knowledge_config.indexing_technique:
raise ValueError("indexing_technique is required.")
@ -390,17 +397,7 @@ class DatasetDocumentListApi(Resource):
class DatasetInitApi(Resource):
@console_ns.doc("init_dataset")
@console_ns.doc(description="Initialize dataset with documents")
@console_ns.expect(
console_ns.model(
"DatasetInitRequest",
{
"upload_file_id": fields.String(required=True, description="Upload file ID"),
"indexing_technique": fields.String(description="Indexing technique"),
"process_rule": fields.Raw(description="Processing rules"),
"data_source": fields.Raw(description="Data source configuration"),
},
)
)
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
@console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model)
@console_ns.response(400, "Invalid request parameters")
@setup_required
@ -415,27 +412,7 @@ class DatasetInitApi(Resource):
if not current_user.is_dataset_editor:
raise Forbidden()
parser = (
reqparse.RequestParser()
.add_argument(
"indexing_technique",
type=str,
choices=Dataset.INDEXING_TECHNIQUE_LIST,
required=True,
nullable=False,
location="json",
)
.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
)
args = parser.parse_args()
knowledge_config = KnowledgeConfig.model_validate(args)
knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {})
if knowledge_config.indexing_technique == "high_quality":
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
@ -443,10 +420,14 @@ class DatasetInitApi(Resource):
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_tenant_id,
provider=args["embedding_model_provider"],
provider=knowledge_config.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=args["embedding_model"],
model=knowledge_config.embedding_model,
)
is_multimodal = DatasetService.check_is_multimodal_model(
current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model
)
knowledge_config.is_multimodal = is_multimodal
except InvokeAuthorizationError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
@ -1076,19 +1057,16 @@ class DocumentRetryApi(DocumentResource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[DocumentRetryPayload.__name__])
def post(self, dataset_id):
"""retry document."""
parser = reqparse.RequestParser().add_argument(
"document_ids", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args()
payload = DocumentRetryPayload.model_validate(console_ns.payload or {})
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
retry_documents = []
if not dataset:
raise NotFound("Dataset not found.")
for document_id in args["document_ids"]:
for document_id in payload.document_ids:
try:
document_id = str(document_id)
@ -1121,6 +1099,7 @@ class DocumentRenameApi(DocumentResource):
@login_required
@account_initialization_required
@marshal_with(document_fields)
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
def post(self, dataset_id, document_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
current_user, _ = current_account_with_tenant()
@ -1130,11 +1109,10 @@ class DocumentRenameApi(DocumentResource):
if not dataset:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_operator_permission(current_user, dataset)
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
payload = DocumentRenamePayload.model_validate(console_ns.payload or {})
try:
document = DocumentService.rename_document(dataset_id, document_id, args["name"])
document = DocumentService.rename_document(dataset_id, document_id, payload.name)
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")

View File

@ -1,11 +1,13 @@
import uuid
from flask import request
from flask_restx import Resource, marshal, reqparse
from flask_restx import Resource, marshal
from pydantic import BaseModel, Field
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import (
@ -36,6 +38,58 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
class SegmentListQuery(BaseModel):
limit: int = Field(default=20, ge=1, le=100)
status: list[str] = Field(default_factory=list)
hit_count_gte: int | None = None
enabled: str = Field(default="all")
keyword: str | None = None
page: int = Field(default=1, ge=1)
class SegmentCreatePayload(BaseModel):
content: str
answer: str | None = None
keywords: list[str] | None = None
attachment_ids: list[str] | None = None
class SegmentUpdatePayload(BaseModel):
content: str
answer: str | None = None
keywords: list[str] | None = None
regenerate_child_chunks: bool = False
attachment_ids: list[str] | None = None
class BatchImportPayload(BaseModel):
upload_file_id: str
class ChildChunkCreatePayload(BaseModel):
content: str
class ChildChunkUpdatePayload(BaseModel):
content: str
class ChildChunkBatchUpdatePayload(BaseModel):
chunks: list[ChildChunkUpdateArgs]
register_schema_models(
console_ns,
SegmentListQuery,
SegmentCreatePayload,
SegmentUpdatePayload,
BatchImportPayload,
ChildChunkCreatePayload,
ChildChunkUpdatePayload,
ChildChunkBatchUpdatePayload,
)
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
class DatasetDocumentSegmentListApi(Resource):
@setup_required
@ -60,23 +114,18 @@ class DatasetDocumentSegmentListApi(Resource):
if not document:
raise NotFound("Document not found.")
parser = (
reqparse.RequestParser()
.add_argument("limit", type=int, default=20, location="args")
.add_argument("status", type=str, action="append", default=[], location="args")
.add_argument("hit_count_gte", type=int, default=None, location="args")
.add_argument("enabled", type=str, default="all", location="args")
.add_argument("keyword", type=str, default=None, location="args")
.add_argument("page", type=int, default=1, location="args")
args = SegmentListQuery.model_validate(
{
**request.args.to_dict(),
"status": request.args.getlist("status"),
}
)
args = parser.parse_args()
page = args["page"]
limit = min(args["limit"], 100)
status_list = args["status"]
hit_count_gte = args["hit_count_gte"]
keyword = args["keyword"]
page = args.page
limit = min(args.limit, 100)
status_list = args.status
hit_count_gte = args.hit_count_gte
keyword = args.keyword
query = (
select(DocumentSegment)
@ -96,10 +145,10 @@ class DatasetDocumentSegmentListApi(Resource):
if keyword:
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
if args["enabled"].lower() != "all":
if args["enabled"].lower() == "true":
if args.enabled.lower() != "all":
if args.enabled.lower() == "true":
query = query.where(DocumentSegment.enabled == True)
elif args["enabled"].lower() == "false":
elif args.enabled.lower() == "false":
query = query.where(DocumentSegment.enabled == False)
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
@ -210,6 +259,7 @@ class DatasetDocumentSegmentAddApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[SegmentCreatePayload.__name__])
def post(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
@ -246,15 +296,10 @@ class DatasetDocumentSegmentAddApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = (
reqparse.RequestParser()
.add_argument("content", type=str, required=True, nullable=False, location="json")
.add_argument("answer", type=str, required=False, nullable=True, location="json")
.add_argument("keywords", type=list, required=False, nullable=True, location="json")
)
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.create_segment(args, document, dataset)
payload = SegmentCreatePayload.model_validate(console_ns.payload or {})
payload_dict = payload.model_dump(exclude_none=True)
SegmentService.segment_create_args_validate(payload_dict, document)
segment = SegmentService.create_segment(payload_dict, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@ -265,6 +310,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[SegmentUpdatePayload.__name__])
def patch(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
@ -313,18 +359,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = (
reqparse.RequestParser()
.add_argument("content", type=str, required=True, nullable=False, location="json")
.add_argument("answer", type=str, required=False, nullable=True, location="json")
.add_argument("keywords", type=list, required=False, nullable=True, location="json")
.add_argument(
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
)
payload = SegmentUpdatePayload.model_validate(console_ns.payload or {})
payload_dict = payload.model_dump(exclude_none=True)
SegmentService.segment_create_args_validate(payload_dict, document)
segment = SegmentService.update_segment(
SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
)
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(SegmentUpdateArgs.model_validate(args), segment, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@setup_required
@ -377,6 +417,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[BatchImportPayload.__name__])
def post(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
@ -391,11 +432,8 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
if not document:
raise NotFound("Document not found.")
parser = reqparse.RequestParser().add_argument(
"upload_file_id", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
upload_file_id = args["upload_file_id"]
payload = BatchImportPayload.model_validate(console_ns.payload or {})
upload_file_id = payload.upload_file_id
upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
if not upload_file:
@ -446,6 +484,7 @@ class ChildChunkAddApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[ChildChunkCreatePayload.__name__])
def post(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
@ -491,13 +530,9 @@ class ChildChunkAddApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser().add_argument(
"content", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
try:
content = args["content"]
child_chunk = SegmentService.create_child_chunk(content, segment, document, dataset)
payload = ChildChunkCreatePayload.model_validate(console_ns.payload or {})
child_chunk = SegmentService.create_child_chunk(payload.content, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
@ -529,18 +564,17 @@ class ChildChunkAddApi(Resource):
)
if not segment:
raise NotFound("Segment not found.")
parser = (
reqparse.RequestParser()
.add_argument("limit", type=int, default=20, location="args")
.add_argument("keyword", type=str, default=None, location="args")
.add_argument("page", type=int, default=1, location="args")
args = SegmentListQuery.model_validate(
{
"limit": request.args.get("limit", default=20, type=int),
"keyword": request.args.get("keyword"),
"page": request.args.get("page", default=1, type=int),
}
)
args = parser.parse_args()
page = args["page"]
limit = min(args["limit"], 100)
keyword = args["keyword"]
page = args.page
limit = min(args.limit, 100)
keyword = args.keyword
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
return {
@ -588,14 +622,9 @@ class ChildChunkAddApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser().add_argument(
"chunks", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args()
payload = ChildChunkBatchUpdatePayload.model_validate(console_ns.payload or {})
try:
chunks_data = args["chunks"]
chunks = [ChildChunkUpdateArgs.model_validate(chunk) for chunk in chunks_data]
child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
child_chunks = SegmentService.update_child_chunks(payload.chunks, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunks, child_chunk_fields)}, 200
@ -665,6 +694,7 @@ class ChildChunkUpdateApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[ChildChunkUpdatePayload.__name__])
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
current_user, current_tenant_id = current_account_with_tenant()
@ -711,13 +741,9 @@ class ChildChunkUpdateApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser().add_argument(
"content", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
try:
content = args["content"]
child_chunk = SegmentService.update_child_chunk(content, child_chunk, segment, document, dataset)
payload = ChildChunkUpdatePayload.model_validate(console_ns.payload or {})
child_chunk = SegmentService.update_child_chunk(payload.content, child_chunk, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200

View File

@ -1,8 +1,10 @@
from flask import request
from flask_restx import Resource, fields, marshal, reqparse
from flask_restx import Resource, fields, marshal
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
@ -71,10 +73,38 @@ except KeyError:
dataset_detail_model = _build_dataset_detail_model()
def _validate_name(name: str) -> str:
if not name or len(name) < 1 or len(name) > 100:
raise ValueError("Name must be between 1 to 100 characters.")
return name
class ExternalKnowledgeApiPayload(BaseModel):
name: str = Field(..., min_length=1, max_length=40)
settings: dict[str, object]
class ExternalDatasetCreatePayload(BaseModel):
external_knowledge_api_id: str
external_knowledge_id: str
name: str = Field(..., min_length=1, max_length=40)
description: str | None = Field(None, max_length=400)
external_retrieval_model: dict[str, object] | None = None
class ExternalHitTestingPayload(BaseModel):
query: str
external_retrieval_model: dict[str, object] | None = None
metadata_filtering_conditions: dict[str, object] | None = None
class BedrockRetrievalPayload(BaseModel):
retrieval_setting: dict[str, object]
query: str
knowledge_id: str
register_schema_models(
console_ns,
ExternalKnowledgeApiPayload,
ExternalDatasetCreatePayload,
ExternalHitTestingPayload,
BedrockRetrievalPayload,
)
@console_ns.route("/datasets/external-knowledge-api")
@ -113,28 +143,12 @@ class ExternalApiTemplateListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
def post(self):
current_user, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="Name is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
.add_argument(
"settings",
type=dict,
location="json",
nullable=False,
required=True,
)
)
args = parser.parse_args()
payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
ExternalDatasetService.validate_api_list(args["settings"])
ExternalDatasetService.validate_api_list(payload.settings)
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
@ -142,7 +156,7 @@ class ExternalApiTemplateListApi(Resource):
try:
external_knowledge_api = ExternalDatasetService.create_external_knowledge_api(
tenant_id=current_tenant_id, user_id=current_user.id, args=args
tenant_id=current_tenant_id, user_id=current_user.id, args=payload.model_dump()
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
@ -171,35 +185,19 @@ class ExternalApiTemplateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
def patch(self, external_knowledge_api_id):
current_user, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id = str(external_knowledge_api_id)
parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="type is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
.add_argument(
"settings",
type=dict,
location="json",
nullable=False,
required=True,
)
)
args = parser.parse_args()
ExternalDatasetService.validate_api_list(args["settings"])
payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
ExternalDatasetService.validate_api_list(payload.settings)
external_knowledge_api = ExternalDatasetService.update_external_knowledge_api(
tenant_id=current_tenant_id,
user_id=current_user.id,
external_knowledge_api_id=external_knowledge_api_id,
args=args,
args=payload.model_dump(),
)
return external_knowledge_api.to_dict(), 200
@ -240,17 +238,7 @@ class ExternalApiUseCheckApi(Resource):
class ExternalDatasetCreateApi(Resource):
@console_ns.doc("create_external_dataset")
@console_ns.doc(description="Create external knowledge dataset")
@console_ns.expect(
console_ns.model(
"CreateExternalDatasetRequest",
{
"external_knowledge_api_id": fields.String(required=True, description="External knowledge API ID"),
"external_knowledge_id": fields.String(required=True, description="External knowledge ID"),
"name": fields.String(required=True, description="Dataset name"),
"description": fields.String(description="Dataset description"),
},
)
)
@console_ns.expect(console_ns.models[ExternalDatasetCreatePayload.__name__])
@console_ns.response(201, "External dataset created successfully", dataset_detail_model)
@console_ns.response(400, "Invalid parameters")
@console_ns.response(403, "Permission denied")
@ -261,22 +249,8 @@ class ExternalDatasetCreateApi(Resource):
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, current_tenant_id = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")
.add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json")
.add_argument(
"name",
nullable=False,
required=True,
help="name is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
.add_argument("description", type=str, required=False, nullable=True, location="json")
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
)
args = parser.parse_args()
payload = ExternalDatasetCreatePayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
@ -299,16 +273,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
@console_ns.doc("test_external_knowledge_retrieval")
@console_ns.doc(description="Test external knowledge retrieval for dataset")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect(
console_ns.model(
"ExternalHitTestingRequest",
{
"query": fields.String(required=True, description="Query text for testing"),
"retrieval_model": fields.Raw(description="Retrieval model configuration"),
"external_retrieval_model": fields.Raw(description="External retrieval model configuration"),
},
)
)
@console_ns.expect(console_ns.models[ExternalHitTestingPayload.__name__])
@console_ns.response(200, "External hit testing completed successfully")
@console_ns.response(404, "Dataset not found")
@console_ns.response(400, "Invalid parameters")
@ -327,23 +292,16 @@ class ExternalKnowledgeHitTestingApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
parser = (
reqparse.RequestParser()
.add_argument("query", type=str, location="json")
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json")
)
args = parser.parse_args()
HitTestingService.hit_testing_args_check(args)
payload = ExternalHitTestingPayload.model_validate(console_ns.payload or {})
HitTestingService.hit_testing_args_check(payload.model_dump())
try:
response = HitTestingService.external_retrieve(
dataset=dataset,
query=args["query"],
query=payload.query,
account=current_user,
external_retrieval_model=args["external_retrieval_model"],
metadata_filtering_conditions=args["metadata_filtering_conditions"],
external_retrieval_model=payload.external_retrieval_model,
metadata_filtering_conditions=payload.metadata_filtering_conditions,
)
return response
@ -356,33 +314,13 @@ class BedrockRetrievalApi(Resource):
# this api is only for internal testing
@console_ns.doc("bedrock_retrieval_test")
@console_ns.doc(description="Bedrock retrieval test (internal use only)")
@console_ns.expect(
console_ns.model(
"BedrockRetrievalTestRequest",
{
"retrieval_setting": fields.Raw(required=True, description="Retrieval settings"),
"query": fields.String(required=True, description="Query text"),
"knowledge_id": fields.String(required=True, description="Knowledge ID"),
},
)
)
@console_ns.expect(console_ns.models[BedrockRetrievalPayload.__name__])
@console_ns.response(200, "Bedrock retrieval test completed")
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
.add_argument(
"query",
nullable=False,
required=True,
type=str,
)
.add_argument("knowledge_id", nullable=False, required=True, type=str)
)
args = parser.parse_args()
payload = BedrockRetrievalPayload.model_validate(console_ns.payload or {})
# Call the knowledge retrieval service
result = ExternalDatasetTestService.knowledge_retrieval(
args["retrieval_setting"], args["query"], args["knowledge_id"]
payload.retrieval_setting, payload.query, payload.knowledge_id
)
return result, 200

View File

@ -1,13 +1,17 @@
from flask_restx import Resource, fields
from flask_restx import Resource
from controllers.console import console_ns
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
from controllers.console.wraps import (
from controllers.common.schema import register_schema_model
from libs.login import login_required
from .. import console_ns
from ..datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload
from ..wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
setup_required,
)
from libs.login import login_required
register_schema_model(console_ns, HitTestingPayload)
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
@ -15,17 +19,7 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
@console_ns.doc("test_dataset_retrieval")
@console_ns.doc(description="Test dataset knowledge retrieval")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect(
console_ns.model(
"HitTestingRequest",
{
"query": fields.String(required=True, description="Query text for testing"),
"retrieval_model": fields.Raw(description="Retrieval model configuration"),
"top_k": fields.Integer(description="Number of top results to return"),
"score_threshold": fields.Float(description="Score threshold for filtering results"),
},
)
)
@console_ns.expect(console_ns.models[HitTestingPayload.__name__])
@console_ns.response(200, "Hit testing completed successfully")
@console_ns.response(404, "Dataset not found")
@console_ns.response(400, "Invalid parameters")
@ -37,7 +31,8 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
dataset_id_str = str(dataset_id)
dataset = self.get_and_validate_dataset(dataset_id_str)
args = self.parse_args()
payload = HitTestingPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
self.hit_testing_args_check(args)
return self.perform_hit_testing(dataset, args)

View File

@ -1,6 +1,8 @@
import logging
from typing import Any
from flask_restx import marshal, reqparse
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
@ -27,6 +29,13 @@ from services.hit_testing_service import HitTestingService
logger = logging.getLogger(__name__)
class HitTestingPayload(BaseModel):
query: str = Field(max_length=250)
retrieval_model: dict[str, Any] | None = None
external_retrieval_model: dict[str, Any] | None = None
attachment_ids: list[str] | None = None
class DatasetsHitTestingBase:
@staticmethod
def get_and_validate_dataset(dataset_id: str):
@ -43,14 +52,15 @@ class DatasetsHitTestingBase:
return dataset
@staticmethod
def hit_testing_args_check(args):
def hit_testing_args_check(args: dict[str, Any]):
HitTestingService.hit_testing_args_check(args)
@staticmethod
def parse_args():
parser = (
reqparse.RequestParser()
.add_argument("query", type=str, location="json")
.add_argument("query", type=str, required=False, location="json")
.add_argument("attachment_ids", type=list, required=False, location="json")
.add_argument("retrieval_model", type=dict, required=False, location="json")
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
)
@ -62,10 +72,11 @@ class DatasetsHitTestingBase:
try:
response = HitTestingService.retrieve(
dataset=dataset,
query=args["query"],
query=args.get("query"),
account=current_user,
retrieval_model=args["retrieval_model"],
external_retrieval_model=args["external_retrieval_model"],
retrieval_model=args.get("retrieval_model"),
external_retrieval_model=args.get("external_retrieval_model"),
attachment_ids=args.get("attachment_ids"),
limit=10,
)
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}

View File

@ -1,8 +1,10 @@
from typing import Literal
from flask_restx import Resource, marshal_with, reqparse
from flask_restx import Resource, marshal_with
from pydantic import BaseModel
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_model, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from fields.dataset_fields import dataset_metadata_fields
@ -15,6 +17,14 @@ from services.entities.knowledge_entities.knowledge_entities import (
from services.metadata_service import MetadataService
class MetadataUpdatePayload(BaseModel):
name: str
register_schema_models(console_ns, MetadataArgs, MetadataOperationData)
register_schema_model(console_ns, MetadataUpdatePayload)
@console_ns.route("/datasets/<uuid:dataset_id>/metadata")
class DatasetMetadataCreateApi(Resource):
@setup_required
@ -22,15 +32,10 @@ class DatasetMetadataCreateApi(Resource):
@account_initialization_required
@enterprise_license_required
@marshal_with(dataset_metadata_fields)
@console_ns.expect(console_ns.models[MetadataArgs.__name__])
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("type", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
metadata_args = MetadataArgs.model_validate(args)
metadata_args = MetadataArgs.model_validate(console_ns.payload or {})
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -60,11 +65,11 @@ class DatasetMetadataApi(Resource):
@account_initialization_required
@enterprise_license_required
@marshal_with(dataset_metadata_fields)
@console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__])
def patch(self, dataset_id, metadata_id):
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
name = args["name"]
payload = MetadataUpdatePayload.model_validate(console_ns.payload or {})
name = payload.name
dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id)
@ -131,6 +136,7 @@ class DocumentMetadataEditApi(Resource):
@login_required
@account_initialization_required
@enterprise_license_required
@console_ns.expect(console_ns.models[MetadataOperationData.__name__])
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
@ -139,11 +145,7 @@ class DocumentMetadataEditApi(Resource):
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
parser = reqparse.RequestParser().add_argument(
"operation_data", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args()
metadata_args = MetadataOperationData.model_validate(args)
metadata_args = MetadataOperationData.model_validate(console_ns.payload or {})
MetadataService.update_documents_metadata(dataset, metadata_args)

View File

@ -1,20 +1,63 @@
from typing import Any
from flask import make_response, redirect, request
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.oauth import OAuthHandler
from libs.helper import StrLen
from libs.login import current_account_with_tenant, login_required
from models.provider_ids import DatasourceProviderID
from services.datasource_provider_service import DatasourceProviderService
from services.plugin.oauth_service import OAuthProxyService
class DatasourceCredentialPayload(BaseModel):
name: str | None = Field(default=None, max_length=100)
credentials: dict[str, Any]
class DatasourceCredentialDeletePayload(BaseModel):
credential_id: str
class DatasourceCredentialUpdatePayload(BaseModel):
credential_id: str
name: str | None = Field(default=None, max_length=100)
credentials: dict[str, Any] | None = None
class DatasourceCustomClientPayload(BaseModel):
client_params: dict[str, Any] | None = None
enable_oauth_custom_client: bool | None = None
class DatasourceDefaultPayload(BaseModel):
id: str
class DatasourceUpdateNamePayload(BaseModel):
credential_id: str
name: str = Field(max_length=100)
register_schema_models(
console_ns,
DatasourceCredentialPayload,
DatasourceCredentialDeletePayload,
DatasourceCredentialUpdatePayload,
DatasourceCustomClientPayload,
DatasourceDefaultPayload,
DatasourceUpdateNamePayload,
)
@console_ns.route("/oauth/plugin/<path:provider_id>/datasource/get-authorization-url")
class DatasourcePluginOAuthAuthorizationUrl(Resource):
@setup_required
@ -121,16 +164,9 @@ class DatasourceOAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
parser_datasource = (
reqparse.RequestParser()
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>")
class DatasourceAuth(Resource):
@console_ns.expect(parser_datasource)
@console_ns.expect(console_ns.models[DatasourceCredentialPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -138,7 +174,7 @@ class DatasourceAuth(Resource):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
args = parser_datasource.parse_args()
payload = DatasourceCredentialPayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
@ -146,8 +182,8 @@ class DatasourceAuth(Resource):
datasource_provider_service.add_datasource_api_key_provider(
tenant_id=current_tenant_id,
provider_id=datasource_provider_id,
credentials=args["credentials"],
name=args["name"],
credentials=payload.credentials,
name=payload.name,
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
@ -169,14 +205,9 @@ class DatasourceAuth(Resource):
return {"result": datasources}, 200
parser_datasource_delete = reqparse.RequestParser().add_argument(
"credential_id", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete")
class DatasourceAuthDeleteApi(Resource):
@console_ns.expect(parser_datasource_delete)
@console_ns.expect(console_ns.models[DatasourceCredentialDeletePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -188,28 +219,20 @@ class DatasourceAuthDeleteApi(Resource):
plugin_id = datasource_provider_id.plugin_id
provider_name = datasource_provider_id.provider_name
args = parser_datasource_delete.parse_args()
payload = DatasourceCredentialDeletePayload.model_validate(console_ns.payload or {})
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials(
tenant_id=current_tenant_id,
auth_id=args["credential_id"],
auth_id=payload.credential_id,
provider=provider_name,
plugin_id=plugin_id,
)
return {"result": "success"}, 200
parser_datasource_update = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update")
class DatasourceAuthUpdateApi(Resource):
@console_ns.expect(parser_datasource_update)
@console_ns.expect(console_ns.models[DatasourceCredentialUpdatePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -218,16 +241,16 @@ class DatasourceAuthUpdateApi(Resource):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id)
args = parser_datasource_update.parse_args()
payload = DatasourceCredentialUpdatePayload.model_validate(console_ns.payload or {})
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_credentials(
tenant_id=current_tenant_id,
auth_id=args["credential_id"],
auth_id=payload.credential_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
credentials=args.get("credentials", {}),
name=args.get("name", None),
credentials=payload.credentials or {},
name=payload.name,
)
return {"result": "success"}, 201
@ -258,16 +281,9 @@ class DatasourceHardCodeAuthListApi(Resource):
return {"result": jsonable_encoder(datasources)}, 200
parser_datasource_custom = (
reqparse.RequestParser()
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client")
class DatasourceAuthOauthCustomClient(Resource):
@console_ns.expect(parser_datasource_custom)
@console_ns.expect(console_ns.models[DatasourceCustomClientPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -275,14 +291,14 @@ class DatasourceAuthOauthCustomClient(Resource):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
args = parser_datasource_custom.parse_args()
payload = DatasourceCustomClientPayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.setup_oauth_custom_client_params(
tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id,
client_params=args.get("client_params", {}),
enabled=args.get("enable_oauth_custom_client", False),
client_params=payload.client_params or {},
enabled=payload.enable_oauth_custom_client or False,
)
return {"result": "success"}, 200
@ -301,12 +317,9 @@ class DatasourceAuthOauthCustomClient(Resource):
return {"result": "success"}, 200
parser_default = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/default")
class DatasourceAuthDefaultApi(Resource):
@console_ns.expect(parser_default)
@console_ns.expect(console_ns.models[DatasourceDefaultPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -314,27 +327,20 @@ class DatasourceAuthDefaultApi(Resource):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
args = parser_default.parse_args()
payload = DatasourceDefaultPayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.set_default_datasource_provider(
tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id,
credential_id=args["id"],
credential_id=payload.id,
)
return {"result": "success"}, 200
parser_update_name = (
reqparse.RequestParser()
.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name")
class DatasourceUpdateProviderNameApi(Resource):
@console_ns.expect(parser_update_name)
@console_ns.expect(console_ns.models[DatasourceUpdateNamePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -342,13 +348,13 @@ class DatasourceUpdateProviderNameApi(Resource):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
args = parser_update_name.parse_args()
payload = DatasourceUpdateNamePayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_provider_name(
tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id,
name=args["name"],
credential_id=args["credential_id"],
name=payload.name,
credential_id=payload.credential_id,
)
return {"result": "success"}, 200

View File

@ -26,7 +26,7 @@ console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=D
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
class DataSourceContentPreviewApi(Resource):
@console_ns.expect(console_ns.models[Parser.__name__], validate=True)
@console_ns.expect(console_ns.models[Parser.__name__])
@setup_required
@login_required
@account_initialization_required

View File

@ -1,9 +1,11 @@
import logging
from flask import request
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
@ -20,18 +22,6 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService
logger = logging.getLogger(__name__)
def _validate_name(name: str) -> str:
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description: str) -> str:
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
@console_ns.route("/rag/pipeline/templates")
class PipelineTemplateListApi(Resource):
@setup_required
@ -59,6 +49,15 @@ class PipelineTemplateDetailApi(Resource):
return pipeline_template, 200
class Payload(BaseModel):
name: str = Field(..., min_length=1, max_length=40)
description: str = Field(default="", max_length=400)
icon_info: dict[str, object] | None = None
register_schema_models(console_ns, Payload)
@console_ns.route("/rag/pipeline/customized/templates/<string:template_id>")
class CustomizedPipelineTemplateApi(Resource):
@setup_required
@ -66,31 +65,8 @@ class CustomizedPipelineTemplateApi(Resource):
@account_initialization_required
@enterprise_license_required
def patch(self, template_id: str):
parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument(
"description",
type=_validate_description_length,
nullable=True,
required=False,
default="",
)
.add_argument(
"icon_info",
type=dict,
location="json",
nullable=True,
)
)
args = parser.parse_args()
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args)
payload = Payload.model_validate(console_ns.payload or {})
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(payload.model_dump())
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
return 200
@ -119,36 +95,14 @@ class CustomizedPipelineTemplateApi(Resource):
@console_ns.route("/rag/pipelines/<string:pipeline_id>/customized/publish")
class PublishCustomizedPipelineTemplateApi(Resource):
@console_ns.expect(console_ns.models[Payload.__name__])
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
@knowledge_pipeline_publish_enabled
def post(self, pipeline_id: str):
parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument(
"description",
type=_validate_description_length,
nullable=True,
required=False,
default="",
)
.add_argument(
"icon_info",
type=dict,
location="json",
nullable=True,
)
)
args = parser.parse_args()
payload = Payload.model_validate(console_ns.payload or {})
rag_pipeline_service = RagPipelineService()
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args)
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, payload.model_dump())
return {"result": "success"}

View File

@ -1,8 +1,10 @@
from flask_restx import Resource, marshal, reqparse
from flask_restx import Resource, marshal
from pydantic import BaseModel
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
import services
from controllers.common.schema import register_schema_model
from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import (
@ -19,22 +21,22 @@ from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo,
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
class RagPipelineDatasetImportPayload(BaseModel):
yaml_content: str
register_schema_model(console_ns, RagPipelineDatasetImportPayload)
@console_ns.route("/rag/pipeline/dataset")
class CreateRagPipelineDatasetApi(Resource):
@console_ns.expect(console_ns.models[RagPipelineDatasetImportPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
parser = reqparse.RequestParser().add_argument(
"yaml_content",
type=str,
nullable=False,
required=True,
help="yaml_content is required.",
)
args = parser.parse_args()
payload = RagPipelineDatasetImportPayload.model_validate(console_ns.payload or {})
current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
@ -49,7 +51,7 @@ class CreateRagPipelineDatasetApi(Resource):
),
permission=DatasetPermissionEnum.ONLY_ME,
partial_member_list=None,
yaml_content=args["yaml_content"],
yaml_content=payload.yaml_content,
)
try:
with Session(db.engine) as session:

View File

@ -1,11 +1,13 @@
import logging
from typing import NoReturn
from typing import Any, NoReturn
from flask import Response
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from flask import Response, request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import (
DraftWorkflowNotExist,
@ -33,19 +35,21 @@ logger = logging.getLogger(__name__)
def _create_pagination_parser():
parser = (
reqparse.RequestParser()
.add_argument(
"page",
type=inputs.int_range(1, 100_000),
required=False,
default=1,
location="args",
help="the page of data requested",
)
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
)
return parser
class PaginationQuery(BaseModel):
page: int = Field(default=1, ge=1, le=100_000)
limit: int = Field(default=20, ge=1, le=100)
register_schema_models(console_ns, PaginationQuery)
return PaginationQuery
class WorkflowDraftVariablePatchPayload(BaseModel):
name: str | None = None
value: Any | None = None
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
@ -93,8 +97,8 @@ class RagPipelineVariableCollectionApi(Resource):
"""
Get draft workflow
"""
parser = _create_pagination_parser()
args = parser.parse_args()
pagination = _create_pagination_parser()
query = pagination.model_validate(request.args.to_dict())
# fetch draft workflow by app_model
rag_pipeline_service = RagPipelineService()
@ -109,8 +113,8 @@ class RagPipelineVariableCollectionApi(Resource):
)
workflow_vars = draft_var_srv.list_variables_without_values(
app_id=pipeline.id,
page=args.page,
limit=args.limit,
page=query.page,
limit=query.limit,
)
return workflow_vars
@ -186,6 +190,7 @@ class RagPipelineVariableApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
@console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__])
def patch(self, pipeline: Pipeline, variable_id: str):
# Request payload for file types:
#
@ -208,16 +213,11 @@ class RagPipelineVariableApi(Resource):
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
# }
parser = (
reqparse.RequestParser()
.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
)
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
args = parser.parse_args(strict=True)
payload = WorkflowDraftVariablePatchPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:

View File

@ -1,6 +1,9 @@
from flask_restx import Resource, marshal_with, reqparse # type: ignore
from flask import request
from flask_restx import Resource, marshal_with # type: ignore
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import (
@ -16,6 +19,25 @@ from services.app_dsl_service import ImportStatus
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
class RagPipelineImportPayload(BaseModel):
mode: str
yaml_content: str | None = None
yaml_url: str | None = None
name: str | None = None
description: str | None = None
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
pipeline_id: str | None = None
class IncludeSecretQuery(BaseModel):
include_secret: str = Field(default="false")
register_schema_models(console_ns, RagPipelineImportPayload, IncludeSecretQuery)
@console_ns.route("/rag/pipelines/imports")
class RagPipelineImportApi(Resource):
@setup_required
@ -23,23 +45,11 @@ class RagPipelineImportApi(Resource):
@account_initialization_required
@edit_permission_required
@marshal_with(pipeline_import_fields)
@console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__])
def post(self):
# Check user role first
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("mode", type=str, required=True, location="json")
.add_argument("yaml_content", type=str, location="json")
.add_argument("yaml_url", type=str, location="json")
.add_argument("name", type=str, location="json")
.add_argument("description", type=str, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
.add_argument("pipeline_id", type=str, location="json")
)
args = parser.parse_args()
payload = RagPipelineImportPayload.model_validate(console_ns.payload or {})
# Create service with session
with Session(db.engine) as session:
@ -48,11 +58,11 @@ class RagPipelineImportApi(Resource):
account = current_user
result = import_service.import_rag_pipeline(
account=account,
import_mode=args["mode"],
yaml_content=args.get("yaml_content"),
yaml_url=args.get("yaml_url"),
pipeline_id=args.get("pipeline_id"),
dataset_name=args.get("name"),
import_mode=payload.mode,
yaml_content=payload.yaml_content,
yaml_url=payload.yaml_url,
pipeline_id=payload.pipeline_id,
dataset_name=payload.name,
)
session.commit()
@ -114,13 +124,12 @@ class RagPipelineExportApi(Resource):
@edit_permission_required
def get(self, pipeline: Pipeline):
# Add include_secret params
parser = reqparse.RequestParser().add_argument("include_secret", type=str, default="false", location="args")
args = parser.parse_args()
query = IncludeSecretQuery.model_validate(request.args.to_dict())
with Session(db.engine) as session:
export_service = RagPipelineDslService(session)
result = export_service.export_rag_pipeline_dsl(
pipeline=pipeline, include_secret=args["include_secret"] == "true"
pipeline=pipeline, include_secret=query.include_secret == "true"
)
return {"data": result}, 200

View File

@ -1,14 +1,16 @@
import json
import logging
from typing import cast
from typing import Any, Literal, cast
from uuid import UUID
from flask import abort, request
from flask_restx import Resource, inputs, marshal_with, reqparse # type: ignore # type: ignore
from flask_restx.inputs import int_range # type: ignore
from flask_restx import Resource, marshal_with # type: ignore
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import (
ConversationCompletedError,
@ -36,7 +38,7 @@ from fields.workflow_run_fields import (
workflow_run_pagination_fields,
)
from libs import helper
from libs.helper import TimestampField, uuid_value
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, current_user, login_required
from models import Account
from models.dataset import Pipeline
@ -51,6 +53,91 @@ from services.rag_pipeline.rag_pipeline_transform_service import RagPipelineTran
logger = logging.getLogger(__name__)
class DraftWorkflowSyncPayload(BaseModel):
graph: dict[str, Any]
hash: str | None = None
environment_variables: list[dict[str, Any]] | None = None
conversation_variables: list[dict[str, Any]] | None = None
rag_pipeline_variables: list[dict[str, Any]] | None = None
features: dict[str, Any] | None = None
class NodeRunPayload(BaseModel):
inputs: dict[str, Any] | None = None
class NodeRunRequiredPayload(BaseModel):
inputs: dict[str, Any]
class DatasourceNodeRunPayload(BaseModel):
inputs: dict[str, Any]
datasource_type: str
credential_id: str | None = None
class DraftWorkflowRunPayload(BaseModel):
inputs: dict[str, Any]
datasource_type: str
datasource_info_list: list[dict[str, Any]]
start_node_id: str
class PublishedWorkflowRunPayload(DraftWorkflowRunPayload):
is_preview: bool = False
response_mode: Literal["streaming", "blocking"] = "streaming"
original_document_id: str | None = None
class DefaultBlockConfigQuery(BaseModel):
q: str | None = None
class WorkflowListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=10, ge=1, le=100)
user_id: str | None = None
named_only: bool = False
class WorkflowUpdatePayload(BaseModel):
marked_name: str | None = Field(default=None, max_length=20)
marked_comment: str | None = Field(default=None, max_length=100)
class NodeIdQuery(BaseModel):
node_id: str
class WorkflowRunQuery(BaseModel):
last_id: UUID | None = None
limit: int = Field(default=20, ge=1, le=100)
class DatasourceVariablesPayload(BaseModel):
datasource_type: str
datasource_info: dict[str, Any]
start_node_id: str
start_node_title: str
register_schema_models(
console_ns,
DraftWorkflowSyncPayload,
NodeRunPayload,
NodeRunRequiredPayload,
DatasourceNodeRunPayload,
DraftWorkflowRunPayload,
PublishedWorkflowRunPayload,
DefaultBlockConfigQuery,
WorkflowListQuery,
WorkflowUpdatePayload,
NodeIdQuery,
WorkflowRunQuery,
DatasourceVariablesPayload,
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft")
class DraftRagPipelineApi(Resource):
@setup_required
@ -88,15 +175,7 @@ class DraftRagPipelineApi(Resource):
content_type = request.headers.get("Content-Type", "")
if "application/json" in content_type:
parser = (
reqparse.RequestParser()
.add_argument("graph", type=dict, required=True, nullable=False, location="json")
.add_argument("hash", type=str, required=False, location="json")
.add_argument("environment_variables", type=list, required=False, location="json")
.add_argument("conversation_variables", type=list, required=False, location="json")
.add_argument("rag_pipeline_variables", type=list, required=False, location="json")
)
args = parser.parse_args()
payload_dict = console_ns.payload or {}
elif "text/plain" in content_type:
try:
data = json.loads(request.data.decode("utf-8"))
@ -106,7 +185,7 @@ class DraftRagPipelineApi(Resource):
if not isinstance(data.get("graph"), dict):
raise ValueError("graph is not a dict")
args = {
payload_dict = {
"graph": data.get("graph"),
"features": data.get("features"),
"hash": data.get("hash"),
@ -119,24 +198,26 @@ class DraftRagPipelineApi(Resource):
else:
abort(415)
payload = DraftWorkflowSyncPayload.model_validate(payload_dict)
try:
environment_variables_list = args.get("environment_variables") or []
environment_variables_list = payload.environment_variables or []
environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
conversation_variables_list = args.get("conversation_variables") or []
conversation_variables_list = payload.conversation_variables or []
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.sync_draft_workflow(
pipeline=pipeline,
graph=args["graph"],
unique_hash=args.get("hash"),
graph=payload.graph,
unique_hash=payload.hash,
account=current_user,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
rag_pipeline_variables=args.get("rag_pipeline_variables") or [],
rag_pipeline_variables=payload.rag_pipeline_variables or [],
)
except WorkflowHashNotEqualError:
raise DraftWorkflowNotSync()
@ -148,12 +229,9 @@ class DraftRagPipelineApi(Resource):
}
parser_run = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
class RagPipelineDraftRunIterationNodeApi(Resource):
@console_ns.expect(parser_run)
@console_ns.expect(console_ns.models[NodeRunPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -166,7 +244,8 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
args = parser_run.parse_args()
payload = NodeRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
try:
response = PipelineGenerateService.generate_single_iteration(
@ -187,7 +266,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run")
class RagPipelineDraftRunLoopNodeApi(Resource):
@console_ns.expect(parser_run)
@console_ns.expect(console_ns.models[NodeRunPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -200,7 +279,8 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
args = parser_run.parse_args()
payload = NodeRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
try:
response = PipelineGenerateService.generate_single_loop(
@ -219,18 +299,9 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
raise InternalServerError()
parser_draft_run = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run")
class DraftRagPipelineRunApi(Resource):
@console_ns.expect(parser_draft_run)
@console_ns.expect(console_ns.models[DraftWorkflowRunPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -243,7 +314,8 @@ class DraftRagPipelineRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
args = parser_draft_run.parse_args()
payload = DraftWorkflowRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump()
try:
response = PipelineGenerateService.generate(
@ -259,21 +331,9 @@ class DraftRagPipelineRunApi(Resource):
raise InvokeRateLimitHttpError(ex.description)
parser_published_run = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
.add_argument("is_preview", type=bool, required=True, location="json", default=False)
.add_argument("response_mode", type=str, required=True, location="json", default="streaming")
.add_argument("original_document_id", type=str, required=False, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/run")
class PublishedRagPipelineRunApi(Resource):
@console_ns.expect(parser_published_run)
@console_ns.expect(console_ns.models[PublishedWorkflowRunPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -286,16 +346,16 @@ class PublishedRagPipelineRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
args = parser_published_run.parse_args()
streaming = args["response_mode"] == "streaming"
payload = PublishedWorkflowRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
streaming = payload.response_mode == "streaming"
try:
response = PipelineGenerateService.generate(
pipeline=pipeline,
user=current_user,
args=args,
invoke_from=InvokeFrom.DEBUGGER if args.get("is_preview") else InvokeFrom.PUBLISHED,
invoke_from=InvokeFrom.DEBUGGER if payload.is_preview else InvokeFrom.PUBLISHED,
streaming=streaming,
)
@ -387,17 +447,9 @@ class PublishedRagPipelineRunApi(Resource):
#
# return result
#
parser_rag_run = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@console_ns.expect(parser_rag_run)
@console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -410,14 +462,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
args = parser_rag_run.parse_args()
inputs = args.get("inputs")
if inputs is None:
raise ValueError("missing inputs")
datasource_type = args.get("datasource_type")
if datasource_type is None:
raise ValueError("missing datasource_type")
payload = DatasourceNodeRunPayload.model_validate(console_ns.payload or {})
rag_pipeline_service = RagPipelineService()
return helper.compact_generate_response(
@ -425,11 +470,11 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
rag_pipeline_service.run_datasource_workflow_node(
pipeline=pipeline,
node_id=node_id,
user_inputs=inputs,
user_inputs=payload.inputs,
account=current_user,
datasource_type=datasource_type,
datasource_type=payload.datasource_type,
is_published=False,
credential_id=args.get("credential_id"),
credential_id=payload.credential_id,
)
)
)
@ -437,7 +482,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run")
class RagPipelineDraftDatasourceNodeRunApi(Resource):
@console_ns.expect(parser_rag_run)
@console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__])
@setup_required
@login_required
@edit_permission_required
@ -450,14 +495,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
args = parser_rag_run.parse_args()
inputs = args.get("inputs")
if inputs is None:
raise ValueError("missing inputs")
datasource_type = args.get("datasource_type")
if datasource_type is None:
raise ValueError("missing datasource_type")
payload = DatasourceNodeRunPayload.model_validate(console_ns.payload or {})
rag_pipeline_service = RagPipelineService()
return helper.compact_generate_response(
@ -465,24 +503,19 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
rag_pipeline_service.run_datasource_workflow_node(
pipeline=pipeline,
node_id=node_id,
user_inputs=inputs,
user_inputs=payload.inputs,
account=current_user,
datasource_type=datasource_type,
datasource_type=payload.datasource_type,
is_published=False,
credential_id=args.get("credential_id"),
credential_id=payload.credential_id,
)
)
)
parser_run_api = reqparse.RequestParser().add_argument(
"inputs", type=dict, required=True, nullable=False, location="json"
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run")
class RagPipelineDraftNodeRunApi(Resource):
@console_ns.expect(parser_run_api)
@console_ns.expect(console_ns.models[NodeRunRequiredPayload.__name__])
@setup_required
@login_required
@edit_permission_required
@ -496,11 +529,8 @@ class RagPipelineDraftNodeRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
args = parser_run_api.parse_args()
inputs = args.get("inputs")
if inputs == None:
raise ValueError("missing inputs")
payload = NodeRunRequiredPayload.model_validate(console_ns.payload or {})
inputs = payload.inputs
rag_pipeline_service = RagPipelineService()
workflow_node_execution = rag_pipeline_service.run_draft_workflow_node(
@ -602,12 +632,8 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
return rag_pipeline_service.get_default_block_configs()
parser_default = reqparse.RequestParser().add_argument("q", type=str, location="args")
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>")
class DefaultRagPipelineBlockConfigApi(Resource):
@console_ns.expect(parser_default)
@setup_required
@login_required
@account_initialization_required
@ -617,14 +643,12 @@ class DefaultRagPipelineBlockConfigApi(Resource):
"""
Get default block config
"""
args = parser_default.parse_args()
q = args.get("q")
query = DefaultBlockConfigQuery.model_validate(request.args.to_dict())
filters = None
if q:
if query.q:
try:
filters = json.loads(args.get("q", ""))
filters = json.loads(query.q)
except json.JSONDecodeError:
raise ValueError("Invalid filters")
@ -633,18 +657,8 @@ class DefaultRagPipelineBlockConfigApi(Resource):
return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters)
parser_wf = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
.add_argument("user_id", type=str, required=False, location="args")
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows")
class PublishedAllRagPipelineApi(Resource):
@console_ns.expect(parser_wf)
@setup_required
@login_required
@account_initialization_required
@ -657,16 +671,16 @@ class PublishedAllRagPipelineApi(Resource):
"""
current_user, _ = current_account_with_tenant()
args = parser_wf.parse_args()
page = args["page"]
limit = args["limit"]
user_id = args.get("user_id")
named_only = args.get("named_only", False)
query = WorkflowListQuery.model_validate(request.args.to_dict())
page = query.page
limit = query.limit
user_id = query.user_id
named_only = query.named_only
if user_id:
if user_id != current_user.id:
raise Forbidden()
user_id = cast(str, user_id)
rag_pipeline_service = RagPipelineService()
with Session(db.engine) as session:
@ -687,16 +701,8 @@ class PublishedAllRagPipelineApi(Resource):
}
parser_wf_id = (
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, location="json")
.add_argument("marked_comment", type=str, required=False, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>")
class RagPipelineByIdApi(Resource):
@console_ns.expect(parser_wf_id)
@setup_required
@login_required
@account_initialization_required
@ -710,20 +716,8 @@ class RagPipelineByIdApi(Resource):
# Check permission
current_user, _ = current_account_with_tenant()
args = parser_wf_id.parse_args()
# Validate name and comment length
if args.marked_name and len(args.marked_name) > 20:
raise ValueError("Marked name cannot exceed 20 characters")
if args.marked_comment and len(args.marked_comment) > 100:
raise ValueError("Marked comment cannot exceed 100 characters")
# Prepare update data
update_data = {}
if args.get("marked_name") is not None:
update_data["marked_name"] = args["marked_name"]
if args.get("marked_comment") is not None:
update_data["marked_comment"] = args["marked_comment"]
payload = WorkflowUpdatePayload.model_validate(console_ns.payload or {})
update_data = payload.model_dump(exclude_unset=True)
if not update_data:
return {"message": "No valid fields to update"}, 400
@ -749,12 +743,8 @@ class RagPipelineByIdApi(Resource):
return workflow
parser_parameters = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters")
class PublishedRagPipelineSecondStepApi(Resource):
@console_ns.expect(parser_parameters)
@setup_required
@login_required
@account_initialization_required
@ -764,10 +754,8 @@ class PublishedRagPipelineSecondStepApi(Resource):
"""
Get second step parameters of rag pipeline
"""
args = parser_parameters.parse_args()
node_id = args.get("node_id")
if not node_id:
raise ValueError("Node ID is required")
query = NodeIdQuery.model_validate(request.args.to_dict())
node_id = query.node_id
rag_pipeline_service = RagPipelineService()
variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False)
return {
@ -777,7 +765,6 @@ class PublishedRagPipelineSecondStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters")
class PublishedRagPipelineFirstStepApi(Resource):
@console_ns.expect(parser_parameters)
@setup_required
@login_required
@account_initialization_required
@ -787,10 +774,8 @@ class PublishedRagPipelineFirstStepApi(Resource):
"""
Get first step parameters of rag pipeline
"""
args = parser_parameters.parse_args()
node_id = args.get("node_id")
if not node_id:
raise ValueError("Node ID is required")
query = NodeIdQuery.model_validate(request.args.to_dict())
node_id = query.node_id
rag_pipeline_service = RagPipelineService()
variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False)
return {
@ -800,7 +785,6 @@ class PublishedRagPipelineFirstStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters")
class DraftRagPipelineFirstStepApi(Resource):
@console_ns.expect(parser_parameters)
@setup_required
@login_required
@account_initialization_required
@ -810,10 +794,8 @@ class DraftRagPipelineFirstStepApi(Resource):
"""
Get first step parameters of rag pipeline
"""
args = parser_parameters.parse_args()
node_id = args.get("node_id")
if not node_id:
raise ValueError("Node ID is required")
query = NodeIdQuery.model_validate(request.args.to_dict())
node_id = query.node_id
rag_pipeline_service = RagPipelineService()
variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True)
return {
@ -823,7 +805,6 @@ class DraftRagPipelineFirstStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters")
class DraftRagPipelineSecondStepApi(Resource):
@console_ns.expect(parser_parameters)
@setup_required
@login_required
@account_initialization_required
@ -833,10 +814,8 @@ class DraftRagPipelineSecondStepApi(Resource):
"""
Get second step parameters of rag pipeline
"""
args = parser_parameters.parse_args()
node_id = args.get("node_id")
if not node_id:
raise ValueError("Node ID is required")
query = NodeIdQuery.model_validate(request.args.to_dict())
node_id = query.node_id
rag_pipeline_service = RagPipelineService()
variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True)
@ -845,16 +824,8 @@ class DraftRagPipelineSecondStepApi(Resource):
}
parser_wf_run = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs")
class RagPipelineWorkflowRunListApi(Resource):
@console_ns.expect(parser_wf_run)
@setup_required
@login_required
@account_initialization_required
@ -864,7 +835,16 @@ class RagPipelineWorkflowRunListApi(Resource):
"""
Get workflow run list
"""
args = parser_wf_run.parse_args()
query = WorkflowRunQuery.model_validate(
{
"last_id": request.args.get("last_id"),
"limit": request.args.get("limit", type=int, default=20),
}
)
args = {
"last_id": str(query.last_id) if query.last_id else None,
"limit": query.limit,
}
rag_pipeline_service = RagPipelineService()
result = rag_pipeline_service.get_rag_pipeline_paginate_workflow_runs(pipeline=pipeline, args=args)
@ -964,18 +944,9 @@ class RagPipelineTransformApi(Resource):
return result
parser_var = (
reqparse.RequestParser()
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info", type=dict, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
.add_argument("start_node_title", type=str, required=True, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect")
class RagPipelineDatasourceVariableApi(Resource):
@console_ns.expect(parser_var)
@console_ns.expect(console_ns.models[DatasourceVariablesPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@ -987,7 +958,7 @@ class RagPipelineDatasourceVariableApi(Resource):
Set datasource variables
"""
current_user, _ = current_account_with_tenant()
args = parser_var.parse_args()
args = DatasourceVariablesPayload.model_validate(console_ns.payload or {}).model_dump()
rag_pipeline_service = RagPipelineService()
workflow_node_execution = rag_pipeline_service.set_datasource_variables(

View File

@ -1,5 +1,10 @@
from flask_restx import Resource, fields, reqparse
from typing import Literal
from flask import request
from flask_restx import Resource
from pydantic import BaseModel
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.datasets.error import WebsiteCrawlError
from controllers.console.wraps import account_initialization_required, setup_required
@ -7,48 +12,35 @@ from libs.login import login_required
from services.website_service import WebsiteCrawlApiRequest, WebsiteCrawlStatusApiRequest, WebsiteService
class WebsiteCrawlPayload(BaseModel):
provider: Literal["firecrawl", "watercrawl", "jinareader"]
url: str
options: dict[str, object]
class WebsiteCrawlStatusQuery(BaseModel):
provider: Literal["firecrawl", "watercrawl", "jinareader"]
register_schema_models(console_ns, WebsiteCrawlPayload, WebsiteCrawlStatusQuery)
@console_ns.route("/website/crawl")
class WebsiteCrawlApi(Resource):
@console_ns.doc("crawl_website")
@console_ns.doc(description="Crawl website content")
@console_ns.expect(
console_ns.model(
"WebsiteCrawlRequest",
{
"provider": fields.String(
required=True,
description="Crawl provider (firecrawl/watercrawl/jinareader)",
enum=["firecrawl", "watercrawl", "jinareader"],
),
"url": fields.String(required=True, description="URL to crawl"),
"options": fields.Raw(required=True, description="Crawl options"),
},
)
)
@console_ns.expect(console_ns.models[WebsiteCrawlPayload.__name__])
@console_ns.response(200, "Website crawl initiated successfully")
@console_ns.response(400, "Invalid crawl parameters")
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = (
reqparse.RequestParser()
.add_argument(
"provider",
type=str,
choices=["firecrawl", "watercrawl", "jinareader"],
required=True,
nullable=True,
location="json",
)
.add_argument("url", type=str, required=True, nullable=True, location="json")
.add_argument("options", type=dict, required=True, nullable=True, location="json")
)
args = parser.parse_args()
payload = WebsiteCrawlPayload.model_validate(console_ns.payload or {})
# Create typed request and validate
try:
api_request = WebsiteCrawlApiRequest.from_args(args)
api_request = WebsiteCrawlApiRequest.from_args(payload.model_dump())
except ValueError as e:
raise WebsiteCrawlError(str(e))
@ -65,6 +57,7 @@ class WebsiteCrawlStatusApi(Resource):
@console_ns.doc("get_crawl_status")
@console_ns.doc(description="Get website crawl status")
@console_ns.doc(params={"job_id": "Crawl job ID", "provider": "Crawl provider (firecrawl/watercrawl/jinareader)"})
@console_ns.expect(console_ns.models[WebsiteCrawlStatusQuery.__name__])
@console_ns.response(200, "Crawl status retrieved successfully")
@console_ns.response(404, "Crawl job not found")
@console_ns.response(400, "Invalid provider")
@ -72,14 +65,11 @@ class WebsiteCrawlStatusApi(Resource):
@login_required
@account_initialization_required
def get(self, job_id: str):
parser = reqparse.RequestParser().add_argument(
"provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args"
)
args = parser.parse_args()
args = WebsiteCrawlStatusQuery.model_validate(request.args.to_dict())
# Create typed request and validate
try:
api_request = WebsiteCrawlStatusApiRequest.from_args(args, job_id)
api_request = WebsiteCrawlStatusApiRequest.from_args(args.model_dump(), job_id)
except ValueError as e:
raise WebsiteCrawlError(str(e))

View File

@ -1,9 +1,11 @@
import logging
from flask import request
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError
import services
from controllers.common.schema import register_schema_model
from controllers.console.app.error import (
AppUnavailableError,
AudioTooLargeError,
@ -31,6 +33,16 @@ from .. import console_ns
logger = logging.getLogger(__name__)
class TextToAudioPayload(BaseModel):
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = Field(default=None, description="Enable streaming response")
register_schema_model(console_ns, TextToAudioPayload)
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/audio-to-text",
endpoint="installed_app_audio",
@ -76,23 +88,15 @@ class ChatAudioApi(InstalledAppResource):
endpoint="installed_app_text",
)
class ChatTextApi(InstalledAppResource):
@console_ns.expect(console_ns.models[TextToAudioPayload.__name__])
def post(self, installed_app):
from flask_restx import reqparse
app_model = installed_app.app
try:
parser = (
reqparse.RequestParser()
.add_argument("message_id", type=str, required=False, location="json")
.add_argument("voice", type=str, location="json")
.add_argument("text", type=str, location="json")
.add_argument("streaming", type=bool, location="json")
)
args = parser.parse_args()
payload = TextToAudioPayload.model_validate(console_ns.payload or {})
message_id = args.get("message_id", None)
text = args.get("text", None)
voice = args.get("voice", None)
message_id = payload.message_id
text = payload.text
voice = payload.voice
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
return response

View File

@ -1,9 +1,12 @@
import logging
from typing import Any, Literal
from uuid import UUID
from flask_restx import reqparse
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import InternalServerError, NotFound
import services
from controllers.common.schema import register_schema_models
from controllers.console.app.error import (
AppUnavailableError,
CompletionRequestError,
@ -25,7 +28,6 @@ from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.helper import uuid_value
from libs.login import current_user
from models import Account
from models.model import AppMode
@ -38,28 +40,56 @@ from .. import console_ns
logger = logging.getLogger(__name__)
class CompletionMessagePayload(BaseModel):
inputs: dict[str, Any]
query: str = ""
files: list[dict[str, Any]] | None = None
response_mode: Literal["blocking", "streaming"] | None = None
retriever_from: str = Field(default="explore_app")
class ChatMessagePayload(BaseModel):
inputs: dict[str, Any]
query: str
files: list[dict[str, Any]] | None = None
conversation_id: str | None = None
parent_message_id: str | None = None
retriever_from: str = Field(default="explore_app")
@field_validator("conversation_id", "parent_message_id", mode="before")
@classmethod
def normalize_uuid(cls, value: str | UUID | None) -> str | None:
"""
Accept blank IDs and validate UUID format when provided.
"""
if not value:
return None
try:
return helper.uuid_value(value)
except ValueError as exc:
raise ValueError("must be a valid UUID") from exc
register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload)
# define completion api for user
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/completion-messages",
endpoint="installed_app_completion",
)
class CompletionApi(InstalledAppResource):
@console_ns.expect(console_ns.models[CompletionMessagePayload.__name__])
def post(self, installed_app):
app_model = installed_app.app
if app_model.mode != AppMode.COMPLETION:
raise NotCompletionAppError()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, location="json", default="")
.add_argument("files", type=list, required=False, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
)
args = parser.parse_args()
payload = CompletionMessagePayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
streaming = args["response_mode"] == "streaming"
streaming = payload.response_mode == "streaming"
args["auto_generate_name"] = False
installed_app.last_used_at = naive_utc_now()
@ -123,22 +153,15 @@ class CompletionStopApi(InstalledAppResource):
endpoint="installed_app_chat_completion",
)
class ChatApi(InstalledAppResource):
@console_ns.expect(console_ns.models[ChatMessagePayload.__name__])
def post(self, installed_app):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, required=True, location="json")
.add_argument("files", type=list, required=False, location="json")
.add_argument("conversation_id", type=uuid_value, location="json")
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
)
args = parser.parse_args()
payload = ChatMessagePayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
args["auto_generate_name"] = False

View File

@ -1,14 +1,18 @@
from flask_restx import marshal_with, reqparse
from flask_restx.inputs import int_range
from typing import Any
from uuid import UUID
from flask import request
from flask_restx import marshal_with
from pydantic import BaseModel, Field, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console.explore.error import NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value
from libs.login import current_user
from models import Account
from models.model import AppMode
@ -19,29 +23,51 @@ from services.web_conversation_service import WebConversationService
from .. import console_ns
class ConversationListQuery(BaseModel):
last_id: UUID | None = None
limit: int = Field(default=20, ge=1, le=100)
pinned: bool | None = None
class ConversationRenamePayload(BaseModel):
name: str | None = None
auto_generate: bool = False
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload)
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/conversations",
endpoint="installed_app_conversations",
)
class ConversationListApi(InstalledAppResource):
@marshal_with(conversation_infinite_scroll_pagination_fields)
@console_ns.expect(console_ns.models[ConversationListQuery.__name__])
def get(self, installed_app):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
.add_argument("pinned", type=str, choices=["true", "false", None], location="args")
)
args = parser.parse_args()
pinned = None
if "pinned" in args and args["pinned"] is not None:
pinned = args["pinned"] == "true"
raw_args: dict[str, Any] = {
"last_id": request.args.get("last_id"),
"limit": request.args.get("limit", default=20, type=int),
"pinned": request.args.get("pinned"),
}
if raw_args["last_id"] is None:
raw_args["last_id"] = None
pinned_value = raw_args["pinned"]
if isinstance(pinned_value, str):
raw_args["pinned"] = pinned_value == "true"
args = ConversationListQuery.model_validate(raw_args)
try:
if not isinstance(current_user, Account):
@ -51,10 +77,10 @@ class ConversationListApi(InstalledAppResource):
session=session,
app_model=app_model,
user=current_user,
last_id=args["last_id"],
limit=args["limit"],
last_id=str(args.last_id) if args.last_id else None,
limit=args.limit,
invoke_from=InvokeFrom.EXPLORE,
pinned=pinned,
pinned=args.pinned,
)
except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
@ -88,6 +114,7 @@ class ConversationApi(InstalledAppResource):
)
class ConversationRenameApi(InstalledAppResource):
@marshal_with(simple_conversation_fields)
@console_ns.expect(console_ns.models[ConversationRenamePayload.__name__])
def post(self, installed_app, c_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
@ -96,18 +123,13 @@ class ConversationRenameApi(InstalledAppResource):
conversation_id = str(c_id)
parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=False, location="json")
.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
)
args = parser.parse_args()
payload = ConversationRenamePayload.model_validate(console_ns.payload or {})
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
return ConversationService.rename(
app_model, conversation_id, current_user, args["name"], args["auto_generate"]
app_model, conversation_id, current_user, payload.name, payload.auto_generate
)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")

View File

@ -1,9 +1,13 @@
import logging
from typing import Literal
from uuid import UUID
from flask_restx import marshal_with, reqparse
from flask_restx.inputs import int_range
from flask import request
from flask_restx import marshal_with
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models
from controllers.console.app.error import (
AppMoreLikeThisDisabledError,
CompletionRequestError,
@ -22,7 +26,6 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from core.model_runtime.errors.invoke import InvokeError
from fields.message_fields import message_infinite_scroll_pagination_fields
from libs import helper
from libs.helper import uuid_value
from libs.login import current_account_with_tenant
from models.model import AppMode
from services.app_generate_service import AppGenerateService
@ -40,12 +43,31 @@ from .. import console_ns
logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel):
conversation_id: UUID
first_id: UUID | None = None
limit: int = Field(default=20, ge=1, le=100)
class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = None
content: str | None = None
class MoreLikeThisQuery(BaseModel):
response_mode: Literal["blocking", "streaming"]
register_schema_models(console_ns, MessageListQuery, MessageFeedbackPayload, MoreLikeThisQuery)
@console_ns.route(
"/installed-apps/<uuid:installed_app_id>/messages",
endpoint="installed_app_messages",
)
class MessageListApi(InstalledAppResource):
@marshal_with(message_infinite_scroll_pagination_fields)
@console_ns.expect(console_ns.models[MessageListQuery.__name__])
def get(self, installed_app):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
@ -53,18 +75,15 @@ class MessageListApi(InstalledAppResource):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = (
reqparse.RequestParser()
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
.add_argument("first_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args()
args = MessageListQuery.model_validate(request.args.to_dict())
try:
return MessageService.pagination_by_first_id(
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
app_model,
current_user,
str(args.conversation_id),
str(args.first_id) if args.first_id else None,
args.limit,
)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@ -77,26 +96,22 @@ class MessageListApi(InstalledAppResource):
endpoint="installed_app_message_feedback",
)
class MessageFeedbackApi(InstalledAppResource):
@console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
def post(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
message_id = str(message_id)
parser = (
reqparse.RequestParser()
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
.add_argument("content", type=str, location="json")
)
args = parser.parse_args()
payload = MessageFeedbackPayload.model_validate(console_ns.payload or {})
try:
MessageService.create_feedback(
app_model=app_model,
message_id=message_id,
user=current_user,
rating=args.get("rating"),
content=args.get("content"),
rating=payload.rating,
content=payload.content,
)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
@ -109,6 +124,7 @@ class MessageFeedbackApi(InstalledAppResource):
endpoint="installed_app_more_like_this",
)
class MessageMoreLikeThisApi(InstalledAppResource):
@console_ns.expect(console_ns.models[MoreLikeThisQuery.__name__])
def get(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
@ -117,12 +133,9 @@ class MessageMoreLikeThisApi(InstalledAppResource):
message_id = str(message_id)
parser = reqparse.RequestParser().add_argument(
"response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args"
)
args = parser.parse_args()
args = MoreLikeThisQuery.model_validate(request.args.to_dict())
streaming = args["response_mode"] == "streaming"
streaming = args.response_mode == "streaming"
try:
response = AppGenerateService.generate_more_like_this(

View File

@ -1,16 +1,33 @@
from flask_restx import fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from uuid import UUID
from flask import request
from flask_restx import fields, marshal_with
from pydantic import BaseModel, Field
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, uuid_value
from libs.helper import TimestampField
from libs.login import current_account_with_tenant
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
class SavedMessageListQuery(BaseModel):
last_id: UUID | None = None
limit: int = Field(default=20, ge=1, le=100)
class SavedMessageCreatePayload(BaseModel):
message_id: UUID
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)
feedback_fields = {"rating": fields.String}
message_fields = {
@ -33,32 +50,33 @@ class SavedMessageListApi(InstalledAppResource):
}
@marshal_with(saved_message_infinite_scroll_pagination_fields)
@console_ns.expect(console_ns.models[SavedMessageListQuery.__name__])
def get(self, installed_app):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if app_model.mode != "completion":
raise NotCompletionAppError()
parser = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
args = SavedMessageListQuery.model_validate(request.args.to_dict())
return SavedMessageService.pagination_by_last_id(
app_model,
current_user,
str(args.last_id) if args.last_id else None,
args.limit,
)
args = parser.parse_args()
return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"])
@console_ns.expect(console_ns.models[SavedMessageCreatePayload.__name__])
def post(self, installed_app):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if app_model.mode != "completion":
raise NotCompletionAppError()
parser = reqparse.RequestParser().add_argument("message_id", type=uuid_value, required=True, location="json")
args = parser.parse_args()
payload = SavedMessageCreatePayload.model_validate(console_ns.payload or {})
try:
SavedMessageService.save(app_model, current_user, args["message_id"])
SavedMessageService.save(app_model, current_user, str(payload.message_id))
except MessageNotExistsError:
raise NotFound("Message Not Exists.")

View File

@ -1,8 +1,10 @@
import logging
from typing import Any
from flask_restx import reqparse
from pydantic import BaseModel
from werkzeug.exceptions import InternalServerError
from controllers.common.schema import register_schema_model
from controllers.console.app.error import (
CompletionRequestError,
ProviderModelCurrentlyNotSupportError,
@ -32,8 +34,17 @@ from .. import console_ns
logger = logging.getLogger(__name__)
class WorkflowRunPayload(BaseModel):
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
register_schema_model(console_ns, WorkflowRunPayload)
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
class InstalledAppWorkflowRunApi(InstalledAppResource):
@console_ns.expect(console_ns.models[WorkflowRunPayload.__name__])
def post(self, installed_app: InstalledApp):
"""
Run workflow
@ -46,12 +57,8 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("files", type=list, required=False, location="json")
)
args = parser.parse_args()
payload = WorkflowRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
try:
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True

View File

@ -45,6 +45,9 @@ class FileApi(Resource):
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
"image_file_batch_limit": dify_config.IMAGE_FILE_BATCH_LIMIT,
"single_chunk_attachment_limit": dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT,
"attachment_image_file_size_limit": dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT,
}, 200
@setup_required

View File

@ -230,7 +230,7 @@ class ModelProviderModelApi(Resource):
return {"result": "success"}, 200
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__], validate=True)
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@ -282,9 +282,10 @@ class ModelProviderModelCredentialApi(Resource):
tenant_id=tenant_id, provider_name=provider
)
else:
model_type = args.model_type
# Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM)
normalized_model_type = args.model_type.to_origin_model_type()
available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials(
tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args.model
tenant_id=tenant_id, provider_name=provider, model_type=normalized_model_type, model_name=args.model
)
return jsonable_encoder(

View File

@ -46,8 +46,8 @@ class PluginDebuggingKeyApi(Resource):
class ParserList(BaseModel):
page: int = Field(default=1)
page_size: int = Field(default=256)
page: int = Field(default=1, ge=1, description="Page number")
page_size: int = Field(default=256, ge=1, le=256, description="Page size (1-256)")
reg(ParserList)
@ -106,8 +106,8 @@ class ParserPluginIdentifierQuery(BaseModel):
class ParserTasks(BaseModel):
page: int
page_size: int
page: int = Field(default=1, ge=1, description="Page number")
page_size: int = Field(default=256, ge=1, le=256, description="Page size (1-256)")
class ParserMarketplaceUpgrade(BaseModel):

View File

@ -22,7 +22,12 @@ from services.trigger.trigger_subscription_builder_service import TriggerSubscri
from services.trigger.trigger_subscription_operator_service import TriggerSubscriptionOperatorService
from .. import console_ns
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from ..wraps import (
account_initialization_required,
edit_permission_required,
is_admin_or_owner_required,
setup_required,
)
logger = logging.getLogger(__name__)
@ -72,7 +77,7 @@ class TriggerProviderInfoApi(Resource):
class TriggerSubscriptionListApi(Resource):
@setup_required
@login_required
@is_admin_or_owner_required
@edit_permission_required
@account_initialization_required
def get(self, provider):
"""List all trigger subscriptions for the current tenant's provider"""
@ -104,7 +109,7 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
@console_ns.expect(parser)
@setup_required
@login_required
@is_admin_or_owner_required
@edit_permission_required
@account_initialization_required
def post(self, provider):
"""Add a new subscription instance for a trigger provider"""
@ -133,6 +138,7 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
class TriggerSubscriptionBuilderGetApi(Resource):
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
def get(self, provider, subscription_builder_id):
"""Get a subscription instance for a trigger provider"""
@ -155,7 +161,7 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
@console_ns.expect(parser_api)
@setup_required
@login_required
@is_admin_or_owner_required
@edit_permission_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
"""Verify a subscription instance for a trigger provider"""
@ -200,6 +206,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
@console_ns.expect(parser_update_api)
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
"""Update a subscription instance for a trigger provider"""
@ -233,6 +240,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
class TriggerSubscriptionBuilderLogsApi(Resource):
@setup_required
@login_required
@edit_permission_required
@account_initialization_required
def get(self, provider, subscription_builder_id):
"""Get the request logs for a subscription instance for a trigger provider"""
@ -255,7 +263,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
@console_ns.expect(parser_update_api)
@setup_required
@login_required
@is_admin_or_owner_required
@edit_permission_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
"""Build a subscription instance for a trigger provider"""

View File

@ -331,3 +331,91 @@ def is_admin_or_owner_required(f: Callable[P, R]):
return f(*args, **kwargs)
return decorated_function
def annotation_import_rate_limit(view: Callable[P, R]):
"""
Rate limiting decorator for annotation import operations.
Implements sliding window rate limiting with two tiers:
- Short-term: Configurable requests per minute (default: 5)
- Long-term: Configurable requests per hour (default: 20)
Uses Redis ZSET for distributed rate limiting across multiple instances.
"""
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
current_time = int(time.time() * 1000)
# Check per-minute rate limit
minute_key = f"annotation_import_rate_limit:{current_tenant_id}:1min"
redis_client.zadd(minute_key, {current_time: current_time})
redis_client.zremrangebyscore(minute_key, 0, current_time - 60000)
minute_count = redis_client.zcard(minute_key)
redis_client.expire(minute_key, 120) # 2 minutes TTL
if minute_count > dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE:
abort(
429,
f"Too many annotation import requests. Maximum {dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE} "
f"requests per minute allowed. Please try again later.",
)
# Check per-hour rate limit
hour_key = f"annotation_import_rate_limit:{current_tenant_id}:1hour"
redis_client.zadd(hour_key, {current_time: current_time})
redis_client.zremrangebyscore(hour_key, 0, current_time - 3600000)
hour_count = redis_client.zcard(hour_key)
redis_client.expire(hour_key, 7200) # 2 hours TTL
if hour_count > dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR:
abort(
429,
f"Too many annotation import requests. Maximum {dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR} "
f"requests per hour allowed. Please try again later.",
)
return view(*args, **kwargs)
return decorated
def annotation_import_concurrency_limit(view: Callable[P, R]):
"""
Concurrency control decorator for annotation import operations.
Limits the number of concurrent import tasks per tenant to prevent
resource exhaustion and ensure fair resource allocation.
Uses Redis ZSET to track active import jobs with automatic cleanup
of stale entries (jobs older than 2 minutes).
"""
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
current_time = int(time.time() * 1000)
active_jobs_key = f"annotation_import_active:{current_tenant_id}"
# Clean up stale entries (jobs that should have completed or timed out)
stale_threshold = current_time - 120000 # 2 minutes ago
redis_client.zremrangebyscore(active_jobs_key, 0, stale_threshold)
# Check current active job count
active_count = redis_client.zcard(active_jobs_key)
if active_count >= dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT:
abort(
429,
f"Too many concurrent import tasks. Maximum {dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT} "
f"concurrent imports allowed per workspace. Please wait for existing imports to complete.",
)
# Allow the request to proceed
# The actual job registration will happen in the service layer
return view(*args, **kwargs)
return decorated

View File

@ -1,29 +1,38 @@
from flask_restx import Resource, reqparse
from typing import Any
from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.common.schema import register_schema_model
from controllers.console.wraps import setup_required
from controllers.inner_api import inner_api_ns
from controllers.inner_api.wraps import billing_inner_api_only, enterprise_inner_api_only
from tasks.mail_inner_task import send_inner_email_task
_mail_parser = (
reqparse.RequestParser()
.add_argument("to", type=str, action="append", required=True)
.add_argument("subject", type=str, required=True)
.add_argument("body", type=str, required=True)
.add_argument("substitutions", type=dict, required=False)
)
class InnerMailPayload(BaseModel):
to: list[str] = Field(description="Recipient email addresses", min_length=1)
subject: str
body: str
substitutions: dict[str, Any] | None = None
register_schema_model(inner_api_ns, InnerMailPayload)
class BaseMail(Resource):
"""Shared logic for sending an inner email."""
@inner_api_ns.doc("send_inner_mail")
@inner_api_ns.doc(description="Send internal email")
@inner_api_ns.expect(inner_api_ns.models[InnerMailPayload.__name__])
def post(self):
args = _mail_parser.parse_args()
send_inner_email_task.delay( # type: ignore
to=args["to"],
subject=args["subject"],
body=args["body"],
substitutions=args["substitutions"],
args = InnerMailPayload.model_validate(inner_api_ns.payload or {})
send_inner_email_task.delay(
to=args.to,
subject=args.subject,
body=args.body,
substitutions=args.substitutions, # type: ignore
)
return {"message": "success"}, 200
@ -34,7 +43,7 @@ class EnterpriseMail(BaseMail):
@inner_api_ns.doc("send_enterprise_mail")
@inner_api_ns.doc(description="Send internal email for enterprise features")
@inner_api_ns.expect(_mail_parser)
@inner_api_ns.expect(inner_api_ns.models[InnerMailPayload.__name__])
@inner_api_ns.doc(
responses={200: "Email sent successfully", 401: "Unauthorized - invalid API key", 404: "Service not available"}
)
@ -56,7 +65,7 @@ class BillingMail(BaseMail):
@inner_api_ns.doc("send_billing_mail")
@inner_api_ns.doc(description="Send internal email for billing notifications")
@inner_api_ns.expect(_mail_parser)
@inner_api_ns.expect(inner_api_ns.models[InnerMailPayload.__name__])
@inner_api_ns.doc(
responses={200: "Email sent successfully", 401: "Unauthorized - invalid API key", 404: "Service not available"}
)

View File

@ -1,10 +1,9 @@
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar, cast
from typing import ParamSpec, TypeVar
from flask import current_app, request
from flask_login import user_logged_in
from flask_restx import reqparse
from pydantic import BaseModel
from sqlalchemy.orm import Session
@ -17,6 +16,11 @@ P = ParamSpec("P")
R = TypeVar("R")
class TenantUserPayload(BaseModel):
tenant_id: str
user_id: str
def get_user(tenant_id: str, user_id: str | None) -> EndUser:
"""
Get current user
@ -67,58 +71,45 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
return user_model
def get_user_tenant(view: Callable[P, R] | None = None):
def decorator(view_func: Callable[P, R]):
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
# fetch json body
parser = (
reqparse.RequestParser()
.add_argument("tenant_id", type=str, required=True, location="json")
.add_argument("user_id", type=str, required=True, location="json")
)
def get_user_tenant(view_func: Callable[P, R]):
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
payload = TenantUserPayload.model_validate(request.get_json(silent=True) or {})
p = parser.parse_args()
user_id = payload.user_id
tenant_id = payload.tenant_id
user_id = cast(str, p.get("user_id"))
tenant_id = cast(str, p.get("tenant_id"))
if not tenant_id:
raise ValueError("tenant_id is required")
if not tenant_id:
raise ValueError("tenant_id is required")
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
try:
tenant_model = (
db.session.query(Tenant)
.where(
Tenant.id == tenant_id,
)
.first()
try:
tenant_model = (
db.session.query(Tenant)
.where(
Tenant.id == tenant_id,
)
except Exception:
raise ValueError("tenant not found")
.first()
)
except Exception:
raise ValueError("tenant not found")
if not tenant_model:
raise ValueError("tenant not found")
if not tenant_model:
raise ValueError("tenant not found")
kwargs["tenant_model"] = tenant_model
kwargs["tenant_model"] = tenant_model
user = get_user(tenant_id, user_id)
kwargs["user_model"] = user
user = get_user(tenant_id, user_id)
kwargs["user_model"] = user
current_app.login_manager._update_request_context_with_user(user) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
current_app.login_manager._update_request_context_with_user(user) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
return view_func(*args, **kwargs)
return view_func(*args, **kwargs)
return decorated_view
if view is None:
return decorator
else:
return decorator(view)
return decorated_view
def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]):

View File

@ -1,7 +1,9 @@
import json
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel
from controllers.common.schema import register_schema_models
from controllers.console.wraps import setup_required
from controllers.inner_api import inner_api_ns
from controllers.inner_api.wraps import enterprise_inner_api_only
@ -11,12 +13,25 @@ from models import Account
from services.account_service import TenantService
class WorkspaceCreatePayload(BaseModel):
name: str
owner_email: str
class WorkspaceOwnerlessPayload(BaseModel):
name: str
register_schema_models(inner_api_ns, WorkspaceCreatePayload, WorkspaceOwnerlessPayload)
@inner_api_ns.route("/enterprise/workspace")
class EnterpriseWorkspace(Resource):
@setup_required
@enterprise_inner_api_only
@inner_api_ns.doc("create_enterprise_workspace")
@inner_api_ns.doc(description="Create a new enterprise workspace with owner assignment")
@inner_api_ns.expect(inner_api_ns.models[WorkspaceCreatePayload.__name__])
@inner_api_ns.doc(
responses={
200: "Workspace created successfully",
@ -25,18 +40,13 @@ class EnterpriseWorkspace(Resource):
}
)
def post(self):
parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=True, location="json")
.add_argument("owner_email", type=str, required=True, location="json")
)
args = parser.parse_args()
args = WorkspaceCreatePayload.model_validate(inner_api_ns.payload or {})
account = db.session.query(Account).filter_by(email=args["owner_email"]).first()
account = db.session.query(Account).filter_by(email=args.owner_email).first()
if account is None:
return {"message": "owner account not found."}, 404
tenant = TenantService.create_tenant(args["name"], is_from_dashboard=True)
tenant = TenantService.create_tenant(args.name, is_from_dashboard=True)
TenantService.create_tenant_member(tenant, account, role="owner")
tenant_was_created.send(tenant)
@ -62,6 +72,7 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource):
@enterprise_inner_api_only
@inner_api_ns.doc("create_enterprise_workspace_ownerless")
@inner_api_ns.doc(description="Create a new enterprise workspace without initial owner assignment")
@inner_api_ns.expect(inner_api_ns.models[WorkspaceOwnerlessPayload.__name__])
@inner_api_ns.doc(
responses={
200: "Workspace created successfully",
@ -70,10 +81,9 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource):
}
)
def post(self):
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
args = parser.parse_args()
args = WorkspaceOwnerlessPayload.model_validate(inner_api_ns.payload or {})
tenant = TenantService.create_tenant(args["name"], is_from_dashboard=True)
tenant = TenantService.create_tenant(args.name, is_from_dashboard=True)
tenant_was_created.send(tenant)

View File

@ -1,10 +1,11 @@
from typing import Union
from typing import Any, Union
from flask import Response
from flask_restx import Resource, reqparse
from pydantic import ValidationError
from flask_restx import Resource
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_model
from controllers.console.app.mcp_server import AppMCPServerStatus
from controllers.mcp import mcp_ns
from core.app.app_config.entities import VariableEntity
@ -24,27 +25,19 @@ class MCPRequestError(Exception):
super().__init__(message)
def int_or_str(value):
"""Validate that a value is either an integer or string."""
if isinstance(value, (int, str)):
return value
else:
return None
class MCPRequestPayload(BaseModel):
jsonrpc: str = Field(description="JSON-RPC version (should be '2.0')")
method: str = Field(description="The method to invoke")
params: dict[str, Any] | None = Field(default=None, description="Parameters for the method")
id: int | str | None = Field(default=None, description="Request ID for tracking responses")
# Define parser for both documentation and validation
mcp_request_parser = (
reqparse.RequestParser()
.add_argument("jsonrpc", type=str, required=True, location="json", help="JSON-RPC version (should be '2.0')")
.add_argument("method", type=str, required=True, location="json", help="The method to invoke")
.add_argument("params", type=dict, required=False, location="json", help="Parameters for the method")
.add_argument("id", type=int_or_str, required=False, location="json", help="Request ID for tracking responses")
)
register_schema_model(mcp_ns, MCPRequestPayload)
@mcp_ns.route("/server/<string:server_code>/mcp")
class MCPAppApi(Resource):
@mcp_ns.expect(mcp_request_parser)
@mcp_ns.expect(mcp_ns.models[MCPRequestPayload.__name__])
@mcp_ns.doc("handle_mcp_request")
@mcp_ns.doc(description="Handle Model Context Protocol (MCP) requests for a specific server")
@mcp_ns.doc(params={"server_code": "Unique identifier for the MCP server"})
@ -70,9 +63,9 @@ class MCPAppApi(Resource):
Raises:
ValidationError: Invalid request format or parameters
"""
args = mcp_request_parser.parse_args()
request_id: Union[int, str] | None = args.get("id")
mcp_request = self._parse_mcp_request(args)
args = MCPRequestPayload.model_validate(mcp_ns.payload or {})
request_id: Union[int, str] | None = args.id
mcp_request = self._parse_mcp_request(args.model_dump(exclude_none=True))
with Session(db.engine, expire_on_commit=False) as session:
# Get MCP server and app

View File

@ -1,9 +1,11 @@
from typing import Literal
from flask import request
from flask_restx import Api, Namespace, Resource, fields, reqparse
from flask_restx import Api, Namespace, Resource, fields
from flask_restx.api import HTTPStatus
from pydantic import BaseModel, Field
from controllers.common.schema import register_schema_models
from controllers.console.wraps import edit_permission_required
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import validate_app_token
@ -12,26 +14,24 @@ from fields.annotation_fields import annotation_fields, build_annotation_model
from models.model import App
from services.annotation_service import AppAnnotationService
# Define parsers for annotation API
annotation_create_parser = (
reqparse.RequestParser()
.add_argument("question", required=True, type=str, location="json", help="Annotation question")
.add_argument("answer", required=True, type=str, location="json", help="Annotation answer")
)
annotation_reply_action_parser = (
reqparse.RequestParser()
.add_argument(
"score_threshold", required=True, type=float, location="json", help="Score threshold for annotation matching"
)
.add_argument("embedding_provider_name", required=True, type=str, location="json", help="Embedding provider name")
.add_argument("embedding_model_name", required=True, type=str, location="json", help="Embedding model name")
)
class AnnotationCreatePayload(BaseModel):
question: str = Field(description="Annotation question")
answer: str = Field(description="Annotation answer")
class AnnotationReplyActionPayload(BaseModel):
score_threshold: float = Field(description="Score threshold for annotation matching")
embedding_provider_name: str = Field(description="Embedding provider name")
embedding_model_name: str = Field(description="Embedding model name")
register_schema_models(service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload)
@service_api_ns.route("/apps/annotation-reply/<string:action>")
class AnnotationReplyActionApi(Resource):
@service_api_ns.expect(annotation_reply_action_parser)
@service_api_ns.expect(service_api_ns.models[AnnotationReplyActionPayload.__name__])
@service_api_ns.doc("annotation_reply_action")
@service_api_ns.doc(description="Enable or disable annotation reply feature")
@service_api_ns.doc(params={"action": "Action to perform: 'enable' or 'disable'"})
@ -44,7 +44,7 @@ class AnnotationReplyActionApi(Resource):
@validate_app_token
def post(self, app_model: App, action: Literal["enable", "disable"]):
"""Enable or disable annotation reply feature."""
args = annotation_reply_action_parser.parse_args()
args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump()
if action == "enable":
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
elif action == "disable":
@ -126,7 +126,7 @@ class AnnotationListApi(Resource):
"page": page,
}
@service_api_ns.expect(annotation_create_parser)
@service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__])
@service_api_ns.doc("create_annotation")
@service_api_ns.doc(description="Create a new annotation")
@service_api_ns.doc(
@ -139,14 +139,14 @@ class AnnotationListApi(Resource):
@service_api_ns.marshal_with(build_annotation_model(service_api_ns), code=HTTPStatus.CREATED)
def post(self, app_model: App):
"""Create a new annotation."""
args = annotation_create_parser.parse_args()
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id)
return annotation, 201
@service_api_ns.route("/apps/annotations/<uuid:annotation_id>")
class AnnotationUpdateDeleteApi(Resource):
@service_api_ns.expect(annotation_create_parser)
@service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__])
@service_api_ns.doc("update_annotation")
@service_api_ns.doc(description="Update an existing annotation")
@service_api_ns.doc(params={"annotation_id": "Annotation ID"})
@ -163,7 +163,7 @@ class AnnotationUpdateDeleteApi(Resource):
@service_api_ns.marshal_with(build_annotation_model(service_api_ns))
def put(self, app_model: App, annotation_id: str):
"""Update an existing annotation."""
args = annotation_create_parser.parse_args()
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
return annotation

View File

@ -1,10 +1,12 @@
import logging
from flask import request
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError
import services
from controllers.common.schema import register_schema_model
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import (
AppUnavailableError,
@ -84,19 +86,19 @@ class AudioApi(Resource):
raise InternalServerError()
# Define parser for text-to-audio API
text_to_audio_parser = (
reqparse.RequestParser()
.add_argument("message_id", type=str, required=False, location="json", help="Message ID")
.add_argument("voice", type=str, location="json", help="Voice to use for TTS")
.add_argument("text", type=str, location="json", help="Text to convert to audio")
.add_argument("streaming", type=bool, location="json", help="Enable streaming response")
)
class TextToAudioPayload(BaseModel):
message_id: str | None = Field(default=None, description="Message ID")
voice: str | None = Field(default=None, description="Voice to use for TTS")
text: str | None = Field(default=None, description="Text to convert to audio")
streaming: bool | None = Field(default=None, description="Enable streaming response")
register_schema_model(service_api_ns, TextToAudioPayload)
@service_api_ns.route("/text-to-audio")
class TextApi(Resource):
@service_api_ns.expect(text_to_audio_parser)
@service_api_ns.expect(service_api_ns.models[TextToAudioPayload.__name__])
@service_api_ns.doc("text_to_audio")
@service_api_ns.doc(description="Convert text to audio using text-to-speech")
@service_api_ns.doc(
@ -114,11 +116,11 @@ class TextApi(Resource):
Converts the provided text to audio using the specified voice.
"""
try:
args = text_to_audio_parser.parse_args()
payload = TextToAudioPayload.model_validate(service_api_ns.payload or {})
message_id = args.get("message_id", None)
text = args.get("text", None)
voice = args.get("voice", None)
message_id = payload.message_id
text = payload.text
voice = payload.voice
response = AudioService.transcript_tts(
app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
)

View File

@ -1,10 +1,14 @@
import logging
from typing import Any, Literal
from uuid import UUID
from flask import request
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import (
AppUnavailableError,
@ -26,7 +30,6 @@ from core.errors.error import (
from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import uuid_value
from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService
from services.app_task_service import AppTaskService
@ -36,40 +39,46 @@ from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__)
# Define parser for completion API
completion_parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for completion")
.add_argument("query", type=str, location="json", default="", help="The query string")
.add_argument("files", type=list, required=False, location="json", help="List of file attachments")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode")
.add_argument("retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source")
)
class CompletionRequestPayload(BaseModel):
inputs: dict[str, Any]
query: str = Field(default="")
files: list[dict[str, Any]] | None = None
response_mode: Literal["blocking", "streaming"] | None = None
retriever_from: str = Field(default="dev")
# Define parser for chat API
chat_parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for chat")
.add_argument("query", type=str, required=True, location="json", help="The chat query")
.add_argument("files", type=list, required=False, location="json", help="List of file attachments")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode")
.add_argument("conversation_id", type=uuid_value, location="json", help="Existing conversation ID")
.add_argument("retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source")
.add_argument(
"auto_generate_name",
type=bool,
required=False,
default=True,
location="json",
help="Auto generate conversation name",
)
.add_argument("workflow_id", type=str, required=False, location="json", help="Workflow ID for advanced chat")
)
class ChatRequestPayload(BaseModel):
inputs: dict[str, Any]
query: str
files: list[dict[str, Any]] | None = None
response_mode: Literal["blocking", "streaming"] | None = None
conversation_id: str | None = Field(default=None, description="Conversation UUID")
retriever_from: str = Field(default="dev")
auto_generate_name: bool = Field(default=True, description="Auto generate conversation name")
workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat")
@field_validator("conversation_id", mode="before")
@classmethod
def normalize_conversation_id(cls, value: str | UUID | None) -> str | None:
"""Allow missing or blank conversation IDs; enforce UUID format when provided."""
if isinstance(value, str):
value = value.strip()
if not value:
return None
try:
return helper.uuid_value(value)
except ValueError as exc:
raise ValueError("conversation_id must be a valid UUID") from exc
register_schema_models(service_api_ns, CompletionRequestPayload, ChatRequestPayload)
@service_api_ns.route("/completion-messages")
class CompletionApi(Resource):
@service_api_ns.expect(completion_parser)
@service_api_ns.expect(service_api_ns.models[CompletionRequestPayload.__name__])
@service_api_ns.doc("create_completion")
@service_api_ns.doc(description="Create a completion for the given prompt")
@service_api_ns.doc(
@ -91,12 +100,13 @@ class CompletionApi(Resource):
if app_model.mode != AppMode.COMPLETION:
raise AppUnavailableError()
args = completion_parser.parse_args()
payload = CompletionRequestPayload.model_validate(service_api_ns.payload or {})
external_trace_id = get_external_trace_id(request)
args = payload.model_dump(exclude_none=True)
if external_trace_id:
args["external_trace_id"] = external_trace_id
streaming = args["response_mode"] == "streaming"
streaming = payload.response_mode == "streaming"
args["auto_generate_name"] = False
@ -162,7 +172,7 @@ class CompletionStopApi(Resource):
@service_api_ns.route("/chat-messages")
class ChatApi(Resource):
@service_api_ns.expect(chat_parser)
@service_api_ns.expect(service_api_ns.models[ChatRequestPayload.__name__])
@service_api_ns.doc("create_chat_message")
@service_api_ns.doc(description="Send a message in a chat conversation")
@service_api_ns.doc(
@ -186,13 +196,14 @@ class ChatApi(Resource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
args = chat_parser.parse_args()
payload = ChatRequestPayload.model_validate(service_api_ns.payload or {})
external_trace_id = get_external_trace_id(request)
args = payload.model_dump(exclude_none=True)
if external_trace_id:
args["external_trace_id"] = external_trace_id
streaming = args["response_mode"] == "streaming"
streaming = payload.response_mode == "streaming"
try:
response = AppGenerateService.generate(

View File

@ -1,10 +1,15 @@
from flask_restx import Resource, reqparse
from typing import Any, Literal
from uuid import UUID
from flask import request
from flask_restx import Resource
from flask_restx._http import HTTPStatus
from flask_restx.inputs import int_range
from pydantic import BaseModel, Field, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound
import services
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
@ -19,74 +24,51 @@ from fields.conversation_variable_fields import (
build_conversation_variable_infinite_scroll_pagination_model,
build_conversation_variable_model,
)
from libs.helper import uuid_value
from models.model import App, AppMode, EndUser
from services.conversation_service import ConversationService
# Define parsers for conversation APIs
conversation_list_parser = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args", help="Last conversation ID for pagination")
.add_argument(
"limit",
type=int_range(1, 100),
required=False,
default=20,
location="args",
help="Number of conversations to return",
)
.add_argument(
"sort_by",
type=str,
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
required=False,
default="-updated_at",
location="args",
help="Sort order for conversations",
)
)
conversation_rename_parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=False, location="json", help="New conversation name")
.add_argument(
"auto_generate",
type=bool,
required=False,
default=False,
location="json",
help="Auto-generate conversation name",
class ConversationListQuery(BaseModel):
last_id: UUID | None = Field(default=None, description="Last conversation ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of conversations to return")
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
default="-updated_at", description="Sort order for conversations"
)
)
conversation_variables_parser = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args", help="Last variable ID for pagination")
.add_argument(
"limit",
type=int_range(1, 100),
required=False,
default=20,
location="args",
help="Number of variables to return",
)
)
conversation_variable_update_parser = reqparse.RequestParser().add_argument(
# using lambda is for passing the already-typed value without modification
# if no lambda, it will be converted to string
# the string cannot be converted using json.loads
"value",
required=True,
location="json",
type=lambda x: x,
help="New value for the conversation variable",
class ConversationRenamePayload(BaseModel):
name: str | None = Field(default=None, description="New conversation name (required if auto_generate is false)")
auto_generate: bool = Field(default=False, description="Auto-generate conversation name")
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
class ConversationVariablesQuery(BaseModel):
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
class ConversationVariableUpdatePayload(BaseModel):
value: Any
register_schema_models(
service_api_ns,
ConversationListQuery,
ConversationRenamePayload,
ConversationVariablesQuery,
ConversationVariableUpdatePayload,
)
@service_api_ns.route("/conversations")
class ConversationApi(Resource):
@service_api_ns.expect(conversation_list_parser)
@service_api_ns.expect(service_api_ns.models[ConversationListQuery.__name__])
@service_api_ns.doc("list_conversations")
@service_api_ns.doc(description="List all conversations for the current user")
@service_api_ns.doc(
@ -107,7 +89,8 @@ class ConversationApi(Resource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
args = conversation_list_parser.parse_args()
query_args = ConversationListQuery.model_validate(request.args.to_dict())
last_id = str(query_args.last_id) if query_args.last_id else None
try:
with Session(db.engine) as session:
@ -115,10 +98,10 @@ class ConversationApi(Resource):
session=session,
app_model=app_model,
user=end_user,
last_id=args["last_id"],
limit=args["limit"],
last_id=last_id,
limit=query_args.limit,
invoke_from=InvokeFrom.SERVICE_API,
sort_by=args["sort_by"],
sort_by=query_args.sort_by,
)
except services.errors.conversation.LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
@ -155,7 +138,7 @@ class ConversationDetailApi(Resource):
@service_api_ns.route("/conversations/<uuid:c_id>/name")
class ConversationRenameApi(Resource):
@service_api_ns.expect(conversation_rename_parser)
@service_api_ns.expect(service_api_ns.models[ConversationRenamePayload.__name__])
@service_api_ns.doc("rename_conversation")
@service_api_ns.doc(description="Rename a conversation or auto-generate a name")
@service_api_ns.doc(params={"c_id": "Conversation ID"})
@ -176,17 +159,17 @@ class ConversationRenameApi(Resource):
conversation_id = str(c_id)
args = conversation_rename_parser.parse_args()
payload = ConversationRenamePayload.model_validate(service_api_ns.payload or {})
try:
return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"])
return ConversationService.rename(app_model, conversation_id, end_user, payload.name, payload.auto_generate)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@service_api_ns.route("/conversations/<uuid:c_id>/variables")
class ConversationVariablesApi(Resource):
@service_api_ns.expect(conversation_variables_parser)
@service_api_ns.expect(service_api_ns.models[ConversationVariablesQuery.__name__])
@service_api_ns.doc("list_conversation_variables")
@service_api_ns.doc(description="List all variables for a conversation")
@service_api_ns.doc(params={"c_id": "Conversation ID"})
@ -211,11 +194,12 @@ class ConversationVariablesApi(Resource):
conversation_id = str(c_id)
args = conversation_variables_parser.parse_args()
query_args = ConversationVariablesQuery.model_validate(request.args.to_dict())
last_id = str(query_args.last_id) if query_args.last_id else None
try:
return ConversationService.get_conversational_variable(
app_model, conversation_id, end_user, args["limit"], args["last_id"]
app_model, conversation_id, end_user, query_args.limit, last_id
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@ -223,7 +207,7 @@ class ConversationVariablesApi(Resource):
@service_api_ns.route("/conversations/<uuid:c_id>/variables/<uuid:variable_id>")
class ConversationVariableDetailApi(Resource):
@service_api_ns.expect(conversation_variable_update_parser)
@service_api_ns.expect(service_api_ns.models[ConversationVariableUpdatePayload.__name__])
@service_api_ns.doc("update_conversation_variable")
@service_api_ns.doc(description="Update a conversation variable's value")
@service_api_ns.doc(params={"c_id": "Conversation ID", "variable_id": "Variable ID"})
@ -250,11 +234,11 @@ class ConversationVariableDetailApi(Resource):
conversation_id = str(c_id)
variable_id = str(variable_id)
args = conversation_variable_update_parser.parse_args()
payload = ConversationVariableUpdatePayload.model_validate(service_api_ns.payload or {})
try:
return ConversationService.update_conversation_variable(
app_model, conversation_id, variable_id, end_user, args["value"]
app_model, conversation_id, variable_id, end_user, payload.value
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")

View File

@ -1,9 +1,11 @@
import logging
from urllib.parse import quote
from flask import Response
from flask_restx import Resource, reqparse
from flask import Response, request
from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.common.schema import register_schema_model
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import (
FileAccessDeniedError,
@ -17,10 +19,11 @@ from models.model import App, EndUser, Message, MessageFile, UploadFile
logger = logging.getLogger(__name__)
# Define parser for file preview API
file_preview_parser = reqparse.RequestParser().add_argument(
"as_attachment", type=bool, required=False, default=False, location="args", help="Download as attachment"
)
class FilePreviewQuery(BaseModel):
as_attachment: bool = Field(default=False, description="Download as attachment")
register_schema_model(service_api_ns, FilePreviewQuery)
@service_api_ns.route("/files/<uuid:file_id>/preview")
@ -32,7 +35,7 @@ class FilePreviewApi(Resource):
Files can only be accessed if they belong to messages within the requesting app's context.
"""
@service_api_ns.expect(file_preview_parser)
@service_api_ns.expect(service_api_ns.models[FilePreviewQuery.__name__])
@service_api_ns.doc("preview_file")
@service_api_ns.doc(description="Preview or download a file uploaded via Service API")
@service_api_ns.doc(params={"file_id": "UUID of the file to preview"})
@ -55,7 +58,7 @@ class FilePreviewApi(Resource):
file_id = str(file_id)
# Parse query parameters
args = file_preview_parser.parse_args()
args = FilePreviewQuery.model_validate(request.args.to_dict())
# Validate file ownership and get file objects
_, upload_file = self._validate_file_ownership(file_id, app_model.id)
@ -67,7 +70,7 @@ class FilePreviewApi(Resource):
raise FileNotFoundError(f"Failed to load file content: {str(e)}")
# Build response with appropriate headers
response = self._build_file_response(generator, upload_file, args["as_attachment"])
response = self._build_file_response(generator, upload_file, args.as_attachment)
return response

View File

@ -1,11 +1,15 @@
import json
import logging
from typing import Literal
from uuid import UUID
from flask_restx import Api, Namespace, Resource, fields, reqparse
from flask_restx.inputs import int_range
from flask import request
from flask_restx import Namespace, Resource, fields
from pydantic import BaseModel, Field
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
@ -13,7 +17,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import build_message_file_model
from fields.message_fields import build_agent_thought_model, build_feedback_model
from fields.raws import FilesContainedField
from libs.helper import TimestampField, uuid_value
from libs.helper import TimestampField
from models.model import App, AppMode, EndUser
from services.errors.message import (
FirstMessageNotExistsError,
@ -25,42 +29,26 @@ from services.message_service import MessageService
logger = logging.getLogger(__name__)
# Define parsers for message APIs
message_list_parser = (
reqparse.RequestParser()
.add_argument("conversation_id", required=True, type=uuid_value, location="args", help="Conversation ID")
.add_argument("first_id", type=uuid_value, location="args", help="First message ID for pagination")
.add_argument(
"limit",
type=int_range(1, 100),
required=False,
default=20,
location="args",
help="Number of messages to return",
)
)
message_feedback_parser = (
reqparse.RequestParser()
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json", help="Feedback rating")
.add_argument("content", type=str, location="json", help="Feedback content")
)
feedback_list_parser = (
reqparse.RequestParser()
.add_argument("page", type=int, default=1, location="args", help="Page number")
.add_argument(
"limit",
type=int_range(1, 101),
required=False,
default=20,
location="args",
help="Number of feedbacks per page",
)
)
class MessageListQuery(BaseModel):
conversation_id: UUID
first_id: UUID | None = None
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return")
def build_message_model(api_or_ns: Api | Namespace):
class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
content: str | None = Field(default=None, description="Feedback content")
class FeedbackListQuery(BaseModel):
page: int = Field(default=1, ge=1, description="Page number")
limit: int = Field(default=20, ge=1, le=101, description="Number of feedbacks per page")
register_schema_models(service_api_ns, MessageListQuery, MessageFeedbackPayload, FeedbackListQuery)
def build_message_model(api_or_ns: Namespace):
"""Build the message model for the API or Namespace."""
# First build the nested models
feedback_model = build_feedback_model(api_or_ns)
@ -91,7 +79,7 @@ def build_message_model(api_or_ns: Api | Namespace):
return api_or_ns.model("Message", message_fields)
def build_message_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
def build_message_infinite_scroll_pagination_model(api_or_ns: Namespace):
"""Build the message infinite scroll pagination model for the API or Namespace."""
# Build the nested message model first
message_model = build_message_model(api_or_ns)
@ -106,7 +94,7 @@ def build_message_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
@service_api_ns.route("/messages")
class MessageListApi(Resource):
@service_api_ns.expect(message_list_parser)
@service_api_ns.expect(service_api_ns.models[MessageListQuery.__name__])
@service_api_ns.doc("list_messages")
@service_api_ns.doc(description="List messages in a conversation")
@service_api_ns.doc(
@ -127,11 +115,13 @@ class MessageListApi(Resource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
args = message_list_parser.parse_args()
query_args = MessageListQuery.model_validate(request.args.to_dict())
conversation_id = str(query_args.conversation_id)
first_id = str(query_args.first_id) if query_args.first_id else None
try:
return MessageService.pagination_by_first_id(
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"]
app_model, end_user, conversation_id, first_id, query_args.limit
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@ -141,7 +131,7 @@ class MessageListApi(Resource):
@service_api_ns.route("/messages/<uuid:message_id>/feedbacks")
class MessageFeedbackApi(Resource):
@service_api_ns.expect(message_feedback_parser)
@service_api_ns.expect(service_api_ns.models[MessageFeedbackPayload.__name__])
@service_api_ns.doc("create_message_feedback")
@service_api_ns.doc(description="Submit feedback for a message")
@service_api_ns.doc(params={"message_id": "Message ID"})
@ -160,15 +150,15 @@ class MessageFeedbackApi(Resource):
"""
message_id = str(message_id)
args = message_feedback_parser.parse_args()
payload = MessageFeedbackPayload.model_validate(service_api_ns.payload or {})
try:
MessageService.create_feedback(
app_model=app_model,
message_id=message_id,
user=end_user,
rating=args.get("rating"),
content=args.get("content"),
rating=payload.rating,
content=payload.content,
)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
@ -178,7 +168,7 @@ class MessageFeedbackApi(Resource):
@service_api_ns.route("/app/feedbacks")
class AppGetFeedbacksApi(Resource):
@service_api_ns.expect(feedback_list_parser)
@service_api_ns.expect(service_api_ns.models[FeedbackListQuery.__name__])
@service_api_ns.doc("get_app_feedbacks")
@service_api_ns.doc(description="Get all feedbacks for the application")
@service_api_ns.doc(
@ -193,8 +183,8 @@ class AppGetFeedbacksApi(Resource):
Returns paginated list of all feedback submitted for messages in this app.
"""
args = feedback_list_parser.parse_args()
feedbacks = MessageService.get_all_messages_feedbacks(app_model, page=args["page"], limit=args["limit"])
query_args = FeedbackListQuery.model_validate(request.args.to_dict())
feedbacks = MessageService.get_all_messages_feedbacks(app_model, page=query_args.page, limit=query_args.limit)
return {"data": feedbacks}

View File

@ -1,12 +1,14 @@
import logging
from typing import Any, Literal
from dateutil.parser import isoparse
from flask import request
from flask_restx import Api, Namespace, Resource, fields, reqparse
from flask_restx.inputs import int_range
from flask_restx import Api, Namespace, Resource, fields
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session, sessionmaker
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import (
CompletionRequestError,
@ -41,37 +43,25 @@ from services.workflow_app_service import WorkflowAppService
logger = logging.getLogger(__name__)
# Define parsers for workflow APIs
workflow_run_parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("files", type=list, required=False, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
)
workflow_log_parser = (
reqparse.RequestParser()
.add_argument("keyword", type=str, location="args")
.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
.add_argument("created_at__before", type=str, location="args")
.add_argument("created_at__after", type=str, location="args")
.add_argument(
"created_by_end_user_session_id",
type=str,
location="args",
required=False,
default=None,
)
.add_argument(
"created_by_account",
type=str,
location="args",
required=False,
default=None,
)
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
)
class WorkflowRunPayload(BaseModel):
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
response_mode: Literal["blocking", "streaming"] | None = None
class WorkflowLogQuery(BaseModel):
keyword: str | None = None
status: Literal["succeeded", "failed", "stopped"] | None = None
created_at__before: str | None = None
created_at__after: str | None = None
created_by_end_user_session_id: str | None = None
created_by_account: str | None = None
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=20, ge=1, le=100)
register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery)
workflow_run_fields = {
"id": fields.String,
@ -130,7 +120,7 @@ class WorkflowRunDetailApi(Resource):
@service_api_ns.route("/workflows/run")
class WorkflowRunApi(Resource):
@service_api_ns.expect(workflow_run_parser)
@service_api_ns.expect(service_api_ns.models[WorkflowRunPayload.__name__])
@service_api_ns.doc("run_workflow")
@service_api_ns.doc(description="Execute a workflow")
@service_api_ns.doc(
@ -154,11 +144,12 @@ class WorkflowRunApi(Resource):
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
args = workflow_run_parser.parse_args()
payload = WorkflowRunPayload.model_validate(service_api_ns.payload or {})
args = payload.model_dump(exclude_none=True)
external_trace_id = get_external_trace_id(request)
if external_trace_id:
args["external_trace_id"] = external_trace_id
streaming = args.get("response_mode") == "streaming"
streaming = payload.response_mode == "streaming"
try:
response = AppGenerateService.generate(
@ -185,7 +176,7 @@ class WorkflowRunApi(Resource):
@service_api_ns.route("/workflows/<string:workflow_id>/run")
class WorkflowRunByIdApi(Resource):
@service_api_ns.expect(workflow_run_parser)
@service_api_ns.expect(service_api_ns.models[WorkflowRunPayload.__name__])
@service_api_ns.doc("run_workflow_by_id")
@service_api_ns.doc(description="Execute a specific workflow by ID")
@service_api_ns.doc(params={"workflow_id": "Workflow ID to execute"})
@ -209,7 +200,8 @@ class WorkflowRunByIdApi(Resource):
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
args = workflow_run_parser.parse_args()
payload = WorkflowRunPayload.model_validate(service_api_ns.payload or {})
args = payload.model_dump(exclude_none=True)
# Add workflow_id to args for AppGenerateService
args["workflow_id"] = workflow_id
@ -217,7 +209,7 @@ class WorkflowRunByIdApi(Resource):
external_trace_id = get_external_trace_id(request)
if external_trace_id:
args["external_trace_id"] = external_trace_id
streaming = args.get("response_mode") == "streaming"
streaming = payload.response_mode == "streaming"
try:
response = AppGenerateService.generate(
@ -279,7 +271,7 @@ class WorkflowTaskStopApi(Resource):
@service_api_ns.route("/workflows/logs")
class WorkflowAppLogApi(Resource):
@service_api_ns.expect(workflow_log_parser)
@service_api_ns.expect(service_api_ns.models[WorkflowLogQuery.__name__])
@service_api_ns.doc("get_workflow_logs")
@service_api_ns.doc(description="Get workflow execution logs")
@service_api_ns.doc(
@ -295,14 +287,11 @@ class WorkflowAppLogApi(Resource):
Returns paginated workflow execution logs with filtering options.
"""
args = workflow_log_parser.parse_args()
args = WorkflowLogQuery.model_validate(request.args.to_dict())
args.status = WorkflowExecutionStatus(args.status) if args.status else None
if args.created_at__before:
args.created_at__before = isoparse(args.created_at__before)
if args.created_at__after:
args.created_at__after = isoparse(args.created_at__after)
status = WorkflowExecutionStatus(args.status) if args.status else None
created_at_before = isoparse(args.created_at__before) if args.created_at__before else None
created_at_after = isoparse(args.created_at__after) if args.created_at__after else None
# get paginate workflow app logs
workflow_app_service = WorkflowAppService()
@ -311,9 +300,9 @@ class WorkflowAppLogApi(Resource):
session=session,
app_model=app_model,
keyword=args.keyword,
status=args.status,
created_at_before=args.created_at__before,
created_at_after=args.created_at__after,
status=status,
created_at_before=created_at_before,
created_at_after=created_at_after,
page=args.page,
limit=args.limit,
created_by_end_user_session_id=args.created_by_end_user_session_id,

View File

@ -1,10 +1,12 @@
from typing import Any, Literal, cast
from flask import request
from flask_restx import marshal, reqparse
from flask_restx import marshal
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.common.schema import register_schema_models
from controllers.console.wraps import edit_permission_required
from controllers.service_api import service_api_ns
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
@ -18,173 +20,83 @@ from core.provider_manager import ProviderManager
from fields.dataset_fields import dataset_detail_fields
from fields.tag_fields import build_dataset_tag_fields
from libs.login import current_user
from libs.validators import validate_description_length
from models.account import Account
from models.dataset import Dataset, DatasetPermissionEnum
from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
from services.tag_service import TagService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
class DatasetCreatePayload(BaseModel):
name: str = Field(..., min_length=1, max_length=40)
description: str = Field(default="", description="Dataset description (max 400 chars)", max_length=400)
indexing_technique: Literal["high_quality", "economy"] | None = None
permission: DatasetPermissionEnum | None = DatasetPermissionEnum.ONLY_ME
external_knowledge_api_id: str | None = None
provider: str = "vendor"
external_knowledge_id: str | None = None
retrieval_model: RetrievalModel | None = None
embedding_model: str | None = None
embedding_model_provider: str | None = None
# Define parsers for dataset operations
dataset_create_parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument(
"description",
type=validate_description_length,
nullable=True,
required=False,
default="",
)
.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
help="Invalid indexing technique.",
)
.add_argument(
"permission",
type=str,
location="json",
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
help="Invalid permission.",
required=False,
nullable=False,
)
.add_argument(
"external_knowledge_api_id",
type=str,
nullable=True,
required=False,
default="_validate_name",
)
.add_argument(
"provider",
type=str,
nullable=True,
required=False,
default="vendor",
)
.add_argument(
"external_knowledge_id",
type=str,
nullable=True,
required=False,
)
.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
)
class DatasetUpdatePayload(BaseModel):
name: str | None = Field(default=None, min_length=1, max_length=40)
description: str | None = Field(default=None, description="Dataset description (max 400 chars)", max_length=400)
indexing_technique: Literal["high_quality", "economy"] | None = None
permission: DatasetPermissionEnum | None = None
embedding_model: str | None = None
embedding_model_provider: str | None = None
retrieval_model: RetrievalModel | None = None
partial_member_list: list[str] | None = None
external_retrieval_model: dict[str, Any] | None = None
external_knowledge_id: str | None = None
external_knowledge_api_id: str | None = None
dataset_update_parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument("description", location="json", store_missing=False, type=validate_description_length)
.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
.add_argument(
"permission",
type=str,
location="json",
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
help="Invalid permission.",
)
.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
.add_argument("embedding_model_provider", type=str, location="json", help="Invalid embedding model provider.")
.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
.add_argument(
"external_retrieval_model",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid external retrieval model.",
)
.add_argument(
"external_knowledge_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge id.",
)
.add_argument(
"external_knowledge_api_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge api id.",
)
)
tag_create_parser = reqparse.RequestParser().add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 50 characters.",
type=lambda x: x
if x and 1 <= len(x) <= 50
else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")),
)
class TagNamePayload(BaseModel):
name: str = Field(..., min_length=1, max_length=50)
tag_update_parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 50 characters.",
type=lambda x: x
if x and 1 <= len(x) <= 50
else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")),
)
.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
)
tag_delete_parser = reqparse.RequestParser().add_argument(
"tag_id", nullable=False, required=True, help="Id of a tag.", type=str
)
class TagCreatePayload(TagNamePayload):
pass
tag_binding_parser = (
reqparse.RequestParser()
.add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.")
.add_argument(
"target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required."
)
)
tag_unbinding_parser = (
reqparse.RequestParser()
.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
class TagUpdatePayload(TagNamePayload):
tag_id: str
class TagDeletePayload(BaseModel):
tag_id: str
class TagBindingPayload(BaseModel):
tag_ids: list[str]
target_id: str
@field_validator("tag_ids")
@classmethod
def validate_tag_ids(cls, value: list[str]) -> list[str]:
if not value:
raise ValueError("Tag IDs is required.")
return value
class TagUnbindingPayload(BaseModel):
tag_id: str
target_id: str
register_schema_models(
service_api_ns,
DatasetCreatePayload,
DatasetUpdatePayload,
TagCreatePayload,
TagUpdatePayload,
TagDeletePayload,
TagBindingPayload,
TagUnbindingPayload,
)
@ -239,7 +151,7 @@ class DatasetListApi(DatasetApiResource):
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
return response, 200
@service_api_ns.expect(dataset_create_parser)
@service_api_ns.expect(service_api_ns.models[DatasetCreatePayload.__name__])
@service_api_ns.doc("create_dataset")
@service_api_ns.doc(description="Create a new dataset")
@service_api_ns.doc(
@ -252,42 +164,41 @@ class DatasetListApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id):
"""Resource for creating datasets."""
args = dataset_create_parser.parse_args()
payload = DatasetCreatePayload.model_validate(service_api_ns.payload or {})
embedding_model_provider = args.get("embedding_model_provider")
embedding_model = args.get("embedding_model")
embedding_model_provider = payload.embedding_model_provider
embedding_model = payload.embedding_model
if embedding_model_provider and embedding_model:
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
retrieval_model = args.get("retrieval_model")
retrieval_model = payload.retrieval_model
if (
retrieval_model
and retrieval_model.get("reranking_model")
and retrieval_model.get("reranking_model").get("reranking_provider_name")
and retrieval_model.reranking_model
and retrieval_model.reranking_model.reranking_provider_name
and retrieval_model.reranking_model.reranking_model_name
):
DatasetService.check_reranking_model_setting(
tenant_id,
retrieval_model.get("reranking_model").get("reranking_provider_name"),
retrieval_model.get("reranking_model").get("reranking_model_name"),
retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.reranking_model.reranking_model_name,
)
try:
assert isinstance(current_user, Account)
dataset = DatasetService.create_empty_dataset(
tenant_id=tenant_id,
name=args["name"],
description=args["description"],
indexing_technique=args["indexing_technique"],
name=payload.name,
description=payload.description,
indexing_technique=payload.indexing_technique,
account=current_user,
permission=args["permission"],
provider=args["provider"],
external_knowledge_api_id=args["external_knowledge_api_id"],
external_knowledge_id=args["external_knowledge_id"],
embedding_model_provider=args["embedding_model_provider"],
embedding_model_name=args["embedding_model"],
retrieval_model=RetrievalModel.model_validate(args["retrieval_model"])
if args["retrieval_model"] is not None
else None,
permission=str(payload.permission) if payload.permission else None,
provider=payload.provider,
external_knowledge_api_id=payload.external_knowledge_api_id,
external_knowledge_id=payload.external_knowledge_id,
embedding_model_provider=payload.embedding_model_provider,
embedding_model_name=payload.embedding_model,
retrieval_model=payload.retrieval_model,
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
@ -353,7 +264,7 @@ class DatasetApi(DatasetApiResource):
return data, 200
@service_api_ns.expect(dataset_update_parser)
@service_api_ns.expect(service_api_ns.models[DatasetUpdatePayload.__name__])
@service_api_ns.doc("update_dataset")
@service_api_ns.doc(description="Update an existing dataset")
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
@ -372,36 +283,45 @@ class DatasetApi(DatasetApiResource):
if dataset is None:
raise NotFound("Dataset not found.")
args = dataset_update_parser.parse_args()
data = request.get_json()
payload_dict = service_api_ns.payload or {}
payload = DatasetUpdatePayload.model_validate(payload_dict)
update_data = payload.model_dump(exclude_unset=True)
if payload.permission is not None:
update_data["permission"] = str(payload.permission)
if payload.retrieval_model is not None:
update_data["retrieval_model"] = payload.retrieval_model.model_dump()
# check embedding model setting
embedding_model_provider = data.get("embedding_model_provider")
embedding_model = data.get("embedding_model")
if data.get("indexing_technique") == "high_quality" or embedding_model_provider:
embedding_model_provider = payload.embedding_model_provider
embedding_model = payload.embedding_model
if payload.indexing_technique == "high_quality" or embedding_model_provider:
if embedding_model_provider and embedding_model:
DatasetService.check_embedding_model_setting(
dataset.tenant_id, embedding_model_provider, embedding_model
)
retrieval_model = data.get("retrieval_model")
retrieval_model = payload.retrieval_model
if (
retrieval_model
and retrieval_model.get("reranking_model")
and retrieval_model.get("reranking_model").get("reranking_provider_name")
and retrieval_model.reranking_model
and retrieval_model.reranking_model.reranking_provider_name
and retrieval_model.reranking_model.reranking_model_name
):
DatasetService.check_reranking_model_setting(
dataset.tenant_id,
retrieval_model.get("reranking_model").get("reranking_provider_name"),
retrieval_model.get("reranking_model").get("reranking_model_name"),
retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.reranking_model.reranking_model_name,
)
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission(
current_user, dataset, data.get("permission"), data.get("partial_member_list")
current_user,
dataset,
str(payload.permission) if payload.permission else None,
payload.partial_member_list,
)
dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
dataset = DatasetService.update_dataset(dataset_id_str, update_data, current_user)
if dataset is None:
raise NotFound("Dataset not found.")
@ -410,15 +330,10 @@ class DatasetApi(DatasetApiResource):
assert isinstance(current_user, Account)
tenant_id = current_user.current_tenant_id
if data.get("partial_member_list") and data.get("permission") == "partial_members":
DatasetPermissionService.update_partial_member_list(
tenant_id, dataset_id_str, data.get("partial_member_list")
)
if payload.partial_member_list and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM:
DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id_str, payload.partial_member_list)
# clear partial member list when permission is only_me or all_team_members
elif (
data.get("permission") == DatasetPermissionEnum.ONLY_ME
or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
):
elif payload.permission in {DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM}:
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
@ -556,7 +471,7 @@ class DatasetTagsApi(DatasetApiResource):
return tags, 200
@service_api_ns.expect(tag_create_parser)
@service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__])
@service_api_ns.doc("create_dataset_tag")
@service_api_ns.doc(description="Add a knowledge type tag")
@service_api_ns.doc(
@ -574,14 +489,13 @@ class DatasetTagsApi(DatasetApiResource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
args = tag_create_parser.parse_args()
args["type"] = "knowledge"
tag = TagService.save_tags(args)
payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
tag = TagService.save_tags({"name": payload.name, "type": "knowledge"})
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
return response, 200
@service_api_ns.expect(tag_update_parser)
@service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__])
@service_api_ns.doc("update_dataset_tag")
@service_api_ns.doc(description="Update a knowledge type tag")
@service_api_ns.doc(
@ -598,10 +512,10 @@ class DatasetTagsApi(DatasetApiResource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
args = tag_update_parser.parse_args()
args["type"] = "knowledge"
tag_id = args["tag_id"]
tag = TagService.update_tags(args, tag_id)
payload = TagUpdatePayload.model_validate(service_api_ns.payload or {})
params = {"name": payload.name, "type": "knowledge"}
tag_id = payload.tag_id
tag = TagService.update_tags(params, tag_id)
binding_count = TagService.get_tag_binding_count(tag_id)
@ -609,7 +523,7 @@ class DatasetTagsApi(DatasetApiResource):
return response, 200
@service_api_ns.expect(tag_delete_parser)
@service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__])
@service_api_ns.doc("delete_dataset_tag")
@service_api_ns.doc(description="Delete a knowledge type tag")
@service_api_ns.doc(
@ -623,15 +537,15 @@ class DatasetTagsApi(DatasetApiResource):
@edit_permission_required
def delete(self, _, dataset_id):
"""Delete a knowledge type tag."""
args = tag_delete_parser.parse_args()
TagService.delete_tag(args["tag_id"])
payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag(payload.tag_id)
return 204
@service_api_ns.route("/datasets/tags/binding")
class DatasetTagBindingApi(DatasetApiResource):
@service_api_ns.expect(tag_binding_parser)
@service_api_ns.expect(service_api_ns.models[TagBindingPayload.__name__])
@service_api_ns.doc("bind_dataset_tags")
@service_api_ns.doc(description="Bind tags to a dataset")
@service_api_ns.doc(
@ -648,16 +562,15 @@ class DatasetTagBindingApi(DatasetApiResource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
args = tag_binding_parser.parse_args()
args["type"] = "knowledge"
TagService.save_tag_binding(args)
payload = TagBindingPayload.model_validate(service_api_ns.payload or {})
TagService.save_tag_binding({"tag_ids": payload.tag_ids, "target_id": payload.target_id, "type": "knowledge"})
return 204
@service_api_ns.route("/datasets/tags/unbinding")
class DatasetTagUnbindingApi(DatasetApiResource):
@service_api_ns.expect(tag_unbinding_parser)
@service_api_ns.expect(service_api_ns.models[TagUnbindingPayload.__name__])
@service_api_ns.doc("unbind_dataset_tag")
@service_api_ns.doc(description="Unbind a tag from a dataset")
@service_api_ns.doc(
@ -674,9 +587,8 @@ class DatasetTagUnbindingApi(DatasetApiResource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
args = tag_unbinding_parser.parse_args()
args["type"] = "knowledge"
TagService.delete_tag_binding(args)
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag_binding({"tag_id": payload.tag_id, "target_id": payload.target_id, "type": "knowledge"})
return 204

View File

@ -3,8 +3,8 @@ from typing import Self
from uuid import UUID
from flask import request
from flask_restx import marshal, reqparse
from pydantic import BaseModel, model_validator
from flask_restx import marshal
from pydantic import BaseModel, Field, model_validator
from sqlalchemy import desc, select
from werkzeug.exceptions import Forbidden, NotFound
@ -37,22 +37,19 @@ from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
from services.file_service import FileService
# Define parsers for document operations
document_text_create_parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=True, nullable=False, location="json")
.add_argument("text", type=str, required=True, nullable=False, location="json")
.add_argument("process_rule", type=dict, required=False, nullable=True, location="json")
.add_argument("original_document_id", type=str, required=False, location="json")
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
.add_argument(
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
)
.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
)
class DocumentTextCreatePayload(BaseModel):
name: str
text: str
process_rule: ProcessRule | None = None
original_document_id: str | None = None
doc_form: str = Field(default="text_model")
doc_language: str = Field(default="English")
indexing_technique: str | None = None
retrieval_model: RetrievalModel | None = None
embedding_model: str | None = None
embedding_model_provider: str | None = None
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@ -72,7 +69,7 @@ class DocumentTextUpdate(BaseModel):
return self
for m in [ProcessRule, RetrievalModel, DocumentTextUpdate]:
for m in [ProcessRule, RetrievalModel, DocumentTextCreatePayload, DocumentTextUpdate]:
service_api_ns.schema_model(m.__name__, m.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) # type: ignore
@ -83,7 +80,7 @@ for m in [ProcessRule, RetrievalModel, DocumentTextUpdate]:
class DocumentAddByTextApi(DatasetApiResource):
"""Resource for documents."""
@service_api_ns.expect(document_text_create_parser)
@service_api_ns.expect(service_api_ns.models[DocumentTextCreatePayload.__name__])
@service_api_ns.doc("create_document_by_text")
@service_api_ns.doc(description="Create a new document by providing text content")
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
@ -99,7 +96,8 @@ class DocumentAddByTextApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id):
"""Create document by text."""
args = document_text_create_parser.parse_args()
payload = DocumentTextCreatePayload.model_validate(service_api_ns.payload or {})
args = payload.model_dump(exclude_none=True)
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
@ -111,33 +109,29 @@ class DocumentAddByTextApi(DatasetApiResource):
if not dataset.indexing_technique and not args["indexing_technique"]:
raise ValueError("indexing_technique is required.")
text = args.get("text")
name = args.get("name")
if text is None or name is None:
raise ValueError("Both 'text' and 'name' must be non-null values.")
embedding_model_provider = args.get("embedding_model_provider")
embedding_model = args.get("embedding_model")
embedding_model_provider = payload.embedding_model_provider
embedding_model = payload.embedding_model
if embedding_model_provider and embedding_model:
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
retrieval_model = args.get("retrieval_model")
retrieval_model = payload.retrieval_model
if (
retrieval_model
and retrieval_model.get("reranking_model")
and retrieval_model.get("reranking_model").get("reranking_provider_name")
and retrieval_model.reranking_model
and retrieval_model.reranking_model.reranking_provider_name
and retrieval_model.reranking_model.reranking_model_name
):
DatasetService.check_reranking_model_setting(
tenant_id,
retrieval_model.get("reranking_model").get("reranking_provider_name"),
retrieval_model.get("reranking_model").get("reranking_model_name"),
retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.reranking_model.reranking_model_name,
)
if not current_user:
raise ValueError("current_user is required")
upload_file = FileService(db.engine).upload_text(
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
text=payload.text, text_name=payload.name, user_id=current_user.id, tenant_id=tenant_id
)
data_source = {
"type": "upload_file",
@ -174,7 +168,7 @@ class DocumentAddByTextApi(DatasetApiResource):
class DocumentUpdateByTextApi(DatasetApiResource):
"""Resource for update documents."""
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__], validate=True)
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__])
@service_api_ns.doc("update_document_by_text")
@service_api_ns.doc(description="Update an existing document by providing text content")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@ -189,22 +183,23 @@ class DocumentUpdateByTextApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
"""Update document by text."""
args = DocumentTextUpdate.model_validate(service_api_ns.payload).model_dump(exclude_unset=True)
payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {})
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).first()
args = payload.model_dump(exclude_none=True)
if not dataset:
raise ValueError("Dataset does not exist.")
retrieval_model = args.get("retrieval_model")
retrieval_model = payload.retrieval_model
if (
retrieval_model
and retrieval_model.get("reranking_model")
and retrieval_model.get("reranking_model").get("reranking_provider_name")
and retrieval_model.reranking_model
and retrieval_model.reranking_model.reranking_provider_name
and retrieval_model.reranking_model.reranking_model_name
):
DatasetService.check_reranking_model_setting(
tenant_id,
retrieval_model.get("reranking_model").get("reranking_provider_name"),
retrieval_model.get("reranking_model").get("reranking_model_name"),
retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.reranking_model.reranking_model_name,
)
# indexing_technique is already set in dataset since this is an update

View File

@ -1,9 +1,11 @@
from typing import Literal
from flask_login import current_user
from flask_restx import marshal, reqparse
from flask_restx import marshal
from pydantic import BaseModel
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_model, register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
from fields.dataset_fields import dataset_metadata_fields
@ -14,25 +16,18 @@ from services.entities.knowledge_entities.knowledge_entities import (
)
from services.metadata_service import MetadataService
# Define parsers for metadata APIs
metadata_create_parser = (
reqparse.RequestParser()
.add_argument("type", type=str, required=True, nullable=False, location="json", help="Metadata type")
.add_argument("name", type=str, required=True, nullable=False, location="json", help="Metadata name")
)
metadata_update_parser = reqparse.RequestParser().add_argument(
"name", type=str, required=True, nullable=False, location="json", help="New metadata name"
)
class MetadataUpdatePayload(BaseModel):
name: str
document_metadata_parser = reqparse.RequestParser().add_argument(
"operation_data", type=list, required=True, nullable=False, location="json", help="Metadata operation data"
)
register_schema_model(service_api_ns, MetadataUpdatePayload)
register_schema_models(service_api_ns, MetadataArgs, MetadataOperationData)
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata")
class DatasetMetadataCreateServiceApi(DatasetApiResource):
@service_api_ns.expect(metadata_create_parser)
@service_api_ns.expect(service_api_ns.models[MetadataArgs.__name__])
@service_api_ns.doc("create_dataset_metadata")
@service_api_ns.doc(description="Create metadata for a dataset")
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
@ -46,8 +41,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id):
"""Create metadata for a dataset."""
args = metadata_create_parser.parse_args()
metadata_args = MetadataArgs.model_validate(args)
metadata_args = MetadataArgs.model_validate(service_api_ns.payload or {})
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -79,7 +73,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
class DatasetMetadataServiceApi(DatasetApiResource):
@service_api_ns.expect(metadata_update_parser)
@service_api_ns.expect(service_api_ns.models[MetadataUpdatePayload.__name__])
@service_api_ns.doc("update_dataset_metadata")
@service_api_ns.doc(description="Update metadata name")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "metadata_id": "Metadata ID"})
@ -93,7 +87,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, tenant_id, dataset_id, metadata_id):
"""Update metadata name."""
args = metadata_update_parser.parse_args()
payload = MetadataUpdatePayload.model_validate(service_api_ns.payload or {})
dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id)
@ -102,7 +96,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args["name"])
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, payload.name)
return marshal(metadata, dataset_metadata_fields), 200
@service_api_ns.doc("delete_dataset_metadata")
@ -175,7 +169,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/metadata")
class DocumentMetadataEditServiceApi(DatasetApiResource):
@service_api_ns.expect(document_metadata_parser)
@service_api_ns.expect(service_api_ns.models[MetadataOperationData.__name__])
@service_api_ns.doc("update_documents_metadata")
@service_api_ns.doc(description="Update metadata for multiple documents")
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
@ -195,8 +189,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
args = document_metadata_parser.parse_args()
metadata_args = MetadataOperationData.model_validate(args)
metadata_args = MetadataOperationData.model_validate(service_api_ns.payload or {})
MetadataService.update_documents_metadata(dataset, metadata_args)

View File

@ -4,12 +4,12 @@ from collections.abc import Generator
from typing import Any
from flask import request
from flask_restx import reqparse
from flask_restx.reqparse import ParseResult, RequestParser
from pydantic import BaseModel
from werkzeug.exceptions import Forbidden
import services
from controllers.common.errors import FilenameNotExistsError, NoFileUploadedError, TooManyFilesError
from controllers.common.schema import register_schema_model
from controllers.service_api import service_api_ns
from controllers.service_api.dataset.error import PipelineRunError
from controllers.service_api.wraps import DatasetApiResource
@ -22,11 +22,25 @@ from models.dataset import Pipeline
from models.engine import db
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
from services.file_service import FileService
from services.rag_pipeline.entity.pipeline_service_api_entities import DatasourceNodeRunApiEntity
from services.rag_pipeline.entity.pipeline_service_api_entities import (
DatasourceNodeRunApiEntity,
PipelineRunApiEntity,
)
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
from services.rag_pipeline.rag_pipeline import RagPipelineService
class DatasourceNodeRunPayload(BaseModel):
inputs: dict[str, Any]
datasource_type: str
credential_id: str | None = None
is_published: bool
register_schema_model(service_api_ns, DatasourceNodeRunPayload)
register_schema_model(service_api_ns, PipelineRunApiEntity)
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource-plugins")
class DatasourcePluginsApi(DatasetApiResource):
"""Resource for datasource plugins."""
@ -88,22 +102,20 @@ class DatasourceNodeRunApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
@service_api_ns.expect(service_api_ns.models[DatasourceNodeRunPayload.__name__])
def post(self, tenant_id: str, dataset_id: str, node_id: str):
"""Resource for getting datasource plugins."""
# Get query parameter to determine published or draft
parser: RequestParser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
.add_argument("is_published", type=bool, required=True, location="json")
)
args: ParseResult = parser.parse_args()
datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(args)
payload = DatasourceNodeRunPayload.model_validate(service_api_ns.payload or {})
assert isinstance(current_user, Account)
rag_pipeline_service: RagPipelineService = RagPipelineService()
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(
{
**payload.model_dump(exclude_none=True),
"pipeline_id": str(pipeline.id),
"node_id": node_id,
}
)
return helper.compact_generate_response(
PipelineGenerator.convert_to_event_stream(
rag_pipeline_service.run_datasource_workflow_node(
@ -147,25 +159,10 @@ class PipelineRunApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
@service_api_ns.expect(service_api_ns.models[PipelineRunApiEntity.__name__])
def post(self, tenant_id: str, dataset_id: str):
"""Resource for running a rag pipeline."""
parser: RequestParser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
.add_argument("is_published", type=bool, required=True, default=True, location="json")
.add_argument(
"response_mode",
type=str,
required=True,
choices=["streaming", "blocking"],
default="blocking",
location="json",
)
)
args: ParseResult = parser.parse_args()
payload = PipelineRunApiEntity.model_validate(service_api_ns.payload or {})
if not isinstance(current_user, Account):
raise Forbidden()
@ -176,9 +173,9 @@ class PipelineRunApi(DatasetApiResource):
response: dict[Any, Any] | Generator[str, Any, None] = PipelineGenerateService.generate(
pipeline=pipeline,
user=current_user,
args=args,
invoke_from=InvokeFrom.PUBLISHED if args.get("is_published") else InvokeFrom.DEBUGGER,
streaming=args.get("response_mode") == "streaming",
args=payload.model_dump(),
invoke_from=InvokeFrom.PUBLISHED if payload.is_published else InvokeFrom.DEBUGGER,
streaming=payload.response_mode == "streaming",
)
return helper.compact_generate_response(response)

View File

@ -1,8 +1,12 @@
from typing import Any
from flask import request
from flask_restx import marshal, reqparse
from flask_restx import marshal
from pydantic import BaseModel, Field
from werkzeug.exceptions import NotFound
from configs import dify_config
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.wraps import (
@ -24,34 +28,42 @@ from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexing
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
# Define parsers for segment operations
segment_create_parser = reqparse.RequestParser().add_argument(
"segments", type=list, required=False, nullable=True, location="json"
)
segment_list_parser = (
reqparse.RequestParser()
.add_argument("status", type=str, action="append", default=[], location="args")
.add_argument("keyword", type=str, default=None, location="args")
)
class SegmentCreatePayload(BaseModel):
segments: list[dict[str, Any]] | None = None
segment_update_parser = reqparse.RequestParser().add_argument(
"segment", type=dict, required=False, nullable=True, location="json"
)
child_chunk_create_parser = reqparse.RequestParser().add_argument(
"content", type=str, required=True, nullable=False, location="json"
)
class SegmentListQuery(BaseModel):
status: list[str] = Field(default_factory=list)
keyword: str | None = None
child_chunk_list_parser = (
reqparse.RequestParser()
.add_argument("limit", type=int, default=20, location="args")
.add_argument("keyword", type=str, default=None, location="args")
.add_argument("page", type=int, default=1, location="args")
)
child_chunk_update_parser = reqparse.RequestParser().add_argument(
"content", type=str, required=True, nullable=False, location="json"
class SegmentUpdatePayload(BaseModel):
segment: SegmentUpdateArgs
class ChildChunkCreatePayload(BaseModel):
content: str
class ChildChunkListQuery(BaseModel):
limit: int = Field(default=20, ge=1)
keyword: str | None = None
page: int = Field(default=1, ge=1)
class ChildChunkUpdatePayload(BaseModel):
content: str
register_schema_models(
service_api_ns,
SegmentCreatePayload,
SegmentListQuery,
SegmentUpdatePayload,
ChildChunkCreatePayload,
ChildChunkListQuery,
ChildChunkUpdatePayload,
)
@ -59,7 +71,7 @@ child_chunk_update_parser = reqparse.RequestParser().add_argument(
class SegmentApi(DatasetApiResource):
"""Resource for segments."""
@service_api_ns.expect(segment_create_parser)
@service_api_ns.expect(service_api_ns.models[SegmentCreatePayload.__name__])
@service_api_ns.doc("create_segments")
@service_api_ns.doc(description="Create segments in a document")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@ -106,20 +118,20 @@ class SegmentApi(DatasetApiResource):
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# validate args
args = segment_create_parser.parse_args()
if args["segments"] is not None:
payload = SegmentCreatePayload.model_validate(service_api_ns.payload or {})
if payload.segments is not None:
segments_limit = dify_config.DATASET_MAX_SEGMENTS_PER_REQUEST
if segments_limit > 0 and len(args["segments"]) > segments_limit:
if segments_limit > 0 and len(payload.segments) > segments_limit:
raise ValueError(f"Exceeded maximum segments limit of {segments_limit}.")
for args_item in args["segments"]:
for args_item in payload.segments:
SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(args["segments"], document, dataset)
segments = SegmentService.multi_create_segment(payload.segments, document, dataset)
return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200
else:
return {"error": "Segments is required"}, 400
@service_api_ns.expect(segment_list_parser)
@service_api_ns.expect(service_api_ns.models[SegmentListQuery.__name__])
@service_api_ns.doc("list_segments")
@service_api_ns.doc(description="List segments in a document")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@ -160,13 +172,18 @@ class SegmentApi(DatasetApiResource):
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
args = segment_list_parser.parse_args()
args = SegmentListQuery.model_validate(
{
"status": request.args.getlist("status"),
"keyword": request.args.get("keyword"),
}
)
segments, total = SegmentService.get_segments(
document_id=document_id,
tenant_id=current_tenant_id,
status_list=args["status"],
keyword=args["keyword"],
status_list=args.status,
keyword=args.keyword,
page=page,
limit=limit,
)
@ -217,7 +234,7 @@ class DatasetSegmentApi(DatasetApiResource):
SegmentService.delete_segment(segment, document, dataset)
return 204
@service_api_ns.expect(segment_update_parser)
@service_api_ns.expect(service_api_ns.models[SegmentUpdatePayload.__name__])
@service_api_ns.doc("update_segment")
@service_api_ns.doc(description="Update a specific segment")
@service_api_ns.doc(
@ -265,12 +282,9 @@ class DatasetSegmentApi(DatasetApiResource):
if not segment:
raise NotFound("Segment not found.")
# validate args
args = segment_update_parser.parse_args()
payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {})
updated_segment = SegmentService.update_segment(
SegmentUpdateArgs.model_validate(args["segment"]), segment, document, dataset
)
updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset)
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200
@service_api_ns.doc("get_segment")
@ -308,7 +322,7 @@ class DatasetSegmentApi(DatasetApiResource):
class ChildChunkApi(DatasetApiResource):
"""Resource for child chunks."""
@service_api_ns.expect(child_chunk_create_parser)
@service_api_ns.expect(service_api_ns.models[ChildChunkCreatePayload.__name__])
@service_api_ns.doc("create_child_chunk")
@service_api_ns.doc(description="Create a new child chunk for a segment")
@service_api_ns.doc(
@ -360,16 +374,16 @@ class ChildChunkApi(DatasetApiResource):
raise ProviderNotInitializeError(ex.description)
# validate args
args = child_chunk_create_parser.parse_args()
payload = ChildChunkCreatePayload.model_validate(service_api_ns.payload or {})
try:
child_chunk = SegmentService.create_child_chunk(args["content"], segment, document, dataset)
child_chunk = SegmentService.create_child_chunk(payload.content, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
@service_api_ns.expect(child_chunk_list_parser)
@service_api_ns.expect(service_api_ns.models[ChildChunkListQuery.__name__])
@service_api_ns.doc("list_child_chunks")
@service_api_ns.doc(description="List child chunks for a segment")
@service_api_ns.doc(
@ -400,11 +414,17 @@ class ChildChunkApi(DatasetApiResource):
if not segment:
raise NotFound("Segment not found.")
args = child_chunk_list_parser.parse_args()
args = ChildChunkListQuery.model_validate(
{
"limit": request.args.get("limit", default=20, type=int),
"keyword": request.args.get("keyword"),
"page": request.args.get("page", default=1, type=int),
}
)
page = args["page"]
limit = min(args["limit"], 100)
keyword = args["keyword"]
page = args.page
limit = min(args.limit, 100)
keyword = args.keyword
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
@ -480,7 +500,7 @@ class DatasetChildChunkApi(DatasetApiResource):
return 204
@service_api_ns.expect(child_chunk_update_parser)
@service_api_ns.expect(service_api_ns.models[ChildChunkUpdatePayload.__name__])
@service_api_ns.doc("update_child_chunk")
@service_api_ns.doc(description="Update a specific child chunk")
@service_api_ns.doc(
@ -533,10 +553,10 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Child chunk not found.")
# validate args
args = child_chunk_update_parser.parse_args()
payload = ChildChunkUpdatePayload.model_validate(service_api_ns.payload or {})
try:
child_chunk = SegmentService.update_child_chunk(args["content"], child_chunk, segment, document, dataset)
child_chunk = SegmentService.update_child_chunk(payload.content, child_chunk, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))

View File

@ -33,7 +33,7 @@ def trigger_endpoint(endpoint_id: str):
if response:
break
if not response:
logger.error("Endpoint not found for {endpoint_id}")
logger.info("Endpoint not found for %s", endpoint_id)
return jsonify({"error": "Endpoint not found"}), 404
return response
except ValueError as e: