mirror of
https://github.com/langgenius/dify.git
synced 2026-04-24 12:55:49 +08:00
Merge branch 'feat/queue-based-graph-engine' into feat/rag-2
# Conflicts: # api/core/memory/token_buffer_memory.py # api/core/rag/extractor/notion_extractor.py # api/core/repositories/sqlalchemy_workflow_node_execution_repository.py # api/core/variables/variables.py # api/core/workflow/graph/graph.py # api/core/workflow/graph_engine/entities/event.py # api/services/dataset_service.py # web/app/components/app-sidebar/index.tsx # web/app/components/base/tag-management/selector.tsx # web/app/components/base/toast/index.tsx # web/app/components/datasets/create/website/index.tsx # web/app/components/datasets/create/website/jina-reader/base/options-wrap.tsx # web/app/components/workflow/header/version-history-button.tsx # web/app/components/workflow/hooks/use-inspect-vars-crud-common.ts # web/app/components/workflow/hooks/use-workflow-interactions.ts # web/app/components/workflow/panel/version-history-panel/index.tsx # web/service/base.ts
This commit is contained in:
@ -3,6 +3,7 @@ from typing import Any, Optional
|
||||
|
||||
import orjson
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||
@ -212,11 +213,10 @@ class Jieba(BaseKeyword):
|
||||
return sorted_chunk_indices[:k]
|
||||
|
||||
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
|
||||
document_segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
|
||||
.first()
|
||||
stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id
|
||||
)
|
||||
document_segment = db.session.scalar(stmt)
|
||||
if document_segment:
|
||||
document_segment.keywords = keywords
|
||||
db.session.add(document_segment)
|
||||
|
||||
@ -3,6 +3,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, load_only
|
||||
|
||||
from configs import dify_config
|
||||
@ -24,7 +25,7 @@ default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"top_k": 4,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
@ -127,7 +128,8 @@ class RetrievalService:
|
||||
external_retrieval_model: Optional[dict] = None,
|
||||
metadata_filtering_conditions: Optional[dict] = None,
|
||||
):
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
stmt = select(Dataset).where(Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(stmt)
|
||||
if not dataset:
|
||||
return []
|
||||
metadata_condition = (
|
||||
@ -316,10 +318,8 @@ class RetrievalService:
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
# Handle parent-child documents
|
||||
child_index_node_id = document.metadata.get("doc_id")
|
||||
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id).first()
|
||||
)
|
||||
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
|
||||
child_chunk = db.session.scalar(child_chunk_stmt)
|
||||
|
||||
if not child_chunk:
|
||||
continue
|
||||
@ -378,17 +378,13 @@ class RetrievalService:
|
||||
index_node_id = document.metadata.get("doc_id")
|
||||
if not index_node_id:
|
||||
continue
|
||||
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.index_node_id == index_node_id,
|
||||
)
|
||||
.first()
|
||||
document_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.index_node_id == index_node_id,
|
||||
)
|
||||
segment = db.session.scalar(document_segment_stmt)
|
||||
|
||||
if not segment:
|
||||
continue
|
||||
|
||||
@ -192,8 +192,8 @@ class AnalyticdbVectorOpenAPI:
|
||||
collection=self._collection_name,
|
||||
metrics=self.config.metrics,
|
||||
include_values=True,
|
||||
vector=None,
|
||||
content=None,
|
||||
vector=None, # ty: ignore [invalid-argument-type]
|
||||
content=None, # ty: ignore [invalid-argument-type]
|
||||
top_k=1,
|
||||
filter=f"ref_doc_id='{id}'",
|
||||
)
|
||||
@ -211,7 +211,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None,
|
||||
collection_data=None, # ty: ignore [invalid-argument-type]
|
||||
collection_data_filter=f"ref_doc_id IN {ids_str}",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
@ -225,7 +225,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None,
|
||||
collection_data=None, # ty: ignore [invalid-argument-type]
|
||||
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
@ -249,14 +249,14 @@ class AnalyticdbVectorOpenAPI:
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=query_vector,
|
||||
content=None,
|
||||
content=None, # ty: ignore [invalid-argument-type]
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=where_clause,
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score > score_threshold:
|
||||
if match.score >= score_threshold:
|
||||
metadata = json.loads(match.metadata.get("metadata_"))
|
||||
metadata["score"] = match.score
|
||||
doc = Document(
|
||||
@ -285,7 +285,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
collection=self._collection_name,
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=None,
|
||||
vector=None, # ty: ignore [invalid-argument-type]
|
||||
content=query,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=where_clause,
|
||||
@ -293,7 +293,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score > score_threshold:
|
||||
if match.score >= score_threshold:
|
||||
metadata = json.loads(match.metadata.get("metadata_"))
|
||||
metadata["score"] = match.score
|
||||
doc = Document(
|
||||
|
||||
@ -228,8 +228,8 @@ class AnalyticdbVectorBySql:
|
||||
)
|
||||
documents = []
|
||||
for record in cur:
|
||||
id, vector, score, page_content, metadata = record
|
||||
if score > score_threshold:
|
||||
_, vector, score, page_content, metadata = record
|
||||
if score >= score_threshold:
|
||||
metadata["score"] = score
|
||||
doc = Document(
|
||||
page_content=page_content,
|
||||
@ -260,7 +260,7 @@ class AnalyticdbVectorBySql:
|
||||
)
|
||||
documents = []
|
||||
for record in cur:
|
||||
id, vector, page_content, metadata, score = record
|
||||
_, vector, page_content, metadata, score = record
|
||||
metadata["score"] = score
|
||||
doc = Document(
|
||||
page_content=page_content,
|
||||
|
||||
@ -157,7 +157,7 @@ class BaiduVector(BaseVector):
|
||||
if meta is not None:
|
||||
meta = json.loads(meta)
|
||||
score = row.get("score", 0.0)
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
meta["score"] = score
|
||||
doc = Document(page_content=row_data.get(self.field_text), metadata=meta)
|
||||
docs.append(doc)
|
||||
|
||||
@ -120,7 +120,7 @@ class ChromaVector(BaseVector):
|
||||
distance = distances[index]
|
||||
metadata = dict(metadatas[index])
|
||||
score = 1 - distance
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
metadata["score"] = score
|
||||
doc = Document(
|
||||
page_content=documents[index],
|
||||
|
||||
@ -12,7 +12,7 @@ import clickzetta # type: ignore
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from clickzetta import Connection
|
||||
from clickzetta.connector.v0.connection import Connection # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
@ -701,7 +701,7 @@ class ClickzettaVector(BaseVector):
|
||||
len(data_rows),
|
||||
vector_dimension,
|
||||
)
|
||||
except (RuntimeError, ValueError, TypeError, ConnectionError) as e:
|
||||
except (RuntimeError, ValueError, TypeError, ConnectionError):
|
||||
logger.exception("Parameterized SQL execution failed for %d documents", len(data_rows))
|
||||
logger.exception("SQL template: %s", insert_sql)
|
||||
logger.exception("Sample data row: %s", data_rows[0] if data_rows else "None")
|
||||
@ -787,7 +787,7 @@ class ClickzettaVector(BaseVector):
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
|
||||
# Handle filter parameter from canvas (workflow)
|
||||
filter_param = kwargs.get("filter", {})
|
||||
_ = kwargs.get("filter", {})
|
||||
|
||||
# Build filter clause
|
||||
filter_clauses = []
|
||||
@ -879,7 +879,7 @@ class ClickzettaVector(BaseVector):
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
|
||||
# Handle filter parameter from canvas (workflow)
|
||||
filter_param = kwargs.get("filter", {})
|
||||
_ = kwargs.get("filter", {})
|
||||
|
||||
# Build filter clause
|
||||
filter_clauses = []
|
||||
@ -938,7 +938,7 @@ class ClickzettaVector(BaseVector):
|
||||
metadata = {}
|
||||
else:
|
||||
metadata = {}
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.exception("JSON parsing failed")
|
||||
# Fallback: extract document_id with regex
|
||||
|
||||
@ -956,7 +956,7 @@ class ClickzettaVector(BaseVector):
|
||||
metadata["score"] = 1.0 # Clickzetta doesn't provide relevance scores
|
||||
doc = Document(page_content=row[1], metadata=metadata)
|
||||
documents.append(doc)
|
||||
except (RuntimeError, ValueError, TypeError, ConnectionError) as e:
|
||||
except (RuntimeError, ValueError, TypeError, ConnectionError):
|
||||
logger.exception("Full-text search failed")
|
||||
# Fallback to LIKE search if full-text search fails
|
||||
return self._search_by_like(query, **kwargs)
|
||||
@ -978,7 +978,7 @@ class ClickzettaVector(BaseVector):
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
|
||||
# Handle filter parameter from canvas (workflow)
|
||||
filter_param = kwargs.get("filter", {})
|
||||
_ = kwargs.get("filter", {})
|
||||
|
||||
# Build filter clause
|
||||
filter_clauses = []
|
||||
|
||||
@ -212,10 +212,10 @@ class CouchbaseVector(BaseVector):
|
||||
|
||||
documents_to_insert = [
|
||||
{"text": text, "embedding": vector, "metadata": metadata}
|
||||
for id, text, vector, metadata in zip(uuids, texts, embeddings, metadatas)
|
||||
for _, text, vector, metadata in zip(uuids, texts, embeddings, metadatas)
|
||||
]
|
||||
for doc, id in zip(documents_to_insert, uuids):
|
||||
result = self._scope.collection(self._collection_name).upsert(id, doc)
|
||||
_ = self._scope.collection(self._collection_name).upsert(id, doc)
|
||||
|
||||
doc_ids.extend(uuids)
|
||||
|
||||
@ -241,7 +241,7 @@ class CouchbaseVector(BaseVector):
|
||||
"""
|
||||
try:
|
||||
self._cluster.query(query, named_parameters={"doc_ids": ids}).execute()
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Failed to delete documents, ids: %s", ids)
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
@ -304,9 +304,9 @@ class CouchbaseVector(BaseVector):
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 2)
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
try:
|
||||
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
|
||||
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) # ty: ignore [too-many-positional-arguments]
|
||||
search_iter = self._scope.search(
|
||||
self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"])
|
||||
)
|
||||
|
||||
@ -138,7 +138,7 @@ class ElasticSearchVector(BaseVector):
|
||||
if not client.ping():
|
||||
raise ConnectionError("Failed to connect to Elasticsearch")
|
||||
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
except requests.ConnectionError as e:
|
||||
raise ConnectionError(f"Vector database connection error: {str(e)}")
|
||||
except Exception as e:
|
||||
raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}")
|
||||
@ -216,7 +216,7 @@ class ElasticSearchVector(BaseVector):
|
||||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
if doc.metadata is not None:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
||||
@ -127,7 +127,7 @@ class HuaweiCloudVector(BaseVector):
|
||||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
if doc.metadata is not None:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
||||
@ -275,7 +275,7 @@ class LindormVectorStore(BaseVector):
|
||||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) or 0.0
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
if doc.metadata is not None:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
||||
@ -99,7 +99,7 @@ class MatrixoneVector(BaseVector):
|
||||
return client
|
||||
try:
|
||||
client.create_full_text_index()
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Failed to create full text index")
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
return client
|
||||
|
||||
@ -376,7 +376,12 @@ class MilvusVector(BaseVector):
|
||||
if config.token:
|
||||
client = MilvusClient(uri=config.uri, token=config.token, db_name=config.database)
|
||||
else:
|
||||
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
|
||||
client = MilvusClient(
|
||||
uri=config.uri,
|
||||
user=config.user or "",
|
||||
password=config.password or "",
|
||||
db_name=config.database,
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
|
||||
@ -194,7 +194,7 @@ class OpenGauss(BaseVector):
|
||||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
return docs
|
||||
|
||||
|
||||
@ -197,7 +197,7 @@ class OpenSearchVector(BaseVector):
|
||||
|
||||
try:
|
||||
response = self._client.search(index=self._collection_name.lower(), body=query)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Error executing vector search, query: %s", query)
|
||||
raise
|
||||
|
||||
@ -211,7 +211,7 @@ class OpenSearchVector(BaseVector):
|
||||
|
||||
metadata["score"] = hit["_score"]
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if hit["_score"] > score_threshold:
|
||||
if hit["_score"] >= score_threshold:
|
||||
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
|
||||
@ -261,7 +261,7 @@ class OracleVector(BaseVector):
|
||||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
conn.close()
|
||||
return docs
|
||||
|
||||
@ -202,7 +202,7 @@ class PGVectoRS(BaseVector):
|
||||
score = 1 - dis
|
||||
metadata["score"] = score
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
doc = Document(page_content=record.text, metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@ -195,7 +195,7 @@ class PGVector(BaseVector):
|
||||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
return docs
|
||||
|
||||
|
||||
@ -170,7 +170,7 @@ class VastbaseVector(BaseVector):
|
||||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
return docs
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ import os
|
||||
import uuid
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from itertools import islice
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import qdrant_client
|
||||
from flask import current_app
|
||||
@ -18,6 +18,7 @@ from qdrant_client.http.models import (
|
||||
TokenizerType,
|
||||
)
|
||||
from qdrant_client.local.qdrant_local import QdrantLocal
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
@ -369,7 +370,7 @@ class QdrantVector(BaseVector):
|
||||
continue
|
||||
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
||||
# duplicate check score threshold
|
||||
if result.score > score_threshold:
|
||||
if result.score >= score_threshold:
|
||||
metadata["score"] = result.score
|
||||
doc = Document(
|
||||
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
|
||||
@ -426,7 +427,6 @@ class QdrantVector(BaseVector):
|
||||
|
||||
def _reload_if_needed(self):
|
||||
if isinstance(self._client, QdrantLocal):
|
||||
self._client = cast(QdrantLocal, self._client)
|
||||
self._client._load()
|
||||
|
||||
@classmethod
|
||||
@ -446,11 +446,8 @@ class QdrantVector(BaseVector):
|
||||
class QdrantVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector:
|
||||
if dataset.collection_binding_id:
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.where(DatasetCollectionBinding.id == dataset.collection_binding_id)
|
||||
.one_or_none()
|
||||
)
|
||||
stmt = select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == dataset.collection_binding_id)
|
||||
dataset_collection_binding = db.session.scalars(stmt).one_or_none()
|
||||
if dataset_collection_binding:
|
||||
collection_name = dataset_collection_binding.collection_name
|
||||
else:
|
||||
|
||||
@ -233,7 +233,7 @@ class RelytVector(BaseVector):
|
||||
docs = []
|
||||
for document, score in results:
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if 1 - score > score_threshold:
|
||||
if 1 - score >= score_threshold:
|
||||
docs.append(document)
|
||||
return docs
|
||||
|
||||
|
||||
@ -71,7 +71,7 @@ class TableStoreVector(BaseVector):
|
||||
table_result = result.get_result_by_table(self._table_name)
|
||||
for item in table_result:
|
||||
if item.is_ok and item.row:
|
||||
kv = {k: v for k, v, t in item.row.attribute_columns}
|
||||
kv = {k: v for k, v, _ in item.row.attribute_columns}
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=kv[Field.CONTENT_KEY.value], metadata=json.loads(kv[Field.METADATA_KEY.value])
|
||||
@ -300,7 +300,7 @@ class TableStoreVector(BaseVector):
|
||||
)
|
||||
documents = []
|
||||
for search_hit in search_response.search_hits:
|
||||
if search_hit.score > score_threshold:
|
||||
if search_hit.score >= score_threshold:
|
||||
ots_column_map = {}
|
||||
for col in search_hit.row[1]:
|
||||
ots_column_map[col[0]] = col[1]
|
||||
|
||||
@ -39,6 +39,9 @@ class TencentConfig(BaseModel):
|
||||
return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout}
|
||||
|
||||
|
||||
bm25 = BM25Encoder.default("zh")
|
||||
|
||||
|
||||
class TencentVector(BaseVector):
|
||||
field_id: str = "id"
|
||||
field_vector: str = "vector"
|
||||
@ -53,7 +56,6 @@ class TencentVector(BaseVector):
|
||||
self._dimension = 1024
|
||||
self._init_database()
|
||||
self._load_collection()
|
||||
self._bm25 = BM25Encoder.default("zh")
|
||||
|
||||
def _load_collection(self):
|
||||
"""
|
||||
@ -186,7 +188,7 @@ class TencentVector(BaseVector):
|
||||
metadata=metadata,
|
||||
)
|
||||
if self._enable_hybrid_search:
|
||||
doc.__dict__["sparse_vector"] = self._bm25.encode_texts(texts[i])
|
||||
doc.__dict__["sparse_vector"] = bm25.encode_texts(texts[i])
|
||||
docs.append(doc)
|
||||
self._client.upsert(
|
||||
database_name=self._client_config.database,
|
||||
@ -264,7 +266,7 @@ class TencentVector(BaseVector):
|
||||
match=[
|
||||
KeywordSearch(
|
||||
field_name="sparse_vector",
|
||||
data=self._bm25.encode_queries(query),
|
||||
data=bm25.encode_queries(query),
|
||||
),
|
||||
],
|
||||
rerank=WeightedRerank(
|
||||
@ -291,7 +293,7 @@ class TencentVector(BaseVector):
|
||||
score = 1 - result.get("score", 0.0)
|
||||
else:
|
||||
score = result.get("score", 0.0)
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
meta["score"] = score
|
||||
doc = Document(page_content=result.get(self.field_text), metadata=meta)
|
||||
docs.append(doc)
|
||||
|
||||
@ -20,6 +20,7 @@ from qdrant_client.http.models import (
|
||||
)
|
||||
from qdrant_client.local.qdrant_local import QdrantLocal
|
||||
from requests.auth import HTTPDigestAuth
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
@ -351,7 +352,7 @@ class TidbOnQdrantVector(BaseVector):
|
||||
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
||||
# duplicate check score threshold
|
||||
score_threshold = kwargs.get("score_threshold") or 0.0
|
||||
if result.score > score_threshold:
|
||||
if result.score >= score_threshold:
|
||||
metadata["score"] = result.score
|
||||
doc = Document(
|
||||
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
|
||||
@ -416,16 +417,12 @@ class TidbOnQdrantVector(BaseVector):
|
||||
|
||||
class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector:
|
||||
tidb_auth_binding = (
|
||||
db.session.query(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
|
||||
)
|
||||
stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
||||
tidb_auth_binding = db.session.scalars(stmt).one_or_none()
|
||||
if not tidb_auth_binding:
|
||||
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
|
||||
tidb_auth_binding = (
|
||||
db.session.query(TidbAuthBinding)
|
||||
.where(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
||||
.one_or_none()
|
||||
)
|
||||
stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
||||
tidb_auth_binding = db.session.scalars(stmt).one_or_none()
|
||||
if tidb_auth_binding:
|
||||
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||
|
||||
|
||||
@ -110,7 +110,7 @@ class UpstashVector(BaseVector):
|
||||
score = record.score
|
||||
if metadata is not None and text is not None:
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
return docs
|
||||
|
||||
|
||||
@ -3,6 +3,8 @@ import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
@ -45,11 +47,10 @@ class Vector:
|
||||
vector_type = self._dataset.index_struct_dict["type"]
|
||||
else:
|
||||
if dify_config.VECTOR_STORE_WHITELIST_ENABLE:
|
||||
whitelist = (
|
||||
db.session.query(Whitelist)
|
||||
.where(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db")
|
||||
.one_or_none()
|
||||
stmt = select(Whitelist).where(
|
||||
Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db"
|
||||
)
|
||||
whitelist = db.session.scalars(stmt).one_or_none()
|
||||
if whitelist:
|
||||
vector_type = VectorType.TIDB_ON_QDRANT
|
||||
|
||||
|
||||
@ -32,9 +32,9 @@ class VikingDBConfig(BaseModel):
|
||||
scheme: str
|
||||
connection_timeout: int
|
||||
socket_timeout: int
|
||||
index_type: str = IndexType.HNSW
|
||||
distance: str = DistanceType.L2
|
||||
quant: str = QuantType.Float
|
||||
index_type: str = str(IndexType.HNSW)
|
||||
distance: str = str(DistanceType.L2)
|
||||
quant: str = str(QuantType.Float)
|
||||
|
||||
|
||||
class VikingDBVector(BaseVector):
|
||||
@ -192,7 +192,7 @@ class VikingDBVector(BaseVector):
|
||||
metadata = result.fields.get(vdb_Field.METADATA_KEY.value)
|
||||
if metadata is not None:
|
||||
metadata = json.loads(metadata)
|
||||
if result.score > score_threshold:
|
||||
if result.score >= score_threshold:
|
||||
metadata["score"] = result.score
|
||||
doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
@ -37,22 +37,22 @@ class WeaviateVector(BaseVector):
|
||||
self._attributes = attributes
|
||||
|
||||
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
|
||||
auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
|
||||
auth_config = weaviate.AuthApiKey(api_key=config.api_key or "")
|
||||
|
||||
weaviate.connect.connection.has_grpc = False
|
||||
weaviate.connect.connection.has_grpc = False # ty: ignore [unresolved-attribute]
|
||||
|
||||
# Fix to minimize the performance impact of the deprecation check in weaviate-client 3.24.0,
|
||||
# by changing the connection timeout to pypi.org from 1 second to 0.001 seconds.
|
||||
# TODO: This can be removed once weaviate-client is updated to 3.26.7 or higher,
|
||||
# which does not contain the deprecation check.
|
||||
if hasattr(weaviate.connect.connection, "PYPI_TIMEOUT"):
|
||||
weaviate.connect.connection.PYPI_TIMEOUT = 0.001
|
||||
if hasattr(weaviate.connect.connection, "PYPI_TIMEOUT"): # ty: ignore [unresolved-attribute]
|
||||
weaviate.connect.connection.PYPI_TIMEOUT = 0.001 # ty: ignore [unresolved-attribute]
|
||||
|
||||
try:
|
||||
client = weaviate.Client(
|
||||
url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
except requests.ConnectionError:
|
||||
raise ConnectionError("Vector database connection error")
|
||||
|
||||
client.batch.configure(
|
||||
@ -220,7 +220,7 @@ class WeaviateVector(BaseVector):
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
# check score threshold
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
if doc.metadata is not None:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
@ -41,9 +41,8 @@ class DatasetDocumentStore:
|
||||
|
||||
@property
|
||||
def docs(self) -> dict[str, Document]:
|
||||
document_segments = (
|
||||
db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id).all()
|
||||
)
|
||||
stmt = select(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id)
|
||||
document_segments = db.session.scalars(stmt).all()
|
||||
|
||||
output = {}
|
||||
for document_segment in document_segments:
|
||||
@ -228,10 +227,9 @@ class DatasetDocumentStore:
|
||||
return data
|
||||
|
||||
def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]:
|
||||
document_segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id)
|
||||
.first()
|
||||
stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id
|
||||
)
|
||||
document_segment = db.session.scalar(stmt)
|
||||
|
||||
return document_segment
|
||||
|
||||
@ -107,7 +107,7 @@ class Blob(BaseModel):
|
||||
Blob instance
|
||||
"""
|
||||
if mime_type is None and guess_type:
|
||||
_mimetype = mimetypes.guess_type(path)[0] if guess_type else None
|
||||
_mimetype = mimetypes.guess_type(path)[0]
|
||||
else:
|
||||
_mimetype = mime_type
|
||||
# We do not load the data immediately, instead we treat the blob as a
|
||||
|
||||
@ -45,7 +45,7 @@ class ExtractProcessor:
|
||||
cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False
|
||||
) -> Union[list[Document], str]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file", upload_file=upload_file, document_model="text_model"
|
||||
datasource_type=DatasourceType.FILE.value, upload_file=upload_file, document_model="text_model"
|
||||
)
|
||||
if return_text:
|
||||
delimiter = "\n"
|
||||
@ -76,7 +76,7 @@ class ExtractProcessor:
|
||||
# https://stackoverflow.com/questions/26541416/generate-temporary-file-names-without-creating-actual-file-in-python#comment90414256_26541521
|
||||
file_path = f"{temp_dir}/{tempfile.gettempdir()}{suffix}"
|
||||
Path(file_path).write_bytes(response.content)
|
||||
extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model")
|
||||
extract_setting = ExtractSetting(datasource_type=DatasourceType.FILE.value, document_model="text_model")
|
||||
if return_text:
|
||||
delimiter = "\n"
|
||||
return delimiter.join(
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional, cast
|
||||
from typing import Optional
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.extractor.helpers import detect_file_encodings
|
||||
@ -76,7 +76,7 @@ class MarkdownExtractor(BaseExtractor):
|
||||
markdown_tups.append((current_header, current_text))
|
||||
|
||||
markdown_tups = [
|
||||
(re.sub(r"#", "", cast(str, key)).strip() if key else None, re.sub(r"<.*?>", "", value))
|
||||
(re.sub(r"#", "", key).strip() if key else None, re.sub(r"<.*?>", "", value))
|
||||
for key, value in markdown_tups
|
||||
]
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import contextlib
|
||||
from collections.abc import Iterator
|
||||
from typing import Optional, cast
|
||||
from typing import Optional
|
||||
|
||||
from core.rag.extractor.blob.blob import Blob
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
@ -27,7 +27,7 @@ class PdfExtractor(BaseExtractor):
|
||||
plaintext_file_exists = False
|
||||
if self._file_cache_key:
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8")
|
||||
text = storage.load(self._file_cache_key).decode("utf-8")
|
||||
plaintext_file_exists = True
|
||||
return [Document(page_content=text)]
|
||||
documents = list(self.load())
|
||||
|
||||
@ -23,7 +23,7 @@ class UnstructuredWordExtractor(BaseExtractor):
|
||||
unstructured_version = tuple(int(x) for x in __unstructured_version__.split("."))
|
||||
# check the file extension
|
||||
try:
|
||||
import magic # noqa: F401
|
||||
import magic # noqa: F401 # pyright: ignore[reportUnusedImport]
|
||||
|
||||
is_doc = detect_filetype(self._file_path) == FileType.DOC
|
||||
except ImportError:
|
||||
|
||||
@ -34,6 +34,7 @@ class BaseIndexProcessor(ABC):
|
||||
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -127,7 +127,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
for result in results:
|
||||
metadata = result.metadata
|
||||
metadata["score"] = result.score
|
||||
if result.score > score_threshold:
|
||||
if result.score >= score_threshold:
|
||||
doc = Document(page_content=result.page_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@ -135,13 +135,16 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
if delete_child_chunks:
|
||||
db.session.query(ChildChunk).where(
|
||||
ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids)
|
||||
).delete()
|
||||
).delete(synchronize_session=False)
|
||||
db.session.commit()
|
||||
else:
|
||||
vector.delete()
|
||||
|
||||
if delete_child_chunks:
|
||||
db.session.query(ChildChunk).where(ChildChunk.dataset_id == dataset.id).delete()
|
||||
# Use existing compound index: (tenant_id, dataset_id, ...)
|
||||
db.session.query(ChildChunk).where(
|
||||
ChildChunk.tenant_id == dataset.tenant_id, ChildChunk.dataset_id == dataset.id
|
||||
).delete(synchronize_session=False)
|
||||
db.session.commit()
|
||||
|
||||
def retrieve(
|
||||
@ -167,7 +170,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
for result in results:
|
||||
metadata = result.metadata
|
||||
metadata["score"] = result.score
|
||||
if result.score > score_threshold:
|
||||
if result.score >= score_threshold:
|
||||
doc = Document(page_content=result.page_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@ -117,7 +117,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
# Skip the first row
|
||||
df = pd.read_csv(file)
|
||||
text_docs = []
|
||||
for index, row in df.iterrows():
|
||||
for _, row in df.iterrows():
|
||||
data = Document(page_content=row.iloc[0], metadata={"answer": row.iloc[1]})
|
||||
text_docs.append(data)
|
||||
if len(text_docs) == 0:
|
||||
@ -162,7 +162,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
for result in results:
|
||||
metadata = result.metadata
|
||||
metadata["score"] = result.score
|
||||
if result.score > score_threshold:
|
||||
if result.score >= score_threshold:
|
||||
doc = Document(page_content=result.page_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
@ -221,7 +221,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
qa_document.metadata["doc_hash"] = hash
|
||||
qa_documents.append(qa_document)
|
||||
format_documents.extend(qa_documents)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Failed to format qa document")
|
||||
|
||||
all_qa_documents.extend(format_documents)
|
||||
|
||||
@ -7,9 +7,8 @@ from collections.abc import Generator, Mapping
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import Float, and_, or_, text
|
||||
from sqlalchemy import Float, and_, or_, select, text
|
||||
from sqlalchemy import cast as sqlalchemy_cast
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
@ -65,7 +64,7 @@ default_retrieval_model: dict[str, Any] = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"top_k": 4,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
@ -135,7 +134,8 @@ class DatasetRetrieval:
|
||||
available_datasets = []
|
||||
for dataset_id in dataset_ids:
|
||||
# get dataset from dataset id
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(dataset_stmt)
|
||||
|
||||
# pass if dataset is not available
|
||||
if not dataset:
|
||||
@ -240,15 +240,12 @@ class DatasetRetrieval:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||
document = (
|
||||
db.session.query(DatasetDocument)
|
||||
.where(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
.first()
|
||||
dataset_document_stmt = select(DatasetDocument).where(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
document = db.session.scalar(dataset_document_stmt)
|
||||
if dataset and document:
|
||||
source = RetrievalSourceMetadata(
|
||||
dataset_id=dataset.id,
|
||||
@ -327,7 +324,8 @@ class DatasetRetrieval:
|
||||
|
||||
if dataset_id:
|
||||
# get retrieval model config
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(dataset_stmt)
|
||||
if dataset:
|
||||
results = []
|
||||
if dataset.provider == "external":
|
||||
@ -514,24 +512,20 @@ class DatasetRetrieval:
|
||||
dify_documents = [document for document in documents if document.provider == "dify"]
|
||||
for document in dify_documents:
|
||||
if document.metadata is not None:
|
||||
dataset_document = (
|
||||
db.session.query(DatasetDocument)
|
||||
.where(DatasetDocument.id == document.metadata["document_id"])
|
||||
.first()
|
||||
dataset_document_stmt = select(DatasetDocument).where(
|
||||
DatasetDocument.id == document.metadata["document_id"]
|
||||
)
|
||||
dataset_document = db.session.scalar(dataset_document_stmt)
|
||||
if dataset_document:
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
ChildChunk.document_id == dataset_document.id,
|
||||
)
|
||||
.first()
|
||||
child_chunk_stmt = select(ChildChunk).where(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
ChildChunk.document_id == dataset_document.id,
|
||||
)
|
||||
child_chunk = db.session.scalar(child_chunk_stmt)
|
||||
if child_chunk:
|
||||
segment = (
|
||||
_ = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == child_chunk.segment_id)
|
||||
.update(
|
||||
@ -539,7 +533,6 @@ class DatasetRetrieval:
|
||||
synchronize_session=False,
|
||||
)
|
||||
)
|
||||
db.session.commit()
|
||||
else:
|
||||
query = db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.index_node_id == document.metadata["doc_id"]
|
||||
@ -599,8 +592,8 @@ class DatasetRetrieval:
|
||||
metadata_condition: Optional[MetadataCondition] = None,
|
||||
):
|
||||
with flask_app.app_context():
|
||||
with Session(db.engine) as session:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(dataset_stmt)
|
||||
|
||||
if not dataset:
|
||||
return []
|
||||
@ -647,7 +640,7 @@ class DatasetRetrieval:
|
||||
retrieval_method=retrieval_model["search_method"],
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=retrieval_model.get("top_k") or 2,
|
||||
top_k=retrieval_model.get("top_k") or 4,
|
||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||
if retrieval_model["score_threshold_enabled"]
|
||||
else 0.0,
|
||||
@ -685,7 +678,8 @@ class DatasetRetrieval:
|
||||
available_datasets = []
|
||||
for dataset_id in dataset_ids:
|
||||
# get dataset from dataset id
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(dataset_stmt)
|
||||
|
||||
# pass if dataset is not available
|
||||
if not dataset:
|
||||
@ -743,7 +737,7 @@ class DatasetRetrieval:
|
||||
tool = DatasetMultiRetrieverTool.from_dataset(
|
||||
dataset_ids=[dataset.id for dataset in available_datasets],
|
||||
tenant_id=tenant_id,
|
||||
top_k=retrieve_config.top_k or 2,
|
||||
top_k=retrieve_config.top_k or 4,
|
||||
score_threshold=retrieve_config.score_threshold,
|
||||
hit_callbacks=[hit_callback],
|
||||
return_resource=return_resource,
|
||||
@ -958,7 +952,8 @@ class DatasetRetrieval:
|
||||
self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
|
||||
) -> Optional[list[dict[str, Any]]]:
|
||||
# get all metadata field
|
||||
metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
|
||||
metadata_stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
|
||||
metadata_fields = db.session.scalars(metadata_stmt).all()
|
||||
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
|
||||
# get metadata model config
|
||||
if metadata_model_config is None:
|
||||
@ -990,7 +985,7 @@ class DatasetRetrieval:
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
result_text, usage = self._handle_invoke_result(invoke_result=invoke_result)
|
||||
result_text, _ = self._handle_invoke_result(invoke_result=invoke_result)
|
||||
|
||||
result_text_json = parse_and_check_json_markdown(result_text, [])
|
||||
automatic_metadata_filters = []
|
||||
@ -1005,7 +1000,7 @@ class DatasetRetrieval:
|
||||
"condition": item.get("comparison_operator"),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return None
|
||||
return automatic_metadata_filters
|
||||
|
||||
|
||||
@ -19,5 +19,5 @@ class StructuredChatOutputParser:
|
||||
return ReactAction(response["action"], response.get("action_input", {}), text)
|
||||
else:
|
||||
return ReactFinish({"output": text}, text)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise ValueError(f"Could not parse LLM output: {text}")
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Union, cast
|
||||
from typing import Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.model_manager import ModelInstance
|
||||
@ -28,18 +28,15 @@ class FunctionCallMultiDatasetRouter:
|
||||
SystemPromptMessage(content="You are a helpful AI assistant."),
|
||||
UserPromptMessage(content=query),
|
||||
]
|
||||
result = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
tools=dataset_tools,
|
||||
stream=False,
|
||||
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
|
||||
),
|
||||
result: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
tools=dataset_tools,
|
||||
stream=False,
|
||||
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
|
||||
)
|
||||
if result.message.tool_calls:
|
||||
# get retrieval model config
|
||||
return result.message.tool_calls[0].function.name
|
||||
return None
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Union, cast
|
||||
from typing import Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.model_manager import ModelInstance
|
||||
@ -77,7 +77,7 @@ class ReactMultiDatasetRouter:
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _react_invoke(
|
||||
@ -120,7 +120,7 @@ class ReactMultiDatasetRouter:
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
)
|
||||
result_text, usage = self._invoke_llm(
|
||||
result_text, _ = self._invoke_llm(
|
||||
completion_param=model_config.parameters,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
@ -150,15 +150,12 @@ class ReactMultiDatasetRouter:
|
||||
:param stop: stop
|
||||
:return:
|
||||
"""
|
||||
invoke_result = cast(
|
||||
Generator[LLMResult, None, None],
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=completion_param,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
user=user_id,
|
||||
),
|
||||
invoke_result: Generator[LLMResult, None, None] = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=completion_param,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
|
||||
Reference in New Issue
Block a user