feat: sync main branch (#31938)

Signed-off-by: majiayu000 <1835304752@qq.com>
Signed-off-by: dependabot[bot] <support@github.com>
Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com>
Signed-off-by: -LAN- <laipz8200@outlook.com>
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
Co-authored-by: 盐粒 Yanli <yanli@dify.ai>
Co-authored-by: wangxiaolei <fatelei@gmail.com>
Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Cursx <33718736+Cursx@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: lif <1835304752@qq.com>
Co-authored-by: 非法操作 <hjlarry@163.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: fenglin <790872612@qq.com>
Co-authored-by: qiaofenglin <qiaofenglin@baidu.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: TomoOkuyama <49631611+TomoOkuyama@users.noreply.github.com>
Co-authored-by: Tomo Okuyama <tomo.okuyama@intersystems.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: zyssyz123 <916125788@qq.com>
Co-authored-by: hj24 <mambahj24@gmail.com>
Co-authored-by: Coding On Star <447357187@qq.com>
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
Co-authored-by: Xiangxuan Qu <fghpdf@outlook.com>
Co-authored-by: fghpdf <fghpdf@users.noreply.github.com>
Co-authored-by: coopercoder <whitetiger0127@163.com>
Co-authored-by: zhaiguangpeng <zhaiguangpeng@didiglobal.com>
Co-authored-by: Junyan Qin (Chin) <rockchinq@gmail.com>
Co-authored-by: E.G <146701565+GlobalStar117@users.noreply.github.com>
Co-authored-by: GlobalStar117 <GlobalStar117@users.noreply.github.com>
Co-authored-by: Claude Haiku 4.5 <noreply@anthropic.com>
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: heyszt <270985384@qq.com>
Co-authored-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com>
Co-authored-by: Yeuoly <45712896+Yeuoly@users.noreply.github.com>
Co-authored-by: zxhlyh <jasonapring2015@outlook.com>
Co-authored-by: moonpanda <chuanzegao@163.com>
Co-authored-by: warlocgao <warlocgao@tencent.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: KVOJJJin <jzongcode@gmail.com>
Co-authored-by: eux <euxx@users.noreply.github.com>
Co-authored-by: bangjiehan <bangjiehan@gmail.com>
Co-authored-by: FFXN <31929997+FFXN@users.noreply.github.com>
Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com>
Co-authored-by: Nie Ronghua <nieronghua@sf-express.com>
Co-authored-by: JQSevenMiao <141806521+JQSevenMiao@users.noreply.github.com>
Co-authored-by: jiasiqi <jiasiqi3@tal.com>
Co-authored-by: Seokrin Taron Sung <sungsjade@gmail.com>
Co-authored-by: CrabSAMA <40541269+CrabSAMA@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: yihong <zouzou0208@gmail.com>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com>
Co-authored-by: yessenia <yessenia.contact@gmail.com>
Co-authored-by: Jax <anobaka@qq.com>
Co-authored-by: niveshdandyan <155956228+niveshdandyan@users.noreply.github.com>
Co-authored-by: OSS Contributor <oss-contributor@example.com>
Co-authored-by: niveshdandyan <niveshdandyan@users.noreply.github.com>
Co-authored-by: Sean Kenneth Doherty <Smaster7772@gmail.com>
This commit is contained in:
qiuqiua
2026-02-04 19:04:24 +08:00
committed by GitHub
parent 489d27f817
commit 9ef6b90843
988 changed files with 106174 additions and 23067 deletions

View File

@ -243,15 +243,13 @@ class InsertExploreBannerApi(Resource):
def post(self):
payload = InsertExploreBannerPayload.model_validate(console_ns.payload)
content = {
"category": payload.category,
"title": payload.title,
"description": payload.description,
"img-src": payload.img_src,
}
banner = ExporleBanner(
content=content,
content={
"category": payload.category,
"title": payload.title,
"description": payload.description,
"img-src": payload.img_src,
},
link=payload.link,
sort=payload.sort,
language=payload.language,

View File

@ -22,10 +22,10 @@ api_key_fields = {
"created_at": TimestampField,
}
api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="items")}
api_key_item_model = console_ns.model("ApiKeyItem", api_key_fields)
api_key_list = {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
api_key_list_model = console_ns.model(
"ApiKeyList", {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
)

View File

@ -1,10 +1,11 @@
from typing import Any, Literal
from flask import abort, make_response, request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from flask_restx import Resource
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
@ -16,9 +17,11 @@ from controllers.console.wraps import (
)
from extensions.ext_redis import redis_client
from fields.annotation_fields import (
annotation_fields,
annotation_hit_history_fields,
build_annotation_model,
Annotation,
AnnotationExportList,
AnnotationHitHistory,
AnnotationHitHistoryList,
AnnotationList,
)
from libs.helper import uuid_value
from libs.login import login_required
@ -89,6 +92,14 @@ reg(CreateAnnotationPayload)
reg(UpdateAnnotationPayload)
reg(AnnotationReplyStatusQuery)
reg(AnnotationFilePayload)
register_schema_models(
console_ns,
Annotation,
AnnotationList,
AnnotationExportList,
AnnotationHitHistory,
AnnotationHitHistoryList,
)
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>")
@ -107,10 +118,11 @@ class AnnotationReplyActionApi(Resource):
def post(self, app_id, action: Literal["enable", "disable"]):
app_id = str(app_id)
args = AnnotationReplyPayload.model_validate(console_ns.payload)
if action == "enable":
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
elif action == "disable":
result = AppAnnotationService.disable_app_annotation(app_id)
match action:
case "enable":
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
case "disable":
result = AppAnnotationService.disable_app_annotation(app_id)
return result, 200
@ -201,33 +213,33 @@ class AnnotationApi(Resource):
app_id = str(app_id)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
response = {
"data": marshal(annotation_list, annotation_fields),
"has_more": len(annotation_list) == limit,
"limit": limit,
"total": total,
"page": page,
}
return response, 200
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
response = AnnotationList(
data=annotation_models,
has_more=len(annotation_list) == limit,
limit=limit,
total=total,
page=page,
)
return response.model_dump(mode="json"), 200
@console_ns.doc("create_annotation")
@console_ns.doc(description="Create a new annotation for an app")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[CreateAnnotationPayload.__name__])
@console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns))
@console_ns.response(201, "Annotation created successfully", console_ns.models[Annotation.__name__])
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@marshal_with(annotation_fields)
@edit_permission_required
def post(self, app_id):
app_id = str(app_id)
args = CreateAnnotationPayload.model_validate(console_ns.payload)
data = args.model_dump(exclude_none=True)
annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
return annotation
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
@setup_required
@login_required
@ -264,7 +276,7 @@ class AnnotationExportApi(Resource):
@console_ns.response(
200,
"Annotations exported successfully",
console_ns.model("AnnotationList", {"data": fields.List(fields.Nested(build_annotation_model(console_ns)))}),
console_ns.models[AnnotationExportList.__name__],
)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -274,7 +286,8 @@ class AnnotationExportApi(Resource):
def get(self, app_id):
app_id = str(app_id)
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
response_data = {"data": marshal(annotation_list, annotation_fields)}
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
response_data = AnnotationExportList(data=annotation_models).model_dump(mode="json")
# Create response with secure headers for CSV export
response = make_response(response_data, 200)
@ -289,7 +302,7 @@ class AnnotationUpdateDeleteApi(Resource):
@console_ns.doc("update_delete_annotation")
@console_ns.doc(description="Update or delete an annotation")
@console_ns.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
@console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns))
@console_ns.response(200, "Annotation updated successfully", console_ns.models[Annotation.__name__])
@console_ns.response(204, "Annotation deleted successfully")
@console_ns.response(403, "Insufficient permissions")
@console_ns.expect(console_ns.models[UpdateAnnotationPayload.__name__])
@ -298,7 +311,6 @@ class AnnotationUpdateDeleteApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@edit_permission_required
@marshal_with(annotation_fields)
def post(self, app_id, annotation_id):
app_id = str(app_id)
annotation_id = str(annotation_id)
@ -306,7 +318,7 @@ class AnnotationUpdateDeleteApi(Resource):
annotation = AppAnnotationService.update_app_annotation_directly(
args.model_dump(exclude_none=True), app_id, annotation_id
)
return annotation
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
@setup_required
@login_required
@ -414,14 +426,7 @@ class AnnotationHitHistoryListApi(Resource):
@console_ns.response(
200,
"Hit histories retrieved successfully",
console_ns.model(
"AnnotationHitHistoryList",
{
"data": fields.List(
fields.Nested(console_ns.model("AnnotationHitHistoryItem", annotation_hit_history_fields))
)
},
),
console_ns.models[AnnotationHitHistoryList.__name__],
)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -436,11 +441,14 @@ class AnnotationHitHistoryListApi(Resource):
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(
app_id, annotation_id, page, limit
)
response = {
"data": marshal(annotation_hit_history_list, annotation_hit_history_fields),
"has_more": len(annotation_hit_history_list) == limit,
"limit": limit,
"total": total,
"page": page,
}
return response
history_models = TypeAdapter(list[AnnotationHitHistory]).validate_python(
annotation_hit_history_list, from_attributes=True
)
response = AnnotationHitHistoryList(
data=history_models,
has_more=len(annotation_hit_history_list) == limit,
limit=limit,
total=total,
page=page,
)
return response.model_dump(mode="json")

View File

@ -9,9 +9,11 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest
from controllers.common.schema import register_schema_models
from controllers.common.helpers import FileInfo
from controllers.common.schema import register_enum_models, register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.workspace.models import LoadBalancingPayload
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
@ -22,18 +24,36 @@ from controllers.console.wraps import (
)
from core.file import helpers as file_helpers
from core.ops.ops_trace_manager import OpsTraceManager
from core.workflow.enums import NodeType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.workflow.enums import NodeType, WorkflowExecutionStatus
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from models import App, Workflow
from models import App, DatasetPermissionEnum, Workflow
from models.model import IconType
from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.entities.knowledge_entities.knowledge_entities import (
DataSource,
InfoList,
NotionIcon,
NotionInfo,
NotionPage,
PreProcessingRule,
RerankingModel,
Rule,
Segmentation,
WebsiteInfo,
WeightKeywordSetting,
WeightModel,
WeightVectorSetting,
)
from services.feature_service import FeatureService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
register_enum_models(console_ns, IconType)
class AppListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
@ -151,7 +171,7 @@ def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str |
if icon is None or icon_type is None:
return None
icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type)
if icon_type_value.lower() != IconType.IMAGE.value:
if icon_type_value.lower() != IconType.IMAGE:
return None
return file_helpers.get_signed_file_url(icon)
@ -391,6 +411,8 @@ class AppExportResponse(ResponseModel):
data: str
register_enum_models(console_ns, RetrievalMethod, WorkflowExecutionStatus, DatasetPermissionEnum)
register_schema_models(
console_ns,
AppListQuery,
@ -414,6 +436,22 @@ register_schema_models(
AppDetailWithSite,
AppPagination,
AppExportResponse,
Segmentation,
PreProcessingRule,
Rule,
WeightVectorSetting,
WeightKeywordSetting,
WeightModel,
RerankingModel,
InfoList,
NotionInfo,
FileInfo,
WebsiteInfo,
NotionPage,
NotionIcon,
RerankingModel,
DataSource,
LoadBalancingPayload,
)

View File

@ -41,14 +41,14 @@ DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AppImportPayload(BaseModel):
mode: str = Field(..., description="Import mode")
yaml_content: str | None = None
yaml_url: str | None = None
name: str | None = None
description: str | None = None
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
app_id: str | None = None
yaml_content: str | None = Field(None)
yaml_url: str | None = Field(None)
name: str | None = Field(None)
description: str | None = Field(None)
icon_type: str | None = Field(None)
icon: str | None = Field(None)
icon_background: str | None = Field(None)
app_id: str | None = Field(None)
console_ns.schema_model(

View File

@ -6,6 +6,7 @@ from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError
import services
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import (
AppUnavailableError,
@ -33,7 +34,6 @@ from services.errors.audio import (
)
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class TextToSpeechPayload(BaseModel):
@ -47,13 +47,11 @@ class TextToSpeechVoiceQuery(BaseModel):
language: str = Field(..., description="Language code")
console_ns.schema_model(
TextToSpeechPayload.__name__, TextToSpeechPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
TextToSpeechVoiceQuery.__name__,
TextToSpeechVoiceQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
class AudioTranscriptResponse(BaseModel):
text: str = Field(description="Transcribed text from audio")
register_schema_models(console_ns, AudioTranscriptResponse, TextToSpeechPayload, TextToSpeechVoiceQuery)
@console_ns.route("/apps/<uuid:app_id>/audio-to-text")
@ -64,7 +62,7 @@ class ChatMessageAudioApi(Resource):
@console_ns.response(
200,
"Audio transcription successful",
console_ns.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}),
console_ns.models[AudioTranscriptResponse.__name__],
)
@console_ns.response(400, "Bad request - No audio uploaded or unsupported type")
@console_ns.response(413, "Audio file too large")

View File

@ -508,16 +508,19 @@ class ChatConversationApi(Resource):
case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at <= end_datetime_utc)
if args.annotation_status == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
elif args.annotation_status == "not_annotated":
query = (
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.group_by(Conversation.id)
.having(func.count(MessageAnnotation.id) == 0)
)
match args.annotation_status:
case "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
case "not_annotated":
query = (
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.group_by(Conversation.id)
.having(func.count(MessageAnnotation.id) == 0)
)
case "all":
pass
if app_model.mode == AppMode.ADVANCED_CHAT:
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)

View File

@ -82,13 +82,13 @@ class ProviderNotSupportSpeechToTextError(BaseHTTPException):
class DraftWorkflowNotExist(BaseHTTPException):
error_code = "draft_workflow_not_exist"
description = "Draft workflow need to be initialized."
code = 400
code = 404
class DraftWorkflowNotSync(BaseHTTPException):
error_code = "draft_workflow_not_sync"
description = "Workflow graph might have been modified, please refresh and resubmit."
code = 400
code = 409
class TracingConfigNotExist(BaseHTTPException):

View File

@ -1,6 +1,5 @@
import logging
from collections.abc import Sequence
from typing import Any
from flask_restx import Resource
from pydantic import BaseModel, Field
@ -16,10 +15,12 @@ from controllers.console.app.error import (
ProviderQuotaExceededError,
)
from controllers.console.wraps import account_initialization_required, setup_required
from core.app.app_config.entities import ModelConfig
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
from core.llm_generator.llm_generator import LLMGenerator
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.generator import WorkflowGenerator
@ -31,28 +32,13 @@ from services.workflow_service import WorkflowService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class RuleGeneratePayload(BaseModel):
instruction: str = Field(..., description="Rule generation instruction")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
no_variable: bool = Field(default=False, description="Whether to exclude variables")
class RuleCodeGeneratePayload(RuleGeneratePayload):
code_language: str = Field(default="javascript", description="Programming language for code generation")
class RuleStructuredOutputPayload(BaseModel):
instruction: str = Field(..., description="Structured output generation instruction")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
class InstructionGeneratePayload(BaseModel):
flow_id: str = Field(..., description="Workflow/Flow ID")
node_id: str = Field(default="", description="Node ID for workflow context")
current: str = Field(default="", description="Current instruction text")
language: str = Field(default="javascript", description="Programming language (javascript/python)")
instruction: str = Field(..., description="Instruction for generation")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration")
ideal_output: str = Field(default="", description="Expected ideal output")
@ -94,6 +80,7 @@ reg(RuleStructuredOutputPayload)
reg(InstructionGeneratePayload)
reg(InstructionTemplatePayload)
reg(FlowchartGeneratePayload)
reg(ModelConfig)
@console_ns.route("/rule-generate")
@ -112,12 +99,7 @@ class RuleGenerateApi(Resource):
_, current_tenant_id = current_account_with_tenant()
try:
rules = LLMGenerator.generate_rule_config(
tenant_id=current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=args.no_variable,
)
rules = LLMGenerator.generate_rule_config(tenant_id=current_tenant_id, args=args)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
@ -148,9 +130,7 @@ class RuleCodeGenerateApi(Resource):
try:
code_result = LLMGenerator.generate_code(
tenant_id=current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
code_language=args.code_language,
args=args,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -182,8 +162,7 @@ class RuleStructuredOutputGenerateApi(Resource):
try:
structured_output = LLMGenerator.generate_structured_output(
tenant_id=current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
args=args,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -234,23 +213,29 @@ class InstructionGenerateApi(Resource):
case "llm":
return LLMGenerator.generate_rule_config(
current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=True,
args=RuleGeneratePayload(
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=True,
),
)
case "agent":
return LLMGenerator.generate_rule_config(
current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=True,
args=RuleGeneratePayload(
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=True,
),
)
case "code":
return LLMGenerator.generate_code(
tenant_id=current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
code_language=args.language,
args=RuleCodeGeneratePayload(
instruction=args.instruction,
model_config=args.model_config_data,
code_language=args.language,
),
)
case _:
return {"error": f"invalid node type: {node_type}"}

View File

@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, field_validator
from sqlalchemy import exists, select
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import (
CompletionRequestError,
@ -35,7 +36,6 @@ from services.errors.message import MessageNotExistsError, SuggestedQuestionsAft
from services.message_service import MessageService
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ChatMessagesQuery(BaseModel):
@ -90,13 +90,22 @@ class FeedbackExportQuery(BaseModel):
raise ValueError("has_comment must be a boolean value")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
class AnnotationCountResponse(BaseModel):
count: int = Field(description="Number of annotations")
reg(ChatMessagesQuery)
reg(MessageFeedbackPayload)
reg(FeedbackExportQuery)
class SuggestedQuestionsResponse(BaseModel):
data: list[str] = Field(description="Suggested question")
register_schema_models(
console_ns,
ChatMessagesQuery,
MessageFeedbackPayload,
FeedbackExportQuery,
AnnotationCountResponse,
SuggestedQuestionsResponse,
)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
@ -231,7 +240,7 @@ class ChatMessageListApi(Resource):
@marshal_with(message_infinite_scroll_pagination_model)
@edit_permission_required
def get(self, app_model):
args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = ChatMessagesQuery.model_validate(request.args.to_dict())
conversation = (
db.session.query(Conversation)
@ -356,7 +365,7 @@ class MessageAnnotationCountApi(Resource):
@console_ns.response(
200,
"Annotation count retrieved successfully",
console_ns.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}),
console_ns.models[AnnotationCountResponse.__name__],
)
@get_app_model
@setup_required
@ -376,9 +385,7 @@ class MessageSuggestedQuestionApi(Resource):
@console_ns.response(
200,
"Suggested questions retrieved successfully",
console_ns.model(
"SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))}
),
console_ns.models[SuggestedQuestionsResponse.__name__],
)
@console_ns.response(404, "Message or conversation not found")
@setup_required
@ -428,7 +435,7 @@ class MessageFeedbackExportApi(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = FeedbackExportQuery.model_validate(request.args.to_dict())
# Import the service function
from services.feedback_service import FeedbackService

View File

@ -12,6 +12,7 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.console import console_ns
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.workflow_run import workflow_run_node_execution_model
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
@ -35,7 +36,6 @@ from extensions.ext_database import db
from factories import file_factory, variable_factory
from fields.member_fields import simple_account_fields
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.helper import TimestampField, uuid_value
@ -88,26 +88,6 @@ workflow_pagination_fields_copy = workflow_pagination_fields.copy()
workflow_pagination_fields_copy["items"] = fields.List(fields.Nested(workflow_model), attribute="items")
workflow_pagination_model = console_ns.model("WorkflowPagination", workflow_pagination_fields_copy)
# Reuse workflow_run_node_execution_model from workflow_run.py if already registered
# Otherwise register it here
from fields.end_user_fields import simple_end_user_fields
simple_end_user_model = None
try:
simple_end_user_model = console_ns.models.get("SimpleEndUser")
except AttributeError:
pass
if simple_end_user_model is None:
simple_end_user_model = console_ns.model("SimpleEndUser", simple_end_user_fields)
workflow_run_node_execution_model = None
try:
workflow_run_node_execution_model = console_ns.models.get("WorkflowRunNodeExecution")
except AttributeError:
pass
if workflow_run_node_execution_model is None:
workflow_run_node_execution_model = console_ns.model("WorkflowRunNodeExecution", workflow_run_node_execution_fields)
class SyncDraftWorkflowPayload(BaseModel):
graph: dict[str, Any]
@ -470,7 +450,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
Run draft workflow loop node
"""
current_user, _ = current_account_with_tenant()
args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
args = LoopNodeRunPayload.model_validate(console_ns.payload or {})
try:
response = AppGenerateService.generate_single_loop(
@ -508,7 +488,7 @@ class WorkflowDraftRunLoopNodeApi(Resource):
Run draft workflow loop node
"""
current_user, _ = current_account_with_tenant()
args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
args = LoopNodeRunPayload.model_validate(console_ns.payload or {})
try:
response = AppGenerateService.generate_single_loop(
@ -999,6 +979,7 @@ class DraftWorkflowTriggerRunApi(Resource):
if not event:
return jsonable_encoder({"status": "waiting", "retry_in": LISTENING_RETRY_IN})
workflow_args = dict(event.workflow_args)
workflow_args[SKIP_PREPARE_USER_INPUTS_KEY] = True
return helper.compact_generate_response(
AppGenerateService.generate(
@ -1147,6 +1128,7 @@ class DraftWorkflowTriggerRunAllApi(Resource):
try:
workflow_args = dict(trigger_debug_event.workflow_args)
workflow_args[SKIP_PREPARE_USER_INPUTS_KEY] = True
response = AppGenerateService.generate(
app_model=app_model,

View File

@ -11,7 +11,10 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.workflow.enums import WorkflowExecutionStatus
from extensions.ext_database import db
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
from fields.workflow_app_log_fields import (
build_workflow_app_log_pagination_model,
build_workflow_archived_log_pagination_model,
)
from libs.login import login_required
from models import App
from models.model import AppMode
@ -61,6 +64,7 @@ console_ns.schema_model(
# Register model for flask_restx to avoid dict type issues in Swagger
workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns)
workflow_archived_log_pagination_model = build_workflow_archived_log_pagination_model(console_ns)
@console_ns.route("/apps/<uuid:app_id>/workflow-app-logs")
@ -99,3 +103,33 @@ class WorkflowAppLogApi(Resource):
)
return workflow_app_log_pagination
@console_ns.route("/apps/<uuid:app_id>/workflow-archived-logs")
class WorkflowArchivedLogApi(Resource):
@console_ns.doc("get_workflow_archived_logs")
@console_ns.doc(description="Get workflow archived execution logs")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
@console_ns.response(200, "Workflow archived logs retrieved successfully", workflow_archived_log_pagination_model)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@marshal_with(workflow_archived_log_pagination_model)
def get(self, app_model: App):
"""
Get workflow archived logs
"""
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
workflow_app_service = WorkflowAppService()
with Session(db.engine) as session:
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_archive_logs(
session=session,
app_model=app_model,
page=args.page,
limit=args.limit,
)
return workflow_app_log_pagination

View File

@ -1,12 +1,15 @@
from datetime import UTC, datetime, timedelta
from typing import Literal, cast
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from fields.end_user_fields import simple_end_user_fields
from fields.member_fields import simple_account_fields
from fields.workflow_run_fields import (
@ -19,14 +22,17 @@ from fields.workflow_run_fields import (
workflow_run_node_execution_list_fields,
workflow_run_pagination_fields,
)
from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage
from libs.custom_inputs import time_duration
from libs.helper import uuid_value
from libs.login import current_user, login_required
from models import Account, App, AppMode, EndUser, WorkflowRunTriggeredFrom
from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowRunTriggeredFrom
from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
from services.workflow_run_service import WorkflowRunService
# Workflow run status choices for filtering
WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"]
EXPORT_SIGNED_URL_EXPIRE_SECONDS = 3600
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
@ -93,6 +99,15 @@ workflow_run_node_execution_list_model = console_ns.model(
"WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy
)
workflow_run_export_fields = console_ns.model(
"WorkflowRunExport",
{
"status": fields.String(description="Export status: success/failed"),
"presigned_url": fields.String(description="Pre-signed URL for download", required=False),
"presigned_url_expires_at": fields.String(description="Pre-signed URL expiration time", required=False),
},
)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@ -181,6 +196,56 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
return result
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>/export")
class WorkflowRunExportApi(Resource):
@console_ns.doc("get_workflow_run_export_url")
@console_ns.doc(description="Generate a download URL for an archived workflow run.")
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
@console_ns.response(200, "Export URL generated", workflow_run_export_fields)
@setup_required
@login_required
@account_initialization_required
@get_app_model()
def get(self, app_model: App, run_id: str):
tenant_id = str(app_model.tenant_id)
app_id = str(app_model.id)
run_id_str = str(run_id)
run_created_at = db.session.scalar(
select(WorkflowArchiveLog.run_created_at)
.where(
WorkflowArchiveLog.tenant_id == tenant_id,
WorkflowArchiveLog.app_id == app_id,
WorkflowArchiveLog.workflow_run_id == run_id_str,
)
.limit(1)
)
if not run_created_at:
return {"code": "archive_log_not_found", "message": "workflow run archive not found"}, 404
prefix = (
f"{tenant_id}/app_id={app_id}/year={run_created_at.strftime('%Y')}/"
f"month={run_created_at.strftime('%m')}/workflow_run_id={run_id_str}"
)
archive_key = f"{prefix}/{ARCHIVE_BUNDLE_NAME}"
try:
archive_storage = get_archive_storage()
except ArchiveStorageNotConfiguredError as e:
return {"code": "archive_storage_not_configured", "message": str(e)}, 500
presigned_url = archive_storage.generate_presigned_url(
archive_key,
expires_in=EXPORT_SIGNED_URL_EXPIRE_SECONDS,
)
expires_at = datetime.now(UTC) + timedelta(seconds=EXPORT_SIGNED_URL_EXPIRE_SECONDS)
return {
"status": "success",
"presigned_url": presigned_url,
"presigned_url_expires_at": expires_at.isoformat(),
}, 200
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs/count")
class AdvancedChatAppWorkflowRunCountApi(Resource):
@console_ns.doc("get_advanced_chat_workflow_runs_count")

View File

@ -1,13 +1,14 @@
import logging
from flask import request
from flask_restx import Resource, marshal_with
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from configs import dify_config
from controllers.common.schema import get_or_create_model
from extensions.ext_database import db
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
from libs.login import current_user, login_required
@ -22,6 +23,14 @@ from ..wraps import account_initialization_required, edit_permission_required, s
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
trigger_model = get_or_create_model("WorkflowTrigger", trigger_fields)
triggers_list_fields_copy = triggers_list_fields.copy()
triggers_list_fields_copy["data"] = fields.List(fields.Nested(trigger_model))
triggers_list_model = get_or_create_model("WorkflowTriggerList", triggers_list_fields_copy)
webhook_trigger_model = get_or_create_model("WebhookTrigger", webhook_trigger_fields)
class Parser(BaseModel):
node_id: str
@ -48,7 +57,7 @@ class WebhookTriggerApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(webhook_trigger_fields)
@marshal_with(webhook_trigger_model)
def get(self, app_model: App):
"""Get webhook trigger for a node"""
args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore
@ -80,7 +89,7 @@ class AppTriggersApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(triggers_list_fields)
@marshal_with(triggers_list_model)
def get(self, app_model: App):
"""Get app triggers list"""
assert isinstance(current_user, Account)
@ -120,7 +129,7 @@ class AppTriggerEnableApi(Resource):
@account_initialization_required
@edit_permission_required
@get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(trigger_fields)
@marshal_with(trigger_model)
def post(self, app_model: App):
"""Update app trigger (enable/disable)"""
args = ParserEnable.model_validate(console_ns.payload)

View File

@ -2,9 +2,11 @@ import logging
import httpx
from flask import current_app, redirect, request
from flask_restx import Resource, fields
from flask_restx import Resource
from pydantic import BaseModel, Field
from configs import dify_config
from controllers.common.schema import register_schema_models
from libs.login import login_required
from libs.oauth_data_source import NotionOAuth
@ -14,6 +16,26 @@ from ..wraps import account_initialization_required, is_admin_or_owner_required,
logger = logging.getLogger(__name__)
class OAuthDataSourceResponse(BaseModel):
data: str = Field(description="Authorization URL or 'internal' for internal setup")
class OAuthDataSourceBindingResponse(BaseModel):
result: str = Field(description="Operation result")
class OAuthDataSourceSyncResponse(BaseModel):
result: str = Field(description="Operation result")
register_schema_models(
console_ns,
OAuthDataSourceResponse,
OAuthDataSourceBindingResponse,
OAuthDataSourceSyncResponse,
)
def get_oauth_providers():
with current_app.app_context():
notion_oauth = NotionOAuth(
@ -34,10 +56,7 @@ class OAuthDataSource(Resource):
@console_ns.response(
200,
"Authorization URL or internal setup success",
console_ns.model(
"OAuthDataSourceResponse",
{"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")},
),
console_ns.models[OAuthDataSourceResponse.__name__],
)
@console_ns.response(400, "Invalid provider")
@console_ns.response(403, "Admin privileges required")
@ -101,7 +120,7 @@ class OAuthDataSourceBinding(Resource):
@console_ns.response(
200,
"Data source binding success",
console_ns.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}),
console_ns.models[OAuthDataSourceBindingResponse.__name__],
)
@console_ns.response(400, "Invalid provider or code")
def get(self, provider: str):
@ -133,7 +152,7 @@ class OAuthDataSourceSync(Resource):
@console_ns.response(
200,
"Data source sync success",
console_ns.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}),
console_ns.models[OAuthDataSourceSyncResponse.__name__],
)
@console_ns.response(400, "Invalid provider or sync failed")
@setup_required

View File

@ -2,10 +2,11 @@ import base64
import secrets
from flask import request
from flask_restx import Resource, fields
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.auth.error import (
EmailCodeError,
@ -48,8 +49,31 @@ class ForgotPasswordResetPayload(BaseModel):
return valid_password(value)
for model in (ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
class ForgotPasswordEmailResponse(BaseModel):
result: str = Field(description="Operation result")
data: str | None = Field(default=None, description="Reset token")
code: str | None = Field(default=None, description="Error code if account not found")
class ForgotPasswordCheckResponse(BaseModel):
is_valid: bool = Field(description="Whether code is valid")
email: EmailStr = Field(description="Email address")
token: str = Field(description="New reset token")
class ForgotPasswordResetResponse(BaseModel):
result: str = Field(description="Operation result")
register_schema_models(
console_ns,
ForgotPasswordSendPayload,
ForgotPasswordCheckPayload,
ForgotPasswordResetPayload,
ForgotPasswordEmailResponse,
ForgotPasswordCheckResponse,
ForgotPasswordResetResponse,
)
@console_ns.route("/forgot-password")
@ -60,14 +84,7 @@ class ForgotPasswordSendEmailApi(Resource):
@console_ns.response(
200,
"Email sent successfully",
console_ns.model(
"ForgotPasswordEmailResponse",
{
"result": fields.String(description="Operation result"),
"data": fields.String(description="Reset token"),
"code": fields.String(description="Error code if account not found"),
},
),
console_ns.models[ForgotPasswordEmailResponse.__name__],
)
@console_ns.response(400, "Invalid email or rate limit exceeded")
@setup_required
@ -106,14 +123,7 @@ class ForgotPasswordCheckApi(Resource):
@console_ns.response(
200,
"Code verified successfully",
console_ns.model(
"ForgotPasswordCheckResponse",
{
"is_valid": fields.Boolean(description="Whether code is valid"),
"email": fields.String(description="Email address"),
"token": fields.String(description="New reset token"),
},
),
console_ns.models[ForgotPasswordCheckResponse.__name__],
)
@console_ns.response(400, "Invalid code or token")
@setup_required
@ -163,7 +173,7 @@ class ForgotPasswordResetApi(Resource):
@console_ns.response(
200,
"Password reset successfully",
console_ns.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}),
console_ns.models[ForgotPasswordResetResponse.__name__],
)
@console_ns.response(400, "Invalid token or password mismatch")
@setup_required

View File

@ -155,43 +155,43 @@ class OAuthServerUserTokenApi(Resource):
grant_type = OAuthGrantType(payload.grant_type)
except ValueError:
raise BadRequest("invalid grant_type")
match grant_type:
case OAuthGrantType.AUTHORIZATION_CODE:
if not payload.code:
raise BadRequest("code is required")
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
if not payload.code:
raise BadRequest("code is required")
if payload.client_secret != oauth_provider_app.client_secret:
raise BadRequest("client_secret is invalid")
if payload.client_secret != oauth_provider_app.client_secret:
raise BadRequest("client_secret is invalid")
if payload.redirect_uri not in oauth_provider_app.redirect_uris:
raise BadRequest("redirect_uri is invalid")
if payload.redirect_uri not in oauth_provider_app.redirect_uris:
raise BadRequest("redirect_uri is invalid")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, code=payload.code, client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
"access_token": access_token,
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token,
}
)
case OAuthGrantType.REFRESH_TOKEN:
if not payload.refresh_token:
raise BadRequest("refresh_token is required")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, code=payload.code, client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
"access_token": access_token,
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token,
}
)
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
if not payload.refresh_token:
raise BadRequest("refresh_token is required")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
"access_token": access_token,
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token,
}
)
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
"access_token": access_token,
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token,
}
)
@console_ns.route("/oauth/provider/account")

View File

@ -1,15 +1,15 @@
import json
from collections.abc import Generator
from typing import Any, cast
from typing import Any, Literal, cast
from flask import request
from flask_restx import Resource, marshal_with
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_model
from controllers.common.schema import get_or_create_model, register_schema_model
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.indexing_runner import IndexingRunner
@ -17,7 +17,14 @@ from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
from fields.data_source_fields import (
integrate_fields,
integrate_icon_fields,
integrate_list_fields,
integrate_notion_info_list_fields,
integrate_page_fields,
integrate_workspace_fields,
)
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import DataSourceOauthBinding, Document
@ -36,9 +43,62 @@ class NotionEstimatePayload(BaseModel):
doc_language: str = Field(default="English")
class DataSourceNotionListQuery(BaseModel):
dataset_id: str | None = Field(default=None, description="Dataset ID")
credential_id: str = Field(..., description="Credential ID", min_length=1)
datasource_parameters: dict[str, Any] | None = Field(default=None, description="Datasource parameters JSON string")
class DataSourceNotionPreviewQuery(BaseModel):
credential_id: str = Field(..., description="Credential ID", min_length=1)
register_schema_model(console_ns, NotionEstimatePayload)
integrate_icon_model = get_or_create_model("DataSourceIntegrateIcon", integrate_icon_fields)
integrate_page_fields_copy = integrate_page_fields.copy()
integrate_page_fields_copy["page_icon"] = fields.Nested(integrate_icon_model, allow_null=True)
integrate_page_model = get_or_create_model("DataSourceIntegratePage", integrate_page_fields_copy)
integrate_workspace_fields_copy = integrate_workspace_fields.copy()
integrate_workspace_fields_copy["pages"] = fields.List(fields.Nested(integrate_page_model))
integrate_workspace_model = get_or_create_model("DataSourceIntegrateWorkspace", integrate_workspace_fields_copy)
integrate_fields_copy = integrate_fields.copy()
integrate_fields_copy["source_info"] = fields.Nested(integrate_workspace_model)
integrate_model = get_or_create_model("DataSourceIntegrate", integrate_fields_copy)
integrate_list_fields_copy = integrate_list_fields.copy()
integrate_list_fields_copy["data"] = fields.List(fields.Nested(integrate_model))
integrate_list_model = get_or_create_model("DataSourceIntegrateList", integrate_list_fields_copy)
notion_page_fields = {
"page_name": fields.String,
"page_id": fields.String,
"page_icon": fields.Nested(integrate_icon_model, allow_null=True),
"is_bound": fields.Boolean,
"parent_id": fields.String,
"type": fields.String,
}
notion_page_model = get_or_create_model("NotionIntegratePage", notion_page_fields)
notion_workspace_fields = {
"workspace_name": fields.String,
"workspace_id": fields.String,
"workspace_icon": fields.String,
"pages": fields.List(fields.Nested(notion_page_model)),
}
notion_workspace_model = get_or_create_model("NotionIntegrateWorkspace", notion_workspace_fields)
integrate_notion_info_list_fields_copy = integrate_notion_info_list_fields.copy()
integrate_notion_info_list_fields_copy["notion_info"] = fields.List(fields.Nested(notion_workspace_model))
integrate_notion_info_list_model = get_or_create_model(
"NotionIntegrateInfoList", integrate_notion_info_list_fields_copy
)
@console_ns.route(
"/data-source/integrates",
"/data-source/integrates/<uuid:binding_id>/<string:action>",
@ -47,7 +107,7 @@ class DataSourceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(integrate_list_fields)
@marshal_with(integrate_list_model)
def get(self):
_, current_tenant_id = current_account_with_tenant()
@ -97,9 +157,8 @@ class DataSourceApi(Resource):
@setup_required
@login_required
@account_initialization_required
def patch(self, binding_id, action):
def patch(self, binding_id, action: Literal["enable", "disable"]):
binding_id = str(binding_id)
action = str(action)
with Session(db.engine) as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).filter_by(id=binding_id)
@ -107,23 +166,24 @@ class DataSourceApi(Resource):
if data_source_binding is None:
raise NotFound("Data source binding not found.")
# enable binding
if action == "enable":
if data_source_binding.disabled:
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.add(data_source_binding)
db.session.commit()
else:
raise ValueError("Data source is not disabled.")
# disable binding
if action == "disable":
if not data_source_binding.disabled:
data_source_binding.disabled = True
data_source_binding.updated_at = naive_utc_now()
db.session.add(data_source_binding)
db.session.commit()
else:
raise ValueError("Data source is disabled.")
match action:
case "enable":
if data_source_binding.disabled:
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.add(data_source_binding)
db.session.commit()
else:
raise ValueError("Data source is not disabled.")
# disable binding
case "disable":
if not data_source_binding.disabled:
data_source_binding.disabled = True
data_source_binding.updated_at = naive_utc_now()
db.session.add(data_source_binding)
db.session.commit()
else:
raise ValueError("Data source is disabled.")
return {"result": "success"}, 200
@ -132,30 +192,19 @@ class DataSourceNotionListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(integrate_notion_info_list_fields)
@marshal_with(integrate_notion_info_list_model)
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = request.args.get("dataset_id", default=None, type=str)
credential_id = request.args.get("credential_id", default=None, type=str)
if not credential_id:
raise ValueError("Credential id is required.")
query = DataSourceNotionListQuery.model_validate(request.args.to_dict())
# Get datasource_parameters from query string (optional, for GitHub and other datasources)
datasource_parameters_str = request.args.get("datasource_parameters", default=None, type=str)
datasource_parameters = {}
if datasource_parameters_str:
try:
datasource_parameters = json.loads(datasource_parameters_str)
if not isinstance(datasource_parameters, dict):
raise ValueError("datasource_parameters must be a JSON object.")
except json.JSONDecodeError:
raise ValueError("Invalid datasource_parameters JSON format.")
datasource_parameters = query.datasource_parameters or {}
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_tenant_id,
credential_id=credential_id,
credential_id=query.credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
@ -164,8 +213,8 @@ class DataSourceNotionListApi(Resource):
exist_page_ids = []
with Session(db.engine) as session:
# import notion in the exist dataset
if dataset_id:
dataset = DatasetService.get_dataset(dataset_id)
if query.dataset_id:
dataset = DatasetService.get_dataset(query.dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
if dataset.data_source_type != "notion_import":
@ -173,7 +222,7 @@ class DataSourceNotionListApi(Resource):
documents = session.scalars(
select(Document).filter_by(
dataset_id=dataset_id,
dataset_id=query.dataset_id,
tenant_id=current_tenant_id,
data_source_type="notion_import",
enabled=True,
@ -240,13 +289,12 @@ class DataSourceNotionApi(Resource):
def get(self, page_id, page_type):
_, current_tenant_id = current_account_with_tenant()
credential_id = request.args.get("credential_id", default=None, type=str)
if not credential_id:
raise ValueError("Credential id is required.")
query = DataSourceNotionPreviewQuery.model_validate(request.args.to_dict())
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_tenant_id,
credential_id=credential_id,
credential_id=query.credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)

View File

@ -8,7 +8,7 @@ from werkzeug.exceptions import Forbidden, NotFound
import services
from configs import dify_config
from controllers.common.schema import register_schema_models
from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
from controllers.console.apikey import (
api_key_item_model,
@ -34,6 +34,7 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from fields.app_fields import app_detail_kernel_fields, related_app_list
from fields.dataset_fields import (
content_fields,
dataset_detail_fields,
dataset_fields,
dataset_query_detail_fields,
@ -41,6 +42,7 @@ from fields.dataset_fields import (
doc_metadata_fields,
external_knowledge_info_fields,
external_retrieval_model_fields,
file_info_fields,
icon_info_fields,
keyword_setting_fields,
reranking_model_fields,
@ -55,41 +57,33 @@ from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
def _get_or_create_model(model_name: str, field_def):
existing = console_ns.models.get(model_name)
if existing is None:
existing = console_ns.model(model_name, field_def)
return existing
# Register models for flask_restx to avoid dict type issues in Swagger
dataset_base_model = _get_or_create_model("DatasetBase", dataset_fields)
dataset_base_model = get_or_create_model("DatasetBase", dataset_fields)
tag_model = _get_or_create_model("Tag", tag_fields)
tag_model = get_or_create_model("Tag", tag_fields)
keyword_setting_model = _get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
vector_setting_model = _get_or_create_model("DatasetVectorSetting", vector_setting_fields)
keyword_setting_model = get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
vector_setting_model = get_or_create_model("DatasetVectorSetting", vector_setting_fields)
weighted_score_fields_copy = weighted_score_fields.copy()
weighted_score_fields_copy["keyword_setting"] = fields.Nested(keyword_setting_model)
weighted_score_fields_copy["vector_setting"] = fields.Nested(vector_setting_model)
weighted_score_model = _get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
weighted_score_model = get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
reranking_model = _get_or_create_model("DatasetRerankingModel", reranking_model_fields)
reranking_model = get_or_create_model("DatasetRerankingModel", reranking_model_fields)
dataset_retrieval_model_fields_copy = dataset_retrieval_model_fields.copy()
dataset_retrieval_model_fields_copy["reranking_model"] = fields.Nested(reranking_model)
dataset_retrieval_model_fields_copy["weights"] = fields.Nested(weighted_score_model, allow_null=True)
dataset_retrieval_model = _get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
dataset_retrieval_model = get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
external_knowledge_info_model = _get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
external_knowledge_info_model = get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
external_retrieval_model = _get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
external_retrieval_model = get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
doc_metadata_model = _get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
doc_metadata_model = get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
icon_info_model = _get_or_create_model("DatasetIconInfo", icon_info_fields)
icon_info_model = get_or_create_model("DatasetIconInfo", icon_info_fields)
dataset_detail_fields_copy = dataset_detail_fields.copy()
dataset_detail_fields_copy["retrieval_model_dict"] = fields.Nested(dataset_retrieval_model)
@ -98,14 +92,22 @@ dataset_detail_fields_copy["external_knowledge_info"] = fields.Nested(external_k
dataset_detail_fields_copy["external_retrieval_model"] = fields.Nested(external_retrieval_model, allow_null=True)
dataset_detail_fields_copy["doc_metadata"] = fields.List(fields.Nested(doc_metadata_model))
dataset_detail_fields_copy["icon_info"] = fields.Nested(icon_info_model)
dataset_detail_model = _get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
dataset_detail_model = get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
dataset_query_detail_model = _get_or_create_model("DatasetQueryDetail", dataset_query_detail_fields)
file_info_model = get_or_create_model("DatasetFileInfo", file_info_fields)
app_detail_kernel_model = _get_or_create_model("AppDetailKernel", app_detail_kernel_fields)
content_fields_copy = content_fields.copy()
content_fields_copy["file_info"] = fields.Nested(file_info_model, allow_null=True)
content_model = get_or_create_model("DatasetContent", content_fields_copy)
dataset_query_detail_fields_copy = dataset_query_detail_fields.copy()
dataset_query_detail_fields_copy["queries"] = fields.Nested(content_model)
dataset_query_detail_model = get_or_create_model("DatasetQueryDetail", dataset_query_detail_fields_copy)
app_detail_kernel_model = get_or_create_model("AppDetailKernel", app_detail_kernel_fields)
related_app_list_copy = related_app_list.copy()
related_app_list_copy["data"] = fields.List(fields.Nested(app_detail_kernel_model))
related_app_list_model = _get_or_create_model("RelatedAppList", related_app_list_copy)
related_app_list_model = get_or_create_model("RelatedAppList", related_app_list_copy)
def _validate_indexing_technique(value: str | None) -> str | None:
@ -146,6 +148,7 @@ class DatasetUpdatePayload(BaseModel):
embedding_model: str | None = None
embedding_model_provider: str | None = None
retrieval_model: dict[str, Any] | None = None
summary_index_setting: dict[str, Any] | None = None
partial_member_list: list[dict[str, str]] | None = None
external_retrieval_model: dict[str, Any] | None = None
external_knowledge_id: str | None = None
@ -176,7 +179,18 @@ class IndexingEstimatePayload(BaseModel):
return result
register_schema_models(console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload)
class ConsoleDatasetListQuery(BaseModel):
page: int = Field(default=1, description="Page number")
limit: int = Field(default=20, description="Number of items per page")
keyword: str | None = Field(default=None, description="Search keyword")
include_all: bool = Field(default=False, description="Include all datasets")
ids: list[str] = Field(default_factory=list, description="Filter by dataset IDs")
tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs")
register_schema_models(
console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload, ConsoleDatasetListQuery
)
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
@ -275,18 +289,26 @@ class DatasetListApi(Resource):
@enterprise_license_required
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
ids = request.args.getlist("ids")
# Convert query parameters to dict, handling list parameters correctly
query_params: dict[str, str | list[str]] = dict(request.args.to_dict())
# Handle ids and tag_ids as lists (Flask request.args.getlist returns list even for single value)
if "ids" in request.args:
query_params["ids"] = request.args.getlist("ids")
if "tag_ids" in request.args:
query_params["tag_ids"] = request.args.getlist("tag_ids")
query = ConsoleDatasetListQuery.model_validate(query_params)
# provider = request.args.get("provider", default="vendor")
search = request.args.get("keyword", default=None, type=str)
tag_ids = request.args.getlist("tag_ids")
include_all = request.args.get("include_all", default="false").lower() == "true"
if ids:
datasets, total = DatasetService.get_datasets_by_ids(ids, current_tenant_id)
if query.ids:
datasets, total = DatasetService.get_datasets_by_ids(query.ids, current_tenant_id)
else:
datasets, total = DatasetService.get_datasets(
page, limit, current_tenant_id, current_user, search, tag_ids, include_all
query.page,
query.limit,
current_tenant_id,
current_user,
query.keyword,
query.tag_ids,
query.include_all,
)
# check embedding setting
@ -318,7 +340,13 @@ class DatasetListApi(Resource):
else:
item.update({"partial_member_list": []})
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
response = {
"data": data,
"has_more": len(datasets) == query.limit,
"limit": query.limit,
"total": total,
"page": query.page,
}
return response, 200
@console_ns.doc("create_dataset")

View File

@ -14,7 +14,7 @@ from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.common.schema import register_schema_models
from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
from core.errors.error import (
LLMBadRequestError,
@ -45,6 +45,7 @@ from models.dataset import DocumentPipelineExecutionLog
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
from services.file_service import FileService
from tasks.generate_summary_index_task import generate_summary_index_task
from ..app.error import (
ProviderModelCurrentlyNotSupportError,
@ -72,34 +73,27 @@ logger = logging.getLogger(__name__)
DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100
def _get_or_create_model(model_name: str, field_def):
existing = console_ns.models.get(model_name)
if existing is None:
existing = console_ns.model(model_name, field_def)
return existing
# Register models for flask_restx to avoid dict type issues in Swagger
dataset_model = _get_or_create_model("Dataset", dataset_fields)
dataset_model = get_or_create_model("Dataset", dataset_fields)
document_metadata_model = _get_or_create_model("DocumentMetadata", document_metadata_fields)
document_metadata_model = get_or_create_model("DocumentMetadata", document_metadata_fields)
document_fields_copy = document_fields.copy()
document_fields_copy["doc_metadata"] = fields.List(
fields.Nested(document_metadata_model), attribute="doc_metadata_details"
)
document_model = _get_or_create_model("Document", document_fields_copy)
document_model = get_or_create_model("Document", document_fields_copy)
document_with_segments_fields_copy = document_with_segments_fields.copy()
document_with_segments_fields_copy["doc_metadata"] = fields.List(
fields.Nested(document_metadata_model), attribute="doc_metadata_details"
)
document_with_segments_model = _get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy)
document_with_segments_model = get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy)
dataset_and_document_fields_copy = dataset_and_document_fields.copy()
dataset_and_document_fields_copy["dataset"] = fields.Nested(dataset_model)
dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(document_model))
dataset_and_document_model = _get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
dataset_and_document_model = get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
class DocumentRetryPayload(BaseModel):
@ -110,6 +104,10 @@ class DocumentRenamePayload(BaseModel):
name: str
class GenerateSummaryPayload(BaseModel):
document_list: list[str]
class DocumentBatchDownloadZipPayload(BaseModel):
"""Request payload for bulk downloading documents as a zip archive."""
@ -132,6 +130,7 @@ register_schema_models(
RetrievalModel,
DocumentRetryPayload,
DocumentRenamePayload,
GenerateSummaryPayload,
DocumentBatchDownloadZipPayload,
)
@ -319,6 +318,13 @@ class DatasetDocumentListApi(Resource):
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items
DocumentService.enrich_documents_with_summary_index_status(
documents=documents,
dataset=dataset,
tenant_id=current_tenant_id,
)
if fetch:
for document in documents:
completed_segments = (
@ -570,63 +576,62 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
if document.indexing_status in {"completed", "error"}:
raise DocumentAlreadyFinishedError()
data_source_info = document.data_source_info_dict
match document.data_source_type:
case "upload_file":
if not data_source_info:
continue
file_id = data_source_info["upload_file_id"]
file_detail = (
db.session.query(UploadFile)
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
.first()
)
if document.data_source_type == "upload_file":
if not data_source_info:
continue
file_id = data_source_info["upload_file_id"]
file_detail = (
db.session.query(UploadFile)
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
.first()
)
if file_detail is None:
raise NotFound("File not found.")
if file_detail is None:
raise NotFound("File not found.")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form
)
extract_settings.append(extract_setting)
case "notion_import":
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info.get("credential_id"),
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"tenant_id": current_tenant_id,
}
),
document_model=document.doc_form,
)
extract_settings.append(extract_setting)
case "website_crawl":
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"url": data_source_info["url"],
"tenant_id": current_tenant_id,
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
}
),
document_model=document.doc_form,
)
extract_settings.append(extract_setting)
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form
)
extract_settings.append(extract_setting)
elif document.data_source_type == "notion_import":
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info.get("credential_id"),
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"tenant_id": current_tenant_id,
}
),
document_model=document.doc_form,
)
extract_settings.append(extract_setting)
elif document.data_source_type == "website_crawl":
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"url": data_source_info["url"],
"tenant_id": current_tenant_id,
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
}
),
document_model=document.doc_form,
)
extract_settings.append(extract_setting)
else:
raise ValueError("Data source type not support")
case _:
raise ValueError("Data source type not support")
indexing_runner = IndexingRunner()
try:
response = indexing_runner.indexing_estimate(
@ -804,6 +809,7 @@ class DocumentApi(DocumentResource):
"display_status": document.display_status,
"doc_form": document.doc_form,
"doc_language": document.doc_language,
"need_summary": document.need_summary if document.need_summary is not None else False,
}
else:
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
@ -839,6 +845,7 @@ class DocumentApi(DocumentResource):
"display_status": document.display_status,
"doc_form": document.doc_form,
"doc_language": document.doc_language,
"need_summary": document.need_summary if document.need_summary is not None else False,
}
return response, 200
@ -946,23 +953,24 @@ class DocumentProcessingApi(DocumentResource):
if not current_user.is_dataset_editor:
raise Forbidden()
if action == "pause":
if document.indexing_status != "indexing":
raise InvalidActionError("Document not in indexing state.")
match action:
case "pause":
if document.indexing_status != "indexing":
raise InvalidActionError("Document not in indexing state.")
document.paused_by = current_user.id
document.paused_at = naive_utc_now()
document.is_paused = True
db.session.commit()
document.paused_by = current_user.id
document.paused_at = naive_utc_now()
document.is_paused = True
db.session.commit()
elif action == "resume":
if document.indexing_status not in {"paused", "error"}:
raise InvalidActionError("Document not in paused or error state.")
case "resume":
if document.indexing_status not in {"paused", "error"}:
raise InvalidActionError("Document not in paused or error state.")
document.paused_by = None
document.paused_at = None
document.is_paused = False
db.session.commit()
document.paused_by = None
document.paused_at = None
document.is_paused = False
db.session.commit()
return {"result": "success"}, 200
@ -1178,7 +1186,7 @@ class DocumentRenameApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(document_fields)
@marshal_with(document_model)
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
def post(self, dataset_id, document_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
@ -1262,3 +1270,149 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
"input_data": log.input_data,
"datasource_node_id": log.datasource_node_id,
}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/generate-summary")
class DocumentGenerateSummaryApi(Resource):
@console_ns.doc("generate_summary_for_documents")
@console_ns.doc(description="Generate summary index for documents")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect(console_ns.models[GenerateSummaryPayload.__name__])
@console_ns.response(200, "Summary generation started successfully")
@console_ns.response(400, "Invalid request or dataset configuration")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id):
"""
Generate summary index for specified documents.
This endpoint checks if the dataset configuration supports summary generation
(indexing_technique must be 'high_quality' and summary_index_setting.enable must be true),
then asynchronously generates summary indexes for the provided documents.
"""
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
# Get dataset
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# Check permissions
if not current_user.is_dataset_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# Validate request payload
payload = GenerateSummaryPayload.model_validate(console_ns.payload or {})
document_list = payload.document_list
if not document_list:
from werkzeug.exceptions import BadRequest
raise BadRequest("document_list cannot be empty.")
# Check if dataset configuration supports summary generation
if dataset.indexing_technique != "high_quality":
raise ValueError(
f"Summary generation is only available for 'high_quality' indexing technique. "
f"Current indexing technique: {dataset.indexing_technique}"
)
summary_index_setting = dataset.summary_index_setting
if not summary_index_setting or not summary_index_setting.get("enable"):
raise ValueError("Summary index is not enabled for this dataset. Please enable it in the dataset settings.")
# Verify all documents exist and belong to the dataset
documents = DocumentService.get_documents_by_ids(dataset_id, document_list)
if len(documents) != len(document_list):
found_ids = {doc.id for doc in documents}
missing_ids = set(document_list) - found_ids
raise NotFound(f"Some documents not found: {list(missing_ids)}")
# Update need_summary to True for documents that don't have it set
# This handles the case where documents were created when summary_index_setting was disabled
documents_to_update = [doc for doc in documents if not doc.need_summary and doc.doc_form != "qa_model"]
if documents_to_update:
document_ids_to_update = [str(doc.id) for doc in documents_to_update]
DocumentService.update_documents_need_summary(
dataset_id=dataset_id,
document_ids=document_ids_to_update,
need_summary=True,
)
# Dispatch async tasks for each document
for document in documents:
# Skip qa_model documents as they don't generate summaries
if document.doc_form == "qa_model":
logger.info("Skipping summary generation for qa_model document %s", document.id)
continue
# Dispatch async task
generate_summary_index_task.delay(dataset_id, document.id)
logger.info(
"Dispatched summary generation task for document %s in dataset %s",
document.id,
dataset_id,
)
return {"result": "success"}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/summary-status")
class DocumentSummaryStatusApi(DocumentResource):
@console_ns.doc("get_document_summary_status")
@console_ns.doc(description="Get summary index generation status for a document")
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@console_ns.response(200, "Summary status retrieved successfully")
@console_ns.response(404, "Document not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, document_id):
"""
Get summary index generation status for a document.
Returns:
- total_segments: Total number of segments in the document
- summary_status: Dictionary with status counts
- completed: Number of summaries completed
- generating: Number of summaries being generated
- error: Number of summaries with errors
- not_started: Number of segments without summary records
- summaries: List of summary records with status and content preview
"""
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
document_id = str(document_id)
# Get dataset
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# Check permissions
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# Get summary status detail from service
from services.summary_index_service import SummaryIndexService
result = SummaryIndexService.get_document_summary_status_detail(
document_id=document_id,
dataset_id=dataset_id,
)
return result, 200

View File

@ -41,6 +41,17 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
def _get_segment_with_summary(segment, dataset_id):
"""Helper function to marshal segment and add summary information."""
from services.summary_index_service import SummaryIndexService
segment_dict = dict(marshal(segment, segment_fields))
# Query summary for this segment (only enabled summaries)
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
segment_dict["summary"] = summary.summary_content if summary else None
return segment_dict
class SegmentListQuery(BaseModel):
limit: int = Field(default=20, ge=1, le=100)
status: list[str] = Field(default_factory=list)
@ -63,6 +74,7 @@ class SegmentUpdatePayload(BaseModel):
keywords: list[str] | None = None
regenerate_child_chunks: bool = False
attachment_ids: list[str] | None = None
summary: str | None = None # Summary content for summary index
class BatchImportPayload(BaseModel):
@ -90,6 +102,7 @@ register_schema_models(
ChildChunkCreatePayload,
ChildChunkUpdatePayload,
ChildChunkBatchUpdatePayload,
ChildChunkUpdateArgs,
)
@ -180,8 +193,25 @@ class DatasetDocumentSegmentListApi(Resource):
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
# Query summaries for all segments in this page (batch query for efficiency)
segment_ids = [segment.id for segment in segments.items]
summaries = {}
if segment_ids:
from services.summary_index_service import SummaryIndexService
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
# Only include enabled summaries (already filtered by service)
summaries = {chunk_id: summary.summary_content for chunk_id, summary in summary_records.items()}
# Add summary to each segment
segments_with_summary = []
for segment in segments.items:
segment_dict = dict(marshal(segment, segment_fields))
segment_dict["summary"] = summaries.get(segment.id)
segments_with_summary.append(segment_dict)
response = {
"data": marshal(segments.items, segment_fields),
"data": segments_with_summary,
"limit": limit,
"total": segments.total,
"total_pages": segments.pages,
@ -327,7 +357,7 @@ class DatasetDocumentSegmentAddApi(Resource):
payload_dict = payload.model_dump(exclude_none=True)
SegmentService.segment_create_args_validate(payload_dict, document)
segment = SegmentService.create_segment(payload_dict, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
@ -389,10 +419,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
payload = SegmentUpdatePayload.model_validate(console_ns.payload or {})
payload_dict = payload.model_dump(exclude_none=True)
SegmentService.segment_create_args_validate(payload_dict, document)
# Update segment (summary update with change detection is handled in SegmentService.update_segment)
segment = SegmentService.update_segment(
SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
@setup_required
@login_required

View File

@ -4,7 +4,7 @@ from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.common.schema import register_schema_models
from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
@ -28,34 +28,27 @@ from services.hit_testing_service import HitTestingService
from services.knowledge_service import ExternalDatasetTestService
def _get_or_create_model(model_name: str, field_def):
existing = console_ns.models.get(model_name)
if existing is None:
existing = console_ns.model(model_name, field_def)
return existing
def _build_dataset_detail_model():
keyword_setting_model = _get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
vector_setting_model = _get_or_create_model("DatasetVectorSetting", vector_setting_fields)
keyword_setting_model = get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
vector_setting_model = get_or_create_model("DatasetVectorSetting", vector_setting_fields)
weighted_score_fields_copy = weighted_score_fields.copy()
weighted_score_fields_copy["keyword_setting"] = fields.Nested(keyword_setting_model)
weighted_score_fields_copy["vector_setting"] = fields.Nested(vector_setting_model)
weighted_score_model = _get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
weighted_score_model = get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
reranking_model = _get_or_create_model("DatasetRerankingModel", reranking_model_fields)
reranking_model = get_or_create_model("DatasetRerankingModel", reranking_model_fields)
dataset_retrieval_model_fields_copy = dataset_retrieval_model_fields.copy()
dataset_retrieval_model_fields_copy["reranking_model"] = fields.Nested(reranking_model)
dataset_retrieval_model_fields_copy["weights"] = fields.Nested(weighted_score_model, allow_null=True)
dataset_retrieval_model = _get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
dataset_retrieval_model = get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
tag_model = _get_or_create_model("Tag", tag_fields)
doc_metadata_model = _get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
external_knowledge_info_model = _get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
external_retrieval_model = _get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
icon_info_model = _get_or_create_model("DatasetIconInfo", icon_info_fields)
tag_model = get_or_create_model("Tag", tag_fields)
doc_metadata_model = get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
external_knowledge_info_model = get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
external_retrieval_model = get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
icon_info_model = get_or_create_model("DatasetIconInfo", icon_info_fields)
dataset_detail_fields_copy = dataset_detail_fields.copy()
dataset_detail_fields_copy["retrieval_model_dict"] = fields.Nested(dataset_retrieval_model)
@ -64,7 +57,7 @@ def _build_dataset_detail_model():
dataset_detail_fields_copy["external_retrieval_model"] = fields.Nested(external_retrieval_model, allow_null=True)
dataset_detail_fields_copy["doc_metadata"] = fields.List(fields.Nested(doc_metadata_model))
dataset_detail_fields_copy["icon_info"] = fields.Nested(icon_info_model)
return _get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
return get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
try:
@ -98,12 +91,19 @@ class BedrockRetrievalPayload(BaseModel):
knowledge_id: str
class ExternalApiTemplateListQuery(BaseModel):
page: int = Field(default=1, description="Page number")
limit: int = Field(default=20, description="Number of items per page")
keyword: str | None = Field(default=None, description="Search keyword")
register_schema_models(
console_ns,
ExternalKnowledgeApiPayload,
ExternalDatasetCreatePayload,
ExternalHitTestingPayload,
BedrockRetrievalPayload,
ExternalApiTemplateListQuery,
)
@ -124,19 +124,17 @@ class ExternalApiTemplateListApi(Resource):
@account_initialization_required
def get(self):
_, current_tenant_id = current_account_with_tenant()
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str)
query = ExternalApiTemplateListQuery.model_validate(request.args.to_dict())
external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis(
page, limit, current_tenant_id, search
query.page, query.limit, current_tenant_id, query.keyword
)
response = {
"data": [item.to_dict() for item in external_knowledge_apis],
"has_more": len(external_knowledge_apis) == limit,
"limit": limit,
"has_more": len(external_knowledge_apis) == query.limit,
"limit": query.limit,
"total": total,
"page": page,
"page": query.page,
}
return response, 200

View File

@ -1,6 +1,13 @@
from flask_restx import Resource
from flask_restx import Resource, fields
from controllers.common.schema import register_schema_model
from fields.hit_testing_fields import (
child_chunk_fields,
document_fields,
files_fields,
hit_testing_record_fields,
segment_fields,
)
from libs.login import login_required
from .. import console_ns
@ -14,13 +21,45 @@ from ..wraps import (
register_schema_model(console_ns, HitTestingPayload)
def _get_or_create_model(model_name: str, field_def):
"""Get or create a flask_restx model to avoid dict type issues in Swagger."""
existing = console_ns.models.get(model_name)
if existing is None:
existing = console_ns.model(model_name, field_def)
return existing
# Register models for flask_restx to avoid dict type issues in Swagger
document_model = _get_or_create_model("HitTestingDocument", document_fields)
segment_fields_copy = segment_fields.copy()
segment_fields_copy["document"] = fields.Nested(document_model)
segment_model = _get_or_create_model("HitTestingSegment", segment_fields_copy)
child_chunk_model = _get_or_create_model("HitTestingChildChunk", child_chunk_fields)
files_model = _get_or_create_model("HitTestingFile", files_fields)
hit_testing_record_fields_copy = hit_testing_record_fields.copy()
hit_testing_record_fields_copy["segment"] = fields.Nested(segment_model)
hit_testing_record_fields_copy["child_chunks"] = fields.List(fields.Nested(child_chunk_model))
hit_testing_record_fields_copy["files"] = fields.List(fields.Nested(files_model))
hit_testing_record_model = _get_or_create_model("HitTestingRecord", hit_testing_record_fields_copy)
# Response model for hit testing API
hit_testing_response_fields = {
"query": fields.String,
"records": fields.List(fields.Nested(hit_testing_record_model)),
}
hit_testing_response_model = _get_or_create_model("HitTestingResponse", hit_testing_response_fields)
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
class HitTestingApi(Resource, DatasetsHitTestingBase):
@console_ns.doc("test_dataset_retrieval")
@console_ns.doc(description="Test dataset knowledge retrieval")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect(console_ns.models[HitTestingPayload.__name__])
@console_ns.response(200, "Hit testing completed successfully")
@console_ns.response(200, "Hit testing completed successfully", model=hit_testing_response_model)
@console_ns.response(404, "Dataset not found")
@console_ns.response(400, "Invalid parameters")
@setup_required

View File

@ -4,14 +4,16 @@ from flask_restx import Resource, marshal_with
from pydantic import BaseModel
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_model, register_schema_models
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from fields.dataset_fields import dataset_metadata_fields
from libs.login import current_account_with_tenant, login_required
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import (
DocumentMetadataOperation,
MetadataArgs,
MetadataDetail,
MetadataOperationData,
)
from services.metadata_service import MetadataService
@ -21,8 +23,9 @@ class MetadataUpdatePayload(BaseModel):
name: str
register_schema_models(console_ns, MetadataArgs, MetadataOperationData)
register_schema_model(console_ns, MetadataUpdatePayload)
register_schema_models(
console_ns, MetadataArgs, MetadataOperationData, MetadataUpdatePayload, DocumentMetadataOperation, MetadataDetail
)
@console_ns.route("/datasets/<uuid:dataset_id>/metadata")
@ -123,10 +126,11 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
if action == "enable":
MetadataService.enable_built_in_field(dataset)
elif action == "disable":
MetadataService.disable_built_in_field(dataset)
match action:
case "enable":
MetadataService.enable_built_in_field(dataset)
case "disable":
MetadataService.disable_built_in_field(dataset)
return {"result": "success"}, 200

View File

@ -2,7 +2,7 @@ import logging
from typing import Any, NoReturn
from flask import Response, request
from flask_restx import Resource, fields, marshal, marshal_with
from flask_restx import Resource, marshal, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
@ -14,7 +14,9 @@ from controllers.console.app.error import (
)
from controllers.console.app.workflow_draft_variable import (
_WORKFLOW_DRAFT_VARIABLE_FIELDS, # type: ignore[private-usage]
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, # type: ignore[private-usage]
workflow_draft_variable_list_model,
workflow_draft_variable_list_without_value_model,
workflow_draft_variable_model,
)
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required
@ -27,7 +29,6 @@ from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required
from models import Account
from models.dataset import Pipeline
from models.workflow import WorkflowDraftVariable
from services.rag_pipeline.rag_pipeline import RagPipelineService
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
@ -52,20 +53,6 @@ class WorkflowDraftVariablePatchPayload(BaseModel):
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
return var_list.variables
_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = {
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items),
"total": fields.Raw(),
}
_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
}
def _api_prerequisite(f):
"""Common prerequisites for all draft workflow variable APIs.
@ -92,7 +79,7 @@ def _api_prerequisite(f):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables")
class RagPipelineVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
@marshal_with(workflow_draft_variable_list_without_value_model)
def get(self, pipeline: Pipeline):
"""
Get draft workflow
@ -150,7 +137,7 @@ def validate_node_id(node_id: str) -> NoReturn | None:
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/variables")
class RagPipelineNodeVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
@marshal_with(workflow_draft_variable_list_model)
def get(self, pipeline: Pipeline, node_id: str):
validate_node_id(node_id)
with Session(bind=db.engine, expire_on_commit=False) as session:
@ -176,7 +163,7 @@ class RagPipelineVariableApi(Resource):
_PATCH_VALUE_FIELD = "value"
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
@marshal_with(workflow_draft_variable_model)
def get(self, pipeline: Pipeline, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
@ -189,7 +176,7 @@ class RagPipelineVariableApi(Resource):
return variable
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
@marshal_with(workflow_draft_variable_model)
@console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__])
def patch(self, pipeline: Pipeline, variable_id: str):
# Request payload for file types:
@ -307,7 +294,7 @@ def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/system-variables")
class RagPipelineSystemVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
@marshal_with(workflow_draft_variable_list_model)
def get(self, pipeline: Pipeline):
return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID)

View File

@ -1,9 +1,9 @@
from flask import request
from flask_restx import Resource, marshal_with # type: ignore
from flask_restx import Resource, fields, marshal_with # type: ignore
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import (
@ -12,7 +12,11 @@ from controllers.console.wraps import (
setup_required,
)
from extensions.ext_database import db
from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
from fields.rag_pipeline_fields import (
leaked_dependency_fields,
pipeline_import_check_dependencies_fields,
pipeline_import_fields,
)
from libs.login import current_account_with_tenant, login_required
from models.dataset import Pipeline
from services.app_dsl_service import ImportStatus
@ -38,13 +42,25 @@ class IncludeSecretQuery(BaseModel):
register_schema_models(console_ns, RagPipelineImportPayload, IncludeSecretQuery)
pipeline_import_model = get_or_create_model("RagPipelineImport", pipeline_import_fields)
leaked_dependency_model = get_or_create_model("RagPipelineLeakedDependency", leaked_dependency_fields)
pipeline_import_check_dependencies_fields_copy = pipeline_import_check_dependencies_fields.copy()
pipeline_import_check_dependencies_fields_copy["leaked_dependencies"] = fields.List(
fields.Nested(leaked_dependency_model)
)
pipeline_import_check_dependencies_model = get_or_create_model(
"RagPipelineImportCheckDependencies", pipeline_import_check_dependencies_fields_copy
)
@console_ns.route("/rag/pipelines/imports")
class RagPipelineImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@marshal_with(pipeline_import_fields)
@marshal_with(pipeline_import_model)
@console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__])
def post(self):
# Check user role first
@ -81,7 +97,7 @@ class RagPipelineImportConfirmApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
@marshal_with(pipeline_import_fields)
@marshal_with(pipeline_import_model)
def post(self, import_id):
current_user, _ = current_account_with_tenant()
@ -106,7 +122,7 @@ class RagPipelineImportCheckDependenciesApi(Resource):
@get_rag_pipeline
@account_initialization_required
@edit_permission_required
@marshal_with(pipeline_import_check_dependencies_fields)
@marshal_with(pipeline_import_check_dependencies_model)
def get(self, pipeline: Pipeline):
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)

View File

@ -1,10 +1,9 @@
import json
import logging
from typing import Any, Literal, cast
from uuid import UUID
from flask import abort, request
from flask_restx import Resource, marshal_with, reqparse # type: ignore
from flask_restx import Resource, marshal_with # type: ignore
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@ -17,6 +16,13 @@ from controllers.console.app.error import (
DraftWorkflowNotExist,
DraftWorkflowNotSync,
)
from controllers.console.app.workflow import workflow_model, workflow_pagination_model
from controllers.console.app.workflow_run import (
workflow_run_detail_model,
workflow_run_node_execution_list_model,
workflow_run_node_execution_model,
workflow_run_pagination_model,
)
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import (
account_initialization_required,
@ -30,15 +36,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from factories import variable_factory
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
from fields.workflow_run_fields import (
workflow_run_detail_fields,
workflow_run_node_execution_fields,
workflow_run_node_execution_list_fields,
workflow_run_pagination_fields,
)
from libs import helper
from libs.helper import TimestampField
from libs.helper import TimestampField, UUIDStrOrEmpty
from libs.login import current_account_with_tenant, current_user, login_required
from models import Account
from models.dataset import Pipeline
@ -110,7 +109,7 @@ class NodeIdQuery(BaseModel):
class WorkflowRunQuery(BaseModel):
last_id: UUID | None = None
last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
@ -121,6 +120,10 @@ class DatasourceVariablesPayload(BaseModel):
start_node_title: str
class RagPipelineRecommendedPluginQuery(BaseModel):
type: str = "all"
register_schema_models(
console_ns,
DraftWorkflowSyncPayload,
@ -135,6 +138,7 @@ register_schema_models(
NodeIdQuery,
WorkflowRunQuery,
DatasourceVariablesPayload,
RagPipelineRecommendedPluginQuery,
)
@ -145,7 +149,7 @@ class DraftRagPipelineApi(Resource):
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
@marshal_with(workflow_fields)
@marshal_with(workflow_model)
def get(self, pipeline: Pipeline):
"""
Get draft rag pipeline's workflow
@ -521,7 +525,7 @@ class RagPipelineDraftNodeRunApi(Resource):
@edit_permission_required
@account_initialization_required
@get_rag_pipeline
@marshal_with(workflow_run_node_execution_fields)
@marshal_with(workflow_run_node_execution_model)
def post(self, pipeline: Pipeline, node_id: str):
"""
Run draft workflow node
@ -569,7 +573,7 @@ class PublishedRagPipelineApi(Resource):
@account_initialization_required
@edit_permission_required
@get_rag_pipeline
@marshal_with(workflow_fields)
@marshal_with(workflow_model)
def get(self, pipeline: Pipeline):
"""
Get published pipeline
@ -664,7 +668,7 @@ class PublishedAllRagPipelineApi(Resource):
@account_initialization_required
@edit_permission_required
@get_rag_pipeline
@marshal_with(workflow_pagination_fields)
@marshal_with(workflow_pagination_model)
def get(self, pipeline: Pipeline):
"""
Get published workflows
@ -708,7 +712,7 @@ class RagPipelineByIdApi(Resource):
@account_initialization_required
@edit_permission_required
@get_rag_pipeline
@marshal_with(workflow_fields)
@marshal_with(workflow_model)
def patch(self, pipeline: Pipeline, workflow_id: str):
"""
Update workflow attributes
@ -830,7 +834,7 @@ class RagPipelineWorkflowRunListApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@marshal_with(workflow_run_pagination_fields)
@marshal_with(workflow_run_pagination_model)
def get(self, pipeline: Pipeline):
"""
Get workflow run list
@ -858,7 +862,7 @@ class RagPipelineWorkflowRunDetailApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@marshal_with(workflow_run_detail_fields)
@marshal_with(workflow_run_detail_model)
def get(self, pipeline: Pipeline, run_id):
"""
Get workflow run detail
@ -877,7 +881,7 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@marshal_with(workflow_run_node_execution_list_fields)
@marshal_with(workflow_run_node_execution_list_model)
def get(self, pipeline: Pipeline, run_id: str):
"""
Get workflow run node execution list
@ -911,7 +915,7 @@ class RagPipelineWorkflowLastRunApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@marshal_with(workflow_run_node_execution_fields)
@marshal_with(workflow_run_node_execution_model)
def get(self, pipeline: Pipeline, node_id: str):
rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
@ -952,7 +956,7 @@ class RagPipelineDatasourceVariableApi(Resource):
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
@marshal_with(workflow_run_node_execution_fields)
@marshal_with(workflow_run_node_execution_model)
def post(self, pipeline: Pipeline):
"""
Set datasource variables
@ -975,11 +979,8 @@ class RagPipelineRecommendedPluginApi(Resource):
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("type", type=str, location="args", required=False, default="all")
args = parser.parse_args()
type = args["type"]
query = RagPipelineRecommendedPluginQuery.model_validate(request.args.to_dict())
rag_pipeline_service = RagPipelineService()
recommended_plugins = rag_pipeline_service.get_recommended_plugins(type)
recommended_plugins = rag_pipeline_service.get_recommended_plugins(query.type)
return recommended_plugins

View File

@ -2,16 +2,17 @@ import logging
from typing import Any
from flask import request
from flask_restx import Resource, marshal_with
from pydantic import BaseModel
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import and_, select
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
from controllers.common.schema import get_or_create_model
from controllers.console import console_ns
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from extensions.ext_database import db
from fields.installed_app_fields import installed_app_list_fields
from fields.installed_app_fields import app_fields, installed_app_fields, installed_app_list_fields
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import App, InstalledApp, RecommendedApp
@ -28,22 +29,37 @@ class InstalledAppUpdatePayload(BaseModel):
is_pinned: bool | None = None
class InstalledAppsListQuery(BaseModel):
app_id: str | None = Field(default=None, description="App ID to filter by")
logger = logging.getLogger(__name__)
app_model = get_or_create_model("InstalledAppInfo", app_fields)
installed_app_fields_copy = installed_app_fields.copy()
installed_app_fields_copy["app"] = fields.Nested(app_model)
installed_app_model = get_or_create_model("InstalledApp", installed_app_fields_copy)
installed_app_list_fields_copy = installed_app_list_fields.copy()
installed_app_list_fields_copy["installed_apps"] = fields.List(fields.Nested(installed_app_model))
installed_app_list_model = get_or_create_model("InstalledAppList", installed_app_list_fields_copy)
@console_ns.route("/installed-apps")
class InstalledAppsListApi(Resource):
@login_required
@account_initialization_required
@marshal_with(installed_app_list_fields)
@marshal_with(installed_app_list_model)
def get(self):
app_id = request.args.get("app_id", default=None, type=str)
query = InstalledAppsListQuery.model_validate(request.args.to_dict())
current_user, current_tenant_id = current_account_with_tenant()
if app_id:
if query.app_id:
installed_apps = db.session.scalars(
select(InstalledApp).where(
and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id)
and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == query.app_id)
)
).all()
else:

View File

@ -3,6 +3,7 @@ from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from constants.languages import languages
from controllers.common.schema import get_or_create_model
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required
from libs.helper import AppIconUrlField
@ -19,8 +20,10 @@ app_fields = {
"icon_background": fields.String,
}
app_model = get_or_create_model("RecommendedAppInfo", app_fields)
recommended_app_fields = {
"app": fields.Nested(app_fields, attribute="app"),
"app": fields.Nested(app_model, attribute="app"),
"app_id": fields.String,
"description": fields.String(attribute="description"),
"copyright": fields.String,
@ -32,11 +35,15 @@ recommended_app_fields = {
"can_trial": fields.Boolean,
}
recommended_app_model = get_or_create_model("RecommendedApp", recommended_app_fields)
recommended_app_list_fields = {
"recommended_apps": fields.List(fields.Nested(recommended_app_fields)),
"recommended_apps": fields.List(fields.Nested(recommended_app_model)),
"categories": fields.List(fields.String),
}
recommended_app_list_model = get_or_create_model("RecommendedAppList", recommended_app_list_fields)
class RecommendedAppsQuery(BaseModel):
language: str | None = Field(default=None)
@ -53,7 +60,7 @@ class RecommendedAppListApi(Resource):
@console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__])
@login_required
@account_initialization_required
@marshal_with(recommended_app_list_fields)
@marshal_with(recommended_app_list_model)
def get(self):
# language args
args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore

View File

@ -1,14 +1,16 @@
import logging
from typing import Any, cast
from typing import Any, Literal, cast
from flask import request
from flask_restx import Resource, marshal, marshal_with, reqparse
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.common.fields import Parameters as ParametersResponse
from controllers.common.fields import Site as SiteResponse
from controllers.console import api
from controllers.common.schema import get_or_create_model
from controllers.console import api, console_ns
from controllers.console.app.error import (
AppUnavailableError,
AudioTooLargeError,
@ -42,9 +44,21 @@ from core.errors.error import (
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from fields.app_fields import app_detail_fields_with_site
from fields.app_fields import (
app_detail_fields_with_site,
deleted_tool_fields,
model_config_fields,
site_fields,
tag_fields,
)
from fields.dataset_fields import dataset_fields
from fields.workflow_fields import workflow_fields
from fields.member_fields import simple_account_fields
from fields.workflow_fields import (
conversation_variable_fields,
pipeline_variable_fields,
workflow_fields,
workflow_partial_fields,
)
from libs import helper
from libs.helper import uuid_value
from libs.login import current_user
@ -74,7 +88,86 @@ from services.recommended_app_service import RecommendedAppService
logger = logging.getLogger(__name__)
model_config_model = get_or_create_model("TrialAppModelConfig", model_config_fields)
workflow_partial_model = get_or_create_model("TrialWorkflowPartial", workflow_partial_fields)
deleted_tool_model = get_or_create_model("TrialDeletedTool", deleted_tool_fields)
tag_model = get_or_create_model("TrialTag", tag_fields)
site_model = get_or_create_model("TrialSite", site_fields)
app_detail_fields_with_site_copy = app_detail_fields_with_site.copy()
app_detail_fields_with_site_copy["model_config"] = fields.Nested(
model_config_model, attribute="app_model_config", allow_null=True
)
app_detail_fields_with_site_copy["workflow"] = fields.Nested(workflow_partial_model, allow_null=True)
app_detail_fields_with_site_copy["deleted_tools"] = fields.List(fields.Nested(deleted_tool_model))
app_detail_fields_with_site_copy["tags"] = fields.List(fields.Nested(tag_model))
app_detail_fields_with_site_copy["site"] = fields.Nested(site_model)
app_detail_with_site_model = get_or_create_model("TrialAppDetailWithSite", app_detail_fields_with_site_copy)
simple_account_model = get_or_create_model("SimpleAccount", simple_account_fields)
conversation_variable_model = get_or_create_model("TrialConversationVariable", conversation_variable_fields)
pipeline_variable_model = get_or_create_model("TrialPipelineVariable", pipeline_variable_fields)
workflow_fields_copy = workflow_fields.copy()
workflow_fields_copy["created_by"] = fields.Nested(simple_account_model, attribute="created_by_account")
workflow_fields_copy["updated_by"] = fields.Nested(
simple_account_model, attribute="updated_by_account", allow_null=True
)
workflow_fields_copy["conversation_variables"] = fields.List(fields.Nested(conversation_variable_model))
workflow_fields_copy["rag_pipeline_variables"] = fields.List(fields.Nested(pipeline_variable_model))
workflow_model = get_or_create_model("TrialWorkflow", workflow_fields_copy)
# Pydantic models for request validation
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowRunRequest(BaseModel):
inputs: dict
files: list | None = None
class ChatRequest(BaseModel):
inputs: dict
query: str
files: list | None = None
conversation_id: str | None = None
parent_message_id: str | None = None
retriever_from: str = "explore_app"
class TextToSpeechRequest(BaseModel):
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = None
class CompletionRequest(BaseModel):
inputs: dict
query: str = ""
files: list | None = None
response_mode: Literal["blocking", "streaming"] | None = None
retriever_from: str = "explore_app"
# Register schemas for Swagger documentation
console_ns.schema_model(
WorkflowRunRequest.__name__, WorkflowRunRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ChatRequest.__name__, ChatRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
TextToSpeechRequest.__name__, TextToSpeechRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
CompletionRequest.__name__, CompletionRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
class TrialAppWorkflowRunApi(TrialAppResource):
@console_ns.expect(console_ns.models[WorkflowRunRequest.__name__])
def post(self, trial_app):
"""
Run workflow
@ -86,10 +179,8 @@ class TrialAppWorkflowRunApi(TrialAppResource):
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("files", type=list, required=False, location="json")
args = parser.parse_args()
request_data = WorkflowRunRequest.model_validate(console_ns.payload)
args = request_data.model_dump()
assert current_user is not None
try:
app_id = app_model.id
@ -140,6 +231,7 @@ class TrialAppWorkflowTaskStopApi(TrialAppResource):
class TrialChatApi(TrialAppResource):
@console_ns.expect(console_ns.models[ChatRequest.__name__])
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
@ -147,14 +239,14 @@ class TrialChatApi(TrialAppResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
args = parser.parse_args()
request_data = ChatRequest.model_validate(console_ns.payload)
args = request_data.model_dump()
# Validate UUID values if provided
if args.get("conversation_id"):
args["conversation_id"] = uuid_value(args["conversation_id"])
if args.get("parent_message_id"):
args["parent_message_id"] = uuid_value(args["parent_message_id"])
args["auto_generate_name"] = False
@ -277,20 +369,16 @@ class TrialChatAudioApi(TrialAppResource):
class TrialChatTextApi(TrialAppResource):
@console_ns.expect(console_ns.models[TextToSpeechRequest.__name__])
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
try:
parser = reqparse.RequestParser()
parser.add_argument("message_id", type=str, required=False, location="json")
parser.add_argument("voice", type=str, location="json")
parser.add_argument("text", type=str, location="json")
parser.add_argument("streaming", type=bool, location="json")
args = parser.parse_args()
request_data = TextToSpeechRequest.model_validate(console_ns.payload)
message_id = args.get("message_id", None)
text = args.get("text", None)
voice = args.get("voice", None)
message_id = request_data.message_id
text = request_data.text
voice = request_data.voice
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
@ -328,19 +416,15 @@ class TrialChatTextApi(TrialAppResource):
class TrialCompletionApi(TrialAppResource):
@console_ns.expect(console_ns.models[CompletionRequest.__name__])
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
if app_model.mode != "completion":
raise NotCompletionAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, location="json", default="")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
args = parser.parse_args()
request_data = CompletionRequest.model_validate(console_ns.payload)
args = request_data.model_dump()
streaming = args["response_mode"] == "streaming"
args["auto_generate_name"] = False
@ -437,7 +521,7 @@ class TrialAppParameterApi(Resource):
class AppApi(Resource):
@trial_feature_enable
@get_app_model_with_trial
@marshal_with(app_detail_fields_with_site)
@marshal_with(app_detail_with_site_model)
def get(self, app_model):
"""Get app detail"""
@ -450,7 +534,7 @@ class AppApi(Resource):
class AppWorkflowApi(Resource):
@trial_feature_enable
@get_app_model_with_trial
@marshal_with(workflow_fields)
@marshal_with(workflow_model)
def get(self, app_model):
"""Get workflow detail"""
if not app_model.workflow_id:

View File

@ -1,6 +1,7 @@
from flask_restx import Resource, fields
from werkzeug.exceptions import Unauthorized
from libs.login import current_account_with_tenant, login_required
from libs.login import current_account_with_tenant, current_user, login_required
from services.feature_service import FeatureService
from . import console_ns
@ -39,5 +40,21 @@ class SystemFeatureApi(Resource):
),
)
def get(self):
"""Get system-wide feature configuration"""
return FeatureService.get_system_features().model_dump()
"""Get system-wide feature configuration
NOTE: This endpoint is unauthenticated by design, as it provides system features
data required for dashboard initialization.
Authentication would create circular dependency (can't login without dashboard loading).
Only non-sensitive configuration data should be returned by this endpoint.
"""
# NOTE(QuantumGhost): ideally we should access `current_user.is_authenticated`
# without a try-catch. However, due to the implementation of user loader (the `load_user_from_request`
# in api/extensions/ext_login.py), accessing `current_user.is_authenticated` will
# raise `Unauthorized` exception if authentication token is not provided.
try:
is_authenticated = current_user.is_authenticated
except Unauthorized:
is_authenticated = False
return FeatureService.get_system_features(is_authenticated=is_authenticated).model_dump()

View File

@ -1,87 +1,74 @@
import os
from typing import Literal
from flask import session
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from controllers.fastopenapi import console_router
from extensions.ext_database import db
from models.model import DifySetup
from services.account_service import TenantService
from . import console_ns
from .error import AlreadySetupError, InitValidateFailedError
from .wraps import only_edition_self_hosted
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class InitValidatePayload(BaseModel):
password: str = Field(..., max_length=30)
password: str = Field(..., max_length=30, description="Initialization password")
console_ns.schema_model(
InitValidatePayload.__name__,
InitValidatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
class InitStatusResponse(BaseModel):
status: Literal["finished", "not_started"] = Field(..., description="Initialization status")
class InitValidateResponse(BaseModel):
result: str = Field(description="Operation result", examples=["success"])
@console_router.get(
"/init",
response_model=InitStatusResponse,
tags=["console"],
)
def get_init_status() -> InitStatusResponse:
"""Get initialization validation status."""
init_status = get_init_validate_status()
if init_status:
return InitStatusResponse(status="finished")
return InitStatusResponse(status="not_started")
@console_ns.route("/init")
class InitValidateAPI(Resource):
@console_ns.doc("get_init_status")
@console_ns.doc(description="Get initialization validation status")
@console_ns.response(
200,
"Success",
model=console_ns.model(
"InitStatusResponse",
{"status": fields.String(description="Initialization status", enum=["finished", "not_started"])},
),
)
def get(self):
"""Get initialization validation status"""
init_status = get_init_validate_status()
if init_status:
return {"status": "finished"}
return {"status": "not_started"}
@console_router.post(
"/init",
response_model=InitValidateResponse,
tags=["console"],
status_code=201,
)
@only_edition_self_hosted
def validate_init_password(payload: InitValidatePayload) -> InitValidateResponse:
"""Validate initialization password."""
tenant_count = TenantService.get_tenant_count()
if tenant_count > 0:
raise AlreadySetupError()
@console_ns.doc("validate_init_password")
@console_ns.doc(description="Validate initialization password for self-hosted edition")
@console_ns.expect(console_ns.models[InitValidatePayload.__name__])
@console_ns.response(
201,
"Success",
model=console_ns.model("InitValidateResponse", {"result": fields.String(description="Operation result")}),
)
@console_ns.response(400, "Already setup or validation failed")
@only_edition_self_hosted
def post(self):
"""Validate initialization password"""
# is tenant created
tenant_count = TenantService.get_tenant_count()
if tenant_count > 0:
raise AlreadySetupError()
if payload.password != os.environ.get("INIT_PASSWORD"):
session["is_init_validated"] = False
raise InitValidateFailedError()
payload = InitValidatePayload.model_validate(console_ns.payload)
input_password = payload.password
if input_password != os.environ.get("INIT_PASSWORD"):
session["is_init_validated"] = False
raise InitValidateFailedError()
session["is_init_validated"] = True
return {"result": "success"}, 201
session["is_init_validated"] = True
return InitValidateResponse(result="success")
def get_init_validate_status():
def get_init_validate_status() -> bool:
if dify_config.EDITION == "SELF_HOSTED":
if os.environ.get("INIT_PASSWORD"):
if session.get("is_init_validated"):
return True
with Session(db.engine) as db_session:
return db_session.execute(select(DifySetup)).scalar_one_or_none()
return db_session.execute(select(DifySetup)).scalar_one_or_none() is not None
return True

View File

@ -1,17 +1,17 @@
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from . import console_ns
from controllers.fastopenapi import console_router
@console_ns.route("/ping")
class PingApi(Resource):
@console_ns.doc("health_check")
@console_ns.doc(description="Health check endpoint for connection testing")
@console_ns.response(
200,
"Success",
console_ns.model("PingResponse", {"result": fields.String(description="Health check result", example="pong")}),
)
def get(self):
"""Health check endpoint for connection testing"""
return {"result": "pong"}
class PingResponse(BaseModel):
result: str = Field(description="Health check result", examples=["pong"])
@console_router.get(
"/ping",
response_model=PingResponse,
tags=["console"],
)
def ping() -> PingResponse:
"""Health check endpoint for connection testing."""
return PingResponse(result="pong")

View File

@ -1,7 +1,6 @@
import urllib.parse
import httpx
from flask_restx import Resource
from pydantic import BaseModel, Field
import services
@ -11,7 +10,7 @@ from controllers.common.errors import (
RemoteFileUploadError,
UnsupportedFileTypeError,
)
from controllers.common.schema import register_schema_models
from controllers.fastopenapi import console_router
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from extensions.ext_database import db
@ -19,84 +18,74 @@ from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
from libs.login import current_account_with_tenant
from services.file_service import FileService
from . import console_ns
register_schema_models(console_ns, RemoteFileInfo, FileWithSignedUrl)
@console_ns.route("/remote-files/<path:url>")
class RemoteFileInfoApi(Resource):
@console_ns.response(200, "Remote file info", console_ns.models[RemoteFileInfo.__name__])
def get(self, url):
decoded_url = urllib.parse.unquote(url)
resp = ssrf_proxy.head(decoded_url)
if resp.status_code != httpx.codes.OK:
# failed back to get method
resp = ssrf_proxy.get(decoded_url, timeout=3)
resp.raise_for_status()
info = RemoteFileInfo(
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
file_length=int(resp.headers.get("Content-Length", 0)),
)
return info.model_dump(mode="json")
class RemoteFileUploadPayload(BaseModel):
url: str = Field(..., description="URL to fetch")
console_ns.schema_model(
RemoteFileUploadPayload.__name__,
RemoteFileUploadPayload.model_json_schema(ref_template="#/definitions/{model}"),
@console_router.get(
"/remote-files/<path:url>",
response_model=RemoteFileInfo,
tags=["console"],
)
def get_remote_file_info(url: str) -> RemoteFileInfo:
decoded_url = urllib.parse.unquote(url)
resp = ssrf_proxy.head(decoded_url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(decoded_url, timeout=3)
resp.raise_for_status()
return RemoteFileInfo(
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
file_length=int(resp.headers.get("Content-Length", 0)),
)
@console_ns.route("/remote-files/upload")
class RemoteFileUploadApi(Resource):
@console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__])
@console_ns.response(201, "Remote file uploaded", console_ns.models[FileWithSignedUrl.__name__])
def post(self):
args = RemoteFileUploadPayload.model_validate(console_ns.payload)
url = args.url
@console_router.post(
"/remote-files/upload",
response_model=FileWithSignedUrl,
tags=["console"],
status_code=201,
)
def upload_remote_file(payload: RemoteFileUploadPayload) -> FileWithSignedUrl:
url = payload.url
try:
resp = ssrf_proxy.head(url=url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
if resp.status_code != httpx.codes.OK:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
except httpx.RequestError as e:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}")
try:
resp = ssrf_proxy.head(url=url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
if resp.status_code != httpx.codes.OK:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
except httpx.RequestError as e:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}")
file_info = helpers.guess_file_info_from_response(resp)
file_info = helpers.guess_file_info_from_response(resp)
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
raise FileTooLargeError
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
raise FileTooLargeError
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
try:
user, _ = current_account_with_tenant()
upload_file = FileService(db.engine).upload_file(
filename=file_info.filename,
content=content,
mimetype=file_info.mimetype,
user=user,
source_url=url,
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
payload = FileWithSignedUrl(
id=upload_file.id,
name=upload_file.name,
size=upload_file.size,
extension=upload_file.extension,
url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
mime_type=upload_file.mime_type,
created_by=upload_file.created_by,
created_at=int(upload_file.created_at.timestamp()),
try:
user, _ = current_account_with_tenant()
upload_file = FileService(db.engine).upload_file(
filename=file_info.filename,
content=content,
mimetype=file_info.mimetype,
user=user,
source_url=url,
)
return payload.model_dump(mode="json"), 201
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return FileWithSignedUrl(
id=upload_file.id,
name=upload_file.name,
size=upload_file.size,
extension=upload_file.extension,
url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
mime_type=upload_file.mime_type,
created_by=upload_file.created_by,
created_at=int(upload_file.created_at.timestamp()),
)

View File

@ -1,20 +1,19 @@
from typing import Literal
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from configs import dify_config
from controllers.fastopenapi import console_router
from libs.helper import EmailStr, extract_remote_ip
from libs.password import valid_password
from models.model import DifySetup, db
from services.account_service import RegisterService, TenantService
from . import console_ns
from .error import AlreadySetupError, NotInitValidateError
from .init_validate import get_init_validate_status
from .wraps import only_edition_self_hosted
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class SetupRequestPayload(BaseModel):
email: EmailStr = Field(..., description="Admin email address")
@ -28,78 +27,66 @@ class SetupRequestPayload(BaseModel):
return valid_password(value)
console_ns.schema_model(
SetupRequestPayload.__name__,
SetupRequestPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
class SetupStatusResponse(BaseModel):
step: Literal["not_started", "finished"] = Field(description="Setup step status")
setup_at: str | None = Field(default=None, description="Setup completion time (ISO format)")
class SetupResponse(BaseModel):
result: str = Field(description="Setup result", examples=["success"])
@console_router.get(
"/setup",
response_model=SetupStatusResponse,
tags=["console"],
)
def get_setup_status_api() -> SetupStatusResponse:
"""Get system setup status."""
if dify_config.EDITION == "SELF_HOSTED":
setup_status = get_setup_status()
if setup_status and not isinstance(setup_status, bool):
return SetupStatusResponse(step="finished", setup_at=setup_status.setup_at.isoformat())
if setup_status:
return SetupStatusResponse(step="finished")
return SetupStatusResponse(step="not_started")
return SetupStatusResponse(step="finished")
@console_ns.route("/setup")
class SetupApi(Resource):
@console_ns.doc("get_setup_status")
@console_ns.doc(description="Get system setup status")
@console_ns.response(
200,
"Success",
console_ns.model(
"SetupStatusResponse",
{
"step": fields.String(description="Setup step status", enum=["not_started", "finished"]),
"setup_at": fields.String(description="Setup completion time (ISO format)", required=False),
},
),
@console_router.post(
"/setup",
response_model=SetupResponse,
tags=["console"],
status_code=201,
)
@only_edition_self_hosted
def setup_system(payload: SetupRequestPayload) -> SetupResponse:
"""Initialize system setup with admin account."""
if get_setup_status():
raise AlreadySetupError()
tenant_count = TenantService.get_tenant_count()
if tenant_count > 0:
raise AlreadySetupError()
if not get_init_validate_status():
raise NotInitValidateError()
normalized_email = payload.email.lower()
RegisterService.setup(
email=normalized_email,
name=payload.name,
password=payload.password,
ip_address=extract_remote_ip(request),
language=payload.language,
)
def get(self):
"""Get system setup status"""
if dify_config.EDITION == "SELF_HOSTED":
setup_status = get_setup_status()
# Check if setup_status is a DifySetup object rather than a bool
if setup_status and not isinstance(setup_status, bool):
return {"step": "finished", "setup_at": setup_status.setup_at.isoformat()}
elif setup_status:
return {"step": "finished"}
return {"step": "not_started"}
return {"step": "finished"}
@console_ns.doc("setup_system")
@console_ns.doc(description="Initialize system setup with admin account")
@console_ns.expect(console_ns.models[SetupRequestPayload.__name__])
@console_ns.response(
201, "Success", console_ns.model("SetupResponse", {"result": fields.String(description="Setup result")})
)
@console_ns.response(400, "Already setup or validation failed")
@only_edition_self_hosted
def post(self):
"""Initialize system setup with admin account"""
# is set up
if get_setup_status():
raise AlreadySetupError()
# is tenant created
tenant_count = TenantService.get_tenant_count()
if tenant_count > 0:
raise AlreadySetupError()
if not get_init_validate_status():
raise NotInitValidateError()
args = SetupRequestPayload.model_validate(console_ns.payload)
normalized_email = args.email.lower()
# setup
RegisterService.setup(
email=normalized_email,
name=args.name,
password=args.password,
ip_address=extract_remote_ip(request),
language=args.language,
)
return {"result": "success"}, 201
return SetupResponse(result="success")
def get_setup_status():
def get_setup_status() -> DifySetup | bool | None:
if dify_config.EDITION == "SELF_HOSTED":
return db.session.query(DifySetup).first()
else:
return True
return True

View File

@ -1,17 +1,27 @@
from typing import Literal
from flask import request
from flask_restx import Resource, marshal_with
from flask_restx import Namespace, Resource, fields, marshal_with
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from fields.tag_fields import dataset_tag_fields
from libs.login import current_account_with_tenant, login_required
from services.tag_service import TagService
dataset_tag_fields = {
"id": fields.String,
"name": fields.String,
"type": fields.String,
"binding_count": fields.String,
}
def build_dataset_tag_fields(api_or_ns: Namespace):
return api_or_ns.model("DataSetTag", dataset_tag_fields)
class TagBasePayload(BaseModel):
name: str = Field(description="Tag name", min_length=1, max_length=50)
@ -40,6 +50,7 @@ register_schema_models(
TagBasePayload,
TagBindingPayload,
TagBindingRemovePayload,
TagListQueryParam,
)

View File

@ -1,15 +1,11 @@
import json
import logging
import httpx
from flask import request
from flask_restx import Resource, fields
from packaging import version
from pydantic import BaseModel, Field
from configs import dify_config
from . import console_ns
from controllers.fastopenapi import console_router
logger = logging.getLogger(__name__)
@ -18,69 +14,61 @@ class VersionQuery(BaseModel):
current_version: str = Field(..., description="Current application version")
console_ns.schema_model(
VersionQuery.__name__,
VersionQuery.model_json_schema(ref_template="#/definitions/{model}"),
class VersionFeatures(BaseModel):
can_replace_logo: bool = Field(description="Whether logo replacement is supported")
model_load_balancing_enabled: bool = Field(description="Whether model load balancing is enabled")
class VersionResponse(BaseModel):
version: str = Field(description="Latest version number")
release_date: str = Field(description="Release date of latest version")
release_notes: str = Field(description="Release notes for latest version")
can_auto_update: bool = Field(description="Whether auto-update is supported")
features: VersionFeatures = Field(description="Feature flags and capabilities")
@console_router.get(
"/version",
response_model=VersionResponse,
tags=["console"],
)
def check_version_update(query: VersionQuery) -> VersionResponse:
"""Check for application version updates."""
check_update_url = dify_config.CHECK_UPDATE_URL
@console_ns.route("/version")
class VersionApi(Resource):
@console_ns.doc("check_version_update")
@console_ns.doc(description="Check for application version updates")
@console_ns.expect(console_ns.models[VersionQuery.__name__])
@console_ns.response(
200,
"Success",
console_ns.model(
"VersionResponse",
{
"version": fields.String(description="Latest version number"),
"release_date": fields.String(description="Release date of latest version"),
"release_notes": fields.String(description="Release notes for latest version"),
"can_auto_update": fields.Boolean(description="Whether auto-update is supported"),
"features": fields.Raw(description="Feature flags and capabilities"),
},
result = VersionResponse(
version=dify_config.project.version,
release_date="",
release_notes="",
can_auto_update=False,
features=VersionFeatures(
can_replace_logo=dify_config.CAN_REPLACE_LOGO,
model_load_balancing_enabled=dify_config.MODEL_LB_ENABLED,
),
)
def get(self):
"""Check for application version updates"""
args = VersionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
check_update_url = dify_config.CHECK_UPDATE_URL
result = {
"version": dify_config.project.version,
"release_date": "",
"release_notes": "",
"can_auto_update": False,
"features": {
"can_replace_logo": dify_config.CAN_REPLACE_LOGO,
"model_load_balancing_enabled": dify_config.MODEL_LB_ENABLED,
},
}
if not check_update_url:
return result
try:
response = httpx.get(
check_update_url,
params={"current_version": args.current_version},
timeout=httpx.Timeout(timeout=10.0, connect=3.0),
)
except Exception as error:
logger.warning("Check update version error: %s.", str(error))
result["version"] = args.current_version
return result
content = json.loads(response.content)
if _has_new_version(latest_version=content["version"], current_version=f"{args.current_version}"):
result["version"] = content["version"]
result["release_date"] = content["releaseDate"]
result["release_notes"] = content["releaseNotes"]
result["can_auto_update"] = content["canAutoUpdate"]
if not check_update_url:
return result
try:
response = httpx.get(
check_update_url,
params={"current_version": query.current_version},
timeout=httpx.Timeout(timeout=10.0, connect=3.0),
)
content = response.json()
except Exception as error:
logger.warning("Check update version error: %s.", str(error))
result.version = query.current_version
return result
latest_version = content.get("version", result.version)
if _has_new_version(latest_version=latest_version, current_version=f"{query.current_version}"):
result.version = latest_version
result.release_date = content.get("releaseDate", "")
result.release_notes = content.get("releaseNotes", "")
result.can_auto_update = content.get("canAutoUpdate", False)
return result
def _has_new_version(*, latest_version: str, current_version: str) -> bool:
try:

View File

@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
from configs import dify_config
from constants.languages import supported_language
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.auth.error import (
EmailAlreadyInUseError,
@ -37,7 +38,7 @@ from controllers.console.wraps import (
setup_required,
)
from extensions.ext_database import db
from fields.member_fields import account_fields
from fields.member_fields import Account as AccountResponse
from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
from libs.login import current_account_with_tenant, login_required
@ -170,6 +171,25 @@ reg(ChangeEmailSendPayload)
reg(ChangeEmailValidityPayload)
reg(ChangeEmailResetPayload)
reg(CheckEmailUniquePayload)
register_schema_models(console_ns, AccountResponse)
def _serialize_account(account) -> dict:
return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json")
integrate_fields = {
"provider": fields.String,
"created_at": TimestampField,
"is_bound": fields.Boolean,
"link": fields.String,
}
integrate_model = console_ns.model("AccountIntegrate", integrate_fields)
integrate_list_model = console_ns.model(
"AccountIntegrateList",
{"data": fields.List(fields.Nested(integrate_model))},
)
@console_ns.route("/account/init")
@ -223,11 +243,11 @@ class AccountProfileApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
@enterprise_license_required
def get(self):
current_user, _ = current_account_with_tenant()
return current_user
return _serialize_account(current_user)
@console_ns.route("/account/name")
@ -236,14 +256,14 @@ class AccountNameApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
args = AccountNamePayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, name=args.name)
return updated_account
return _serialize_account(updated_account)
@console_ns.route("/account/avatar")
@ -252,7 +272,7 @@ class AccountAvatarApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
@ -260,7 +280,7 @@ class AccountAvatarApi(Resource):
updated_account = AccountService.update_account(current_user, avatar=args.avatar)
return updated_account
return _serialize_account(updated_account)
@console_ns.route("/account/interface-language")
@ -269,7 +289,7 @@ class AccountInterfaceLanguageApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
@ -277,7 +297,7 @@ class AccountInterfaceLanguageApi(Resource):
updated_account = AccountService.update_account(current_user, interface_language=args.interface_language)
return updated_account
return _serialize_account(updated_account)
@console_ns.route("/account/interface-theme")
@ -286,7 +306,7 @@ class AccountInterfaceThemeApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
@ -294,7 +314,7 @@ class AccountInterfaceThemeApi(Resource):
updated_account = AccountService.update_account(current_user, interface_theme=args.interface_theme)
return updated_account
return _serialize_account(updated_account)
@console_ns.route("/account/timezone")
@ -303,7 +323,7 @@ class AccountTimezoneApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
@ -311,7 +331,7 @@ class AccountTimezoneApi(Resource):
updated_account = AccountService.update_account(current_user, timezone=args.timezone)
return updated_account
return _serialize_account(updated_account)
@console_ns.route("/account/password")
@ -320,7 +340,7 @@ class AccountPasswordApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
@ -331,26 +351,15 @@ class AccountPasswordApi(Resource):
except ServiceCurrentPasswordIncorrectError:
raise CurrentPasswordIncorrectError()
return {"result": "success"}
return _serialize_account(current_user)
@console_ns.route("/account/integrates")
class AccountIntegrateApi(Resource):
integrate_fields = {
"provider": fields.String,
"created_at": TimestampField,
"is_bound": fields.Boolean,
"link": fields.String,
}
integrate_list_fields = {
"data": fields.List(fields.Nested(integrate_fields)),
}
@setup_required
@login_required
@account_initialization_required
@marshal_with(integrate_list_fields)
@marshal_with(integrate_list_model)
def get(self):
account, _ = current_account_with_tenant()
@ -618,7 +627,7 @@ class ChangeEmailResetApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_fields)
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
payload = console_ns.payload or {}
args = ChangeEmailResetPayload.model_validate(payload)
@ -647,7 +656,7 @@ class ChangeEmailResetApi(Resource):
email=normalized_new_email,
)
return updated_account
return _serialize_account(updated_account)
@console_ns.route("/account/change-email/check-email-unique")

View File

@ -1,9 +1,10 @@
from typing import Any
from flask import request
from flask_restx import Resource, fields
from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
@ -38,15 +39,53 @@ class EndpointListForPluginQuery(EndpointListQuery):
plugin_id: str
class EndpointCreateResponse(BaseModel):
success: bool = Field(description="Operation success")
class EndpointListResponse(BaseModel):
endpoints: list[dict[str, Any]] = Field(description="Endpoint information")
class PluginEndpointListResponse(BaseModel):
endpoints: list[dict[str, Any]] = Field(description="Endpoint information")
class EndpointDeleteResponse(BaseModel):
success: bool = Field(description="Operation success")
class EndpointUpdateResponse(BaseModel):
success: bool = Field(description="Operation success")
class EndpointEnableResponse(BaseModel):
success: bool = Field(description="Operation success")
class EndpointDisableResponse(BaseModel):
success: bool = Field(description="Operation success")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(EndpointCreatePayload)
reg(EndpointIdPayload)
reg(EndpointUpdatePayload)
reg(EndpointListQuery)
reg(EndpointListForPluginQuery)
register_schema_models(
console_ns,
EndpointCreatePayload,
EndpointIdPayload,
EndpointUpdatePayload,
EndpointListQuery,
EndpointListForPluginQuery,
EndpointCreateResponse,
EndpointListResponse,
PluginEndpointListResponse,
EndpointDeleteResponse,
EndpointUpdateResponse,
EndpointEnableResponse,
EndpointDisableResponse,
)
@console_ns.route("/workspaces/current/endpoints/create")
@ -57,7 +96,7 @@ class EndpointCreateApi(Resource):
@console_ns.response(
200,
"Endpoint created successfully",
console_ns.model("EndpointCreateResponse", {"success": fields.Boolean(description="Operation success")}),
console_ns.models[EndpointCreateResponse.__name__],
)
@console_ns.response(403, "Admin privileges required")
@setup_required
@ -91,9 +130,7 @@ class EndpointListApi(Resource):
@console_ns.response(
200,
"Success",
console_ns.model(
"EndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}
),
console_ns.models[EndpointListResponse.__name__],
)
@setup_required
@login_required
@ -126,9 +163,7 @@ class EndpointListForSinglePluginApi(Resource):
@console_ns.response(
200,
"Success",
console_ns.model(
"PluginEndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}
),
console_ns.models[PluginEndpointListResponse.__name__],
)
@setup_required
@login_required
@ -163,7 +198,7 @@ class EndpointDeleteApi(Resource):
@console_ns.response(
200,
"Endpoint deleted successfully",
console_ns.model("EndpointDeleteResponse", {"success": fields.Boolean(description="Operation success")}),
console_ns.models[EndpointDeleteResponse.__name__],
)
@console_ns.response(403, "Admin privileges required")
@setup_required
@ -190,7 +225,7 @@ class EndpointUpdateApi(Resource):
@console_ns.response(
200,
"Endpoint updated successfully",
console_ns.model("EndpointUpdateResponse", {"success": fields.Boolean(description="Operation success")}),
console_ns.models[EndpointUpdateResponse.__name__],
)
@console_ns.response(403, "Admin privileges required")
@setup_required
@ -221,7 +256,7 @@ class EndpointEnableApi(Resource):
@console_ns.response(
200,
"Endpoint enabled successfully",
console_ns.model("EndpointEnableResponse", {"success": fields.Boolean(description="Operation success")}),
console_ns.models[EndpointEnableResponse.__name__],
)
@console_ns.response(403, "Admin privileges required")
@setup_required
@ -248,7 +283,7 @@ class EndpointDisableApi(Resource):
@console_ns.response(
200,
"Endpoint disabled successfully",
console_ns.model("EndpointDisableResponse", {"success": fields.Boolean(description="Operation success")}),
console_ns.models[EndpointDisableResponse.__name__],
)
@console_ns.response(403, "Admin privileges required")
@setup_required

View File

@ -1,11 +1,12 @@
from urllib import parse
from flask import abort, request
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource
from pydantic import BaseModel, Field, TypeAdapter
import services
from configs import dify_config
from controllers.common.schema import register_enum_models, register_schema_models
from controllers.console import console_ns
from controllers.console.auth.error import (
CannotTransferOwnerToSelfError,
@ -24,7 +25,7 @@ from controllers.console.wraps import (
setup_required,
)
from extensions.ext_database import db
from fields.member_fields import account_with_role_list_fields
from fields.member_fields import AccountWithRole, AccountWithRoleList
from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant, login_required
from models.account import Account, TenantAccountRole
@ -67,6 +68,8 @@ reg(MemberRoleUpdatePayload)
reg(OwnerTransferEmailPayload)
reg(OwnerTransferCheckPayload)
reg(OwnerTransferPayload)
register_enum_models(console_ns, TenantAccountRole)
register_schema_models(console_ns, AccountWithRole, AccountWithRoleList)
@console_ns.route("/workspaces/current/members")
@ -76,13 +79,15 @@ class MemberListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_with_role_list_fields)
@console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__])
def get(self):
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_tenant_members(current_user.current_tenant)
return {"result": "success", "accounts": members}, 200
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
response = AccountWithRoleList(accounts=member_models)
return response.model_dump(mode="json"), 200
@console_ns.route("/workspaces/current/members/invite-email")
@ -227,13 +232,15 @@ class DatasetOperatorMemberListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(account_with_role_list_fields)
@console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__])
def get(self):
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
return {"result": "success", "accounts": members}, 200
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
response = AccountWithRoleList(accounts=member_models)
return response.model_dump(mode="json"), 200
@console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email")

View File

@ -5,6 +5,7 @@ from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from controllers.common.schema import register_enum_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
@ -23,12 +24,13 @@ class ParserGetDefault(BaseModel):
model_type: ModelType
class ParserPostDefault(BaseModel):
class Inner(BaseModel):
model_type: ModelType
model: str | None = None
provider: str | None = None
class Inner(BaseModel):
model_type: ModelType
model: str | None = None
provider: str | None = None
class ParserPostDefault(BaseModel):
model_settings: list[Inner]
@ -105,19 +107,21 @@ class ParserParameter(BaseModel):
model: str
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
register_schema_models(
console_ns,
ParserGetDefault,
ParserPostDefault,
ParserDeleteModels,
ParserPostModels,
ParserGetCredentials,
ParserCreateCredential,
ParserUpdateCredential,
ParserDeleteCredential,
ParserParameter,
Inner,
)
reg(ParserGetDefault)
reg(ParserPostDefault)
reg(ParserDeleteModels)
reg(ParserPostModels)
reg(ParserGetCredentials)
reg(ParserCreateCredential)
reg(ParserUpdateCredential)
reg(ParserDeleteCredential)
reg(ParserParameter)
register_enum_models(console_ns, ModelType)
@console_ns.route("/workspaces/current/default-model")

View File

@ -8,6 +8,7 @@ from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.common.schema import register_enum_models, register_schema_models
from controllers.console import console_ns
from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
@ -20,57 +21,12 @@ from services.plugin.plugin_parameter_service import PluginParameterService
from services.plugin.plugin_permission_service import PluginPermissionService
from services.plugin.plugin_service import PluginService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/workspaces/current/plugin/debugging-key")
class PluginDebuggingKeyApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def get(self):
_, tenant_id = current_account_with_tenant()
try:
return {
"key": PluginService.get_debugging_key(tenant_id),
"host": dify_config.PLUGIN_REMOTE_INSTALL_HOST,
"port": dify_config.PLUGIN_REMOTE_INSTALL_PORT,
}
except PluginDaemonClientSideError as e:
raise ValueError(e)
class ParserList(BaseModel):
page: int = Field(default=1, ge=1, description="Page number")
page_size: int = Field(default=256, ge=1, le=256, description="Page size (1-256)")
reg(ParserList)
@console_ns.route("/workspaces/current/plugin/list")
class PluginListApi(Resource):
@console_ns.expect(console_ns.models[ParserList.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
args = ParserList.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total})
class ParserLatest(BaseModel):
plugin_ids: list[str]
@ -180,23 +136,73 @@ class ParserReadme(BaseModel):
language: str = Field(default="en-US")
reg(ParserLatest)
reg(ParserIcon)
reg(ParserAsset)
reg(ParserGithubUpload)
reg(ParserPluginIdentifiers)
reg(ParserGithubInstall)
reg(ParserPluginIdentifierQuery)
reg(ParserTasks)
reg(ParserMarketplaceUpgrade)
reg(ParserGithubUpgrade)
reg(ParserUninstall)
reg(ParserPermissionChange)
reg(ParserDynamicOptions)
reg(ParserDynamicOptionsWithCredentials)
reg(ParserPreferencesChange)
reg(ParserExcludePlugin)
reg(ParserReadme)
register_schema_models(
console_ns,
ParserList,
PluginAutoUpgradeSettingsPayload,
PluginPermissionSettingsPayload,
ParserLatest,
ParserIcon,
ParserAsset,
ParserGithubUpload,
ParserPluginIdentifiers,
ParserGithubInstall,
ParserPluginIdentifierQuery,
ParserTasks,
ParserMarketplaceUpgrade,
ParserGithubUpgrade,
ParserUninstall,
ParserPermissionChange,
ParserDynamicOptions,
ParserDynamicOptionsWithCredentials,
ParserPreferencesChange,
ParserExcludePlugin,
ParserReadme,
)
register_enum_models(
console_ns,
TenantPluginPermission.DebugPermission,
TenantPluginAutoUpgradeStrategy.UpgradeMode,
TenantPluginAutoUpgradeStrategy.StrategySetting,
TenantPluginPermission.InstallPermission,
)
@console_ns.route("/workspaces/current/plugin/debugging-key")
class PluginDebuggingKeyApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def get(self):
_, tenant_id = current_account_with_tenant()
try:
return {
"key": PluginService.get_debugging_key(tenant_id),
"host": dify_config.PLUGIN_REMOTE_INSTALL_HOST,
"port": dify_config.PLUGIN_REMOTE_INSTALL_PORT,
}
except PluginDaemonClientSideError as e:
raise ValueError(e)
@console_ns.route("/workspaces/current/plugin/list")
class PluginListApi(Resource):
@console_ns.expect(console_ns.models[ParserList.__name__])
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
args = ParserList.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total})
@console_ns.route("/workspaces/current/plugin/list/latest-versions")

File diff suppressed because it is too large Load Diff