Merge main

This commit is contained in:
Yeuoly
2024-09-10 14:05:20 +08:00
650 changed files with 15950 additions and 4747 deletions

View File

@ -8,7 +8,7 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.rerank.rerank_model import RerankModelRunner
from core.rag.retrieval.retrival_methods import RetrievalMethod
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
@ -163,7 +163,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(retrival_method='keyword_search',
documents = RetrievalService.retrieve(retrieval_method='keyword_search',
dataset_id=dataset.id,
query=query,
top_k=self.top_k
@ -173,7 +173,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
else:
if self.top_k > 0:
# retrieval source
documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
documents = RetrievalService.retrieve(retrieval_method=retrieval_model['search_method'],
dataset_id=dataset.id,
query=query,
top_k=self.top_k,

View File

@ -2,7 +2,7 @@
from pydantic import BaseModel, Field
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.retrieval.retrival_methods import RetrievalMethod
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
@ -63,7 +63,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(retrival_method='keyword_search',
documents = RetrievalService.retrieve(retrieval_method='keyword_search',
dataset_id=dataset.id,
query=query,
top_k=self.top_k
@ -72,7 +72,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
else:
if self.top_k > 0:
# retrieval source
documents = RetrievalService.retrieve(retrival_method=retrieval_model.get('search_method', 'semantic_search'),
documents = RetrievalService.retrieve(retrieval_method=retrieval_model.get('search_method', 'semantic_search'),
dataset_id=dataset.id,
query=query,
top_k=self.top_k,

View File

@ -18,7 +18,7 @@ from core.tools.tool.tool import Tool
class DatasetRetrieverTool(Tool):
retrival_tool: DatasetRetrieverBaseTool
retrieval_tool: DatasetRetrieverBaseTool
@staticmethod
def get_dataset_tools(tenant_id: str,
@ -43,7 +43,7 @@ class DatasetRetrieverTool(Tool):
# Agent only support SINGLE mode
original_retriever_mode = retrieve_config.retrieve_strategy
retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
retrival_tools = feature.to_dataset_retriever_tool(
retrieval_tools = feature.to_dataset_retriever_tool(
tenant_id=tenant_id,
dataset_ids=dataset_ids,
retrieve_config=retrieve_config,
@ -51,20 +51,23 @@ class DatasetRetrieverTool(Tool):
invoke_from=invoke_from,
hit_callback=hit_callback
)
if retrieval_tools is None or len(retrieval_tools) == 0:
return []
# restore retrieve strategy
retrieve_config.retrieve_strategy = original_retriever_mode
# convert retrival tools to Tools
# convert retrieval tools to Tools
tools = []
for retrival_tool in retrival_tools:
for retrieval_tool in retrieval_tools:
tool = DatasetRetrieverTool(
retrival_tool=retrival_tool,
identity=ToolIdentity(provider='', author='', name=retrival_tool.name, label=I18nObject(en_US='', zh_Hans='')),
retrieval_tool=retrieval_tool,
identity=ToolIdentity(provider='', author='', name=retrieval_tool.name, label=I18nObject(en_US='', zh_Hans='')),
parameters=[],
is_team_authorization=True,
description=ToolDescription(
human=I18nObject(en_US='', zh_Hans=''),
llm=retrival_tool.description),
llm=retrieval_tool.description),
runtime=DatasetRetrieverTool.Runtime()
)
@ -96,8 +99,7 @@ class DatasetRetrieverTool(Tool):
yield self.create_text_message(text='please input query')
else:
# invoke dataset retriever tool
result = self.retrival_tool._run(query=query)
result = self.retrieval_tool._run(query=query)
yield self.create_text_message(text=result)
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None: