mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
Merge branch 'main' into feat/r2
This commit is contained in:
@ -10,6 +10,7 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.embedding.retrieval import RetrievalSegments
|
||||
from core.rag.entities.metadata_entities import MetadataCondition
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.rerank.rerank_type import RerankMode
|
||||
@ -119,12 +120,25 @@ class RetrievalService:
|
||||
return all_documents
|
||||
|
||||
@classmethod
|
||||
def external_retrieve(cls, dataset_id: str, query: str, external_retrieval_model: Optional[dict] = None):
|
||||
def external_retrieve(
|
||||
cls,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
external_retrieval_model: Optional[dict] = None,
|
||||
metadata_filtering_conditions: Optional[dict] = None,
|
||||
):
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
return []
|
||||
metadata_condition = (
|
||||
MetadataCondition(**metadata_filtering_conditions) if metadata_filtering_conditions else None
|
||||
)
|
||||
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
dataset.tenant_id, dataset_id, query, external_retrieval_model or {}
|
||||
dataset.tenant_id,
|
||||
dataset_id,
|
||||
query,
|
||||
external_retrieval_model or {},
|
||||
metadata_condition=metadata_condition,
|
||||
)
|
||||
return all_documents
|
||||
|
||||
|
||||
@ -156,8 +156,8 @@ class AnalyticdbVectorBySql:
|
||||
values = []
|
||||
id_prefix = str(uuid.uuid4()) + "_"
|
||||
sql = f"""
|
||||
INSERT INTO {self.table_name}
|
||||
(id, ref_doc_id, vector, page_content, metadata_, to_tsvector)
|
||||
INSERT INTO {self.table_name}
|
||||
(id, ref_doc_id, vector, page_content, metadata_, to_tsvector)
|
||||
VALUES (%s, %s, %s, %s, %s, to_tsvector('zh_cn', %s));
|
||||
"""
|
||||
for i, doc in enumerate(documents):
|
||||
@ -242,7 +242,7 @@ class AnalyticdbVectorBySql:
|
||||
where_clause += f"AND metadata_->>'document_id' IN ({document_ids})"
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"""SELECT id, vector, page_content, metadata_,
|
||||
f"""SELECT id, vector, page_content, metadata_,
|
||||
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
|
||||
FROM {self.table_name}
|
||||
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') {where_clause}
|
||||
|
||||
@ -27,8 +27,8 @@ class MilvusConfig(BaseModel):
|
||||
|
||||
uri: str # Milvus server URI
|
||||
token: Optional[str] = None # Optional token for authentication
|
||||
user: str # Username for authentication
|
||||
password: str # Password for authentication
|
||||
user: Optional[str] = None # Username for authentication
|
||||
password: Optional[str] = None # Password for authentication
|
||||
batch_size: int = 100 # Batch size for operations
|
||||
database: str = "default" # Database name
|
||||
enable_hybrid_search: bool = False # Flag to enable hybrid search
|
||||
@ -43,10 +43,11 @@ class MilvusConfig(BaseModel):
|
||||
"""
|
||||
if not values.get("uri"):
|
||||
raise ValueError("config MILVUS_URI is required")
|
||||
if not values.get("user"):
|
||||
raise ValueError("config MILVUS_USER is required")
|
||||
if not values.get("password"):
|
||||
raise ValueError("config MILVUS_PASSWORD is required")
|
||||
if not values.get("token"):
|
||||
if not values.get("user"):
|
||||
raise ValueError("config MILVUS_USER is required")
|
||||
if not values.get("password"):
|
||||
raise ValueError("config MILVUS_PASSWORD is required")
|
||||
return values
|
||||
|
||||
def to_milvus_params(self):
|
||||
@ -356,11 +357,14 @@ class MilvusVector(BaseVector):
|
||||
)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _init_client(self, config) -> MilvusClient:
|
||||
def _init_client(self, config: MilvusConfig) -> MilvusClient:
|
||||
"""
|
||||
Initialize and return a Milvus client.
|
||||
"""
|
||||
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
|
||||
if config.token:
|
||||
client = MilvusClient(uri=config.uri, token=config.token, db_name=config.database)
|
||||
else:
|
||||
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
|
||||
return client
|
||||
|
||||
|
||||
|
||||
@ -203,7 +203,7 @@ class OceanBaseVector(BaseVector):
|
||||
|
||||
full_sql = f"""SELECT metadata, text, MATCH (text) AGAINST (:query) AS score
|
||||
FROM {self._collection_name}
|
||||
WHERE MATCH (text) AGAINST (:query) > 0
|
||||
WHERE MATCH (text) AGAINST (:query) > 0
|
||||
{where_clause}
|
||||
ORDER BY score DESC
|
||||
LIMIT {top_k}"""
|
||||
|
||||
@ -59,12 +59,12 @@ CREATE TABLE IF NOT EXISTS {table_name} (
|
||||
"""
|
||||
|
||||
SQL_CREATE_INDEX_PQ = """
|
||||
CREATE INDEX IF NOT EXISTS embedding_{table_name}_pq_idx ON {table_name}
|
||||
CREATE INDEX IF NOT EXISTS embedding_{table_name}_pq_idx ON {table_name}
|
||||
USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64, enable_pq=on, pq_m={pq_m});
|
||||
"""
|
||||
|
||||
SQL_CREATE_INDEX = """
|
||||
CREATE INDEX IF NOT EXISTS embedding_cosine_{table_name}_idx ON {table_name}
|
||||
CREATE INDEX IF NOT EXISTS embedding_cosine_{table_name}_idx ON {table_name}
|
||||
USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64);
|
||||
"""
|
||||
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
import json
|
||||
import logging
|
||||
import ssl
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Literal, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from opensearchpy import OpenSearch, helpers
|
||||
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
|
||||
from opensearchpy.helpers import BulkIndexError
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
@ -24,9 +23,12 @@ logger = logging.getLogger(__name__)
|
||||
class OpenSearchConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
secure: bool = False
|
||||
auth_method: Literal["basic", "aws_managed_iam"] = "basic"
|
||||
user: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
secure: bool = False
|
||||
aws_region: Optional[str] = None
|
||||
aws_service: Optional[str] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@ -35,24 +37,40 @@ class OpenSearchConfig(BaseModel):
|
||||
raise ValueError("config OPENSEARCH_HOST is required")
|
||||
if not values.get("port"):
|
||||
raise ValueError("config OPENSEARCH_PORT is required")
|
||||
if values.get("auth_method") == "aws_managed_iam":
|
||||
if not values.get("aws_region"):
|
||||
raise ValueError("config OPENSEARCH_AWS_REGION is required for AWS_MANAGED_IAM auth method")
|
||||
if not values.get("aws_service"):
|
||||
raise ValueError("config OPENSEARCH_AWS_SERVICE is required for AWS_MANAGED_IAM auth method")
|
||||
return values
|
||||
|
||||
def create_ssl_context(self) -> ssl.SSLContext:
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE # Disable Certificate Validation
|
||||
return ssl_context
|
||||
def create_aws_managed_iam_auth(self) -> Urllib3AWSV4SignerAuth:
|
||||
import boto3 # type: ignore
|
||||
|
||||
return Urllib3AWSV4SignerAuth(
|
||||
credentials=boto3.Session().get_credentials(),
|
||||
region=self.aws_region,
|
||||
service=self.aws_service, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
def to_opensearch_params(self) -> dict[str, Any]:
|
||||
params = {
|
||||
"hosts": [{"host": self.host, "port": self.port}],
|
||||
"use_ssl": self.secure,
|
||||
"verify_certs": self.secure,
|
||||
"connection_class": Urllib3HttpConnection,
|
||||
"pool_maxsize": 20,
|
||||
}
|
||||
if self.user and self.password:
|
||||
|
||||
if self.auth_method == "basic":
|
||||
logger.info("Using basic authentication for OpenSearch Vector DB")
|
||||
|
||||
params["http_auth"] = (self.user, self.password)
|
||||
if self.secure:
|
||||
params["ssl_context"] = self.create_ssl_context()
|
||||
elif self.auth_method == "aws_managed_iam":
|
||||
logger.info("Using AWS managed IAM role for OpenSearch Vector DB")
|
||||
|
||||
params["http_auth"] = self.create_aws_managed_iam_auth()
|
||||
|
||||
return params
|
||||
|
||||
|
||||
@ -76,16 +94,23 @@ class OpenSearchVector(BaseVector):
|
||||
action = {
|
||||
"_op_type": "index",
|
||||
"_index": self._collection_name.lower(),
|
||||
"_id": uuid4().hex,
|
||||
"_source": {
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
|
||||
Field.METADATA_KEY.value: documents[i].metadata,
|
||||
},
|
||||
}
|
||||
# See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377
|
||||
if self._client_config.aws_service not in ["aoss"]:
|
||||
action["_id"] = uuid4().hex
|
||||
actions.append(action)
|
||||
|
||||
helpers.bulk(self._client, actions)
|
||||
helpers.bulk(
|
||||
client=self._client,
|
||||
actions=actions,
|
||||
timeout=30,
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}}
|
||||
@ -234,6 +259,7 @@ class OpenSearchVector(BaseVector):
|
||||
},
|
||||
}
|
||||
|
||||
logger.info(f"Creating OpenSearch index {self._collection_name.lower()}")
|
||||
self._client.indices.create(index=self._collection_name.lower(), body=index_body)
|
||||
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
@ -252,9 +278,12 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
|
||||
open_search_config = OpenSearchConfig(
|
||||
host=dify_config.OPENSEARCH_HOST or "localhost",
|
||||
port=dify_config.OPENSEARCH_PORT,
|
||||
secure=dify_config.OPENSEARCH_SECURE,
|
||||
auth_method=dify_config.OPENSEARCH_AUTH_METHOD.value,
|
||||
user=dify_config.OPENSEARCH_USER,
|
||||
password=dify_config.OPENSEARCH_PASSWORD,
|
||||
secure=dify_config.OPENSEARCH_SECURE,
|
||||
aws_region=dify_config.OPENSEARCH_AWS_REGION,
|
||||
aws_service=dify_config.OPENSEARCH_AWS_SERVICE,
|
||||
)
|
||||
|
||||
return OpenSearchVector(collection_name=collection_name, config=open_search_config)
|
||||
|
||||
@ -59,8 +59,8 @@ CREATE TABLE IF NOT EXISTS {table_name} (
|
||||
)
|
||||
"""
|
||||
SQL_CREATE_INDEX = """
|
||||
CREATE INDEX IF NOT EXISTS idx_docs_{table_name} ON {table_name}(text)
|
||||
INDEXTYPE IS CTXSYS.CONTEXT PARAMETERS
|
||||
CREATE INDEX IF NOT EXISTS idx_docs_{table_name} ON {table_name}(text)
|
||||
INDEXTYPE IS CTXSYS.CONTEXT PARAMETERS
|
||||
('FILTER CTXSYS.NULL_FILTER SECTION GROUP CTXSYS.HTML_SECTION_GROUP LEXER world_lexer')
|
||||
"""
|
||||
|
||||
@ -164,7 +164,7 @@ class OracleVector(BaseVector):
|
||||
with conn.cursor() as cur:
|
||||
try:
|
||||
cur.execute(
|
||||
f"""INSERT INTO {self.table_name} (id, text, meta, embedding)
|
||||
f"""INSERT INTO {self.table_name} (id, text, meta, embedding)
|
||||
VALUES (:1, :2, :3, :4)""",
|
||||
value,
|
||||
)
|
||||
@ -227,8 +227,8 @@ class OracleVector(BaseVector):
|
||||
conn.outputtypehandler = self.output_type_handler
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
f"""SELECT meta, text, vector_distance(embedding,(select to_vector(:1) from dual),cosine)
|
||||
AS distance FROM {self.table_name}
|
||||
f"""SELECT meta, text, vector_distance(embedding,(select to_vector(:1) from dual),cosine)
|
||||
AS distance FROM {self.table_name}
|
||||
{where_clause} ORDER BY distance fetch first {top_k} rows only""",
|
||||
[numpy.array(query_vector)],
|
||||
)
|
||||
@ -290,7 +290,7 @@ class OracleVector(BaseVector):
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
|
||||
cur.execute(
|
||||
f"""select meta, text, embedding FROM {self.table_name}
|
||||
f"""select meta, text, embedding FROM {self.table_name}
|
||||
WHERE CONTAINS(text, :kk, 1) > 0 {where_clause}
|
||||
order by score(1) desc fetch first {top_k} rows only""",
|
||||
kk=" ACCUM ".join(entities),
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
@ -61,12 +62,12 @@ CREATE TABLE IF NOT EXISTS {table_name} (
|
||||
"""
|
||||
|
||||
SQL_CREATE_INDEX = """
|
||||
CREATE INDEX IF NOT EXISTS embedding_cosine_v1_idx ON {table_name}
|
||||
CREATE INDEX IF NOT EXISTS embedding_cosine_v1_idx_{index_hash} ON {table_name}
|
||||
USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64);
|
||||
"""
|
||||
|
||||
SQL_CREATE_INDEX_PG_BIGM = """
|
||||
CREATE INDEX IF NOT EXISTS bigm_idx ON {table_name}
|
||||
CREATE INDEX IF NOT EXISTS bigm_idx_{index_hash} ON {table_name}
|
||||
USING gin (text gin_bigm_ops);
|
||||
"""
|
||||
|
||||
@ -76,6 +77,7 @@ class PGVector(BaseVector):
|
||||
super().__init__(collection_name)
|
||||
self.pool = self._create_connection_pool(config)
|
||||
self.table_name = f"embedding_{collection_name}"
|
||||
self.index_hash = hashlib.md5(self.table_name.encode()).hexdigest()[:8]
|
||||
self.pg_bigm = config.pg_bigm
|
||||
|
||||
def get_type(self) -> str:
|
||||
@ -256,10 +258,9 @@ class PGVector(BaseVector):
|
||||
# PG hnsw index only support 2000 dimension or less
|
||||
# ref: https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
|
||||
if dimension <= 2000:
|
||||
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
|
||||
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name, index_hash=self.index_hash))
|
||||
if self.pg_bigm:
|
||||
cur.execute("CREATE EXTENSION IF NOT EXISTS pg_bigm")
|
||||
cur.execute(SQL_CREATE_INDEX_PG_BIGM.format(table_name=self.table_name))
|
||||
cur.execute(SQL_CREATE_INDEX_PG_BIGM.format(table_name=self.table_name, index_hash=self.index_hash))
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
|
||||
|
||||
@ -58,7 +58,7 @@ CREATE TABLE IF NOT EXISTS {table_name} (
|
||||
"""
|
||||
|
||||
SQL_CREATE_INDEX = """
|
||||
CREATE INDEX IF NOT EXISTS embedding_cosine_v1_idx ON {table_name}
|
||||
CREATE INDEX IF NOT EXISTS embedding_cosine_v1_idx ON {table_name}
|
||||
USING hnsw (embedding floatvector_cosine_ops) WITH (m = 16, ef_construction = 64);
|
||||
"""
|
||||
|
||||
|
||||
@ -205,9 +205,9 @@ class TiDBVector(BaseVector):
|
||||
|
||||
with Session(self._engine) as session:
|
||||
select_statement = sql_text(f"""
|
||||
SELECT meta, text, distance
|
||||
SELECT meta, text, distance
|
||||
FROM (
|
||||
SELECT
|
||||
SELECT
|
||||
meta,
|
||||
text,
|
||||
{tidb_dist_func}(vector, :query_vector_str) AS distance
|
||||
|
||||
@ -317,7 +317,7 @@ class NotionExtractor(BaseExtractor):
|
||||
data_source_info["last_edited_time"] = last_edited_time
|
||||
update_params = {DocumentModel.data_source_info: json.dumps(data_source_info)}
|
||||
|
||||
DocumentModel.query.filter_by(id=document_model.id).update(update_params)
|
||||
db.session.query(DocumentModel).filter_by(id=document_model.id).update(update_params)
|
||||
db.session.commit()
|
||||
|
||||
def get_notion_last_edited_time(self) -> str:
|
||||
@ -347,14 +347,18 @@ class NotionExtractor(BaseExtractor):
|
||||
|
||||
@classmethod
|
||||
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
|
||||
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"',
|
||||
data_source_binding = (
|
||||
db.session.query(DataSourceOauthBinding)
|
||||
.filter(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"',
|
||||
)
|
||||
)
|
||||
).first()
|
||||
.first()
|
||||
)
|
||||
|
||||
if not data_source_binding:
|
||||
raise Exception(
|
||||
|
||||
@ -76,8 +76,7 @@ class WordExtractor(BaseExtractor):
|
||||
parsed = urlparse(url)
|
||||
return bool(parsed.netloc) and bool(parsed.scheme)
|
||||
|
||||
def _extract_images_from_docx(self, doc, image_folder):
|
||||
os.makedirs(image_folder, exist_ok=True)
|
||||
def _extract_images_from_docx(self, doc):
|
||||
image_count = 0
|
||||
image_map = {}
|
||||
|
||||
@ -210,7 +209,7 @@ class WordExtractor(BaseExtractor):
|
||||
|
||||
content = []
|
||||
|
||||
image_map = self._extract_images_from_docx(doc, image_folder)
|
||||
image_map = self._extract_images_from_docx(doc)
|
||||
|
||||
hyperlinks_url = None
|
||||
url_pattern = re.compile(r"http://[^\s+]+//|https://[^\s+]+")
|
||||
@ -225,7 +224,7 @@ class WordExtractor(BaseExtractor):
|
||||
xml = ElementTree.XML(run.element.xml)
|
||||
x_child = [c for c in xml.iter() if c is not None]
|
||||
for x in x_child:
|
||||
if x_child is None:
|
||||
if x is None:
|
||||
continue
|
||||
if x.tag.endswith("instrText"):
|
||||
if x.text is None:
|
||||
|
||||
@ -52,14 +52,16 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
rerank_documents = []
|
||||
|
||||
for result in rerank_result.docs:
|
||||
# format document
|
||||
rerank_document = Document(
|
||||
page_content=result.text,
|
||||
metadata=documents[result.index].metadata,
|
||||
provider=documents[result.index].provider,
|
||||
)
|
||||
if rerank_document.metadata is not None:
|
||||
rerank_document.metadata["score"] = result.score
|
||||
rerank_documents.append(rerank_document)
|
||||
if score_threshold is None or result.score >= score_threshold:
|
||||
# format document
|
||||
rerank_document = Document(
|
||||
page_content=result.text,
|
||||
metadata=documents[result.index].metadata,
|
||||
provider=documents[result.index].provider,
|
||||
)
|
||||
if rerank_document.metadata is not None:
|
||||
rerank_document.metadata["score"] = result.score
|
||||
rerank_documents.append(rerank_document)
|
||||
|
||||
return rerank_documents
|
||||
rerank_documents.sort(key=lambda x: x.metadata.get("score", 0.0), reverse=True)
|
||||
return rerank_documents[:top_n] if top_n else rerank_documents
|
||||
|
||||
@ -7,7 +7,7 @@ from collections.abc import Generator, Mapping
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import Integer, and_, or_, text
|
||||
from sqlalchemy import Float, and_, or_, text
|
||||
from sqlalchemy import cast as sqlalchemy_cast
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
@ -149,7 +149,7 @@ class DatasetRetrieval:
|
||||
else:
|
||||
inputs = {}
|
||||
available_datasets_ids = [dataset.id for dataset in available_datasets]
|
||||
metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition(
|
||||
metadata_filter_document_ids, metadata_condition = self.get_metadata_filter_condition(
|
||||
available_datasets_ids,
|
||||
query,
|
||||
tenant_id,
|
||||
@ -237,12 +237,16 @@ class DatasetRetrieval:
|
||||
if show_retrieve_source:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
|
||||
document = DatasetDocument.query.filter(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).first()
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||
document = (
|
||||
db.session.query(DatasetDocument)
|
||||
.filter(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if dataset and document:
|
||||
source = {
|
||||
"dataset_id": dataset.id,
|
||||
@ -506,19 +510,30 @@ class DatasetRetrieval:
|
||||
dify_documents = [document for document in documents if document.provider == "dify"]
|
||||
for document in dify_documents:
|
||||
if document.metadata is not None:
|
||||
dataset_document = DatasetDocument.query.filter(
|
||||
DatasetDocument.id == document.metadata["document_id"]
|
||||
).first()
|
||||
dataset_document = (
|
||||
db.session.query(DatasetDocument)
|
||||
.filter(DatasetDocument.id == document.metadata["document_id"])
|
||||
.first()
|
||||
)
|
||||
if dataset_document:
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunk = ChildChunk.query.filter(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
ChildChunk.document_id == dataset_document.id,
|
||||
).first()
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk)
|
||||
.filter(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
ChildChunk.document_id == dataset_document.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if child_chunk:
|
||||
segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.id == child_chunk.segment_id)
|
||||
.update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False,
|
||||
)
|
||||
)
|
||||
db.session.commit()
|
||||
else:
|
||||
@ -649,6 +664,8 @@ class DatasetRetrieval:
|
||||
return_resource: bool,
|
||||
invoke_from: InvokeFrom,
|
||||
hit_callback: DatasetIndexToolCallbackHandler,
|
||||
user_id: str,
|
||||
inputs: dict,
|
||||
) -> Optional[list[DatasetRetrieverBaseTool]]:
|
||||
"""
|
||||
A dataset tool is a tool that can be used to retrieve information from a dataset
|
||||
@ -706,6 +723,9 @@ class DatasetRetrieval:
|
||||
hit_callbacks=[hit_callback],
|
||||
return_resource=return_resource,
|
||||
retriever_from=invoke_from.to_source(),
|
||||
retrieve_config=retrieve_config,
|
||||
user_id=user_id,
|
||||
inputs=inputs,
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
@ -826,7 +846,7 @@ class DatasetRetrieval:
|
||||
)
|
||||
return filter_documents[:top_k] if top_k else filter_documents
|
||||
|
||||
def _get_metadata_filter_condition(
|
||||
def get_metadata_filter_condition(
|
||||
self,
|
||||
dataset_ids: list,
|
||||
query: str,
|
||||
@ -876,20 +896,31 @@ class DatasetRetrieval:
|
||||
)
|
||||
elif metadata_filtering_mode == "manual":
|
||||
if metadata_filtering_conditions:
|
||||
metadata_condition = MetadataCondition(**metadata_filtering_conditions.model_dump())
|
||||
conditions = []
|
||||
for sequence, condition in enumerate(metadata_filtering_conditions.conditions): # type: ignore
|
||||
metadata_name = condition.name
|
||||
expected_value = condition.value
|
||||
if expected_value is not None or condition.comparison_operator in ("empty", "not empty"):
|
||||
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
|
||||
if isinstance(expected_value, str):
|
||||
expected_value = self._replace_metadata_filter_value(expected_value, inputs)
|
||||
filters = self._process_metadata_filter_func(
|
||||
sequence,
|
||||
condition.comparison_operator,
|
||||
metadata_name,
|
||||
expected_value,
|
||||
filters,
|
||||
conditions.append(
|
||||
Condition(
|
||||
name=metadata_name,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
value=expected_value,
|
||||
)
|
||||
)
|
||||
filters = self._process_metadata_filter_func(
|
||||
sequence,
|
||||
condition.comparison_operator,
|
||||
metadata_name,
|
||||
expected_value,
|
||||
filters,
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=metadata_filtering_conditions.logical_operator,
|
||||
conditions=conditions,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid metadata filtering mode")
|
||||
if filters:
|
||||
@ -1005,28 +1036,24 @@ class DatasetRetrieval:
|
||||
if isinstance(value, str):
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name] == f'"{value}"')
|
||||
else:
|
||||
filters.append(
|
||||
sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) == value
|
||||
)
|
||||
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) == value)
|
||||
case "is not" | "≠":
|
||||
if isinstance(value, str):
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name] != f'"{value}"')
|
||||
else:
|
||||
filters.append(
|
||||
sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) != value
|
||||
)
|
||||
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) != value)
|
||||
case "empty":
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].is_(None))
|
||||
case "not empty":
|
||||
filters.append(DatasetDocument.doc_metadata[metadata_name].isnot(None))
|
||||
case "before" | "<":
|
||||
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) < value)
|
||||
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) < value)
|
||||
case "after" | ">":
|
||||
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) > value)
|
||||
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) > value)
|
||||
case "≤" | "<=":
|
||||
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) <= value)
|
||||
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) <= value)
|
||||
case "≥" | ">=":
|
||||
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) >= value)
|
||||
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) >= value)
|
||||
case _:
|
||||
pass
|
||||
return filters
|
||||
|
||||
@ -2,7 +2,7 @@ METADATA_FILTER_SYSTEM_PROMPT = """
|
||||
### Job Description',
|
||||
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
|
||||
### Task
|
||||
Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
|
||||
Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty", "=", "≠", ">", "<", "≥", "≤", "before", "after"] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
|
||||
### Format
|
||||
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
|
||||
### Constraint
|
||||
@ -50,7 +50,7 @@ You are a text metadata extract engine that extract text's metadata based on use
|
||||
# Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
|
||||
### Format
|
||||
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
|
||||
### Constraint
|
||||
### Constraint
|
||||
DO NOT include anything other than the JSON array in your response.
|
||||
### Example
|
||||
Here is the chat example between human and assistant, inside <example></example> XML tags.
|
||||
@ -59,7 +59,7 @@ User:{{"input_text": ["I want to know which company’s email address test@examp
|
||||
Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}}
|
||||
User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}}
|
||||
Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}}
|
||||
</example>
|
||||
</example>
|
||||
### User Input
|
||||
{{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}}
|
||||
### Assistant Output
|
||||
|
||||
@ -159,50 +159,6 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
)
|
||||
return cls(length_function=lambda x: [_huggingface_tokenizer_length(text) for text in x], **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_tiktoken_encoder(
|
||||
cls: type[TS],
|
||||
encoding_name: str = "gpt2",
|
||||
model_name: Optional[str] = None,
|
||||
allowed_special: Union[Literal["all"], Set[str]] = set(),
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||
**kwargs: Any,
|
||||
) -> TS:
|
||||
"""Text splitter that uses tiktoken encoder to count length."""
|
||||
try:
|
||||
import tiktoken
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import tiktoken python package. "
|
||||
"This is needed in order to calculate max_tokens_for_prompt. "
|
||||
"Please install it with `pip install tiktoken`."
|
||||
)
|
||||
|
||||
if model_name is not None:
|
||||
enc = tiktoken.encoding_for_model(model_name)
|
||||
else:
|
||||
enc = tiktoken.get_encoding(encoding_name)
|
||||
|
||||
def _tiktoken_encoder(text: str) -> int:
|
||||
return len(
|
||||
enc.encode(
|
||||
text,
|
||||
allowed_special=allowed_special,
|
||||
disallowed_special=disallowed_special,
|
||||
)
|
||||
)
|
||||
|
||||
if issubclass(cls, TokenTextSplitter):
|
||||
extra_kwargs = {
|
||||
"encoding_name": encoding_name,
|
||||
"model_name": model_name,
|
||||
"allowed_special": allowed_special,
|
||||
"disallowed_special": disallowed_special,
|
||||
}
|
||||
kwargs = {**kwargs, **extra_kwargs}
|
||||
|
||||
return cls(length_function=lambda x: [_tiktoken_encoder(text) for text in x], **kwargs)
|
||||
|
||||
def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
|
||||
"""Transform sequence of documents by splitting them."""
|
||||
return self.split_documents(list(documents))
|
||||
|
||||
Reference in New Issue
Block a user