diff --git a/admin/client/ragflow_client.py b/admin/client/ragflow_client.py index 86236bb52..423befafa 100644 --- a/admin/client/ragflow_client.py +++ b/admin/client/ragflow_client.py @@ -764,14 +764,14 @@ class RAGFlowClient: iterations = command.get("iterations", 1) if iterations > 1: - response = self.http_client.request("GET", "/datasets", use_api_base=True, auth_kind="web", + response = self.http_client.request("POST", "/kb/list", use_api_base=False, auth_kind="web", iterations=iterations) return response else: - response = self.http_client.request("GET", "/datasets", use_api_base=True, auth_kind="web") + response = self.http_client.request("POST", "/kb/list", use_api_base=False, auth_kind="web") res_json = response.json() if response.status_code == 200: - self._print_table_simple(res_json["data"]) + self._print_table_simple(res_json["data"]["kbs"]) else: print(f"Fail to list datasets, code: {res_json['code']}, message: {res_json['message']}") return None @@ -781,13 +781,13 @@ class RAGFlowClient: print("This command is only allowed in USER mode") payload = { "name": command["dataset_name"], - "embedding_model": command["embedding"] + "embd_id": command["embedding"] } if "parser_id" in command: - payload["chunk_method"] = command["parser"] + payload["parser_id"] = command["parser"] if "pipeline" in command: payload["pipeline_id"] = command["pipeline"] - response = self.http_client.request("POST", "/datasets", json_body=payload, use_api_base=True, + response = self.http_client.request("POST", "/kb/create", json_body=payload, use_api_base=False, auth_kind="web") res_json = response.json() if response.status_code == 200: @@ -803,8 +803,8 @@ class RAGFlowClient: dataset_id = self._get_dataset_id(dataset_name) if dataset_id is None: return - payload = {"ids": [dataset_id]} - response = self.http_client.request("DELETE", "/datasets", json_body=payload, use_api_base=True, auth_kind="web") + payload = {"kb_id": dataset_id} + response = self.http_client.request("POST", "/kb/rm", json_body=payload, use_api_base=False, auth_kind="web") res_json = response.json() if response.status_code == 200: print(f"Drop dataset {dataset_name} successfully") @@ -1349,13 +1349,13 @@ class RAGFlowClient: return res_json["data"]["docs"] def _get_dataset_id(self, dataset_name: str): - response = self.http_client.request("GET", "/datasets", use_api_base=True, auth_kind="web") + response = self.http_client.request("POST", "/kb/list", use_api_base=False, auth_kind="web") res_json = response.json() if response.status_code != 200: print(f"Fail to list datasets, code: {res_json['code']}, message: {res_json['message']}") return None - dataset_list = res_json["data"] + dataset_list = res_json["data"]["kbs"] dataset_id: str = "" for dataset in dataset_list: if dataset["name"] == dataset_name: diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index f817de633..8a57bcd63 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import json import logging import random import re @@ -25,29 +26,34 @@ from api.db.services.connector_service import Connector2KbService from api.db.services.llm_service import LLMBundle from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks from api.db.services.doc_metadata_service import DocMetadataService +from api.db.services.file2document_service import File2DocumentService +from api.db.services.file_service import FileService from api.db.services.pipeline_operation_log_service import PipelineOperationLogService from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID -from api.db.services.user_service import UserTenantService +from api.db.services.user_service import TenantService, UserTenantService from api.db.joint_services.tenant_model_service import get_model_config_by_type_and_name, get_model_config_by_id from api.utils.api_utils import ( get_error_data_result, server_error_response, get_data_error_result, validate_request, + not_allowed_parameters, get_request_json, ) +from common.misc_utils import thread_pool_exec from api.db import VALID_FILE_TYPES from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.db_models import File from api.utils.api_utils import get_json_result +from api.utils.tenant_utils import ensure_tenant_model_id_for_params from rag.nlp import search +from api.constants import DATASET_NAME_LIMIT from rag.utils.redis_conn import REDIS_CONN -from common.constants import RetCode, PipelineTaskType, VALID_TASK_STATUS, LLMType +from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType, PAGERANK_FLD from common import settings from common.doc_store.doc_store_base import OrderByExpr from api.apps import login_required, current_user -""" -Deprecated, todo delete @manager.route('/create', methods=['post']) # noqa: F821 @login_required @validate_request("name") @@ -180,7 +186,7 @@ async def update(): return get_json_result(data=kb) except Exception as e: return server_error_response(e) -""" + @manager.route('/update_metadata_setting', methods=['post']) # noqa: F821 @login_required @@ -228,8 +234,7 @@ def detail(): except Exception as e: return server_error_response(e) -""" -Deprecated, todo delete + @manager.route('/list', methods=['POST']) # noqa: F821 @login_required async def list_kbs(): @@ -324,7 +329,7 @@ async def rm(): return await thread_pool_exec(_rm_sync) except Exception as e: return server_error_response(e) -""" + @manager.route('//tags', methods=['GET']) # noqa: F821 @login_required @@ -400,8 +405,7 @@ async def rename_tags(kb_id): kb_id) return get_json_result(data=True) -""" -Deprecated, todo delete + @manager.route('//knowledge_graph', methods=['GET']) # noqa: F821 @login_required async def knowledge_graph(kb_id): @@ -455,7 +459,7 @@ def delete_knowledge_graph(kb_id): settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id) return get_json_result(data=True) -""" + @manager.route("/get_meta", methods=["GET"]) # noqa: F821 @login_required @@ -594,8 +598,6 @@ def pipeline_log_detail(): return get_json_result(data=log.to_dict()) -""" -Deprecated, todo delete @manager.route("/run_graphrag", methods=["POST"]) # noqa: F821 @login_required async def run_graphrag(): @@ -732,7 +734,7 @@ def trace_raptor(): return get_error_data_result(message="RAPTOR Task Not Found or Error Occurred") return get_json_result(data=task.to_dict()) -""" + @manager.route("/run_mindmap", methods=["POST"]) # noqa: F821 @login_required diff --git a/api/apps/restful_apis/dataset_api.py b/api/apps/restful_apis/dataset_api.py deleted file mode 100644 index 4f3ff2d59..000000000 --- a/api/apps/restful_apis/dataset_api.py +++ /dev/null @@ -1,517 +0,0 @@ -# -# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import logging - -from peewee import OperationalError -from quart import request -from common.constants import RetCode -from api.apps import login_required, current_user -from api.utils.api_utils import get_error_argument_result, get_error_data_result, get_result, add_tenant_id_to_kwargs -from api.utils.validation_utils import ( - CreateDatasetReq, - DeleteDatasetReq, - ListDatasetReq, - UpdateDatasetReq, - validate_and_parse_json_request, - validate_and_parse_request_args, -) -from api.apps.services import dataset_api_service - - -@manager.route("/datasets", methods=["POST"]) # noqa: F821 -@login_required -@add_tenant_id_to_kwargs -async def create(tenant_id: str=None): - """ - Create a new dataset. - --- - tags: - - Datasets - security: - - ApiKeyAuth: [] - parameters: - - in: header - name: Authorization - type: string - required: true - description: Bearer token for authentication. - - in: body - name: body - description: Dataset creation parameters. - required: true - schema: - type: object - required: - - name - properties: - name: - type: string - description: Dataset name (required). - avatar: - type: string - description: Optional base64-encoded avatar image. - description: - type: string - description: Optional dataset description. - embedding_model: - type: string - description: Optional embedding model name; if omitted, the tenant's default embedding model is used. - permission: - type: string - enum: ['me', 'team'] - description: Visibility of the dataset (private to me or shared with team). - chunk_method: - type: string - enum: ["naive", "book", "email", "laws", "manual", "one", "paper", - "picture", "presentation", "qa", "table", "tag"] - description: Chunking method; if omitted, defaults to "naive". - parser_config: - type: object - description: Optional parser configuration; server-side defaults will be applied. - responses: - 200: - description: Successful operation. - schema: - type: object - properties: - data: - type: object - """ - req, err = await validate_and_parse_json_request(request, CreateDatasetReq) - if err is not None: - return get_error_argument_result(err) - - try: - if not tenant_id: - tenant_id = current_user.id - success, result = await dataset_api_service.create_dataset(tenant_id, req) - if success: - return get_result(data=result) - else: - return get_error_data_result(message=result) - except Exception as e: - logging.exception(e) - return get_error_data_result(message="Internal server error") - - -@manager.route("/datasets", methods=["DELETE"]) # noqa: F821 -@login_required -@add_tenant_id_to_kwargs -async def delete(tenant_id): - """ - Delete datasets. - --- - tags: - - Datasets - security: - - ApiKeyAuth: [] - parameters: - - in: header - name: Authorization - type: string - required: true - description: Bearer token for authentication. - - in: body - name: body - description: Dataset deletion parameters. - required: true - schema: - type: object - required: - - ids - properties: - ids: - type: array or null - items: - type: string - description: | - Specifies the datasets to delete: - - If `null`, all datasets will be deleted. - - If an array of IDs, only the specified datasets will be deleted. - - If an empty array, no datasets will be deleted. - responses: - 200: - description: Successful operation. - schema: - type: object - """ - req, err = await validate_and_parse_json_request(request, DeleteDatasetReq) - if err is not None: - return get_error_argument_result(err) - - try: - success, result = await dataset_api_service.delete_datasets(tenant_id, req.get("ids"), req.get("delete_all", False)) - if success: - return get_result(data=result) - else: - return get_error_data_result(message=result) - except OperationalError as e: - logging.exception(e) - return get_error_data_result(message="Database operation failed") - except Exception as e: - logging.exception(e) - return get_error_data_result(message="Internal server error") - - -@manager.route("/datasets/", methods=["PUT"]) # noqa: F821 -@login_required -@add_tenant_id_to_kwargs -async def update(tenant_id, dataset_id): - """ - Update a dataset. - --- - tags: - - Datasets - security: - - ApiKeyAuth: [] - parameters: - - in: path - name: dataset_id - type: string - required: true - description: ID of the dataset to update. - - in: header - name: Authorization - type: string - required: true - description: Bearer token for authentication. - - in: body - name: body - description: Dataset update parameters. - required: true - schema: - type: object - properties: - name: - type: string - description: New name of the dataset. - avatar: - type: string - description: Updated base64 encoding of the avatar. - description: - type: string - description: Updated description of the dataset. - embedding_model: - type: string - description: Updated embedding model Name. - permission: - type: string - enum: ['me', 'team'] - description: Updated dataset permission. - chunk_method: - type: string - enum: ["naive", "book", "email", "laws", "manual", "one", "paper", - "picture", "presentation", "qa", "table", "tag" - ] - description: Updated chunking method. - pagerank: - type: integer - description: Updated page rank. - parser_config: - type: object - description: Updated parser configuration. - responses: - 200: - description: Successful operation. - schema: - type: object - """ - # Field name transformations during model dump: - # | Original | Dump Output | - # |----------------|-------------| - # | embedding_model| embd_id | - # | chunk_method | parser_id | - extras = {"dataset_id": dataset_id} - req, err = await validate_and_parse_json_request(request, UpdateDatasetReq, extras=extras, exclude_unset=True) - if err is not None: - return get_error_argument_result(err) - - try: - success, result = await dataset_api_service.update_dataset(tenant_id, dataset_id, req) - if success: - return get_result(data=result) - else: - return get_error_data_result(message=result) - except OperationalError as e: - logging.exception(e) - return get_error_data_result(message="Database operation failed") - except Exception as e: - logging.exception(e) - return get_error_data_result(message="Internal server error") - - -@manager.route("/datasets", methods=["GET"]) # noqa: F821 -@login_required -@add_tenant_id_to_kwargs -def list_datasets(tenant_id): - """ - List datasets. - --- - tags: - - Datasets - security: - - ApiKeyAuth: [] - parameters: - - in: query - name: id - type: string - required: false - description: Dataset ID to filter. - - in: query - name: name - type: string - required: false - description: Dataset name to filter. - - in: query - name: page - type: integer - required: false - default: 1 - description: Page number. - - in: query - name: page_size - type: integer - required: false - default: 30 - description: Number of items per page. - - in: query - name: orderby - type: string - required: false - default: "create_time" - description: Field to order by. - - in: query - name: desc - type: boolean - required: false - default: true - description: Order in descending. - - in: header - name: Authorization - type: string - required: true - description: Bearer token for authentication. - responses: - 200: - description: Successful operation. - schema: - type: array - items: - type: object - """ - args, err = validate_and_parse_request_args(request, ListDatasetReq) - if err is not None: - return get_error_argument_result(err) - - try: - success, result = dataset_api_service.list_datasets(tenant_id, args) - if success: - return get_result(data=result.get("data"), total=result.get("total")) - else: - return get_error_data_result(message=result) - except OperationalError as e: - logging.exception(e) - return get_error_data_result(message="Database operation failed") - except Exception as e: - logging.exception(e) - return get_error_data_result(message="Internal server error") - - -@manager.route('/datasets//knowledge_graph', methods=['GET']) # noqa: F821 -@login_required -@add_tenant_id_to_kwargs -async def knowledge_graph(tenant_id, dataset_id): - try: - success, result = await dataset_api_service.get_knowledge_graph(dataset_id, tenant_id) - if success: - return get_result(data=result) - else: - return get_result( - data=False, - message=result, - code=RetCode.AUTHENTICATION_ERROR - ) - except Exception as e: - logging.exception(e) - return get_error_data_result(message="Internal server error") - - -@manager.route('/datasets//knowledge_graph', methods=['DELETE']) # noqa: F821 -@login_required -@add_tenant_id_to_kwargs -def delete_knowledge_graph(tenant_id, dataset_id): - try: - success, result = dataset_api_service.delete_knowledge_graph(dataset_id, tenant_id) - if success: - return get_result(data=result) - else: - return get_result( - data=False, - message=result, - code=RetCode.AUTHENTICATION_ERROR - ) - except Exception as e: - logging.exception(e) - return get_error_data_result(message="Internal server error") - - -@manager.route("/datasets//run_graphrag", methods=["POST"]) # noqa: F821 -@login_required -@add_tenant_id_to_kwargs -async def run_graphrag(tenant_id, dataset_id): - try: - success, result = dataset_api_service.run_graphrag(dataset_id, tenant_id) - if success: - return get_result(data=result) - else: - return get_error_data_result(message=result) - except Exception as e: - logging.exception(e) - return get_error_data_result(message="Internal server error") - - -@manager.route("/datasets//trace_graphrag", methods=["GET"]) # noqa: F821 -@login_required -@add_tenant_id_to_kwargs -def trace_graphrag(tenant_id, dataset_id): - try: - success, result = dataset_api_service.trace_graphrag(dataset_id, tenant_id) - if success: - return get_result(data=result) - else: - return get_error_data_result(message=result) - except Exception as e: - logging.exception(e) - return get_error_data_result(message="Internal server error") - - -@manager.route("/datasets//run_raptor", methods=["POST"]) # noqa: F821 -@login_required -@add_tenant_id_to_kwargs -async def run_raptor(tenant_id, dataset_id): - try: - success, result = dataset_api_service.run_raptor(dataset_id, tenant_id) - if success: - return get_result(data=result) - else: - return get_error_data_result(message=result) - except Exception as e: - logging.exception(e) - return get_error_data_result(message="Internal server error") - - -@manager.route("/datasets//trace_raptor", methods=["GET"]) # noqa: F821 -@login_required -@add_tenant_id_to_kwargs -def trace_raptor(tenant_id, dataset_id): - try: - success, result = dataset_api_service.trace_raptor(dataset_id, tenant_id) - if success: - return get_result(data=result) - else: - return get_error_data_result(message=result) - except Exception as e: - logging.exception(e) - return get_error_data_result(message="Internal server error") - - -@manager.route("/datasets//auto_metadata", methods=["GET"]) # noqa: F821 -@login_required -@add_tenant_id_to_kwargs -def get_auto_metadata(tenant_id, dataset_id): - """ - Get auto-metadata configuration for a dataset. - --- - tags: - - Datasets - security: - - ApiKeyAuth: [] - parameters: - - in: path - name: dataset_id - type: string - required: true - description: ID of the dataset. - - in: header - name: Authorization - type: string - required: true - description: Bearer token for authentication. - responses: - 200: - description: Successful operation. - schema: - type: object - """ - try: - success, result = dataset_api_service.get_auto_metadata(dataset_id, tenant_id) - if success: - return get_result(data=result) - else: - return get_error_data_result(message=result) - except Exception as e: - logging.exception(e) - return get_error_data_result(message="Internal server error") - - -@manager.route("/datasets//auto_metadata", methods=["PUT"]) # noqa: F821 -@login_required -@add_tenant_id_to_kwargs -async def update_auto_metadata(tenant_id, dataset_id): - """ - Update auto-metadata configuration for a dataset. - --- - tags: - - Datasets - security: - - ApiKeyAuth: [] - parameters: - - in: path - name: dataset_id - type: string - required: true - description: ID of the dataset. - - in: header - name: Authorization - type: string - required: true - description: Bearer token for authentication. - - in: body - name: body - description: Auto-metadata configuration. - required: true - schema: - type: object - responses: - 200: - description: Successful operation. - schema: - type: object - """ - from api.utils.validation_utils import AutoMetadataConfig - cfg, err = await validate_and_parse_json_request(request, AutoMetadataConfig) - if err is not None: - return get_error_argument_result(err) - - try: - success, result = await dataset_api_service.update_auto_metadata(dataset_id, tenant_id, cfg) - if success: - return get_result(data=result) - else: - return get_error_data_result(message=result) - except Exception as e: - logging.exception(e) - return get_error_data_result(message="Internal server error") diff --git a/api/apps/restful_apis/memory_api.py b/api/apps/restful_apis/memory_api.py index 672adde6e..79a85d631 100644 --- a/api/apps/restful_apis/memory_api.py +++ b/api/apps/restful_apis/memory_api.py @@ -1,5 +1,5 @@ # -# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/api/apps/sdk/dataset.py b/api/apps/sdk/dataset.py new file mode 100644 index 000000000..bda151671 --- /dev/null +++ b/api/apps/sdk/dataset.py @@ -0,0 +1,798 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import logging +import os +import json +from quart import request +from peewee import OperationalError +from api.db.db_models import File +from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks +from api.db.services.file2document_service import File2DocumentService +from api.db.services.file_service import FileService +from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID, TaskService +from api.db.services.user_service import TenantService +from common.constants import RetCode, FileSource, StatusEnum +from api.utils.api_utils import ( + deep_merge, + get_error_argument_result, + get_error_data_result, + get_error_permission_result, + get_parser_config, + get_result, + remap_dictionary_keys, + token_required, + verify_embedding_availability, +) +from api.utils.validation_utils import ( + AutoMetadataConfig, + CreateDatasetReq, + DeleteDatasetReq, + ListDatasetReq, + UpdateDatasetReq, + validate_and_parse_json_request, + validate_and_parse_request_args, +) +from rag.nlp import search +from common.constants import PAGERANK_FLD +from common import settings + + +@manager.route("/datasets", methods=["POST"]) # noqa: F821 +@token_required +async def create(tenant_id): + """ + Create a new dataset. + --- + tags: + - Datasets + security: + - ApiKeyAuth: [] + parameters: + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + - in: body + name: body + description: Dataset creation parameters. + required: true + schema: + type: object + required: + - name + properties: + name: + type: string + description: Dataset name (required). + avatar: + type: string + description: Optional base64-encoded avatar image. + description: + type: string + description: Optional dataset description. + embedding_model: + type: string + description: Optional embedding model name; if omitted, the tenant's default embedding model is used. + permission: + type: string + enum: ['me', 'team'] + description: Visibility of the dataset (private to me or shared with team). + chunk_method: + type: string + enum: ["naive", "book", "email", "laws", "manual", "one", "paper", + "picture", "presentation", "qa", "table", "tag"] + description: Chunking method; if omitted, defaults to "naive". + parser_config: + type: object + description: Optional parser configuration; server-side defaults will be applied. + responses: + 200: + description: Successful operation. + schema: + type: object + properties: + data: + type: object + """ + # Field name transformations during model dump: + # | Original | Dump Output | + # |----------------|-------------| + # | embedding_model| embd_id | + # | chunk_method | parser_id | + + req, err = await validate_and_parse_json_request(request, CreateDatasetReq) + if err is not None: + return get_error_argument_result(err) + # Map auto_metadata_config (if provided) into parser_config structure + auto_meta = req.pop("auto_metadata_config", None) + if auto_meta: + parser_cfg = req.get("parser_config") or {} + fields = [] + for f in auto_meta.get("fields", []): + fields.append( + { + "name": f.get("name", ""), + "type": f.get("type", ""), + "description": f.get("description"), + "examples": f.get("examples"), + "restrict_values": f.get("restrict_values", False), + } + ) + parser_cfg["metadata"] = fields + parser_cfg["enable_metadata"] = auto_meta.get("enabled", True) + req["parser_config"] = parser_cfg + e, req = KnowledgebaseService.create_with_name(name=req.pop("name", None), tenant_id=tenant_id, parser_id=req.pop("parser_id", None), **req) + + if not e: + return req + + # Insert embedding model(embd id) + ok, t = TenantService.get_by_id(tenant_id) + if not ok: + return get_error_permission_result(message="Tenant not found") + if not req.get("embd_id"): + req["embd_id"] = t.embd_id + else: + ok, err = verify_embedding_availability(req["embd_id"], tenant_id) + if not ok: + return err + + try: + if not KnowledgebaseService.save(**req): + return get_error_data_result() + ok, k = KnowledgebaseService.get_by_id(req["id"]) + if not ok: + return get_error_data_result(message="Dataset created failed") + response_data = remap_dictionary_keys(k.to_dict()) + return get_result(data=response_data) + except Exception as e: + logging.exception(e) + return get_error_data_result(message="Database operation failed") + + +@manager.route("/datasets", methods=["DELETE"]) # noqa: F821 +@token_required +async def delete(tenant_id): + """ + Delete datasets. + --- + tags: + - Datasets + security: + - ApiKeyAuth: [] + parameters: + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + - in: body + name: body + description: Dataset deletion parameters. + required: true + schema: + type: object + required: + - ids + properties: + ids: + type: array or null + items: + type: string + description: | + List of dataset IDs to delete. + If `null` or an empty array is provided, no datasets will be deleted + unless `delete_all` is set to `true`. + delete_all: + type: boolean + description: | + If `true` and `ids` is null or empty, delete all datasets owned by the current user. + Defaults to `false`. + responses: + 200: + description: Successful operation. + schema: + type: object + """ + req, err = await validate_and_parse_json_request(request, DeleteDatasetReq) + if err is not None: + return get_error_argument_result(err) + + try: + kb_id_instance_pairs = [] + if req["ids"] is None or len(req["ids"]) == 0: + if req.get("delete_all"): + req["ids"] = [kb.id for kb in KnowledgebaseService.query(tenant_id=tenant_id)] + if not req["ids"]: + return get_result() + else: + return get_result() + + error_kb_ids = [] + for kb_id in req["ids"]: + kb = KnowledgebaseService.get_or_none(id=kb_id, tenant_id=tenant_id) + if kb is None: + error_kb_ids.append(kb_id) + continue + kb_id_instance_pairs.append((kb_id, kb)) + if len(error_kb_ids) > 0: + return get_error_permission_result(message=f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""") + + errors = [] + success_count = 0 + for kb_id, kb in kb_id_instance_pairs: + for doc in DocumentService.query(kb_id=kb_id): + if not DocumentService.remove_document(doc, tenant_id): + errors.append(f"Remove document '{doc.id}' error for dataset '{kb_id}'") + continue + f2d = File2DocumentService.get_by_document_id(doc.id) + FileService.filter_delete( + [ + File.source_type == FileSource.KNOWLEDGEBASE, + File.id == f2d[0].file_id, + ] + ) + File2DocumentService.delete_by_document_id(doc.id) + FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name]) + + # Drop index for this dataset + try: + from rag.nlp import search + + idxnm = search.index_name(kb.tenant_id) + settings.docStoreConn.delete_idx(idxnm, kb_id) + except Exception as e: + logging.warning(f"Failed to drop index for dataset {kb_id}: {e}") + + if not KnowledgebaseService.delete_by_id(kb_id): + errors.append(f"Delete dataset error for {kb_id}") + continue + success_count += 1 + + if not errors: + return get_result() + + error_message = f"Successfully deleted {success_count} datasets, {len(errors)} failed. Details: {'; '.join(errors)[:128]}..." + if success_count == 0: + return get_error_data_result(message=error_message) + + return get_result(data={"success_count": success_count, "errors": errors[:5]}, message=error_message) + except OperationalError as e: + logging.exception(e) + return get_error_data_result(message="Database operation failed") + + +@manager.route("/datasets/", methods=["PUT"]) # noqa: F821 +@token_required +async def update(tenant_id, dataset_id): + """ + Update a dataset. + --- + tags: + - Datasets + security: + - ApiKeyAuth: [] + parameters: + - in: path + name: dataset_id + type: string + required: true + description: ID of the dataset to update. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + - in: body + name: body + description: Dataset update parameters. + required: true + schema: + type: object + properties: + name: + type: string + description: New name of the dataset. + avatar: + type: string + description: Updated base64 encoding of the avatar. + description: + type: string + description: Updated description of the dataset. + embedding_model: + type: string + description: Updated embedding model Name. + permission: + type: string + enum: ['me', 'team'] + description: Updated dataset permission. + chunk_method: + type: string + enum: ["naive", "book", "email", "laws", "manual", "one", "paper", + "picture", "presentation", "qa", "table", "tag" + ] + description: Updated chunking method. + pagerank: + type: integer + description: Updated page rank. + parser_config: + type: object + description: Updated parser configuration. + responses: + 200: + description: Successful operation. + schema: + type: object + """ + # Field name transformations during model dump: + # | Original | Dump Output | + # |----------------|-------------| + # | embedding_model| embd_id | + # | chunk_method | parser_id | + extras = {"dataset_id": dataset_id} + req, err = await validate_and_parse_json_request(request, UpdateDatasetReq, extras=extras, exclude_unset=True) + if err is not None: + return get_error_argument_result(err) + + if not req: + return get_error_argument_result(message="No properties were modified") + + try: + kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id) + if kb is None: + return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'") + + # Map auto_metadata_config into parser_config if present + auto_meta = req.pop("auto_metadata_config", None) + if auto_meta: + parser_cfg = req.get("parser_config") or {} + fields = [] + for f in auto_meta.get("fields", []): + fields.append( + { + "name": f.get("name", ""), + "type": f.get("type", ""), + "description": f.get("description"), + "examples": f.get("examples"), + "restrict_values": f.get("restrict_values", False), + } + ) + parser_cfg["metadata"] = fields + parser_cfg["enable_metadata"] = auto_meta.get("enabled", True) + req["parser_config"] = parser_cfg + + if req.get("parser_config"): + req["parser_config"] = deep_merge(kb.parser_config, req["parser_config"]) + + if (chunk_method := req.get("parser_id")) and chunk_method != kb.parser_id: + if not req.get("parser_config"): + req["parser_config"] = get_parser_config(chunk_method, None) + elif "parser_config" in req and not req["parser_config"]: + del req["parser_config"] + + if "name" in req and req["name"].lower() != kb.name.lower(): + exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value) + if exists: + return get_error_data_result(message=f"Dataset name '{req['name']}' already exists") + + if "embd_id" in req: + if not req["embd_id"]: + req["embd_id"] = kb.embd_id + if kb.chunk_num != 0 and req["embd_id"] != kb.embd_id: + return get_error_data_result(message=f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}") + ok, err = verify_embedding_availability(req["embd_id"], tenant_id) + if not ok: + return err + + if "pagerank" in req and req["pagerank"] != kb.pagerank: + if os.environ.get("DOC_ENGINE", "elasticsearch") == "infinity": + return get_error_argument_result(message="'pagerank' can only be set when doc_engine is elasticsearch") + + if req["pagerank"] > 0: + settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, search.index_name(kb.tenant_id), kb.id) + else: + # Elasticsearch requires PAGERANK_FLD be non-zero! + settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, search.index_name(kb.tenant_id), kb.id) + + if not KnowledgebaseService.update_by_id(kb.id, req): + return get_error_data_result(message="Update dataset error.(Database error)") + + ok, k = KnowledgebaseService.get_by_id(kb.id) + if not ok: + return get_error_data_result(message="Dataset created failed") + + response_data = remap_dictionary_keys(k.to_dict()) + return get_result(data=response_data) + except OperationalError as e: + logging.exception(e) + return get_error_data_result(message="Database operation failed") + + +@manager.route("/datasets", methods=["GET"]) # noqa: F821 +@token_required +def list_datasets(tenant_id): + """ + List datasets. + --- + tags: + - Datasets + security: + - ApiKeyAuth: [] + parameters: + - in: query + name: id + type: string + required: false + description: Dataset ID to filter. + - in: query + name: name + type: string + required: false + description: Dataset name to filter. + - in: query + name: page + type: integer + required: false + default: 1 + description: Page number. + - in: query + name: page_size + type: integer + required: false + default: 30 + description: Number of items per page. + - in: query + name: orderby + type: string + required: false + default: "create_time" + description: Field to order by. + - in: query + name: desc + type: boolean + required: false + default: true + description: Order in descending. + - in: query + name: include_parsing_status + type: boolean + required: false + default: false + description: | + Whether to include document parsing status counts in the response. + When true, each dataset object will include: unstart_count, running_count, + cancel_count, done_count, and fail_count. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + responses: + 200: + description: Successful operation. + schema: + type: array + items: + type: object + """ + args, err = validate_and_parse_request_args(request, ListDatasetReq) + if err is not None: + return get_error_argument_result(err) + + include_parsing_status = args.get("include_parsing_status", False) + + try: + kb_id = request.args.get("id") + name = args.get("name") + # check whether user has permission for the dataset with specified id + if kb_id: + if not KnowledgebaseService.get_kb_by_id(kb_id, tenant_id): + return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{kb_id}'") + # check whether user has permission for the dataset with specified name + if name: + if not KnowledgebaseService.get_kb_by_name(name, tenant_id): + return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{name}'") + + tenants = TenantService.get_joined_tenants_by_user_id(tenant_id) + kbs, total = KnowledgebaseService.get_list( + [m["tenant_id"] for m in tenants], + tenant_id, + args["page"], + args["page_size"], + args["orderby"], + args["desc"], + kb_id, + name, + ) + + parsing_status_map = {} + if include_parsing_status and kbs: + kb_ids = [kb["id"] for kb in kbs] + parsing_status_map = DocumentService.get_parsing_status_by_kb_ids(kb_ids) + + response_data_list = [] + for kb in kbs: + data = remap_dictionary_keys(kb) + if include_parsing_status: + data.update(parsing_status_map.get(kb["id"], {})) + response_data_list.append(data) + return get_result(data=response_data_list, total=total) + except OperationalError as e: + logging.exception(e) + return get_error_data_result(message="Database operation failed") + + +@manager.route("/datasets//auto_metadata", methods=["GET"]) # noqa: F821 +@token_required +def get_auto_metadata(tenant_id, dataset_id): + """ + Get auto-metadata configuration for a dataset. + """ + try: + kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id) + if kb is None: + return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'") + + parser_cfg = kb.parser_config or {} + metadata = parser_cfg.get("metadata") or [] + enabled = parser_cfg.get("enable_metadata", bool(metadata)) + # Normalize to AutoMetadataConfig-like JSON + fields = [] + for f in metadata: + if not isinstance(f, dict): + continue + fields.append( + { + "name": f.get("name", ""), + "type": f.get("type", ""), + "description": f.get("description"), + "examples": f.get("examples"), + "restrict_values": f.get("restrict_values", False), + } + ) + return get_result(data={"enabled": enabled, "fields": fields}) + except OperationalError as e: + logging.exception(e) + return get_error_data_result(message="Database operation failed") + + +@manager.route("/datasets//auto_metadata", methods=["PUT"]) # noqa: F821 +@token_required +async def update_auto_metadata(tenant_id, dataset_id): + """ + Update auto-metadata configuration for a dataset. + """ + cfg, err = await validate_and_parse_json_request(request, AutoMetadataConfig) + if err is not None: + return get_error_argument_result(err) + + try: + kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id) + if kb is None: + return get_error_permission_result(message=f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'") + + parser_cfg = kb.parser_config or {} + fields = [] + for f in cfg.get("fields", []): + fields.append( + { + "name": f.get("name", ""), + "type": f.get("type", ""), + "description": f.get("description"), + "examples": f.get("examples"), + "restrict_values": f.get("restrict_values", False), + } + ) + parser_cfg["metadata"] = fields + parser_cfg["enable_metadata"] = cfg.get("enabled", True) + + if not KnowledgebaseService.update_by_id(kb.id, {"parser_config": parser_cfg}): + return get_error_data_result(message="Update auto-metadata error.(Database error)") + + return get_result(data={"enabled": parser_cfg["enable_metadata"], "fields": fields}) + except OperationalError as e: + logging.exception(e) + return get_error_data_result(message="Database operation failed") + + +@manager.route("/datasets//knowledge_graph", methods=["GET"]) # noqa: F821 +@token_required +async def knowledge_graph(tenant_id, dataset_id): + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return get_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) + _, kb = KnowledgebaseService.get_by_id(dataset_id) + req = {"kb_id": [dataset_id], "knowledge_graph_kwd": ["graph"]} + + obj = {"graph": {}, "mind_map": {}} + if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), dataset_id): + return get_result(data=obj) + sres = await settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id]) + if not len(sres.ids): + return get_result(data=obj) + + for id in sres.ids[:1]: + ty = sres.field[id]["knowledge_graph_kwd"] + try: + content_json = json.loads(sres.field[id]["content_with_weight"]) + except Exception: + continue + + obj[ty] = content_json + + if "nodes" in obj["graph"]: + obj["graph"]["nodes"] = sorted(obj["graph"]["nodes"], key=lambda x: x.get("pagerank", 0), reverse=True)[:256] + if "edges" in obj["graph"]: + node_id_set = {o["id"] for o in obj["graph"]["nodes"]} + filtered_edges = [o for o in obj["graph"]["edges"] if o["source"] != o["target"] and o["source"] in node_id_set and o["target"] in node_id_set] + obj["graph"]["edges"] = sorted(filtered_edges, key=lambda x: x.get("weight", 0), reverse=True)[:128] + return get_result(data=obj) + + +@manager.route("/datasets//knowledge_graph", methods=["DELETE"]) # noqa: F821 +@token_required +def delete_knowledge_graph(tenant_id, dataset_id): + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return get_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) + _, kb = KnowledgebaseService.get_by_id(dataset_id) + settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), dataset_id) + + return get_result(data=True) + + +@manager.route("/datasets//run_graphrag", methods=["POST"]) # noqa: F821 +@token_required +def run_graphrag(tenant_id, dataset_id): + if not dataset_id: + return get_error_data_result(message='Lack of "Dataset ID"') + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return get_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) + + ok, kb = KnowledgebaseService.get_by_id(dataset_id) + if not ok: + return get_error_data_result(message="Invalid Dataset ID") + + task_id = kb.graphrag_task_id + if task_id: + ok, task = TaskService.get_by_id(task_id) + if not ok: + logging.warning(f"A valid GraphRAG task id is expected for Dataset {dataset_id}") + + if task and task.progress not in [-1, 1]: + return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running.") + + documents, _ = DocumentService.get_by_kb_id( + kb_id=dataset_id, + page_number=0, + items_per_page=0, + orderby="create_time", + desc=False, + keywords="", + run_status=[], + types=[], + suffix=[], + ) + if not documents: + return get_error_data_result(message=f"No documents in Dataset {dataset_id}") + + sample_document = documents[0] + document_ids = [document["id"] for document in documents] + + task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids)) + + if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}): + logging.warning(f"Cannot save graphrag_task_id for Dataset {dataset_id}") + + return get_result(data={"graphrag_task_id": task_id}) + + +@manager.route("/datasets//trace_graphrag", methods=["GET"]) # noqa: F821 +@token_required +def trace_graphrag(tenant_id, dataset_id): + if not dataset_id: + return get_error_data_result(message='Lack of "Dataset ID"') + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return get_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) + + ok, kb = KnowledgebaseService.get_by_id(dataset_id) + if not ok: + return get_error_data_result(message="Invalid Dataset ID") + + task_id = kb.graphrag_task_id + if not task_id: + return get_result(data={}) + + ok, task = TaskService.get_by_id(task_id) + if not ok: + return get_result(data={}) + + return get_result(data=task.to_dict()) + + +@manager.route("/datasets//run_raptor", methods=["POST"]) # noqa: F821 +@token_required +def run_raptor(tenant_id, dataset_id): + if not dataset_id: + return get_error_data_result(message='Lack of "Dataset ID"') + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return get_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) + + ok, kb = KnowledgebaseService.get_by_id(dataset_id) + if not ok: + return get_error_data_result(message="Invalid Dataset ID") + + task_id = kb.raptor_task_id + if task_id: + ok, task = TaskService.get_by_id(task_id) + if not ok: + logging.warning(f"A valid RAPTOR task id is expected for Dataset {dataset_id}") + + if task and task.progress not in [-1, 1]: + return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running.") + + documents, _ = DocumentService.get_by_kb_id( + kb_id=dataset_id, + page_number=0, + items_per_page=0, + orderby="create_time", + desc=False, + keywords="", + run_status=[], + types=[], + suffix=[], + ) + if not documents: + return get_error_data_result(message=f"No documents in Dataset {dataset_id}") + + sample_document = documents[0] + document_ids = [document["id"] for document in documents] + + task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids)) + + if not KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": task_id}): + logging.warning(f"Cannot save raptor_task_id for Dataset {dataset_id}") + + return get_result(data={"raptor_task_id": task_id}) + + +@manager.route("/datasets//trace_raptor", methods=["GET"]) # noqa: F821 +@token_required +def trace_raptor(tenant_id, dataset_id): + if not dataset_id: + return get_error_data_result(message='Lack of "Dataset ID"') + + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return get_result( + data=False, + message='No authorization.', + code=RetCode.AUTHENTICATION_ERROR + ) + ok, kb = KnowledgebaseService.get_by_id(dataset_id) + if not ok: + return get_error_data_result(message="Invalid Dataset ID") + + task_id = kb.raptor_task_id + if not task_id: + return get_result(data={}) + + ok, task = TaskService.get_by_id(task_id) + if not ok: + return get_error_data_result(message="RAPTOR Task Not Found or Error Occurred") + + return get_result(data=task.to_dict()) diff --git a/api/apps/services/dataset_api_service.py b/api/apps/services/dataset_api_service.py deleted file mode 100644 index 094570528..000000000 --- a/api/apps/services/dataset_api_service.py +++ /dev/null @@ -1,613 +0,0 @@ -# -# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import logging -import json -import os -from common.constants import PAGERANK_FLD -from common import settings -from api.db.db_models import File -from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks -from api.db.services.file2document_service import File2DocumentService -from api.db.services.file_service import FileService -from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db.services.connector_service import Connector2KbService -from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID, TaskService -from api.db.services.user_service import TenantService, UserService -from common.constants import FileSource, StatusEnum -from api.utils.api_utils import deep_merge, get_parser_config, remap_dictionary_keys, verify_embedding_availability - - -async def create_dataset(tenant_id: str, req: dict): - """ - Create a new dataset. - - :param tenant_id: tenant ID - :param req: dataset creation request - :return: (success, result) or (success, error_message) - """ - # Extract ext field for additional parameters - ext_fields = req.pop("ext", {}) - - # Map auto_metadata_config (if provided) into parser_config structure - auto_meta = req.pop("auto_metadata_config", {}) - if auto_meta: - parser_cfg = req.get("parser_config") or {} - fields = [] - for f in auto_meta.get("fields", []): - fields.append( - { - "name": f.get("name", ""), - "type": f.get("type", ""), - "description": f.get("description"), - "examples": f.get("examples"), - "restrict_values": f.get("restrict_values", False), - } - ) - parser_cfg["metadata"] = fields - parser_cfg["enable_metadata"] = auto_meta.get("enabled", True) - req["parser_config"] = parser_cfg - req.update(ext_fields) - - e, create_dict = KnowledgebaseService.create_with_name( - name=req.pop("name", None), - tenant_id=tenant_id, - parser_id=req.pop("parser_id", None), - **req - ) - - if not e: - return False, create_dict - - # Insert embedding model(embd id) - ok, t = TenantService.get_by_id(tenant_id) - if not ok: - return False, "Tenant not found" - if not create_dict.get("embd_id"): - create_dict["embd_id"] = t.embd_id - else: - ok, err = verify_embedding_availability(create_dict["embd_id"], tenant_id) - if not ok: - return False, err - - if not KnowledgebaseService.save(**create_dict): - return False, "Failed to save dataset" - ok, k = KnowledgebaseService.get_by_id(create_dict["id"]) - if not ok: - return False, "Dataset created failed" - response_data = remap_dictionary_keys(k.to_dict()) - return True, response_data - - -async def delete_datasets(tenant_id: str, ids: list = None, delete_all: bool = False): - """ - Delete datasets. - - :param tenant_id: tenant ID - :param ids: list of dataset IDs - :param delete_all: whether to delete all datasets of the tenant (if ids is not provided) - :return: (success, result) or (success, error_message) - """ - kb_id_instance_pairs = [] - if not ids: - if not delete_all: - return True, {"success_count": 0} - else: - ids = [kb.id for kb in KnowledgebaseService.query(tenant_id=tenant_id)] - - error_kb_ids = [] - for kb_id in ids: - kb = KnowledgebaseService.get_or_none(id=kb_id, tenant_id=tenant_id) - if kb is None: - error_kb_ids.append(kb_id) - continue - kb_id_instance_pairs.append((kb_id, kb)) - if len(error_kb_ids) > 0: - return False, f"""User '{tenant_id}' lacks permission for datasets: '{", ".join(error_kb_ids)}'""" - - errors = [] - success_count = 0 - for kb_id, kb in kb_id_instance_pairs: - for doc in DocumentService.query(kb_id=kb_id): - if not DocumentService.remove_document(doc, tenant_id): - errors.append(f"Remove document '{doc.id}' error for dataset '{kb_id}'") - continue - f2d = File2DocumentService.get_by_document_id(doc.id) - FileService.filter_delete( - [ - File.source_type == FileSource.KNOWLEDGEBASE, - File.id == f2d[0].file_id, - ] - ) - File2DocumentService.delete_by_document_id(doc.id) - FileService.filter_delete( - [File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kb.name]) - - # Drop index for this dataset - try: - from rag.nlp import search - idxnm = search.index_name(kb.tenant_id) - settings.docStoreConn.delete_idx(idxnm, kb_id) - except Exception as e: - errors.append(f"Failed to drop index for dataset {kb_id}: {e}") - - if not KnowledgebaseService.delete_by_id(kb_id): - errors.append(f"Delete dataset error for {kb_id}") - continue - success_count += 1 - - if not errors: - return True, {"success_count": success_count} - - error_message = f"Successfully deleted {success_count} datasets, {len(errors)} failed. Details: {'; '.join(errors)[:128]}..." - if success_count == 0: - return False, error_message - - return True, {"success_count": success_count, "errors": errors[:5]} - - -async def update_dataset(tenant_id: str, dataset_id: str, req: dict): - """ - Update a dataset. - - :param tenant_id: tenant ID - :param dataset_id: dataset ID - :param req: dataset update request - :return: (success, result) or (success, error_message) - """ - if not req: - return False, "No properties were modified" - - kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id) - if kb is None: - return False, f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'" - - # Extract ext field for additional parameters - ext_fields = req.pop("ext", {}) - - # Map auto_metadata_config into parser_config if present - auto_meta = req.pop("auto_metadata_config", {}) - if auto_meta: - parser_cfg = req.get("parser_config") or {} - fields = [] - for f in auto_meta.get("fields", []): - fields.append( - { - "name": f.get("name", ""), - "type": f.get("type", ""), - "description": f.get("description"), - "examples": f.get("examples"), - "restrict_values": f.get("restrict_values", False), - } - ) - parser_cfg["metadata"] = fields - parser_cfg["enable_metadata"] = auto_meta.get("enabled", True) - req["parser_config"] = parser_cfg - - # Merge ext fields with req - req.update(ext_fields) - - # Extract connectors from request - connectors = [] - if "connectors" in req: - connectors = req["connectors"] - del req["connectors"] - - if req.get("parser_config"): - parser_config = req["parser_config"] - req_ext_fields = parser_config.pop("ext", {}) - parser_config.update(req_ext_fields) - req["parser_config"] = deep_merge(kb.parser_config, parser_config) - - if (chunk_method := req.get("parser_id")) and chunk_method != kb.parser_id: - if not req.get("parser_config"): - req["parser_config"] = get_parser_config(chunk_method, None) - elif "parser_config" in req and not req["parser_config"]: - del req["parser_config"] - - if "name" in req and req["name"].lower() != kb.name.lower(): - exists = KnowledgebaseService.get_or_none(name=req["name"], tenant_id=tenant_id, - status=StatusEnum.VALID.value) - if exists: - return False, f"Dataset name '{req['name']}' already exists" - - if "embd_id" in req: - if not req["embd_id"]: - req["embd_id"] = kb.embd_id - if kb.chunk_num != 0 and req["embd_id"] != kb.embd_id: - return False, f"When chunk_num ({kb.chunk_num}) > 0, embedding_model must remain {kb.embd_id}" - ok, err = verify_embedding_availability(req["embd_id"], tenant_id) - if not ok: - return False, err - - if "pagerank" in req and req["pagerank"] != kb.pagerank: - if os.environ.get("DOC_ENGINE", "elasticsearch") == "infinity": - return False, "'pagerank' can only be set when doc_engine is elasticsearch" - - if req["pagerank"] > 0: - from rag.nlp import search - settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, - search.index_name(kb.tenant_id), kb.id) - else: - # Elasticsearch requires PAGERANK_FLD be non-zero! - from rag.nlp import search - settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, - search.index_name(kb.tenant_id), kb.id) - - if not KnowledgebaseService.update_by_id(kb.id, req): - return False, "Update dataset error.(Database error)" - - ok, k = KnowledgebaseService.get_by_id(kb.id) - if not ok: - return False, "Dataset updated failed" - - # Link connectors to the dataset - errors = Connector2KbService.link_connectors(kb.id, [conn for conn in connectors], tenant_id) - if errors: - logging.error("Link KB errors: %s", errors) - - response_data = remap_dictionary_keys(k.to_dict()) - response_data["connectors"] = connectors - return True, response_data - - -def list_datasets(tenant_id: str, args: dict): - """ - List datasets. - - :param tenant_id: tenant ID - :param args: query arguments - :return: (success, result) or (success, error_message) - """ - kb_id = args.get("id") - name = args.get("name") - page = int(args.get("page", 1)) - page_size = int(args.get("page_size", 30)) - ext_fields = args.get("ext", {}) - parser_id = ext_fields.get("parser_id") - keywords = ext_fields.get("keywords", "") - orderby = args.get("orderby", "create_time") - desc_arg = args.get("desc", "true") - if isinstance(desc_arg, str): - desc = desc_arg.lower() != "false" - elif isinstance(desc_arg, bool): - desc = desc_arg - else: - # unknown type, default to True - desc = True - - if kb_id: - kbs = KnowledgebaseService.get_kb_by_id(kb_id, tenant_id) - if not kbs: - return False, f"User '{tenant_id}' lacks permission for dataset '{kb_id}'" - if name: - kbs = KnowledgebaseService.get_kb_by_name(name, tenant_id) - if not kbs: - return False, f"User '{tenant_id}' lacks permission for dataset '{name}'" - if ext_fields.get("owner_ids", []): - tenant_ids = ext_fields["owner_ids"] - else: - tenants = TenantService.get_joined_tenants_by_user_id(tenant_id) - tenant_ids = [m["tenant_id"] for m in tenants] - kbs, total = KnowledgebaseService.get_list( - tenant_ids, - tenant_id, - page, - page_size, - orderby, - desc, - kb_id, - name, - keywords, - parser_id - ) - users = UserService.get_by_ids([m["tenant_id"] for m in kbs]) - user_map = {m.id: m.to_dict() for m in users} - response_data_list = [] - for kb in kbs: - user_dict = user_map.get(kb["tenant_id"], {}) - kb.update({ - "nickname": user_dict.get("nickname", ""), - "tenant_avatar": user_dict.get("avatar", "") - }) - response_data_list.append(remap_dictionary_keys(kb)) - return True, {"data": response_data_list, "total": total} - - -async def get_knowledge_graph(dataset_id: str, tenant_id: str): - """ - Get knowledge graph for a dataset. - - :param dataset_id: dataset ID - :param tenant_id: tenant ID - :return: (success, result) or (success, error_message) - """ - if not KnowledgebaseService.accessible(dataset_id, tenant_id): - return False, "No authorization." - _, kb = KnowledgebaseService.get_by_id(dataset_id) - - req = { - "kb_id": [dataset_id], - "knowledge_graph_kwd": ["graph"] - } - - obj = {"graph": {}, "mind_map": {}} - from rag.nlp import search - if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), dataset_id): - return True, obj - sres = await settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id]) - if not len(sres.ids): - return True, obj - - for id in sres.ids[:1]: - ty = sres.field[id]["knowledge_graph_kwd"] - try: - content_json = json.loads(sres.field[id]["content_with_weight"]) - except Exception: - continue - - obj[ty] = content_json - - if "nodes" in obj["graph"]: - obj["graph"]["nodes"] = sorted(obj["graph"]["nodes"], key=lambda x: x.get("pagerank", 0), reverse=True)[:256] - if "edges" in obj["graph"]: - node_id_set = {o["id"] for o in obj["graph"]["nodes"]} - filtered_edges = [o for o in obj["graph"]["edges"] if - o["source"] != o["target"] and o["source"] in node_id_set and o["target"] in node_id_set] - obj["graph"]["edges"] = sorted(filtered_edges, key=lambda x: x.get("weight", 0), reverse=True)[:128] - return True, obj - - -def delete_knowledge_graph(dataset_id: str, tenant_id: str): - """ - Delete knowledge graph for a dataset. - - :param dataset_id: dataset ID - :param tenant_id: tenant ID - :return: (success, result) or (success, error_message) - """ - if not KnowledgebaseService.accessible(dataset_id, tenant_id): - return False, "No authorization." - _, kb = KnowledgebaseService.get_by_id(dataset_id) - from rag.nlp import search - settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, - search.index_name(kb.tenant_id), dataset_id) - - return True, True - - -def run_graphrag(dataset_id: str, tenant_id: str): - """ - Run GraphRAG for a dataset. - - :param dataset_id: dataset ID - :param tenant_id: tenant ID - :return: (success, result) or (success, error_message) - """ - if not dataset_id: - return False, 'Lack of "Dataset ID"' - if not KnowledgebaseService.accessible(dataset_id, tenant_id): - return False, "No authorization." - - ok, kb = KnowledgebaseService.get_by_id(dataset_id) - if not ok: - return False, "Invalid Dataset ID" - - task_id = kb.graphrag_task_id - if task_id: - ok, task = TaskService.get_by_id(task_id) - if not ok: - logging.warning(f"A valid GraphRAG task id is expected for Dataset {dataset_id}") - - if task and task.progress not in [-1, 1]: - return False, f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running." - - documents, _ = DocumentService.get_by_kb_id( - kb_id=dataset_id, - page_number=0, - items_per_page=0, - orderby="create_time", - desc=False, - keywords="", - run_status=[], - types=[], - suffix=[], - ) - if not documents: - return False, f"No documents in Dataset {dataset_id}" - - sample_document = documents[0] - document_ids = [document["id"] for document in documents] - - task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids)) - - if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}): - logging.warning(f"Cannot save graphrag_task_id for Dataset {dataset_id}") - - return True, {"graphrag_task_id": task_id} - - -def trace_graphrag(dataset_id: str, tenant_id: str): - """ - Trace GraphRAG task for a dataset. - - :param dataset_id: dataset ID - :param tenant_id: tenant ID - :return: (success, result) or (success, error_message) - """ - if not dataset_id: - return False, 'Lack of "Dataset ID"' - if not KnowledgebaseService.accessible(dataset_id, tenant_id): - return False, "No authorization." - - ok, kb = KnowledgebaseService.get_by_id(dataset_id) - if not ok: - return False, "Invalid Dataset ID" - - task_id = kb.graphrag_task_id - if not task_id: - return True, {} - - ok, task = TaskService.get_by_id(task_id) - if not ok: - return True, {} - - return True, task.to_dict() - - -def run_raptor(dataset_id: str, tenant_id: str): - """ - Run RAPTOR for a dataset. - - :param dataset_id: dataset ID - :param tenant_id: tenant ID - :return: (success, result) or (success, error_message) - """ - if not dataset_id: - return False, 'Lack of "Dataset ID"' - if not KnowledgebaseService.accessible(dataset_id, tenant_id): - return False, "No authorization." - - ok, kb = KnowledgebaseService.get_by_id(dataset_id) - if not ok: - return False, "Invalid Dataset ID" - - task_id = kb.raptor_task_id - if task_id: - ok, task = TaskService.get_by_id(task_id) - if not ok: - logging.warning(f"A valid RAPTOR task id is expected for Dataset {dataset_id}") - - if task and task.progress not in [-1, 1]: - return False, f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running." - - documents, _ = DocumentService.get_by_kb_id( - kb_id=dataset_id, - page_number=0, - items_per_page=0, - orderby="create_time", - desc=False, - keywords="", - run_status=[], - types=[], - suffix=[], - ) - if not documents: - return False, f"No documents in Dataset {dataset_id}" - - sample_document = documents[0] - document_ids = [document["id"] for document in documents] - - task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids)) - - if not KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": task_id}): - logging.warning(f"Cannot save raptor_task_id for Dataset {dataset_id}") - - return True, {"raptor_task_id": task_id} - - -def trace_raptor(dataset_id: str, tenant_id: str): - """ - Trace RAPTOR task for a dataset. - - :param dataset_id: dataset ID - :param tenant_id: tenant ID - :return: (success, result) or (success, error_message) - """ - if not dataset_id: - return False, 'Lack of "Dataset ID"' - - if not KnowledgebaseService.accessible(dataset_id, tenant_id): - return False, "No authorization." - - ok, kb = KnowledgebaseService.get_by_id(dataset_id) - if not ok: - return False, "Invalid Dataset ID" - - task_id = kb.raptor_task_id - if not task_id: - return True, {} - - ok, task = TaskService.get_by_id(task_id) - if not ok: - return False, "RAPTOR Task Not Found or Error Occurred" - - return True, task.to_dict() - - -def get_auto_metadata(dataset_id: str, tenant_id: str): - """ - Get auto-metadata configuration for a dataset. - - :param dataset_id: dataset ID - :param tenant_id: tenant ID - :return: (success, result) or (success, error_message) - """ - kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id) - if kb is None: - return False, f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'" - - parser_cfg = kb.parser_config or {} - metadata = parser_cfg.get("metadata") or [] - enabled = parser_cfg.get("enable_metadata", bool(metadata)) - # Normalize to AutoMetadataConfig-like JSON - fields = [] - for f in metadata: - if not isinstance(f, dict): - continue - fields.append( - { - "name": f.get("name", ""), - "type": f.get("type", ""), - "description": f.get("description"), - "examples": f.get("examples"), - "restrict_values": f.get("restrict_values", False), - } - ) - return True, {"enabled": enabled, "fields": fields} - - -async def update_auto_metadata(dataset_id: str, tenant_id: str, cfg: dict): - """ - Update auto-metadata configuration for a dataset. - - :param dataset_id: dataset ID - :param tenant_id: tenant ID - :param cfg: auto-metadata configuration - :return: (success, result) or (success, error_message) - """ - kb = KnowledgebaseService.get_or_none(id=dataset_id, tenant_id=tenant_id) - if kb is None: - return False, f"User '{tenant_id}' lacks permission for dataset '{dataset_id}'" - - parser_cfg = kb.parser_config or {} - fields = [] - for f in cfg.get("fields", []): - fields.append( - { - "name": f.get("name", ""), - "type": f.get("type", ""), - "description": f.get("description"), - "examples": f.get("examples"), - "restrict_values": f.get("restrict_values", False), - } - ) - parser_cfg["metadata"] = fields - parser_cfg["enable_metadata"] = cfg.get("enabled", True) - - if not KnowledgebaseService.update_by_id(kb.id, {"parser_config": parser_cfg}): - return False, "Update auto-metadata error.(Database error)" - - return True, {"enabled": parser_cfg["enable_metadata"], "fields": fields} diff --git a/api/apps/services/memory_api_service.py b/api/apps/services/memory_api_service.py index 1b640cff6..ca6627d5a 100644 --- a/api/apps/services/memory_api_service.py +++ b/api/apps/services/memory_api_service.py @@ -1,5 +1,5 @@ # -# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/api/db/services/knowledgebase_service.py b/api/db/services/knowledgebase_service.py index c66d66a68..dcd403887 100644 --- a/api/db/services/knowledgebase_service.py +++ b/api/db/services/knowledgebase_service.py @@ -433,7 +433,7 @@ class KnowledgebaseService(CommonService): @classmethod @DB.connection_context() def get_list(cls, joined_tenant_ids, user_id, - page_number, items_per_page, orderby, desc, id, name, keywords, parser_id=None): + page_number, items_per_page, orderby, desc, id, name): # Get list of knowledge bases with filtering and pagination # Args: # joined_tenant_ids: List of tenant IDs @@ -444,8 +444,6 @@ class KnowledgebaseService(CommonService): # desc: Boolean indicating descending order # id: Optional ID filter # name: Optional name filter - # keywords: Optional keywords filter - # parser_id: Optional parser ID filter # Returns: # List of knowledge bases # Total count of knowledge bases @@ -454,11 +452,6 @@ class KnowledgebaseService(CommonService): kbs = kbs.where(cls.model.id == id) if name: kbs = kbs.where(cls.model.name == name) - if keywords: - kbs = kbs.where(fn.LOWER(cls.model.name).contains(keywords.lower())) - if parser_id: - kbs = kbs.where(cls.model.parser_id == parser_id) - kbs = kbs.where( ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) | ( diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index 9cf5e5a3f..b70ff2f9f 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -28,6 +28,7 @@ from typing import Any import requests from quart import ( + Response, jsonify, request, has_app_context, @@ -233,17 +234,6 @@ def active_required(func): return wrapper -def add_tenant_id_to_kwargs(func): - @wraps(func) - async def wrapper(**kwargs): - from api.apps import current_user - kwargs["tenant_id"] = current_user.id - if inspect.iscoroutinefunction(func): - return await func(**kwargs) - return func(**kwargs) - return wrapper - - def get_json_result(code: RetCode = RetCode.SUCCESS, message="success", data=None): response = {"code": code, "message": message, "data": data} return _safe_jsonify(response) @@ -523,7 +513,7 @@ def check_duplicate_ids(ids, id_type="item"): return list(set(ids)), duplicate_messages -def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, str | None]: +def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, Response | None]: from api.db.services.llm_service import LLMService from api.db.services.tenant_llm_service import TenantLLMService @@ -569,16 +559,13 @@ def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, s is_builtin_model = llm_factory == "Builtin" if not (is_builtin_model or is_tenant_model or in_llm_service): - return False, f"Unsupported model: <{embd_id}>" + return False, get_error_argument_result(f"Unsupported model: <{embd_id}>") if not (is_builtin_model or is_tenant_model): - return False, f"Unauthorized model: <{embd_id}>" + return False, get_error_argument_result(f"Unauthorized model: <{embd_id}>") except OperationalError as e: logging.exception(e) - return False, "Database operation failed" - except Exception as e: - logging.exception(e) - return False, "Internal server error" + return False, get_error_data_result(message="Database operation failed") return True, None diff --git a/api/utils/validation_utils.py b/api/utils/validation_utils.py index 35e0b91f5..54d5f67dc 100644 --- a/api/utils/validation_utils.py +++ b/api/utils/validation_utils.py @@ -27,7 +27,6 @@ from pydantic import ( ValidationError, field_validator, model_validator, - ValidationInfo ) from pydantic_core import PydanticCustomError from werkzeug.exceptions import BadRequest, UnsupportedMediaType @@ -163,15 +162,6 @@ def validate_and_parse_request_args(request: Request, validator: type[BaseModel] - Preserves type conversion from Pydantic validation """ args = request.args.to_dict(flat=True) - - # Handle ext parameter: parse JSON string to dict if it's a string - if 'ext' in args and isinstance(args['ext'], str): - import json - try: - args['ext'] = json.loads(args['ext']) - except json.JSONDecodeError: - pass # Keep the string and let validation handle the error - try: if extras is not None: args.update(extras) @@ -346,7 +336,6 @@ class RaptorConfig(Base): max_cluster: Annotated[int, Field(default=64, ge=1, le=1024)] random_seed: Annotated[int, Field(default=0, ge=0)] auto_disable_for_structured_data: Annotated[bool, Field(default=True)] - ext: Annotated[dict, Field(default={})] class GraphragConfig(Base): @@ -388,7 +377,6 @@ class ParserConfig(Base): filename_embd_weight: Annotated[float | None, Field(default=0.1, ge=0.0, le=1.0)] task_page_size: Annotated[int | None, Field(default=None, ge=1)] pages: Annotated[list[list[int]] | None, Field(default=None)] - ext: Annotated[dict, Field(default={})] class CreateDatasetReq(Base): @@ -402,25 +390,6 @@ class CreateDatasetReq(Base): pipeline_id: Annotated[str | None, Field(default=None, min_length=32, max_length=32, serialization_alias="pipeline_id")] parser_config: Annotated[ParserConfig | None, Field(default=None)] auto_metadata_config: Annotated[AutoMetadataConfig | None, Field(default=None)] - ext: Annotated[dict, Field(default={})] - - @field_validator("pipeline_id", mode="before") - @classmethod - def handle_pipeline_id(cls, v: str | None, info: ValidationInfo): - if v is None: - return v - if info.data.get("chunk_method") is not None and isinstance(v, str): - v = None - return v - - @field_validator("parse_type", mode="before") - @classmethod - def handle_parse_type(cls, v: int | None, info: ValidationInfo): - if v is None: - return v - if info.data.get("chunk_method") is not None and isinstance(v, int): - v = None - return v @field_validator("avatar", mode="after") @classmethod @@ -778,4 +747,3 @@ class BaseListReq(BaseModel): class ListDatasetReq(BaseListReq): include_parsing_status: Annotated[bool, Field(default=False)] - ext: Annotated[dict, Field(default={})] diff --git a/sdk/python/ragflow_sdk/modules/dataset.py b/sdk/python/ragflow_sdk/modules/dataset.py index 158cebfa8..b686dceec 100644 --- a/sdk/python/ragflow_sdk/modules/dataset.py +++ b/sdk/python/ragflow_sdk/modules/dataset.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Any + from .base import Base from .document import Document @@ -151,23 +151,3 @@ class DataSet(Base): res = res.json() if res.get("code") != 0: raise Exception(res.get("message")) - - def get_auto_metadata(self) -> dict[str, Any]: - """ - Retrieve auto-metadata configuration for a dataset via SDK. - """ - res = self.get(f"/datasets/{self.id}/auto_metadata") - res = res.json() - if res.get("code") == 0: - return res["data"] - raise Exception(res["message"]) - - def update_auto_metadata(self, **config: Any) -> dict[str, Any]: - """ - Update auto-metadata configuration for a dataset via SDK. - """ - res = self.put(f"/datasets/{self.id}/auto_metadata", config) - res = res.json() - if res.get("code") == 0: - return res["data"] - raise Exception(res["message"]) diff --git a/sdk/python/ragflow_sdk/ragflow.py b/sdk/python/ragflow_sdk/ragflow.py index 15b571872..ff4f423c4 100644 --- a/sdk/python/ragflow_sdk/ragflow.py +++ b/sdk/python/ragflow_sdk/ragflow.py @@ -111,6 +111,26 @@ class RAGFlow: return result_list raise Exception(res["message"]) + def get_auto_metadata(self, dataset_id: str) -> dict[str, Any]: + """ + Retrieve auto-metadata configuration for a dataset via SDK. + """ + res = self.get(f"/datasets/{dataset_id}/auto_metadata") + res = res.json() + if res.get("code") == 0: + return res["data"] + raise Exception(res["message"]) + + def update_auto_metadata(self, dataset_id: str, **config: Any) -> dict[str, Any]: + """ + Update auto-metadata configuration for a dataset via SDK. + """ + res = self.put(f"/datasets/{dataset_id}/auto_metadata", config) + res = res.json() + if res.get("code") == 0: + return res["data"] + raise Exception(res["message"]) + def create_chat(self, name: str, avatar: str = "", dataset_ids=None, llm: Chat.LLM | None = None, prompt: Chat.Prompt | None = None) -> Chat: if dataset_ids is None: dataset_ids = [] diff --git a/test/testcases/test_http_api/test_dataset_management/test_create_dataset.py b/test/testcases/test_http_api/test_dataset_management/test_create_dataset.py index 18a3506b0..15bd9df1c 100644 --- a/test/testcases/test_http_api/test_dataset_management/test_create_dataset.py +++ b/test/testcases/test_http_api/test_dataset_management/test_create_dataset.py @@ -23,7 +23,7 @@ from utils import encode_avatar from utils.file_utils import create_image_file from utils.hypothesis_utils import valid_names -from test_http_api.common import create_dataset +from common import create_dataset @pytest.mark.usefixtures("clear_datasets") @@ -32,11 +32,11 @@ class TestAuthorization: @pytest.mark.parametrize( "invalid_auth, expected_code, expected_message", [ - (None, 401, ""), + (None, 0, "`Authorization` can't be empty"), ( RAGFlowHttpApiAuth(INVALID_API_TOKEN), - 401, - "", + 109, + "Authentication error: API key is invalid!", ), ], ids=["empty_auth", "invalid_api_token"], @@ -250,7 +250,7 @@ class TestDatasetCreate: def test_embedding_model_invalid(self, HttpApiAuth, name, embedding_model): payload = {"name": name, "embedding_model": embedding_model} res = create_dataset(HttpApiAuth, payload) - assert res["code"] == 102, res + assert res["code"] == 101, res if "tenant_no_auth" in name: assert res["message"] == f"Unauthorized model: <{embedding_model}>", res else: diff --git a/test/testcases/test_http_api/test_dataset_management/test_delete_datasets.py b/test/testcases/test_http_api/test_dataset_management/test_delete_datasets.py index 77e9e0f92..024085741 100644 --- a/test/testcases/test_http_api/test_dataset_management/test_delete_datasets.py +++ b/test/testcases/test_http_api/test_dataset_management/test_delete_datasets.py @@ -31,11 +31,11 @@ class TestAuthorization: @pytest.mark.parametrize( "invalid_auth, expected_code, expected_message", [ - (None, 401, ""), + (None, 0, "`Authorization` can't be empty"), ( RAGFlowHttpApiAuth(INVALID_API_TOKEN), - 401, - "", + 109, + "Authentication error: API key is invalid!", ), ], ) @@ -160,7 +160,7 @@ class TestDatasetsDelete: def test_id_wrong_uuid(self, HttpApiAuth): payload = {"ids": ["d94a8dc02c9711f0930f7fbc369eab6d"]} res = delete_datasets(HttpApiAuth, payload) - assert res["code"] == 102, res + assert res["code"] == 108, res assert "lacks permission for dataset" in res["message"], res res = list_datasets(HttpApiAuth) @@ -180,7 +180,7 @@ class TestDatasetsDelete: if callable(func): payload = func(dataset_ids) res = delete_datasets(HttpApiAuth, payload) - assert res["code"] == 102, res + assert res["code"] == 108, res assert "lacks permission for dataset" in res["message"], res res = list_datasets(HttpApiAuth) @@ -205,7 +205,7 @@ class TestDatasetsDelete: assert res["code"] == 0, res res = delete_datasets(HttpApiAuth, payload) - assert res["code"] == 102, res + assert res["code"] == 108, res assert "lacks permission for dataset" in res["message"], res @pytest.mark.p3 diff --git a/test/testcases/test_http_api/test_dataset_management/test_knowledge_graph.py b/test/testcases/test_http_api/test_dataset_management/test_knowledge_graph.py index 0398f7723..665635f16 100644 --- a/test/testcases/test_http_api/test_dataset_management/test_knowledge_graph.py +++ b/test/testcases/test_http_api/test_dataset_management/test_knowledge_graph.py @@ -24,8 +24,8 @@ class TestAuthorization: @pytest.mark.parametrize( "invalid_auth, expected_code, expected_message", [ - (None, 401, ""), - (RAGFlowHttpApiAuth(INVALID_API_TOKEN), 401, ""), + (None, 0, "Authorization"), + (RAGFlowHttpApiAuth(INVALID_API_TOKEN), 109, "API key is invalid"), ], ) def test_invalid_auth(self, invalid_auth, expected_code, expected_message): diff --git a/test/testcases/test_http_api/test_dataset_management/test_list_datasets.py b/test/testcases/test_http_api/test_dataset_management/test_list_datasets.py index 794761ed8..7887ff1fd 100644 --- a/test/testcases/test_http_api/test_dataset_management/test_list_datasets.py +++ b/test/testcases/test_http_api/test_dataset_management/test_list_datasets.py @@ -28,11 +28,11 @@ class TestAuthorization: @pytest.mark.parametrize( "invalid_auth, expected_code, expected_message", [ - (None, 401, ""), + (None, 0, "`Authorization` can't be empty"), ( RAGFlowHttpApiAuth(INVALID_API_TOKEN), - 401, - "", + 109, + "Authentication error: API key is invalid!", ), ], ) @@ -237,7 +237,7 @@ class TestDatasetsList: def test_name_wrong(self, HttpApiAuth): params = {"name": "wrong name"} res = list_datasets(HttpApiAuth, params) - assert res["code"] == 102, res + assert res["code"] == 108, res assert "lacks permission for dataset" in res["message"], res @pytest.mark.p2 @@ -281,7 +281,7 @@ class TestDatasetsList: def test_id_wrong_uuid(self, HttpApiAuth): params = {"id": "d94a8dc02c9711f0930f7fbc369eab6d"} res = list_datasets(HttpApiAuth, params) - assert res["code"] == 102, res + assert res["code"] == 108, res assert "lacks permission for dataset" in res["message"], res @pytest.mark.p2 diff --git a/test/testcases/test_http_api/test_dataset_management/test_update_dataset.py b/test/testcases/test_http_api/test_dataset_management/test_update_dataset.py index a1fcefb66..0cafe1f67 100644 --- a/test/testcases/test_http_api/test_dataset_management/test_update_dataset.py +++ b/test/testcases/test_http_api/test_dataset_management/test_update_dataset.py @@ -33,11 +33,11 @@ class TestAuthorization: @pytest.mark.parametrize( "invalid_auth, expected_code, expected_message", [ - (None, 401, ""), + (None, 0, "`Authorization` can't be empty"), ( RAGFlowHttpApiAuth(INVALID_API_TOKEN), - 401, - "", + 109, + "Authentication error: API key is invalid!", ), ], ids=["empty_auth", "invalid_api_token"], @@ -76,7 +76,7 @@ class TestRquest: def test_payload_empty(self, HttpApiAuth, add_dataset_func): dataset_id = add_dataset_func res = update_dataset(HttpApiAuth, dataset_id, {}) - assert res["code"] == 102, res + assert res["code"] == 101, res assert res["message"] == "No properties were modified", res @pytest.mark.p3 @@ -313,7 +313,7 @@ class TestDatasetUpdate: dataset_id = add_dataset_func payload = {"name": name, "embedding_model": embedding_model} res = update_dataset(HttpApiAuth, dataset_id, payload) - assert res["code"] == 102, res + assert res["code"] == 101, res if "tenant_no_auth" in name: assert res["message"] == f"Unauthorized model: <{embedding_model}>", res else: @@ -494,7 +494,7 @@ class TestDatasetUpdate: dataset_id = add_dataset_func payload = {"pagerank": 50} res = update_dataset(HttpApiAuth, dataset_id, payload) - assert res["code"] == 102, res + assert res["code"] == 101, res assert res["message"] == "'pagerank' can only be set when doc_engine is elasticsearch", res @pytest.mark.p2 diff --git a/test/testcases/test_sdk_api/test_dataset_mangement/test_auto_metadata.py b/test/testcases/test_sdk_api/test_dataset_mangement/test_auto_metadata.py index 908d95dae..2d2dd9246 100644 --- a/test/testcases/test_sdk_api/test_dataset_mangement/test_auto_metadata.py +++ b/test/testcases/test_sdk_api/test_dataset_mangement/test_auto_metadata.py @@ -44,7 +44,7 @@ class TestAutoMetadataOnCreate: dataset = client.create_dataset(**payload) # The SDK should expose parser_config via internal properties or metadata; # we rely on the HTTP API for verification via get_auto_metadata. - cfg = dataset.get_auto_metadata() + cfg = client.get_auto_metadata(dataset_id=dataset.id) assert cfg["enabled"] is True assert len(cfg["fields"]) == 2 names = {f["name"] for f in cfg["fields"]} @@ -74,7 +74,7 @@ class TestAutoMetadataOnUpdate: } dataset.update(payload) - cfg = dataset.get_auto_metadata() + cfg = client.get_auto_metadata(dataset_id=dataset.id) assert cfg["enabled"] is True assert len(cfg["fields"]) == 1 assert cfg["fields"][0]["name"] == "tags" @@ -93,9 +93,9 @@ class TestAutoMetadataOnUpdate: } ], } - dataset.update_auto_metadata(**update_cfg) + client.update_auto_metadata(dataset_id=dataset.id, **update_cfg) - cfg2 = dataset.get_auto_metadata() + cfg2 = client.get_auto_metadata(dataset_id=dataset.id) assert cfg2["enabled"] is False assert len(cfg2["fields"]) == 1 assert cfg2["fields"][0]["name"] == "year" diff --git a/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py b/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py index 6b7679544..444b05d14 100644 --- a/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py +++ b/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py @@ -31,8 +31,8 @@ class TestAuthorization: @pytest.mark.parametrize( "invalid_auth, expected_message", [ - (None, ""), - (INVALID_API_TOKEN, ""), + (None, "Authentication error: API key is invalid!"), + (INVALID_API_TOKEN, "Authentication error: API key is invalid!"), ], ids=["empty_auth", "invalid_api_token"], ) diff --git a/test/testcases/test_sdk_api/test_dataset_mangement/test_delete_datasets.py b/test/testcases/test_sdk_api/test_dataset_mangement/test_delete_datasets.py index 88e95742d..dbf0e588e 100644 --- a/test/testcases/test_sdk_api/test_dataset_mangement/test_delete_datasets.py +++ b/test/testcases/test_sdk_api/test_dataset_mangement/test_delete_datasets.py @@ -27,8 +27,8 @@ class TestAuthorization: @pytest.mark.parametrize( "invalid_auth, expected_message", [ - (None, ""), - (INVALID_API_TOKEN, ""), + (None, "Authentication error: API key is invalid!"), + (INVALID_API_TOKEN, "Authentication error: API key is invalid!"), ], ) def test_auth_invalid(self, invalid_auth, expected_message): diff --git a/test/testcases/test_sdk_api/test_dataset_mangement/test_list_datasets.py b/test/testcases/test_sdk_api/test_dataset_mangement/test_list_datasets.py index b2648d8fd..aa7e1b163 100644 --- a/test/testcases/test_sdk_api/test_dataset_mangement/test_list_datasets.py +++ b/test/testcases/test_sdk_api/test_dataset_mangement/test_list_datasets.py @@ -26,8 +26,8 @@ class TestAuthorization: @pytest.mark.parametrize( "invalid_auth, expected_message", [ - (None, ""), - (INVALID_API_TOKEN, ""), + (None, "Authentication error: API key is invalid!"), + (INVALID_API_TOKEN, "Authentication error: API key is invalid!"), ], ) def test_auth_invalid(self, invalid_auth, expected_message): diff --git a/test/testcases/test_web_api/common.py b/test/testcases/test_web_api/common.py index 9b091539f..05e47eb18 100644 --- a/test/testcases/test_web_api/common.py +++ b/test/testcases/test_web_api/common.py @@ -27,7 +27,6 @@ from utils.file_utils import create_txt_file HEADERS = {"Content-Type": "application/json"} KB_APP_URL = f"/{VERSION}/kb" -DATASETS_URL = f"/api/{VERSION}/datasets" DOCUMENT_APP_URL = f"/{VERSION}/document" CHUNK_API_URL = f"/{VERSION}/chunk" DIALOG_APP_URL = f"/{VERSION}/dialog" @@ -169,28 +168,25 @@ def search_rm(auth, payload=None, *, headers=HEADERS, data=None): # KB APP -def create_dataset(auth, payload=None, *, headers=HEADERS, data=None): - res = requests.post(url=f"{HOST_ADDRESS}{DATASETS_URL}", headers=headers, auth=auth, json=payload, data=data) +def create_kb(auth, payload=None, *, headers=HEADERS, data=None): + res = requests.post(url=f"{HOST_ADDRESS}{KB_APP_URL}/create", headers=headers, auth=auth, json=payload, data=data) return res.json() -def list_datasets(auth, params=None, *, headers=HEADERS): - res = requests.get(url=f"{HOST_ADDRESS}{DATASETS_URL}", headers=headers, auth=auth, params=params) +def list_kbs(auth, params=None, payload=None, *, headers=HEADERS, data=None): + if payload is None: + payload = {} + res = requests.post(url=f"{HOST_ADDRESS}{KB_APP_URL}/list", headers=headers, auth=auth, params=params, json=payload, data=data) return res.json() -def update_dataset(auth, dataset_id, payload=None, *, headers=HEADERS, data=None): - res = requests.put(url=f"{HOST_ADDRESS}{DATASETS_URL}/{dataset_id}", headers=headers, auth=auth, json=payload, data=data) +def update_kb(auth, payload=None, *, headers=HEADERS, data=None): + res = requests.post(url=f"{HOST_ADDRESS}{KB_APP_URL}/update", headers=headers, auth=auth, json=payload, data=data) return res.json() -def delete_datasets(auth, payload=None, *, headers=HEADERS, data=None): - """ - Delete datasets. - The endpoint is DELETE /api/{VERSION}/datasets with payload {"ids": [...]} - This is the standard SDK REST API endpoint for dataset deletion. - """ - res = requests.delete(url=f"{HOST_ADDRESS}{DATASETS_URL}", headers=headers, auth=auth, json=payload, data=data) +def rm_kb(auth, payload=None, *, headers=HEADERS, data=None): + res = requests.post(url=f"{HOST_ADDRESS}{KB_APP_URL}/rm", headers=headers, auth=auth, json=payload, data=data) return res.json() @@ -240,43 +236,23 @@ def kb_pipeline_log_detail(auth, params=None, *, headers=HEADERS): return res.json() -# DATASET GRAPH AND TASKS -def knowledge_graph(auth, dataset_id, params=None): - url = f"{HOST_ADDRESS}{DATASETS_URL}/{dataset_id}/knowledge_graph" - res = requests.get(url=url, headers=HEADERS, auth=auth, params=params) +def kb_run_graphrag(auth, payload=None, *, headers=HEADERS, data=None): + res = requests.post(url=f"{HOST_ADDRESS}{KB_APP_URL}/run_graphrag", headers=headers, auth=auth, json=payload, data=data) return res.json() -def delete_knowledge_graph(auth, dataset_id, payload=None): - url = f"{HOST_ADDRESS}{DATASETS_URL}/{dataset_id}/knowledge_graph" - if payload is None: - res = requests.delete(url=url, headers=HEADERS, auth=auth) - else: - res = requests.delete(url=url, headers=HEADERS, auth=auth, json=payload) +def kb_trace_graphrag(auth, params=None, *, headers=HEADERS): + res = requests.get(url=f"{HOST_ADDRESS}{KB_APP_URL}/trace_graphrag", headers=headers, auth=auth, params=params) return res.json() -def run_graphrag(auth, dataset_id, payload=None): - url = f"{HOST_ADDRESS}{DATASETS_URL}/{dataset_id}/run_graphrag" - res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) +def kb_run_raptor(auth, payload=None, *, headers=HEADERS, data=None): + res = requests.post(url=f"{HOST_ADDRESS}{KB_APP_URL}/run_raptor", headers=headers, auth=auth, json=payload, data=data) return res.json() -def trace_graphrag(auth, dataset_id, params=None): - url = f"{HOST_ADDRESS}{DATASETS_URL}/{dataset_id}/trace_graphrag" - res = requests.get(url=url, headers=HEADERS, auth=auth, params=params) - return res.json() - - -def run_raptor(auth, dataset_id, payload=None): - url = f"{HOST_ADDRESS}{DATASETS_URL}/{dataset_id}/run_raptor" - res = requests.post(url=url, headers=HEADERS, auth=auth, json=payload) - return res.json() - - -def trace_raptor(auth, dataset_id, params=None): - url = f"{HOST_ADDRESS}{DATASETS_URL}/{dataset_id}/trace_raptor" - res = requests.get(url=url, headers=HEADERS, auth=auth, params=params) +def kb_trace_raptor(auth, params=None, *, headers=HEADERS): + res = requests.get(url=f"{HOST_ADDRESS}{KB_APP_URL}/trace_raptor", headers=headers, auth=auth, params=params) return res.json() @@ -310,11 +286,21 @@ def rename_tags(auth, dataset_id, payload=None, *, headers=HEADERS, data=None): return res.json() +def knowledge_graph(auth, dataset_id, params=None, *, headers=HEADERS): + res = requests.get(url=f"{HOST_ADDRESS}{KB_APP_URL}/{dataset_id}/knowledge_graph", headers=headers, auth=auth, params=params) + return res.json() + + +def delete_knowledge_graph(auth, dataset_id, payload=None, *, headers=HEADERS, data=None): + res = requests.delete(url=f"{HOST_ADDRESS}{KB_APP_URL}/{dataset_id}/knowledge_graph", headers=headers, auth=auth, json=payload, data=data) + return res.json() + + def batch_create_datasets(auth, num): ids = [] for i in range(num): - res = create_dataset(auth, {"name": f"kb_{i}"}) - ids.append(res["data"]["id"]) + res = create_kb(auth, {"name": f"kb_{i}"}) + ids.append(res["data"]["kb_id"]) return ids diff --git a/test/testcases/test_web_api/conftest.py b/test/testcases/test_web_api/conftest.py index 2f3a2c628..51db85b3d 100644 --- a/test/testcases/test_web_api/conftest.py +++ b/test/testcases/test_web_api/conftest.py @@ -26,9 +26,9 @@ from common import ( delete_dialogs, list_chunks, list_documents, - list_datasets, + list_kbs, parse_documents, - delete_datasets, + rm_kb, ) from libs.auth import RAGFlowWebApiAuth from pytest import FixtureRequest @@ -104,9 +104,9 @@ def require_env_flag(): @pytest.fixture(scope="function") def clear_datasets(request: FixtureRequest, WebApiAuth: RAGFlowWebApiAuth): def cleanup(): - res = list_datasets(WebApiAuth, params={"page_size": 1000}) - kb_ids = [kb["id"] for kb in res["data"]] - delete_datasets(WebApiAuth, {"ids": kb_ids}) + res = list_kbs(WebApiAuth, params={"page_size": 1000}) + for kb in res["data"]["kbs"]: + rm_kb(WebApiAuth, {"kb_id": kb["id"]}) request.addfinalizer(cleanup) @@ -122,9 +122,9 @@ def clear_dialogs(request, WebApiAuth): @pytest.fixture(scope="class") def add_dataset(request: FixtureRequest, WebApiAuth: RAGFlowWebApiAuth) -> str: def cleanup(): - res = list_datasets(WebApiAuth, params={"page_size": 1000}) - kb_ids = [kb["id"] for kb in res["data"]] - delete_datasets(WebApiAuth, {"ids": kb_ids}) + res = list_kbs(WebApiAuth, params={"page_size": 1000}) + for kb in res["data"]["kbs"]: + rm_kb(WebApiAuth, {"kb_id": kb["id"]}) request.addfinalizer(cleanup) return batch_create_datasets(WebApiAuth, 1)[0] @@ -133,9 +133,9 @@ def add_dataset(request: FixtureRequest, WebApiAuth: RAGFlowWebApiAuth) -> str: @pytest.fixture(scope="function") def add_dataset_func(request: FixtureRequest, WebApiAuth: RAGFlowWebApiAuth) -> str: def cleanup(): - res = list_datasets(WebApiAuth, params={"page_size": 1000}) - kb_ids = [kb["id"] for kb in res["data"]] - delete_datasets(WebApiAuth, {"ids": kb_ids}) + res = list_kbs(WebApiAuth, params={"page_size": 1000}) + for kb in res["data"]["kbs"]: + rm_kb(WebApiAuth, {"kb_id": kb["id"]}) request.addfinalizer(cleanup) return batch_create_datasets(WebApiAuth, 1)[0] diff --git a/test/testcases/test_web_api/test_dataset_management/test_dataset_sdk_routes_unit.py b/test/testcases/test_web_api/test_dataset_management/test_dataset_sdk_routes_unit.py index 445f74c6a..967c95ef7 100644 --- a/test/testcases/test_web_api/test_dataset_management/test_dataset_sdk_routes_unit.py +++ b/test/testcases/test_web_api/test_dataset_management/test_dataset_sdk_routes_unit.py @@ -409,7 +409,7 @@ def _load_dataset_module(monkeypatch): rag_nlp_pkg.search = search_mod module_name = "test_dataset_sdk_routes_unit_module" - module_path = repo_root / "api" / "apps" / "restful_apis" / "dataset_api.py" + module_path = repo_root / "api" / "apps" / "sdk" / "dataset.py" spec = importlib.util.spec_from_file_location(module_name, module_path) module = importlib.util.module_from_spec(spec) module.manager = _DummyManager() @@ -418,7 +418,7 @@ def _load_dataset_module(monkeypatch): return module -@pytest.mark.p3 +@pytest.mark.p2 def test_create_route_error_matrix_unit(monkeypatch): module = _load_dataset_module(monkeypatch) req_state = {"name": "kb"} @@ -448,7 +448,7 @@ def test_create_route_error_matrix_unit(monkeypatch): assert res["message"] == "Database operation failed", res -@pytest.mark.p3 +@pytest.mark.p2 def test_delete_route_error_summary_matrix_unit(monkeypatch): module = _load_dataset_module(monkeypatch) req_state = {"ids": ["kb-1"]} @@ -476,7 +476,7 @@ def test_delete_route_error_summary_matrix_unit(monkeypatch): assert res["code"] == module.RetCode.SUCCESS, res -@pytest.mark.p3 +@pytest.mark.p2 def test_update_route_branch_matrix_unit(monkeypatch): module = _load_dataset_module(monkeypatch) req_state = {"name": "new"} @@ -556,7 +556,7 @@ def test_update_route_branch_matrix_unit(monkeypatch): assert res["message"] == "Database operation failed", res -@pytest.mark.p3 +@pytest.mark.p2 def test_list_knowledge_graph_delete_kg_matrix_unit(monkeypatch): module = _load_dataset_module(monkeypatch) @@ -629,7 +629,7 @@ def test_list_knowledge_graph_delete_kg_matrix_unit(monkeypatch): assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res -@pytest.mark.p3 +@pytest.mark.p2 def test_run_trace_graphrag_matrix_unit(monkeypatch): module = _load_dataset_module(monkeypatch) @@ -705,7 +705,7 @@ def test_run_trace_graphrag_matrix_unit(monkeypatch): assert res["data"]["id"] == "task-1", res -@pytest.mark.p3 +@pytest.mark.p2 def test_run_trace_raptor_matrix_unit(monkeypatch): module = _load_dataset_module(monkeypatch) diff --git a/test/testcases/test_web_api/test_document_app/test_create_document.py b/test/testcases/test_web_api/test_document_app/test_create_document.py index 4e590ba3d..8c39bdf4e 100644 --- a/test/testcases/test_web_api/test_document_app/test_create_document.py +++ b/test/testcases/test_web_api/test_document_app/test_create_document.py @@ -19,7 +19,7 @@ from types import SimpleNamespace from concurrent.futures import ThreadPoolExecutor, as_completed import pytest -from test_web_api.common import create_document, list_datasets +from common import create_document, list_kbs from configs import DOCUMENT_NAME_LIMIT, INVALID_API_TOKEN from libs.auth import RAGFlowWebApiAuth from utils.file_utils import create_txt_file @@ -91,8 +91,8 @@ class TestDocumentCreate: assert len(responses) == count, responses assert all(future.result()["code"] == 0 for future in futures), responses - res = list_datasets(WebApiAuth, {"id": kb_id}) - assert res["data"][0]["document_count"] == count, res + res = list_kbs(WebApiAuth, {"id": kb_id}) + assert res["data"]["kbs"][0]["doc_num"] == count, res def _run(coro): diff --git a/test/testcases/test_web_api/test_document_app/test_upload_documents.py b/test/testcases/test_web_api/test_document_app/test_upload_documents.py index c8b82774e..b4a151551 100644 --- a/test/testcases/test_web_api/test_document_app/test_upload_documents.py +++ b/test/testcases/test_web_api/test_document_app/test_upload_documents.py @@ -20,7 +20,7 @@ from types import ModuleType, SimpleNamespace from concurrent.futures import ThreadPoolExecutor, as_completed import pytest -from common import list_datasets, upload_documents +from common import list_kbs, upload_documents from configs import DOCUMENT_NAME_LIMIT, INVALID_API_TOKEN from libs.auth import RAGFlowWebApiAuth from utils.file_utils import create_txt_file @@ -172,8 +172,8 @@ class TestDocumentsUpload: res = upload_documents(WebApiAuth, {"kb_id": kb_id}, fps) assert res["code"] == 0, res - res = list_datasets(WebApiAuth) - assert res["data"][0]["document_count"] == expected_document_count, res + res = list_kbs(WebApiAuth) + assert res["data"]["kbs"][0]["doc_num"] == expected_document_count, res @pytest.mark.p3 def test_concurrent_upload(self, WebApiAuth, add_dataset_func, tmp_path): @@ -191,8 +191,8 @@ class TestDocumentsUpload: assert len(responses) == count, responses assert all(future.result()["code"] == 0 for future in futures), responses - res = list_datasets(WebApiAuth) - assert res["data"][0]["document_count"] == count, res + res = list_kbs(WebApiAuth) + assert res["data"]["kbs"][0]["doc_num"] == count, res class _AwaitableValue: diff --git a/test/testcases/test_web_api/test_kb_app/conftest.py b/test/testcases/test_web_api/test_kb_app/conftest.py index d51df5c21..8a2387391 100644 --- a/test/testcases/test_web_api/test_kb_app/conftest.py +++ b/test/testcases/test_web_api/test_kb_app/conftest.py @@ -14,7 +14,7 @@ # limitations under the License. # import pytest -from common import batch_create_datasets, list_datasets, delete_datasets +from common import batch_create_datasets, list_kbs, rm_kb from libs.auth import RAGFlowWebApiAuth from pytest import FixtureRequest from ragflow_sdk import RAGFlow @@ -26,10 +26,11 @@ def add_datasets(request: FixtureRequest, client: RAGFlow, WebApiAuth: RAGFlowWe def cleanup(): # Web KB cleanup cannot call SDK dataset bulk delete with empty ids; deletion must stay explicit. - res = list_datasets(WebApiAuth, params={"page_size": 1000}) - existing_ids = {kb["id"] for kb in res["data"]} - ids_to_delete = list({dataset_id for dataset_id in dataset_ids if dataset_id in existing_ids}) - delete_datasets(WebApiAuth, {"ids": ids_to_delete}) + res = list_kbs(WebApiAuth, params={"page_size": 1000}) + existing_ids = {kb["id"] for kb in res["data"]["kbs"]} + for dataset_id in dataset_ids: + if dataset_id in existing_ids: + rm_kb(WebApiAuth, {"kb_id": dataset_id}) request.addfinalizer(cleanup) return dataset_ids @@ -41,10 +42,11 @@ def add_datasets_func(request: FixtureRequest, client: RAGFlow, WebApiAuth: RAGF def cleanup(): # Web KB cleanup cannot call SDK dataset bulk delete with empty ids; deletion must stay explicit. - res = list_datasets(WebApiAuth, params={"page_size": 1000}) - existing_ids = {kb["id"] for kb in res["data"]} - ids_to_delete = list({dataset_id for dataset_id in dataset_ids if dataset_id in existing_ids}) - delete_datasets(WebApiAuth, {"ids": ids_to_delete}) + res = list_kbs(WebApiAuth, params={"page_size": 1000}) + existing_ids = {kb["id"] for kb in res["data"]["kbs"]} + for dataset_id in dataset_ids: + if dataset_id in existing_ids: + rm_kb(WebApiAuth, {"kb_id": dataset_id}) request.addfinalizer(cleanup) return dataset_ids diff --git a/test/testcases/test_web_api/test_kb_app/test_create_kb.py b/test/testcases/test_web_api/test_kb_app/test_create_kb.py index b0942cd6e..0e7fe0c55 100644 --- a/test/testcases/test_web_api/test_kb_app/test_create_kb.py +++ b/test/testcases/test_web_api/test_kb_app/test_create_kb.py @@ -16,7 +16,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed import pytest -from common import create_dataset +from common import create_kb from configs import DATASET_NAME_LIMIT, INVALID_API_TOKEN from hypothesis import example, given, settings from libs.auth import RAGFlowWebApiAuth @@ -35,7 +35,7 @@ class TestAuthorization: ids=["empty_auth", "invalid_api_token"], ) def test_auth_invalid(self, invalid_auth, expected_code, expected_message): - res = create_dataset(invalid_auth, {"name": "auth_test"}) + res = create_kb(invalid_auth, {"name": "auth_test"}) assert res["code"] == expected_code, res assert res["message"] == expected_message, res @@ -46,14 +46,14 @@ class TestCapability: def test_create_kb_1k(self, WebApiAuth): for i in range(1_000): payload = {"name": f"dataset_{i}"} - res = create_dataset(WebApiAuth, payload) + res = create_kb(WebApiAuth, payload) assert res["code"] == 0, f"Failed to create dataset {i}" @pytest.mark.p3 def test_create_kb_concurrent(self, WebApiAuth): count = 100 with ThreadPoolExecutor(max_workers=5) as executor: - futures = [executor.submit(create_dataset, WebApiAuth, {"name": f"dataset_{i}"}) for i in range(count)] + futures = [executor.submit(create_kb, WebApiAuth, {"name": f"dataset_{i}"}) for i in range(count)] responses = list(as_completed(futures)) assert len(responses) == count, responses assert all(future.result()["code"] == 0 for future in futures) @@ -66,44 +66,44 @@ class TestDatasetCreate: @example("a" * 128) @settings(max_examples=20) def test_name(self, WebApiAuth, name): - res = create_dataset(WebApiAuth, {"name": name}) + res = create_kb(WebApiAuth, {"name": name}) assert res["code"] == 0, res @pytest.mark.p2 @pytest.mark.parametrize( "name, expected_message", [ - ("", "Field: - Message: "), - (" ", "Field: - Message: "), - ("a" * (DATASET_NAME_LIMIT + 1), "Field: - Message: "), - (0, "Field: - Message: "), - (None, "Field: - Message: "), + ("", "Dataset name can't be empty."), + (" ", "Dataset name can't be empty."), + ("a" * (DATASET_NAME_LIMIT + 1), "Dataset name length is 129 which is large than 128"), + (0, "Dataset name must be string."), + (None, "Dataset name must be string."), ], ids=["empty_name", "space_name", "too_long_name", "invalid_name", "None_name"], ) def test_name_invalid(self, WebApiAuth, name, expected_message): payload = {"name": name} - res = create_dataset(WebApiAuth, payload) - assert res["code"] == 101, res + res = create_kb(WebApiAuth, payload) + assert res["code"] == 102, res assert expected_message in res["message"], res @pytest.mark.p3 def test_name_duplicated(self, WebApiAuth): name = "duplicated_name" payload = {"name": name} - res = create_dataset(WebApiAuth, payload) + res = create_kb(WebApiAuth, payload) assert res["code"] == 0, res - res = create_dataset(WebApiAuth, payload) + res = create_kb(WebApiAuth, payload) assert res["code"] == 0, res @pytest.mark.p3 def test_name_case_insensitive(self, WebApiAuth): name = "CaseInsensitive" payload = {"name": name.upper()} - res = create_dataset(WebApiAuth, payload) + res = create_kb(WebApiAuth, payload) assert res["code"] == 0, res payload = {"name": name.lower()} - res = create_dataset(WebApiAuth, payload) + res = create_kb(WebApiAuth, payload) assert res["code"] == 0, res diff --git a/test/testcases/test_web_api/test_kb_app/test_kb_pipeline_tasks.py b/test/testcases/test_web_api/test_kb_app/test_kb_pipeline_tasks.py index 6bf2da491..21fb0ec50 100644 --- a/test/testcases/test_web_api/test_kb_app/test_kb_pipeline_tasks.py +++ b/test/testcases/test_web_api/test_kb_app/test_kb_pipeline_tasks.py @@ -14,17 +14,17 @@ # limitations under the License. # import pytest -from test_web_api.common import ( +from common import ( kb_delete_pipeline_logs, kb_list_pipeline_dataset_logs, kb_list_pipeline_logs, kb_pipeline_log_detail, - run_graphrag, - trace_graphrag, - run_raptor, - trace_raptor, + kb_run_graphrag, kb_run_mindmap, + kb_run_raptor, + kb_trace_graphrag, kb_trace_mindmap, + kb_trace_raptor, list_documents, parse_documents, ) @@ -101,13 +101,13 @@ class TestKbPipelineTasks: @pytest.mark.p3 def test_graphrag_run_and_trace(self, WebApiAuth, add_chunks): kb_id, _, _ = add_chunks - run_res = run_graphrag(WebApiAuth, kb_id) + run_res = kb_run_graphrag(WebApiAuth, {"kb_id": kb_id}) assert run_res["code"] == 0, run_res task_id = run_res["data"]["graphrag_task_id"] assert task_id, run_res - _wait_for_task(trace_graphrag, WebApiAuth, kb_id, task_id) - trace_res = trace_graphrag(WebApiAuth, kb_id) + _wait_for_task(kb_trace_graphrag, WebApiAuth, kb_id, task_id) + trace_res = kb_trace_graphrag(WebApiAuth, {"kb_id": kb_id}) assert trace_res["code"] == 0, trace_res task = _find_task(trace_res["data"], task_id) assert task, trace_res @@ -118,13 +118,13 @@ class TestKbPipelineTasks: @pytest.mark.p3 def test_raptor_run_and_trace(self, WebApiAuth, add_chunks): kb_id, _, _ = add_chunks - run_res = run_raptor(WebApiAuth, kb_id) + run_res = kb_run_raptor(WebApiAuth, {"kb_id": kb_id}) assert run_res["code"] == 0, run_res task_id = run_res["data"]["raptor_task_id"] assert task_id, run_res - _wait_for_task(trace_raptor, WebApiAuth, kb_id, task_id) - trace_res = trace_raptor(WebApiAuth, kb_id) + _wait_for_task(kb_trace_raptor, WebApiAuth, kb_id, task_id) + trace_res = kb_trace_raptor(WebApiAuth, {"kb_id": kb_id}) assert trace_res["code"] == 0, trace_res task = _find_task(trace_res["data"], task_id) assert task, trace_res diff --git a/test/testcases/test_web_api/test_kb_app/test_kb_routes_unit.py b/test/testcases/test_web_api/test_kb_app/test_kb_routes_unit.py index 40a74ae12..d3a3cde43 100644 --- a/test/testcases/test_web_api/test_kb_app/test_kb_routes_unit.py +++ b/test/testcases/test_web_api/test_kb_app/test_kb_routes_unit.py @@ -181,7 +181,7 @@ def set_tenant_info(): return None -@pytest.mark.p3 +@pytest.mark.p2 def test_create_branches(monkeypatch): module = _load_kb_module(monkeypatch) @@ -211,7 +211,7 @@ def test_create_branches(monkeypatch): assert "save boom" in res["message"], res -@pytest.mark.p3 +@pytest.mark.p2 def test_update_branches(monkeypatch): module = _load_kb_module(monkeypatch) update_route = _unwrap_route(module.update) @@ -326,7 +326,7 @@ def test_update_branches(monkeypatch): assert "update boom" in res["message"], res -@pytest.mark.p3 +@pytest.mark.p2 def test_update_metadata_setting_not_found(monkeypatch): module = _load_kb_module(monkeypatch) _set_request_json(monkeypatch, module, {"kb_id": "missing-kb", "metadata": {}}) @@ -336,7 +336,7 @@ def test_update_metadata_setting_not_found(monkeypatch): assert "Database error" in res["message"], res -@pytest.mark.p3 +@pytest.mark.p2 def test_detail_branches(monkeypatch): module = _load_kb_module(monkeypatch) @@ -380,7 +380,7 @@ def test_detail_branches(monkeypatch): assert "detail boom" in res["message"], res -@pytest.mark.p3 +@pytest.mark.p2 def test_list_kbs_owner_ids_and_desc(monkeypatch): module = _load_kb_module(monkeypatch) @@ -414,7 +414,7 @@ def test_list_kbs_owner_ids_and_desc(monkeypatch): assert "list boom" in res["message"], res -@pytest.mark.p3 +@pytest.mark.p2 def test_rm_and_rm_sync_branches(monkeypatch): module = _load_kb_module(monkeypatch) @@ -491,7 +491,7 @@ def test_rm_and_rm_sync_branches(monkeypatch): assert "rm boom" in res["message"], res -@pytest.mark.p3 +@pytest.mark.p2 def test_tags_and_meta_branches(monkeypatch): module = _load_kb_module(monkeypatch) @@ -560,7 +560,7 @@ def test_tags_and_meta_branches(monkeypatch): assert res["data"]["finished"] == 1, res -@pytest.mark.p3 +@pytest.mark.p2 def test_knowledge_graph_branches(monkeypatch): module = _load_kb_module(monkeypatch) @@ -636,7 +636,7 @@ def test_knowledge_graph_branches(monkeypatch): assert res["data"] is True, res -@pytest.mark.p3 +@pytest.mark.p2 def test_list_pipeline_logs_validation_branches(monkeypatch): module = _load_kb_module(monkeypatch) @@ -681,7 +681,7 @@ def test_list_pipeline_logs_validation_branches(monkeypatch): assert "Create data filter is abnormal." in res["message"], res -@pytest.mark.p3 +@pytest.mark.p2 def test_list_pipeline_logs_filter_and_exception_branches(monkeypatch): module = _load_kb_module(monkeypatch) @@ -718,7 +718,7 @@ def test_list_pipeline_logs_filter_and_exception_branches(monkeypatch): assert "logs boom" in res["message"], res -@pytest.mark.p3 +@pytest.mark.p2 def test_list_pipeline_dataset_logs_branches(monkeypatch): module = _load_kb_module(monkeypatch) @@ -792,7 +792,7 @@ def test_list_pipeline_dataset_logs_branches(monkeypatch): assert "dataset logs boom" in res["message"], res -@pytest.mark.p3 +@pytest.mark.p2 def test_pipeline_log_detail_and_delete_routes_branches(monkeypatch): module = _load_kb_module(monkeypatch) @@ -841,7 +841,7 @@ def test_pipeline_log_detail_and_delete_routes_branches(monkeypatch): assert res["data"]["id"] == "log-1", res -@pytest.mark.p3 +@pytest.mark.p2 @pytest.mark.parametrize( "route_name,task_attr,response_key,task_type", [ @@ -914,7 +914,7 @@ def test_run_pipeline_task_routes_branch_matrix(monkeypatch, route_name, task_at assert queue_calls["doc_ids"] == ["doc-1", "doc-2"], queue_calls -@pytest.mark.p3 +@pytest.mark.p2 @pytest.mark.parametrize( "route_name,task_attr,empty_on_missing_task,error_text", [ @@ -970,7 +970,7 @@ def test_trace_pipeline_task_routes_branch_matrix(monkeypatch, route_name, task_ assert res["data"]["id"] == "task-1", res -@pytest.mark.p3 +@pytest.mark.p2 def test_unbind_task_branch_matrix(monkeypatch): module = _load_kb_module(monkeypatch) route = inspect.unwrap(module.delete_kb_task) @@ -1060,7 +1060,7 @@ def test_unbind_task_branch_matrix(monkeypatch): assert "cannot delete task" in res["message"], res -@pytest.mark.p3 +@pytest.mark.p2 def test_check_embedding_similarity_threshold_matrix_unit(monkeypatch): module = _load_kb_module(monkeypatch) route = inspect.unwrap(module.check_embedding) @@ -1229,7 +1229,7 @@ def test_check_embedding_similarity_threshold_matrix_unit(monkeypatch): assert res["data"]["summary"]["avg_cos_sim"] > 0.9, res -@pytest.mark.p3 +@pytest.mark.p2 def test_check_embedding_error_and_empty_sample_paths_unit(monkeypatch): module = _load_kb_module(monkeypatch) route = inspect.unwrap(module.check_embedding) diff --git a/test/testcases/test_web_api/test_kb_app/test_kb_tags_meta.py b/test/testcases/test_web_api/test_kb_app/test_kb_tags_meta.py index 9b5bb5237..810b636de 100644 --- a/test/testcases/test_web_api/test_kb_app/test_kb_tags_meta.py +++ b/test/testcases/test_web_api/test_kb_app/test_kb_tags_meta.py @@ -16,7 +16,7 @@ import uuid import pytest -from test_web_api.common import ( +from common import ( delete_knowledge_graph, kb_basic_info, kb_get_meta, diff --git a/test/testcases/test_web_api/test_kb_app/test_list_kbs.py b/test/testcases/test_web_api/test_kb_app/test_list_kbs.py index b6ed92f76..530686788 100644 --- a/test/testcases/test_web_api/test_kb_app/test_list_kbs.py +++ b/test/testcases/test_web_api/test_kb_app/test_list_kbs.py @@ -13,11 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import json from concurrent.futures import ThreadPoolExecutor, as_completed import pytest -from common import list_datasets +from common import list_kbs from configs import INVALID_API_TOKEN from libs.auth import RAGFlowWebApiAuth from utils import is_sorted @@ -33,7 +32,7 @@ class TestAuthorization: ], ) def test_auth_invalid(self, invalid_auth, expected_code, expected_message): - res = list_datasets(invalid_auth) + res = list_kbs(invalid_auth) assert res["code"] == expected_code, res assert res["message"] == expected_message, res @@ -43,7 +42,7 @@ class TestCapability: def test_concurrent_list(self, WebApiAuth): count = 100 with ThreadPoolExecutor(max_workers=5) as executor: - futures = [executor.submit(list_datasets, WebApiAuth) for i in range(count)] + futures = [executor.submit(list_kbs, WebApiAuth) for i in range(count)] responses = list(as_completed(futures)) assert len(responses) == count, responses assert all(future.result()["code"] == 0 for future in futures) @@ -53,15 +52,15 @@ class TestCapability: class TestDatasetsList: @pytest.mark.p2 def test_params_unset(self, WebApiAuth): - res = list_datasets(WebApiAuth, None) + res = list_kbs(WebApiAuth, None) assert res["code"] == 0, res - assert len(res["data"]) == 5, res + assert len(res["data"]["kbs"]) == 5, res @pytest.mark.p2 def test_params_empty(self, WebApiAuth): - res = list_datasets(WebApiAuth, {}) + res = list_kbs(WebApiAuth, {}) assert res["code"] == 0, res - assert len(res["data"]) == 5, res + assert len(res["data"]["kbs"]) == 5, res @pytest.mark.p1 @pytest.mark.parametrize( @@ -76,9 +75,9 @@ class TestDatasetsList: ids=["normal_middle_page", "normal_last_partial_page", "beyond_max_page", "string_page_number", "full_data_single_page"], ) def test_page(self, WebApiAuth, params, expected_page_size): - res = list_datasets(WebApiAuth, params) + res = list_kbs(WebApiAuth, params) assert res["code"] == 0, res - assert len(res["data"]) == expected_page_size, res + assert len(res["data"]["kbs"]) == expected_page_size, res @pytest.mark.skip @pytest.mark.p2 @@ -91,16 +90,16 @@ class TestDatasetsList: ids=["page_0", "page_a"], ) def test_page_invalid(self, WebApiAuth, params, expected_code, expected_message): - res = list_datasets(WebApiAuth, params=params) + res = list_kbs(WebApiAuth, params=params) assert res["code"] == expected_code, res assert expected_message in res["message"], res @pytest.mark.p2 def test_page_none(self, WebApiAuth): params = {"page": None} - res = list_datasets(WebApiAuth, params) + res = list_kbs(WebApiAuth, params) assert res["code"] == 0, res - assert len(res["data"]) == 5, res + assert len(res["data"]["kbs"]) == 5, res @pytest.mark.p1 @pytest.mark.parametrize( @@ -115,9 +114,9 @@ class TestDatasetsList: ids=["min_valid_page_size", "medium_page_size", "page_size_equals_total", "page_size_exceeds_total", "string_type_page_size"], ) def test_page_size(self, WebApiAuth, params, expected_page_size): - res = list_datasets(WebApiAuth, params) + res = list_kbs(WebApiAuth, params) assert res["code"] == 0, res - assert len(res["data"]) == expected_page_size, res + assert len(res["data"]["kbs"]) == expected_page_size, res @pytest.mark.skip @pytest.mark.p2 @@ -129,27 +128,27 @@ class TestDatasetsList: ], ) def test_page_size_invalid(self, WebApiAuth, params, expected_code, expected_message): - res = list_datasets(WebApiAuth, params) + res = list_kbs(WebApiAuth, params) assert res["code"] == expected_code, res assert expected_message in res["message"], res @pytest.mark.p2 def test_page_size_none(self, WebApiAuth): params = {"page_size": None} - res = list_datasets(WebApiAuth, params) + res = list_kbs(WebApiAuth, params) assert res["code"] == 0, res - assert len(res["data"]) == 5, res + assert len(res["data"]["kbs"]) == 5, res @pytest.mark.p3 @pytest.mark.parametrize( "params, assertions", [ - ({"orderby": "update_time"}, lambda r: (is_sorted(r["data"], "update_time", True))), + ({"orderby": "update_time"}, lambda r: (is_sorted(r["data"]["kbs"], "update_time", True))), ], ids=["orderby_update_time"], ) def test_orderby(self, WebApiAuth, params, assertions): - res = list_datasets(WebApiAuth, params) + res = list_kbs(WebApiAuth, params) assert res["code"] == 0, res if callable(assertions): assert assertions(res), res @@ -158,13 +157,13 @@ class TestDatasetsList: @pytest.mark.parametrize( "params, assertions", [ - ({"desc": "True"}, lambda r: (is_sorted(r["data"], "update_time", True))), - ({"desc": "False"}, lambda r: (is_sorted(r["data"], "update_time", False))), + ({"desc": "True"}, lambda r: (is_sorted(r["data"]["kbs"], "update_time", True))), + ({"desc": "False"}, lambda r: (is_sorted(r["data"]["kbs"], "update_time", False))), ], ids=["desc=True", "desc=False"], ) def test_desc(self, WebApiAuth, params, assertions): - res = list_datasets(WebApiAuth, params) + res = list_kbs(WebApiAuth, params) assert res["code"] == 0, res if callable(assertions): @@ -174,28 +173,29 @@ class TestDatasetsList: @pytest.mark.parametrize( "params, expected_page_size", [ - ({"ext": json.dumps({"parser_id": "naive"})}, 5), - ({"ext": json.dumps({"parser_id": "qa"})}, 0), + ({"parser_id": "naive"}, 5), + ({"parser_id": "qa"}, 0), ], ids=["naive", "dqa"], ) def test_parser_id(self, WebApiAuth, params, expected_page_size): - res = list_datasets(WebApiAuth, params) + res = list_kbs(WebApiAuth, params) assert res["code"] == 0, res - assert len(res["data"]) == expected_page_size, res + assert len(res["data"]["kbs"]) == expected_page_size, res @pytest.mark.p2 def test_owner_ids_payload_mode(self, WebApiAuth): - base_res = list_datasets(WebApiAuth, {"page_size": 10}) + base_res = list_kbs(WebApiAuth, {"page_size": 10}) assert base_res["code"] == 0, base_res - assert base_res["data"], base_res - owner_id = base_res["data"][0]["tenant_id"] + assert base_res["data"]["kbs"], base_res + owner_id = base_res["data"]["kbs"][0]["tenant_id"] - res = list_datasets( + res = list_kbs( WebApiAuth, - params={"page": 1, "page_size": 2, "desc": "false", "ext": json.dumps({"owner_ids": [owner_id]})}, + params={"page": 1, "page_size": 2, "desc": "false"}, + payload={"owner_ids": [owner_id]}, ) assert res["code"] == 0, res - assert res["total_datasets"] >= len(res["data"]), res - assert len(res["data"]) <= 2, res - assert all(kb["tenant_id"] == owner_id for kb in res["data"]), res + assert res["data"]["total"] >= len(res["data"]["kbs"]), res + assert len(res["data"]["kbs"]) <= 2, res + assert all(kb["tenant_id"] == owner_id for kb in res["data"]["kbs"]), res diff --git a/test/testcases/test_web_api/test_kb_app/test_rm_kb.py b/test/testcases/test_web_api/test_kb_app/test_rm_kb.py index d421bb3aa..21ea624a6 100644 --- a/test/testcases/test_web_api/test_kb_app/test_rm_kb.py +++ b/test/testcases/test_web_api/test_kb_app/test_rm_kb.py @@ -16,8 +16,8 @@ import pytest from common import ( - list_datasets, - delete_datasets, + list_kbs, + rm_kb, ) from configs import INVALID_API_TOKEN from libs.auth import RAGFlowWebApiAuth @@ -33,7 +33,7 @@ class TestAuthorization: ], ) def test_auth_invalid(self, invalid_auth, expected_code, expected_message): - res = delete_datasets(invalid_auth) + res = rm_kb(invalid_auth) assert res["code"] == expected_code, res assert res["message"] == expected_message, res @@ -42,20 +42,20 @@ class TestDatasetsDelete: @pytest.mark.p1 def test_kb_id(self, WebApiAuth, add_datasets_func): kb_ids = add_datasets_func - payload = {"ids": [kb_ids[0]]} - res = delete_datasets(WebApiAuth, payload) + payload = {"kb_id": kb_ids[0]} + res = rm_kb(WebApiAuth, payload) assert res["code"] == 0, res - res = list_datasets(WebApiAuth) - assert len(res["data"]) == 2, res + res = list_kbs(WebApiAuth) + assert len(res["data"]["kbs"]) == 2, res @pytest.mark.p2 @pytest.mark.usefixtures("add_dataset_func") def test_id_wrong_uuid(self, WebApiAuth): - payload = {"ids": ["d94a8dc02c9711f0930f7fbc369eab6d"]} - res = delete_datasets(WebApiAuth, payload) - assert res["code"] == 102, res - assert "lacks permission" in res["message"], res + payload = {"kb_id": "d94a8dc02c9711f0930f7fbc369eab6d"} + res = rm_kb(WebApiAuth, payload) + assert res["code"] == 109, res + assert "No authorization." in res["message"], res - res = list_datasets(WebApiAuth) - assert len(res["data"]) == 1, res + res = list_kbs(WebApiAuth) + assert len(res["data"]["kbs"]) == 1, res diff --git a/test/testcases/test_web_api/test_kb_app/test_update_kb.py b/test/testcases/test_web_api/test_kb_app/test_update_kb.py index 7ee1bad5f..641ed3b1f 100644 --- a/test/testcases/test_web_api/test_kb_app/test_update_kb.py +++ b/test/testcases/test_web_api/test_kb_app/test_update_kb.py @@ -17,7 +17,7 @@ import os from concurrent.futures import ThreadPoolExecutor, as_completed import pytest -from test_web_api.common import update_dataset +from common import update_kb from configs import DATASET_NAME_LIMIT, INVALID_API_TOKEN from hypothesis import HealthCheck, example, given, settings from libs.auth import RAGFlowWebApiAuth @@ -37,7 +37,7 @@ class TestAuthorization: ids=["empty_auth", "invalid_api_token"], ) def test_auth_invalid(self, invalid_auth, expected_code, expected_message): - res = update_dataset(invalid_auth, "dataset_id") + res = update_kb(invalid_auth, "dataset_id") assert res["code"] == expected_code, res assert res["message"] == expected_message, res @@ -50,13 +50,13 @@ class TestCapability: with ThreadPoolExecutor(max_workers=5) as executor: futures = [ executor.submit( - update_dataset, + update_kb, WebApiAuth, - dataset_id, { + "kb_id": dataset_id, "name": f"dataset_{i}", "description": "", - "chunk_method": "naive", + "parser_id": "naive", }, ) for i in range(count) @@ -69,8 +69,8 @@ class TestCapability: class TestDatasetUpdate: @pytest.mark.p3 def test_dataset_id_not_uuid(self, WebApiAuth): - payload = {"name": "not uuid", "description": "", "chunk_method": "naive"} - res = update_dataset(WebApiAuth, "not_uuid", payload) + payload = {"name": "not uuid", "description": "", "parser_id": "naive", "kb_id": "not_uuid"} + res = update_kb(WebApiAuth, payload) assert res["code"] == 109, res assert "No authorization." in res["message"], res @@ -81,8 +81,8 @@ class TestDatasetUpdate: @settings(max_examples=20, suppress_health_check=[HealthCheck.function_scoped_fixture], deadline=None) def test_name(self, WebApiAuth, add_dataset_func, name): dataset_id = add_dataset_func - payload = {"name": name, "description": "", "chunk_method": "naive"} - res = update_dataset(WebApiAuth, dataset_id, payload) + payload = {"name": name, "description": "", "parser_id": "naive", "kb_id": dataset_id} + res = update_kb(WebApiAuth, payload) assert res["code"] == 0, res assert res["data"]["name"] == name, res @@ -90,27 +90,27 @@ class TestDatasetUpdate: @pytest.mark.parametrize( "name, expected_message", [ - ("", "Field: - Message: "), - (" ", "Field: - Message: "), - ("a" * (DATASET_NAME_LIMIT + 1), "Field: - Message: "), - (0, "Field: - Message: "), - (None, "Field: - Message: "), + ("", "Dataset name can't be empty."), + (" ", "Dataset name can't be empty."), + ("a" * (DATASET_NAME_LIMIT + 1), "Dataset name length is 129 which is large than 128"), + (0, "Dataset name must be string."), + (None, "Dataset name must be string."), ], ids=["empty_name", "space_name", "too_long_name", "invalid_name", "None_name"], ) def test_name_invalid(self, WebApiAuth, add_dataset_func, name, expected_message): kb_id = add_dataset_func - payload = {"name": name, "description": "", "chunk_method": "naive"} - res = update_dataset(WebApiAuth, kb_id, payload) - assert res["code"] == 101, res + payload = {"name": name, "description": "", "parser_id": "naive", "kb_id": kb_id} + res = update_kb(WebApiAuth, payload) + assert res["code"] == 102, res assert expected_message in res["message"], res @pytest.mark.p3 def test_name_duplicated(self, WebApiAuth, add_datasets_func): kb_id = add_datasets_func[0] name = "kb_1" - payload = {"name": name, "description": "", "chunk_method": "naive"} - res = update_dataset(WebApiAuth, kb_id, payload) + payload = {"name": name, "description": "", "parser_id": "naive", "kb_id": kb_id} + res = update_kb(WebApiAuth, payload) assert res["code"] == 102, res assert res["message"] == "Duplicated dataset name.", res @@ -118,8 +118,8 @@ class TestDatasetUpdate: def test_name_case_insensitive(self, WebApiAuth, add_datasets_func): kb_id = add_datasets_func[0] name = "KB_1" - payload = {"name": name, "description": "", "chunk_method": "naive"} - res = update_dataset(WebApiAuth, kb_id, payload) + payload = {"name": name, "description": "", "parser_id": "naive", "kb_id": kb_id} + res = update_kb(WebApiAuth, payload) assert res["code"] == 102, res assert res["message"] == "Duplicated dataset name.", res @@ -130,18 +130,19 @@ class TestDatasetUpdate: payload = { "name": "avatar", "description": "", - "chunk_method": "naive", + "parser_id": "naive", + "kb_id": kb_id, "avatar": f"data:image/png;base64,{encode_avatar(fn)}", } - res = update_dataset(WebApiAuth, kb_id, payload) + res = update_kb(WebApiAuth, payload) assert res["code"] == 0, res assert res["data"]["avatar"] == f"data:image/png;base64,{encode_avatar(fn)}", res @pytest.mark.p2 def test_description(self, WebApiAuth, add_dataset_func): kb_id = add_dataset_func - payload = {"name": "description", "description": "description", "chunk_method": "naive"} - res = update_dataset(WebApiAuth, kb_id, payload) + payload = {"name": "description", "description": "description", "parser_id": "naive", "kb_id": kb_id} + res = update_kb(WebApiAuth, payload) assert res["code"] == 0, res assert res["data"]["description"] == "description", res @@ -156,10 +157,10 @@ class TestDatasetUpdate: ) def test_embedding_model(self, WebApiAuth, add_dataset_func, embedding_model): kb_id = add_dataset_func - payload = {"name": "embedding_model", "description": "", "chunk_method": "naive", "embedding_model": embedding_model} - res = update_dataset(WebApiAuth, kb_id, payload) + payload = {"name": "embedding_model", "description": "", "parser_id": "naive", "kb_id": kb_id, "embd_id": embedding_model} + res = update_kb(WebApiAuth, payload) assert res["code"] == 0, res - assert res["data"]["embedding_model"] == embedding_model, res + assert res["data"]["embd_id"] == embedding_model, res @pytest.mark.p2 @pytest.mark.parametrize( @@ -172,8 +173,8 @@ class TestDatasetUpdate: ) def test_permission(self, WebApiAuth, add_dataset_func, permission): kb_id = add_dataset_func - payload = {"name": "permission", "description": "", "chunk_method": "naive", "permission": permission} - res = update_dataset(WebApiAuth, kb_id, payload) + payload = {"name": "permission", "description": "", "parser_id": "naive", "kb_id": kb_id, "permission": permission} + res = update_kb(WebApiAuth, payload) assert res["code"] == 0, res assert res["data"]["permission"] == permission.lower().strip(), res @@ -198,17 +199,17 @@ class TestDatasetUpdate: ) def test_chunk_method(self, WebApiAuth, add_dataset_func, chunk_method): kb_id = add_dataset_func - payload = {"name": "chunk_method", "description": "", "chunk_method": chunk_method} - res = update_dataset(WebApiAuth, kb_id, payload) + payload = {"name": "chunk_method", "description": "", "parser_id": chunk_method, "kb_id": kb_id} + res = update_kb(WebApiAuth, payload) assert res["code"] == 0, res - assert res["data"]["chunk_method"] == chunk_method, res + assert res["data"]["parser_id"] == chunk_method, res @pytest.mark.p1 @pytest.mark.skipif(os.getenv("DOC_ENGINE") != "infinity", reason="Infinity does not support parser_id=tag") def test_chunk_method_tag_with_infinity(self, WebApiAuth, add_dataset_func): kb_id = add_dataset_func - payload = {"name": "chunk_method", "description": "", "chunk_method": "tag"} - res = update_dataset(WebApiAuth, kb_id, payload) + payload = {"name": "chunk_method", "description": "", "parser_id": "tag", "kb_id": kb_id} + res = update_kb(WebApiAuth, payload) assert res["code"] == 103, res assert res["message"] == "The chunking method Tag has not been supported by Infinity yet.", res @@ -217,8 +218,8 @@ class TestDatasetUpdate: @pytest.mark.parametrize("pagerank", [0, 50, 100], ids=["min", "mid", "max"]) def test_pagerank(self, WebApiAuth, add_dataset_func, pagerank): kb_id = add_dataset_func - payload = {"name": "pagerank", "description": "", "chunk_method": "naive", "pagerank": pagerank} - res = update_dataset(WebApiAuth, kb_id, payload) + payload = {"name": "pagerank", "description": "", "parser_id": "naive", "kb_id": kb_id, "pagerank": pagerank} + res = update_kb(WebApiAuth, payload) assert res["code"] == 0, res assert res["data"]["pagerank"] == pagerank, res @@ -226,13 +227,13 @@ class TestDatasetUpdate: @pytest.mark.p2 def test_pagerank_set_to_0(self, WebApiAuth, add_dataset_func): kb_id = add_dataset_func - payload = {"name": "pagerank", "description": "", "chunk_method": "naive", "pagerank": 50} - res = update_dataset(WebApiAuth, kb_id, payload) + payload = {"name": "pagerank", "description": "", "parser_id": "naive", "kb_id": kb_id, "pagerank": 50} + res = update_kb(WebApiAuth, payload) assert res["code"] == 0, res assert res["data"]["pagerank"] == 50, res - payload = {"name": "pagerank", "description": "", "chunk_method": "naive", "pagerank": 0} - res = update_dataset(WebApiAuth, kb_id, payload) + payload = {"name": "pagerank", "description": "", "parser_id": "naive", "kb_id": kb_id, "pagerank": 0} + res = update_kb(WebApiAuth, payload) assert res["code"] == 0, res assert res["data"]["pagerank"] == 0, res @@ -240,8 +241,8 @@ class TestDatasetUpdate: @pytest.mark.p2 def test_pagerank_infinity(self, WebApiAuth, add_dataset_func): kb_id = add_dataset_func - payload = {"name": "pagerank", "description": "", "chunk_method": "naive", "pagerank": 50} - res = update_dataset(WebApiAuth, kb_id, payload) + payload = {"name": "pagerank", "description": "", "parser_id": "naive", "kb_id": kb_id, "pagerank": 50} + res = update_kb(WebApiAuth, payload) assert res["code"] == 102, res assert res["message"] == "'pagerank' can only be set when doc_engine is elasticsearch", res @@ -351,15 +352,10 @@ class TestDatasetUpdate: ) def test_parser_config(self, WebApiAuth, add_dataset_func, parser_config): kb_id = add_dataset_func - payload = {"name": "parser_config", "description": "", "chunk_method": "naive", "parser_config": parser_config} - res = update_dataset(WebApiAuth, kb_id, payload) + payload = {"name": "parser_config", "description": "", "parser_id": "naive", "kb_id": kb_id, "parser_config": parser_config} + res = update_kb(WebApiAuth, payload) assert res["code"] == 0, res - for key, value in parser_config.items(): - if not isinstance(value, dict): - assert res["data"]["parser_config"].get(key) == value, res - else: - for sub_key, sub_value in value.items(): - assert res["data"]["parser_config"].get(key, {}).get(sub_key) == sub_value, res + assert res["data"]["parser_config"] == parser_config, res @pytest.mark.p2 @pytest.mark.parametrize( @@ -376,7 +372,7 @@ class TestDatasetUpdate: ) def test_field_unsupported(self, WebApiAuth, add_dataset_func, payload): kb_id = add_dataset_func - full_payload = {"name": "field_unsupported", "description": "", "chunk_method": "naive", **payload} - res = update_dataset(WebApiAuth, kb_id, full_payload) + full_payload = {"name": "field_unsupported", "description": "", "parser_id": "naive", "kb_id": kb_id, **payload} + res = update_kb(WebApiAuth, full_payload) assert res["code"] == 101, res - assert "are not permitted" in res["message"], res + assert "isn't allowed" in res["message"], res diff --git a/web/src/components/chunk-method-dialog/index.tsx b/web/src/components/chunk-method-dialog/index.tsx index 34dfb60e2..c845fda35 100644 --- a/web/src/components/chunk-method-dialog/index.tsx +++ b/web/src/components/chunk-method-dialog/index.tsx @@ -181,7 +181,7 @@ export function ChunkMethodDialog({ }); const selectedTag = useWatch({ - name: 'chunk_method', + name: 'parser_id', control: form.control, }); const isMineruSelected = diff --git a/web/src/components/knowledge-base-item.tsx b/web/src/components/knowledge-base-item.tsx index 90faf38f3..9ced7bfc2 100644 --- a/web/src/components/knowledge-base-item.tsx +++ b/web/src/components/knowledge-base-item.tsx @@ -23,7 +23,7 @@ export function useDisableDifferenceEmbeddingDataset() { useEffect(() => { const datasetListMap = datasetListOrigin - .filter((x) => x.chunk_method !== DocumentParserType.Tag) + .filter((x) => x.parser_id !== DocumentParserType.Tag) .map((item: IKnowledge) => { return { label: item.name, @@ -36,12 +36,12 @@ export function useDisableDifferenceEmbeddingDataset() { ), suffix: (
- {item.embedding_model} + {item.embd_id}
), value: item.id, disabled: - item.embedding_model !== datasetSelectEmbedId && + item.embd_id !== datasetSelectEmbedId && datasetSelectEmbedId !== '', }; }); @@ -54,7 +54,7 @@ export function useDisableDifferenceEmbeddingDataset() { ) => { if (value.length) { const data = datasetListOrigin?.find((item) => item.id === value[0]); - setDatasetSelectEmbedId(data?.embedding_model ?? ''); + setDatasetSelectEmbedId(data?.embd_id ?? ''); } else { setDatasetSelectEmbedId(''); } diff --git a/web/src/components/ui/multi-select.tsx b/web/src/components/ui/multi-select.tsx index 287ec26e4..b4464aff8 100644 --- a/web/src/components/ui/multi-select.tsx +++ b/web/src/components/ui/multi-select.tsx @@ -242,9 +242,7 @@ export const MultiSelect = React.forwardRef< const disabledValueSet = React.useMemo(() => { return new Set( - flatOptions - .filter((option) => option.disabled) - .map((option) => option.value), + flatOptions.filter((option) => option.disabled).map((option) => option.value), ); }, [flatOptions]); diff --git a/web/src/hooks/use-knowledge-request.ts b/web/src/hooks/use-knowledge-request.ts index 4985a21b6..d7294b3a3 100644 --- a/web/src/hooks/use-knowledge-request.ts +++ b/web/src/hooks/use-knowledge-request.ts @@ -18,7 +18,6 @@ import kbService, { listTag, removeTag, renameTag, - updateKb, } from '@/services/knowledge-service'; import { useIsMutating, @@ -138,20 +137,22 @@ export const useFetchNextKnowledgeListByPage = () => { ], initialData: { kbs: [], - total_datasets: 0, + total: 0, }, gcTime: 0, queryFn: async () => { - const { data } = await listDataset({ - page_size: pagination.pageSize, - page: pagination.current, - ext: { + const { data } = await listDataset( + { keywords: debouncedSearchString, + page_size: pagination.pageSize, + page: pagination.current, + }, + { owner_ids: filterValue.owner, }, - }); + ); - return { kbs: data?.data, total_datasets: data?.total_datasets }; + return data?.data; }, }); @@ -167,7 +168,7 @@ export const useFetchNextKnowledgeListByPage = () => { ...data, searchString, handleInputChange: onInputChange, - pagination: { ...pagination, total: data?.total_datasets }, + pagination: { ...pagination, total: data?.total }, setPagination, loading, filterValue, @@ -183,18 +184,7 @@ export const useCreateKnowledge = () => { mutateAsync, } = useMutation({ mutationKey: [KnowledgeApiAction.CreateKnowledge], - mutationFn: async (params: { - id?: string; - name: string; - embedding_model?: string; - chunk_method?: string; - parseType?: number; - pipeline_id?: string | null; - ext?: { - language?: string; - [key: string]: any; - }; - }) => { + mutationFn: async (params: { id?: string; name: string }) => { const { data = {} } = await kbService.createKb(params); if (data.code === 0) { message.success( @@ -218,7 +208,7 @@ export const useDeleteKnowledge = () => { } = useMutation({ mutationKey: [KnowledgeApiAction.DeleteKnowledge], mutationFn: async (id: string) => { - const { data } = await kbService.rmKb({ ids: [id] }); + const { data } = await kbService.rmKb({ kb_id: id }); if (data.code === 0) { message.success(i18n.t(`message.deleted`)); queryClient.invalidateQueries({ @@ -235,119 +225,17 @@ export const useDeleteKnowledge = () => { export const useUpdateKnowledge = (shouldFetchList = false) => { const knowledgeBaseId = useKnowledgeBaseId(); const queryClient = useQueryClient(); - - const extractRaptorConfigExt = ( - raptorConfig: Record | undefined, - ) => { - if (!raptorConfig) return raptorConfig; - const { - use_raptor, - prompt, - max_token, - threshold, - max_cluster, - random_seed, - auto_disable_for_structured_data, - ext, - ...raptorExt - } = raptorConfig; - return { - use_raptor, - prompt, - max_token, - threshold, - max_cluster, - random_seed, - auto_disable_for_structured_data, - ext: { ...ext, ...raptorExt }, - }; - }; - - const extractParserConfigExt = ( - parserConfig: Record | undefined, - ) => { - if (!parserConfig) return parserConfig; - const { - auto_keywords, - auto_questions, - chunk_token_num, - delimiter, - graphrag, - html4excel, - layout_recognize, - raptor, - tag_kb_ids, - topn_tags, - filename_embd_weight, - task_page_size, - pages, - ext, - ...parserExt - } = parserConfig; - return { - auto_keywords, - auto_questions, - chunk_token_num, - delimiter, - graphrag, - html4excel, - layout_recognize, - raptor: extractRaptorConfigExt(raptor), - tag_kb_ids, - topn_tags, - filename_embd_weight, - task_page_size, - pages, - ext: { ...ext, ...parserExt }, - }; - }; - const { data, isPending: loading, mutateAsync, } = useMutation({ mutationKey: [KnowledgeApiAction.SaveKnowledge], - mutationFn: async (params: { - kb_id?: string; - name?: string; - embedding_model?: string; - chunk_method?: string; - pipeline_id?: string | null; - avatar?: string | null; - description?: string; - permission?: string; - pagerank?: number; - parser_config?: Record; - [key: string]: any; - }) => { - const kbId = params?.kb_id || knowledgeBaseId; - const { - kb_id, - name, - embedding_model, - chunk_method, - pipeline_id, - avatar, - description, - permission, - pagerank, - parser_config, - ...ext - } = params; - const requestBody: Record = { - name, - embedding_model, - chunk_method, - pipeline_id, - avatar, - description, - permission, - pagerank, - parser_config: extractParserConfigExt(parser_config), - ext, - }; - const { data = {} } = await updateKb(kbId, requestBody); + mutationFn: async (params: Record) => { + const { data = {} } = await kbService.updateKb({ + kb_id: params?.kb_id ? params?.kb_id : knowledgeBaseId, + ...params, + }); if (data.code === 0) { message.success(i18n.t(`message.updated`)); if (shouldFetchList) { @@ -471,9 +359,9 @@ export const useFetchKnowledgeList = ( gcTime: 0, // https://tanstack.com/query/latest/docs/framework/react/guides/caching?from=reactQueryV3 queryFn: async () => { const { data } = await listDataset(); - const list = data?.data ?? []; + const list = data?.data?.kbs ?? []; return shouldFilterListWithoutDocument - ? list.filter((x: IKnowledge) => x.chunk_count > 0) + ? list.filter((x: IKnowledge) => x.chunk_num > 0) : list; }, }); diff --git a/web/src/interfaces/database/knowledge.ts b/web/src/interfaces/database/knowledge.ts index 6ef02986b..f68a6f5b0 100644 --- a/web/src/interfaces/database/knowledge.ts +++ b/web/src/interfaces/database/knowledge.ts @@ -11,16 +11,16 @@ export interface IConnector { // knowledge base export interface IKnowledge { avatar?: any; - chunk_count: number; + chunk_num: number; create_date: string; create_time: number; created_by: string; description: string; - document_count: number; + doc_num: number; id: string; name: string; parser_config: ParserConfig; - chunk_method: string; + parser_id: string; pipeline_id: string; pipeline_name: string; pipeline_avatar: string; @@ -32,7 +32,7 @@ export interface IKnowledge { update_date: string; update_time: number; vector_similarity_weight: number; - embedding_model: string; + embd_id: string; nickname: string; operator_permission: number; size: number; @@ -47,7 +47,7 @@ export interface IKnowledge { export interface IKnowledgeResult { kbs: IKnowledge[]; - total_datasets: number; + total: number; } export interface Raptor { diff --git a/web/src/interfaces/request/knowledge.ts b/web/src/interfaces/request/knowledge.ts index f93c5dd93..8690c1062 100644 --- a/web/src/interfaces/request/knowledge.ts +++ b/web/src/interfaces/request/knowledge.ts @@ -24,14 +24,10 @@ export interface IFetchKnowledgeListRequestBody { } export interface IFetchKnowledgeListRequestParams { - id?: string; + kb_id?: string; + keywords?: string; page?: number; page_size?: number; - ext?: { - keywords?: string; - owner_ids?: string[]; - parser_id?: string; - }; } export interface IFetchDocumentListRequestBody { diff --git a/web/src/pages/dataset/dataset-setting/chunk-method-form.tsx b/web/src/pages/dataset/dataset-setting/chunk-method-form.tsx index 3f48eb39c..8d6debc16 100644 --- a/web/src/pages/dataset/dataset-setting/chunk-method-form.tsx +++ b/web/src/pages/dataset/dataset-setting/chunk-method-form.tsx @@ -45,7 +45,7 @@ export function ChunkMethodForm() { const finalParserId: DocumentParserType = useWatch({ control: form.control, - name: 'chunk_method', + name: 'parser_id', }); const ConfigurationComponent = useMemo(() => { diff --git a/web/src/pages/dataset/dataset-setting/configuration/common-item.tsx b/web/src/pages/dataset/dataset-setting/configuration/common-item.tsx index 70446e5a3..0ce24fc29 100644 --- a/web/src/pages/dataset/dataset-setting/configuration/common-item.tsx +++ b/web/src/pages/dataset/dataset-setting/configuration/common-item.tsx @@ -69,7 +69,7 @@ export function ChunkMethodItem(props: IProps) { return ( (
@@ -121,7 +121,7 @@ export const EmbeddingSelect = ({ const { handleChange } = useHandleKbEmbedding(); const oldValue = useMemo(() => { - const embdStr = form.getValues(name || 'embedding_model'); + const embdStr = form.getValues(name || 'embd_id'); return embdStr || ''; }, [form]); const [loading, setLoading] = useState(false); @@ -165,7 +165,7 @@ export function EmbeddingModelItem({ line = 1, isEdit }: IProps) { <> (
0; + return knowledgeDetails.chunk_num > 0; } export const useFetchKnowledgeConfigurationOnMount = ( @@ -60,14 +60,14 @@ export const useFetchKnowledgeConfigurationOnMount = ( 'description', 'name', 'permission', + 'embd_id', + 'parser_id', 'language', 'parser_config', 'connectors', 'pagerank', 'avatar', ]), - embedding_model: knowledgeDetails.embd_id, - chunk_method: knowledgeDetails.parser_id, } as z.infer; form.reset(formValues); }, [form, knowledgeDetails]); diff --git a/web/src/pages/dataset/dataset-setting/index.tsx b/web/src/pages/dataset/dataset-setting/index.tsx index 53e2048f8..5cce7049a 100644 --- a/web/src/pages/dataset/dataset-setting/index.tsx +++ b/web/src/pages/dataset/dataset-setting/index.tsx @@ -219,7 +219,7 @@ export default function DatasetSettings() { defaultValue: knowledgeDetails.pipeline_id ? 2 : 1, }); const selectedTag = useWatch({ - name: 'chunk_method', + name: 'parser_id', control: form.control, }); useEffect(() => { diff --git a/web/src/pages/dataset/dataset-setting/saving-button.tsx b/web/src/pages/dataset/dataset-setting/saving-button.tsx index 76ceb1eef..8d77ca57a 100644 --- a/web/src/pages/dataset/dataset-setting/saving-button.tsx +++ b/web/src/pages/dataset/dataset-setting/saving-button.tsx @@ -16,7 +16,7 @@ export function GeneralSavingButton() { () => form.formState.defaultValues ?? {}, [form.formState.defaultValues], ); - const chunk_method = defaultValues['chunk_method']; + const parser_id = defaultValues['parser_id']; return ( { retryDelay: 1000, enabled: open, queryFn: async () => { - const { data } = await traceGraphRag(id); + const { data } = await kbService.traceGraphRag({ + kb_id: id, + }); return data?.data || {}; }, }); @@ -74,7 +70,9 @@ export const useTraceGenerate = ({ open }: { open: boolean }) => { retryDelay: 1000, enabled: open, queryFn: async () => { - const { data } = await traceRaptor(id); + const { data } = await kbService.traceRaptor({ + kb_id: id, + }); return data?.data || {}; }, }); @@ -135,8 +133,12 @@ export const useDatasetGenerate = () => { mutationKey: [DatasetKey.generate], mutationFn: async ({ type }: { type: GenerateType }) => { const func = - type === GenerateType.KnowledgeGraph ? runGraphRag : runRaptor; - const { data } = await func(id); + type === GenerateType.KnowledgeGraph + ? kbService.runGraphRag + : kbService.runRaptor; + const { data } = await func({ + kb_id: id, + }); if (data.code === 0) { message.success(t('message.operated')); queryClient.invalidateQueries({ diff --git a/web/src/pages/datasets/dataset-card.tsx b/web/src/pages/datasets/dataset-card.tsx index 338951b27..d49fd6308 100644 --- a/web/src/pages/datasets/dataset-card.tsx +++ b/web/src/pages/datasets/dataset-card.tsx @@ -23,7 +23,7 @@ export function DatasetCard({ ) { }) .trim(), parseType: z.number().optional(), - embedding_model: z + embd_id: z .string() .min(1, { message: t('knowledgeConfiguration.embeddingModelPlaceholder'), }) .trim(), - chunk_method: z.string().optional(), + parser_id: z.string().optional(), pipeline_id: z.string().optional(), }) .superRefine((data, ctx) => { - // When parseType === 1, chunk_method is required + // When parseType === 1, parser_id is required if ( data.parseType === 1 && - (!data.chunk_method || data.chunk_method.trim() === '') + (!data.parser_id || data.parser_id.trim() === '') ) { ctx.addIssue({ code: z.ZodIssueCode.custom, @@ -82,8 +82,8 @@ export function InputForm({ onOk }: IModalProps) { defaultValues: { name: '', parseType: 1, - chunk_method: '', - embedding_model: tenantInfo?.embd_id, + parser_id: '', + embd_id: tenantInfo?.embd_id, }, }); diff --git a/web/src/pages/datasets/hooks.ts b/web/src/pages/datasets/hooks.ts index 140976ee7..d194af9e1 100644 --- a/web/src/pages/datasets/hooks.ts +++ b/web/src/pages/datasets/hooks.ts @@ -16,14 +16,8 @@ export const useSearchKnowledge = () => { export interface Iknowledge { name: string; - embedding_model?: string; - chunk_method?: string; - parseType?: number; - pipeline_id?: string | null; - ext?: { - language?: string; - [key: string]: any; - }; + embd_id: string; + parser_id: string; } export const useSaveKnowledge = () => { const { visible: visible, hideModal, showModal } = useSetModalState(); @@ -36,7 +30,7 @@ export const useSaveKnowledge = () => { if (ret?.code === 0) { hideModal(); - navigateToDataset(ret.data.id)(); + navigateToDataset(ret.data.kb_id)(); } }, [createKnowledge, hideModal, navigateToDataset], diff --git a/web/src/pages/datasets/index.tsx b/web/src/pages/datasets/index.tsx index 85fa19226..05ff2a78d 100644 --- a/web/src/pages/datasets/index.tsx +++ b/web/src/pages/datasets/index.tsx @@ -30,7 +30,7 @@ export default function Datasets() { const { kbs, - total_datasets, + total, pagination, setPagination, handleInputChange, @@ -107,7 +107,7 @@ export default function Datasets() {
diff --git a/web/src/services/knowledge-service.ts b/web/src/services/knowledge-service.ts index e727fc500..4c1225990 100644 --- a/web/src/services/knowledge-service.ts +++ b/web/src/services/knowledge-service.ts @@ -1,6 +1,7 @@ import { IRenameTag } from '@/interfaces/database/knowledge'; import { IFetchDocumentListRequestBody, + IFetchKnowledgeListRequestBody, IFetchKnowledgeListRequestParams, } from '@/interfaces/request/knowledge'; import { ProcessingType } from '@/pages/dataset/dataset-overview/dataset-common'; @@ -10,6 +11,7 @@ import request, { post } from '@/utils/request'; const { create_kb, + update_kb, rm_kb, get_kb_detail, kb_list, @@ -40,6 +42,10 @@ const { getKnowledgeBasicInfo, fetchDataPipelineLog, fetchPipelineDatasetLogs, + runGraphRag, + traceGraphRag, + runRaptor, + traceRaptor, check_embedding, kbUpdateMetaData, documentUpdateMetaData, @@ -50,9 +56,13 @@ const methods = { url: create_kb, method: 'post', }, + updateKb: { + url: update_kb, + method: 'post', + }, rmKb: { url: rm_kb, - method: 'delete', + method: 'post', }, get_kb_detail: { url: get_kb_detail, @@ -60,7 +70,7 @@ const methods = { }, getList: { url: kb_list, - method: 'get', + method: 'post', }, // document manager get_document_list: { @@ -181,6 +191,22 @@ const methods = { method: 'get', }, + runGraphRag: { + url: runGraphRag, + method: 'post', + }, + traceGraphRag: { + url: traceGraphRag, + method: 'get', + }, + runRaptor: { + url: runRaptor, + method: 'post', + }, + traceRaptor: { + url: traceRaptor, + method: 'get', + }, pipelineRerun: { url: api.pipelineRerun, method: 'post', @@ -225,23 +251,10 @@ export function deleteKnowledgeGraph(knowledgeId: string) { return request.delete(api.getKnowledgeGraph(knowledgeId)); } -export const listDataset = (params?: IFetchKnowledgeListRequestParams) => - request.get(api.kb_list, { params }); - -export const updateKb = (datasetId: string, data: Record) => - request.put(api.update_kb(datasetId), { data }); - -export const runGraphRag = (datasetId: string) => - request.post(api.runGraphRag(datasetId)); - -export const traceGraphRag = (datasetId: string) => - request.get(api.traceGraphRag(datasetId)); - -export const runRaptor = (datasetId: string) => - request.post(api.runRaptor(datasetId)); - -export const traceRaptor = (datasetId: string) => - request.get(api.traceRaptor(datasetId)); +export const listDataset = ( + params?: IFetchKnowledgeListRequestParams, + body?: IFetchKnowledgeListRequestBody, +) => request.post(api.kb_list, { data: body || {}, params }); export const listDocument = ( params?: IFetchKnowledgeListRequestParams, diff --git a/web/src/utils/api.ts b/web/src/utils/api.ts index 5750c9aee..20440065a 100644 --- a/web/src/utils/api.ts +++ b/web/src/utils/api.ts @@ -57,30 +57,23 @@ export default { // knowledge base check_embedding: `${api_host}/kb/check_embedding`, - kb_list: `${ExternalApi}${api_host}/datasets`, - create_kb: `${ExternalApi}${api_host}/datasets`, - update_kb: (datasetId: string) => - `${ExternalApi}${api_host}/datasets/${datasetId}`, - rm_kb: `${ExternalApi}${api_host}/datasets`, + kb_list: `${api_host}/kb/list`, + create_kb: `${api_host}/kb/create`, + update_kb: `${api_host}/kb/update`, + rm_kb: `${api_host}/kb/rm`, get_kb_detail: `${api_host}/kb/detail`, getKnowledgeGraph: (knowledgeId: string) => - `${ExternalApi}${api_host}/datasets/${knowledgeId}/knowledge_graph`, - deleteKnowledgeGraph: (knowledgeId: string) => - `${ExternalApi}${api_host}/datasets/${knowledgeId}/knowledge_graph`, + `${api_host}/kb/${knowledgeId}/knowledge_graph`, getMeta: `${api_host}/kb/get_meta`, getKnowledgeBasicInfo: `${api_host}/kb/basic_info`, // data pipeline log fetchDataPipelineLog: `${api_host}/kb/list_pipeline_logs`, get_pipeline_detail: `${api_host}/kb/pipeline_log_detail`, fetchPipelineDatasetLogs: `${api_host}/kb/list_pipeline_dataset_logs`, - runGraphRag: (datasetId: string) => - `${ExternalApi}${api_host}/datasets/${datasetId}/run_graphrag`, - traceGraphRag: (datasetId: string) => - `${ExternalApi}${api_host}/datasets/${datasetId}/trace_graphrag`, - runRaptor: (datasetId: string) => - `${ExternalApi}${api_host}/datasets/${datasetId}/run_raptor`, - traceRaptor: (datasetId: string) => - `${ExternalApi}${api_host}/datasets/${datasetId}/trace_raptor`, + runGraphRag: `${api_host}/kb/run_graphrag`, + traceGraphRag: `${api_host}/kb/trace_graphrag`, + runRaptor: `${api_host}/kb/run_raptor`, + traceRaptor: `${api_host}/kb/trace_raptor`, unbindPipelineTask: ({ kb_id, type }: { kb_id: string; type: string }) => `${api_host}/kb/unbind_task?kb_id=${kb_id}&pipeline_task_type=${type}`, pipelineRerun: `${api_host}/canvas/rerun`,