Knowledge optimization (#3755)

Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: JzoNg <jzongcode@gmail.com>
This commit is contained in:
Jyong
2024-04-24 15:02:29 +08:00
committed by GitHub
parent 3cd8e6f5c6
commit f257f2c396
75 changed files with 2756 additions and 266 deletions

View File

@ -53,5 +53,8 @@ from .explore import (
workflow,
)
# Import tag controllers
from .tag import tags
# Import workspace controllers
from .workspace import account, members, model_providers, models, tool_providers, workspace

View File

@ -1,18 +1,25 @@
import json
import uuid
from flask_login import current_user
from flask_restful import Resource, inputs, marshal_with, reqparse
from werkzeug.exceptions import BadRequest, Forbidden
from flask_restful import Resource, inputs, marshal, marshal_with, reqparse
from werkzeug.exceptions import BadRequest, Forbidden, abort
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from fields.app_fields import (
app_detail_fields,
app_detail_fields_with_site,
app_pagination_fields,
)
from libs.login import login_required
from models.model import App, AppMode, AppModelConfig
from services.app_service import AppService
from services.tag_service import TagService
ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion']
@ -22,21 +29,29 @@ class AppListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_pagination_fields)
def get(self):
"""Get app list"""
def uuid_list(value):
try:
return [str(uuid.UUID(v)) for v in value.split(',')]
except ValueError:
abort(400, message="Invalid UUID format in tag_ids.")
parser = reqparse.RequestParser()
parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args')
parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('mode', type=str, choices=['chat', 'workflow', 'agent-chat', 'channel', 'all'], default='all', location='args', required=False)
parser.add_argument('name', type=str, location='args', required=False)
parser.add_argument('tag_ids', type=uuid_list, location='args', required=False)
args = parser.parse_args()
# get app list
app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args)
if not app_pagination:
return {'data': [], 'total': 0, 'page': 1, 'limit': 20, 'has_more': False}
return app_pagination
return marshal(app_pagination, app_pagination_fields)
@setup_required
@login_required

View File

@ -48,11 +48,14 @@ class DatasetListApi(Resource):
limit = request.args.get('limit', default=20, type=int)
ids = request.args.getlist('ids')
provider = request.args.get('provider', default="vendor")
search = request.args.get('keyword', default=None, type=str)
tag_ids = request.args.getlist('tag_ids')
if ids:
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
else:
datasets, total = DatasetService.get_datasets(page, limit, provider,
current_user.current_tenant_id, current_user)
current_user.current_tenant_id, current_user, search, tag_ids)
# check embedding setting
provider_manager = ProviderManager()
@ -184,6 +187,10 @@ class DatasetApi(Resource):
help='Invalid indexing technique.')
parser.add_argument('permission', type=str, location='json', choices=(
'only_me', 'all_team_members'), help='Invalid permission.')
parser.add_argument('embedding_model', type=str,
location='json', help='Invalid embedding model.')
parser.add_argument('embedding_model_provider', type=str,
location='json', help='Invalid embedding model provider.')
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
args = parser.parse_args()
@ -506,10 +513,27 @@ class DatasetRetrievalSettingMockApi(Resource):
else:
raise ValueError("Unsupported vector db type.")
class DatasetErrorDocs(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str)
return {
'data': [marshal(item, document_status_fields) for item in results],
'total': len(results)
}, 200
api.add_resource(DatasetListApi, '/datasets')
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
api.add_resource(DatasetErrorDocs, '/datasets/<uuid:dataset_id>/error-docs')
api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate')
api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps')
api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status')

View File

@ -1,3 +1,4 @@
import logging
from datetime import datetime, timezone
from flask import request
@ -233,7 +234,7 @@ class DatasetDocumentListApi(Resource):
location='json')
parser.add_argument('data_source', type=dict, required=False, location='json')
parser.add_argument('process_rule', type=dict, required=False, location='json')
parser.add_argument('duplicate', type=bool, nullable=False, location='json')
parser.add_argument('duplicate', type=bool, default=True, nullable=False, location='json')
parser.add_argument('original_document_id', type=str, required=False, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
@ -883,6 +884,49 @@ class DocumentRecoverApi(DocumentResource):
return {'result': 'success'}, 204
class DocumentRetryApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id):
"""retry document."""
parser = reqparse.RequestParser()
parser.add_argument('document_ids', type=list, required=True, nullable=False,
location='json')
args = parser.parse_args()
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
retry_documents = []
if not dataset:
raise NotFound('Dataset not found.')
for document_id in args['document_ids']:
try:
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
# 404 if document not found
if document is None:
raise NotFound("Document Not Exists.")
# 403 if document is archived
if DocumentService.check_archived(document):
raise ArchivedDocumentImmutableError()
# 400 if document is completed
if document.indexing_status == 'completed':
raise DocumentAlreadyFinishedError()
retry_documents.append(document)
except Exception as e:
logging.error(f"Document {document_id} retry failed: {str(e)}")
continue
# retry document
DocumentService.retry_document(dataset_id, retry_documents)
return {'result': 'success'}, 204
api.add_resource(GetProcessRuleApi, '/datasets/process-rule')
api.add_resource(DatasetDocumentListApi,
'/datasets/<uuid:dataset_id>/documents')
@ -908,3 +952,4 @@ api.add_resource(DocumentStatusApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/status/<string:action>')
api.add_resource(DocumentPauseApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause')
api.add_resource(DocumentRecoverApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume')
api.add_resource(DocumentRetryApi, '/datasets/<uuid:dataset_id>/retry')

View File

@ -0,0 +1,159 @@
from flask import request
from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from fields.tag_fields import tag_fields
from libs.login import login_required
from models.model import Tag
from services.tag_service import TagService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError('Name must be between 1 to 50 characters.')
return name
class TagListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(tag_fields)
def get(self):
tag_type = request.args.get('type', type=str)
keyword = request.args.get('keyword', default=None, type=str)
tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword)
return tags, 200
@setup_required
@login_required
@account_initialization_required
def post(self):
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('name', nullable=False, required=True,
help='Name must be between 1 to 50 characters.',
type=_validate_name)
parser.add_argument('type', type=str, location='json',
choices=Tag.TAG_TYPE_LIST,
nullable=True,
help='Invalid tag type.')
args = parser.parse_args()
tag = TagService.save_tags(args)
response = {
'id': tag.id,
'name': tag.name,
'type': tag.type,
'binding_count': 0
}
return response, 200
class TagUpdateDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def patch(self, tag_id):
tag_id = str(tag_id)
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('name', nullable=False, required=True,
help='Name must be between 1 to 50 characters.',
type=_validate_name)
args = parser.parse_args()
tag = TagService.update_tags(args, tag_id)
binding_count = TagService.get_tag_binding_count(tag_id)
response = {
'id': tag.id,
'name': tag.name,
'type': tag.type,
'binding_count': binding_count
}
return response, 200
@setup_required
@login_required
@account_initialization_required
def delete(self, tag_id):
tag_id = str(tag_id)
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
TagService.delete_tag(tag_id)
return 200
class TagBindingCreateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('tag_ids', type=list, nullable=False, required=True, location='json',
help='Tag IDs is required.')
parser.add_argument('target_id', type=str, nullable=False, required=True, location='json',
help='Target ID is required.')
parser.add_argument('type', type=str, location='json',
choices=Tag.TAG_TYPE_LIST,
nullable=True,
help='Invalid tag type.')
args = parser.parse_args()
TagService.save_tag_binding(args)
return 200
class TagBindingDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('tag_id', type=str, nullable=False, required=True,
help='Tag ID is required.')
parser.add_argument('target_id', type=str, nullable=False, required=True,
help='Target ID is required.')
parser.add_argument('type', type=str, location='json',
choices=Tag.TAG_TYPE_LIST,
nullable=True,
help='Invalid tag type.')
args = parser.parse_args()
TagService.delete_tag_binding(args)
return 200
api.add_resource(TagListApi, '/tags')
api.add_resource(TagUpdateDeleteApi, '/tags/<uuid:tag_id>')
api.add_resource(TagBindingCreateApi, '/tag-bindings/create')
api.add_resource(TagBindingDeleteApi, '/tag-bindings/remove')

View File

@ -26,8 +26,11 @@ class DatasetApi(DatasetApiResource):
page = request.args.get('page', default=1, type=int)
limit = request.args.get('limit', default=20, type=int)
provider = request.args.get('provider', default="vendor")
search = request.args.get('keyword', default=None, type=str)
tag_ids = request.args.getlist('tag_ids')
datasets, total = DatasetService.get_datasets(page, limit, provider,
tenant_id, current_user)
tenant_id, current_user, search, tag_ids)
# check embedding setting
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(

View File

@ -110,19 +110,37 @@ class MilvusVector(BaseVector):
return None
def delete_by_metadata_field(self, key: str, value: str):
alias = uuid4().hex
if self._client_config.secure:
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
else:
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self._client.delete(collection_name=self._collection_name, pks=ids)
from pymilvus import utility
if utility.has_collection(self._collection_name, using=alias):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self._client.delete(collection_name=self._collection_name, pks=ids)
def delete_by_ids(self, doc_ids: list[str]) -> None:
alias = uuid4().hex
if self._client_config.secure:
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
else:
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
result = self._client.query(collection_name=self._collection_name,
filter=f'metadata["doc_id"] in {doc_ids}',
output_fields=["id"])
if result:
ids = [item["id"] for item in result]
self._client.delete(collection_name=self._collection_name, pks=ids)
from pymilvus import utility
if utility.has_collection(self._collection_name, using=alias):
result = self._client.query(collection_name=self._collection_name,
filter=f'metadata["doc_id"] in {doc_ids}',
output_fields=["id"])
if result:
ids = [item["id"] for item in result]
self._client.delete(collection_name=self._collection_name, pks=ids)
def delete(self) -> None:
alias = uuid4().hex

View File

@ -217,29 +217,38 @@ class QdrantVector(BaseVector):
def delete_by_metadata_field(self, key: str, value: str):
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
filter = models.Filter(
must=[
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
try:
filter = models.Filter(
must=[
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
],
)
self._reload_if_needed()
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(
filter=filter
),
],
)
self._reload_if_needed()
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(
filter=filter
),
)
)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
def delete(self):
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
try:
filter = models.Filter(
must=[
@ -257,29 +266,40 @@ class QdrantVector(BaseVector):
)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
def delete_by_ids(self, ids: list[str]) -> None:
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
for node_id in ids:
filter = models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchValue(value=node_id),
try:
filter = models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchValue(value=node_id),
),
],
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(
filter=filter
),
],
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(
filter=filter
),
)
)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
def text_exists(self, id: str) -> bool:
all_collection_name = []

View File

@ -121,18 +121,20 @@ class WeaviateVector(BaseVector):
return ids
def delete_by_metadata_field(self, key: str, value: str):
# check whether the index already exists
schema = self._default_schema(self._collection_name)
if self._client.schema.contains(schema):
where_filter = {
"operator": "Equal",
"path": [key],
"valueText": value
}
where_filter = {
"operator": "Equal",
"path": [key],
"valueText": value
}
self._client.batch.delete_objects(
class_name=self._collection_name,
where=where_filter,
output='minimal'
)
self._client.batch.delete_objects(
class_name=self._collection_name,
where=where_filter,
output='minimal'
)
def delete(self):
# check whether the index already exists
@ -163,11 +165,14 @@ class WeaviateVector(BaseVector):
return True
def delete_by_ids(self, ids: list[str]) -> None:
for uuid in ids:
self._client.data_object.delete(
class_name=self._collection_name,
uuid=uuid,
)
# check whether the index already exists
schema = self._default_schema(self._collection_name)
if self._client.schema.contains(schema):
for uuid in ids:
self._client.data_object.delete(
class_name=self._collection_name,
uuid=uuid,
)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""Look up similar documents by embedding vector in Weaviate."""

View File

@ -62,6 +62,12 @@ model_config_partial_fields = {
'pre_prompt': fields.String,
}
tag_fields = {
'id': fields.String,
'name': fields.String,
'type': fields.String
}
app_partial_fields = {
'id': fields.String,
'name': fields.String,
@ -70,9 +76,11 @@ app_partial_fields = {
'icon': fields.String,
'icon_background': fields.String,
'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config', allow_null=True),
'created_at': TimestampField
'created_at': TimestampField,
'tags': fields.List(fields.Nested(tag_fields))
}
app_pagination_fields = {
'page': fields.Integer,
'limit': fields.Integer(attribute='per_page'),

View File

@ -27,6 +27,11 @@ dataset_retrieval_model_fields = {
'score_threshold': fields.Float
}
tag_fields = {
'id': fields.String,
'name': fields.String,
'type': fields.String
}
dataset_detail_fields = {
'id': fields.String,
@ -46,7 +51,8 @@ dataset_detail_fields = {
'embedding_model': fields.String,
'embedding_model_provider': fields.String,
'embedding_available': fields.Boolean,
'retrieval_model_dict': fields.Nested(dataset_retrieval_model_fields)
'retrieval_model_dict': fields.Nested(dataset_retrieval_model_fields),
'tags': fields.List(fields.Nested(tag_fields))
}
dataset_query_detail_fields = {

8
api/fields/tag_fields.py Normal file
View File

@ -0,0 +1,8 @@
from flask_restful import fields
tag_fields = {
'id': fields.String,
'name': fields.String,
'type': fields.String,
'binding_count': fields.String
}

View File

@ -0,0 +1,62 @@
"""add-tags-and-binding-table
Revision ID: 3c7cac9521c6
Revises: c3311b089690
Create Date: 2024-04-11 06:17:34.278594
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '3c7cac9521c6'
down_revision = 'c3311b089690'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tag_bindings',
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', postgresql.UUID(), nullable=True),
sa.Column('tag_id', postgresql.UUID(), nullable=True),
sa.Column('target_id', postgresql.UUID(), nullable=True),
sa.Column('created_by', postgresql.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='tag_binding_pkey')
)
with op.batch_alter_table('tag_bindings', schema=None) as batch_op:
batch_op.create_index('tag_bind_tag_id_idx', ['tag_id'], unique=False)
batch_op.create_index('tag_bind_target_id_idx', ['target_id'], unique=False)
op.create_table('tags',
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', postgresql.UUID(), nullable=True),
sa.Column('type', sa.String(length=16), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('created_by', postgresql.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='tag_pkey')
)
with op.batch_alter_table('tags', schema=None) as batch_op:
batch_op.create_index('tag_name_idx', ['name'], unique=False)
batch_op.create_index('tag_type_idx', ['type'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tags', schema=None) as batch_op:
batch_op.drop_index('tag_type_idx')
batch_op.drop_index('tag_name_idx')
op.drop_table('tags')
with op.batch_alter_table('tag_bindings', schema=None) as batch_op:
batch_op.drop_index('tag_bind_target_id_idx')
batch_op.drop_index('tag_bind_tag_id_idx')
op.drop_table('tag_bindings')
# ### end Alembic commands ###

View File

@ -9,7 +9,7 @@ from sqlalchemy.dialects.postgresql import JSONB, UUID
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.account import Account
from models.model import App, UploadFile
from models.model import App, Tag, TagBinding, UploadFile
class Dataset(db.Model):
@ -118,6 +118,20 @@ class Dataset(db.Model):
}
return self.retrieval_model if self.retrieval_model else default_retrieval_model
@property
def tags(self):
tags = db.session.query(Tag).join(
TagBinding,
Tag.id == TagBinding.tag_id
).filter(
TagBinding.target_id == self.id,
TagBinding.tenant_id == self.tenant_id,
Tag.tenant_id == self.tenant_id,
Tag.type == 'knowledge'
).all()
return tags if tags else []
@staticmethod
def gen_collection_name_by_id(dataset_id: str) -> str:
normalized_dataset_id = dataset_id.replace("-", "_")

View File

@ -148,7 +148,7 @@ class App(db.Model):
return []
agent_mode = app_model_config.agent_mode_dict
tools = agent_mode.get('tools', [])
provider_ids = []
for tool in tools:
@ -185,6 +185,20 @@ class App(db.Model):
return deleted_tools
@property
def tags(self):
tags = db.session.query(Tag).join(
TagBinding,
Tag.id == TagBinding.tag_id
).filter(
TagBinding.target_id == self.id,
TagBinding.tenant_id == self.tenant_id,
Tag.tenant_id == self.tenant_id,
Tag.type == 'app'
).all()
return tags if tags else []
class AppModelConfig(db.Model):
__tablename__ = 'app_model_configs'
@ -292,7 +306,8 @@ class AppModelConfig(db.Model):
@property
def agent_mode_dict(self) -> dict:
return json.loads(self.agent_mode) if self.agent_mode else {"enabled": False, "strategy": None, "tools": [], "prompt": None}
return json.loads(self.agent_mode) if self.agent_mode else {"enabled": False, "strategy": None, "tools": [],
"prompt": None}
@property
def chat_prompt_config_dict(self) -> dict:
@ -463,6 +478,7 @@ class InstalledApp(db.Model):
return tenant
class Conversation(db.Model):
__tablename__ = 'conversations'
__table_args__ = (
@ -1175,11 +1191,11 @@ class MessageAgentThought(db.Model):
return json.loads(self.message_files)
else:
return []
@property
def tools(self) -> list[str]:
return self.tool.split(";") if self.tool else []
@property
def tool_labels(self) -> dict:
try:
@ -1189,7 +1205,7 @@ class MessageAgentThought(db.Model):
return {}
except Exception as e:
return {}
@property
def tool_meta(self) -> dict:
try:
@ -1199,7 +1215,7 @@ class MessageAgentThought(db.Model):
return {}
except Exception as e:
return {}
@property
def tool_inputs_dict(self) -> dict:
tools = self.tools
@ -1222,7 +1238,7 @@ class MessageAgentThought(db.Model):
}
except Exception as e:
return {}
@property
def tool_outputs_dict(self) -> dict:
tools = self.tools
@ -1249,6 +1265,7 @@ class MessageAgentThought(db.Model):
tool: self.observation for tool in tools
}
class DatasetRetrieverResource(db.Model):
__tablename__ = 'dataset_retriever_resources'
__table_args__ = (
@ -1274,3 +1291,37 @@ class DatasetRetrieverResource(db.Model):
retriever_from = db.Column(db.Text, nullable=False)
created_by = db.Column(UUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
class Tag(db.Model):
__tablename__ = 'tags'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='tag_pkey'),
db.Index('tag_type_idx', 'type'),
db.Index('tag_name_idx', 'name'),
)
TAG_TYPE_LIST = ['knowledge', 'app']
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=True)
type = db.Column(db.String(16), nullable=False)
name = db.Column(db.String(255), nullable=False)
created_by = db.Column(UUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
class TagBinding(db.Model):
__tablename__ = 'tag_bindings'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='tag_binding_pkey'),
db.Index('tag_bind_target_id_idx', 'target_id'),
db.Index('tag_bind_tag_id_idx', 'tag_id'),
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=True)
tag_id = db.Column(UUID, nullable=True)
target_id = db.Column(UUID, nullable=True)
created_by = db.Column(UUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))

View File

@ -21,11 +21,12 @@ from extensions.ext_database import db
from models.account import Account
from models.model import App, AppMode, AppModelConfig
from models.tools import ApiToolProvider
from services.tag_service import TagService
from services.workflow_service import WorkflowService
class AppService:
def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination:
def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination | None:
"""
Get app list with pagination
:param tenant_id: tenant id
@ -49,6 +50,14 @@ class AppService:
if 'name' in args and args['name']:
name = args['name'][:30]
filters.append(App.name.ilike(f'%{name}%'))
if 'tag_ids' in args and args['tag_ids']:
target_ids = TagService.get_target_ids_by_tag_ids('app',
tenant_id,
args['tag_ids'])
if target_ids:
filters.append(App.id.in_(target_ids))
else:
return None
app_models = db.paginate(
db.select(App).where(*filters).order_by(App.created_at.desc()),

View File

@ -38,28 +38,39 @@ from services.errors.dataset import DatasetNameDuplicateError
from services.errors.document import DocumentIndexingError
from services.errors.file import FileNotExistsError
from services.feature_service import FeatureModel, FeatureService
from services.tag_service import TagService
from services.vector_service import VectorService
from tasks.clean_notion_document_task import clean_notion_document_task
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
from tasks.document_indexing_task import document_indexing_task
from tasks.document_indexing_update_task import document_indexing_update_task
from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
from tasks.recover_document_indexing_task import recover_document_indexing_task
from tasks.retry_document_indexing_task import retry_document_indexing_task
class DatasetService:
@staticmethod
def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None):
def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None, search=None, tag_ids=None):
if user:
permission_filter = db.or_(Dataset.created_by == user.id,
Dataset.permission == 'all_team_members')
else:
permission_filter = Dataset.permission == 'all_team_members'
datasets = Dataset.query.filter(
query = Dataset.query.filter(
db.and_(Dataset.provider == provider, Dataset.tenant_id == tenant_id, permission_filter)) \
.order_by(Dataset.created_at.desc()) \
.paginate(
.order_by(Dataset.created_at.desc())
if search:
query = query.filter(db.and_(Dataset.name.ilike(f'%{search}%')))
if tag_ids:
target_ids = TagService.get_target_ids_by_tag_ids('knowledge', tenant_id, tag_ids)
if target_ids:
query = query.filter(db.and_(Dataset.id.in_(target_ids)))
else:
return [], 0
datasets = query.paginate(
page=page,
per_page=per_page,
max_per_page=100,
@ -165,9 +176,36 @@ class DatasetService:
# get embedding model setting
try:
model_manager = ModelManager()
embedding_model = model_manager.get_default_model_instance(
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.TEXT_EMBEDDING
provider=data['embedding_model_provider'],
model_type=ModelType.TEXT_EMBEDDING,
model=data['embedding_model']
)
filtered_data['embedding_model'] = embedding_model.model
filtered_data['embedding_model_provider'] = embedding_model.provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider,
embedding_model.model
)
filtered_data['collection_binding_id'] = dataset_collection_binding.id
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
else:
if data['embedding_model_provider'] != dataset.embedding_model_provider or \
data['embedding_model'] != dataset.embedding_model:
action = 'update'
try:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=data['embedding_model_provider'],
model_type=ModelType.TEXT_EMBEDDING,
model=data['embedding_model']
)
filtered_data['embedding_model'] = embedding_model.model
filtered_data['embedding_model_provider'] = embedding_model.provider
@ -376,6 +414,15 @@ class DocumentService:
return documents
@staticmethod
def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]:
documents = db.session.query(Document).filter(
Document.dataset_id == dataset_id,
Document.indexing_status == 'error' or Document.indexing_status == 'paused'
).all()
return documents
@staticmethod
def get_batch_documents(dataset_id: str, batch: str) -> list[Document]:
documents = db.session.query(Document).filter(
@ -440,6 +487,20 @@ class DocumentService:
# trigger async task
recover_document_indexing_task.delay(document.dataset_id, document.id)
@staticmethod
def retry_document(dataset_id: str, documents: list[Document]):
for document in documents:
# retry document indexing
document.indexing_status = 'waiting'
db.session.add(document)
db.session.commit()
# add retry flag
retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id)
redis_client.setex(retry_indexing_cache_key, 600, 1)
# trigger async task
document_ids = [document.id for document in documents]
retry_document_indexing_task.delay(dataset_id, document_ids)
@staticmethod
def get_documents_position(dataset_id):
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
@ -537,6 +598,7 @@ class DocumentService:
db.session.commit()
position = DocumentService.get_documents_position(dataset.id)
document_ids = []
duplicate_document_ids = []
if document_data["data_source"]["type"] == "upload_file":
upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
for file_id in upload_file_list:
@ -553,6 +615,28 @@ class DocumentService:
data_source_info = {
"upload_file_id": file_id,
}
# check duplicate
if document_data.get('duplicate', False):
document = Document.query.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type='upload_file',
enabled=True,
name=file_name
).first()
if document:
document.dataset_process_rule_id = dataset_process_rule.id
document.updated_at = datetime.datetime.utcnow()
document.created_from = created_from
document.doc_form = document_data['doc_form']
document.doc_language = document_data['doc_language']
document.data_source_info = json.dumps(data_source_info)
document.batch = batch
document.indexing_status = 'waiting'
db.session.add(document)
documents.append(document)
duplicate_document_ids.append(document.id)
continue
document = DocumentService.build_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
@ -618,7 +702,10 @@ class DocumentService:
db.session.commit()
# trigger async task
document_indexing_task.delay(dataset.id, document_ids)
if document_ids:
document_indexing_task.delay(dataset.id, document_ids)
if duplicate_document_ids:
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
return documents, batch
@ -626,7 +713,8 @@ class DocumentService:
def check_documents_upload_quota(count: int, features: FeatureModel):
can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size
if count > can_upload_size:
raise ValueError(f'You have reached the limit of your subscription. Only {can_upload_size} documents can be uploaded.')
raise ValueError(
f'You have reached the limit of your subscription. Only {can_upload_size} documents can be uploaded.')
@staticmethod
def build_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
@ -752,7 +840,6 @@ class DocumentService:
db.session.commit()
# trigger async task
document_indexing_update_task.delay(document.dataset_id, document.id)
return document
@staticmethod

161
api/services/tag_service.py Normal file
View File

@ -0,0 +1,161 @@
import uuid
from flask_login import current_user
from sqlalchemy import func
from werkzeug.exceptions import NotFound
from extensions.ext_database import db
from models.dataset import Dataset
from models.model import App, Tag, TagBinding
class TagService:
@staticmethod
def get_tags(tag_type: str, current_tenant_id: str, keyword: str = None) -> list:
query = db.session.query(
Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label('binding_count')
).outerjoin(
TagBinding, Tag.id == TagBinding.tag_id
).filter(
Tag.type == tag_type,
Tag.tenant_id == current_tenant_id
)
if keyword:
query = query.filter(db.and_(Tag.name.ilike(f'%{keyword}%')))
query = query.group_by(
Tag.id
)
results = query.order_by(Tag.created_at.desc()).all()
return results
@staticmethod
def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list:
tags = db.session.query(Tag).filter(
Tag.id.in_(tag_ids),
Tag.tenant_id == current_tenant_id,
Tag.type == tag_type
).all()
if not tags:
return []
tag_ids = [tag.id for tag in tags]
tag_bindings = db.session.query(
TagBinding.target_id
).filter(
TagBinding.tag_id.in_(tag_ids),
TagBinding.tenant_id == current_tenant_id
).all()
if not tag_bindings:
return []
results = [tag_binding.target_id for tag_binding in tag_bindings]
return results
@staticmethod
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list:
tags = db.session.query(Tag).join(
TagBinding,
Tag.id == TagBinding.tag_id
).filter(
TagBinding.target_id == target_id,
TagBinding.tenant_id == current_tenant_id,
Tag.tenant_id == current_tenant_id,
Tag.type == tag_type
).all()
return tags if tags else []
@staticmethod
def save_tags(args: dict) -> Tag:
tag = Tag(
id=str(uuid.uuid4()),
name=args['name'],
type=args['type'],
created_by=current_user.id,
tenant_id=current_user.current_tenant_id
)
db.session.add(tag)
db.session.commit()
return tag
@staticmethod
def update_tags(args: dict, tag_id: str) -> Tag:
tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
if not tag:
raise NotFound("Tag not found")
tag.name = args['name']
db.session.commit()
return tag
@staticmethod
def get_tag_binding_count(tag_id: str) -> int:
count = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).count()
return count
@staticmethod
def delete_tag(tag_id: str):
tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
if not tag:
raise NotFound("Tag not found")
db.session.delete(tag)
# delete tag binding
tag_bindings = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).all()
if tag_bindings:
for tag_binding in tag_bindings:
db.session.delete(tag_binding)
db.session.commit()
@staticmethod
def save_tag_binding(args):
# check if target exists
TagService.check_target_exists(args['type'], args['target_id'])
# save tag binding
for tag_id in args['tag_ids']:
tag_binding = db.session.query(TagBinding).filter(
TagBinding.tag_id == tag_id,
TagBinding.target_id == args['target_id']
).first()
if tag_binding:
continue
new_tag_binding = TagBinding(
tag_id=tag_id,
target_id=args['target_id'],
tenant_id=current_user.current_tenant_id,
created_by=current_user.id
)
db.session.add(new_tag_binding)
db.session.commit()
@staticmethod
def delete_tag_binding(args):
# check if target exists
TagService.check_target_exists(args['type'], args['target_id'])
# delete tag binding
tag_bindings = db.session.query(TagBinding).filter(
TagBinding.target_id == args['target_id'],
TagBinding.tag_id == (args['tag_id'])
).first()
if tag_bindings:
db.session.delete(tag_bindings)
db.session.commit()
@staticmethod
def check_target_exists(type: str, target_id: str):
if type == 'knowledge':
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == current_user.current_tenant_id,
Dataset.id == target_id
).first()
if not dataset:
raise NotFound("Dataset not found")
elif type == 'app':
app = db.session.query(App).filter(
App.tenant_id == current_user.current_tenant_id,
App.id == target_id
).first()
if not app:
raise NotFound("App not found")
else:
raise NotFound("Invalid binding type")

View File

@ -16,6 +16,7 @@ from models.dataset import (
)
# Add import statement for ValueError
@shared_task(queue='dataset')
def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
index_struct: str, collection_binding_id: str, doc_form: str):
@ -48,6 +49,9 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
logging.info(click.style('No documents found for dataset: {}'.format(dataset_id), fg='green'))
else:
logging.info(click.style('Cleaning documents for dataset: {}'.format(dataset_id), fg='green'))
# Specify the index type before initializing the index processor
if doc_form is None:
raise ValueError("Index type must be specified.")
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, None)

View File

@ -64,6 +64,39 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
# save vector index
index_processor.load(dataset, documents, with_keywords=False)
elif action == 'update':
# clean index
index_processor.clean(dataset, None, with_keywords=False)
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
# add new index
if dataset_documents:
documents = []
for dataset_document in dataset_documents:
# delete from vector index
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True
).order_by(DocumentSegment.position.asc()).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
# save vector index
index_processor.load(dataset, documents, with_keywords=False)
end_at = time.perf_counter()
logging.info(

View File

@ -0,0 +1,94 @@
import datetime
import logging
import time
import click
from celery import shared_task
from flask import current_app
from core.indexing_runner import DocumentIsPausedException, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
from services.feature_service import FeatureService
@shared_task(queue='dataset')
def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
"""
Async process document
:param dataset_id:
:param document_ids:
Usage: duplicate_document_indexing_task.delay(dataset_id, document_id)
"""
documents = []
start_at = time.perf_counter()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
# check document limit
features = FeatureService.get_features(dataset.tenant_id)
try:
if features.billing.enabled:
vector_space = features.vector_space
count = len(document_ids)
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
if 0 < vector_space.limit <= vector_space.size:
raise ValueError("Your total number of documents plus the number of uploads have over the limit of "
"your subscription.")
except Exception as e:
for document_id in document_ids:
document = db.session.query(Document).filter(
Document.id == document_id,
Document.dataset_id == dataset_id
).first()
if document:
document.indexing_status = 'error'
document.error = str(e)
document.stopped_at = datetime.datetime.utcnow()
db.session.add(document)
db.session.commit()
return
for document_id in document_ids:
logging.info(click.style('Start process document: {}'.format(document_id), fg='green'))
document = db.session.query(Document).filter(
Document.id == document_id,
Document.dataset_id == dataset_id
).first()
if document:
# clean old data
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids)
for segment in segments:
db.session.delete(segment)
db.session.commit()
document.indexing_status = 'parsing'
document.processing_started_at = datetime.datetime.utcnow()
documents.append(document)
db.session.add(document)
db.session.commit()
try:
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()
logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green'))
except DocumentIsPausedException as ex:
logging.info(click.style(str(ex), fg='yellow'))
except Exception:
pass

View File

@ -0,0 +1,91 @@
import datetime
import logging
import time
import click
from celery import shared_task
from core.indexing_runner import IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, Document, DocumentSegment
from services.feature_service import FeatureService
@shared_task(queue='dataset')
def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
"""
Async process document
:param dataset_id:
:param document_ids:
Usage: retry_document_indexing_task.delay(dataset_id, document_id)
"""
documents = []
start_at = time.perf_counter()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
for document_id in document_ids:
retry_indexing_cache_key = 'document_{}_is_retried'.format(document_id)
# check document limit
features = FeatureService.get_features(dataset.tenant_id)
try:
if features.billing.enabled:
vector_space = features.vector_space
if 0 < vector_space.limit <= vector_space.size:
raise ValueError("Your total number of documents plus the number of uploads have over the limit of "
"your subscription.")
except Exception as e:
document = db.session.query(Document).filter(
Document.id == document_id,
Document.dataset_id == dataset_id
).first()
if document:
document.indexing_status = 'error'
document.error = str(e)
document.stopped_at = datetime.datetime.utcnow()
db.session.add(document)
db.session.commit()
redis_client.delete(retry_indexing_cache_key)
return
logging.info(click.style('Start retry document: {}'.format(document_id), fg='green'))
document = db.session.query(Document).filter(
Document.id == document_id,
Document.dataset_id == dataset_id
).first()
try:
if document:
# clean old data
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids)
for segment in segments:
db.session.delete(segment)
db.session.commit()
document.indexing_status = 'parsing'
document.processing_started_at = datetime.datetime.utcnow()
db.session.add(document)
db.session.commit()
indexing_runner = IndexingRunner()
indexing_runner.run([document])
redis_client.delete(retry_indexing_cache_key)
except Exception as ex:
document.indexing_status = 'error'
document.error = str(ex)
document.stopped_at = datetime.datetime.utcnow()
db.session.add(document)
db.session.commit()
logging.info(click.style(str(ex), fg='yellow'))
redis_client.delete(retry_indexing_cache_key)
pass
end_at = time.perf_counter()
logging.info(click.style('Retry dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green'))