Merge main into feat/plugin

This commit is contained in:
Yeuoly
2024-08-29 13:09:13 +08:00
1405 changed files with 48109 additions and 23346 deletions

View File

@ -145,7 +145,7 @@ class ApiTool(Tool):
path_params[parameter['name']] = value
elif parameter['in'] == 'query':
params[parameter['name']] = value
if value !='': params[parameter['name']] = value
elif parameter['in'] == 'cookie':
cookies[parameter['name']] = value

View File

@ -177,10 +177,12 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
dataset_id=dataset.id,
query=query,
top_k=self.top_k,
score_threshold=retrieval_model['score_threshold']
score_threshold=retrieval_model.get('score_threshold', .0)
if retrieval_model['score_threshold_enabled'] else None,
reranking_model=retrieval_model['reranking_model']
reranking_model=retrieval_model.get('reranking_model', None)
if retrieval_model['reranking_enable'] else None,
reranking_mode=retrieval_model.get('reranking_mode')
if retrieval_model.get('reranking_mode') else 'reranking_model',
weights=retrieval_model.get('weights', None),
)

View File

@ -14,6 +14,7 @@ default_retrieval_model = {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'reranking_mode': 'reranking_model',
'top_k': 2,
'score_threshold_enabled': False
}
@ -71,14 +72,16 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
else:
if self.top_k > 0:
# retrieval source
documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
documents = RetrievalService.retrieve(retrival_method=retrieval_model.get('search_method', 'semantic_search'),
dataset_id=dataset.id,
query=query,
top_k=self.top_k,
score_threshold=retrieval_model['score_threshold']
score_threshold=retrieval_model.get('score_threshold', .0)
if retrieval_model['score_threshold_enabled'] else None,
reranking_model=retrieval_model['reranking_model']
reranking_model=retrieval_model.get('reranking_model', None)
if retrieval_model['reranking_enable'] else None,
reranking_mode=retrieval_model.get('reranking_mode')
if retrieval_model.get('reranking_mode') else 'reranking_model',
weights=retrieval_model.get('weights', None),
)
else:

View File

@ -2,13 +2,12 @@ from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping
from copy import deepcopy
from enum import Enum
from typing import Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
from pydantic import BaseModel, ConfigDict, field_validator
from pydantic_core.core_schema import ValidationInfo
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.file_obj import FileVar
from core.tools.entities.tool_entities import (
ToolDescription,
ToolIdentity,
@ -23,6 +22,9 @@ from core.tools.entities.tool_entities import (
from core.tools.tool_file_manager import ToolFileManager
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
if TYPE_CHECKING:
from core.file.file_obj import FileVar
class Tool(BaseModel, ABC):
identity: Optional[ToolIdentity] = None
@ -286,12 +288,17 @@ class Tool(BaseModel, ABC):
:param image: the url of the image
:return: the image message
"""
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, message=image, save_as=save_as)
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE,
message=image,
save_as=save_as)
def create_file_var_message(self, file_var: FileVar) -> ToolInvokeMessage:
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.FILE_VAR, message='', meta={'file_var': file_var}, save_as=''
)
def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage:
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE_VAR,
message=None,
meta={
'file_var': file_var
},
save_as='')
def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage:
"""
@ -300,7 +307,9 @@ class Tool(BaseModel, ABC):
:param link: the url of the link
:return: the link message
"""
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK, message=link, save_as=save_as)
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK,
message=link,
save_as=save_as)
def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage:
"""
@ -309,7 +318,11 @@ class Tool(BaseModel, ABC):
:param text: the text
:return: the text message
"""
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.TEXT, message=text, save_as=save_as)
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.TEXT,
message=text,
save_as=save_as
)
def create_blob_message(self, blob: bytes, meta: Optional[dict] = None, save_as: str = '') -> ToolInvokeMessage:
"""

View File

@ -72,6 +72,7 @@ class WorkflowTool(Tool):
yield self.create_file_var_message(file)
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
yield self.create_json_message(outputs)
def _get_user(self, user_id: str) -> Union[EndUser, Account]:
"""