mirror of
https://github.com/langgenius/dify.git
synced 2026-05-01 16:08:04 +08:00
Merge branch 'main' into fix/chore-fix
This commit is contained in:
@ -5,7 +5,7 @@ from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||
|
||||
|
||||
class BuiltinToolProviderSort:
|
||||
_position = {}
|
||||
_position: dict[str, int] = {}
|
||||
|
||||
@classmethod
|
||||
def sort(cls, providers: list[ToolProviderApiEntity]) -> list[ToolProviderApiEntity]:
|
||||
|
||||
@ -23,8 +23,10 @@ class TTSTool(BuiltinTool):
|
||||
provider, model = tool_parameters.get("model").split("#") # type: ignore
|
||||
voice = tool_parameters.get(f"voice#{provider}#{model}")
|
||||
model_manager = ModelManager()
|
||||
if not self.runtime:
|
||||
raise ValueError("Runtime is required")
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
provider=provider,
|
||||
model_type=ModelType.TTS,
|
||||
model=model,
|
||||
@ -47,8 +49,11 @@ class TTSTool(BuiltinTool):
|
||||
)
|
||||
|
||||
def get_available_models(self) -> list[tuple[str, str, list[Any]]]:
|
||||
if not self.runtime:
|
||||
raise ValueError("Runtime is required")
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_model_type(tenant_id=self.runtime.tenant_id, model_type="tts")
|
||||
tid: str = self.runtime.tenant_id or ""
|
||||
models = model_provider_service.get_models_by_model_type(tenant_id=tid, model_type="tts")
|
||||
items = []
|
||||
for provider_model in models:
|
||||
provider = provider_model.provider
|
||||
@ -68,6 +73,8 @@ class TTSTool(BuiltinTool):
|
||||
ToolParameter(
|
||||
name=f"voice#{provider}#{model}",
|
||||
label=I18nObject(en_US=f"Voice of {model}({provider})"),
|
||||
human_description=I18nObject(en_US=f"Select a voice for {model} model"),
|
||||
placeholder=I18nObject(en_US="Select a voice"),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
options=[
|
||||
@ -89,6 +96,7 @@ class TTSTool(BuiltinTool):
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
required=True,
|
||||
placeholder=I18nObject(en_US="Select a model", zh_Hans="选择模型"),
|
||||
options=options,
|
||||
),
|
||||
)
|
||||
|
||||
@ -49,9 +49,12 @@ class BuiltinTool(Tool):
|
||||
:return: the model result
|
||||
"""
|
||||
# invoke model
|
||||
if self.runtime is None or self.identity is None:
|
||||
raise ValueError("runtime and identity are required")
|
||||
|
||||
return ModelInvocationUtils.invoke(
|
||||
user_id=user_id,
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
tool_type="builtin",
|
||||
tool_name=self.entity.identity.name,
|
||||
prompt_messages=prompt_messages,
|
||||
@ -67,8 +70,11 @@ class BuiltinTool(Tool):
|
||||
:param model_config: the model config
|
||||
:return: the max tokens
|
||||
"""
|
||||
if self.runtime is None:
|
||||
raise ValueError("runtime is required")
|
||||
|
||||
return ModelInvocationUtils.get_max_llm_context_tokens(
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
)
|
||||
|
||||
def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int:
|
||||
@ -78,7 +84,12 @@ class BuiltinTool(Tool):
|
||||
:param prompt_messages: the prompt messages
|
||||
:return: the tokens
|
||||
"""
|
||||
return ModelInvocationUtils.calculate_tokens(tenant_id=self.runtime.tenant_id, prompt_messages=prompt_messages)
|
||||
if self.runtime is None:
|
||||
raise ValueError("runtime is required")
|
||||
|
||||
return ModelInvocationUtils.calculate_tokens(
|
||||
tenant_id=self.runtime.tenant_id or "", prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
def summary(self, user_id: str, content: str) -> str:
|
||||
max_tokens = self.get_max_tokens()
|
||||
@ -120,16 +131,16 @@ class BuiltinTool(Tool):
|
||||
|
||||
# merge lines into messages with max tokens
|
||||
messages: list[str] = []
|
||||
for i in new_lines:
|
||||
for j in new_lines:
|
||||
if len(messages) == 0:
|
||||
messages.append(i)
|
||||
messages.append(j)
|
||||
else:
|
||||
if len(messages[-1]) + len(i) < max_tokens * 0.5:
|
||||
messages[-1] += i
|
||||
if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7:
|
||||
messages.append(i)
|
||||
if len(messages[-1]) + len(j) < max_tokens * 0.5:
|
||||
messages[-1] += j
|
||||
if get_prompt_tokens(messages[-1] + j) > max_tokens * 0.7:
|
||||
messages.append(j)
|
||||
else:
|
||||
messages[-1] += i
|
||||
messages[-1] += j
|
||||
|
||||
summaries = []
|
||||
for i in range(len(messages)):
|
||||
|
||||
@ -130,7 +130,7 @@ class ApiToolProviderController(ToolProviderController):
|
||||
runtime=ToolRuntime(tenant_id=self.tenant_id),
|
||||
)
|
||||
|
||||
def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]:
|
||||
def load_bundled_tools(self, tools: list[ApiToolBundle]):
|
||||
"""
|
||||
load bundled tools
|
||||
|
||||
@ -151,6 +151,8 @@ class ApiToolProviderController(ToolProviderController):
|
||||
"""
|
||||
if len(self.tools) > 0:
|
||||
return self.tools
|
||||
if self.identity is None:
|
||||
return None
|
||||
|
||||
tools: list[ApiTool] = []
|
||||
|
||||
@ -170,7 +172,7 @@ class ApiToolProviderController(ToolProviderController):
|
||||
self.tools = tools
|
||||
return tools
|
||||
|
||||
def get_tool(self, tool_name: str) -> ApiTool:
|
||||
def get_tool(self, tool_name: str):
|
||||
"""
|
||||
get tool by name
|
||||
|
||||
|
||||
@ -40,6 +40,8 @@ class ApiTool(Tool):
|
||||
:param meta: the meta data of a tool call processing, tenant_id is required
|
||||
:return: the new tool
|
||||
"""
|
||||
if self.api_bundle is None:
|
||||
raise ValueError("api_bundle is required")
|
||||
return self.__class__(
|
||||
entity=self.entity,
|
||||
api_bundle=self.api_bundle.model_copy(),
|
||||
@ -67,10 +69,12 @@ class ApiTool(Tool):
|
||||
return ToolProviderType.API
|
||||
|
||||
def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
if self.runtime == None:
|
||||
if self.runtime is None:
|
||||
raise ToolProviderCredentialValidationError("runtime not initialized")
|
||||
|
||||
headers = {}
|
||||
if self.runtime is None:
|
||||
raise ValueError("runtime is required")
|
||||
credentials = self.runtime.credentials or {}
|
||||
|
||||
if "auth_type" not in credentials:
|
||||
@ -121,9 +125,9 @@ class ApiTool(Tool):
|
||||
response = response.json()
|
||||
try:
|
||||
return json.dumps(response, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return json.dumps(response)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return response.text
|
||||
else:
|
||||
raise ValueError(f"Invalid response type {type(response)}")
|
||||
@ -147,7 +151,8 @@ class ApiTool(Tool):
|
||||
|
||||
params = {}
|
||||
path_params = {}
|
||||
body = {}
|
||||
# FIXME: body should be a dict[str, Any] but it changed a lot in this function
|
||||
body: Any = {}
|
||||
cookies = {}
|
||||
files = []
|
||||
|
||||
@ -208,7 +213,7 @@ class ApiTool(Tool):
|
||||
body = body
|
||||
|
||||
if method in {"get", "head", "post", "put", "delete", "patch"}:
|
||||
response = getattr(ssrf_proxy, method)(
|
||||
response: httpx.Response = getattr(ssrf_proxy, method)(
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
@ -291,7 +296,7 @@ class ApiTool(Tool):
|
||||
raise ValueError(f"Invalid type {property['type']} for property {property}")
|
||||
elif "anyOf" in property and isinstance(property["anyOf"], list):
|
||||
return self._convert_body_property_any_of(property, value, property["anyOf"])
|
||||
except ValueError as e:
|
||||
except ValueError:
|
||||
return value
|
||||
|
||||
def _invoke(
|
||||
@ -305,6 +310,7 @@ class ApiTool(Tool):
|
||||
"""
|
||||
invoke http request
|
||||
"""
|
||||
response: httpx.Response | str = ""
|
||||
# assemble request
|
||||
headers = self.assembling_request(tool_parameters)
|
||||
|
||||
|
||||
@ -77,6 +77,8 @@ class ToolEngine:
|
||||
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
|
||||
|
||||
# invoke the tool
|
||||
if tool.identity is None:
|
||||
raise ValueError("tool identity is not set")
|
||||
try:
|
||||
# hit the callback handler
|
||||
agent_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters)
|
||||
@ -205,6 +207,8 @@ class ToolEngine:
|
||||
"""
|
||||
Invoke the tool with the given arguments.
|
||||
"""
|
||||
if tool.identity is None:
|
||||
raise ValueError("tool identity is not set")
|
||||
started_at = datetime.now(UTC)
|
||||
meta = ToolInvokeMeta(
|
||||
time_cost=0.0,
|
||||
@ -250,7 +254,7 @@ class ToolEngine:
|
||||
text = json.dumps(cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False)
|
||||
result += f"tool response: {text}."
|
||||
else:
|
||||
result += f"tool response: {response.message}."
|
||||
result += f"tool response: {response.message!r}."
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@ -8,6 +8,8 @@ from mimetypes import guess_extension, guess_type
|
||||
from typing import Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
@ -96,9 +98,8 @@ class ToolFileManager:
|
||||
response = ssrf_proxy.get(file_url)
|
||||
response.raise_for_status()
|
||||
blob = response.content
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to download file from {file_url}")
|
||||
raise
|
||||
except httpx.TimeoutException:
|
||||
raise ValueError(f"timeout when downloading file from {file_url}")
|
||||
|
||||
mimetype = guess_type(file_url)[0] or "octet/stream"
|
||||
extension = guess_extension(mimetype) or ".bin"
|
||||
@ -217,6 +218,6 @@ class ToolFileManager:
|
||||
|
||||
|
||||
# init tool_file_parser
|
||||
from core.file.tool_file_parser import tool_file_manager
|
||||
from core.file.tool_file_parser import tool_file_manager # noqa: E402
|
||||
|
||||
tool_file_manager["manager"] = ToolFileManager
|
||||
|
||||
@ -93,7 +93,7 @@ class ToolLabelManager:
|
||||
db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all()
|
||||
)
|
||||
|
||||
tool_labels = {label.tool_id: [] for label in labels}
|
||||
tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels}
|
||||
|
||||
for label in labels:
|
||||
tool_labels[label.tool_id].append(label.label_name)
|
||||
|
||||
@ -4,16 +4,18 @@ import mimetypes
|
||||
from collections.abc import Generator
|
||||
from os import listdir, path
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
|
||||
from yarl import URL
|
||||
|
||||
import contexts
|
||||
from core.plugin.entities.plugin import GenericProviderID
|
||||
from core.plugin.manager.tool import PluginToolManager
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.plugin_tool.tool import PluginTool
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
@ -39,7 +41,7 @@ from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.configuration import (
|
||||
ProviderConfigEncrypter,
|
||||
@ -57,7 +59,7 @@ class ToolManager:
|
||||
_builtin_provider_lock = Lock()
|
||||
_hardcoded_providers = {}
|
||||
_builtin_providers_loaded = False
|
||||
_builtin_tools_labels = {}
|
||||
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
|
||||
|
||||
@classmethod
|
||||
def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
|
||||
@ -140,6 +142,8 @@ class ToolManager:
|
||||
"""
|
||||
provider_controller = cls.get_builtin_provider(provider, tenant_id)
|
||||
tool = provider_controller.get_tool(tool_name)
|
||||
if tool is None:
|
||||
raise ToolNotFoundError(f"tool {tool_name} not found")
|
||||
|
||||
return tool
|
||||
|
||||
@ -266,6 +270,11 @@ class ToolManager:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
|
||||
controller_tools: Optional[list[Tool]] = controller.get_tools(
|
||||
user_id="", tenant_id=workflow_provider.tenant_id
|
||||
)
|
||||
if controller_tools is None or len(controller_tools) == 0:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
return cast(
|
||||
WorkflowTool,
|
||||
@ -333,6 +342,8 @@ class ToolManager:
|
||||
identity_id=f"AGENT.{app_id}",
|
||||
)
|
||||
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
|
||||
if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None:
|
||||
raise ValueError("runtime not found or runtime parameters not found")
|
||||
|
||||
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
|
||||
return tool_entity
|
||||
@ -583,9 +594,11 @@ class ToolManager:
|
||||
# append builtin providers
|
||||
for provider in builtin_providers:
|
||||
# handle include, exclude
|
||||
if provider.identity is None:
|
||||
continue
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
|
||||
include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
|
||||
exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
|
||||
data=provider,
|
||||
name_func=lambda x: x.identity.name,
|
||||
):
|
||||
@ -609,7 +622,7 @@ class ToolManager:
|
||||
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
|
||||
)
|
||||
|
||||
api_provider_controllers = [
|
||||
api_provider_controllers: list[dict[str, Any]] = [
|
||||
{"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
|
||||
for provider in db_api_providers
|
||||
]
|
||||
@ -632,7 +645,7 @@ class ToolManager:
|
||||
db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
|
||||
)
|
||||
|
||||
workflow_provider_controllers = []
|
||||
workflow_provider_controllers: list[WorkflowToolProviderController] = []
|
||||
for provider in workflow_providers:
|
||||
try:
|
||||
workflow_provider_controllers.append(
|
||||
@ -642,7 +655,9 @@ class ToolManager:
|
||||
# app has been deleted
|
||||
pass
|
||||
|
||||
labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers)
|
||||
labels = ToolLabelManager.get_tools_labels(
|
||||
[cast(ToolProviderController, controller) for controller in workflow_provider_controllers]
|
||||
)
|
||||
|
||||
for provider_controller in workflow_provider_controllers:
|
||||
user_provider = ToolTransformService.workflow_provider_to_user_provider(
|
||||
@ -693,7 +708,7 @@ class ToolManager:
|
||||
get tool provider
|
||||
"""
|
||||
provider_name = provider
|
||||
provider_obj: ApiToolProvider = (
|
||||
provider_obj: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
@ -707,7 +722,7 @@ class ToolManager:
|
||||
|
||||
try:
|
||||
credentials = json.loads(provider_obj.credentials_str) or {}
|
||||
except:
|
||||
except Exception:
|
||||
credentials = {}
|
||||
|
||||
# package tool provider controller
|
||||
@ -728,7 +743,7 @@ class ToolManager:
|
||||
|
||||
try:
|
||||
icon = json.loads(provider_obj.icon)
|
||||
except:
|
||||
except Exception:
|
||||
icon = {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
# add tool labels
|
||||
@ -783,7 +798,7 @@ class ToolManager:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
return json.loads(workflow_provider.icon)
|
||||
except:
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
@classmethod
|
||||
@ -799,7 +814,7 @@ class ToolManager:
|
||||
raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
|
||||
|
||||
return json.loads(api_provider.icon)
|
||||
except:
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
@classmethod
|
||||
@ -824,7 +839,7 @@ class ToolManager:
|
||||
if isinstance(provider, PluginToolProviderController):
|
||||
try:
|
||||
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
|
||||
except:
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
return cls.generate_builtin_tool_icon_url(provider_id)
|
||||
elif provider_type == ToolProviderType.API:
|
||||
@ -836,7 +851,7 @@ class ToolManager:
|
||||
if isinstance(provider, PluginToolProviderController):
|
||||
try:
|
||||
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
|
||||
except:
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
raise ValueError(f"plugin provider {provider_id} not found")
|
||||
else:
|
||||
|
||||
@ -101,7 +101,7 @@ class ProviderConfigEncrypter(BaseModel):
|
||||
continue
|
||||
|
||||
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
cache.set(data)
|
||||
@ -221,6 +221,9 @@ class ToolParameterConfigurationManager:
|
||||
|
||||
return a deep copy of parameters with decrypted values
|
||||
"""
|
||||
if self.tool_runtime is None or self.tool_runtime.identity is None:
|
||||
raise ValueError("tool_runtime is required")
|
||||
|
||||
cache = ToolParameterCache(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f"{self.provider_type.value}.{self.provider_name}",
|
||||
@ -245,7 +248,7 @@ class ToolParameterConfigurationManager:
|
||||
try:
|
||||
has_secret_input = True
|
||||
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if has_secret_input:
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import BaseModel, Field
|
||||
@ -7,13 +8,14 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.models.document import Document as RagDocument
|
||||
from core.rag.rerank.rerank_model import RerankModelRunner
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
default_retrieval_model = {
|
||||
default_retrieval_model: dict[str, Any] = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
@ -44,7 +46,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
threads = []
|
||||
all_documents = []
|
||||
all_documents: list[RagDocument] = []
|
||||
for dataset_id in self.dataset_ids:
|
||||
retrieval_thread = threading.Thread(
|
||||
target=self._retriever,
|
||||
@ -77,8 +79,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
|
||||
document_score_list = {}
|
||||
for item in all_documents:
|
||||
assert item.metadata
|
||||
if item.metadata.get("score"):
|
||||
if item.metadata and item.metadata.get("score"):
|
||||
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
||||
|
||||
document_context_list = []
|
||||
@ -87,7 +88,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
DocumentSegment.dataset_id.in_(self.dataset_ids),
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.enabled is True,
|
||||
DocumentSegment.index_node_id.in_(index_node_ids),
|
||||
).all()
|
||||
|
||||
@ -108,8 +109,8 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
|
||||
document = Document.query.filter(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
Document.enabled is True,
|
||||
Document.archived is False,
|
||||
).first()
|
||||
if dataset and document:
|
||||
source = {
|
||||
@ -140,6 +141,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
|
||||
return str("\n".join(document_context_list))
|
||||
return ""
|
||||
|
||||
raise RuntimeError("not segments found")
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
from msal_extensions.persistence import ABC
|
||||
from msal_extensions.persistence import ABC # type: ignore
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
@ -69,25 +71,27 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
metadata=external_document.get("metadata"),
|
||||
provider="external",
|
||||
)
|
||||
document.metadata["score"] = external_document.get("score")
|
||||
document.metadata["title"] = external_document.get("title")
|
||||
document.metadata["dataset_id"] = dataset.id
|
||||
document.metadata["dataset_name"] = dataset.name
|
||||
results.append(document)
|
||||
if document.metadata is not None:
|
||||
document.metadata["score"] = external_document.get("score")
|
||||
document.metadata["title"] = external_document.get("title")
|
||||
document.metadata["dataset_id"] = dataset.id
|
||||
document.metadata["dataset_name"] = dataset.name
|
||||
results.append(document)
|
||||
# deal with external documents
|
||||
context_list = []
|
||||
for position, item in enumerate(results, start=1):
|
||||
source = {
|
||||
"position": position,
|
||||
"dataset_id": item.metadata.get("dataset_id"),
|
||||
"dataset_name": item.metadata.get("dataset_name"),
|
||||
"document_name": item.metadata.get("title"),
|
||||
"data_source_type": "external",
|
||||
"retriever_from": self.retriever_from,
|
||||
"score": item.metadata.get("score"),
|
||||
"title": item.metadata.get("title"),
|
||||
"content": item.page_content,
|
||||
}
|
||||
if item.metadata is not None:
|
||||
source = {
|
||||
"position": position,
|
||||
"dataset_id": item.metadata.get("dataset_id"),
|
||||
"dataset_name": item.metadata.get("dataset_name"),
|
||||
"document_name": item.metadata.get("title"),
|
||||
"data_source_type": "external",
|
||||
"retriever_from": self.retriever_from,
|
||||
"score": item.metadata.get("score"),
|
||||
"title": item.metadata.get("title"),
|
||||
"content": item.page_content,
|
||||
}
|
||||
context_list.append(source)
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
@ -95,7 +99,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
return str("\n".join([item.page_content for item in results]))
|
||||
else:
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
@ -113,11 +117,11 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||
if retrieval_model["score_threshold_enabled"]
|
||||
else 0.0,
|
||||
reranking_model=retrieval_model.get("reranking_model", None)
|
||||
reranking_model=retrieval_model.get("reranking_model")
|
||||
if retrieval_model["reranking_enable"]
|
||||
else None,
|
||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||
weights=retrieval_model.get("weights", None),
|
||||
weights=retrieval_model.get("weights"),
|
||||
)
|
||||
else:
|
||||
documents = []
|
||||
@ -127,7 +131,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
document_score_list = {}
|
||||
if dataset.indexing_technique != "economy":
|
||||
for item in documents:
|
||||
if item.metadata.get("score"):
|
||||
if item.metadata is not None and item.metadata.get("score"):
|
||||
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
||||
document_context_list = []
|
||||
index_node_ids = [document.metadata["doc_id"] for document in documents]
|
||||
@ -155,20 +159,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
context_list = []
|
||||
resource_number = 1
|
||||
for segment in sorted_segments:
|
||||
context = {}
|
||||
document = Document.query.filter(
|
||||
document_segment = Document.query.filter(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
).first()
|
||||
if dataset and document:
|
||||
if not document_segment:
|
||||
continue
|
||||
if dataset and document_segment:
|
||||
source = {
|
||||
"position": resource_number,
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"document_id": document.id,
|
||||
"document_name": document.name,
|
||||
"data_source_type": document.data_source_type,
|
||||
"document_id": document_segment.id,
|
||||
"document_name": document_segment.name,
|
||||
"data_source_type": document_segment.data_source_type,
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": self.retriever_from,
|
||||
"score": document_score_list.get(segment.index_node_id, None),
|
||||
|
||||
@ -94,6 +94,7 @@ class DatasetRetrieverTool(Tool):
|
||||
llm_description="Query for the dataset to be used to retrieve the dataset.",
|
||||
required=True,
|
||||
default="",
|
||||
placeholder=I18nObject(en_US="", zh_Hans=""),
|
||||
),
|
||||
]
|
||||
|
||||
@ -112,7 +113,9 @@ class DatasetRetrieverTool(Tool):
|
||||
result = self.retrieval_tool._run(query=query)
|
||||
yield self.create_text_message(text=result)
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
|
||||
def validate_credentials(
|
||||
self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False
|
||||
) -> str | None:
|
||||
"""
|
||||
validate the credentials for dataset retriever tool
|
||||
"""
|
||||
|
||||
@ -5,7 +5,7 @@ Therefore, a model manager is needed to list/invoke/validate models.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import cast
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
@ -51,7 +51,7 @@ class ModelInvocationUtils:
|
||||
if not schema:
|
||||
raise InvokeModelError("No model schema found")
|
||||
|
||||
max_tokens = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
|
||||
max_tokens: Optional[int] = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
|
||||
if max_tokens is None:
|
||||
return 2048
|
||||
|
||||
@ -133,14 +133,17 @@ class ModelInvocationUtils:
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=[],
|
||||
stop=[],
|
||||
stream=False,
|
||||
user=user_id,
|
||||
callbacks=[],
|
||||
response: LLMResult = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=[],
|
||||
stop=[],
|
||||
stream=False,
|
||||
user=user_id,
|
||||
callbacks=[],
|
||||
),
|
||||
)
|
||||
except InvokeRateLimitError as e:
|
||||
raise InvokeModelError(f"Invoke rate limit error: {e}")
|
||||
|
||||
@ -5,7 +5,7 @@ from json import loads as json_loads
|
||||
from json.decoder import JSONDecodeError
|
||||
|
||||
from requests import get
|
||||
from yaml import YAMLError, safe_load
|
||||
from yaml import YAMLError, safe_load # type: ignore
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
@ -63,6 +63,9 @@ class ApiBasedToolSchemaParser:
|
||||
default=parameter["schema"]["default"]
|
||||
if "schema" in parameter and "default" in parameter["schema"]
|
||||
else None,
|
||||
placeholder=I18nObject(
|
||||
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||
),
|
||||
)
|
||||
|
||||
# check if there is a type
|
||||
@ -107,6 +110,9 @@ class ApiBasedToolSchemaParser:
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=property.get("description", ""),
|
||||
default=property.get("default", None),
|
||||
placeholder=I18nObject(
|
||||
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||
),
|
||||
)
|
||||
|
||||
# check if there is a type
|
||||
@ -157,9 +163,9 @@ class ApiBasedToolSchemaParser:
|
||||
return bundles
|
||||
|
||||
@staticmethod
|
||||
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType:
|
||||
def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]:
|
||||
parameter = parameter or {}
|
||||
typ = None
|
||||
typ: Optional[str] = None
|
||||
if parameter.get("format") == "binary":
|
||||
return ToolParameter.ToolParameterType.FILE
|
||||
|
||||
@ -174,6 +180,8 @@ class ApiBasedToolSchemaParser:
|
||||
return ToolParameter.ToolParameterType.BOOLEAN
|
||||
elif typ == "string":
|
||||
return ToolParameter.ToolParameterType.STRING
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_openapi_yaml_to_tool_bundle(
|
||||
@ -236,7 +244,8 @@ class ApiBasedToolSchemaParser:
|
||||
if ("summary" not in operation or len(operation["summary"]) == 0) and (
|
||||
"description" not in operation or len(operation["description"]) == 0
|
||||
):
|
||||
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
|
||||
if warning is not None:
|
||||
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
|
||||
|
||||
openapi["paths"][path][method] = {
|
||||
"operationId": operation["operationId"],
|
||||
|
||||
@ -9,13 +9,13 @@ import tempfile
|
||||
import unicodedata
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Any, Literal, Optional, cast
|
||||
from urllib.parse import unquote
|
||||
|
||||
import chardet
|
||||
import cloudscraper
|
||||
from bs4 import BeautifulSoup, CData, Comment, NavigableString
|
||||
from regex import regex
|
||||
import cloudscraper # type: ignore
|
||||
from bs4 import BeautifulSoup, CData, Comment, NavigableString # type: ignore
|
||||
from regex import regex # type: ignore
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from core.rag.extractor import extract_processor
|
||||
@ -68,7 +68,7 @@ def get_url(url: str, user_agent: Optional[str] = None) -> str:
|
||||
return "Unsupported content-type [{}] of URL.".format(main_content_type)
|
||||
|
||||
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
|
||||
return ExtractProcessor.load_from_url(url, return_text=True)
|
||||
return cast(str, ExtractProcessor.load_from_url(url, return_text=True))
|
||||
|
||||
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
|
||||
elif response.status_code == 403:
|
||||
@ -125,7 +125,7 @@ def extract_using_readabilipy(html):
|
||||
os.unlink(article_json_path)
|
||||
os.unlink(html_path)
|
||||
|
||||
article_json = {
|
||||
article_json: dict[str, Any] = {
|
||||
"title": None,
|
||||
"byline": None,
|
||||
"date": None,
|
||||
@ -300,7 +300,7 @@ def strip_control_characters(text):
|
||||
|
||||
def normalize_unicode(text):
|
||||
"""Normalize unicode such that things that are visually equivalent map to the same unicode string where possible."""
|
||||
normal_form = "NFKC"
|
||||
normal_form: Literal["NFC", "NFD", "NFKC", "NFKD"] = "NFKC"
|
||||
text = unicodedata.normalize(normal_form, text)
|
||||
return text
|
||||
|
||||
@ -332,6 +332,7 @@ def add_content_digest(element):
|
||||
|
||||
|
||||
def content_digest(element):
|
||||
digest: Any
|
||||
if is_text(element):
|
||||
# Hash
|
||||
trimmed_string = element.string.strip()
|
||||
|
||||
@ -7,7 +7,7 @@ from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
|
||||
class WorkflowToolConfigurationUtils:
|
||||
@classmethod
|
||||
def check_parameter_configurations(cls, configurations: Mapping[str, Any]):
|
||||
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
|
||||
for configuration in configurations:
|
||||
WorkflowToolParameterConfiguration.model_validate(configuration)
|
||||
|
||||
@ -27,7 +27,7 @@ class WorkflowToolConfigurationUtils:
|
||||
@classmethod
|
||||
def check_is_synced(
|
||||
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
|
||||
) -> None:
|
||||
) -> bool:
|
||||
"""
|
||||
check is synced
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
import yaml # type: ignore
|
||||
from yaml import YAMLError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -18,6 +18,7 @@ from core.tools.entities.tool_entities import (
|
||||
ToolProviderIdentity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
@ -130,6 +131,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
llm_description=parameter.description,
|
||||
required=variable.required,
|
||||
options=options,
|
||||
placeholder=I18nObject(en_US="", zh_Hans=""),
|
||||
)
|
||||
)
|
||||
elif features.file_upload:
|
||||
@ -142,6 +144,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
llm_description=parameter.description,
|
||||
required=False,
|
||||
form=parameter.form,
|
||||
placeholder=I18nObject(en_US="", zh_Hans=""),
|
||||
)
|
||||
)
|
||||
else:
|
||||
@ -198,6 +201,8 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
|
||||
if not db_providers:
|
||||
return []
|
||||
if not db_providers.app:
|
||||
raise ValueError("app not found")
|
||||
|
||||
app = db_providers.app
|
||||
if not app:
|
||||
@ -207,7 +212,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
|
||||
return self.tools
|
||||
|
||||
def get_tool(self, tool_name: str) -> Optional[WorkflowTool]:
|
||||
def get_tool(self, tool_name: str) -> Optional[Tool]:
|
||||
"""
|
||||
get tool by name
|
||||
|
||||
|
||||
@ -102,7 +102,7 @@ class WorkflowTool(Tool):
|
||||
raise Exception(data.get("error"))
|
||||
|
||||
outputs = data.get("outputs")
|
||||
if outputs == None:
|
||||
if outputs is None:
|
||||
outputs = {}
|
||||
else:
|
||||
outputs, files = self._extract_files(outputs)
|
||||
|
||||
Reference in New Issue
Block a user