mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 09:28:04 +08:00
Knowledge optimization (#3755)
Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: JzoNg <jzongcode@gmail.com>
This commit is contained in:
@ -53,5 +53,8 @@ from .explore import (
|
||||
workflow,
|
||||
)
|
||||
|
||||
# Import tag controllers
|
||||
from .tag import tags
|
||||
|
||||
# Import workspace controllers
|
||||
from .workspace import account, members, model_providers, models, tool_providers, workspace
|
||||
|
||||
@ -1,18 +1,25 @@
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, inputs, marshal_with, reqparse
|
||||
from werkzeug.exceptions import BadRequest, Forbidden
|
||||
from flask_restful import Resource, inputs, marshal, marshal_with, reqparse
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, abort
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from fields.app_fields import (
|
||||
app_detail_fields,
|
||||
app_detail_fields_with_site,
|
||||
app_pagination_fields,
|
||||
)
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
from services.app_service import AppService
|
||||
from services.tag_service import TagService
|
||||
|
||||
ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion']
|
||||
|
||||
@ -22,21 +29,29 @@ class AppListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_pagination_fields)
|
||||
def get(self):
|
||||
"""Get app list"""
|
||||
def uuid_list(value):
|
||||
try:
|
||||
return [str(uuid.UUID(v)) for v in value.split(',')]
|
||||
except ValueError:
|
||||
abort(400, message="Invalid UUID format in tag_ids.")
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args')
|
||||
parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('mode', type=str, choices=['chat', 'workflow', 'agent-chat', 'channel', 'all'], default='all', location='args', required=False)
|
||||
parser.add_argument('name', type=str, location='args', required=False)
|
||||
parser.add_argument('tag_ids', type=uuid_list, location='args', required=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# get app list
|
||||
app_service = AppService()
|
||||
app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args)
|
||||
if not app_pagination:
|
||||
return {'data': [], 'total': 0, 'page': 1, 'limit': 20, 'has_more': False}
|
||||
|
||||
return app_pagination
|
||||
return marshal(app_pagination, app_pagination_fields)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
|
||||
@ -48,11 +48,14 @@ class DatasetListApi(Resource):
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
ids = request.args.getlist('ids')
|
||||
provider = request.args.get('provider', default="vendor")
|
||||
search = request.args.get('keyword', default=None, type=str)
|
||||
tag_ids = request.args.getlist('tag_ids')
|
||||
|
||||
if ids:
|
||||
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
|
||||
else:
|
||||
datasets, total = DatasetService.get_datasets(page, limit, provider,
|
||||
current_user.current_tenant_id, current_user)
|
||||
current_user.current_tenant_id, current_user, search, tag_ids)
|
||||
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
@ -184,6 +187,10 @@ class DatasetApi(Resource):
|
||||
help='Invalid indexing technique.')
|
||||
parser.add_argument('permission', type=str, location='json', choices=(
|
||||
'only_me', 'all_team_members'), help='Invalid permission.')
|
||||
parser.add_argument('embedding_model', type=str,
|
||||
location='json', help='Invalid embedding model.')
|
||||
parser.add_argument('embedding_model_provider', type=str,
|
||||
location='json', help='Invalid embedding model provider.')
|
||||
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -506,10 +513,27 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
else:
|
||||
raise ValueError("Unsupported vector db type.")
|
||||
|
||||
class DatasetErrorDocs(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str)
|
||||
|
||||
return {
|
||||
'data': [marshal(item, document_status_fields) for item in results],
|
||||
'total': len(results)
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(DatasetListApi, '/datasets')
|
||||
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
|
||||
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
|
||||
api.add_resource(DatasetErrorDocs, '/datasets/<uuid:dataset_id>/error-docs')
|
||||
api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate')
|
||||
api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps')
|
||||
api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status')
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from flask import request
|
||||
@ -233,7 +234,7 @@ class DatasetDocumentListApi(Resource):
|
||||
location='json')
|
||||
parser.add_argument('data_source', type=dict, required=False, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=False, location='json')
|
||||
parser.add_argument('duplicate', type=bool, nullable=False, location='json')
|
||||
parser.add_argument('duplicate', type=bool, default=True, nullable=False, location='json')
|
||||
parser.add_argument('original_document_id', type=str, required=False, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
|
||||
@ -883,6 +884,49 @@ class DocumentRecoverApi(DocumentResource):
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
class DocumentRetryApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id):
|
||||
"""retry document."""
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('document_ids', type=list, required=True, nullable=False,
|
||||
location='json')
|
||||
args = parser.parse_args()
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
retry_documents = []
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
for document_id in args['document_ids']:
|
||||
try:
|
||||
document_id = str(document_id)
|
||||
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
|
||||
# 404 if document not found
|
||||
if document is None:
|
||||
raise NotFound("Document Not Exists.")
|
||||
|
||||
# 403 if document is archived
|
||||
if DocumentService.check_archived(document):
|
||||
raise ArchivedDocumentImmutableError()
|
||||
|
||||
# 400 if document is completed
|
||||
if document.indexing_status == 'completed':
|
||||
raise DocumentAlreadyFinishedError()
|
||||
retry_documents.append(document)
|
||||
except Exception as e:
|
||||
logging.error(f"Document {document_id} retry failed: {str(e)}")
|
||||
continue
|
||||
# retry document
|
||||
DocumentService.retry_document(dataset_id, retry_documents)
|
||||
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
api.add_resource(GetProcessRuleApi, '/datasets/process-rule')
|
||||
api.add_resource(DatasetDocumentListApi,
|
||||
'/datasets/<uuid:dataset_id>/documents')
|
||||
@ -908,3 +952,4 @@ api.add_resource(DocumentStatusApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/status/<string:action>')
|
||||
api.add_resource(DocumentPauseApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause')
|
||||
api.add_resource(DocumentRecoverApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume')
|
||||
api.add_resource(DocumentRetryApi, '/datasets/<uuid:dataset_id>/retry')
|
||||
|
||||
159
api/controllers/console/tag/tags.py
Normal file
159
api/controllers/console/tag/tags.py
Normal file
@ -0,0 +1,159 @@
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from fields.tag_fields import tag_fields
|
||||
from libs.login import login_required
|
||||
from models.model import Tag
|
||||
from services.tag_service import TagService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError('Name must be between 1 to 50 characters.')
|
||||
return name
|
||||
|
||||
|
||||
class TagListApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(tag_fields)
|
||||
def get(self):
|
||||
tag_type = request.args.get('type', type=str)
|
||||
keyword = request.args.get('keyword', default=None, type=str)
|
||||
tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword)
|
||||
|
||||
return tags, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', nullable=False, required=True,
|
||||
help='Name must be between 1 to 50 characters.',
|
||||
type=_validate_name)
|
||||
parser.add_argument('type', type=str, location='json',
|
||||
choices=Tag.TAG_TYPE_LIST,
|
||||
nullable=True,
|
||||
help='Invalid tag type.')
|
||||
args = parser.parse_args()
|
||||
tag = TagService.save_tags(args)
|
||||
|
||||
response = {
|
||||
'id': tag.id,
|
||||
'name': tag.name,
|
||||
'type': tag.type,
|
||||
'binding_count': 0
|
||||
}
|
||||
|
||||
return response, 200
|
||||
|
||||
|
||||
class TagUpdateDeleteApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, tag_id):
|
||||
tag_id = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', nullable=False, required=True,
|
||||
help='Name must be between 1 to 50 characters.',
|
||||
type=_validate_name)
|
||||
args = parser.parse_args()
|
||||
tag = TagService.update_tags(args, tag_id)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
response = {
|
||||
'id': tag.id,
|
||||
'name': tag.name,
|
||||
'type': tag.type,
|
||||
'binding_count': binding_count
|
||||
}
|
||||
|
||||
return response, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, tag_id):
|
||||
tag_id = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
TagService.delete_tag(tag_id)
|
||||
|
||||
return 200
|
||||
|
||||
|
||||
class TagBindingCreateApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('tag_ids', type=list, nullable=False, required=True, location='json',
|
||||
help='Tag IDs is required.')
|
||||
parser.add_argument('target_id', type=str, nullable=False, required=True, location='json',
|
||||
help='Target ID is required.')
|
||||
parser.add_argument('type', type=str, location='json',
|
||||
choices=Tag.TAG_TYPE_LIST,
|
||||
nullable=True,
|
||||
help='Invalid tag type.')
|
||||
args = parser.parse_args()
|
||||
TagService.save_tag_binding(args)
|
||||
|
||||
return 200
|
||||
|
||||
|
||||
class TagBindingDeleteApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('tag_id', type=str, nullable=False, required=True,
|
||||
help='Tag ID is required.')
|
||||
parser.add_argument('target_id', type=str, nullable=False, required=True,
|
||||
help='Target ID is required.')
|
||||
parser.add_argument('type', type=str, location='json',
|
||||
choices=Tag.TAG_TYPE_LIST,
|
||||
nullable=True,
|
||||
help='Invalid tag type.')
|
||||
args = parser.parse_args()
|
||||
TagService.delete_tag_binding(args)
|
||||
|
||||
return 200
|
||||
|
||||
|
||||
api.add_resource(TagListApi, '/tags')
|
||||
api.add_resource(TagUpdateDeleteApi, '/tags/<uuid:tag_id>')
|
||||
api.add_resource(TagBindingCreateApi, '/tag-bindings/create')
|
||||
api.add_resource(TagBindingDeleteApi, '/tag-bindings/remove')
|
||||
@ -26,8 +26,11 @@ class DatasetApi(DatasetApiResource):
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
provider = request.args.get('provider', default="vendor")
|
||||
search = request.args.get('keyword', default=None, type=str)
|
||||
tag_ids = request.args.getlist('tag_ids')
|
||||
|
||||
datasets, total = DatasetService.get_datasets(page, limit, provider,
|
||||
tenant_id, current_user)
|
||||
tenant_id, current_user, search, tag_ids)
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
configurations = provider_manager.get_configurations(
|
||||
|
||||
@ -110,19 +110,37 @@ class MilvusVector(BaseVector):
|
||||
return None
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
alias = uuid4().hex
|
||||
if self._client_config.secure:
|
||||
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
||||
else:
|
||||
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
||||
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
|
||||
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||
from pymilvus import utility
|
||||
if utility.has_collection(self._collection_name, using=alias):
|
||||
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||
|
||||
def delete_by_ids(self, doc_ids: list[str]) -> None:
|
||||
alias = uuid4().hex
|
||||
if self._client_config.secure:
|
||||
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
||||
else:
|
||||
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
||||
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
|
||||
|
||||
result = self._client.query(collection_name=self._collection_name,
|
||||
filter=f'metadata["doc_id"] in {doc_ids}',
|
||||
output_fields=["id"])
|
||||
if result:
|
||||
ids = [item["id"] for item in result]
|
||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||
from pymilvus import utility
|
||||
if utility.has_collection(self._collection_name, using=alias):
|
||||
|
||||
result = self._client.query(collection_name=self._collection_name,
|
||||
filter=f'metadata["doc_id"] in {doc_ids}',
|
||||
output_fields=["id"])
|
||||
if result:
|
||||
ids = [item["id"] for item in result]
|
||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||
|
||||
def delete(self) -> None:
|
||||
alias = uuid4().hex
|
||||
|
||||
@ -217,29 +217,38 @@ class QdrantVector(BaseVector):
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key=f"metadata.{key}",
|
||||
match=models.MatchValue(value=value),
|
||||
try:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key=f"metadata.{key}",
|
||||
match=models.MatchValue(value=value),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
self._reload_if_needed()
|
||||
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(
|
||||
filter=filter
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
self._reload_if_needed()
|
||||
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(
|
||||
filter=filter
|
||||
),
|
||||
)
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
return
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
raise e
|
||||
|
||||
def delete(self):
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
|
||||
try:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
@ -257,29 +266,40 @@ class QdrantVector(BaseVector):
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
if e.status_code == 404:
|
||||
return
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
raise e
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
for node_id in ids:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.doc_id",
|
||||
match=models.MatchValue(value=node_id),
|
||||
try:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.doc_id",
|
||||
match=models.MatchValue(value=node_id),
|
||||
),
|
||||
],
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(
|
||||
filter=filter
|
||||
),
|
||||
],
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(
|
||||
filter=filter
|
||||
),
|
||||
)
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
return
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
raise e
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
all_collection_name = []
|
||||
|
||||
@ -121,18 +121,20 @@ class WeaviateVector(BaseVector):
|
||||
return ids
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
# check whether the index already exists
|
||||
schema = self._default_schema(self._collection_name)
|
||||
if self._client.schema.contains(schema):
|
||||
where_filter = {
|
||||
"operator": "Equal",
|
||||
"path": [key],
|
||||
"valueText": value
|
||||
}
|
||||
|
||||
where_filter = {
|
||||
"operator": "Equal",
|
||||
"path": [key],
|
||||
"valueText": value
|
||||
}
|
||||
|
||||
self._client.batch.delete_objects(
|
||||
class_name=self._collection_name,
|
||||
where=where_filter,
|
||||
output='minimal'
|
||||
)
|
||||
self._client.batch.delete_objects(
|
||||
class_name=self._collection_name,
|
||||
where=where_filter,
|
||||
output='minimal'
|
||||
)
|
||||
|
||||
def delete(self):
|
||||
# check whether the index already exists
|
||||
@ -163,11 +165,14 @@ class WeaviateVector(BaseVector):
|
||||
return True
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
for uuid in ids:
|
||||
self._client.data_object.delete(
|
||||
class_name=self._collection_name,
|
||||
uuid=uuid,
|
||||
)
|
||||
# check whether the index already exists
|
||||
schema = self._default_schema(self._collection_name)
|
||||
if self._client.schema.contains(schema):
|
||||
for uuid in ids:
|
||||
self._client.data_object.delete(
|
||||
class_name=self._collection_name,
|
||||
uuid=uuid,
|
||||
)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
"""Look up similar documents by embedding vector in Weaviate."""
|
||||
|
||||
@ -62,6 +62,12 @@ model_config_partial_fields = {
|
||||
'pre_prompt': fields.String,
|
||||
}
|
||||
|
||||
tag_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'type': fields.String
|
||||
}
|
||||
|
||||
app_partial_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
@ -70,9 +76,11 @@ app_partial_fields = {
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config', allow_null=True),
|
||||
'created_at': TimestampField
|
||||
'created_at': TimestampField,
|
||||
'tags': fields.List(fields.Nested(tag_fields))
|
||||
}
|
||||
|
||||
|
||||
app_pagination_fields = {
|
||||
'page': fields.Integer,
|
||||
'limit': fields.Integer(attribute='per_page'),
|
||||
|
||||
@ -27,6 +27,11 @@ dataset_retrieval_model_fields = {
|
||||
'score_threshold': fields.Float
|
||||
}
|
||||
|
||||
tag_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'type': fields.String
|
||||
}
|
||||
|
||||
dataset_detail_fields = {
|
||||
'id': fields.String,
|
||||
@ -46,7 +51,8 @@ dataset_detail_fields = {
|
||||
'embedding_model': fields.String,
|
||||
'embedding_model_provider': fields.String,
|
||||
'embedding_available': fields.Boolean,
|
||||
'retrieval_model_dict': fields.Nested(dataset_retrieval_model_fields)
|
||||
'retrieval_model_dict': fields.Nested(dataset_retrieval_model_fields),
|
||||
'tags': fields.List(fields.Nested(tag_fields))
|
||||
}
|
||||
|
||||
dataset_query_detail_fields = {
|
||||
|
||||
8
api/fields/tag_fields.py
Normal file
8
api/fields/tag_fields.py
Normal file
@ -0,0 +1,8 @@
|
||||
from flask_restful import fields
|
||||
|
||||
tag_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'type': fields.String,
|
||||
'binding_count': fields.String
|
||||
}
|
||||
@ -0,0 +1,62 @@
|
||||
"""add-tags-and-binding-table
|
||||
|
||||
Revision ID: 3c7cac9521c6
|
||||
Revises: c3311b089690
|
||||
Create Date: 2024-04-11 06:17:34.278594
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '3c7cac9521c6'
|
||||
down_revision = 'c3311b089690'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('tag_bindings',
|
||||
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', postgresql.UUID(), nullable=True),
|
||||
sa.Column('tag_id', postgresql.UUID(), nullable=True),
|
||||
sa.Column('target_id', postgresql.UUID(), nullable=True),
|
||||
sa.Column('created_by', postgresql.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tag_binding_pkey')
|
||||
)
|
||||
with op.batch_alter_table('tag_bindings', schema=None) as batch_op:
|
||||
batch_op.create_index('tag_bind_tag_id_idx', ['tag_id'], unique=False)
|
||||
batch_op.create_index('tag_bind_target_id_idx', ['target_id'], unique=False)
|
||||
|
||||
op.create_table('tags',
|
||||
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', postgresql.UUID(), nullable=True),
|
||||
sa.Column('type', sa.String(length=16), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('created_by', postgresql.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tag_pkey')
|
||||
)
|
||||
with op.batch_alter_table('tags', schema=None) as batch_op:
|
||||
batch_op.create_index('tag_name_idx', ['name'], unique=False)
|
||||
batch_op.create_index('tag_type_idx', ['type'], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tags', schema=None) as batch_op:
|
||||
batch_op.drop_index('tag_type_idx')
|
||||
batch_op.drop_index('tag_name_idx')
|
||||
|
||||
op.drop_table('tags')
|
||||
with op.batch_alter_table('tag_bindings', schema=None) as batch_op:
|
||||
batch_op.drop_index('tag_bind_target_id_idx')
|
||||
batch_op.drop_index('tag_bind_tag_id_idx')
|
||||
|
||||
op.drop_table('tag_bindings')
|
||||
# ### end Alembic commands ###
|
||||
@ -9,7 +9,7 @@ from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.account import Account
|
||||
from models.model import App, UploadFile
|
||||
from models.model import App, Tag, TagBinding, UploadFile
|
||||
|
||||
|
||||
class Dataset(db.Model):
|
||||
@ -118,6 +118,20 @@ class Dataset(db.Model):
|
||||
}
|
||||
return self.retrieval_model if self.retrieval_model else default_retrieval_model
|
||||
|
||||
@property
|
||||
def tags(self):
|
||||
tags = db.session.query(Tag).join(
|
||||
TagBinding,
|
||||
Tag.id == TagBinding.tag_id
|
||||
).filter(
|
||||
TagBinding.target_id == self.id,
|
||||
TagBinding.tenant_id == self.tenant_id,
|
||||
Tag.tenant_id == self.tenant_id,
|
||||
Tag.type == 'knowledge'
|
||||
).all()
|
||||
|
||||
return tags if tags else []
|
||||
|
||||
@staticmethod
|
||||
def gen_collection_name_by_id(dataset_id: str) -> str:
|
||||
normalized_dataset_id = dataset_id.replace("-", "_")
|
||||
|
||||
@ -148,7 +148,7 @@ class App(db.Model):
|
||||
return []
|
||||
agent_mode = app_model_config.agent_mode_dict
|
||||
tools = agent_mode.get('tools', [])
|
||||
|
||||
|
||||
provider_ids = []
|
||||
|
||||
for tool in tools:
|
||||
@ -185,6 +185,20 @@ class App(db.Model):
|
||||
|
||||
return deleted_tools
|
||||
|
||||
@property
|
||||
def tags(self):
|
||||
tags = db.session.query(Tag).join(
|
||||
TagBinding,
|
||||
Tag.id == TagBinding.tag_id
|
||||
).filter(
|
||||
TagBinding.target_id == self.id,
|
||||
TagBinding.tenant_id == self.tenant_id,
|
||||
Tag.tenant_id == self.tenant_id,
|
||||
Tag.type == 'app'
|
||||
).all()
|
||||
|
||||
return tags if tags else []
|
||||
|
||||
|
||||
class AppModelConfig(db.Model):
|
||||
__tablename__ = 'app_model_configs'
|
||||
@ -292,7 +306,8 @@ class AppModelConfig(db.Model):
|
||||
|
||||
@property
|
||||
def agent_mode_dict(self) -> dict:
|
||||
return json.loads(self.agent_mode) if self.agent_mode else {"enabled": False, "strategy": None, "tools": [], "prompt": None}
|
||||
return json.loads(self.agent_mode) if self.agent_mode else {"enabled": False, "strategy": None, "tools": [],
|
||||
"prompt": None}
|
||||
|
||||
@property
|
||||
def chat_prompt_config_dict(self) -> dict:
|
||||
@ -463,6 +478,7 @@ class InstalledApp(db.Model):
|
||||
return tenant
|
||||
|
||||
|
||||
|
||||
class Conversation(db.Model):
|
||||
__tablename__ = 'conversations'
|
||||
__table_args__ = (
|
||||
@ -1175,11 +1191,11 @@ class MessageAgentThought(db.Model):
|
||||
return json.loads(self.message_files)
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
@property
|
||||
def tools(self) -> list[str]:
|
||||
return self.tool.split(";") if self.tool else []
|
||||
|
||||
|
||||
@property
|
||||
def tool_labels(self) -> dict:
|
||||
try:
|
||||
@ -1189,7 +1205,7 @@ class MessageAgentThought(db.Model):
|
||||
return {}
|
||||
except Exception as e:
|
||||
return {}
|
||||
|
||||
|
||||
@property
|
||||
def tool_meta(self) -> dict:
|
||||
try:
|
||||
@ -1199,7 +1215,7 @@ class MessageAgentThought(db.Model):
|
||||
return {}
|
||||
except Exception as e:
|
||||
return {}
|
||||
|
||||
|
||||
@property
|
||||
def tool_inputs_dict(self) -> dict:
|
||||
tools = self.tools
|
||||
@ -1222,7 +1238,7 @@ class MessageAgentThought(db.Model):
|
||||
}
|
||||
except Exception as e:
|
||||
return {}
|
||||
|
||||
|
||||
@property
|
||||
def tool_outputs_dict(self) -> dict:
|
||||
tools = self.tools
|
||||
@ -1249,6 +1265,7 @@ class MessageAgentThought(db.Model):
|
||||
tool: self.observation for tool in tools
|
||||
}
|
||||
|
||||
|
||||
class DatasetRetrieverResource(db.Model):
|
||||
__tablename__ = 'dataset_retriever_resources'
|
||||
__table_args__ = (
|
||||
@ -1274,3 +1291,37 @@ class DatasetRetrieverResource(db.Model):
|
||||
retriever_from = db.Column(db.Text, nullable=False)
|
||||
created_by = db.Column(UUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
|
||||
|
||||
class Tag(db.Model):
|
||||
__tablename__ = 'tags'
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint('id', name='tag_pkey'),
|
||||
db.Index('tag_type_idx', 'type'),
|
||||
db.Index('tag_name_idx', 'name'),
|
||||
)
|
||||
|
||||
TAG_TYPE_LIST = ['knowledge', 'app']
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=True)
|
||||
type = db.Column(db.String(16), nullable=False)
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
created_by = db.Column(UUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
|
||||
class TagBinding(db.Model):
|
||||
__tablename__ = 'tag_bindings'
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint('id', name='tag_binding_pkey'),
|
||||
db.Index('tag_bind_target_id_idx', 'target_id'),
|
||||
db.Index('tag_bind_tag_id_idx', 'tag_id'),
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=True)
|
||||
tag_id = db.Column(UUID, nullable=True)
|
||||
target_id = db.Column(UUID, nullable=True)
|
||||
created_by = db.Column(UUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
@ -21,11 +21,12 @@ from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
from models.tools import ApiToolProvider
|
||||
from services.tag_service import TagService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
class AppService:
|
||||
def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination:
|
||||
def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination | None:
|
||||
"""
|
||||
Get app list with pagination
|
||||
:param tenant_id: tenant id
|
||||
@ -49,6 +50,14 @@ class AppService:
|
||||
if 'name' in args and args['name']:
|
||||
name = args['name'][:30]
|
||||
filters.append(App.name.ilike(f'%{name}%'))
|
||||
if 'tag_ids' in args and args['tag_ids']:
|
||||
target_ids = TagService.get_target_ids_by_tag_ids('app',
|
||||
tenant_id,
|
||||
args['tag_ids'])
|
||||
if target_ids:
|
||||
filters.append(App.id.in_(target_ids))
|
||||
else:
|
||||
return None
|
||||
|
||||
app_models = db.paginate(
|
||||
db.select(App).where(*filters).order_by(App.created_at.desc()),
|
||||
|
||||
@ -38,28 +38,39 @@ from services.errors.dataset import DatasetNameDuplicateError
|
||||
from services.errors.document import DocumentIndexingError
|
||||
from services.errors.file import FileNotExistsError
|
||||
from services.feature_service import FeatureModel, FeatureService
|
||||
from services.tag_service import TagService
|
||||
from services.vector_service import VectorService
|
||||
from tasks.clean_notion_document_task import clean_notion_document_task
|
||||
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
|
||||
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
|
||||
from tasks.document_indexing_task import document_indexing_task
|
||||
from tasks.document_indexing_update_task import document_indexing_update_task
|
||||
from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
|
||||
from tasks.recover_document_indexing_task import recover_document_indexing_task
|
||||
from tasks.retry_document_indexing_task import retry_document_indexing_task
|
||||
|
||||
|
||||
class DatasetService:
|
||||
|
||||
@staticmethod
|
||||
def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None):
|
||||
def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None, search=None, tag_ids=None):
|
||||
if user:
|
||||
permission_filter = db.or_(Dataset.created_by == user.id,
|
||||
Dataset.permission == 'all_team_members')
|
||||
else:
|
||||
permission_filter = Dataset.permission == 'all_team_members'
|
||||
datasets = Dataset.query.filter(
|
||||
query = Dataset.query.filter(
|
||||
db.and_(Dataset.provider == provider, Dataset.tenant_id == tenant_id, permission_filter)) \
|
||||
.order_by(Dataset.created_at.desc()) \
|
||||
.paginate(
|
||||
.order_by(Dataset.created_at.desc())
|
||||
if search:
|
||||
query = query.filter(db.and_(Dataset.name.ilike(f'%{search}%')))
|
||||
if tag_ids:
|
||||
target_ids = TagService.get_target_ids_by_tag_ids('knowledge', tenant_id, tag_ids)
|
||||
if target_ids:
|
||||
query = query.filter(db.and_(Dataset.id.in_(target_ids)))
|
||||
else:
|
||||
return [], 0
|
||||
datasets = query.paginate(
|
||||
page=page,
|
||||
per_page=per_page,
|
||||
max_per_page=100,
|
||||
@ -165,9 +176,36 @@ class DatasetService:
|
||||
# get embedding model setting
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_default_model_instance(
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING
|
||||
provider=data['embedding_model_provider'],
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=data['embedding_model']
|
||||
)
|
||||
filtered_data['embedding_model'] = embedding_model.model
|
||||
filtered_data['embedding_model_provider'] = embedding_model.provider
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.provider,
|
||||
embedding_model.model
|
||||
)
|
||||
filtered_data['collection_binding_id'] = dataset_collection_binding.id
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(ex.description)
|
||||
else:
|
||||
if data['embedding_model_provider'] != dataset.embedding_model_provider or \
|
||||
data['embedding_model'] != dataset.embedding_model:
|
||||
action = 'update'
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=data['embedding_model_provider'],
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=data['embedding_model']
|
||||
)
|
||||
filtered_data['embedding_model'] = embedding_model.model
|
||||
filtered_data['embedding_model_provider'] = embedding_model.provider
|
||||
@ -376,6 +414,15 @@ class DocumentService:
|
||||
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]:
|
||||
documents = db.session.query(Document).filter(
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.indexing_status == 'error' or Document.indexing_status == 'paused'
|
||||
).all()
|
||||
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def get_batch_documents(dataset_id: str, batch: str) -> list[Document]:
|
||||
documents = db.session.query(Document).filter(
|
||||
@ -440,6 +487,20 @@ class DocumentService:
|
||||
# trigger async task
|
||||
recover_document_indexing_task.delay(document.dataset_id, document.id)
|
||||
|
||||
@staticmethod
|
||||
def retry_document(dataset_id: str, documents: list[Document]):
|
||||
for document in documents:
|
||||
# retry document indexing
|
||||
document.indexing_status = 'waiting'
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
# add retry flag
|
||||
retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id)
|
||||
redis_client.setex(retry_indexing_cache_key, 600, 1)
|
||||
# trigger async task
|
||||
document_ids = [document.id for document in documents]
|
||||
retry_document_indexing_task.delay(dataset_id, document_ids)
|
||||
|
||||
@staticmethod
|
||||
def get_documents_position(dataset_id):
|
||||
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
|
||||
@ -537,6 +598,7 @@ class DocumentService:
|
||||
db.session.commit()
|
||||
position = DocumentService.get_documents_position(dataset.id)
|
||||
document_ids = []
|
||||
duplicate_document_ids = []
|
||||
if document_data["data_source"]["type"] == "upload_file":
|
||||
upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
|
||||
for file_id in upload_file_list:
|
||||
@ -553,6 +615,28 @@ class DocumentService:
|
||||
data_source_info = {
|
||||
"upload_file_id": file_id,
|
||||
}
|
||||
# check duplicate
|
||||
if document_data.get('duplicate', False):
|
||||
document = Document.query.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type='upload_file',
|
||||
enabled=True,
|
||||
name=file_name
|
||||
).first()
|
||||
if document:
|
||||
document.dataset_process_rule_id = dataset_process_rule.id
|
||||
document.updated_at = datetime.datetime.utcnow()
|
||||
document.created_from = created_from
|
||||
document.doc_form = document_data['doc_form']
|
||||
document.doc_language = document_data['doc_language']
|
||||
document.data_source_info = json.dumps(data_source_info)
|
||||
document.batch = batch
|
||||
document.indexing_status = 'waiting'
|
||||
db.session.add(document)
|
||||
documents.append(document)
|
||||
duplicate_document_ids.append(document.id)
|
||||
continue
|
||||
document = DocumentService.build_document(dataset, dataset_process_rule.id,
|
||||
document_data["data_source"]["type"],
|
||||
document_data["doc_form"],
|
||||
@ -618,7 +702,10 @@ class DocumentService:
|
||||
db.session.commit()
|
||||
|
||||
# trigger async task
|
||||
document_indexing_task.delay(dataset.id, document_ids)
|
||||
if document_ids:
|
||||
document_indexing_task.delay(dataset.id, document_ids)
|
||||
if duplicate_document_ids:
|
||||
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
|
||||
|
||||
return documents, batch
|
||||
|
||||
@ -626,7 +713,8 @@ class DocumentService:
|
||||
def check_documents_upload_quota(count: int, features: FeatureModel):
|
||||
can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size
|
||||
if count > can_upload_size:
|
||||
raise ValueError(f'You have reached the limit of your subscription. Only {can_upload_size} documents can be uploaded.')
|
||||
raise ValueError(
|
||||
f'You have reached the limit of your subscription. Only {can_upload_size} documents can be uploaded.')
|
||||
|
||||
@staticmethod
|
||||
def build_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
|
||||
@ -752,7 +840,6 @@ class DocumentService:
|
||||
db.session.commit()
|
||||
# trigger async task
|
||||
document_indexing_update_task.delay(document.dataset_id, document.id)
|
||||
|
||||
return document
|
||||
|
||||
@staticmethod
|
||||
|
||||
161
api/services/tag_service.py
Normal file
161
api/services/tag_service.py
Normal file
@ -0,0 +1,161 @@
|
||||
import uuid
|
||||
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import func
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.model import App, Tag, TagBinding
|
||||
|
||||
|
||||
class TagService:
|
||||
@staticmethod
|
||||
def get_tags(tag_type: str, current_tenant_id: str, keyword: str = None) -> list:
|
||||
query = db.session.query(
|
||||
Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label('binding_count')
|
||||
).outerjoin(
|
||||
TagBinding, Tag.id == TagBinding.tag_id
|
||||
).filter(
|
||||
Tag.type == tag_type,
|
||||
Tag.tenant_id == current_tenant_id
|
||||
)
|
||||
if keyword:
|
||||
query = query.filter(db.and_(Tag.name.ilike(f'%{keyword}%')))
|
||||
query = query.group_by(
|
||||
Tag.id
|
||||
)
|
||||
results = query.order_by(Tag.created_at.desc()).all()
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list:
|
||||
tags = db.session.query(Tag).filter(
|
||||
Tag.id.in_(tag_ids),
|
||||
Tag.tenant_id == current_tenant_id,
|
||||
Tag.type == tag_type
|
||||
).all()
|
||||
if not tags:
|
||||
return []
|
||||
tag_ids = [tag.id for tag in tags]
|
||||
tag_bindings = db.session.query(
|
||||
TagBinding.target_id
|
||||
).filter(
|
||||
TagBinding.tag_id.in_(tag_ids),
|
||||
TagBinding.tenant_id == current_tenant_id
|
||||
).all()
|
||||
if not tag_bindings:
|
||||
return []
|
||||
results = [tag_binding.target_id for tag_binding in tag_bindings]
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list:
|
||||
tags = db.session.query(Tag).join(
|
||||
TagBinding,
|
||||
Tag.id == TagBinding.tag_id
|
||||
).filter(
|
||||
TagBinding.target_id == target_id,
|
||||
TagBinding.tenant_id == current_tenant_id,
|
||||
Tag.tenant_id == current_tenant_id,
|
||||
Tag.type == tag_type
|
||||
).all()
|
||||
|
||||
return tags if tags else []
|
||||
|
||||
|
||||
@staticmethod
|
||||
def save_tags(args: dict) -> Tag:
|
||||
tag = Tag(
|
||||
id=str(uuid.uuid4()),
|
||||
name=args['name'],
|
||||
type=args['type'],
|
||||
created_by=current_user.id,
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
db.session.add(tag)
|
||||
db.session.commit()
|
||||
return tag
|
||||
|
||||
@staticmethod
|
||||
def update_tags(args: dict, tag_id: str) -> Tag:
|
||||
tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
|
||||
if not tag:
|
||||
raise NotFound("Tag not found")
|
||||
tag.name = args['name']
|
||||
db.session.commit()
|
||||
return tag
|
||||
|
||||
@staticmethod
|
||||
def get_tag_binding_count(tag_id: str) -> int:
|
||||
count = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).count()
|
||||
return count
|
||||
|
||||
@staticmethod
|
||||
def delete_tag(tag_id: str):
|
||||
tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
|
||||
if not tag:
|
||||
raise NotFound("Tag not found")
|
||||
db.session.delete(tag)
|
||||
# delete tag binding
|
||||
tag_bindings = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).all()
|
||||
if tag_bindings:
|
||||
for tag_binding in tag_bindings:
|
||||
db.session.delete(tag_binding)
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def save_tag_binding(args):
|
||||
# check if target exists
|
||||
TagService.check_target_exists(args['type'], args['target_id'])
|
||||
# save tag binding
|
||||
for tag_id in args['tag_ids']:
|
||||
tag_binding = db.session.query(TagBinding).filter(
|
||||
TagBinding.tag_id == tag_id,
|
||||
TagBinding.target_id == args['target_id']
|
||||
).first()
|
||||
if tag_binding:
|
||||
continue
|
||||
new_tag_binding = TagBinding(
|
||||
tag_id=tag_id,
|
||||
target_id=args['target_id'],
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
created_by=current_user.id
|
||||
)
|
||||
db.session.add(new_tag_binding)
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def delete_tag_binding(args):
|
||||
# check if target exists
|
||||
TagService.check_target_exists(args['type'], args['target_id'])
|
||||
# delete tag binding
|
||||
tag_bindings = db.session.query(TagBinding).filter(
|
||||
TagBinding.target_id == args['target_id'],
|
||||
TagBinding.tag_id == (args['tag_id'])
|
||||
).first()
|
||||
if tag_bindings:
|
||||
db.session.delete(tag_bindings)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def check_target_exists(type: str, target_id: str):
|
||||
if type == 'knowledge':
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == current_user.current_tenant_id,
|
||||
Dataset.id == target_id
|
||||
).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found")
|
||||
elif type == 'app':
|
||||
app = db.session.query(App).filter(
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.id == target_id
|
||||
).first()
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
else:
|
||||
raise NotFound("Invalid binding type")
|
||||
|
||||
@ -16,6 +16,7 @@ from models.dataset import (
|
||||
)
|
||||
|
||||
|
||||
# Add import statement for ValueError
|
||||
@shared_task(queue='dataset')
|
||||
def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
|
||||
index_struct: str, collection_binding_id: str, doc_form: str):
|
||||
@ -48,6 +49,9 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
|
||||
logging.info(click.style('No documents found for dataset: {}'.format(dataset_id), fg='green'))
|
||||
else:
|
||||
logging.info(click.style('Cleaning documents for dataset: {}'.format(dataset_id), fg='green'))
|
||||
# Specify the index type before initializing the index processor
|
||||
if doc_form is None:
|
||||
raise ValueError("Index type must be specified.")
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(dataset, None)
|
||||
|
||||
|
||||
@ -64,6 +64,39 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
|
||||
# save vector index
|
||||
index_processor.load(dataset, documents, with_keywords=False)
|
||||
elif action == 'update':
|
||||
# clean index
|
||||
index_processor.clean(dataset, None, with_keywords=False)
|
||||
dataset_documents = db.session.query(DatasetDocument).filter(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == 'completed',
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).all()
|
||||
# add new index
|
||||
if dataset_documents:
|
||||
documents = []
|
||||
for dataset_document in dataset_documents:
|
||||
# delete from vector index
|
||||
segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.enabled == True
|
||||
).order_by(DocumentSegment.position.asc()).all()
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
|
||||
# save vector index
|
||||
index_processor.load(dataset, documents, with_keywords=False)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
|
||||
94
api/tasks/duplicate_document_indexing_task.py
Normal file
94
api/tasks/duplicate_document_indexing_task.py
Normal file
@ -0,0 +1,94 @@
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from flask import current_app
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedException, IndexingRunner
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
@shared_task(queue='dataset')
|
||||
def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
|
||||
"""
|
||||
Async process document
|
||||
:param dataset_id:
|
||||
:param document_ids:
|
||||
|
||||
Usage: duplicate_document_indexing_task.delay(dataset_id, document_id)
|
||||
"""
|
||||
documents = []
|
||||
start_at = time.perf_counter()
|
||||
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
|
||||
# check document limit
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
try:
|
||||
if features.billing.enabled:
|
||||
vector_space = features.vector_space
|
||||
count = len(document_ids)
|
||||
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
if 0 < vector_space.limit <= vector_space.size:
|
||||
raise ValueError("Your total number of documents plus the number of uploads have over the limit of "
|
||||
"your subscription.")
|
||||
except Exception as e:
|
||||
for document_id in document_ids:
|
||||
document = db.session.query(Document).filter(
|
||||
Document.id == document_id,
|
||||
Document.dataset_id == dataset_id
|
||||
).first()
|
||||
if document:
|
||||
document.indexing_status = 'error'
|
||||
document.error = str(e)
|
||||
document.stopped_at = datetime.datetime.utcnow()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
return
|
||||
|
||||
for document_id in document_ids:
|
||||
logging.info(click.style('Start process document: {}'.format(document_id), fg='green'))
|
||||
|
||||
document = db.session.query(Document).filter(
|
||||
Document.id == document_id,
|
||||
Document.dataset_id == dataset_id
|
||||
).first()
|
||||
|
||||
if document:
|
||||
# clean old data
|
||||
index_type = document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
|
||||
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids)
|
||||
|
||||
for segment in segments:
|
||||
db.session.delete(segment)
|
||||
db.session.commit()
|
||||
|
||||
document.indexing_status = 'parsing'
|
||||
document.processing_started_at = datetime.datetime.utcnow()
|
||||
documents.append(document)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run(documents)
|
||||
end_at = time.perf_counter()
|
||||
logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green'))
|
||||
except DocumentIsPausedException as ex:
|
||||
logging.info(click.style(str(ex), fg='yellow'))
|
||||
except Exception:
|
||||
pass
|
||||
91
api/tasks/retry_document_indexing_task.py
Normal file
91
api/tasks/retry_document_indexing_task.py
Normal file
@ -0,0 +1,91 @@
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
@shared_task(queue='dataset')
|
||||
def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
|
||||
"""
|
||||
Async process document
|
||||
:param dataset_id:
|
||||
:param document_ids:
|
||||
|
||||
Usage: retry_document_indexing_task.delay(dataset_id, document_id)
|
||||
"""
|
||||
documents = []
|
||||
start_at = time.perf_counter()
|
||||
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
for document_id in document_ids:
|
||||
retry_indexing_cache_key = 'document_{}_is_retried'.format(document_id)
|
||||
# check document limit
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
try:
|
||||
if features.billing.enabled:
|
||||
vector_space = features.vector_space
|
||||
if 0 < vector_space.limit <= vector_space.size:
|
||||
raise ValueError("Your total number of documents plus the number of uploads have over the limit of "
|
||||
"your subscription.")
|
||||
except Exception as e:
|
||||
document = db.session.query(Document).filter(
|
||||
Document.id == document_id,
|
||||
Document.dataset_id == dataset_id
|
||||
).first()
|
||||
if document:
|
||||
document.indexing_status = 'error'
|
||||
document.error = str(e)
|
||||
document.stopped_at = datetime.datetime.utcnow()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
redis_client.delete(retry_indexing_cache_key)
|
||||
return
|
||||
|
||||
logging.info(click.style('Start retry document: {}'.format(document_id), fg='green'))
|
||||
document = db.session.query(Document).filter(
|
||||
Document.id == document_id,
|
||||
Document.dataset_id == dataset_id
|
||||
).first()
|
||||
try:
|
||||
if document:
|
||||
# clean old data
|
||||
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
|
||||
|
||||
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids)
|
||||
|
||||
for segment in segments:
|
||||
db.session.delete(segment)
|
||||
db.session.commit()
|
||||
|
||||
document.indexing_status = 'parsing'
|
||||
document.processing_started_at = datetime.datetime.utcnow()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run([document])
|
||||
redis_client.delete(retry_indexing_cache_key)
|
||||
except Exception as ex:
|
||||
document.indexing_status = 'error'
|
||||
document.error = str(ex)
|
||||
document.stopped_at = datetime.datetime.utcnow()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
logging.info(click.style(str(ex), fg='yellow'))
|
||||
redis_client.delete(retry_indexing_cache_key)
|
||||
pass
|
||||
end_at = time.perf_counter()
|
||||
logging.info(click.style('Retry dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green'))
|
||||
Reference in New Issue
Block a user