Feat/assistant app (#2086)

Co-authored-by: chenhe <guchenhe@gmail.com>
Co-authored-by: Pascal M <11357019+perzeuss@users.noreply.github.com>
This commit is contained in:
Yeuoly
2024-01-23 19:58:23 +08:00
committed by GitHub
parent 7bbe12b2bd
commit 86286e1ac8
175 changed files with 11619 additions and 1235 deletions

25
api/core/tools/README.md Normal file
View File

@ -0,0 +1,25 @@
# Tools
This module implements built-in tools used in Agent Assistants and Workflows within Dify. You could define and display your own tools in this module, without modifying the frontend logic. This decoupling allows for easier horizontal scaling of Dify's capabilities.
## Feature Introduction
The tools provided for Agents and Workflows are currently divided into two categories:
- `Built-in Tools` are internally implemented within our product and are hardcoded for use in Agents and Workflows.
- `Api-Based Tools` leverage third-party APIs for implementation. You don't need to code to integrate these -- simply provide interface definitions in formats like `OpenAPI` , `Swagger`, or the `OpenAI-plugin` on the front-end.
### Built-in Tool Providers
![Alt text](docs/zh_Hans/images/index/image.png)
### API Tool Providers
![Alt text](docs/zh_Hans/images/index/image-1.png)
## Tool Integration
To enable developers to build flexible and powerful tools, we provide two guides:
### [Quick Integration 👈🏻](./docs/en_US/tool_scale_out.md)
Quick integration aims at quickly getting you up to speed with tool integration by walking over an example Google Search tool.
### [Advanced Integration 👈🏻](./docs/en_US/advanced_scale_out.md)
Advanced integration will offer a deeper dive into the module interfaces, and explain how to implement more complex capabilities, such as generating images, combining multiple tools, and managing the flow of parameters, images, and files between different tools.

View File

@ -0,0 +1,27 @@
# Tools
该模块提供了各Agent和Workflow中会使用的内置工具的调用、鉴权接口并为 Dify 提供了统一的工具供应商的信息和凭据表单规则。
- 一方面将工具和业务代码解耦,方便开发者对模型横向扩展,
- 另一方面提供了只需在后端定义供应商和工具,即可在前端页面直接展示,无需修改前端逻辑。
## 功能介绍
对于给Agent和Workflow提供的工具我们当前将其分为两类
- `Built-in Tools` 内置工具即Dify内部实现的工具通过硬编码的方式提供给Agent和Workflow使用。
- `Api-Based Tools` 基于API的工具即通过调用第三方API实现的工具`Api-Based Tool`不需要再额外定义,只需提供`OpenAPI` `Swagger` `OpenAI plugin`等接口文档即可。
### 内置工具供应商
![Alt text](docs/zh_Hans/images/index/image.png)
### API工具供应商
![Alt text](docs/zh_Hans/images/index/image-1.png)
## 工具接入
为了实现更灵活更强大的功能Tools提供了一系列的接口帮助开发者快速构建想要的工具本文作为开发者的入门指南将会以[快速接入](./docs/zh_Hans/tool_scale_out.md)和[高级接入](./docs/zh_Hans/advanced_scale_out.md)两部分介绍如何接入工具。
### [快速接入 👈🏻](./docs/zh_Hans/tool_scale_out.md)
快速接入可以帮助你在10~20分钟内完成工具的接入但是这种接入方式只能实现简单的功能如果你想要实现更复杂的功能可以参考下面的高级接入。
### [高级接入 👈🏻](./docs/zh_Hans/advanced_scale_out.md)
高级接入将介绍如何实现更复杂的功能配置,包括实现图生图、实现多个工具的组合、实现参数、图片、文件在多个工具之间的流转。

View File

@ -0,0 +1,266 @@
# Advanced Tool Integration
Before starting with this advanced guide, please make sure you have a basic understanding of the tool integration process in Dify. Check out [Quick Integration](./tool_scale_out.md) for a quick runthrough.
## Tool Interface
We have defined a series of helper methods in the `Tool` class to help developers quickly build more complex tools.
### Message Return
Dify supports various message types such as `text`, `link`, `image`, and `file BLOB`. You can return different types of messages to the LLM and users through the following interfaces.
Please note, some parameters in the following interfaces will be introduced in later sections.
#### Image URL
You only need to pass the URL of the image, and Dify will automatically download the image and return it to the user.
```python
def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage:
"""
create an image message
:param image: the url of the image
:return: the image message
"""
```
#### Link
If you need to return a link, you can use the following interface.
```python
def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage:
"""
create a link message
:param link: the url of the link
:return: the link message
"""
```
#### Text
If you need to return a text message, you can use the following interface.
```python
def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage:
"""
create a text message
:param text: the text of the message
:return: the text message
"""
```
#### File BLOB
If you need to return the raw data of a file, such as images, audio, video, PPT, Word, Excel, etc., you can use the following interface.
- `blob` The raw data of the file, of bytes type
- `meta` The metadata of the file, if you know the type of the file, it is best to pass a `mime_type`, otherwise Dify will use `octet/stream` as the default type
```python
def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage:
"""
create a blob message
:param blob: the blob
:return: the blob message
"""
```
### Shortcut Tools
In large model applications, we have two common needs:
- First, summarize a long text in advance, and then pass the summary content to the LLM to prevent the original text from being too long for the LLM to handle
- The content obtained by the tool is a link, and the web page information needs to be crawled before it can be returned to the LLM
To help developers quickly implement these two needs, we provide the following two shortcut tools.
#### Text Summary Tool
This tool takes in an user_id and the text to be summarized, and returns the summarized text. Dify will use the default model of the current workspace to summarize the long text.
```python
def summary(self, user_id: str, content: str) -> str:
"""
summary the content
:param user_id: the user id
:param content: the content
:return: the summary
"""
```
#### Web Page Crawling Tool
This tool takes in web page link to be crawled and a user_agent (which can be empty), and returns a string containing the information of the web page. The `user_agent` is an optional parameter that can be used to identify the tool. If not passed, Dify will use the default `user_agent`.
```python
def get_url(self, url: str, user_agent: str = None) -> str:
"""
get url
""" the crawled result
```
### Variable Pool
We have introduced a variable pool in `Tool` to store variables, files, etc. generated during the tool's operation. These variables can be used by other tools during the tool's operation.
Next, we will use `DallE3` and `Vectorizer.AI` as examples to introduce how to use the variable pool.
- `DallE3` is an image generation tool that can generate images based on text. Here, we will let `DallE3` generate a logo for a coffee shop
- `Vectorizer.AI` is a vector image conversion tool that can convert images into vector images, so that the images can be infinitely enlarged without distortion. Here, we will convert the PNG icon generated by `DallE3` into a vector image, so that it can be truly used by designers.
#### DallE3
First, we use DallE3. After creating the image, we save the image to the variable pool. The code is as follows:
```python
from typing import Any, Dict, List, Union
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
from base64 import b64decode
from openai import OpenAI
class DallE3Tool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_paramters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
client = OpenAI(
api_key=self.runtime.credentials['openai_api_key'],
)
# prompt
prompt = tool_paramters.get('prompt', '')
if not prompt:
return self.create_text_message('Please input prompt')
# call openapi dalle3
response = client.images.generate(
prompt=prompt, model='dall-e-3',
size='1024x1024', n=1, style='vivid', quality='standard',
response_format='b64_json'
)
result = []
for image in response.data:
# Save all images to the variable pool through the save_as parameter. The variable name is self.VARIABLE_KEY.IMAGE.value. If new images are generated later, they will overwrite the previous images.
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
meta={ 'mime_type': 'image/png' },
save_as=self.VARIABLE_KEY.IMAGE.value))
return result
```
Note that we used `self.VARIABLE_KEY.IMAGE.value` as the variable name of the image. In order for developers' tools to cooperate with each other, we defined this `KEY`. You can use it freely, or you can choose not to use this `KEY`. Passing a custom KEY is also acceptable.
#### Vectorizer.AI
Next, we use Vectorizer.AI to convert the PNG icon generated by DallE3 into a vector image. Let's go through the functions we defined here. The code is as follows:
```python
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParamter
from core.tools.errors import ToolProviderCredentialValidationError
from typing import Any, Dict, List, Union
from httpx import post
from base64 import b64decode
class VectorizerTool(BuiltinTool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
Tool invocation, the image variable name needs to be passed in from here, so that we can get the image from the variable pool
"""
def get_runtime_parameters(self) -> List[ToolParamter]:
"""
Override the tool parameter list, we can dynamically generate the parameter list based on the actual situation in the current variable pool, so that the LLM can generate the form based on the parameter list
"""
def is_tool_avaliable(self) -> bool:
"""
Whether the current tool is available, if there is no image in the current variable pool, then we don't need to display this tool, just return False here
"""
```
Next, let's implement these three functions
```python
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParamter
from core.tools.errors import ToolProviderCredentialValidationError
from typing import Any, Dict, List, Union
from httpx import post
from base64 import b64decode
class VectorizerTool(BuiltinTool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
api_key_name = self.runtime.credentials.get('api_key_name', None)
api_key_value = self.runtime.credentials.get('api_key_value', None)
if not api_key_name or not api_key_value:
raise ToolProviderCredentialValidationError('Please input api key name and value')
# Get image_id, the definition of image_id can be found in get_runtime_parameters
image_id = tool_paramters.get('image_id', '')
if not image_id:
return self.create_text_message('Please input image id')
# Get the image generated by DallE from the variable pool
image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE)
if not image_binary:
return self.create_text_message('Image not found, please request user to generate image firstly.')
# Generate vector image
response = post(
'https://vectorizer.ai/api/v1/vectorize',
files={ 'image': image_binary },
data={ 'mode': 'test' },
auth=(api_key_name, api_key_value),
timeout=30
)
if response.status_code != 200:
raise Exception(response.text)
return [
self.create_text_message('the vectorized svg is saved as an image.'),
self.create_blob_message(blob=response.content,
meta={'mime_type': 'image/svg+xml'})
]
def get_runtime_parameters(self) -> List[ToolParamter]:
"""
override the runtime parameters
"""
# Here, we override the tool parameter list, define the image_id, and set its option list to all images in the current variable pool. The configuration here is consistent with the configuration in yaml.
return [
ToolParamter.get_simple_instance(
name='image_id',
llm_description=f'the image id that you want to vectorize, \
and the image id should be specified in \
{[i.name for i in self.list_default_image_variables()]}',
type=ToolParamter.ToolParameterType.SELECT,
required=True,
options=[i.name for i in self.list_default_image_variables()]
)
]
def is_tool_avaliable(self) -> bool:
# Only when there are images in the variable pool, the LLM needs to use this tool
return len(self.list_default_image_variables()) > 0
```
It's worth noting that we didn't actually use `image_id` here. We assumed that there must be an image in the default variable pool when calling this tool, so we directly used `image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE)` to get the image. In cases where the model's capabilities are weak, we recommend developers to do the same, which can effectively improve fault tolerance and avoid the model passing incorrect parameters.

View File

@ -0,0 +1,212 @@
# Quick Tool Integration
Here, we will use GoogleSearch as an example to demonstrate how to quickly integrate a tool.
## 1. Prepare the Tool Provider yaml
### Introduction
This yaml declares a new tool provider, and includes information like the provider's name, icon, author, and other details that are fetched by the frontend for display.
### Example
We need to create a `google` module (folder) under `core/tools/provider/builtin`, and create `google.yaml`. The name must be consistent with the module name.
Subsequently, all operations related to this tool will be carried out under this module.
```yaml
identity: # Basic information of the tool provider
author: Dify # Author
name: google # Name, unique, no duplication with other providers
label: # Label for frontend display
en_US: Google # English label
zh_Hans: Google # Chinese label
description: # Description for frontend display
en_US: Google # English description
zh_Hans: Google # Chinese description
icon: icon.svg # Icon, needs to be placed in the _assets folder of the current module
```
- The `identity` field is mandatory, it contains the basic information of the tool provider, including author, name, label, description, icon, etc.
- The icon needs to be placed in the `_assets` folder of the current module, you can refer to [here](../../provider/builtin/google/_assets/icon.svg).
## 2. Prepare Provider Credentials
Google, as a third-party tool, uses the API provided by SerpApi, which requires an API Key to use. This means that this tool needs a credential to use. For tools like `wikipedia`, there is no need to fill in the credential field, you can refer to [here](../../provider/builtin/wikipedia/wikipedia.yaml).
After configuring the credential field, the effect is as follows:
```yaml
identity:
author: Dify
name: google
label:
en_US: Google
zh_Hans: Google
description:
en_US: Google
zh_Hans: Google
icon: icon.svg
credentails_for_provider: # Credential field
serpapi_api_key: # Credential field name
type: secret-input # Credential field type
required: true # Required or not
label: # Credential field label
en_US: SerpApi API key # English label
zh_Hans: SerpApi API key # Chinese label
placeholder: # Credential field placeholder
en_US: Please input your SerpApi API key # English placeholder
zh_Hans: 请输入你的 SerpApi API key # Chinese placeholder
help: # Credential field help text
en_US: Get your SerpApi API key from SerpApi # English help text
zh_Hans: 从 SerpApi 获取您的 SerpApi API key # Chinese help text
url: https://serpapi.com/manage-api-key # Credential field help link
```
- `type`: Credential field type, currently can be either `secret-input`, `text-input`, or `select` , corresponding to password input box, text input box, and drop-down box, respectively. If set to `secret-input`, it will mask the input content on the frontend, and the backend will encrypt the input content.
## 3. Prepare Tool yaml
A provider can have multiple tools, each tool needs a yaml file to describe, this file contains the basic information, parameters, output, etc. of the tool.
Still taking GoogleSearch as an example, we need to create a `tools` module under the `google` module, and create `tools/google_search.yaml`, the content is as follows.
```yaml
identity: # Basic information of the tool
name: google_search # Tool name, unique, no duplication with other tools
author: Dify # Author
label: # Label for frontend display
en_US: GoogleSearch # English label
zh_Hans: 谷歌搜索 # Chinese label
description: # Description for frontend display
human: # Introduction for frontend display, supports multiple languages
en_US: A tool for performing a Google SERP search and extracting snippets and webpages.Input should be a search query.
zh_Hans: 一个用于执行 Google SERP 搜索并提取片段和网页的工具。输入应该是一个搜索查询。
llm: A tool for performing a Google SERP search and extracting snippets and webpages.Input should be a search query. # Introduction passed to LLM, in order to make LLM better understand this tool, we suggest to write as detailed information about this tool as possible here, so that LLM can understand and use this tool
parameters: # Parameter list
- name: query # Parameter name
type: string # Parameter type
required: true # Required or not
label: # Parameter label
en_US: Query string # English label
zh_Hans: 查询语句 # Chinese label
human_description: # Introduction for frontend display, supports multiple languages
en_US: used for searching
zh_Hans: 用于搜索网页内容
llm_description: key words for searching # Introduction passed to LLM, similarly, in order to make LLM better understand this parameter, we suggest to write as detailed information about this parameter as possible here, so that LLM can understand this parameter
form: llm # Form type, llm means this parameter needs to be inferred by Agent, the frontend will not display this parameter
- name: result_type
type: select # Parameter type
required: true
options: # Drop-down box options
- value: text
label:
en_US: text
zh_Hans: 文本
- value: link
label:
en_US: link
zh_Hans: 链接
default: link
label:
en_US: Result type
zh_Hans: 结果类型
human_description:
en_US: used for selecting the result type, text or link
zh_Hans: 用于选择结果类型,使用文本还是链接进行展示
form: form # Form type, form means this parameter needs to be filled in by the user on the frontend before the conversation starts
```
- The `identity` field is mandatory, it contains the basic information of the tool, including name, author, label, description, etc.
- `parameters` Parameter list
- `name` Parameter name, unique, no duplication with other parameters
- `type` Parameter type, currently supports `string`, `number`, `boolean`, `select` four types, corresponding to string, number, boolean, drop-down box
- `required` Required or not
- In `llm` mode, if the parameter is required, the Agent is required to infer this parameter
- In `form` mode, if the parameter is required, the user is required to fill in this parameter on the frontend before the conversation starts
- `options` Parameter options
- In `llm` mode, Dify will pass all options to LLM, LLM can infer based on these options
- In `form` mode, when `type` is `select`, the frontend will display these options
- `default` Default value
- `label` Parameter label, for frontend display
- `human_description` Introduction for frontend display, supports multiple languages
- `llm_description` Introduction passed to LLM, in order to make LLM better understand this parameter, we suggest to write as detailed information about this parameter as possible here, so that LLM can understand this parameter
- `form` Form type, currently supports `llm`, `form` two types, corresponding to Agent self-inference and frontend filling
## 4. Add Tool Logic
After completing the tool configuration, we can start writing the tool code that defines how it is invoked.
Create `google_search.py` under the `google/tools` module, the content is as follows.
```python
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
from typing import Any, Dict, List, Union
class GoogleSearchTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_paramters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
query = tool_paramters['query']
result_type = tool_paramters['result_type']
api_key = self.runtime.credentials['serpapi_api_key']
# TODO: search with serpapi
result = SerpAPI(api_key).run(query, result_type=result_type)
if result_type == 'text':
return self.create_text_message(text=result)
return self.create_link_message(link=result)
```
### Parameters
The overall logic of the tool is in the `_invoke` method, this method accepts two parameters: `user_id` and `tool_paramters`, which represent the user ID and tool parameters respectively
### Return Data
When the tool returns, you can choose to return one message or multiple messages, here we return one message, using `create_text_message` and `create_link_message` can create a text message or a link message.
## 5. Add Provider Code
Finally, we need to create a provider class under the provider module to implement the provider's credential verification logic. If the credential verification fails, it will throw a `ToolProviderCredentialValidationError` exception.
Create `google.py` under the `google` module, the content is as follows.
```python
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
from core.tools.tool.tool import Tool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.google.tools.google_search import GoogleSearchTool
from typing import Any, Dict
class GoogleProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
try:
# 1. Here you need to instantiate a GoogleSearchTool with GoogleSearchTool(), it will automatically load the yaml configuration of GoogleSearchTool, but at this time it does not have credential information inside
# 2. Then you need to use the fork_tool_runtime method to pass the current credential information to GoogleSearchTool
# 3. Finally, invoke it, the parameters need to be passed according to the parameter rules configured in the yaml of GoogleSearchTool
GoogleSearchTool().fork_tool_runtime(
meta={
"credentials": credentials,
}
).invoke(
user_id='',
tool_paramters={
"query": "test",
"result_type": "link"
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
```
## Completion
After the above steps are completed, we can see this tool on the frontend, and it can be used in the Agent.
Of course, because google_search needs a credential, before using it, you also need to input your credentials on the frontend.
![Alt text](../zh_Hans/images/index/image-2.png)

View File

@ -0,0 +1,266 @@
# 高级接入Tool
在开始高级接入之前,请确保你已经阅读过[快速接入](./tool_scale_out.md)并对Dify的工具接入流程有了基本的了解。
## 工具接口
我们在`Tool`类中定义了一系列快捷方法,用于帮助开发者快速构较为复杂的工具
### 消息返回
Dify支持`文本` `链接` `图片` `文件BLOB` 等多种消息类型你可以通过以下几个接口返回不同类型的消息给LLM和用户。
注意,在下面的接口中的部分参数将在后面的章节中介绍。
#### 图片URL
只需要传递图片的URL即可Dify会自动下载图片并返回给用户。
```python
def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage:
"""
create an image message
:param image: the url of the image
:return: the image message
"""
```
#### 链接
如果你需要返回一个链接,可以使用以下接口。
```python
def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage:
"""
create a link message
:param link: the url of the link
:return: the link message
"""
```
#### 文本
如果你需要返回一个文本消息,可以使用以下接口。
```python
def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage:
"""
create a text message
:param text: the text of the message
:return: the text message
"""
```
#### 文件BLOB
如果你需要返回文件的原始数据如图片、音频、视频、PPT、Word、Excel等可以使用以下接口。
- `blob` 文件的原始数据bytes类型
- `meta` 文件的元数据,如果你知道该文件的类型,最好传递一个`mime_type`否则Dify将使用`octet/stream`作为默认类型
```python
def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage:
"""
create a blob message
:param blob: the blob
:return: the blob message
"""
```
### 快捷工具
在大模型应用中,我们有两种常见的需求:
- 先将很长的文本进行提前总结然后再将总结内容传递给LLM以防止原文本过长导致LLM无法处理
- 工具获取到的内容是一个链接需要爬取网页信息后再返回给LLM
为了帮助开发者快速实现这两种需求,我们提供了以下两个快捷工具。
#### 文本总结工具
该工具需要传入user_id和需要进行总结的文本返回一个总结后的文本Dify会使用当前工作空间的默认模型对长文本进行总结。
```python
def summary(self, user_id: str, content: str) -> str:
"""
summary the content
:param user_id: the user id
:param content: the content
:return: the summary
"""
```
#### 网页爬取工具
该工具需要传入需要爬取的网页链接和一个user_agent可为空返回一个包含该网页信息的字符串其中`user_agent`是可选参数可以用来识别工具如果不传递Dify将使用默认的`user_agent`
```python
def get_url(self, url: str, user_agent: str = None) -> str:
"""
get url
""" the crawled result
```
### 变量池
我们在`Tool`中引入了一个变量池,用于存储工具运行过程中产生的变量、文件等,这些变量可以在工具运行过程中被其他工具使用。
下面,我们以`DallE3``Vectorizer.AI`为例,介绍如何使用变量池。
- `DallE3`是一个图片生成工具,它可以根据文本生成图片,在这里,我们将让`DallE3`生成一个咖啡厅的Logo
- `Vectorizer.AI`是一个矢量图转换工具,它可以将图片转换为矢量图,使得图片可以无限放大而不失真,在这里,我们将`DallE3`生成的PNG图标转换为矢量图从而可以真正被设计师使用。
#### DallE3
首先我们使用DallE3在创建完图片以后我们将图片保存到变量池中代码如下
```python
from typing import Any, Dict, List, Union
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
from base64 import b64decode
from openai import OpenAI
class DallE3Tool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_paramters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
client = OpenAI(
api_key=self.runtime.credentials['openai_api_key'],
)
# prompt
prompt = tool_paramters.get('prompt', '')
if not prompt:
return self.create_text_message('Please input prompt')
# call openapi dalle3
response = client.images.generate(
prompt=prompt, model='dall-e-3',
size='1024x1024', n=1, style='vivid', quality='standard',
response_format='b64_json'
)
result = []
for image in response.data:
# 将所有图片通过save_as参数保存到变量池中变量名为self.VARIABLE_KEY.IMAGE.value如果如果后续有新的图片生成那么将会覆盖之前的图片
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
meta={ 'mime_type': 'image/png' },
save_as=self.VARIABLE_KEY.IMAGE.value))
return result
```
我们可以注意到这里我们使用了`self.VARIABLE_KEY.IMAGE.value`作为图片的变量名,为了便于开发者们的工具能够互相配合,我们定义了这个`KEY`,大家可以自由使用,也可以不使用这个`KEY`传递一个自定义的KEY也是可以的。
#### Vectorizer.AI
接下来我们使用Vectorizer.AI将DallE3生成的PNG图标转换为矢量图我们先来过一遍我们在这里定义的函数代码如下
```python
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParamter
from core.tools.errors import ToolProviderCredentialValidationError
from typing import Any, Dict, List, Union
from httpx import post
from base64 import b64decode
class VectorizerTool(BuiltinTool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
工具调用,图片变量名需要从这里传递进来,从而我们就可以从变量池中获取到图片
"""
def get_runtime_parameters(self) -> List[ToolParamter]:
"""
重写工具参数列表我们可以根据当前变量池里的实际情况来动态生成参数列表从而LLM可以根据参数列表来生成表单
"""
def is_tool_avaliable(self) -> bool:
"""
当前工具是否可用如果当前变量池中没有图片那么我们就不需要展示这个工具这里返回False即可
"""
```
接下来我们来实现这三个函数
```python
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParamter
from core.tools.errors import ToolProviderCredentialValidationError
from typing import Any, Dict, List, Union
from httpx import post
from base64 import b64decode
class VectorizerTool(BuiltinTool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
api_key_name = self.runtime.credentials.get('api_key_name', None)
api_key_value = self.runtime.credentials.get('api_key_value', None)
if not api_key_name or not api_key_value:
raise ToolProviderCredentialValidationError('Please input api key name and value')
# 获取image_idimage_id的定义可以在get_runtime_parameters中找到
image_id = tool_paramters.get('image_id', '')
if not image_id:
return self.create_text_message('Please input image id')
# 从变量池中获取到之前DallE生成的图片
image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE)
if not image_binary:
return self.create_text_message('Image not found, please request user to generate image firstly.')
# 生成矢量图
response = post(
'https://vectorizer.ai/api/v1/vectorize',
files={ 'image': image_binary },
data={ 'mode': 'test' },
auth=(api_key_name, api_key_value),
timeout=30
)
if response.status_code != 200:
raise Exception(response.text)
return [
self.create_text_message('the vectorized svg is saved as an image.'),
self.create_blob_message(blob=response.content,
meta={'mime_type': 'image/svg+xml'})
]
def get_runtime_parameters(self) -> List[ToolParamter]:
"""
override the runtime parameters
"""
# 这里我们重写了工具参数列表定义了image_id并设置了它的选项列表为当前变量池中的所有图片这里的配置与yaml中的配置是一致的
return [
ToolParamter.get_simple_instance(
name='image_id',
llm_description=f'the image id that you want to vectorize, \
and the image id should be specified in \
{[i.name for i in self.list_default_image_variables()]}',
type=ToolParamter.ToolParameterType.SELECT,
required=True,
options=[i.name for i in self.list_default_image_variables()]
)
]
def is_tool_avaliable(self) -> bool:
# 只有当变量池中有图片时LLM才需要使用这个工具
return len(self.list_default_image_variables()) > 0
```
可以注意到的是,我们这里其实并没有使用到`image_id`,我们已经假设了调用这个工具的时候一定有一张图片在默认的变量池中,所以直接使用了`image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE)`来获取图片,在模型能力较弱的情况下,我们建议开发者们也这样做,可以有效提升容错率,避免模型传递错误的参数。

Binary file not shown.

After

Width:  |  Height:  |  Size: 242 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 407 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 266 KiB

View File

@ -0,0 +1,212 @@
# 快速接入Tool
这里我们以GoogleSearch为例介绍如何快速接入一个工具。
## 1. 准备工具供应商yaml
### 介绍
这个yaml将包含工具供应商的信息包括供应商名称、图标、作者等详细信息以帮助前端灵活展示。
### 示例
我们需要在 `core/tools/provider/builtin`下创建一个`google`模块(文件夹),并创建`google.yaml`,名称必须与模块名称一致。
后续,我们关于这个工具的所有操作都将在这个模块下进行。
```yaml
identity: # 工具供应商的基本信息
author: Dify # 作者
name: google # 名称,唯一,不允许和其他供应商重名
label: # 标签,用于前端展示
en_US: Google # 英文标签
zh_Hans: Google # 中文标签
description: # 描述,用于前端展示
en_US: Google # 英文描述
zh_Hans: Google # 中文描述
icon: icon.svg # 图标需要放置在当前模块的_assets文件夹下
```
- `identity` 字段是必须的,它包含了工具供应商的基本信息,包括作者、名称、标签、描述、图标等
- 图标需要放置在当前模块的`_assets`文件夹下,可以参考[这里](../../provider/builtin/google/_assets/icon.svg)。
## 2. 准备供应商凭据
Google作为一个第三方工具使用了SerpApi提供的API而SerpApi需要一个API Key才能使用那么就意味着这个工具需要一个凭据才可以使用而像`wikipedia`这样的工具,就不需要填写凭据字段,可以参考[这里](../../provider/builtin/wikipedia/wikipedia.yaml)。
配置好凭据字段后效果如下:
```yaml
identity:
author: Dify
name: google
label:
en_US: Google
zh_Hans: Google
description:
en_US: Google
zh_Hans: Google
icon: icon.svg
credentails_for_provider: # 凭据字段
serpapi_api_key: # 凭据字段名称
type: secret-input # 凭据字段类型
required: true # 是否必填
label: # 凭据字段标签
en_US: SerpApi API key # 英文标签
zh_Hans: SerpApi API key # 中文标签
placeholder: # 凭据字段占位符
en_US: Please input your SerpApi API key # 英文占位符
zh_Hans: 请输入你的 SerpApi API key # 中文占位符
help: # 凭据字段帮助文本
en_US: Get your SerpApi API key from SerpApi # 英文帮助文本
zh_Hans: 从 SerpApi 获取您的 SerpApi API key # 中文帮助文本
url: https://serpapi.com/manage-api-key # 凭据字段帮助链接
```
- `type`:凭据字段类型,目前支持`secret-input``text-input``select` 三种类型,分别对应密码输入框、文本输入框、下拉框,如果为`secret-input`,则会在前端隐藏输入内容,并且后端会对输入内容进行加密。
## 3. 准备工具yaml
一个供应商底下可以有多个工具每个工具都需要一个yaml文件来描述这个文件包含了工具的基本信息、参数、输出等。
仍然以GoogleSearch为例我们需要在`google`模块下创建一个`tools`模块,并创建`tools/google_search.yaml`,内容如下。
```yaml
identity: # 工具的基本信息
name: google_search # 工具名称,唯一,不允许和其他工具重名
author: Dify # 作者
label: # 标签,用于前端展示
en_US: GoogleSearch # 英文标签
zh_Hans: 谷歌搜索 # 中文标签
description: # 描述,用于前端展示
human: # 用于前端展示的介绍,支持多语言
en_US: A tool for performing a Google SERP search and extracting snippets and webpages.Input should be a search query.
zh_Hans: 一个用于执行 Google SERP 搜索并提取片段和网页的工具。输入应该是一个搜索查询。
llm: A tool for performing a Google SERP search and extracting snippets and webpages.Input should be a search query. # 传递给LLM的介绍为了使得LLM更好理解这个工具我们建议在这里写上关于这个工具尽可能详细的信息让LLM能够理解并使用这个工具
parameters: # 参数列表
- name: query # 参数名称
type: string # 参数类型
required: true # 是否必填
label: # 参数标签
en_US: Query string # 英文标签
zh_Hans: 查询语句 # 中文标签
human_description: # 用于前端展示的介绍,支持多语言
en_US: used for searching
zh_Hans: 用于搜索网页内容
llm_description: key words for searching # 传递给LLM的介绍同上为了使得LLM更好理解这个参数我们建议在这里写上关于这个参数尽可能详细的信息让LLM能够理解这个参数
form: llm # 表单类型llm表示这个参数需要由Agent自行推理出来前端将不会展示这个参数
- name: result_type
type: select # 参数类型
required: true
options: # 下拉框选项
- value: text
label:
en_US: text
zh_Hans: 文本
- value: link
label:
en_US: link
zh_Hans: 链接
default: link
label:
en_US: Result type
zh_Hans: 结果类型
human_description:
en_US: used for selecting the result type, text or link
zh_Hans: 用于选择结果类型,使用文本还是链接进行展示
form: form # 表单类型form表示这个参数需要由用户在对话开始前在前端填写
```
- `identity` 字段是必须的,它包含了工具的基本信息,包括名称、作者、标签、描述等
- `parameters` 参数列表
- `name` 参数名称,唯一,不允许和其他参数重名
- `type` 参数类型,目前支持`string``number``boolean``select` 四种类型,分别对应字符串、数字、布尔值、下拉框
- `required` 是否必填
-`llm`模式下如果参数为必填则会要求Agent必须要推理出这个参数
-`form`模式下,如果参数为必填,则会要求用户在对话开始前在前端填写这个参数
- `options` 参数选项
-`llm`模式下Dify会将所有选项传递给LLMLLM可以根据这些选项进行推理
-`form`模式下,`type``select`时,前端会展示这些选项
- `default` 默认值
- `label` 参数标签,用于前端展示
- `human_description` 用于前端展示的介绍,支持多语言
- `llm_description` 传递给LLM的介绍为了使得LLM更好理解这个参数我们建议在这里写上关于这个参数尽可能详细的信息让LLM能够理解这个参数
- `form` 表单类型,目前支持`llm``form`两种类型分别对应Agent自行推理和前端填写
## 4. 准备工具代码
当完成工具的配置以后,我们就可以开始编写工具代码了,主要用于实现工具的逻辑。
`google/tools`模块下创建`google_search.py`,内容如下。
```python
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
from typing import Any, Dict, List, Union
class GoogleSearchTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_paramters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
query = tool_paramters['query']
result_type = tool_paramters['result_type']
api_key = self.runtime.credentials['serpapi_api_key']
# TODO: search with serpapi
result = SerpAPI(api_key).run(query, result_type=result_type)
if result_type == 'text':
return self.create_text_message(text=result)
return self.create_link_message(link=result)
```
### 参数
工具的整体逻辑都在`_invoke`方法中,这个方法接收两个参数:`user_id``tool_paramters`分别表示用户ID和工具参数
### 返回数据
在工具返回时,你可以选择返回一个消息或者多个消息,这里我们返回一个消息,使用`create_text_message``create_link_message`可以创建一个文本消息或者一个链接消息。
## 5. 准备供应商代码
最后,我们需要在供应商模块下创建一个供应商类,用于实现供应商的凭据验证逻辑,如果凭据验证失败,将会抛出`ToolProviderCredentialValidationError`异常。
`google`模块下创建`google.py`,内容如下。
```python
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
from core.tools.tool.tool import Tool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.google.tools.google_search import GoogleSearchTool
from typing import Any, Dict
class GoogleProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
try:
# 1. 此处需要使用GoogleSearchTool()实例化一个GoogleSearchTool它会自动加载GoogleSearchTool的yaml配置但是此时它内部没有凭据信息
# 2. 随后需要使用fork_tool_runtime方法将当前的凭据信息传递给GoogleSearchTool
# 3. 最后invoke即可参数需要根据GoogleSearchTool的yaml中配置的参数规则进行传递
GoogleSearchTool().fork_tool_runtime(
meta={
"credentials": credentials,
}
).invoke(
user_id='',
tool_paramters={
"query": "test",
"result_type": "link"
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
```
## 完成
当上述步骤完成以后我们就可以在前端看到这个工具了并且可以在Agent中使用这个工具。
当然因为google_search需要一个凭据在使用之前还需要在前端配置它的凭据。
![Alt text](images/index/image-2.png)

View File

@ -0,0 +1,22 @@
from typing import Optional
from pydantic import BaseModel
class I18nObject(BaseModel):
"""
Model class for i18n object.
"""
zh_Hans: Optional[str] = None
en_US: str
def __init__(self, **data):
super().__init__(**data)
if not self.zh_Hans:
self.zh_Hans = self.en_US
def to_dict(self) -> dict:
return {
'zh_Hans': self.zh_Hans,
'en_US': self.en_US,
}

View File

@ -0,0 +1,3 @@
class DEFAULT_PROVIDERS:
API_BASED = '__api_based'
APP_BASED = '__app_based'

View File

@ -0,0 +1,34 @@
from pydantic import BaseModel
from typing import Dict, Optional, Any, List
from core.tools.entities.tool_entities import ToolProviderType, ToolParamter
class ApiBasedToolBundle(BaseModel):
"""
This class is used to store the schema information of an api based tool. such as the url, the method, the parameters, etc.
"""
# server_url
server_url: str
# method
method: str
# summary
summary: Optional[str] = None
# operation_id
operation_id: str = None
# parameters
parameters: Optional[List[ToolParamter]] = None
# author
author: str
# icon
icon: Optional[str] = None
# openapi operation
openapi: dict
class AppToolBundle(BaseModel):
"""
This class is used to store the schema information of an tool for an app.
"""
type: ToolProviderType
credential: Optional[Dict[str, Any]] = None
provider_id: str
tool_name: str

View File

@ -0,0 +1,305 @@
from pydantic import BaseModel, Field
from enum import Enum
from typing import Optional, List, Dict, Any, Union, cast
from core.tools.entities.common_entities import I18nObject
class ToolProviderType(Enum):
"""
Enum class for tool provider
"""
BUILT_IN = "built-in"
APP_BASED = "app-based"
API_BASED = "api-based"
@classmethod
def value_of(cls, value: str) -> 'ToolProviderType':
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid mode value {value}')
class ApiProviderSchemaType(Enum):
"""
Enum class for api provider schema type.
"""
OPENAPI = "openapi"
SWAGGER = "swagger"
OPENAI_PLUGIN = "openai_plugin"
OPENAI_ACTIONS = "openai_actions"
@classmethod
def value_of(cls, value: str) -> 'ApiProviderSchemaType':
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid mode value {value}')
class ApiProviderAuthType(Enum):
"""
Enum class for api provider auth type.
"""
NONE = "none"
API_KEY = "api_key"
@classmethod
def value_of(cls, value: str) -> 'ApiProviderAuthType':
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid mode value {value}')
class ToolInvokeMessage(BaseModel):
class MessageType(Enum):
TEXT = "text"
IMAGE = "image"
LINK = "link"
BLOB = "blob"
IMAGE_LINK = "image_link"
type: MessageType = MessageType.TEXT
"""
plain text, image url or link url
"""
message: Union[str, bytes] = None
meta: Dict[str, Any] = None
save_as: str = ''
class ToolInvokeMessageBinary(BaseModel):
mimetype: str = Field(..., description="The mimetype of the binary")
url: str = Field(..., description="The url of the binary")
save_as: str = ''
class ToolParamterOption(BaseModel):
value: str = Field(..., description="The value of the option")
label: I18nObject = Field(..., description="The label of the option")
class ToolParamter(BaseModel):
class ToolParameterType(Enum):
STRING = "string"
NUMBER = "number"
BOOLEAN = "boolean"
SELECT = "select"
class ToolParameterForm(Enum):
SCHEMA = "schema" # should be set while adding tool
FORM = "form" # should be set before invoking tool
LLM = "llm" # will be set by LLM
name: str = Field(..., description="The name of the parameter")
label: I18nObject = Field(..., description="The label presented to the user")
human_description: I18nObject = Field(..., description="The description presented to the user")
type: ToolParameterType = Field(..., description="The type of the parameter")
form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
llm_description: Optional[str] = None
required: Optional[bool] = False
default: Optional[str] = None
min: Optional[Union[float, int]] = None
max: Optional[Union[float, int]] = None
options: Optional[List[ToolParamterOption]] = None
@classmethod
def get_simple_instance(cls,
name: str, llm_description: str, type: ToolParameterType,
required: bool, options: Optional[List[str]] = None) -> 'ToolParamter':
"""
get a simple tool parameter
:param name: the name of the parameter
:param llm_description: the description presented to the LLM
:param type: the type of the parameter
:param required: if the parameter is required
:param options: the options of the parameter
"""
# convert options to ToolParamterOption
if options:
options = [ToolParamterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options]
return cls(
name=name,
label=I18nObject(en_US='', zh_Hans=''),
human_description=I18nObject(en_US='', zh_Hans=''),
type=type,
form=cls.ToolParameterForm.LLM,
llm_description=llm_description,
required=required,
options=options,
)
class ToolProviderIdentity(BaseModel):
author: str = Field(..., description="The author of the tool")
name: str = Field(..., description="The name of the tool")
description: I18nObject = Field(..., description="The description of the tool")
icon: str = Field(..., description="The icon of the tool")
label: I18nObject = Field(..., description="The label of the tool")
class ToolDescription(BaseModel):
human: I18nObject = Field(..., description="The description presented to the user")
llm: str = Field(..., description="The description presented to the LLM")
class ToolIdentity(BaseModel):
author: str = Field(..., description="The author of the tool")
name: str = Field(..., description="The name of the tool")
label: I18nObject = Field(..., description="The label of the tool")
class ToolCredentialsOption(BaseModel):
value: str = Field(..., description="The value of the option")
label: I18nObject = Field(..., description="The label of the option")
class ToolProviderCredentials(BaseModel):
class CredentialsType(Enum):
SECRET_INPUT = "secret-input"
TEXT_INPUT = "text-input"
SELECT = "select"
@classmethod
def value_of(cls, value: str) -> "ToolProviderCredentials.CredentialsType":
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid mode value {value}')
@staticmethod
def defaut(value: str) -> str:
return ""
name: str = Field(..., description="The name of the credentials")
type: CredentialsType = Field(..., description="The type of the credentials")
required: bool = False
default: Optional[str] = None
options: Optional[List[ToolCredentialsOption]] = None
label: Optional[I18nObject] = None
help: Optional[I18nObject] = None
url: Optional[str] = None
placeholder: Optional[I18nObject] = None
def to_dict(self) -> dict:
return {
'name': self.name,
'type': self.type.value,
'required': self.required,
'default': self.default,
'options': self.options,
'help': self.help.to_dict() if self.help else None,
'label': self.label.to_dict(),
'url': self.url,
'placeholder': self.placeholder.to_dict() if self.placeholder else None,
}
class ToolRuntimeVariableType(Enum):
TEXT = "text"
IMAGE = "image"
class ToolRuntimeVariable(BaseModel):
type: ToolRuntimeVariableType = Field(..., description="The type of the variable")
name: str = Field(..., description="The name of the variable")
position: int = Field(..., description="The position of the variable")
tool_name: str = Field(..., description="The name of the tool")
class ToolRuntimeTextVariable(ToolRuntimeVariable):
value: str = Field(..., description="The value of the variable")
class ToolRuntimeImageVariable(ToolRuntimeVariable):
value: str = Field(..., description="The path of the image")
class ToolRuntimeVariablePool(BaseModel):
conversation_id: str = Field(..., description="The conversation id")
user_id: str = Field(..., description="The user id")
tenant_id: str = Field(..., description="The tenant id of assistant")
pool: List[ToolRuntimeVariable] = Field(..., description="The pool of variables")
def __init__(self, **data: Any):
pool = data.get('pool', [])
# convert pool into correct type
for index, variable in enumerate(pool):
if variable['type'] == ToolRuntimeVariableType.TEXT.value:
pool[index] = ToolRuntimeTextVariable(**variable)
elif variable['type'] == ToolRuntimeVariableType.IMAGE.value:
pool[index] = ToolRuntimeImageVariable(**variable)
super().__init__(**data)
def dict(self) -> dict:
return {
'conversation_id': self.conversation_id,
'user_id': self.user_id,
'tenant_id': self.tenant_id,
'pool': [variable.dict() for variable in self.pool],
}
def set_text(self, tool_name: str, name: str, value: str) -> None:
"""
set a text variable
"""
for variable in self.pool:
if variable.name == name:
if variable.type == ToolRuntimeVariableType.TEXT:
variable = cast(ToolRuntimeTextVariable, variable)
variable.value = value
return
variable = ToolRuntimeTextVariable(
type=ToolRuntimeVariableType.TEXT,
name=name,
position=len(self.pool),
tool_name=tool_name,
value=value,
)
self.pool.append(variable)
def set_file(self, tool_name: str, value: str, name: str = None) -> None:
"""
set an image variable
:param tool_name: the name of the tool
:param value: the id of the file
"""
# check how many image variables are there
image_variable_count = 0
for variable in self.pool:
if variable.type == ToolRuntimeVariableType.IMAGE:
image_variable_count += 1
if name is None:
name = f"file_{image_variable_count}"
for variable in self.pool:
if variable.name == name:
if variable.type == ToolRuntimeVariableType.IMAGE:
variable = cast(ToolRuntimeImageVariable, variable)
variable.value = value
return
variable = ToolRuntimeImageVariable(
type=ToolRuntimeVariableType.IMAGE,
name=name,
position=len(self.pool),
tool_name=tool_name,
value=value,
)
self.pool.append(variable)

View File

@ -0,0 +1,48 @@
from pydantic import BaseModel
from enum import Enum
from typing import List, Dict, Optional
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderCredentials
from core.tools.tool.tool import ToolParamter
class UserToolProvider(BaseModel):
class ProviderType(Enum):
BUILTIN = "builtin"
APP = "app"
API = "api"
id: str
author: str
name: str # identifier
description: I18nObject
icon: str
label: I18nObject # label
type: ProviderType
team_credentials: dict = None
is_team_authorization: bool = False
allow_delete: bool = True
def to_dict(self) -> dict:
return {
'id': self.id,
'author': self.author,
'name': self.name,
'description': self.description.to_dict(),
'icon': self.icon,
'label': self.label.to_dict(),
'type': self.type.value,
'team_credentials': self.team_credentials,
'is_team_authorization': self.is_team_authorization,
'allow_delete': self.allow_delete
}
class UserToolProviderCredentials(BaseModel):
credentails: Dict[str, ToolProviderCredentials]
class UserTool(BaseModel):
author: str
name: str # identifier
label: I18nObject # label
description: I18nObject
parameters: Optional[List[ToolParamter]]

20
api/core/tools/errors.py Normal file
View File

@ -0,0 +1,20 @@
class ToolProviderNotFoundError(ValueError):
pass
class ToolNotFoundError(ValueError):
pass
class ToolParamterValidationError(ValueError):
pass
class ToolProviderCredentialValidationError(ValueError):
pass
class ToolNotSupportedError(ValueError):
pass
class ToolInvokeError(ValueError):
pass
class ToolApiSchemaError(ValueError):
pass

View File

@ -0,0 +1,2 @@
class InvokeModelError(Exception):
pass

View File

@ -0,0 +1,174 @@
"""
For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc.
Therefore, a model manager is needed to list/invoke/validate models.
"""
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel, ModelPropertyKey
from core.model_runtime.errors.invoke import InvokeRateLimitError, InvokeBadRequestError, \
InvokeConnectionError, InvokeAuthorizationError, InvokeServerUnavailableError
from core.model_runtime.utils.encoders import jsonable_encoder
from core.model_manager import ModelManager
from core.tools.model.errors import InvokeModelError
from extensions.ext_database import db
from models.tools import ToolModelInvoke
from typing import List, cast
import json
class ToolModelManager:
@staticmethod
def get_max_llm_context_tokens(
tenant_id: str,
) -> int:
"""
get max llm context tokens of the model
"""
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id, model_type=ModelType.LLM,
)
if not model_instance:
raise InvokeModelError(f'Model not found')
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if not schema:
raise InvokeModelError(f'No model schema found')
max_tokens = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
if max_tokens is None:
return 2048
return max_tokens
@staticmethod
def calculate_tokens(
tenant_id: str,
prompt_messages: List[PromptMessage]
) -> int:
"""
calculate tokens from prompt messages and model parameters
"""
# get model instance
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id, model_type=ModelType.LLM
)
if not model_instance:
raise InvokeModelError(f'Model not found')
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
# get tokens
tokens = llm_model.get_num_tokens(model_instance.model, model_instance.credentials, prompt_messages)
return tokens
@staticmethod
def invoke(
user_id: str, tenant_id: str,
tool_type: str, tool_name: str,
prompt_messages: List[PromptMessage]
) -> LLMResult:
"""
invoke model with parameters in user's own context
:param user_id: user id
:param tenant_id: tenant id, the tenant id of the creator of the tool
:param tool_provider: tool provider
:param tool_id: tool id
:param tool_name: tool name
:param provider: model provider
:param model: model name
:param model_parameters: model parameters
:param prompt_messages: prompt messages
:return: AssistantPromptMessage
"""
# get model manager
model_manager = ModelManager()
# get model instance
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id, model_type=ModelType.LLM,
)
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
# get model credentials
model_credentials = model_instance.credentials
# get prompt tokens
prompt_tokens = llm_model.get_num_tokens(model_instance.model, model_credentials, prompt_messages)
model_parameters = {
'temperature': 0.8,
'top_p': 0.8,
}
# create tool model invoke
tool_model_invoke = ToolModelInvoke(
user_id=user_id,
tenant_id=tenant_id,
provider=model_instance.provider,
tool_type=tool_type,
tool_name=tool_name,
model_parameters=json.dumps(model_parameters),
prompt_messages=json.dumps(jsonable_encoder(prompt_messages)),
model_response='',
prompt_tokens=prompt_tokens,
answer_tokens=0,
answer_unit_price=0,
answer_price_unit=0,
provider_response_latency=0,
total_price=0,
currency='USD',
)
db.session.add(tool_model_invoke)
db.session.commit()
try:
response: LLMResult = llm_model.invoke(
model=model_instance.model,
credentials=model_credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=[], stop=[], stream=False, user=user_id, callbacks=[]
)
except InvokeRateLimitError as e:
raise InvokeModelError(f'Invoke rate limit error: {e}')
except InvokeBadRequestError as e:
raise InvokeModelError(f'Invoke bad request error: {e}')
except InvokeConnectionError as e:
raise InvokeModelError(f'Invoke connection error: {e}')
except InvokeAuthorizationError as e:
raise InvokeModelError(f'Invoke authorization error')
except InvokeServerUnavailableError as e:
raise InvokeModelError(f'Invoke server unavailable error: {e}')
except Exception as e:
raise InvokeModelError(f'Invoke error: {e}')
# update tool model invoke
tool_model_invoke.model_response = response.message.content
if response.usage:
tool_model_invoke.answer_tokens = response.usage.completion_tokens
tool_model_invoke.answer_unit_price = response.usage.completion_unit_price
tool_model_invoke.answer_price_unit = response.usage.completion_price_unit
tool_model_invoke.provider_response_latency = response.usage.latency
tool_model_invoke.total_price = response.usage.total_price
tool_model_invoke.currency = response.usage.currency
db.session.commit()
return response

View File

@ -0,0 +1,102 @@
ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES = """Respond to the human as helpfully and accurately as possible.
{{instruction}}
You have access to the following tools:
{{tools}}
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
Valid "action" values: "Final Answer" or {{tool_names}}
Provide only ONE action per $JSON_BLOB, as shown:
```
{
"action": $TOOL_NAME,
"action_input": $ACTION_INPUT
}
```
Follow this format:
Question: input question to answer
Thought: consider previous and subsequent steps
Action:
```
$JSON_BLOB
```
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
Action:
```
{
"action": "Final Answer",
"action_input": "Final response to human"
}
```
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
Question: {{query}}
Thought: {{agent_scratchpad}}"""
ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES = """Observation: {{observation}}
Thought:"""
ENGLISH_REACT_CHAT_PROMPT_TEMPLATES = """Respond to the human as helpfully and accurately as possible.
{{instruction}}
You have access to the following tools:
{{tools}}
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
Valid "action" values: "Final Answer" or {{tool_names}}
Provide only ONE action per $JSON_BLOB, as shown:
```
{
"action": $TOOL_NAME,
"action_input": $ACTION_INPUT
}
```
Follow this format:
Question: input question to answer
Thought: consider previous and subsequent steps
Action:
```
$JSON_BLOB
```
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
Action:
```
{
"action": "Final Answer",
"action_input": "Final response to human"
}
```
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
"""
ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = ""
REACT_PROMPT_TEMPLATES = {
'english': {
'chat': {
'prompt': ENGLISH_REACT_CHAT_PROMPT_TEMPLATES,
'agent_scratchpad': ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES
},
'completion': {
'prompt': ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
'agent_scratchpad': ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES
}
}
}

View File

@ -0,0 +1,169 @@
from typing import Any, Dict, List
from core.tools.entities.tool_entities import ToolProviderType, ApiProviderAuthType, ToolProviderCredentials, ToolCredentialsOption
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiBasedToolBundle
from core.tools.tool.tool import Tool
from core.tools.tool.api_tool import ApiTool
from core.tools.provider.tool_provider import ToolProviderController
from extensions.ext_database import db
from models.tools import ApiToolProvider
class ApiBasedToolProviderController(ToolProviderController):
@staticmethod
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiBasedToolProviderController':
credentials_schema = {
'auth_type': ToolProviderCredentials(
name='auth_type',
required=True,
type=ToolProviderCredentials.CredentialsType.SELECT,
options=[
ToolCredentialsOption(value='none', label=I18nObject(en_US='None', zh_Hans='')),
ToolCredentialsOption(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key'))
],
default='none',
help=I18nObject(
en_US='The auth type of the api provider',
zh_Hans='api provider 的认证类型'
)
)
}
if auth_type == ApiProviderAuthType.API_KEY:
credentials_schema = {
**credentials_schema,
'api_key_header': ToolProviderCredentials(
name='api_key_header',
required=False,
default='api_key',
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
help=I18nObject(
en_US='The header name of the api key',
zh_Hans='携带 api key 的 header 名称'
)
),
'api_key_value': ToolProviderCredentials(
name='api_key_value',
required=True,
type=ToolProviderCredentials.CredentialsType.SECRET_INPUT,
help=I18nObject(
en_US='The api key',
zh_Hans='api key的值'
)
)
}
elif auth_type == ApiProviderAuthType.NONE:
pass
else:
raise ValueError(f'invalid auth type {auth_type}')
return ApiBasedToolProviderController(**{
'identity': {
'author': db_provider.user.name if db_provider.user_id and db_provider.user else '',
'name': db_provider.name,
'label': {
'en_US': db_provider.name,
'zh_Hans': db_provider.name
},
'description': {
'en_US': db_provider.description,
'zh_Hans': db_provider.description
},
'icon': db_provider.icon
},
'credentials_schema': credentials_schema
})
@property
def app_type(self) -> ToolProviderType:
return ToolProviderType.API_BASED
def _validate_credentials(self, tool_name: str, credentials: Dict[str, Any]) -> None:
pass
def validate_parameters(self, tool_name: str, tool_parameters: Dict[str, Any]) -> None:
pass
def _parse_tool_bundle(self, tool_bundle: ApiBasedToolBundle) -> ApiTool:
"""
parse tool bundle to tool
:param tool_bundle: the tool bundle
:return: the tool
"""
return ApiTool(**{
'api_bundle': tool_bundle,
'identity' : {
'author': tool_bundle.author,
'name': tool_bundle.operation_id,
'label': {
'en_US': tool_bundle.operation_id,
'zh_Hans': tool_bundle.operation_id
},
'icon': tool_bundle.icon if tool_bundle.icon else ''
},
'description': {
'human': {
'en_US': tool_bundle.summary or '',
'zh_Hans': tool_bundle.summary or ''
},
'llm': tool_bundle.summary or ''
},
'parameters' : tool_bundle.parameters if tool_bundle.parameters else [],
})
def load_bundled_tools(self, tools: List[ApiBasedToolBundle]) -> List[ApiTool]:
"""
load bundled tools
:param tools: the bundled tools
:return: the tools
"""
self.tools = [self._parse_tool_bundle(tool) for tool in tools]
return self.tools
def get_tools(self, user_id: str, tanent_id: str) -> List[ApiTool]:
"""
fetch tools from database
:param user_id: the user id
:param tanent_id: the tanent id
:return: the tools
"""
if self.tools is not None:
return self.tools
tools: List[Tool] = []
# get tanent api providers
db_providers: List[ApiToolProvider] = db.session.query(ApiToolProvider).filter(
ApiToolProvider.tenant_id == tanent_id,
ApiToolProvider.name == self.identity.name
).all()
if db_providers and len(db_providers) != 0:
for db_provider in db_providers:
for tool in db_provider.tools:
assistant_tool = self._parse_tool_bundle(tool)
assistant_tool.is_team_authorization = True
tools.append(assistant_tool)
self.tools = tools
return tools
def get_tool(self, tool_name: str) -> ApiTool:
"""
get tool by name
:param tool_name: the name of the tool
:return: the tool
"""
if self.tools is None:
self.get_tools()
for tool in self.tools:
if tool.identity.name == tool_name:
return tool
raise ValueError(f'tool {tool_name} not found')

View File

@ -0,0 +1,116 @@
from typing import Any, Dict, List
from core.tools.entities.tool_entities import ToolProviderType, ToolParamter, ToolParamterOption
from core.tools.tool.tool import Tool
from core.tools.entities.common_entities import I18nObject
from core.tools.provider.tool_provider import ToolProviderController
from extensions.ext_database import db
from models.tools import PublishedAppTool
from models.model import App, AppModelConfig
import logging
logger = logging.getLogger(__name__)
class AppBasedToolProviderEntity(ToolProviderController):
@property
def app_type(self) -> ToolProviderType:
return ToolProviderType.APP_BASED
def _validate_credentials(self, tool_name: str, credentials: Dict[str, Any]) -> None:
pass
def validate_parameters(self, tool_name: str, tool_parameters: Dict[str, Any]) -> None:
pass
def get_tools(self, user_id: str) -> List[Tool]:
db_tools: List[PublishedAppTool] = db.session.query(PublishedAppTool).filter(
PublishedAppTool.user_id == user_id,
).all()
if not db_tools or len(db_tools) == 0:
return []
tools: List[Tool] = []
for db_tool in db_tools:
tool = {
'identity': {
'author': db_tool.author,
'name': db_tool.tool_name,
'label': {
'en_US': db_tool.tool_name,
'zh_Hans': db_tool.tool_name
},
'icon': ''
},
'description': {
'human': {
'en_US': db_tool.description_i18n.en_US,
'zh_Hans': db_tool.description_i18n.zh_Hans
},
'llm': db_tool.llm_description
},
'parameters': []
}
# get app from db
app: App = db_tool.app
if not app:
logger.error(f"app {db_tool.app_id} not found")
continue
app_model_config: AppModelConfig = app.app_model_config
user_input_form_list = app_model_config.user_input_form_list
for input_form in user_input_form_list:
# get type
form_type = input_form.keys()[0]
default = input_form[form_type]['default']
required = input_form[form_type]['required']
label = input_form[form_type]['label']
variable_name = input_form[form_type]['variable_name']
options = input_form[form_type].get('options', [])
if form_type == 'paragraph' or form_type == 'text-input':
tool['parameters'].append(ToolParamter(
name=variable_name,
label=I18nObject(
en_US=label,
zh_Hans=label
),
human_description=I18nObject(
en_US=label,
zh_Hans=label
),
llm_description=label,
form=ToolParamter.ToolParameterForm.FORM,
type=ToolParamter.ToolParameterType.STRING,
required=required,
default=default
))
elif form_type == 'select':
tool['parameters'].append(ToolParamter(
name=variable_name,
label=I18nObject(
en_US=label,
zh_Hans=label
),
human_description=I18nObject(
en_US=label,
zh_Hans=label
),
llm_description=label,
form=ToolParamter.ToolParameterForm.FORM,
type=ToolParamter.ToolParameterType.SELECT,
required=required,
default=default,
options=[ToolParamterOption(
value=option,
label=I18nObject(
en_US=option,
zh_Hans=option
)
) for option in options]
))
tools.append(Tool(**tool))
return tools

View File

@ -0,0 +1,26 @@
from core.tools.entities.user_entities import UserToolProvider
from typing import List
position = {
'google': 1,
'wikipedia': 2,
'dalle': 3,
'webscraper': 4,
'wolframalpha': 5,
'chart': 6,
'time': 7,
'yahoo': 8,
'stablediffusion': 9,
'vectorizer': 10,
'youtube': 11,
}
class BuiltinToolProviderSort:
@staticmethod
def sort(providers: List[UserToolProvider]) -> List[UserToolProvider]:
def sort_compare(provider: UserToolProvider) -> int:
return position.get(provider.name, 10000)
sorted_providers = sorted(providers, key=sort_compare)
return sorted_providers

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 KiB

View File

@ -0,0 +1,24 @@
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.chart.tools.line import LinearChartTool
import matplotlib.pyplot as plt
# use a business theme
plt.style.use('seaborn-v0_8-darkgrid')
class ChartProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
LinearChartTool().fork_tool_runtime(
meta={
"credentials": credentials,
}
).invoke(
user_id='',
tool_paramters={
"data": "1,3,5,7,9,2,4,6,8,10",
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -0,0 +1,11 @@
identity:
author: Dify
name: chart
label:
en_US: ChartGenerator
zh_Hans: 图表生成
description:
en_US: Chart Generator is a tool for generating statistical charts like bar chart, line chart, pie chart, etc.
zh_Hans: 图表生成是一个用于生成可视化图表的工具,你可以通过它来生成柱状图、折线图、饼图等各类图表
icon: icon.png
credentails_for_provider:

View File

@ -0,0 +1,47 @@
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
import matplotlib.pyplot as plt
import io
from typing import Any, Dict, List, Union
class BarChartTool(BuiltinTool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
data = tool_paramters.get('data', '')
if not data:
return self.create_text_message('Please input data')
data = data.split(';')
# if all data is int, convert to int
if all([i.isdigit() for i in data]):
data = [int(i) for i in data]
else:
data = [float(i) for i in data]
axis = tool_paramters.get('x_axis', None) or None
if axis:
axis = axis.split(';')
if len(axis) != len(data):
axis = None
flg, ax = plt.subplots(figsize=(10, 8))
if axis:
axis = [label[:10] + '...' if len(label) > 10 else label for label in axis]
ax.set_xticklabels(axis, rotation=45, ha='right')
ax.bar(axis, data)
else:
ax.bar(range(len(data)), data)
buf = io.BytesIO()
flg.savefig(buf, format='png')
buf.seek(0)
plt.close(flg)
return [
self.create_text_message('the bar chart is saved as an image.'),
self.create_blob_message(blob=buf.read(),
meta={'mime_type': 'image/png'})
]

View File

@ -0,0 +1,35 @@
identity:
name: bar_chart
author: Dify
label:
en_US: Bar Chart
zh_Hans: 柱状图
icon: icon.svg
description:
human:
en_US: Bar chart
zh_Hans: 柱状图
llm: generate a bar chart with input data
parameters:
- name: data
type: string
required: true
label:
en_US: data
zh_Hans: 数据
human_description:
en_US: data for generating bar chart
zh_Hans: 用于生成柱状图的数据
llm_description: data for generating bar chart, data should be a string contains a list of numbers like "1;2;3;4;5"
form: llm
- name: x_axis
type: string
required: false
label:
en_US: X Axis
zh_Hans: x 轴
human_description:
en_US: X axis for bar chart
zh_Hans: 柱状图的 x 轴
llm_description: x axis for bar chart, x axis should be a string contains a list of texts like "a;b;c;1;2" in order to match the data
form: llm

View File

@ -0,0 +1,49 @@
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
import matplotlib.pyplot as plt
import io
from typing import Any, Dict, List, Union
class LinearChartTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_paramters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
data = tool_paramters.get('data', '')
if not data:
return self.create_text_message('Please input data')
data = data.split(';')
axis = tool_paramters.get('x_axis', None) or None
if axis:
axis = axis.split(';')
if len(axis) != len(data):
axis = None
# if all data is int, convert to int
if all([i.isdigit() for i in data]):
data = [int(i) for i in data]
else:
data = [float(i) for i in data]
flg, ax = plt.subplots(figsize=(10, 8))
if axis:
axis = [label[:10] + '...' if len(label) > 10 else label for label in axis]
ax.set_xticklabels(axis, rotation=45, ha='right')
ax.plot(axis, data)
else:
ax.plot(data)
buf = io.BytesIO()
flg.savefig(buf, format='png')
buf.seek(0)
plt.close(flg)
return [
self.create_text_message('the linear chart is saved as an image.'),
self.create_blob_message(blob=buf.read(),
meta={'mime_type': 'image/png'})
]

View File

@ -0,0 +1,35 @@
identity:
name: line_chart
author: Dify
label:
en_US: Linear Chart
zh_Hans: 线性图表
icon: icon.svg
description:
human:
en_US: linear chart
zh_Hans: 线性图表
llm: generate a linear chart with input data
parameters:
- name: data
type: string
required: true
label:
en_US: data
zh_Hans: 数据
human_description:
en_US: data for generating linear chart
zh_Hans: 用于生成线性图表的数据
llm_description: data for generating linear chart, data should be a string contains a list of numbers like "1;2;3;4;5"
form: llm
- name: x_axis
type: string
required: false
label:
en_US: X Axis
zh_Hans: x 轴
human_description:
en_US: X axis for linear chart
zh_Hans: 线性图表的 x 轴
llm_description: x axis for linear chart, x axis should be a string contains a list of texts like "a;b;c;1;2" in order to match the data
form: llm

View File

@ -0,0 +1,46 @@
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
import matplotlib.pyplot as plt
import io
from typing import Any, Dict, List, Union
class PieChartTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_paramters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
data = tool_paramters.get('data', '')
if not data:
return self.create_text_message('Please input data')
data = data.split(';')
categories = tool_paramters.get('categories', None) or None
# if all data is int, convert to int
if all([i.isdigit() for i in data]):
data = [int(i) for i in data]
else:
data = [float(i) for i in data]
flg, ax = plt.subplots()
if categories:
categories = categories.split(';')
if len(categories) != len(data):
categories = None
if categories:
ax.pie(data, labels=categories)
else:
ax.pie(data)
buf = io.BytesIO()
flg.savefig(buf, format='png')
buf.seek(0)
plt.close(flg)
return [
self.create_text_message('the pie chart is saved as an image.'),
self.create_blob_message(blob=buf.read(),
meta={'mime_type': 'image/png'})
]

View File

@ -0,0 +1,35 @@
identity:
name: pie_chart
author: Dify
label:
en_US: Pie Chart
zh_Hans: 饼图
icon: icon.svg
description:
human:
en_US: Pie chart
zh_Hans: 饼图
llm: generate a pie chart with input data
parameters:
- name: data
type: string
required: true
label:
en_US: data
zh_Hans: 数据
human_description:
en_US: data for generating pie chart
zh_Hans: 用于生成饼图的数据
llm_description: data for generating pie chart, data should be a string contains a list of numbers like "1;2;3;4;5"
form: llm
- name: categories
type: string
required: true
label:
en_US: Categories
zh_Hans: 分类
human_description:
en_US: Categories for pie chart
zh_Hans: 饼图的分类
llm_description: categories for pie chart, categories should be a string contains a list of texts like "a;b;c;1;2" in order to match the data, each category should be split by ";"
form: llm

Binary file not shown.

After

Width:  |  Height:  |  Size: 153 KiB

View File

@ -0,0 +1,23 @@
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from core.tools.provider.builtin.dalle.tools.dalle2 import DallE2Tool
from core.tools.errors import ToolProviderCredentialValidationError
from typing import Any, Dict
class DALLEProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
try:
DallE2Tool().fork_tool_runtime(
meta={
"credentials": credentials,
}
).invoke(
user_id='',
tool_paramters={
"prompt": "cute girl, blue eyes, white hair, anime style",
"size": "small",
"n": 1
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -0,0 +1,47 @@
identity:
author: Dify
name: dalle
label:
en_US: DALL-E
zh_Hans: DALL-E 绘画
description:
en_US: DALL-E art
zh_Hans: DALL-E 绘画
icon: icon.png
credentails_for_provider:
openai_api_key:
type: secret-input
required: true
label:
en_US: OpenAI API key
zh_Hans: OpenAI API key
help:
en_US: Please input your OpenAI API key
zh_Hans: 请输入你的 OpenAI API key
placeholder:
en_US: Please input your OpenAI API key
zh_Hans: 请输入你的 OpenAI API key
openai_organizaion_id:
type: text-input
required: false
label:
en_US: OpenAI organization ID
zh_Hans: OpenAI organization ID
help:
en_US: Please input your OpenAI organization ID
zh_Hans: 请输入你的 OpenAI organization ID
placeholder:
en_US: Please input your OpenAI organization ID
zh_Hans: 请输入你的 OpenAI organization ID
openai_base_url:
type: text-input
required: false
label:
en_US: OpenAI base URL
zh_Hans: OpenAI base URL
help:
en_US: Please input your OpenAI base URL
zh_Hans: 请输入你的 OpenAI base URL
placeholder:
en_US: Please input your OpenAI base URL
zh_Hans: 请输入你的 OpenAI base URL

View File

@ -0,0 +1,66 @@
from typing import Any, Dict, List, Union
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
from base64 import b64decode
from os.path import join
from openai import OpenAI
class DallE2Tool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_paramters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
openai_organization = self.runtime.credentials.get('openai_organizaion_id', None)
if not openai_organization:
openai_organization = None
openai_base_url = self.runtime.credentials.get('openai_base_url', None)
if not openai_base_url:
openai_base_url = None
else:
openai_base_url = join(openai_base_url, 'v1')
client = OpenAI(
api_key=self.runtime.credentials['openai_api_key'],
base_url=openai_base_url,
organization=openai_organization
)
SIZE_MAPPING = {
'small': '256x256',
'medium': '512x512',
'large': '1024x1024',
}
# prompt
prompt = tool_paramters.get('prompt', '')
if not prompt:
return self.create_text_message('Please input prompt')
# get size
size = SIZE_MAPPING[tool_paramters.get('size', 'large')]
# get n
n = tool_paramters.get('n', 1)
# call openapi dalle2
response = client.images.generate(
prompt=prompt,
model='dall-e-2',
size=size,
n=n,
response_format='b64_json'
)
result = []
for image in response.data:
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
meta={ 'mime_type': 'image/png' },
save_as=self.VARIABLE_KEY.IMAGE.value))
return result

View File

@ -0,0 +1,63 @@
identity:
name: dalle2
author: Dify
label:
en_US: DALL-E 2
zh_Hans: DALL-E 2 绘画
description:
en_US: DALL-E 2 is a powerful drawing tool that can draw the image you want based on your prompt
zh_Hans: DALL-E 2 是一个强大的绘画工具,它可以根据您的提示词绘制出您想要的图像
description:
human:
en_US: DALL-E is a text to image tool
zh_Hans: DALL-E 是一个文本到图像的工具
llm: DALL-E is a tool used to generate images from text
parameters:
- name: prompt
type: string
required: true
label:
en_US: Prompt
zh_Hans: 提示词
human_description:
en_US: Image prompt, you can check the official documentation of DallE 2
zh_Hans: 图像提示词您可以查看DallE 2 的官方文档
llm_description: Image prompt of DallE 2, you should describe the image you want to generate as a list of words as possible as detailed
form: llm
- name: size
type: select
required: true
human_description:
en_US: used for selecting the image size
zh_Hans: 用于选择图像大小
label:
en_US: Image size
zh_Hans: 图像大小
form: form
options:
- value: small
label:
en_US: Small(256x256)
zh_Hans: 小(256x256)
- value: medium
label:
en_US: Medium(512x512)
zh_Hans: 中(512x512)
- value: large
label:
en_US: Large(1024x1024)
zh_Hans: 大(1024x1024)
default: large
- name: n
type: number
required: true
human_description:
en_US: used for selecting the number of images
zh_Hans: 用于选择图像数量
label:
en_US: Number of images
zh_Hans: 图像数量
form: form
default: 1
min: 1
max: 10

View File

@ -0,0 +1,74 @@
from typing import Any, Dict, List, Union
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
from base64 import b64decode
from os.path import join
from openai import OpenAI
class DallE3Tool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_paramters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
openai_organization = self.runtime.credentials.get('openai_organizaion_id', None)
if not openai_organization:
openai_organization = None
openai_base_url = self.runtime.credentials.get('openai_base_url', None)
if not openai_base_url:
openai_base_url = None
else:
openai_base_url = join(openai_base_url, 'v1')
client = OpenAI(
api_key=self.runtime.credentials['openai_api_key'],
base_url=openai_base_url,
organization=openai_organization
)
SIZE_MAPPING = {
'square': '1024x1024',
'vertical': '1024x1792',
'horizontal': '1792x1024',
}
# prompt
prompt = tool_paramters.get('prompt', '')
if not prompt:
return self.create_text_message('Please input prompt')
# get size
size = SIZE_MAPPING[tool_paramters.get('size', 'square')]
# get n
n = tool_paramters.get('n', 1)
# get quality
quality = tool_paramters.get('quality', 'standard')
if quality not in ['standard', 'hd']:
return self.create_text_message('Invalid quality')
# get style
style = tool_paramters.get('style', 'vivid')
if style not in ['natural', 'vivid']:
return self.create_text_message('Invalid style')
# call openapi dalle3
response = client.images.generate(
prompt=prompt,
model='dall-e-3',
size=size,
n=n,
style=style,
quality=quality,
response_format='b64_json'
)
result = []
for image in response.data:
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
meta={ 'mime_type': 'image/png' },
save_as=self.VARIABLE_KEY.IMAGE.value))
return result

View File

@ -0,0 +1,103 @@
identity:
name: dalle3
author: Dify
label:
en_US: DALL-E 3
zh_Hans: DALL-E 3 绘画
description:
en_US: DALL-E 3 is a powerful drawing tool that can draw the image you want based on your prompt, compared to DallE 2, DallE 3 has stronger drawing ability, but it will consume more resources
zh_Hans: DALL-E 3 是一个强大的绘画工具它可以根据您的提示词绘制出您想要的图像相比于DallE 2 DallE 3拥有更强的绘画能力但会消耗更多的资源
description:
human:
en_US: DALL-E is a text to image tool
zh_Hans: DALL-E 是一个文本到图像的工具
llm: DALL-E is a tool used to generate images from text
parameters:
- name: prompt
type: string
required: true
label:
en_US: Prompt
zh_Hans: 提示词
human_description:
en_US: Image prompt, you can check the official documentation of DallE 3
zh_Hans: 图像提示词您可以查看DallE 3 的官方文档
llm_description: Image prompt of DallE 3, you should describe the image you want to generate as a list of words as possible as detailed
form: llm
- name: size
type: select
required: true
human_description:
en_US: selecting the image size
zh_Hans: 选择图像大小
label:
en_US: Image size
zh_Hans: 图像大小
form: form
options:
- value: square
label:
en_US: Squre(1024x1024)
zh_Hans: 方(1024x1024)
- value: vertical
label:
en_US: Vertical(1024x1792)
zh_Hans: 竖屏(1024x1792)
- value: horizontal
label:
en_US: Horizontal(1792x1024)
zh_Hans: 横屏(1792x1024)
default: square
- name: n
type: number
required: true
human_description:
en_US: selecting the number of images
zh_Hans: 选择图像数量
label:
en_US: Number of images
zh_Hans: 图像数量
form: form
min: 1
max: 1
default: 1
- name: quality
type: select
required: true
human_description:
en_US: selecting the image quality
zh_Hans: 选择图像质量
label:
en_US: Image quality
zh_Hans: 图像质量
form: form
options:
- value: standard
label:
en_US: Standard
zh_Hans: 标准
- value: hd
label:
en_US: HD
zh_Hans: 高清
default: standard
- name: style
type: select
required: true
human_description:
en_US: selecting the image style
zh_Hans: 选择图像风格
label:
en_US: Image style
zh_Hans: 图像风格
form: form
options:
- value: vivid
label:
en_US: Vivid
zh_Hans: 生动
- value: natural
label:
en_US: Natural
zh_Hans: 自然
default: vivid

View File

@ -0,0 +1,6 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="25" viewBox="0 0 24 25" fill="none">
<path d="M22.501 12.7332C22.501 11.8699 22.4296 11.2399 22.2748 10.5865H12.2153V14.4832H18.12C18.001 15.4515 17.3582 16.9099 15.9296 17.8898L15.9096 18.0203L19.0902 20.435L19.3106 20.4565C21.3343 18.6249 22.501 15.9298 22.501 12.7332Z" fill="#4285F4"/>
<path d="M12.214 23C15.1068 23 17.5353 22.0666 19.3092 20.4567L15.9282 17.8899C15.0235 18.5083 13.8092 18.9399 12.214 18.9399C9.38069 18.9399 6.97596 17.1083 6.11874 14.5766L5.99309 14.5871L2.68583 17.0954L2.64258 17.2132C4.40446 20.6433 8.0235 23 12.214 23Z" fill="#34A853"/>
<path d="M6.12046 14.5766C5.89428 13.9233 5.76337 13.2233 5.76337 12.5C5.76337 11.7766 5.89428 11.0766 6.10856 10.4233L6.10257 10.2841L2.75386 7.7355L2.64429 7.78658C1.91814 9.20993 1.50146 10.8083 1.50146 12.5C1.50146 14.1916 1.91814 15.7899 2.64429 17.2132L6.12046 14.5766Z" fill="#FBBC05"/>
<path d="M12.2141 6.05997C14.2259 6.05997 15.583 6.91163 16.3569 7.62335L19.3807 4.73C17.5236 3.03834 15.1069 2 12.2141 2C8.02353 2 4.40447 4.35665 2.64258 7.78662L6.10686 10.4233C6.97598 7.89166 9.38073 6.05997 12.2141 6.05997Z" fill="#EB4335"/>
</svg>

After

Width:  |  Height:  |  Size: 1.2 KiB

View File

@ -0,0 +1,23 @@
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.google.tools.google_search import GoogleSearchTool
from typing import Any, Dict, List
class GoogleProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
try:
GoogleSearchTool().fork_tool_runtime(
meta={
"credentials": credentials,
}
).invoke(
user_id='',
tool_paramters={
"query": "test",
"result_type": "link"
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -0,0 +1,24 @@
identity:
author: Dify
name: google
label:
en_US: Google
zh_Hans: Google
description:
en_US: Google
zh_Hans: GoogleSearch
icon: icon.svg
credentails_for_provider:
serpapi_api_key:
type: secret-input
required: true
label:
en_US: SerpApi API key
zh_Hans: SerpApi API key
placeholder:
en_US: Please input your SerpApi API key
zh_Hans: 请输入你的 SerpApi API key
help:
en_US: Get your SerpApi API key from SerpApi
zh_Hans: 从 SerpApi 获取您的 SerpApi API key
url: https://serpapi.com/manage-api-key

View File

@ -0,0 +1,163 @@
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
from typing import Any, Dict, List, Union
import os
import sys
from serpapi import GoogleSearch
class HiddenPrints:
"""Context manager to hide prints."""
def __enter__(self) -> None:
"""Open file to pipe stdout to."""
self._original_stdout = sys.stdout
sys.stdout = open(os.devnull, "w")
def __exit__(self, *_: Any) -> None:
"""Close file that stdout was piped to."""
sys.stdout.close()
sys.stdout = self._original_stdout
class SerpAPI:
"""
SerpAPI tool provider.
"""
search_engine: Any #: :meta private:
serpapi_api_key: str = None
def __init__(self, api_key: str) -> None:
"""Initialize SerpAPI tool provider."""
self.serpapi_api_key = api_key
self.search_engine = GoogleSearch
def run(self, query: str, **kwargs: Any) -> str:
"""Run query through SerpAPI and parse result."""
typ = kwargs.get("result_type", "text")
return self._process_response(self.results(query), typ=typ)
def results(self, query: str) -> dict:
"""Run query through SerpAPI and return the raw result."""
params = self.get_params(query)
with HiddenPrints():
search = self.search_engine(params)
res = search.get_dict()
return res
def get_params(self, query: str) -> Dict[str, str]:
"""Get parameters for SerpAPI."""
_params = {
"api_key": self.serpapi_api_key,
"q": query,
}
params = {
"engine": "google",
"google_domain": "google.com",
"gl": "us",
"hl": "en",
**_params
}
return params
@staticmethod
def _process_response(res: dict, typ: str) -> str:
"""Process response from SerpAPI."""
if "error" in res.keys():
raise ValueError(f"Got error from SerpAPI: {res['error']}")
if typ == "text":
if "answer_box" in res.keys() and type(res["answer_box"]) == list:
res["answer_box"] = res["answer_box"][0]
if "answer_box" in res.keys() and "answer" in res["answer_box"].keys():
toret = res["answer_box"]["answer"]
elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
toret = res["answer_box"]["snippet"]
elif (
"answer_box" in res.keys()
and "snippet_highlighted_words" in res["answer_box"].keys()
):
toret = res["answer_box"]["snippet_highlighted_words"][0]
elif (
"sports_results" in res.keys()
and "game_spotlight" in res["sports_results"].keys()
):
toret = res["sports_results"]["game_spotlight"]
elif (
"shopping_results" in res.keys()
and "title" in res["shopping_results"][0].keys()
):
toret = res["shopping_results"][:3]
elif (
"knowledge_graph" in res.keys()
and "description" in res["knowledge_graph"].keys()
):
toret = res["knowledge_graph"]["description"]
elif "snippet" in res["organic_results"][0].keys():
toret = res["organic_results"][0]["snippet"]
elif "link" in res["organic_results"][0].keys():
toret = res["organic_results"][0]["link"]
elif (
"images_results" in res.keys()
and "thumbnail" in res["images_results"][0].keys()
):
thumbnails = [item["thumbnail"] for item in res["images_results"][:10]]
toret = thumbnails
else:
toret = "No good search result found"
elif typ == "link":
if "knowledge_graph" in res.keys() and "title" in res["knowledge_graph"].keys() \
and "description_link" in res["knowledge_graph"].keys():
toret = res["knowledge_graph"]["description_link"]
elif "knowledge_graph" in res.keys() and "see_results_about" in res["knowledge_graph"].keys() \
and len(res["knowledge_graph"]["see_results_about"]) > 0:
see_result_about = res["knowledge_graph"]["see_results_about"]
toret = ""
for item in see_result_about:
if "name" not in item.keys() or "link" not in item.keys():
continue
toret += f"[{item['name']}]({item['link']})\n"
elif "organic_results" in res.keys() and len(res["organic_results"]) > 0:
organic_results = res["organic_results"]
toret = ""
for item in organic_results:
if "title" not in item.keys() or "link" not in item.keys():
continue
toret += f"[{item['title']}]({item['link']})\n"
elif "related_questions" in res.keys() and len(res["related_questions"]) > 0:
related_questions = res["related_questions"]
toret = ""
for item in related_questions:
if "question" not in item.keys() or "link" not in item.keys():
continue
toret += f"[{item['question']}]({item['link']})\n"
elif "related_searches" in res.keys() and len(res["related_searches"]) > 0:
related_searches = res["related_searches"]
toret = ""
for item in related_searches:
if "query" not in item.keys() or "link" not in item.keys():
continue
toret += f"[{item['query']}]({item['link']})\n"
else:
toret = "No good search result found"
return toret
class GoogleSearchTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_paramters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
query = tool_paramters['query']
result_type = tool_paramters['result_type']
api_key = self.runtime.credentials['serpapi_api_key']
result = SerpAPI(api_key).run(query, result_type=result_type)
if result_type == 'text':
return self.create_text_message(text=result)
return self.create_link_message(link=result)

View File

@ -0,0 +1,43 @@
identity:
name: google_search
author: Dify
label:
en_US: GoogleSearch
zh_Hans: 谷歌搜索
description:
human:
en_US: A tool for performing a Google SERP search and extracting snippets and webpages.Input should be a search query.
zh_Hans: 一个用于执行 Google SERP 搜索并提取片段和网页的工具。输入应该是一个搜索查询。
llm: A tool for performing a Google SERP search and extracting snippets and webpages.Input should be a search query.
parameters:
- name: query
type: string
required: true
label:
en_US: Query string
zh_Hans: 查询语句
human_description:
en_US: used for searching
zh_Hans: 用于搜索网页内容
llm_description: key words for searching
form: llm
- name: result_type
type: select
required: true
options:
- value: text
label:
en_US: text
zh_Hans: 文本
- value: link
label:
en_US: link
zh_Hans: 链接
default: link
label:
en_US: Result type
zh_Hans: 结果类型
human_description:
en_US: used for selecting the result type, text or link
zh_Hans: 用于选择结果类型,使用文本还是链接进行展示
form: form

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

View File

@ -0,0 +1,26 @@
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.stablediffusion.tools.stable_diffusion import StableDiffusionTool
from typing import Any, Dict
class StableDiffusionProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
try:
StableDiffusionTool().fork_tool_runtime(
meta={
"credentials": credentials,
}
).invoke(
user_id='',
tool_paramters={
"prompt": "cat",
"lora": "",
"steps": 1,
"width": 512,
"height": 512,
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -0,0 +1,29 @@
identity:
author: Dify
name: stablediffusion
label:
en_US: Stable Diffusion
zh_Hans: Stable Diffusion
description:
en_US: Stable Diffusion is a tool for generating images which can be deployed locally.
zh_Hans: Stable Diffusion 是一个可以在本地部署的图片生成的工具。
icon: icon.png
credentails_for_provider:
base_url:
type: secret-input
required: true
label:
en_US: Base URL
zh_Hans: StableDiffusion服务器的Base URL
placeholder:
en_US: Please input your StableDiffusion server's Base URL
zh_Hans: 请输入你的 StableDiffusion 服务器的 Base URL
model:
type: text-input
required: true
label:
en_US: Model
zh_Hans: 模型
placeholder:
en_US: Please input your model
zh_Hans: 请输入你的模型名称

View File

@ -0,0 +1,244 @@
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParamter, ToolParamterOption
from core.tools.entities.common_entities import I18nObject
from core.tools.errors import ToolProviderCredentialValidationError
from typing import Any, Dict, List, Union
from httpx import post
from os.path import join
from base64 import b64decode, b64encode
from PIL import Image
import json
import io
from copy import deepcopy
DRAW_TEXT_OPTIONS = {
"prompt": "",
"negative_prompt": "",
"seed": -1,
"subseed": -1,
"subseed_strength": 0,
"seed_resize_from_h": -1,
'sampler_index': 'DPM++ SDE Karras',
"seed_resize_from_w": -1,
"batch_size": 1,
"n_iter": 1,
"steps": 10,
"cfg_scale": 7,
"width": 1024,
"height": 1024,
"restore_faces": False,
"do_not_save_samples": False,
"do_not_save_grid": False,
"eta": 0,
"denoising_strength": 0,
"s_min_uncond": 0,
"s_churn": 0,
"s_tmax": 0,
"s_tmin": 0,
"s_noise": 0,
"override_settings": {},
"override_settings_restore_afterwards": True,
"refiner_switch_at": 0,
"disable_extra_networks": False,
"comments": {},
"enable_hr": False,
"firstphase_width": 0,
"firstphase_height": 0,
"hr_scale": 2,
"hr_second_pass_steps": 0,
"hr_resize_x": 0,
"hr_resize_y": 0,
"hr_prompt": "",
"hr_negative_prompt": "",
"script_args": [],
"send_images": True,
"save_images": False,
"alwayson_scripts": {}
}
class StableDiffusionTool(BuiltinTool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
# base url
base_url = self.runtime.credentials.get('base_url', None)
if not base_url:
return self.create_text_message('Please input base_url')
model = self.runtime.credentials.get('model', None)
if not model:
return self.create_text_message('Please input model')
# set model
try:
url = join(base_url, 'sdapi/v1/options')
response = post(url, data=json.dumps({
'sd_model_checkpoint': model
}))
if response.status_code != 200:
raise ToolProviderCredentialValidationError('Failed to set model, please tell user to set model')
except Exception as e:
raise ToolProviderCredentialValidationError('Failed to set model, please tell user to set model')
# prompt
prompt = tool_paramters.get('prompt', '')
if not prompt:
return self.create_text_message('Please input prompt')
# get negative prompt
negative_prompt = tool_paramters.get('negative_prompt', '')
# get size
width = tool_paramters.get('width', 1024)
height = tool_paramters.get('height', 1024)
# get steps
steps = tool_paramters.get('steps', 1)
# get lora
lora = tool_paramters.get('lora', '')
# get image id
image_id = tool_paramters.get('image_id', '')
if image_id.strip():
image_variable = self.get_default_image_variable()
if image_variable:
image_binary = self.get_variable_file(image_variable.name)
if not image_binary:
return self.create_text_message('Image not found, please request user to generate image firstly.')
# convert image to RGB
image = Image.open(io.BytesIO(image_binary))
image = image.convert("RGB")
buffer = io.BytesIO()
image.save(buffer, format="PNG")
image_binary = buffer.getvalue()
image.close()
return self.img2img(base_url=base_url,
lora=lora,
image_binary=image_binary,
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
steps=steps)
return self.text2img(base_url=base_url,
lora=lora,
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
steps=steps)
def img2img(self, base_url: str, lora: str, image_binary: bytes,
prompt: str, negative_prompt: str,
width: int, height: int, steps: int) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
generate image
"""
draw_options = {
"init_images": [b64encode(image_binary).decode('utf-8')],
"prompt": "",
"negative_prompt": negative_prompt,
"denoising_strength": 0.9,
"width": width,
"height": height,
"cfg_scale": 7,
"sampler_name": "Euler a",
"restore_faces": False,
"steps": steps,
"script_args": ["outpainting mk2"]
}
if lora:
draw_options['prompt'] = f'{lora},{prompt}'
else:
draw_options['prompt'] = prompt
try:
url = join(base_url, 'sdapi/v1/img2img')
response = post(url, data=json.dumps(draw_options), timeout=120)
if response.status_code != 200:
return self.create_text_message('Failed to generate image')
image = response.json()['images'][0]
return self.create_blob_message(blob=b64decode(image),
meta={ 'mime_type': 'image/png' },
save_as=self.VARIABLE_KEY.IMAGE.value)
except Exception as e:
return self.create_text_message('Failed to generate image')
def text2img(self, base_url: str, lora: str, prompt: str, negative_prompt: str, width: int, height: int, steps: int) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
generate image
"""
# copy draw options
draw_options = deepcopy(DRAW_TEXT_OPTIONS)
if lora:
draw_options['prompt'] = f'{lora},{prompt}'
draw_options['width'] = width
draw_options['height'] = height
draw_options['steps'] = steps
draw_options['negative_prompt'] = negative_prompt
try:
url = join(base_url, 'sdapi/v1/txt2img')
response = post(url, data=json.dumps(draw_options), timeout=120)
if response.status_code != 200:
return self.create_text_message('Failed to generate image')
image = response.json()['images'][0]
return self.create_blob_message(blob=b64decode(image),
meta={ 'mime_type': 'image/png' },
save_as=self.VARIABLE_KEY.IMAGE.value)
except Exception as e:
return self.create_text_message('Failed to generate image')
def get_runtime_parameters(self) -> List[ToolParamter]:
parameters = [
ToolParamter(name='prompt',
label=I18nObject(en_US='Prompt', zh_Hans='Prompt'),
human_description=I18nObject(
en_US='Image prompt, you can check the official documentation of Stable Diffusion',
zh_Hans='图像提示词,您可以查看 Stable Diffusion 的官方文档',
),
type=ToolParamter.ToolParameterType.STRING,
form=ToolParamter.ToolParameterForm.LLM,
llm_description='Image prompt of Stable Diffusion, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English.',
required=True),
]
if len(self.list_default_image_variables()) != 0:
parameters.append(
ToolParamter(name='image_id',
label=I18nObject(en_US='image_id', zh_Hans='image_id'),
human_description=I18nObject(
en_US='Image id of the image you want to generate based on, if you want to generate image based on the default image, you can leave this field empty.',
zh_Hans='您想要生成的图像的图像 ID如果您想要基于默认图像生成图像则可以将此字段留空。',
),
type=ToolParamter.ToolParameterType.STRING,
form=ToolParamter.ToolParameterForm.LLM,
llm_description='Image id of the original image, you can leave this field empty if you want to generate a new image.',
required=True,
options=[ToolParamterOption(
value=i.name,
label=I18nObject(en_US=i.name, zh_Hans=i.name)
) for i in self.list_default_image_variables()])
)
return parameters

View File

@ -0,0 +1,77 @@
identity:
name: stable_diffusion
author: Dify
label:
en_US: Stable Diffusion WebUI
zh_Hans: Stable Diffusion WebUI
description:
human:
en_US: A tool for generating images which can be deployed locally, you can use stable-diffusion-webui to deploy it.
zh_Hans: 一个可以在本地部署的图片生成的工具,您可以使用 stable-diffusion-webui 来部署它。
llm: draw the image you want based on your prompt.
parameters:
- name: prompt
type: string
required: true
label:
en_US: Prompt
zh_Hans: 提示词
human_description:
en_US: Image prompt, you can check the official documentation of Stable Diffusion
zh_Hans: 图像提示词,您可以查看 Stable Diffusion 的官方文档
llm_description: Image prompt of Stable Diffusion, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English.
form: llm
- name: lora
type: string
required: false
label:
en_US: Lora
zh_Hans: Lora
human_description:
en_US: Lora
zh_Hans: Lora
form: form
- name: steps
type: number
required: false
label:
en_US: Steps
zh_Hans: Steps
human_description:
en_US: Steps
zh_Hans: Steps
form: form
default: 10
- name: width
type: number
required: false
label:
en_US: Width
zh_Hans: Width
human_description:
en_US: Width
zh_Hans: Width
form: form
default: 1024
- name: height
type: number
required: false
label:
en_US: Height
zh_Hans: Height
human_description:
en_US: Height
zh_Hans: Height
form: form
default: 1024
- name: negative_prompt
type: string
required: false
label:
en_US: Negative prompt
zh_Hans: Negative prompt
human_description:
en_US: Negative prompt
zh_Hans: Negative prompt
form: form
default: bad art, ugly, deformed, watermark, duplicated, discontinuous lines

View File

@ -0,0 +1,3 @@
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path fill-rule="evenodd" clip-rule="evenodd" d="M0.666992 8.00008C0.666992 3.94999 3.95024 0.666748 8.00033 0.666748C12.0504 0.666748 15.3337 3.94999 15.3337 8.00008C15.3337 12.0502 12.0504 15.3334 8.00033 15.3334C3.95024 15.3334 0.666992 12.0502 0.666992 8.00008ZM8.66699 4.00008C8.66699 3.63189 8.36852 3.33341 8.00033 3.33341C7.63213 3.33341 7.33366 3.63189 7.33366 4.00008V8.00008C7.33366 8.2526 7.47633 8.48344 7.70218 8.59637L10.3688 9.9297C10.6982 10.0944 11.0986 9.96088 11.2633 9.63156C11.4279 9.30224 11.2945 8.90179 10.9651 8.73713L8.66699 7.58806V4.00008Z" fill="#EC4A0A"/>
</svg>

After

Width:  |  Height:  |  Size: 691 B

View File

@ -0,0 +1,16 @@
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.time.tools.current_time import CurrentTimeTool
from typing import Any, Dict
class WikiPediaProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
try:
CurrentTimeTool().invoke(
user_id='',
tool_paramters={},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -0,0 +1,11 @@
identity:
author: Dify
name: time
label:
en_US: CurrentTime
zh_Hans: 时间
description:
en_US: A tool for getting the current time.
zh_Hans: 一个用于获取当前时间的工具。
icon: icon.svg
credentails_for_provider:

View File

@ -0,0 +1,17 @@
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
from typing import Any, Dict, List, Union
from datetime import datetime, timezone
class CurrentTimeTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_paramters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
return self.create_text_message(f'{datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z")}')

View File

@ -0,0 +1,12 @@
identity:
name: current_time
author: Dify
label:
en_US: Current Time
zh_Hans: 获取当前时间
description:
human:
en_US: A tool for getting the current time.
zh_Hans: 一个用于获取当前时间的工具。
llm: A tool for getting the current time.
parameters:

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 KiB

View File

@ -0,0 +1 @@
VECTORIZER_ICON_PNG = 'iVBORw0KGgoAAAANSUhEUgAAAGAAAABgCAYAAADimHc4AAAACXBIWXMAACxLAAAsSwGlPZapAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAboSURBVHgB7Z09bBxFFMffRoAvcQqbguBUxu4wCUikMCZ0TmQK4NLQJCJOlQIkokgEGhQ7NCFIKEhQuIqNnIaGMxRY2GVwmlggDHS+pIHELmIXMTEULPP3eeXz7e7szO7MvE1ufpKV03nuNn7/mfcxH7tEHo/H42lXgqwG1bGw65+/aTQM6K0gpJdCoi7ypCIMui5s9Qv9R1OVTqrVxoL1jPbpvH4hrIp/rnmj5+YOhTQ++1kwmdZgT9ovRi6EF4Xhv/XGL0Sv6OLXYMu0BokjYOSDcBQfJI8xhKFP/HAlqCW8v5vqubBr8yn6maCexxiIDR376LnWmBBzQZtPEvx+L3mMAleOZKb1/XgM2EOnyWMFZJKt78UEQKpJHisk2TYmgM967JFk2z3kYcULwIwXgBkvADNeAGa8AMw8Qcwc6N55/eAh0cYmGaOzQtR/kOhQX+M6+/c23r+3RlT/i2ipTrSyRqw4F+CwMMbgANHQwG7jRywLw/wqDDNzI79xYPjqa2L262jjtYzaT0QT3xEbsck4MXUakgWOvUx08liy0ZPYEKNhel4Y6AZpgR7/8Tvq1wEQ+sMJN6Nh9kqwy+bWYwAM8elZovNv6xmlU7iLs280RNO9ls51os/h/8eBVQEig8Dt5OXUsNrno2tluZw0cI3qUXKONQHy9sYkVHqnjntLA2LnFTAv1gSA+zBhfIDvkfVO/B4xRgWZn4fbe2WAnGJFAAxn03+I7PtUXdzE90Sjl4ne+6L4d5nCigAyYyHPn7tFdPN30uJwX/qI6jtISkQZFVLdhd9SrtNPTrFSB6QZBAaYntsptpAyfvk+KYOCamVR/XrNtLqepduiFnkh3g4iIw6YLAhlOJmKwB9zaarhApr/MPREjAZVisSU1s/KYsGzhmKXClYEWLm/8xpV7btXhcv5I7lt2vtJFA3q/T07r1HopdG5l5xhxQVdn28YFn8kBJCBOZmiPHio1m5QuJzlu9ntXApgZwSsNYJslvGjtjrfm8Sq4neceFUtz3dZCzwW09Gqo2hreuPN7HZRnNqa1BP1x8lhczVNK+zT0TqkjYAF4e7Okxoo2PZX5K4IrhNpb/P8FTK2S1+TcUq1HpBFmquJYo1qEYU6RVarJE0c2ooL7C5IRwBZ5nJ9joyRtk5hA3YBdHqWzG1gBKgE/bzMaK5LqMIugKrbUDHu59/YWVRBsWhrsYZdANV5HBUXYGNlC9dFBW8LdgH6FQVYUnQvkQgm3NH8YuO7bM4LsWZBfT3qRY9OxRyJgJRz+Ij+FDPEQ1C3GVMiWAVQ7f31u/ncytxi4wdZTbRGgdcHnpYLD/FcwSrAoOKizfKfVAiIF4kBMPK+Opfe1iWsMUB1BJh2BRgBabSNAOiFqkXYbcNFUF9P+u82FGdWTcEmgGrvh0FUppB1kC073muXEaDq/21kIjLxV9tFAC7/n5X6tkUM0PH/dcP+P0v41fvkFBYBVHs/MD0CDmVsOzEdb7JgEYDT/8uq4rpj44NSjwDTc/CyzV1gxbH7Ac4F0PH/S4ZHAOaFZLiY+2nFuQA6/t9kQMTCz1CG66tbWvWS4VwAVf9vugAbel6efqrsYbKBcwFeVNz8ajobyTppw2F84FQAnfl/kwER6wJZcWdBc7e2KZwKoOP/TVakWb0f7md+kVhwOwI0BDCFyq42rt4PSiuAiRGAEXdK4ZQlV+8HTgVwefwHvR7nhbOA0FwBGDgTIM/Z3SLXUj2hOW1wR10eSrs7Ou9eTB3jo/dzuh/gTABdn35c8dhpM3BxOmeTuXs/cDoCdDY4qe7l32pbaZxL1jF+GXo/cLotBcWVTiZU3T7RMn8rHiijW9FgauP4Ef1TLdhHWgacCgAj6tYCqGKjU/DNbqxIkMYZNs7MpxmnLuhmwYJna1dbdzHjY42hDL4/wqkA6HWuDkAngRH0iYVjRkVwnoZO/0gsuLwpkw7OBcAtwlwvfESHxctmfMBSiOG0oStj4HCF7T3+RWARwIU7QK/HbWlqls52mYJtezqMj3v34C5VOveFy8Ll4QoTsJ8Txp0RsW8/Os2im2LCtSC1RIqLw3RldTVplOKkPEYDhMAPqttnune2rzTv5Y+WKdEem2ixkWqZYSeDSUp3qwIYNOrR7cBjcbOORxkvADNeAGa8AMx4AZjxAjATf5Ab0Tp5rJBk2/iD3PAwYo8Vkmyb9CjDGfLYIaCp1rdiAnT8S5PeDVkgoDuVCsWeJxwToHZ163m3Z8hjloDGk54vn5gFbT/5eZw8phifvZz8XPlA9qmRj8JRCumi+OkljzbbrvxM0qPMm9rIqY6FXZubVBUinMbzcP3jbuXA6Mh2kMx07KPJJLfj8Xg8Hg/4H+KfFYb2WM4MAAAAAElFTkSuQmCC'

View File

@ -0,0 +1,74 @@
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParamter
from core.tools.provider.builtin.vectorizer.tools.test_data import VECTORIZER_ICON_PNG
from core.tools.errors import ToolProviderCredentialValidationError
from typing import Any, Dict, List, Union
from httpx import post
from base64 import b64decode
class VectorizerTool(BuiltinTool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
api_key_name = self.runtime.credentials.get('api_key_name', None)
api_key_value = self.runtime.credentials.get('api_key_value', None)
mode = tool_paramters.get('mode', 'test')
if mode == 'production':
mode = 'preview'
if not api_key_name or not api_key_value:
raise ToolProviderCredentialValidationError('Please input api key name and value')
image_id = tool_paramters.get('image_id', '')
if not image_id:
return self.create_text_message('Please input image id')
if image_id.startswith('__test_'):
image_binary = b64decode(VECTORIZER_ICON_PNG)
else:
image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE)
if not image_binary:
return self.create_text_message('Image not found, please request user to generate image firstly.')
response = post(
'https://vectorizer.ai/api/v1/vectorize',
files={
'image': image_binary
},
data={
'mode': mode
} if mode == 'test' else {},
auth=(api_key_name, api_key_value),
timeout=30
)
if response.status_code != 200:
raise Exception(response.text)
return [
self.create_text_message('the vectorized svg is saved as an image.'),
self.create_blob_message(blob=response.content,
meta={'mime_type': 'image/svg+xml'})
]
def get_runtime_parameters(self) -> List[ToolParamter]:
"""
override the runtime parameters
"""
return [
ToolParamter.get_simple_instance(
name='image_id',
llm_description=f'the image id that you want to vectorize, \
and the image id should be specified in \
{[i.name for i in self.list_default_image_variables()]}',
type=ToolParamter.ToolParameterType.SELECT,
required=True,
options=[i.name for i in self.list_default_image_variables()]
)
]
def is_tool_avaliable(self) -> bool:
return len(self.list_default_image_variables()) > 0

View File

@ -0,0 +1,32 @@
identity:
name: vectorizer
author: Dify
label:
en_US: Vectorizer.AI
zh_Hans: Vectorizer.AI
description:
human:
en_US: Convert your PNG and JPG images to SVG vectors quickly and easily. Fully automatically. Using AI.
zh_Hans: 一个将 PNG 和 JPG 图像快速轻松地转换为 SVG 矢量图的工具。
llm: A tool for converting images to SVG vectors. you should input the image id as the input of this tool. the image id can be got from parameters.
parameters:
- name: mode
type: select
required: true
options:
- value: production
label:
en_US: production
zh_Hans: 生产模式
- value: test
label:
en_US: test
zh_Hans: 测试模式
default: test
label:
en_US: Mode
zh_Hans: 模式
human_description:
en_US: It is free to integrate with and test out the API in test mode, no subscription required.
zh_Hans: 在测试模式下可以免费测试API。
form: form

View File

@ -0,0 +1,23 @@
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.vectorizer.tools.vectorizer import VectorizerTool
from typing import Any, Dict
class VectorizerProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
try:
VectorizerTool().fork_tool_runtime(
meta={
"credentials": credentials,
}
).invoke(
user_id='',
tool_paramters={
"mode": "test",
"image_id": "__test_123"
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -0,0 +1,36 @@
identity:
author: Dify
name: vectorizer
label:
en_US: Vectorizer.AI
zh_Hans: Vectorizer.AI
description:
en_US: Convert your PNG and JPG images to SVG vectors quickly and easily. Fully automatically. Using AI.
zh_Hans: 一个将 PNG 和 JPG 图像快速轻松地转换为 SVG 矢量图的工具。
icon: icon.png
credentails_for_provider:
api_key_name:
type: secret-input
required: true
label:
en_US: Vectorizer.AI API Key name
zh_Hans: Vectorizer.AI API Key name
placeholder:
en_US: Please input your Vectorizer.AI ApiKey name
zh_Hans: 请输入你的 Vectorizer.AI ApiKey name
help:
en_US: Get your Vectorizer.AI API Key from Vectorizer.AI.
zh_Hans: 从 Vectorizer.AI 获取您的 Vectorizer.AI API Key。
url: https://vectorizer.ai/api
api_key_value:
type: secret-input
required: true
label:
en_US: Vectorizer.AI API Key
zh_Hans: Vectorizer.AI API Key
placeholder:
en_US: Please input your Vectorizer.AI ApiKey
zh_Hans: 请输入你的 Vectorizer.AI ApiKey
help:
en_US: Get your Vectorizer.AI API Key from Vectorizer.AI.
zh_Hans: 从 Vectorizer.AI 获取您的 Vectorizer.AI API Key。

View File

@ -0,0 +1,3 @@
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="17" viewBox="0 0 16 17" fill="none">
<path fill-rule="evenodd" clip-rule="evenodd" d="M2.6665 1.16667C1.56193 1.16667 0.666504 2.0621 0.666504 3.16667C0.666504 4.27124 1.56193 5.16667 2.6665 5.16667C2.79161 5.16667 2.91403 5.15519 3.03277 5.13321C2.3808 6.09319 1.99984 7.25211 1.99984 8.5C1.99984 9.7479 2.3808 10.9068 3.03277 11.8668C2.91403 11.8448 2.79161 11.8333 2.6665 11.8333C1.56193 11.8333 0.666504 12.7288 0.666504 13.8333C0.666504 14.9379 1.56193 15.8333 2.6665 15.8333C3.77107 15.8333 4.6665 14.9379 4.6665 13.8333C4.6665 13.7082 4.65502 13.5858 4.63304 13.4671C5.59302 14.119 6.75194 14.5 7.99984 14.5C9.24773 14.5 10.4066 14.119 11.3666 13.4671C11.3447 13.5858 11.3332 13.7082 11.3332 13.8333C11.3332 14.9379 12.2286 15.8333 13.3332 15.8333C14.4377 15.8333 15.3332 14.9379 15.3332 13.8333C15.3332 12.7288 14.4377 11.8333 13.3332 11.8333C13.2081 11.8333 13.0856 11.8448 12.9669 11.8668C13.6189 10.9068 13.9998 9.7479 13.9998 8.5C13.9998 7.25211 13.6189 6.09319 12.9669 5.13321C13.0856 5.15519 13.2081 5.16667 13.3332 5.16667C14.4377 5.16667 15.3332 4.27124 15.3332 3.16667C15.3332 2.0621 14.4377 1.16667 13.3332 1.16667C12.2286 1.16667 11.3332 2.0621 11.3332 3.16667C11.3332 3.29177 11.3447 3.41419 11.3666 3.53293C10.4066 2.88097 9.24773 2.50001 7.99984 2.50001C6.75194 2.50001 5.59302 2.88097 4.63304 3.53293C4.65502 3.41419 4.6665 3.29177 4.6665 3.16667C4.6665 2.0621 3.77107 1.16667 2.6665 1.16667ZM3.38043 7.83334C3.63081 6.08287 4.85262 4.64578 6.48223 4.08565C5.79223 5.22099 5.36488 6.50185 5.23815 7.83334H3.38043ZM6.48228 12.9144C4.85264 12.3543 3.63082 10.9172 3.38043 9.16667H5.23815C5.3649 10.4982 5.79226 11.779 6.48228 12.9144ZM12.6192 9.16667C12.3689 10.9168 11.1475 12.3537 9.5183 12.9141C10.2082 11.7788 10.6355 10.498 10.7622 9.16667H12.6192ZM9.51834 4.08596C11.1475 4.64631 12.3689 6.0832 12.6192 7.83334H10.7622C10.6355 6.50197 10.2082 5.22123 9.51834 4.08596ZM9.4218 7.83334C9.27457 6.52262 8.78381 5.27411 8.00019 4.2145C7.21658 5.27411 6.72582 6.52262 6.57859 7.83334H9.4218ZM6.5786 9.16667C6.72583 10.4774 7.21659 11.7259 8.00019 12.7855C8.7838 11.7259 9.27456 10.4774 9.42179 9.16667H6.5786Z" fill="#DD2590"/>
</svg>

After

Width:  |  Height:  |  Size: 2.2 KiB

View File

@ -0,0 +1,28 @@
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.errors import ToolInvokeError
from typing import Any, Dict, List, Union
class WebscraperTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_paramters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
try:
url = tool_paramters.get('url', '')
user_agent = tool_paramters.get('user_agent', '')
if not url:
return self.create_text_message('Please input url')
# get webpage
result = self.get_url(url, user_agent=user_agent)
# summarize and return
return self.create_text_message(self.summary(user_id=user_id, content=result))
except Exception as e:
raise ToolInvokeError(str(e))

View File

@ -0,0 +1,34 @@
identity:
name: webscraper
author: Dify
label:
en_US: Web Scraper
zh_Hans: 网页爬虫
description:
human:
en_US: A tool for scraping webpages.
zh_Hans: 一个用于爬取网页的工具。
llm: A tool for scraping webpages. Input should be a URL.
parameters:
- name: url
type: string
required: true
label:
en_US: URL
zh_Hans: 网页链接
human_description:
en_US: used for linking to webpages
zh_Hans: 用于链接到网页
llm_description: url for scraping
form: llm
- name: user_agent
type: string
required: false
label:
en_US: User Agent
zh_Hans: User Agent
human_description:
en_US: used for identifying the browser.
zh_Hans: 用于识别浏览器。
form: form
default: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/100.0.1000.0 Safari/537.36

View File

@ -0,0 +1,23 @@
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.webscraper.tools.webscraper import WebscraperTool
from typing import Any, Dict, List
class WebscraperProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
try:
WebscraperTool().fork_tool_runtime(
meta={
"credentials": credentials,
}
).invoke(
user_id='',
tool_paramters={
'url': 'https://www.google.com',
'user_agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 '
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -0,0 +1,11 @@
identity:
author: Dify
name: webscraper
label:
en_US: WebScraper
zh_Hans: 网页抓取
description:
en_US: Web Scrapper tool kit is used to scrape web
zh_Hans: 一个用于抓取网页的工具。
icon: icon.svg
credentails_for_provider:

View File

@ -0,0 +1,3 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="25" viewBox="0 0 24 25" fill="none">
<path d="M22.3627 6.50009H18.3156H18.1783V6.63743V7.07751V7.21484H18.3156H18.5969C18.924 7.21484 19.2189 7.38272 19.386 7.66394C19.553 7.94516 19.5593 8.28448 19.4028 8.57169L14.9027 16.8317L12.8532 11.9459L14.7837 8.40336C15.1832 7.67026 15.95 7.21484 16.7849 7.21484H16.8761H17.0134V7.07751V6.63743V6.50009H16.8761H12.829H12.6917V6.63743V7.07751V7.21484H12.829H13.1102C13.4373 7.21484 13.7323 7.38272 13.8993 7.66394C14.0663 7.94516 14.0726 8.28448 13.9162 8.57169L12.5159 11.1419L11.268 8.16696C11.1776 7.95134 11.1999 7.71594 11.3294 7.52124C11.4589 7.32654 11.6673 7.21484 11.9011 7.21484H12.221H12.3583V7.07751V6.63743V6.50009H12.221H7.3808H7.24347V6.63743V7.07751V7.21484H7.3808H7.44737C8.40218 7.21484 9.25775 7.78379 9.62715 8.66426L11.471 13.0599L9.4161 16.8317L5.78141 8.16696C5.69095 7.95134 5.71334 7.71594 5.8428 7.52124C5.97227 7.32654 6.18065 7.21484 6.41449 7.21484H6.90603H7.04337V7.07751V6.63743V6.50009H6.90603H1.63734H1.5V6.63743V7.07751V7.21484H1.63734H1.96072C2.91554 7.21484 3.77116 7.78379 4.1405 8.66426L8.33049 18.6529C8.40379 18.8276 8.57372 18.9405 8.76347 18.9405C8.93762 18.9405 9.09139 18.849 9.17485 18.6958L9.72141 17.6928L11.8081 13.8635L13.8171 18.6528C13.8904 18.8275 14.0603 18.9404 14.2501 18.9404C14.4242 18.9404 14.578 18.849 14.6614 18.6958L15.208 17.6928L20.2703 8.40327C20.6698 7.67016 21.4366 7.21475 22.2715 7.21475H22.3627H22.5V7.07741V6.63734V6.5H22.3627V6.50009Z" fill="#222A30"/>
</svg>

After

Width:  |  Height:  |  Size: 1.5 KiB

View File

@ -0,0 +1,37 @@
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
from pydantic import BaseModel, Field
from typing import Any, Dict, List, Union
from langchain import WikipediaAPIWrapper
from langchain.tools import WikipediaQueryRun
class WikipediaInput(BaseModel):
query: str = Field(..., description="search query.")
class WikiPediaSearchTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_paramters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
query = tool_paramters.get('query', '')
if not query:
return self.create_text_message('Please input query')
tool = WikipediaQueryRun(
name="wikipedia",
api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
args_schema=WikipediaInput
)
result = tool.run(tool_input={
'query': query
})
return self.create_text_message(self.summary(user_id=user_id,content=result))

View File

@ -0,0 +1,24 @@
identity:
name: wikipedia_search
author: Dify
label:
en_US: WikipediaSearch
zh_Hans: 维基百科搜索
icon: icon.svg
description:
human:
en_US: A tool for performing a Wikipedia search and extracting snippets and webpages.
zh_Hans: 一个用于执行维基百科搜索并提取片段和网页的工具。
llm: A tool for performing a Wikipedia search and extracting snippets and webpages. Input should be a search query.
parameters:
- name: query
type: string
required: true
label:
en_US: Query string
zh_Hans: 查询语句
human_description:
en_US: key words for searching
zh_Hans: 查询关键词
llm_description: key words for searching
form: llm

View File

@ -0,0 +1,20 @@
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.wikipedia.tools.wikipedia_search import WikiPediaSearchTool
class WikiPediaProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
WikiPediaSearchTool().fork_tool_runtime(
meta={
"credentials": credentials,
}
).invoke(
user_id='',
tool_paramters={
"query": "misaka mikoto",
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -0,0 +1,11 @@
identity:
author: Dify
name: wikipedia
label:
en_US: Wikipedia
zh_Hans: 维基百科
description:
en_US: Wikipedia is a free online encyclopedia, created and edited by volunteers around the world.
zh_Hans: 维基百科是一个由全世界的志愿者创建和编辑的免费在线百科全书。
icon: icon.svg
credentails_for_provider:

View File

@ -0,0 +1,23 @@
<svg xmlns="http://www.w3.org/2000/svg" width="22" height="23" viewBox="0 0 22 23" fill="none">
<path d="M21.4992 14.615L17.0326 15.5683L17.4865 20.0065L13.3037 18.2132L10.9994 22.0669L8.69549 18.2132L4.51225 20.0065L4.96656 15.5678L0.5 14.615L3.542 11.2832L0.5 7.9519L4.96656 6.99854L4.51225 2.55981L8.69549 4.35369L10.9994 0.5L13.3037 4.35369L17.4865 2.56031L17.0326 7.00401L21.4992 7.9519L18.4572 11.2832L21.4992 14.615Z" fill="#F16850"/>
<path d="M10.9993 7.23111L8.69495 4.35315L4.51221 2.56026L7.00396 5.84084L10.9993 7.23111Z" fill="#FD694F"/>
<path d="M4.96656 6.99847L0.5 7.95183L3.542 11.2831L7.11734 9.9838L4.96656 6.99847Z" fill="#FF3413"/>
<path d="M7.00346 5.84037L4.51221 2.55978L4.96602 6.99851L7.11729 9.98384L7.00346 5.84037Z" fill="#DC1D23"/>
<path d="M13.3031 4.35369L10.9987 0.5L8.69434 4.35369L10.9987 7.23116L13.3031 4.35369Z" fill="#FF9281"/>
<path d="M18.4577 11.2831L21.4997 7.95183L17.0331 6.99847L14.8818 9.9838L18.4577 11.2831Z" fill="#FF8B79"/>
<path d="M14.8823 9.98384L17.0331 6.99851L17.4929 2.55978L14.9957 5.84037L14.8818 9.98384H14.8823Z" fill="#FD694F"/>
<path d="M14.9954 5.84034L17.4926 2.55975L13.3044 4.35364L11 7.23111L14.9954 5.84034Z" fill="#EF5240"/>
<path d="M17.47 13.2694L21.4997 14.6149L18.4577 11.2831L14.8818 9.98383L17.47 13.2694Z" fill="#FF482C"/>
<path d="M7.11783 9.98383L3.542 11.2831L0.5 14.6149L4.52965 13.2699L7.11783 9.98383Z" fill="#EC2101"/>
<path d="M11 17.8612V22.0664L13.3044 18.2132L13.4008 14.439L11 17.8612Z" fill="#D21C22"/>
<path d="M17.4703 13.2693L13.4009 14.4389L17.0334 15.5682L21.4999 14.6149L17.4703 13.2693Z" fill="#C90901"/>
<path d="M13.3042 18.2132L17.4874 20.0065L17.0331 15.5683L13.4011 14.439L13.3042 18.2132Z" fill="#EC2101"/>
<path d="M4.52965 13.2693L0.5 14.6154L4.96656 15.5632L8.59906 14.4394L4.52965 13.2703V13.2693Z" fill="#B6171E"/>
<path d="M8.59912 14.439L8.69555 18.2132L10.9999 22.0669V17.8612L8.59912 14.439Z" fill="#B4151B"/>
<path d="M4.96602 15.5623L4.51221 20.006L8.69495 18.2131L8.59852 14.439L4.96602 15.5623Z" fill="#D21C22"/>
<path d="M14.882 9.98384L14.9954 5.84036L11 7.23113V11.2608L14.882 9.98384Z" fill="#E63320"/>
<path d="M11.0003 7.23113L7.00391 5.84036L7.11773 9.98384L10.9998 11.2608L11.0003 7.23113Z" fill="#FF4527"/>
<path d="M8.59912 14.439L10.9999 17.8613L13.4007 14.439L10.9999 11.2608L8.59912 14.439Z" fill="#FF9281"/>
<path d="M11 11.2608L13.4008 14.439L17.4702 13.2699L14.882 9.98386L11 11.2608Z" fill="#FD684D"/>
<path d="M7.1165 9.9839L4.52832 13.2694L8.59773 14.439L10.9985 11.2608L7.1165 9.9839Z" fill="#FD745C"/>
</svg>

After

Width:  |  Height:  |  Size: 2.5 KiB

View File

@ -0,0 +1,77 @@
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.errors import ToolProviderCredentialValidationError, ToolInvokeError
from typing import Any, Dict, List, Union
from httpx import get
class WolframAlphaTool(BuiltinTool):
_base_url = 'https://api.wolframalpha.com/v2/query'
def _invoke(self,
user_id: str,
tool_paramters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
query = tool_paramters.get('query', '')
if not query:
return self.create_text_message('Please input query')
appid = self.runtime.credentials.get('appid', '')
if not appid:
raise ToolProviderCredentialValidationError('Please input appid')
params = {
'appid': appid,
'input': query,
'includepodid': 'Result',
'format': 'plaintext',
'output': 'json'
}
finished = False
result = None
# try 3 times at most
counter = 0
while not finished and counter < 3:
counter += 1
try:
response = get(self._base_url, params=params, timeout=20)
response.raise_for_status()
response_data = response.json()
except Exception as e:
raise ToolInvokeError(str(e))
if 'success' not in response_data['queryresult'] or response_data['queryresult']['success'] != True:
query_result = response_data.get('queryresult', {})
if 'error' in query_result and query_result['error']:
if 'msg' in query_result['error']:
if query_result['error']['msg'] == 'Invalid appid':
raise ToolProviderCredentialValidationError('Invalid appid')
raise ToolInvokeError('Failed to invoke tool')
if 'didyoumeans' in response_data['queryresult']:
# get the most likely interpretation
query = ''
max_score = 0
for didyoumean in response_data['queryresult']['didyoumeans']:
if float(didyoumean['score']) > max_score:
query = didyoumean['val']
max_score = float(didyoumean['score'])
params['input'] = query
else:
finished = True
if 'souces' in response_data['queryresult']:
return self.create_link_message(response_data['queryresult']['sources']['url'])
elif 'pods' in response_data['queryresult']:
result = response_data['queryresult']['pods'][0]['subpods'][0]['plaintext']
if not finished or not result:
return self.create_text_message('No result found')
return self.create_text_message(result)

View File

@ -0,0 +1,23 @@
identity:
name: wolframalpha
author: Dify
label:
en_US: WolframAlpha
zh_Hans: WolframAlpha
description:
human:
en_US: WolframAlpha is a powerful computational knowledge engine.
zh_Hans: WolframAlpha 是一个强大的计算知识引擎。
llm: WolframAlpha is a powerful computational knowledge engine. one single query can get the answer of a question.
parameters:
- name: query
type: string
required: true
label:
en_US: Query string
zh_Hans: 计算语句
human_description:
en_US: used for calculating
zh_Hans: 用于计算最终结果
llm_description: a single query for calculating
form: llm

View File

@ -0,0 +1,24 @@
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
from core.tools.tool.tool import Tool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.wolframalpha.tools.wolframalpha import WolframAlphaTool
from typing import Any, Dict, List
class GoogleProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
try:
WolframAlphaTool().fork_tool_runtime(
meta={
"credentials": credentials,
}
).invoke(
user_id='',
tool_paramters={
"query": "1+2+....+111",
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -0,0 +1,24 @@
identity:
author: Dify
name: wolframalpha
label:
en_US: WolframAlpha
zh_Hans: WolframAlpha
description:
en_US: WolframAlpha is a powerful computational knowledge engine.
zh_Hans: WolframAlpha 是一个强大的计算知识引擎。
icon: icon.svg
credentails_for_provider:
appid:
type: secret-input
required: true
label:
en_US: WolframAlpha AppID
zh_Hans: WolframAlpha AppID
placeholder:
en_US: Please input your WolframAlpha AppID
zh_Hans: 请输入你的 WolframAlpha AppID
help:
en_US: Get your WolframAlpha AppID from WolframAlpha, please use "full results" api access.
zh_Hans: 从 WolframAlpha 获取您的 WolframAlpha AppID请使用 "full results" API。
url: https://products.wolframalpha.com/api

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.5 KiB

View File

@ -0,0 +1,69 @@
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
from typing import Any, Dict, List, Union
from requests.exceptions import HTTPError, ReadTimeout
from datetime import datetime
from yfinance import download
import pandas as pd
class YahooFinanceAnalyticsTool(BuiltinTool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
symbol = tool_paramters.get('symbol', '')
if not symbol:
return self.create_text_message('Please input symbol')
time_range = [None, None]
start_date = tool_paramters.get('start_date', '')
if start_date:
time_range[0] = start_date
else:
time_range[0] = '1800-01-01'
end_date = tool_paramters.get('end_date', '')
if end_date:
time_range[1] = end_date
else:
time_range[1] = datetime.now().strftime('%Y-%m-%d')
stock_data = download(symbol, start=time_range[0], end=time_range[1])
max_segments = min(15, len(stock_data))
rows_per_segment = len(stock_data) // max_segments
summary_data = []
for i in range(max_segments):
start_idx = i * rows_per_segment
end_idx = (i + 1) * rows_per_segment if i < max_segments - 1 else len(stock_data)
segment_data = stock_data.iloc[start_idx:end_idx]
segment_summary = {
'Start Date': segment_data.index[0],
'End Date': segment_data.index[-1],
'Average Close': segment_data['Close'].mean(),
'Average Volume': segment_data['Volume'].mean(),
'Average Open': segment_data['Open'].mean(),
'Average High': segment_data['High'].mean(),
'Average Low': segment_data['Low'].mean(),
'Average Adj Close': segment_data['Adj Close'].mean(),
'Max Close': segment_data['Close'].max(),
'Min Close': segment_data['Close'].min(),
'Max Volume': segment_data['Volume'].max(),
'Min Volume': segment_data['Volume'].min(),
'Max Open': segment_data['Open'].max(),
'Min Open': segment_data['Open'].min(),
'Max High': segment_data['High'].max(),
'Min High': segment_data['High'].min(),
}
summary_data.append(segment_summary)
summary_df = pd.DataFrame(summary_data)
try:
return self.create_text_message(str(summary_df.to_dict()))
except (HTTPError, ReadTimeout):
return self.create_text_message(f'There is a internet connection problem. Please try again later.')

View File

@ -0,0 +1,46 @@
identity:
name: yahoo_finance_analytics
author: Dify
label:
en_US: Analytics
zh_Hans: 分析
icon: icon.svg
description:
human:
en_US: A tool for get analytics about a ticker from Yahoo Finance.
zh_Hans: 一个用于从雅虎财经获取分析数据的工具。
llm: A tool for get analytics from Yahoo Finance. Input should be the ticker symbol like AAPL.
parameters:
- name: symbol
type: string
required: true
label:
en_US: Ticker symbol
zh_Hans: 股票代码
human_description:
en_US: The ticker symbol of the company you want to analyze.
zh_Hans: 你想要搜索的公司的股票代码。
llm_description: The ticker symbol of the company you want to analyze.
form: llm
- name: start_date
type: string
required: false
label:
en_US: Start date
zh_Hans: 开始日期
human_description:
en_US: The start date of the analytics.
zh_Hans: 分析的开始日期。
llm_description: The start date of the analytics, the format of the date must be YYYY-MM-DD like 2020-01-01.
form: llm
- name: end_date
type: string
required: false
label:
en_US: End date
zh_Hans: 结束日期
human_description:
en_US: The end date of the analytics.
zh_Hans: 分析的结束日期。
llm_description: The end date of the analytics, the format of the date must be YYYY-MM-DD like 2024-01-01.
form: llm

View File

@ -0,0 +1,46 @@
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
from typing import Any, Dict, List, Union
from requests.exceptions import HTTPError, ReadTimeout
import yfinance
class YahooFinanceSearchTickerTool(BuiltinTool):
def _invoke(self,user_id: str, tool_paramters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
'''
invoke tools
'''
query = tool_paramters.get('symbol', '')
if not query:
return self.create_text_message('Please input symbol')
try:
return self.run(ticker=query, user_id=user_id)
except (HTTPError, ReadTimeout):
return self.create_text_message(f'There is a internet connection problem. Please try again later.')
def run(self, ticker: str, user_id: str) -> ToolInvokeMessage:
company = yfinance.Ticker(ticker)
try:
if company.isin is None:
return self.create_text_message(f'Company ticker {ticker} not found.')
except (HTTPError, ReadTimeout, ConnectionError):
return self.create_text_message(f'Company ticker {ticker} not found.')
links = []
try:
links = [n['link'] for n in company.news if n['type'] == 'STORY']
except (HTTPError, ReadTimeout, ConnectionError):
if not links:
return self.create_text_message(f'There is nothing about {ticker} ticker')
if not links:
return self.create_text_message(f'No news found for company that searched with {ticker} ticker.')
result = '\n\n'.join([
self.get_url(link) for link in links
])
return self.create_text_message(self.summary(user_id=user_id, content=result))

View File

@ -0,0 +1,24 @@
identity:
name: yahoo_finance_news
author: Dify
label:
en_US: News
zh_Hans: 新闻
icon: icon.svg
description:
human:
en_US: A tool for get news about a ticker from Yahoo Finance.
zh_Hans: 一个用于从雅虎财经获取新闻的工具。
llm: A tool for get news from Yahoo Finance. Input should be the ticker symbol like AAPL.
parameters:
- name: symbol
type: string
required: true
label:
en_US: Ticker symbol
zh_Hans: 股票代码
human_description:
en_US: The ticker symbol of the company you want to search.
zh_Hans: 你想要搜索的公司的股票代码。
llm_description: The ticker symbol of the company you want to search.
form: llm

View File

@ -0,0 +1,25 @@
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
from typing import Any, Dict, List, Union
from requests.exceptions import HTTPError, ReadTimeout
from yfinance import Ticker
class YahooFinanceSearchTickerTool(BuiltinTool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
query = tool_paramters.get('symbol', '')
if not query:
return self.create_text_message('Please input symbol')
try:
return self.create_text_message(self.run(ticker=query))
except (HTTPError, ReadTimeout):
return self.create_text_message(f'There is a internet connection problem. Please try again later.')
def run(self, ticker: str) -> str:
return str(Ticker(ticker).info)

View File

@ -0,0 +1,24 @@
identity:
name: yahoo_finance_ticker
author: Dify
label:
en_US: Ticker
zh_Hans: 股票信息
icon: icon.svg
description:
human:
en_US: A tool for search ticker information from Yahoo Finance.
zh_Hans: 一个用于从雅虎财经搜索股票信息的工具。
llm: A tool for search ticker information from Yahoo Finance. Input should be the ticker symbol like AAPL.
parameters:
- name: symbol
type: string
required: true
label:
en_US: Ticker symbol
zh_Hans: 股票代码
human_description:
en_US: The ticker symbol of the company you want to search.
zh_Hans: 你想要搜索的公司的股票代码。
llm_description: The ticker symbol of the company you want to search.
form: llm

View File

@ -0,0 +1,20 @@
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.yahoo.tools.ticker import YahooFinanceSearchTickerTool
class YahooFinanceProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
YahooFinanceSearchTickerTool().fork_tool_runtime(
meta={
"credentials": credentials,
}
).invoke(
user_id='',
tool_paramters={
"ticker": "MSFT",
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -0,0 +1,11 @@
identity:
author: Dify
name: yahoo
label:
en_US: YahooFinance
zh_Hans: 雅虎财经
description:
en_US: Finance, and Yahoo! get the latest news, stock quotes, and interactive chart with Yahoo!
zh_Hans: 雅虎财经,获取并整理出最新的新闻、股票报价等一切你想要的财经信息。
icon: icon.png
credentails_for_provider:

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.0 KiB

View File

@ -0,0 +1,66 @@
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
from typing import Any, Dict, List, Union
from datetime import datetime
from googleapiclient.discovery import build
class YoutubeVideosAnalyticsTool(BuiltinTool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
channel = tool_paramters.get('channel', '')
if not channel:
return self.create_text_message('Please input symbol')
time_range = [None, None]
start_date = tool_paramters.get('start_date', '')
if start_date:
time_range[0] = start_date
else:
time_range[0] = '1800-01-01'
end_date = tool_paramters.get('end_date', '')
if end_date:
time_range[1] = end_date
else:
time_range[1] = datetime.now().strftime('%Y-%m-%d')
if 'google_api_key' not in self.runtime.credentials or not self.runtime.credentials['google_api_key']:
return self.create_text_message('Please input api key')
youtube = build('youtube', 'v3', developerKey=self.runtime.credentials['google_api_key'])
# try to get channel id
search_results = youtube.search().list(q='mrbeast', type='channel', order='relevance', part='id').execute()
channel_id = search_results['items'][0]['id']['channelId']
start_date, end_date = time_range
start_date = datetime.strptime(start_date, '%Y-%m-%d').strftime('%Y-%m-%dT%H:%M:%SZ')
end_date = datetime.strptime(end_date, '%Y-%m-%d').strftime('%Y-%m-%dT%H:%M:%SZ')
# get videos
time_range_videos = youtube.search().list(
part='snippet', channelId=channel_id, order='date', type='video',
publishedAfter=start_date,
publishedBefore=end_date
).execute()
def extract_video_data(video_list):
data = []
for video in video_list['items']:
video_id = video['id']['videoId']
video_info = youtube.videos().list(part='snippet,statistics', id=video_id).execute()
title = video_info['items'][0]['snippet']['title']
views = video_info['items'][0]['statistics']['viewCount']
data.append({'Title': title, 'Views': views})
return data
summary = extract_video_data(time_range_videos)
return self.create_text_message(str(summary))

View File

@ -0,0 +1,46 @@
identity:
name: youtube_video_statistics
author: Dify
label:
en_US: Video statistics
zh_Hans: 视频统计
icon: icon.svg
description:
human:
en_US: A tool for get statistics about a channel's videos.
zh_Hans: 一个用于获取油管频道视频统计数据的工具。
llm: A tool for get statistics about a channel's videos. Input should be the name of the channel like PewDiePie.
parameters:
- name: channel
type: string
required: true
label:
en_US: Channel name
zh_Hans: 频道名
human_description:
en_US: The name of the channel you want to search.
zh_Hans: 你想要搜索的油管频道名。
llm_description: The name of the channel you want to search.
form: llm
- name: start_date
type: string
required: false
label:
en_US: Start date
zh_Hans: 开始日期
human_description:
en_US: The start date of the analytics.
zh_Hans: 分析的开始日期。
llm_description: The start date of the analytics, the format of the date must be YYYY-MM-DD like 2020-01-01.
form: llm
- name: end_date
type: string
required: false
label:
en_US: End date
zh_Hans: 结束日期
human_description:
en_US: The end date of the analytics.
zh_Hans: 分析的结束日期。
llm_description: The end date of the analytics, the format of the date must be YYYY-MM-DD like 2024-01-01.
form: llm

View File

@ -0,0 +1,22 @@
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.youtube.tools.videos import YoutubeVideosAnalyticsTool
class YahooFinanceProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
YoutubeVideosAnalyticsTool().fork_tool_runtime(
meta={
"credentials": credentials,
}
).invoke(
user_id='',
tool_paramters={
"channel": "TOKYO GIRLS COLLECTION",
"start_date": "2020-01-01",
"end_date": "2024-12-31",
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -0,0 +1,24 @@
identity:
author: Dify
name: youtube
label:
en_US: Youtube
zh_Hans: Youtube
description:
en_US: Youtube
zh_Hans: Youtube油管是全球最大的视频分享网站用户可以在上面上传、观看和分享视频。
icon: icon.png
credentails_for_provider:
google_api_key:
type: secret-input
required: true
label:
en_US: Google API key
zh_Hans: Google API key
placeholder:
en_US: Please input your Google API key
zh_Hans: 请输入你的 Google API key
help:
en_US: Get your Google API key from Google
zh_Hans: 从 Google 获取您的 Google API key
url: https://console.developers.google.com/apis/credentials

View File

@ -0,0 +1,286 @@
from abc import abstractmethod
from typing import List, Dict, Any
from os import path, listdir
from yaml import load, FullLoader
from core.tools.entities.tool_entities import ToolProviderType, \
ToolParamter, ToolProviderCredentials
from core.tools.tool.tool import Tool
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.entities.user_entities import UserToolProviderCredentials
from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError, \
ToolParamterValidationError, ToolProviderCredentialValidationError
import importlib
class BuiltinToolProviderController(ToolProviderController):
def __init__(self, **data: Any) -> None:
if self.app_type == ToolProviderType.API_BASED or self.app_type == ToolProviderType.APP_BASED:
super().__init__(**data)
return
# load provider yaml
provider = self.__class__.__module__.split('.')[-1]
yaml_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.yaml')
try:
with open(yaml_path, 'r') as f:
provider_yaml = load(f.read(), FullLoader)
except:
raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}')
if 'credentails_for_provider' in provider_yaml and provider_yaml['credentails_for_provider'] is not None:
# set credentials name
for credential_name in provider_yaml['credentails_for_provider']:
provider_yaml['credentails_for_provider'][credential_name]['name'] = credential_name
super().__init__(**{
'identity': provider_yaml['identity'],
'credentials_schema': provider_yaml['credentails_for_provider'] if 'credentails_for_provider' in provider_yaml else None,
})
def _get_bulitin_tools(self) -> List[Tool]:
"""
returns a list of tools that the provider can provide
:return: list of tools
"""
if self.tools:
return self.tools
provider = self.identity.name
tool_path = path.join(path.dirname(path.realpath(__file__)), "builtin", provider, "tools")
# get all the yaml files in the tool path
tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path)))
tools = []
for tool_file in tool_files:
with open(path.join(tool_path, tool_file), "r") as f:
# get tool name
tool_name = tool_file.split(".")[0]
tool = load(f.read(), FullLoader)
# get tool class, import the module
py_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, 'tools', f'{tool_name}.py')
spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.tools.{tool_name}', py_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
# get all the classes in the module
classes = [x for _, x in vars(mod).items()
if isinstance(x, type) and x not in [BuiltinTool, Tool] and issubclass(x, BuiltinTool)
]
assistant_tool_class = classes[0]
tools.append(assistant_tool_class(**tool))
self.tools = tools
return tools
def get_credentails_schema(self) -> Dict[str, ToolProviderCredentials]:
"""
returns the credentials schema of the provider
:return: the credentials schema
"""
if not self.credentials_schema:
return {}
return self.credentials_schema.copy()
def user_get_credentails_schema(self) -> UserToolProviderCredentials:
"""
returns the credentials schema of the provider, this method is used for user
:return: the credentials schema
"""
credentials = self.credentials_schema.copy()
return UserToolProviderCredentials(credentails=credentials)
def get_tools(self) -> List[Tool]:
"""
returns a list of tools that the provider can provide
:return: list of tools
"""
return self._get_bulitin_tools()
def get_tool(self, tool_name: str) -> Tool:
"""
returns the tool that the provider can provide
"""
return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
def get_parameters(self, tool_name: str) -> List[ToolParamter]:
"""
returns the parameters of the tool
:param tool_name: the name of the tool, defined in `get_tools`
:return: list of parameters
"""
tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
if tool is None:
raise ToolNotFoundError(f'tool {tool_name} not found')
return tool.parameters
@property
def need_credentials(self) -> bool:
"""
returns whether the provider needs credentials
:return: whether the provider needs credentials
"""
return self.credentials_schema is not None and len(self.credentials_schema) != 0
@property
def app_type(self) -> ToolProviderType:
"""
returns the type of the provider
:return: type of the provider
"""
return ToolProviderType.BUILT_IN
def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: Dict[str, Any]) -> None:
"""
validate the parameters of the tool and set the default value if needed
:param tool_name: the name of the tool, defined in `get_tools`
:param tool_parameters: the parameters of the tool
"""
tool_parameters_schema = self.get_parameters(tool_name)
tool_parameters_need_to_validate: Dict[str, ToolParamter] = {}
for parameter in tool_parameters_schema:
tool_parameters_need_to_validate[parameter.name] = parameter
for parameter in tool_parameters:
if parameter not in tool_parameters_need_to_validate:
raise ToolParamterValidationError(f'parameter {parameter} not found in tool {tool_name}')
# check type
parameter_schema = tool_parameters_need_to_validate[parameter]
if parameter_schema.type == ToolParamter.ToolParameterType.STRING:
if not isinstance(tool_parameters[parameter], str):
raise ToolParamterValidationError(f'parameter {parameter} should be string')
elif parameter_schema.type == ToolParamter.ToolParameterType.NUMBER:
if not isinstance(tool_parameters[parameter], (int, float)):
raise ToolParamterValidationError(f'parameter {parameter} should be number')
if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min:
raise ToolParamterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}')
if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max:
raise ToolParamterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}')
elif parameter_schema.type == ToolParamter.ToolParameterType.BOOLEAN:
if not isinstance(tool_parameters[parameter], bool):
raise ToolParamterValidationError(f'parameter {parameter} should be boolean')
elif parameter_schema.type == ToolParamter.ToolParameterType.SELECT:
if not isinstance(tool_parameters[parameter], str):
raise ToolParamterValidationError(f'parameter {parameter} should be string')
options = parameter_schema.options
if not isinstance(options, list):
raise ToolParamterValidationError(f'parameter {parameter} options should be list')
if tool_parameters[parameter] not in [x.value for x in options]:
raise ToolParamterValidationError(f'parameter {parameter} should be one of {options}')
tool_parameters_need_to_validate.pop(parameter)
for parameter in tool_parameters_need_to_validate:
parameter_schema = tool_parameters_need_to_validate[parameter]
if parameter_schema.required:
raise ToolParamterValidationError(f'parameter {parameter} is required')
# the parameter is not set currently, set the default value if needed
if parameter_schema.default is not None:
default_value = parameter_schema.default
# parse default value into the correct type
if parameter_schema.type == ToolParamter.ToolParameterType.STRING or \
parameter_schema.type == ToolParamter.ToolParameterType.SELECT:
default_value = str(default_value)
elif parameter_schema.type == ToolParamter.ToolParameterType.NUMBER:
default_value = float(default_value)
elif parameter_schema.type == ToolParamter.ToolParameterType.BOOLEAN:
default_value = bool(default_value)
tool_parameters[parameter] = default_value
def validate_credentials_format(self, credentials: Dict[str, Any]) -> None:
"""
validate the format of the credentials of the provider and set the default value if needed
:param credentials: the credentials of the tool
"""
credentials_schema = self.credentials_schema
if credentials_schema is None:
return
credentials_need_to_validate: Dict[str, ToolProviderCredentials] = {}
for credential_name in credentials_schema:
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
for credential_name in credentials:
if credential_name not in credentials_need_to_validate:
raise ToolProviderCredentialValidationError(f'credential {credential_name} not found in provider {self.identity.name}')
# check type
credential_schema = credentials_need_to_validate[credential_name]
if credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \
credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT:
if not isinstance(credentials[credential_name], str):
raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string')
elif credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT:
if not isinstance(credentials[credential_name], str):
raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string')
options = credential_schema.options
if not isinstance(options, list):
raise ToolProviderCredentialValidationError(f'credential {credential_name} options should be list')
if credentials[credential_name] not in [x.value for x in options]:
raise ToolProviderCredentialValidationError(f'credential {credential_name} should be one of {options}')
credentials_need_to_validate.pop(credential_name)
for credential_name in credentials_need_to_validate:
credential_schema = credentials_need_to_validate[credential_name]
if credential_schema.required:
raise ToolProviderCredentialValidationError(f'credential {credential_name} is required')
# the credential is not set currently, set the default value if needed
if credential_schema.default is not None:
default_value = credential_schema.default
# parse default value into the correct type
if credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \
credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT or \
credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT:
default_value = str(default_value)
credentials[credential_name] = default_value
def validate_credentials(self, credentials: Dict[str, Any]) -> None:
"""
validate the credentials of the provider
:param tool_name: the name of the tool, defined in `get_tools`
:param credentials: the credentials of the tool
"""
# validate credentials format
self.validate_credentials_format(credentials)
# validate credentials
self._validate_credentials(credentials)
@abstractmethod
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
"""
validate the credentials of the provider
:param tool_name: the name of the tool, defined in `get_tools`
:param credentials: the credentials of the tool
"""
pass

View File

@ -0,0 +1,218 @@
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
from pydantic import BaseModel
from core.tools.entities.tool_entities import ToolProviderType, \
ToolProviderIdentity, ToolParamter, ToolProviderCredentials
from core.tools.tool.tool import Tool
from core.tools.entities.user_entities import UserToolProviderCredentials
from core.tools.errors import ToolNotFoundError, \
ToolParamterValidationError, ToolProviderCredentialValidationError
class ToolProviderController(BaseModel, ABC):
identity: Optional[ToolProviderIdentity] = None
tools: Optional[List[Tool]] = None
credentials_schema: Optional[Dict[str, ToolProviderCredentials]] = None
def get_credentails_schema(self) -> Dict[str, ToolProviderCredentials]:
"""
returns the credentials schema of the provider
:return: the credentials schema
"""
return self.credentials_schema.copy()
def user_get_credentails_schema(self) -> UserToolProviderCredentials:
"""
returns the credentials schema of the provider, this method is used for user
:return: the credentials schema
"""
credentials = self.credentials_schema.copy()
return UserToolProviderCredentials(credentails=credentials)
@abstractmethod
def get_tools(self) -> List[Tool]:
"""
returns a list of tools that the provider can provide
:return: list of tools
"""
pass
@abstractmethod
def get_tool(self, tool_name: str) -> Tool:
"""
returns a tool that the provider can provide
:return: tool
"""
pass
def get_parameters(self, tool_name: str) -> List[ToolParamter]:
"""
returns the parameters of the tool
:param tool_name: the name of the tool, defined in `get_tools`
:return: list of parameters
"""
tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
if tool is None:
raise ToolNotFoundError(f'tool {tool_name} not found')
return tool.parameters
@property
def app_type(self) -> ToolProviderType:
"""
returns the type of the provider
:return: type of the provider
"""
return ToolProviderType.BUILT_IN
def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: Dict[str, Any]) -> None:
"""
validate the parameters of the tool and set the default value if needed
:param tool_name: the name of the tool, defined in `get_tools`
:param tool_parameters: the parameters of the tool
"""
tool_parameters_schema = self.get_parameters(tool_name)
tool_parameters_need_to_validate: Dict[str, ToolParamter] = {}
for parameter in tool_parameters_schema:
tool_parameters_need_to_validate[parameter.name] = parameter
for parameter in tool_parameters:
if parameter not in tool_parameters_need_to_validate:
raise ToolParamterValidationError(f'parameter {parameter} not found in tool {tool_name}')
# check type
parameter_schema = tool_parameters_need_to_validate[parameter]
if parameter_schema.type == ToolParamter.ToolParameterType.STRING:
if not isinstance(tool_parameters[parameter], str):
raise ToolParamterValidationError(f'parameter {parameter} should be string')
elif parameter_schema.type == ToolParamter.ToolParameterType.NUMBER:
if not isinstance(tool_parameters[parameter], (int, float)):
raise ToolParamterValidationError(f'parameter {parameter} should be number')
if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min:
raise ToolParamterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}')
if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max:
raise ToolParamterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}')
elif parameter_schema.type == ToolParamter.ToolParameterType.BOOLEAN:
if not isinstance(tool_parameters[parameter], bool):
raise ToolParamterValidationError(f'parameter {parameter} should be boolean')
elif parameter_schema.type == ToolParamter.ToolParameterType.SELECT:
if not isinstance(tool_parameters[parameter], str):
raise ToolParamterValidationError(f'parameter {parameter} should be string')
options = parameter_schema.options
if not isinstance(options, list):
raise ToolParamterValidationError(f'parameter {parameter} options should be list')
if tool_parameters[parameter] not in [x.value for x in options]:
raise ToolParamterValidationError(f'parameter {parameter} should be one of {options}')
tool_parameters_need_to_validate.pop(parameter)
for parameter in tool_parameters_need_to_validate:
parameter_schema = tool_parameters_need_to_validate[parameter]
if parameter_schema.required:
raise ToolParamterValidationError(f'parameter {parameter} is required')
# the parameter is not set currently, set the default value if needed
if parameter_schema.default is not None:
default_value = parameter_schema.default
# parse default value into the correct type
if parameter_schema.type == ToolParamter.ToolParameterType.STRING or \
parameter_schema.type == ToolParamter.ToolParameterType.SELECT:
default_value = str(default_value)
elif parameter_schema.type == ToolParamter.ToolParameterType.NUMBER:
default_value = float(default_value)
elif parameter_schema.type == ToolParamter.ToolParameterType.BOOLEAN:
default_value = bool(default_value)
tool_parameters[parameter] = default_value
def validate_credentials_format(self, credentials: Dict[str, Any]) -> None:
"""
validate the format of the credentials of the provider and set the default value if needed
:param credentials: the credentials of the tool
"""
credentials_schema = self.credentials_schema
if credentials_schema is None:
return
credentials_need_to_validate: Dict[str, ToolProviderCredentials] = {}
for credential_name in credentials_schema:
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
for credential_name in credentials:
if credential_name not in credentials_need_to_validate:
raise ToolProviderCredentialValidationError(f'credential {credential_name} not found in provider {self.identity.name}')
# check type
credential_schema = credentials_need_to_validate[credential_name]
if credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \
credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT:
if not isinstance(credentials[credential_name], str):
raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string')
elif credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT:
if not isinstance(credentials[credential_name], str):
raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string')
options = credential_schema.options
if not isinstance(options, list):
raise ToolProviderCredentialValidationError(f'credential {credential_name} options should be list')
if credentials[credential_name] not in [x.value for x in options]:
raise ToolProviderCredentialValidationError(f'credential {credential_name} should be one of {options}')
credentials_need_to_validate.pop(credential_name)
for credential_name in credentials_need_to_validate:
credential_schema = credentials_need_to_validate[credential_name]
if credential_schema.required:
raise ToolProviderCredentialValidationError(f'credential {credential_name} is required')
# the credential is not set currently, set the default value if needed
if credential_schema.default is not None:
default_value = credential_schema.default
# parse default value into the correct type
if credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \
credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT or \
credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT:
default_value = str(default_value)
credentials[credential_name] = default_value
def validate_credentials(self, credentials: Dict[str, Any]) -> None:
"""
validate the credentials of the provider
:param tool_name: the name of the tool, defined in `get_tools`
:param credentials: the credentials of the tool
"""
# validate credentials format
self.validate_credentials_format(credentials)
# validate credentials
self._validate_credentials(credentials)
@abstractmethod
def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
"""
validate the credentials of the provider
:param tool_name: the name of the tool, defined in `get_tools`
:param credentials: the credentials of the tool
"""
pass

View File

@ -0,0 +1,222 @@
from typing import Any, Dict, List, Union
from json import dumps
from core.tools.entities.tool_bundle import ApiBasedToolBundle
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.tool import Tool
from core.tools.errors import ToolProviderCredentialValidationError
import httpx
import requests
class ApiTool(Tool):
api_bundle: ApiBasedToolBundle
"""
Api tool
"""
def fork_tool_runtime(self, meta: Dict[str, Any]) -> 'Tool':
"""
fork a new tool with meta data
:param meta: the meta data of a tool call processing, tenant_id is required
:return: the new tool
"""
return self.__class__(
identity=self.identity.copy() if self.identity else None,
parameters=self.parameters.copy() if self.parameters else None,
description=self.description.copy() if self.description else None,
api_bundle=self.api_bundle.copy() if self.api_bundle else None,
runtime=Tool.Runtime(**meta)
)
def validate_credentials(self, credentails: Dict[str, Any], parameters: Dict[str, Any], format_only: bool = False) -> None:
"""
validate the credentials for Api tool
"""
# assemble validate request and request parameters
headers = self.assembling_request(parameters)
if format_only:
return
response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, parameters)
# validate response
self.validate_and_parse_response(response)
def assembling_request(self, parameters: Dict[str, Any]) -> Dict[str, Any]:
headers = {}
credentials = self.runtime.credentials or {}
if 'auth_type' not in credentials:
raise ToolProviderCredentialValidationError('Missing auth_type')
if credentials['auth_type'] == 'api_key':
api_key_header = 'api_key'
if 'api_key_header' in credentials:
api_key_header = credentials['api_key_header']
if 'api_key_value' not in credentials:
raise ToolProviderCredentialValidationError('Missing api_key_value')
headers[api_key_header] = credentials['api_key_value']
needed_parameters = [parameter for parameter in self.api_bundle.parameters if parameter.required]
for parameter in needed_parameters:
if parameter.required and parameter.name not in parameters:
raise ToolProviderCredentialValidationError(f"Missing required parameter {parameter.name}")
if parameter.default is not None and parameter.name not in parameters:
parameters[parameter.name] = parameter.default
return headers
def validate_and_parse_response(self, response: Union[httpx.Response, requests.Response]) -> str:
"""
validate the response
"""
if isinstance(response, httpx.Response):
if response.status_code >= 400:
raise ToolProviderCredentialValidationError(f"Request failed with status code {response.status_code}")
return response.text
elif isinstance(response, requests.Response):
if not response.ok:
raise ToolProviderCredentialValidationError(f"Request failed with status code {response.status_code}")
return response.text
else:
raise ValueError(f'Invalid response type {type(response)}')
def do_http_request(self, url: str, method: str, headers: Dict[str, Any], parameters: Dict[str, Any]) -> httpx.Response:
"""
do http request depending on api bundle
"""
method = method.lower()
params = {}
path_params = {}
body = {}
cookies = {}
# check parameters
for parameter in self.api_bundle.openapi.get('parameters', []):
if parameter['in'] == 'path':
value = ''
if parameter['name'] in parameters:
value = parameters[parameter['name']]
elif parameter['required']:
raise ToolProviderCredentialValidationError(f"Missing required parameter {parameter['name']}")
path_params[parameter['name']] = value
elif parameter['in'] == 'query':
value = ''
if parameter['name'] in parameters:
value = parameters[parameter['name']]
elif parameter['required']:
raise ToolProviderCredentialValidationError(f"Missing required parameter {parameter['name']}")
params[parameter['name']] = value
elif parameter['in'] == 'cookie':
value = ''
if parameter['name'] in parameters:
value = parameters[parameter['name']]
elif parameter['required']:
raise ToolProviderCredentialValidationError(f"Missing required parameter {parameter['name']}")
cookies[parameter['name']] = value
elif parameter['in'] == 'header':
value = ''
if parameter['name'] in parameters:
value = parameters[parameter['name']]
elif parameter['required']:
raise ToolProviderCredentialValidationError(f"Missing required parameter {parameter['name']}")
headers[parameter['name']] = value
# check if there is a request body and handle it
if 'requestBody' in self.api_bundle.openapi and self.api_bundle.openapi['requestBody'] is not None:
# handle json request body
if 'content' in self.api_bundle.openapi['requestBody']:
for content_type in self.api_bundle.openapi['requestBody']['content']:
headers['Content-Type'] = content_type
body_schema = self.api_bundle.openapi['requestBody']['content'][content_type]['schema']
required = body_schema['required'] if 'required' in body_schema else []
properties = body_schema['properties'] if 'properties' in body_schema else {}
for name, property in properties.items():
if name in parameters:
# convert type
try:
value = parameters[name]
if property['type'] == 'integer':
value = int(value)
elif property['type'] == 'number':
# check if it is a float
if '.' in value:
value = float(value)
else:
value = int(value)
elif property['type'] == 'boolean':
value = bool(value)
body[name] = value
except ValueError as e:
body[name] = parameters[name]
elif name in required:
raise ToolProviderCredentialValidationError(
f"Missing required parameter {name} in operation {self.api_bundle.operation_id}"
)
elif 'default' in property:
body[name] = property['default']
else:
body[name] = None
break
# replace path parameters
for name, value in path_params.items():
url = url.replace(f'{{{name}}}', value)
# parse http body data if needed, for GET/HEAD/OPTIONS/TRACE, the body is ignored
if 'Content-Type' in headers:
if headers['Content-Type'] == 'application/json':
body = dumps(body)
else:
body = body
# do http request
if method == 'get':
response = httpx.get(url, params=params, headers=headers, cookies=cookies, timeout=10, follow_redirects=True)
elif method == 'post':
response = httpx.post(url, params=params, headers=headers, cookies=cookies, data=body, timeout=10, follow_redirects=True)
elif method == 'put':
response = httpx.put(url, params=params, headers=headers, cookies=cookies, data=body, timeout=10, follow_redirects=True)
elif method == 'delete':
"""
request body data is unsupported for DELETE method in standard http protocol
however, OpenAPI 3.0 supports request body data for DELETE method, so we support it here by using requests
"""
response = requests.delete(url, params=params, headers=headers, cookies=cookies, data=body, timeout=10, allow_redirects=True)
elif method == 'patch':
response = httpx.patch(url, params=params, headers=headers, cookies=cookies, data=body, timeout=10, follow_redirects=True)
elif method == 'head':
response = httpx.head(url, params=params, headers=headers, cookies=cookies, timeout=10, follow_redirects=True)
elif method == 'options':
response = httpx.options(url, params=params, headers=headers, cookies=cookies, timeout=10, follow_redirects=True)
else:
raise ValueError(f'Invalid http method {method}')
return response
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) -> ToolInvokeMessage | List[ToolInvokeMessage]:
"""
invoke http request
"""
# assemble request
headers = self.assembling_request(tool_paramters)
# do http request
response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, tool_paramters)
# validate response
response = self.validate_and_parse_response(response)
# assemble invoke message
return self.create_text_message(response)

View File

@ -0,0 +1,140 @@
from core.tools.tool.tool import Tool
from core.tools.model.tool_model_manager import ToolModelManager
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
from core.tools.utils.web_reader_tool import get_url
from typing import List
from enum import Enum
_SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language
and you can quickly aimed at the main point of an webpage and reproduce it in your own words but
retain the original meaning and keep the key points.
however, the text you got is too long, what you got is possible a part of the text.
Please summarize the text you got.
"""
class BuiltinTool(Tool):
"""
Builtin tool
:param meta: the meta data of a tool call processing
"""
def invoke_model(
self, user_id: str, prompt_messages: List[PromptMessage], stop: List[str]
) -> LLMResult:
"""
invoke model
:param model_config: the model config
:param prompt_messages: the prompt messages
:param stop: the stop words
:return: the model result
"""
# invoke model
return ToolModelManager.invoke(
user_id=user_id,
tenant_id=self.runtime.tenant_id,
tool_type='builtin',
tool_name=self.identity.name,
prompt_messages=prompt_messages,
)
def get_max_tokens(self) -> int:
"""
get max tokens
:param model_config: the model config
:return: the max tokens
"""
return ToolModelManager.get_max_llm_context_tokens(
tenant_id=self.runtime.tenant_id,
)
def get_prompt_tokens(self, prompt_messages: List[PromptMessage]) -> int:
"""
get prompt tokens
:param prompt_messages: the prompt messages
:return: the tokens
"""
return ToolModelManager.calculate_tokens(
tenant_id=self.runtime.tenant_id,
prompt_messages=prompt_messages
)
def summary(self, user_id: str, content: str) -> str:
max_tokens = self.get_max_tokens()
if self.get_prompt_tokens(prompt_messages=[
UserPromptMessage(content=content)
]) < max_tokens * 0.6:
return content
def get_prompt_tokens(content: str) -> int:
return self.get_prompt_tokens(prompt_messages=[
SystemPromptMessage(content=_SUMMARY_PROMPT),
UserPromptMessage(content=content)
])
def summarize(content: str) -> str:
summary = self.invoke_model(user_id=user_id, prompt_messages=[
SystemPromptMessage(content=_SUMMARY_PROMPT),
UserPromptMessage(content=content)
], stop=[])
return summary.message.content
lines = content.split('\n')
new_lines = []
# split long line into multiple lines
for i in range(len(lines)):
line = lines[i]
if not line.strip():
continue
if len(line) < max_tokens * 0.5:
new_lines.append(line)
elif get_prompt_tokens(line) > max_tokens * 0.7:
while get_prompt_tokens(line) > max_tokens * 0.7:
new_lines.append(line[:int(max_tokens * 0.5)])
line = line[int(max_tokens * 0.5):]
new_lines.append(line)
else:
new_lines.append(line)
# merge lines into messages with max tokens
messages: List[str] = []
for i in new_lines:
if len(messages) == 0:
messages.append(i)
else:
if len(messages[-1]) + len(i) < max_tokens * 0.5:
messages[-1] += i
if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7:
messages.append(i)
else:
messages[-1] += i
summaries = []
for i in range(len(messages)):
message = messages[i]
summary = summarize(message)
summaries.append(summary)
result = '\n'.join(summaries)
if self.get_prompt_tokens(prompt_messages=[
UserPromptMessage(content=result)
]) > max_tokens * 0.7:
return self.summary(user_id=user_id, content=result)
return result
def get_url(self, url: str, user_agent: str = None) -> str:
"""
get url
"""
return get_url(url, user_agent=user_agent)

View File

@ -0,0 +1,249 @@
import json
import threading
from typing import List, Optional, Type
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.embedding.cached_embedding import CacheEmbedding
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rerank.rerank import RerankRunner
from extensions.ext_database import db
from flask import Flask, current_app
from langchain.tools import BaseTool
from models.dataset import Dataset, Document, DocumentSegment
from pydantic import BaseModel, Field
from services.retrieval_service import RetrievalService
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enabled': False
}
class DatasetMultiRetrieverToolInput(BaseModel):
query: str = Field(..., description="dataset multi retriever and rerank")
class DatasetMultiRetrieverTool(BaseTool):
"""Tool for querying multi dataset."""
name: str = "dataset-"
args_schema: Type[BaseModel] = DatasetMultiRetrieverToolInput
description: str = "dataset multi retriever and rerank. "
tenant_id: str
dataset_ids: List[str]
top_k: int = 2
score_threshold: Optional[float] = None
reranking_provider_name: str
reranking_model_name: str
return_resource: bool
retriever_from: str
hit_callbacks: List[DatasetIndexToolCallbackHandler] = []
@classmethod
def from_dataset(cls, dataset_ids: List[str], tenant_id: str, **kwargs):
return cls(
name=f'dataset-{tenant_id}',
tenant_id=tenant_id,
dataset_ids=dataset_ids,
**kwargs
)
def _run(self, query: str) -> str:
threads = []
all_documents = []
for dataset_id in self.dataset_ids:
retrieval_thread = threading.Thread(target=self._retriever, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id,
'query': query,
'all_documents': all_documents,
'hit_callbacks': self.hit_callbacks
})
threads.append(retrieval_thread)
retrieval_thread.start()
for thread in threads:
thread.join()
# do rerank for searched documents
model_manager = ModelManager()
rerank_model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id,
provider=self.reranking_provider_name,
model_type=ModelType.RERANK,
model=self.reranking_model_name
)
rerank_runner = RerankRunner(rerank_model_instance)
all_documents = rerank_runner.run(query, all_documents, self.score_threshold, self.top_k)
for hit_callback in self.hit_callbacks:
hit_callback.on_tool_end(all_documents)
document_score_list = {}
for item in all_documents:
if 'score' in item.metadata and item.metadata['score']:
document_score_list[item.metadata['doc_id']] = item.metadata['score']
document_context_list = []
index_node_ids = [document.metadata['doc_id'] for document in all_documents]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids)
).all()
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
sorted_segments = sorted(segments,
key=lambda segment: index_node_id_to_position.get(segment.index_node_id,
float('inf')))
for segment in sorted_segments:
if segment.answer:
document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
else:
document_context_list.append(segment.content)
if self.return_resource:
context_list = []
resource_number = 1
for segment in sorted_segments:
dataset = Dataset.query.filter_by(
id=segment.dataset_id
).first()
document = Document.query.filter(Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
).first()
if dataset and document:
source = {
'position': resource_number,
'dataset_id': dataset.id,
'dataset_name': dataset.name,
'document_id': document.id,
'document_name': document.name,
'data_source_type': document.data_source_type,
'segment_id': segment.id,
'retriever_from': self.retriever_from,
'score': document_score_list.get(segment.index_node_id, None)
}
if self.retriever_from == 'dev':
source['hit_count'] = segment.hit_count
source['word_count'] = segment.word_count
source['segment_position'] = segment.position
source['index_node_hash'] = segment.index_node_hash
if segment.answer:
source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
else:
source['content'] = segment.content
context_list.append(source)
resource_number += 1
for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list)
return str("\n".join(document_context_list))
async def _arun(self, tool_input: str) -> str:
raise NotImplementedError()
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: List,
hit_callbacks: List[DatasetIndexToolCallbackHandler]):
with flask_app.app_context():
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
return []
for hit_callback in hit_callbacks:
hit_callback.on_query(query, dataset.id)
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
kw_table_index = KeywordTableIndex(
dataset=dataset,
config=KeywordTableConfig(
max_keywords_per_chunk=5
)
)
documents = kw_table_index.search(query, search_kwargs={'k': self.top_k})
if documents:
all_documents.extend(documents)
else:
try:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
except LLMBadRequestError:
return []
except ProviderTokenNotInitError:
return []
embeddings = CacheEmbedding(embedding_model)
documents = []
threads = []
if self.top_k > 0:
# retrieval_model source with semantic
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model[
'search_method'] == 'hybrid_search':
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': str(dataset.id),
'query': query,
'top_k': self.top_k,
'score_threshold': self.score_threshold,
'reranking_model': None,
'all_documents': documents,
'search_method': 'hybrid_search',
'embeddings': embeddings
})
threads.append(embedding_thread)
embedding_thread.start()
# retrieval_model source with full text
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model[
'search_method'] == 'hybrid_search':
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search,
kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': str(dataset.id),
'query': query,
'search_method': 'hybrid_search',
'embeddings': embeddings,
'score_threshold': retrieval_model[
'score_threshold'] if retrieval_model[
'score_threshold_enabled'] else None,
'top_k': self.top_k,
'reranking_model': retrieval_model[
'reranking_model'] if retrieval_model[
'reranking_enable'] else None,
'all_documents': documents
})
threads.append(full_text_index_thread)
full_text_index_thread.start()
for thread in threads:
thread.join()
all_documents.extend(documents)

View File

@ -0,0 +1,236 @@
import threading
from typing import List, Optional, Type
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rerank.rerank import RerankRunner
from extensions.ext_database import db
from flask import current_app
from langchain.tools import BaseTool
from models.dataset import Dataset, Document, DocumentSegment
from pydantic import BaseModel, Field
from services.retrieval_service import RetrievalService
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enabled': False
}
class DatasetRetrieverToolInput(BaseModel):
query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
class DatasetRetrieverTool(BaseTool):
"""Tool for querying a Dataset."""
name: str = "dataset"
args_schema: Type[BaseModel] = DatasetRetrieverToolInput
description: str = "use this to retrieve a dataset. "
tenant_id: str
dataset_id: str
top_k: int = 2
score_threshold: Optional[float] = None
hit_callbacks: List[DatasetIndexToolCallbackHandler] = []
return_resource: bool
retriever_from: str
@classmethod
def from_dataset(cls, dataset: Dataset, **kwargs):
description = dataset.description
if not description:
description = 'useful for when you want to answer queries about the ' + dataset.name
description = description.replace('\n', '').replace('\r', '')
return cls(
name=f'dataset-{dataset.id}',
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
description=description,
**kwargs
)
def _run(self, query: str) -> str:
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id == self.dataset_id
).first()
if not dataset:
return ''
for hit_callback in self.hit_callbacks:
hit_callback.on_query(query, dataset.id)
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
kw_table_index = KeywordTableIndex(
dataset=dataset,
config=KeywordTableConfig(
max_keywords_per_chunk=5
)
)
documents = kw_table_index.search(query, search_kwargs={'k': self.top_k})
return str("\n".join([document.page_content for document in documents]))
else:
# get embedding model instance
try:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
except InvokeAuthorizationError:
return ''
embeddings = CacheEmbedding(embedding_model)
documents = []
threads = []
if self.top_k > 0:
# retrieval source with semantic
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': str(dataset.id),
'query': query,
'top_k': self.top_k,
'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
'score_threshold_enabled'] else None,
'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
'reranking_enable'] else None,
'all_documents': documents,
'search_method': retrieval_model['search_method'],
'embeddings': embeddings
})
threads.append(embedding_thread)
embedding_thread.start()
# retrieval_model source with full text
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': str(dataset.id),
'query': query,
'search_method': retrieval_model['search_method'],
'embeddings': embeddings,
'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
'score_threshold_enabled'] else None,
'top_k': self.top_k,
'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
'reranking_enable'] else None,
'all_documents': documents
})
threads.append(full_text_index_thread)
full_text_index_thread.start()
for thread in threads:
thread.join()
# hybrid search: rerank after all documents have been searched
if retrieval_model['search_method'] == 'hybrid_search':
# get rerank model instance
try:
model_manager = ModelManager()
rerank_model_instance = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=retrieval_model['reranking_model']['reranking_provider_name'],
model_type=ModelType.RERANK,
model=retrieval_model['reranking_model']['reranking_model_name']
)
except InvokeAuthorizationError:
return ''
rerank_runner = RerankRunner(rerank_model_instance)
documents = rerank_runner.run(
query=query,
documents=documents,
score_threshold=retrieval_model['score_threshold'] if retrieval_model[
'score_threshold_enabled'] else None,
top_n=self.top_k
)
else:
documents = []
for hit_callback in self.hit_callbacks:
hit_callback.on_tool_end(documents)
document_score_list = {}
if dataset.indexing_technique != "economy":
for item in documents:
if 'score' in item.metadata and item.metadata['score']:
document_score_list[item.metadata['doc_id']] = item.metadata['score']
document_context_list = []
index_node_ids = [document.metadata['doc_id'] for document in documents]
segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id,
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids)
).all()
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
sorted_segments = sorted(segments,
key=lambda segment: index_node_id_to_position.get(segment.index_node_id,
float('inf')))
for segment in sorted_segments:
if segment.answer:
document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
else:
document_context_list.append(segment.content)
if self.return_resource:
context_list = []
resource_number = 1
for segment in sorted_segments:
context = {}
document = Document.query.filter(Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
).first()
if dataset and document:
source = {
'position': resource_number,
'dataset_id': dataset.id,
'dataset_name': dataset.name,
'document_id': document.id,
'document_name': document.name,
'data_source_type': document.data_source_type,
'segment_id': segment.id,
'retriever_from': self.retriever_from,
'score': document_score_list.get(segment.index_node_id, None)
}
if self.retriever_from == 'dev':
source['hit_count'] = segment.hit_count
source['word_count'] = segment.word_count
source['segment_position'] = segment.position
source['index_node_hash'] = segment.index_node_hash
if segment.answer:
source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
else:
source['content'] = segment.content
context_list.append(source)
resource_number += 1
for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list)
return str("\n".join(document_context_list))
async def _arun(self, tool_input: str) -> str:
raise NotImplementedError()

View File

@ -0,0 +1,95 @@
from typing import Any, Dict, List, Union
from core.features.dataset_retrieval import DatasetRetrievalFeature
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParamter, ToolIdentity, ToolDescription
from core.tools.tool.tool import Tool
from core.tools.entities.common_entities import I18nObject
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom
from langchain.tools import BaseTool
class DatasetRetrieverTool(Tool):
langchain_tool: BaseTool
@staticmethod
def get_dataset_tools(tenant_id: str,
dataset_ids: list[str],
retrieve_config: DatasetRetrieveConfigEntity,
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler
) -> List['DatasetRetrieverTool']:
"""
get dataset tool
"""
# check if retrieve_config is valid
if dataset_ids is None or len(dataset_ids) == 0:
return []
if retrieve_config is None:
return []
feature = DatasetRetrievalFeature()
# save original retrieve strategy, and set retrieve strategy to SINGLE
# Agent only support SINGLE mode
original_retriever_mode = retrieve_config.retrieve_strategy
retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
langchain_tools = feature.to_dataset_retriever_tool(
tenant_id=tenant_id,
dataset_ids=dataset_ids,
retrieve_config=retrieve_config,
return_resource=return_resource,
invoke_from=invoke_from,
hit_callback=hit_callback
)
# restore retrieve strategy
retrieve_config.retrieve_strategy = original_retriever_mode
# convert langchain tools to Tools
tools = []
for langchain_tool in langchain_tools:
tool = DatasetRetrieverTool(
langchain_tool=langchain_tool,
identity=ToolIdentity(author='', name=langchain_tool.name, label=I18nObject(en_US='', zh_Hans='')),
parameters=[],
is_team_authorization=True,
description=ToolDescription(
human=I18nObject(en_US='', zh_Hans=''),
llm=langchain_tool.description),
runtime=DatasetRetrieverTool.Runtime()
)
tools.append(tool)
return tools
def get_runtime_parameters(self) -> List[ToolParamter]:
return [
ToolParamter(name='query',
label=I18nObject(en_US='', zh_Hans=''),
human_description=I18nObject(en_US='', zh_Hans=''),
type=ToolParamter.ToolParameterType.STRING,
form=ToolParamter.ToolParameterForm.LLM,
llm_description='Query for the dataset to be used to retrieve the dataset.',
required=True,
default=''),
]
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) -> ToolInvokeMessage | List[ToolInvokeMessage]:
"""
invoke dataset retriever tool
"""
query = tool_paramters.get('query', None)
if not query:
return self.create_text_message(text='please input query')
# invoke dataset retriever tool
result = self.langchain_tool._run(query=query)
return self.create_text_message(text=result)
def validate_credentials(self, credentails: Dict[str, Any], parameters: Dict[str, Any]) -> None:
"""
validate the credentials for dataset retriever tool
"""
pass

302
api/core/tools/tool/tool.py Normal file
View File

@ -0,0 +1,302 @@
from pydantic import BaseModel
from typing import List, Dict, Any, Union, Optional
from abc import abstractmethod, ABC
from enum import Enum
from core.tools.entities.tool_entities import ToolIdentity, ToolInvokeMessage,\
ToolParamter, ToolDescription, ToolRuntimeVariablePool, ToolRuntimeVariable, ToolRuntimeImageVariable
from core.tools.tool_file_manager import ToolFileManager
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
class Tool(BaseModel, ABC):
identity: ToolIdentity = None
parameters: Optional[List[ToolParamter]] = None
description: ToolDescription = None
is_team_authorization: bool = False
agent_callback: Optional[DifyAgentCallbackHandler] = None
use_callback: bool = False
class Runtime(BaseModel):
"""
Meta data of a tool call processing
"""
def __init__(self, **data: Any):
super().__init__(**data)
if not self.runtime_parameters:
self.runtime_parameters = {}
tenant_id: str = None
tool_id: str = None
credentials: Dict[str, Any] = None
runtime_parameters: Dict[str, Any] = None
runtime: Runtime = None
variables: ToolRuntimeVariablePool = None
def __init__(self, **data: Any):
super().__init__(**data)
if not self.agent_callback:
self.use_callback = False
else:
self.use_callback = True
class VARIABLE_KEY(Enum):
IMAGE = 'image'
def fork_tool_runtime(self, meta: Dict[str, Any], agent_callback: DifyAgentCallbackHandler = None) -> 'Tool':
"""
fork a new tool with meta data
:param meta: the meta data of a tool call processing, tenant_id is required
:return: the new tool
"""
return self.__class__(
identity=self.identity.copy() if self.identity else None,
parameters=self.parameters.copy() if self.parameters else None,
description=self.description.copy() if self.description else None,
runtime=Tool.Runtime(**meta),
agent_callback=agent_callback
)
def load_variables(self, variables: ToolRuntimeVariablePool):
"""
load variables from database
:param conversation_id: the conversation id
"""
self.variables = variables
def set_image_variable(self, variable_name: str, image_key: str) -> None:
"""
set an image variable
"""
if not self.variables:
return
self.variables.set_file(self.identity.name, variable_name, image_key)
def set_text_variable(self, variable_name: str, text: str) -> None:
"""
set a text variable
"""
if not self.variables:
return
self.variables.set_text(self.identity.name, variable_name, text)
def get_variable(self, name: Union[str, Enum]) -> Optional[ToolRuntimeVariable]:
"""
get a variable
:param name: the name of the variable
:return: the variable
"""
if not self.variables:
return None
if isinstance(name, Enum):
name = name.value
for variable in self.variables.pool:
if variable.name == name:
return variable
return None
def get_default_image_variable(self) -> Optional[ToolRuntimeVariable]:
"""
get the default image variable
:return: the image variable
"""
if not self.variables:
return None
return self.get_variable(self.VARIABLE_KEY.IMAGE)
def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]:
"""
get a variable file
:param name: the name of the variable
:return: the variable file
"""
variable = self.get_variable(name)
if not variable:
return None
if not isinstance(variable, ToolRuntimeImageVariable):
return None
message_file_id = variable.value
# get file binary
file_binary = ToolFileManager.get_file_binary_by_message_file_id(message_file_id)
if not file_binary:
return None
return file_binary[0]
def list_variables(self) -> List[ToolRuntimeVariable]:
"""
list all variables
:return: the variables
"""
if not self.variables:
return []
return self.variables.pool
def list_default_image_variables(self) -> List[ToolRuntimeVariable]:
"""
list all image variables
:return: the image variables
"""
if not self.variables:
return []
result = []
for variable in self.variables.pool:
if variable.name.startswith(self.VARIABLE_KEY.IMAGE.value):
result.append(variable)
return result
def invoke(self, user_id: str, tool_paramters: Dict[str, Any]) -> List[ToolInvokeMessage]:
# update tool_paramters
if self.runtime.runtime_parameters:
tool_paramters.update(self.runtime.runtime_parameters)
# hit callback
if self.use_callback:
self.agent_callback.on_tool_start(
tool_name=self.identity.name,
tool_inputs=tool_paramters
)
try:
result = self._invoke(
user_id=user_id,
tool_paramters=tool_paramters,
)
except Exception as e:
if self.use_callback:
self.agent_callback.on_tool_error(e)
raise e
if not isinstance(result, list):
result = [result]
# hit callback
if self.use_callback:
self.agent_callback.on_tool_end(
tool_name=self.identity.name,
tool_inputs=tool_paramters,
tool_outputs=self._convert_tool_response_to_str(result)
)
return result
def _convert_tool_response_to_str(self, tool_response: List[ToolInvokeMessage]) -> str:
"""
Handle tool response
"""
result = ''
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.TEXT:
result += response.message
elif response.type == ToolInvokeMessage.MessageType.LINK:
result += f"result link: {response.message}. please dirct user to check it."
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
response.type == ToolInvokeMessage.MessageType.IMAGE:
result += f"image has been created and sent to user already, you should tell user to check it now."
elif response.type == ToolInvokeMessage.MessageType.BLOB:
if len(response.message) > 114:
result += str(response.message[:114]) + '...'
else:
result += str(response.message)
else:
result += f"tool response: {response.message}."
return result
@abstractmethod
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
pass
def validate_credentials(self, credentails: Dict[str, Any], parameters: Dict[str, Any]) -> None:
"""
validate the credentials
:param credentails: the credentials
:param parameters: the parameters
"""
pass
def get_runtime_parameters(self) -> List[ToolParamter]:
"""
get the runtime parameters
interface for developer to dynamic change the parameters of a tool depends on the variables pool
:return: the runtime parameters
"""
return self.parameters
def is_tool_avaliable(self) -> bool:
"""
check if the tool is avaliable
:return: if the tool is avaliable
"""
return True
def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage:
"""
create an image message
:param image: the url of the image
:return: the image message
"""
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE,
message=image,
save_as=save_as)
def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage:
"""
create a link message
:param link: the url of the link
:return: the link message
"""
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK,
message=link,
save_as=save_as)
def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage:
"""
create a text message
:param text: the text
:return: the text message
"""
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.TEXT,
message=text,
save_as=save_as
)
def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage:
"""
create a blob message
:param blob: the blob
:return: the blob message
"""
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.BLOB,
message=blob, meta=meta,
save_as=save_as
)

View File

@ -0,0 +1,197 @@
import logging
import time
import os
import hmac
import base64
import hashlib
from typing import Union, Tuple, Generator
from uuid import uuid4
from mimetypes import guess_extension, guess_type
from httpx import get
from flask import current_app
from models.tools import ToolFile
from models.model import MessageFile
from extensions.ext_database import db
from extensions.ext_storage import storage
logger = logging.getLogger(__name__)
class ToolFileManager:
@staticmethod
def sign_file(file_id: str, extension: str) -> str:
"""
sign file to get a temporary url
"""
base_url = current_app.config.get('FILES_URL')
file_preview_url = f'{base_url}/files/tools/{file_id}{extension}'
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}"
secret_key = current_app.config['SECRET_KEY'].encode()
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
@staticmethod
def verify_file(file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
"""
verify signature
"""
data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}"
secret_key = current_app.config['SECRET_KEY'].encode()
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
# verify signature
if sign != recalculated_encoded_sign:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= 300 # expired after 5 minutes
@staticmethod
def create_file_by_raw(user_id: str, tenant_id: str,
conversation_id: str, file_binary: bytes,
mimetype: str
) -> ToolFile:
"""
create file
"""
extension = guess_extension(mimetype) or '.bin'
unique_name = uuid4().hex
filename = f"/tools/{tenant_id}/{unique_name}{extension}"
storage.save(filename, file_binary)
tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id,
conversation_id=conversation_id, file_key=filename, mimetype=mimetype)
db.session.add(tool_file)
db.session.commit()
return tool_file
@staticmethod
def create_file_by_url(user_id: str, tenant_id: str,
conversation_id: str, file_url: str,
) -> ToolFile:
"""
create file
"""
# try to download image
response = get(file_url)
response.raise_for_status()
blob = response.content
mimetype = guess_type(file_url)[0] or 'octet/stream'
extension = guess_extension(mimetype) or '.bin'
unique_name = uuid4().hex
filename = f"/tools/{tenant_id}/{unique_name}{extension}"
storage.save(filename, blob)
tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id,
conversation_id=conversation_id, file_key=filename,
mimetype=mimetype, original_url=file_url)
db.session.add(tool_file)
db.session.commit()
return tool_file
@staticmethod
def create_file_by_key(user_id: str, tenant_id: str,
conversation_id: str, file_key: str,
mimetype: str
) -> ToolFile:
"""
create file
"""
tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id,
conversation_id=conversation_id, file_key=file_key, mimetype=mimetype)
return tool_file
@staticmethod
def get_file_binary(id: str) -> Union[Tuple[bytes, str], None]:
"""
get file binary
:param id: the id of the file
:return: the binary of the file, mime type
"""
tool_file: ToolFile = db.session.query(ToolFile).filter(
ToolFile.id == id,
).first()
if not tool_file:
return None
blob = storage.load_once(tool_file.file_key)
return blob, tool_file.mimetype
@staticmethod
def get_file_binary_by_message_file_id(id: str) -> Union[Tuple[bytes, str], None]:
"""
get file binary
:param id: the id of the file
:return: the binary of the file, mime type
"""
message_file: MessageFile = db.session.query(MessageFile).filter(
MessageFile.id == id,
).first()
# get tool file id
tool_file_id = message_file.url.split('/')[-1]
# trim extension
tool_file_id = tool_file_id.split('.')[0]
tool_file: ToolFile = db.session.query(ToolFile).filter(
ToolFile.id == tool_file_id,
).first()
if not tool_file:
return None
blob = storage.load_once(tool_file.file_key)
return blob, tool_file.mimetype
@staticmethod
def get_file_generator_by_message_file_id(id: str) -> Union[Tuple[Generator, str], None]:
"""
get file binary
:param id: the id of the file
:return: the binary of the file, mime type
"""
message_file: MessageFile = db.session.query(MessageFile).filter(
MessageFile.id == id,
).first()
# get tool file id
tool_file_id = message_file.url.split('/')[-1]
# trim extension
tool_file_id = tool_file_id.split('.')[0]
tool_file: ToolFile = db.session.query(ToolFile).filter(
ToolFile.id == tool_file_id,
).first()
if not tool_file:
return None
generator = storage.load_stream(tool_file.file_key)
return generator, tool_file.mimetype
# init tool_file_parser
from core.file.tool_file_parser import tool_file_manager
tool_file_manager['manager'] = ToolFileManager

View File

@ -0,0 +1,448 @@
from typing import List, Dict, Any, Tuple, Union
from os import listdir, path
from core.tools.entities.tool_entities import ToolInvokeMessage, ApiProviderAuthType, ToolProviderCredentials
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.api_tool import ApiTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from core.tools.entities.constant import DEFAULT_PROVIDERS
from core.tools.entities.common_entities import I18nObject
from core.tools.errors import ToolProviderNotFoundError
from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
from core.tools.provider.app_tool_provider import AppBasedToolProviderEntity
from core.tools.entities.user_entities import UserToolProvider
from core.tools.utils.configration import ToolConfiguration
from core.tools.utils.encoder import serialize_base_model_dict
from core.tools.provider.builtin._positions import BuiltinToolProviderSort
from core.model_runtime.entities.message_entities import PromptMessage
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider
import importlib
import logging
import json
import mimetypes
logger = logging.getLogger(__name__)
_builtin_providers = {}
class ToolManager:
@staticmethod
def invoke(
provider: str,
tool_id: str,
tool_name: str,
tool_parameters: Dict[str, Any],
credentials: Dict[str, Any],
prompt_messages: List[PromptMessage],
) -> List[ToolInvokeMessage]:
"""
invoke the assistant
:param provider: the name of the provider
:param tool_id: the id of the tool
:param tool_name: the name of the tool, defined in `get_tools`
:param tool_parameters: the parameters of the tool
:param credentials: the credentials of the tool
:param prompt_messages: the prompt messages that the tool can use
:return: the messages that the tool wants to send to the user
"""
provider_entity: ToolProviderController = None
if provider == DEFAULT_PROVIDERS.API_BASED:
provider_entity = ApiBasedToolProviderController()
elif provider == DEFAULT_PROVIDERS.APP_BASED:
provider_entity = AppBasedToolProviderEntity()
if provider_entity is None:
# fetch the provider from .provider.builtin
py_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.py')
spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.{provider}', py_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
# get all the classes in the module
classes = [ x for _, x in vars(mod).items()
if isinstance(x, type) and x != ToolProviderController and issubclass(x, ToolProviderController)
]
if len(classes) == 0:
raise ToolProviderNotFoundError(f'provider {provider} not found')
if len(classes) > 1:
raise ToolProviderNotFoundError(f'multiple providers found for {provider}')
provider_entity = classes[0]()
return provider_entity.invoke(tool_id, tool_name, tool_parameters, credentials, prompt_messages)
@staticmethod
def get_builtin_provider(provider: str) -> BuiltinToolProviderController:
global _builtin_providers
"""
get the builtin provider
:param provider: the name of the provider
:return: the provider
"""
if len(_builtin_providers) == 0:
# init the builtin providers
ToolManager.list_builtin_providers()
if provider not in _builtin_providers:
raise ToolProviderNotFoundError(f'builtin provider {provider} not found')
return _builtin_providers[provider]
@staticmethod
def get_builtin_tool(provider: str, tool_name: str) -> BuiltinTool:
"""
get the builtin tool
:param provider: the name of the provider
:param tool_name: the name of the tool
:return: the provider, the tool
"""
provider_controller = ToolManager.get_builtin_provider(provider)
tool = provider_controller.get_tool(tool_name)
return tool
@staticmethod
def get_tool(provider_type: str, provider_id: str, tool_name: str, tanent_id: str = None) \
-> Union[BuiltinTool, ApiTool]:
"""
get the tool
:param provider_type: the type of the provider
:param provider_name: the name of the provider
:param tool_name: the name of the tool
:return: the tool
"""
if provider_type == 'builtin':
return ToolManager.get_builtin_tool(provider_id, tool_name)
elif provider_type == 'api':
if tanent_id is None:
raise ValueError('tanent id is required for api provider')
api_provider, _ = ToolManager.get_api_provider_controller(tanent_id, provider_id)
return api_provider.get_tool(tool_name)
elif provider_type == 'app':
raise NotImplementedError('app provider not implemented')
else:
raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
@staticmethod
def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tanent_id,
agent_callback: DifyAgentCallbackHandler = None) \
-> Union[BuiltinTool, ApiTool]:
"""
get the tool runtime
:param provider_type: the type of the provider
:param provider_name: the name of the provider
:param tool_name: the name of the tool
:return: the tool
"""
if provider_type == 'builtin':
builtin_tool = ToolManager.get_builtin_tool(provider_name, tool_name)
# check if the builtin tool need credentials
provider_controller = ToolManager.get_builtin_provider(provider_name)
if not provider_controller.need_credentials:
return builtin_tool.fork_tool_runtime(meta={
'tenant_id': tanent_id,
'credentials': {},
}, agent_callback=agent_callback)
# get credentials
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tanent_id,
BuiltinToolProvider.provider == provider_name,
).first()
if builtin_provider is None:
raise ToolProviderNotFoundError(f'builtin provider {provider_name} not found')
# decrypt the credentials
credentials = builtin_provider.credentials
controller = ToolManager.get_builtin_provider(provider_name)
tool_configuration = ToolConfiguration(tenant_id=tanent_id, provider_controller=controller)
decrypted_credentails = tool_configuration.decrypt_tool_credentials(credentials)
return builtin_tool.fork_tool_runtime(meta={
'tenant_id': tanent_id,
'credentials': decrypted_credentails,
'runtime_parameters': {}
}, agent_callback=agent_callback)
elif provider_type == 'api':
if tanent_id is None:
raise ValueError('tanent id is required for api provider')
api_provider, credentials = ToolManager.get_api_provider_controller(tanent_id, provider_name)
# decrypt the credentials
tool_configuration = ToolConfiguration(tenant_id=tanent_id, provider_controller=api_provider)
decrypted_credentails = tool_configuration.decrypt_tool_credentials(credentials)
return api_provider.get_tool(tool_name).fork_tool_runtime(meta={
'tenant_id': tanent_id,
'credentials': decrypted_credentails,
})
elif provider_type == 'app':
raise NotImplementedError('app provider not implemented')
else:
raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
@staticmethod
def get_builtin_provider_icon(provider: str) -> Tuple[str, str]:
"""
get the absolute path of the icon of the builtin provider
:param provider: the name of the provider
:return: the absolute path of the icon, the mime type of the icon
"""
# get provider
provider_controller = ToolManager.get_builtin_provider(provider)
absolute_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, '_assets', provider_controller.identity.icon)
# check if the icon exists
if not path.exists(absolute_path):
raise ToolProviderNotFoundError(f'builtin provider {provider} icon not found')
# get the mime type
mime_type, _ = mimetypes.guess_type(absolute_path)
mime_type = mime_type or 'application/octet-stream'
return absolute_path, mime_type
@staticmethod
def list_builtin_providers() -> List[BuiltinToolProviderController]:
global _builtin_providers
# use cache first
if len(_builtin_providers) > 0:
return list(_builtin_providers.values())
builtin_providers = []
for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
if provider.startswith('__'):
continue
if path.isdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider)):
if provider.startswith('__'):
continue
py_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, f'{provider}.py')
spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.{provider}', py_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
# load all classes
classes = [
obj for name, obj in vars(mod).items()
if isinstance(obj, type) and obj != BuiltinToolProviderController and issubclass(obj, BuiltinToolProviderController)
]
if len(classes) == 0:
raise ToolProviderNotFoundError(f'provider {provider} not found')
if len(classes) > 1:
raise ToolProviderNotFoundError(f'multiple providers found for {provider}')
# init provider
provider_class = classes[0]
builtin_providers.append(provider_class())
# cache the builtin providers
for provider in builtin_providers:
_builtin_providers[provider.identity.name] = provider
return builtin_providers
@staticmethod
def user_list_providers(
user_id: str,
tenant_id: str,
) -> List[UserToolProvider]:
result_providers: Dict[str, UserToolProvider] = {}
# get builtin providers
builtin_providers = ToolManager.list_builtin_providers()
# append builtin providers
for provider in builtin_providers:
result_providers[provider.identity.name] = UserToolProvider(
id=provider.identity.name,
author=provider.identity.author,
name=provider.identity.name,
description=I18nObject(
en_US=provider.identity.description.en_US,
zh_Hans=provider.identity.description.zh_Hans,
),
icon=provider.identity.icon,
label=I18nObject(
en_US=provider.identity.label.en_US,
zh_Hans=provider.identity.label.zh_Hans,
),
type=UserToolProvider.ProviderType.BUILTIN,
team_credentials={},
is_team_authorization=False,
)
# get credentials schema
schema = provider.get_credentails_schema()
for name, value in schema.items():
result_providers[provider.identity.name].team_credentials[name] = \
ToolProviderCredentials.CredentialsType.defaut(value.type)
# check if the provider need credentials
if not provider.need_credentials:
result_providers[provider.identity.name].is_team_authorization = True
result_providers[provider.identity.name].allow_delete = False
# get db builtin providers
db_builtin_providers: List[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \
filter(BuiltinToolProvider.tenant_id == tenant_id).all()
for db_builtin_provider in db_builtin_providers:
# add provider into providers
credentails = db_builtin_provider.credentials
provider_name = db_builtin_provider.provider
result_providers[provider_name].is_team_authorization = True
# package builtin tool provider controller
controller = ToolManager.get_builtin_provider(provider_name)
# init tool configuration
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
# decrypt the credentials and mask the credentials
decrypted_credentails = tool_configuration.decrypt_tool_credentials(credentials=credentails)
masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentails)
result_providers[provider_name].team_credentials = masked_credentials
# get db api providers
db_api_providers: List[ApiToolProvider] = db.session.query(ApiToolProvider). \
filter(ApiToolProvider.tenant_id == tenant_id).all()
for db_api_provider in db_api_providers:
username = 'Anonymous'
try:
username = db_api_provider.user.name
except Exception as e:
logger.error(f'failed to get user name for api provider {db_api_provider.id}: {str(e)}')
# add provider into providers
credentails = db_api_provider.credentials
provider_name = db_api_provider.name
result_providers[provider_name] = UserToolProvider(
id=db_api_provider.id,
author=username,
name=db_api_provider.name,
description=I18nObject(
en_US=db_api_provider.description,
zh_Hans=db_api_provider.description,
),
icon=db_api_provider.icon,
label=I18nObject(
en_US=db_api_provider.name,
zh_Hans=db_api_provider.name,
),
type=UserToolProvider.ProviderType.API,
team_credentials={},
is_team_authorization=True,
)
# package tool provider controller
controller = ApiBasedToolProviderController.from_db(
db_provider=db_api_provider,
auth_type=ApiProviderAuthType.API_KEY if db_api_provider.credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
)
# init tool configuration
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
# decrypt the credentials and mask the credentials
decrypted_credentails = tool_configuration.decrypt_tool_credentials(credentials=credentails)
masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentails)
result_providers[provider_name].team_credentials = masked_credentials
return BuiltinToolProviderSort.sort(list(result_providers.values()))
@staticmethod
def get_api_provider_controller(tanent_id: str, provider_id: str) -> Tuple[ApiBasedToolProviderController, Dict[str, Any]]:
"""
get the api provider
:param provider_name: the name of the provider
:return: the provider controller, the credentials
"""
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
ApiToolProvider.id == provider_id,
ApiToolProvider.tenant_id == tanent_id,
).first()
if provider is None:
raise ToolProviderNotFoundError(f'api provider {provider_id} not found')
controller = ApiBasedToolProviderController.from_db(
provider, ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
)
controller.load_bundled_tools(provider.tools)
return controller, provider.credentials
@staticmethod
def user_get_api_provider(provider: str, tenant_id: str) -> dict:
"""
get api provider
"""
"""
get tool provider
"""
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider,
).first()
if provider is None:
raise ValueError(f'yout have not added provider {provider}')
try:
credentials = json.loads(provider.credentials_str) or {}
except:
credentials = {}
# package tool provider controller
controller = ApiBasedToolProviderController.from_db(
provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
)
# init tool configuration
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
decrypted_credentails = tool_configuration.decrypt_tool_credentials(credentials)
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentails)
try:
icon = json.loads(provider.icon)
except:
icon = {
"background": "#252525",
"content": "\ud83d\ude01"
}
return json.loads(serialize_base_model_dict({
'schema_type': provider.schema_type,
'schema': provider.schema,
'tools': provider.tools,
'icon': icon,
'description': provider.description,
'credentials': masked_credentials,
'privacy_policy': provider.privacy_policy
}))

View File

@ -0,0 +1,77 @@
from typing import Dict, Any
from pydantic import BaseModel
from core.tools.entities.tool_entities import ToolProviderCredentials
from core.tools.provider.tool_provider import ToolProviderController
from core.helper import encrypter
class ToolConfiguration(BaseModel):
tenant_id: str
provider_controller: ToolProviderController
def _deep_copy(self, credentails: Dict[str, str]) -> Dict[str, str]:
"""
deep copy credentials
"""
return {key: value for key, value in credentails.items()}
def encrypt_tool_credentials(self, credentails: Dict[str, str]) -> Dict[str, str]:
"""
encrypt tool credentials with tanent id
return a deep copy of credentials with encrypted values
"""
credentials = self._deep_copy(credentails)
# get fields need to be decrypted
fields = self.provider_controller.get_credentails_schema()
for field_name, field in fields.items():
if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
if field_name in credentials:
encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
credentials[field_name] = encrypted
return credentials
def mask_tool_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any]:
"""
mask tool credentials
return a deep copy of credentials with masked values
"""
credentials = self._deep_copy(credentials)
# get fields need to be decrypted
fields = self.provider_controller.get_credentails_schema()
for field_name, field in fields.items():
if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
if field_name in credentials:
if len(credentials[field_name]) > 6:
credentials[field_name] = \
credentials[field_name][:2] + \
'*' * (len(credentials[field_name]) - 4) +\
credentials[field_name][-2:]
else:
credentials[field_name] = '*' * len(credentials[field_name])
return credentials
def decrypt_tool_credentials(self, credentials: Dict[str, str]) -> Dict[str, str]:
"""
decrypt tool credentials with tanent id
return a deep copy of credentials with decrypted values
"""
credentials = self._deep_copy(credentials)
# get fields need to be decrypted
fields = self.provider_controller.get_credentails_schema()
for field_name, field in fields.items():
if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
if field_name in credentials:
try:
credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
except:
pass
return credentials

Some files were not shown because too many files have changed in this diff Show More