Merge branch 'main' into feat/knowledgebase-summaryIndex

This commit is contained in:
FFXN
2026-01-28 11:22:25 +08:00
committed by GitHub
196 changed files with 4868 additions and 6453 deletions

View File

@ -3,13 +3,13 @@ from collections.abc import Generator
from typing import Any, cast
from flask import request
from flask_restx import Resource, marshal_with
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_model
from controllers.common.schema import get_or_create_model, register_schema_model
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.indexing_runner import IndexingRunner
@ -17,7 +17,14 @@ from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
from fields.data_source_fields import (
integrate_fields,
integrate_icon_fields,
integrate_list_fields,
integrate_notion_info_list_fields,
integrate_page_fields,
integrate_workspace_fields,
)
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import DataSourceOauthBinding, Document
@ -49,6 +56,49 @@ class DataSourceNotionPreviewQuery(BaseModel):
register_schema_model(console_ns, NotionEstimatePayload)
integrate_icon_model = get_or_create_model("DataSourceIntegrateIcon", integrate_icon_fields)
integrate_page_fields_copy = integrate_page_fields.copy()
integrate_page_fields_copy["page_icon"] = fields.Nested(integrate_icon_model, allow_null=True)
integrate_page_model = get_or_create_model("DataSourceIntegratePage", integrate_page_fields_copy)
integrate_workspace_fields_copy = integrate_workspace_fields.copy()
integrate_workspace_fields_copy["pages"] = fields.List(fields.Nested(integrate_page_model))
integrate_workspace_model = get_or_create_model("DataSourceIntegrateWorkspace", integrate_workspace_fields_copy)
integrate_fields_copy = integrate_fields.copy()
integrate_fields_copy["source_info"] = fields.Nested(integrate_workspace_model)
integrate_model = get_or_create_model("DataSourceIntegrate", integrate_fields_copy)
integrate_list_fields_copy = integrate_list_fields.copy()
integrate_list_fields_copy["data"] = fields.List(fields.Nested(integrate_model))
integrate_list_model = get_or_create_model("DataSourceIntegrateList", integrate_list_fields_copy)
notion_page_fields = {
"page_name": fields.String,
"page_id": fields.String,
"page_icon": fields.Nested(integrate_icon_model, allow_null=True),
"is_bound": fields.Boolean,
"parent_id": fields.String,
"type": fields.String,
}
notion_page_model = get_or_create_model("NotionIntegratePage", notion_page_fields)
notion_workspace_fields = {
"workspace_name": fields.String,
"workspace_id": fields.String,
"workspace_icon": fields.String,
"pages": fields.List(fields.Nested(notion_page_model)),
}
notion_workspace_model = get_or_create_model("NotionIntegrateWorkspace", notion_workspace_fields)
integrate_notion_info_list_fields_copy = integrate_notion_info_list_fields.copy()
integrate_notion_info_list_fields_copy["notion_info"] = fields.List(fields.Nested(notion_workspace_model))
integrate_notion_info_list_model = get_or_create_model(
"NotionIntegrateInfoList", integrate_notion_info_list_fields_copy
)
@console_ns.route(
"/data-source/integrates",
"/data-source/integrates/<uuid:binding_id>/<string:action>",
@ -57,7 +107,7 @@ class DataSourceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(integrate_list_fields)
@marshal_with(integrate_list_model)
def get(self):
_, current_tenant_id = current_account_with_tenant()
@ -142,7 +192,7 @@ class DataSourceNotionListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(integrate_notion_info_list_fields)
@marshal_with(integrate_notion_info_list_model)
def get(self):
current_user, current_tenant_id = current_account_with_tenant()

View File

@ -8,7 +8,7 @@ from werkzeug.exceptions import Forbidden, NotFound
import services
from configs import dify_config
from controllers.common.schema import register_schema_models
from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
from controllers.console.apikey import (
api_key_item_model,
@ -34,6 +34,7 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from fields.app_fields import app_detail_kernel_fields, related_app_list
from fields.dataset_fields import (
content_fields,
dataset_detail_fields,
dataset_fields,
dataset_query_detail_fields,
@ -41,6 +42,7 @@ from fields.dataset_fields import (
doc_metadata_fields,
external_knowledge_info_fields,
external_retrieval_model_fields,
file_info_fields,
icon_info_fields,
keyword_setting_fields,
reranking_model_fields,
@ -55,41 +57,33 @@ from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
def _get_or_create_model(model_name: str, field_def):
existing = console_ns.models.get(model_name)
if existing is None:
existing = console_ns.model(model_name, field_def)
return existing
# Register models for flask_restx to avoid dict type issues in Swagger
dataset_base_model = _get_or_create_model("DatasetBase", dataset_fields)
dataset_base_model = get_or_create_model("DatasetBase", dataset_fields)
tag_model = _get_or_create_model("Tag", tag_fields)
tag_model = get_or_create_model("Tag", tag_fields)
keyword_setting_model = _get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
vector_setting_model = _get_or_create_model("DatasetVectorSetting", vector_setting_fields)
keyword_setting_model = get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
vector_setting_model = get_or_create_model("DatasetVectorSetting", vector_setting_fields)
weighted_score_fields_copy = weighted_score_fields.copy()
weighted_score_fields_copy["keyword_setting"] = fields.Nested(keyword_setting_model)
weighted_score_fields_copy["vector_setting"] = fields.Nested(vector_setting_model)
weighted_score_model = _get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
weighted_score_model = get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
reranking_model = _get_or_create_model("DatasetRerankingModel", reranking_model_fields)
reranking_model = get_or_create_model("DatasetRerankingModel", reranking_model_fields)
dataset_retrieval_model_fields_copy = dataset_retrieval_model_fields.copy()
dataset_retrieval_model_fields_copy["reranking_model"] = fields.Nested(reranking_model)
dataset_retrieval_model_fields_copy["weights"] = fields.Nested(weighted_score_model, allow_null=True)
dataset_retrieval_model = _get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
dataset_retrieval_model = get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
external_knowledge_info_model = _get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
external_knowledge_info_model = get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
external_retrieval_model = _get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
external_retrieval_model = get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
doc_metadata_model = _get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
doc_metadata_model = get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
icon_info_model = _get_or_create_model("DatasetIconInfo", icon_info_fields)
icon_info_model = get_or_create_model("DatasetIconInfo", icon_info_fields)
dataset_detail_fields_copy = dataset_detail_fields.copy()
dataset_detail_fields_copy["retrieval_model_dict"] = fields.Nested(dataset_retrieval_model)
@ -98,14 +92,22 @@ dataset_detail_fields_copy["external_knowledge_info"] = fields.Nested(external_k
dataset_detail_fields_copy["external_retrieval_model"] = fields.Nested(external_retrieval_model, allow_null=True)
dataset_detail_fields_copy["doc_metadata"] = fields.List(fields.Nested(doc_metadata_model))
dataset_detail_fields_copy["icon_info"] = fields.Nested(icon_info_model)
dataset_detail_model = _get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
dataset_detail_model = get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
dataset_query_detail_model = _get_or_create_model("DatasetQueryDetail", dataset_query_detail_fields)
file_info_model = get_or_create_model("DatasetFileInfo", file_info_fields)
app_detail_kernel_model = _get_or_create_model("AppDetailKernel", app_detail_kernel_fields)
content_fields_copy = content_fields.copy()
content_fields_copy["file_info"] = fields.Nested(file_info_model, allow_null=True)
content_model = get_or_create_model("DatasetContent", content_fields_copy)
dataset_query_detail_fields_copy = dataset_query_detail_fields.copy()
dataset_query_detail_fields_copy["queries"] = fields.Nested(content_model)
dataset_query_detail_model = get_or_create_model("DatasetQueryDetail", dataset_query_detail_fields_copy)
app_detail_kernel_model = get_or_create_model("AppDetailKernel", app_detail_kernel_fields)
related_app_list_copy = related_app_list.copy()
related_app_list_copy["data"] = fields.List(fields.Nested(app_detail_kernel_model))
related_app_list_model = _get_or_create_model("RelatedAppList", related_app_list_copy)
related_app_list_model = get_or_create_model("RelatedAppList", related_app_list_copy)
def _validate_indexing_technique(value: str | None) -> str | None:

View File

@ -14,7 +14,7 @@ from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.common.schema import register_schema_models
from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
from core.errors.error import (
LLMBadRequestError,
@ -73,34 +73,27 @@ logger = logging.getLogger(__name__)
DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100
def _get_or_create_model(model_name: str, field_def):
existing = console_ns.models.get(model_name)
if existing is None:
existing = console_ns.model(model_name, field_def)
return existing
# Register models for flask_restx to avoid dict type issues in Swagger
dataset_model = _get_or_create_model("Dataset", dataset_fields)
dataset_model = get_or_create_model("Dataset", dataset_fields)
document_metadata_model = _get_or_create_model("DocumentMetadata", document_metadata_fields)
document_metadata_model = get_or_create_model("DocumentMetadata", document_metadata_fields)
document_fields_copy = document_fields.copy()
document_fields_copy["doc_metadata"] = fields.List(
fields.Nested(document_metadata_model), attribute="doc_metadata_details"
)
document_model = _get_or_create_model("Document", document_fields_copy)
document_model = get_or_create_model("Document", document_fields_copy)
document_with_segments_fields_copy = document_with_segments_fields.copy()
document_with_segments_fields_copy["doc_metadata"] = fields.List(
fields.Nested(document_metadata_model), attribute="doc_metadata_details"
)
document_with_segments_model = _get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy)
document_with_segments_model = get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy)
dataset_and_document_fields_copy = dataset_and_document_fields.copy()
dataset_and_document_fields_copy["dataset"] = fields.Nested(dataset_model)
dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(document_model))
dataset_and_document_model = _get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
dataset_and_document_model = get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
class DocumentRetryPayload(BaseModel):
@ -1266,7 +1259,7 @@ class DocumentRenameApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(document_fields)
@marshal_with(document_model)
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
def post(self, dataset_id, document_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator

View File

@ -108,6 +108,7 @@ register_schema_models(
ChildChunkCreatePayload,
ChildChunkUpdatePayload,
ChildChunkBatchUpdatePayload,
ChildChunkUpdateArgs,
)

View File

@ -4,7 +4,7 @@ from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.common.schema import register_schema_models
from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
@ -28,34 +28,27 @@ from services.hit_testing_service import HitTestingService
from services.knowledge_service import ExternalDatasetTestService
def _get_or_create_model(model_name: str, field_def):
existing = console_ns.models.get(model_name)
if existing is None:
existing = console_ns.model(model_name, field_def)
return existing
def _build_dataset_detail_model():
keyword_setting_model = _get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
vector_setting_model = _get_or_create_model("DatasetVectorSetting", vector_setting_fields)
keyword_setting_model = get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
vector_setting_model = get_or_create_model("DatasetVectorSetting", vector_setting_fields)
weighted_score_fields_copy = weighted_score_fields.copy()
weighted_score_fields_copy["keyword_setting"] = fields.Nested(keyword_setting_model)
weighted_score_fields_copy["vector_setting"] = fields.Nested(vector_setting_model)
weighted_score_model = _get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
weighted_score_model = get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
reranking_model = _get_or_create_model("DatasetRerankingModel", reranking_model_fields)
reranking_model = get_or_create_model("DatasetRerankingModel", reranking_model_fields)
dataset_retrieval_model_fields_copy = dataset_retrieval_model_fields.copy()
dataset_retrieval_model_fields_copy["reranking_model"] = fields.Nested(reranking_model)
dataset_retrieval_model_fields_copy["weights"] = fields.Nested(weighted_score_model, allow_null=True)
dataset_retrieval_model = _get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
dataset_retrieval_model = get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
tag_model = _get_or_create_model("Tag", tag_fields)
doc_metadata_model = _get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
external_knowledge_info_model = _get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
external_retrieval_model = _get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
icon_info_model = _get_or_create_model("DatasetIconInfo", icon_info_fields)
tag_model = get_or_create_model("Tag", tag_fields)
doc_metadata_model = get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
external_knowledge_info_model = get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
external_retrieval_model = get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
icon_info_model = get_or_create_model("DatasetIconInfo", icon_info_fields)
dataset_detail_fields_copy = dataset_detail_fields.copy()
dataset_detail_fields_copy["retrieval_model_dict"] = fields.Nested(dataset_retrieval_model)
@ -64,7 +57,7 @@ def _build_dataset_detail_model():
dataset_detail_fields_copy["external_retrieval_model"] = fields.Nested(external_retrieval_model, allow_null=True)
dataset_detail_fields_copy["doc_metadata"] = fields.List(fields.Nested(doc_metadata_model))
dataset_detail_fields_copy["icon_info"] = fields.Nested(icon_info_model)
return _get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
return get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
try:

View File

@ -4,14 +4,16 @@ from flask_restx import Resource, marshal_with
from pydantic import BaseModel
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_model, register_schema_models
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from fields.dataset_fields import dataset_metadata_fields
from libs.login import current_account_with_tenant, login_required
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import (
DocumentMetadataOperation,
MetadataArgs,
MetadataDetail,
MetadataOperationData,
)
from services.metadata_service import MetadataService
@ -21,8 +23,9 @@ class MetadataUpdatePayload(BaseModel):
name: str
register_schema_models(console_ns, MetadataArgs, MetadataOperationData)
register_schema_model(console_ns, MetadataUpdatePayload)
register_schema_models(
console_ns, MetadataArgs, MetadataOperationData, MetadataUpdatePayload, DocumentMetadataOperation, MetadataDetail
)
@console_ns.route("/datasets/<uuid:dataset_id>/metadata")

View File

@ -2,7 +2,7 @@ import logging
from typing import Any, NoReturn
from flask import Response, request
from flask_restx import Resource, fields, marshal, marshal_with
from flask_restx import Resource, marshal, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
@ -14,7 +14,9 @@ from controllers.console.app.error import (
)
from controllers.console.app.workflow_draft_variable import (
_WORKFLOW_DRAFT_VARIABLE_FIELDS, # type: ignore[private-usage]
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, # type: ignore[private-usage]
workflow_draft_variable_list_model,
workflow_draft_variable_list_without_value_model,
workflow_draft_variable_model,
)
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required
@ -27,7 +29,6 @@ from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required
from models import Account
from models.dataset import Pipeline
from models.workflow import WorkflowDraftVariable
from services.rag_pipeline.rag_pipeline import RagPipelineService
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
@ -52,20 +53,6 @@ class WorkflowDraftVariablePatchPayload(BaseModel):
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
return var_list.variables
_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = {
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items),
"total": fields.Raw(),
}
_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
}
def _api_prerequisite(f):
"""Common prerequisites for all draft workflow variable APIs.
@ -92,7 +79,7 @@ def _api_prerequisite(f):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables")
class RagPipelineVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
@marshal_with(workflow_draft_variable_list_without_value_model)
def get(self, pipeline: Pipeline):
"""
Get draft workflow
@ -150,7 +137,7 @@ def validate_node_id(node_id: str) -> NoReturn | None:
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/variables")
class RagPipelineNodeVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
@marshal_with(workflow_draft_variable_list_model)
def get(self, pipeline: Pipeline, node_id: str):
validate_node_id(node_id)
with Session(bind=db.engine, expire_on_commit=False) as session:
@ -176,7 +163,7 @@ class RagPipelineVariableApi(Resource):
_PATCH_VALUE_FIELD = "value"
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
@marshal_with(workflow_draft_variable_model)
def get(self, pipeline: Pipeline, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
@ -189,7 +176,7 @@ class RagPipelineVariableApi(Resource):
return variable
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
@marshal_with(workflow_draft_variable_model)
@console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__])
def patch(self, pipeline: Pipeline, variable_id: str):
# Request payload for file types:
@ -307,7 +294,7 @@ def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/system-variables")
class RagPipelineSystemVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
@marshal_with(workflow_draft_variable_list_model)
def get(self, pipeline: Pipeline):
return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID)

View File

@ -1,9 +1,9 @@
from flask import request
from flask_restx import Resource, marshal_with # type: ignore
from flask_restx import Resource, fields, marshal_with # type: ignore
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import (
@ -12,7 +12,11 @@ from controllers.console.wraps import (
setup_required,
)
from extensions.ext_database import db
from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
from fields.rag_pipeline_fields import (
leaked_dependency_fields,
pipeline_import_check_dependencies_fields,
pipeline_import_fields,
)
from libs.login import current_account_with_tenant, login_required
from models.dataset import Pipeline
from services.app_dsl_service import ImportStatus
@ -38,13 +42,25 @@ class IncludeSecretQuery(BaseModel):
register_schema_models(console_ns, RagPipelineImportPayload, IncludeSecretQuery)
pipeline_import_model = get_or_create_model("RagPipelineImport", pipeline_import_fields)
leaked_dependency_model = get_or_create_model("RagPipelineLeakedDependency", leaked_dependency_fields)
pipeline_import_check_dependencies_fields_copy = pipeline_import_check_dependencies_fields.copy()
pipeline_import_check_dependencies_fields_copy["leaked_dependencies"] = fields.List(
fields.Nested(leaked_dependency_model)
)
pipeline_import_check_dependencies_model = get_or_create_model(
"RagPipelineImportCheckDependencies", pipeline_import_check_dependencies_fields_copy
)
@console_ns.route("/rag/pipelines/imports")
class RagPipelineImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@marshal_with(pipeline_import_fields)
@marshal_with(pipeline_import_model)
@console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__])
def post(self):
# Check user role first
@ -81,7 +97,7 @@ class RagPipelineImportConfirmApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
@marshal_with(pipeline_import_fields)
@marshal_with(pipeline_import_model)
def post(self, import_id):
current_user, _ = current_account_with_tenant()
@ -106,7 +122,7 @@ class RagPipelineImportCheckDependenciesApi(Resource):
@get_rag_pipeline
@account_initialization_required
@edit_permission_required
@marshal_with(pipeline_import_check_dependencies_fields)
@marshal_with(pipeline_import_check_dependencies_model)
def get(self, pipeline: Pipeline):
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)

View File

@ -17,6 +17,13 @@ from controllers.console.app.error import (
DraftWorkflowNotExist,
DraftWorkflowNotSync,
)
from controllers.console.app.workflow import workflow_model, workflow_pagination_model
from controllers.console.app.workflow_run import (
workflow_run_detail_model,
workflow_run_node_execution_list_model,
workflow_run_node_execution_model,
workflow_run_pagination_model,
)
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import (
account_initialization_required,
@ -30,13 +37,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from factories import variable_factory
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
from fields.workflow_run_fields import (
workflow_run_detail_fields,
workflow_run_node_execution_fields,
workflow_run_node_execution_list_fields,
workflow_run_pagination_fields,
)
from libs import helper
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, current_user, login_required
@ -145,7 +145,7 @@ class DraftRagPipelineApi(Resource):
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
@marshal_with(workflow_fields)
@marshal_with(workflow_model)
def get(self, pipeline: Pipeline):
"""
Get draft rag pipeline's workflow
@ -521,7 +521,7 @@ class RagPipelineDraftNodeRunApi(Resource):
@edit_permission_required
@account_initialization_required
@get_rag_pipeline
@marshal_with(workflow_run_node_execution_fields)
@marshal_with(workflow_run_node_execution_model)
def post(self, pipeline: Pipeline, node_id: str):
"""
Run draft workflow node
@ -569,7 +569,7 @@ class PublishedRagPipelineApi(Resource):
@account_initialization_required
@edit_permission_required
@get_rag_pipeline
@marshal_with(workflow_fields)
@marshal_with(workflow_model)
def get(self, pipeline: Pipeline):
"""
Get published pipeline
@ -664,7 +664,7 @@ class PublishedAllRagPipelineApi(Resource):
@account_initialization_required
@edit_permission_required
@get_rag_pipeline
@marshal_with(workflow_pagination_fields)
@marshal_with(workflow_pagination_model)
def get(self, pipeline: Pipeline):
"""
Get published workflows
@ -708,7 +708,7 @@ class RagPipelineByIdApi(Resource):
@account_initialization_required
@edit_permission_required
@get_rag_pipeline
@marshal_with(workflow_fields)
@marshal_with(workflow_model)
def patch(self, pipeline: Pipeline, workflow_id: str):
"""
Update workflow attributes
@ -830,7 +830,7 @@ class RagPipelineWorkflowRunListApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@marshal_with(workflow_run_pagination_fields)
@marshal_with(workflow_run_pagination_model)
def get(self, pipeline: Pipeline):
"""
Get workflow run list
@ -858,7 +858,7 @@ class RagPipelineWorkflowRunDetailApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@marshal_with(workflow_run_detail_fields)
@marshal_with(workflow_run_detail_model)
def get(self, pipeline: Pipeline, run_id):
"""
Get workflow run detail
@ -877,7 +877,7 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@marshal_with(workflow_run_node_execution_list_fields)
@marshal_with(workflow_run_node_execution_list_model)
def get(self, pipeline: Pipeline, run_id: str):
"""
Get workflow run node execution list
@ -911,7 +911,7 @@ class RagPipelineWorkflowLastRunApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
@marshal_with(workflow_run_node_execution_fields)
@marshal_with(workflow_run_node_execution_model)
def get(self, pipeline: Pipeline, node_id: str):
rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
@ -952,7 +952,7 @@ class RagPipelineDatasourceVariableApi(Resource):
@account_initialization_required
@get_rag_pipeline
@edit_permission_required
@marshal_with(workflow_run_node_execution_fields)
@marshal_with(workflow_run_node_execution_model)
def post(self, pipeline: Pipeline):
"""
Set datasource variables