refactor: tool
@ -1,38 +0,0 @@
|
||||
- google
|
||||
- bing
|
||||
- perplexity
|
||||
- duckduckgo
|
||||
- searchapi
|
||||
- serper
|
||||
- searxng
|
||||
- dalle
|
||||
- azuredalle
|
||||
- stability
|
||||
- wikipedia
|
||||
- nominatim
|
||||
- yahoo
|
||||
- alphavantage
|
||||
- arxiv
|
||||
- pubmed
|
||||
- stablediffusion
|
||||
- webscraper
|
||||
- jina
|
||||
- aippt
|
||||
- youtube
|
||||
- code
|
||||
- wolframalpha
|
||||
- maths
|
||||
- github
|
||||
- chart
|
||||
- time
|
||||
- vectorizer
|
||||
- gaode
|
||||
- wecom
|
||||
- qrcode
|
||||
- dingtalk
|
||||
- feishu
|
||||
- feishu_base
|
||||
- feishu_document
|
||||
- feishu_message
|
||||
- slack
|
||||
- tianditu
|
||||
@ -1,171 +0,0 @@
|
||||
from pydantic import Field
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool.api_tool import ApiTool
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider
|
||||
|
||||
|
||||
class ApiToolProviderController(ToolProviderController):
|
||||
provider_id: str
|
||||
tenant_id: str
|
||||
tools: list[ApiTool] = Field(default_factory=list)
|
||||
|
||||
@staticmethod
|
||||
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
|
||||
credentials_schema = {
|
||||
"auth_type": ProviderConfig(
|
||||
name="auth_type",
|
||||
required=True,
|
||||
type=ProviderConfig.Type.SELECT,
|
||||
options=[
|
||||
ProviderConfig.Option(value="none", label=I18nObject(en_US="None", zh_Hans="无")),
|
||||
ProviderConfig.Option(value="api_key", label=I18nObject(en_US="api_key", zh_Hans="api_key")),
|
||||
],
|
||||
default="none",
|
||||
help=I18nObject(en_US="The auth type of the api provider", zh_Hans="api provider 的认证类型"),
|
||||
)
|
||||
}
|
||||
if auth_type == ApiProviderAuthType.API_KEY:
|
||||
credentials_schema = {
|
||||
**credentials_schema,
|
||||
"api_key_header": ProviderConfig(
|
||||
name="api_key_header",
|
||||
required=False,
|
||||
default="api_key",
|
||||
type=ProviderConfig.Type.TEXT_INPUT,
|
||||
help=I18nObject(en_US="The header name of the api key", zh_Hans="携带 api key 的 header 名称"),
|
||||
),
|
||||
"api_key_value": ProviderConfig(
|
||||
name="api_key_value",
|
||||
required=True,
|
||||
type=ProviderConfig.Type.SECRET_INPUT,
|
||||
help=I18nObject(en_US="The api key", zh_Hans="api key的值"),
|
||||
),
|
||||
"api_key_header_prefix": ProviderConfig(
|
||||
name="api_key_header_prefix",
|
||||
required=False,
|
||||
default="basic",
|
||||
type=ProviderConfig.Type.SELECT,
|
||||
help=I18nObject(en_US="The prefix of the api key header", zh_Hans="api key header 的前缀"),
|
||||
options=[
|
||||
ProviderConfig.Option(value="basic", label=I18nObject(en_US="Basic", zh_Hans="Basic")),
|
||||
ProviderConfig.Option(value="bearer", label=I18nObject(en_US="Bearer", zh_Hans="Bearer")),
|
||||
ProviderConfig.Option(value="custom", label=I18nObject(en_US="Custom", zh_Hans="Custom")),
|
||||
],
|
||||
),
|
||||
}
|
||||
elif auth_type == ApiProviderAuthType.NONE:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"invalid auth type {auth_type}")
|
||||
|
||||
user_name = db_provider.user.name if db_provider.user_id else ""
|
||||
|
||||
return ApiToolProviderController(
|
||||
**{
|
||||
"identity": {
|
||||
"author": user_name,
|
||||
"name": db_provider.name,
|
||||
"label": {"en_US": db_provider.name, "zh_Hans": db_provider.name},
|
||||
"description": {"en_US": db_provider.description, "zh_Hans": db_provider.description},
|
||||
"icon": db_provider.icon,
|
||||
},
|
||||
"credentials_schema": credentials_schema,
|
||||
"provider_id": db_provider.id or "",
|
||||
"tenant_id": db_provider.tenant_id or "",
|
||||
},
|
||||
)
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.API
|
||||
|
||||
def _parse_tool_bundle(self, tool_bundle: ApiToolBundle) -> ApiTool:
|
||||
"""
|
||||
parse tool bundle to tool
|
||||
|
||||
:param tool_bundle: the tool bundle
|
||||
:return: the tool
|
||||
"""
|
||||
return ApiTool(
|
||||
**{
|
||||
"api_bundle": tool_bundle,
|
||||
"identity": {
|
||||
"author": tool_bundle.author,
|
||||
"name": tool_bundle.operation_id,
|
||||
"label": {"en_US": tool_bundle.operation_id, "zh_Hans": tool_bundle.operation_id},
|
||||
"icon": self.identity.icon,
|
||||
"provider": self.provider_id,
|
||||
},
|
||||
"description": {
|
||||
"human": {"en_US": tool_bundle.summary or "", "zh_Hans": tool_bundle.summary or ""},
|
||||
"llm": tool_bundle.summary or "",
|
||||
},
|
||||
"parameters": tool_bundle.parameters or [],
|
||||
}
|
||||
)
|
||||
|
||||
def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]:
|
||||
"""
|
||||
load bundled tools
|
||||
|
||||
:param tools: the bundled tools
|
||||
:return: the tools
|
||||
"""
|
||||
self.tools = [self._parse_tool_bundle(tool) for tool in tools]
|
||||
|
||||
return self.tools
|
||||
|
||||
def get_tools(self, tenant_id: str) -> list[ApiTool]:
|
||||
"""
|
||||
fetch tools from database
|
||||
|
||||
:param user_id: the user id
|
||||
:param tenant_id: the tenant id
|
||||
:return: the tools
|
||||
"""
|
||||
if self.tools is not None:
|
||||
return self.tools
|
||||
|
||||
tools: list[ApiTool] = []
|
||||
|
||||
# get tenant api providers
|
||||
db_providers: list[ApiToolProvider] = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.identity.name)
|
||||
.all()
|
||||
)
|
||||
|
||||
if db_providers and len(db_providers) != 0:
|
||||
for db_provider in db_providers:
|
||||
for tool in db_provider.tools:
|
||||
assistant_tool = self._parse_tool_bundle(tool)
|
||||
assistant_tool.is_team_authorization = True
|
||||
tools.append(assistant_tool)
|
||||
|
||||
self.tools = tools
|
||||
return tools
|
||||
|
||||
def get_tool(self, tool_name: str) -> ApiTool:
|
||||
"""
|
||||
get tool by name
|
||||
|
||||
:param tool_name: the name of the tool
|
||||
:return: the tool
|
||||
"""
|
||||
if self.tools is None:
|
||||
self.get_tools(self.tenant_id)
|
||||
|
||||
for tool in self.tools:
|
||||
if tool.identity.name == tool_name:
|
||||
return tool
|
||||
|
||||
raise ValueError(f"tool {tool_name} not found")
|
||||
@ -1,103 +0,0 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolParameter, ToolParameterOption, ToolProviderType
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool.tool import Tool
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppModelConfig
|
||||
from models.tools import PublishedAppTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppToolProviderEntity(ToolProviderController):
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.APP
|
||||
|
||||
def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def validate_parameters(self, tool_name: str, tool_parameters: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def get_tools(self, user_id: str) -> list[Tool]:
|
||||
db_tools: list[PublishedAppTool] = (
|
||||
db.session.query(PublishedAppTool)
|
||||
.filter(
|
||||
PublishedAppTool.user_id == user_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not db_tools or len(db_tools) == 0:
|
||||
return []
|
||||
|
||||
tools: list[Tool] = []
|
||||
|
||||
for db_tool in db_tools:
|
||||
tool = {
|
||||
"identity": {
|
||||
"author": db_tool.author,
|
||||
"name": db_tool.tool_name,
|
||||
"label": {"en_US": db_tool.tool_name, "zh_Hans": db_tool.tool_name},
|
||||
"icon": "",
|
||||
},
|
||||
"description": {
|
||||
"human": {"en_US": db_tool.description_i18n.en_US, "zh_Hans": db_tool.description_i18n.zh_Hans},
|
||||
"llm": db_tool.llm_description,
|
||||
},
|
||||
"parameters": [],
|
||||
}
|
||||
# get app from db
|
||||
app: App = db_tool.app
|
||||
|
||||
if not app:
|
||||
logger.error(f"app {db_tool.app_id} not found")
|
||||
continue
|
||||
|
||||
app_model_config: AppModelConfig = app.app_model_config
|
||||
user_input_form_list = app_model_config.user_input_form_list
|
||||
for input_form in user_input_form_list:
|
||||
# get type
|
||||
form_type = input_form.keys()[0]
|
||||
default = input_form[form_type]["default"]
|
||||
required = input_form[form_type]["required"]
|
||||
label = input_form[form_type]["label"]
|
||||
variable_name = input_form[form_type]["variable_name"]
|
||||
options = input_form[form_type].get("options", [])
|
||||
if form_type in {"paragraph", "text-input"}:
|
||||
tool["parameters"].append(
|
||||
ToolParameter(
|
||||
name=variable_name,
|
||||
label=I18nObject(en_US=label, zh_Hans=label),
|
||||
human_description=I18nObject(en_US=label, zh_Hans=label),
|
||||
llm_description=label,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=required,
|
||||
default=default,
|
||||
)
|
||||
)
|
||||
elif form_type == "select":
|
||||
tool["parameters"].append(
|
||||
ToolParameter(
|
||||
name=variable_name,
|
||||
label=I18nObject(en_US=label, zh_Hans=label),
|
||||
human_description=I18nObject(en_US=label, zh_Hans=label),
|
||||
llm_description=label,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
required=required,
|
||||
default=default,
|
||||
options=[
|
||||
ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
|
||||
for option in options
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
tools.append(Tool(**tool))
|
||||
return tools
|
||||
@ -1,20 +0,0 @@
|
||||
import os.path
|
||||
|
||||
from core.helper.position_helper import get_tool_position_map, sort_by_position_map
|
||||
from core.tools.entities.api_entities import UserToolProvider
|
||||
|
||||
|
||||
class BuiltinToolProviderSort:
|
||||
_position = {}
|
||||
|
||||
@classmethod
|
||||
def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
|
||||
if not cls._position:
|
||||
cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
def name_func(provider: UserToolProvider) -> str:
|
||||
return provider.name
|
||||
|
||||
sorted_providers = sort_by_position_map(cls._position, providers, name_func)
|
||||
|
||||
return sorted_providers
|
||||
|
Before Width: | Height: | Size: 1.9 KiB |
@ -1,11 +0,0 @@
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.aippt.tools.aippt import AIPPTGenerateTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class AIPPTProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
AIPPTGenerateTool._get_api_token(credentials, user_id="__dify_system__")
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
@ -1,45 +0,0 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: aippt
|
||||
label:
|
||||
en_US: AIPPT
|
||||
zh_Hans: AIPPT
|
||||
description:
|
||||
en_US: AI-generated PPT with one click, input your content topic, and let AI serve you one-stop
|
||||
zh_Hans: AI一键生成PPT,输入你的内容主题,让AI为你一站式服务到底
|
||||
icon: icon.png
|
||||
tags:
|
||||
- productivity
|
||||
- design
|
||||
credentials_for_provider:
|
||||
aippt_access_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: AIPPT API key
|
||||
zh_Hans: AIPPT API key
|
||||
pt_BR: AIPPT API key
|
||||
help:
|
||||
en_US: Please input your AIPPT API key
|
||||
zh_Hans: 请输入你的 AIPPT API key
|
||||
pt_BR: Please input your AIPPT API key
|
||||
placeholder:
|
||||
en_US: Please input your AIPPT API key
|
||||
zh_Hans: 请输入你的 AIPPT API key
|
||||
pt_BR: Please input your AIPPT API key
|
||||
url: https://www.aippt.cn
|
||||
aippt_secret_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: AIPPT Secret key
|
||||
zh_Hans: AIPPT Secret key
|
||||
pt_BR: AIPPT Secret key
|
||||
help:
|
||||
en_US: Please input your AIPPT Secret key
|
||||
zh_Hans: 请输入你的 AIPPT Secret key
|
||||
pt_BR: Please input your AIPPT Secret key
|
||||
placeholder:
|
||||
en_US: Please input your AIPPT Secret key
|
||||
zh_Hans: 请输入你的 AIPPT Secret key
|
||||
pt_BR: Please input your AIPPT Secret key
|
||||
@ -1,498 +0,0 @@
|
||||
from base64 import b64encode
|
||||
from hashlib import sha1
|
||||
from hmac import new as hmac_new
|
||||
from json import loads as json_loads
|
||||
from threading import Lock
|
||||
from time import sleep, time
|
||||
from typing import Any, Optional
|
||||
|
||||
from httpx import get, post
|
||||
from requests import get as requests_get
|
||||
from yarl import URL
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class AIPPTGenerateTool(BuiltinTool):
|
||||
"""
|
||||
A tool for generating a ppt
|
||||
"""
|
||||
|
||||
_api_base_url = URL("https://co.aippt.cn/api")
|
||||
_api_token_cache = {}
|
||||
_api_token_cache_lock: Optional[Lock] = None
|
||||
_style_cache = {}
|
||||
_style_cache_lock: Optional[Lock] = None
|
||||
|
||||
_task = {}
|
||||
_task_type_map = {
|
||||
"auto": 1,
|
||||
"markdown": 7,
|
||||
}
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self._api_token_cache_lock = Lock()
|
||||
self._style_cache_lock = Lock()
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
"""
|
||||
Invokes the AIPPT generate tool with the given user ID and tool parameters.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user invoking the tool.
|
||||
tool_parameters (dict[str, Any]): The parameters for the tool
|
||||
|
||||
Returns:
|
||||
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation,
|
||||
which can be a single message or a list of messages.
|
||||
"""
|
||||
title = tool_parameters.get("title", "")
|
||||
if not title:
|
||||
return self.create_text_message("Please provide a title for the ppt")
|
||||
|
||||
model = tool_parameters.get("model", "aippt")
|
||||
if not model:
|
||||
return self.create_text_message("Please provide a model for the ppt")
|
||||
|
||||
outline = tool_parameters.get("outline", "")
|
||||
|
||||
# create task
|
||||
task_id = self._create_task(
|
||||
type=self._task_type_map["auto" if not outline else "markdown"],
|
||||
title=title,
|
||||
content=outline,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# get suit
|
||||
color = tool_parameters.get("color")
|
||||
style = tool_parameters.get("style")
|
||||
|
||||
if color == "__default__":
|
||||
color_id = ""
|
||||
else:
|
||||
color_id = int(color.split("-")[1])
|
||||
|
||||
if style == "__default__":
|
||||
style_id = ""
|
||||
else:
|
||||
style_id = int(style.split("-")[1])
|
||||
|
||||
suit_id = self._get_suit(style_id=style_id, colour_id=color_id)
|
||||
|
||||
# generate outline
|
||||
if not outline:
|
||||
self._generate_outline(task_id=task_id, model=model, user_id=user_id)
|
||||
|
||||
# generate content
|
||||
self._generate_content(task_id=task_id, model=model, user_id=user_id)
|
||||
|
||||
# generate ppt
|
||||
_, ppt_url = self._generate_ppt(task_id=task_id, suit_id=suit_id, user_id=user_id)
|
||||
|
||||
return self.create_text_message(
|
||||
"""the ppt has been created successfully,"""
|
||||
f"""the ppt url is {ppt_url}"""
|
||||
"""please give the ppt url to user and direct user to download it."""
|
||||
)
|
||||
|
||||
def _create_task(self, type: int, title: str, content: str, user_id: str) -> str:
|
||||
"""
|
||||
Create a task
|
||||
|
||||
:param type: the task type
|
||||
:param title: the task title
|
||||
:param content: the task content
|
||||
|
||||
:return: the task ID
|
||||
"""
|
||||
headers = {
|
||||
"x-channel": "",
|
||||
"x-api-key": self.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
|
||||
}
|
||||
response = post(
|
||||
str(self._api_base_url / "ai" / "chat" / "v2" / "task"),
|
||||
headers=headers,
|
||||
files={"type": ("", str(type)), "title": ("", title), "content": ("", content)},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to create task: {response.get("msg")}')
|
||||
|
||||
return response.get("data", {}).get("id")
|
||||
|
||||
def _generate_outline(self, task_id: str, model: str, user_id: str) -> str:
|
||||
api_url = (
|
||||
self._api_base_url / "ai" / "chat" / "outline"
|
||||
if model == "aippt"
|
||||
else self._api_base_url / "ai" / "chat" / "wx" / "outline"
|
||||
)
|
||||
api_url %= {"task_id": task_id}
|
||||
|
||||
headers = {
|
||||
"x-channel": "",
|
||||
"x-api-key": self.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
|
||||
}
|
||||
|
||||
response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60))
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
|
||||
outline = ""
|
||||
for chunk in response.iter_lines(delimiter=b"\n\n"):
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
event = ""
|
||||
lines = chunk.decode("utf-8").split("\n")
|
||||
for line in lines:
|
||||
if line.startswith("event:"):
|
||||
event = line[6:]
|
||||
elif line.startswith("data:"):
|
||||
data = line[5:]
|
||||
if event == "message":
|
||||
try:
|
||||
data = json_loads(data)
|
||||
outline += data.get("content", "")
|
||||
except Exception as e:
|
||||
pass
|
||||
elif event == "close":
|
||||
break
|
||||
elif event in {"error", "filter"}:
|
||||
raise Exception(f"Failed to generate outline: {data}")
|
||||
|
||||
return outline
|
||||
|
||||
def _generate_content(self, task_id: str, model: str, user_id: str) -> str:
|
||||
api_url = (
|
||||
self._api_base_url / "ai" / "chat" / "content"
|
||||
if model == "aippt"
|
||||
else self._api_base_url / "ai" / "chat" / "wx" / "content"
|
||||
)
|
||||
api_url %= {"task_id": task_id}
|
||||
|
||||
headers = {
|
||||
"x-channel": "",
|
||||
"x-api-key": self.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
|
||||
}
|
||||
|
||||
response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60))
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
|
||||
if model == "aippt":
|
||||
content = ""
|
||||
for chunk in response.iter_lines(delimiter=b"\n\n"):
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
event = ""
|
||||
lines = chunk.decode("utf-8").split("\n")
|
||||
for line in lines:
|
||||
if line.startswith("event:"):
|
||||
event = line[6:]
|
||||
elif line.startswith("data:"):
|
||||
data = line[5:]
|
||||
if event == "message":
|
||||
try:
|
||||
data = json_loads(data)
|
||||
content += data.get("content", "")
|
||||
except Exception as e:
|
||||
pass
|
||||
elif event == "close":
|
||||
break
|
||||
elif event in {"error", "filter"}:
|
||||
raise Exception(f"Failed to generate content: {data}")
|
||||
|
||||
return content
|
||||
elif model == "wenxin":
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to generate content: {response.get("msg")}')
|
||||
|
||||
return response.get("data", "")
|
||||
|
||||
return ""
|
||||
|
||||
def _generate_ppt(self, task_id: str, suit_id: int, user_id) -> tuple[str, str]:
|
||||
"""
|
||||
Generate a ppt
|
||||
|
||||
:param task_id: the task ID
|
||||
:param suit_id: the suit ID
|
||||
:return: the cover url of the ppt and the ppt url
|
||||
"""
|
||||
headers = {
|
||||
"x-channel": "",
|
||||
"x-api-key": self.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
|
||||
}
|
||||
|
||||
response = post(
|
||||
str(self._api_base_url / "design" / "v2" / "save"),
|
||||
headers=headers,
|
||||
data={"task_id": task_id, "template_id": suit_id},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
|
||||
|
||||
id = response.get("data", {}).get("id")
|
||||
cover_url = response.get("data", {}).get("cover_url")
|
||||
|
||||
response = post(
|
||||
str(self._api_base_url / "download" / "export" / "file"),
|
||||
headers=headers,
|
||||
data={"id": id, "format": "ppt", "files_to_zip": False, "edit": True},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
|
||||
|
||||
export_code = response.get("data")
|
||||
if not export_code:
|
||||
raise Exception("Failed to generate ppt, the export code is empty")
|
||||
|
||||
current_iteration = 0
|
||||
while current_iteration < 50:
|
||||
# get ppt url
|
||||
response = post(
|
||||
str(self._api_base_url / "download" / "export" / "file" / "result"),
|
||||
headers=headers,
|
||||
data={"task_key": export_code},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
|
||||
|
||||
if response.get("msg") == "导出中":
|
||||
current_iteration += 1
|
||||
sleep(2)
|
||||
continue
|
||||
|
||||
ppt_url = response.get("data", [])
|
||||
if len(ppt_url) == 0:
|
||||
raise Exception("Failed to generate ppt, the ppt url is empty")
|
||||
|
||||
return cover_url, ppt_url[0]
|
||||
|
||||
raise Exception("Failed to generate ppt, the export is timeout")
|
||||
|
||||
@classmethod
|
||||
def _get_api_token(cls, credentials: dict[str, str], user_id: str) -> str:
|
||||
"""
|
||||
Get API token
|
||||
|
||||
:param credentials: the credentials
|
||||
:return: the API token
|
||||
"""
|
||||
access_key = credentials["aippt_access_key"]
|
||||
secret_key = credentials["aippt_secret_key"]
|
||||
|
||||
cache_key = f"{access_key}#@#{user_id}"
|
||||
|
||||
with cls._api_token_cache_lock:
|
||||
# clear expired tokens
|
||||
now = time()
|
||||
for key in list(cls._api_token_cache.keys()):
|
||||
if cls._api_token_cache[key]["expire"] < now:
|
||||
del cls._api_token_cache[key]
|
||||
|
||||
if cache_key in cls._api_token_cache:
|
||||
return cls._api_token_cache[cache_key]["token"]
|
||||
|
||||
# get token
|
||||
headers = {
|
||||
"x-api-key": access_key,
|
||||
"x-timestamp": str(int(now)),
|
||||
"x-signature": cls._calculate_sign(access_key, secret_key, int(now)),
|
||||
}
|
||||
|
||||
param = {"uid": user_id, "channel": ""}
|
||||
|
||||
response = get(str(cls._api_base_url / "grant" / "token"), params=param, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
|
||||
|
||||
token = response.get("data", {}).get("token")
|
||||
expire = response.get("data", {}).get("time_expire")
|
||||
|
||||
with cls._api_token_cache_lock:
|
||||
cls._api_token_cache[cache_key] = {"token": token, "expire": now + expire}
|
||||
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def _calculate_sign(cls, access_key: str, secret_key: str, timestamp: int) -> str:
|
||||
return b64encode(
|
||||
hmac_new(
|
||||
key=secret_key.encode("utf-8"), msg=f"GET@/api/grant/token/@{timestamp}".encode(), digestmod=sha1
|
||||
).digest()
|
||||
).decode("utf-8")
|
||||
|
||||
@classmethod
|
||||
def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[dict], list[dict]]:
|
||||
"""
|
||||
Get styles
|
||||
"""
|
||||
|
||||
# check cache
|
||||
with cls._style_cache_lock:
|
||||
# clear expired styles
|
||||
now = time()
|
||||
for key in list(cls._style_cache.keys()):
|
||||
if cls._style_cache[key]["expire"] < now:
|
||||
del cls._style_cache[key]
|
||||
|
||||
key = f'{credentials["aippt_access_key"]}#@#{user_id}'
|
||||
if key in cls._style_cache:
|
||||
return cls._style_cache[key]["colors"], cls._style_cache[key]["styles"]
|
||||
|
||||
headers = {
|
||||
"x-channel": "",
|
||||
"x-api-key": credentials["aippt_access_key"],
|
||||
"x-token": cls._get_api_token(credentials=credentials, user_id=user_id),
|
||||
}
|
||||
response = get(str(cls._api_base_url / "template_component" / "suit" / "select"), headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
|
||||
response = response.json()
|
||||
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
|
||||
|
||||
colors = [
|
||||
{
|
||||
"id": f'id-{item.get("id")}',
|
||||
"name": item.get("name"),
|
||||
"en_name": item.get("en_name", item.get("name")),
|
||||
}
|
||||
for item in response.get("data", {}).get("colour") or []
|
||||
]
|
||||
styles = [
|
||||
{
|
||||
"id": f'id-{item.get("id")}',
|
||||
"name": item.get("title"),
|
||||
}
|
||||
for item in response.get("data", {}).get("suit_style") or []
|
||||
]
|
||||
|
||||
with cls._style_cache_lock:
|
||||
cls._style_cache[key] = {"colors": colors, "styles": styles, "expire": now + 60 * 60}
|
||||
|
||||
return colors, styles
|
||||
|
||||
def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]:
|
||||
"""
|
||||
Get styles
|
||||
|
||||
:param credentials: the credentials
|
||||
:return: Tuple[list[dict[id, color]], list[dict[id, style]]
|
||||
"""
|
||||
if not self.runtime.credentials.get("aippt_access_key") or not self.runtime.credentials.get("aippt_secret_key"):
|
||||
raise Exception("Please provide aippt credentials")
|
||||
|
||||
return self._get_styles(credentials=self.runtime.credentials, user_id=user_id)
|
||||
|
||||
def _get_suit(self, style_id: int, colour_id: int) -> int:
|
||||
"""
|
||||
Get suit
|
||||
"""
|
||||
headers = {
|
||||
"x-channel": "",
|
||||
"x-api-key": self.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id="__dify_system__"),
|
||||
}
|
||||
response = get(
|
||||
str(self._api_base_url / "template_component" / "suit" / "search"),
|
||||
headers=headers,
|
||||
params={"style_id": style_id, "colour_id": colour_id, "page": 1, "page_size": 1},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
|
||||
response = response.json()
|
||||
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
|
||||
|
||||
if len(response.get("data", {}).get("list") or []) > 0:
|
||||
return response.get("data", {}).get("list")[0].get("id")
|
||||
|
||||
raise Exception("Failed to get suit, the suit does not exist, please check the style and color")
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
"""
|
||||
Get runtime parameters
|
||||
|
||||
Override this method to add runtime parameters to the tool.
|
||||
"""
|
||||
try:
|
||||
colors, styles = self.get_styles(user_id="__dify_system__")
|
||||
except Exception as e:
|
||||
colors, styles = (
|
||||
[{"id": "-1", "name": "__default__", "en_name": "__default__"}],
|
||||
[{"id": "-1", "name": "__default__", "en_name": "__default__"}],
|
||||
)
|
||||
|
||||
return [
|
||||
ToolParameter(
|
||||
name="color",
|
||||
label=I18nObject(zh_Hans="颜色", en_US="Color"),
|
||||
human_description=I18nObject(zh_Hans="颜色", en_US="Color"),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
required=False,
|
||||
default=colors[0]["id"],
|
||||
options=[
|
||||
ToolParameterOption(
|
||||
value=color["id"], label=I18nObject(zh_Hans=color["name"], en_US=color["en_name"])
|
||||
)
|
||||
for color in colors
|
||||
],
|
||||
),
|
||||
ToolParameter(
|
||||
name="style",
|
||||
label=I18nObject(zh_Hans="风格", en_US="Style"),
|
||||
human_description=I18nObject(zh_Hans="风格", en_US="Style"),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
required=False,
|
||||
default=styles[0]["id"],
|
||||
options=[
|
||||
ToolParameterOption(value=style["id"], label=I18nObject(zh_Hans=style["name"], en_US=style["name"]))
|
||||
for style in styles
|
||||
],
|
||||
),
|
||||
]
|
||||
@ -1,54 +0,0 @@
|
||||
identity:
|
||||
name: aippt
|
||||
author: Dify
|
||||
label:
|
||||
en_US: AIPPT
|
||||
zh_Hans: AIPPT
|
||||
description:
|
||||
human:
|
||||
en_US: AI-generated PPT with one click, input your content topic, and let AI serve you one-stop
|
||||
zh_Hans: AI一键生成PPT,输入你的内容主题,让AI为你一站式服务到底
|
||||
llm: A tool used to generate PPT with AI, input your content topic, and let AI generate PPT for you.
|
||||
parameters:
|
||||
- name: title
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Title
|
||||
zh_Hans: 标题
|
||||
human_description:
|
||||
en_US: The title of the PPT.
|
||||
zh_Hans: PPT的标题。
|
||||
llm_description: The title of the PPT, which will be used to generate the PPT outline.
|
||||
form: llm
|
||||
- name: outline
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Outline
|
||||
zh_Hans: 大纲
|
||||
human_description:
|
||||
en_US: The outline of the PPT
|
||||
zh_Hans: PPT的大纲
|
||||
llm_description: The outline of the PPT, which will be used to generate the PPT content. provide it if you have.
|
||||
form: llm
|
||||
- name: llm
|
||||
type: select
|
||||
required: true
|
||||
label:
|
||||
en_US: LLM model
|
||||
zh_Hans: 生成大纲的LLM
|
||||
options:
|
||||
- value: aippt
|
||||
label:
|
||||
en_US: AIPPT default model
|
||||
zh_Hans: AIPPT默认模型
|
||||
- value: wenxin
|
||||
label:
|
||||
en_US: Wenxin ErnieBot
|
||||
zh_Hans: 文心一言
|
||||
default: aippt
|
||||
human_description:
|
||||
en_US: The LLM model used for generating PPT outline.
|
||||
zh_Hans: 用于生成PPT大纲的LLM模型。
|
||||
form: form
|
||||
@ -1,7 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg width="56px" height="56px" viewBox="0 0 56 56" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||
<title>形状结合</title>
|
||||
<g id="设计规范" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||
<path d="M56,0 L56,56 L0,56 L0,0 L56,0 Z M31.6063018,12 L24.3936982,12 L24.1061064,12.7425499 L12.6071308,42.4324141 L12,44 L19.7849972,44 L20.0648488,43.2391815 L22.5196173,36.5567427 L33.4780427,36.5567427 L35.9351512,43.2391815 L36.2150028,44 L44,44 L43.3928692,42.4324141 L31.8938936,12.7425499 L31.6063018,12 Z M28.0163803,21.5755126 L31.1613993,30.2523823 L24.8432808,30.2523823 L28.0163803,21.5755126 Z" id="形状结合" fill="#2F4F4F"></path>
|
||||
</g>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 780 B |
@ -1,22 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.alphavantage.tools.query_stock import QueryStockTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class AlphaVantageProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
QueryStockTool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"code": "AAPL", # Apple Inc.
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
@ -1,31 +0,0 @@
|
||||
identity:
|
||||
author: zhuhao
|
||||
name: alphavantage
|
||||
label:
|
||||
en_US: AlphaVantage
|
||||
zh_Hans: AlphaVantage
|
||||
pt_BR: AlphaVantage
|
||||
description:
|
||||
en_US: AlphaVantage is an online platform that provides financial market data and APIs, making it convenient for individual investors and developers to access stock quotes, technical indicators, and stock analysis.
|
||||
zh_Hans: AlphaVantage是一个在线平台,它提供金融市场数据和API,便于个人投资者和开发者获取股票报价、技术指标和股票分析。
|
||||
pt_BR: AlphaVantage is an online platform that provides financial market data and APIs, making it convenient for individual investors and developers to access stock quotes, technical indicators, and stock analysis.
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- finance
|
||||
credentials_for_provider:
|
||||
api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: AlphaVantage API key
|
||||
zh_Hans: AlphaVantage API key
|
||||
pt_BR: AlphaVantage API key
|
||||
placeholder:
|
||||
en_US: Please input your AlphaVantage API key
|
||||
zh_Hans: 请输入你的 AlphaVantage API key
|
||||
pt_BR: Please input your AlphaVantage API key
|
||||
help:
|
||||
en_US: Get your AlphaVantage API key from AlphaVantage
|
||||
zh_Hans: 从 AlphaVantage 获取您的 AlphaVantage API key
|
||||
pt_BR: Get your AlphaVantage API key from AlphaVantage
|
||||
url: https://www.alphavantage.co/support/#api-key
|
||||
@ -1,48 +0,0 @@
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
ALPHAVANTAGE_API_URL = "https://www.alphavantage.co/query"
|
||||
|
||||
|
||||
class QueryStockTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
stock_code = tool_parameters.get("code", "")
|
||||
if not stock_code:
|
||||
return self.create_text_message("Please tell me your stock code")
|
||||
|
||||
if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"):
|
||||
return self.create_text_message("Alpha Vantage API key is required.")
|
||||
|
||||
params = {
|
||||
"function": "TIME_SERIES_DAILY",
|
||||
"symbol": stock_code,
|
||||
"outputsize": "compact",
|
||||
"datatype": "json",
|
||||
"apikey": self.runtime.credentials["api_key"],
|
||||
}
|
||||
response = requests.get(url=ALPHAVANTAGE_API_URL, params=params)
|
||||
response.raise_for_status()
|
||||
result = self._handle_response(response.json())
|
||||
return self.create_json_message(result)
|
||||
|
||||
def _handle_response(self, response: dict[str, Any]) -> dict[str, Any]:
|
||||
result = response.get("Time Series (Daily)", {})
|
||||
if not result:
|
||||
return {}
|
||||
stock_result = {}
|
||||
for k, v in result.items():
|
||||
stock_result[k] = {}
|
||||
stock_result[k]["open"] = v.get("1. open")
|
||||
stock_result[k]["high"] = v.get("2. high")
|
||||
stock_result[k]["low"] = v.get("3. low")
|
||||
stock_result[k]["close"] = v.get("4. close")
|
||||
stock_result[k]["volume"] = v.get("5. volume")
|
||||
return stock_result
|
||||
@ -1,27 +0,0 @@
|
||||
identity:
|
||||
name: query_stock
|
||||
author: zhuhao
|
||||
label:
|
||||
en_US: query_stock
|
||||
zh_Hans: query_stock
|
||||
pt_BR: query_stock
|
||||
description:
|
||||
human:
|
||||
en_US: Retrieve information such as daily opening price, daily highest price, daily lowest price, daily closing price, and daily trading volume for a specified stock symbol.
|
||||
zh_Hans: 获取指定股票代码的每日开盘价、每日最高价、每日最低价、每日收盘价和每日交易量等信息。
|
||||
pt_BR: Retrieve information such as daily opening price, daily highest price, daily lowest price, daily closing price, and daily trading volume for a specified stock symbol
|
||||
llm: Retrieve information such as daily opening price, daily highest price, daily lowest price, daily closing price, and daily trading volume for a specified stock symbol
|
||||
parameters:
|
||||
- name: code
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: stock code
|
||||
zh_Hans: 股票代码
|
||||
pt_BR: stock code
|
||||
human_description:
|
||||
en_US: stock code
|
||||
zh_Hans: 股票代码
|
||||
pt_BR: stock code
|
||||
llm_description: stock code for query from alphavantage
|
||||
form: llm
|
||||
@ -1 +0,0 @@
|
||||
<svg id="logomark" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 17.732 24.269"><g id="tiny"><path d="M573.549,280.916l2.266,2.738,6.674-7.84c.353-.47.52-.717.353-1.117a1.218,1.218,0,0,0-1.061-.748h0a.953.953,0,0,0-.712.262Z" transform="translate(-566.984 -271.548)" fill="#bdb9b4"/><path d="M579.525,282.225l-10.606-10.174a1.413,1.413,0,0,0-.834-.5,1.09,1.09,0,0,0-1.027.66c-.167.4-.047.681.319,1.206l8.44,10.242h0l-6.282,7.716a1.336,1.336,0,0,0-.323,1.3,1.114,1.114,0,0,0,1.04.69A.992.992,0,0,0,571,293l8.519-7.92A1.924,1.924,0,0,0,579.525,282.225Z" transform="translate(-566.984 -271.548)" fill="#b31b1b"/><path d="M584.32,293.912l-8.525-10.275,0,0L573.53,280.9l-1.389,1.254a2.063,2.063,0,0,0,0,2.965l10.812,10.419a.925.925,0,0,0,.742.282,1.039,1.039,0,0,0,.953-.667A1.261,1.261,0,0,0,584.32,293.912Z" transform="translate(-566.984 -271.548)" fill="#bdb9b4"/></g></svg>
|
||||
|
Before Width: | Height: | Size: 874 B |
@ -1,20 +0,0 @@
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.arxiv.tools.arxiv_search import ArxivSearchTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class ArxivProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
ArxivSearchTool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"query": "John Doe",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
@ -1,12 +0,0 @@
|
||||
identity:
|
||||
author: Yash Parmar
|
||||
name: arxiv
|
||||
label:
|
||||
en_US: ArXiv
|
||||
zh_Hans: ArXiv
|
||||
description:
|
||||
en_US: Access to a vast repository of scientific papers and articles in various fields of research.
|
||||
zh_Hans: 访问各个研究领域大量科学论文和文章的存储库。
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- search
|
||||
@ -1,119 +0,0 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
import arxiv
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ArxivAPIWrapper(BaseModel):
|
||||
"""Wrapper around ArxivAPI.
|
||||
|
||||
To use, you should have the ``arxiv`` python package installed.
|
||||
https://lukasschwab.me/arxiv.py/index.html
|
||||
This wrapper will use the Arxiv API to conduct searches and
|
||||
fetch document summaries. By default, it will return the document summaries
|
||||
of the top-k results.
|
||||
It limits the Document content by doc_content_chars_max.
|
||||
Set doc_content_chars_max=None if you don't want to limit the content size.
|
||||
|
||||
Args:
|
||||
top_k_results: number of the top-scored document used for the arxiv tool
|
||||
ARXIV_MAX_QUERY_LENGTH: the cut limit on the query used for the arxiv tool.
|
||||
load_max_docs: a limit to the number of loaded documents
|
||||
load_all_available_meta:
|
||||
if True: the `metadata` of the loaded Documents contains all available
|
||||
meta info (see https://lukasschwab.me/arxiv.py/index.html#Result),
|
||||
if False: the `metadata` contains only the published date, title,
|
||||
authors and summary.
|
||||
doc_content_chars_max: an optional cut limit for the length of a document's
|
||||
content
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
arxiv = ArxivAPIWrapper(
|
||||
top_k_results = 3,
|
||||
ARXIV_MAX_QUERY_LENGTH = 300,
|
||||
load_max_docs = 3,
|
||||
load_all_available_meta = False,
|
||||
doc_content_chars_max = 40000
|
||||
)
|
||||
arxiv.run("tree of thought llm)
|
||||
"""
|
||||
|
||||
arxiv_search: type[arxiv.Search] = arxiv.Search #: :meta private:
|
||||
arxiv_http_error: tuple[type[Exception]] = (arxiv.ArxivError, arxiv.UnexpectedEmptyPageError, arxiv.HTTPError)
|
||||
top_k_results: int = 3
|
||||
ARXIV_MAX_QUERY_LENGTH: int = 300
|
||||
load_max_docs: int = 100
|
||||
load_all_available_meta: bool = False
|
||||
doc_content_chars_max: Optional[int] = 4000
|
||||
|
||||
def run(self, query: str) -> str:
|
||||
"""
|
||||
Performs an arxiv search and A single string
|
||||
with the publish date, title, authors, and summary
|
||||
for each article separated by two newlines.
|
||||
|
||||
If an error occurs or no documents found, error text
|
||||
is returned instead. Wrapper for
|
||||
https://lukasschwab.me/arxiv.py/index.html#Search
|
||||
|
||||
Args:
|
||||
query: a plaintext search query
|
||||
"""
|
||||
try:
|
||||
results = self.arxiv_search( # type: ignore
|
||||
query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.top_k_results
|
||||
).results()
|
||||
except arxiv_http_error as ex:
|
||||
return f"Arxiv exception: {ex}"
|
||||
docs = [
|
||||
f"Published: {result.updated.date()}\n"
|
||||
f"Title: {result.title}\n"
|
||||
f"Authors: {', '.join(a.name for a in result.authors)}\n"
|
||||
f"Summary: {result.summary}"
|
||||
for result in results
|
||||
]
|
||||
if docs:
|
||||
return "\n\n".join(docs)[: self.doc_content_chars_max]
|
||||
else:
|
||||
return "No good Arxiv Result was found"
|
||||
|
||||
|
||||
class ArxivSearchInput(BaseModel):
|
||||
query: str = Field(..., description="Search query.")
|
||||
|
||||
|
||||
class ArxivSearchTool(BuiltinTool):
|
||||
"""
|
||||
A tool for searching articles on Arxiv.
|
||||
"""
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
"""
|
||||
Invokes the Arxiv search tool with the given user ID and tool parameters.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user invoking the tool.
|
||||
tool_parameters (dict[str, Any]): The parameters for the tool, including the 'query' parameter.
|
||||
|
||||
Returns:
|
||||
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation,
|
||||
which can be a single message or a list of messages.
|
||||
"""
|
||||
query = tool_parameters.get("query", "")
|
||||
|
||||
if not query:
|
||||
return self.create_text_message("Please input query")
|
||||
|
||||
arxiv = ArxivAPIWrapper()
|
||||
|
||||
response = arxiv.run(query)
|
||||
|
||||
return self.create_text_message(self.summary(user_id=user_id, content=response))
|
||||
@ -1,23 +0,0 @@
|
||||
identity:
|
||||
name: arxiv_search
|
||||
author: Yash Parmar
|
||||
label:
|
||||
en_US: Arxiv Search
|
||||
zh_Hans: Arxiv 搜索
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for searching scientific papers and articles from the Arxiv repository. Input can be an Arxiv ID or an author's name.
|
||||
zh_Hans: 一个用于从Arxiv存储库搜索科学论文和文章的工具。 输入可以是Arxiv ID或作者姓名。
|
||||
llm: A tool for searching scientific papers and articles from the Arxiv repository. Input can be an Arxiv ID or an author's name.
|
||||
parameters:
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Query string
|
||||
zh_Hans: 查询字符串
|
||||
human_description:
|
||||
en_US: The Arxiv ID or author's name used for searching.
|
||||
zh_Hans: 用于搜索的Arxiv ID或作者姓名。
|
||||
llm_description: The Arxiv ID or author's name used for searching.
|
||||
form: llm
|
||||
@ -1,9 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
|
||||
<!-- Uploaded to: SVG Repo, www.svgrepo.com, Generator: SVG Repo Mixer Tools -->
|
||||
<svg width="800px" height="800px" viewBox="0 0 16 16" xmlns="http://www.w3.org/2000/svg" fill="none">
|
||||
|
||||
<path fill="#252F3E" d="M4.51 7.687c0 .197.02.357.058.475.042.117.096.245.17.384a.233.233 0 01.037.123c0 .053-.032.107-.1.16l-.336.224a.255.255 0 01-.138.048c-.054 0-.107-.026-.16-.074a1.652 1.652 0 01-.192-.251 4.137 4.137 0 01-.165-.315c-.415.491-.936.737-1.564.737-.447 0-.804-.129-1.064-.385-.261-.256-.394-.598-.394-1.025 0-.454.16-.822.484-1.1.325-.278.756-.416 1.304-.416.18 0 .367.016.564.042.197.027.4.07.612.118v-.39c0-.406-.085-.689-.25-.854-.17-.166-.458-.246-.868-.246-.186 0-.377.022-.574.07a4.23 4.23 0 00-.575.181 1.525 1.525 0 01-.186.07.326.326 0 01-.085.016c-.075 0-.112-.054-.112-.166v-.262c0-.085.01-.15.037-.186a.399.399 0 01.15-.113c.185-.096.409-.176.67-.24.26-.07.537-.101.83-.101.633 0 1.096.144 1.394.432.293.288.442.726.442 1.314v1.73h.01zm-2.161.811c.175 0 .356-.032.548-.096.191-.064.362-.182.505-.342a.848.848 0 00.181-.341c.032-.129.054-.283.054-.465V7.03a4.43 4.43 0 00-.49-.09 3.996 3.996 0 00-.5-.033c-.357 0-.618.07-.793.214-.176.144-.26.347-.26.614 0 .25.063.437.196.566.128.133.314.197.559.197zm4.273.577c-.096 0-.16-.016-.202-.054-.043-.032-.08-.106-.112-.208l-1.25-4.127a.938.938 0 01-.049-.214c0-.085.043-.133.128-.133h.522c.1 0 .17.016.207.053.043.032.075.107.107.208l.894 3.535.83-3.535c.026-.106.058-.176.1-.208a.365.365 0 01.214-.053h.425c.102 0 .17.016.213.053.043.032.08.107.101.208l.841 3.578.92-3.578a.458.458 0 01.107-.208.346.346 0 01.208-.053h.495c.085 0 .133.043.133.133 0 .027-.006.054-.01.086a.76.76 0 01-.038.133l-1.283 4.127c-.032.107-.069.177-.111.209a.34.34 0 01-.203.053h-.457c-.101 0-.17-.016-.213-.053-.043-.038-.08-.107-.101-.214L8.213 5.37l-.82 3.439c-.026.107-.058.176-.1.213-.043.038-.118.054-.213.054h-.458zm6.838.144a3.51 3.51 0 01-.82-.096c-.266-.064-.473-.134-.612-.214-.085-.048-.143-.101-.165-.15a.378.378 0 01-.031-.149v-.272c0-.112.042-.166.122-.166a.3.3 0 01.096.016c.032.011.08.032.133.054.18.08.378.144.585.187.213.042.42.064.633.064.336 0 .596-.059.777-.176a.575.575 0 00.277-.508.52.52 0 00-.144-.373c-.095-.102-.276-.193-.537-.278l-.772-.24c-.388-.123-.676-.305-.851-.545a1.275 1.275 0 01-.266-.774c0-.224.048-.422.143-.593.096-.17.224-.32.384-.438.16-.122.34-.213.553-.277.213-.064.436-.091.67-.091.118 0 .24.005.357.021.122.016.234.038.346.06.106.026.208.052.303.085.096.032.17.064.224.096a.46.46 0 01.16.133.289.289 0 01.047.176v.251c0 .112-.042.171-.122.171a.552.552 0 01-.202-.064 2.427 2.427 0 00-1.022-.208c-.303 0-.543.048-.708.15-.165.1-.25.256-.25.475 0 .149.053.277.16.379.106.101.303.202.585.293l.756.24c.383.123.66.294.825.513.165.219.244.47.244.748 0 .23-.047.437-.138.619a1.436 1.436 0 01-.388.47c-.165.133-.362.23-.591.299-.24.075-.49.112-.761.112z"/>
|
||||
|
||||
<g fill="#F90" fill-rule="evenodd" clip-rule="evenodd">
|
||||
|
||||
|
Before Width: | Height: | Size: 3.3 KiB |
@ -1,24 +0,0 @@
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.aws.tools.sagemaker_text_rerank import SageMakerReRankTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class SageMakerProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
SageMakerReRankTool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"sagemaker_endpoint": "",
|
||||
"query": "misaka mikoto",
|
||||
"candidate_texts": "hello$$$hello world",
|
||||
"topk": 5,
|
||||
"aws_region": "",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
@ -1,15 +0,0 @@
|
||||
identity:
|
||||
author: AWS
|
||||
name: aws
|
||||
label:
|
||||
en_US: AWS
|
||||
zh_Hans: 亚马逊云科技
|
||||
pt_BR: AWS
|
||||
description:
|
||||
en_US: Services on AWS.
|
||||
zh_Hans: 亚马逊云科技的各类服务
|
||||
pt_BR: Services on AWS.
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- search
|
||||
credentials_for_provider:
|
||||
@ -1,90 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import BotoCoreError
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GuardrailParameters(BaseModel):
|
||||
guardrail_id: str = Field(..., description="The identifier of the guardrail")
|
||||
guardrail_version: str = Field(..., description="The version of the guardrail")
|
||||
source: str = Field(..., description="The source of the content")
|
||||
text: str = Field(..., description="The text to apply the guardrail to")
|
||||
aws_region: str = Field(..., description="AWS region for the Bedrock client")
|
||||
|
||||
|
||||
class ApplyGuardrailTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
Invoke the ApplyGuardrail tool
|
||||
"""
|
||||
try:
|
||||
# Validate and parse input parameters
|
||||
params = GuardrailParameters(**tool_parameters)
|
||||
|
||||
# Initialize AWS client
|
||||
bedrock_client = boto3.client("bedrock-runtime", region_name=params.aws_region)
|
||||
|
||||
# Apply guardrail
|
||||
response = bedrock_client.apply_guardrail(
|
||||
guardrailIdentifier=params.guardrail_id,
|
||||
guardrailVersion=params.guardrail_version,
|
||||
source=params.source,
|
||||
content=[{"text": {"text": params.text}}],
|
||||
)
|
||||
|
||||
logger.info(f"Raw response from AWS: {json.dumps(response, indent=2)}")
|
||||
|
||||
# Check for empty response
|
||||
if not response:
|
||||
return self.create_text_message(text="Received empty response from AWS Bedrock.")
|
||||
|
||||
# Process the result
|
||||
action = response.get("action", "No action specified")
|
||||
outputs = response.get("outputs", [])
|
||||
output = outputs[0].get("text", "No output received") if outputs else "No output received"
|
||||
assessments = response.get("assessments", [])
|
||||
|
||||
# Format assessments
|
||||
formatted_assessments = []
|
||||
for assessment in assessments:
|
||||
for policy_type, policy_data in assessment.items():
|
||||
if isinstance(policy_data, dict) and "topics" in policy_data:
|
||||
for topic in policy_data["topics"]:
|
||||
formatted_assessments.append(
|
||||
f"Policy: {policy_type}, Topic: {topic['name']}, Type: {topic['type']},"
|
||||
f" Action: {topic['action']}"
|
||||
)
|
||||
else:
|
||||
formatted_assessments.append(f"Policy: {policy_type}, Data: {policy_data}")
|
||||
|
||||
result = f"Action: {action}\n "
|
||||
result += f"Output: {output}\n "
|
||||
if formatted_assessments:
|
||||
result += "Assessments:\n " + "\n ".join(formatted_assessments) + "\n "
|
||||
# result += f"Full response: {json.dumps(response, indent=2, ensure_ascii=False)}"
|
||||
|
||||
return self.create_text_message(text=result)
|
||||
|
||||
except BotoCoreError as e:
|
||||
error_message = f"AWS service error: {str(e)}"
|
||||
logger.error(error_message, exc_info=True)
|
||||
return self.create_text_message(text=error_message)
|
||||
except json.JSONDecodeError as e:
|
||||
error_message = f"JSON parsing error: {str(e)}"
|
||||
logger.error(error_message, exc_info=True)
|
||||
return self.create_text_message(text=error_message)
|
||||
except Exception as e:
|
||||
error_message = f"An unexpected error occurred: {str(e)}"
|
||||
logger.error(error_message, exc_info=True)
|
||||
return self.create_text_message(text=error_message)
|
||||
@ -1,67 +0,0 @@
|
||||
identity:
|
||||
name: apply_guardrail
|
||||
author: AWS
|
||||
label:
|
||||
en_US: Content Moderation Guardrails
|
||||
zh_Hans: 内容审查护栏
|
||||
description:
|
||||
human:
|
||||
en_US: Content Moderation Guardrails utilizes the ApplyGuardrail API, a feature of Guardrails for Amazon Bedrock. This API is capable of evaluating input prompts and model responses for all Foundation Models (FMs), including those on Amazon Bedrock, custom FMs, and third-party FMs. By implementing this functionality, organizations can achieve centralized governance across all their generative AI applications, thereby enhancing control and consistency in content moderation.
|
||||
zh_Hans: 内容审查护栏采用 Guardrails for Amazon Bedrock 功能中的 ApplyGuardrail API 。ApplyGuardrail 可以评估所有基础模型(FMs)的输入提示和模型响应,包括 Amazon Bedrock 上的 FMs、自定义 FMs 和第三方 FMs。通过实施这一功能, 组织可以在所有生成式 AI 应用程序中实现集中化的治理,从而增强内容审核的控制力和一致性。
|
||||
llm: Content Moderation Guardrails utilizes the ApplyGuardrail API, a feature of Guardrails for Amazon Bedrock. This API is capable of evaluating input prompts and model responses for all Foundation Models (FMs), including those on Amazon Bedrock, custom FMs, and third-party FMs. By implementing this functionality, organizations can achieve centralized governance across all their generative AI applications, thereby enhancing control and consistency in content moderation.
|
||||
parameters:
|
||||
- name: guardrail_id
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Guardrail ID
|
||||
zh_Hans: Guardrail ID
|
||||
human_description:
|
||||
en_US: Please enter the ID of the Guardrail that has already been created on Amazon Bedrock, for example 'qk5nk0e4b77b'.
|
||||
zh_Hans: 请输入已经在 Amazon Bedrock 上创建好的 Guardrail ID, 例如 'qk5nk0e4b77b'.
|
||||
llm_description: Please enter the ID of the Guardrail that has already been created on Amazon Bedrock, for example 'qk5nk0e4b77b'.
|
||||
form: form
|
||||
- name: guardrail_version
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Guardrail Version Number
|
||||
zh_Hans: Guardrail 版本号码
|
||||
human_description:
|
||||
en_US: Please enter the published version of the Guardrail ID that has already been created on Amazon Bedrock. This is typically a version number, such as 2.
|
||||
zh_Hans: 请输入已经在Amazon Bedrock 上创建好的Guardrail ID发布的版本, 通常使用版本号, 例如2.
|
||||
llm_description: Please enter the published version of the Guardrail ID that has already been created on Amazon Bedrock. This is typically a version number, such as 2.
|
||||
form: form
|
||||
- name: source
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Content Source (INPUT or OUTPUT)
|
||||
zh_Hans: 内容来源 (INPUT or OUTPUT)
|
||||
human_description:
|
||||
en_US: The source of data used in the request to apply the guardrail. Valid Values "INPUT | OUTPUT"
|
||||
zh_Hans: 用于应用护栏的请求中所使用的数据来源。有效值为 "INPUT | OUTPUT"
|
||||
llm_description: The source of data used in the request to apply the guardrail. Valid Values "INPUT | OUTPUT"
|
||||
form: form
|
||||
- name: text
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Content to be reviewed
|
||||
zh_Hans: 待审查内容
|
||||
human_description:
|
||||
en_US: The content used for requesting guardrail review, which can be either user input or LLM output.
|
||||
zh_Hans: 用于请求护栏审查的内容,可以是用户输入或 LLM 输出。
|
||||
llm_description: The content used for requesting guardrail review, which can be either user input or LLM output.
|
||||
form: llm
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: AWS Region
|
||||
zh_Hans: AWS 区域
|
||||
human_description:
|
||||
en_US: Please enter the AWS region for the Bedrock client, for example 'us-east-1'.
|
||||
zh_Hans: 请输入 Bedrock 客户端的 AWS 区域,例如 'us-east-1'。
|
||||
llm_description: Please enter the AWS region for the Bedrock client, for example 'us-east-1'.
|
||||
form: form
|
||||
@ -1,91 +0,0 @@
|
||||
import json
|
||||
from typing import Any, Union
|
||||
|
||||
import boto3
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class LambdaTranslateUtilsTool(BuiltinTool):
|
||||
lambda_client: Any = None
|
||||
|
||||
def _invoke_lambda(self, text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name):
|
||||
msg = {
|
||||
"src_content": text_content,
|
||||
"src_lang": src_lang,
|
||||
"dest_lang": dest_lang,
|
||||
"dictionary_id": dictionary_name,
|
||||
"request_type": request_type,
|
||||
"model_id": model_id,
|
||||
}
|
||||
|
||||
invoke_response = self.lambda_client.invoke(
|
||||
FunctionName=lambda_name, InvocationType="RequestResponse", Payload=json.dumps(msg)
|
||||
)
|
||||
response_body = invoke_response["Payload"]
|
||||
|
||||
response_str = response_body.read().decode("unicode_escape")
|
||||
|
||||
return response_str
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
line = 0
|
||||
try:
|
||||
if not self.lambda_client:
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
if aws_region:
|
||||
self.lambda_client = boto3.client("lambda", region_name=aws_region)
|
||||
else:
|
||||
self.lambda_client = boto3.client("lambda")
|
||||
|
||||
line = 1
|
||||
text_content = tool_parameters.get("text_content", "")
|
||||
if not text_content:
|
||||
return self.create_text_message("Please input text_content")
|
||||
|
||||
line = 2
|
||||
src_lang = tool_parameters.get("src_lang", "")
|
||||
if not src_lang:
|
||||
return self.create_text_message("Please input src_lang")
|
||||
|
||||
line = 3
|
||||
dest_lang = tool_parameters.get("dest_lang", "")
|
||||
if not dest_lang:
|
||||
return self.create_text_message("Please input dest_lang")
|
||||
|
||||
line = 4
|
||||
lambda_name = tool_parameters.get("lambda_name", "")
|
||||
if not lambda_name:
|
||||
return self.create_text_message("Please input lambda_name")
|
||||
|
||||
line = 5
|
||||
request_type = tool_parameters.get("request_type", "")
|
||||
if not request_type:
|
||||
return self.create_text_message("Please input request_type")
|
||||
|
||||
line = 6
|
||||
model_id = tool_parameters.get("model_id", "")
|
||||
if not model_id:
|
||||
return self.create_text_message("Please input model_id")
|
||||
|
||||
line = 7
|
||||
dictionary_name = tool_parameters.get("dictionary_name", "")
|
||||
if not dictionary_name:
|
||||
return self.create_text_message("Please input dictionary_name")
|
||||
|
||||
result = self._invoke_lambda(
|
||||
text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name
|
||||
)
|
||||
|
||||
return self.create_text_message(text=result)
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Exception {str(e)}, line : {line}")
|
||||
@ -1,134 +0,0 @@
|
||||
identity:
|
||||
name: lambda_translate_utils
|
||||
author: AWS
|
||||
label:
|
||||
en_US: TranslateTool
|
||||
zh_Hans: 翻译工具
|
||||
pt_BR: TranslateTool
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: A util tools for LLM translation, extra deployment is needed on AWS. Please refer Github Repo - https://github.com/ybalbert001/dynamodb-rag
|
||||
zh_Hans: 大语言模型翻译工具(专词映射获取),需要在AWS上进行额外部署,可参考Github Repo - https://github.com/ybalbert001/dynamodb-rag
|
||||
pt_BR: A util tools for LLM translation, specific Lambda Function deployment is needed on AWS. Please refer Github Repo - https://github.com/ybalbert001/dynamodb-rag
|
||||
llm: A util tools for translation.
|
||||
parameters:
|
||||
- name: text_content
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: source content for translation
|
||||
zh_Hans: 待翻译原文
|
||||
pt_BR: source content for translation
|
||||
human_description:
|
||||
en_US: source content for translation
|
||||
zh_Hans: 待翻译原文
|
||||
pt_BR: source content for translation
|
||||
llm_description: source content for translation
|
||||
form: llm
|
||||
- name: src_lang
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: source language code
|
||||
zh_Hans: 原文语言代号
|
||||
pt_BR: source language code
|
||||
human_description:
|
||||
en_US: source language code
|
||||
zh_Hans: 原文语言代号
|
||||
pt_BR: source language code
|
||||
llm_description: source language code
|
||||
form: llm
|
||||
- name: dest_lang
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: target language code
|
||||
zh_Hans: 目标语言代号
|
||||
pt_BR: target language code
|
||||
human_description:
|
||||
en_US: target language code
|
||||
zh_Hans: 目标语言代号
|
||||
pt_BR: target language code
|
||||
llm_description: target language code
|
||||
form: llm
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: region of Lambda
|
||||
zh_Hans: Lambda 所在的region
|
||||
pt_BR: region of Lambda
|
||||
human_description:
|
||||
en_US: region of Lambda
|
||||
zh_Hans: Lambda 所在的region
|
||||
pt_BR: region of Lambda
|
||||
llm_description: region of Lambda
|
||||
form: form
|
||||
- name: model_id
|
||||
type: string
|
||||
required: false
|
||||
default: anthropic.claude-3-sonnet-20240229-v1:0
|
||||
label:
|
||||
en_US: LLM model_id in bedrock
|
||||
zh_Hans: bedrock上的大语言模型model_id
|
||||
pt_BR: LLM model_id in bedrock
|
||||
human_description:
|
||||
en_US: LLM model_id in bedrock
|
||||
zh_Hans: bedrock上的大语言模型model_id
|
||||
pt_BR: LLM model_id in bedrock
|
||||
llm_description: LLM model_id in bedrock
|
||||
form: form
|
||||
- name: dictionary_name
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: dictionary name for term mapping
|
||||
zh_Hans: 专词映射表名称
|
||||
pt_BR: dictionary name for term mapping
|
||||
human_description:
|
||||
en_US: dictionary name for term mapping
|
||||
zh_Hans: 专词映射表名称
|
||||
pt_BR: dictionary name for term mapping
|
||||
llm_description: dictionary name for term mapping
|
||||
form: form
|
||||
- name: request_type
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: request type
|
||||
zh_Hans: 请求类型
|
||||
pt_BR: request type
|
||||
human_description:
|
||||
en_US: request type
|
||||
zh_Hans: 请求类型
|
||||
pt_BR: request type
|
||||
default: term_mapping
|
||||
options:
|
||||
- value: term_mapping
|
||||
label:
|
||||
en_US: term_mapping
|
||||
zh_Hans: 专词映射
|
||||
- value: segment_only
|
||||
label:
|
||||
en_US: segment_only
|
||||
zh_Hans: 仅切词
|
||||
- value: translate
|
||||
label:
|
||||
en_US: translate
|
||||
zh_Hans: 翻译内容
|
||||
form: form
|
||||
- name: lambda_name
|
||||
type: string
|
||||
default: "translate_tool"
|
||||
required: true
|
||||
label:
|
||||
en_US: AWS Lambda for term mapping retrieval
|
||||
zh_Hans: 专词召回映射 - AWS Lambda
|
||||
pt_BR: lambda name for term mapping retrieval
|
||||
human_description:
|
||||
en_US: AWS Lambda for term mapping retrieval
|
||||
zh_Hans: 专词召回映射 - AWS Lambda
|
||||
pt_BR: AWS Lambda for term mapping retrieval
|
||||
llm_description: AWS Lambda for term mapping retrieval
|
||||
form: form
|
||||
@ -1,70 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
|
||||
import boto3
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
console_handler = logging.StreamHandler()
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
|
||||
class LambdaYamlToJsonTool(BuiltinTool):
|
||||
lambda_client: Any = None
|
||||
|
||||
def _invoke_lambda(self, lambda_name: str, yaml_content: str) -> str:
|
||||
msg = {"body": yaml_content}
|
||||
logger.info(json.dumps(msg))
|
||||
|
||||
invoke_response = self.lambda_client.invoke(
|
||||
FunctionName=lambda_name, InvocationType="RequestResponse", Payload=json.dumps(msg)
|
||||
)
|
||||
response_body = invoke_response["Payload"]
|
||||
|
||||
response_str = response_body.read().decode("utf-8")
|
||||
resp_json = json.loads(response_str)
|
||||
|
||||
logger.info(resp_json)
|
||||
if resp_json["statusCode"] != 200:
|
||||
raise Exception(f"Invalid status code: {response_str}")
|
||||
|
||||
return resp_json["body"]
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
try:
|
||||
if not self.lambda_client:
|
||||
aws_region = tool_parameters.get("aws_region") # todo: move aws_region out, and update client region
|
||||
if aws_region:
|
||||
self.lambda_client = boto3.client("lambda", region_name=aws_region)
|
||||
else:
|
||||
self.lambda_client = boto3.client("lambda")
|
||||
|
||||
yaml_content = tool_parameters.get("yaml_content", "")
|
||||
if not yaml_content:
|
||||
return self.create_text_message("Please input yaml_content")
|
||||
|
||||
lambda_name = tool_parameters.get("lambda_name", "")
|
||||
if not lambda_name:
|
||||
return self.create_text_message("Please input lambda_name")
|
||||
logger.debug(f"{json.dumps(tool_parameters, indent=2, ensure_ascii=False)}")
|
||||
|
||||
result = self._invoke_lambda(lambda_name, yaml_content)
|
||||
logger.debug(result)
|
||||
|
||||
return self.create_text_message(result)
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Exception: {str(e)}")
|
||||
|
||||
console_handler.flush()
|
||||
@ -1,53 +0,0 @@
|
||||
identity:
|
||||
name: lambda_yaml_to_json
|
||||
author: AWS
|
||||
label:
|
||||
en_US: LambdaYamlToJson
|
||||
zh_Hans: LambdaYamlToJson
|
||||
pt_BR: LambdaYamlToJson
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: A tool to convert yaml to json using AWS Lambda.
|
||||
zh_Hans: 将 YAML 转为 JSON 的工具(通过AWS Lambda)。
|
||||
pt_BR: A tool to convert yaml to json using AWS Lambda.
|
||||
llm: A tool to convert yaml to json.
|
||||
parameters:
|
||||
- name: yaml_content
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: YAML content to convert for
|
||||
zh_Hans: YAML 内容
|
||||
pt_BR: YAML content to convert for
|
||||
human_description:
|
||||
en_US: YAML content to convert for
|
||||
zh_Hans: YAML 内容
|
||||
pt_BR: YAML content to convert for
|
||||
llm_description: YAML content to convert for
|
||||
form: llm
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: region of lambda
|
||||
zh_Hans: Lambda 所在的region
|
||||
pt_BR: region of lambda
|
||||
human_description:
|
||||
en_US: region of lambda
|
||||
zh_Hans: Lambda 所在的region
|
||||
pt_BR: region of lambda
|
||||
llm_description: region of lambda
|
||||
form: form
|
||||
- name: lambda_name
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: name of lambda
|
||||
zh_Hans: Lambda 名称
|
||||
pt_BR: name of lambda
|
||||
human_description:
|
||||
en_US: name of lambda
|
||||
zh_Hans: Lambda 名称
|
||||
pt_BR: name of lambda
|
||||
form: form
|
||||
@ -1,81 +0,0 @@
|
||||
import json
|
||||
import operator
|
||||
from typing import Any, Union
|
||||
|
||||
import boto3
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class SageMakerReRankTool(BuiltinTool):
|
||||
sagemaker_client: Any = None
|
||||
sagemaker_endpoint: str = None
|
||||
topk: int = None
|
||||
|
||||
def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str):
|
||||
inputs = [query_input] * len(docs)
|
||||
response_model = self.sagemaker_client.invoke_endpoint(
|
||||
EndpointName=rerank_endpoint,
|
||||
Body=json.dumps({"inputs": inputs, "docs": docs}),
|
||||
ContentType="application/json",
|
||||
)
|
||||
json_str = response_model["Body"].read().decode("utf8")
|
||||
json_obj = json.loads(json_str)
|
||||
scores = json_obj["scores"]
|
||||
return scores if isinstance(scores, list) else [scores]
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
line = 0
|
||||
try:
|
||||
if not self.sagemaker_client:
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
if aws_region:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||
|
||||
line = 1
|
||||
if not self.sagemaker_endpoint:
|
||||
self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint")
|
||||
|
||||
line = 2
|
||||
if not self.topk:
|
||||
self.topk = tool_parameters.get("topk", 5)
|
||||
|
||||
line = 3
|
||||
query = tool_parameters.get("query", "")
|
||||
if not query:
|
||||
return self.create_text_message("Please input query")
|
||||
|
||||
line = 4
|
||||
candidate_texts = tool_parameters.get("candidate_texts")
|
||||
if not candidate_texts:
|
||||
return self.create_text_message("Please input candidate_texts")
|
||||
|
||||
line = 5
|
||||
candidate_docs = json.loads(candidate_texts)
|
||||
docs = [item.get("content") for item in candidate_docs]
|
||||
|
||||
line = 6
|
||||
scores = self._sagemaker_rerank(query_input=query, docs=docs, rerank_endpoint=self.sagemaker_endpoint)
|
||||
|
||||
line = 7
|
||||
for idx in range(len(candidate_docs)):
|
||||
candidate_docs[idx]["score"] = scores[idx]
|
||||
|
||||
line = 8
|
||||
sorted_candidate_docs = sorted(candidate_docs, key=operator.itemgetter("score"), reverse=True)
|
||||
|
||||
line = 9
|
||||
return [self.create_json_message(res) for res in sorted_candidate_docs[: self.topk]]
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Exception {str(e)}, line : {line}")
|
||||
@ -1,82 +0,0 @@
|
||||
identity:
|
||||
name: sagemaker_text_rerank
|
||||
author: AWS
|
||||
label:
|
||||
en_US: SagemakerRerank
|
||||
zh_Hans: Sagemaker重排序
|
||||
pt_BR: SagemakerRerank
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for performing text similarity ranking. You can find deploy notebook on Github Repo - https://github.com/aws-samples/dify-aws-tool
|
||||
zh_Hans: Sagemaker重排序工具, 请参考 Github Repo - https://github.com/aws-samples/dify-aws-tool上的部署脚本
|
||||
pt_BR: A tool for performing text similarity ranking.
|
||||
llm: A tool for performing text similarity ranking. You can find deploy notebook on Github Repo - https://github.com/aws-samples/dify-aws-tool
|
||||
parameters:
|
||||
- name: sagemaker_endpoint
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: sagemaker endpoint for reranking
|
||||
zh_Hans: 重排序的SageMaker 端点
|
||||
pt_BR: sagemaker endpoint for reranking
|
||||
human_description:
|
||||
en_US: sagemaker endpoint for reranking
|
||||
zh_Hans: 重排序的SageMaker 端点
|
||||
pt_BR: sagemaker endpoint for reranking
|
||||
llm_description: sagemaker endpoint for reranking
|
||||
form: form
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Query string
|
||||
zh_Hans: 查询语句
|
||||
pt_BR: Query string
|
||||
human_description:
|
||||
en_US: key words for searching
|
||||
zh_Hans: 查询关键词
|
||||
pt_BR: key words for searching
|
||||
llm_description: key words for searching
|
||||
form: llm
|
||||
- name: candidate_texts
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: text candidates
|
||||
zh_Hans: 候选文本
|
||||
pt_BR: text candidates
|
||||
human_description:
|
||||
en_US: searched candidates by query
|
||||
zh_Hans: 查询文本搜到候选文本
|
||||
pt_BR: searched candidates by query
|
||||
llm_description: searched candidates by query
|
||||
form: llm
|
||||
- name: topk
|
||||
type: number
|
||||
required: false
|
||||
form: form
|
||||
label:
|
||||
en_US: Limit for results count
|
||||
zh_Hans: 返回个数限制
|
||||
pt_BR: Limit for results count
|
||||
human_description:
|
||||
en_US: Limit for results count
|
||||
zh_Hans: 返回个数限制
|
||||
pt_BR: Limit for results count
|
||||
min: 1
|
||||
max: 10
|
||||
default: 5
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: region of sagemaker endpoint
|
||||
zh_Hans: SageMaker 端点所在的region
|
||||
pt_BR: region of sagemaker endpoint
|
||||
human_description:
|
||||
en_US: region of sagemaker endpoint
|
||||
zh_Hans: SageMaker 端点所在的region
|
||||
pt_BR: region of sagemaker endpoint
|
||||
llm_description: region of sagemaker endpoint
|
||||
form: form
|
||||
@ -1,101 +0,0 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Any, Union
|
||||
|
||||
import boto3
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class TTSModelType(Enum):
|
||||
PresetVoice = "PresetVoice"
|
||||
CloneVoice = "CloneVoice"
|
||||
CloneVoice_CrossLingual = "CloneVoice_CrossLingual"
|
||||
InstructVoice = "InstructVoice"
|
||||
|
||||
|
||||
class SageMakerTTSTool(BuiltinTool):
|
||||
sagemaker_client: Any = None
|
||||
sagemaker_endpoint: str = None
|
||||
s3_client: Any = None
|
||||
comprehend_client: Any = None
|
||||
|
||||
def _detect_lang_code(self, content: str, map_dict: dict = None):
|
||||
map_dict = {"zh": "<|zh|>", "en": "<|en|>", "ja": "<|jp|>", "zh-TW": "<|yue|>", "ko": "<|ko|>"}
|
||||
|
||||
response = self.comprehend_client.detect_dominant_language(Text=content)
|
||||
language_code = response["Languages"][0]["LanguageCode"]
|
||||
return map_dict.get(language_code, "<|zh|>")
|
||||
|
||||
def _build_tts_payload(
|
||||
self,
|
||||
model_type: str,
|
||||
content_text: str,
|
||||
model_role: str,
|
||||
prompt_text: str,
|
||||
prompt_audio: str,
|
||||
instruct_text: str,
|
||||
):
|
||||
if model_type == TTSModelType.PresetVoice.value and model_role:
|
||||
return {"tts_text": content_text, "role": model_role}
|
||||
if model_type == TTSModelType.CloneVoice.value and prompt_text and prompt_audio:
|
||||
return {"tts_text": content_text, "prompt_text": prompt_text, "prompt_audio": prompt_audio}
|
||||
if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio:
|
||||
lang_tag = self._detect_lang_code(content_text)
|
||||
return {"tts_text": f"{content_text}", "prompt_audio": prompt_audio, "lang_tag": lang_tag}
|
||||
if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role:
|
||||
return {"tts_text": content_text, "role": model_role, "instruct_text": instruct_text}
|
||||
|
||||
raise RuntimeError(f"Invalid params for {model_type}")
|
||||
|
||||
def _invoke_sagemaker(self, payload: dict, endpoint: str):
|
||||
response_model = self.sagemaker_client.invoke_endpoint(
|
||||
EndpointName=endpoint,
|
||||
Body=json.dumps(payload),
|
||||
ContentType="application/json",
|
||||
)
|
||||
json_str = response_model["Body"].read().decode("utf8")
|
||||
json_obj = json.loads(json_str)
|
||||
return json_obj
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
try:
|
||||
if not self.sagemaker_client:
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
if aws_region:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||
self.s3_client = boto3.client("s3", region_name=aws_region)
|
||||
self.comprehend_client = boto3.client("comprehend", region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||
self.s3_client = boto3.client("s3")
|
||||
self.comprehend_client = boto3.client("comprehend")
|
||||
|
||||
if not self.sagemaker_endpoint:
|
||||
self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint")
|
||||
|
||||
tts_text = tool_parameters.get("tts_text")
|
||||
tts_infer_type = tool_parameters.get("tts_infer_type")
|
||||
|
||||
voice = tool_parameters.get("voice")
|
||||
mock_voice_audio = tool_parameters.get("mock_voice_audio")
|
||||
mock_voice_text = tool_parameters.get("mock_voice_text")
|
||||
voice_instruct_prompt = tool_parameters.get("voice_instruct_prompt")
|
||||
payload = self._build_tts_payload(
|
||||
tts_infer_type, tts_text, voice, mock_voice_text, mock_voice_audio, voice_instruct_prompt
|
||||
)
|
||||
|
||||
result = self._invoke_sagemaker(payload, self.sagemaker_endpoint)
|
||||
|
||||
return self.create_text_message(text=result["s3_presign_url"])
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Exception {str(e)}")
|
||||
@ -1,149 +0,0 @@
|
||||
identity:
|
||||
name: sagemaker_tts
|
||||
author: AWS
|
||||
label:
|
||||
en_US: SagemakerTTS
|
||||
zh_Hans: Sagemaker语音合成
|
||||
pt_BR: SagemakerTTS
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for Speech synthesis - https://github.com/aws-samples/dify-aws-tool
|
||||
zh_Hans: Sagemaker语音合成工具, 请参考 Github Repo - https://github.com/aws-samples/dify-aws-tool上的部署脚本
|
||||
pt_BR: A tool for Speech synthesis.
|
||||
llm: A tool for Speech synthesis. You can find deploy notebook on Github Repo - https://github.com/aws-samples/dify-aws-tool
|
||||
parameters:
|
||||
- name: sagemaker_endpoint
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: sagemaker endpoint for tts
|
||||
zh_Hans: 语音生成的SageMaker端点
|
||||
pt_BR: sagemaker endpoint for tts
|
||||
human_description:
|
||||
en_US: sagemaker endpoint for tts
|
||||
zh_Hans: 语音生成的SageMaker端点
|
||||
pt_BR: sagemaker endpoint for tts
|
||||
llm_description: sagemaker endpoint for tts
|
||||
form: form
|
||||
- name: tts_text
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: tts text
|
||||
zh_Hans: 语音合成原文
|
||||
pt_BR: tts text
|
||||
human_description:
|
||||
en_US: tts text
|
||||
zh_Hans: 语音合成原文
|
||||
pt_BR: tts text
|
||||
llm_description: tts text
|
||||
form: llm
|
||||
- name: tts_infer_type
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: tts infer type
|
||||
zh_Hans: 合成方式
|
||||
pt_BR: tts infer type
|
||||
human_description:
|
||||
en_US: tts infer type
|
||||
zh_Hans: 合成方式
|
||||
pt_BR: tts infer type
|
||||
llm_description: tts infer type
|
||||
options:
|
||||
- value: PresetVoice
|
||||
label:
|
||||
en_US: preset voice
|
||||
zh_Hans: 预置音色
|
||||
- value: CloneVoice
|
||||
label:
|
||||
en_US: clone voice
|
||||
zh_Hans: 克隆音色
|
||||
- value: CloneVoice_CrossLingual
|
||||
label:
|
||||
en_US: clone crossLingual voice
|
||||
zh_Hans: 克隆音色(跨语言)
|
||||
- value: InstructVoice
|
||||
label:
|
||||
en_US: instruct voice
|
||||
zh_Hans: 指令音色
|
||||
form: form
|
||||
- name: voice
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: preset voice
|
||||
zh_Hans: 预置音色
|
||||
pt_BR: preset voice
|
||||
human_description:
|
||||
en_US: preset voice
|
||||
zh_Hans: 预置音色
|
||||
pt_BR: preset voice
|
||||
llm_description: preset voice
|
||||
options:
|
||||
- value: 中文男
|
||||
label:
|
||||
en_US: zh-cn male
|
||||
zh_Hans: 中文男
|
||||
- value: 中文女
|
||||
label:
|
||||
en_US: zh-cn female
|
||||
zh_Hans: 中文女
|
||||
- value: 粤语女
|
||||
label:
|
||||
en_US: zh-TW female
|
||||
zh_Hans: 粤语女
|
||||
form: form
|
||||
- name: mock_voice_audio
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: clone voice link
|
||||
zh_Hans: 克隆音频链接
|
||||
pt_BR: clone voice link
|
||||
human_description:
|
||||
en_US: clone voice link
|
||||
zh_Hans: 克隆音频链接
|
||||
pt_BR: clone voice link
|
||||
llm_description: clone voice link
|
||||
form: llm
|
||||
- name: mock_voice_text
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: text of clone voice
|
||||
zh_Hans: 克隆音频对应文本
|
||||
pt_BR: text of clone voice
|
||||
human_description:
|
||||
en_US: text of clone voice
|
||||
zh_Hans: 克隆音频对应文本
|
||||
pt_BR: text of clone voice
|
||||
llm_description: text of clone voice
|
||||
form: llm
|
||||
- name: voice_instruct_prompt
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: instruct prompt for voice
|
||||
zh_Hans: 音色指令文本
|
||||
pt_BR: instruct prompt for voice
|
||||
human_description:
|
||||
en_US: instruct prompt for voice
|
||||
zh_Hans: 音色指令文本
|
||||
pt_BR: instruct prompt for voice
|
||||
llm_description: instruct prompt for voice
|
||||
form: llm
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: region of sagemaker endpoint
|
||||
zh_Hans: SageMaker 端点所在的region
|
||||
pt_BR: region of sagemaker endpoint
|
||||
human_description:
|
||||
en_US: region of sagemaker endpoint
|
||||
zh_Hans: SageMaker 端点所在的region
|
||||
pt_BR: region of sagemaker endpoint
|
||||
llm_description: region of sagemaker endpoint
|
||||
form: form
|
||||
|
Before Width: | Height: | Size: 50 KiB |
@ -1,20 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.azuredalle.tools.dalle3 import DallE3Tool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class AzureDALLEProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
DallE3Tool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id="",
|
||||
tool_parameters={"prompt": "cute girl, blue eyes, white hair, anime style", "size": "square", "n": 1},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
@ -1,76 +0,0 @@
|
||||
identity:
|
||||
author: Leslie
|
||||
name: azuredalle
|
||||
label:
|
||||
en_US: Azure DALL-E
|
||||
zh_Hans: Azure DALL-E 绘画
|
||||
pt_BR: Azure DALL-E
|
||||
description:
|
||||
en_US: Azure DALL-E art
|
||||
zh_Hans: Azure DALL-E 绘画
|
||||
pt_BR: Azure DALL-E art
|
||||
icon: icon.png
|
||||
tags:
|
||||
- image
|
||||
- productivity
|
||||
credentials_for_provider:
|
||||
azure_openai_api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: API key
|
||||
zh_Hans: 密钥
|
||||
pt_BR: API key
|
||||
help:
|
||||
en_US: Please input your Azure OpenAI API key
|
||||
zh_Hans: 请输入你的 Azure OpenAI API key
|
||||
pt_BR: Introduza a sua chave de API OpenAI do Azure
|
||||
placeholder:
|
||||
en_US: Please input your Azure OpenAI API key
|
||||
zh_Hans: 请输入你的 Azure OpenAI API key
|
||||
pt_BR: Introduza a sua chave de API OpenAI do Azure
|
||||
azure_openai_api_model_name:
|
||||
type: text-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Deployment Name
|
||||
zh_Hans: 部署名称
|
||||
pt_BR: Nome da Implantação
|
||||
help:
|
||||
en_US: Please input the name of your Azure Openai DALL-E API deployment
|
||||
zh_Hans: 请输入你的 Azure Openai DALL-E API 部署名称
|
||||
pt_BR: Insira o nome da implantação da API DALL-E do Azure Openai
|
||||
placeholder:
|
||||
en_US: Please input the name of your Azure Openai DALL-E API deployment
|
||||
zh_Hans: 请输入你的 Azure Openai DALL-E API 部署名称
|
||||
pt_BR: Insira o nome da implantação da API DALL-E do Azure Openai
|
||||
azure_openai_base_url:
|
||||
type: text-input
|
||||
required: true
|
||||
label:
|
||||
en_US: API Endpoint URL
|
||||
zh_Hans: API 域名
|
||||
pt_BR: API Endpoint URL
|
||||
help:
|
||||
en_US: Please input your Azure OpenAI Endpoint URL, e.g. https://xxx.openai.azure.com/
|
||||
zh_Hans: 请输入你的 Azure OpenAI API域名,例如:https://xxx.openai.azure.com/
|
||||
pt_BR: Introduza a URL do Azure OpenAI Endpoint, e.g. https://xxx.openai.azure.com/
|
||||
placeholder:
|
||||
en_US: Please input your Azure OpenAI Endpoint URL, e.g. https://xxx.openai.azure.com/
|
||||
zh_Hans: 请输入你的 Azure OpenAI API域名,例如:https://xxx.openai.azure.com/
|
||||
pt_BR: Introduza a URL do Azure OpenAI Endpoint, e.g. https://xxx.openai.azure.com/
|
||||
azure_openai_api_version:
|
||||
type: text-input
|
||||
required: true
|
||||
label:
|
||||
en_US: API Version
|
||||
zh_Hans: API 版本
|
||||
pt_BR: API Version
|
||||
help:
|
||||
en_US: Please input your Azure OpenAI API Version,e.g. 2023-12-01-preview
|
||||
zh_Hans: 请输入你的 Azure OpenAI API 版本,例如:2023-12-01-preview
|
||||
pt_BR: Introduza a versão da API OpenAI do Azure,e.g. 2023-12-01-preview
|
||||
placeholder:
|
||||
en_US: Please input your Azure OpenAI API Version,e.g. 2023-12-01-preview
|
||||
zh_Hans: 请输入你的 Azure OpenAI API 版本,例如:2023-12-01-preview
|
||||
pt_BR: Introduza a versão da API OpenAI do Azure,e.g. 2023-12-01-preview
|
||||
@ -1,83 +0,0 @@
|
||||
import random
|
||||
from base64 import b64decode
|
||||
from typing import Any, Union
|
||||
|
||||
from openai import AzureOpenAI
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class DallE3Tool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
client = AzureOpenAI(
|
||||
api_version=self.runtime.credentials["azure_openai_api_version"],
|
||||
azure_endpoint=self.runtime.credentials["azure_openai_base_url"],
|
||||
api_key=self.runtime.credentials["azure_openai_api_key"],
|
||||
)
|
||||
|
||||
SIZE_MAPPING = {
|
||||
"square": "1024x1024",
|
||||
"vertical": "1024x1792",
|
||||
"horizontal": "1792x1024",
|
||||
}
|
||||
|
||||
# prompt
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
if not prompt:
|
||||
return self.create_text_message("Please input prompt")
|
||||
# get size
|
||||
size = SIZE_MAPPING[tool_parameters.get("size", "square")]
|
||||
# get n
|
||||
n = tool_parameters.get("n", 1)
|
||||
# get quality
|
||||
quality = tool_parameters.get("quality", "standard")
|
||||
if quality not in {"standard", "hd"}:
|
||||
return self.create_text_message("Invalid quality")
|
||||
# get style
|
||||
style = tool_parameters.get("style", "vivid")
|
||||
if style not in {"natural", "vivid"}:
|
||||
return self.create_text_message("Invalid style")
|
||||
# set extra body
|
||||
seed_id = tool_parameters.get("seed_id", self._generate_random_id(8))
|
||||
extra_body = {"seed": seed_id}
|
||||
|
||||
# call openapi dalle3
|
||||
model = self.runtime.credentials["azure_openai_api_model_name"]
|
||||
response = client.images.generate(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
size=size,
|
||||
n=n,
|
||||
extra_body=extra_body,
|
||||
style=style,
|
||||
quality=quality,
|
||||
response_format="b64_json",
|
||||
)
|
||||
|
||||
result = []
|
||||
|
||||
for image in response.data:
|
||||
result.append(
|
||||
self.create_blob_message(
|
||||
blob=b64decode(image.b64_json),
|
||||
meta={"mime_type": "image/png"},
|
||||
save_as=self.VariableKey.IMAGE.value,
|
||||
)
|
||||
)
|
||||
result.append(self.create_text_message(f"\nGenerate image source to Seed ID: {seed_id}"))
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _generate_random_id(length=8):
|
||||
characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
||||
random_id = "".join(random.choices(characters, k=length))
|
||||
return random_id
|
||||
@ -1,136 +0,0 @@
|
||||
identity:
|
||||
name: azure_dalle3
|
||||
author: Leslie
|
||||
label:
|
||||
en_US: Azure DALL-E 3
|
||||
zh_Hans: Azure DALL-E 3 绘画
|
||||
pt_BR: Azure DALL-E 3
|
||||
description:
|
||||
en_US: DALL-E 3 is a powerful drawing tool that can draw the image you want based on your prompt, compared to DallE 2, DallE 3 has stronger drawing ability, but it will consume more resources
|
||||
zh_Hans: DALL-E 3 是一个强大的绘画工具,它可以根据您的提示词绘制出您想要的图像,相比于DallE 2, DallE 3拥有更强的绘画能力,但会消耗更多的资源
|
||||
pt_BR: DALL-E 3 é uma poderosa ferramenta de desenho que pode desenhar a imagem que você deseja com base em seu prompt, em comparação com DallE 2, DallE 3 tem uma capacidade de desenho mais forte, mas consumirá mais recursos
|
||||
description:
|
||||
human:
|
||||
en_US: DALL-E is a text to image tool
|
||||
zh_Hans: DALL-E 是一个文本到图像的工具
|
||||
pt_BR: DALL-E é uma ferramenta de texto para imagem
|
||||
llm: DALL-E is a tool used to generate images from text
|
||||
parameters:
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Prompt
|
||||
zh_Hans: 提示词
|
||||
pt_BR: Prompt
|
||||
human_description:
|
||||
en_US: Image prompt, you can check the official documentation of DallE 3
|
||||
zh_Hans: 图像提示词,您可以查看 DallE 3 的官方文档
|
||||
pt_BR: Imagem prompt, você pode verificar a documentação oficial do DallE 3
|
||||
llm_description: Image prompt of DallE 3, you should describe the image you want to generate as a list of words as possible as detailed
|
||||
form: llm
|
||||
- name: seed_id
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Seed ID
|
||||
zh_Hans: 种子ID
|
||||
pt_BR: ID da semente
|
||||
human_description:
|
||||
en_US: Image generation seed ID to ensure consistency of series generated images
|
||||
zh_Hans: 图像生成种子ID,确保系列生成图像的一致性
|
||||
pt_BR: ID de semente de geração de imagem para garantir a consistência das imagens geradas em série
|
||||
llm_description: If the user requests image consistency, extract the seed ID from the user's question or context.The seed id consists of an 8-bit string containing uppercase and lowercase letters and numbers
|
||||
form: llm
|
||||
- name: size
|
||||
type: select
|
||||
required: true
|
||||
human_description:
|
||||
en_US: selecting the image size
|
||||
zh_Hans: 选择图像大小
|
||||
pt_BR: seleccionar o tamanho da imagem
|
||||
label:
|
||||
en_US: Image size
|
||||
zh_Hans: 图像大小
|
||||
pt_BR: Tamanho da imagem
|
||||
form: form
|
||||
options:
|
||||
- value: square
|
||||
label:
|
||||
en_US: Squre(1024x1024)
|
||||
zh_Hans: 方(1024x1024)
|
||||
pt_BR: Squire(1024x1024)
|
||||
- value: vertical
|
||||
label:
|
||||
en_US: Vertical(1024x1792)
|
||||
zh_Hans: 竖屏(1024x1792)
|
||||
pt_BR: Vertical(1024x1792)
|
||||
- value: horizontal
|
||||
label:
|
||||
en_US: Horizontal(1792x1024)
|
||||
zh_Hans: 横屏(1792x1024)
|
||||
pt_BR: Horizontal(1792x1024)
|
||||
default: square
|
||||
- name: n
|
||||
type: number
|
||||
required: true
|
||||
human_description:
|
||||
en_US: selecting the number of images
|
||||
zh_Hans: 选择图像数量
|
||||
pt_BR: seleccionar o número de imagens
|
||||
label:
|
||||
en_US: Number of images
|
||||
zh_Hans: 图像数量
|
||||
pt_BR: Número de imagens
|
||||
form: form
|
||||
min: 1
|
||||
max: 1
|
||||
default: 1
|
||||
- name: quality
|
||||
type: select
|
||||
required: true
|
||||
human_description:
|
||||
en_US: selecting the image quality
|
||||
zh_Hans: 选择图像质量
|
||||
pt_BR: seleccionar a qualidade da imagem
|
||||
label:
|
||||
en_US: Image quality
|
||||
zh_Hans: 图像质量
|
||||
pt_BR: Qualidade da imagem
|
||||
form: form
|
||||
options:
|
||||
- value: standard
|
||||
label:
|
||||
en_US: Standard
|
||||
zh_Hans: 标准
|
||||
pt_BR: Normal
|
||||
- value: hd
|
||||
label:
|
||||
en_US: HD
|
||||
zh_Hans: 高清
|
||||
pt_BR: HD
|
||||
default: standard
|
||||
- name: style
|
||||
type: select
|
||||
required: true
|
||||
human_description:
|
||||
en_US: selecting the image style
|
||||
zh_Hans: 选择图像风格
|
||||
pt_BR: seleccionar o estilo da imagem
|
||||
label:
|
||||
en_US: Image style
|
||||
zh_Hans: 图像风格
|
||||
pt_BR: Estilo da imagem
|
||||
form: form
|
||||
options:
|
||||
- value: vivid
|
||||
label:
|
||||
en_US: Vivid
|
||||
zh_Hans: 生动
|
||||
pt_BR: Vívido
|
||||
- value: natural
|
||||
label:
|
||||
en_US: Natural
|
||||
zh_Hans: 自然
|
||||
pt_BR: Natural
|
||||
default: vivid
|
||||
@ -1,40 +0,0 @@
|
||||
<svg viewBox="-29.62167543756803 0.1 574.391675437568 799.8100000000002" xmlns="http://www.w3.org/2000/svg" width="1888"
|
||||
height="2500">
|
||||
<linearGradient id="a" gradientUnits="userSpaceOnUse" x1="286.383" x2="542.057" y1="284.169" y2="569.112">
|
||||
<stop offset="0" stop-color="#37bdff"/>
|
||||
<stop offset=".25" stop-color="#26c6f4"/>
|
||||
<stop offset=".5" stop-color="#15d0e9"/>
|
||||
<stop offset=".75" stop-color="#3bd6df"/>
|
||||
<stop offset="1" stop-color="#62dcd4"/>
|
||||
</linearGradient>
|
||||
<linearGradient id="b" gradientUnits="userSpaceOnUse" x1="108.979" x2="100.756" y1="675.98" y2="43.669">
|
||||
<stop offset="0" stop-color="#1b48ef"/>
|
||||
<stop offset=".5" stop-color="#2080f1"/>
|
||||
<stop offset="1" stop-color="#26b8f4"/>
|
||||
</linearGradient>
|
||||
<linearGradient id="c" gradientUnits="userSpaceOnUse" x1="256.823" x2="875.632" y1="649.719" y2="649.719">
|
||||
<stop offset="0" stop-color="#39d2ff"/>
|
||||
<stop offset=".5" stop-color="#248ffa"/>
|
||||
<stop offset="1" stop-color="#104cf5"/>
|
||||
</linearGradient>
|
||||
<linearGradient id="d" gradientUnits="userSpaceOnUse" x1="256.823" x2="875.632" y1="649.719" y2="649.719">
|
||||
<stop offset="0" stop-color="#fff"/>
|
||||
<stop offset="1"/>
|
||||
</linearGradient>
|
||||
<path d="M249.97 277.48c-.12.96-.12 2.05-.12 3.12 0 4.16.83 8.16 2.33 11.84l1.34 2.76 5.3 13.56 27.53 70.23 24.01 61.33c6.85 12.38 17.82 22.1 31.05 27.28l4.11 1.51c.16.05.43.05.65.11l65.81 22.63v.05l25.16 8.64 1.72.58c.06 0 .16.06.22.06 4.96 1.25 9.82 2.93 14.46 4.98 10.73 4.63 20.46 11.23 28.77 19.28 3.35 3.2 6.43 6.65 9.28 10.33a88.64 88.64 0 0 1 6.64 9.72c8.78 14.58 13.82 31.72 13.82 49.97 0 3.26-.16 6.41-.49 9.61-.11 1.41-.28 2.77-.49 4.12v.11c-.22 1.43-.49 2.91-.76 4.36-.28 1.41-.54 2.81-.86 4.21-.05.16-.11.33-.17.49-.3 1.42-.68 2.82-1.07 4.23-.35 1.33-.79 2.7-1.28 3.99a42.96 42.96 0 0 1-1.51 4.16c-.49 1.4-1.07 2.82-1.72 4.16-1.78 4.11-3.9 8.06-6.28 11.83a97.889 97.889 0 0 1-10.47 13.95c30.88-33.2 51.41-76.07 56.52-123.51.86-7.78 1.3-15.67 1.3-23.61 0-5.07-.22-10.09-.55-15.13-3.89-56.89-29.79-107.77-69.32-144.08-10.9-10.09-22.81-19.07-35.62-26.69l-24.2-12.37-122.63-62.93a30.15 30.15 0 0 0-11.93-2.44c-15.88 0-28.99 12.11-30.55 27.56z"
|
||||
fill="#7f7f7f"/>
|
||||
<path d="M249.97 277.48c-.12.96-.12 2.05-.12 3.12 0 4.16.83 8.16 2.33 11.84l1.34 2.76 5.3 13.56 27.53 70.23 24.01 61.33c6.85 12.38 17.82 22.1 31.05 27.28l4.11 1.51c.16.05.43.05.65.11l65.81 22.63v.05l25.16 8.64 1.72.58c.06 0 .16.06.22.06 4.96 1.25 9.82 2.93 14.46 4.98 10.73 4.63 20.46 11.23 28.77 19.28 3.35 3.2 6.43 6.65 9.28 10.33a88.64 88.64 0 0 1 6.64 9.72c8.78 14.58 13.82 31.72 13.82 49.97 0 3.26-.16 6.41-.49 9.61-.11 1.41-.28 2.77-.49 4.12v.11c-.22 1.43-.49 2.91-.76 4.36-.28 1.41-.54 2.81-.86 4.21-.05.16-.11.33-.17.49-.3 1.42-.68 2.82-1.07 4.23-.35 1.33-.79 2.7-1.28 3.99a42.96 42.96 0 0 1-1.51 4.16c-.49 1.4-1.07 2.82-1.72 4.16-1.78 4.11-3.9 8.06-6.28 11.83a97.889 97.889 0 0 1-10.47 13.95c30.88-33.2 51.41-76.07 56.52-123.51.86-7.78 1.3-15.67 1.3-23.61 0-5.07-.22-10.09-.55-15.13-3.89-56.89-29.79-107.77-69.32-144.08-10.9-10.09-22.81-19.07-35.62-26.69l-24.2-12.37-122.63-62.93a30.15 30.15 0 0 0-11.93-2.44c-15.88 0-28.99 12.11-30.55 27.56z"
|
||||
fill="url(#a)"/>
|
||||
<path d="M31.62.1C14.17.41.16 14.69.16 32.15v559.06c.07 3.9.29 7.75.57 11.66.25 2.06.52 4.2.9 6.28 7.97 44.87 47.01 78.92 94.15 78.92 16.53 0 32.03-4.21 45.59-11.53.08-.06.22-.14.29-.14l4.88-2.95 19.78-11.64 25.16-14.93.06-496.73c0-33.01-16.52-62.11-41.81-79.4-.6-.36-1.18-.74-1.71-1.17L50.12 5.56C45.16 2.28 39.18.22 32.77.1z"
|
||||
fill="#7f7f7f"/>
|
||||
<path d="M31.62.1C14.17.41.16 14.69.16 32.15v559.06c.07 3.9.29 7.75.57 11.66.25 2.06.52 4.2.9 6.28 7.97 44.87 47.01 78.92 94.15 78.92 16.53 0 32.03-4.21 45.59-11.53.08-.06.22-.14.29-.14l4.88-2.95 19.78-11.64 25.16-14.93.06-496.73c0-33.01-16.52-62.11-41.81-79.4-.6-.36-1.18-.74-1.71-1.17L50.12 5.56C45.16 2.28 39.18.22 32.77.1z"
|
||||
fill="url(#b)"/>
|
||||
<path d="M419.81 510.84L194.72 644.26l-3.24 1.95v.71l-25.16 14.9-19.77 11.67-4.85 2.93-.33.16c-13.53 7.35-29.04 11.51-45.56 11.51-47.13 0-86.22-34.03-94.16-78.92 3.77 32.84 14.96 63.41 31.84 90.04 34.76 54.87 93.54 93.04 161.54 99.67h41.58c36.78-3.84 67.49-18.57 99.77-38.46l49.64-30.36c22.36-14.33 83.05-49.58 100.93-69.36 3.89-4.33 7.4-8.97 10.47-13.94 2.38-3.78 4.5-7.73 6.28-11.84.6-1.4 1.17-2.76 1.72-4.15.52-1.38 1.01-2.77 1.51-4.18.93-2.7 1.67-5.41 2.38-8.2.36-1.59.69-3.16 1.02-4.72 1.08-5.89 1.67-11.94 1.67-18.21 0-18.25-5.04-35.39-13.77-49.95-2-3.4-4.2-6.65-6.64-9.72-2.85-3.7-5.93-7.13-9.28-10.33-8.31-8.05-18.01-14.65-28.77-19.29-4.64-2.05-9.48-3.74-14.46-4.97-.06 0-.16-.06-.22-.06l-1.72-.58z"
|
||||
fill="#7f7f7f"/>
|
||||
<path d="M419.81 510.84L194.72 644.26l-3.24 1.95v.71l-25.16 14.9-19.77 11.67-4.85 2.93-.33.16c-13.53 7.35-29.04 11.51-45.56 11.51-47.13 0-86.22-34.03-94.16-78.92 3.77 32.84 14.96 63.41 31.84 90.04 34.76 54.87 93.54 93.04 161.54 99.67h41.58c36.78-3.84 67.49-18.57 99.77-38.46l49.64-30.36c22.36-14.33 83.05-49.58 100.93-69.36 3.89-4.33 7.4-8.97 10.47-13.94 2.38-3.78 4.5-7.73 6.28-11.84.6-1.4 1.17-2.76 1.72-4.15.52-1.38 1.01-2.77 1.51-4.18.93-2.7 1.67-5.41 2.38-8.2.36-1.59.69-3.16 1.02-4.72 1.08-5.89 1.67-11.94 1.67-18.21 0-18.25-5.04-35.39-13.77-49.95-2-3.4-4.2-6.65-6.64-9.72-2.85-3.7-5.93-7.13-9.28-10.33-8.31-8.05-18.01-14.65-28.77-19.29-4.64-2.05-9.48-3.74-14.46-4.97-.06 0-.16-.06-.22-.06l-1.72-.58z"
|
||||
fill="url(#c)"/>
|
||||
<path d="M512 595.46c0 6.27-.59 12.33-1.68 18.22-.32 1.56-.65 3.12-1.02 4.7-.7 2.8-1.44 5.51-2.37 8.22-.49 1.4-.99 2.8-1.51 4.16-.54 1.4-1.12 2.76-1.73 4.16a87.873 87.873 0 0 1-6.26 11.83 96.567 96.567 0 0 1-10.48 13.94c-17.88 19.79-78.57 55.04-100.93 69.37l-49.64 30.36c-36.39 22.42-70.77 38.29-114.13 39.38-2.05.06-4.06.11-6.05.11-2.8 0-5.56-.05-8.33-.16-73.42-2.8-137.45-42.25-174.38-100.54a213.368 213.368 0 0 1-31.84-90.04c7.94 44.89 47.03 78.92 94.16 78.92 16.52 0 32.03-4.17 45.56-11.51l.33-.17 4.85-2.92 19.77-11.67 25.16-14.9v-.71l3.24-1.95 225.09-133.43 17.33-10.27 1.72.58c.05 0 .16.06.22.06 4.98 1.23 9.83 2.92 14.46 4.97 10.76 4.64 20.45 11.24 28.77 19.29a92.13 92.13 0 0 1 9.28 10.33c2.44 3.07 4.64 6.32 6.64 9.72 8.73 14.56 13.77 31.7 13.77 49.95z"
|
||||
fill="#7f7f7f" opacity=".15"/>
|
||||
<path d="M512 595.46c0 6.27-.59 12.33-1.68 18.22-.32 1.56-.65 3.12-1.02 4.7-.7 2.8-1.44 5.51-2.37 8.22-.49 1.4-.99 2.8-1.51 4.16-.54 1.4-1.12 2.76-1.73 4.16a87.873 87.873 0 0 1-6.26 11.83 96.567 96.567 0 0 1-10.48 13.94c-17.88 19.79-78.57 55.04-100.93 69.37l-49.64 30.36c-36.39 22.42-70.77 38.29-114.13 39.38-2.05.06-4.06.11-6.05.11-2.8 0-5.56-.05-8.33-.16-73.42-2.8-137.45-42.25-174.38-100.54a213.368 213.368 0 0 1-31.84-90.04c7.94 44.89 47.03 78.92 94.16 78.92 16.52 0 32.03-4.17 45.56-11.51l.33-.17 4.85-2.92 19.77-11.67 25.16-14.9v-.71l3.24-1.95 225.09-133.43 17.33-10.27 1.72.58c.05 0 .16.06.22.06 4.98 1.23 9.83 2.92 14.46 4.97 10.76 4.64 20.45 11.24 28.77 19.29a92.13 92.13 0 0 1 9.28 10.33c2.44 3.07 4.64 6.32 6.64 9.72 8.73 14.56 13.77 31.7 13.77 49.95z"
|
||||
fill="url(#d)" opacity=".15"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 6.9 KiB |
@ -1,23 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.bing.tools.bing_web_search import BingSearchTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class BingProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
BingSearchTool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).validate_credentials(
|
||||
credentials=credentials,
|
||||
tool_parameters={
|
||||
"query": "test",
|
||||
"result_type": "link",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
@ -1,107 +0,0 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: bing
|
||||
label:
|
||||
en_US: Bing
|
||||
zh_Hans: Bing
|
||||
pt_BR: Bing
|
||||
description:
|
||||
en_US: Bing Search
|
||||
zh_Hans: Bing 搜索
|
||||
pt_BR: Bing Search
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- search
|
||||
credentials_for_provider:
|
||||
subscription_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Bing subscription key
|
||||
zh_Hans: Bing subscription key
|
||||
pt_BR: Bing subscription key
|
||||
placeholder:
|
||||
en_US: Please input your Bing subscription key
|
||||
zh_Hans: 请输入你的 Bing subscription key
|
||||
pt_BR: Please input your Bing subscription key
|
||||
help:
|
||||
en_US: Get your Bing subscription key from Bing
|
||||
zh_Hans: 从 Bing 获取您的 Bing subscription key
|
||||
pt_BR: Get your Bing subscription key from Bing
|
||||
url: https://www.microsoft.com/cognitive-services/en-us/bing-web-search-api
|
||||
server_url:
|
||||
type: text-input
|
||||
required: false
|
||||
label:
|
||||
en_US: Bing endpoint
|
||||
zh_Hans: Bing endpoint
|
||||
pt_BR: Bing endpoint
|
||||
placeholder:
|
||||
en_US: Please input your Bing endpoint
|
||||
zh_Hans: 请输入你的 Bing 端点
|
||||
pt_BR: Please input your Bing endpoint
|
||||
help:
|
||||
en_US: An endpoint is like "https://api.bing.microsoft.com/v7.0/search"
|
||||
zh_Hans: 例如 "https://api.bing.microsoft.com/v7.0/search"
|
||||
pt_BR: An endpoint is like "https://api.bing.microsoft.com/v7.0/search"
|
||||
default: https://api.bing.microsoft.com/v7.0/search
|
||||
allow_entities:
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Allow Entities Search
|
||||
zh_Hans: 支持实体搜索
|
||||
pt_BR: Allow Entities Search
|
||||
help:
|
||||
en_US: Does your subscription plan allow entity search
|
||||
zh_Hans: 您的订阅计划是否支持实体搜索
|
||||
pt_BR: Does your subscription plan allow entity search
|
||||
default: true
|
||||
allow_web_pages:
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Allow Web Pages Search
|
||||
zh_Hans: 支持网页搜索
|
||||
pt_BR: Allow Web Pages Search
|
||||
help:
|
||||
en_US: Does your subscription plan allow web pages search
|
||||
zh_Hans: 您的订阅计划是否支持网页搜索
|
||||
pt_BR: Does your subscription plan allow web pages search
|
||||
default: true
|
||||
allow_computation:
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Allow Computation Search
|
||||
zh_Hans: 支持计算搜索
|
||||
pt_BR: Allow Computation Search
|
||||
help:
|
||||
en_US: Does your subscription plan allow computation search
|
||||
zh_Hans: 您的订阅计划是否支持计算搜索
|
||||
pt_BR: Does your subscription plan allow computation search
|
||||
default: false
|
||||
allow_news:
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Allow News Search
|
||||
zh_Hans: 支持新闻搜索
|
||||
pt_BR: Allow News Search
|
||||
help:
|
||||
en_US: Does your subscription plan allow news search
|
||||
zh_Hans: 您的订阅计划是否支持新闻搜索
|
||||
pt_BR: Does your subscription plan allow news search
|
||||
default: false
|
||||
allow_related_searches:
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Allow Related Searches
|
||||
zh_Hans: 支持相关搜索
|
||||
pt_BR: Allow Related Searches
|
||||
help:
|
||||
en_US: Does your subscription plan allow related searches
|
||||
zh_Hans: 您的订阅计划是否支持相关搜索
|
||||
pt_BR: Does your subscription plan allow related searches
|
||||
default: false
|
||||
@ -1,202 +0,0 @@
|
||||
from typing import Any, Union
|
||||
from urllib.parse import quote
|
||||
|
||||
from requests import get
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class BingSearchTool(BuiltinTool):
|
||||
url: str = "https://api.bing.microsoft.com/v7.0/search"
|
||||
|
||||
def _invoke_bing(
|
||||
self,
|
||||
user_id: str,
|
||||
server_url: str,
|
||||
subscription_key: str,
|
||||
query: str,
|
||||
limit: int,
|
||||
result_type: str,
|
||||
market: str,
|
||||
lang: str,
|
||||
filters: list[str],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke bing search
|
||||
"""
|
||||
market_code = f"{lang}-{market}"
|
||||
accept_language = f"{lang},{market_code};q=0.9"
|
||||
headers = {"Ocp-Apim-Subscription-Key": subscription_key, "Accept-Language": accept_language}
|
||||
|
||||
query = quote(query)
|
||||
server_url = f'{server_url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filters)}'
|
||||
response = get(server_url, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error {response.status_code}: {response.text}")
|
||||
|
||||
response = response.json()
|
||||
search_results = response["webPages"]["value"][:limit] if "webPages" in response else []
|
||||
related_searches = response["relatedSearches"]["value"] if "relatedSearches" in response else []
|
||||
entities = response["entities"]["value"] if "entities" in response else []
|
||||
news = response["news"]["value"] if "news" in response else []
|
||||
computation = response["computation"]["value"] if "computation" in response else None
|
||||
|
||||
if result_type == "link":
|
||||
results = []
|
||||
if search_results:
|
||||
for result in search_results:
|
||||
url = f': {result["url"]}' if "url" in result else ""
|
||||
results.append(self.create_text_message(text=f'{result["name"]}{url}'))
|
||||
|
||||
if entities:
|
||||
for entity in entities:
|
||||
url = f': {entity["url"]}' if "url" in entity else ""
|
||||
results.append(self.create_text_message(text=f'{entity.get("name", "")}{url}'))
|
||||
|
||||
if news:
|
||||
for news_item in news:
|
||||
url = f': {news_item["url"]}' if "url" in news_item else ""
|
||||
results.append(self.create_text_message(text=f'{news_item.get("name", "")}{url}'))
|
||||
|
||||
if related_searches:
|
||||
for related in related_searches:
|
||||
url = f': {related["displayText"]}' if "displayText" in related else ""
|
||||
results.append(self.create_text_message(text=f'{related.get("displayText", "")}{url}'))
|
||||
|
||||
return results
|
||||
else:
|
||||
# construct text
|
||||
text = ""
|
||||
if search_results:
|
||||
for i, result in enumerate(search_results):
|
||||
text += f'{i + 1}: {result.get("name", "")} - {result.get("snippet", "")}\n'
|
||||
|
||||
if computation and "expression" in computation and "value" in computation:
|
||||
text += "\nComputation:\n"
|
||||
text += f'{computation["expression"]} = {computation["value"]}\n'
|
||||
|
||||
if entities:
|
||||
text += "\nEntities:\n"
|
||||
for entity in entities:
|
||||
url = f'- {entity["url"]}' if "url" in entity else ""
|
||||
text += f'{entity.get("name", "")}{url}\n'
|
||||
|
||||
if news:
|
||||
text += "\nNews:\n"
|
||||
for news_item in news:
|
||||
url = f'- {news_item["url"]}' if "url" in news_item else ""
|
||||
text += f'{news_item.get("name", "")}{url}\n'
|
||||
|
||||
if related_searches:
|
||||
text += "\n\nRelated Searches:\n"
|
||||
for related in related_searches:
|
||||
url = f'- {related["webSearchUrl"]}' if "webSearchUrl" in related else ""
|
||||
text += f'{related.get("displayText", "")}{url}\n'
|
||||
|
||||
return self.create_text_message(text=self.summary(user_id=user_id, content=text))
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any], tool_parameters: dict[str, Any]) -> None:
|
||||
key = credentials.get("subscription_key")
|
||||
if not key:
|
||||
raise Exception("subscription_key is required")
|
||||
|
||||
server_url = credentials.get("server_url")
|
||||
if not server_url:
|
||||
server_url = self.url
|
||||
|
||||
query = tool_parameters.get("query")
|
||||
if not query:
|
||||
raise Exception("query is required")
|
||||
|
||||
limit = min(tool_parameters.get("limit", 5), 10)
|
||||
result_type = tool_parameters.get("result_type", "text") or "text"
|
||||
|
||||
market = tool_parameters.get("market", "US")
|
||||
lang = tool_parameters.get("language", "en")
|
||||
filter = []
|
||||
|
||||
if credentials.get("allow_entities", False):
|
||||
filter.append("Entities")
|
||||
|
||||
if credentials.get("allow_computation", False):
|
||||
filter.append("Computation")
|
||||
|
||||
if credentials.get("allow_news", False):
|
||||
filter.append("News")
|
||||
|
||||
if credentials.get("allow_related_searches", False):
|
||||
filter.append("RelatedSearches")
|
||||
|
||||
if credentials.get("allow_web_pages", False):
|
||||
filter.append("WebPages")
|
||||
|
||||
if not filter:
|
||||
raise Exception("At least one filter is required")
|
||||
|
||||
self._invoke_bing(
|
||||
user_id="test",
|
||||
server_url=server_url,
|
||||
subscription_key=key,
|
||||
query=query,
|
||||
limit=limit,
|
||||
result_type=result_type,
|
||||
market=market,
|
||||
lang=lang,
|
||||
filters=filter,
|
||||
)
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
|
||||
key = self.runtime.credentials.get("subscription_key", None)
|
||||
if not key:
|
||||
raise Exception("subscription_key is required")
|
||||
|
||||
server_url = self.runtime.credentials.get("server_url", None)
|
||||
if not server_url:
|
||||
server_url = self.url
|
||||
|
||||
query = tool_parameters.get("query")
|
||||
if not query:
|
||||
raise Exception("query is required")
|
||||
|
||||
limit = min(tool_parameters.get("limit", 5), 10)
|
||||
result_type = tool_parameters.get("result_type", "text") or "text"
|
||||
|
||||
market = tool_parameters.get("market", "US")
|
||||
lang = tool_parameters.get("language", "en")
|
||||
filter = []
|
||||
|
||||
if tool_parameters.get("enable_computation", False):
|
||||
filter.append("Computation")
|
||||
if tool_parameters.get("enable_entities", False):
|
||||
filter.append("Entities")
|
||||
if tool_parameters.get("enable_news", False):
|
||||
filter.append("News")
|
||||
if tool_parameters.get("enable_related_search", False):
|
||||
filter.append("RelatedSearches")
|
||||
if tool_parameters.get("enable_webpages", False):
|
||||
filter.append("WebPages")
|
||||
|
||||
if not filter:
|
||||
raise Exception("At least one filter is required")
|
||||
|
||||
return self._invoke_bing(
|
||||
user_id=user_id,
|
||||
server_url=server_url,
|
||||
subscription_key=key,
|
||||
query=query,
|
||||
limit=limit,
|
||||
result_type=result_type,
|
||||
market=market,
|
||||
lang=lang,
|
||||
filters=filter,
|
||||
)
|
||||
@ -1,584 +0,0 @@
|
||||
identity:
|
||||
name: bing_web_search
|
||||
author: Dify
|
||||
label:
|
||||
en_US: BingWebSearch
|
||||
zh_Hans: 必应网页搜索
|
||||
pt_BR: BingWebSearch
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for performing a Bing SERP search and extracting snippets and webpages.Input should be a search query.
|
||||
zh_Hans: 一个用于执行 Bing SERP 搜索并提取片段和网页的工具。输入应该是一个搜索查询。
|
||||
pt_BR: A tool for performing a Bing SERP search and extracting snippets and webpages.Input should be a search query.
|
||||
llm: A tool for performing a Bing SERP search and extracting snippets and webpages.Input should be a search query.
|
||||
parameters:
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
en_US: Query string
|
||||
zh_Hans: 查询语句
|
||||
pt_BR: Query string
|
||||
human_description:
|
||||
en_US: used for searching
|
||||
zh_Hans: 用于搜索网页内容
|
||||
pt_BR: used for searching
|
||||
llm_description: key words for searching
|
||||
- name: enable_computation
|
||||
type: boolean
|
||||
required: false
|
||||
form: form
|
||||
label:
|
||||
en_US: Enable computation
|
||||
zh_Hans: 启用计算
|
||||
pt_BR: Enable computation
|
||||
human_description:
|
||||
en_US: enable computation
|
||||
zh_Hans: 启用计算
|
||||
pt_BR: enable computation
|
||||
default: false
|
||||
- name: enable_entities
|
||||
type: boolean
|
||||
required: false
|
||||
form: form
|
||||
label:
|
||||
en_US: Enable entities
|
||||
zh_Hans: 启用实体搜索
|
||||
pt_BR: Enable entities
|
||||
human_description:
|
||||
en_US: enable entities
|
||||
zh_Hans: 启用实体搜索
|
||||
pt_BR: enable entities
|
||||
default: true
|
||||
- name: enable_news
|
||||
type: boolean
|
||||
required: false
|
||||
form: form
|
||||
label:
|
||||
en_US: Enable news
|
||||
zh_Hans: 启用新闻搜索
|
||||
pt_BR: Enable news
|
||||
human_description:
|
||||
en_US: enable news
|
||||
zh_Hans: 启用新闻搜索
|
||||
pt_BR: enable news
|
||||
default: false
|
||||
- name: enable_related_search
|
||||
type: boolean
|
||||
required: false
|
||||
form: form
|
||||
label:
|
||||
en_US: Enable related search
|
||||
zh_Hans: 启用相关搜索
|
||||
pt_BR: Enable related search
|
||||
human_description:
|
||||
en_US: enable related search
|
||||
zh_Hans: 启用相关搜索
|
||||
pt_BR: enable related search
|
||||
default: false
|
||||
- name: enable_webpages
|
||||
type: boolean
|
||||
required: false
|
||||
form: form
|
||||
label:
|
||||
en_US: Enable webpages search
|
||||
zh_Hans: 启用网页搜索
|
||||
pt_BR: Enable webpages search
|
||||
human_description:
|
||||
en_US: enable webpages search
|
||||
zh_Hans: 启用网页搜索
|
||||
pt_BR: enable webpages search
|
||||
default: true
|
||||
- name: limit
|
||||
type: number
|
||||
required: true
|
||||
form: form
|
||||
label:
|
||||
en_US: Limit for results length
|
||||
zh_Hans: 返回长度限制
|
||||
pt_BR: Limit for results length
|
||||
human_description:
|
||||
en_US: limit the number of results
|
||||
zh_Hans: 限制返回结果的数量
|
||||
pt_BR: limit the number of results
|
||||
min: 1
|
||||
max: 10
|
||||
default: 5
|
||||
- name: result_type
|
||||
type: select
|
||||
required: true
|
||||
label:
|
||||
en_US: result type
|
||||
zh_Hans: 结果类型
|
||||
pt_BR: result type
|
||||
human_description:
|
||||
en_US: return a list of links or texts
|
||||
zh_Hans: 返回一个连接列表还是纯文本内容
|
||||
pt_BR: return a list of links or texts
|
||||
default: text
|
||||
options:
|
||||
- value: link
|
||||
label:
|
||||
en_US: Link
|
||||
zh_Hans: 链接
|
||||
pt_BR: Link
|
||||
- value: text
|
||||
label:
|
||||
en_US: Text
|
||||
zh_Hans: 文本
|
||||
pt_BR: Text
|
||||
form: form
|
||||
- name: market
|
||||
type: select
|
||||
label:
|
||||
en_US: Market
|
||||
zh_Hans: 市场
|
||||
pt_BR: Market
|
||||
human_description:
|
||||
en_US: market takes responsibility for the region
|
||||
zh_Hans: 市场决定了搜索结果的地区
|
||||
pt_BR: market takes responsibility for the region
|
||||
required: false
|
||||
form: form
|
||||
default: US
|
||||
options:
|
||||
- value: AR
|
||||
label:
|
||||
en_US: Argentina
|
||||
zh_Hans: 阿根廷
|
||||
pt_BR: Argentina
|
||||
- value: AU
|
||||
label:
|
||||
en_US: Australia
|
||||
zh_Hans: 澳大利亚
|
||||
pt_BR: Australia
|
||||
- value: AT
|
||||
label:
|
||||
en_US: Austria
|
||||
zh_Hans: 奥地利
|
||||
pt_BR: Austria
|
||||
- value: BE
|
||||
label:
|
||||
en_US: Belgium
|
||||
zh_Hans: 比利时
|
||||
pt_BR: Belgium
|
||||
- value: BR
|
||||
label:
|
||||
en_US: Brazil
|
||||
zh_Hans: 巴西
|
||||
pt_BR: Brazil
|
||||
- value: CA
|
||||
label:
|
||||
en_US: Canada
|
||||
zh_Hans: 加拿大
|
||||
pt_BR: Canada
|
||||
- value: CL
|
||||
label:
|
||||
en_US: Chile
|
||||
zh_Hans: 智利
|
||||
pt_BR: Chile
|
||||
- value: CO
|
||||
label:
|
||||
en_US: Colombia
|
||||
zh_Hans: 哥伦比亚
|
||||
pt_BR: Colombia
|
||||
- value: CN
|
||||
label:
|
||||
en_US: China
|
||||
zh_Hans: 中国
|
||||
pt_BR: China
|
||||
- value: CZ
|
||||
label:
|
||||
en_US: Czech Republic
|
||||
zh_Hans: 捷克共和国
|
||||
pt_BR: Czech Republic
|
||||
- value: DK
|
||||
label:
|
||||
en_US: Denmark
|
||||
zh_Hans: 丹麦
|
||||
pt_BR: Denmark
|
||||
- value: FI
|
||||
label:
|
||||
en_US: Finland
|
||||
zh_Hans: 芬兰
|
||||
pt_BR: Finland
|
||||
- value: FR
|
||||
label:
|
||||
en_US: France
|
||||
zh_Hans: 法国
|
||||
pt_BR: France
|
||||
- value: DE
|
||||
label:
|
||||
en_US: Germany
|
||||
zh_Hans: 德国
|
||||
pt_BR: Germany
|
||||
- value: HK
|
||||
label:
|
||||
en_US: Hong Kong
|
||||
zh_Hans: 香港
|
||||
pt_BR: Hong Kong
|
||||
- value: IN
|
||||
label:
|
||||
en_US: India
|
||||
zh_Hans: 印度
|
||||
pt_BR: India
|
||||
- value: ID
|
||||
label:
|
||||
en_US: Indonesia
|
||||
zh_Hans: 印度尼西亚
|
||||
pt_BR: Indonesia
|
||||
- value: IT
|
||||
label:
|
||||
en_US: Italy
|
||||
zh_Hans: 意大利
|
||||
pt_BR: Italy
|
||||
- value: JP
|
||||
label:
|
||||
en_US: Japan
|
||||
zh_Hans: 日本
|
||||
pt_BR: Japan
|
||||
- value: KR
|
||||
label:
|
||||
en_US: Korea
|
||||
zh_Hans: 韩国
|
||||
pt_BR: Korea
|
||||
- value: MY
|
||||
label:
|
||||
en_US: Malaysia
|
||||
zh_Hans: 马来西亚
|
||||
pt_BR: Malaysia
|
||||
- value: MX
|
||||
label:
|
||||
en_US: Mexico
|
||||
zh_Hans: 墨西哥
|
||||
pt_BR: Mexico
|
||||
- value: NL
|
||||
label:
|
||||
en_US: Netherlands
|
||||
zh_Hans: 荷兰
|
||||
pt_BR: Netherlands
|
||||
- value: NZ
|
||||
label:
|
||||
en_US: New Zealand
|
||||
zh_Hans: 新西兰
|
||||
pt_BR: New Zealand
|
||||
- value: 'NO'
|
||||
label:
|
||||
en_US: Norway
|
||||
zh_Hans: 挪威
|
||||
pt_BR: Norway
|
||||
- value: PH
|
||||
label:
|
||||
en_US: Philippines
|
||||
zh_Hans: 菲律宾
|
||||
pt_BR: Philippines
|
||||
- value: PL
|
||||
label:
|
||||
en_US: Poland
|
||||
zh_Hans: 波兰
|
||||
pt_BR: Poland
|
||||
- value: PT
|
||||
label:
|
||||
en_US: Portugal
|
||||
zh_Hans: 葡萄牙
|
||||
pt_BR: Portugal
|
||||
- value: RU
|
||||
label:
|
||||
en_US: Russia
|
||||
zh_Hans: 俄罗斯
|
||||
pt_BR: Russia
|
||||
- value: SA
|
||||
label:
|
||||
en_US: Saudi Arabia
|
||||
zh_Hans: 沙特阿拉伯
|
||||
pt_BR: Saudi Arabia
|
||||
- value: SG
|
||||
label:
|
||||
en_US: Singapore
|
||||
zh_Hans: 新加坡
|
||||
pt_BR: Singapore
|
||||
- value: ZA
|
||||
label:
|
||||
en_US: South Africa
|
||||
zh_Hans: 南非
|
||||
pt_BR: South Africa
|
||||
- value: ES
|
||||
label:
|
||||
en_US: Spain
|
||||
zh_Hans: 西班牙
|
||||
pt_BR: Spain
|
||||
- value: SE
|
||||
label:
|
||||
en_US: Sweden
|
||||
zh_Hans: 瑞典
|
||||
pt_BR: Sweden
|
||||
- value: CH
|
||||
label:
|
||||
en_US: Switzerland
|
||||
zh_Hans: 瑞士
|
||||
pt_BR: Switzerland
|
||||
- value: TW
|
||||
label:
|
||||
en_US: Taiwan
|
||||
zh_Hans: 台湾
|
||||
pt_BR: Taiwan
|
||||
- value: TH
|
||||
label:
|
||||
en_US: Thailand
|
||||
zh_Hans: 泰国
|
||||
pt_BR: Thailand
|
||||
- value: TR
|
||||
label:
|
||||
en_US: Turkey
|
||||
zh_Hans: 土耳其
|
||||
pt_BR: Turkey
|
||||
- value: GB
|
||||
label:
|
||||
en_US: United Kingdom
|
||||
zh_Hans: 英国
|
||||
pt_BR: United Kingdom
|
||||
- value: US
|
||||
label:
|
||||
en_US: United States
|
||||
zh_Hans: 美国
|
||||
pt_BR: United States
|
||||
- name: language
|
||||
type: select
|
||||
label:
|
||||
en_US: Language
|
||||
zh_Hans: 语言
|
||||
pt_BR: Language
|
||||
human_description:
|
||||
en_US: language takes responsibility for the language of the search result
|
||||
zh_Hans: 语言决定了搜索结果的语言
|
||||
pt_BR: language takes responsibility for the language of the search result
|
||||
required: false
|
||||
default: en
|
||||
form: form
|
||||
options:
|
||||
- value: ar
|
||||
label:
|
||||
en_US: Arabic
|
||||
zh_Hans: 阿拉伯语
|
||||
pt_BR: Arabic
|
||||
- value: bg
|
||||
label:
|
||||
en_US: Bulgarian
|
||||
zh_Hans: 保加利亚语
|
||||
pt_BR: Bulgarian
|
||||
- value: ca
|
||||
label:
|
||||
en_US: Catalan
|
||||
zh_Hans: 加泰罗尼亚语
|
||||
pt_BR: Catalan
|
||||
- value: zh-hans
|
||||
label:
|
||||
en_US: Chinese (Simplified)
|
||||
zh_Hans: 中文(简体)
|
||||
pt_BR: Chinese (Simplified)
|
||||
- value: zh-hant
|
||||
label:
|
||||
en_US: Chinese (Traditional)
|
||||
zh_Hans: 中文(繁体)
|
||||
pt_BR: Chinese (Traditional)
|
||||
- value: cs
|
||||
label:
|
||||
en_US: Czech
|
||||
zh_Hans: 捷克语
|
||||
pt_BR: Czech
|
||||
- value: da
|
||||
label:
|
||||
en_US: Danish
|
||||
zh_Hans: 丹麦语
|
||||
pt_BR: Danish
|
||||
- value: nl
|
||||
label:
|
||||
en_US: Dutch
|
||||
zh_Hans: 荷兰语
|
||||
pt_BR: Dutch
|
||||
- value: en
|
||||
label:
|
||||
en_US: English
|
||||
zh_Hans: 英语
|
||||
pt_BR: English
|
||||
- value: et
|
||||
label:
|
||||
en_US: Estonian
|
||||
zh_Hans: 爱沙尼亚语
|
||||
pt_BR: Estonian
|
||||
- value: fi
|
||||
label:
|
||||
en_US: Finnish
|
||||
zh_Hans: 芬兰语
|
||||
pt_BR: Finnish
|
||||
- value: fr
|
||||
label:
|
||||
en_US: French
|
||||
zh_Hans: 法语
|
||||
pt_BR: French
|
||||
- value: de
|
||||
label:
|
||||
en_US: German
|
||||
zh_Hans: 德语
|
||||
pt_BR: German
|
||||
- value: el
|
||||
label:
|
||||
en_US: Greek
|
||||
zh_Hans: 希腊语
|
||||
pt_BR: Greek
|
||||
- value: he
|
||||
label:
|
||||
en_US: Hebrew
|
||||
zh_Hans: 希伯来语
|
||||
pt_BR: Hebrew
|
||||
- value: hi
|
||||
label:
|
||||
en_US: Hindi
|
||||
zh_Hans: 印地语
|
||||
pt_BR: Hindi
|
||||
- value: hu
|
||||
label:
|
||||
en_US: Hungarian
|
||||
zh_Hans: 匈牙利语
|
||||
pt_BR: Hungarian
|
||||
- value: id
|
||||
label:
|
||||
en_US: Indonesian
|
||||
zh_Hans: 印尼语
|
||||
pt_BR: Indonesian
|
||||
- value: it
|
||||
label:
|
||||
en_US: Italian
|
||||
zh_Hans: 意大利语
|
||||
pt_BR: Italian
|
||||
- value: jp
|
||||
label:
|
||||
en_US: Japanese
|
||||
zh_Hans: 日语
|
||||
pt_BR: Japanese
|
||||
- value: kn
|
||||
label:
|
||||
en_US: Kannada
|
||||
zh_Hans: 卡纳达语
|
||||
pt_BR: Kannada
|
||||
- value: ko
|
||||
label:
|
||||
en_US: Korean
|
||||
zh_Hans: 韩语
|
||||
pt_BR: Korean
|
||||
- value: lv
|
||||
label:
|
||||
en_US: Latvian
|
||||
zh_Hans: 拉脱维亚语
|
||||
pt_BR: Latvian
|
||||
- value: lt
|
||||
label:
|
||||
en_US: Lithuanian
|
||||
zh_Hans: 立陶宛语
|
||||
pt_BR: Lithuanian
|
||||
- value: ms
|
||||
label:
|
||||
en_US: Malay
|
||||
zh_Hans: 马来语
|
||||
pt_BR: Malay
|
||||
- value: ml
|
||||
label:
|
||||
en_US: Malayalam
|
||||
zh_Hans: 马拉雅拉姆语
|
||||
pt_BR: Malayalam
|
||||
- value: mr
|
||||
label:
|
||||
en_US: Marathi
|
||||
zh_Hans: 马拉地语
|
||||
pt_BR: Marathi
|
||||
- value: nb
|
||||
label:
|
||||
en_US: Norwegian
|
||||
zh_Hans: 挪威语
|
||||
pt_BR: Norwegian
|
||||
- value: pl
|
||||
label:
|
||||
en_US: Polish
|
||||
zh_Hans: 波兰语
|
||||
pt_BR: Polish
|
||||
- value: pt-br
|
||||
label:
|
||||
en_US: Portuguese (Brazil)
|
||||
zh_Hans: 葡萄牙语(巴西)
|
||||
pt_BR: Portuguese (Brazil)
|
||||
- value: pt-pt
|
||||
label:
|
||||
en_US: Portuguese (Portugal)
|
||||
zh_Hans: 葡萄牙语(葡萄牙)
|
||||
pt_BR: Portuguese (Portugal)
|
||||
- value: pa
|
||||
label:
|
||||
en_US: Punjabi
|
||||
zh_Hans: 旁遮普语
|
||||
pt_BR: Punjabi
|
||||
- value: ro
|
||||
label:
|
||||
en_US: Romanian
|
||||
zh_Hans: 罗马尼亚语
|
||||
pt_BR: Romanian
|
||||
- value: ru
|
||||
label:
|
||||
en_US: Russian
|
||||
zh_Hans: 俄语
|
||||
pt_BR: Russian
|
||||
- value: sr
|
||||
label:
|
||||
en_US: Serbian
|
||||
zh_Hans: 塞尔维亚语
|
||||
pt_BR: Serbian
|
||||
- value: sk
|
||||
label:
|
||||
en_US: Slovak
|
||||
zh_Hans: 斯洛伐克语
|
||||
pt_BR: Slovak
|
||||
- value: sl
|
||||
label:
|
||||
en_US: Slovenian
|
||||
zh_Hans: 斯洛文尼亚语
|
||||
pt_BR: Slovenian
|
||||
- value: es
|
||||
label:
|
||||
en_US: Spanish
|
||||
zh_Hans: 西班牙语
|
||||
pt_BR: Spanish
|
||||
- value: sv
|
||||
label:
|
||||
en_US: Swedish
|
||||
zh_Hans: 瑞典语
|
||||
pt_BR: Swedish
|
||||
- value: ta
|
||||
label:
|
||||
en_US: Tamil
|
||||
zh_Hans: 泰米尔语
|
||||
pt_BR: Tamil
|
||||
- value: te
|
||||
label:
|
||||
en_US: Telugu
|
||||
zh_Hans: 泰卢固语
|
||||
pt_BR: Telugu
|
||||
- value: th
|
||||
label:
|
||||
en_US: Thai
|
||||
zh_Hans: 泰语
|
||||
pt_BR: Thai
|
||||
- value: tr
|
||||
label:
|
||||
en_US: Turkish
|
||||
zh_Hans: 土耳其语
|
||||
pt_BR: Turkish
|
||||
- value: uk
|
||||
label:
|
||||
en_US: Ukrainian
|
||||
zh_Hans: 乌克兰语
|
||||
pt_BR: Ukrainian
|
||||
- value: vi
|
||||
label:
|
||||
en_US: Vietnamese
|
||||
zh_Hans: 越南语
|
||||
pt_BR: Vietnamese
|
||||
@ -1 +0,0 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 48 48" width="48px" height="48px" clip-rule="evenodd" baseProfile="basic"><linearGradient id="yG17B1EwMCiUUe9ON9hI5a" x1="-329.441" x2="-329.276" y1="-136.877" y2="-136.877" gradientTransform="matrix(217.6 0 0 -255.4727 71694.719 -34944.293)" gradientUnits="userSpaceOnUse"><stop offset="0" stop-color="#e68e00"/><stop offset=".437" stop-color="#d75500"/><stop offset=".562" stop-color="#cf3600"/><stop offset=".89" stop-color="#d22900"/><stop offset="1" stop-color="#d42400"/></linearGradient><path fill="url(#yG17B1EwMCiUUe9ON9hI5a)" fill-rule="evenodd" d="M40.635,13.075l0.984-2.418c0,0-1.252-1.343-2.772-2.865 s-4.74-0.627-4.74-0.627L30.439,3H24h-6.439l-3.667,4.165c0,0-3.22-0.895-4.74,0.627s-2.772,2.865-2.772,2.865l0.984,2.418 l-1.252,3.582c0,0,3.682,13.965,4.114,15.671c0.85,3.358,1.431,4.656,3.846,6.358c2.415,1.701,6.797,4.656,7.512,5.104 C22.301,44.237,23.195,45,24,45c0.805,0,1.699-0.763,2.415-1.21c0.715-0.448,5.098-3.403,7.512-5.104 c2.415-1.701,2.996-3,3.846-6.358c0.431-1.705,4.114-15.671,4.114-15.671L40.635,13.075z" clip-rule="evenodd"/><linearGradient id="yG17B1EwMCiUUe9ON9hI5b" x1="19.087" x2="31.755" y1="7.685" y2="32.547" gradientUnits="userSpaceOnUse"><stop offset="0" stop-color="#fff"/><stop offset=".24" stop-color="#f8f8f7"/><stop offset="1" stop-color="#e3e3e1"/></linearGradient><path fill="url(#yG17B1EwMCiUUe9ON9hI5b)" fill-rule="evenodd" d="M33.078,9.807c0,0,4.716,5.709,4.716,6.929 s-0.593,1.542-1.19,2.176c-0.597,0.634-3.202,3.404-3.536,3.76c-0.335,0.356-1.031,0.895-0.621,1.866 c0.41,0.971,1.014,2.206,0.342,3.459c-0.672,1.253-1.824,2.089-2.561,1.951c-0.738-0.138-2.471-1.045-3.108-1.459 c-0.637-0.414-2.657-2.082-2.657-2.72c0-0.638,2.088-1.784,2.473-2.044c0.386-0.26,2.145-1.268,2.181-1.663 c0.036-0.396,0.022-0.511-0.497-1.489c-0.519-0.977-1.454-2.281-1.298-3.149c0.156-0.868,1.663-1.319,2.74-1.726 c1.076-0.407,3.148-1.175,3.406-1.295c0.259-0.12,0.192-0.233-0.592-0.308c-0.784-0.074-3.009-0.37-4.012-0.09 c-1.003,0.28-2.717,0.706-2.855,0.932c-0.139,0.226-0.261,0.233-0.119,1.012c0.142,0.779,0.876,4.517,0.948,5.181 c0.071,0.664,0.211,1.103-0.504,1.267c-0.715,0.164-1.919,0.448-2.332,0.448s-1.617-0.284-2.332-0.448 c-0.715-0.164-0.576-0.603-0.504-1.267s0.805-4.402,0.948-5.181c0.142-0.779,0.02-0.787-0.119-1.012 c-0.139-0.226-1.852-0.652-2.855-0.932c-1.003-0.28-3.228,0.016-4.012,0.09c-0.784,0.074-0.851,0.188-0.592,0.308 c0.259,0.119,2.331,0.888,3.406,1.295c1.076,0.407,2.584,0.858,2.74,1.726c0.156,0.868-0.779,2.172-1.298,3.149 c-0.519,0.977-0.533,1.093-0.497,1.489c0.036,0.395,1.795,1.403,2.181,1.663c0.386,0.26,2.473,1.406,2.473,2.044 c0,0.638-2.02,2.306-2.657,2.72c-0.637,0.414-2.37,1.321-3.108,1.459c-0.738,0.138-1.889-0.698-2.561-1.951 c-0.672-1.253-0.068-2.488,0.342-3.459c0.41-0.971-0.287-1.51-0.621-1.866c-0.334-0.356-2.94-3.126-3.536-3.76 c-0.597-0.634-1.19-0.956-1.19-2.176s4.716-6.929,4.716-6.929s3.98,0.761,4.516,0.761c0.537,0,1.699-0.448,2.772-0.806 C23.285,9.404,24,9.401,24,9.401s0.715,0.003,1.789,0.361c1.073,0.358,2.236,0.806,2.772,0.806 C29.098,10.568,33.078,9.807,33.078,9.807z M29.542,31.643c0.292,0.183,0.114,0.528-0.152,0.716 c-0.266,0.188-3.84,2.959-4.187,3.265c-0.347,0.306-0.857,0.812-1.203,0.812c-0.347,0-0.856-0.506-1.203-0.812 c-0.347-0.306-3.921-3.077-4.187-3.265c-0.266-0.188-0.444-0.533-0.152-0.716c0.292-0.183,1.205-0.645,2.466-1.298 c1.26-0.653,2.831-1.208,3.076-1.208c0.245,0,1.816,0.555,3.076,1.208C28.336,30.999,29.25,31.46,29.542,31.643z" clip-rule="evenodd"/><linearGradient id="yG17B1EwMCiUUe9ON9hI5c" x1="-329.279" x2="-329.074" y1="-140.492" y2="-140.492" gradientTransform="matrix(180.608 0 0 -46.0337 59468.86 -6460.583)" gradientUnits="userSpaceOnUse"><stop offset="0" stop-color="#e68e00"/><stop offset="1" stop-color="#d42400"/></linearGradient><path fill="url(#yG17B1EwMCiUUe9ON9hI5c)" fill-rule="evenodd" d="M34.106,7.165L30.439,3H24h-6.439l-3.667,4.165 c0,0-3.22-0.895-4.74,0.627c0,0,4.293-0.388,5.769,2.015c0,0,3.98,0.761,4.516,0.761c0.537,0,1.699-0.448,2.772-0.806 C23.285,9.404,24,9.401,24,9.401s0.715,0.003,1.789,0.361c1.073,0.358,2.236,0.806,2.772,0.806c0.537,0,4.516-0.761,4.516-0.761 c1.476-2.403,5.769-2.015,5.769-2.015C37.326,6.27,34.106,7.165,34.106,7.165" clip-rule="evenodd"/></svg>
|
||||
|
Before Width: | Height: | Size: 4.1 KiB |
@ -1,22 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.brave.tools.brave_search import BraveSearchTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class BraveProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
BraveSearchTool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"query": "Sachin Tendulkar",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
@ -1,39 +0,0 @@
|
||||
identity:
|
||||
author: Yash Parmar
|
||||
name: brave
|
||||
label:
|
||||
en_US: Brave
|
||||
zh_Hans: Brave
|
||||
pt_BR: Brave
|
||||
description:
|
||||
en_US: Brave
|
||||
zh_Hans: Brave
|
||||
pt_BR: Brave
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- search
|
||||
credentials_for_provider:
|
||||
brave_search_api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Brave Search API key
|
||||
zh_Hans: Brave Search API key
|
||||
pt_BR: Brave Search API key
|
||||
placeholder:
|
||||
en_US: Please input your Brave Search API key
|
||||
zh_Hans: 请输入你的 Brave Search API key
|
||||
pt_BR: Please input your Brave Search API key
|
||||
help:
|
||||
en_US: Get your Brave Search API key from Brave
|
||||
zh_Hans: 从 Brave 获取您的 Brave Search API key
|
||||
pt_BR: Get your Brave Search API key from Brave
|
||||
url: https://brave.com/search/api/
|
||||
base_url:
|
||||
type: text-input
|
||||
required: false
|
||||
label:
|
||||
en_US: Brave server's Base URL
|
||||
zh_Hans: Brave服务器的API URL
|
||||
placeholder:
|
||||
en_US: https://api.search.brave.com/res/v1/web/search
|
||||
@ -1,138 +0,0 @@
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
BRAVE_BASE_URL = "https://api.search.brave.com/res/v1/web/search"
|
||||
|
||||
|
||||
class BraveSearchWrapper(BaseModel):
|
||||
"""Wrapper around the Brave search engine."""
|
||||
|
||||
api_key: str
|
||||
"""The API key to use for the Brave search engine."""
|
||||
search_kwargs: dict = Field(default_factory=dict)
|
||||
"""Additional keyword arguments to pass to the search request."""
|
||||
base_url: str = BRAVE_BASE_URL
|
||||
"""The base URL for the Brave search engine."""
|
||||
ensure_ascii: bool = True
|
||||
"""Ensure the JSON output is ASCII encoded."""
|
||||
|
||||
def run(self, query: str) -> str:
|
||||
"""Query the Brave search engine and return the results as a JSON string.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
|
||||
Returns: The results as a JSON string.
|
||||
|
||||
"""
|
||||
web_search_results = self._search_request(query=query)
|
||||
final_results = [
|
||||
{
|
||||
"title": item.get("title"),
|
||||
"link": item.get("url"),
|
||||
"snippet": item.get("description"),
|
||||
}
|
||||
for item in web_search_results
|
||||
]
|
||||
return json.dumps(final_results, ensure_ascii=self.ensure_ascii)
|
||||
|
||||
def _search_request(self, query: str) -> list[dict]:
|
||||
headers = {
|
||||
"X-Subscription-Token": self.api_key,
|
||||
"Accept": "application/json",
|
||||
}
|
||||
req = requests.PreparedRequest()
|
||||
params = {**self.search_kwargs, **{"q": query}}
|
||||
req.prepare_url(self.base_url, params)
|
||||
if req.url is None:
|
||||
raise ValueError("prepared url is None, this should not happen")
|
||||
|
||||
response = requests.get(req.url, headers=headers)
|
||||
if not response.ok:
|
||||
raise Exception(f"HTTP error {response.status_code}")
|
||||
|
||||
return response.json().get("web", {}).get("results", [])
|
||||
|
||||
|
||||
class BraveSearch(BaseModel):
|
||||
"""Tool that queries the BraveSearch."""
|
||||
|
||||
name: str = "brave_search"
|
||||
description: str = (
|
||||
"a search engine. "
|
||||
"useful for when you need to answer questions about current events."
|
||||
" input should be a search query."
|
||||
)
|
||||
search_wrapper: BraveSearchWrapper
|
||||
|
||||
@classmethod
|
||||
def from_api_key(
|
||||
cls, api_key: str, base_url: str, search_kwargs: Optional[dict] = None, ensure_ascii: bool = True, **kwargs: Any
|
||||
) -> "BraveSearch":
|
||||
"""Create a tool from an api key.
|
||||
|
||||
Args:
|
||||
api_key: The api key to use.
|
||||
search_kwargs: Any additional kwargs to pass to the search wrapper.
|
||||
**kwargs: Any additional kwargs to pass to the tool.
|
||||
|
||||
Returns:
|
||||
A tool.
|
||||
"""
|
||||
wrapper = BraveSearchWrapper(
|
||||
api_key=api_key, base_url=base_url, search_kwargs=search_kwargs or {}, ensure_ascii=ensure_ascii
|
||||
)
|
||||
return cls(search_wrapper=wrapper, **kwargs)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
) -> str:
|
||||
"""Use the tool."""
|
||||
return self.search_wrapper.run(query)
|
||||
|
||||
|
||||
class BraveSearchTool(BuiltinTool):
|
||||
"""
|
||||
Tool for performing a search using Brave search engine.
|
||||
"""
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
"""
|
||||
Invoke the Brave search tool.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user invoking the tool.
|
||||
tool_parameters (dict[str, Any]): The parameters for the tool invocation.
|
||||
|
||||
Returns:
|
||||
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation.
|
||||
"""
|
||||
query = tool_parameters.get("query", "")
|
||||
count = tool_parameters.get("count", 3)
|
||||
api_key = self.runtime.credentials["brave_search_api_key"]
|
||||
base_url = self.runtime.credentials.get("base_url", BRAVE_BASE_URL)
|
||||
ensure_ascii = tool_parameters.get("ensure_ascii", True)
|
||||
|
||||
if len(base_url) == 0:
|
||||
base_url = BRAVE_BASE_URL
|
||||
|
||||
if not query:
|
||||
return self.create_text_message("Please input query")
|
||||
|
||||
tool = BraveSearch.from_api_key(
|
||||
api_key=api_key, base_url=base_url, search_kwargs={"count": count}, ensure_ascii=ensure_ascii
|
||||
)
|
||||
|
||||
results = tool._run(query)
|
||||
|
||||
if not results:
|
||||
return self.create_text_message(f"No results found for '{query}' in Tavily")
|
||||
else:
|
||||
return self.create_text_message(text=results)
|
||||
@ -1,53 +0,0 @@
|
||||
identity:
|
||||
name: brave_search
|
||||
author: Yash Parmar
|
||||
label:
|
||||
en_US: BraveSearch
|
||||
zh_Hans: BraveSearch
|
||||
pt_BR: BraveSearch
|
||||
description:
|
||||
human:
|
||||
en_US: BraveSearch is a privacy-focused search engine that leverages its own index to deliver unbiased, independent, and fast search results. It's designed to respect user privacy by not tracking searches or personal information, making it a secure choice for those concerned about online privacy.
|
||||
zh_Hans: BraveSearch 是一个注重隐私的搜索引擎,它利用自己的索引来提供公正、独立和快速的搜索结果。它旨在通过不跟踪搜索或个人信息来尊重用户隐私,为那些关注在线隐私的用户提供了一个安全的选择。
|
||||
pt_BR: BraveSearch é um mecanismo de busca focado na privacidade que utiliza seu próprio índice para entregar resultados de busca imparciais, independentes e rápidos. Ele é projetado para respeitar a privacidade do usuário, não rastreando buscas ou informações pessoais, tornando-se uma escolha segura para aqueles preocupados com a privacidade online.
|
||||
llm: BraveSearch is a privacy-centric search engine utilizing its unique index to offer unbiased, independent, and swift search results. It aims to protect user privacy by avoiding the tracking of search activities or personal data, presenting a secure option for users mindful of their online privacy.
|
||||
parameters:
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Query string
|
||||
zh_Hans: 查询语句
|
||||
pt_BR: Query string
|
||||
human_description:
|
||||
en_US: The text input used for initiating searches on the web, focusing on delivering relevant and accurate results without compromising user privacy.
|
||||
zh_Hans: 用于在网上启动搜索的文本输入,专注于提供相关且准确的结果,同时不妨碍用户隐私。
|
||||
pt_BR: A entrada de texto usada para iniciar pesquisas na web, focada em entregar resultados relevantes e precisos sem comprometer a privacidade do usuário.
|
||||
llm_description: Keywords or phrases entered to perform searches, aimed at providing relevant and precise results while ensuring the privacy of the user is maintained.
|
||||
form: llm
|
||||
- name: count
|
||||
type: number
|
||||
required: false
|
||||
default: 3
|
||||
label:
|
||||
en_US: Result count
|
||||
zh_Hans: 结果数量
|
||||
pt_BR: Contagem de resultados
|
||||
human_description:
|
||||
en_US: The number of search results to return, allowing users to control the breadth of their search output.
|
||||
zh_Hans: 要返回的搜索结果数量,允许用户控制他们搜索输出的广度。
|
||||
pt_BR: O número de resultados de pesquisa a serem retornados, permitindo que os usuários controlem a amplitude de sua saída de pesquisa.
|
||||
llm_description: Specifies the amount of search results to be displayed, offering users the ability to adjust the scope of their search findings.
|
||||
form: llm
|
||||
- name: ensure_ascii
|
||||
type: boolean
|
||||
default: true
|
||||
label:
|
||||
en_US: Ensure ASCII
|
||||
zh_Hans: 确保 ASCII
|
||||
pt_BR: Ensure ASCII
|
||||
human_description:
|
||||
en_US: Ensure the JSON output is ASCII encoded
|
||||
zh_Hans: 确保输出的 JSON 是 ASCII 编码
|
||||
pt_BR: Ensure the JSON output is ASCII encoded
|
||||
form: form
|
||||
|
Before Width: | Height: | Size: 1.3 KiB |
@ -1,77 +0,0 @@
|
||||
import matplotlib.pyplot as plt
|
||||
from fontTools.ttLib import TTFont
|
||||
from matplotlib.font_manager import findSystemFonts
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.chart.tools.line import LinearChartTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
# use a business theme
|
||||
plt.style.use("seaborn-v0_8-darkgrid")
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
|
||||
|
||||
def init_fonts():
|
||||
fonts = findSystemFonts()
|
||||
|
||||
popular_unicode_fonts = [
|
||||
"Arial Unicode MS",
|
||||
"DejaVu Sans",
|
||||
"DejaVu Sans Mono",
|
||||
"DejaVu Serif",
|
||||
"FreeMono",
|
||||
"FreeSans",
|
||||
"FreeSerif",
|
||||
"Liberation Mono",
|
||||
"Liberation Sans",
|
||||
"Liberation Serif",
|
||||
"Noto Mono",
|
||||
"Noto Sans",
|
||||
"Noto Serif",
|
||||
"Open Sans",
|
||||
"Roboto",
|
||||
"Source Code Pro",
|
||||
"Source Sans Pro",
|
||||
"Source Serif Pro",
|
||||
"Ubuntu",
|
||||
"Ubuntu Mono",
|
||||
]
|
||||
|
||||
supported_fonts = []
|
||||
|
||||
for font_path in fonts:
|
||||
try:
|
||||
font = TTFont(font_path)
|
||||
# get family name
|
||||
family_name = font["name"].getName(1, 3, 1).toUnicode()
|
||||
if family_name in popular_unicode_fonts:
|
||||
supported_fonts.append(family_name)
|
||||
except:
|
||||
pass
|
||||
|
||||
plt.rcParams["font.family"] = "sans-serif"
|
||||
# sort by order of popular_unicode_fonts
|
||||
for font in popular_unicode_fonts:
|
||||
if font in supported_fonts:
|
||||
plt.rcParams["font.sans-serif"] = font
|
||||
break
|
||||
|
||||
|
||||
init_fonts()
|
||||
|
||||
|
||||
class ChartProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
LinearChartTool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"data": "1,3,5,7,9,2,4,6,8,10",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
@ -1,17 +0,0 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: chart
|
||||
label:
|
||||
en_US: ChartGenerator
|
||||
zh_Hans: 图表生成
|
||||
pt_BR: Gerador de gráficos
|
||||
description:
|
||||
en_US: Chart Generator is a tool for generating statistical charts like bar chart, line chart, pie chart, etc.
|
||||
zh_Hans: 图表生成是一个用于生成可视化图表的工具,你可以通过它来生成柱状图、折线图、饼图等各类图表
|
||||
pt_BR: O Gerador de gráficos é uma ferramenta para gerar gráficos estatísticos como gráfico de barras, gráfico de linhas, gráfico de pizza, etc.
|
||||
icon: icon.png
|
||||
tags:
|
||||
- design
|
||||
- productivity
|
||||
- utilities
|
||||
credentials_for_provider:
|
||||
@ -1,48 +0,0 @@
|
||||
import io
|
||||
from typing import Any, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class BarChartTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
data = tool_parameters.get("data", "")
|
||||
if not data:
|
||||
return self.create_text_message("Please input data")
|
||||
data = data.split(";")
|
||||
|
||||
# if all data is int, convert to int
|
||||
if all(i.isdigit() for i in data):
|
||||
data = [int(i) for i in data]
|
||||
else:
|
||||
data = [float(i) for i in data]
|
||||
|
||||
axis = tool_parameters.get("x_axis") or None
|
||||
if axis:
|
||||
axis = axis.split(";")
|
||||
if len(axis) != len(data):
|
||||
axis = None
|
||||
|
||||
flg, ax = plt.subplots(figsize=(10, 8))
|
||||
|
||||
if axis:
|
||||
axis = [label[:10] + "..." if len(label) > 10 else label for label in axis]
|
||||
ax.set_xticklabels(axis, rotation=45, ha="right")
|
||||
ax.bar(axis, data)
|
||||
else:
|
||||
ax.bar(range(len(data)), data)
|
||||
|
||||
buf = io.BytesIO()
|
||||
flg.savefig(buf, format="png")
|
||||
buf.seek(0)
|
||||
plt.close(flg)
|
||||
|
||||
return [
|
||||
self.create_text_message("the bar chart is saved as an image."),
|
||||
self.create_blob_message(blob=buf.read(), meta={"mime_type": "image/png"}),
|
||||
]
|
||||
@ -1,41 +0,0 @@
|
||||
identity:
|
||||
name: bar_chart
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Bar Chart
|
||||
zh_Hans: 柱状图
|
||||
pt_BR: Gráfico de barras
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: Bar chart
|
||||
zh_Hans: 柱状图
|
||||
pt_BR: Gráfico de barras
|
||||
llm: generate a bar chart with input data
|
||||
parameters:
|
||||
- name: data
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: data
|
||||
zh_Hans: 数据
|
||||
pt_BR: dados
|
||||
human_description:
|
||||
en_US: data for generating chart, each number should be separated by ";"
|
||||
zh_Hans: 用于生成柱状图的数据,每个数字之间用 ";" 分隔
|
||||
pt_BR: dados para gerar gráfico de barras, cada número deve ser separado por ";"
|
||||
llm_description: data for generating bar chart, data should be a string contains a list of numbers like "1;2;3;4;5"
|
||||
form: llm
|
||||
- name: x_axis
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: X Axis
|
||||
zh_Hans: x 轴
|
||||
pt_BR: Eixo X
|
||||
human_description:
|
||||
en_US: X axis for chart, each text should be separated by ";"
|
||||
zh_Hans: 柱状图的 x 轴,每个文本之间用 ";" 分隔
|
||||
pt_BR: Eixo X para gráfico de barras, cada texto deve ser separado por ";"
|
||||
llm_description: x axis for bar chart, x axis should be a string contains a list of texts like "a;b;c;1;2" in order to match the data
|
||||
form: llm
|
||||
@ -1,50 +0,0 @@
|
||||
import io
|
||||
from typing import Any, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class LinearChartTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
data = tool_parameters.get("data", "")
|
||||
if not data:
|
||||
return self.create_text_message("Please input data")
|
||||
data = data.split(";")
|
||||
|
||||
axis = tool_parameters.get("x_axis") or None
|
||||
if axis:
|
||||
axis = axis.split(";")
|
||||
if len(axis) != len(data):
|
||||
axis = None
|
||||
|
||||
# if all data is int, convert to int
|
||||
if all(i.isdigit() for i in data):
|
||||
data = [int(i) for i in data]
|
||||
else:
|
||||
data = [float(i) for i in data]
|
||||
|
||||
flg, ax = plt.subplots(figsize=(10, 8))
|
||||
|
||||
if axis:
|
||||
axis = [label[:10] + "..." if len(label) > 10 else label for label in axis]
|
||||
ax.set_xticklabels(axis, rotation=45, ha="right")
|
||||
ax.plot(axis, data)
|
||||
else:
|
||||
ax.plot(data)
|
||||
|
||||
buf = io.BytesIO()
|
||||
flg.savefig(buf, format="png")
|
||||
buf.seek(0)
|
||||
plt.close(flg)
|
||||
|
||||
return [
|
||||
self.create_text_message("the linear chart is saved as an image."),
|
||||
self.create_blob_message(blob=buf.read(), meta={"mime_type": "image/png"}),
|
||||
]
|
||||
@ -1,41 +0,0 @@
|
||||
identity:
|
||||
name: line_chart
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Linear Chart
|
||||
zh_Hans: 线性图表
|
||||
pt_BR: Gráfico linear
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: linear chart
|
||||
zh_Hans: 线性图表
|
||||
pt_BR: Gráfico linear
|
||||
llm: generate a linear chart with input data
|
||||
parameters:
|
||||
- name: data
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: data
|
||||
zh_Hans: 数据
|
||||
pt_BR: dados
|
||||
human_description:
|
||||
en_US: data for generating chart, each number should be separated by ";"
|
||||
zh_Hans: 用于生成线性图表的数据,每个数字之间用 ";" 分隔
|
||||
pt_BR: dados para gerar gráfico linear, cada número deve ser separado por ";"
|
||||
llm_description: data for generating linear chart, data should be a string contains a list of numbers like "1;2;3;4;5"
|
||||
form: llm
|
||||
- name: x_axis
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: X Axis
|
||||
zh_Hans: x 轴
|
||||
pt_BR: Eixo X
|
||||
human_description:
|
||||
en_US: X axis for chart, each text should be separated by ";"
|
||||
zh_Hans: 线性图表的 x 轴,每个文本之间用 ";" 分隔
|
||||
pt_BR: Eixo X para gráfico linear, cada texto deve ser separado por ";"
|
||||
llm_description: x axis for linear chart, x axis should be a string contains a list of texts like "a;b;c;1;2" in order to match the data
|
||||
form: llm
|
||||
@ -1,48 +0,0 @@
|
||||
import io
|
||||
from typing import Any, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class PieChartTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
data = tool_parameters.get("data", "")
|
||||
if not data:
|
||||
return self.create_text_message("Please input data")
|
||||
data = data.split(";")
|
||||
categories = tool_parameters.get("categories") or None
|
||||
|
||||
# if all data is int, convert to int
|
||||
if all(i.isdigit() for i in data):
|
||||
data = [int(i) for i in data]
|
||||
else:
|
||||
data = [float(i) for i in data]
|
||||
|
||||
flg, ax = plt.subplots()
|
||||
|
||||
if categories:
|
||||
categories = categories.split(";")
|
||||
if len(categories) != len(data):
|
||||
categories = None
|
||||
|
||||
if categories:
|
||||
ax.pie(data, labels=categories)
|
||||
else:
|
||||
ax.pie(data)
|
||||
|
||||
buf = io.BytesIO()
|
||||
flg.savefig(buf, format="png")
|
||||
buf.seek(0)
|
||||
plt.close(flg)
|
||||
|
||||
return [
|
||||
self.create_text_message("the pie chart is saved as an image."),
|
||||
self.create_blob_message(blob=buf.read(), meta={"mime_type": "image/png"}),
|
||||
]
|
||||
@ -1,41 +0,0 @@
|
||||
identity:
|
||||
name: pie_chart
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Pie Chart
|
||||
zh_Hans: 饼图
|
||||
pt_BR: Gráfico de pizza
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: Pie chart
|
||||
zh_Hans: 饼图
|
||||
pt_BR: Gráfico de pizza
|
||||
llm: generate a pie chart with input data
|
||||
parameters:
|
||||
- name: data
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: data
|
||||
zh_Hans: 数据
|
||||
pt_BR: dados
|
||||
human_description:
|
||||
en_US: data for generating chart, each number should be separated by ";"
|
||||
zh_Hans: 用于生成饼图的数据,每个数字之间用 ";" 分隔
|
||||
pt_BR: dados para gerar gráfico de pizza, cada número deve ser separado por ";"
|
||||
llm_description: data for generating pie chart, data should be a string contains a list of numbers like "1;2;3;4;5"
|
||||
form: llm
|
||||
- name: categories
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Categories
|
||||
zh_Hans: 分类
|
||||
pt_BR: Categorias
|
||||
human_description:
|
||||
en_US: Categories for chart, each category should be separated by ";"
|
||||
zh_Hans: 饼图的分类,每个分类之间用 ";" 分隔
|
||||
pt_BR: Categorias para gráfico de pizza, cada categoria deve ser separada por ";"
|
||||
llm_description: categories for pie chart, categories should be a string contains a list of texts like "a;b;c;1;2" in order to match the data, each category should be split by ";"
|
||||
form: llm
|
||||
@ -1 +0,0 @@
|
||||
<svg width="14" height="14" viewBox="0 0 14 14" fill="none" xmlns="http://www.w3.org/2000/svg" class="w-3.5 h-3.5" data-icon="Code" aria-hidden="true"><g id="icons/code"><path id="Vector (Stroke)" fill-rule="evenodd" clip-rule="evenodd" d="M8.32593 1.69675C8.67754 1.78466 8.89132 2.14096 8.80342 2.49257L6.47009 11.8259C6.38218 12.1775 6.02588 12.3913 5.67427 12.3034C5.32265 12.2155 5.10887 11.8592 5.19678 11.5076L7.53011 2.17424C7.61801 1.82263 7.97431 1.60885 8.32593 1.69675ZM3.96414 4.20273C4.22042 4.45901 4.22042 4.87453 3.96413 5.13081L2.45578 6.63914C2.45577 6.63915 2.45578 6.63914 2.45578 6.63914C2.25645 6.83851 2.25643 7.16168 2.45575 7.36103C2.45574 7.36103 2.45576 7.36104 2.45575 7.36103L3.96413 8.86936C4.22041 9.12564 4.22042 9.54115 3.96414 9.79744C3.70787 10.0537 3.29235 10.0537 3.03607 9.79745L1.52769 8.28913C0.815811 7.57721 0.815803 6.42302 1.52766 5.7111L3.03606 4.20272C3.29234 3.94644 3.70786 3.94644 3.96414 4.20273ZM10.0361 4.20273C10.2923 3.94644 10.7078 3.94644 10.9641 4.20272L12.4725 5.71108C13.1843 6.423 13.1844 7.57717 12.4725 8.28909L10.9641 9.79745C10.7078 10.0537 10.2923 10.0537 10.036 9.79744C9.77977 9.54115 9.77978 9.12564 10.0361 8.86936L11.5444 7.36107C11.7437 7.16172 11.7438 6.83854 11.5444 6.63917C11.5444 6.63915 11.5445 6.63918 11.5444 6.63917L10.0361 5.13081C9.77978 4.87453 9.77978 4.45901 10.0361 4.20273Z" fill="currentColor"></path></g></svg>
|
||||
|
Before Width: | Height: | Size: 1.4 KiB |
@ -1,8 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class CodeToolProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
pass
|
||||
@ -1,15 +0,0 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: code
|
||||
label:
|
||||
en_US: Code Interpreter
|
||||
zh_Hans: 代码解释器
|
||||
pt_BR: Interpretador de Código
|
||||
description:
|
||||
en_US: Run a piece of code and get the result back.
|
||||
zh_Hans: 运行一段代码并返回结果。
|
||||
pt_BR: Execute um trecho de código e obtenha o resultado de volta.
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- productivity
|
||||
credentials_for_provider:
|
||||
@ -1,22 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class SimpleCode(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
"""
|
||||
invoke simple code
|
||||
"""
|
||||
|
||||
language = tool_parameters.get("language", CodeLanguage.PYTHON3)
|
||||
code = tool_parameters.get("code", "")
|
||||
|
||||
if language not in {CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT}:
|
||||
raise ValueError(f"Only python3 and javascript are supported, not {language}")
|
||||
|
||||
result = CodeExecutor.execute_code(language, "", code)
|
||||
|
||||
return self.create_text_message(result)
|
||||
@ -1,51 +0,0 @@
|
||||
identity:
|
||||
name: simple_code
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Code Interpreter
|
||||
zh_Hans: 代码解释器
|
||||
pt_BR: Interpretador de Código
|
||||
description:
|
||||
human:
|
||||
en_US: Run code and get the result back. When you're using a lower quality model, please make sure there are some tips help LLM to understand how to write the code.
|
||||
zh_Hans: 运行一段代码并返回结果。当您使用较低质量的模型时,请确保有一些提示帮助LLM理解如何编写代码。
|
||||
pt_BR: Execute um trecho de código e obtenha o resultado de volta. quando você estiver usando um modelo de qualidade inferior, certifique-se de que existam algumas dicas para ajudar o LLM a entender como escrever o código.
|
||||
llm: A tool for running code and getting the result back. Only native packages are allowed, network/IO operations are disabled. and you must use print() or console.log() to output the result or result will be empty.
|
||||
parameters:
|
||||
- name: language
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Language
|
||||
zh_Hans: 语言
|
||||
pt_BR: Idioma
|
||||
human_description:
|
||||
en_US: The programming language of the code
|
||||
zh_Hans: 代码的编程语言
|
||||
pt_BR: A linguagem de programação do código
|
||||
llm_description: language of the code, only "python3" and "javascript" are supported
|
||||
form: llm
|
||||
options:
|
||||
- value: python3
|
||||
label:
|
||||
en_US: Python3
|
||||
zh_Hans: Python3
|
||||
pt_BR: Python3
|
||||
- value: javascript
|
||||
label:
|
||||
en_US: JavaScript
|
||||
zh_Hans: JavaScript
|
||||
pt_BR: JavaScript
|
||||
- name: code
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Code
|
||||
zh_Hans: 代码
|
||||
pt_BR: Código
|
||||
human_description:
|
||||
en_US: The code to be executed
|
||||
zh_Hans: 要执行的代码
|
||||
pt_BR: O código a ser executado
|
||||
llm_description: code to be executed, only native packages are allowed, network/IO operations are disabled.
|
||||
form: llm
|
||||
|
Before Width: | Height: | Size: 22 KiB |
@ -1,28 +0,0 @@
|
||||
"""Provide the input parameters type for the cogview provider class"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.cogview.tools.cogview3 import CogView3Tool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class COGVIEWProvider(BuiltinToolProviderController):
|
||||
"""cogview provider"""
|
||||
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
CogView3Tool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"prompt": "一个城市在水晶瓶中欢快生活的场景,水彩画风格,展现出微观与珠宝般的美丽。",
|
||||
"size": "square",
|
||||
"n": 1,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e)) from e
|
||||
@ -1,61 +0,0 @@
|
||||
identity:
|
||||
author: Waffle
|
||||
name: cogview
|
||||
label:
|
||||
en_US: CogView
|
||||
zh_Hans: CogView 绘画
|
||||
pt_BR: CogView
|
||||
description:
|
||||
en_US: CogView art
|
||||
zh_Hans: CogView 绘画
|
||||
pt_BR: CogView art
|
||||
icon: icon.png
|
||||
tags:
|
||||
- image
|
||||
- productivity
|
||||
credentials_for_provider:
|
||||
zhipuai_api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: ZhipuAI API key
|
||||
zh_Hans: ZhipuAI API key
|
||||
pt_BR: ZhipuAI API key
|
||||
help:
|
||||
en_US: Please input your ZhipuAI API key
|
||||
zh_Hans: 请输入你的 ZhipuAI API key
|
||||
pt_BR: Please input your ZhipuAI API key
|
||||
placeholder:
|
||||
en_US: Please input your ZhipuAI API key
|
||||
zh_Hans: 请输入你的 ZhipuAI API key
|
||||
pt_BR: Please input your ZhipuAI API key
|
||||
zhipuai_organizaion_id:
|
||||
type: text-input
|
||||
required: false
|
||||
label:
|
||||
en_US: ZhipuAI organization ID
|
||||
zh_Hans: ZhipuAI organization ID
|
||||
pt_BR: ZhipuAI organization ID
|
||||
help:
|
||||
en_US: Please input your ZhipuAI organization ID
|
||||
zh_Hans: 请输入你的 ZhipuAI organization ID
|
||||
pt_BR: Please input your ZhipuAI organization ID
|
||||
placeholder:
|
||||
en_US: Please input your ZhipuAI organization ID
|
||||
zh_Hans: 请输入你的 ZhipuAI organization ID
|
||||
pt_BR: Please input your ZhipuAI organization ID
|
||||
zhipuai_base_url:
|
||||
type: text-input
|
||||
required: false
|
||||
label:
|
||||
en_US: ZhipuAI base URL
|
||||
zh_Hans: ZhipuAI base URL
|
||||
pt_BR: ZhipuAI base URL
|
||||
help:
|
||||
en_US: Please input your ZhipuAI base URL
|
||||
zh_Hans: 请输入你的 ZhipuAI base URL
|
||||
pt_BR: Please input your ZhipuAI base URL
|
||||
placeholder:
|
||||
en_US: Please input your ZhipuAI base URL
|
||||
zh_Hans: 请输入你的 ZhipuAI base URL
|
||||
pt_BR: Please input your ZhipuAI base URL
|
||||
@ -1,72 +0,0 @@
|
||||
import random
|
||||
from typing import Any, Union
|
||||
|
||||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class CogView3Tool(BuiltinTool):
|
||||
"""CogView3 Tool"""
|
||||
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
Invoke CogView3 tool
|
||||
"""
|
||||
client = ZhipuAI(
|
||||
base_url=self.runtime.credentials["zhipuai_base_url"],
|
||||
api_key=self.runtime.credentials["zhipuai_api_key"],
|
||||
)
|
||||
size_mapping = {
|
||||
"square": "1024x1024",
|
||||
"vertical": "1024x1792",
|
||||
"horizontal": "1792x1024",
|
||||
}
|
||||
# prompt
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
if not prompt:
|
||||
return self.create_text_message("Please input prompt")
|
||||
# get size
|
||||
size = size_mapping[tool_parameters.get("size", "square")]
|
||||
# get n
|
||||
n = tool_parameters.get("n", 1)
|
||||
# get quality
|
||||
quality = tool_parameters.get("quality", "standard")
|
||||
if quality not in {"standard", "hd"}:
|
||||
return self.create_text_message("Invalid quality")
|
||||
# get style
|
||||
style = tool_parameters.get("style", "vivid")
|
||||
if style not in {"natural", "vivid"}:
|
||||
return self.create_text_message("Invalid style")
|
||||
# set extra body
|
||||
seed_id = tool_parameters.get("seed_id", self._generate_random_id(8))
|
||||
extra_body = {"seed": seed_id}
|
||||
response = client.images.generations(
|
||||
prompt=prompt,
|
||||
model="cogview-3",
|
||||
size=size,
|
||||
n=n,
|
||||
extra_body=extra_body,
|
||||
style=style,
|
||||
quality=quality,
|
||||
response_format="b64_json",
|
||||
)
|
||||
result = []
|
||||
for image in response.data:
|
||||
result.append(self.create_image_message(image=image.url))
|
||||
result.append(
|
||||
self.create_json_message(
|
||||
{
|
||||
"url": image.url,
|
||||
}
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _generate_random_id(length=8):
|
||||
characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
||||
random_id = "".join(random.choices(characters, k=length))
|
||||
return random_id
|
||||
@ -1,123 +0,0 @@
|
||||
identity:
|
||||
name: cogview3
|
||||
author: Waffle
|
||||
label:
|
||||
en_US: CogView 3
|
||||
zh_Hans: CogView 3 绘画
|
||||
pt_BR: CogView 3
|
||||
description:
|
||||
en_US: CogView 3 is a powerful drawing tool that can draw the image you want based on your prompt
|
||||
zh_Hans: CogView 3 是一个强大的绘画工具,它可以根据您的提示词绘制出您想要的图像
|
||||
pt_BR: CogView 3 is a powerful drawing tool that can draw the image you want based on your prompt
|
||||
description:
|
||||
human:
|
||||
en_US: CogView 3 is a text to image tool
|
||||
zh_Hans: CogView 3 是一个文本到图像的工具
|
||||
pt_BR: CogView 3 is a text to image tool
|
||||
llm: CogView 3 is a tool used to generate images from text
|
||||
parameters:
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Prompt
|
||||
zh_Hans: 提示词
|
||||
pt_BR: Prompt
|
||||
human_description:
|
||||
en_US: Image prompt, you can check the official documentation of CogView 3
|
||||
zh_Hans: 图像提示词,您可以查看 CogView 3 的官方文档
|
||||
pt_BR: Image prompt, you can check the official documentation of CogView 3
|
||||
llm_description: Image prompt of CogView 3, you should describe the image you want to generate as a list of words as possible as detailed
|
||||
form: llm
|
||||
- name: size
|
||||
type: select
|
||||
required: true
|
||||
human_description:
|
||||
en_US: selecting the image size
|
||||
zh_Hans: 选择图像大小
|
||||
pt_BR: selecting the image size
|
||||
label:
|
||||
en_US: Image size
|
||||
zh_Hans: 图像大小
|
||||
pt_BR: Image size
|
||||
form: form
|
||||
options:
|
||||
- value: square
|
||||
label:
|
||||
en_US: Squre(1024x1024)
|
||||
zh_Hans: 方(1024x1024)
|
||||
pt_BR: Squre(1024x1024)
|
||||
- value: vertical
|
||||
label:
|
||||
en_US: Vertical(1024x1792)
|
||||
zh_Hans: 竖屏(1024x1792)
|
||||
pt_BR: Vertical(1024x1792)
|
||||
- value: horizontal
|
||||
label:
|
||||
en_US: Horizontal(1792x1024)
|
||||
zh_Hans: 横屏(1792x1024)
|
||||
pt_BR: Horizontal(1792x1024)
|
||||
default: square
|
||||
- name: n
|
||||
type: number
|
||||
required: true
|
||||
human_description:
|
||||
en_US: selecting the number of images
|
||||
zh_Hans: 选择图像数量
|
||||
pt_BR: selecting the number of images
|
||||
label:
|
||||
en_US: Number of images
|
||||
zh_Hans: 图像数量
|
||||
pt_BR: Number of images
|
||||
form: form
|
||||
min: 1
|
||||
max: 1
|
||||
default: 1
|
||||
- name: quality
|
||||
type: select
|
||||
required: true
|
||||
human_description:
|
||||
en_US: selecting the image quality
|
||||
zh_Hans: 选择图像质量
|
||||
pt_BR: selecting the image quality
|
||||
label:
|
||||
en_US: Image quality
|
||||
zh_Hans: 图像质量
|
||||
pt_BR: Image quality
|
||||
form: form
|
||||
options:
|
||||
- value: standard
|
||||
label:
|
||||
en_US: Standard
|
||||
zh_Hans: 标准
|
||||
pt_BR: Standard
|
||||
- value: hd
|
||||
label:
|
||||
en_US: HD
|
||||
zh_Hans: 高清
|
||||
pt_BR: HD
|
||||
default: standard
|
||||
- name: style
|
||||
type: select
|
||||
required: true
|
||||
human_description:
|
||||
en_US: selecting the image style
|
||||
zh_Hans: 选择图像风格
|
||||
pt_BR: selecting the image style
|
||||
label:
|
||||
en_US: Image style
|
||||
zh_Hans: 图像风格
|
||||
pt_BR: Image style
|
||||
form: form
|
||||
options:
|
||||
- value: vivid
|
||||
label:
|
||||
en_US: Vivid
|
||||
zh_Hans: 生动
|
||||
pt_BR: Vivid
|
||||
- value: natural
|
||||
label:
|
||||
en_US: Natural
|
||||
zh_Hans: 自然
|
||||
pt_BR: Natural
|
||||
default: vivid
|
||||
|
Before Width: | Height: | Size: 209 KiB |
@ -1,17 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.comfyui.tools.comfyui_stable_diffusion import ComfyuiStableDiffusionTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class ComfyUIProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
ComfyuiStableDiffusionTool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).validate_models()
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
@ -1,42 +0,0 @@
|
||||
identity:
|
||||
author: Qun
|
||||
name: comfyui
|
||||
label:
|
||||
en_US: ComfyUI
|
||||
zh_Hans: ComfyUI
|
||||
pt_BR: ComfyUI
|
||||
description:
|
||||
en_US: ComfyUI is a tool for generating images which can be deployed locally.
|
||||
zh_Hans: ComfyUI 是一个可以在本地部署的图片生成的工具。
|
||||
pt_BR: ComfyUI is a tool for generating images which can be deployed locally.
|
||||
icon: icon.png
|
||||
tags:
|
||||
- image
|
||||
credentials_for_provider:
|
||||
base_url:
|
||||
type: text-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Base URL
|
||||
zh_Hans: ComfyUI服务器的Base URL
|
||||
pt_BR: Base URL
|
||||
placeholder:
|
||||
en_US: Please input your ComfyUI server's Base URL
|
||||
zh_Hans: 请输入你的 ComfyUI 服务器的 Base URL
|
||||
pt_BR: Please input your ComfyUI server's Base URL
|
||||
model:
|
||||
type: text-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Model with suffix
|
||||
zh_Hans: 模型, 需要带后缀
|
||||
pt_BR: Model with suffix
|
||||
placeholder:
|
||||
en_US: Please input your model
|
||||
zh_Hans: 请输入你的模型名称
|
||||
pt_BR: Please input your model
|
||||
help:
|
||||
en_US: The checkpoint name of the ComfyUI server, e.g. xxx.safetensors
|
||||
zh_Hans: ComfyUI服务器的模型名称, 比如 xxx.safetensors
|
||||
pt_BR: The checkpoint name of the ComfyUI server, e.g. xxx.safetensors
|
||||
url: https://docs.dify.ai/tutorials/tool-configuration/comfyui
|
||||
@ -1,475 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import uuid
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from typing import Any, Union
|
||||
|
||||
import websocket
|
||||
from httpx import get, post
|
||||
from yarl import URL
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
SD_TXT2IMG_OPTIONS = {}
|
||||
LORA_NODE = {
|
||||
"inputs": {"lora_name": "", "strength_model": 1, "strength_clip": 1, "model": ["11", 0], "clip": ["11", 1]},
|
||||
"class_type": "LoraLoader",
|
||||
"_meta": {"title": "Load LoRA"},
|
||||
}
|
||||
FluxGuidanceNode = {
|
||||
"inputs": {"guidance": 3.5, "conditioning": ["6", 0]},
|
||||
"class_type": "FluxGuidance",
|
||||
"_meta": {"title": "FluxGuidance"},
|
||||
}
|
||||
|
||||
|
||||
class ModelType(Enum):
|
||||
SD15 = 1
|
||||
SDXL = 2
|
||||
SD3 = 3
|
||||
FLUX = 4
|
||||
|
||||
|
||||
class ComfyuiStableDiffusionTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
# base url
|
||||
base_url = self.runtime.credentials.get("base_url", "")
|
||||
if not base_url:
|
||||
return self.create_text_message("Please input base_url")
|
||||
|
||||
if tool_parameters.get("model"):
|
||||
self.runtime.credentials["model"] = tool_parameters["model"]
|
||||
|
||||
model = self.runtime.credentials.get("model", None)
|
||||
if not model:
|
||||
return self.create_text_message("Please input model")
|
||||
|
||||
# prompt
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
if not prompt:
|
||||
return self.create_text_message("Please input prompt")
|
||||
|
||||
# get negative prompt
|
||||
negative_prompt = tool_parameters.get("negative_prompt", "")
|
||||
|
||||
# get size
|
||||
width = tool_parameters.get("width", 1024)
|
||||
height = tool_parameters.get("height", 1024)
|
||||
|
||||
# get steps
|
||||
steps = tool_parameters.get("steps", 1)
|
||||
|
||||
# get sampler_name
|
||||
sampler_name = tool_parameters.get("sampler_name", "euler")
|
||||
|
||||
# scheduler
|
||||
scheduler = tool_parameters.get("scheduler", "normal")
|
||||
|
||||
# get cfg
|
||||
cfg = tool_parameters.get("cfg", 7.0)
|
||||
|
||||
# get model type
|
||||
model_type = tool_parameters.get("model_type", ModelType.SD15.name)
|
||||
|
||||
# get lora
|
||||
# supports up to 3 loras
|
||||
lora_list = []
|
||||
lora_strength_list = []
|
||||
if tool_parameters.get("lora_1"):
|
||||
lora_list.append(tool_parameters["lora_1"])
|
||||
lora_strength_list.append(tool_parameters.get("lora_strength_1", 1))
|
||||
if tool_parameters.get("lora_2"):
|
||||
lora_list.append(tool_parameters["lora_2"])
|
||||
lora_strength_list.append(tool_parameters.get("lora_strength_2", 1))
|
||||
if tool_parameters.get("lora_3"):
|
||||
lora_list.append(tool_parameters["lora_3"])
|
||||
lora_strength_list.append(tool_parameters.get("lora_strength_3", 1))
|
||||
|
||||
return self.text2img(
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
model_type=model_type,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
steps=steps,
|
||||
sampler_name=sampler_name,
|
||||
scheduler=scheduler,
|
||||
cfg=cfg,
|
||||
lora_list=lora_list,
|
||||
lora_strength_list=lora_strength_list,
|
||||
)
|
||||
|
||||
def get_checkpoints(self) -> list[str]:
|
||||
"""
|
||||
get checkpoints
|
||||
"""
|
||||
try:
|
||||
base_url = self.runtime.credentials.get("base_url", None)
|
||||
if not base_url:
|
||||
return []
|
||||
api_url = str(URL(base_url) / "models" / "checkpoints")
|
||||
response = get(url=api_url, timeout=(2, 10))
|
||||
if response.status_code != 200:
|
||||
return []
|
||||
else:
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
return []
|
||||
|
||||
def get_loras(self) -> list[str]:
|
||||
"""
|
||||
get loras
|
||||
"""
|
||||
try:
|
||||
base_url = self.runtime.credentials.get("base_url", None)
|
||||
if not base_url:
|
||||
return []
|
||||
api_url = str(URL(base_url) / "models" / "loras")
|
||||
response = get(url=api_url, timeout=(2, 10))
|
||||
if response.status_code != 200:
|
||||
return []
|
||||
else:
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
return []
|
||||
|
||||
def get_sample_methods(self) -> tuple[list[str], list[str]]:
|
||||
"""
|
||||
get sample method
|
||||
"""
|
||||
try:
|
||||
base_url = self.runtime.credentials.get("base_url", None)
|
||||
if not base_url:
|
||||
return [], []
|
||||
api_url = str(URL(base_url) / "object_info" / "KSampler")
|
||||
response = get(url=api_url, timeout=(2, 10))
|
||||
if response.status_code != 200:
|
||||
return [], []
|
||||
else:
|
||||
data = response.json()["KSampler"]["input"]["required"]
|
||||
return data["sampler_name"][0], data["scheduler"][0]
|
||||
except Exception as e:
|
||||
return [], []
|
||||
|
||||
def validate_models(self) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
validate models
|
||||
"""
|
||||
try:
|
||||
base_url = self.runtime.credentials.get("base_url", None)
|
||||
if not base_url:
|
||||
raise ToolProviderCredentialValidationError("Please input base_url")
|
||||
model = self.runtime.credentials.get("model", None)
|
||||
if not model:
|
||||
raise ToolProviderCredentialValidationError("Please input model")
|
||||
|
||||
api_url = str(URL(base_url) / "models" / "checkpoints")
|
||||
response = get(url=api_url, timeout=(2, 10))
|
||||
if response.status_code != 200:
|
||||
raise ToolProviderCredentialValidationError("Failed to get models")
|
||||
else:
|
||||
models = response.json()
|
||||
if len([d for d in models if d == model]) > 0:
|
||||
return self.create_text_message(json.dumps(models))
|
||||
else:
|
||||
raise ToolProviderCredentialValidationError(f"model {model} does not exist")
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(f"Failed to get models, {e}")
|
||||
|
||||
def get_history(self, base_url, prompt_id):
|
||||
"""
|
||||
get history
|
||||
"""
|
||||
url = str(URL(base_url) / "history")
|
||||
respond = get(url, params={"prompt_id": prompt_id}, timeout=(2, 10))
|
||||
return respond.json()
|
||||
|
||||
def download_image(self, base_url, filename, subfolder, folder_type):
|
||||
"""
|
||||
download image
|
||||
"""
|
||||
url = str(URL(base_url) / "view")
|
||||
response = get(url, params={"filename": filename, "subfolder": subfolder, "type": folder_type}, timeout=(2, 10))
|
||||
return response.content
|
||||
|
||||
def queue_prompt_image(self, base_url, client_id, prompt):
|
||||
"""
|
||||
send prompt task and rotate
|
||||
"""
|
||||
# initiate task execution
|
||||
url = str(URL(base_url) / "prompt")
|
||||
respond = post(url, data=json.dumps({"client_id": client_id, "prompt": prompt}), timeout=(2, 10))
|
||||
prompt_id = respond.json()["prompt_id"]
|
||||
|
||||
ws = websocket.WebSocket()
|
||||
if "https" in base_url:
|
||||
ws_url = base_url.replace("https", "ws")
|
||||
else:
|
||||
ws_url = base_url.replace("http", "ws")
|
||||
ws.connect(str(URL(f"{ws_url}") / "ws") + f"?clientId={client_id}", timeout=120)
|
||||
|
||||
# websocket rotate execution status
|
||||
output_images = {}
|
||||
while True:
|
||||
out = ws.recv()
|
||||
if isinstance(out, str):
|
||||
message = json.loads(out)
|
||||
if message["type"] == "executing":
|
||||
data = message["data"]
|
||||
if data["node"] is None and data["prompt_id"] == prompt_id:
|
||||
break # Execution is done
|
||||
elif message["type"] == "status":
|
||||
data = message["data"]
|
||||
if data["status"]["exec_info"]["queue_remaining"] == 0 and data.get("sid"):
|
||||
break # Execution is done
|
||||
else:
|
||||
continue # previews are binary data
|
||||
|
||||
# download image when execution finished
|
||||
history = self.get_history(base_url, prompt_id)[prompt_id]
|
||||
for o in history["outputs"]:
|
||||
for node_id in history["outputs"]:
|
||||
node_output = history["outputs"][node_id]
|
||||
if "images" in node_output:
|
||||
images_output = []
|
||||
for image in node_output["images"]:
|
||||
image_data = self.download_image(base_url, image["filename"], image["subfolder"], image["type"])
|
||||
images_output.append(image_data)
|
||||
output_images[node_id] = images_output
|
||||
|
||||
ws.close()
|
||||
|
||||
return output_images
|
||||
|
||||
def text2img(
|
||||
self,
|
||||
base_url: str,
|
||||
model: str,
|
||||
model_type: str,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
width: int,
|
||||
height: int,
|
||||
steps: int,
|
||||
sampler_name: str,
|
||||
scheduler: str,
|
||||
cfg: float,
|
||||
lora_list: list,
|
||||
lora_strength_list: list,
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
generate image
|
||||
"""
|
||||
if not SD_TXT2IMG_OPTIONS:
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
with open(os.path.join(current_dir, "txt2img.json")) as file:
|
||||
SD_TXT2IMG_OPTIONS.update(json.load(file))
|
||||
|
||||
draw_options = deepcopy(SD_TXT2IMG_OPTIONS)
|
||||
draw_options["3"]["inputs"]["steps"] = steps
|
||||
draw_options["3"]["inputs"]["sampler_name"] = sampler_name
|
||||
draw_options["3"]["inputs"]["scheduler"] = scheduler
|
||||
draw_options["3"]["inputs"]["cfg"] = cfg
|
||||
# generate different image when using same prompt next time
|
||||
draw_options["3"]["inputs"]["seed"] = random.randint(0, 100000000)
|
||||
draw_options["4"]["inputs"]["ckpt_name"] = model
|
||||
draw_options["5"]["inputs"]["width"] = width
|
||||
draw_options["5"]["inputs"]["height"] = height
|
||||
draw_options["6"]["inputs"]["text"] = prompt
|
||||
draw_options["7"]["inputs"]["text"] = negative_prompt
|
||||
# if the model is SD3 or FLUX series, the Latent class should be corresponding to SD3 Latent
|
||||
if model_type in {ModelType.SD3.name, ModelType.FLUX.name}:
|
||||
draw_options["5"]["class_type"] = "EmptySD3LatentImage"
|
||||
|
||||
if lora_list:
|
||||
# last Lora node link to KSampler node
|
||||
draw_options["3"]["inputs"]["model"][0] = "10"
|
||||
# last Lora node link to positive and negative Clip node
|
||||
draw_options["6"]["inputs"]["clip"][0] = "10"
|
||||
draw_options["7"]["inputs"]["clip"][0] = "10"
|
||||
# every Lora node link to next Lora node, and Checkpoints node link to first Lora node
|
||||
for i, (lora, strength) in enumerate(zip(lora_list, lora_strength_list), 10):
|
||||
if i - 10 == len(lora_list) - 1:
|
||||
next_node_id = "4"
|
||||
else:
|
||||
next_node_id = str(i + 1)
|
||||
lora_node = deepcopy(LORA_NODE)
|
||||
lora_node["inputs"]["lora_name"] = lora
|
||||
lora_node["inputs"]["strength_model"] = strength
|
||||
lora_node["inputs"]["strength_clip"] = strength
|
||||
lora_node["inputs"]["model"][0] = next_node_id
|
||||
lora_node["inputs"]["clip"][0] = next_node_id
|
||||
draw_options[str(i)] = lora_node
|
||||
|
||||
# FLUX need to add FluxGuidance Node
|
||||
if model_type == ModelType.FLUX.name:
|
||||
last_node_id = str(10 + len(lora_list))
|
||||
draw_options[last_node_id] = deepcopy(FluxGuidanceNode)
|
||||
draw_options[last_node_id]["inputs"]["conditioning"][0] = "6"
|
||||
draw_options["3"]["inputs"]["positive"][0] = last_node_id
|
||||
|
||||
try:
|
||||
client_id = str(uuid.uuid4())
|
||||
result = self.queue_prompt_image(base_url, client_id, prompt=draw_options)
|
||||
|
||||
# get first image
|
||||
image = b""
|
||||
for node in result:
|
||||
for img in result[node]:
|
||||
if img:
|
||||
image = img
|
||||
break
|
||||
|
||||
return self.create_blob_message(
|
||||
blob=image, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Failed to generate image: {str(e)}")
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
parameters = [
|
||||
ToolParameter(
|
||||
name="prompt",
|
||||
label=I18nObject(en_US="Prompt", zh_Hans="Prompt"),
|
||||
human_description=I18nObject(
|
||||
en_US="Image prompt, you can check the official documentation of Stable Diffusion",
|
||||
zh_Hans="图像提示词,您可以查看 Stable Diffusion 的官方文档",
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description="Image prompt of Stable Diffusion, you should describe the image "
|
||||
"you want to generate as a list of words as possible as detailed, "
|
||||
"the prompt must be written in English.",
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
if self.runtime.credentials:
|
||||
try:
|
||||
models = self.get_checkpoints()
|
||||
if len(models) != 0:
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
name="model",
|
||||
label=I18nObject(en_US="Model", zh_Hans="Model"),
|
||||
human_description=I18nObject(
|
||||
en_US="Model of Stable Diffusion or FLUX, "
|
||||
"you can check the official documentation of Stable Diffusion or FLUX",
|
||||
zh_Hans="Stable Diffusion 或者 FLUX 的模型,您可以查看 Stable Diffusion 的官方文档",
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
llm_description="Model of Stable Diffusion or FLUX, "
|
||||
"you can check the official documentation of Stable Diffusion or FLUX",
|
||||
required=True,
|
||||
default=models[0],
|
||||
options=[
|
||||
ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) for i in models
|
||||
],
|
||||
)
|
||||
)
|
||||
loras = self.get_loras()
|
||||
if len(loras) != 0:
|
||||
for n in range(1, 4):
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
name=f"lora_{n}",
|
||||
label=I18nObject(en_US=f"Lora {n}", zh_Hans=f"Lora {n}"),
|
||||
human_description=I18nObject(
|
||||
en_US="Lora of Stable Diffusion, "
|
||||
"you can check the official documentation of Stable Diffusion",
|
||||
zh_Hans="Stable Diffusion 的 Lora 模型,您可以查看 Stable Diffusion 的官方文档",
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
llm_description="Lora of Stable Diffusion, "
|
||||
"you can check the official documentation of "
|
||||
"Stable Diffusion",
|
||||
required=False,
|
||||
options=[
|
||||
ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) for i in loras
|
||||
],
|
||||
)
|
||||
)
|
||||
sample_methods, schedulers = self.get_sample_methods()
|
||||
if len(sample_methods) != 0:
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
name="sampler_name",
|
||||
label=I18nObject(en_US="Sampling method", zh_Hans="Sampling method"),
|
||||
human_description=I18nObject(
|
||||
en_US="Sampling method of Stable Diffusion, "
|
||||
"you can check the official documentation of Stable Diffusion",
|
||||
zh_Hans="Stable Diffusion 的Sampling method,您可以查看 Stable Diffusion 的官方文档",
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
llm_description="Sampling method of Stable Diffusion, "
|
||||
"you can check the official documentation of Stable Diffusion",
|
||||
required=True,
|
||||
default=sample_methods[0],
|
||||
options=[
|
||||
ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i))
|
||||
for i in sample_methods
|
||||
],
|
||||
)
|
||||
)
|
||||
if len(schedulers) != 0:
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
name="scheduler",
|
||||
label=I18nObject(en_US="Scheduler", zh_Hans="Scheduler"),
|
||||
human_description=I18nObject(
|
||||
en_US="Scheduler of Stable Diffusion, "
|
||||
"you can check the official documentation of Stable Diffusion",
|
||||
zh_Hans="Stable Diffusion 的Scheduler,您可以查看 Stable Diffusion 的官方文档",
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
llm_description="Scheduler of Stable Diffusion, "
|
||||
"you can check the official documentation of Stable Diffusion",
|
||||
required=True,
|
||||
default=schedulers[0],
|
||||
options=[
|
||||
ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) for i in schedulers
|
||||
],
|
||||
)
|
||||
)
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
name="model_type",
|
||||
label=I18nObject(en_US="Model Type", zh_Hans="Model Type"),
|
||||
human_description=I18nObject(
|
||||
en_US="Model Type of Stable Diffusion or Flux, "
|
||||
"you can check the official documentation of Stable Diffusion or Flux",
|
||||
zh_Hans="Stable Diffusion 或 FLUX 的模型类型,"
|
||||
"您可以查看 Stable Diffusion 或 Flux 的官方文档",
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
llm_description="Model Type of Stable Diffusion or Flux, "
|
||||
"you can check the official documentation of Stable Diffusion or Flux",
|
||||
required=True,
|
||||
default=ModelType.SD15.name,
|
||||
options=[
|
||||
ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i))
|
||||
for i in ModelType.__members__
|
||||
],
|
||||
)
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
return parameters
|
||||
@ -1,212 +0,0 @@
|
||||
identity:
|
||||
name: txt2img workflow
|
||||
author: Qun
|
||||
label:
|
||||
en_US: Txt2Img Workflow
|
||||
zh_Hans: Txt2Img Workflow
|
||||
pt_BR: Txt2Img Workflow
|
||||
description:
|
||||
human:
|
||||
en_US: a pre-defined comfyui workflow that can use one model and up to 3 loras to generate images. Support SD1.5, SDXL, SD3 and FLUX which contain text encoders/clip, but does not support models that requires a triple clip loader.
|
||||
zh_Hans: 一个预定义的 ComfyUI 工作流,可以使用一个模型和最多3个loras来生成图像。支持包含文本编码器/clip的SD1.5、SDXL、SD3和FLUX,但不支持需要clip加载器的模型。
|
||||
pt_BR: a pre-defined comfyui workflow that can use one model and up to 3 loras to generate images. Support SD1.5, SDXL, SD3 and FLUX which contain text encoders/clip, but does not support models that requires a triple clip loader.
|
||||
llm: draw the image you want based on your prompt.
|
||||
parameters:
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Prompt
|
||||
zh_Hans: 提示词
|
||||
pt_BR: Prompt
|
||||
human_description:
|
||||
en_US: Image prompt, you can check the official documentation of Stable Diffusion or FLUX
|
||||
zh_Hans: 图像提示词,您可以查看 Stable Diffusion 或者 FLUX 的官方文档
|
||||
pt_BR: Image prompt, you can check the official documentation of Stable Diffusion or FLUX
|
||||
llm_description: Image prompt of Stable Diffusion, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English.
|
||||
form: llm
|
||||
- name: model
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Model Name
|
||||
zh_Hans: 模型名称
|
||||
pt_BR: Model Name
|
||||
human_description:
|
||||
en_US: Model Name
|
||||
zh_Hans: 模型名称
|
||||
pt_BR: Model Name
|
||||
form: form
|
||||
- name: model_type
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Model Type
|
||||
zh_Hans: 模型类型
|
||||
pt_BR: Model Type
|
||||
human_description:
|
||||
en_US: Model Type
|
||||
zh_Hans: 模型类型
|
||||
pt_BR: Model Type
|
||||
form: form
|
||||
- name: lora_1
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Lora 1
|
||||
zh_Hans: Lora 1
|
||||
pt_BR: Lora 1
|
||||
human_description:
|
||||
en_US: Lora 1
|
||||
zh_Hans: Lora 1
|
||||
pt_BR: Lora 1
|
||||
form: form
|
||||
- name: lora_strength_1
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Lora Strength 1
|
||||
zh_Hans: Lora Strength 1
|
||||
pt_BR: Lora Strength 1
|
||||
human_description:
|
||||
en_US: Lora Strength 1
|
||||
zh_Hans: Lora模型的权重
|
||||
pt_BR: Lora Strength 1
|
||||
form: form
|
||||
- name: steps
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Steps
|
||||
zh_Hans: Steps
|
||||
pt_BR: Steps
|
||||
human_description:
|
||||
en_US: Steps
|
||||
zh_Hans: Steps
|
||||
pt_BR: Steps
|
||||
form: form
|
||||
default: 20
|
||||
- name: width
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Width
|
||||
zh_Hans: Width
|
||||
pt_BR: Width
|
||||
human_description:
|
||||
en_US: Width
|
||||
zh_Hans: Width
|
||||
pt_BR: Width
|
||||
form: form
|
||||
default: 1024
|
||||
- name: height
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Height
|
||||
zh_Hans: Height
|
||||
pt_BR: Height
|
||||
human_description:
|
||||
en_US: Height
|
||||
zh_Hans: Height
|
||||
pt_BR: Height
|
||||
form: form
|
||||
default: 1024
|
||||
- name: negative_prompt
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Negative prompt
|
||||
zh_Hans: Negative prompt
|
||||
pt_BR: Negative prompt
|
||||
human_description:
|
||||
en_US: Negative prompt
|
||||
zh_Hans: Negative prompt
|
||||
pt_BR: Negative prompt
|
||||
form: form
|
||||
default: bad art, ugly, deformed, watermark, duplicated, discontinuous lines
|
||||
- name: cfg
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: CFG Scale
|
||||
zh_Hans: CFG Scale
|
||||
pt_BR: CFG Scale
|
||||
human_description:
|
||||
en_US: CFG Scale
|
||||
zh_Hans: 提示词相关性(CFG Scale)
|
||||
pt_BR: CFG Scale
|
||||
form: form
|
||||
default: 7.0
|
||||
- name: sampler_name
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Sampling method
|
||||
zh_Hans: Sampling method
|
||||
pt_BR: Sampling method
|
||||
human_description:
|
||||
en_US: Sampling method
|
||||
zh_Hans: Sampling method
|
||||
pt_BR: Sampling method
|
||||
form: form
|
||||
- name: scheduler
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Scheduler
|
||||
zh_Hans: Scheduler
|
||||
pt_BR: Scheduler
|
||||
human_description:
|
||||
en_US: Scheduler
|
||||
zh_Hans: Scheduler
|
||||
pt_BR: Scheduler
|
||||
form: form
|
||||
- name: lora_2
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Lora 2
|
||||
zh_Hans: Lora 2
|
||||
pt_BR: Lora 2
|
||||
human_description:
|
||||
en_US: Lora 2
|
||||
zh_Hans: Lora 2
|
||||
pt_BR: Lora 2
|
||||
form: form
|
||||
- name: lora_strength_2
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Lora Strength 2
|
||||
zh_Hans: Lora Strength 2
|
||||
pt_BR: Lora Strength 2
|
||||
human_description:
|
||||
en_US: Lora Strength 2
|
||||
zh_Hans: Lora模型的权重
|
||||
pt_BR: Lora Strength 2
|
||||
form: form
|
||||
- name: lora_3
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Lora 3
|
||||
zh_Hans: Lora 3
|
||||
pt_BR: Lora 3
|
||||
human_description:
|
||||
en_US: Lora 3
|
||||
zh_Hans: Lora 3
|
||||
pt_BR: Lora 3
|
||||
form: form
|
||||
- name: lora_strength_3
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Lora Strength 3
|
||||
zh_Hans: Lora Strength 3
|
||||
pt_BR: Lora Strength 3
|
||||
human_description:
|
||||
en_US: Lora Strength 3
|
||||
zh_Hans: Lora模型的权重
|
||||
pt_BR: Lora Strength 3
|
||||
form: form
|
||||
@ -1,107 +0,0 @@
|
||||
{
|
||||
"3": {
|
||||
"inputs": {
|
||||
"seed": 156680208700286,
|
||||
"steps": 20,
|
||||
"cfg": 8,
|
||||
"sampler_name": "euler",
|
||||
"scheduler": "normal",
|
||||
"denoise": 1,
|
||||
"model": [
|
||||
"4",
|
||||
0
|
||||
],
|
||||
"positive": [
|
||||
"6",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"7",
|
||||
0
|
||||
],
|
||||
"latent_image": [
|
||||
"5",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "KSampler",
|
||||
"_meta": {
|
||||
"title": "KSampler"
|
||||
}
|
||||
},
|
||||
"4": {
|
||||
"inputs": {
|
||||
"ckpt_name": "3dAnimationDiffusion_v10.safetensors"
|
||||
},
|
||||
"class_type": "CheckpointLoaderSimple",
|
||||
"_meta": {
|
||||
"title": "Load Checkpoint"
|
||||
}
|
||||
},
|
||||
"5": {
|
||||
"inputs": {
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
"batch_size": 1
|
||||
},
|
||||
"class_type": "EmptyLatentImage",
|
||||
"_meta": {
|
||||
"title": "Empty Latent Image"
|
||||
}
|
||||
},
|
||||
"6": {
|
||||
"inputs": {
|
||||
"text": "beautiful scenery nature glass bottle landscape, , purple galaxy bottle,",
|
||||
"clip": [
|
||||
"4",
|
||||
1
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Prompt)"
|
||||
}
|
||||
},
|
||||
"7": {
|
||||
"inputs": {
|
||||
"text": "text, watermark",
|
||||
"clip": [
|
||||
"4",
|
||||
1
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Prompt)"
|
||||
}
|
||||
},
|
||||
"8": {
|
||||
"inputs": {
|
||||
"samples": [
|
||||
"3",
|
||||
0
|
||||
],
|
||||
"vae": [
|
||||
"4",
|
||||
2
|
||||
]
|
||||
},
|
||||
"class_type": "VAEDecode",
|
||||
"_meta": {
|
||||
"title": "VAE Decode"
|
||||
}
|
||||
},
|
||||
"9": {
|
||||
"inputs": {
|
||||
"filename_prefix": "ComfyUI",
|
||||
"images": [
|
||||
"8",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "SaveImage",
|
||||
"_meta": {
|
||||
"title": "Save Image"
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,49 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<!-- Generator: Adobe Illustrator 19.2.1, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
|
||||
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
|
||||
<svg version="1.1" id="Layer_1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px"
|
||||
viewBox="0 0 200 130.2" style="enable-background:new 0 0 200 130.2;" xml:space="preserve">
|
||||
<style type="text/css">
|
||||
.st0{fill:#3EB1C8;}
|
||||
.st1{fill:#D8D2C4;}
|
||||
.st2{fill:#4F5858;}
|
||||
.st3{fill:#FFC72C;}
|
||||
.st4{fill:#EF3340;}
|
||||
</style>
|
||||
<g>
|
||||
<polygon class="st0" points="111.8,95.5 111.8,66.8 135.4,59 177.2,73.3 "/>
|
||||
<polygon class="st1" points="153.6,36.8 111.8,51.2 135.4,59 177.2,44.6 "/>
|
||||
<polygon class="st2" points="135.4,59 177.2,44.6 177.2,73.3 "/>
|
||||
<polygon class="st3" points="177.2,0.3 177.2,29 153.6,36.8 111.8,22.5 "/>
|
||||
<polygon class="st4" points="153.6,36.8 111.8,51.2 111.8,22.5 "/>
|
||||
<g>
|
||||
<g>
|
||||
<g>
|
||||
<g>
|
||||
<path class="st2" d="M26.3,104.8c-0.5-3.7-4.1-6.5-8.1-6.5c-7.3,0-10.1,6.2-10.1,12.7c0,6.2,2.8,12.4,10.1,12.4
|
||||
c5,0,7.8-3.4,8.4-8.3h7.9c-0.8,9.2-7.2,15.2-16.3,15.2C6.8,130.2,0,121.7,0,111c0-11,6.8-19.6,18.2-19.6c8.2,0,15,4.8,16,13.3
|
||||
H26.3z"/>
|
||||
<path class="st2" d="M37.4,102.5h7v5h0.1c1.4-3.4,5-5.7,8.6-5.7c0.5,0,1.1,0.1,1.6,0.3v6.9c-0.7-0.2-1.8-0.3-2.6-0.3
|
||||
c-5.4,0-7.3,3.9-7.3,8.6v12.1h-7.4V102.5z"/>
|
||||
<path class="st2" d="M68.7,101.8c8.5,0,13.9,5.6,13.9,14.2c0,8.5-5.5,14.1-13.9,14.1c-8.4,0-13.9-5.6-13.9-14.1
|
||||
C54.9,107.4,60.3,101.8,68.7,101.8z M68.7,124.5c5,0,6.5-4.3,6.5-8.6c0-4.3-1.5-8.6-6.5-8.6c-5,0-6.5,4.3-6.5,8.6
|
||||
C62.2,120.2,63.8,124.5,68.7,124.5z"/>
|
||||
<path class="st2" d="M91.2,120.6c0.1,3.2,2.8,4.5,5.7,4.5c2.1,0,4.8-0.8,4.8-3.4c0-2.2-3.1-3-8.4-4.2c-4.3-0.9-8.5-2.4-8.5-7.2
|
||||
c0-6.9,5.9-8.6,11.7-8.6c5.9,0,11.3,2,11.8,8.6h-7c-0.2-2.9-2.4-3.6-5-3.6c-1.7,0-4.1,0.3-4.1,2.5c0,2.6,4.2,3,8.4,4
|
||||
c4.3,1,8.5,2.5,8.5,7.5c0,7.1-6.1,9.3-12.3,9.3c-6.2,0-12.3-2.3-12.6-9.5H91.2z"/>
|
||||
<path class="st2" d="M118.1,120.6c0.1,3.2,2.8,4.5,5.7,4.5c2.1,0,4.8-0.8,4.8-3.4c0-2.2-3.1-3-8.4-4.2
|
||||
c-4.3-0.9-8.5-2.4-8.5-7.2c0-6.9,5.9-8.6,11.7-8.6c5.9,0,11.3,2,11.8,8.6h-7c-0.2-2.9-2.4-3.6-5-3.6c-1.7,0-4.1,0.3-4.1,2.5
|
||||
c0,2.6,4.2,3,8.4,4c4.3,1,8.5,2.5,8.5,7.5c0,7.1-6.1,9.3-12.3,9.3c-6.2,0-12.3-2.3-12.6-9.5H118.1z"/>
|
||||
<path class="st2" d="M138.4,102.5h7v5h0.1c1.4-3.4,5-5.7,8.6-5.7c0.5,0,1.1,0.1,1.6,0.3v6.9c-0.7-0.2-1.8-0.3-2.6-0.3
|
||||
c-5.4,0-7.3,3.9-7.3,8.6v12.1h-7.4V102.5z"/>
|
||||
<path class="st2" d="M163.7,117.7c0.2,4.7,2.5,6.8,6.6,6.8c3,0,5.3-1.8,5.8-3.5h6.5c-2.1,6.3-6.5,9-12.6,9
|
||||
c-8.5,0-13.7-5.8-13.7-14.1c0-8,5.6-14.2,13.7-14.2c9.1,0,13.6,7.7,13,15.9H163.7z M175.7,113.1c-0.7-3.7-2.3-5.7-5.9-5.7
|
||||
c-4.7,0-6,3.6-6.1,5.7H175.7z"/>
|
||||
<path class="st2" d="M187.2,107.5h-4.4v-4.9h4.4v-2.1c0-4.7,3-8.2,9-8.2c1.3,0,2.6,0.2,3.9,0.2V98c-0.9-0.1-1.8-0.2-2.7-0.2
|
||||
c-2,0-2.8,0.8-2.8,3.1v1.6h5.1v4.9h-5.1v21.9h-7.4V107.5z"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 3.0 KiB |
@ -1,20 +0,0 @@
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.crossref.tools.query_doi import CrossRefQueryDOITool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class CrossRefProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
CrossRefQueryDOITool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"doi": "10.1007/s00894-022-05373-8",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
@ -1,29 +0,0 @@
|
||||
identity:
|
||||
author: Sakura4036
|
||||
name: crossref
|
||||
label:
|
||||
en_US: CrossRef
|
||||
zh_Hans: CrossRef
|
||||
description:
|
||||
en_US: Crossref is a cross-publisher reference linking registration query system using DOI technology created in 2000. Crossref establishes cross-database links between the reference list and citation full text of papers, making it very convenient for readers to access the full text of papers.
|
||||
zh_Hans: Crossref是于2000年创建的使用DOI技术的跨出版商参考文献链接注册查询系统。Crossref建立了在论文的参考文献列表和引文全文之间的跨数据库链接,使得读者能够非常便捷地获取文献全文。
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- search
|
||||
credentials_for_provider:
|
||||
mailto:
|
||||
type: text-input
|
||||
required: true
|
||||
label:
|
||||
en_US: email address
|
||||
zh_Hans: email地址
|
||||
pt_BR: email address
|
||||
placeholder:
|
||||
en_US: Please input your email address
|
||||
zh_Hans: 请输入你的email地址
|
||||
pt_BR: Please input your email address
|
||||
help:
|
||||
en_US: According to the requirements of Crossref, an email address is required
|
||||
zh_Hans: 根据Crossref的要求,需要提供一个邮箱地址
|
||||
pt_BR: According to the requirements of Crossref, an email address is required
|
||||
url: https://api.crossref.org/swagger-ui/index.html
|
||||
@ -1,28 +0,0 @@
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.errors import ToolParameterValidationError
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class CrossRefQueryDOITool(BuiltinTool):
|
||||
"""
|
||||
Tool for querying the metadata of a publication using its DOI.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
doi = tool_parameters.get("doi")
|
||||
if not doi:
|
||||
raise ToolParameterValidationError("doi is required.")
|
||||
# doc: https://github.com/CrossRef/rest-api-doc
|
||||
url = f"https://api.crossref.org/works/{doi}"
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
response = response.json()
|
||||
message = response.get("message", {})
|
||||
|
||||
return self.create_json_message(message)
|
||||
@ -1,23 +0,0 @@
|
||||
identity:
|
||||
name: crossref_query_doi
|
||||
author: Sakura4036
|
||||
label:
|
||||
en_US: CrossRef Query DOI
|
||||
zh_Hans: CrossRef DOI 查询
|
||||
pt_BR: CrossRef Query DOI
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for searching literature information using CrossRef by DOI.
|
||||
zh_Hans: 一个使用CrossRef通过DOI获取文献信息的工具。
|
||||
pt_BR: A tool for searching literature information using CrossRef by DOI.
|
||||
llm: A tool for searching literature information using CrossRef by DOI.
|
||||
parameters:
|
||||
- name: doi
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: DOI
|
||||
zh_Hans: DOI
|
||||
pt_BR: DOI
|
||||
llm_description: DOI for searching in CrossRef
|
||||
form: llm
|
||||
@ -1,143 +0,0 @@
|
||||
import time
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
def convert_time_str_to_seconds(time_str: str) -> int:
|
||||
"""
|
||||
Convert a time string to seconds.
|
||||
example: 1s -> 1, 1m30s -> 90, 1h30m -> 5400, 1h30m30s -> 5430
|
||||
"""
|
||||
time_str = time_str.lower().strip().replace(" ", "")
|
||||
seconds = 0
|
||||
if "h" in time_str:
|
||||
hours, time_str = time_str.split("h")
|
||||
seconds += int(hours) * 3600
|
||||
if "m" in time_str:
|
||||
minutes, time_str = time_str.split("m")
|
||||
seconds += int(minutes) * 60
|
||||
if "s" in time_str:
|
||||
seconds += int(time_str.replace("s", ""))
|
||||
return seconds
|
||||
|
||||
|
||||
class CrossRefQueryTitleAPI:
|
||||
"""
|
||||
Tool for querying the metadata of a publication using its title.
|
||||
Crossref API doc: https://github.com/CrossRef/rest-api-doc
|
||||
"""
|
||||
|
||||
query_url_template: str = "https://api.crossref.org/works?query.bibliographic={query}&rows={rows}&offset={offset}&sort={sort}&order={order}&mailto={mailto}"
|
||||
rate_limit: int = 50
|
||||
rate_interval: float = 1
|
||||
max_limit: int = 1000
|
||||
|
||||
def __init__(self, mailto: str):
|
||||
self.mailto = mailto
|
||||
|
||||
def _query(
|
||||
self,
|
||||
query: str,
|
||||
rows: int = 5,
|
||||
offset: int = 0,
|
||||
sort: str = "relevance",
|
||||
order: str = "desc",
|
||||
fuzzy_query: bool = False,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Query the metadata of a publication using its title.
|
||||
:param query: the title of the publication
|
||||
:param rows: the number of results to return
|
||||
:param sort: the sort field
|
||||
:param order: the sort order
|
||||
:param fuzzy_query: whether to return all items that match the query
|
||||
"""
|
||||
url = self.query_url_template.format(
|
||||
query=query, rows=rows, offset=offset, sort=sort, order=order, mailto=self.mailto
|
||||
)
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
rate_limit = int(response.headers["x-ratelimit-limit"])
|
||||
# convert time string to seconds
|
||||
rate_interval = convert_time_str_to_seconds(response.headers["x-ratelimit-interval"])
|
||||
|
||||
self.rate_limit = rate_limit
|
||||
self.rate_interval = rate_interval
|
||||
|
||||
response = response.json()
|
||||
if response["status"] != "ok":
|
||||
return []
|
||||
|
||||
message = response["message"]
|
||||
if fuzzy_query:
|
||||
# fuzzy query return all items
|
||||
return message["items"]
|
||||
else:
|
||||
for paper in message["items"]:
|
||||
title = paper["title"][0]
|
||||
if title.lower() != query.lower():
|
||||
continue
|
||||
return [paper]
|
||||
return []
|
||||
|
||||
def query(
|
||||
self, query: str, rows: int = 5, sort: str = "relevance", order: str = "desc", fuzzy_query: bool = False
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Query the metadata of a publication using its title.
|
||||
:param query: the title of the publication
|
||||
:param rows: the number of results to return
|
||||
:param sort: the sort field
|
||||
:param order: the sort order
|
||||
:param fuzzy_query: whether to return all items that match the query
|
||||
"""
|
||||
rows = min(rows, self.max_limit)
|
||||
if rows > self.rate_limit:
|
||||
# query multiple times
|
||||
query_times = rows // self.rate_limit + 1
|
||||
results = []
|
||||
|
||||
for i in range(query_times):
|
||||
result = self._query(
|
||||
query,
|
||||
rows=self.rate_limit,
|
||||
offset=i * self.rate_limit,
|
||||
sort=sort,
|
||||
order=order,
|
||||
fuzzy_query=fuzzy_query,
|
||||
)
|
||||
if fuzzy_query:
|
||||
results.extend(result)
|
||||
else:
|
||||
# fuzzy_query=False, only one result
|
||||
if result:
|
||||
return result
|
||||
time.sleep(self.rate_interval)
|
||||
return results
|
||||
else:
|
||||
# query once
|
||||
return self._query(query, rows, sort=sort, order=order, fuzzy_query=fuzzy_query)
|
||||
|
||||
|
||||
class CrossRefQueryTitleTool(BuiltinTool):
|
||||
"""
|
||||
Tool for querying the metadata of a publication using its title.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
query = tool_parameters.get("query")
|
||||
fuzzy_query = tool_parameters.get("fuzzy_query", False)
|
||||
rows = tool_parameters.get("rows", 3)
|
||||
sort = tool_parameters.get("sort", "relevance")
|
||||
order = tool_parameters.get("order", "desc")
|
||||
mailto = self.runtime.credentials["mailto"]
|
||||
|
||||
result = CrossRefQueryTitleAPI(mailto).query(query, rows, sort, order, fuzzy_query)
|
||||
|
||||
return [self.create_json_message(r) for r in result]
|
||||
@ -1,105 +0,0 @@
|
||||
identity:
|
||||
name: crossref_query_title
|
||||
author: Sakura4036
|
||||
label:
|
||||
en_US: CrossRef Title Query
|
||||
zh_Hans: CrossRef 标题查询
|
||||
pt_BR: CrossRef Title Query
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for querying literature information using CrossRef by title.
|
||||
zh_Hans: 一个使用CrossRef通过标题搜索文献信息的工具。
|
||||
pt_BR: A tool for querying literature information using CrossRef by title.
|
||||
llm: A tool for querying literature information using CrossRef by title.
|
||||
parameters:
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: 标题
|
||||
zh_Hans: 查询语句
|
||||
pt_BR: 标题
|
||||
human_description:
|
||||
en_US: Query bibliographic information, useful for citation look up. Includes titles, authors, ISSNs and publication years
|
||||
zh_Hans: 用于搜索文献信息,有助于查找引用。包括标题,作者,ISSN和出版年份
|
||||
pt_BR: Query bibliographic information, useful for citation look up. Includes titles, authors, ISSNs and publication years
|
||||
llm_description: key words for querying in Web of Science
|
||||
form: llm
|
||||
- name: fuzzy_query
|
||||
type: boolean
|
||||
default: false
|
||||
label:
|
||||
en_US: Whether to fuzzy search
|
||||
zh_Hans: 是否模糊搜索
|
||||
pt_BR: Whether to fuzzy search
|
||||
human_description:
|
||||
en_US: used for selecting the query type, fuzzy query returns more results, precise query returns 1 or none
|
||||
zh_Hans: 用于选择搜索类型,模糊搜索返回更多结果,精确搜索返回1条结果或无
|
||||
pt_BR: used for selecting the query type, fuzzy query returns more results, precise query returns 1 or none
|
||||
form: form
|
||||
- name: limit
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: max query number
|
||||
zh_Hans: 最大搜索数
|
||||
pt_BR: max query number
|
||||
human_description:
|
||||
en_US: max query number(fuzzy search returns the maximum number of results or precise search the maximum number of matches)
|
||||
zh_Hans: 最大搜索数(模糊搜索返回的最大结果数或精确搜索最大匹配数)
|
||||
pt_BR: max query number(fuzzy search returns the maximum number of results or precise search the maximum number of matches)
|
||||
form: llm
|
||||
default: 50
|
||||
- name: sort
|
||||
type: select
|
||||
required: true
|
||||
options:
|
||||
- value: relevance
|
||||
label:
|
||||
en_US: relevance
|
||||
zh_Hans: 相关性
|
||||
pt_BR: relevance
|
||||
- value: published
|
||||
label:
|
||||
en_US: publication date
|
||||
zh_Hans: 出版日期
|
||||
pt_BR: publication date
|
||||
- value: references-count
|
||||
label:
|
||||
en_US: references-count
|
||||
zh_Hans: 引用次数
|
||||
pt_BR: references-count
|
||||
default: relevance
|
||||
label:
|
||||
en_US: sorting field
|
||||
zh_Hans: 排序字段
|
||||
pt_BR: sorting field
|
||||
human_description:
|
||||
en_US: Sorting of query results
|
||||
zh_Hans: 检索结果的排序字段
|
||||
pt_BR: Sorting of query results
|
||||
form: form
|
||||
- name: order
|
||||
type: select
|
||||
required: true
|
||||
options:
|
||||
- value: desc
|
||||
label:
|
||||
en_US: descending
|
||||
zh_Hans: 降序
|
||||
pt_BR: descending
|
||||
- value: asc
|
||||
label:
|
||||
en_US: ascending
|
||||
zh_Hans: 升序
|
||||
pt_BR: ascending
|
||||
default: desc
|
||||
label:
|
||||
en_US: Order
|
||||
zh_Hans: 排序
|
||||
pt_BR: Order
|
||||
human_description:
|
||||
en_US: Order of query results
|
||||
zh_Hans: 检索结果的排序方式
|
||||
pt_BR: Order of query results
|
||||
form: form
|
||||
|
Before Width: | Height: | Size: 153 KiB |
@ -1,20 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.dalle.tools.dalle2 import DallE2Tool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class DALLEProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
DallE2Tool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id="",
|
||||
tool_parameters={"prompt": "cute girl, blue eyes, white hair, anime style", "size": "small", "n": 1},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
@ -1,61 +0,0 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: dalle
|
||||
label:
|
||||
en_US: DALL-E
|
||||
zh_Hans: DALL-E 绘画
|
||||
pt_BR: DALL-E
|
||||
description:
|
||||
en_US: DALL-E art
|
||||
zh_Hans: DALL-E 绘画
|
||||
pt_BR: DALL-E art
|
||||
icon: icon.png
|
||||
tags:
|
||||
- image
|
||||
- productivity
|
||||
credentials_for_provider:
|
||||
openai_api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: OpenAI API key
|
||||
zh_Hans: OpenAI API key
|
||||
pt_BR: OpenAI API key
|
||||
help:
|
||||
en_US: Please input your OpenAI API key
|
||||
zh_Hans: 请输入你的 OpenAI API key
|
||||
pt_BR: Please input your OpenAI API key
|
||||
placeholder:
|
||||
en_US: Please input your OpenAI API key
|
||||
zh_Hans: 请输入你的 OpenAI API key
|
||||
pt_BR: Please input your OpenAI API key
|
||||
openai_organization_id:
|
||||
type: text-input
|
||||
required: false
|
||||
label:
|
||||
en_US: OpenAI organization ID
|
||||
zh_Hans: OpenAI organization ID
|
||||
pt_BR: OpenAI organization ID
|
||||
help:
|
||||
en_US: Please input your OpenAI organization ID
|
||||
zh_Hans: 请输入你的 OpenAI organization ID
|
||||
pt_BR: Please input your OpenAI organization ID
|
||||
placeholder:
|
||||
en_US: Please input your OpenAI organization ID
|
||||
zh_Hans: 请输入你的 OpenAI organization ID
|
||||
pt_BR: Please input your OpenAI organization ID
|
||||
openai_base_url:
|
||||
type: text-input
|
||||
required: false
|
||||
label:
|
||||
en_US: OpenAI base URL
|
||||
zh_Hans: OpenAI base URL
|
||||
pt_BR: OpenAI base URL
|
||||
help:
|
||||
en_US: Please input your OpenAI base URL
|
||||
zh_Hans: 请输入你的 OpenAI base URL
|
||||
pt_BR: Please input your OpenAI base URL
|
||||
placeholder:
|
||||
en_US: Please input your OpenAI base URL
|
||||
zh_Hans: 请输入你的 OpenAI base URL
|
||||
pt_BR: Please input your OpenAI base URL
|
||||
@ -1,66 +0,0 @@
|
||||
from base64 import b64decode
|
||||
from typing import Any, Union
|
||||
|
||||
from openai import OpenAI
|
||||
from yarl import URL
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class DallE2Tool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
openai_organization = self.runtime.credentials.get("openai_organization_id", None)
|
||||
if not openai_organization:
|
||||
openai_organization = None
|
||||
openai_base_url = self.runtime.credentials.get("openai_base_url", None)
|
||||
if not openai_base_url:
|
||||
openai_base_url = None
|
||||
else:
|
||||
openai_base_url = str(URL(openai_base_url) / "v1")
|
||||
|
||||
client = OpenAI(
|
||||
api_key=self.runtime.credentials["openai_api_key"],
|
||||
base_url=openai_base_url,
|
||||
organization=openai_organization,
|
||||
)
|
||||
|
||||
SIZE_MAPPING = {
|
||||
"small": "256x256",
|
||||
"medium": "512x512",
|
||||
"large": "1024x1024",
|
||||
}
|
||||
|
||||
# prompt
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
if not prompt:
|
||||
return self.create_text_message("Please input prompt")
|
||||
|
||||
# get size
|
||||
size = SIZE_MAPPING[tool_parameters.get("size", "large")]
|
||||
|
||||
# get n
|
||||
n = tool_parameters.get("n", 1)
|
||||
|
||||
# call openapi dalle2
|
||||
response = client.images.generate(prompt=prompt, model="dall-e-2", size=size, n=n, response_format="b64_json")
|
||||
|
||||
result = []
|
||||
|
||||
for image in response.data:
|
||||
result.append(
|
||||
self.create_blob_message(
|
||||
blob=b64decode(image.b64_json),
|
||||
meta={"mime_type": "image/png"},
|
||||
save_as=self.VariableKey.IMAGE.value,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
@ -1,74 +0,0 @@
|
||||
identity:
|
||||
name: dalle2
|
||||
author: Dify
|
||||
label:
|
||||
en_US: DALL-E 2
|
||||
zh_Hans: DALL-E 2 绘画
|
||||
description:
|
||||
en_US: DALL-E 2 is a powerful drawing tool that can draw the image you want based on your prompt
|
||||
zh_Hans: DALL-E 2 是一个强大的绘画工具,它可以根据您的提示词绘制出您想要的图像
|
||||
pt_BR: DALL-E 2 is a powerful drawing tool that can draw the image you want based on your prompt
|
||||
description:
|
||||
human:
|
||||
en_US: DALL-E is a text to image tool
|
||||
zh_Hans: DALL-E 是一个文本到图像的工具
|
||||
pt_BR: DALL-E is a text to image tool
|
||||
llm: DALL-E is a tool used to generate images from text
|
||||
parameters:
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Prompt
|
||||
zh_Hans: 提示词
|
||||
pt_BR: Prompt
|
||||
human_description:
|
||||
en_US: Image prompt, you can check the official documentation of DallE 2
|
||||
zh_Hans: 图像提示词,您可以查看 DallE 2 的官方文档
|
||||
pt_BR: Image prompt, you can check the official documentation of DallE 2
|
||||
llm_description: Image prompt of DallE 2, you should describe the image you want to generate as a list of words as possible as detailed
|
||||
form: llm
|
||||
- name: size
|
||||
type: select
|
||||
required: true
|
||||
human_description:
|
||||
en_US: used for selecting the image size
|
||||
zh_Hans: 用于选择图像大小
|
||||
pt_BR: used for selecting the image size
|
||||
label:
|
||||
en_US: Image size
|
||||
zh_Hans: 图像大小
|
||||
pt_BR: Image size
|
||||
form: form
|
||||
options:
|
||||
- value: small
|
||||
label:
|
||||
en_US: Small(256x256)
|
||||
zh_Hans: 小(256x256)
|
||||
pt_BR: Small(256x256)
|
||||
- value: medium
|
||||
label:
|
||||
en_US: Medium(512x512)
|
||||
zh_Hans: 中(512x512)
|
||||
pt_BR: Medium(512x512)
|
||||
- value: large
|
||||
label:
|
||||
en_US: Large(1024x1024)
|
||||
zh_Hans: 大(1024x1024)
|
||||
pt_BR: Large(1024x1024)
|
||||
default: large
|
||||
- name: n
|
||||
type: number
|
||||
required: true
|
||||
human_description:
|
||||
en_US: used for selecting the number of images
|
||||
zh_Hans: 用于选择图像数量
|
||||
pt_BR: used for selecting the number of images
|
||||
label:
|
||||
en_US: Number of images
|
||||
zh_Hans: 图像数量
|
||||
pt_BR: Number of images
|
||||
form: form
|
||||
default: 1
|
||||
min: 1
|
||||
max: 10
|
||||
@ -1,115 +0,0 @@
|
||||
import base64
|
||||
import random
|
||||
from typing import Any, Union
|
||||
|
||||
from openai import OpenAI
|
||||
from yarl import URL
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class DallE3Tool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
openai_organization = self.runtime.credentials.get("openai_organization_id", None)
|
||||
if not openai_organization:
|
||||
openai_organization = None
|
||||
openai_base_url = self.runtime.credentials.get("openai_base_url", None)
|
||||
if not openai_base_url:
|
||||
openai_base_url = None
|
||||
else:
|
||||
openai_base_url = str(URL(openai_base_url) / "v1")
|
||||
|
||||
client = OpenAI(
|
||||
api_key=self.runtime.credentials["openai_api_key"],
|
||||
base_url=openai_base_url,
|
||||
organization=openai_organization,
|
||||
)
|
||||
|
||||
SIZE_MAPPING = {
|
||||
"square": "1024x1024",
|
||||
"vertical": "1024x1792",
|
||||
"horizontal": "1792x1024",
|
||||
}
|
||||
|
||||
# prompt
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
if not prompt:
|
||||
return self.create_text_message("Please input prompt")
|
||||
# get size
|
||||
size = SIZE_MAPPING[tool_parameters.get("size", "square")]
|
||||
# get n
|
||||
n = tool_parameters.get("n", 1)
|
||||
# get quality
|
||||
quality = tool_parameters.get("quality", "standard")
|
||||
if quality not in {"standard", "hd"}:
|
||||
return self.create_text_message("Invalid quality")
|
||||
# get style
|
||||
style = tool_parameters.get("style", "vivid")
|
||||
if style not in {"natural", "vivid"}:
|
||||
return self.create_text_message("Invalid style")
|
||||
|
||||
# call openapi dalle3
|
||||
response = client.images.generate(
|
||||
prompt=prompt, model="dall-e-3", size=size, n=n, style=style, quality=quality, response_format="b64_json"
|
||||
)
|
||||
|
||||
result = []
|
||||
|
||||
for image in response.data:
|
||||
mime_type, blob_image = DallE3Tool._decode_image(image.b64_json)
|
||||
blob_message = self.create_blob_message(
|
||||
blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VariableKey.IMAGE.value
|
||||
)
|
||||
result.append(blob_message)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _decode_image(base64_image: str) -> tuple[str, bytes]:
|
||||
"""
|
||||
Decode a base64 encoded image. If the image is not prefixed with a MIME type,
|
||||
it assumes 'image/png' as the default.
|
||||
|
||||
:param base64_image: Base64 encoded image string
|
||||
:return: A tuple containing the MIME type and the decoded image bytes
|
||||
"""
|
||||
if DallE3Tool._is_plain_base64(base64_image):
|
||||
return "image/png", base64.b64decode(base64_image)
|
||||
else:
|
||||
return DallE3Tool._extract_mime_and_data(base64_image)
|
||||
|
||||
@staticmethod
|
||||
def _is_plain_base64(encoded_str: str) -> bool:
|
||||
"""
|
||||
Check if the given encoded string is plain base64 without a MIME type prefix.
|
||||
|
||||
:param encoded_str: Base64 encoded image string
|
||||
:return: True if the string is plain base64, False otherwise
|
||||
"""
|
||||
return not encoded_str.startswith("data:image")
|
||||
|
||||
@staticmethod
|
||||
def _extract_mime_and_data(encoded_str: str) -> tuple[str, bytes]:
|
||||
"""
|
||||
Extract MIME type and image data from a base64 encoded string with a MIME type prefix.
|
||||
|
||||
:param encoded_str: Base64 encoded image string with MIME type prefix
|
||||
:return: A tuple containing the MIME type and the decoded image bytes
|
||||
"""
|
||||
mime_type = encoded_str.split(";")[0].split(":")[1]
|
||||
image_data_base64 = encoded_str.split(",")[1]
|
||||
decoded_data = base64.b64decode(image_data_base64)
|
||||
return mime_type, decoded_data
|
||||
|
||||
@staticmethod
|
||||
def _generate_random_id(length=8):
|
||||
characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
||||
random_id = "".join(random.choices(characters, k=length))
|
||||
return random_id
|
||||
@ -1,123 +0,0 @@
|
||||
identity:
|
||||
name: dalle3
|
||||
author: Dify
|
||||
label:
|
||||
en_US: DALL-E 3
|
||||
zh_Hans: DALL-E 3 绘画
|
||||
pt_BR: DALL-E 3
|
||||
description:
|
||||
en_US: DALL-E 3 is a powerful drawing tool that can draw the image you want based on your prompt, compared to DallE 2, DallE 3 has stronger drawing ability, but it will consume more resources
|
||||
zh_Hans: DALL-E 3 是一个强大的绘画工具,它可以根据您的提示词绘制出您想要的图像,相比于DallE 2, DallE 3拥有更强的绘画能力,但会消耗更多的资源
|
||||
pt_BR: DALL-E 3 is a powerful drawing tool that can draw the image you want based on your prompt, compared to DallE 2, DallE 3 has stronger drawing ability, but it will consume more resources
|
||||
description:
|
||||
human:
|
||||
en_US: DALL-E is a text to image tool
|
||||
zh_Hans: DALL-E 是一个文本到图像的工具
|
||||
pt_BR: DALL-E is a text to image tool
|
||||
llm: DALL-E is a tool used to generate images from text
|
||||
parameters:
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Prompt
|
||||
zh_Hans: 提示词
|
||||
pt_BR: Prompt
|
||||
human_description:
|
||||
en_US: Image prompt, you can check the official documentation of DallE 3
|
||||
zh_Hans: 图像提示词,您可以查看 DallE 3 的官方文档
|
||||
pt_BR: Image prompt, you can check the official documentation of DallE 3
|
||||
llm_description: Image prompt of DallE 3, you should describe the image you want to generate as a list of words as possible as detailed
|
||||
form: llm
|
||||
- name: size
|
||||
type: select
|
||||
required: true
|
||||
human_description:
|
||||
en_US: selecting the image size
|
||||
zh_Hans: 选择图像大小
|
||||
pt_BR: selecting the image size
|
||||
label:
|
||||
en_US: Image size
|
||||
zh_Hans: 图像大小
|
||||
pt_BR: Image size
|
||||
form: form
|
||||
options:
|
||||
- value: square
|
||||
label:
|
||||
en_US: Squre(1024x1024)
|
||||
zh_Hans: 方(1024x1024)
|
||||
pt_BR: Squre(1024x1024)
|
||||
- value: vertical
|
||||
label:
|
||||
en_US: Vertical(1024x1792)
|
||||
zh_Hans: 竖屏(1024x1792)
|
||||
pt_BR: Vertical(1024x1792)
|
||||
- value: horizontal
|
||||
label:
|
||||
en_US: Horizontal(1792x1024)
|
||||
zh_Hans: 横屏(1792x1024)
|
||||
pt_BR: Horizontal(1792x1024)
|
||||
default: square
|
||||
- name: n
|
||||
type: number
|
||||
required: true
|
||||
human_description:
|
||||
en_US: selecting the number of images
|
||||
zh_Hans: 选择图像数量
|
||||
pt_BR: selecting the number of images
|
||||
label:
|
||||
en_US: Number of images
|
||||
zh_Hans: 图像数量
|
||||
pt_BR: Number of images
|
||||
form: form
|
||||
min: 1
|
||||
max: 1
|
||||
default: 1
|
||||
- name: quality
|
||||
type: select
|
||||
required: true
|
||||
human_description:
|
||||
en_US: selecting the image quality
|
||||
zh_Hans: 选择图像质量
|
||||
pt_BR: selecting the image quality
|
||||
label:
|
||||
en_US: Image quality
|
||||
zh_Hans: 图像质量
|
||||
pt_BR: Image quality
|
||||
form: form
|
||||
options:
|
||||
- value: standard
|
||||
label:
|
||||
en_US: Standard
|
||||
zh_Hans: 标准
|
||||
pt_BR: Standard
|
||||
- value: hd
|
||||
label:
|
||||
en_US: HD
|
||||
zh_Hans: 高清
|
||||
pt_BR: HD
|
||||
default: standard
|
||||
- name: style
|
||||
type: select
|
||||
required: true
|
||||
human_description:
|
||||
en_US: selecting the image style
|
||||
zh_Hans: 选择图像风格
|
||||
pt_BR: selecting the image style
|
||||
label:
|
||||
en_US: Image style
|
||||
zh_Hans: 图像风格
|
||||
pt_BR: Image style
|
||||
form: form
|
||||
options:
|
||||
- value: vivid
|
||||
label:
|
||||
en_US: Vivid
|
||||
zh_Hans: 生动
|
||||
pt_BR: Vivid
|
||||
- value: natural
|
||||
label:
|
||||
en_US: Natural
|
||||
zh_Hans: 自然
|
||||
pt_BR: Natural
|
||||
default: vivid
|
||||
@ -1,4 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<svg width="800px" height="800px" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M15.6111 1.5837C17.2678 1.34703 18.75 2.63255 18.75 4.30606V5.68256C19.9395 6.31131 20.75 7.56102 20.75 9.00004V19C20.75 21.0711 19.0711 22.75 17 22.75H7C4.92893 22.75 3.25 21.0711 3.25 19V5.00004C3.25 4.99074 3.25017 4.98148 3.2505 4.97227C3.25017 4.95788 3.25 4.94344 3.25 4.92897C3.25 4.02272 3.91638 3.25437 4.81353 3.12621L15.6111 1.5837ZM4.75 6.75004V19C4.75 20.2427 5.75736 21.25 7 21.25H17C18.2426 21.25 19.25 20.2427 19.25 19V9.00004C19.25 7.7574 18.2426 6.75004 17 6.75004H4.75ZM5.07107 5.25004H17.25V4.30606C17.25 3.54537 16.5763 2.96104 15.8232 3.06862L5.02566 4.61113C4.86749 4.63373 4.75 4.76919 4.75 4.92897C4.75 5.10629 4.89375 5.25004 5.07107 5.25004ZM7.25 12C7.25 11.5858 7.58579 11.25 8 11.25H16C16.4142 11.25 16.75 11.5858 16.75 12C16.75 12.4143 16.4142 12.75 16 12.75H8C7.58579 12.75 7.25 12.4143 7.25 12ZM7.25 15.5C7.25 15.0858 7.58579 14.75 8 14.75H13.5C13.9142 14.75 14.25 15.0858 14.25 15.5C14.25 15.9143 13.9142 16.25 13.5 16.25H8C7.58579 16.25 7.25 15.9143 7.25 15.5Z" fill="#1C274D"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 1.2 KiB |
@ -1,21 +0,0 @@
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.devdocs.tools.searchDevDocs import SearchDevDocsTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class DevDocsProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
SearchDevDocsTool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"doc": "python~3.12",
|
||||
"topic": "library/code",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
@ -1,13 +0,0 @@
|
||||
identity:
|
||||
author: Richards Tu
|
||||
name: devdocs
|
||||
label:
|
||||
en_US: DevDocs
|
||||
zh_Hans: DevDocs
|
||||
description:
|
||||
en_US: Get official developer documentations on DevDocs.
|
||||
zh_Hans: 从DevDocs获取官方开发者文档。
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- search
|
||||
- productivity
|
||||
@ -1,47 +0,0 @@
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class SearchDevDocsInput(BaseModel):
|
||||
doc: str = Field(..., description="The name of the documentation.")
|
||||
topic: str = Field(..., description="The path of the section/topic.")
|
||||
|
||||
|
||||
class SearchDevDocsTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
Invokes the DevDocs search tool with the given user ID and tool parameters.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user invoking the tool.
|
||||
tool_parameters (dict[str, Any]): The parameters for the tool, including 'doc' and 'topic'.
|
||||
|
||||
Returns:
|
||||
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation,
|
||||
which can be a single message or a list of messages.
|
||||
"""
|
||||
doc = tool_parameters.get("doc", "")
|
||||
topic = tool_parameters.get("topic", "")
|
||||
|
||||
if not doc:
|
||||
return self.create_text_message("Please provide the documentation name.")
|
||||
if not topic:
|
||||
return self.create_text_message("Please provide the topic path.")
|
||||
|
||||
url = f"https://documents.devdocs.io/{doc}/{topic}.html"
|
||||
response = requests.get(url)
|
||||
|
||||
if response.status_code == 200:
|
||||
content = response.text
|
||||
return self.create_text_message(self.summary(user_id=user_id, content=content))
|
||||
else:
|
||||
return self.create_text_message(
|
||||
f"Failed to retrieve the documentation. Status code: {response.status_code}"
|
||||
)
|
||||
@ -1,34 +0,0 @@
|
||||
identity:
|
||||
name: searchDevDocs
|
||||
author: Richards Tu
|
||||
label:
|
||||
en_US: Search Developer Docs
|
||||
zh_Hans: 搜索开发者文档
|
||||
description:
|
||||
human:
|
||||
en_US: A tools for searching for a specific topic and path in DevDocs based on the provided documentation name and topic. Don't for get to add some shots in the system prompt; for example, the documentation name should be like \"vuex~4\", \"css\", or \"python~3.12\", while the topic should be like \"guide/actions\" for Vuex 4, \"display-box\" for CSS, or \"library/code\" for Python 3.12.
|
||||
zh_Hans: 一个用于根据提供的文档名称和主题,在DevDocs中搜索特定主题和路径的工具。不要忘记在系统提示词中添加一些示例;例如,文档名称应该是\"vuex~4\"、\"css\"或\"python~3.12\",而主题应该是\"guide/actions\"用于Vuex 4,\"display-box\"用于CSS,或\"library/code\"用于Python 3.12。
|
||||
llm: A tools for searching for specific developer documentation in DevDocs based on the provided documentation name and topic.
|
||||
parameters:
|
||||
- name: doc
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Documentation name
|
||||
zh_Hans: 文档名称
|
||||
human_description:
|
||||
en_US: The name of the documentation.
|
||||
zh_Hans: 文档名称。
|
||||
llm_description: The name of the documentation, such as \"vuex~4\", \"css\", or \"python~3.12\". The exact value should be identified by the user.
|
||||
form: llm
|
||||
- name: topic
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Topic name
|
||||
zh_Hans: 主题名称
|
||||
human_description:
|
||||
en_US: The path of the section/topic.
|
||||
zh_Hans: 文档主题的路径。
|
||||
llm_description: The path of the section/topic, such as \"guide/actions\" for Vuex 4, \"display-box\" for CSS, or \"library/code\" for Python 3.12.
|
||||
form: llm
|
||||
@ -1,14 +0,0 @@
|
||||
<svg viewBox="0 0 40 40" xmlns="http://www.w3.org/2000/svg">
|
||||
<rect x="0" y="0" width="40" height="40" style="fill: #0d0a08;"></rect>
|
||||
<g clip-path="url(#clip0_269_13)" transform="matrix(0.429227, 0, 0, 0.429227, 6.326543, 9.593137)" style="background-color: f4f3f2">
|
||||
<path d="M6.05159 7.04111H0.5V44.0227H6.05159C13.5 44.0227 16.6023 42.1692 16.6023 34.1718V16.8831C16.6023 8.791 13.503 7.03223 6.05159 7.03223V7.03815V7.04111ZM11.9755 34.1718C11.9755 38.7019 10.5898 39.3948 6.09591 39.3948H5.12091V11.6601H6.09591C10.5839 11.6601 11.9755 12.353 11.9755 16.8831V34.1718Z" fill="white"></path>
|
||||
<path d="M18.9834 26.2188V29.9169H25.9207V26.2188H18.9834Z" fill="white"></path>
|
||||
<path d="M28.562 13.9783V44.0225H33.1888V13.9783H28.562Z" fill="white"></path>
|
||||
<path d="M41.3822 7.04111H35.8306V44.0227H41.3822C48.8306 44.0227 51.9358 42.1692 51.9358 34.1718V16.8831C51.9358 8.791 48.8365 7.03223 41.3822 7.03223V7.03815V7.04111ZM47.306 34.1718C47.306 38.7019 45.9203 39.3948 41.4265 39.3948H40.4515V11.6601H41.4265C45.9144 11.6601 47.306 12.353 47.306 16.8831V34.1718Z" fill="white"></path>
|
||||
<path d="M30.8758 11.2278C32.2775 11.2278 33.4138 10.0917 33.4138 8.69032C33.4138 7.2889 32.2775 6.15283 30.8758 6.15283C29.4742 6.15283 28.3379 7.2889 28.3379 8.69032C28.3379 10.0917 29.4742 11.2278 30.8758 11.2278Z" fill="#FF882E"></path>
|
||||
<path d="M36.191 4.02677C36.9621 4.02677 37.5885 3.40202 37.5885 2.62923C37.5885 1.85644 36.9621 1.23169 36.191 1.23169C35.4198 1.23169 34.7935 1.85644 34.7935 2.62923C34.7935 3.40202 35.4198 4.02677 36.191 4.02677Z" fill="#FF882E"></path>
|
||||
<path d="M42.1978 2.09631C42.7769 2.09631 43.2467 1.62553 43.2467 1.04816C43.2467 0.470782 42.7769 0 42.1978 0C41.6187 0 41.1489 0.470782 41.1489 1.04816C41.1489 1.62553 41.6187 2.09631 42.1978 2.09631Z" fill="#FF882E"></path>
|
||||
<path d="M47.8467 3.14734C48.4258 3.14734 48.8956 2.67656 48.8956 2.09918C48.8956 1.52181 48.4258 1.05103 47.8467 1.05103C47.2676 1.05103 46.7979 1.52181 46.7979 2.09918C46.7979 2.67656 47.2676 3.14734 47.8467 3.14734Z" fill="#FF882E"></path>
|
||||
<path d="M55.9065 53C54.7276 53 53.729 52.6239 53.0081 52.3515L52.7422 52.2538C51.5367 51.8156 50.3726 51.3774 49.2854 50.951C48.6826 50.7142 48.3842 50.0332 48.6206 49.4291C48.857 48.8251 49.5395 48.529 50.1422 48.7659C51.2117 49.1863 52.3581 49.6157 53.5488 50.048C53.6433 50.0835 53.7408 50.119 53.8383 50.1575C54.6449 50.4625 55.5608 50.8089 56.5654 50.5839C57.4635 50.3825 58.0219 50.0391 58.3144 49.5091C58.5892 49.0117 58.5035 48.6593 58.3144 48.0227C58.1549 47.4897 57.9599 46.8265 58.214 46.107C58.4976 45.3016 59.0738 44.9078 59.4963 44.6206C59.8833 44.3542 59.9631 44.2831 60.0074 44.0581C60.1049 43.5606 59.8272 43.3001 59.7297 43.2261C59.2895 43.0662 58.9763 42.6516 58.9585 42.1661C58.9379 41.5916 59.3338 41.0883 59.8951 40.9728C59.9956 40.9521 60.2999 40.8899 60.5451 40.6412C60.6722 40.5139 60.8908 40.2474 60.9115 39.8891C60.9292 39.5605 60.7549 39.291 60.4683 38.8913C60.1492 38.4501 59.5554 37.627 60.1049 36.7032C60.5392 35.9719 61.4581 35.5899 62.1967 35.282C62.3504 35.2168 62.4892 35.1606 62.5956 35.1103C63.0388 34.8911 63.2338 34.6484 63.1392 33.8696C63.1097 33.6357 63.0004 33.4492 62.9295 33.3485C61.9456 32.0813 61.0297 30.8081 60.1315 29.4579C59.3397 28.2617 58.8079 27.3823 58.601 26.1328C58.3913 24.8804 59.0472 22.5916 59.124 22.334C59.907 19.6692 59.9424 17.641 58.0367 13.321C56.5979 10.064 54.376 7.8345 52.7658 6.53762C52.2606 6.13198 52.1808 5.39176 52.5885 4.88841C52.9963 4.38209 53.7349 4.30215 54.2401 4.71075C56.8401 6.8041 58.8935 9.4541 60.1847 12.3735C62.1435 16.8119 62.4331 19.3938 61.3783 22.9943C61.1006 23.9388 60.8465 25.3008 60.9204 25.7479C61.0415 26.4792 61.3192 26.9915 62.0904 28.161C62.959 29.4668 63.8454 30.6985 64.7997 31.9243C64.8085 31.9362 64.8174 31.9451 64.8233 31.9569C65.0685 32.2944 65.3788 32.8511 65.4704 33.5824C65.7244 35.7084 64.6135 36.7299 63.6385 37.2125C63.4967 37.2835 63.3076 37.3635 63.1008 37.4494C62.9531 37.5115 62.7226 37.6063 62.5129 37.707C62.8645 38.2103 63.3195 38.9742 63.2604 40.0165C63.2131 40.8603 62.8408 41.6716 62.2115 42.2993C62.1613 42.3496 62.1081 42.4 62.052 42.4473C62.3622 43.0721 62.4567 43.7857 62.3179 44.5022C62.0845 45.6954 61.2956 46.2343 60.8258 46.5541C60.6249 46.6903 60.492 46.7851 60.4476 46.8561C60.4565 46.9597 60.5245 47.1818 60.5717 47.3476C60.7845 48.0612 61.139 49.2574 60.3767 50.6372C59.7533 51.7682 58.6454 52.5173 57.0883 52.8667C56.6806 52.9585 56.2876 52.997 55.9124 52.997L55.9065 53Z" fill="#FF882E"></path>
|
||||
</g>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 4.4 KiB |
@ -1,18 +0,0 @@
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.did.tools.talks import TalksTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class DIDProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
# Example validation using the D-ID talks tool
|
||||
TalksTool().fork_tool_runtime(runtime={"credentials": credentials}).invoke(
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"source_url": "https://www.d-id.com/wp-content/uploads/2023/11/Hero-image-1.png",
|
||||
"text_input": "Hello, welcome to use D-ID tool in Dify",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
@ -1,28 +0,0 @@
|
||||
identity:
|
||||
author: Matri Qi
|
||||
name: did
|
||||
label:
|
||||
en_US: D-ID
|
||||
description:
|
||||
en_US: D-ID is a tool enabling the creation of high-quality, custom videos of Digital Humans from a single image.
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- videos
|
||||
credentials_for_provider:
|
||||
did_api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: D-ID API Key
|
||||
placeholder:
|
||||
en_US: Please input your D-ID API key
|
||||
help:
|
||||
en_US: Get your D-ID API key from your D-ID account settings.
|
||||
url: https://studio.d-id.com/account-settings
|
||||
base_url:
|
||||
type: text-input
|
||||
required: false
|
||||
label:
|
||||
en_US: D-ID server's Base URL
|
||||
placeholder:
|
||||
en_US: https://api.d-id.com
|
||||
@ -1,87 +0,0 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DIDApp:
|
||||
def __init__(self, api_key: str | None = None, base_url: str | None = None):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url or "https://api.d-id.com"
|
||||
if not self.api_key:
|
||||
raise ValueError("API key is required")
|
||||
|
||||
def _prepare_headers(self, idempotency_key: str | None = None):
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Basic {self.api_key}"}
|
||||
if idempotency_key:
|
||||
headers["Idempotency-Key"] = idempotency_key
|
||||
return headers
|
||||
|
||||
def _request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
data: Mapping[str, Any] | None = None,
|
||||
headers: Mapping[str, str] | None = None,
|
||||
retries: int = 3,
|
||||
backoff_factor: float = 0.3,
|
||||
) -> Mapping[str, Any] | None:
|
||||
for i in range(retries):
|
||||
try:
|
||||
response = requests.request(method, url, json=data, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
if i < retries - 1 and isinstance(e, HTTPError) and e.response.status_code >= 500:
|
||||
time.sleep(backoff_factor * (2**i))
|
||||
else:
|
||||
raise
|
||||
return None
|
||||
|
||||
def talks(self, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs):
|
||||
endpoint = f"{self.base_url}/talks"
|
||||
headers = self._prepare_headers(idempotency_key)
|
||||
data = kwargs["params"]
|
||||
logger.debug(f"Send request to {endpoint=} body={data}")
|
||||
response = self._request("POST", endpoint, data, headers)
|
||||
if response is None:
|
||||
raise HTTPError("Failed to initiate D-ID talks after multiple retries")
|
||||
id: str = response["id"]
|
||||
if wait:
|
||||
return self._monitor_job_status(id=id, target="talks", poll_interval=poll_interval)
|
||||
return id
|
||||
|
||||
def animations(self, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs):
|
||||
endpoint = f"{self.base_url}/animations"
|
||||
headers = self._prepare_headers(idempotency_key)
|
||||
data = kwargs["params"]
|
||||
logger.debug(f"Send request to {endpoint=} body={data}")
|
||||
response = self._request("POST", endpoint, data, headers)
|
||||
if response is None:
|
||||
raise HTTPError("Failed to initiate D-ID talks after multiple retries")
|
||||
id: str = response["id"]
|
||||
if wait:
|
||||
return self._monitor_job_status(target="animations", id=id, poll_interval=poll_interval)
|
||||
return id
|
||||
|
||||
def check_did_status(self, target: str, id: str):
|
||||
endpoint = f"{self.base_url}/{target}/{id}"
|
||||
headers = self._prepare_headers()
|
||||
response = self._request("GET", endpoint, headers=headers)
|
||||
if response is None:
|
||||
raise HTTPError(f"Failed to check status for talks {id} after multiple retries")
|
||||
return response
|
||||
|
||||
def _monitor_job_status(self, target: str, id: str, poll_interval: int):
|
||||
while True:
|
||||
status = self.check_did_status(target=target, id=id)
|
||||
if status["status"] == "done":
|
||||
return status
|
||||
elif status["status"] == "error" or status["status"] == "rejected":
|
||||
raise HTTPError(f'Talks {id} failed: {status["status"]} {status.get("error", {}).get("description")}')
|
||||
time.sleep(poll_interval)
|
||||
@ -1,49 +0,0 @@
|
||||
import json
|
||||
from typing import Any, Union
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.provider.builtin.did.did_appx import DIDApp
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class AnimationsTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
app = DIDApp(api_key=self.runtime.credentials["did_api_key"], base_url=self.runtime.credentials["base_url"])
|
||||
|
||||
driver_expressions_str = tool_parameters.get("driver_expressions")
|
||||
driver_expressions = json.loads(driver_expressions_str) if driver_expressions_str else None
|
||||
|
||||
config = {
|
||||
"stitch": tool_parameters.get("stitch", True),
|
||||
"mute": tool_parameters.get("mute"),
|
||||
"result_format": tool_parameters.get("result_format") or "mp4",
|
||||
}
|
||||
config = {k: v for k, v in config.items() if v is not None and v != ""}
|
||||
|
||||
options = {
|
||||
"source_url": tool_parameters["source_url"],
|
||||
"driver_url": tool_parameters.get("driver_url"),
|
||||
"config": config,
|
||||
}
|
||||
options = {k: v for k, v in options.items() if v is not None and v != ""}
|
||||
|
||||
if not options.get("source_url"):
|
||||
raise ValueError("Source URL is required")
|
||||
|
||||
if config.get("logo_url"):
|
||||
if not config.get("logo_x"):
|
||||
raise ValueError("Logo X position is required when logo URL is provided")
|
||||
if not config.get("logo_y"):
|
||||
raise ValueError("Logo Y position is required when logo URL is provided")
|
||||
|
||||
animations_result = app.animations(params=options, wait=True)
|
||||
|
||||
if not isinstance(animations_result, str):
|
||||
animations_result = json.dumps(animations_result, ensure_ascii=False, indent=4)
|
||||
|
||||
if not animations_result:
|
||||
return self.create_text_message("D-ID animations request failed.")
|
||||
|
||||
return self.create_text_message(animations_result)
|
||||