refactor: tool

This commit is contained in:
Yeuoly
2024-09-20 02:25:14 +08:00
parent c472ea6c67
commit 661392eaef
524 changed files with 338 additions and 31279 deletions

View File

@ -0,0 +1,3 @@
- code
- time
- qrcode

View File

@ -0,0 +1,159 @@
from abc import abstractmethod
from os import listdir, path
from typing import Any
from pydantic import Field
from core.entities.provider_entities import ProviderConfig
from core.helper.module_import_helper import load_single_subclass_from_source
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
from core.tools.errors import (
ToolProviderNotFoundError,
)
from core.tools.utils.yaml_utils import load_yaml_file
class BuiltinToolProviderController(ToolProviderController):
tools: list[BuiltinTool] = Field(default_factory=list)
def __init__(self, **data: Any) -> None:
if self.provider_type in {ToolProviderType.API, ToolProviderType.APP}:
super().__init__(**data)
return
# load provider yaml
provider = self.__class__.__module__.split(".")[-1]
yaml_path = path.join(path.dirname(path.realpath(__file__)), "providers", provider, f"{provider}.yaml")
try:
provider_yaml = load_yaml_file(yaml_path, ignore_error=False)
except Exception as e:
raise ToolProviderNotFoundError(f"can not load provider yaml for {provider}: {e}")
if "credentials_for_provider" in provider_yaml and provider_yaml["credentials_for_provider"] is not None:
# set credentials name
for credential_name in provider_yaml["credentials_for_provider"]:
provider_yaml["credentials_for_provider"][credential_name]["name"] = credential_name
super().__init__(**{
'identity': provider_yaml['identity'],
'credentials_schema': provider_yaml.get('credentials_for_provider', {}) or {},
})
def _get_builtin_tools(self) -> list[BuiltinTool]:
"""
returns a list of tools that the provider can provide
:return: list of tools
"""
if self.tools:
return self.tools
provider = self.identity.name
tool_path = path.join(path.dirname(path.realpath(__file__)), "providers", provider, "tools")
# get all the yaml files in the tool path
tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path)))
tools = []
for tool_file in tool_files:
# get tool name
tool_name = tool_file.split(".")[0]
tool = load_yaml_file(path.join(tool_path, tool_file), ignore_error=False)
# get tool class, import the module
assistant_tool_class = load_single_subclass_from_source(
module_name=f"core.tools.builtin_tool.providers.{provider}.tools.{tool_name}",
script_path=path.join(
path.dirname(path.realpath(__file__)),
"builtin_tool", "providers", provider, "tools", f"{tool_name}.py"
),
parent_type=BuiltinTool,
)
tool["identity"]["provider"] = provider
tools.append(assistant_tool_class(**tool))
self.tools = tools
return tools
def get_credentials_schema(self) -> dict[str, ProviderConfig]:
"""
returns the credentials schema of the provider
:return: the credentials schema
"""
if not self.credentials_schema:
return {}
return self.credentials_schema.copy()
def get_tools(self) -> list[BuiltinTool]:
"""
returns a list of tools that the provider can provide
:return: list of tools
"""
return self._get_builtin_tools()
def get_tool(self, tool_name: str) -> BuiltinTool | None:
"""
returns the tool that the provider can provide
"""
return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
@property
def need_credentials(self) -> bool:
"""
returns whether the provider needs credentials
:return: whether the provider needs credentials
"""
return self.credentials_schema is not None and len(self.credentials_schema) != 0
@property
def provider_type(self) -> ToolProviderType:
"""
returns the type of the provider
:return: type of the provider
"""
return ToolProviderType.BUILT_IN
@property
def tool_labels(self) -> list[str]:
"""
returns the labels of the provider
:return: labels of the provider
"""
label_enums = self._get_tool_labels()
return [default_tool_label_dict[label].name for label in label_enums]
def _get_tool_labels(self) -> list[ToolLabelEnum]:
"""
returns the labels of the provider
"""
return self.identity.tags or []
def validate_credentials(self, credentials: dict[str, Any]) -> None:
"""
validate the credentials of the provider
:param tool_name: the name of the tool, defined in `get_tools`
:param credentials: the credentials of the tool
"""
# validate credentials format
self.validate_credentials_format(credentials)
# validate credentials
self._validate_credentials(credentials)
@abstractmethod
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
"""
validate the credentials of the provider
:param tool_name: the name of the tool, defined in `get_tools`
:param credentials: the credentials of the tool
"""
pass

View File

@ -0,0 +1,20 @@
import os.path
from core.helper.position_helper import get_tool_position_map, sort_by_position_map
from core.tools.entities.api_entities import UserToolProvider
class BuiltinToolProviderSort:
_position = {}
@classmethod
def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
if not cls._position:
cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), ".."))
def name_func(provider: UserToolProvider) -> str:
return provider.name
sorted_providers = sort_by_position_map(cls._position, providers, name_func)
return sorted_providers

View File

@ -0,0 +1 @@
<svg width="14" height="14" viewBox="0 0 14 14" fill="none" xmlns="http://www.w3.org/2000/svg" class="w-3.5 h-3.5" data-icon="Code" aria-hidden="true"><g id="icons/code"><path id="Vector (Stroke)" fill-rule="evenodd" clip-rule="evenodd" d="M8.32593 1.69675C8.67754 1.78466 8.89132 2.14096 8.80342 2.49257L6.47009 11.8259C6.38218 12.1775 6.02588 12.3913 5.67427 12.3034C5.32265 12.2155 5.10887 11.8592 5.19678 11.5076L7.53011 2.17424C7.61801 1.82263 7.97431 1.60885 8.32593 1.69675ZM3.96414 4.20273C4.22042 4.45901 4.22042 4.87453 3.96413 5.13081L2.45578 6.63914C2.45577 6.63915 2.45578 6.63914 2.45578 6.63914C2.25645 6.83851 2.25643 7.16168 2.45575 7.36103C2.45574 7.36103 2.45576 7.36104 2.45575 7.36103L3.96413 8.86936C4.22041 9.12564 4.22042 9.54115 3.96414 9.79744C3.70787 10.0537 3.29235 10.0537 3.03607 9.79745L1.52769 8.28913C0.815811 7.57721 0.815803 6.42302 1.52766 5.7111L3.03606 4.20272C3.29234 3.94644 3.70786 3.94644 3.96414 4.20273ZM10.0361 4.20273C10.2923 3.94644 10.7078 3.94644 10.9641 4.20272L12.4725 5.71108C13.1843 6.423 13.1844 7.57717 12.4725 8.28909L10.9641 9.79745C10.7078 10.0537 10.2923 10.0537 10.036 9.79744C9.77977 9.54115 9.77978 9.12564 10.0361 8.86936L11.5444 7.36107C11.7437 7.16172 11.7438 6.83854 11.5444 6.63917C11.5444 6.63915 11.5445 6.63918 11.5444 6.63917L10.0361 5.13081C9.77978 4.87453 9.77978 4.45901 10.0361 4.20273Z" fill="currentColor"></path></g></svg>

After

Width:  |  Height:  |  Size: 1.4 KiB

View File

@ -0,0 +1,8 @@
from typing import Any
from core.tools.builtin_tool.provider import BuiltinToolProviderController
class CodeToolProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
pass

View File

@ -0,0 +1,15 @@
identity:
author: Dify
name: code
label:
en_US: Code Interpreter
zh_Hans: 代码解释器
pt_BR: Interpretador de Código
description:
en_US: Run a piece of code and get the result back.
zh_Hans: 运行一段代码并返回结果。
pt_BR: Execute um trecho de código e obtenha o resultado de volta.
icon: icon.svg
tags:
- productivity
credentials_for_provider:

View File

@ -0,0 +1,22 @@
from typing import Any
from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
class SimpleCode(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
"""
invoke simple code
"""
language = tool_parameters.get("language", CodeLanguage.PYTHON3)
code = tool_parameters.get("code", "")
if language not in {CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT}:
raise ValueError(f"Only python3 and javascript are supported, not {language}")
result = CodeExecutor.execute_code(language, "", code)
return self.create_text_message(result)

View File

@ -0,0 +1,51 @@
identity:
name: simple_code
author: Dify
label:
en_US: Code Interpreter
zh_Hans: 代码解释器
pt_BR: Interpretador de Código
description:
human:
en_US: Run code and get the result back. When you're using a lower quality model, please make sure there are some tips help LLM to understand how to write the code.
zh_Hans: 运行一段代码并返回结果。当您使用较低质量的模型时请确保有一些提示帮助LLM理解如何编写代码。
pt_BR: Execute um trecho de código e obtenha o resultado de volta. quando você estiver usando um modelo de qualidade inferior, certifique-se de que existam algumas dicas para ajudar o LLM a entender como escrever o código.
llm: A tool for running code and getting the result back. Only native packages are allowed, network/IO operations are disabled. and you must use print() or console.log() to output the result or result will be empty.
parameters:
- name: language
type: string
required: true
label:
en_US: Language
zh_Hans: 语言
pt_BR: Idioma
human_description:
en_US: The programming language of the code
zh_Hans: 代码的编程语言
pt_BR: A linguagem de programação do código
llm_description: language of the code, only "python3" and "javascript" are supported
form: llm
options:
- value: python3
label:
en_US: Python3
zh_Hans: Python3
pt_BR: Python3
- value: javascript
label:
en_US: JavaScript
zh_Hans: JavaScript
pt_BR: JavaScript
- name: code
type: string
required: true
label:
en_US: Code
zh_Hans: 代码
pt_BR: Código
human_description:
en_US: The code to be executed
zh_Hans: 要执行的代码
pt_BR: O código a ser executado
llm_description: code to be executed, only native packages are allowed, network/IO operations are disabled.
form: llm

View File

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="utf-8"?>
<svg width="800px" height="800px" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg">
<g>
<path fill="none" d="M0 0h24v24H0z"/>
<path d="M16 17v-1h-3v-3h3v2h2v2h-1v2h-2v2h-2v-3h2v-1h1zm5 4h-4v-2h2v-2h2v4zM3 3h8v8H3V3zm2 2v4h4V5H5zm8-2h8v8h-8V3zm2 2v4h4V5h-4zM3 13h8v8H3v-8zm2 2v4h4v-4H5zm13-2h3v2h-3v-2zM6 6h2v2H6V6zm0 10h2v2H6v-2zM16 6h2v2h-2V6z"/>
</g>
</svg>

After

Width:  |  Height:  |  Size: 428 B

View File

@ -0,0 +1,13 @@
from typing import Any
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers.qrcode.tools.qrcode_generator import QRCodeGeneratorTool
from core.tools.errors import ToolProviderCredentialValidationError
class QRCodeProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
QRCodeGeneratorTool().invoke(user_id="", tool_parameters={"content": "Dify 123 😊"})
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -0,0 +1,14 @@
identity:
author: Bowen Liang
name: qrcode
label:
en_US: QRCode
zh_Hans: 二维码工具
pt_BR: QRCode
description:
en_US: A tool for generating QR code (quick-response code) image.
zh_Hans: 一个二维码工具
pt_BR: A tool for generating QR code (quick-response code) image.
icon: icon.svg
tags:
- utilities

View File

@ -0,0 +1,70 @@
import io
import logging
from typing import Any, Union
from qrcode.constants import ERROR_CORRECT_H, ERROR_CORRECT_L, ERROR_CORRECT_M, ERROR_CORRECT_Q
from qrcode.image.base import BaseImage
from qrcode.image.pure import PyPNGImage
from qrcode.main import QRCode
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
class QRCodeGeneratorTool(BuiltinTool):
error_correction_levels: dict[str, int] = {
"L": ERROR_CORRECT_L, # <=7%
"M": ERROR_CORRECT_M, # <=15%
"Q": ERROR_CORRECT_Q, # <=25%
"H": ERROR_CORRECT_H, # <=30%
}
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
# get text content
content = tool_parameters.get("content", "")
if not content:
return self.create_text_message("Invalid parameter content")
# get border size
border = tool_parameters.get("border", 0)
if border < 0 or border > 100:
return self.create_text_message("Invalid parameter border")
# get error_correction
error_correction = tool_parameters.get("error_correction", "")
if error_correction not in self.error_correction_levels:
return self.create_text_message("Invalid parameter error_correction")
try:
image = self._generate_qrcode(content, border, error_correction)
image_bytes = self._image_to_byte_array(image)
return self.create_blob_message(
blob=image_bytes, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value
)
except Exception:
logging.exception(f"Failed to generate QR code for content: {content}")
return self.create_text_message("Failed to generate QR code")
def _generate_qrcode(self, content: str, border: int, error_correction: str) -> BaseImage:
qr = QRCode(
image_factory=PyPNGImage,
error_correction=self.error_correction_levels.get(error_correction),
border=border,
)
qr.add_data(data=content)
qr.make(fit=True)
img = qr.make_image()
return img
@staticmethod
def _image_to_byte_array(image: BaseImage) -> bytes:
byte_stream = io.BytesIO()
image.save(byte_stream)
return byte_stream.getvalue()

View File

@ -0,0 +1,76 @@
identity:
name: qrcode_generator
author: Bowen Liang
label:
en_US: Generate QR Code
zh_Hans: 生成二维码
pt_BR: Generate QR Code
description:
human:
en_US: A tool for generating QR code image
zh_Hans: 一个用于生成二维码的工具
pt_BR: A tool for generating QR code image
llm: A tool for generating QR code image
parameters:
- name: content
type: string
required: true
label:
en_US: content text for QR code
zh_Hans: 二维码文本内容
pt_BR: content text for QR code
human_description:
en_US: content text for QR code
zh_Hans: 二维码文本内容
pt_BR: 二维码文本内容
form: llm
- name: error_correction
type: select
required: true
default: M
label:
en_US: Error Correction
zh_Hans: 容错等级
pt_BR: Error Correction
human_description:
en_US: Error Correction in L, M, Q or H, from low to high, the bigger size of generated QR code with the better error correction effect
zh_Hans: 容错等级,可设置为低、中、偏高或高,从低到高,生成的二维码越大且容错效果越好
pt_BR: Error Correction in L, M, Q or H, from low to high, the bigger size of generated QR code with the better error correction effect
options:
- value: L
label:
en_US: Low
zh_Hans:
pt_BR: Low
- value: M
label:
en_US: Medium
zh_Hans:
pt_BR: Medium
- value: Q
label:
en_US: Quartile
zh_Hans: 偏高
pt_BR: Quartile
- value: H
label:
en_US: High
zh_Hans:
pt_BR: High
form: form
- name: border
type: number
required: true
default: 2
min: 0
max: 100
label:
en_US: border size
zh_Hans: 边框粗细
pt_BR: border size
human_description:
en_US: border sizedefault to 2
zh_Hans: 边框粗细的格数默认为2
pt_BR: border sizedefault to 2
llm: border size, default to 2
form: form

View File

@ -0,0 +1,3 @@
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path fill-rule="evenodd" clip-rule="evenodd" d="M0.666992 8.00008C0.666992 3.94999 3.95024 0.666748 8.00033 0.666748C12.0504 0.666748 15.3337 3.94999 15.3337 8.00008C15.3337 12.0502 12.0504 15.3334 8.00033 15.3334C3.95024 15.3334 0.666992 12.0502 0.666992 8.00008ZM8.66699 4.00008C8.66699 3.63189 8.36852 3.33341 8.00033 3.33341C7.63213 3.33341 7.33366 3.63189 7.33366 4.00008V8.00008C7.33366 8.2526 7.47633 8.48344 7.70218 8.59637L10.3688 9.9297C10.6982 10.0944 11.0986 9.96088 11.2633 9.63156C11.4279 9.30224 11.2945 8.90179 10.9651 8.73713L8.66699 7.58806V4.00008Z" fill="#EC4A0A"/>
</svg>

After

Width:  |  Height:  |  Size: 691 B

View File

@ -0,0 +1,16 @@
from typing import Any
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers.time.tools.current_time import CurrentTimeTool
from core.tools.errors import ToolProviderCredentialValidationError
class WikiPediaProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
CurrentTimeTool().invoke(
user_id="",
tool_parameters={},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -0,0 +1,15 @@
identity:
author: Dify
name: time
label:
en_US: CurrentTime
zh_Hans: 时间
pt_BR: CurrentTime
description:
en_US: A tool for getting the current time.
zh_Hans: 一个用于获取当前时间的工具。
pt_BR: A tool for getting the current time.
icon: icon.svg
tags:
- utilities
credentials_for_provider:

View File

@ -0,0 +1,29 @@
from datetime import datetime, timezone
from typing import Any, Union
from pytz import timezone as pytz_timezone
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
class CurrentTimeTool(BuiltinTool):
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
# get timezone
tz = tool_parameters.get("timezone", "UTC")
fm = tool_parameters.get("format") or "%Y-%m-%d %H:%M:%S %Z"
if tz == "UTC":
return self.create_text_message(f"{datetime.now(timezone.utc).strftime(fm)}")
try:
tz = pytz_timezone(tz)
except:
return self.create_text_message(f"Invalid timezone: {tz}")
return self.create_text_message(f"{datetime.now(tz).strftime(fm)}")

View File

@ -0,0 +1,131 @@
identity:
name: current_time
author: Dify
label:
en_US: Current Time
zh_Hans: 获取当前时间
pt_BR: Current Time
description:
human:
en_US: A tool for getting the current time.
zh_Hans: 一个用于获取当前时间的工具。
pt_BR: A tool for getting the current time.
llm: A tool for getting the current time.
parameters:
- name: format
type: string
required: false
label:
en_US: Format
zh_Hans: 格式
pt_BR: Format
human_description:
en_US: Time format in strftime standard.
zh_Hans: strftime 标准的时间格式。
pt_BR: Time format in strftime standard.
form: form
default: "%Y-%m-%d %H:%M:%S"
- name: timezone
type: select
required: false
label:
en_US: Timezone
zh_Hans: 时区
pt_BR: Timezone
human_description:
en_US: Timezone
zh_Hans: 时区
pt_BR: Timezone
form: form
default: UTC
options:
- value: UTC
label:
en_US: UTC
zh_Hans: UTC
pt_BR: UTC
- value: America/New_York
label:
en_US: America/New_York
zh_Hans: 美洲/纽约
pt_BR: America/New_York
- value: America/Los_Angeles
label:
en_US: America/Los_Angeles
zh_Hans: 美洲/洛杉矶
pt_BR: America/Los_Angeles
- value: America/Chicago
label:
en_US: America/Chicago
zh_Hans: 美洲/芝加哥
pt_BR: America/Chicago
- value: America/Sao_Paulo
label:
en_US: America/Sao_Paulo
zh_Hans: 美洲/圣保罗
pt_BR: América/São Paulo
- value: Asia/Shanghai
label:
en_US: Asia/Shanghai
zh_Hans: 亚洲/上海
pt_BR: Asia/Shanghai
- value: Asia/Ho_Chi_Minh
label:
en_US: Asia/Ho_Chi_Minh
zh_Hans: 亚洲/胡志明市
pt_BR: Ásia/Ho Chi Minh
- value: Asia/Tokyo
label:
en_US: Asia/Tokyo
zh_Hans: 亚洲/东京
pt_BR: Asia/Tokyo
- value: Asia/Dubai
label:
en_US: Asia/Dubai
zh_Hans: 亚洲/迪拜
pt_BR: Asia/Dubai
- value: Asia/Kolkata
label:
en_US: Asia/Kolkata
zh_Hans: 亚洲/加尔各答
pt_BR: Asia/Kolkata
- value: Asia/Seoul
label:
en_US: Asia/Seoul
zh_Hans: 亚洲/首尔
pt_BR: Asia/Seoul
- value: Asia/Singapore
label:
en_US: Asia/Singapore
zh_Hans: 亚洲/新加坡
pt_BR: Asia/Singapore
- value: Europe/London
label:
en_US: Europe/London
zh_Hans: 欧洲/伦敦
pt_BR: Europe/London
- value: Europe/Berlin
label:
en_US: Europe/Berlin
zh_Hans: 欧洲/柏林
pt_BR: Europe/Berlin
- value: Europe/Moscow
label:
en_US: Europe/Moscow
zh_Hans: 欧洲/莫斯科
pt_BR: Europe/Moscow
- value: Australia/Sydney
label:
en_US: Australia/Sydney
zh_Hans: 澳大利亚/悉尼
pt_BR: Australia/Sydney
- value: Pacific/Auckland
label:
en_US: Pacific/Auckland
zh_Hans: 太平洋/奥克兰
pt_BR: Pacific/Auckland
- value: Africa/Cairo
label:
en_US: Africa/Cairo
zh_Hans: 非洲/开罗
pt_BR: Africa/Cairo

View File

@ -0,0 +1,45 @@
import calendar
from datetime import datetime
from typing import Any, Union
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
class WeekdayTool(BuiltinTool):
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
Calculate the day of the week for a given date
"""
year = tool_parameters.get("year")
month = tool_parameters.get("month")
if month is None:
raise ValueError("Month is required")
day = tool_parameters.get("day")
date_obj = self.convert_datetime(year, month, day)
if not date_obj:
return self.create_text_message(f"Invalid date: Year {year}, Month {month}, Day {day}.")
weekday_name = calendar.day_name[date_obj.weekday()]
month_name = calendar.month_name[month]
readable_date = f"{month_name} {date_obj.day}, {date_obj.year}"
return self.create_text_message(f"{readable_date} is {weekday_name}.")
@staticmethod
def convert_datetime(year, month, day) -> datetime | None:
try:
# allowed range in datetime module
if not (year >= 1 and 1 <= month <= 12 and 1 <= day <= 31):
return None
year = int(year)
month = int(month)
day = int(day)
return datetime(year, month, day)
except ValueError:
return None

View File

@ -0,0 +1,42 @@
identity:
name: weekday
author: Bowen Liang
label:
en_US: Weekday Calculator
zh_Hans: 星期几计算器
description:
human:
en_US: A tool for calculating the weekday of a given date.
zh_Hans: 计算指定日期为星期几的工具。
llm: A tool for calculating the weekday of a given date by year, month and day.
parameters:
- name: year
type: number
required: true
form: llm
label:
en_US: Year
zh_Hans:
human_description:
en_US: Year
zh_Hans:
- name: month
type: number
required: true
form: llm
label:
en_US: Month
zh_Hans:
human_description:
en_US: Month
zh_Hans:
- name: day
type: number
required: true
form: llm
label:
en_US: day
zh_Hans:
human_description:
en_US: day
zh_Hans:

View File

@ -0,0 +1,131 @@
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
from core.tools.utils.web_reader_tool import get_url
_SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language
and you can quickly aimed at the main point of an webpage and reproduce it in your own words but
retain the original meaning and keep the key points.
however, the text you got is too long, what you got is possible a part of the text.
Please summarize the text you got.
"""
class BuiltinTool(Tool):
"""
Builtin tool
:param meta: the meta data of a tool call processing
"""
def invoke_model(self, user_id: str, prompt_messages: list[PromptMessage], stop: list[str]) -> LLMResult:
"""
invoke model
:param model_config: the model config
:param prompt_messages: the prompt messages
:param stop: the stop words
:return: the model result
"""
# invoke model
return ModelInvocationUtils.invoke(
user_id=user_id,
tenant_id=self.runtime.tenant_id or "",
tool_type="builtin",
tool_name=self.identity.name,
prompt_messages=prompt_messages,
)
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.BUILT_IN
def get_max_tokens(self) -> int:
"""
get max tokens
:param model_config: the model config
:return: the max tokens
"""
return ModelInvocationUtils.get_max_llm_context_tokens(
tenant_id=self.runtime.tenant_id,
)
def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int:
"""
get prompt tokens
:param prompt_messages: the prompt messages
:return: the tokens
"""
return ModelInvocationUtils.calculate_tokens(tenant_id=self.runtime.tenant_id, prompt_messages=prompt_messages)
def summary(self, user_id: str, content: str) -> str:
max_tokens = self.get_max_tokens()
if self.get_prompt_tokens(prompt_messages=[UserPromptMessage(content=content)]) < max_tokens * 0.6:
return content
def get_prompt_tokens(content: str) -> int:
return self.get_prompt_tokens(
prompt_messages=[SystemPromptMessage(content=_SUMMARY_PROMPT), UserPromptMessage(content=content)]
)
def summarize(content: str) -> str:
summary = self.invoke_model(
user_id=user_id,
prompt_messages=[SystemPromptMessage(content=_SUMMARY_PROMPT), UserPromptMessage(content=content)],
stop=[],
)
return summary.message.content
lines = content.split("\n")
new_lines = []
# split long line into multiple lines
for i in range(len(lines)):
line = lines[i]
if not line.strip():
continue
if len(line) < max_tokens * 0.5:
new_lines.append(line)
elif get_prompt_tokens(line) > max_tokens * 0.7:
while get_prompt_tokens(line) > max_tokens * 0.7:
new_lines.append(line[: int(max_tokens * 0.5)])
line = line[int(max_tokens * 0.5) :]
new_lines.append(line)
else:
new_lines.append(line)
# merge lines into messages with max tokens
messages: list[str] = []
for i in new_lines:
if len(messages) == 0:
messages.append(i)
else:
if len(messages[-1]) + len(i) < max_tokens * 0.5:
messages[-1] += i
if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7:
messages.append(i)
else:
messages[-1] += i
summaries = []
for i in range(len(messages)):
message = messages[i]
summary = summarize(message)
summaries.append(summary)
result = "\n".join(summaries)
if self.get_prompt_tokens(prompt_messages=[UserPromptMessage(content=result)]) > max_tokens * 0.7:
return self.summary(user_id=user_id, content=result)
return result
def get_url(self, url: str, user_agent: str | None = None) -> str:
"""
get url
"""
return get_url(url, user_agent=user_agent)