mirror of
https://github.com/langgenius/dify.git
synced 2026-04-30 23:48:04 +08:00
feat: sync main branch (#31938)
Signed-off-by: majiayu000 <1835304752@qq.com> Signed-off-by: dependabot[bot] <support@github.com> Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com> Signed-off-by: -LAN- <laipz8200@outlook.com> Signed-off-by: yihong0618 <zouzou0208@gmail.com> Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com> Co-authored-by: 盐粒 Yanli <yanli@dify.ai> Co-authored-by: wangxiaolei <fatelei@gmail.com> Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Cursx <33718736+Cursx@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: lif <1835304752@qq.com> Co-authored-by: 非法操作 <hjlarry@163.com> Co-authored-by: Asuka Minato <i@asukaminato.eu.org> Co-authored-by: fenglin <790872612@qq.com> Co-authored-by: qiaofenglin <qiaofenglin@baidu.com> Co-authored-by: -LAN- <laipz8200@outlook.com> Co-authored-by: TomoOkuyama <49631611+TomoOkuyama@users.noreply.github.com> Co-authored-by: Tomo Okuyama <tomo.okuyama@intersystems.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: zyssyz123 <916125788@qq.com> Co-authored-by: hj24 <mambahj24@gmail.com> Co-authored-by: Coding On Star <447357187@qq.com> Co-authored-by: CodingOnStar <hanxujiang@dify.ai> Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com> Co-authored-by: Xiangxuan Qu <fghpdf@outlook.com> Co-authored-by: fghpdf <fghpdf@users.noreply.github.com> Co-authored-by: coopercoder <whitetiger0127@163.com> Co-authored-by: zhaiguangpeng <zhaiguangpeng@didiglobal.com> Co-authored-by: Junyan Qin (Chin) <rockchinq@gmail.com> Co-authored-by: E.G <146701565+GlobalStar117@users.noreply.github.com> Co-authored-by: GlobalStar117 <GlobalStar117@users.noreply.github.com> Co-authored-by: Claude Haiku 4.5 <noreply@anthropic.com> Co-authored-by: CodingOnStar <hanxujiang@dify.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: heyszt <270985384@qq.com> Co-authored-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com> Co-authored-by: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Co-authored-by: zxhlyh <jasonapring2015@outlook.com> Co-authored-by: moonpanda <chuanzegao@163.com> Co-authored-by: warlocgao <warlocgao@tencent.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Co-authored-by: KVOJJJin <jzongcode@gmail.com> Co-authored-by: eux <euxx@users.noreply.github.com> Co-authored-by: bangjiehan <bangjiehan@gmail.com> Co-authored-by: FFXN <31929997+FFXN@users.noreply.github.com> Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com> Co-authored-by: Nie Ronghua <nieronghua@sf-express.com> Co-authored-by: JQSevenMiao <141806521+JQSevenMiao@users.noreply.github.com> Co-authored-by: jiasiqi <jiasiqi3@tal.com> Co-authored-by: Seokrin Taron Sung <sungsjade@gmail.com> Co-authored-by: CrabSAMA <40541269+CrabSAMA@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: yihong <zouzou0208@gmail.com> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com> Co-authored-by: yessenia <yessenia.contact@gmail.com> Co-authored-by: Jax <anobaka@qq.com> Co-authored-by: niveshdandyan <155956228+niveshdandyan@users.noreply.github.com> Co-authored-by: OSS Contributor <oss-contributor@example.com> Co-authored-by: niveshdandyan <niveshdandyan@users.noreply.github.com> Co-authored-by: Sean Kenneth Doherty <Smaster7772@gmail.com>
This commit is contained in:
@ -243,15 +243,13 @@ class InsertExploreBannerApi(Resource):
|
||||
def post(self):
|
||||
payload = InsertExploreBannerPayload.model_validate(console_ns.payload)
|
||||
|
||||
content = {
|
||||
"category": payload.category,
|
||||
"title": payload.title,
|
||||
"description": payload.description,
|
||||
"img-src": payload.img_src,
|
||||
}
|
||||
|
||||
banner = ExporleBanner(
|
||||
content=content,
|
||||
content={
|
||||
"category": payload.category,
|
||||
"title": payload.title,
|
||||
"description": payload.description,
|
||||
"img-src": payload.img_src,
|
||||
},
|
||||
link=payload.link,
|
||||
sort=payload.sort,
|
||||
language=payload.language,
|
||||
|
||||
@ -22,10 +22,10 @@ api_key_fields = {
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="items")}
|
||||
|
||||
api_key_item_model = console_ns.model("ApiKeyItem", api_key_fields)
|
||||
|
||||
api_key_list = {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
|
||||
|
||||
api_key_list_model = console_ns.model(
|
||||
"ApiKeyList", {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
|
||||
)
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import abort, make_response, request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
|
||||
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
@ -16,9 +17,11 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import (
|
||||
annotation_fields,
|
||||
annotation_hit_history_fields,
|
||||
build_annotation_model,
|
||||
Annotation,
|
||||
AnnotationExportList,
|
||||
AnnotationHitHistory,
|
||||
AnnotationHitHistoryList,
|
||||
AnnotationList,
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
@ -89,6 +92,14 @@ reg(CreateAnnotationPayload)
|
||||
reg(UpdateAnnotationPayload)
|
||||
reg(AnnotationReplyStatusQuery)
|
||||
reg(AnnotationFilePayload)
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
Annotation,
|
||||
AnnotationList,
|
||||
AnnotationExportList,
|
||||
AnnotationHitHistory,
|
||||
AnnotationHitHistoryList,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>")
|
||||
@ -107,10 +118,11 @@ class AnnotationReplyActionApi(Resource):
|
||||
def post(self, app_id, action: Literal["enable", "disable"]):
|
||||
app_id = str(app_id)
|
||||
args = AnnotationReplyPayload.model_validate(console_ns.payload)
|
||||
if action == "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
|
||||
elif action == "disable":
|
||||
result = AppAnnotationService.disable_app_annotation(app_id)
|
||||
match action:
|
||||
case "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
|
||||
case "disable":
|
||||
result = AppAnnotationService.disable_app_annotation(app_id)
|
||||
return result, 200
|
||||
|
||||
|
||||
@ -201,33 +213,33 @@ class AnnotationApi(Resource):
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
|
||||
response = {
|
||||
"data": marshal(annotation_list, annotation_fields),
|
||||
"has_more": len(annotation_list) == limit,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
"page": page,
|
||||
}
|
||||
return response, 200
|
||||
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
|
||||
response = AnnotationList(
|
||||
data=annotation_models,
|
||||
has_more=len(annotation_list) == limit,
|
||||
limit=limit,
|
||||
total=total,
|
||||
page=page,
|
||||
)
|
||||
return response.model_dump(mode="json"), 200
|
||||
|
||||
@console_ns.doc("create_annotation")
|
||||
@console_ns.doc(description="Create a new annotation for an app")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[CreateAnnotationPayload.__name__])
|
||||
@console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns))
|
||||
@console_ns.response(201, "Annotation created successfully", console_ns.models[Annotation.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@marshal_with(annotation_fields)
|
||||
@edit_permission_required
|
||||
def post(self, app_id):
|
||||
app_id = str(app_id)
|
||||
args = CreateAnnotationPayload.model_validate(console_ns.payload)
|
||||
data = args.model_dump(exclude_none=True)
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
|
||||
return annotation
|
||||
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -264,7 +276,7 @@ class AnnotationExportApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Annotations exported successfully",
|
||||
console_ns.model("AnnotationList", {"data": fields.List(fields.Nested(build_annotation_model(console_ns)))}),
|
||||
console_ns.models[AnnotationExportList.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@ -274,7 +286,8 @@ class AnnotationExportApi(Resource):
|
||||
def get(self, app_id):
|
||||
app_id = str(app_id)
|
||||
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
|
||||
response_data = {"data": marshal(annotation_list, annotation_fields)}
|
||||
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
|
||||
response_data = AnnotationExportList(data=annotation_models).model_dump(mode="json")
|
||||
|
||||
# Create response with secure headers for CSV export
|
||||
response = make_response(response_data, 200)
|
||||
@ -289,7 +302,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@console_ns.doc("update_delete_annotation")
|
||||
@console_ns.doc(description="Update or delete an annotation")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
|
||||
@console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns))
|
||||
@console_ns.response(200, "Annotation updated successfully", console_ns.models[Annotation.__name__])
|
||||
@console_ns.response(204, "Annotation deleted successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.expect(console_ns.models[UpdateAnnotationPayload.__name__])
|
||||
@ -298,7 +311,6 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@edit_permission_required
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id, annotation_id):
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
@ -306,7 +318,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(
|
||||
args.model_dump(exclude_none=True), app_id, annotation_id
|
||||
)
|
||||
return annotation
|
||||
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -414,14 +426,7 @@ class AnnotationHitHistoryListApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Hit histories retrieved successfully",
|
||||
console_ns.model(
|
||||
"AnnotationHitHistoryList",
|
||||
{
|
||||
"data": fields.List(
|
||||
fields.Nested(console_ns.model("AnnotationHitHistoryItem", annotation_hit_history_fields))
|
||||
)
|
||||
},
|
||||
),
|
||||
console_ns.models[AnnotationHitHistoryList.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@ -436,11 +441,14 @@ class AnnotationHitHistoryListApi(Resource):
|
||||
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(
|
||||
app_id, annotation_id, page, limit
|
||||
)
|
||||
response = {
|
||||
"data": marshal(annotation_hit_history_list, annotation_hit_history_fields),
|
||||
"has_more": len(annotation_hit_history_list) == limit,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
"page": page,
|
||||
}
|
||||
return response
|
||||
history_models = TypeAdapter(list[AnnotationHitHistory]).validate_python(
|
||||
annotation_hit_history_list, from_attributes=True
|
||||
)
|
||||
response = AnnotationHitHistoryList(
|
||||
data=history_models,
|
||||
has_more=len(annotation_hit_history_list) == limit,
|
||||
limit=limit,
|
||||
total=total,
|
||||
page=page,
|
||||
)
|
||||
return response.model_dump(mode="json")
|
||||
|
||||
@ -9,9 +9,11 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.helpers import FileInfo
|
||||
from controllers.common.schema import register_enum_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.workspace.models import LoadBalancingPayload
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
@ -22,18 +24,36 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from core.file import helpers as file_helpers
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from core.workflow.enums import NodeType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.workflow.enums import NodeType, WorkflowExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, Workflow
|
||||
from models import App, DatasetPermissionEnum, Workflow
|
||||
from models.model import IconType
|
||||
from services.app_dsl_service import AppDslService, ImportMode
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
DataSource,
|
||||
InfoList,
|
||||
NotionIcon,
|
||||
NotionInfo,
|
||||
NotionPage,
|
||||
PreProcessingRule,
|
||||
RerankingModel,
|
||||
Rule,
|
||||
Segmentation,
|
||||
WebsiteInfo,
|
||||
WeightKeywordSetting,
|
||||
WeightModel,
|
||||
WeightVectorSetting,
|
||||
)
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
|
||||
register_enum_models(console_ns, IconType)
|
||||
|
||||
|
||||
class AppListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
|
||||
@ -151,7 +171,7 @@ def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str |
|
||||
if icon is None or icon_type is None:
|
||||
return None
|
||||
icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type)
|
||||
if icon_type_value.lower() != IconType.IMAGE.value:
|
||||
if icon_type_value.lower() != IconType.IMAGE:
|
||||
return None
|
||||
return file_helpers.get_signed_file_url(icon)
|
||||
|
||||
@ -391,6 +411,8 @@ class AppExportResponse(ResponseModel):
|
||||
data: str
|
||||
|
||||
|
||||
register_enum_models(console_ns, RetrievalMethod, WorkflowExecutionStatus, DatasetPermissionEnum)
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
AppListQuery,
|
||||
@ -414,6 +436,22 @@ register_schema_models(
|
||||
AppDetailWithSite,
|
||||
AppPagination,
|
||||
AppExportResponse,
|
||||
Segmentation,
|
||||
PreProcessingRule,
|
||||
Rule,
|
||||
WeightVectorSetting,
|
||||
WeightKeywordSetting,
|
||||
WeightModel,
|
||||
RerankingModel,
|
||||
InfoList,
|
||||
NotionInfo,
|
||||
FileInfo,
|
||||
WebsiteInfo,
|
||||
NotionPage,
|
||||
NotionIcon,
|
||||
RerankingModel,
|
||||
DataSource,
|
||||
LoadBalancingPayload,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -41,14 +41,14 @@ DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
class AppImportPayload(BaseModel):
|
||||
mode: str = Field(..., description="Import mode")
|
||||
yaml_content: str | None = None
|
||||
yaml_url: str | None = None
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
icon_type: str | None = None
|
||||
icon: str | None = None
|
||||
icon_background: str | None = None
|
||||
app_id: str | None = None
|
||||
yaml_content: str | None = Field(None)
|
||||
yaml_url: str | None = Field(None)
|
||||
name: str | None = Field(None)
|
||||
description: str | None = Field(None)
|
||||
icon_type: str | None = Field(None)
|
||||
icon: str | None = Field(None)
|
||||
icon_background: str | None = Field(None)
|
||||
app_id: str | None = Field(None)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
|
||||
@ -6,6 +6,7 @@ from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
@ -33,7 +34,6 @@ from services.errors.audio import (
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class TextToSpeechPayload(BaseModel):
|
||||
@ -47,13 +47,11 @@ class TextToSpeechVoiceQuery(BaseModel):
|
||||
language: str = Field(..., description="Language code")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
TextToSpeechPayload.__name__, TextToSpeechPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
TextToSpeechVoiceQuery.__name__,
|
||||
TextToSpeechVoiceQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
class AudioTranscriptResponse(BaseModel):
|
||||
text: str = Field(description="Transcribed text from audio")
|
||||
|
||||
|
||||
register_schema_models(console_ns, AudioTranscriptResponse, TextToSpeechPayload, TextToSpeechVoiceQuery)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/audio-to-text")
|
||||
@ -64,7 +62,7 @@ class ChatMessageAudioApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Audio transcription successful",
|
||||
console_ns.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}),
|
||||
console_ns.models[AudioTranscriptResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Bad request - No audio uploaded or unsupported type")
|
||||
@console_ns.response(413, "Audio file too large")
|
||||
|
||||
@ -508,16 +508,19 @@ class ChatConversationApi(Resource):
|
||||
case "created_at" | "-created_at" | _:
|
||||
query = query.where(Conversation.created_at <= end_datetime_utc)
|
||||
|
||||
if args.annotation_status == "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
elif args.annotation_status == "not_annotated":
|
||||
query = (
|
||||
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
.having(func.count(MessageAnnotation.id) == 0)
|
||||
)
|
||||
match args.annotation_status:
|
||||
case "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
case "not_annotated":
|
||||
query = (
|
||||
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
.having(func.count(MessageAnnotation.id) == 0)
|
||||
)
|
||||
case "all":
|
||||
pass
|
||||
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
|
||||
|
||||
@ -82,13 +82,13 @@ class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
||||
class DraftWorkflowNotExist(BaseHTTPException):
|
||||
error_code = "draft_workflow_not_exist"
|
||||
description = "Draft workflow need to be initialized."
|
||||
code = 400
|
||||
code = 404
|
||||
|
||||
|
||||
class DraftWorkflowNotSync(BaseHTTPException):
|
||||
error_code = "draft_workflow_not_sync"
|
||||
description = "Workflow graph might have been modified, please refresh and resubmit."
|
||||
code = 400
|
||||
code = 409
|
||||
|
||||
|
||||
class TracingConfigNotExist(BaseHTTPException):
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
@ -16,10 +15,12 @@ from controllers.console.app.error import (
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.generator import WorkflowGenerator
|
||||
@ -31,28 +32,13 @@ from services.workflow_service import WorkflowService
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class RuleGeneratePayload(BaseModel):
|
||||
instruction: str = Field(..., description="Rule generation instruction")
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
no_variable: bool = Field(default=False, description="Whether to exclude variables")
|
||||
|
||||
|
||||
class RuleCodeGeneratePayload(RuleGeneratePayload):
|
||||
code_language: str = Field(default="javascript", description="Programming language for code generation")
|
||||
|
||||
|
||||
class RuleStructuredOutputPayload(BaseModel):
|
||||
instruction: str = Field(..., description="Structured output generation instruction")
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
|
||||
|
||||
class InstructionGeneratePayload(BaseModel):
|
||||
flow_id: str = Field(..., description="Workflow/Flow ID")
|
||||
node_id: str = Field(default="", description="Node ID for workflow context")
|
||||
current: str = Field(default="", description="Current instruction text")
|
||||
language: str = Field(default="javascript", description="Programming language (javascript/python)")
|
||||
instruction: str = Field(..., description="Instruction for generation")
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration")
|
||||
ideal_output: str = Field(default="", description="Expected ideal output")
|
||||
|
||||
|
||||
@ -94,6 +80,7 @@ reg(RuleStructuredOutputPayload)
|
||||
reg(InstructionGeneratePayload)
|
||||
reg(InstructionTemplatePayload)
|
||||
reg(FlowchartGeneratePayload)
|
||||
reg(ModelConfig)
|
||||
|
||||
|
||||
@console_ns.route("/rule-generate")
|
||||
@ -112,12 +99,7 @@ class RuleGenerateApi(Resource):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
rules = LLMGenerator.generate_rule_config(
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=args.no_variable,
|
||||
)
|
||||
rules = LLMGenerator.generate_rule_config(tenant_id=current_tenant_id, args=args)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
@ -148,9 +130,7 @@ class RuleCodeGenerateApi(Resource):
|
||||
try:
|
||||
code_result = LLMGenerator.generate_code(
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
code_language=args.code_language,
|
||||
args=args,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -182,8 +162,7 @@ class RuleStructuredOutputGenerateApi(Resource):
|
||||
try:
|
||||
structured_output = LLMGenerator.generate_structured_output(
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
args=args,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -234,23 +213,29 @@ class InstructionGenerateApi(Resource):
|
||||
case "llm":
|
||||
return LLMGenerator.generate_rule_config(
|
||||
current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=True,
|
||||
args=RuleGeneratePayload(
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=True,
|
||||
),
|
||||
)
|
||||
case "agent":
|
||||
return LLMGenerator.generate_rule_config(
|
||||
current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=True,
|
||||
args=RuleGeneratePayload(
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=True,
|
||||
),
|
||||
)
|
||||
case "code":
|
||||
return LLMGenerator.generate_code(
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
code_language=args.language,
|
||||
args=RuleCodeGeneratePayload(
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
code_language=args.language,
|
||||
),
|
||||
)
|
||||
case _:
|
||||
return {"error": f"invalid node type: {node_type}"}
|
||||
|
||||
@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import exists, select
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
@ -35,7 +36,6 @@ from services.errors.message import MessageNotExistsError, SuggestedQuestionsAft
|
||||
from services.message_service import MessageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ChatMessagesQuery(BaseModel):
|
||||
@ -90,13 +90,22 @@ class FeedbackExportQuery(BaseModel):
|
||||
raise ValueError("has_comment must be a boolean value")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
class AnnotationCountResponse(BaseModel):
|
||||
count: int = Field(description="Number of annotations")
|
||||
|
||||
|
||||
reg(ChatMessagesQuery)
|
||||
reg(MessageFeedbackPayload)
|
||||
reg(FeedbackExportQuery)
|
||||
class SuggestedQuestionsResponse(BaseModel):
|
||||
data: list[str] = Field(description="Suggested question")
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
ChatMessagesQuery,
|
||||
MessageFeedbackPayload,
|
||||
FeedbackExportQuery,
|
||||
AnnotationCountResponse,
|
||||
SuggestedQuestionsResponse,
|
||||
)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
@ -231,7 +240,7 @@ class ChatMessageListApi(Resource):
|
||||
@marshal_with(message_infinite_scroll_pagination_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = ChatMessagesQuery.model_validate(request.args.to_dict())
|
||||
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
@ -356,7 +365,7 @@ class MessageAnnotationCountApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Annotation count retrieved successfully",
|
||||
console_ns.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}),
|
||||
console_ns.models[AnnotationCountResponse.__name__],
|
||||
)
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@ -376,9 +385,7 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Suggested questions retrieved successfully",
|
||||
console_ns.model(
|
||||
"SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))}
|
||||
),
|
||||
console_ns.models[SuggestedQuestionsResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "Message or conversation not found")
|
||||
@setup_required
|
||||
@ -428,7 +435,7 @@ class MessageFeedbackExportApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = FeedbackExportQuery.model_validate(request.args.to_dict())
|
||||
|
||||
# Import the service function
|
||||
from services.feedback_service import FeedbackService
|
||||
|
||||
@ -12,6 +12,7 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
import services
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.app.workflow_run import workflow_run_node_execution_model
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
@ -35,7 +36,6 @@ from extensions.ext_database import db
|
||||
from factories import file_factory, variable_factory
|
||||
from fields.member_fields import simple_account_fields
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
||||
from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
@ -88,26 +88,6 @@ workflow_pagination_fields_copy = workflow_pagination_fields.copy()
|
||||
workflow_pagination_fields_copy["items"] = fields.List(fields.Nested(workflow_model), attribute="items")
|
||||
workflow_pagination_model = console_ns.model("WorkflowPagination", workflow_pagination_fields_copy)
|
||||
|
||||
# Reuse workflow_run_node_execution_model from workflow_run.py if already registered
|
||||
# Otherwise register it here
|
||||
from fields.end_user_fields import simple_end_user_fields
|
||||
|
||||
simple_end_user_model = None
|
||||
try:
|
||||
simple_end_user_model = console_ns.models.get("SimpleEndUser")
|
||||
except AttributeError:
|
||||
pass
|
||||
if simple_end_user_model is None:
|
||||
simple_end_user_model = console_ns.model("SimpleEndUser", simple_end_user_fields)
|
||||
|
||||
workflow_run_node_execution_model = None
|
||||
try:
|
||||
workflow_run_node_execution_model = console_ns.models.get("WorkflowRunNodeExecution")
|
||||
except AttributeError:
|
||||
pass
|
||||
if workflow_run_node_execution_model is None:
|
||||
workflow_run_node_execution_model = console_ns.model("WorkflowRunNodeExecution", workflow_run_node_execution_fields)
|
||||
|
||||
|
||||
class SyncDraftWorkflowPayload(BaseModel):
|
||||
graph: dict[str, Any]
|
||||
@ -470,7 +450,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
||||
Run draft workflow loop node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
args = LoopNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_loop(
|
||||
@ -508,7 +488,7 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
||||
Run draft workflow loop node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
args = LoopNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_loop(
|
||||
@ -999,6 +979,7 @@ class DraftWorkflowTriggerRunApi(Resource):
|
||||
if not event:
|
||||
return jsonable_encoder({"status": "waiting", "retry_in": LISTENING_RETRY_IN})
|
||||
workflow_args = dict(event.workflow_args)
|
||||
|
||||
workflow_args[SKIP_PREPARE_USER_INPUTS_KEY] = True
|
||||
return helper.compact_generate_response(
|
||||
AppGenerateService.generate(
|
||||
@ -1147,6 +1128,7 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
||||
|
||||
try:
|
||||
workflow_args = dict(trigger_debug_event.workflow_args)
|
||||
|
||||
workflow_args[SKIP_PREPARE_USER_INPUTS_KEY] = True
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
|
||||
@ -11,7 +11,10 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
|
||||
from fields.workflow_app_log_fields import (
|
||||
build_workflow_app_log_pagination_model,
|
||||
build_workflow_archived_log_pagination_model,
|
||||
)
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
@ -61,6 +64,7 @@ console_ns.schema_model(
|
||||
|
||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||
workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns)
|
||||
workflow_archived_log_pagination_model = build_workflow_archived_log_pagination_model(console_ns)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-app-logs")
|
||||
@ -99,3 +103,33 @@ class WorkflowAppLogApi(Resource):
|
||||
)
|
||||
|
||||
return workflow_app_log_pagination
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-archived-logs")
|
||||
class WorkflowArchivedLogApi(Resource):
|
||||
@console_ns.doc("get_workflow_archived_logs")
|
||||
@console_ns.doc(description="Get workflow archived execution logs")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
|
||||
@console_ns.response(200, "Workflow archived logs retrieved successfully", workflow_archived_log_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_archived_log_pagination_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get workflow archived logs
|
||||
"""
|
||||
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
workflow_app_service = WorkflowAppService()
|
||||
with Session(db.engine) as session:
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_archive_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
)
|
||||
|
||||
return workflow_app_log_pagination
|
||||
|
||||
@ -1,12 +1,15 @@
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Literal, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.end_user_fields import simple_end_user_fields
|
||||
from fields.member_fields import simple_account_fields
|
||||
from fields.workflow_run_fields import (
|
||||
@ -19,14 +22,17 @@ from fields.workflow_run_fields import (
|
||||
workflow_run_node_execution_list_fields,
|
||||
workflow_run_pagination_fields,
|
||||
)
|
||||
from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage
|
||||
from libs.custom_inputs import time_duration
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account, App, AppMode, EndUser, WorkflowRunTriggeredFrom
|
||||
from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowRunTriggeredFrom
|
||||
from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
|
||||
from services.workflow_run_service import WorkflowRunService
|
||||
|
||||
# Workflow run status choices for filtering
|
||||
WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"]
|
||||
EXPORT_SIGNED_URL_EXPIRE_SECONDS = 3600
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
@ -93,6 +99,15 @@ workflow_run_node_execution_list_model = console_ns.model(
|
||||
"WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy
|
||||
)
|
||||
|
||||
workflow_run_export_fields = console_ns.model(
|
||||
"WorkflowRunExport",
|
||||
{
|
||||
"status": fields.String(description="Export status: success/failed"),
|
||||
"presigned_url": fields.String(description="Pre-signed URL for download", required=False),
|
||||
"presigned_url_expires_at": fields.String(description="Pre-signed URL expiration time", required=False),
|
||||
},
|
||||
)
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
@ -181,6 +196,56 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
||||
return result
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>/export")
|
||||
class WorkflowRunExportApi(Resource):
|
||||
@console_ns.doc("get_workflow_run_export_url")
|
||||
@console_ns.doc(description="Generate a download URL for an archived workflow run.")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@console_ns.response(200, "Export URL generated", workflow_run_export_fields)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App, run_id: str):
|
||||
tenant_id = str(app_model.tenant_id)
|
||||
app_id = str(app_model.id)
|
||||
run_id_str = str(run_id)
|
||||
|
||||
run_created_at = db.session.scalar(
|
||||
select(WorkflowArchiveLog.run_created_at)
|
||||
.where(
|
||||
WorkflowArchiveLog.tenant_id == tenant_id,
|
||||
WorkflowArchiveLog.app_id == app_id,
|
||||
WorkflowArchiveLog.workflow_run_id == run_id_str,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if not run_created_at:
|
||||
return {"code": "archive_log_not_found", "message": "workflow run archive not found"}, 404
|
||||
|
||||
prefix = (
|
||||
f"{tenant_id}/app_id={app_id}/year={run_created_at.strftime('%Y')}/"
|
||||
f"month={run_created_at.strftime('%m')}/workflow_run_id={run_id_str}"
|
||||
)
|
||||
archive_key = f"{prefix}/{ARCHIVE_BUNDLE_NAME}"
|
||||
|
||||
try:
|
||||
archive_storage = get_archive_storage()
|
||||
except ArchiveStorageNotConfiguredError as e:
|
||||
return {"code": "archive_storage_not_configured", "message": str(e)}, 500
|
||||
|
||||
presigned_url = archive_storage.generate_presigned_url(
|
||||
archive_key,
|
||||
expires_in=EXPORT_SIGNED_URL_EXPIRE_SECONDS,
|
||||
)
|
||||
expires_at = datetime.now(UTC) + timedelta(seconds=EXPORT_SIGNED_URL_EXPIRE_SECONDS)
|
||||
return {
|
||||
"status": "success",
|
||||
"presigned_url": presigned_url,
|
||||
"presigned_url_expires_at": expires_at.isoformat(),
|
||||
}, 200
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs/count")
|
||||
class AdvancedChatAppWorkflowRunCountApi(Resource):
|
||||
@console_ns.doc("get_advanced_chat_workflow_runs_count")
|
||||
|
||||
@ -1,13 +1,14 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import get_or_create_model
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
|
||||
from libs.login import current_user, login_required
|
||||
@ -22,6 +23,14 @@ from ..wraps import account_initialization_required, edit_permission_required, s
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
trigger_model = get_or_create_model("WorkflowTrigger", trigger_fields)
|
||||
|
||||
triggers_list_fields_copy = triggers_list_fields.copy()
|
||||
triggers_list_fields_copy["data"] = fields.List(fields.Nested(trigger_model))
|
||||
triggers_list_model = get_or_create_model("WorkflowTriggerList", triggers_list_fields_copy)
|
||||
|
||||
webhook_trigger_model = get_or_create_model("WebhookTrigger", webhook_trigger_fields)
|
||||
|
||||
|
||||
class Parser(BaseModel):
|
||||
node_id: str
|
||||
@ -48,7 +57,7 @@ class WebhookTriggerApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(webhook_trigger_fields)
|
||||
@marshal_with(webhook_trigger_model)
|
||||
def get(self, app_model: App):
|
||||
"""Get webhook trigger for a node"""
|
||||
args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
@ -80,7 +89,7 @@ class AppTriggersApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(triggers_list_fields)
|
||||
@marshal_with(triggers_list_model)
|
||||
def get(self, app_model: App):
|
||||
"""Get app triggers list"""
|
||||
assert isinstance(current_user, Account)
|
||||
@ -120,7 +129,7 @@ class AppTriggerEnableApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(trigger_fields)
|
||||
@marshal_with(trigger_model)
|
||||
def post(self, app_model: App):
|
||||
"""Update app trigger (enable/disable)"""
|
||||
args = ParserEnable.model_validate(console_ns.payload)
|
||||
|
||||
@ -2,9 +2,11 @@ import logging
|
||||
|
||||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restx import Resource, fields
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import register_schema_models
|
||||
from libs.login import login_required
|
||||
from libs.oauth_data_source import NotionOAuth
|
||||
|
||||
@ -14,6 +16,26 @@ from ..wraps import account_initialization_required, is_admin_or_owner_required,
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthDataSourceResponse(BaseModel):
|
||||
data: str = Field(description="Authorization URL or 'internal' for internal setup")
|
||||
|
||||
|
||||
class OAuthDataSourceBindingResponse(BaseModel):
|
||||
result: str = Field(description="Operation result")
|
||||
|
||||
|
||||
class OAuthDataSourceSyncResponse(BaseModel):
|
||||
result: str = Field(description="Operation result")
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
OAuthDataSourceResponse,
|
||||
OAuthDataSourceBindingResponse,
|
||||
OAuthDataSourceSyncResponse,
|
||||
)
|
||||
|
||||
|
||||
def get_oauth_providers():
|
||||
with current_app.app_context():
|
||||
notion_oauth = NotionOAuth(
|
||||
@ -34,10 +56,7 @@ class OAuthDataSource(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Authorization URL or internal setup success",
|
||||
console_ns.model(
|
||||
"OAuthDataSourceResponse",
|
||||
{"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")},
|
||||
),
|
||||
console_ns.models[OAuthDataSourceResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Invalid provider")
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@ -101,7 +120,7 @@ class OAuthDataSourceBinding(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Data source binding success",
|
||||
console_ns.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}),
|
||||
console_ns.models[OAuthDataSourceBindingResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Invalid provider or code")
|
||||
def get(self, provider: str):
|
||||
@ -133,7 +152,7 @@ class OAuthDataSourceSync(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Data source sync success",
|
||||
console_ns.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}),
|
||||
console_ns.models[OAuthDataSourceSyncResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Invalid provider or sync failed")
|
||||
@setup_required
|
||||
|
||||
@ -2,10 +2,11 @@ import base64
|
||||
import secrets
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
EmailCodeError,
|
||||
@ -48,8 +49,31 @@ class ForgotPasswordResetPayload(BaseModel):
|
||||
return valid_password(value)
|
||||
|
||||
|
||||
for model in (ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
class ForgotPasswordEmailResponse(BaseModel):
|
||||
result: str = Field(description="Operation result")
|
||||
data: str | None = Field(default=None, description="Reset token")
|
||||
code: str | None = Field(default=None, description="Error code if account not found")
|
||||
|
||||
|
||||
class ForgotPasswordCheckResponse(BaseModel):
|
||||
is_valid: bool = Field(description="Whether code is valid")
|
||||
email: EmailStr = Field(description="Email address")
|
||||
token: str = Field(description="New reset token")
|
||||
|
||||
|
||||
class ForgotPasswordResetResponse(BaseModel):
|
||||
result: str = Field(description="Operation result")
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
ForgotPasswordSendPayload,
|
||||
ForgotPasswordCheckPayload,
|
||||
ForgotPasswordResetPayload,
|
||||
ForgotPasswordEmailResponse,
|
||||
ForgotPasswordCheckResponse,
|
||||
ForgotPasswordResetResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/forgot-password")
|
||||
@ -60,14 +84,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Email sent successfully",
|
||||
console_ns.model(
|
||||
"ForgotPasswordEmailResponse",
|
||||
{
|
||||
"result": fields.String(description="Operation result"),
|
||||
"data": fields.String(description="Reset token"),
|
||||
"code": fields.String(description="Error code if account not found"),
|
||||
},
|
||||
),
|
||||
console_ns.models[ForgotPasswordEmailResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Invalid email or rate limit exceeded")
|
||||
@setup_required
|
||||
@ -106,14 +123,7 @@ class ForgotPasswordCheckApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Code verified successfully",
|
||||
console_ns.model(
|
||||
"ForgotPasswordCheckResponse",
|
||||
{
|
||||
"is_valid": fields.Boolean(description="Whether code is valid"),
|
||||
"email": fields.String(description="Email address"),
|
||||
"token": fields.String(description="New reset token"),
|
||||
},
|
||||
),
|
||||
console_ns.models[ForgotPasswordCheckResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Invalid code or token")
|
||||
@setup_required
|
||||
@ -163,7 +173,7 @@ class ForgotPasswordResetApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Password reset successfully",
|
||||
console_ns.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}),
|
||||
console_ns.models[ForgotPasswordResetResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Invalid token or password mismatch")
|
||||
@setup_required
|
||||
|
||||
@ -155,43 +155,43 @@ class OAuthServerUserTokenApi(Resource):
|
||||
grant_type = OAuthGrantType(payload.grant_type)
|
||||
except ValueError:
|
||||
raise BadRequest("invalid grant_type")
|
||||
match grant_type:
|
||||
case OAuthGrantType.AUTHORIZATION_CODE:
|
||||
if not payload.code:
|
||||
raise BadRequest("code is required")
|
||||
|
||||
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
|
||||
if not payload.code:
|
||||
raise BadRequest("code is required")
|
||||
if payload.client_secret != oauth_provider_app.client_secret:
|
||||
raise BadRequest("client_secret is invalid")
|
||||
|
||||
if payload.client_secret != oauth_provider_app.client_secret:
|
||||
raise BadRequest("client_secret is invalid")
|
||||
if payload.redirect_uri not in oauth_provider_app.redirect_uris:
|
||||
raise BadRequest("redirect_uri is invalid")
|
||||
|
||||
if payload.redirect_uri not in oauth_provider_app.redirect_uris:
|
||||
raise BadRequest("redirect_uri is invalid")
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, code=payload.code, client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
)
|
||||
case OAuthGrantType.REFRESH_TOKEN:
|
||||
if not payload.refresh_token:
|
||||
raise BadRequest("refresh_token is required")
|
||||
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, code=payload.code, client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
)
|
||||
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
|
||||
if not payload.refresh_token:
|
||||
raise BadRequest("refresh_token is required")
|
||||
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
)
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/oauth/provider/account")
|
||||
|
||||
@ -1,15 +1,15 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.common.schema import get_or_create_model, register_schema_model
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.indexing_runner import IndexingRunner
|
||||
@ -17,7 +17,14 @@ from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
from extensions.ext_database import db
|
||||
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
|
||||
from fields.data_source_fields import (
|
||||
integrate_fields,
|
||||
integrate_icon_fields,
|
||||
integrate_list_fields,
|
||||
integrate_notion_info_list_fields,
|
||||
integrate_page_fields,
|
||||
integrate_workspace_fields,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import DataSourceOauthBinding, Document
|
||||
@ -36,9 +43,62 @@ class NotionEstimatePayload(BaseModel):
|
||||
doc_language: str = Field(default="English")
|
||||
|
||||
|
||||
class DataSourceNotionListQuery(BaseModel):
|
||||
dataset_id: str | None = Field(default=None, description="Dataset ID")
|
||||
credential_id: str = Field(..., description="Credential ID", min_length=1)
|
||||
datasource_parameters: dict[str, Any] | None = Field(default=None, description="Datasource parameters JSON string")
|
||||
|
||||
|
||||
class DataSourceNotionPreviewQuery(BaseModel):
|
||||
credential_id: str = Field(..., description="Credential ID", min_length=1)
|
||||
|
||||
|
||||
register_schema_model(console_ns, NotionEstimatePayload)
|
||||
|
||||
|
||||
integrate_icon_model = get_or_create_model("DataSourceIntegrateIcon", integrate_icon_fields)
|
||||
|
||||
integrate_page_fields_copy = integrate_page_fields.copy()
|
||||
integrate_page_fields_copy["page_icon"] = fields.Nested(integrate_icon_model, allow_null=True)
|
||||
integrate_page_model = get_or_create_model("DataSourceIntegratePage", integrate_page_fields_copy)
|
||||
|
||||
integrate_workspace_fields_copy = integrate_workspace_fields.copy()
|
||||
integrate_workspace_fields_copy["pages"] = fields.List(fields.Nested(integrate_page_model))
|
||||
integrate_workspace_model = get_or_create_model("DataSourceIntegrateWorkspace", integrate_workspace_fields_copy)
|
||||
|
||||
integrate_fields_copy = integrate_fields.copy()
|
||||
integrate_fields_copy["source_info"] = fields.Nested(integrate_workspace_model)
|
||||
integrate_model = get_or_create_model("DataSourceIntegrate", integrate_fields_copy)
|
||||
|
||||
integrate_list_fields_copy = integrate_list_fields.copy()
|
||||
integrate_list_fields_copy["data"] = fields.List(fields.Nested(integrate_model))
|
||||
integrate_list_model = get_or_create_model("DataSourceIntegrateList", integrate_list_fields_copy)
|
||||
|
||||
notion_page_fields = {
|
||||
"page_name": fields.String,
|
||||
"page_id": fields.String,
|
||||
"page_icon": fields.Nested(integrate_icon_model, allow_null=True),
|
||||
"is_bound": fields.Boolean,
|
||||
"parent_id": fields.String,
|
||||
"type": fields.String,
|
||||
}
|
||||
notion_page_model = get_or_create_model("NotionIntegratePage", notion_page_fields)
|
||||
|
||||
notion_workspace_fields = {
|
||||
"workspace_name": fields.String,
|
||||
"workspace_id": fields.String,
|
||||
"workspace_icon": fields.String,
|
||||
"pages": fields.List(fields.Nested(notion_page_model)),
|
||||
}
|
||||
notion_workspace_model = get_or_create_model("NotionIntegrateWorkspace", notion_workspace_fields)
|
||||
|
||||
integrate_notion_info_list_fields_copy = integrate_notion_info_list_fields.copy()
|
||||
integrate_notion_info_list_fields_copy["notion_info"] = fields.List(fields.Nested(notion_workspace_model))
|
||||
integrate_notion_info_list_model = get_or_create_model(
|
||||
"NotionIntegrateInfoList", integrate_notion_info_list_fields_copy
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/data-source/integrates",
|
||||
"/data-source/integrates/<uuid:binding_id>/<string:action>",
|
||||
@ -47,7 +107,7 @@ class DataSourceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(integrate_list_fields)
|
||||
@marshal_with(integrate_list_model)
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@ -97,9 +157,8 @@ class DataSourceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, binding_id, action):
|
||||
def patch(self, binding_id, action: Literal["enable", "disable"]):
|
||||
binding_id = str(binding_id)
|
||||
action = str(action)
|
||||
with Session(db.engine) as session:
|
||||
data_source_binding = session.execute(
|
||||
select(DataSourceOauthBinding).filter_by(id=binding_id)
|
||||
@ -107,23 +166,24 @@ class DataSourceApi(Resource):
|
||||
if data_source_binding is None:
|
||||
raise NotFound("Data source binding not found.")
|
||||
# enable binding
|
||||
if action == "enable":
|
||||
if data_source_binding.disabled:
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError("Data source is not disabled.")
|
||||
# disable binding
|
||||
if action == "disable":
|
||||
if not data_source_binding.disabled:
|
||||
data_source_binding.disabled = True
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError("Data source is disabled.")
|
||||
match action:
|
||||
case "enable":
|
||||
if data_source_binding.disabled:
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError("Data source is not disabled.")
|
||||
# disable binding
|
||||
case "disable":
|
||||
if not data_source_binding.disabled:
|
||||
data_source_binding.disabled = True
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError("Data source is disabled.")
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@ -132,30 +192,19 @@ class DataSourceNotionListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(integrate_notion_info_list_fields)
|
||||
@marshal_with(integrate_notion_info_list_model)
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
dataset_id = request.args.get("dataset_id", default=None, type=str)
|
||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||
if not credential_id:
|
||||
raise ValueError("Credential id is required.")
|
||||
query = DataSourceNotionListQuery.model_validate(request.args.to_dict())
|
||||
|
||||
# Get datasource_parameters from query string (optional, for GitHub and other datasources)
|
||||
datasource_parameters_str = request.args.get("datasource_parameters", default=None, type=str)
|
||||
datasource_parameters = {}
|
||||
if datasource_parameters_str:
|
||||
try:
|
||||
datasource_parameters = json.loads(datasource_parameters_str)
|
||||
if not isinstance(datasource_parameters, dict):
|
||||
raise ValueError("datasource_parameters must be a JSON object.")
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("Invalid datasource_parameters JSON format.")
|
||||
datasource_parameters = query.datasource_parameters or {}
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
credential_id=credential_id,
|
||||
credential_id=query.credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
@ -164,8 +213,8 @@ class DataSourceNotionListApi(Resource):
|
||||
exist_page_ids = []
|
||||
with Session(db.engine) as session:
|
||||
# import notion in the exist dataset
|
||||
if dataset_id:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if query.dataset_id:
|
||||
dataset = DatasetService.get_dataset(query.dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
if dataset.data_source_type != "notion_import":
|
||||
@ -173,7 +222,7 @@ class DataSourceNotionListApi(Resource):
|
||||
|
||||
documents = session.scalars(
|
||||
select(Document).filter_by(
|
||||
dataset_id=dataset_id,
|
||||
dataset_id=query.dataset_id,
|
||||
tenant_id=current_tenant_id,
|
||||
data_source_type="notion_import",
|
||||
enabled=True,
|
||||
@ -240,13 +289,12 @@ class DataSourceNotionApi(Resource):
|
||||
def get(self, page_id, page_type):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||
if not credential_id:
|
||||
raise ValueError("Credential id is required.")
|
||||
query = DataSourceNotionPreviewQuery.model_validate(request.args.to_dict())
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
credential_id=credential_id,
|
||||
credential_id=query.credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
|
||||
@ -8,7 +8,7 @@ from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.schema import get_or_create_model, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.apikey import (
|
||||
api_key_item_model,
|
||||
@ -34,6 +34,7 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_detail_kernel_fields, related_app_list
|
||||
from fields.dataset_fields import (
|
||||
content_fields,
|
||||
dataset_detail_fields,
|
||||
dataset_fields,
|
||||
dataset_query_detail_fields,
|
||||
@ -41,6 +42,7 @@ from fields.dataset_fields import (
|
||||
doc_metadata_fields,
|
||||
external_knowledge_info_fields,
|
||||
external_retrieval_model_fields,
|
||||
file_info_fields,
|
||||
icon_info_fields,
|
||||
keyword_setting_fields,
|
||||
reranking_model_fields,
|
||||
@ -55,41 +57,33 @@ from models.dataset import DatasetPermissionEnum
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
|
||||
|
||||
def _get_or_create_model(model_name: str, field_def):
|
||||
existing = console_ns.models.get(model_name)
|
||||
if existing is None:
|
||||
existing = console_ns.model(model_name, field_def)
|
||||
return existing
|
||||
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
dataset_base_model = _get_or_create_model("DatasetBase", dataset_fields)
|
||||
dataset_base_model = get_or_create_model("DatasetBase", dataset_fields)
|
||||
|
||||
tag_model = _get_or_create_model("Tag", tag_fields)
|
||||
tag_model = get_or_create_model("Tag", tag_fields)
|
||||
|
||||
keyword_setting_model = _get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
|
||||
vector_setting_model = _get_or_create_model("DatasetVectorSetting", vector_setting_fields)
|
||||
keyword_setting_model = get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
|
||||
vector_setting_model = get_or_create_model("DatasetVectorSetting", vector_setting_fields)
|
||||
|
||||
weighted_score_fields_copy = weighted_score_fields.copy()
|
||||
weighted_score_fields_copy["keyword_setting"] = fields.Nested(keyword_setting_model)
|
||||
weighted_score_fields_copy["vector_setting"] = fields.Nested(vector_setting_model)
|
||||
weighted_score_model = _get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
|
||||
weighted_score_model = get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
|
||||
|
||||
reranking_model = _get_or_create_model("DatasetRerankingModel", reranking_model_fields)
|
||||
reranking_model = get_or_create_model("DatasetRerankingModel", reranking_model_fields)
|
||||
|
||||
dataset_retrieval_model_fields_copy = dataset_retrieval_model_fields.copy()
|
||||
dataset_retrieval_model_fields_copy["reranking_model"] = fields.Nested(reranking_model)
|
||||
dataset_retrieval_model_fields_copy["weights"] = fields.Nested(weighted_score_model, allow_null=True)
|
||||
dataset_retrieval_model = _get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
|
||||
dataset_retrieval_model = get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
|
||||
|
||||
external_knowledge_info_model = _get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
|
||||
external_knowledge_info_model = get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
|
||||
|
||||
external_retrieval_model = _get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
|
||||
external_retrieval_model = get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
|
||||
|
||||
doc_metadata_model = _get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
|
||||
doc_metadata_model = get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
|
||||
|
||||
icon_info_model = _get_or_create_model("DatasetIconInfo", icon_info_fields)
|
||||
icon_info_model = get_or_create_model("DatasetIconInfo", icon_info_fields)
|
||||
|
||||
dataset_detail_fields_copy = dataset_detail_fields.copy()
|
||||
dataset_detail_fields_copy["retrieval_model_dict"] = fields.Nested(dataset_retrieval_model)
|
||||
@ -98,14 +92,22 @@ dataset_detail_fields_copy["external_knowledge_info"] = fields.Nested(external_k
|
||||
dataset_detail_fields_copy["external_retrieval_model"] = fields.Nested(external_retrieval_model, allow_null=True)
|
||||
dataset_detail_fields_copy["doc_metadata"] = fields.List(fields.Nested(doc_metadata_model))
|
||||
dataset_detail_fields_copy["icon_info"] = fields.Nested(icon_info_model)
|
||||
dataset_detail_model = _get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
|
||||
dataset_detail_model = get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
|
||||
|
||||
dataset_query_detail_model = _get_or_create_model("DatasetQueryDetail", dataset_query_detail_fields)
|
||||
file_info_model = get_or_create_model("DatasetFileInfo", file_info_fields)
|
||||
|
||||
app_detail_kernel_model = _get_or_create_model("AppDetailKernel", app_detail_kernel_fields)
|
||||
content_fields_copy = content_fields.copy()
|
||||
content_fields_copy["file_info"] = fields.Nested(file_info_model, allow_null=True)
|
||||
content_model = get_or_create_model("DatasetContent", content_fields_copy)
|
||||
|
||||
dataset_query_detail_fields_copy = dataset_query_detail_fields.copy()
|
||||
dataset_query_detail_fields_copy["queries"] = fields.Nested(content_model)
|
||||
dataset_query_detail_model = get_or_create_model("DatasetQueryDetail", dataset_query_detail_fields_copy)
|
||||
|
||||
app_detail_kernel_model = get_or_create_model("AppDetailKernel", app_detail_kernel_fields)
|
||||
related_app_list_copy = related_app_list.copy()
|
||||
related_app_list_copy["data"] = fields.List(fields.Nested(app_detail_kernel_model))
|
||||
related_app_list_model = _get_or_create_model("RelatedAppList", related_app_list_copy)
|
||||
related_app_list_model = get_or_create_model("RelatedAppList", related_app_list_copy)
|
||||
|
||||
|
||||
def _validate_indexing_technique(value: str | None) -> str | None:
|
||||
@ -146,6 +148,7 @@ class DatasetUpdatePayload(BaseModel):
|
||||
embedding_model: str | None = None
|
||||
embedding_model_provider: str | None = None
|
||||
retrieval_model: dict[str, Any] | None = None
|
||||
summary_index_setting: dict[str, Any] | None = None
|
||||
partial_member_list: list[dict[str, str]] | None = None
|
||||
external_retrieval_model: dict[str, Any] | None = None
|
||||
external_knowledge_id: str | None = None
|
||||
@ -176,7 +179,18 @@ class IndexingEstimatePayload(BaseModel):
|
||||
return result
|
||||
|
||||
|
||||
register_schema_models(console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload)
|
||||
class ConsoleDatasetListQuery(BaseModel):
|
||||
page: int = Field(default=1, description="Page number")
|
||||
limit: int = Field(default=20, description="Number of items per page")
|
||||
keyword: str | None = Field(default=None, description="Search keyword")
|
||||
include_all: bool = Field(default=False, description="Include all datasets")
|
||||
ids: list[str] = Field(default_factory=list, description="Filter by dataset IDs")
|
||||
tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs")
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload, ConsoleDatasetListQuery
|
||||
)
|
||||
|
||||
|
||||
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
|
||||
@ -275,18 +289,26 @@ class DatasetListApi(Resource):
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
ids = request.args.getlist("ids")
|
||||
# Convert query parameters to dict, handling list parameters correctly
|
||||
query_params: dict[str, str | list[str]] = dict(request.args.to_dict())
|
||||
# Handle ids and tag_ids as lists (Flask request.args.getlist returns list even for single value)
|
||||
if "ids" in request.args:
|
||||
query_params["ids"] = request.args.getlist("ids")
|
||||
if "tag_ids" in request.args:
|
||||
query_params["tag_ids"] = request.args.getlist("tag_ids")
|
||||
query = ConsoleDatasetListQuery.model_validate(query_params)
|
||||
# provider = request.args.get("provider", default="vendor")
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
tag_ids = request.args.getlist("tag_ids")
|
||||
include_all = request.args.get("include_all", default="false").lower() == "true"
|
||||
if ids:
|
||||
datasets, total = DatasetService.get_datasets_by_ids(ids, current_tenant_id)
|
||||
if query.ids:
|
||||
datasets, total = DatasetService.get_datasets_by_ids(query.ids, current_tenant_id)
|
||||
else:
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page, limit, current_tenant_id, current_user, search, tag_ids, include_all
|
||||
query.page,
|
||||
query.limit,
|
||||
current_tenant_id,
|
||||
current_user,
|
||||
query.keyword,
|
||||
query.tag_ids,
|
||||
query.include_all,
|
||||
)
|
||||
|
||||
# check embedding setting
|
||||
@ -318,7 +340,13 @@ class DatasetListApi(Resource):
|
||||
else:
|
||||
item.update({"partial_member_list": []})
|
||||
|
||||
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
|
||||
response = {
|
||||
"data": data,
|
||||
"has_more": len(datasets) == query.limit,
|
||||
"limit": query.limit,
|
||||
"total": total,
|
||||
"page": query.page,
|
||||
}
|
||||
return response, 200
|
||||
|
||||
@console_ns.doc("create_dataset")
|
||||
|
||||
@ -14,7 +14,7 @@ from sqlalchemy import asc, desc, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.schema import get_or_create_model, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from core.errors.error import (
|
||||
LLMBadRequestError,
|
||||
@ -45,6 +45,7 @@ from models.dataset import DocumentPipelineExecutionLog
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
|
||||
from services.file_service import FileService
|
||||
from tasks.generate_summary_index_task import generate_summary_index_task
|
||||
|
||||
from ..app.error import (
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
@ -72,34 +73,27 @@ logger = logging.getLogger(__name__)
|
||||
DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100
|
||||
|
||||
|
||||
def _get_or_create_model(model_name: str, field_def):
|
||||
existing = console_ns.models.get(model_name)
|
||||
if existing is None:
|
||||
existing = console_ns.model(model_name, field_def)
|
||||
return existing
|
||||
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
dataset_model = _get_or_create_model("Dataset", dataset_fields)
|
||||
dataset_model = get_or_create_model("Dataset", dataset_fields)
|
||||
|
||||
document_metadata_model = _get_or_create_model("DocumentMetadata", document_metadata_fields)
|
||||
document_metadata_model = get_or_create_model("DocumentMetadata", document_metadata_fields)
|
||||
|
||||
document_fields_copy = document_fields.copy()
|
||||
document_fields_copy["doc_metadata"] = fields.List(
|
||||
fields.Nested(document_metadata_model), attribute="doc_metadata_details"
|
||||
)
|
||||
document_model = _get_or_create_model("Document", document_fields_copy)
|
||||
document_model = get_or_create_model("Document", document_fields_copy)
|
||||
|
||||
document_with_segments_fields_copy = document_with_segments_fields.copy()
|
||||
document_with_segments_fields_copy["doc_metadata"] = fields.List(
|
||||
fields.Nested(document_metadata_model), attribute="doc_metadata_details"
|
||||
)
|
||||
document_with_segments_model = _get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy)
|
||||
document_with_segments_model = get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy)
|
||||
|
||||
dataset_and_document_fields_copy = dataset_and_document_fields.copy()
|
||||
dataset_and_document_fields_copy["dataset"] = fields.Nested(dataset_model)
|
||||
dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(document_model))
|
||||
dataset_and_document_model = _get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
|
||||
dataset_and_document_model = get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
|
||||
|
||||
|
||||
class DocumentRetryPayload(BaseModel):
|
||||
@ -110,6 +104,10 @@ class DocumentRenamePayload(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class GenerateSummaryPayload(BaseModel):
|
||||
document_list: list[str]
|
||||
|
||||
|
||||
class DocumentBatchDownloadZipPayload(BaseModel):
|
||||
"""Request payload for bulk downloading documents as a zip archive."""
|
||||
|
||||
@ -132,6 +130,7 @@ register_schema_models(
|
||||
RetrievalModel,
|
||||
DocumentRetryPayload,
|
||||
DocumentRenamePayload,
|
||||
GenerateSummaryPayload,
|
||||
DocumentBatchDownloadZipPayload,
|
||||
)
|
||||
|
||||
@ -319,6 +318,13 @@ class DatasetDocumentListApi(Resource):
|
||||
|
||||
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
documents = paginated_documents.items
|
||||
|
||||
DocumentService.enrich_documents_with_summary_index_status(
|
||||
documents=documents,
|
||||
dataset=dataset,
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if fetch:
|
||||
for document in documents:
|
||||
completed_segments = (
|
||||
@ -570,63 +576,62 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
if document.indexing_status in {"completed", "error"}:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
data_source_info = document.data_source_info_dict
|
||||
match document.data_source_type:
|
||||
case "upload_file":
|
||||
if not data_source_info:
|
||||
continue
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
file_detail = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if document.data_source_type == "upload_file":
|
||||
if not data_source_info:
|
||||
continue
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
file_detail = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
|
||||
.first()
|
||||
)
|
||||
if file_detail is None:
|
||||
raise NotFound("File not found.")
|
||||
|
||||
if file_detail is None:
|
||||
raise NotFound("File not found.")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
case "notion_import":
|
||||
if not data_source_info:
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": data_source_info.get("credential_id"),
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
"tenant_id": current_tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=document.doc_form,
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
case "website_crawl":
|
||||
if not data_source_info:
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE,
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
"url": data_source_info["url"],
|
||||
"tenant_id": current_tenant_id,
|
||||
"mode": data_source_info["mode"],
|
||||
"only_main_content": data_source_info["only_main_content"],
|
||||
}
|
||||
),
|
||||
document_model=document.doc_form,
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
elif document.data_source_type == "notion_import":
|
||||
if not data_source_info:
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": data_source_info.get("credential_id"),
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
"tenant_id": current_tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=document.doc_form,
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
elif document.data_source_type == "website_crawl":
|
||||
if not data_source_info:
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE,
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
"url": data_source_info["url"],
|
||||
"tenant_id": current_tenant_id,
|
||||
"mode": data_source_info["mode"],
|
||||
"only_main_content": data_source_info["only_main_content"],
|
||||
}
|
||||
),
|
||||
document_model=document.doc_form,
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
else:
|
||||
raise ValueError("Data source type not support")
|
||||
case _:
|
||||
raise ValueError("Data source type not support")
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.indexing_estimate(
|
||||
@ -804,6 +809,7 @@ class DocumentApi(DocumentResource):
|
||||
"display_status": document.display_status,
|
||||
"doc_form": document.doc_form,
|
||||
"doc_language": document.doc_language,
|
||||
"need_summary": document.need_summary if document.need_summary is not None else False,
|
||||
}
|
||||
else:
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
@ -839,6 +845,7 @@ class DocumentApi(DocumentResource):
|
||||
"display_status": document.display_status,
|
||||
"doc_form": document.doc_form,
|
||||
"doc_language": document.doc_language,
|
||||
"need_summary": document.need_summary if document.need_summary is not None else False,
|
||||
}
|
||||
|
||||
return response, 200
|
||||
@ -946,23 +953,24 @@ class DocumentProcessingApi(DocumentResource):
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if action == "pause":
|
||||
if document.indexing_status != "indexing":
|
||||
raise InvalidActionError("Document not in indexing state.")
|
||||
match action:
|
||||
case "pause":
|
||||
if document.indexing_status != "indexing":
|
||||
raise InvalidActionError("Document not in indexing state.")
|
||||
|
||||
document.paused_by = current_user.id
|
||||
document.paused_at = naive_utc_now()
|
||||
document.is_paused = True
|
||||
db.session.commit()
|
||||
document.paused_by = current_user.id
|
||||
document.paused_at = naive_utc_now()
|
||||
document.is_paused = True
|
||||
db.session.commit()
|
||||
|
||||
elif action == "resume":
|
||||
if document.indexing_status not in {"paused", "error"}:
|
||||
raise InvalidActionError("Document not in paused or error state.")
|
||||
case "resume":
|
||||
if document.indexing_status not in {"paused", "error"}:
|
||||
raise InvalidActionError("Document not in paused or error state.")
|
||||
|
||||
document.paused_by = None
|
||||
document.paused_at = None
|
||||
document.is_paused = False
|
||||
db.session.commit()
|
||||
document.paused_by = None
|
||||
document.paused_at = None
|
||||
document.is_paused = False
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@ -1178,7 +1186,7 @@ class DocumentRenameApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(document_fields)
|
||||
@marshal_with(document_model)
|
||||
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
|
||||
def post(self, dataset_id, document_id):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
@ -1262,3 +1270,149 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
|
||||
"input_data": log.input_data,
|
||||
"datasource_node_id": log.datasource_node_id,
|
||||
}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/generate-summary")
|
||||
class DocumentGenerateSummaryApi(Resource):
|
||||
@console_ns.doc("generate_summary_for_documents")
|
||||
@console_ns.doc(description="Generate summary index for documents")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.expect(console_ns.models[GenerateSummaryPayload.__name__])
|
||||
@console_ns.response(200, "Summary generation started successfully")
|
||||
@console_ns.response(400, "Invalid request or dataset configuration")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id):
|
||||
"""
|
||||
Generate summary index for specified documents.
|
||||
|
||||
This endpoint checks if the dataset configuration supports summary generation
|
||||
(indexing_technique must be 'high_quality' and summary_index_setting.enable must be true),
|
||||
then asynchronously generates summary indexes for the provided documents.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
|
||||
# Get dataset
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
# Check permissions
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
# Validate request payload
|
||||
payload = GenerateSummaryPayload.model_validate(console_ns.payload or {})
|
||||
document_list = payload.document_list
|
||||
|
||||
if not document_list:
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
raise BadRequest("document_list cannot be empty.")
|
||||
|
||||
# Check if dataset configuration supports summary generation
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
raise ValueError(
|
||||
f"Summary generation is only available for 'high_quality' indexing technique. "
|
||||
f"Current indexing technique: {dataset.indexing_technique}"
|
||||
)
|
||||
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
raise ValueError("Summary index is not enabled for this dataset. Please enable it in the dataset settings.")
|
||||
|
||||
# Verify all documents exist and belong to the dataset
|
||||
documents = DocumentService.get_documents_by_ids(dataset_id, document_list)
|
||||
|
||||
if len(documents) != len(document_list):
|
||||
found_ids = {doc.id for doc in documents}
|
||||
missing_ids = set(document_list) - found_ids
|
||||
raise NotFound(f"Some documents not found: {list(missing_ids)}")
|
||||
|
||||
# Update need_summary to True for documents that don't have it set
|
||||
# This handles the case where documents were created when summary_index_setting was disabled
|
||||
documents_to_update = [doc for doc in documents if not doc.need_summary and doc.doc_form != "qa_model"]
|
||||
|
||||
if documents_to_update:
|
||||
document_ids_to_update = [str(doc.id) for doc in documents_to_update]
|
||||
DocumentService.update_documents_need_summary(
|
||||
dataset_id=dataset_id,
|
||||
document_ids=document_ids_to_update,
|
||||
need_summary=True,
|
||||
)
|
||||
|
||||
# Dispatch async tasks for each document
|
||||
for document in documents:
|
||||
# Skip qa_model documents as they don't generate summaries
|
||||
if document.doc_form == "qa_model":
|
||||
logger.info("Skipping summary generation for qa_model document %s", document.id)
|
||||
continue
|
||||
|
||||
# Dispatch async task
|
||||
generate_summary_index_task.delay(dataset_id, document.id)
|
||||
logger.info(
|
||||
"Dispatched summary generation task for document %s in dataset %s",
|
||||
document.id,
|
||||
dataset_id,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/summary-status")
|
||||
class DocumentSummaryStatusApi(DocumentResource):
|
||||
@console_ns.doc("get_document_summary_status")
|
||||
@console_ns.doc(description="Get summary index generation status for a document")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@console_ns.response(200, "Summary status retrieved successfully")
|
||||
@console_ns.response(404, "Document not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
"""
|
||||
Get summary index generation status for a document.
|
||||
|
||||
Returns:
|
||||
- total_segments: Total number of segments in the document
|
||||
- summary_status: Dictionary with status counts
|
||||
- completed: Number of summaries completed
|
||||
- generating: Number of summaries being generated
|
||||
- error: Number of summaries with errors
|
||||
- not_started: Number of segments without summary records
|
||||
- summaries: List of summary records with status and content preview
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
|
||||
# Get dataset
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
# Check permissions
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
# Get summary status detail from service
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
result = SummaryIndexService.get_document_summary_status_detail(
|
||||
document_id=document_id,
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
|
||||
return result, 200
|
||||
|
||||
@ -41,6 +41,17 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
|
||||
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
||||
|
||||
|
||||
def _get_segment_with_summary(segment, dataset_id):
|
||||
"""Helper function to marshal segment and add summary information."""
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
segment_dict = dict(marshal(segment, segment_fields))
|
||||
# Query summary for this segment (only enabled summaries)
|
||||
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
|
||||
segment_dict["summary"] = summary.summary_content if summary else None
|
||||
return segment_dict
|
||||
|
||||
|
||||
class SegmentListQuery(BaseModel):
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
status: list[str] = Field(default_factory=list)
|
||||
@ -63,6 +74,7 @@ class SegmentUpdatePayload(BaseModel):
|
||||
keywords: list[str] | None = None
|
||||
regenerate_child_chunks: bool = False
|
||||
attachment_ids: list[str] | None = None
|
||||
summary: str | None = None # Summary content for summary index
|
||||
|
||||
|
||||
class BatchImportPayload(BaseModel):
|
||||
@ -90,6 +102,7 @@ register_schema_models(
|
||||
ChildChunkCreatePayload,
|
||||
ChildChunkUpdatePayload,
|
||||
ChildChunkBatchUpdatePayload,
|
||||
ChildChunkUpdateArgs,
|
||||
)
|
||||
|
||||
|
||||
@ -180,8 +193,25 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
|
||||
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
|
||||
# Query summaries for all segments in this page (batch query for efficiency)
|
||||
segment_ids = [segment.id for segment in segments.items]
|
||||
summaries = {}
|
||||
if segment_ids:
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
|
||||
# Only include enabled summaries (already filtered by service)
|
||||
summaries = {chunk_id: summary.summary_content for chunk_id, summary in summary_records.items()}
|
||||
|
||||
# Add summary to each segment
|
||||
segments_with_summary = []
|
||||
for segment in segments.items:
|
||||
segment_dict = dict(marshal(segment, segment_fields))
|
||||
segment_dict["summary"] = summaries.get(segment.id)
|
||||
segments_with_summary.append(segment_dict)
|
||||
|
||||
response = {
|
||||
"data": marshal(segments.items, segment_fields),
|
||||
"data": segments_with_summary,
|
||||
"limit": limit,
|
||||
"total": segments.total,
|
||||
"total_pages": segments.pages,
|
||||
@ -327,7 +357,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
payload_dict = payload.model_dump(exclude_none=True)
|
||||
SegmentService.segment_create_args_validate(payload_dict, document)
|
||||
segment = SegmentService.create_segment(payload_dict, document, dataset)
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
|
||||
@ -389,10 +419,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
payload = SegmentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
payload_dict = payload.model_dump(exclude_none=True)
|
||||
SegmentService.segment_create_args_validate(payload_dict, document)
|
||||
|
||||
# Update segment (summary update with change detection is handled in SegmentService.update_segment)
|
||||
segment = SegmentService.update_segment(
|
||||
SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
|
||||
)
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
|
||||
@ -4,7 +4,7 @@ from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.schema import get_or_create_model, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
@ -28,34 +28,27 @@ from services.hit_testing_service import HitTestingService
|
||||
from services.knowledge_service import ExternalDatasetTestService
|
||||
|
||||
|
||||
def _get_or_create_model(model_name: str, field_def):
|
||||
existing = console_ns.models.get(model_name)
|
||||
if existing is None:
|
||||
existing = console_ns.model(model_name, field_def)
|
||||
return existing
|
||||
|
||||
|
||||
def _build_dataset_detail_model():
|
||||
keyword_setting_model = _get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
|
||||
vector_setting_model = _get_or_create_model("DatasetVectorSetting", vector_setting_fields)
|
||||
keyword_setting_model = get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
|
||||
vector_setting_model = get_or_create_model("DatasetVectorSetting", vector_setting_fields)
|
||||
|
||||
weighted_score_fields_copy = weighted_score_fields.copy()
|
||||
weighted_score_fields_copy["keyword_setting"] = fields.Nested(keyword_setting_model)
|
||||
weighted_score_fields_copy["vector_setting"] = fields.Nested(vector_setting_model)
|
||||
weighted_score_model = _get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
|
||||
weighted_score_model = get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
|
||||
|
||||
reranking_model = _get_or_create_model("DatasetRerankingModel", reranking_model_fields)
|
||||
reranking_model = get_or_create_model("DatasetRerankingModel", reranking_model_fields)
|
||||
|
||||
dataset_retrieval_model_fields_copy = dataset_retrieval_model_fields.copy()
|
||||
dataset_retrieval_model_fields_copy["reranking_model"] = fields.Nested(reranking_model)
|
||||
dataset_retrieval_model_fields_copy["weights"] = fields.Nested(weighted_score_model, allow_null=True)
|
||||
dataset_retrieval_model = _get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
|
||||
dataset_retrieval_model = get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
|
||||
|
||||
tag_model = _get_or_create_model("Tag", tag_fields)
|
||||
doc_metadata_model = _get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
|
||||
external_knowledge_info_model = _get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
|
||||
external_retrieval_model = _get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
|
||||
icon_info_model = _get_or_create_model("DatasetIconInfo", icon_info_fields)
|
||||
tag_model = get_or_create_model("Tag", tag_fields)
|
||||
doc_metadata_model = get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
|
||||
external_knowledge_info_model = get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
|
||||
external_retrieval_model = get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
|
||||
icon_info_model = get_or_create_model("DatasetIconInfo", icon_info_fields)
|
||||
|
||||
dataset_detail_fields_copy = dataset_detail_fields.copy()
|
||||
dataset_detail_fields_copy["retrieval_model_dict"] = fields.Nested(dataset_retrieval_model)
|
||||
@ -64,7 +57,7 @@ def _build_dataset_detail_model():
|
||||
dataset_detail_fields_copy["external_retrieval_model"] = fields.Nested(external_retrieval_model, allow_null=True)
|
||||
dataset_detail_fields_copy["doc_metadata"] = fields.List(fields.Nested(doc_metadata_model))
|
||||
dataset_detail_fields_copy["icon_info"] = fields.Nested(icon_info_model)
|
||||
return _get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
|
||||
return get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
|
||||
|
||||
|
||||
try:
|
||||
@ -98,12 +91,19 @@ class BedrockRetrievalPayload(BaseModel):
|
||||
knowledge_id: str
|
||||
|
||||
|
||||
class ExternalApiTemplateListQuery(BaseModel):
|
||||
page: int = Field(default=1, description="Page number")
|
||||
limit: int = Field(default=20, description="Number of items per page")
|
||||
keyword: str | None = Field(default=None, description="Search keyword")
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
ExternalKnowledgeApiPayload,
|
||||
ExternalDatasetCreatePayload,
|
||||
ExternalHitTestingPayload,
|
||||
BedrockRetrievalPayload,
|
||||
ExternalApiTemplateListQuery,
|
||||
)
|
||||
|
||||
|
||||
@ -124,19 +124,17 @@ class ExternalApiTemplateListApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
query = ExternalApiTemplateListQuery.model_validate(request.args.to_dict())
|
||||
|
||||
external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis(
|
||||
page, limit, current_tenant_id, search
|
||||
query.page, query.limit, current_tenant_id, query.keyword
|
||||
)
|
||||
response = {
|
||||
"data": [item.to_dict() for item in external_knowledge_apis],
|
||||
"has_more": len(external_knowledge_apis) == limit,
|
||||
"limit": limit,
|
||||
"has_more": len(external_knowledge_apis) == query.limit,
|
||||
"limit": query.limit,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page": query.page,
|
||||
}
|
||||
return response, 200
|
||||
|
||||
|
||||
@ -1,6 +1,13 @@
|
||||
from flask_restx import Resource
|
||||
from flask_restx import Resource, fields
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from fields.hit_testing_fields import (
|
||||
child_chunk_fields,
|
||||
document_fields,
|
||||
files_fields,
|
||||
hit_testing_record_fields,
|
||||
segment_fields,
|
||||
)
|
||||
from libs.login import login_required
|
||||
|
||||
from .. import console_ns
|
||||
@ -14,13 +21,45 @@ from ..wraps import (
|
||||
register_schema_model(console_ns, HitTestingPayload)
|
||||
|
||||
|
||||
def _get_or_create_model(model_name: str, field_def):
|
||||
"""Get or create a flask_restx model to avoid dict type issues in Swagger."""
|
||||
existing = console_ns.models.get(model_name)
|
||||
if existing is None:
|
||||
existing = console_ns.model(model_name, field_def)
|
||||
return existing
|
||||
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
document_model = _get_or_create_model("HitTestingDocument", document_fields)
|
||||
|
||||
segment_fields_copy = segment_fields.copy()
|
||||
segment_fields_copy["document"] = fields.Nested(document_model)
|
||||
segment_model = _get_or_create_model("HitTestingSegment", segment_fields_copy)
|
||||
|
||||
child_chunk_model = _get_or_create_model("HitTestingChildChunk", child_chunk_fields)
|
||||
files_model = _get_or_create_model("HitTestingFile", files_fields)
|
||||
|
||||
hit_testing_record_fields_copy = hit_testing_record_fields.copy()
|
||||
hit_testing_record_fields_copy["segment"] = fields.Nested(segment_model)
|
||||
hit_testing_record_fields_copy["child_chunks"] = fields.List(fields.Nested(child_chunk_model))
|
||||
hit_testing_record_fields_copy["files"] = fields.List(fields.Nested(files_model))
|
||||
hit_testing_record_model = _get_or_create_model("HitTestingRecord", hit_testing_record_fields_copy)
|
||||
|
||||
# Response model for hit testing API
|
||||
hit_testing_response_fields = {
|
||||
"query": fields.String,
|
||||
"records": fields.List(fields.Nested(hit_testing_record_model)),
|
||||
}
|
||||
hit_testing_response_model = _get_or_create_model("HitTestingResponse", hit_testing_response_fields)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
|
||||
class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
@console_ns.doc("test_dataset_retrieval")
|
||||
@console_ns.doc(description="Test dataset knowledge retrieval")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.expect(console_ns.models[HitTestingPayload.__name__])
|
||||
@console_ns.response(200, "Hit testing completed successfully")
|
||||
@console_ns.response(200, "Hit testing completed successfully", model=hit_testing_response_model)
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(400, "Invalid parameters")
|
||||
@setup_required
|
||||
|
||||
@ -4,14 +4,16 @@ from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_model, register_schema_models
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from fields.dataset_fields import dataset_metadata_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
DocumentMetadataOperation,
|
||||
MetadataArgs,
|
||||
MetadataDetail,
|
||||
MetadataOperationData,
|
||||
)
|
||||
from services.metadata_service import MetadataService
|
||||
@ -21,8 +23,9 @@ class MetadataUpdatePayload(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
register_schema_models(console_ns, MetadataArgs, MetadataOperationData)
|
||||
register_schema_model(console_ns, MetadataUpdatePayload)
|
||||
register_schema_models(
|
||||
console_ns, MetadataArgs, MetadataOperationData, MetadataUpdatePayload, DocumentMetadataOperation, MetadataDetail
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/metadata")
|
||||
@ -123,10 +126,11 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
if action == "enable":
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
elif action == "disable":
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
match action:
|
||||
case "enable":
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
case "disable":
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
from typing import Any, NoReturn
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from flask_restx import Resource, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
@ -14,7 +14,9 @@ from controllers.console.app.error import (
|
||||
)
|
||||
from controllers.console.app.workflow_draft_variable import (
|
||||
_WORKFLOW_DRAFT_VARIABLE_FIELDS, # type: ignore[private-usage]
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, # type: ignore[private-usage]
|
||||
workflow_draft_variable_list_model,
|
||||
workflow_draft_variable_list_without_value_model,
|
||||
workflow_draft_variable_model,
|
||||
)
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
@ -27,7 +29,6 @@ from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
|
||||
@ -52,20 +53,6 @@ class WorkflowDraftVariablePatchPayload(BaseModel):
|
||||
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
|
||||
|
||||
|
||||
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
|
||||
return var_list.variables
|
||||
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = {
|
||||
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items),
|
||||
"total": fields.Raw(),
|
||||
}
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
|
||||
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
|
||||
}
|
||||
|
||||
|
||||
def _api_prerequisite(f):
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
@ -92,7 +79,7 @@ def _api_prerequisite(f):
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables")
|
||||
class RagPipelineVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
|
||||
@marshal_with(workflow_draft_variable_list_without_value_model)
|
||||
def get(self, pipeline: Pipeline):
|
||||
"""
|
||||
Get draft workflow
|
||||
@ -150,7 +137,7 @@ def validate_node_id(node_id: str) -> NoReturn | None:
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/variables")
|
||||
class RagPipelineNodeVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, pipeline: Pipeline, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
@ -176,7 +163,7 @@ class RagPipelineVariableApi(Resource):
|
||||
_PATCH_VALUE_FIELD = "value"
|
||||
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def get(self, pipeline: Pipeline, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
@ -189,7 +176,7 @@ class RagPipelineVariableApi(Resource):
|
||||
return variable
|
||||
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
@console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__])
|
||||
def patch(self, pipeline: Pipeline, variable_id: str):
|
||||
# Request payload for file types:
|
||||
@ -307,7 +294,7 @@ def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/system-variables")
|
||||
class RagPipelineSystemVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, pipeline: Pipeline):
|
||||
return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID)
|
||||
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with # type: ignore
|
||||
from flask_restx import Resource, fields, marshal_with # type: ignore
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.schema import get_or_create_model, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import (
|
||||
@ -12,7 +12,11 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
|
||||
from fields.rag_pipeline_fields import (
|
||||
leaked_dependency_fields,
|
||||
pipeline_import_check_dependencies_fields,
|
||||
pipeline_import_fields,
|
||||
)
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import Pipeline
|
||||
from services.app_dsl_service import ImportStatus
|
||||
@ -38,13 +42,25 @@ class IncludeSecretQuery(BaseModel):
|
||||
register_schema_models(console_ns, RagPipelineImportPayload, IncludeSecretQuery)
|
||||
|
||||
|
||||
pipeline_import_model = get_or_create_model("RagPipelineImport", pipeline_import_fields)
|
||||
|
||||
leaked_dependency_model = get_or_create_model("RagPipelineLeakedDependency", leaked_dependency_fields)
|
||||
pipeline_import_check_dependencies_fields_copy = pipeline_import_check_dependencies_fields.copy()
|
||||
pipeline_import_check_dependencies_fields_copy["leaked_dependencies"] = fields.List(
|
||||
fields.Nested(leaked_dependency_model)
|
||||
)
|
||||
pipeline_import_check_dependencies_model = get_or_create_model(
|
||||
"RagPipelineImportCheckDependencies", pipeline_import_check_dependencies_fields_copy
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/imports")
|
||||
class RagPipelineImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@marshal_with(pipeline_import_fields)
|
||||
@marshal_with(pipeline_import_model)
|
||||
@console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__])
|
||||
def post(self):
|
||||
# Check user role first
|
||||
@ -81,7 +97,7 @@ class RagPipelineImportConfirmApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@marshal_with(pipeline_import_fields)
|
||||
@marshal_with(pipeline_import_model)
|
||||
def post(self, import_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
@ -106,7 +122,7 @@ class RagPipelineImportCheckDependenciesApi(Resource):
|
||||
@get_rag_pipeline
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@marshal_with(pipeline_import_check_dependencies_fields)
|
||||
@marshal_with(pipeline_import_check_dependencies_model)
|
||||
def get(self, pipeline: Pipeline):
|
||||
with Session(db.engine) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Literal, cast
|
||||
from uuid import UUID
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, marshal_with, reqparse # type: ignore
|
||||
from flask_restx import Resource, marshal_with # type: ignore
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
@ -17,6 +16,13 @@ from controllers.console.app.error import (
|
||||
DraftWorkflowNotExist,
|
||||
DraftWorkflowNotSync,
|
||||
)
|
||||
from controllers.console.app.workflow import workflow_model, workflow_pagination_model
|
||||
from controllers.console.app.workflow_run import (
|
||||
workflow_run_detail_model,
|
||||
workflow_run_node_execution_list_model,
|
||||
workflow_run_node_execution_model,
|
||||
workflow_run_pagination_model,
|
||||
)
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
@ -30,15 +36,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
from fields.workflow_run_fields import (
|
||||
workflow_run_detail_fields,
|
||||
workflow_run_node_execution_fields,
|
||||
workflow_run_node_execution_list_fields,
|
||||
workflow_run_pagination_fields,
|
||||
)
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField
|
||||
from libs.helper import TimestampField, UUIDStrOrEmpty
|
||||
from libs.login import current_account_with_tenant, current_user, login_required
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
@ -110,7 +109,7 @@ class NodeIdQuery(BaseModel):
|
||||
|
||||
|
||||
class WorkflowRunQuery(BaseModel):
|
||||
last_id: UUID | None = None
|
||||
last_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
@ -121,6 +120,10 @@ class DatasourceVariablesPayload(BaseModel):
|
||||
start_node_title: str
|
||||
|
||||
|
||||
class RagPipelineRecommendedPluginQuery(BaseModel):
|
||||
type: str = "all"
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
DraftWorkflowSyncPayload,
|
||||
@ -135,6 +138,7 @@ register_schema_models(
|
||||
NodeIdQuery,
|
||||
WorkflowRunQuery,
|
||||
DatasourceVariablesPayload,
|
||||
RagPipelineRecommendedPluginQuery,
|
||||
)
|
||||
|
||||
|
||||
@ -145,7 +149,7 @@ class DraftRagPipelineApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
@edit_permission_required
|
||||
@marshal_with(workflow_fields)
|
||||
@marshal_with(workflow_model)
|
||||
def get(self, pipeline: Pipeline):
|
||||
"""
|
||||
Get draft rag pipeline's workflow
|
||||
@ -521,7 +525,7 @@ class RagPipelineDraftNodeRunApi(Resource):
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
@marshal_with(workflow_run_node_execution_fields)
|
||||
@marshal_with(workflow_run_node_execution_model)
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
"""
|
||||
Run draft workflow node
|
||||
@ -569,7 +573,7 @@ class PublishedRagPipelineApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_rag_pipeline
|
||||
@marshal_with(workflow_fields)
|
||||
@marshal_with(workflow_model)
|
||||
def get(self, pipeline: Pipeline):
|
||||
"""
|
||||
Get published pipeline
|
||||
@ -664,7 +668,7 @@ class PublishedAllRagPipelineApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_rag_pipeline
|
||||
@marshal_with(workflow_pagination_fields)
|
||||
@marshal_with(workflow_pagination_model)
|
||||
def get(self, pipeline: Pipeline):
|
||||
"""
|
||||
Get published workflows
|
||||
@ -708,7 +712,7 @@ class RagPipelineByIdApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_rag_pipeline
|
||||
@marshal_with(workflow_fields)
|
||||
@marshal_with(workflow_model)
|
||||
def patch(self, pipeline: Pipeline, workflow_id: str):
|
||||
"""
|
||||
Update workflow attributes
|
||||
@ -830,7 +834,7 @@ class RagPipelineWorkflowRunListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
@marshal_with(workflow_run_pagination_fields)
|
||||
@marshal_with(workflow_run_pagination_model)
|
||||
def get(self, pipeline: Pipeline):
|
||||
"""
|
||||
Get workflow run list
|
||||
@ -858,7 +862,7 @@ class RagPipelineWorkflowRunDetailApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
@marshal_with(workflow_run_detail_fields)
|
||||
@marshal_with(workflow_run_detail_model)
|
||||
def get(self, pipeline: Pipeline, run_id):
|
||||
"""
|
||||
Get workflow run detail
|
||||
@ -877,7 +881,7 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
@marshal_with(workflow_run_node_execution_list_fields)
|
||||
@marshal_with(workflow_run_node_execution_list_model)
|
||||
def get(self, pipeline: Pipeline, run_id: str):
|
||||
"""
|
||||
Get workflow run node execution list
|
||||
@ -911,7 +915,7 @@ class RagPipelineWorkflowLastRunApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
@marshal_with(workflow_run_node_execution_fields)
|
||||
@marshal_with(workflow_run_node_execution_model)
|
||||
def get(self, pipeline: Pipeline, node_id: str):
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
|
||||
@ -952,7 +956,7 @@ class RagPipelineDatasourceVariableApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
@edit_permission_required
|
||||
@marshal_with(workflow_run_node_execution_fields)
|
||||
@marshal_with(workflow_run_node_execution_model)
|
||||
def post(self, pipeline: Pipeline):
|
||||
"""
|
||||
Set datasource variables
|
||||
@ -975,11 +979,8 @@ class RagPipelineRecommendedPluginApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("type", type=str, location="args", required=False, default="all")
|
||||
args = parser.parse_args()
|
||||
type = args["type"]
|
||||
query = RagPipelineRecommendedPluginQuery.model_validate(request.args.to_dict())
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
recommended_plugins = rag_pipeline_service.get_recommended_plugins(type)
|
||||
recommended_plugins = rag_pipeline_service.get_recommended_plugins(query.type)
|
||||
return recommended_plugins
|
||||
|
||||
@ -2,16 +2,17 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import and_, select
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
from controllers.common.schema import get_or_create_model
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from extensions.ext_database import db
|
||||
from fields.installed_app_fields import installed_app_list_fields
|
||||
from fields.installed_app_fields import app_fields, installed_app_fields, installed_app_list_fields
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, InstalledApp, RecommendedApp
|
||||
@ -28,22 +29,37 @@ class InstalledAppUpdatePayload(BaseModel):
|
||||
is_pinned: bool | None = None
|
||||
|
||||
|
||||
class InstalledAppsListQuery(BaseModel):
|
||||
app_id: str | None = Field(default=None, description="App ID to filter by")
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
app_model = get_or_create_model("InstalledAppInfo", app_fields)
|
||||
|
||||
installed_app_fields_copy = installed_app_fields.copy()
|
||||
installed_app_fields_copy["app"] = fields.Nested(app_model)
|
||||
installed_app_model = get_or_create_model("InstalledApp", installed_app_fields_copy)
|
||||
|
||||
installed_app_list_fields_copy = installed_app_list_fields.copy()
|
||||
installed_app_list_fields_copy["installed_apps"] = fields.List(fields.Nested(installed_app_model))
|
||||
installed_app_list_model = get_or_create_model("InstalledAppList", installed_app_list_fields_copy)
|
||||
|
||||
|
||||
@console_ns.route("/installed-apps")
|
||||
class InstalledAppsListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(installed_app_list_fields)
|
||||
@marshal_with(installed_app_list_model)
|
||||
def get(self):
|
||||
app_id = request.args.get("app_id", default=None, type=str)
|
||||
query = InstalledAppsListQuery.model_validate(request.args.to_dict())
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if app_id:
|
||||
if query.app_id:
|
||||
installed_apps = db.session.scalars(
|
||||
select(InstalledApp).where(
|
||||
and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id)
|
||||
and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == query.app_id)
|
||||
)
|
||||
).all()
|
||||
else:
|
||||
|
||||
@ -3,6 +3,7 @@ from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from constants.languages import languages
|
||||
from controllers.common.schema import get_or_create_model
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from libs.helper import AppIconUrlField
|
||||
@ -19,8 +20,10 @@ app_fields = {
|
||||
"icon_background": fields.String,
|
||||
}
|
||||
|
||||
app_model = get_or_create_model("RecommendedAppInfo", app_fields)
|
||||
|
||||
recommended_app_fields = {
|
||||
"app": fields.Nested(app_fields, attribute="app"),
|
||||
"app": fields.Nested(app_model, attribute="app"),
|
||||
"app_id": fields.String,
|
||||
"description": fields.String(attribute="description"),
|
||||
"copyright": fields.String,
|
||||
@ -32,11 +35,15 @@ recommended_app_fields = {
|
||||
"can_trial": fields.Boolean,
|
||||
}
|
||||
|
||||
recommended_app_model = get_or_create_model("RecommendedApp", recommended_app_fields)
|
||||
|
||||
recommended_app_list_fields = {
|
||||
"recommended_apps": fields.List(fields.Nested(recommended_app_fields)),
|
||||
"recommended_apps": fields.List(fields.Nested(recommended_app_model)),
|
||||
"categories": fields.List(fields.String),
|
||||
}
|
||||
|
||||
recommended_app_list_model = get_or_create_model("RecommendedAppList", recommended_app_list_fields)
|
||||
|
||||
|
||||
class RecommendedAppsQuery(BaseModel):
|
||||
language: str | None = Field(default=None)
|
||||
@ -53,7 +60,7 @@ class RecommendedAppListApi(Resource):
|
||||
@console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__])
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(recommended_app_list_fields)
|
||||
@marshal_with(recommended_app_list_model)
|
||||
def get(self):
|
||||
# language args
|
||||
args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
@ -1,14 +1,16 @@
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal, marshal_with, reqparse
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.fields import Parameters as ParametersResponse
|
||||
from controllers.common.fields import Site as SiteResponse
|
||||
from controllers.console import api
|
||||
from controllers.common.schema import get_or_create_model
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
@ -42,9 +44,21 @@ from core.errors.error import (
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_detail_fields_with_site
|
||||
from fields.app_fields import (
|
||||
app_detail_fields_with_site,
|
||||
deleted_tool_fields,
|
||||
model_config_fields,
|
||||
site_fields,
|
||||
tag_fields,
|
||||
)
|
||||
from fields.dataset_fields import dataset_fields
|
||||
from fields.workflow_fields import workflow_fields
|
||||
from fields.member_fields import simple_account_fields
|
||||
from fields.workflow_fields import (
|
||||
conversation_variable_fields,
|
||||
pipeline_variable_fields,
|
||||
workflow_fields,
|
||||
workflow_partial_fields,
|
||||
)
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_user
|
||||
@ -74,7 +88,86 @@ from services.recommended_app_service import RecommendedAppService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
model_config_model = get_or_create_model("TrialAppModelConfig", model_config_fields)
|
||||
workflow_partial_model = get_or_create_model("TrialWorkflowPartial", workflow_partial_fields)
|
||||
deleted_tool_model = get_or_create_model("TrialDeletedTool", deleted_tool_fields)
|
||||
tag_model = get_or_create_model("TrialTag", tag_fields)
|
||||
site_model = get_or_create_model("TrialSite", site_fields)
|
||||
|
||||
app_detail_fields_with_site_copy = app_detail_fields_with_site.copy()
|
||||
app_detail_fields_with_site_copy["model_config"] = fields.Nested(
|
||||
model_config_model, attribute="app_model_config", allow_null=True
|
||||
)
|
||||
app_detail_fields_with_site_copy["workflow"] = fields.Nested(workflow_partial_model, allow_null=True)
|
||||
app_detail_fields_with_site_copy["deleted_tools"] = fields.List(fields.Nested(deleted_tool_model))
|
||||
app_detail_fields_with_site_copy["tags"] = fields.List(fields.Nested(tag_model))
|
||||
app_detail_fields_with_site_copy["site"] = fields.Nested(site_model)
|
||||
app_detail_with_site_model = get_or_create_model("TrialAppDetailWithSite", app_detail_fields_with_site_copy)
|
||||
|
||||
simple_account_model = get_or_create_model("SimpleAccount", simple_account_fields)
|
||||
conversation_variable_model = get_or_create_model("TrialConversationVariable", conversation_variable_fields)
|
||||
pipeline_variable_model = get_or_create_model("TrialPipelineVariable", pipeline_variable_fields)
|
||||
|
||||
workflow_fields_copy = workflow_fields.copy()
|
||||
workflow_fields_copy["created_by"] = fields.Nested(simple_account_model, attribute="created_by_account")
|
||||
workflow_fields_copy["updated_by"] = fields.Nested(
|
||||
simple_account_model, attribute="updated_by_account", allow_null=True
|
||||
)
|
||||
workflow_fields_copy["conversation_variables"] = fields.List(fields.Nested(conversation_variable_model))
|
||||
workflow_fields_copy["rag_pipeline_variables"] = fields.List(fields.Nested(pipeline_variable_model))
|
||||
workflow_model = get_or_create_model("TrialWorkflow", workflow_fields_copy)
|
||||
|
||||
|
||||
# Pydantic models for request validation
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowRunRequest(BaseModel):
|
||||
inputs: dict
|
||||
files: list | None = None
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
inputs: dict
|
||||
query: str
|
||||
files: list | None = None
|
||||
conversation_id: str | None = None
|
||||
parent_message_id: str | None = None
|
||||
retriever_from: str = "explore_app"
|
||||
|
||||
|
||||
class TextToSpeechRequest(BaseModel):
|
||||
message_id: str | None = None
|
||||
voice: str | None = None
|
||||
text: str | None = None
|
||||
streaming: bool | None = None
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
inputs: dict
|
||||
query: str = ""
|
||||
files: list | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
retriever_from: str = "explore_app"
|
||||
|
||||
|
||||
# Register schemas for Swagger documentation
|
||||
console_ns.schema_model(
|
||||
WorkflowRunRequest.__name__, WorkflowRunRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChatRequest.__name__, ChatRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
TextToSpeechRequest.__name__, TextToSpeechRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
CompletionRequest.__name__, CompletionRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
class TrialAppWorkflowRunApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[WorkflowRunRequest.__name__])
|
||||
def post(self, trial_app):
|
||||
"""
|
||||
Run workflow
|
||||
@ -86,10 +179,8 @@ class TrialAppWorkflowRunApi(TrialAppResource):
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
request_data = WorkflowRunRequest.model_validate(console_ns.payload)
|
||||
args = request_data.model_dump()
|
||||
assert current_user is not None
|
||||
try:
|
||||
app_id = app_model.id
|
||||
@ -140,6 +231,7 @@ class TrialAppWorkflowTaskStopApi(TrialAppResource):
|
||||
|
||||
|
||||
class TrialChatApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[ChatRequest.__name__])
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
@ -147,14 +239,14 @@ class TrialChatApi(TrialAppResource):
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, required=True, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
request_data = ChatRequest.model_validate(console_ns.payload)
|
||||
args = request_data.model_dump()
|
||||
|
||||
# Validate UUID values if provided
|
||||
if args.get("conversation_id"):
|
||||
args["conversation_id"] = uuid_value(args["conversation_id"])
|
||||
if args.get("parent_message_id"):
|
||||
args["parent_message_id"] = uuid_value(args["parent_message_id"])
|
||||
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
@ -277,20 +369,16 @@ class TrialChatAudioApi(TrialAppResource):
|
||||
|
||||
|
||||
class TrialChatTextApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[TextToSpeechRequest.__name__])
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", type=str, required=False, location="json")
|
||||
parser.add_argument("voice", type=str, location="json")
|
||||
parser.add_argument("text", type=str, location="json")
|
||||
parser.add_argument("streaming", type=bool, location="json")
|
||||
args = parser.parse_args()
|
||||
request_data = TextToSpeechRequest.model_validate(console_ns.payload)
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
voice = args.get("voice", None)
|
||||
message_id = request_data.message_id
|
||||
text = request_data.text
|
||||
voice = request_data.voice
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
@ -328,19 +416,15 @@ class TrialChatTextApi(TrialAppResource):
|
||||
|
||||
|
||||
class TrialCompletionApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[CompletionRequest.__name__])
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, location="json", default="")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
request_data = CompletionRequest.model_validate(console_ns.payload)
|
||||
args = request_data.model_dump()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
@ -437,7 +521,7 @@ class TrialAppParameterApi(Resource):
|
||||
class AppApi(Resource):
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
@marshal_with(app_detail_with_site_model)
|
||||
def get(self, app_model):
|
||||
"""Get app detail"""
|
||||
|
||||
@ -450,7 +534,7 @@ class AppApi(Resource):
|
||||
class AppWorkflowApi(Resource):
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial
|
||||
@marshal_with(workflow_fields)
|
||||
@marshal_with(workflow_model)
|
||||
def get(self, app_model):
|
||||
"""Get workflow detail"""
|
||||
if not app_model.workflow_id:
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from flask_restx import Resource, fields
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import current_account_with_tenant, current_user, login_required
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from . import console_ns
|
||||
@ -39,5 +40,21 @@ class SystemFeatureApi(Resource):
|
||||
),
|
||||
)
|
||||
def get(self):
|
||||
"""Get system-wide feature configuration"""
|
||||
return FeatureService.get_system_features().model_dump()
|
||||
"""Get system-wide feature configuration
|
||||
|
||||
NOTE: This endpoint is unauthenticated by design, as it provides system features
|
||||
data required for dashboard initialization.
|
||||
|
||||
Authentication would create circular dependency (can't login without dashboard loading).
|
||||
|
||||
Only non-sensitive configuration data should be returned by this endpoint.
|
||||
"""
|
||||
# NOTE(QuantumGhost): ideally we should access `current_user.is_authenticated`
|
||||
# without a try-catch. However, due to the implementation of user loader (the `load_user_from_request`
|
||||
# in api/extensions/ext_login.py), accessing `current_user.is_authenticated` will
|
||||
# raise `Unauthorized` exception if authentication token is not provided.
|
||||
try:
|
||||
is_authenticated = current_user.is_authenticated
|
||||
except Unauthorized:
|
||||
is_authenticated = False
|
||||
return FeatureService.get_system_features(is_authenticated=is_authenticated).model_dump()
|
||||
|
||||
@ -1,87 +1,74 @@
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
from flask import session
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.fastopenapi import console_router
|
||||
from extensions.ext_database import db
|
||||
from models.model import DifySetup
|
||||
from services.account_service import TenantService
|
||||
|
||||
from . import console_ns
|
||||
from .error import AlreadySetupError, InitValidateFailedError
|
||||
from .wraps import only_edition_self_hosted
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class InitValidatePayload(BaseModel):
|
||||
password: str = Field(..., max_length=30)
|
||||
password: str = Field(..., max_length=30, description="Initialization password")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
InitValidatePayload.__name__,
|
||||
InitValidatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
class InitStatusResponse(BaseModel):
|
||||
status: Literal["finished", "not_started"] = Field(..., description="Initialization status")
|
||||
|
||||
|
||||
class InitValidateResponse(BaseModel):
|
||||
result: str = Field(description="Operation result", examples=["success"])
|
||||
|
||||
|
||||
@console_router.get(
|
||||
"/init",
|
||||
response_model=InitStatusResponse,
|
||||
tags=["console"],
|
||||
)
|
||||
def get_init_status() -> InitStatusResponse:
|
||||
"""Get initialization validation status."""
|
||||
init_status = get_init_validate_status()
|
||||
if init_status:
|
||||
return InitStatusResponse(status="finished")
|
||||
return InitStatusResponse(status="not_started")
|
||||
|
||||
|
||||
@console_ns.route("/init")
|
||||
class InitValidateAPI(Resource):
|
||||
@console_ns.doc("get_init_status")
|
||||
@console_ns.doc(description="Get initialization validation status")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
model=console_ns.model(
|
||||
"InitStatusResponse",
|
||||
{"status": fields.String(description="Initialization status", enum=["finished", "not_started"])},
|
||||
),
|
||||
)
|
||||
def get(self):
|
||||
"""Get initialization validation status"""
|
||||
init_status = get_init_validate_status()
|
||||
if init_status:
|
||||
return {"status": "finished"}
|
||||
return {"status": "not_started"}
|
||||
@console_router.post(
|
||||
"/init",
|
||||
response_model=InitValidateResponse,
|
||||
tags=["console"],
|
||||
status_code=201,
|
||||
)
|
||||
@only_edition_self_hosted
|
||||
def validate_init_password(payload: InitValidatePayload) -> InitValidateResponse:
|
||||
"""Validate initialization password."""
|
||||
tenant_count = TenantService.get_tenant_count()
|
||||
if tenant_count > 0:
|
||||
raise AlreadySetupError()
|
||||
|
||||
@console_ns.doc("validate_init_password")
|
||||
@console_ns.doc(description="Validate initialization password for self-hosted edition")
|
||||
@console_ns.expect(console_ns.models[InitValidatePayload.__name__])
|
||||
@console_ns.response(
|
||||
201,
|
||||
"Success",
|
||||
model=console_ns.model("InitValidateResponse", {"result": fields.String(description="Operation result")}),
|
||||
)
|
||||
@console_ns.response(400, "Already setup or validation failed")
|
||||
@only_edition_self_hosted
|
||||
def post(self):
|
||||
"""Validate initialization password"""
|
||||
# is tenant created
|
||||
tenant_count = TenantService.get_tenant_count()
|
||||
if tenant_count > 0:
|
||||
raise AlreadySetupError()
|
||||
if payload.password != os.environ.get("INIT_PASSWORD"):
|
||||
session["is_init_validated"] = False
|
||||
raise InitValidateFailedError()
|
||||
|
||||
payload = InitValidatePayload.model_validate(console_ns.payload)
|
||||
input_password = payload.password
|
||||
|
||||
if input_password != os.environ.get("INIT_PASSWORD"):
|
||||
session["is_init_validated"] = False
|
||||
raise InitValidateFailedError()
|
||||
|
||||
session["is_init_validated"] = True
|
||||
return {"result": "success"}, 201
|
||||
session["is_init_validated"] = True
|
||||
return InitValidateResponse(result="success")
|
||||
|
||||
|
||||
def get_init_validate_status():
|
||||
def get_init_validate_status() -> bool:
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
if os.environ.get("INIT_PASSWORD"):
|
||||
if session.get("is_init_validated"):
|
||||
return True
|
||||
|
||||
with Session(db.engine) as db_session:
|
||||
return db_session.execute(select(DifySetup)).scalar_one_or_none()
|
||||
return db_session.execute(select(DifySetup)).scalar_one_or_none() is not None
|
||||
|
||||
return True
|
||||
|
||||
@ -1,17 +1,17 @@
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from . import console_ns
|
||||
from controllers.fastopenapi import console_router
|
||||
|
||||
|
||||
@console_ns.route("/ping")
|
||||
class PingApi(Resource):
|
||||
@console_ns.doc("health_check")
|
||||
@console_ns.doc(description="Health check endpoint for connection testing")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model("PingResponse", {"result": fields.String(description="Health check result", example="pong")}),
|
||||
)
|
||||
def get(self):
|
||||
"""Health check endpoint for connection testing"""
|
||||
return {"result": "pong"}
|
||||
class PingResponse(BaseModel):
|
||||
result: str = Field(description="Health check result", examples=["pong"])
|
||||
|
||||
|
||||
@console_router.get(
|
||||
"/ping",
|
||||
response_model=PingResponse,
|
||||
tags=["console"],
|
||||
)
|
||||
def ping() -> PingResponse:
|
||||
"""Health check endpoint for connection testing."""
|
||||
return PingResponse(result="pong")
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import urllib.parse
|
||||
|
||||
import httpx
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import services
|
||||
@ -11,7 +10,7 @@ from controllers.common.errors import (
|
||||
RemoteFileUploadError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.fastopenapi import console_router
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
@ -19,84 +18,74 @@ from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
|
||||
from libs.login import current_account_with_tenant
|
||||
from services.file_service import FileService
|
||||
|
||||
from . import console_ns
|
||||
|
||||
register_schema_models(console_ns, RemoteFileInfo, FileWithSignedUrl)
|
||||
|
||||
|
||||
@console_ns.route("/remote-files/<path:url>")
|
||||
class RemoteFileInfoApi(Resource):
|
||||
@console_ns.response(200, "Remote file info", console_ns.models[RemoteFileInfo.__name__])
|
||||
def get(self, url):
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
resp = ssrf_proxy.head(decoded_url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
# failed back to get method
|
||||
resp = ssrf_proxy.get(decoded_url, timeout=3)
|
||||
resp.raise_for_status()
|
||||
info = RemoteFileInfo(
|
||||
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
|
||||
file_length=int(resp.headers.get("Content-Length", 0)),
|
||||
)
|
||||
return info.model_dump(mode="json")
|
||||
|
||||
|
||||
class RemoteFileUploadPayload(BaseModel):
|
||||
url: str = Field(..., description="URL to fetch")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
RemoteFileUploadPayload.__name__,
|
||||
RemoteFileUploadPayload.model_json_schema(ref_template="#/definitions/{model}"),
|
||||
@console_router.get(
|
||||
"/remote-files/<path:url>",
|
||||
response_model=RemoteFileInfo,
|
||||
tags=["console"],
|
||||
)
|
||||
def get_remote_file_info(url: str) -> RemoteFileInfo:
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
resp = ssrf_proxy.head(decoded_url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(decoded_url, timeout=3)
|
||||
resp.raise_for_status()
|
||||
return RemoteFileInfo(
|
||||
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
|
||||
file_length=int(resp.headers.get("Content-Length", 0)),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/remote-files/upload")
|
||||
class RemoteFileUploadApi(Resource):
|
||||
@console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__])
|
||||
@console_ns.response(201, "Remote file uploaded", console_ns.models[FileWithSignedUrl.__name__])
|
||||
def post(self):
|
||||
args = RemoteFileUploadPayload.model_validate(console_ns.payload)
|
||||
url = args.url
|
||||
@console_router.post(
|
||||
"/remote-files/upload",
|
||||
response_model=FileWithSignedUrl,
|
||||
tags=["console"],
|
||||
status_code=201,
|
||||
)
|
||||
def upload_remote_file(payload: RemoteFileUploadPayload) -> FileWithSignedUrl:
|
||||
url = payload.url
|
||||
|
||||
try:
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
|
||||
except httpx.RequestError as e:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}")
|
||||
try:
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
|
||||
except httpx.RequestError as e:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}")
|
||||
|
||||
file_info = helpers.guess_file_info_from_response(resp)
|
||||
file_info = helpers.guess_file_info_from_response(resp)
|
||||
|
||||
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
|
||||
raise FileTooLargeError
|
||||
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
|
||||
raise FileTooLargeError
|
||||
|
||||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||
|
||||
try:
|
||||
user, _ = current_account_with_tenant()
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file_info.filename,
|
||||
content=content,
|
||||
mimetype=file_info.mimetype,
|
||||
user=user,
|
||||
source_url=url,
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
payload = FileWithSignedUrl(
|
||||
id=upload_file.id,
|
||||
name=upload_file.name,
|
||||
size=upload_file.size,
|
||||
extension=upload_file.extension,
|
||||
url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
|
||||
mime_type=upload_file.mime_type,
|
||||
created_by=upload_file.created_by,
|
||||
created_at=int(upload_file.created_at.timestamp()),
|
||||
try:
|
||||
user, _ = current_account_with_tenant()
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file_info.filename,
|
||||
content=content,
|
||||
mimetype=file_info.mimetype,
|
||||
user=user,
|
||||
source_url=url,
|
||||
)
|
||||
return payload.model_dump(mode="json"), 201
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return FileWithSignedUrl(
|
||||
id=upload_file.id,
|
||||
name=upload_file.name,
|
||||
size=upload_file.size,
|
||||
extension=upload_file.extension,
|
||||
url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
|
||||
mime_type=upload_file.mime_type,
|
||||
created_by=upload_file.created_by,
|
||||
created_at=int(upload_file.created_at.timestamp()),
|
||||
)
|
||||
|
||||
@ -1,20 +1,19 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.fastopenapi import console_router
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models.model import DifySetup, db
|
||||
from services.account_service import RegisterService, TenantService
|
||||
|
||||
from . import console_ns
|
||||
from .error import AlreadySetupError, NotInitValidateError
|
||||
from .init_validate import get_init_validate_status
|
||||
from .wraps import only_edition_self_hosted
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class SetupRequestPayload(BaseModel):
|
||||
email: EmailStr = Field(..., description="Admin email address")
|
||||
@ -28,78 +27,66 @@ class SetupRequestPayload(BaseModel):
|
||||
return valid_password(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
SetupRequestPayload.__name__,
|
||||
SetupRequestPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
class SetupStatusResponse(BaseModel):
|
||||
step: Literal["not_started", "finished"] = Field(description="Setup step status")
|
||||
setup_at: str | None = Field(default=None, description="Setup completion time (ISO format)")
|
||||
|
||||
|
||||
class SetupResponse(BaseModel):
|
||||
result: str = Field(description="Setup result", examples=["success"])
|
||||
|
||||
|
||||
@console_router.get(
|
||||
"/setup",
|
||||
response_model=SetupStatusResponse,
|
||||
tags=["console"],
|
||||
)
|
||||
def get_setup_status_api() -> SetupStatusResponse:
|
||||
"""Get system setup status."""
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
setup_status = get_setup_status()
|
||||
if setup_status and not isinstance(setup_status, bool):
|
||||
return SetupStatusResponse(step="finished", setup_at=setup_status.setup_at.isoformat())
|
||||
if setup_status:
|
||||
return SetupStatusResponse(step="finished")
|
||||
return SetupStatusResponse(step="not_started")
|
||||
return SetupStatusResponse(step="finished")
|
||||
|
||||
|
||||
@console_ns.route("/setup")
|
||||
class SetupApi(Resource):
|
||||
@console_ns.doc("get_setup_status")
|
||||
@console_ns.doc(description="Get system setup status")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model(
|
||||
"SetupStatusResponse",
|
||||
{
|
||||
"step": fields.String(description="Setup step status", enum=["not_started", "finished"]),
|
||||
"setup_at": fields.String(description="Setup completion time (ISO format)", required=False),
|
||||
},
|
||||
),
|
||||
@console_router.post(
|
||||
"/setup",
|
||||
response_model=SetupResponse,
|
||||
tags=["console"],
|
||||
status_code=201,
|
||||
)
|
||||
@only_edition_self_hosted
|
||||
def setup_system(payload: SetupRequestPayload) -> SetupResponse:
|
||||
"""Initialize system setup with admin account."""
|
||||
if get_setup_status():
|
||||
raise AlreadySetupError()
|
||||
|
||||
tenant_count = TenantService.get_tenant_count()
|
||||
if tenant_count > 0:
|
||||
raise AlreadySetupError()
|
||||
|
||||
if not get_init_validate_status():
|
||||
raise NotInitValidateError()
|
||||
|
||||
normalized_email = payload.email.lower()
|
||||
|
||||
RegisterService.setup(
|
||||
email=normalized_email,
|
||||
name=payload.name,
|
||||
password=payload.password,
|
||||
ip_address=extract_remote_ip(request),
|
||||
language=payload.language,
|
||||
)
|
||||
def get(self):
|
||||
"""Get system setup status"""
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
setup_status = get_setup_status()
|
||||
# Check if setup_status is a DifySetup object rather than a bool
|
||||
if setup_status and not isinstance(setup_status, bool):
|
||||
return {"step": "finished", "setup_at": setup_status.setup_at.isoformat()}
|
||||
elif setup_status:
|
||||
return {"step": "finished"}
|
||||
return {"step": "not_started"}
|
||||
return {"step": "finished"}
|
||||
|
||||
@console_ns.doc("setup_system")
|
||||
@console_ns.doc(description="Initialize system setup with admin account")
|
||||
@console_ns.expect(console_ns.models[SetupRequestPayload.__name__])
|
||||
@console_ns.response(
|
||||
201, "Success", console_ns.model("SetupResponse", {"result": fields.String(description="Setup result")})
|
||||
)
|
||||
@console_ns.response(400, "Already setup or validation failed")
|
||||
@only_edition_self_hosted
|
||||
def post(self):
|
||||
"""Initialize system setup with admin account"""
|
||||
# is set up
|
||||
if get_setup_status():
|
||||
raise AlreadySetupError()
|
||||
|
||||
# is tenant created
|
||||
tenant_count = TenantService.get_tenant_count()
|
||||
if tenant_count > 0:
|
||||
raise AlreadySetupError()
|
||||
|
||||
if not get_init_validate_status():
|
||||
raise NotInitValidateError()
|
||||
|
||||
args = SetupRequestPayload.model_validate(console_ns.payload)
|
||||
normalized_email = args.email.lower()
|
||||
|
||||
# setup
|
||||
RegisterService.setup(
|
||||
email=normalized_email,
|
||||
name=args.name,
|
||||
password=args.password,
|
||||
ip_address=extract_remote_ip(request),
|
||||
language=args.language,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 201
|
||||
return SetupResponse(result="success")
|
||||
|
||||
|
||||
def get_setup_status():
|
||||
def get_setup_status() -> DifySetup | bool | None:
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
return db.session.query(DifySetup).first()
|
||||
else:
|
||||
return True
|
||||
|
||||
return True
|
||||
|
||||
@ -1,17 +1,27 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from flask_restx import Namespace, Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from fields.tag_fields import dataset_tag_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.tag_service import TagService
|
||||
|
||||
dataset_tag_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"type": fields.String,
|
||||
"binding_count": fields.String,
|
||||
}
|
||||
|
||||
|
||||
def build_dataset_tag_fields(api_or_ns: Namespace):
|
||||
return api_or_ns.model("DataSetTag", dataset_tag_fields)
|
||||
|
||||
|
||||
class TagBasePayload(BaseModel):
|
||||
name: str = Field(description="Tag name", min_length=1, max_length=50)
|
||||
@ -40,6 +50,7 @@ register_schema_models(
|
||||
TagBasePayload,
|
||||
TagBindingPayload,
|
||||
TagBindingRemovePayload,
|
||||
TagListQueryParam,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -1,15 +1,11 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from packaging import version
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
from . import console_ns
|
||||
from controllers.fastopenapi import console_router
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -18,69 +14,61 @@ class VersionQuery(BaseModel):
|
||||
current_version: str = Field(..., description="Current application version")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
VersionQuery.__name__,
|
||||
VersionQuery.model_json_schema(ref_template="#/definitions/{model}"),
|
||||
class VersionFeatures(BaseModel):
|
||||
can_replace_logo: bool = Field(description="Whether logo replacement is supported")
|
||||
model_load_balancing_enabled: bool = Field(description="Whether model load balancing is enabled")
|
||||
|
||||
|
||||
class VersionResponse(BaseModel):
|
||||
version: str = Field(description="Latest version number")
|
||||
release_date: str = Field(description="Release date of latest version")
|
||||
release_notes: str = Field(description="Release notes for latest version")
|
||||
can_auto_update: bool = Field(description="Whether auto-update is supported")
|
||||
features: VersionFeatures = Field(description="Feature flags and capabilities")
|
||||
|
||||
|
||||
@console_router.get(
|
||||
"/version",
|
||||
response_model=VersionResponse,
|
||||
tags=["console"],
|
||||
)
|
||||
def check_version_update(query: VersionQuery) -> VersionResponse:
|
||||
"""Check for application version updates."""
|
||||
check_update_url = dify_config.CHECK_UPDATE_URL
|
||||
|
||||
|
||||
@console_ns.route("/version")
|
||||
class VersionApi(Resource):
|
||||
@console_ns.doc("check_version_update")
|
||||
@console_ns.doc(description="Check for application version updates")
|
||||
@console_ns.expect(console_ns.models[VersionQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model(
|
||||
"VersionResponse",
|
||||
{
|
||||
"version": fields.String(description="Latest version number"),
|
||||
"release_date": fields.String(description="Release date of latest version"),
|
||||
"release_notes": fields.String(description="Release notes for latest version"),
|
||||
"can_auto_update": fields.Boolean(description="Whether auto-update is supported"),
|
||||
"features": fields.Raw(description="Feature flags and capabilities"),
|
||||
},
|
||||
result = VersionResponse(
|
||||
version=dify_config.project.version,
|
||||
release_date="",
|
||||
release_notes="",
|
||||
can_auto_update=False,
|
||||
features=VersionFeatures(
|
||||
can_replace_logo=dify_config.CAN_REPLACE_LOGO,
|
||||
model_load_balancing_enabled=dify_config.MODEL_LB_ENABLED,
|
||||
),
|
||||
)
|
||||
def get(self):
|
||||
"""Check for application version updates"""
|
||||
args = VersionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
check_update_url = dify_config.CHECK_UPDATE_URL
|
||||
|
||||
result = {
|
||||
"version": dify_config.project.version,
|
||||
"release_date": "",
|
||||
"release_notes": "",
|
||||
"can_auto_update": False,
|
||||
"features": {
|
||||
"can_replace_logo": dify_config.CAN_REPLACE_LOGO,
|
||||
"model_load_balancing_enabled": dify_config.MODEL_LB_ENABLED,
|
||||
},
|
||||
}
|
||||
|
||||
if not check_update_url:
|
||||
return result
|
||||
|
||||
try:
|
||||
response = httpx.get(
|
||||
check_update_url,
|
||||
params={"current_version": args.current_version},
|
||||
timeout=httpx.Timeout(timeout=10.0, connect=3.0),
|
||||
)
|
||||
except Exception as error:
|
||||
logger.warning("Check update version error: %s.", str(error))
|
||||
result["version"] = args.current_version
|
||||
return result
|
||||
|
||||
content = json.loads(response.content)
|
||||
if _has_new_version(latest_version=content["version"], current_version=f"{args.current_version}"):
|
||||
result["version"] = content["version"]
|
||||
result["release_date"] = content["releaseDate"]
|
||||
result["release_notes"] = content["releaseNotes"]
|
||||
result["can_auto_update"] = content["canAutoUpdate"]
|
||||
if not check_update_url:
|
||||
return result
|
||||
|
||||
try:
|
||||
response = httpx.get(
|
||||
check_update_url,
|
||||
params={"current_version": query.current_version},
|
||||
timeout=httpx.Timeout(timeout=10.0, connect=3.0),
|
||||
)
|
||||
content = response.json()
|
||||
except Exception as error:
|
||||
logger.warning("Check update version error: %s.", str(error))
|
||||
result.version = query.current_version
|
||||
return result
|
||||
latest_version = content.get("version", result.version)
|
||||
if _has_new_version(latest_version=latest_version, current_version=f"{query.current_version}"):
|
||||
result.version = latest_version
|
||||
result.release_date = content.get("releaseDate", "")
|
||||
result.release_notes = content.get("releaseNotes", "")
|
||||
result.can_auto_update = content.get("canAutoUpdate", False)
|
||||
return result
|
||||
|
||||
|
||||
def _has_new_version(*, latest_version: str, current_version: str) -> bool:
|
||||
try:
|
||||
|
||||
@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
EmailAlreadyInUseError,
|
||||
@ -37,7 +38,7 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.member_fields import account_fields
|
||||
from fields.member_fields import Account as AccountResponse
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
@ -170,6 +171,25 @@ reg(ChangeEmailSendPayload)
|
||||
reg(ChangeEmailValidityPayload)
|
||||
reg(ChangeEmailResetPayload)
|
||||
reg(CheckEmailUniquePayload)
|
||||
register_schema_models(console_ns, AccountResponse)
|
||||
|
||||
|
||||
def _serialize_account(account) -> dict:
|
||||
return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
integrate_fields = {
|
||||
"provider": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"is_bound": fields.Boolean,
|
||||
"link": fields.String,
|
||||
}
|
||||
|
||||
integrate_model = console_ns.model("AccountIntegrate", integrate_fields)
|
||||
integrate_list_model = console_ns.model(
|
||||
"AccountIntegrateList",
|
||||
{"data": fields.List(fields.Nested(integrate_model))},
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/init")
|
||||
@ -223,11 +243,11 @@ class AccountProfileApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
return current_user
|
||||
return _serialize_account(current_user)
|
||||
|
||||
|
||||
@console_ns.route("/account/name")
|
||||
@ -236,14 +256,14 @@ class AccountNameApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountNamePayload.model_validate(payload)
|
||||
updated_account = AccountService.update_account(current_user, name=args.name)
|
||||
|
||||
return updated_account
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
|
||||
@console_ns.route("/account/avatar")
|
||||
@ -252,7 +272,7 @@ class AccountAvatarApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = console_ns.payload or {}
|
||||
@ -260,7 +280,7 @@ class AccountAvatarApi(Resource):
|
||||
|
||||
updated_account = AccountService.update_account(current_user, avatar=args.avatar)
|
||||
|
||||
return updated_account
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
|
||||
@console_ns.route("/account/interface-language")
|
||||
@ -269,7 +289,7 @@ class AccountInterfaceLanguageApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = console_ns.payload or {}
|
||||
@ -277,7 +297,7 @@ class AccountInterfaceLanguageApi(Resource):
|
||||
|
||||
updated_account = AccountService.update_account(current_user, interface_language=args.interface_language)
|
||||
|
||||
return updated_account
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
|
||||
@console_ns.route("/account/interface-theme")
|
||||
@ -286,7 +306,7 @@ class AccountInterfaceThemeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = console_ns.payload or {}
|
||||
@ -294,7 +314,7 @@ class AccountInterfaceThemeApi(Resource):
|
||||
|
||||
updated_account = AccountService.update_account(current_user, interface_theme=args.interface_theme)
|
||||
|
||||
return updated_account
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
|
||||
@console_ns.route("/account/timezone")
|
||||
@ -303,7 +323,7 @@ class AccountTimezoneApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = console_ns.payload or {}
|
||||
@ -311,7 +331,7 @@ class AccountTimezoneApi(Resource):
|
||||
|
||||
updated_account = AccountService.update_account(current_user, timezone=args.timezone)
|
||||
|
||||
return updated_account
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
|
||||
@console_ns.route("/account/password")
|
||||
@ -320,7 +340,7 @@ class AccountPasswordApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = console_ns.payload or {}
|
||||
@ -331,26 +351,15 @@ class AccountPasswordApi(Resource):
|
||||
except ServiceCurrentPasswordIncorrectError:
|
||||
raise CurrentPasswordIncorrectError()
|
||||
|
||||
return {"result": "success"}
|
||||
return _serialize_account(current_user)
|
||||
|
||||
|
||||
@console_ns.route("/account/integrates")
|
||||
class AccountIntegrateApi(Resource):
|
||||
integrate_fields = {
|
||||
"provider": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"is_bound": fields.Boolean,
|
||||
"link": fields.String,
|
||||
}
|
||||
|
||||
integrate_list_fields = {
|
||||
"data": fields.List(fields.Nested(integrate_fields)),
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(integrate_list_fields)
|
||||
@marshal_with(integrate_list_model)
|
||||
def get(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@ -618,7 +627,7 @@ class ChangeEmailResetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
payload = console_ns.payload or {}
|
||||
args = ChangeEmailResetPayload.model_validate(payload)
|
||||
@ -647,7 +656,7 @@ class ChangeEmailResetApi(Resource):
|
||||
email=normalized_new_email,
|
||||
)
|
||||
|
||||
return updated_account
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email/check-email-unique")
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
@ -38,15 +39,53 @@ class EndpointListForPluginQuery(EndpointListQuery):
|
||||
plugin_id: str
|
||||
|
||||
|
||||
class EndpointCreateResponse(BaseModel):
|
||||
success: bool = Field(description="Operation success")
|
||||
|
||||
|
||||
class EndpointListResponse(BaseModel):
|
||||
endpoints: list[dict[str, Any]] = Field(description="Endpoint information")
|
||||
|
||||
|
||||
class PluginEndpointListResponse(BaseModel):
|
||||
endpoints: list[dict[str, Any]] = Field(description="Endpoint information")
|
||||
|
||||
|
||||
class EndpointDeleteResponse(BaseModel):
|
||||
success: bool = Field(description="Operation success")
|
||||
|
||||
|
||||
class EndpointUpdateResponse(BaseModel):
|
||||
success: bool = Field(description="Operation success")
|
||||
|
||||
|
||||
class EndpointEnableResponse(BaseModel):
|
||||
success: bool = Field(description="Operation success")
|
||||
|
||||
|
||||
class EndpointDisableResponse(BaseModel):
|
||||
success: bool = Field(description="Operation success")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(EndpointCreatePayload)
|
||||
reg(EndpointIdPayload)
|
||||
reg(EndpointUpdatePayload)
|
||||
reg(EndpointListQuery)
|
||||
reg(EndpointListForPluginQuery)
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
EndpointCreatePayload,
|
||||
EndpointIdPayload,
|
||||
EndpointUpdatePayload,
|
||||
EndpointListQuery,
|
||||
EndpointListForPluginQuery,
|
||||
EndpointCreateResponse,
|
||||
EndpointListResponse,
|
||||
PluginEndpointListResponse,
|
||||
EndpointDeleteResponse,
|
||||
EndpointUpdateResponse,
|
||||
EndpointEnableResponse,
|
||||
EndpointDisableResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/create")
|
||||
@ -57,7 +96,7 @@ class EndpointCreateApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint created successfully",
|
||||
console_ns.model("EndpointCreateResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
console_ns.models[EndpointCreateResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@ -91,9 +130,7 @@ class EndpointListApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model(
|
||||
"EndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}
|
||||
),
|
||||
console_ns.models[EndpointListResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -126,9 +163,7 @@ class EndpointListForSinglePluginApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model(
|
||||
"PluginEndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}
|
||||
),
|
||||
console_ns.models[PluginEndpointListResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -163,7 +198,7 @@ class EndpointDeleteApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint deleted successfully",
|
||||
console_ns.model("EndpointDeleteResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
console_ns.models[EndpointDeleteResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@ -190,7 +225,7 @@ class EndpointUpdateApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint updated successfully",
|
||||
console_ns.model("EndpointUpdateResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
console_ns.models[EndpointUpdateResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@ -221,7 +256,7 @@ class EndpointEnableApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint enabled successfully",
|
||||
console_ns.model("EndpointEnableResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
console_ns.models[EndpointEnableResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@ -248,7 +283,7 @@ class EndpointDisableApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint disabled successfully",
|
||||
console_ns.model("EndpointDisableResponse", {"success": fields.Boolean(description="Operation success")}),
|
||||
console_ns.models[EndpointDisableResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
from urllib import parse
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import register_enum_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
CannotTransferOwnerToSelfError,
|
||||
@ -24,7 +25,7 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.member_fields import account_with_role_list_fields
|
||||
from fields.member_fields import AccountWithRole, AccountWithRoleList
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import Account, TenantAccountRole
|
||||
@ -67,6 +68,8 @@ reg(MemberRoleUpdatePayload)
|
||||
reg(OwnerTransferEmailPayload)
|
||||
reg(OwnerTransferCheckPayload)
|
||||
reg(OwnerTransferPayload)
|
||||
register_enum_models(console_ns, TenantAccountRole)
|
||||
register_schema_models(console_ns, AccountWithRole, AccountWithRoleList)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members")
|
||||
@ -76,13 +79,15 @@ class MemberListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_with_role_list_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__])
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
return {"result": "success", "accounts": members}, 200
|
||||
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
|
||||
response = AccountWithRoleList(accounts=member_models)
|
||||
return response.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/invite-email")
|
||||
@ -227,13 +232,15 @@ class DatasetOperatorMemberListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_with_role_list_fields)
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__])
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
|
||||
return {"result": "success", "accounts": members}, 200
|
||||
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
|
||||
response = AccountWithRoleList(accounts=member_models)
|
||||
return response.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email")
|
||||
|
||||
@ -5,6 +5,7 @@ from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from controllers.common.schema import register_enum_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
@ -23,12 +24,13 @@ class ParserGetDefault(BaseModel):
|
||||
model_type: ModelType
|
||||
|
||||
|
||||
class ParserPostDefault(BaseModel):
|
||||
class Inner(BaseModel):
|
||||
model_type: ModelType
|
||||
model: str | None = None
|
||||
provider: str | None = None
|
||||
class Inner(BaseModel):
|
||||
model_type: ModelType
|
||||
model: str | None = None
|
||||
provider: str | None = None
|
||||
|
||||
|
||||
class ParserPostDefault(BaseModel):
|
||||
model_settings: list[Inner]
|
||||
|
||||
|
||||
@ -105,19 +107,21 @@ class ParserParameter(BaseModel):
|
||||
model: str
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
ParserGetDefault,
|
||||
ParserPostDefault,
|
||||
ParserDeleteModels,
|
||||
ParserPostModels,
|
||||
ParserGetCredentials,
|
||||
ParserCreateCredential,
|
||||
ParserUpdateCredential,
|
||||
ParserDeleteCredential,
|
||||
ParserParameter,
|
||||
Inner,
|
||||
)
|
||||
|
||||
|
||||
reg(ParserGetDefault)
|
||||
reg(ParserPostDefault)
|
||||
reg(ParserDeleteModels)
|
||||
reg(ParserPostModels)
|
||||
reg(ParserGetCredentials)
|
||||
reg(ParserCreateCredential)
|
||||
reg(ParserUpdateCredential)
|
||||
reg(ParserDeleteCredential)
|
||||
reg(ParserParameter)
|
||||
register_enum_models(console_ns, ModelType)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/default-model")
|
||||
|
||||
@ -8,6 +8,7 @@ from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import register_enum_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.workspace import plugin_permission_required
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
@ -20,57 +21,12 @@ from services.plugin.plugin_parameter_service import PluginParameterService
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/debugging-key")
|
||||
class PluginDebuggingKeyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
return {
|
||||
"key": PluginService.get_debugging_key(tenant_id),
|
||||
"host": dify_config.PLUGIN_REMOTE_INSTALL_HOST,
|
||||
"port": dify_config.PLUGIN_REMOTE_INSTALL_PORT,
|
||||
}
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
class ParserList(BaseModel):
|
||||
page: int = Field(default=1, ge=1, description="Page number")
|
||||
page_size: int = Field(default=256, ge=1, le=256, description="Page size (1-256)")
|
||||
|
||||
|
||||
reg(ParserList)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/list")
|
||||
class PluginListApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserList.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
args = ParserList.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
try:
|
||||
plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total})
|
||||
|
||||
|
||||
class ParserLatest(BaseModel):
|
||||
plugin_ids: list[str]
|
||||
|
||||
@ -180,23 +136,73 @@ class ParserReadme(BaseModel):
|
||||
language: str = Field(default="en-US")
|
||||
|
||||
|
||||
reg(ParserLatest)
|
||||
reg(ParserIcon)
|
||||
reg(ParserAsset)
|
||||
reg(ParserGithubUpload)
|
||||
reg(ParserPluginIdentifiers)
|
||||
reg(ParserGithubInstall)
|
||||
reg(ParserPluginIdentifierQuery)
|
||||
reg(ParserTasks)
|
||||
reg(ParserMarketplaceUpgrade)
|
||||
reg(ParserGithubUpgrade)
|
||||
reg(ParserUninstall)
|
||||
reg(ParserPermissionChange)
|
||||
reg(ParserDynamicOptions)
|
||||
reg(ParserDynamicOptionsWithCredentials)
|
||||
reg(ParserPreferencesChange)
|
||||
reg(ParserExcludePlugin)
|
||||
reg(ParserReadme)
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
ParserList,
|
||||
PluginAutoUpgradeSettingsPayload,
|
||||
PluginPermissionSettingsPayload,
|
||||
ParserLatest,
|
||||
ParserIcon,
|
||||
ParserAsset,
|
||||
ParserGithubUpload,
|
||||
ParserPluginIdentifiers,
|
||||
ParserGithubInstall,
|
||||
ParserPluginIdentifierQuery,
|
||||
ParserTasks,
|
||||
ParserMarketplaceUpgrade,
|
||||
ParserGithubUpgrade,
|
||||
ParserUninstall,
|
||||
ParserPermissionChange,
|
||||
ParserDynamicOptions,
|
||||
ParserDynamicOptionsWithCredentials,
|
||||
ParserPreferencesChange,
|
||||
ParserExcludePlugin,
|
||||
ParserReadme,
|
||||
)
|
||||
|
||||
register_enum_models(
|
||||
console_ns,
|
||||
TenantPluginPermission.DebugPermission,
|
||||
TenantPluginAutoUpgradeStrategy.UpgradeMode,
|
||||
TenantPluginAutoUpgradeStrategy.StrategySetting,
|
||||
TenantPluginPermission.InstallPermission,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/debugging-key")
|
||||
class PluginDebuggingKeyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
return {
|
||||
"key": PluginService.get_debugging_key(tenant_id),
|
||||
"host": dify_config.PLUGIN_REMOTE_INSTALL_HOST,
|
||||
"port": dify_config.PLUGIN_REMOTE_INSTALL_PORT,
|
||||
}
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/list")
|
||||
class PluginListApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserList.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
args = ParserList.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
try:
|
||||
plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total})
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/list/latest-versions")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user