mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 09:28:04 +08:00
Merge main into feat/plugin
This commit is contained in:
@ -145,7 +145,7 @@ class ApiTool(Tool):
|
||||
path_params[parameter['name']] = value
|
||||
|
||||
elif parameter['in'] == 'query':
|
||||
params[parameter['name']] = value
|
||||
if value !='': params[parameter['name']] = value
|
||||
|
||||
elif parameter['in'] == 'cookie':
|
||||
cookies[parameter['name']] = value
|
||||
|
||||
@ -177,10 +177,12 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=self.top_k,
|
||||
score_threshold=retrieval_model['score_threshold']
|
||||
score_threshold=retrieval_model.get('score_threshold', .0)
|
||||
if retrieval_model['score_threshold_enabled'] else None,
|
||||
reranking_model=retrieval_model['reranking_model']
|
||||
reranking_model=retrieval_model.get('reranking_model', None)
|
||||
if retrieval_model['reranking_enable'] else None,
|
||||
reranking_mode=retrieval_model.get('reranking_mode')
|
||||
if retrieval_model.get('reranking_mode') else 'reranking_model',
|
||||
weights=retrieval_model.get('weights', None),
|
||||
)
|
||||
|
||||
|
||||
@ -14,6 +14,7 @@ default_retrieval_model = {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'reranking_mode': 'reranking_model',
|
||||
'top_k': 2,
|
||||
'score_threshold_enabled': False
|
||||
}
|
||||
@ -71,14 +72,16 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
else:
|
||||
if self.top_k > 0:
|
||||
# retrieval source
|
||||
documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
|
||||
documents = RetrievalService.retrieve(retrival_method=retrieval_model.get('search_method', 'semantic_search'),
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=self.top_k,
|
||||
score_threshold=retrieval_model['score_threshold']
|
||||
score_threshold=retrieval_model.get('score_threshold', .0)
|
||||
if retrieval_model['score_threshold_enabled'] else None,
|
||||
reranking_model=retrieval_model['reranking_model']
|
||||
reranking_model=retrieval_model.get('reranking_model', None)
|
||||
if retrieval_model['reranking_enable'] else None,
|
||||
reranking_mode=retrieval_model.get('reranking_mode')
|
||||
if retrieval_model.get('reranking_mode') else 'reranking_model',
|
||||
weights=retrieval_model.get('weights', None),
|
||||
)
|
||||
else:
|
||||
|
||||
@ -2,13 +2,12 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Mapping
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.file_obj import FileVar
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolDescription,
|
||||
ToolIdentity,
|
||||
@ -23,6 +22,9 @@ from core.tools.entities.tool_entities import (
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.file_obj import FileVar
|
||||
|
||||
|
||||
class Tool(BaseModel, ABC):
|
||||
identity: Optional[ToolIdentity] = None
|
||||
@ -286,12 +288,17 @@ class Tool(BaseModel, ABC):
|
||||
:param image: the url of the image
|
||||
:return: the image message
|
||||
"""
|
||||
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, message=image, save_as=save_as)
|
||||
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE,
|
||||
message=image,
|
||||
save_as=save_as)
|
||||
|
||||
def create_file_var_message(self, file_var: FileVar) -> ToolInvokeMessage:
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.FILE_VAR, message='', meta={'file_var': file_var}, save_as=''
|
||||
)
|
||||
def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage:
|
||||
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE_VAR,
|
||||
message=None,
|
||||
meta={
|
||||
'file_var': file_var
|
||||
},
|
||||
save_as='')
|
||||
|
||||
def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage:
|
||||
"""
|
||||
@ -300,7 +307,9 @@ class Tool(BaseModel, ABC):
|
||||
:param link: the url of the link
|
||||
:return: the link message
|
||||
"""
|
||||
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK, message=link, save_as=save_as)
|
||||
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK,
|
||||
message=link,
|
||||
save_as=save_as)
|
||||
|
||||
def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage:
|
||||
"""
|
||||
@ -309,7 +318,11 @@ class Tool(BaseModel, ABC):
|
||||
:param text: the text
|
||||
:return: the text message
|
||||
"""
|
||||
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.TEXT, message=text, save_as=save_as)
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.TEXT,
|
||||
message=text,
|
||||
save_as=save_as
|
||||
)
|
||||
|
||||
def create_blob_message(self, blob: bytes, meta: Optional[dict] = None, save_as: str = '') -> ToolInvokeMessage:
|
||||
"""
|
||||
|
||||
@ -72,6 +72,7 @@ class WorkflowTool(Tool):
|
||||
yield self.create_file_var_message(file)
|
||||
|
||||
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
|
||||
yield self.create_json_message(outputs)
|
||||
|
||||
def _get_user(self, user_id: str) -> Union[EndUser, Account]:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user