refactor: tool

This commit is contained in:
Yeuoly
2024-09-20 02:25:14 +08:00
parent c472ea6c67
commit 661392eaef
524 changed files with 338 additions and 31279 deletions

View File

@ -0,0 +1,193 @@
import threading
from flask import Flask, current_app
from pydantic import BaseModel, Field
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.rerank.rerank_model import RerankModelRunner
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"score_threshold_enabled": False,
}
class DatasetMultiRetrieverToolInput(BaseModel):
query: str = Field(..., description="dataset multi retriever and rerank")
class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
"""Tool for querying multi dataset."""
name: str = "dataset_"
args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput
description: str = "dataset multi retriever and rerank. "
dataset_ids: list[str]
reranking_provider_name: str
reranking_model_name: str
@classmethod
def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs):
return cls(
name=f"dataset_{tenant_id.replace('-', '_')}", tenant_id=tenant_id, dataset_ids=dataset_ids, **kwargs
)
def _run(self, query: str) -> str:
threads = []
all_documents = []
for dataset_id in self.dataset_ids:
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
"flask_app": current_app._get_current_object(),
"dataset_id": dataset_id,
"query": query,
"all_documents": all_documents,
"hit_callbacks": self.hit_callbacks,
},
)
threads.append(retrieval_thread)
retrieval_thread.start()
for thread in threads:
thread.join()
# do rerank for searched documents
model_manager = ModelManager()
rerank_model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id,
provider=self.reranking_provider_name,
model_type=ModelType.RERANK,
model=self.reranking_model_name,
)
rerank_runner = RerankModelRunner(rerank_model_instance)
all_documents = rerank_runner.run(query, all_documents, self.score_threshold, self.top_k)
for hit_callback in self.hit_callbacks:
hit_callback.on_tool_end(all_documents)
document_score_list = {}
for item in all_documents:
if item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
document_context_list = []
index_node_ids = [document.metadata["doc_id"] for document in all_documents]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
).all()
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
sorted_segments = sorted(
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
)
for segment in sorted_segments:
if segment.answer:
document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}")
else:
document_context_list.append(segment.get_sign_content())
if self.return_resource:
context_list = []
resource_number = 1
for segment in sorted_segments:
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = Document.query.filter(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
).first()
if dataset and document:
source = {
"position": resource_number,
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"document_id": document.id,
"document_name": document.name,
"data_source_type": document.data_source_type,
"segment_id": segment.id,
"retriever_from": self.retriever_from,
"score": document_score_list.get(segment.index_node_id, None),
}
if self.retriever_from == "dev":
source["hit_count"] = segment.hit_count
source["word_count"] = segment.word_count
source["segment_position"] = segment.position
source["index_node_hash"] = segment.index_node_hash
if segment.answer:
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
else:
source["content"] = segment.content
context_list.append(source)
resource_number += 1
for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list)
return str("\n".join(document_context_list))
def _retriever(
self,
flask_app: Flask,
dataset_id: str,
query: str,
all_documents: list,
hit_callbacks: list[DatasetIndexToolCallbackHandler],
):
with flask_app.app_context():
dataset = (
db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first()
)
if not dataset:
return []
for hit_callback in hit_callbacks:
hit_callback.on_query(query, dataset.id)
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model or default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
retrieval_method="keyword_search",
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k") or 2,
)
if documents:
all_documents.extend(documents)
else:
if self.top_k > 0:
# retrieval source
documents = RetrievalService.retrieve(
retrieval_method=retrieval_model["search_method"],
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k") or 2,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
reranking_model=retrieval_model.get("reranking_model", None)
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
)
all_documents.extend(documents)

View File

@ -0,0 +1,33 @@
from abc import abstractmethod
from typing import Any, Optional
from msal_extensions.persistence import ABC
from pydantic import BaseModel, ConfigDict
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
class DatasetRetrieverBaseTool(BaseModel, ABC):
"""Tool for querying a Dataset."""
name: str = "dataset"
description: str = "use this to retrieve a dataset. "
tenant_id: str
top_k: int = 2
score_threshold: Optional[float] = None
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
return_resource: bool
retriever_from: str
model_config = ConfigDict(arbitrary_types_allowed=True)
@abstractmethod
def _run(
self,
*args: Any,
**kwargs: Any,
) -> Any:
"""Use the tool.
Add run_manager: Optional[CallbackManagerForToolRun] = None
to child implementations to enable tracing,
"""

View File

@ -0,0 +1,149 @@
from pydantic import BaseModel, Field
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"reranking_mode": "reranking_model",
"top_k": 2,
"score_threshold_enabled": False,
}
class DatasetRetrieverToolInput(BaseModel):
query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
class DatasetRetrieverTool(DatasetRetrieverBaseTool):
"""Tool for querying a Dataset."""
name: str = "dataset"
args_schema: type[BaseModel] = DatasetRetrieverToolInput
description: str = "use this to retrieve a dataset. "
dataset_id: str
@classmethod
def from_dataset(cls, dataset: Dataset, **kwargs):
description = dataset.description
if not description:
description = "useful for when you want to answer queries about the " + dataset.name
description = description.replace("\n", "").replace("\r", "")
return cls(
name=f"dataset_{dataset.id.replace('-', '_')}",
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
description=description,
**kwargs,
)
def _run(self, query: str) -> str:
dataset = (
db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first()
)
if not dataset:
return ""
for hit_callback in self.hit_callbacks:
hit_callback.on_query(query, dataset.id)
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model or default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k
)
return str("\n".join([document.page_content for document in documents]))
else:
if self.top_k > 0:
# retrieval source
documents = RetrievalService.retrieve(
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
dataset_id=dataset.id,
query=query,
top_k=self.top_k,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
reranking_model=retrieval_model.get("reranking_model", None)
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
)
else:
documents = []
for hit_callback in self.hit_callbacks:
hit_callback.on_tool_end(documents)
document_score_list = {}
if dataset.indexing_technique != "economy":
for item in documents:
if item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
document_context_list = []
index_node_ids = [document.metadata["doc_id"] for document in documents]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id == self.dataset_id,
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
).all()
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
sorted_segments = sorted(
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
)
for segment in sorted_segments:
if segment.answer:
document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}")
else:
document_context_list.append(segment.get_sign_content())
if self.return_resource:
context_list = []
resource_number = 1
for segment in sorted_segments:
context = {}
document = Document.query.filter(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
).first()
if dataset and document:
source = {
"position": resource_number,
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"document_id": document.id,
"document_name": document.name,
"data_source_type": document.data_source_type,
"segment_id": segment.id,
"retriever_from": self.retriever_from,
"score": document_score_list.get(segment.index_node_id, None),
}
if self.retriever_from == "dev":
source["hit_count"] = segment.hit_count
source["word_count"] = segment.word_count
source["segment_position"] = segment.position
source["index_node_hash"] = segment.index_node_hash
if segment.answer:
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
else:
source["content"] = segment.content
context_list.append(source)
resource_number += 1
for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list)
return str("\n".join(document_context_list))