diff --git a/admin/client/ragflow_client.py b/admin/client/ragflow_client.py index 423befafa..86236bb52 100644 --- a/admin/client/ragflow_client.py +++ b/admin/client/ragflow_client.py @@ -764,14 +764,14 @@ class RAGFlowClient: iterations = command.get("iterations", 1) if iterations > 1: - response = self.http_client.request("POST", "/kb/list", use_api_base=False, auth_kind="web", + response = self.http_client.request("GET", "/datasets", use_api_base=True, auth_kind="web", iterations=iterations) return response else: - response = self.http_client.request("POST", "/kb/list", use_api_base=False, auth_kind="web") + response = self.http_client.request("GET", "/datasets", use_api_base=True, auth_kind="web") res_json = response.json() if response.status_code == 200: - self._print_table_simple(res_json["data"]["kbs"]) + self._print_table_simple(res_json["data"]) else: print(f"Fail to list datasets, code: {res_json['code']}, message: {res_json['message']}") return None @@ -781,13 +781,13 @@ class RAGFlowClient: print("This command is only allowed in USER mode") payload = { "name": command["dataset_name"], - "embd_id": command["embedding"] + "embedding_model": command["embedding"] } if "parser_id" in command: - payload["parser_id"] = command["parser"] + payload["chunk_method"] = command["parser"] if "pipeline" in command: payload["pipeline_id"] = command["pipeline"] - response = self.http_client.request("POST", "/kb/create", json_body=payload, use_api_base=False, + response = self.http_client.request("POST", "/datasets", json_body=payload, use_api_base=True, auth_kind="web") res_json = response.json() if response.status_code == 200: @@ -803,8 +803,8 @@ class RAGFlowClient: dataset_id = self._get_dataset_id(dataset_name) if dataset_id is None: return - payload = {"kb_id": dataset_id} - response = self.http_client.request("POST", "/kb/rm", json_body=payload, use_api_base=False, auth_kind="web") + payload = {"ids": [dataset_id]} + response = self.http_client.request("DELETE", "/datasets", json_body=payload, use_api_base=True, auth_kind="web") res_json = response.json() if response.status_code == 200: print(f"Drop dataset {dataset_name} successfully") @@ -1349,13 +1349,13 @@ class RAGFlowClient: return res_json["data"]["docs"] def _get_dataset_id(self, dataset_name: str): - response = self.http_client.request("POST", "/kb/list", use_api_base=False, auth_kind="web") + response = self.http_client.request("GET", "/datasets", use_api_base=True, auth_kind="web") res_json = response.json() if response.status_code != 200: print(f"Fail to list datasets, code: {res_json['code']}, message: {res_json['message']}") return None - dataset_list = res_json["data"]["kbs"] + dataset_list = res_json["data"] dataset_id: str = "" for dataset in dataset_list: if dataset["name"] == dataset_name: diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index 8a57bcd63..f817de633 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import json import logging import random import re @@ -26,34 +25,29 @@ from api.db.services.connector_service import Connector2KbService from api.db.services.llm_service import LLMBundle from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks from api.db.services.doc_metadata_service import DocMetadataService -from api.db.services.file2document_service import File2DocumentService -from api.db.services.file_service import FileService from api.db.services.pipeline_operation_log_service import PipelineOperationLogService from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID -from api.db.services.user_service import TenantService, UserTenantService +from api.db.services.user_service import UserTenantService from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name, get_model_config_by_id from api.utils.api_utils import ( get_error_data_result, server_error_response, get_data_error_result, validate_request, - not_allowed_parameters, get_request_json, ) -from common.misc_utils import thread_pool_exec from api.db import VALID_FILE_TYPES from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db.db_models import File from api.utils.api_utils import get_json_result -from api.utils.tenant_utils import ensure_tenant_model_id_for_params from rag.nlp import search -from api.constants import DATASET_NAME_LIMIT from rag.utils.redis_conn import REDIS_CONN -from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType, PAGERANK_FLD +from common.constants import RetCode, PipelineTaskType, VALID_TASK_STATUS, LLMType from common import settings from common.doc_store.doc_store_base import OrderByExpr from api.apps import login_required, current_user +""" +Deprecated, todo delete @manager.route('/create', methods=['post']) # noqa: F821 @login_required @validate_request("name") @@ -186,7 +180,7 @@ async def update(): return get_json_result(data=kb) except Exception as e: return server_error_response(e) - +""" @manager.route('/update_metadata_setting', methods=['post']) # noqa: F821 @login_required @@ -234,7 +228,8 @@ def detail(): except Exception as e: return server_error_response(e) - +""" +Deprecated, todo delete @manager.route('/list', methods=['POST']) # noqa: F821 @login_required async def list_kbs(): @@ -329,7 +324,7 @@ async def rm(): return await thread_pool_exec(_rm_sync) except Exception as e: return server_error_response(e) - +""" @manager.route('//tags', methods=['GET']) # noqa: F821 @login_required @@ -405,7 +400,8 @@ async def rename_tags(kb_id): kb_id) return get_json_result(data=True) - +""" +Deprecated, todo delete @manager.route('//knowledge_graph', methods=['GET']) # noqa: F821 @login_required async def knowledge_graph(kb_id): @@ -459,7 +455,7 @@ def delete_knowledge_graph(kb_id): settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id) return get_json_result(data=True) - +""" @manager.route("/get_meta", methods=["GET"]) # noqa: F821 @login_required @@ -598,6 +594,8 @@ def pipeline_log_detail(): return get_json_result(data=log.to_dict()) +""" +Deprecated, todo delete @manager.route("/run_graphrag", methods=["POST"]) # noqa: F821 @login_required async def run_graphrag(): @@ -734,7 +732,7 @@ def trace_raptor(): return get_error_data_result(message="RAPTOR Task Not Found or Error Occurred") return get_json_result(data=task.to_dict()) - +""" @manager.route("/run_mindmap", methods=["POST"]) # noqa: F821 @login_required diff --git a/api/apps/restful_apis/dataset_api.py b/api/apps/restful_apis/dataset_api.py new file mode 100644 index 000000000..4f3ff2d59 --- /dev/null +++ b/api/apps/restful_apis/dataset_api.py @@ -0,0 +1,517 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging + +from peewee import OperationalError +from quart import request +from common.constants import RetCode +from api.apps import login_required, current_user +from api.utils.api_utils import get_error_argument_result, get_error_data_result, get_result, add_tenant_id_to_kwargs +from api.utils.validation_utils import ( + CreateDatasetReq, + DeleteDatasetReq, + ListDatasetReq, + UpdateDatasetReq, + validate_and_parse_json_request, + validate_and_parse_request_args, +) +from api.apps.services import dataset_api_service + + +@manager.route("/datasets", methods=["POST"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def create(tenant_id: str=None): + """ + Create a new dataset. + --- + tags: + - Datasets + security: + - ApiKeyAuth: [] + parameters: + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + - in: body + name: body + description: Dataset creation parameters. + required: true + schema: + type: object + required: + - name + properties: + name: + type: string + description: Dataset name (required). + avatar: + type: string + description: Optional base64-encoded avatar image. + description: + type: string + description: Optional dataset description. + embedding_model: + type: string + description: Optional embedding model name; if omitted, the tenant's default embedding model is used. + permission: + type: string + enum: ['me', 'team'] + description: Visibility of the dataset (private to me or shared with team). + chunk_method: + type: string + enum: ["naive", "book", "email", "laws", "manual", "one", "paper", + "picture", "presentation", "qa", "table", "tag"] + description: Chunking method; if omitted, defaults to "naive". + parser_config: + type: object + description: Optional parser configuration; server-side defaults will be applied. + responses: + 200: + description: Successful operation. + schema: + type: object + properties: + data: + type: object + """ + req, err = await validate_and_parse_json_request(request, CreateDatasetReq) + if err is not None: + return get_error_argument_result(err) + + try: + if not tenant_id: + tenant_id = current_user.id + success, result = await dataset_api_service.create_dataset(tenant_id, req) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/datasets", methods=["DELETE"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def delete(tenant_id): + """ + Delete datasets. + --- + tags: + - Datasets + security: + - ApiKeyAuth: [] + parameters: + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + - in: body + name: body + description: Dataset deletion parameters. + required: true + schema: + type: object + required: + - ids + properties: + ids: + type: array or null + items: + type: string + description: | + Specifies the datasets to delete: + - If `null`, all datasets will be deleted. + - If an array of IDs, only the specified datasets will be deleted. + - If an empty array, no datasets will be deleted. + responses: + 200: + description: Successful operation. + schema: + type: object + """ + req, err = await validate_and_parse_json_request(request, DeleteDatasetReq) + if err is not None: + return get_error_argument_result(err) + + try: + success, result = await dataset_api_service.delete_datasets(tenant_id, req.get("ids"), req.get("delete_all", False)) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except OperationalError as e: + logging.exception(e) + return get_error_data_result(message="Database operation failed") + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/datasets/", methods=["PUT"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def update(tenant_id, dataset_id): + """ + Update a dataset. + --- + tags: + - Datasets + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: dataset_id + type: string + required: true + description: ID of the dataset to update. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + - in: body + name: body + description: Dataset update parameters. + required: true + schema: + type: object + properties: + name: + type: string + description: New name of the dataset. + avatar: + type: string + description: Updated base64 encoding of the avatar. + description: + type: string + description: Updated description of the dataset. + embedding_model: + type: string + description: Updated embedding model Name. + permission: + type: string + enum: ['me', 'team'] + description: Updated dataset permission. + chunk_method: + type: string + enum: ["naive", "book", "email", "laws", "manual", "one", "paper", + "picture", "presentation", "qa", "table", "tag" + ] + description: Updated chunking method. + pagerank: + type: integer + description: Updated page rank. + parser_config: + type: object + description: Updated parser configuration. + responses: + 200: + description: Successful operation. + schema: + type: object + """ + # Field name transformations during model dump: + # | Original | Dump Output | + # |----------------|-------------| + # | embedding_model| embd_id | + # | chunk_method | parser_id | + extras = {"dataset_id": dataset_id} + req, err = await validate_and_parse_json_request(request, UpdateDatasetReq, extras=extras, exclude_unset=True) + if err is not None: + return get_error_argument_result(err) + + try: + success, result = await dataset_api_service.update_dataset(tenant_id, dataset_id, req) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except OperationalError as e: + logging.exception(e) + return get_error_data_result(message="Database operation failed") + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/datasets", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +def list_datasets(tenant_id): + """ + List datasets. + --- + tags: + - Datasets + security: + - ApiKeyAuth: [] + parameters: + - in: query + name: id + type: string + required: false + description: Dataset ID to filter. + - in: query + name: name + type: string + required: false + description: Dataset name to filter. + - in: query + name: page + type: integer + required: false + default: 1 + description: Page number. + - in: query + name: page_size + type: integer + required: false + default: 30 + description: Number of items per page. + - in: query + name: orderby + type: string + required: false + default: "create_time" + description: Field to order by. + - in: query + name: desc + type: boolean + required: false + default: true + description: Order in descending. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + responses: + 200: + description: Successful operation. + schema: + type: array + items: + type: object + """ + args, err = validate_and_parse_request_args(request, ListDatasetReq) + if err is not None: + return get_error_argument_result(err) + + try: + success, result = dataset_api_service.list_datasets(tenant_id, args) + if success: + return get_result(data=result.get("data"), total=result.get("total")) + else: + return get_error_data_result(message=result) + except OperationalError as e: + logging.exception(e) + return get_error_data_result(message="Database operation failed") + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route('/datasets//knowledge_graph', methods=['GET']) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def knowledge_graph(tenant_id, dataset_id): + try: + success, result = await dataset_api_service.get_knowledge_graph(dataset_id, tenant_id) + if success: + return get_result(data=result) + else: + return get_result( + data=False, + message=result, + code=RetCode.AUTHENTICATION_ERROR + ) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route('/datasets//knowledge_graph', methods=['DELETE']) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +def delete_knowledge_graph(tenant_id, dataset_id): + try: + success, result = dataset_api_service.delete_knowledge_graph(dataset_id, tenant_id) + if success: + return get_result(data=result) + else: + return get_result( + data=False, + message=result, + code=RetCode.AUTHENTICATION_ERROR + ) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/datasets//run_graphrag", methods=["POST"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def run_graphrag(tenant_id, dataset_id): + try: + success, result = dataset_api_service.run_graphrag(dataset_id, tenant_id) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/datasets//trace_graphrag", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +def trace_graphrag(tenant_id, dataset_id): + try: + success, result = dataset_api_service.trace_graphrag(dataset_id, tenant_id) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/datasets//run_raptor", methods=["POST"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def run_raptor(tenant_id, dataset_id): + try: + success, result = dataset_api_service.run_raptor(dataset_id, tenant_id) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/datasets//trace_raptor", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +def trace_raptor(tenant_id, dataset_id): + try: + success, result = dataset_api_service.trace_raptor(dataset_id, tenant_id) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/datasets//auto_metadata", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +def get_auto_metadata(tenant_id, dataset_id): + """ + Get auto-metadata configuration for a dataset. + --- + tags: + - Datasets + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: dataset_id + type: string + required: true + description: ID of the dataset. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + responses: + 200: + description: Successful operation. + schema: + type: object + """ + try: + success, result = dataset_api_service.get_auto_metadata(dataset_id, tenant_id) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") + + +@manager.route("/datasets//auto_metadata", methods=["PUT"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def update_auto_metadata(tenant_id, dataset_id): + """ + Update auto-metadata configuration for a dataset. + --- + tags: + - Datasets + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: dataset_id + type: string + required: true + description: ID of the dataset. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + - in: body + name: body + description: Auto-metadata configuration. + required: true + schema: + type: object + responses: + 200: + description: Successful operation. + schema: + type: object + """ + from api.utils.validation_utils import AutoMetadataConfig + cfg, err = await validate_and_parse_json_request(request, AutoMetadataConfig) + if err is not None: + return get_error_argument_result(err) + + try: + success, result = await dataset_api_service.update_auto_metadata(dataset_id, tenant_id, cfg) + if success: + return get_result(data=result) + else: + return get_error_data_result(message=result) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Internal server error") diff --git a/api/apps/restful_apis/memory_api.py b/api/apps/restful_apis/memory_api.py index 79a85d631..672adde6e 100644 --- a/api/apps/restful_apis/memory_api.py +++ b/api/apps/restful_apis/memory_api.py @@ -1,5 +1,5 @@ # -# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/api/apps/sdk/dataset.py b/api/apps/sdk/dataset.py deleted file mode 100644 index bda151671..000000000 --- a/api/apps/sdk/dataset.py +++ /dev/null @@ -1,798 +0,0 @@ -# -# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -import logging -import os -import json -from quart import request -from peewee import OperationalError -from api.db.db_models import File -from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks -from api.db.services.file2document_service import File2DocumentService -from api.db.services.file_service import FileService -from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID, TaskService -from api.db.services.user_service import TenantService -from common.constants import RetCode, FileSource, StatusEnum -from api.utils.api_utils import ( - deep_merge, - get_error_argument_result, - get_error_data_result, - get_error_permission_result, - get_parser_config, - get_result, - remap_dictionary_keys, - token_required, - verify_embedding_availability, -) -from api.utils.validation_utils import ( - AutoMetadataConfig, - CreateDatasetReq, - DeleteDatasetReq, - ListDatasetReq, - UpdateDatasetReq, - validate_and_parse_json_request, - validate_and_parse_request_args, -) -from rag.nlp import search -from common.constants import PAGERANK_FLD -from common import settings - - -@manager.route("/datasets", methods=["POST"]) # noqa: F821 -@token_required -async def create(tenant_id): - """ - Create a new dataset. - --- - tags: - - Datasets - security: - - ApiKeyAuth: [] - parameters: - - in: header - name: Authorization - type: string - required: true - description: Bearer token for authentication. - - in: body - name: body - description: Dataset creation parameters. - required: true - schema: - type: object - required: - - name - properties: - name: - type: string - description: Dataset name (required). - avatar: - type: string - description: Optional base64-encoded avatar image. - description: - type: string - description: Optional dataset description. - embedding_model: - type: string - description: Optional embedding model name; if omitted, the tenant's default embedding model is used. - permission: - type: string - enum: ['me', 'team'] - description: Visibility of the dataset (private to me or shared with team). - chunk_method: - type: string - enum: ["naive", "book", "email", "laws", "manual", "one", "paper", - "picture", "presentation", "qa", "table", "tag"] - description: Chunking method; if omitted, defaults to "naive". - parser_config: - type: object - description: Optional parser configuration; server-side defaults will be applied. - responses: - 200: - description: Successful operation. - schema: - type: object - properties: - data: - type: object - """ - # Field name transformations during model dump: - # | Original | Dump Output | - # |----------------|-------------| - # | embedding_model| embd_id | - # | chunk_method | parser_id | - - req, err = await validate_and_parse_json_request(request, CreateDatasetReq) - if err is not None: - return get_error_argument_result(err) - # Map auto_metadata_config (if provided) into parser_config structure - auto_meta = req.pop("auto_metadata_config", None) - if auto_meta: - parser_cfg = req.get("parser_config") or {} - fields = [] - for f in auto_meta.get("fields", []): - fields.append( - { - "name": f.get("name", ""), - "type": f.get("type", ""), - "description": f.get("description"), - "examples": f.get("examples"), - "restrict_values": f.get("restrict_values", False), - } - ) - parser_cfg["metadata"] = fields - parser_cfg["enable_metadata"] = auto_meta.get("enabled", True) - req["parser_config"] = parser_cfg - e, req = KnowledgebaseService.create_with_name(name=req.pop("name", None), tenant_id=tenant_id, parser_id=req.pop("parser_id", None), **req) - - if not e: - return req - - # Insert embedding model(embd id) - ok, t = TenantService.get_by_id(tenant_id) - if not ok: - return get_error_permission_result(message="Tenant not found") - if not req.get("embd_id"): - req["embd_id"] = t.embd_id - else: - ok, err = verify_embedding_availability(req["embd_id"], tenant_id) - if not ok: - return err - - try: - if not KnowledgebaseService.save(**req): - return get_error_data_result() - ok, k = KnowledgebaseService.get_by_id(req["id"]) - if not ok: - return get_error_data_result(message="Dataset created failed") - response_data = remap_dictionary_keys(k.to_dict()) - return get_result(data=response_data) - except Exception as e: - logging.exception(e) - return get_error_data_result(message="Database operation failed") - - -@manager.route("/datasets", methods=["DELETE"]) # noqa: F821 -@token_required -async def delete(tenant_id): - """ - Delete datasets. - --- - tags: - - Datasets - security: - - ApiKeyAuth: [] - parameters: - - in: header - name: Authorization - type: string - required: true - description: Bearer token for authentication. - - in: body - name: body - description: Dataset deletion parameters. - required: true - schema: - type: object - required: - - ids - properties: - ids: - type: array or null - items: - type: string - description: | - List of dataset IDs to delete. - If `null` or an empty array is provided, no datasets will be deleted - unless `delete_all` is set to `true`. - delete_all: - type: boolean - description: | - If `true` and `ids` is null or empty, delete all datasets owned by the current user. - Defaults to `false`. - responses: - 200: - description: Successful operation. - schema: - type: object - """ - req, err = await validate_and_parse_json_request(request, DeleteDatasetReq) - if err is not None: - return get_error_argument_result(err) - - try: - kb_id_instance_pairs = [] - if req["ids"] is None or len(req["ids"]) == 0: - if req.get("delete_all"): - req["ids"] = [kb.id for kb in KnowledgebaseService.query(tenant_id=tenant_id)] - if not req["ids"]: - return get_result() - else: - return get_result() - - error_kb_ids = [] - for kb_id in req["ids"]: - kb = KnowledgebaseService.get_or_none(id=kb_id, tenant_id=tenant_id) - if kb is None: - error_kb_ids.append(kb_id) - continue - kb_id_instance_pairs.append((kb_id, kb)) - if len(error_kb_ids) > 0: - return get_error_permission_result(message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""") - - errors = [] - success_count = 0 - for kb_id, kb in kb_id_instance_pairs: - for doc in DocumentService.query(kb_id=kb_id): - if not DocumentService.remove_document(doc, tenant_id): - errors.append(f"Remove document '{doc.id}' error for dataset '{kb_id}'") - continue - f2d = File2DocumentService.get_by_document_id(doc.id) - FileService.filter_delete( - [ - File.source_type == FileSource.KNOWLEDGEBASE, - File.id == f2d[0].file_id, - ] - ) - File2DocumentService.delete_by_document_id(doc.id) - FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name]) - - # Drop index for this dataset - try: - from rag.nlp import search - - idxnm = search.index_name(kb.tenant_id) - settings.docStoreConn.delete_idx(idxnm, kb_id) - except Exception as e: - logging.warning(f"Failed to drop index for dataset {kb_id}: {e}") - - if not KnowledgebaseService.delete_by_id(kb_id): - errors.append(f"Delete dataset error for {kb_id}") - continue - success_count += 1 - - if not errors: - return get_result() - - error_message = f"Successfully deleted {success_count} datasets, {len(errors)} failed. Details: {'; '.join(errors)[:128]}..." - if success_count == 0: - return get_error_data_result(message=error_message) - - return get_result(data={"success_count": success_count, "errors": errors[:5]}, message=error_message) - except OperationalError as e: - logging.exception(e) - return get_error_data_result(message="Database operation failed") - - -@manager.route("/datasets/", methods=["PUT"]) # noqa: F821 -@token_required -async def update(tenant_id, dataset_id): - """ - Update a dataset. - --- - tags: - - Datasets - security: - - ApiKeyAuth: [] - parameters: - - in: path - name: dataset_id - type: string - required: true - description: ID of the dataset to update. - - in: header - name: Authorization - type: string - required: true - description: Bearer token for authentication. - - in: body - name: body - description: Dataset update parameters. - required: true - schema: - type: object - properties: - name: - type: string - description: New name of the dataset. - avatar: - type: string - description: Updated base64 encoding of the avatar. - description: - type: string - description: Updated description of the dataset. - embedding_model: - type: string - description: Updated embedding model Name. - permission: - type: string - enum: ['me', 'team'] - description: Updated dataset permission. - chunk_method: - type: string - enum: ["naive", "book", "email", "laws", "manual", "one", "paper", - "picture", "presentation", "qa", "table", "tag" - ] - description: Updated chunking method. - pagerank: - type: integer - description: Updated page rank. - parser_config: - type: object - description: Updated parser configuration. - responses: - 200: - description: Successful operation. - schema: - type: object - """ - # Field name transformations during model dump: - # | Original | Dump Output | - # |----------------|-------------| - # | embedding_model| embd_id | - # | chunk_method | parser_id | - extras = {"dataset_id": dataset_id} - req, err = await validate_and_parse_json_request(request, UpdateDatasetReq, extras=extras, exclude_unset=True) - if err is not None: - return get_error_argument_result(err) - - if not req: - return get_error_argument_result(message="No properties were modified") - - try: - kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id) - if kb is None: - return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'") - - # Map auto_metadata_config into parser_config if present - auto_meta = req.pop("auto_metadata_config", None) - if auto_meta: - parser_cfg = req.get("parser_config") or {} - fields = [] - for f in auto_meta.get("fields", []): - fields.append( - { - "name": f.get("name", ""), - "type": f.get("type", ""), - "description": f.get("description"), - "examples": f.get("examples"), - "restrict_values": f.get("restrict_values", False), - } - ) - parser_cfg["metadata"] = fields - parser_cfg["enable_metadata"] = auto_meta.get("enabled", True) - req["parser_config"] = parser_cfg - - if req.get("parser_config"): - req["parser_config"] = deep_merge(kb.parser_config, req["parser_config"]) - - if (chunk_method := req.get("parser_id")) and chunk_method != kb.parser_id: - if not req.get("parser_config"): - req["parser_config"] = get_parser_config(chunk_method, None) - elif "parser_config" in req and not req["parser_config"]: - del req["parser_config"] - - if "name" in req and req["name"].lower() != kb.name.lower(): - exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value) - if exists: - return get_error_data_result(message=f"Dataset name '{req['name']}' already exists") - - if "embd_id" in req: - if not req["embd_id"]: - req["embd_id"] = kb.embd_id - if kb.chunk_num != 0 and req["embd_id"] != kb.embd_id: - return get_error_data_result(message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}") - ok, err = verify_embedding_availability(req["embd_id"], tenant_id) - if not ok: - return err - - if "pagerank" in req and req["pagerank"] != kb.pagerank: - if os.environ.get("DOC_ENGINE", "elasticsearch") == "infinity": - return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch") - - if req["pagerank"] > 0: - settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, search.index_name(kb.tenant_id), kb.id) - else: - # Elasticsearch requires PAGERANK_FLD be non-zero! - settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, search.index_name(kb.tenant_id), kb.id) - - if not KnowledgebaseService.update_by_id(kb.id, req): - return get_error_data_result(message="Update dataset error.(Database error)") - - ok, k = KnowledgebaseService.get_by_id(kb.id) - if not ok: - return get_error_data_result(message="Dataset created failed") - - response_data = remap_dictionary_keys(k.to_dict()) - return get_result(data=response_data) - except OperationalError as e: - logging.exception(e) - return get_error_data_result(message="Database operation failed") - - -@manager.route("/datasets", methods=["GET"]) # noqa: F821 -@token_required -def list_datasets(tenant_id): - """ - List datasets. - --- - tags: - - Datasets - security: - - ApiKeyAuth: [] - parameters: - - in: query - name: id - type: string - required: false - description: Dataset ID to filter. - - in: query - name: name - type: string - required: false - description: Dataset name to filter. - - in: query - name: page - type: integer - required: false - default: 1 - description: Page number. - - in: query - name: page_size - type: integer - required: false - default: 30 - description: Number of items per page. - - in: query - name: orderby - type: string - required: false - default: "create_time" - description: Field to order by. - - in: query - name: desc - type: boolean - required: false - default: true - description: Order in descending. - - in: query - name: include_parsing_status - type: boolean - required: false - default: false - description: | - Whether to include document parsing status counts in the response. - When true, each dataset object will include: unstart_count, running_count, - cancel_count, done_count, and fail_count. - - in: header - name: Authorization - type: string - required: true - description: Bearer token for authentication. - responses: - 200: - description: Successful operation. - schema: - type: array - items: - type: object - """ - args, err = validate_and_parse_request_args(request, ListDatasetReq) - if err is not None: - return get_error_argument_result(err) - - include_parsing_status = args.get("include_parsing_status", False) - - try: - kb_id = request.args.get("id") - name = args.get("name") - # check whether user has permission for the dataset with specified id - if kb_id: - if not KnowledgebaseService.get_kb_by_id(kb_id, tenant_id): - return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{kb_id}'") - # check whether user has permission for the dataset with specified name - if name: - if not KnowledgebaseService.get_kb_by_name(name, tenant_id): - return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{name}'") - - tenants = TenantService.get_joined_tenants_by_user_id(tenant_id) - kbs, total = KnowledgebaseService.get_list( - [m["tenant_id"] for m in tenants], - tenant_id, - args["page"], - args["page_size"], - args["orderby"], - args["desc"], - kb_id, - name, - ) - - parsing_status_map = {} - if include_parsing_status and kbs: - kb_ids = [kb["id"] for kb in kbs] - parsing_status_map = DocumentService.get_parsing_status_by_kb_ids(kb_ids) - - response_data_list = [] - for kb in kbs: - data = remap_dictionary_keys(kb) - if include_parsing_status: - data.update(parsing_status_map.get(kb["id"], {})) - response_data_list.append(data) - return get_result(data=response_data_list, total=total) - except OperationalError as e: - logging.exception(e) - return get_error_data_result(message="Database operation failed") - - -@manager.route("/datasets//auto_metadata", methods=["GET"]) # noqa: F821 -@token_required -def get_auto_metadata(tenant_id, dataset_id): - """ - Get auto-metadata configuration for a dataset. - """ - try: - kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id) - if kb is None: - return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'") - - parser_cfg = kb.parser_config or {} - metadata = parser_cfg.get("metadata") or [] - enabled = parser_cfg.get("enable_metadata", bool(metadata)) - # Normalize to AutoMetadataConfig-like JSON - fields = [] - for f in metadata: - if not isinstance(f, dict): - continue - fields.append( - { - "name": f.get("name", ""), - "type": f.get("type", ""), - "description": f.get("description"), - "examples": f.get("examples"), - "restrict_values": f.get("restrict_values", False), - } - ) - return get_result(data={"enabled": enabled, "fields": fields}) - except OperationalError as e: - logging.exception(e) - return get_error_data_result(message="Database operation failed") - - -@manager.route("/datasets//auto_metadata", methods=["PUT"]) # noqa: F821 -@token_required -async def update_auto_metadata(tenant_id, dataset_id): - """ - Update auto-metadata configuration for a dataset. - """ - cfg, err = await validate_and_parse_json_request(request, AutoMetadataConfig) - if err is not None: - return get_error_argument_result(err) - - try: - kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id) - if kb is None: - return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'") - - parser_cfg = kb.parser_config or {} - fields = [] - for f in cfg.get("fields", []): - fields.append( - { - "name": f.get("name", ""), - "type": f.get("type", ""), - "description": f.get("description"), - "examples": f.get("examples"), - "restrict_values": f.get("restrict_values", False), - } - ) - parser_cfg["metadata"] = fields - parser_cfg["enable_metadata"] = cfg.get("enabled", True) - - if not KnowledgebaseService.update_by_id(kb.id, {"parser_config": parser_cfg}): - return get_error_data_result(message="Update auto-metadata error.(Database error)") - - return get_result(data={"enabled": parser_cfg["enable_metadata"], "fields": fields}) - except OperationalError as e: - logging.exception(e) - return get_error_data_result(message="Database operation failed") - - -@manager.route("/datasets//knowledge_graph", methods=["GET"]) # noqa: F821 -@token_required -async def knowledge_graph(tenant_id, dataset_id): - if not KnowledgebaseService.accessible(dataset_id, tenant_id): - return get_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) - _, kb = KnowledgebaseService.get_by_id(dataset_id) - req = {"kb_id": [dataset_id], "knowledge_graph_kwd": ["graph"]} - - obj = {"graph": {}, "mind_map": {}} - if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), dataset_id): - return get_result(data=obj) - sres = await settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id]) - if not len(sres.ids): - return get_result(data=obj) - - for id in sres.ids[:1]: - ty = sres.field[id]["knowledge_graph_kwd"] - try: - content_json = json.loads(sres.field[id]["content_with_weight"]) - except Exception: - continue - - obj[ty] = content_json - - if "nodes" in obj["graph"]: - obj["graph"]["nodes"] = sorted(obj["graph"]["nodes"], key=lambda x: x.get("pagerank", 0), reverse=True)[:256] - if "edges" in obj["graph"]: - node_id_set = {o["id"] for o in obj["graph"]["nodes"]} - filtered_edges = [o for o in obj["graph"]["edges"] if o["source"] != o["target"] and o["source"] in node_id_set and o["target"] in node_id_set] - obj["graph"]["edges"] = sorted(filtered_edges, key=lambda x: x.get("weight", 0), reverse=True)[:128] - return get_result(data=obj) - - -@manager.route("/datasets//knowledge_graph", methods=["DELETE"]) # noqa: F821 -@token_required -def delete_knowledge_graph(tenant_id, dataset_id): - if not KnowledgebaseService.accessible(dataset_id, tenant_id): - return get_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) - _, kb = KnowledgebaseService.get_by_id(dataset_id) - settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), dataset_id) - - return get_result(data=True) - - -@manager.route("/datasets//run_graphrag", methods=["POST"]) # noqa: F821 -@token_required -def run_graphrag(tenant_id, dataset_id): - if not dataset_id: - return get_error_data_result(message='Lack of "Dataset ID"') - if not KnowledgebaseService.accessible(dataset_id, tenant_id): - return get_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) - - ok, kb = KnowledgebaseService.get_by_id(dataset_id) - if not ok: - return get_error_data_result(message="Invalid Dataset ID") - - task_id = kb.graphrag_task_id - if task_id: - ok, task = TaskService.get_by_id(task_id) - if not ok: - logging.warning(f"A valid GraphRAG task id is expected for Dataset {dataset_id}") - - if task and task.progress not in [-1, 1]: - return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running.") - - documents, _ = DocumentService.get_by_kb_id( - kb_id=dataset_id, - page_number=0, - items_per_page=0, - orderby="create_time", - desc=False, - keywords="", - run_status=[], - types=[], - suffix=[], - ) - if not documents: - return get_error_data_result(message=f"No documents in Dataset {dataset_id}") - - sample_document = documents[0] - document_ids = [document["id"] for document in documents] - - task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids)) - - if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}): - logging.warning(f"Cannot save graphrag_task_id for Dataset {dataset_id}") - - return get_result(data={"graphrag_task_id": task_id}) - - -@manager.route("/datasets//trace_graphrag", methods=["GET"]) # noqa: F821 -@token_required -def trace_graphrag(tenant_id, dataset_id): - if not dataset_id: - return get_error_data_result(message='Lack of "Dataset ID"') - if not KnowledgebaseService.accessible(dataset_id, tenant_id): - return get_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) - - ok, kb = KnowledgebaseService.get_by_id(dataset_id) - if not ok: - return get_error_data_result(message="Invalid Dataset ID") - - task_id = kb.graphrag_task_id - if not task_id: - return get_result(data={}) - - ok, task = TaskService.get_by_id(task_id) - if not ok: - return get_result(data={}) - - return get_result(data=task.to_dict()) - - -@manager.route("/datasets//run_raptor", methods=["POST"]) # noqa: F821 -@token_required -def run_raptor(tenant_id, dataset_id): - if not dataset_id: - return get_error_data_result(message='Lack of "Dataset ID"') - if not KnowledgebaseService.accessible(dataset_id, tenant_id): - return get_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) - - ok, kb = KnowledgebaseService.get_by_id(dataset_id) - if not ok: - return get_error_data_result(message="Invalid Dataset ID") - - task_id = kb.raptor_task_id - if task_id: - ok, task = TaskService.get_by_id(task_id) - if not ok: - logging.warning(f"A valid RAPTOR task id is expected for Dataset {dataset_id}") - - if task and task.progress not in [-1, 1]: - return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running.") - - documents, _ = DocumentService.get_by_kb_id( - kb_id=dataset_id, - page_number=0, - items_per_page=0, - orderby="create_time", - desc=False, - keywords="", - run_status=[], - types=[], - suffix=[], - ) - if not documents: - return get_error_data_result(message=f"No documents in Dataset {dataset_id}") - - sample_document = documents[0] - document_ids = [document["id"] for document in documents] - - task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids)) - - if not KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": task_id}): - logging.warning(f"Cannot save raptor_task_id for Dataset {dataset_id}") - - return get_result(data={"raptor_task_id": task_id}) - - -@manager.route("/datasets//trace_raptor", methods=["GET"]) # noqa: F821 -@token_required -def trace_raptor(tenant_id, dataset_id): - if not dataset_id: - return get_error_data_result(message='Lack of "Dataset ID"') - - if not KnowledgebaseService.accessible(dataset_id, tenant_id): - return get_result( - data=False, - message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR - ) - ok, kb = KnowledgebaseService.get_by_id(dataset_id) - if not ok: - return get_error_data_result(message="Invalid Dataset ID") - - task_id = kb.raptor_task_id - if not task_id: - return get_result(data={}) - - ok, task = TaskService.get_by_id(task_id) - if not ok: - return get_error_data_result(message="RAPTOR Task Not Found or Error Occurred") - - return get_result(data=task.to_dict()) diff --git a/api/apps/services/dataset_api_service.py b/api/apps/services/dataset_api_service.py new file mode 100644 index 000000000..094570528 --- /dev/null +++ b/api/apps/services/dataset_api_service.py @@ -0,0 +1,613 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging +import json +import os +from common.constants import PAGERANK_FLD +from common import settings +from api.db.db_models import File +from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks +from api.db.services.file2document_service import File2DocumentService +from api.db.services.file_service import FileService +from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.services.connector_service import Connector2KbService +from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID, TaskService +from api.db.services.user_service import TenantService, UserService +from common.constants import FileSource, StatusEnum +from api.utils.api_utils import deep_merge, get_parser_config, remap_dictionary_keys, verify_embedding_availability + + +async def create_dataset(tenant_id: str, req: dict): + """ + Create a new dataset. + + :param tenant_id: tenant ID + :param req: dataset creation request + :return: (success, result) or (success, error_message) + """ + # Extract ext field for additional parameters + ext_fields = req.pop("ext", {}) + + # Map auto_metadata_config (if provided) into parser_config structure + auto_meta = req.pop("auto_metadata_config", {}) + if auto_meta: + parser_cfg = req.get("parser_config") or {} + fields = [] + for f in auto_meta.get("fields", []): + fields.append( + { + "name": f.get("name", ""), + "type": f.get("type", ""), + "description": f.get("description"), + "examples": f.get("examples"), + "restrict_values": f.get("restrict_values", False), + } + ) + parser_cfg["metadata"] = fields + parser_cfg["enable_metadata"] = auto_meta.get("enabled", True) + req["parser_config"] = parser_cfg + req.update(ext_fields) + + e, create_dict = KnowledgebaseService.create_with_name( + name=req.pop("name", None), + tenant_id=tenant_id, + parser_id=req.pop("parser_id", None), + **req + ) + + if not e: + return False, create_dict + + # Insert embedding model(embd id) + ok, t = TenantService.get_by_id(tenant_id) + if not ok: + return False, "Tenant not found" + if not create_dict.get("embd_id"): + create_dict["embd_id"] = t.embd_id + else: + ok, err = verify_embedding_availability(create_dict["embd_id"], tenant_id) + if not ok: + return False, err + + if not KnowledgebaseService.save(**create_dict): + return False, "Failed to save dataset" + ok, k = KnowledgebaseService.get_by_id(create_dict["id"]) + if not ok: + return False, "Dataset created failed" + response_data = remap_dictionary_keys(k.to_dict()) + return True, response_data + + +async def delete_datasets(tenant_id: str, ids: list = None, delete_all: bool = False): + """ + Delete datasets. + + :param tenant_id: tenant ID + :param ids: list of dataset IDs + :param delete_all: whether to delete all datasets of the tenant (if ids is not provided) + :return: (success, result) or (success, error_message) + """ + kb_id_instance_pairs = [] + if not ids: + if not delete_all: + return True, {"success_count": 0} + else: + ids = [kb.id for kb in KnowledgebaseService.query(tenant_id=tenant_id)] + + error_kb_ids = [] + for kb_id in ids: + kb = KnowledgebaseService.get_or_none(id=kb_id, tenant_id=tenant_id) + if kb is None: + error_kb_ids.append(kb_id) + continue + kb_id_instance_pairs.append((kb_id, kb)) + if len(error_kb_ids) > 0: + return False, f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""" + + errors = [] + success_count = 0 + for kb_id, kb in kb_id_instance_pairs: + for doc in DocumentService.query(kb_id=kb_id): + if not DocumentService.remove_document(doc, tenant_id): + errors.append(f"Remove document '{doc.id}' error for dataset '{kb_id}'") + continue + f2d = File2DocumentService.get_by_document_id(doc.id) + FileService.filter_delete( + [ + File.source_type == FileSource.KNOWLEDGEBASE, + File.id == f2d[0].file_id, + ] + ) + File2DocumentService.delete_by_document_id(doc.id) + FileService.filter_delete( + [File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name]) + + # Drop index for this dataset + try: + from rag.nlp import search + idxnm = search.index_name(kb.tenant_id) + settings.docStoreConn.delete_idx(idxnm, kb_id) + except Exception as e: + errors.append(f"Failed to drop index for dataset {kb_id}: {e}") + + if not KnowledgebaseService.delete_by_id(kb_id): + errors.append(f"Delete dataset error for {kb_id}") + continue + success_count += 1 + + if not errors: + return True, {"success_count": success_count} + + error_message = f"Successfully deleted {success_count} datasets, {len(errors)} failed. Details: {'; '.join(errors)[:128]}..." + if success_count == 0: + return False, error_message + + return True, {"success_count": success_count, "errors": errors[:5]} + + +async def update_dataset(tenant_id: str, dataset_id: str, req: dict): + """ + Update a dataset. + + :param tenant_id: tenant ID + :param dataset_id: dataset ID + :param req: dataset update request + :return: (success, result) or (success, error_message) + """ + if not req: + return False, "No properties were modified" + + kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id) + if kb is None: + return False, f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'" + + # Extract ext field for additional parameters + ext_fields = req.pop("ext", {}) + + # Map auto_metadata_config into parser_config if present + auto_meta = req.pop("auto_metadata_config", {}) + if auto_meta: + parser_cfg = req.get("parser_config") or {} + fields = [] + for f in auto_meta.get("fields", []): + fields.append( + { + "name": f.get("name", ""), + "type": f.get("type", ""), + "description": f.get("description"), + "examples": f.get("examples"), + "restrict_values": f.get("restrict_values", False), + } + ) + parser_cfg["metadata"] = fields + parser_cfg["enable_metadata"] = auto_meta.get("enabled", True) + req["parser_config"] = parser_cfg + + # Merge ext fields with req + req.update(ext_fields) + + # Extract connectors from request + connectors = [] + if "connectors" in req: + connectors = req["connectors"] + del req["connectors"] + + if req.get("parser_config"): + parser_config = req["parser_config"] + req_ext_fields = parser_config.pop("ext", {}) + parser_config.update(req_ext_fields) + req["parser_config"] = deep_merge(kb.parser_config, parser_config) + + if (chunk_method := req.get("parser_id")) and chunk_method != kb.parser_id: + if not req.get("parser_config"): + req["parser_config"] = get_parser_config(chunk_method, None) + elif "parser_config" in req and not req["parser_config"]: + del req["parser_config"] + + if "name" in req and req["name"].lower() != kb.name.lower(): + exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, + status=StatusEnum.VALID.value) + if exists: + return False, f"Dataset name '{req['name']}' already exists" + + if "embd_id" in req: + if not req["embd_id"]: + req["embd_id"] = kb.embd_id + if kb.chunk_num != 0 and req["embd_id"] != kb.embd_id: + return False, f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}" + ok, err = verify_embedding_availability(req["embd_id"], tenant_id) + if not ok: + return False, err + + if "pagerank" in req and req["pagerank"] != kb.pagerank: + if os.environ.get("DOC_ENGINE", "elasticsearch") == "infinity": + return False, "'pagerank' can only be set when doc_engine is elasticsearch" + + if req["pagerank"] > 0: + from rag.nlp import search + settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, + search.index_name(kb.tenant_id), kb.id) + else: + # Elasticsearch requires PAGERANK_FLD be non-zero! + from rag.nlp import search + settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, + search.index_name(kb.tenant_id), kb.id) + + if not KnowledgebaseService.update_by_id(kb.id, req): + return False, "Update dataset error.(Database error)" + + ok, k = KnowledgebaseService.get_by_id(kb.id) + if not ok: + return False, "Dataset updated failed" + + # Link connectors to the dataset + errors = Connector2KbService.link_connectors(kb.id, [conn for conn in connectors], tenant_id) + if errors: + logging.error("Link KB errors: %s", errors) + + response_data = remap_dictionary_keys(k.to_dict()) + response_data["connectors"] = connectors + return True, response_data + + +def list_datasets(tenant_id: str, args: dict): + """ + List datasets. + + :param tenant_id: tenant ID + :param args: query arguments + :return: (success, result) or (success, error_message) + """ + kb_id = args.get("id") + name = args.get("name") + page = int(args.get("page", 1)) + page_size = int(args.get("page_size", 30)) + ext_fields = args.get("ext", {}) + parser_id = ext_fields.get("parser_id") + keywords = ext_fields.get("keywords", "") + orderby = args.get("orderby", "create_time") + desc_arg = args.get("desc", "true") + if isinstance(desc_arg, str): + desc = desc_arg.lower() != "false" + elif isinstance(desc_arg, bool): + desc = desc_arg + else: + # unknown type, default to True + desc = True + + if kb_id: + kbs = KnowledgebaseService.get_kb_by_id(kb_id, tenant_id) + if not kbs: + return False, f"User '{tenant_id}' lacks permission for dataset '{kb_id}'" + if name: + kbs = KnowledgebaseService.get_kb_by_name(name, tenant_id) + if not kbs: + return False, f"User '{tenant_id}' lacks permission for dataset '{name}'" + if ext_fields.get("owner_ids", []): + tenant_ids = ext_fields["owner_ids"] + else: + tenants = TenantService.get_joined_tenants_by_user_id(tenant_id) + tenant_ids = [m["tenant_id"] for m in tenants] + kbs, total = KnowledgebaseService.get_list( + tenant_ids, + tenant_id, + page, + page_size, + orderby, + desc, + kb_id, + name, + keywords, + parser_id + ) + users = UserService.get_by_ids([m["tenant_id"] for m in kbs]) + user_map = {m.id: m.to_dict() for m in users} + response_data_list = [] + for kb in kbs: + user_dict = user_map.get(kb["tenant_id"], {}) + kb.update({ + "nickname": user_dict.get("nickname", ""), + "tenant_avatar": user_dict.get("avatar", "") + }) + response_data_list.append(remap_dictionary_keys(kb)) + return True, {"data": response_data_list, "total": total} + + +async def get_knowledge_graph(dataset_id: str, tenant_id: str): + """ + Get knowledge graph for a dataset. + + :param dataset_id: dataset ID + :param tenant_id: tenant ID + :return: (success, result) or (success, error_message) + """ + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return False, "No authorization." + _, kb = KnowledgebaseService.get_by_id(dataset_id) + + req = { + "kb_id": [dataset_id], + "knowledge_graph_kwd": ["graph"] + } + + obj = {"graph": {}, "mind_map": {}} + from rag.nlp import search + if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), dataset_id): + return True, obj + sres = await settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id]) + if not len(sres.ids): + return True, obj + + for id in sres.ids[:1]: + ty = sres.field[id]["knowledge_graph_kwd"] + try: + content_json = json.loads(sres.field[id]["content_with_weight"]) + except Exception: + continue + + obj[ty] = content_json + + if "nodes" in obj["graph"]: + obj["graph"]["nodes"] = sorted(obj["graph"]["nodes"], key=lambda x: x.get("pagerank", 0), reverse=True)[:256] + if "edges" in obj["graph"]: + node_id_set = {o["id"] for o in obj["graph"]["nodes"]} + filtered_edges = [o for o in obj["graph"]["edges"] if + o["source"] != o["target"] and o["source"] in node_id_set and o["target"] in node_id_set] + obj["graph"]["edges"] = sorted(filtered_edges, key=lambda x: x.get("weight", 0), reverse=True)[:128] + return True, obj + + +def delete_knowledge_graph(dataset_id: str, tenant_id: str): + """ + Delete knowledge graph for a dataset. + + :param dataset_id: dataset ID + :param tenant_id: tenant ID + :return: (success, result) or (success, error_message) + """ + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return False, "No authorization." + _, kb = KnowledgebaseService.get_by_id(dataset_id) + from rag.nlp import search + settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, + search.index_name(kb.tenant_id), dataset_id) + + return True, True + + +def run_graphrag(dataset_id: str, tenant_id: str): + """ + Run GraphRAG for a dataset. + + :param dataset_id: dataset ID + :param tenant_id: tenant ID + :return: (success, result) or (success, error_message) + """ + if not dataset_id: + return False, 'Lack of "Dataset ID"' + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return False, "No authorization." + + ok, kb = KnowledgebaseService.get_by_id(dataset_id) + if not ok: + return False, "Invalid Dataset ID" + + task_id = kb.graphrag_task_id + if task_id: + ok, task = TaskService.get_by_id(task_id) + if not ok: + logging.warning(f"A valid GraphRAG task id is expected for Dataset {dataset_id}") + + if task and task.progress not in [-1, 1]: + return False, f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running." + + documents, _ = DocumentService.get_by_kb_id( + kb_id=dataset_id, + page_number=0, + items_per_page=0, + orderby="create_time", + desc=False, + keywords="", + run_status=[], + types=[], + suffix=[], + ) + if not documents: + return False, f"No documents in Dataset {dataset_id}" + + sample_document = documents[0] + document_ids = [document["id"] for document in documents] + + task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids)) + + if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}): + logging.warning(f"Cannot save graphrag_task_id for Dataset {dataset_id}") + + return True, {"graphrag_task_id": task_id} + + +def trace_graphrag(dataset_id: str, tenant_id: str): + """ + Trace GraphRAG task for a dataset. + + :param dataset_id: dataset ID + :param tenant_id: tenant ID + :return: (success, result) or (success, error_message) + """ + if not dataset_id: + return False, 'Lack of "Dataset ID"' + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return False, "No authorization." + + ok, kb = KnowledgebaseService.get_by_id(dataset_id) + if not ok: + return False, "Invalid Dataset ID" + + task_id = kb.graphrag_task_id + if not task_id: + return True, {} + + ok, task = TaskService.get_by_id(task_id) + if not ok: + return True, {} + + return True, task.to_dict() + + +def run_raptor(dataset_id: str, tenant_id: str): + """ + Run RAPTOR for a dataset. + + :param dataset_id: dataset ID + :param tenant_id: tenant ID + :return: (success, result) or (success, error_message) + """ + if not dataset_id: + return False, 'Lack of "Dataset ID"' + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return False, "No authorization." + + ok, kb = KnowledgebaseService.get_by_id(dataset_id) + if not ok: + return False, "Invalid Dataset ID" + + task_id = kb.raptor_task_id + if task_id: + ok, task = TaskService.get_by_id(task_id) + if not ok: + logging.warning(f"A valid RAPTOR task id is expected for Dataset {dataset_id}") + + if task and task.progress not in [-1, 1]: + return False, f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running." + + documents, _ = DocumentService.get_by_kb_id( + kb_id=dataset_id, + page_number=0, + items_per_page=0, + orderby="create_time", + desc=False, + keywords="", + run_status=[], + types=[], + suffix=[], + ) + if not documents: + return False, f"No documents in Dataset {dataset_id}" + + sample_document = documents[0] + document_ids = [document["id"] for document in documents] + + task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids)) + + if not KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": task_id}): + logging.warning(f"Cannot save raptor_task_id for Dataset {dataset_id}") + + return True, {"raptor_task_id": task_id} + + +def trace_raptor(dataset_id: str, tenant_id: str): + """ + Trace RAPTOR task for a dataset. + + :param dataset_id: dataset ID + :param tenant_id: tenant ID + :return: (success, result) or (success, error_message) + """ + if not dataset_id: + return False, 'Lack of "Dataset ID"' + + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return False, "No authorization." + + ok, kb = KnowledgebaseService.get_by_id(dataset_id) + if not ok: + return False, "Invalid Dataset ID" + + task_id = kb.raptor_task_id + if not task_id: + return True, {} + + ok, task = TaskService.get_by_id(task_id) + if not ok: + return False, "RAPTOR Task Not Found or Error Occurred" + + return True, task.to_dict() + + +def get_auto_metadata(dataset_id: str, tenant_id: str): + """ + Get auto-metadata configuration for a dataset. + + :param dataset_id: dataset ID + :param tenant_id: tenant ID + :return: (success, result) or (success, error_message) + """ + kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id) + if kb is None: + return False, f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'" + + parser_cfg = kb.parser_config or {} + metadata = parser_cfg.get("metadata") or [] + enabled = parser_cfg.get("enable_metadata", bool(metadata)) + # Normalize to AutoMetadataConfig-like JSON + fields = [] + for f in metadata: + if not isinstance(f, dict): + continue + fields.append( + { + "name": f.get("name", ""), + "type": f.get("type", ""), + "description": f.get("description"), + "examples": f.get("examples"), + "restrict_values": f.get("restrict_values", False), + } + ) + return True, {"enabled": enabled, "fields": fields} + + +async def update_auto_metadata(dataset_id: str, tenant_id: str, cfg: dict): + """ + Update auto-metadata configuration for a dataset. + + :param dataset_id: dataset ID + :param tenant_id: tenant ID + :param cfg: auto-metadata configuration + :return: (success, result) or (success, error_message) + """ + kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id) + if kb is None: + return False, f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'" + + parser_cfg = kb.parser_config or {} + fields = [] + for f in cfg.get("fields", []): + fields.append( + { + "name": f.get("name", ""), + "type": f.get("type", ""), + "description": f.get("description"), + "examples": f.get("examples"), + "restrict_values": f.get("restrict_values", False), + } + ) + parser_cfg["metadata"] = fields + parser_cfg["enable_metadata"] = cfg.get("enabled", True) + + if not KnowledgebaseService.update_by_id(kb.id, {"parser_config": parser_cfg}): + return False, "Update auto-metadata error.(Database error)" + + return True, {"enabled": parser_cfg["enable_metadata"], "fields": fields} diff --git a/api/apps/services/memory_api_service.py b/api/apps/services/memory_api_service.py index ca6627d5a..1b640cff6 100644 --- a/api/apps/services/memory_api_service.py +++ b/api/apps/services/memory_api_service.py @@ -1,5 +1,5 @@ # -# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/api/db/services/knowledgebase_service.py b/api/db/services/knowledgebase_service.py index dcd403887..c66d66a68 100644 --- a/api/db/services/knowledgebase_service.py +++ b/api/db/services/knowledgebase_service.py @@ -433,7 +433,7 @@ class KnowledgebaseService(CommonService): @classmethod @DB.connection_context() def get_list(cls, joined_tenant_ids, user_id, - page_number, items_per_page, orderby, desc, id, name): + page_number, items_per_page, orderby, desc, id, name, keywords, parser_id=None): # Get list of knowledge bases with filtering and pagination # Args: # joined_tenant_ids: List of tenant IDs @@ -444,6 +444,8 @@ class KnowledgebaseService(CommonService): # desc: Boolean indicating descending order # id: Optional ID filter # name: Optional name filter + # keywords: Optional keywords filter + # parser_id: Optional parser ID filter # Returns: # List of knowledge bases # Total count of knowledge bases @@ -452,6 +454,11 @@ class KnowledgebaseService(CommonService): kbs = kbs.where(cls.model.id == id) if name: kbs = kbs.where(cls.model.name == name) + if keywords: + kbs = kbs.where(fn.LOWER(cls.model.name).contains(keywords.lower())) + if parser_id: + kbs = kbs.where(cls.model.parser_id == parser_id) + kbs = kbs.where( ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | ( diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index b70ff2f9f..9cf5e5a3f 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -28,7 +28,6 @@ from typing import Any import requests from quart import ( - Response, jsonify, request, has_app_context, @@ -234,6 +233,17 @@ def active_required(func): return wrapper +def add_tenant_id_to_kwargs(func): + @wraps(func) + async def wrapper(**kwargs): + from api.apps import current_user + kwargs["tenant_id"] = current_user.id + if inspect.iscoroutinefunction(func): + return await func(**kwargs) + return func(**kwargs) + return wrapper + + def get_json_result(code: RetCode = RetCode.SUCCESS, message="success", data=None): response = {"code": code, "message": message, "data": data} return _safe_jsonify(response) @@ -513,7 +523,7 @@ def check_duplicate_ids(ids, id_type="item"): return list(set(ids)), duplicate_messages -def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, Response | None]: +def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, str | None]: from api.db.services.llm_service import LLMService from api.db.services.tenant_llm_service import TenantLLMService @@ -559,13 +569,16 @@ def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, R is_builtin_model = llm_factory == "Builtin" if not (is_builtin_model or is_tenant_model or in_llm_service): - return False, get_error_argument_result(f"Unsupported model: <{embd_id}>") + return False, f"Unsupported model: <{embd_id}>" if not (is_builtin_model or is_tenant_model): - return False, get_error_argument_result(f"Unauthorized model: <{embd_id}>") + return False, f"Unauthorized model: <{embd_id}>" except OperationalError as e: logging.exception(e) - return False, get_error_data_result(message="Database operation failed") + return False, "Database operation failed" + except Exception as e: + logging.exception(e) + return False, "Internal server error" return True, None diff --git a/api/utils/validation_utils.py b/api/utils/validation_utils.py index 54d5f67dc..35e0b91f5 100644 --- a/api/utils/validation_utils.py +++ b/api/utils/validation_utils.py @@ -27,6 +27,7 @@ from pydantic import ( ValidationError, field_validator, model_validator, + ValidationInfo ) from pydantic_core import PydanticCustomError from werkzeug.exceptions import BadRequest, UnsupportedMediaType @@ -162,6 +163,15 @@ def validate_and_parse_request_args(request: Request, validator: type[BaseModel] - Preserves type conversion from Pydantic validation """ args = request.args.to_dict(flat=True) + + # Handle ext parameter: parse JSON string to dict if it's a string + if 'ext' in args and isinstance(args['ext'], str): + import json + try: + args['ext'] = json.loads(args['ext']) + except json.JSONDecodeError: + pass # Keep the string and let validation handle the error + try: if extras is not None: args.update(extras) @@ -336,6 +346,7 @@ class RaptorConfig(Base): max_cluster: Annotated[int, Field(default=64, ge=1, le=1024)] random_seed: Annotated[int, Field(default=0, ge=0)] auto_disable_for_structured_data: Annotated[bool, Field(default=True)] + ext: Annotated[dict, Field(default={})] class GraphragConfig(Base): @@ -377,6 +388,7 @@ class ParserConfig(Base): filename_embd_weight: Annotated[float | None, Field(default=0.1, ge=0.0, le=1.0)] task_page_size: Annotated[int | None, Field(default=None, ge=1)] pages: Annotated[list[list[int]] | None, Field(default=None)] + ext: Annotated[dict, Field(default={})] class CreateDatasetReq(Base): @@ -390,6 +402,25 @@ class CreateDatasetReq(Base): pipeline_id: Annotated[str | None, Field(default=None, min_length=32, max_length=32, serialization_alias="pipeline_id")] parser_config: Annotated[ParserConfig | None, Field(default=None)] auto_metadata_config: Annotated[AutoMetadataConfig | None, Field(default=None)] + ext: Annotated[dict, Field(default={})] + + @field_validator("pipeline_id", mode="before") + @classmethod + def handle_pipeline_id(cls, v: str | None, info: ValidationInfo): + if v is None: + return v + if info.data.get("chunk_method") is not None and isinstance(v, str): + v = None + return v + + @field_validator("parse_type", mode="before") + @classmethod + def handle_parse_type(cls, v: int | None, info: ValidationInfo): + if v is None: + return v + if info.data.get("chunk_method") is not None and isinstance(v, int): + v = None + return v @field_validator("avatar", mode="after") @classmethod @@ -747,3 +778,4 @@ class BaseListReq(BaseModel): class ListDatasetReq(BaseListReq): include_parsing_status: Annotated[bool, Field(default=False)] + ext: Annotated[dict, Field(default={})] diff --git a/sdk/python/ragflow_sdk/modules/dataset.py b/sdk/python/ragflow_sdk/modules/dataset.py index b686dceec..158cebfa8 100644 --- a/sdk/python/ragflow_sdk/modules/dataset.py +++ b/sdk/python/ragflow_sdk/modules/dataset.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +from typing import Any from .base import Base from .document import Document @@ -151,3 +151,23 @@ class DataSet(Base): res = res.json() if res.get("code") != 0: raise Exception(res.get("message")) + + def get_auto_metadata(self) -> dict[str, Any]: + """ + Retrieve auto-metadata configuration for a dataset via SDK. + """ + res = self.get(f"/datasets/{self.id}/auto_metadata") + res = res.json() + if res.get("code") == 0: + return res["data"] + raise Exception(res["message"]) + + def update_auto_metadata(self, **config: Any) -> dict[str, Any]: + """ + Update auto-metadata configuration for a dataset via SDK. + """ + res = self.put(f"/datasets/{self.id}/auto_metadata", config) + res = res.json() + if res.get("code") == 0: + return res["data"] + raise Exception(res["message"]) diff --git a/sdk/python/ragflow_sdk/ragflow.py b/sdk/python/ragflow_sdk/ragflow.py index ff4f423c4..15b571872 100644 --- a/sdk/python/ragflow_sdk/ragflow.py +++ b/sdk/python/ragflow_sdk/ragflow.py @@ -111,26 +111,6 @@ class RAGFlow: return result_list raise Exception(res["message"]) - def get_auto_metadata(self, dataset_id: str) -> dict[str, Any]: - """ - Retrieve auto-metadata configuration for a dataset via SDK. - """ - res = self.get(f"/datasets/{dataset_id}/auto_metadata") - res = res.json() - if res.get("code") == 0: - return res["data"] - raise Exception(res["message"]) - - def update_auto_metadata(self, dataset_id: str, **config: Any) -> dict[str, Any]: - """ - Update auto-metadata configuration for a dataset via SDK. - """ - res = self.put(f"/datasets/{dataset_id}/auto_metadata", config) - res = res.json() - if res.get("code") == 0: - return res["data"] - raise Exception(res["message"]) - def create_chat(self, name: str, avatar: str = "", dataset_ids=None, llm: Chat.LLM | None = None, prompt: Chat.Prompt | None = None) -> Chat: if dataset_ids is None: dataset_ids = [] diff --git a/test/testcases/test_http_api/test_dataset_management/test_create_dataset.py b/test/testcases/test_http_api/test_dataset_management/test_create_dataset.py index 15bd9df1c..18a3506b0 100644 --- a/test/testcases/test_http_api/test_dataset_management/test_create_dataset.py +++ b/test/testcases/test_http_api/test_dataset_management/test_create_dataset.py @@ -23,7 +23,7 @@ from utils import encode_avatar from utils.file_utils import create_image_file from utils.hypothesis_utils import valid_names -from common import create_dataset +from test_http_api.common import create_dataset @pytest.mark.usefixtures("clear_datasets") @@ -32,11 +32,11 @@ class TestAuthorization: @pytest.mark.parametrize( "invalid_auth, expected_code, expected_message", [ - (None, 0, "`Authorization` can't be empty"), + (None, 401, ""), ( RAGFlowHttpApiAuth(INVALID_API_TOKEN), - 109, - "Authentication error: API key is invalid!", + 401, + "", ), ], ids=["empty_auth", "invalid_api_token"], @@ -250,7 +250,7 @@ class TestDatasetCreate: def test_embedding_model_invalid(self, HttpApiAuth, name, embedding_model): payload = {"name": name, "embedding_model": embedding_model} res = create_dataset(HttpApiAuth, payload) - assert res["code"] == 101, res + assert res["code"] == 102, res if "tenant_no_auth" in name: assert res["message"] == f"Unauthorized model: <{embedding_model}>", res else: diff --git a/test/testcases/test_http_api/test_dataset_management/test_delete_datasets.py b/test/testcases/test_http_api/test_dataset_management/test_delete_datasets.py index 024085741..77e9e0f92 100644 --- a/test/testcases/test_http_api/test_dataset_management/test_delete_datasets.py +++ b/test/testcases/test_http_api/test_dataset_management/test_delete_datasets.py @@ -31,11 +31,11 @@ class TestAuthorization: @pytest.mark.parametrize( "invalid_auth, expected_code, expected_message", [ - (None, 0, "`Authorization` can't be empty"), + (None, 401, ""), ( RAGFlowHttpApiAuth(INVALID_API_TOKEN), - 109, - "Authentication error: API key is invalid!", + 401, + "", ), ], ) @@ -160,7 +160,7 @@ class TestDatasetsDelete: def test_id_wrong_uuid(self, HttpApiAuth): payload = {"ids": ["d94a8dc02c9711f0930f7fbc369eab6d"]} res = delete_datasets(HttpApiAuth, payload) - assert res["code"] == 108, res + assert res["code"] == 102, res assert "lacks permission for dataset" in res["message"], res res = list_datasets(HttpApiAuth) @@ -180,7 +180,7 @@ class TestDatasetsDelete: if callable(func): payload = func(dataset_ids) res = delete_datasets(HttpApiAuth, payload) - assert res["code"] == 108, res + assert res["code"] == 102, res assert "lacks permission for dataset" in res["message"], res res = list_datasets(HttpApiAuth) @@ -205,7 +205,7 @@ class TestDatasetsDelete: assert res["code"] == 0, res res = delete_datasets(HttpApiAuth, payload) - assert res["code"] == 108, res + assert res["code"] == 102, res assert "lacks permission for dataset" in res["message"], res @pytest.mark.p3 diff --git a/test/testcases/test_http_api/test_dataset_management/test_knowledge_graph.py b/test/testcases/test_http_api/test_dataset_management/test_knowledge_graph.py index 665635f16..0398f7723 100644 --- a/test/testcases/test_http_api/test_dataset_management/test_knowledge_graph.py +++ b/test/testcases/test_http_api/test_dataset_management/test_knowledge_graph.py @@ -24,8 +24,8 @@ class TestAuthorization: @pytest.mark.parametrize( "invalid_auth, expected_code, expected_message", [ - (None, 0, "Authorization"), - (RAGFlowHttpApiAuth(INVALID_API_TOKEN), 109, "API key is invalid"), + (None, 401, ""), + (RAGFlowHttpApiAuth(INVALID_API_TOKEN), 401, ""), ], ) def test_invalid_auth(self, invalid_auth, expected_code, expected_message): diff --git a/test/testcases/test_http_api/test_dataset_management/test_list_datasets.py b/test/testcases/test_http_api/test_dataset_management/test_list_datasets.py index 7887ff1fd..794761ed8 100644 --- a/test/testcases/test_http_api/test_dataset_management/test_list_datasets.py +++ b/test/testcases/test_http_api/test_dataset_management/test_list_datasets.py @@ -28,11 +28,11 @@ class TestAuthorization: @pytest.mark.parametrize( "invalid_auth, expected_code, expected_message", [ - (None, 0, "`Authorization` can't be empty"), + (None, 401, ""), ( RAGFlowHttpApiAuth(INVALID_API_TOKEN), - 109, - "Authentication error: API key is invalid!", + 401, + "", ), ], ) @@ -237,7 +237,7 @@ class TestDatasetsList: def test_name_wrong(self, HttpApiAuth): params = {"name": "wrong name"} res = list_datasets(HttpApiAuth, params) - assert res["code"] == 108, res + assert res["code"] == 102, res assert "lacks permission for dataset" in res["message"], res @pytest.mark.p2 @@ -281,7 +281,7 @@ class TestDatasetsList: def test_id_wrong_uuid(self, HttpApiAuth): params = {"id": "d94a8dc02c9711f0930f7fbc369eab6d"} res = list_datasets(HttpApiAuth, params) - assert res["code"] == 108, res + assert res["code"] == 102, res assert "lacks permission for dataset" in res["message"], res @pytest.mark.p2 diff --git a/test/testcases/test_http_api/test_dataset_management/test_update_dataset.py b/test/testcases/test_http_api/test_dataset_management/test_update_dataset.py index 0cafe1f67..a1fcefb66 100644 --- a/test/testcases/test_http_api/test_dataset_management/test_update_dataset.py +++ b/test/testcases/test_http_api/test_dataset_management/test_update_dataset.py @@ -33,11 +33,11 @@ class TestAuthorization: @pytest.mark.parametrize( "invalid_auth, expected_code, expected_message", [ - (None, 0, "`Authorization` can't be empty"), + (None, 401, ""), ( RAGFlowHttpApiAuth(INVALID_API_TOKEN), - 109, - "Authentication error: API key is invalid!", + 401, + "", ), ], ids=["empty_auth", "invalid_api_token"], @@ -76,7 +76,7 @@ class TestRquest: def test_payload_empty(self, HttpApiAuth, add_dataset_func): dataset_id = add_dataset_func res = update_dataset(HttpApiAuth, dataset_id, {}) - assert res["code"] == 101, res + assert res["code"] == 102, res assert res["message"] == "No properties were modified", res @pytest.mark.p3 @@ -313,7 +313,7 @@ class TestDatasetUpdate: dataset_id = add_dataset_func payload = {"name": name, "embedding_model": embedding_model} res = update_dataset(HttpApiAuth, dataset_id, payload) - assert res["code"] == 101, res + assert res["code"] == 102, res if "tenant_no_auth" in name: assert res["message"] == f"Unauthorized model: <{embedding_model}>", res else: @@ -494,7 +494,7 @@ class TestDatasetUpdate: dataset_id = add_dataset_func payload = {"pagerank": 50} res = update_dataset(HttpApiAuth, dataset_id, payload) - assert res["code"] == 101, res + assert res["code"] == 102, res assert res["message"] == "'pagerank' can only be set when doc_engine is elasticsearch", res @pytest.mark.p2 diff --git a/test/testcases/test_sdk_api/test_dataset_mangement/test_auto_metadata.py b/test/testcases/test_sdk_api/test_dataset_mangement/test_auto_metadata.py index 2d2dd9246..908d95dae 100644 --- a/test/testcases/test_sdk_api/test_dataset_mangement/test_auto_metadata.py +++ b/test/testcases/test_sdk_api/test_dataset_mangement/test_auto_metadata.py @@ -44,7 +44,7 @@ class TestAutoMetadataOnCreate: dataset = client.create_dataset(**payload) # The SDK should expose parser_config via internal properties or metadata; # we rely on the HTTP API for verification via get_auto_metadata. - cfg = client.get_auto_metadata(dataset_id=dataset.id) + cfg = dataset.get_auto_metadata() assert cfg["enabled"] is True assert len(cfg["fields"]) == 2 names = {f["name"] for f in cfg["fields"]} @@ -74,7 +74,7 @@ class TestAutoMetadataOnUpdate: } dataset.update(payload) - cfg = client.get_auto_metadata(dataset_id=dataset.id) + cfg = dataset.get_auto_metadata() assert cfg["enabled"] is True assert len(cfg["fields"]) == 1 assert cfg["fields"][0]["name"] == "tags" @@ -93,9 +93,9 @@ class TestAutoMetadataOnUpdate: } ], } - client.update_auto_metadata(dataset_id=dataset.id, **update_cfg) + dataset.update_auto_metadata(**update_cfg) - cfg2 = client.get_auto_metadata(dataset_id=dataset.id) + cfg2 = dataset.get_auto_metadata() assert cfg2["enabled"] is False assert len(cfg2["fields"]) == 1 assert cfg2["fields"][0]["name"] == "year" diff --git a/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py b/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py index 444b05d14..6b7679544 100644 --- a/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py +++ b/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py @@ -31,8 +31,8 @@ class TestAuthorization: @pytest.mark.parametrize( "invalid_auth, expected_message", [ - (None, "Authentication error: API key is invalid!"), - (INVALID_API_TOKEN, "Authentication error: API key is invalid!"), + (None, ""), + (INVALID_API_TOKEN, ""), ], ids=["empty_auth", "invalid_api_token"], ) diff --git a/test/testcases/test_sdk_api/test_dataset_mangement/test_delete_datasets.py b/test/testcases/test_sdk_api/test_dataset_mangement/test_delete_datasets.py index dbf0e588e..88e95742d 100644 --- a/test/testcases/test_sdk_api/test_dataset_mangement/test_delete_datasets.py +++ b/test/testcases/test_sdk_api/test_dataset_mangement/test_delete_datasets.py @@ -27,8 +27,8 @@ class TestAuthorization: @pytest.mark.parametrize( "invalid_auth, expected_message", [ - (None, "Authentication error: API key is invalid!"), - (INVALID_API_TOKEN, "Authentication error: API key is invalid!"), + (None, ""), + (INVALID_API_TOKEN, ""), ], ) def test_auth_invalid(self, invalid_auth, expected_message): diff --git a/test/testcases/test_sdk_api/test_dataset_mangement/test_list_datasets.py b/test/testcases/test_sdk_api/test_dataset_mangement/test_list_datasets.py index aa7e1b163..b2648d8fd 100644 --- a/test/testcases/test_sdk_api/test_dataset_mangement/test_list_datasets.py +++ b/test/testcases/test_sdk_api/test_dataset_mangement/test_list_datasets.py @@ -26,8 +26,8 @@ class TestAuthorization: @pytest.mark.parametrize( "invalid_auth, expected_message", [ - (None, "Authentication error: API key is invalid!"), - (INVALID_API_TOKEN, "Authentication error: API key is invalid!"), + (None, ""), + (INVALID_API_TOKEN, ""), ], ) def test_auth_invalid(self, invalid_auth, expected_message): diff --git a/test/testcases/test_web_api/common.py b/test/testcases/test_web_api/common.py index 05e47eb18..9b091539f 100644 --- a/test/testcases/test_web_api/common.py +++ b/test/testcases/test_web_api/common.py @@ -27,6 +27,7 @@ from utils.file_utils import create_txt_file HEADERS = {"Content-Type": "application/json"} KB_APP_URL = f"/{VERSION}/kb" +DATASETS_URL = f"/api/{VERSION}/datasets" DOCUMENT_APP_URL = f"/{VERSION}/document" CHUNK_API_URL = f"/{VERSION}/chunk" DIALOG_APP_URL = f"/{VERSION}/dialog" @@ -168,25 +169,28 @@ def search_rm(auth, payload=None, *, headers=HEADERS, data=None): # KB APP -def create_kb(auth, payload=None, *, headers=HEADERS, data=None): - res = requests.post(url=f"{HOST_ADDRESS}{KB_APP_URL}/create", headers=headers, auth=auth, json=payload, data=data) +def create_dataset(auth, payload=None, *, headers=HEADERS, data=None): + res = requests.post(url=f"{HOST_ADDRESS}{DATASETS_URL}", headers=headers, auth=auth, json=payload, data=data) return res.json() -def list_kbs(auth, params=None, payload=None, *, headers=HEADERS, data=None): - if payload is None: - payload = {} - res = requests.post(url=f"{HOST_ADDRESS}{KB_APP_URL}/list", headers=headers, auth=auth, params=params, json=payload, data=data) +def list_datasets(auth, params=None, *, headers=HEADERS): + res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_URL}", headers=headers, auth=auth, params=params) return res.json() -def update_kb(auth, payload=None, *, headers=HEADERS, data=None): - res = requests.post(url=f"{HOST_ADDRESS}{KB_APP_URL}/update", headers=headers, auth=auth, json=payload, data=data) +def update_dataset(auth, dataset_id, payload=None, *, headers=HEADERS, data=None): + res = requests.put(url=f"{HOST_ADDRESS}{DATASETS_URL}/{dataset_id}", headers=headers, auth=auth, json=payload, data=data) return res.json() -def rm_kb(auth, payload=None, *, headers=HEADERS, data=None): - res = requests.post(url=f"{HOST_ADDRESS}{KB_APP_URL}/rm", headers=headers, auth=auth, json=payload, data=data) +def delete_datasets(auth, payload=None, *, headers=HEADERS, data=None): + """ + Delete datasets. + The endpoint is DELETE /api/{VERSION}/datasets with payload {"ids": [...]} + This is the standard SDK REST API endpoint for dataset deletion. + """ + res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_URL}", headers=headers, auth=auth, json=payload, data=data) return res.json() @@ -236,23 +240,43 @@ def kb_pipeline_log_detail(auth, params=None, *, headers=HEADERS): return res.json() -def kb_run_graphrag(auth, payload=None, *, headers=HEADERS, data=None): - res = requests.post(url=f"{HOST_ADDRESS}{KB_APP_URL}/run_graphrag", headers=headers, auth=auth, json=payload, data=data) +# DATASET GRAPH AND TASKS +def knowledge_graph(auth, dataset_id, params=None): + url = f"{HOST_ADDRESS}{DATASETS_URL}/{dataset_id}/knowledge_graph" + res = requests.get(url=url, headers=HEADERS, auth=auth, params=params) return res.json() -def kb_trace_graphrag(auth, params=None, *, headers=HEADERS): - res = requests.get(url=f"{HOST_ADDRESS}{KB_APP_URL}/trace_graphrag", headers=headers, auth=auth, params=params) +def delete_knowledge_graph(auth, dataset_id, payload=None): + url = f"{HOST_ADDRESS}{DATASETS_URL}/{dataset_id}/knowledge_graph" + if payload is None: + res = requests.delete(url=url, headers=HEADERS, auth=auth) + else: + res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload) return res.json() -def kb_run_raptor(auth, payload=None, *, headers=HEADERS, data=None): - res = requests.post(url=f"{HOST_ADDRESS}{KB_APP_URL}/run_raptor", headers=headers, auth=auth, json=payload, data=data) +def run_graphrag(auth, dataset_id, payload=None): + url = f"{HOST_ADDRESS}{DATASETS_URL}/{dataset_id}/run_graphrag" + res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) return res.json() -def kb_trace_raptor(auth, params=None, *, headers=HEADERS): - res = requests.get(url=f"{HOST_ADDRESS}{KB_APP_URL}/trace_raptor", headers=headers, auth=auth, params=params) +def trace_graphrag(auth, dataset_id, params=None): + url = f"{HOST_ADDRESS}{DATASETS_URL}/{dataset_id}/trace_graphrag" + res = requests.get(url=url, headers=HEADERS, auth=auth, params=params) + return res.json() + + +def run_raptor(auth, dataset_id, payload=None): + url = f"{HOST_ADDRESS}{DATASETS_URL}/{dataset_id}/run_raptor" + res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) + return res.json() + + +def trace_raptor(auth, dataset_id, params=None): + url = f"{HOST_ADDRESS}{DATASETS_URL}/{dataset_id}/trace_raptor" + res = requests.get(url=url, headers=HEADERS, auth=auth, params=params) return res.json() @@ -286,21 +310,11 @@ def rename_tags(auth, dataset_id, payload=None, *, headers=HEADERS, data=None): return res.json() -def knowledge_graph(auth, dataset_id, params=None, *, headers=HEADERS): - res = requests.get(url=f"{HOST_ADDRESS}{KB_APP_URL}/{dataset_id}/knowledge_graph", headers=headers, auth=auth, params=params) - return res.json() - - -def delete_knowledge_graph(auth, dataset_id, payload=None, *, headers=HEADERS, data=None): - res = requests.delete(url=f"{HOST_ADDRESS}{KB_APP_URL}/{dataset_id}/knowledge_graph", headers=headers, auth=auth, json=payload, data=data) - return res.json() - - def batch_create_datasets(auth, num): ids = [] for i in range(num): - res = create_kb(auth, {"name": f"kb_{i}"}) - ids.append(res["data"]["kb_id"]) + res = create_dataset(auth, {"name": f"kb_{i}"}) + ids.append(res["data"]["id"]) return ids diff --git a/test/testcases/test_web_api/conftest.py b/test/testcases/test_web_api/conftest.py index 51db85b3d..2f3a2c628 100644 --- a/test/testcases/test_web_api/conftest.py +++ b/test/testcases/test_web_api/conftest.py @@ -26,9 +26,9 @@ from common import ( delete_dialogs, list_chunks, list_documents, - list_kbs, + list_datasets, parse_documents, - rm_kb, + delete_datasets, ) from libs.auth import RAGFlowWebApiAuth from pytest import FixtureRequest @@ -104,9 +104,9 @@ def require_env_flag(): @pytest.fixture(scope="function") def clear_datasets(request: FixtureRequest, WebApiAuth: RAGFlowWebApiAuth): def cleanup(): - res = list_kbs(WebApiAuth, params={"page_size": 1000}) - for kb in res["data"]["kbs"]: - rm_kb(WebApiAuth, {"kb_id": kb["id"]}) + res = list_datasets(WebApiAuth, params={"page_size": 1000}) + kb_ids = [kb["id"] for kb in res["data"]] + delete_datasets(WebApiAuth, {"ids": kb_ids}) request.addfinalizer(cleanup) @@ -122,9 +122,9 @@ def clear_dialogs(request, WebApiAuth): @pytest.fixture(scope="class") def add_dataset(request: FixtureRequest, WebApiAuth: RAGFlowWebApiAuth) -> str: def cleanup(): - res = list_kbs(WebApiAuth, params={"page_size": 1000}) - for kb in res["data"]["kbs"]: - rm_kb(WebApiAuth, {"kb_id": kb["id"]}) + res = list_datasets(WebApiAuth, params={"page_size": 1000}) + kb_ids = [kb["id"] for kb in res["data"]] + delete_datasets(WebApiAuth, {"ids": kb_ids}) request.addfinalizer(cleanup) return batch_create_datasets(WebApiAuth, 1)[0] @@ -133,9 +133,9 @@ def add_dataset(request: FixtureRequest, WebApiAuth: RAGFlowWebApiAuth) -> str: @pytest.fixture(scope="function") def add_dataset_func(request: FixtureRequest, WebApiAuth: RAGFlowWebApiAuth) -> str: def cleanup(): - res = list_kbs(WebApiAuth, params={"page_size": 1000}) - for kb in res["data"]["kbs"]: - rm_kb(WebApiAuth, {"kb_id": kb["id"]}) + res = list_datasets(WebApiAuth, params={"page_size": 1000}) + kb_ids = [kb["id"] for kb in res["data"]] + delete_datasets(WebApiAuth, {"ids": kb_ids}) request.addfinalizer(cleanup) return batch_create_datasets(WebApiAuth, 1)[0] diff --git a/test/testcases/test_web_api/test_dataset_management/test_dataset_sdk_routes_unit.py b/test/testcases/test_web_api/test_dataset_management/test_dataset_sdk_routes_unit.py index 967c95ef7..445f74c6a 100644 --- a/test/testcases/test_web_api/test_dataset_management/test_dataset_sdk_routes_unit.py +++ b/test/testcases/test_web_api/test_dataset_management/test_dataset_sdk_routes_unit.py @@ -409,7 +409,7 @@ def _load_dataset_module(monkeypatch): rag_nlp_pkg.search = search_mod module_name = "test_dataset_sdk_routes_unit_module" - module_path = repo_root / "api" / "apps" / "sdk" / "dataset.py" + module_path = repo_root / "api" / "apps" / "restful_apis" / "dataset_api.py" spec = importlib.util.spec_from_file_location(module_name, module_path) module = importlib.util.module_from_spec(spec) module.manager = _DummyManager() @@ -418,7 +418,7 @@ def _load_dataset_module(monkeypatch): return module -@pytest.mark.p2 +@pytest.mark.p3 def test_create_route_error_matrix_unit(monkeypatch): module = _load_dataset_module(monkeypatch) req_state = {"name": "kb"} @@ -448,7 +448,7 @@ def test_create_route_error_matrix_unit(monkeypatch): assert res["message"] == "Database operation failed", res -@pytest.mark.p2 +@pytest.mark.p3 def test_delete_route_error_summary_matrix_unit(monkeypatch): module = _load_dataset_module(monkeypatch) req_state = {"ids": ["kb-1"]} @@ -476,7 +476,7 @@ def test_delete_route_error_summary_matrix_unit(monkeypatch): assert res["code"] == module.RetCode.SUCCESS, res -@pytest.mark.p2 +@pytest.mark.p3 def test_update_route_branch_matrix_unit(monkeypatch): module = _load_dataset_module(monkeypatch) req_state = {"name": "new"} @@ -556,7 +556,7 @@ def test_update_route_branch_matrix_unit(monkeypatch): assert res["message"] == "Database operation failed", res -@pytest.mark.p2 +@pytest.mark.p3 def test_list_knowledge_graph_delete_kg_matrix_unit(monkeypatch): module = _load_dataset_module(monkeypatch) @@ -629,7 +629,7 @@ def test_list_knowledge_graph_delete_kg_matrix_unit(monkeypatch): assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res -@pytest.mark.p2 +@pytest.mark.p3 def test_run_trace_graphrag_matrix_unit(monkeypatch): module = _load_dataset_module(monkeypatch) @@ -705,7 +705,7 @@ def test_run_trace_graphrag_matrix_unit(monkeypatch): assert res["data"]["id"] == "task-1", res -@pytest.mark.p2 +@pytest.mark.p3 def test_run_trace_raptor_matrix_unit(monkeypatch): module = _load_dataset_module(monkeypatch) diff --git a/test/testcases/test_web_api/test_document_app/test_create_document.py b/test/testcases/test_web_api/test_document_app/test_create_document.py index 8c39bdf4e..4e590ba3d 100644 --- a/test/testcases/test_web_api/test_document_app/test_create_document.py +++ b/test/testcases/test_web_api/test_document_app/test_create_document.py @@ -19,7 +19,7 @@ from types import SimpleNamespace from concurrent.futures import ThreadPoolExecutor, as_completed import pytest -from common import create_document, list_kbs +from test_web_api.common import create_document, list_datasets from configs import DOCUMENT_NAME_LIMIT, INVALID_API_TOKEN from libs.auth import RAGFlowWebApiAuth from utils.file_utils import create_txt_file @@ -91,8 +91,8 @@ class TestDocumentCreate: assert len(responses) == count, responses assert all(future.result()["code"] == 0 for future in futures), responses - res = list_kbs(WebApiAuth, {"id": kb_id}) - assert res["data"]["kbs"][0]["doc_num"] == count, res + res = list_datasets(WebApiAuth, {"id": kb_id}) + assert res["data"][0]["document_count"] == count, res def _run(coro): diff --git a/test/testcases/test_web_api/test_document_app/test_upload_documents.py b/test/testcases/test_web_api/test_document_app/test_upload_documents.py index b4a151551..c8b82774e 100644 --- a/test/testcases/test_web_api/test_document_app/test_upload_documents.py +++ b/test/testcases/test_web_api/test_document_app/test_upload_documents.py @@ -20,7 +20,7 @@ from types import ModuleType, SimpleNamespace from concurrent.futures import ThreadPoolExecutor, as_completed import pytest -from common import list_kbs, upload_documents +from common import list_datasets, upload_documents from configs import DOCUMENT_NAME_LIMIT, INVALID_API_TOKEN from libs.auth import RAGFlowWebApiAuth from utils.file_utils import create_txt_file @@ -172,8 +172,8 @@ class TestDocumentsUpload: res = upload_documents(WebApiAuth, {"kb_id": kb_id}, fps) assert res["code"] == 0, res - res = list_kbs(WebApiAuth) - assert res["data"]["kbs"][0]["doc_num"] == expected_document_count, res + res = list_datasets(WebApiAuth) + assert res["data"][0]["document_count"] == expected_document_count, res @pytest.mark.p3 def test_concurrent_upload(self, WebApiAuth, add_dataset_func, tmp_path): @@ -191,8 +191,8 @@ class TestDocumentsUpload: assert len(responses) == count, responses assert all(future.result()["code"] == 0 for future in futures), responses - res = list_kbs(WebApiAuth) - assert res["data"]["kbs"][0]["doc_num"] == count, res + res = list_datasets(WebApiAuth) + assert res["data"][0]["document_count"] == count, res class _AwaitableValue: diff --git a/test/testcases/test_web_api/test_kb_app/conftest.py b/test/testcases/test_web_api/test_kb_app/conftest.py index 8a2387391..d51df5c21 100644 --- a/test/testcases/test_web_api/test_kb_app/conftest.py +++ b/test/testcases/test_web_api/test_kb_app/conftest.py @@ -14,7 +14,7 @@ # limitations under the License. # import pytest -from common import batch_create_datasets, list_kbs, rm_kb +from common import batch_create_datasets, list_datasets, delete_datasets from libs.auth import RAGFlowWebApiAuth from pytest import FixtureRequest from ragflow_sdk import RAGFlow @@ -26,11 +26,10 @@ def add_datasets(request: FixtureRequest, client: RAGFlow, WebApiAuth: RAGFlowWe def cleanup(): # Web KB cleanup cannot call SDK dataset bulk delete with empty ids; deletion must stay explicit. - res = list_kbs(WebApiAuth, params={"page_size": 1000}) - existing_ids = {kb["id"] for kb in res["data"]["kbs"]} - for dataset_id in dataset_ids: - if dataset_id in existing_ids: - rm_kb(WebApiAuth, {"kb_id": dataset_id}) + res = list_datasets(WebApiAuth, params={"page_size": 1000}) + existing_ids = {kb["id"] for kb in res["data"]} + ids_to_delete = list({dataset_id for dataset_id in dataset_ids if dataset_id in existing_ids}) + delete_datasets(WebApiAuth, {"ids": ids_to_delete}) request.addfinalizer(cleanup) return dataset_ids @@ -42,11 +41,10 @@ def add_datasets_func(request: FixtureRequest, client: RAGFlow, WebApiAuth: RAGF def cleanup(): # Web KB cleanup cannot call SDK dataset bulk delete with empty ids; deletion must stay explicit. - res = list_kbs(WebApiAuth, params={"page_size": 1000}) - existing_ids = {kb["id"] for kb in res["data"]["kbs"]} - for dataset_id in dataset_ids: - if dataset_id in existing_ids: - rm_kb(WebApiAuth, {"kb_id": dataset_id}) + res = list_datasets(WebApiAuth, params={"page_size": 1000}) + existing_ids = {kb["id"] for kb in res["data"]} + ids_to_delete = list({dataset_id for dataset_id in dataset_ids if dataset_id in existing_ids}) + delete_datasets(WebApiAuth, {"ids": ids_to_delete}) request.addfinalizer(cleanup) return dataset_ids diff --git a/test/testcases/test_web_api/test_kb_app/test_create_kb.py b/test/testcases/test_web_api/test_kb_app/test_create_kb.py index 0e7fe0c55..b0942cd6e 100644 --- a/test/testcases/test_web_api/test_kb_app/test_create_kb.py +++ b/test/testcases/test_web_api/test_kb_app/test_create_kb.py @@ -16,7 +16,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed import pytest -from common import create_kb +from common import create_dataset from configs import DATASET_NAME_LIMIT, INVALID_API_TOKEN from hypothesis import example, given, settings from libs.auth import RAGFlowWebApiAuth @@ -35,7 +35,7 @@ class TestAuthorization: ids=["empty_auth", "invalid_api_token"], ) def test_auth_invalid(self, invalid_auth, expected_code, expected_message): - res = create_kb(invalid_auth, {"name": "auth_test"}) + res = create_dataset(invalid_auth, {"name": "auth_test"}) assert res["code"] == expected_code, res assert res["message"] == expected_message, res @@ -46,14 +46,14 @@ class TestCapability: def test_create_kb_1k(self, WebApiAuth): for i in range(1_000): payload = {"name": f"dataset_{i}"} - res = create_kb(WebApiAuth, payload) + res = create_dataset(WebApiAuth, payload) assert res["code"] == 0, f"Failed to create dataset {i}" @pytest.mark.p3 def test_create_kb_concurrent(self, WebApiAuth): count = 100 with ThreadPoolExecutor(max_workers=5) as executor: - futures = [executor.submit(create_kb, WebApiAuth, {"name": f"dataset_{i}"}) for i in range(count)] + futures = [executor.submit(create_dataset, WebApiAuth, {"name": f"dataset_{i}"}) for i in range(count)] responses = list(as_completed(futures)) assert len(responses) == count, responses assert all(future.result()["code"] == 0 for future in futures) @@ -66,44 +66,44 @@ class TestDatasetCreate: @example("a" * 128) @settings(max_examples=20) def test_name(self, WebApiAuth, name): - res = create_kb(WebApiAuth, {"name": name}) + res = create_dataset(WebApiAuth, {"name": name}) assert res["code"] == 0, res @pytest.mark.p2 @pytest.mark.parametrize( "name, expected_message", [ - ("", "Dataset name can't be empty."), - (" ", "Dataset name can't be empty."), - ("a" * (DATASET_NAME_LIMIT + 1), "Dataset name length is 129 which is large than 128"), - (0, "Dataset name must be string."), - (None, "Dataset name must be string."), + ("", "Field: - Message: "), + (" ", "Field: - Message: "), + ("a" * (DATASET_NAME_LIMIT + 1), "Field: - Message: "), + (0, "Field: - Message: "), + (None, "Field: - Message: "), ], ids=["empty_name", "space_name", "too_long_name", "invalid_name", "None_name"], ) def test_name_invalid(self, WebApiAuth, name, expected_message): payload = {"name": name} - res = create_kb(WebApiAuth, payload) - assert res["code"] == 102, res + res = create_dataset(WebApiAuth, payload) + assert res["code"] == 101, res assert expected_message in res["message"], res @pytest.mark.p3 def test_name_duplicated(self, WebApiAuth): name = "duplicated_name" payload = {"name": name} - res = create_kb(WebApiAuth, payload) + res = create_dataset(WebApiAuth, payload) assert res["code"] == 0, res - res = create_kb(WebApiAuth, payload) + res = create_dataset(WebApiAuth, payload) assert res["code"] == 0, res @pytest.mark.p3 def test_name_case_insensitive(self, WebApiAuth): name = "CaseInsensitive" payload = {"name": name.upper()} - res = create_kb(WebApiAuth, payload) + res = create_dataset(WebApiAuth, payload) assert res["code"] == 0, res payload = {"name": name.lower()} - res = create_kb(WebApiAuth, payload) + res = create_dataset(WebApiAuth, payload) assert res["code"] == 0, res diff --git a/test/testcases/test_web_api/test_kb_app/test_kb_pipeline_tasks.py b/test/testcases/test_web_api/test_kb_app/test_kb_pipeline_tasks.py index 21fb0ec50..6bf2da491 100644 --- a/test/testcases/test_web_api/test_kb_app/test_kb_pipeline_tasks.py +++ b/test/testcases/test_web_api/test_kb_app/test_kb_pipeline_tasks.py @@ -14,17 +14,17 @@ # limitations under the License. # import pytest -from common import ( +from test_web_api.common import ( kb_delete_pipeline_logs, kb_list_pipeline_dataset_logs, kb_list_pipeline_logs, kb_pipeline_log_detail, - kb_run_graphrag, + run_graphrag, + trace_graphrag, + run_raptor, + trace_raptor, kb_run_mindmap, - kb_run_raptor, - kb_trace_graphrag, kb_trace_mindmap, - kb_trace_raptor, list_documents, parse_documents, ) @@ -101,13 +101,13 @@ class TestKbPipelineTasks: @pytest.mark.p3 def test_graphrag_run_and_trace(self, WebApiAuth, add_chunks): kb_id, _, _ = add_chunks - run_res = kb_run_graphrag(WebApiAuth, {"kb_id": kb_id}) + run_res = run_graphrag(WebApiAuth, kb_id) assert run_res["code"] == 0, run_res task_id = run_res["data"]["graphrag_task_id"] assert task_id, run_res - _wait_for_task(kb_trace_graphrag, WebApiAuth, kb_id, task_id) - trace_res = kb_trace_graphrag(WebApiAuth, {"kb_id": kb_id}) + _wait_for_task(trace_graphrag, WebApiAuth, kb_id, task_id) + trace_res = trace_graphrag(WebApiAuth, kb_id) assert trace_res["code"] == 0, trace_res task = _find_task(trace_res["data"], task_id) assert task, trace_res @@ -118,13 +118,13 @@ class TestKbPipelineTasks: @pytest.mark.p3 def test_raptor_run_and_trace(self, WebApiAuth, add_chunks): kb_id, _, _ = add_chunks - run_res = kb_run_raptor(WebApiAuth, {"kb_id": kb_id}) + run_res = run_raptor(WebApiAuth, kb_id) assert run_res["code"] == 0, run_res task_id = run_res["data"]["raptor_task_id"] assert task_id, run_res - _wait_for_task(kb_trace_raptor, WebApiAuth, kb_id, task_id) - trace_res = kb_trace_raptor(WebApiAuth, {"kb_id": kb_id}) + _wait_for_task(trace_raptor, WebApiAuth, kb_id, task_id) + trace_res = trace_raptor(WebApiAuth, kb_id) assert trace_res["code"] == 0, trace_res task = _find_task(trace_res["data"], task_id) assert task, trace_res diff --git a/test/testcases/test_web_api/test_kb_app/test_kb_routes_unit.py b/test/testcases/test_web_api/test_kb_app/test_kb_routes_unit.py index d3a3cde43..40a74ae12 100644 --- a/test/testcases/test_web_api/test_kb_app/test_kb_routes_unit.py +++ b/test/testcases/test_web_api/test_kb_app/test_kb_routes_unit.py @@ -181,7 +181,7 @@ def set_tenant_info(): return None -@pytest.mark.p2 +@pytest.mark.p3 def test_create_branches(monkeypatch): module = _load_kb_module(monkeypatch) @@ -211,7 +211,7 @@ def test_create_branches(monkeypatch): assert "save boom" in res["message"], res -@pytest.mark.p2 +@pytest.mark.p3 def test_update_branches(monkeypatch): module = _load_kb_module(monkeypatch) update_route = _unwrap_route(module.update) @@ -326,7 +326,7 @@ def test_update_branches(monkeypatch): assert "update boom" in res["message"], res -@pytest.mark.p2 +@pytest.mark.p3 def test_update_metadata_setting_not_found(monkeypatch): module = _load_kb_module(monkeypatch) _set_request_json(monkeypatch, module, {"kb_id": "missing-kb", "metadata": {}}) @@ -336,7 +336,7 @@ def test_update_metadata_setting_not_found(monkeypatch): assert "Database error" in res["message"], res -@pytest.mark.p2 +@pytest.mark.p3 def test_detail_branches(monkeypatch): module = _load_kb_module(monkeypatch) @@ -380,7 +380,7 @@ def test_detail_branches(monkeypatch): assert "detail boom" in res["message"], res -@pytest.mark.p2 +@pytest.mark.p3 def test_list_kbs_owner_ids_and_desc(monkeypatch): module = _load_kb_module(monkeypatch) @@ -414,7 +414,7 @@ def test_list_kbs_owner_ids_and_desc(monkeypatch): assert "list boom" in res["message"], res -@pytest.mark.p2 +@pytest.mark.p3 def test_rm_and_rm_sync_branches(monkeypatch): module = _load_kb_module(monkeypatch) @@ -491,7 +491,7 @@ def test_rm_and_rm_sync_branches(monkeypatch): assert "rm boom" in res["message"], res -@pytest.mark.p2 +@pytest.mark.p3 def test_tags_and_meta_branches(monkeypatch): module = _load_kb_module(monkeypatch) @@ -560,7 +560,7 @@ def test_tags_and_meta_branches(monkeypatch): assert res["data"]["finished"] == 1, res -@pytest.mark.p2 +@pytest.mark.p3 def test_knowledge_graph_branches(monkeypatch): module = _load_kb_module(monkeypatch) @@ -636,7 +636,7 @@ def test_knowledge_graph_branches(monkeypatch): assert res["data"] is True, res -@pytest.mark.p2 +@pytest.mark.p3 def test_list_pipeline_logs_validation_branches(monkeypatch): module = _load_kb_module(monkeypatch) @@ -681,7 +681,7 @@ def test_list_pipeline_logs_validation_branches(monkeypatch): assert "Create data filter is abnormal." in res["message"], res -@pytest.mark.p2 +@pytest.mark.p3 def test_list_pipeline_logs_filter_and_exception_branches(monkeypatch): module = _load_kb_module(monkeypatch) @@ -718,7 +718,7 @@ def test_list_pipeline_logs_filter_and_exception_branches(monkeypatch): assert "logs boom" in res["message"], res -@pytest.mark.p2 +@pytest.mark.p3 def test_list_pipeline_dataset_logs_branches(monkeypatch): module = _load_kb_module(monkeypatch) @@ -792,7 +792,7 @@ def test_list_pipeline_dataset_logs_branches(monkeypatch): assert "dataset logs boom" in res["message"], res -@pytest.mark.p2 +@pytest.mark.p3 def test_pipeline_log_detail_and_delete_routes_branches(monkeypatch): module = _load_kb_module(monkeypatch) @@ -841,7 +841,7 @@ def test_pipeline_log_detail_and_delete_routes_branches(monkeypatch): assert res["data"]["id"] == "log-1", res -@pytest.mark.p2 +@pytest.mark.p3 @pytest.mark.parametrize( "route_name,task_attr,response_key,task_type", [ @@ -914,7 +914,7 @@ def test_run_pipeline_task_routes_branch_matrix(monkeypatch, route_name, task_at assert queue_calls["doc_ids"] == ["doc-1", "doc-2"], queue_calls -@pytest.mark.p2 +@pytest.mark.p3 @pytest.mark.parametrize( "route_name,task_attr,empty_on_missing_task,error_text", [ @@ -970,7 +970,7 @@ def test_trace_pipeline_task_routes_branch_matrix(monkeypatch, route_name, task_ assert res["data"]["id"] == "task-1", res -@pytest.mark.p2 +@pytest.mark.p3 def test_unbind_task_branch_matrix(monkeypatch): module = _load_kb_module(monkeypatch) route = inspect.unwrap(module.delete_kb_task) @@ -1060,7 +1060,7 @@ def test_unbind_task_branch_matrix(monkeypatch): assert "cannot delete task" in res["message"], res -@pytest.mark.p2 +@pytest.mark.p3 def test_check_embedding_similarity_threshold_matrix_unit(monkeypatch): module = _load_kb_module(monkeypatch) route = inspect.unwrap(module.check_embedding) @@ -1229,7 +1229,7 @@ def test_check_embedding_similarity_threshold_matrix_unit(monkeypatch): assert res["data"]["summary"]["avg_cos_sim"] > 0.9, res -@pytest.mark.p2 +@pytest.mark.p3 def test_check_embedding_error_and_empty_sample_paths_unit(monkeypatch): module = _load_kb_module(monkeypatch) route = inspect.unwrap(module.check_embedding) diff --git a/test/testcases/test_web_api/test_kb_app/test_kb_tags_meta.py b/test/testcases/test_web_api/test_kb_app/test_kb_tags_meta.py index 810b636de..9b5bb5237 100644 --- a/test/testcases/test_web_api/test_kb_app/test_kb_tags_meta.py +++ b/test/testcases/test_web_api/test_kb_app/test_kb_tags_meta.py @@ -16,7 +16,7 @@ import uuid import pytest -from common import ( +from test_web_api.common import ( delete_knowledge_graph, kb_basic_info, kb_get_meta, diff --git a/test/testcases/test_web_api/test_kb_app/test_list_kbs.py b/test/testcases/test_web_api/test_kb_app/test_list_kbs.py index 530686788..b6ed92f76 100644 --- a/test/testcases/test_web_api/test_kb_app/test_list_kbs.py +++ b/test/testcases/test_web_api/test_kb_app/test_list_kbs.py @@ -13,10 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import json from concurrent.futures import ThreadPoolExecutor, as_completed import pytest -from common import list_kbs +from common import list_datasets from configs import INVALID_API_TOKEN from libs.auth import RAGFlowWebApiAuth from utils import is_sorted @@ -32,7 +33,7 @@ class TestAuthorization: ], ) def test_auth_invalid(self, invalid_auth, expected_code, expected_message): - res = list_kbs(invalid_auth) + res = list_datasets(invalid_auth) assert res["code"] == expected_code, res assert res["message"] == expected_message, res @@ -42,7 +43,7 @@ class TestCapability: def test_concurrent_list(self, WebApiAuth): count = 100 with ThreadPoolExecutor(max_workers=5) as executor: - futures = [executor.submit(list_kbs, WebApiAuth) for i in range(count)] + futures = [executor.submit(list_datasets, WebApiAuth) for i in range(count)] responses = list(as_completed(futures)) assert len(responses) == count, responses assert all(future.result()["code"] == 0 for future in futures) @@ -52,15 +53,15 @@ class TestCapability: class TestDatasetsList: @pytest.mark.p2 def test_params_unset(self, WebApiAuth): - res = list_kbs(WebApiAuth, None) + res = list_datasets(WebApiAuth, None) assert res["code"] == 0, res - assert len(res["data"]["kbs"]) == 5, res + assert len(res["data"]) == 5, res @pytest.mark.p2 def test_params_empty(self, WebApiAuth): - res = list_kbs(WebApiAuth, {}) + res = list_datasets(WebApiAuth, {}) assert res["code"] == 0, res - assert len(res["data"]["kbs"]) == 5, res + assert len(res["data"]) == 5, res @pytest.mark.p1 @pytest.mark.parametrize( @@ -75,9 +76,9 @@ class TestDatasetsList: ids=["normal_middle_page", "normal_last_partial_page", "beyond_max_page", "string_page_number", "full_data_single_page"], ) def test_page(self, WebApiAuth, params, expected_page_size): - res = list_kbs(WebApiAuth, params) + res = list_datasets(WebApiAuth, params) assert res["code"] == 0, res - assert len(res["data"]["kbs"]) == expected_page_size, res + assert len(res["data"]) == expected_page_size, res @pytest.mark.skip @pytest.mark.p2 @@ -90,16 +91,16 @@ class TestDatasetsList: ids=["page_0", "page_a"], ) def test_page_invalid(self, WebApiAuth, params, expected_code, expected_message): - res = list_kbs(WebApiAuth, params=params) + res = list_datasets(WebApiAuth, params=params) assert res["code"] == expected_code, res assert expected_message in res["message"], res @pytest.mark.p2 def test_page_none(self, WebApiAuth): params = {"page": None} - res = list_kbs(WebApiAuth, params) + res = list_datasets(WebApiAuth, params) assert res["code"] == 0, res - assert len(res["data"]["kbs"]) == 5, res + assert len(res["data"]) == 5, res @pytest.mark.p1 @pytest.mark.parametrize( @@ -114,9 +115,9 @@ class TestDatasetsList: ids=["min_valid_page_size", "medium_page_size", "page_size_equals_total", "page_size_exceeds_total", "string_type_page_size"], ) def test_page_size(self, WebApiAuth, params, expected_page_size): - res = list_kbs(WebApiAuth, params) + res = list_datasets(WebApiAuth, params) assert res["code"] == 0, res - assert len(res["data"]["kbs"]) == expected_page_size, res + assert len(res["data"]) == expected_page_size, res @pytest.mark.skip @pytest.mark.p2 @@ -128,27 +129,27 @@ class TestDatasetsList: ], ) def test_page_size_invalid(self, WebApiAuth, params, expected_code, expected_message): - res = list_kbs(WebApiAuth, params) + res = list_datasets(WebApiAuth, params) assert res["code"] == expected_code, res assert expected_message in res["message"], res @pytest.mark.p2 def test_page_size_none(self, WebApiAuth): params = {"page_size": None} - res = list_kbs(WebApiAuth, params) + res = list_datasets(WebApiAuth, params) assert res["code"] == 0, res - assert len(res["data"]["kbs"]) == 5, res + assert len(res["data"]) == 5, res @pytest.mark.p3 @pytest.mark.parametrize( "params, assertions", [ - ({"orderby": "update_time"}, lambda r: (is_sorted(r["data"]["kbs"], "update_time", True))), + ({"orderby": "update_time"}, lambda r: (is_sorted(r["data"], "update_time", True))), ], ids=["orderby_update_time"], ) def test_orderby(self, WebApiAuth, params, assertions): - res = list_kbs(WebApiAuth, params) + res = list_datasets(WebApiAuth, params) assert res["code"] == 0, res if callable(assertions): assert assertions(res), res @@ -157,13 +158,13 @@ class TestDatasetsList: @pytest.mark.parametrize( "params, assertions", [ - ({"desc": "True"}, lambda r: (is_sorted(r["data"]["kbs"], "update_time", True))), - ({"desc": "False"}, lambda r: (is_sorted(r["data"]["kbs"], "update_time", False))), + ({"desc": "True"}, lambda r: (is_sorted(r["data"], "update_time", True))), + ({"desc": "False"}, lambda r: (is_sorted(r["data"], "update_time", False))), ], ids=["desc=True", "desc=False"], ) def test_desc(self, WebApiAuth, params, assertions): - res = list_kbs(WebApiAuth, params) + res = list_datasets(WebApiAuth, params) assert res["code"] == 0, res if callable(assertions): @@ -173,29 +174,28 @@ class TestDatasetsList: @pytest.mark.parametrize( "params, expected_page_size", [ - ({"parser_id": "naive"}, 5), - ({"parser_id": "qa"}, 0), + ({"ext": json.dumps({"parser_id": "naive"})}, 5), + ({"ext": json.dumps({"parser_id": "qa"})}, 0), ], ids=["naive", "dqa"], ) def test_parser_id(self, WebApiAuth, params, expected_page_size): - res = list_kbs(WebApiAuth, params) + res = list_datasets(WebApiAuth, params) assert res["code"] == 0, res - assert len(res["data"]["kbs"]) == expected_page_size, res + assert len(res["data"]) == expected_page_size, res @pytest.mark.p2 def test_owner_ids_payload_mode(self, WebApiAuth): - base_res = list_kbs(WebApiAuth, {"page_size": 10}) + base_res = list_datasets(WebApiAuth, {"page_size": 10}) assert base_res["code"] == 0, base_res - assert base_res["data"]["kbs"], base_res - owner_id = base_res["data"]["kbs"][0]["tenant_id"] + assert base_res["data"], base_res + owner_id = base_res["data"][0]["tenant_id"] - res = list_kbs( + res = list_datasets( WebApiAuth, - params={"page": 1, "page_size": 2, "desc": "false"}, - payload={"owner_ids": [owner_id]}, + params={"page": 1, "page_size": 2, "desc": "false", "ext": json.dumps({"owner_ids": [owner_id]})}, ) assert res["code"] == 0, res - assert res["data"]["total"] >= len(res["data"]["kbs"]), res - assert len(res["data"]["kbs"]) <= 2, res - assert all(kb["tenant_id"] == owner_id for kb in res["data"]["kbs"]), res + assert res["total_datasets"] >= len(res["data"]), res + assert len(res["data"]) <= 2, res + assert all(kb["tenant_id"] == owner_id for kb in res["data"]), res diff --git a/test/testcases/test_web_api/test_kb_app/test_rm_kb.py b/test/testcases/test_web_api/test_kb_app/test_rm_kb.py index 21ea624a6..d421bb3aa 100644 --- a/test/testcases/test_web_api/test_kb_app/test_rm_kb.py +++ b/test/testcases/test_web_api/test_kb_app/test_rm_kb.py @@ -16,8 +16,8 @@ import pytest from common import ( - list_kbs, - rm_kb, + list_datasets, + delete_datasets, ) from configs import INVALID_API_TOKEN from libs.auth import RAGFlowWebApiAuth @@ -33,7 +33,7 @@ class TestAuthorization: ], ) def test_auth_invalid(self, invalid_auth, expected_code, expected_message): - res = rm_kb(invalid_auth) + res = delete_datasets(invalid_auth) assert res["code"] == expected_code, res assert res["message"] == expected_message, res @@ -42,20 +42,20 @@ class TestDatasetsDelete: @pytest.mark.p1 def test_kb_id(self, WebApiAuth, add_datasets_func): kb_ids = add_datasets_func - payload = {"kb_id": kb_ids[0]} - res = rm_kb(WebApiAuth, payload) + payload = {"ids": [kb_ids[0]]} + res = delete_datasets(WebApiAuth, payload) assert res["code"] == 0, res - res = list_kbs(WebApiAuth) - assert len(res["data"]["kbs"]) == 2, res + res = list_datasets(WebApiAuth) + assert len(res["data"]) == 2, res @pytest.mark.p2 @pytest.mark.usefixtures("add_dataset_func") def test_id_wrong_uuid(self, WebApiAuth): - payload = {"kb_id": "d94a8dc02c9711f0930f7fbc369eab6d"} - res = rm_kb(WebApiAuth, payload) - assert res["code"] == 109, res - assert "No authorization." in res["message"], res + payload = {"ids": ["d94a8dc02c9711f0930f7fbc369eab6d"]} + res = delete_datasets(WebApiAuth, payload) + assert res["code"] == 102, res + assert "lacks permission" in res["message"], res - res = list_kbs(WebApiAuth) - assert len(res["data"]["kbs"]) == 1, res + res = list_datasets(WebApiAuth) + assert len(res["data"]) == 1, res diff --git a/test/testcases/test_web_api/test_kb_app/test_update_kb.py b/test/testcases/test_web_api/test_kb_app/test_update_kb.py index 641ed3b1f..7ee1bad5f 100644 --- a/test/testcases/test_web_api/test_kb_app/test_update_kb.py +++ b/test/testcases/test_web_api/test_kb_app/test_update_kb.py @@ -17,7 +17,7 @@ import os from concurrent.futures import ThreadPoolExecutor, as_completed import pytest -from common import update_kb +from test_web_api.common import update_dataset from configs import DATASET_NAME_LIMIT, INVALID_API_TOKEN from hypothesis import HealthCheck, example, given, settings from libs.auth import RAGFlowWebApiAuth @@ -37,7 +37,7 @@ class TestAuthorization: ids=["empty_auth", "invalid_api_token"], ) def test_auth_invalid(self, invalid_auth, expected_code, expected_message): - res = update_kb(invalid_auth, "dataset_id") + res = update_dataset(invalid_auth, "dataset_id") assert res["code"] == expected_code, res assert res["message"] == expected_message, res @@ -50,13 +50,13 @@ class TestCapability: with ThreadPoolExecutor(max_workers=5) as executor: futures = [ executor.submit( - update_kb, + update_dataset, WebApiAuth, + dataset_id, { - "kb_id": dataset_id, "name": f"dataset_{i}", "description": "", - "parser_id": "naive", + "chunk_method": "naive", }, ) for i in range(count) @@ -69,8 +69,8 @@ class TestCapability: class TestDatasetUpdate: @pytest.mark.p3 def test_dataset_id_not_uuid(self, WebApiAuth): - payload = {"name": "not uuid", "description": "", "parser_id": "naive", "kb_id": "not_uuid"} - res = update_kb(WebApiAuth, payload) + payload = {"name": "not uuid", "description": "", "chunk_method": "naive"} + res = update_dataset(WebApiAuth, "not_uuid", payload) assert res["code"] == 109, res assert "No authorization." in res["message"], res @@ -81,8 +81,8 @@ class TestDatasetUpdate: @settings(max_examples=20, suppress_health_check=[HealthCheck.function_scoped_fixture], deadline=None) def test_name(self, WebApiAuth, add_dataset_func, name): dataset_id = add_dataset_func - payload = {"name": name, "description": "", "parser_id": "naive", "kb_id": dataset_id} - res = update_kb(WebApiAuth, payload) + payload = {"name": name, "description": "", "chunk_method": "naive"} + res = update_dataset(WebApiAuth, dataset_id, payload) assert res["code"] == 0, res assert res["data"]["name"] == name, res @@ -90,27 +90,27 @@ class TestDatasetUpdate: @pytest.mark.parametrize( "name, expected_message", [ - ("", "Dataset name can't be empty."), - (" ", "Dataset name can't be empty."), - ("a" * (DATASET_NAME_LIMIT + 1), "Dataset name length is 129 which is large than 128"), - (0, "Dataset name must be string."), - (None, "Dataset name must be string."), + ("", "Field: - Message: "), + (" ", "Field: - Message: "), + ("a" * (DATASET_NAME_LIMIT + 1), "Field: - Message: "), + (0, "Field: - Message: "), + (None, "Field: - Message: "), ], ids=["empty_name", "space_name", "too_long_name", "invalid_name", "None_name"], ) def test_name_invalid(self, WebApiAuth, add_dataset_func, name, expected_message): kb_id = add_dataset_func - payload = {"name": name, "description": "", "parser_id": "naive", "kb_id": kb_id} - res = update_kb(WebApiAuth, payload) - assert res["code"] == 102, res + payload = {"name": name, "description": "", "chunk_method": "naive"} + res = update_dataset(WebApiAuth, kb_id, payload) + assert res["code"] == 101, res assert expected_message in res["message"], res @pytest.mark.p3 def test_name_duplicated(self, WebApiAuth, add_datasets_func): kb_id = add_datasets_func[0] name = "kb_1" - payload = {"name": name, "description": "", "parser_id": "naive", "kb_id": kb_id} - res = update_kb(WebApiAuth, payload) + payload = {"name": name, "description": "", "chunk_method": "naive"} + res = update_dataset(WebApiAuth, kb_id, payload) assert res["code"] == 102, res assert res["message"] == "Duplicated dataset name.", res @@ -118,8 +118,8 @@ class TestDatasetUpdate: def test_name_case_insensitive(self, WebApiAuth, add_datasets_func): kb_id = add_datasets_func[0] name = "KB_1" - payload = {"name": name, "description": "", "parser_id": "naive", "kb_id": kb_id} - res = update_kb(WebApiAuth, payload) + payload = {"name": name, "description": "", "chunk_method": "naive"} + res = update_dataset(WebApiAuth, kb_id, payload) assert res["code"] == 102, res assert res["message"] == "Duplicated dataset name.", res @@ -130,19 +130,18 @@ class TestDatasetUpdate: payload = { "name": "avatar", "description": "", - "parser_id": "naive", - "kb_id": kb_id, + "chunk_method": "naive", "avatar": f"data:image/png;base64,{encode_avatar(fn)}", } - res = update_kb(WebApiAuth, payload) + res = update_dataset(WebApiAuth, kb_id, payload) assert res["code"] == 0, res assert res["data"]["avatar"] == f"data:image/png;base64,{encode_avatar(fn)}", res @pytest.mark.p2 def test_description(self, WebApiAuth, add_dataset_func): kb_id = add_dataset_func - payload = {"name": "description", "description": "description", "parser_id": "naive", "kb_id": kb_id} - res = update_kb(WebApiAuth, payload) + payload = {"name": "description", "description": "description", "chunk_method": "naive"} + res = update_dataset(WebApiAuth, kb_id, payload) assert res["code"] == 0, res assert res["data"]["description"] == "description", res @@ -157,10 +156,10 @@ class TestDatasetUpdate: ) def test_embedding_model(self, WebApiAuth, add_dataset_func, embedding_model): kb_id = add_dataset_func - payload = {"name": "embedding_model", "description": "", "parser_id": "naive", "kb_id": kb_id, "embd_id": embedding_model} - res = update_kb(WebApiAuth, payload) + payload = {"name": "embedding_model", "description": "", "chunk_method": "naive", "embedding_model": embedding_model} + res = update_dataset(WebApiAuth, kb_id, payload) assert res["code"] == 0, res - assert res["data"]["embd_id"] == embedding_model, res + assert res["data"]["embedding_model"] == embedding_model, res @pytest.mark.p2 @pytest.mark.parametrize( @@ -173,8 +172,8 @@ class TestDatasetUpdate: ) def test_permission(self, WebApiAuth, add_dataset_func, permission): kb_id = add_dataset_func - payload = {"name": "permission", "description": "", "parser_id": "naive", "kb_id": kb_id, "permission": permission} - res = update_kb(WebApiAuth, payload) + payload = {"name": "permission", "description": "", "chunk_method": "naive", "permission": permission} + res = update_dataset(WebApiAuth, kb_id, payload) assert res["code"] == 0, res assert res["data"]["permission"] == permission.lower().strip(), res @@ -199,17 +198,17 @@ class TestDatasetUpdate: ) def test_chunk_method(self, WebApiAuth, add_dataset_func, chunk_method): kb_id = add_dataset_func - payload = {"name": "chunk_method", "description": "", "parser_id": chunk_method, "kb_id": kb_id} - res = update_kb(WebApiAuth, payload) + payload = {"name": "chunk_method", "description": "", "chunk_method": chunk_method} + res = update_dataset(WebApiAuth, kb_id, payload) assert res["code"] == 0, res - assert res["data"]["parser_id"] == chunk_method, res + assert res["data"]["chunk_method"] == chunk_method, res @pytest.mark.p1 @pytest.mark.skipif(os.getenv("DOC_ENGINE") != "infinity", reason="Infinity does not support parser_id=tag") def test_chunk_method_tag_with_infinity(self, WebApiAuth, add_dataset_func): kb_id = add_dataset_func - payload = {"name": "chunk_method", "description": "", "parser_id": "tag", "kb_id": kb_id} - res = update_kb(WebApiAuth, payload) + payload = {"name": "chunk_method", "description": "", "chunk_method": "tag"} + res = update_dataset(WebApiAuth, kb_id, payload) assert res["code"] == 103, res assert res["message"] == "The chunking method Tag has not been supported by Infinity yet.", res @@ -218,8 +217,8 @@ class TestDatasetUpdate: @pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"]) def test_pagerank(self, WebApiAuth, add_dataset_func, pagerank): kb_id = add_dataset_func - payload = {"name": "pagerank", "description": "", "parser_id": "naive", "kb_id": kb_id, "pagerank": pagerank} - res = update_kb(WebApiAuth, payload) + payload = {"name": "pagerank", "description": "", "chunk_method": "naive", "pagerank": pagerank} + res = update_dataset(WebApiAuth, kb_id, payload) assert res["code"] == 0, res assert res["data"]["pagerank"] == pagerank, res @@ -227,13 +226,13 @@ class TestDatasetUpdate: @pytest.mark.p2 def test_pagerank_set_to_0(self, WebApiAuth, add_dataset_func): kb_id = add_dataset_func - payload = {"name": "pagerank", "description": "", "parser_id": "naive", "kb_id": kb_id, "pagerank": 50} - res = update_kb(WebApiAuth, payload) + payload = {"name": "pagerank", "description": "", "chunk_method": "naive", "pagerank": 50} + res = update_dataset(WebApiAuth, kb_id, payload) assert res["code"] == 0, res assert res["data"]["pagerank"] == 50, res - payload = {"name": "pagerank", "description": "", "parser_id": "naive", "kb_id": kb_id, "pagerank": 0} - res = update_kb(WebApiAuth, payload) + payload = {"name": "pagerank", "description": "", "chunk_method": "naive", "pagerank": 0} + res = update_dataset(WebApiAuth, kb_id, payload) assert res["code"] == 0, res assert res["data"]["pagerank"] == 0, res @@ -241,8 +240,8 @@ class TestDatasetUpdate: @pytest.mark.p2 def test_pagerank_infinity(self, WebApiAuth, add_dataset_func): kb_id = add_dataset_func - payload = {"name": "pagerank", "description": "", "parser_id": "naive", "kb_id": kb_id, "pagerank": 50} - res = update_kb(WebApiAuth, payload) + payload = {"name": "pagerank", "description": "", "chunk_method": "naive", "pagerank": 50} + res = update_dataset(WebApiAuth, kb_id, payload) assert res["code"] == 102, res assert res["message"] == "'pagerank' can only be set when doc_engine is elasticsearch", res @@ -352,10 +351,15 @@ class TestDatasetUpdate: ) def test_parser_config(self, WebApiAuth, add_dataset_func, parser_config): kb_id = add_dataset_func - payload = {"name": "parser_config", "description": "", "parser_id": "naive", "kb_id": kb_id, "parser_config": parser_config} - res = update_kb(WebApiAuth, payload) + payload = {"name": "parser_config", "description": "", "chunk_method": "naive", "parser_config": parser_config} + res = update_dataset(WebApiAuth, kb_id, payload) assert res["code"] == 0, res - assert res["data"]["parser_config"] == parser_config, res + for key, value in parser_config.items(): + if not isinstance(value, dict): + assert res["data"]["parser_config"].get(key) == value, res + else: + for sub_key, sub_value in value.items(): + assert res["data"]["parser_config"].get(key, {}).get(sub_key) == sub_value, res @pytest.mark.p2 @pytest.mark.parametrize( @@ -372,7 +376,7 @@ class TestDatasetUpdate: ) def test_field_unsupported(self, WebApiAuth, add_dataset_func, payload): kb_id = add_dataset_func - full_payload = {"name": "field_unsupported", "description": "", "parser_id": "naive", "kb_id": kb_id, **payload} - res = update_kb(WebApiAuth, full_payload) + full_payload = {"name": "field_unsupported", "description": "", "chunk_method": "naive", **payload} + res = update_dataset(WebApiAuth, kb_id, full_payload) assert res["code"] == 101, res - assert "isn't allowed" in res["message"], res + assert "are not permitted" in res["message"], res diff --git a/web/src/components/chunk-method-dialog/index.tsx b/web/src/components/chunk-method-dialog/index.tsx index c845fda35..34dfb60e2 100644 --- a/web/src/components/chunk-method-dialog/index.tsx +++ b/web/src/components/chunk-method-dialog/index.tsx @@ -181,7 +181,7 @@ export function ChunkMethodDialog({ }); const selectedTag = useWatch({ - name: 'parser_id', + name: 'chunk_method', control: form.control, }); const isMineruSelected = diff --git a/web/src/components/knowledge-base-item.tsx b/web/src/components/knowledge-base-item.tsx index 9ced7bfc2..90faf38f3 100644 --- a/web/src/components/knowledge-base-item.tsx +++ b/web/src/components/knowledge-base-item.tsx @@ -23,7 +23,7 @@ export function useDisableDifferenceEmbeddingDataset() { useEffect(() => { const datasetListMap = datasetListOrigin - .filter((x) => x.parser_id !== DocumentParserType.Tag) + .filter((x) => x.chunk_method !== DocumentParserType.Tag) .map((item: IKnowledge) => { return { label: item.name, @@ -36,12 +36,12 @@ export function useDisableDifferenceEmbeddingDataset() { ), suffix: (
- {item.embd_id} + {item.embedding_model}
), value: item.id, disabled: - item.embd_id !== datasetSelectEmbedId && + item.embedding_model !== datasetSelectEmbedId && datasetSelectEmbedId !== '', }; }); @@ -54,7 +54,7 @@ export function useDisableDifferenceEmbeddingDataset() { ) => { if (value.length) { const data = datasetListOrigin?.find((item) => item.id === value[0]); - setDatasetSelectEmbedId(data?.embd_id ?? ''); + setDatasetSelectEmbedId(data?.embedding_model ?? ''); } else { setDatasetSelectEmbedId(''); } diff --git a/web/src/components/ui/multi-select.tsx b/web/src/components/ui/multi-select.tsx index b4464aff8..287ec26e4 100644 --- a/web/src/components/ui/multi-select.tsx +++ b/web/src/components/ui/multi-select.tsx @@ -242,7 +242,9 @@ export const MultiSelect = React.forwardRef< const disabledValueSet = React.useMemo(() => { return new Set( - flatOptions.filter((option) => option.disabled).map((option) => option.value), + flatOptions + .filter((option) => option.disabled) + .map((option) => option.value), ); }, [flatOptions]); diff --git a/web/src/hooks/use-knowledge-request.ts b/web/src/hooks/use-knowledge-request.ts index d7294b3a3..4985a21b6 100644 --- a/web/src/hooks/use-knowledge-request.ts +++ b/web/src/hooks/use-knowledge-request.ts @@ -18,6 +18,7 @@ import kbService, { listTag, removeTag, renameTag, + updateKb, } from '@/services/knowledge-service'; import { useIsMutating, @@ -137,22 +138,20 @@ export const useFetchNextKnowledgeListByPage = () => { ], initialData: { kbs: [], - total: 0, + total_datasets: 0, }, gcTime: 0, queryFn: async () => { - const { data } = await listDataset( - { + const { data } = await listDataset({ + page_size: pagination.pageSize, + page: pagination.current, + ext: { keywords: debouncedSearchString, - page_size: pagination.pageSize, - page: pagination.current, - }, - { owner_ids: filterValue.owner, }, - ); + }); - return data?.data; + return { kbs: data?.data, total_datasets: data?.total_datasets }; }, }); @@ -168,7 +167,7 @@ export const useFetchNextKnowledgeListByPage = () => { ...data, searchString, handleInputChange: onInputChange, - pagination: { ...pagination, total: data?.total }, + pagination: { ...pagination, total: data?.total_datasets }, setPagination, loading, filterValue, @@ -184,7 +183,18 @@ export const useCreateKnowledge = () => { mutateAsync, } = useMutation({ mutationKey: [KnowledgeApiAction.CreateKnowledge], - mutationFn: async (params: { id?: string; name: string }) => { + mutationFn: async (params: { + id?: string; + name: string; + embedding_model?: string; + chunk_method?: string; + parseType?: number; + pipeline_id?: string | null; + ext?: { + language?: string; + [key: string]: any; + }; + }) => { const { data = {} } = await kbService.createKb(params); if (data.code === 0) { message.success( @@ -208,7 +218,7 @@ export const useDeleteKnowledge = () => { } = useMutation({ mutationKey: [KnowledgeApiAction.DeleteKnowledge], mutationFn: async (id: string) => { - const { data } = await kbService.rmKb({ kb_id: id }); + const { data } = await kbService.rmKb({ ids: [id] }); if (data.code === 0) { message.success(i18n.t(`message.deleted`)); queryClient.invalidateQueries({ @@ -225,17 +235,119 @@ export const useDeleteKnowledge = () => { export const useUpdateKnowledge = (shouldFetchList = false) => { const knowledgeBaseId = useKnowledgeBaseId(); const queryClient = useQueryClient(); + + const extractRaptorConfigExt = ( + raptorConfig: Record | undefined, + ) => { + if (!raptorConfig) return raptorConfig; + const { + use_raptor, + prompt, + max_token, + threshold, + max_cluster, + random_seed, + auto_disable_for_structured_data, + ext, + ...raptorExt + } = raptorConfig; + return { + use_raptor, + prompt, + max_token, + threshold, + max_cluster, + random_seed, + auto_disable_for_structured_data, + ext: { ...ext, ...raptorExt }, + }; + }; + + const extractParserConfigExt = ( + parserConfig: Record | undefined, + ) => { + if (!parserConfig) return parserConfig; + const { + auto_keywords, + auto_questions, + chunk_token_num, + delimiter, + graphrag, + html4excel, + layout_recognize, + raptor, + tag_kb_ids, + topn_tags, + filename_embd_weight, + task_page_size, + pages, + ext, + ...parserExt + } = parserConfig; + return { + auto_keywords, + auto_questions, + chunk_token_num, + delimiter, + graphrag, + html4excel, + layout_recognize, + raptor: extractRaptorConfigExt(raptor), + tag_kb_ids, + topn_tags, + filename_embd_weight, + task_page_size, + pages, + ext: { ...ext, ...parserExt }, + }; + }; + const { data, isPending: loading, mutateAsync, } = useMutation({ mutationKey: [KnowledgeApiAction.SaveKnowledge], - mutationFn: async (params: Record) => { - const { data = {} } = await kbService.updateKb({ - kb_id: params?.kb_id ? params?.kb_id : knowledgeBaseId, - ...params, - }); + mutationFn: async (params: { + kb_id?: string; + name?: string; + embedding_model?: string; + chunk_method?: string; + pipeline_id?: string | null; + avatar?: string | null; + description?: string; + permission?: string; + pagerank?: number; + parser_config?: Record; + [key: string]: any; + }) => { + const kbId = params?.kb_id || knowledgeBaseId; + const { + kb_id, + name, + embedding_model, + chunk_method, + pipeline_id, + avatar, + description, + permission, + pagerank, + parser_config, + ...ext + } = params; + const requestBody: Record = { + name, + embedding_model, + chunk_method, + pipeline_id, + avatar, + description, + permission, + pagerank, + parser_config: extractParserConfigExt(parser_config), + ext, + }; + const { data = {} } = await updateKb(kbId, requestBody); if (data.code === 0) { message.success(i18n.t(`message.updated`)); if (shouldFetchList) { @@ -359,9 +471,9 @@ export const useFetchKnowledgeList = ( gcTime: 0, // https://tanstack.com/query/latest/docs/framework/react/guides/caching?from=reactQueryV3 queryFn: async () => { const { data } = await listDataset(); - const list = data?.data?.kbs ?? []; + const list = data?.data ?? []; return shouldFilterListWithoutDocument - ? list.filter((x: IKnowledge) => x.chunk_num > 0) + ? list.filter((x: IKnowledge) => x.chunk_count > 0) : list; }, }); diff --git a/web/src/interfaces/database/knowledge.ts b/web/src/interfaces/database/knowledge.ts index f68a6f5b0..6ef02986b 100644 --- a/web/src/interfaces/database/knowledge.ts +++ b/web/src/interfaces/database/knowledge.ts @@ -11,16 +11,16 @@ export interface IConnector { // knowledge base export interface IKnowledge { avatar?: any; - chunk_num: number; + chunk_count: number; create_date: string; create_time: number; created_by: string; description: string; - doc_num: number; + document_count: number; id: string; name: string; parser_config: ParserConfig; - parser_id: string; + chunk_method: string; pipeline_id: string; pipeline_name: string; pipeline_avatar: string; @@ -32,7 +32,7 @@ export interface IKnowledge { update_date: string; update_time: number; vector_similarity_weight: number; - embd_id: string; + embedding_model: string; nickname: string; operator_permission: number; size: number; @@ -47,7 +47,7 @@ export interface IKnowledge { export interface IKnowledgeResult { kbs: IKnowledge[]; - total: number; + total_datasets: number; } export interface Raptor { diff --git a/web/src/interfaces/request/knowledge.ts b/web/src/interfaces/request/knowledge.ts index 8690c1062..f93c5dd93 100644 --- a/web/src/interfaces/request/knowledge.ts +++ b/web/src/interfaces/request/knowledge.ts @@ -24,10 +24,14 @@ export interface IFetchKnowledgeListRequestBody { } export interface IFetchKnowledgeListRequestParams { - kb_id?: string; - keywords?: string; + id?: string; page?: number; page_size?: number; + ext?: { + keywords?: string; + owner_ids?: string[]; + parser_id?: string; + }; } export interface IFetchDocumentListRequestBody { diff --git a/web/src/pages/dataset/dataset-setting/chunk-method-form.tsx b/web/src/pages/dataset/dataset-setting/chunk-method-form.tsx index 8d6debc16..3f48eb39c 100644 --- a/web/src/pages/dataset/dataset-setting/chunk-method-form.tsx +++ b/web/src/pages/dataset/dataset-setting/chunk-method-form.tsx @@ -45,7 +45,7 @@ export function ChunkMethodForm() { const finalParserId: DocumentParserType = useWatch({ control: form.control, - name: 'parser_id', + name: 'chunk_method', }); const ConfigurationComponent = useMemo(() => { diff --git a/web/src/pages/dataset/dataset-setting/configuration/common-item.tsx b/web/src/pages/dataset/dataset-setting/configuration/common-item.tsx index 0ce24fc29..70446e5a3 100644 --- a/web/src/pages/dataset/dataset-setting/configuration/common-item.tsx +++ b/web/src/pages/dataset/dataset-setting/configuration/common-item.tsx @@ -69,7 +69,7 @@ export function ChunkMethodItem(props: IProps) { return ( (
@@ -121,7 +121,7 @@ export const EmbeddingSelect = ({ const { handleChange } = useHandleKbEmbedding(); const oldValue = useMemo(() => { - const embdStr = form.getValues(name || 'embd_id'); + const embdStr = form.getValues(name || 'embedding_model'); return embdStr || ''; }, [form]); const [loading, setLoading] = useState(false); @@ -165,7 +165,7 @@ export function EmbeddingModelItem({ line = 1, isEdit }: IProps) { <> (
0; + return knowledgeDetails.chunk_count > 0; } export const useFetchKnowledgeConfigurationOnMount = ( @@ -60,14 +60,14 @@ export const useFetchKnowledgeConfigurationOnMount = ( 'description', 'name', 'permission', - 'embd_id', - 'parser_id', 'language', 'parser_config', 'connectors', 'pagerank', 'avatar', ]), + embedding_model: knowledgeDetails.embd_id, + chunk_method: knowledgeDetails.parser_id, } as z.infer; form.reset(formValues); }, [form, knowledgeDetails]); diff --git a/web/src/pages/dataset/dataset-setting/index.tsx b/web/src/pages/dataset/dataset-setting/index.tsx index 5cce7049a..53e2048f8 100644 --- a/web/src/pages/dataset/dataset-setting/index.tsx +++ b/web/src/pages/dataset/dataset-setting/index.tsx @@ -219,7 +219,7 @@ export default function DatasetSettings() { defaultValue: knowledgeDetails.pipeline_id ? 2 : 1, }); const selectedTag = useWatch({ - name: 'parser_id', + name: 'chunk_method', control: form.control, }); useEffect(() => { diff --git a/web/src/pages/dataset/dataset-setting/saving-button.tsx b/web/src/pages/dataset/dataset-setting/saving-button.tsx index 8d77ca57a..76ceb1eef 100644 --- a/web/src/pages/dataset/dataset-setting/saving-button.tsx +++ b/web/src/pages/dataset/dataset-setting/saving-button.tsx @@ -16,7 +16,7 @@ export function GeneralSavingButton() { () => form.formState.defaultValues ?? {}, [form.formState.defaultValues], ); - const parser_id = defaultValues['parser_id']; + const chunk_method = defaultValues['chunk_method']; return ( { retryDelay: 1000, enabled: open, queryFn: async () => { - const { data } = await kbService.traceGraphRag({ - kb_id: id, - }); + const { data } = await traceGraphRag(id); return data?.data || {}; }, }); @@ -70,9 +74,7 @@ export const useTraceGenerate = ({ open }: { open: boolean }) => { retryDelay: 1000, enabled: open, queryFn: async () => { - const { data } = await kbService.traceRaptor({ - kb_id: id, - }); + const { data } = await traceRaptor(id); return data?.data || {}; }, }); @@ -133,12 +135,8 @@ export const useDatasetGenerate = () => { mutationKey: [DatasetKey.generate], mutationFn: async ({ type }: { type: GenerateType }) => { const func = - type === GenerateType.KnowledgeGraph - ? kbService.runGraphRag - : kbService.runRaptor; - const { data } = await func({ - kb_id: id, - }); + type === GenerateType.KnowledgeGraph ? runGraphRag : runRaptor; + const { data } = await func(id); if (data.code === 0) { message.success(t('message.operated')); queryClient.invalidateQueries({ diff --git a/web/src/pages/datasets/dataset-card.tsx b/web/src/pages/datasets/dataset-card.tsx index d49fd6308..338951b27 100644 --- a/web/src/pages/datasets/dataset-card.tsx +++ b/web/src/pages/datasets/dataset-card.tsx @@ -23,7 +23,7 @@ export function DatasetCard({ ) { }) .trim(), parseType: z.number().optional(), - embd_id: z + embedding_model: z .string() .min(1, { message: t('knowledgeConfiguration.embeddingModelPlaceholder'), }) .trim(), - parser_id: z.string().optional(), + chunk_method: z.string().optional(), pipeline_id: z.string().optional(), }) .superRefine((data, ctx) => { - // When parseType === 1, parser_id is required + // When parseType === 1, chunk_method is required if ( data.parseType === 1 && - (!data.parser_id || data.parser_id.trim() === '') + (!data.chunk_method || data.chunk_method.trim() === '') ) { ctx.addIssue({ code: z.ZodIssueCode.custom, @@ -82,8 +82,8 @@ export function InputForm({ onOk }: IModalProps) { defaultValues: { name: '', parseType: 1, - parser_id: '', - embd_id: tenantInfo?.embd_id, + chunk_method: '', + embedding_model: tenantInfo?.embd_id, }, }); diff --git a/web/src/pages/datasets/hooks.ts b/web/src/pages/datasets/hooks.ts index d194af9e1..140976ee7 100644 --- a/web/src/pages/datasets/hooks.ts +++ b/web/src/pages/datasets/hooks.ts @@ -16,8 +16,14 @@ export const useSearchKnowledge = () => { export interface Iknowledge { name: string; - embd_id: string; - parser_id: string; + embedding_model?: string; + chunk_method?: string; + parseType?: number; + pipeline_id?: string | null; + ext?: { + language?: string; + [key: string]: any; + }; } export const useSaveKnowledge = () => { const { visible: visible, hideModal, showModal } = useSetModalState(); @@ -30,7 +36,7 @@ export const useSaveKnowledge = () => { if (ret?.code === 0) { hideModal(); - navigateToDataset(ret.data.kb_id)(); + navigateToDataset(ret.data.id)(); } }, [createKnowledge, hideModal, navigateToDataset], diff --git a/web/src/pages/datasets/index.tsx b/web/src/pages/datasets/index.tsx index 05ff2a78d..85fa19226 100644 --- a/web/src/pages/datasets/index.tsx +++ b/web/src/pages/datasets/index.tsx @@ -30,7 +30,7 @@ export default function Datasets() { const { kbs, - total, + total_datasets, pagination, setPagination, handleInputChange, @@ -107,7 +107,7 @@ export default function Datasets() {
diff --git a/web/src/services/knowledge-service.ts b/web/src/services/knowledge-service.ts index 4c1225990..e727fc500 100644 --- a/web/src/services/knowledge-service.ts +++ b/web/src/services/knowledge-service.ts @@ -1,7 +1,6 @@ import { IRenameTag } from '@/interfaces/database/knowledge'; import { IFetchDocumentListRequestBody, - IFetchKnowledgeListRequestBody, IFetchKnowledgeListRequestParams, } from '@/interfaces/request/knowledge'; import { ProcessingType } from '@/pages/dataset/dataset-overview/dataset-common'; @@ -11,7 +10,6 @@ import request, { post } from '@/utils/request'; const { create_kb, - update_kb, rm_kb, get_kb_detail, kb_list, @@ -42,10 +40,6 @@ const { getKnowledgeBasicInfo, fetchDataPipelineLog, fetchPipelineDatasetLogs, - runGraphRag, - traceGraphRag, - runRaptor, - traceRaptor, check_embedding, kbUpdateMetaData, documentUpdateMetaData, @@ -56,13 +50,9 @@ const methods = { url: create_kb, method: 'post', }, - updateKb: { - url: update_kb, - method: 'post', - }, rmKb: { url: rm_kb, - method: 'post', + method: 'delete', }, get_kb_detail: { url: get_kb_detail, @@ -70,7 +60,7 @@ const methods = { }, getList: { url: kb_list, - method: 'post', + method: 'get', }, // document manager get_document_list: { @@ -191,22 +181,6 @@ const methods = { method: 'get', }, - runGraphRag: { - url: runGraphRag, - method: 'post', - }, - traceGraphRag: { - url: traceGraphRag, - method: 'get', - }, - runRaptor: { - url: runRaptor, - method: 'post', - }, - traceRaptor: { - url: traceRaptor, - method: 'get', - }, pipelineRerun: { url: api.pipelineRerun, method: 'post', @@ -251,10 +225,23 @@ export function deleteKnowledgeGraph(knowledgeId: string) { return request.delete(api.getKnowledgeGraph(knowledgeId)); } -export const listDataset = ( - params?: IFetchKnowledgeListRequestParams, - body?: IFetchKnowledgeListRequestBody, -) => request.post(api.kb_list, { data: body || {}, params }); +export const listDataset = (params?: IFetchKnowledgeListRequestParams) => + request.get(api.kb_list, { params }); + +export const updateKb = (datasetId: string, data: Record) => + request.put(api.update_kb(datasetId), { data }); + +export const runGraphRag = (datasetId: string) => + request.post(api.runGraphRag(datasetId)); + +export const traceGraphRag = (datasetId: string) => + request.get(api.traceGraphRag(datasetId)); + +export const runRaptor = (datasetId: string) => + request.post(api.runRaptor(datasetId)); + +export const traceRaptor = (datasetId: string) => + request.get(api.traceRaptor(datasetId)); export const listDocument = ( params?: IFetchKnowledgeListRequestParams, diff --git a/web/src/utils/api.ts b/web/src/utils/api.ts index 20440065a..5750c9aee 100644 --- a/web/src/utils/api.ts +++ b/web/src/utils/api.ts @@ -57,23 +57,30 @@ export default { // knowledge base check_embedding: `${api_host}/kb/check_embedding`, - kb_list: `${api_host}/kb/list`, - create_kb: `${api_host}/kb/create`, - update_kb: `${api_host}/kb/update`, - rm_kb: `${api_host}/kb/rm`, + kb_list: `${ExternalApi}${api_host}/datasets`, + create_kb: `${ExternalApi}${api_host}/datasets`, + update_kb: (datasetId: string) => + `${ExternalApi}${api_host}/datasets/${datasetId}`, + rm_kb: `${ExternalApi}${api_host}/datasets`, get_kb_detail: `${api_host}/kb/detail`, getKnowledgeGraph: (knowledgeId: string) => - `${api_host}/kb/${knowledgeId}/knowledge_graph`, + `${ExternalApi}${api_host}/datasets/${knowledgeId}/knowledge_graph`, + deleteKnowledgeGraph: (knowledgeId: string) => + `${ExternalApi}${api_host}/datasets/${knowledgeId}/knowledge_graph`, getMeta: `${api_host}/kb/get_meta`, getKnowledgeBasicInfo: `${api_host}/kb/basic_info`, // data pipeline log fetchDataPipelineLog: `${api_host}/kb/list_pipeline_logs`, get_pipeline_detail: `${api_host}/kb/pipeline_log_detail`, fetchPipelineDatasetLogs: `${api_host}/kb/list_pipeline_dataset_logs`, - runGraphRag: `${api_host}/kb/run_graphrag`, - traceGraphRag: `${api_host}/kb/trace_graphrag`, - runRaptor: `${api_host}/kb/run_raptor`, - traceRaptor: `${api_host}/kb/trace_raptor`, + runGraphRag: (datasetId: string) => + `${ExternalApi}${api_host}/datasets/${datasetId}/run_graphrag`, + traceGraphRag: (datasetId: string) => + `${ExternalApi}${api_host}/datasets/${datasetId}/trace_graphrag`, + runRaptor: (datasetId: string) => + `${ExternalApi}${api_host}/datasets/${datasetId}/run_raptor`, + traceRaptor: (datasetId: string) => + `${ExternalApi}${api_host}/datasets/${datasetId}/trace_raptor`, unbindPipelineTask: ({ kb_id, type }: { kb_id: string; type: string }) => `${api_host}/kb/unbind_task?kb_id=${kb_id}&pipeline_task_type=${type}`, pipelineRerun: `${api_host}/canvas/rerun`,