Merge branch 'main' into feat/r2

This commit is contained in:
jyong
2025-05-15 15:15:23 +08:00
1025 changed files with 17699 additions and 4959 deletions

View File

@ -10,6 +10,7 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.embedding.retrieval import RetrievalSegments
from core.rag.entities.metadata_entities import MetadataCondition
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.models.document import Document
from core.rag.rerank.rerank_type import RerankMode
@ -119,12 +120,25 @@ class RetrievalService:
return all_documents
@classmethod
def external_retrieve(cls, dataset_id: str, query: str, external_retrieval_model: Optional[dict] = None):
def external_retrieve(
cls,
dataset_id: str,
query: str,
external_retrieval_model: Optional[dict] = None,
metadata_filtering_conditions: Optional[dict] = None,
):
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
return []
metadata_condition = (
MetadataCondition(**metadata_filtering_conditions) if metadata_filtering_conditions else None
)
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
dataset.tenant_id, dataset_id, query, external_retrieval_model or {}
dataset.tenant_id,
dataset_id,
query,
external_retrieval_model or {},
metadata_condition=metadata_condition,
)
return all_documents

View File

@ -156,8 +156,8 @@ class AnalyticdbVectorBySql:
values = []
id_prefix = str(uuid.uuid4()) + "_"
sql = f"""
INSERT INTO {self.table_name}
(id, ref_doc_id, vector, page_content, metadata_, to_tsvector)
INSERT INTO {self.table_name}
(id, ref_doc_id, vector, page_content, metadata_, to_tsvector)
VALUES (%s, %s, %s, %s, %s, to_tsvector('zh_cn', %s));
"""
for i, doc in enumerate(documents):
@ -242,7 +242,7 @@ class AnalyticdbVectorBySql:
where_clause += f"AND metadata_->>'document_id' IN ({document_ids})"
with self._get_cursor() as cur:
cur.execute(
f"""SELECT id, vector, page_content, metadata_,
f"""SELECT id, vector, page_content, metadata_,
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
FROM {self.table_name}
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') {where_clause}

View File

@ -27,8 +27,8 @@ class MilvusConfig(BaseModel):
uri: str # Milvus server URI
token: Optional[str] = None # Optional token for authentication
user: str # Username for authentication
password: str # Password for authentication
user: Optional[str] = None # Username for authentication
password: Optional[str] = None # Password for authentication
batch_size: int = 100 # Batch size for operations
database: str = "default" # Database name
enable_hybrid_search: bool = False # Flag to enable hybrid search
@ -43,10 +43,11 @@ class MilvusConfig(BaseModel):
"""
if not values.get("uri"):
raise ValueError("config MILVUS_URI is required")
if not values.get("user"):
raise ValueError("config MILVUS_USER is required")
if not values.get("password"):
raise ValueError("config MILVUS_PASSWORD is required")
if not values.get("token"):
if not values.get("user"):
raise ValueError("config MILVUS_USER is required")
if not values.get("password"):
raise ValueError("config MILVUS_PASSWORD is required")
return values
def to_milvus_params(self):
@ -356,11 +357,14 @@ class MilvusVector(BaseVector):
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _init_client(self, config) -> MilvusClient:
def _init_client(self, config: MilvusConfig) -> MilvusClient:
"""
Initialize and return a Milvus client.
"""
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
if config.token:
client = MilvusClient(uri=config.uri, token=config.token, db_name=config.database)
else:
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
return client

View File

@ -203,7 +203,7 @@ class OceanBaseVector(BaseVector):
full_sql = f"""SELECT metadata, text, MATCH (text) AGAINST (:query) AS score
FROM {self._collection_name}
WHERE MATCH (text) AGAINST (:query) > 0
WHERE MATCH (text) AGAINST (:query) > 0
{where_clause}
ORDER BY score DESC
LIMIT {top_k}"""

View File

@ -59,12 +59,12 @@ CREATE TABLE IF NOT EXISTS {table_name} (
"""
SQL_CREATE_INDEX_PQ = """
CREATE INDEX IF NOT EXISTS embedding_{table_name}_pq_idx ON {table_name}
CREATE INDEX IF NOT EXISTS embedding_{table_name}_pq_idx ON {table_name}
USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64, enable_pq=on, pq_m={pq_m});
"""
SQL_CREATE_INDEX = """
CREATE INDEX IF NOT EXISTS embedding_cosine_{table_name}_idx ON {table_name}
CREATE INDEX IF NOT EXISTS embedding_cosine_{table_name}_idx ON {table_name}
USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64);
"""

View File

@ -1,10 +1,9 @@
import json
import logging
import ssl
from typing import Any, Optional
from typing import Any, Literal, Optional
from uuid import uuid4
from opensearchpy import OpenSearch, helpers
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator
@ -24,9 +23,12 @@ logger = logging.getLogger(__name__)
class OpenSearchConfig(BaseModel):
host: str
port: int
secure: bool = False
auth_method: Literal["basic", "aws_managed_iam"] = "basic"
user: Optional[str] = None
password: Optional[str] = None
secure: bool = False
aws_region: Optional[str] = None
aws_service: Optional[str] = None
@model_validator(mode="before")
@classmethod
@ -35,24 +37,40 @@ class OpenSearchConfig(BaseModel):
raise ValueError("config OPENSEARCH_HOST is required")
if not values.get("port"):
raise ValueError("config OPENSEARCH_PORT is required")
if values.get("auth_method") == "aws_managed_iam":
if not values.get("aws_region"):
raise ValueError("config OPENSEARCH_AWS_REGION is required for AWS_MANAGED_IAM auth method")
if not values.get("aws_service"):
raise ValueError("config OPENSEARCH_AWS_SERVICE is required for AWS_MANAGED_IAM auth method")
return values
def create_ssl_context(self) -> ssl.SSLContext:
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE # Disable Certificate Validation
return ssl_context
def create_aws_managed_iam_auth(self) -> Urllib3AWSV4SignerAuth:
import boto3 # type: ignore
return Urllib3AWSV4SignerAuth(
credentials=boto3.Session().get_credentials(),
region=self.aws_region,
service=self.aws_service, # type: ignore[arg-type]
)
def to_opensearch_params(self) -> dict[str, Any]:
params = {
"hosts": [{"host": self.host, "port": self.port}],
"use_ssl": self.secure,
"verify_certs": self.secure,
"connection_class": Urllib3HttpConnection,
"pool_maxsize": 20,
}
if self.user and self.password:
if self.auth_method == "basic":
logger.info("Using basic authentication for OpenSearch Vector DB")
params["http_auth"] = (self.user, self.password)
if self.secure:
params["ssl_context"] = self.create_ssl_context()
elif self.auth_method == "aws_managed_iam":
logger.info("Using AWS managed IAM role for OpenSearch Vector DB")
params["http_auth"] = self.create_aws_managed_iam_auth()
return params
@ -76,16 +94,23 @@ class OpenSearchVector(BaseVector):
action = {
"_op_type": "index",
"_index": self._collection_name.lower(),
"_id": uuid4().hex,
"_source": {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY.value: documents[i].metadata,
},
}
# See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377
if self._client_config.aws_service not in ["aoss"]:
action["_id"] = uuid4().hex
actions.append(action)
helpers.bulk(self._client, actions)
helpers.bulk(
client=self._client,
actions=actions,
timeout=30,
max_retries=3,
)
def get_ids_by_metadata_field(self, key: str, value: str):
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}}
@ -234,6 +259,7 @@ class OpenSearchVector(BaseVector):
},
}
logger.info(f"Creating OpenSearch index {self._collection_name.lower()}")
self._client.indices.create(index=self._collection_name.lower(), body=index_body)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
@ -252,9 +278,12 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
open_search_config = OpenSearchConfig(
host=dify_config.OPENSEARCH_HOST or "localhost",
port=dify_config.OPENSEARCH_PORT,
secure=dify_config.OPENSEARCH_SECURE,
auth_method=dify_config.OPENSEARCH_AUTH_METHOD.value,
user=dify_config.OPENSEARCH_USER,
password=dify_config.OPENSEARCH_PASSWORD,
secure=dify_config.OPENSEARCH_SECURE,
aws_region=dify_config.OPENSEARCH_AWS_REGION,
aws_service=dify_config.OPENSEARCH_AWS_SERVICE,
)
return OpenSearchVector(collection_name=collection_name, config=open_search_config)

View File

@ -59,8 +59,8 @@ CREATE TABLE IF NOT EXISTS {table_name} (
)
"""
SQL_CREATE_INDEX = """
CREATE INDEX IF NOT EXISTS idx_docs_{table_name} ON {table_name}(text)
INDEXTYPE IS CTXSYS.CONTEXT PARAMETERS
CREATE INDEX IF NOT EXISTS idx_docs_{table_name} ON {table_name}(text)
INDEXTYPE IS CTXSYS.CONTEXT PARAMETERS
('FILTER CTXSYS.NULL_FILTER SECTION GROUP CTXSYS.HTML_SECTION_GROUP LEXER world_lexer')
"""
@ -164,7 +164,7 @@ class OracleVector(BaseVector):
with conn.cursor() as cur:
try:
cur.execute(
f"""INSERT INTO {self.table_name} (id, text, meta, embedding)
f"""INSERT INTO {self.table_name} (id, text, meta, embedding)
VALUES (:1, :2, :3, :4)""",
value,
)
@ -227,8 +227,8 @@ class OracleVector(BaseVector):
conn.outputtypehandler = self.output_type_handler
with conn.cursor() as cur:
cur.execute(
f"""SELECT meta, text, vector_distance(embedding,(select to_vector(:1) from dual),cosine)
AS distance FROM {self.table_name}
f"""SELECT meta, text, vector_distance(embedding,(select to_vector(:1) from dual),cosine)
AS distance FROM {self.table_name}
{where_clause} ORDER BY distance fetch first {top_k} rows only""",
[numpy.array(query_vector)],
)
@ -290,7 +290,7 @@ class OracleVector(BaseVector):
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
cur.execute(
f"""select meta, text, embedding FROM {self.table_name}
f"""select meta, text, embedding FROM {self.table_name}
WHERE CONTAINS(text, :kk, 1) > 0 {where_clause}
order by score(1) desc fetch first {top_k} rows only""",
kk=" ACCUM ".join(entities),

View File

@ -1,3 +1,4 @@
import hashlib
import json
import logging
import uuid
@ -61,12 +62,12 @@ CREATE TABLE IF NOT EXISTS {table_name} (
"""
SQL_CREATE_INDEX = """
CREATE INDEX IF NOT EXISTS embedding_cosine_v1_idx ON {table_name}
CREATE INDEX IF NOT EXISTS embedding_cosine_v1_idx_{index_hash} ON {table_name}
USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64);
"""
SQL_CREATE_INDEX_PG_BIGM = """
CREATE INDEX IF NOT EXISTS bigm_idx ON {table_name}
CREATE INDEX IF NOT EXISTS bigm_idx_{index_hash} ON {table_name}
USING gin (text gin_bigm_ops);
"""
@ -76,6 +77,7 @@ class PGVector(BaseVector):
super().__init__(collection_name)
self.pool = self._create_connection_pool(config)
self.table_name = f"embedding_{collection_name}"
self.index_hash = hashlib.md5(self.table_name.encode()).hexdigest()[:8]
self.pg_bigm = config.pg_bigm
def get_type(self) -> str:
@ -256,10 +258,9 @@ class PGVector(BaseVector):
# PG hnsw index only support 2000 dimension or less
# ref: https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
if dimension <= 2000:
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name, index_hash=self.index_hash))
if self.pg_bigm:
cur.execute("CREATE EXTENSION IF NOT EXISTS pg_bigm")
cur.execute(SQL_CREATE_INDEX_PG_BIGM.format(table_name=self.table_name))
cur.execute(SQL_CREATE_INDEX_PG_BIGM.format(table_name=self.table_name, index_hash=self.index_hash))
redis_client.set(collection_exist_cache_key, 1, ex=3600)

View File

@ -58,7 +58,7 @@ CREATE TABLE IF NOT EXISTS {table_name} (
"""
SQL_CREATE_INDEX = """
CREATE INDEX IF NOT EXISTS embedding_cosine_v1_idx ON {table_name}
CREATE INDEX IF NOT EXISTS embedding_cosine_v1_idx ON {table_name}
USING hnsw (embedding floatvector_cosine_ops) WITH (m = 16, ef_construction = 64);
"""

View File

@ -205,9 +205,9 @@ class TiDBVector(BaseVector):
with Session(self._engine) as session:
select_statement = sql_text(f"""
SELECT meta, text, distance
SELECT meta, text, distance
FROM (
SELECT
SELECT
meta,
text,
{tidb_dist_func}(vector, :query_vector_str) AS distance

View File

@ -317,7 +317,7 @@ class NotionExtractor(BaseExtractor):
data_source_info["last_edited_time"] = last_edited_time
update_params = {DocumentModel.data_source_info: json.dumps(data_source_info)}
DocumentModel.query.filter_by(id=document_model.id).update(update_params)
db.session.query(DocumentModel).filter_by(id=document_model.id).update(update_params)
db.session.commit()
def get_notion_last_edited_time(self) -> str:
@ -347,14 +347,18 @@ class NotionExtractor(BaseExtractor):
@classmethod
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"',
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
db.and_(
DataSourceOauthBinding.tenant_id == tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"',
)
)
).first()
.first()
)
if not data_source_binding:
raise Exception(

View File

@ -76,8 +76,7 @@ class WordExtractor(BaseExtractor):
parsed = urlparse(url)
return bool(parsed.netloc) and bool(parsed.scheme)
def _extract_images_from_docx(self, doc, image_folder):
os.makedirs(image_folder, exist_ok=True)
def _extract_images_from_docx(self, doc):
image_count = 0
image_map = {}
@ -210,7 +209,7 @@ class WordExtractor(BaseExtractor):
content = []
image_map = self._extract_images_from_docx(doc, image_folder)
image_map = self._extract_images_from_docx(doc)
hyperlinks_url = None
url_pattern = re.compile(r"http://[^\s+]+//|https://[^\s+]+")
@ -225,7 +224,7 @@ class WordExtractor(BaseExtractor):
xml = ElementTree.XML(run.element.xml)
x_child = [c for c in xml.iter() if c is not None]
for x in x_child:
if x_child is None:
if x is None:
continue
if x.tag.endswith("instrText"):
if x.text is None:

View File

@ -52,14 +52,16 @@ class RerankModelRunner(BaseRerankRunner):
rerank_documents = []
for result in rerank_result.docs:
# format document
rerank_document = Document(
page_content=result.text,
metadata=documents[result.index].metadata,
provider=documents[result.index].provider,
)
if rerank_document.metadata is not None:
rerank_document.metadata["score"] = result.score
rerank_documents.append(rerank_document)
if score_threshold is None or result.score >= score_threshold:
# format document
rerank_document = Document(
page_content=result.text,
metadata=documents[result.index].metadata,
provider=documents[result.index].provider,
)
if rerank_document.metadata is not None:
rerank_document.metadata["score"] = result.score
rerank_documents.append(rerank_document)
return rerank_documents
rerank_documents.sort(key=lambda x: x.metadata.get("score", 0.0), reverse=True)
return rerank_documents[:top_n] if top_n else rerank_documents

View File

@ -7,7 +7,7 @@ from collections.abc import Generator, Mapping
from typing import Any, Optional, Union, cast
from flask import Flask, current_app
from sqlalchemy import Integer, and_, or_, text
from sqlalchemy import Float, and_, or_, text
from sqlalchemy import cast as sqlalchemy_cast
from core.app.app_config.entities import (
@ -149,7 +149,7 @@ class DatasetRetrieval:
else:
inputs = {}
available_datasets_ids = [dataset.id for dataset in available_datasets]
metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition(
metadata_filter_document_ids, metadata_condition = self.get_metadata_filter_condition(
available_datasets_ids,
query,
tenant_id,
@ -237,12 +237,16 @@ class DatasetRetrieval:
if show_retrieve_source:
for record in records:
segment = record.segment
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = DatasetDocument.query.filter(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).first()
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = (
db.session.query(DatasetDocument)
.filter(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.first()
)
if dataset and document:
source = {
"dataset_id": dataset.id,
@ -506,19 +510,30 @@ class DatasetRetrieval:
dify_documents = [document for document in documents if document.provider == "dify"]
for document in dify_documents:
if document.metadata is not None:
dataset_document = DatasetDocument.query.filter(
DatasetDocument.id == document.metadata["document_id"]
).first()
dataset_document = (
db.session.query(DatasetDocument)
.filter(DatasetDocument.id == document.metadata["document_id"])
.first()
)
if dataset_document:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = ChildChunk.query.filter(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
).first()
child_chunk = (
db.session.query(ChildChunk)
.filter(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
.first()
)
if child_chunk:
segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == child_chunk.segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
)
)
db.session.commit()
else:
@ -649,6 +664,8 @@ class DatasetRetrieval:
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler,
user_id: str,
inputs: dict,
) -> Optional[list[DatasetRetrieverBaseTool]]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
@ -706,6 +723,9 @@ class DatasetRetrieval:
hit_callbacks=[hit_callback],
return_resource=return_resource,
retriever_from=invoke_from.to_source(),
retrieve_config=retrieve_config,
user_id=user_id,
inputs=inputs,
)
tools.append(tool)
@ -826,7 +846,7 @@ class DatasetRetrieval:
)
return filter_documents[:top_k] if top_k else filter_documents
def _get_metadata_filter_condition(
def get_metadata_filter_condition(
self,
dataset_ids: list,
query: str,
@ -876,20 +896,31 @@ class DatasetRetrieval:
)
elif metadata_filtering_mode == "manual":
if metadata_filtering_conditions:
metadata_condition = MetadataCondition(**metadata_filtering_conditions.model_dump())
conditions = []
for sequence, condition in enumerate(metadata_filtering_conditions.conditions): # type: ignore
metadata_name = condition.name
expected_value = condition.value
if expected_value is not None or condition.comparison_operator in ("empty", "not empty"):
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
if isinstance(expected_value, str):
expected_value = self._replace_metadata_filter_value(expected_value, inputs)
filters = self._process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
expected_value,
filters,
conditions.append(
Condition(
name=metadata_name,
comparison_operator=condition.comparison_operator,
value=expected_value,
)
)
filters = self._process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
expected_value,
filters,
)
metadata_condition = MetadataCondition(
logical_operator=metadata_filtering_conditions.logical_operator,
conditions=conditions,
)
else:
raise ValueError("Invalid metadata filtering mode")
if filters:
@ -1005,28 +1036,24 @@ class DatasetRetrieval:
if isinstance(value, str):
filters.append(DatasetDocument.doc_metadata[metadata_name] == f'"{value}"')
else:
filters.append(
sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) == value
)
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) == value)
case "is not" | "":
if isinstance(value, str):
filters.append(DatasetDocument.doc_metadata[metadata_name] != f'"{value}"')
else:
filters.append(
sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) != value
)
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) != value)
case "empty":
filters.append(DatasetDocument.doc_metadata[metadata_name].is_(None))
case "not empty":
filters.append(DatasetDocument.doc_metadata[metadata_name].isnot(None))
case "before" | "<":
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) < value)
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) < value)
case "after" | ">":
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) > value)
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) > value)
case "" | "<=":
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) <= value)
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) <= value)
case "" | ">=":
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) >= value)
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) >= value)
case _:
pass
return filters

View File

@ -2,7 +2,7 @@ METADATA_FILTER_SYSTEM_PROMPT = """
### Job Description',
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
### Task
Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty", "=", "", ">", "<", "", "", "before", "after"] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
### Format
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
### Constraint
@ -50,7 +50,7 @@ You are a text metadata extract engine that extract text's metadata based on use
# Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
### Format
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
### Constraint
### Constraint
DO NOT include anything other than the JSON array in your response.
### Example
Here is the chat example between human and assistant, inside <example></example> XML tags.
@ -59,7 +59,7 @@ User:{{"input_text": ["I want to know which companys email address test@examp
Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}}
User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}}
Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}}
</example>
</example>
### User Input
{{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}}
### Assistant Output

View File

@ -159,50 +159,6 @@ class TextSplitter(BaseDocumentTransformer, ABC):
)
return cls(length_function=lambda x: [_huggingface_tokenizer_length(text) for text in x], **kwargs)
@classmethod
def from_tiktoken_encoder(
cls: type[TS],
encoding_name: str = "gpt2",
model_name: Optional[str] = None,
allowed_special: Union[Literal["all"], Set[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
) -> TS:
"""Text splitter that uses tiktoken encoder to count length."""
try:
import tiktoken
except ImportError:
raise ImportError(
"Could not import tiktoken python package. "
"This is needed in order to calculate max_tokens_for_prompt. "
"Please install it with `pip install tiktoken`."
)
if model_name is not None:
enc = tiktoken.encoding_for_model(model_name)
else:
enc = tiktoken.get_encoding(encoding_name)
def _tiktoken_encoder(text: str) -> int:
return len(
enc.encode(
text,
allowed_special=allowed_special,
disallowed_special=disallowed_special,
)
)
if issubclass(cls, TokenTextSplitter):
extra_kwargs = {
"encoding_name": encoding_name,
"model_name": model_name,
"allowed_special": allowed_special,
"disallowed_special": disallowed_special,
}
kwargs = {**kwargs, **extra_kwargs}
return cls(length_function=lambda x: [_tiktoken_encoder(text) for text in x], **kwargs)
def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
"""Transform sequence of documents by splitting them."""
return self.split_documents(list(documents))