Merge branch 'main' into fix/chore-fix

This commit is contained in:
Yeuoly
2024-12-24 21:28:56 +08:00
734 changed files with 7911 additions and 5007 deletions

View File

@ -5,7 +5,7 @@ from core.tools.entities.api_entities import ToolProviderApiEntity
class BuiltinToolProviderSort:
_position = {}
_position: dict[str, int] = {}
@classmethod
def sort(cls, providers: list[ToolProviderApiEntity]) -> list[ToolProviderApiEntity]:

View File

@ -23,8 +23,10 @@ class TTSTool(BuiltinTool):
provider, model = tool_parameters.get("model").split("#") # type: ignore
voice = tool_parameters.get(f"voice#{provider}#{model}")
model_manager = ModelManager()
if not self.runtime:
raise ValueError("Runtime is required")
model_instance = model_manager.get_model_instance(
tenant_id=self.runtime.tenant_id,
tenant_id=self.runtime.tenant_id or "",
provider=provider,
model_type=ModelType.TTS,
model=model,
@ -47,8 +49,11 @@ class TTSTool(BuiltinTool):
)
def get_available_models(self) -> list[tuple[str, str, list[Any]]]:
if not self.runtime:
raise ValueError("Runtime is required")
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(tenant_id=self.runtime.tenant_id, model_type="tts")
tid: str = self.runtime.tenant_id or ""
models = model_provider_service.get_models_by_model_type(tenant_id=tid, model_type="tts")
items = []
for provider_model in models:
provider = provider_model.provider
@ -68,6 +73,8 @@ class TTSTool(BuiltinTool):
ToolParameter(
name=f"voice#{provider}#{model}",
label=I18nObject(en_US=f"Voice of {model}({provider})"),
human_description=I18nObject(en_US=f"Select a voice for {model} model"),
placeholder=I18nObject(en_US="Select a voice"),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
options=[
@ -89,6 +96,7 @@ class TTSTool(BuiltinTool):
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
required=True,
placeholder=I18nObject(en_US="Select a model", zh_Hans="选择模型"),
options=options,
),
)

View File

@ -49,9 +49,12 @@ class BuiltinTool(Tool):
:return: the model result
"""
# invoke model
if self.runtime is None or self.identity is None:
raise ValueError("runtime and identity are required")
return ModelInvocationUtils.invoke(
user_id=user_id,
tenant_id=self.runtime.tenant_id,
tenant_id=self.runtime.tenant_id or "",
tool_type="builtin",
tool_name=self.entity.identity.name,
prompt_messages=prompt_messages,
@ -67,8 +70,11 @@ class BuiltinTool(Tool):
:param model_config: the model config
:return: the max tokens
"""
if self.runtime is None:
raise ValueError("runtime is required")
return ModelInvocationUtils.get_max_llm_context_tokens(
tenant_id=self.runtime.tenant_id,
tenant_id=self.runtime.tenant_id or "",
)
def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int:
@ -78,7 +84,12 @@ class BuiltinTool(Tool):
:param prompt_messages: the prompt messages
:return: the tokens
"""
return ModelInvocationUtils.calculate_tokens(tenant_id=self.runtime.tenant_id, prompt_messages=prompt_messages)
if self.runtime is None:
raise ValueError("runtime is required")
return ModelInvocationUtils.calculate_tokens(
tenant_id=self.runtime.tenant_id or "", prompt_messages=prompt_messages
)
def summary(self, user_id: str, content: str) -> str:
max_tokens = self.get_max_tokens()
@ -120,16 +131,16 @@ class BuiltinTool(Tool):
# merge lines into messages with max tokens
messages: list[str] = []
for i in new_lines:
for j in new_lines:
if len(messages) == 0:
messages.append(i)
messages.append(j)
else:
if len(messages[-1]) + len(i) < max_tokens * 0.5:
messages[-1] += i
if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7:
messages.append(i)
if len(messages[-1]) + len(j) < max_tokens * 0.5:
messages[-1] += j
if get_prompt_tokens(messages[-1] + j) > max_tokens * 0.7:
messages.append(j)
else:
messages[-1] += i
messages[-1] += j
summaries = []
for i in range(len(messages)):

View File

@ -130,7 +130,7 @@ class ApiToolProviderController(ToolProviderController):
runtime=ToolRuntime(tenant_id=self.tenant_id),
)
def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]:
def load_bundled_tools(self, tools: list[ApiToolBundle]):
"""
load bundled tools
@ -151,6 +151,8 @@ class ApiToolProviderController(ToolProviderController):
"""
if len(self.tools) > 0:
return self.tools
if self.identity is None:
return None
tools: list[ApiTool] = []
@ -170,7 +172,7 @@ class ApiToolProviderController(ToolProviderController):
self.tools = tools
return tools
def get_tool(self, tool_name: str) -> ApiTool:
def get_tool(self, tool_name: str):
"""
get tool by name

View File

@ -40,6 +40,8 @@ class ApiTool(Tool):
:param meta: the meta data of a tool call processing, tenant_id is required
:return: the new tool
"""
if self.api_bundle is None:
raise ValueError("api_bundle is required")
return self.__class__(
entity=self.entity,
api_bundle=self.api_bundle.model_copy(),
@ -67,10 +69,12 @@ class ApiTool(Tool):
return ToolProviderType.API
def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
if self.runtime == None:
if self.runtime is None:
raise ToolProviderCredentialValidationError("runtime not initialized")
headers = {}
if self.runtime is None:
raise ValueError("runtime is required")
credentials = self.runtime.credentials or {}
if "auth_type" not in credentials:
@ -121,9 +125,9 @@ class ApiTool(Tool):
response = response.json()
try:
return json.dumps(response, ensure_ascii=False)
except Exception as e:
except Exception:
return json.dumps(response)
except Exception as e:
except Exception:
return response.text
else:
raise ValueError(f"Invalid response type {type(response)}")
@ -147,7 +151,8 @@ class ApiTool(Tool):
params = {}
path_params = {}
body = {}
# FIXME: body should be a dict[str, Any] but it changed a lot in this function
body: Any = {}
cookies = {}
files = []
@ -208,7 +213,7 @@ class ApiTool(Tool):
body = body
if method in {"get", "head", "post", "put", "delete", "patch"}:
response = getattr(ssrf_proxy, method)(
response: httpx.Response = getattr(ssrf_proxy, method)(
url,
params=params,
headers=headers,
@ -291,7 +296,7 @@ class ApiTool(Tool):
raise ValueError(f"Invalid type {property['type']} for property {property}")
elif "anyOf" in property and isinstance(property["anyOf"], list):
return self._convert_body_property_any_of(property, value, property["anyOf"])
except ValueError as e:
except ValueError:
return value
def _invoke(
@ -305,6 +310,7 @@ class ApiTool(Tool):
"""
invoke http request
"""
response: httpx.Response | str = ""
# assemble request
headers = self.assembling_request(tool_parameters)

View File

@ -77,6 +77,8 @@ class ToolEngine:
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
# invoke the tool
if tool.identity is None:
raise ValueError("tool identity is not set")
try:
# hit the callback handler
agent_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters)
@ -205,6 +207,8 @@ class ToolEngine:
"""
Invoke the tool with the given arguments.
"""
if tool.identity is None:
raise ValueError("tool identity is not set")
started_at = datetime.now(UTC)
meta = ToolInvokeMeta(
time_cost=0.0,
@ -250,7 +254,7 @@ class ToolEngine:
text = json.dumps(cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False)
result += f"tool response: {text}."
else:
result += f"tool response: {response.message}."
result += f"tool response: {response.message!r}."
return result

View File

@ -8,6 +8,8 @@ from mimetypes import guess_extension, guess_type
from typing import Optional, Union
from uuid import uuid4
import httpx
from configs import dify_config
from core.helper import ssrf_proxy
from extensions.ext_database import db
@ -96,9 +98,8 @@ class ToolFileManager:
response = ssrf_proxy.get(file_url)
response.raise_for_status()
blob = response.content
except Exception as e:
logger.exception(f"Failed to download file from {file_url}")
raise
except httpx.TimeoutException:
raise ValueError(f"timeout when downloading file from {file_url}")
mimetype = guess_type(file_url)[0] or "octet/stream"
extension = guess_extension(mimetype) or ".bin"
@ -217,6 +218,6 @@ class ToolFileManager:
# init tool_file_parser
from core.file.tool_file_parser import tool_file_manager
from core.file.tool_file_parser import tool_file_manager # noqa: E402
tool_file_manager["manager"] = ToolFileManager

View File

@ -93,7 +93,7 @@ class ToolLabelManager:
db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all()
)
tool_labels = {label.tool_id: [] for label in labels}
tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels}
for label in labels:
tool_labels[label.tool_id].append(label.label_name)

View File

@ -4,16 +4,18 @@ import mimetypes
from collections.abc import Generator
from os import listdir, path
from threading import Lock
from typing import TYPE_CHECKING, Any, Union, cast
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from yarl import URL
import contexts
from core.plugin.entities.plugin import GenericProviderID
from core.plugin.manager.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.plugin_tool.tool import PluginTool
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
if TYPE_CHECKING:
from core.workflow.nodes.tool.entities import ToolEntity
@ -39,7 +41,7 @@ from core.tools.entities.tool_entities import (
ToolParameter,
ToolProviderType,
)
from core.tools.errors import ToolProviderNotFoundError
from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import (
ProviderConfigEncrypter,
@ -57,7 +59,7 @@ class ToolManager:
_builtin_provider_lock = Lock()
_hardcoded_providers = {}
_builtin_providers_loaded = False
_builtin_tools_labels = {}
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
@classmethod
def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
@ -140,6 +142,8 @@ class ToolManager:
"""
provider_controller = cls.get_builtin_provider(provider, tenant_id)
tool = provider_controller.get_tool(tool_name)
if tool is None:
raise ToolNotFoundError(f"tool {tool_name} not found")
return tool
@ -266,6 +270,11 @@ class ToolManager:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
controller_tools: Optional[list[Tool]] = controller.get_tools(
user_id="", tenant_id=workflow_provider.tenant_id
)
if controller_tools is None or len(controller_tools) == 0:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
return cast(
WorkflowTool,
@ -333,6 +342,8 @@ class ToolManager:
identity_id=f"AGENT.{app_id}",
)
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None:
raise ValueError("runtime not found or runtime parameters not found")
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
return tool_entity
@ -583,9 +594,11 @@ class ToolManager:
# append builtin providers
for provider in builtin_providers:
# handle include, exclude
if provider.identity is None:
continue
if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
data=provider,
name_func=lambda x: x.identity.name,
):
@ -609,7 +622,7 @@ class ToolManager:
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
)
api_provider_controllers = [
api_provider_controllers: list[dict[str, Any]] = [
{"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
for provider in db_api_providers
]
@ -632,7 +645,7 @@ class ToolManager:
db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
)
workflow_provider_controllers = []
workflow_provider_controllers: list[WorkflowToolProviderController] = []
for provider in workflow_providers:
try:
workflow_provider_controllers.append(
@ -642,7 +655,9 @@ class ToolManager:
# app has been deleted
pass
labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers)
labels = ToolLabelManager.get_tools_labels(
[cast(ToolProviderController, controller) for controller in workflow_provider_controllers]
)
for provider_controller in workflow_provider_controllers:
user_provider = ToolTransformService.workflow_provider_to_user_provider(
@ -693,7 +708,7 @@ class ToolManager:
get tool provider
"""
provider_name = provider
provider_obj: ApiToolProvider = (
provider_obj: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
@ -707,7 +722,7 @@ class ToolManager:
try:
credentials = json.loads(provider_obj.credentials_str) or {}
except:
except Exception:
credentials = {}
# package tool provider controller
@ -728,7 +743,7 @@ class ToolManager:
try:
icon = json.loads(provider_obj.icon)
except:
except Exception:
icon = {"background": "#252525", "content": "\ud83d\ude01"}
# add tool labels
@ -783,7 +798,7 @@ class ToolManager:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
return json.loads(workflow_provider.icon)
except:
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}
@classmethod
@ -799,7 +814,7 @@ class ToolManager:
raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
return json.loads(api_provider.icon)
except:
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}
@classmethod
@ -824,7 +839,7 @@ class ToolManager:
if isinstance(provider, PluginToolProviderController):
try:
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
except:
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}
return cls.generate_builtin_tool_icon_url(provider_id)
elif provider_type == ToolProviderType.API:
@ -836,7 +851,7 @@ class ToolManager:
if isinstance(provider, PluginToolProviderController):
try:
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
except:
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}
raise ValueError(f"plugin provider {provider_id} not found")
else:

View File

@ -101,7 +101,7 @@ class ProviderConfigEncrypter(BaseModel):
continue
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
except:
except Exception:
pass
cache.set(data)
@ -221,6 +221,9 @@ class ToolParameterConfigurationManager:
return a deep copy of parameters with decrypted values
"""
if self.tool_runtime is None or self.tool_runtime.identity is None:
raise ValueError("tool_runtime is required")
cache = ToolParameterCache(
tenant_id=self.tenant_id,
provider=f"{self.provider_type.value}.{self.provider_name}",
@ -245,7 +248,7 @@ class ToolParameterConfigurationManager:
try:
has_secret_input = True
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
except:
except Exception:
pass
if has_secret_input:

View File

@ -1,4 +1,5 @@
import threading
from typing import Any
from flask import Flask, current_app
from pydantic import BaseModel, Field
@ -7,13 +8,14 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.models.document import Document as RagDocument
from core.rag.rerank.rerank_model import RerankModelRunner
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
default_retrieval_model = {
default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
@ -44,7 +46,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
def _run(self, query: str) -> str:
threads = []
all_documents = []
all_documents: list[RagDocument] = []
for dataset_id in self.dataset_ids:
retrieval_thread = threading.Thread(
target=self._retriever,
@ -77,8 +79,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
document_score_list = {}
for item in all_documents:
assert item.metadata
if item.metadata.get("score"):
if item.metadata and item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
document_context_list = []
@ -87,7 +88,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.enabled is True,
DocumentSegment.index_node_id.in_(index_node_ids),
).all()
@ -108,8 +109,8 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = Document.query.filter(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
Document.enabled is True,
Document.archived is False,
).first()
if dataset and document:
source = {
@ -140,6 +141,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
hit_callback.return_retriever_resource_info(context_list)
return str("\n".join(document_context_list))
return ""
raise RuntimeError("not segments found")

View File

@ -1,7 +1,7 @@
from abc import abstractmethod
from typing import Any, Optional
from msal_extensions.persistence import ABC
from msal_extensions.persistence import ABC # type: ignore
from pydantic import BaseModel, ConfigDict
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler

View File

@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel, Field
from core.rag.datasource.retrieval_service import RetrievalService
@ -69,25 +71,27 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
metadata=external_document.get("metadata"),
provider="external",
)
document.metadata["score"] = external_document.get("score")
document.metadata["title"] = external_document.get("title")
document.metadata["dataset_id"] = dataset.id
document.metadata["dataset_name"] = dataset.name
results.append(document)
if document.metadata is not None:
document.metadata["score"] = external_document.get("score")
document.metadata["title"] = external_document.get("title")
document.metadata["dataset_id"] = dataset.id
document.metadata["dataset_name"] = dataset.name
results.append(document)
# deal with external documents
context_list = []
for position, item in enumerate(results, start=1):
source = {
"position": position,
"dataset_id": item.metadata.get("dataset_id"),
"dataset_name": item.metadata.get("dataset_name"),
"document_name": item.metadata.get("title"),
"data_source_type": "external",
"retriever_from": self.retriever_from,
"score": item.metadata.get("score"),
"title": item.metadata.get("title"),
"content": item.page_content,
}
if item.metadata is not None:
source = {
"position": position,
"dataset_id": item.metadata.get("dataset_id"),
"dataset_name": item.metadata.get("dataset_name"),
"document_name": item.metadata.get("title"),
"data_source_type": "external",
"retriever_from": self.retriever_from,
"score": item.metadata.get("score"),
"title": item.metadata.get("title"),
"content": item.page_content,
}
context_list.append(source)
for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list)
@ -95,7 +99,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
return str("\n".join([item.page_content for item in results]))
else:
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model or default_retrieval_model
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
@ -113,11 +117,11 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
reranking_model=retrieval_model.get("reranking_model", None)
reranking_model=retrieval_model.get("reranking_model")
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
weights=retrieval_model.get("weights"),
)
else:
documents = []
@ -127,7 +131,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
document_score_list = {}
if dataset.indexing_technique != "economy":
for item in documents:
if item.metadata.get("score"):
if item.metadata is not None and item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
document_context_list = []
index_node_ids = [document.metadata["doc_id"] for document in documents]
@ -155,20 +159,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
context_list = []
resource_number = 1
for segment in sorted_segments:
context = {}
document = Document.query.filter(
document_segment = Document.query.filter(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
).first()
if dataset and document:
if not document_segment:
continue
if dataset and document_segment:
source = {
"position": resource_number,
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"document_id": document.id,
"document_name": document.name,
"data_source_type": document.data_source_type,
"document_id": document_segment.id,
"document_name": document_segment.name,
"data_source_type": document_segment.data_source_type,
"segment_id": segment.id,
"retriever_from": self.retriever_from,
"score": document_score_list.get(segment.index_node_id, None),

View File

@ -94,6 +94,7 @@ class DatasetRetrieverTool(Tool):
llm_description="Query for the dataset to be used to retrieve the dataset.",
required=True,
default="",
placeholder=I18nObject(en_US="", zh_Hans=""),
),
]
@ -112,7 +113,9 @@ class DatasetRetrieverTool(Tool):
result = self.retrieval_tool._run(query=query)
yield self.create_text_message(text=result)
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
def validate_credentials(
self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False
) -> str | None:
"""
validate the credentials for dataset retriever tool
"""

View File

@ -5,7 +5,7 @@ Therefore, a model manager is needed to list/invoke/validate models.
"""
import json
from typing import cast
from typing import Optional, cast
from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import LLMResult
@ -51,7 +51,7 @@ class ModelInvocationUtils:
if not schema:
raise InvokeModelError("No model schema found")
max_tokens = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
max_tokens: Optional[int] = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
if max_tokens is None:
return 2048
@ -133,14 +133,17 @@ class ModelInvocationUtils:
db.session.commit()
try:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=[],
stop=[],
stream=False,
user=user_id,
callbacks=[],
response: LLMResult = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=[],
stop=[],
stream=False,
user=user_id,
callbacks=[],
),
)
except InvokeRateLimitError as e:
raise InvokeModelError(f"Invoke rate limit error: {e}")

View File

@ -5,7 +5,7 @@ from json import loads as json_loads
from json.decoder import JSONDecodeError
from requests import get
from yaml import YAMLError, safe_load
from yaml import YAMLError, safe_load # type: ignore
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
@ -63,6 +63,9 @@ class ApiBasedToolSchemaParser:
default=parameter["schema"]["default"]
if "schema" in parameter and "default" in parameter["schema"]
else None,
placeholder=I18nObject(
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
),
)
# check if there is a type
@ -107,6 +110,9 @@ class ApiBasedToolSchemaParser:
form=ToolParameter.ToolParameterForm.LLM,
llm_description=property.get("description", ""),
default=property.get("default", None),
placeholder=I18nObject(
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
),
)
# check if there is a type
@ -157,9 +163,9 @@ class ApiBasedToolSchemaParser:
return bundles
@staticmethod
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType:
def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]:
parameter = parameter or {}
typ = None
typ: Optional[str] = None
if parameter.get("format") == "binary":
return ToolParameter.ToolParameterType.FILE
@ -174,6 +180,8 @@ class ApiBasedToolSchemaParser:
return ToolParameter.ToolParameterType.BOOLEAN
elif typ == "string":
return ToolParameter.ToolParameterType.STRING
else:
return None
@staticmethod
def parse_openapi_yaml_to_tool_bundle(
@ -236,7 +244,8 @@ class ApiBasedToolSchemaParser:
if ("summary" not in operation or len(operation["summary"]) == 0) and (
"description" not in operation or len(operation["description"]) == 0
):
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
if warning is not None:
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
openapi["paths"][path][method] = {
"operationId": operation["operationId"],

View File

@ -9,13 +9,13 @@ import tempfile
import unicodedata
from contextlib import contextmanager
from pathlib import Path
from typing import Optional
from typing import Any, Literal, Optional, cast
from urllib.parse import unquote
import chardet
import cloudscraper
from bs4 import BeautifulSoup, CData, Comment, NavigableString
from regex import regex
import cloudscraper # type: ignore
from bs4 import BeautifulSoup, CData, Comment, NavigableString # type: ignore
from regex import regex # type: ignore
from core.helper import ssrf_proxy
from core.rag.extractor import extract_processor
@ -68,7 +68,7 @@ def get_url(url: str, user_agent: Optional[str] = None) -> str:
return "Unsupported content-type [{}] of URL.".format(main_content_type)
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
return ExtractProcessor.load_from_url(url, return_text=True)
return cast(str, ExtractProcessor.load_from_url(url, return_text=True))
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
elif response.status_code == 403:
@ -125,7 +125,7 @@ def extract_using_readabilipy(html):
os.unlink(article_json_path)
os.unlink(html_path)
article_json = {
article_json: dict[str, Any] = {
"title": None,
"byline": None,
"date": None,
@ -300,7 +300,7 @@ def strip_control_characters(text):
def normalize_unicode(text):
"""Normalize unicode such that things that are visually equivalent map to the same unicode string where possible."""
normal_form = "NFKC"
normal_form: Literal["NFC", "NFD", "NFKC", "NFKD"] = "NFKC"
text = unicodedata.normalize(normal_form, text)
return text
@ -332,6 +332,7 @@ def add_content_digest(element):
def content_digest(element):
digest: Any
if is_text(element):
# Hash
trimmed_string = element.string.strip()

View File

@ -7,7 +7,7 @@ from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
class WorkflowToolConfigurationUtils:
@classmethod
def check_parameter_configurations(cls, configurations: Mapping[str, Any]):
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
for configuration in configurations:
WorkflowToolParameterConfiguration.model_validate(configuration)
@ -27,7 +27,7 @@ class WorkflowToolConfigurationUtils:
@classmethod
def check_is_synced(
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
) -> None:
) -> bool:
"""
check is synced

View File

@ -2,7 +2,7 @@ import logging
from pathlib import Path
from typing import Any
import yaml
import yaml # type: ignore
from yaml import YAMLError
logger = logging.getLogger(__name__)

View File

@ -18,6 +18,7 @@ from core.tools.entities.tool_entities import (
ToolProviderIdentity,
ToolProviderType,
)
from core.tools.tool.tool import Tool
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
@ -130,6 +131,7 @@ class WorkflowToolProviderController(ToolProviderController):
llm_description=parameter.description,
required=variable.required,
options=options,
placeholder=I18nObject(en_US="", zh_Hans=""),
)
)
elif features.file_upload:
@ -142,6 +144,7 @@ class WorkflowToolProviderController(ToolProviderController):
llm_description=parameter.description,
required=False,
form=parameter.form,
placeholder=I18nObject(en_US="", zh_Hans=""),
)
)
else:
@ -198,6 +201,8 @@ class WorkflowToolProviderController(ToolProviderController):
if not db_providers:
return []
if not db_providers.app:
raise ValueError("app not found")
app = db_providers.app
if not app:
@ -207,7 +212,7 @@ class WorkflowToolProviderController(ToolProviderController):
return self.tools
def get_tool(self, tool_name: str) -> Optional[WorkflowTool]:
def get_tool(self, tool_name: str) -> Optional[Tool]:
"""
get tool by name

View File

@ -102,7 +102,7 @@ class WorkflowTool(Tool):
raise Exception(data.get("error"))
outputs = data.get("outputs")
if outputs == None:
if outputs is None:
outputs = {}
else:
outputs, files = self._extract_files(outputs)