Introduce Plugins (#13836)
Signed-off-by: yihong0618 <zouzou0208@gmail.com> Signed-off-by: -LAN- <laipz8200@outlook.com> Signed-off-by: xhe <xw897002528@gmail.com> Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: takatost <takatost@gmail.com> Co-authored-by: kurokobo <kuro664@gmail.com> Co-authored-by: Novice Lee <novicelee@NoviPro.local> Co-authored-by: zxhlyh <jasonapring2015@outlook.com> Co-authored-by: AkaraChen <akarachen@outlook.com> Co-authored-by: Yi <yxiaoisme@gmail.com> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: twwu <twwu@dify.ai> Co-authored-by: Hiroshi Fujita <fujita-h@users.noreply.github.com> Co-authored-by: AkaraChen <85140972+AkaraChen@users.noreply.github.com> Co-authored-by: NFish <douxc512@gmail.com> Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com> Co-authored-by: 非法操作 <hjlarry@163.com> Co-authored-by: Novice <857526207@qq.com> Co-authored-by: Hiroki Nagai <82458324+nagaihiroki-git@users.noreply.github.com> Co-authored-by: Gen Sato <52241300+halogen22@users.noreply.github.com> Co-authored-by: eux <euxuuu@gmail.com> Co-authored-by: huangzhuo1949 <167434202+huangzhuo1949@users.noreply.github.com> Co-authored-by: huangzhuo <huangzhuo1@xiaomi.com> Co-authored-by: lotsik <lotsik@mail.ru> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: nite-knite <nkCoding@gmail.com> Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: gakkiyomi <gakkiyomi@aliyun.com> Co-authored-by: CN-P5 <heibai2006@gmail.com> Co-authored-by: CN-P5 <heibai2006@qq.com> Co-authored-by: Chuehnone <1897025+chuehnone@users.noreply.github.com> Co-authored-by: yihong <zouzou0208@gmail.com> Co-authored-by: Kevin9703 <51311316+Kevin9703@users.noreply.github.com> Co-authored-by: -LAN- <laipz8200@outlook.com> Co-authored-by: Boris Feld <lothiraldan@gmail.com> Co-authored-by: mbo <himabo@gmail.com> Co-authored-by: mabo <mabo@aeyes.ai> Co-authored-by: Warren Chen <warren.chen830@gmail.com> Co-authored-by: JzoNgKVO <27049666+JzoNgKVO@users.noreply.github.com> Co-authored-by: jiandanfeng <chenjh3@wangsu.com> Co-authored-by: zhu-an <70234959+xhdd123321@users.noreply.github.com> Co-authored-by: zhaoqingyu.1075 <zhaoqingyu.1075@bytedance.com> Co-authored-by: 海狸大師 <86974027+yenslife@users.noreply.github.com> Co-authored-by: Xu Song <xusong.vip@gmail.com> Co-authored-by: rayshaw001 <396301947@163.com> Co-authored-by: Ding Jiatong <dingjiatong@gmail.com> Co-authored-by: Bowen Liang <liangbowen@gf.com.cn> Co-authored-by: JasonVV <jasonwangiii@outlook.com> Co-authored-by: le0zh <newlight@qq.com> Co-authored-by: zhuxinliang <zhuxinliang@didiglobal.com> Co-authored-by: k-zaku <zaku99@outlook.jp> Co-authored-by: luckylhb90 <luckylhb90@gmail.com> Co-authored-by: hobo.l <hobo.l@binance.com> Co-authored-by: jiangbo721 <365065261@qq.com> Co-authored-by: 刘江波 <jiangbo721@163.com> Co-authored-by: Shun Miyazawa <34241526+miya@users.noreply.github.com> Co-authored-by: EricPan <30651140+Egfly@users.noreply.github.com> Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: sino <sino2322@gmail.com> Co-authored-by: Jhvcc <37662342+Jhvcc@users.noreply.github.com> Co-authored-by: lowell <lowell.hu@zkteco.in> Co-authored-by: Boris Polonsky <BorisPolonsky@users.noreply.github.com> Co-authored-by: Ademílson Tonato <ademilsonft@outlook.com> Co-authored-by: Ademílson Tonato <ademilson.tonato@refurbed.com> Co-authored-by: IWAI, Masaharu <iwaim.sub@gmail.com> Co-authored-by: Yueh-Po Peng (Yabi) <94939112+y10ab1@users.noreply.github.com> Co-authored-by: Jason <ggbbddjm@gmail.com> Co-authored-by: Xin Zhang <sjhpzx@gmail.com> Co-authored-by: yjc980121 <3898524+yjc980121@users.noreply.github.com> Co-authored-by: heyszt <36215648+hieheihei@users.noreply.github.com> Co-authored-by: Abdullah AlOsaimi <osaimiacc@gmail.com> Co-authored-by: Abdullah AlOsaimi <189027247+osaimi@users.noreply.github.com> Co-authored-by: Yingchun Lai <laiyingchun@apache.org> Co-authored-by: Hash Brown <hi@xzd.me> Co-authored-by: zuodongxu <192560071+zuodongxu@users.noreply.github.com> Co-authored-by: Masashi Tomooka <tmokmss@users.noreply.github.com> Co-authored-by: aplio <ryo.091219@gmail.com> Co-authored-by: Obada Khalili <54270856+obadakhalili@users.noreply.github.com> Co-authored-by: Nam Vu <zuzoovn@gmail.com> Co-authored-by: Kei YAMAZAKI <1715090+kei-yamazaki@users.noreply.github.com> Co-authored-by: TechnoHouse <13776377+deephbz@users.noreply.github.com> Co-authored-by: Riddhimaan-Senapati <114703025+Riddhimaan-Senapati@users.noreply.github.com> Co-authored-by: MaFee921 <31881301+2284730142@users.noreply.github.com> Co-authored-by: te-chan <t-nakanome@sakura-is.co.jp> Co-authored-by: HQidea <HQidea@users.noreply.github.com> Co-authored-by: Joshbly <36315710+Joshbly@users.noreply.github.com> Co-authored-by: xhe <xw897002528@gmail.com> Co-authored-by: weiwenyan-dev <154779315+weiwenyan-dev@users.noreply.github.com> Co-authored-by: ex_wenyan.wei <ex_wenyan.wei@tcl.com> Co-authored-by: engchina <12236799+engchina@users.noreply.github.com> Co-authored-by: engchina <atjapan2015@gmail.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: 呆萌闷油瓶 <253605712@qq.com> Co-authored-by: Kemal <kemalmeler@outlook.com> Co-authored-by: Lazy_Frog <4590648+lazyFrogLOL@users.noreply.github.com> Co-authored-by: Yi Xiao <54782454+YIXIAO0@users.noreply.github.com> Co-authored-by: Steven sun <98230804+Tuyohai@users.noreply.github.com> Co-authored-by: steven <sunzwj@digitalchina.com> Co-authored-by: Kalo Chin <91766386+fdb02983rhy@users.noreply.github.com> Co-authored-by: Katy Tao <34019945+KatyTao@users.noreply.github.com> Co-authored-by: depy <42985524+h4ckdepy@users.noreply.github.com> Co-authored-by: 胡春东 <gycm520@gmail.com> Co-authored-by: Junjie.M <118170653@qq.com> Co-authored-by: MuYu <mr.muzea@gmail.com> Co-authored-by: Naoki Takashima <39912547+takatea@users.noreply.github.com> Co-authored-by: Summer-Gu <37869445+gubinjie@users.noreply.github.com> Co-authored-by: Fei He <droxer.he@gmail.com> Co-authored-by: ybalbert001 <120714773+ybalbert001@users.noreply.github.com> Co-authored-by: Yuanbo Li <ybalbert@amazon.com> Co-authored-by: douxc <7553076+douxc@users.noreply.github.com> Co-authored-by: liuzhenghua <1090179900@qq.com> Co-authored-by: Wu Jiayang <62842862+Wu-Jiayang@users.noreply.github.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: kimjion <45935338+kimjion@users.noreply.github.com> Co-authored-by: AugNSo <song.tiankai@icloud.com> Co-authored-by: llinvokerl <38915183+llinvokerl@users.noreply.github.com> Co-authored-by: liusurong.lsr <liusurong.lsr@alibaba-inc.com> Co-authored-by: Vasu Negi <vasu-negi@users.noreply.github.com> Co-authored-by: Hundredwz <1808096180@qq.com> Co-authored-by: Xiyuan Chen <52963600+GareArc@users.noreply.github.com>
223
api/core/tools/__base/tool.py
Normal file
@ -0,0 +1,223 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.model import File
|
||||
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolEntity,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
|
||||
|
||||
class Tool(ABC):
|
||||
"""
|
||||
The base class of a tool
|
||||
"""
|
||||
|
||||
entity: ToolEntity
|
||||
runtime: ToolRuntime
|
||||
|
||||
def __init__(self, entity: ToolEntity, runtime: ToolRuntime) -> None:
|
||||
self.entity = entity
|
||||
self.runtime = runtime
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "Tool":
|
||||
"""
|
||||
fork a new tool with meta data
|
||||
|
||||
:param meta: the meta data of a tool call processing, tenant_id is required
|
||||
:return: the new tool
|
||||
"""
|
||||
return self.__class__(
|
||||
entity=self.entity.model_copy(),
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
get the tool provider type
|
||||
|
||||
:return: the tool provider type
|
||||
"""
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage]:
|
||||
if self.runtime and self.runtime.runtime_parameters:
|
||||
tool_parameters.update(self.runtime.runtime_parameters)
|
||||
|
||||
# try parse tool parameters into the correct type
|
||||
tool_parameters = self._transform_tool_parameters_type(tool_parameters)
|
||||
|
||||
result = self._invoke(
|
||||
user_id=user_id,
|
||||
tool_parameters=tool_parameters,
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
if isinstance(result, ToolInvokeMessage):
|
||||
|
||||
def single_generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
yield result
|
||||
|
||||
return single_generator()
|
||||
elif isinstance(result, list):
|
||||
|
||||
def generator() -> Generator[ToolInvokeMessage, None, None]:
|
||||
yield from result
|
||||
|
||||
return generator()
|
||||
else:
|
||||
return result
|
||||
|
||||
def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Transform tool parameters type
|
||||
"""
|
||||
# Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials
|
||||
result = deepcopy(tool_parameters)
|
||||
for parameter in self.entity.parameters or []:
|
||||
if parameter.name in tool_parameters:
|
||||
result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name])
|
||||
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> ToolInvokeMessage | list[ToolInvokeMessage] | Generator[ToolInvokeMessage, None, None]:
|
||||
pass
|
||||
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> list[ToolParameter]:
|
||||
"""
|
||||
get the runtime parameters
|
||||
|
||||
interface for developer to dynamic change the parameters of a tool depends on the variables pool
|
||||
|
||||
:return: the runtime parameters
|
||||
"""
|
||||
return self.entity.parameters
|
||||
|
||||
def get_merged_runtime_parameters(
|
||||
self,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> list[ToolParameter]:
|
||||
"""
|
||||
get merged runtime parameters
|
||||
|
||||
:return: merged runtime parameters
|
||||
"""
|
||||
parameters = self.entity.parameters
|
||||
parameters = parameters.copy()
|
||||
user_parameters = self.get_runtime_parameters() or []
|
||||
user_parameters = user_parameters.copy()
|
||||
|
||||
# override parameters
|
||||
for parameter in user_parameters:
|
||||
# check if parameter in tool parameters
|
||||
for tool_parameter in parameters:
|
||||
if tool_parameter.name == parameter.name:
|
||||
# override parameter
|
||||
tool_parameter.type = parameter.type
|
||||
tool_parameter.form = parameter.form
|
||||
tool_parameter.required = parameter.required
|
||||
tool_parameter.default = parameter.default
|
||||
tool_parameter.options = parameter.options
|
||||
tool_parameter.llm_description = parameter.llm_description
|
||||
break
|
||||
else:
|
||||
# add new parameter
|
||||
parameters.append(parameter)
|
||||
|
||||
return parameters
|
||||
|
||||
def create_image_message(
|
||||
self,
|
||||
image: str,
|
||||
) -> ToolInvokeMessage:
|
||||
"""
|
||||
create an image message
|
||||
|
||||
:param image: the url of the image
|
||||
:return: the image message
|
||||
"""
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image)
|
||||
)
|
||||
|
||||
def create_file_message(self, file: "File") -> ToolInvokeMessage:
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.FILE,
|
||||
message=ToolInvokeMessage.FileMessage(),
|
||||
meta={"file": file},
|
||||
)
|
||||
|
||||
def create_link_message(self, link: str) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a link message
|
||||
|
||||
:param link: the url of the link
|
||||
:return: the link message
|
||||
"""
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text=link)
|
||||
)
|
||||
|
||||
def create_text_message(self, text: str) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a text message
|
||||
|
||||
:param text: the text
|
||||
:return: the text message
|
||||
"""
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.TEXT,
|
||||
message=ToolInvokeMessage.TextMessage(text=text),
|
||||
)
|
||||
|
||||
def create_blob_message(self, blob: bytes, meta: Optional[dict] = None) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a blob message
|
||||
|
||||
:param blob: the blob
|
||||
:return: the blob message
|
||||
"""
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.BLOB,
|
||||
message=ToolInvokeMessage.BlobMessage(blob=blob),
|
||||
meta=meta,
|
||||
)
|
||||
|
||||
def create_json_message(self, object: dict) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a json message
|
||||
"""
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.JSON, message=ToolInvokeMessage.JsonMessage(json_object=object)
|
||||
)
|
||||
109
api/core/tools/__base/tool_provider.py
Normal file
@ -0,0 +1,109 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolProviderEntity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
|
||||
class ToolProviderController(ABC):
|
||||
entity: ToolProviderEntity
|
||||
|
||||
def __init__(self, entity: ToolProviderEntity) -> None:
|
||||
self.entity = entity
|
||||
|
||||
def get_credentials_schema(self) -> list[ProviderConfig]:
|
||||
"""
|
||||
returns the credentials schema of the provider
|
||||
|
||||
:return: the credentials schema
|
||||
"""
|
||||
return deepcopy(self.entity.credentials_schema)
|
||||
|
||||
@abstractmethod
|
||||
def get_tool(self, tool_name: str) -> Tool:
|
||||
"""
|
||||
returns a tool that the provider can provide
|
||||
|
||||
:return: tool
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
:return: type of the provider
|
||||
"""
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the format of the credentials of the provider and set the default value if needed
|
||||
|
||||
:param credentials: the credentials of the tool
|
||||
"""
|
||||
credentials_schema = dict[str, ProviderConfig]()
|
||||
if credentials_schema is None:
|
||||
return
|
||||
|
||||
for credential in self.entity.credentials_schema:
|
||||
credentials_schema[credential.name] = credential
|
||||
|
||||
credentials_need_to_validate: dict[str, ProviderConfig] = {}
|
||||
for credential_name in credentials_schema:
|
||||
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
|
||||
|
||||
for credential_name in credentials:
|
||||
if credential_name not in credentials_need_to_validate:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
f"credential {credential_name} not found in provider {self.entity.identity.name}"
|
||||
)
|
||||
|
||||
# check type
|
||||
credential_schema = credentials_need_to_validate[credential_name]
|
||||
if not credential_schema.required and credentials[credential_name] is None:
|
||||
continue
|
||||
|
||||
if credential_schema.type in {ProviderConfig.Type.SECRET_INPUT, ProviderConfig.Type.TEXT_INPUT}:
|
||||
if not isinstance(credentials[credential_name], str):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
|
||||
|
||||
elif credential_schema.type == ProviderConfig.Type.SELECT:
|
||||
if not isinstance(credentials[credential_name], str):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
|
||||
|
||||
options = credential_schema.options
|
||||
if not isinstance(options, list):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list")
|
||||
|
||||
if credentials[credential_name] not in [x.value for x in options]:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
f"credential {credential_name} should be one of {options}"
|
||||
)
|
||||
|
||||
credentials_need_to_validate.pop(credential_name)
|
||||
|
||||
for credential_name in credentials_need_to_validate:
|
||||
credential_schema = credentials_need_to_validate[credential_name]
|
||||
if credential_schema.required:
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} is required")
|
||||
|
||||
# the credential is not set currently, set the default value if needed
|
||||
if credential_schema.default is not None:
|
||||
default_value = credential_schema.default
|
||||
# parse default value into the correct type
|
||||
if credential_schema.type in {
|
||||
ProviderConfig.Type.SECRET_INPUT,
|
||||
ProviderConfig.Type.TEXT_INPUT,
|
||||
ProviderConfig.Type.SELECT,
|
||||
}:
|
||||
default_value = str(default_value)
|
||||
|
||||
credentials[credential_name] = default_value
|
||||
36
api/core/tools/__base/tool_runtime.py
Normal file
@ -0,0 +1,36 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from openai import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.entities.tool_entities import ToolInvokeFrom
|
||||
|
||||
|
||||
class ToolRuntime(BaseModel):
|
||||
"""
|
||||
Meta data of a tool call processing
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
tool_id: Optional[str] = None
|
||||
invoke_from: Optional[InvokeFrom] = None
|
||||
tool_invoke_from: Optional[ToolInvokeFrom] = None
|
||||
credentials: dict[str, Any] = Field(default_factory=dict)
|
||||
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class FakeToolRuntime(ToolRuntime):
|
||||
"""
|
||||
Fake tool runtime for testing
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
tenant_id="fake_tenant_id",
|
||||
tool_id="fake_tool_id",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
tool_invoke_from=ToolInvokeFrom.AGENT,
|
||||
credentials={},
|
||||
runtime_parameters={},
|
||||
)
|
||||
3
api/core/tools/builtin_tool/_position.yaml
Normal file
@ -0,0 +1,3 @@
|
||||
- code
|
||||
- time
|
||||
- qrcode
|
||||
173
api/core/tools/builtin_tool/provider.py
Normal file
@ -0,0 +1,173 @@
|
||||
from abc import abstractmethod
|
||||
from os import listdir, path
|
||||
from typing import Any
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolProviderEntity, ToolProviderType
|
||||
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
|
||||
from core.tools.errors import (
|
||||
ToolProviderNotFoundError,
|
||||
)
|
||||
from core.tools.utils.yaml_utils import load_yaml_file
|
||||
|
||||
|
||||
class BuiltinToolProviderController(ToolProviderController):
|
||||
tools: list[BuiltinTool]
|
||||
|
||||
def __init__(self, **data: Any) -> None:
|
||||
self.tools = []
|
||||
|
||||
# load provider yaml
|
||||
provider = self.__class__.__module__.split(".")[-1]
|
||||
yaml_path = path.join(path.dirname(path.realpath(__file__)), "providers", provider, f"{provider}.yaml")
|
||||
try:
|
||||
provider_yaml = load_yaml_file(yaml_path, ignore_error=False)
|
||||
except Exception as e:
|
||||
raise ToolProviderNotFoundError(f"can not load provider yaml for {provider}: {e}")
|
||||
|
||||
if "credentials_for_provider" in provider_yaml and provider_yaml["credentials_for_provider"] is not None:
|
||||
# set credentials name
|
||||
for credential_name in provider_yaml["credentials_for_provider"]:
|
||||
provider_yaml["credentials_for_provider"][credential_name]["name"] = credential_name
|
||||
|
||||
credentials_schema = []
|
||||
for credential in provider_yaml.get("credentials_for_provider", {}):
|
||||
credentials_schema.append(credential)
|
||||
|
||||
super().__init__(
|
||||
entity=ToolProviderEntity(
|
||||
identity=provider_yaml["identity"],
|
||||
credentials_schema=credentials_schema,
|
||||
),
|
||||
)
|
||||
|
||||
self._load_tools()
|
||||
|
||||
def _load_tools(self):
|
||||
provider = self.entity.identity.name
|
||||
tool_path = path.join(path.dirname(path.realpath(__file__)), "providers", provider, "tools")
|
||||
# get all the yaml files in the tool path
|
||||
tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path)))
|
||||
tools = []
|
||||
for tool_file in tool_files:
|
||||
# get tool name
|
||||
tool_name = tool_file.split(".")[0]
|
||||
tool = load_yaml_file(path.join(tool_path, tool_file), ignore_error=False)
|
||||
|
||||
# get tool class, import the module
|
||||
assistant_tool_class: type[BuiltinTool] = load_single_subclass_from_source(
|
||||
module_name=f"core.tools.builtin_tool.providers.{provider}.tools.{tool_name}",
|
||||
script_path=path.join(
|
||||
path.dirname(path.realpath(__file__)),
|
||||
"builtin_tool",
|
||||
"providers",
|
||||
provider,
|
||||
"tools",
|
||||
f"{tool_name}.py",
|
||||
),
|
||||
parent_type=BuiltinTool,
|
||||
)
|
||||
tool["identity"]["provider"] = provider
|
||||
tools.append(
|
||||
assistant_tool_class(
|
||||
provider=provider,
|
||||
entity=ToolEntity(**tool),
|
||||
runtime=ToolRuntime(tenant_id=""),
|
||||
)
|
||||
)
|
||||
|
||||
self.tools = tools
|
||||
|
||||
def _get_builtin_tools(self) -> list[BuiltinTool]:
|
||||
"""
|
||||
returns a list of tools that the provider can provide
|
||||
|
||||
:return: list of tools
|
||||
"""
|
||||
return self.tools
|
||||
|
||||
def get_credentials_schema(self) -> list[ProviderConfig]:
|
||||
"""
|
||||
returns the credentials schema of the provider
|
||||
|
||||
:return: the credentials schema
|
||||
"""
|
||||
if not self.entity.credentials_schema:
|
||||
return []
|
||||
|
||||
return self.entity.credentials_schema.copy()
|
||||
|
||||
def get_tools(self) -> list[BuiltinTool]:
|
||||
"""
|
||||
returns a list of tools that the provider can provide
|
||||
|
||||
:return: list of tools
|
||||
"""
|
||||
return self._get_builtin_tools()
|
||||
|
||||
def get_tool(self, tool_name: str) -> BuiltinTool | None: # type: ignore
|
||||
"""
|
||||
returns the tool that the provider can provide
|
||||
"""
|
||||
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None) # type: ignore
|
||||
|
||||
@property
|
||||
def need_credentials(self) -> bool:
|
||||
"""
|
||||
returns whether the provider needs credentials
|
||||
|
||||
:return: whether the provider needs credentials
|
||||
"""
|
||||
return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
:return: type of the provider
|
||||
"""
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
@property
|
||||
def tool_labels(self) -> list[str]:
|
||||
"""
|
||||
returns the labels of the provider
|
||||
|
||||
:return: labels of the provider
|
||||
"""
|
||||
label_enums = self._get_tool_labels()
|
||||
return [default_tool_label_dict[label].name for label in label_enums]
|
||||
|
||||
def _get_tool_labels(self) -> list[ToolLabelEnum]:
|
||||
"""
|
||||
returns the labels of the provider
|
||||
"""
|
||||
return self.entity.identity.tags or []
|
||||
|
||||
def validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
|
||||
:param tool_name: the name of the tool, defined in `get_tools`
|
||||
:param credentials: the credentials of the tool
|
||||
"""
|
||||
# validate credentials format
|
||||
self.validate_credentials_format(credentials)
|
||||
|
||||
# validate credentials
|
||||
self._validate_credentials(user_id, credentials)
|
||||
|
||||
@abstractmethod
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
|
||||
:param tool_name: the name of the tool, defined in `get_tools`
|
||||
:param credentials: the credentials of the tool
|
||||
"""
|
||||
pass
|
||||
@ -1,18 +1,18 @@
|
||||
import os.path
|
||||
|
||||
from core.helper.position_helper import get_tool_position_map, sort_by_position_map
|
||||
from core.tools.entities.api_entities import UserToolProvider
|
||||
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||
|
||||
|
||||
class BuiltinToolProviderSort:
|
||||
_position: dict[str, int] = {}
|
||||
|
||||
@classmethod
|
||||
def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
|
||||
def sort(cls, providers: list[ToolProviderApiEntity]) -> list[ToolProviderApiEntity]:
|
||||
if not cls._position:
|
||||
cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
def name_func(provider: UserToolProvider) -> str:
|
||||
def name_func(provider: ToolProviderApiEntity) -> str:
|
||||
return provider.name
|
||||
|
||||
sorted_providers = sort_by_position_map(cls._position, providers, name_func)
|
||||
|
Before Width: | Height: | Size: 1.5 KiB After Width: | Height: | Size: 1.5 KiB |
8
api/core/tools/builtin_tool/providers/audio/audio.py
Normal file
@ -0,0 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class AudioToolProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
pass
|
||||
@ -1,24 +1,34 @@
|
||||
import io
|
||||
from typing import Any
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.file.enums import FileType
|
||||
from core.file.file_manager import download
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.entities.parameters import PluginParameterOption
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
class ASRTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
file = tool_parameters.get("audio_file")
|
||||
if file.type != FileType.AUDIO:
|
||||
return [self.create_text_message("not a valid audio file")]
|
||||
audio_binary = io.BytesIO(download(file))
|
||||
if file.type != FileType.AUDIO: # type: ignore
|
||||
yield self.create_text_message("not a valid audio file")
|
||||
return
|
||||
audio_binary = io.BytesIO(download(file)) # type: ignore
|
||||
audio_binary.name = "temp.mp3"
|
||||
provider, model = tool_parameters.get("model").split("#")
|
||||
provider, model = tool_parameters.get("model").split("#") # type: ignore
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
@ -30,7 +40,7 @@ class ASRTool(BuiltinTool):
|
||||
file=audio_binary,
|
||||
user=user_id,
|
||||
)
|
||||
return [self.create_text_message(text)]
|
||||
yield self.create_text_message(text)
|
||||
|
||||
def get_available_models(self) -> list[tuple[str, str]]:
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -44,12 +54,17 @@ class ASRTool(BuiltinTool):
|
||||
items.append((provider, model.model))
|
||||
return items
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> list[ToolParameter]:
|
||||
parameters = []
|
||||
|
||||
options = []
|
||||
for provider, model in self.get_available_models():
|
||||
option = ToolParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})"))
|
||||
option = PluginParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})"))
|
||||
options.append(option)
|
||||
|
||||
parameters.append(
|
||||
@ -1,18 +1,27 @@
|
||||
import io
|
||||
from typing import Any
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.plugin.entities.parameters import PluginParameterOption
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
class TTSTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
|
||||
provider, model = tool_parameters.get("model", "").split("#")
|
||||
voice = tool_parameters.get(f"voice#{provider}#{model}", "")
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
provider, model = tool_parameters.get("model").split("#") # type: ignore
|
||||
voice = tool_parameters.get(f"voice#{provider}#{model}")
|
||||
model_manager = ModelManager()
|
||||
if not self.runtime:
|
||||
raise ValueError("Runtime is required")
|
||||
@ -23,24 +32,21 @@ class TTSTool(BuiltinTool):
|
||||
model=model,
|
||||
)
|
||||
tts = model_instance.invoke_tts(
|
||||
content_text=tool_parameters.get("text", ""),
|
||||
content_text=tool_parameters.get("text"), # type: ignore
|
||||
user=user_id,
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
voice=voice,
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
voice=voice, # type: ignore
|
||||
)
|
||||
buffer = io.BytesIO()
|
||||
for chunk in tts:
|
||||
buffer.write(chunk)
|
||||
|
||||
wav_bytes = buffer.getvalue()
|
||||
return [
|
||||
self.create_text_message("Audio generated successfully"),
|
||||
self.create_blob_message(
|
||||
blob=wav_bytes,
|
||||
meta={"mime_type": "audio/x-wav"},
|
||||
save_as=self.VariableKey.AUDIO,
|
||||
),
|
||||
]
|
||||
yield self.create_text_message("Audio generated successfully")
|
||||
yield self.create_blob_message(
|
||||
blob=wav_bytes,
|
||||
meta={"mime_type": "audio/x-wav"},
|
||||
)
|
||||
|
||||
def get_available_models(self) -> list[tuple[str, str, list[Any]]]:
|
||||
if not self.runtime:
|
||||
@ -56,12 +62,17 @@ class TTSTool(BuiltinTool):
|
||||
items.append((provider, model.model, voices))
|
||||
return items
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> list[ToolParameter]:
|
||||
parameters = []
|
||||
|
||||
options = []
|
||||
for provider, model, voices in self.get_available_models():
|
||||
option = ToolParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})"))
|
||||
option = PluginParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})"))
|
||||
options.append(option)
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
@ -72,7 +83,7 @@ class TTSTool(BuiltinTool):
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
options=[
|
||||
ToolParameterOption(value=voice.get("mode"), label=I18nObject(en_US=voice.get("name")))
|
||||
PluginParameterOption(value=voice.get("mode"), label=I18nObject(en_US=voice.get("name")))
|
||||
for voice in voices
|
||||
],
|
||||
)
|
||||
|
Before Width: | Height: | Size: 1.4 KiB After Width: | Height: | Size: 1.4 KiB |
8
api/core/tools/builtin_tool/providers/code/code.py
Normal file
@ -0,0 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class CodeToolProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
pass
|
||||
@ -12,4 +12,3 @@ identity:
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- productivity
|
||||
credentials_for_provider:
|
||||
@ -1,12 +1,20 @@
|
||||
from typing import Any
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class SimpleCode(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
invoke simple code
|
||||
"""
|
||||
@ -19,4 +27,4 @@ class SimpleCode(BuiltinTool):
|
||||
|
||||
result = CodeExecutor.execute_code(language, "", code)
|
||||
|
||||
return self.create_text_message(result)
|
||||
yield self.create_text_message(result)
|
||||
|
Before Width: | Height: | Size: 691 B After Width: | Height: | Size: 691 B |
8
api/core/tools/builtin_tool/providers/time/time.py
Normal file
@ -0,0 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class WikiPediaProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
pass
|
||||
@ -12,4 +12,3 @@ identity:
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- utilities
|
||||
credentials_for_provider:
|
||||
@ -0,0 +1,35 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from pytz import timezone as pytz_timezone
|
||||
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
|
||||
class CurrentTimeTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
# get timezone
|
||||
tz = tool_parameters.get("timezone", "UTC")
|
||||
fm = tool_parameters.get("format") or "%Y-%m-%d %H:%M:%S %Z"
|
||||
if tz == "UTC":
|
||||
yield self.create_text_message(f"{datetime.now(UTC).strftime(fm)}")
|
||||
return
|
||||
|
||||
try:
|
||||
tz = pytz_timezone(tz)
|
||||
except Exception:
|
||||
yield self.create_text_message(f"Invalid timezone: {tz}")
|
||||
return
|
||||
yield self.create_text_message(f"{datetime.now(tz).strftime(fm)}")
|
||||
@ -1,11 +1,12 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Any, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytz
|
||||
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class LocaltimeToTimestampTool(BuiltinTool):
|
||||
@ -13,7 +14,10 @@ class LocaltimeToTimestampTool(BuiltinTool):
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
Convert localtime to timestamp
|
||||
"""
|
||||
@ -23,11 +27,12 @@ class LocaltimeToTimestampTool(BuiltinTool):
|
||||
timezone = None
|
||||
time_format = "%Y-%m-%d %H:%M:%S"
|
||||
|
||||
timestamp = self.localtime_to_timestamp(localtime, time_format, timezone)
|
||||
timestamp = self.localtime_to_timestamp(localtime, time_format, timezone) # type: ignore
|
||||
if not timestamp:
|
||||
return self.create_text_message(f"Invalid localtime: {localtime}")
|
||||
yield self.create_text_message(f"Invalid localtime: {localtime}")
|
||||
return
|
||||
|
||||
return self.create_text_message(f"{timestamp}")
|
||||
yield self.create_text_message(f"{timestamp}")
|
||||
|
||||
@staticmethod
|
||||
def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None:
|
||||
@ -37,8 +42,8 @@ class LocaltimeToTimestampTool(BuiltinTool):
|
||||
if isinstance(local_tz, str):
|
||||
local_tz = pytz.timezone(local_tz)
|
||||
local_time = datetime.strptime(localtime, time_format)
|
||||
localtime = local_tz.localize(local_time)
|
||||
timestamp = int(localtime.timestamp())
|
||||
localtime = local_tz.localize(local_time) # type: ignore
|
||||
timestamp = int(localtime.timestamp()) # type: ignore
|
||||
return timestamp
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(str(e))
|
||||
@ -1,11 +1,12 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Any, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytz
|
||||
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class TimestampToLocaltimeTool(BuiltinTool):
|
||||
@ -13,11 +14,14 @@ class TimestampToLocaltimeTool(BuiltinTool):
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
Convert timestamp to localtime
|
||||
"""
|
||||
timestamp = tool_parameters.get("timestamp")
|
||||
timestamp: int = tool_parameters.get("timestamp", 0)
|
||||
timezone = tool_parameters.get("timezone", "Asia/Shanghai")
|
||||
if not timezone:
|
||||
timezone = None
|
||||
@ -25,11 +29,12 @@ class TimestampToLocaltimeTool(BuiltinTool):
|
||||
|
||||
locatime = self.timestamp_to_localtime(timestamp, timezone)
|
||||
if not locatime:
|
||||
return self.create_text_message(f"Invalid timestamp: {timestamp}")
|
||||
yield self.create_text_message(f"Invalid timestamp: {timestamp}")
|
||||
return
|
||||
|
||||
localtime_format = locatime.strftime(time_format)
|
||||
|
||||
return self.create_text_message(f"{localtime_format}")
|
||||
yield self.create_text_message(f"{localtime_format}")
|
||||
|
||||
@staticmethod
|
||||
def timestamp_to_localtime(timestamp: int, local_tz=None) -> datetime | None:
|
||||
@ -1,11 +1,12 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Any, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytz
|
||||
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class TimezoneConversionTool(BuiltinTool):
|
||||
@ -13,20 +14,24 @@ class TimezoneConversionTool(BuiltinTool):
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
Convert time to equivalent time zone
|
||||
"""
|
||||
current_time = tool_parameters.get("current_time")
|
||||
current_timezone = tool_parameters.get("current_timezone", "Asia/Shanghai")
|
||||
target_timezone = tool_parameters.get("target_timezone", "Asia/Tokyo")
|
||||
target_time = self.timezone_convert(current_time, current_timezone, target_timezone)
|
||||
target_time = self.timezone_convert(current_time, current_timezone, target_timezone) # type: ignore
|
||||
if not target_time:
|
||||
return self.create_text_message(
|
||||
yield self.create_text_message(
|
||||
f"Invalid datatime and timezone: {current_time},{current_timezone},{target_timezone}"
|
||||
)
|
||||
return
|
||||
|
||||
return self.create_text_message(f"{target_time}")
|
||||
yield self.create_text_message(f"{target_time}")
|
||||
|
||||
@staticmethod
|
||||
def timezone_convert(current_time: str, source_timezone: str, target_timezone: str) -> str:
|
||||
@ -43,6 +48,6 @@ class TimezoneConversionTool(BuiltinTool):
|
||||
datetime_with_tz = input_timezone.localize(local_time)
|
||||
# timezone convert
|
||||
converted_datetime = datetime_with_tz.astimezone(output_timezone)
|
||||
return converted_datetime.strftime(format=time_format)
|
||||
return converted_datetime.strftime(format=time_format) # type: ignore
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(str(e))
|
||||
@ -1,9 +1,10 @@
|
||||
import calendar
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Any, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class WeekdayTool(BuiltinTool):
|
||||
@ -11,22 +12,28 @@ class WeekdayTool(BuiltinTool):
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
Calculate the day of the week for a given date
|
||||
"""
|
||||
year = tool_parameters.get("year")
|
||||
month = tool_parameters.get("month")
|
||||
if month is None:
|
||||
raise ValueError("Month is required")
|
||||
day = tool_parameters.get("day")
|
||||
|
||||
date_obj = self.convert_datetime(year, month, day)
|
||||
if not date_obj:
|
||||
return self.create_text_message(f"Invalid date: Year {year}, Month {month}, Day {day}.")
|
||||
yield self.create_text_message(f"Invalid date: Year {year}, Month {month}, Day {day}.")
|
||||
return
|
||||
|
||||
weekday_name = calendar.day_name[date_obj.weekday()]
|
||||
month_name = calendar.month_name[month]
|
||||
readable_date = f"{month_name} {date_obj.day}, {date_obj.year}"
|
||||
return self.create_text_message(f"{readable_date} is {weekday_name}.")
|
||||
yield self.create_text_message(f"{readable_date} is {weekday_name}.")
|
||||
|
||||
@staticmethod
|
||||
def convert_datetime(year, month, day) -> datetime | None:
|
||||
|
Before Width: | Height: | Size: 2.2 KiB After Width: | Height: | Size: 2.2 KiB |
@ -1,8 +1,10 @@
|
||||
from typing import Any, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.utils.web_reader_tool import get_url
|
||||
|
||||
|
||||
class WebscraperTool(BuiltinTool):
|
||||
@ -10,7 +12,10 @@ class WebscraperTool(BuiltinTool):
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
@ -18,16 +23,17 @@ class WebscraperTool(BuiltinTool):
|
||||
url = tool_parameters.get("url", "")
|
||||
user_agent = tool_parameters.get("user_agent", "")
|
||||
if not url:
|
||||
return self.create_text_message("Please input url")
|
||||
yield self.create_text_message("Please input url")
|
||||
return
|
||||
|
||||
# get webpage
|
||||
result = self.get_url(url, user_agent=user_agent)
|
||||
result = get_url(url, user_agent=user_agent)
|
||||
|
||||
if tool_parameters.get("generate_summary"):
|
||||
# summarize and return
|
||||
return self.create_text_message(self.summary(user_id=user_id, content=result))
|
||||
yield self.create_text_message(self.summary(user_id=user_id, content=result))
|
||||
else:
|
||||
# return full webpage
|
||||
return self.create_text_message(result)
|
||||
yield self.create_text_message(result)
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(str(e))
|
||||
@ -49,12 +49,12 @@ parameters:
|
||||
zh_Hans: 如果启用,爬虫将仅返回页面摘要内容。
|
||||
form: form
|
||||
options:
|
||||
- value: 'true'
|
||||
- value: "true"
|
||||
label:
|
||||
en_US: 'Yes'
|
||||
en_US: "Yes"
|
||||
zh_Hans: 是
|
||||
- value: 'false'
|
||||
- value: "false"
|
||||
label:
|
||||
en_US: 'No'
|
||||
en_US: "No"
|
||||
zh_Hans: 否
|
||||
default: 'false'
|
||||
default: "false"
|
||||
@ -0,0 +1,11 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class WebscraperProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate credentials
|
||||
"""
|
||||
pass
|
||||
@ -12,4 +12,4 @@ identity:
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- productivity
|
||||
credentials_for_provider:
|
||||
credentials_for_provider: []
|
||||
@ -1,11 +1,9 @@
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
|
||||
from core.tools.utils.web_reader_tool import get_url
|
||||
|
||||
_SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language
|
||||
and you can quickly aimed at the main point of an webpage and reproduce it in your own words but
|
||||
@ -22,6 +20,25 @@ class BuiltinTool(Tool):
|
||||
:param meta: the meta data of a tool call processing
|
||||
"""
|
||||
|
||||
provider: str
|
||||
|
||||
def __init__(self, provider: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.provider = provider
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "BuiltinTool":
|
||||
"""
|
||||
fork a new tool with meta data
|
||||
|
||||
:param meta: the meta data of a tool call processing, tenant_id is required
|
||||
:return: the new tool
|
||||
"""
|
||||
return self.__class__(
|
||||
entity=self.entity.model_copy(),
|
||||
runtime=runtime,
|
||||
provider=self.provider,
|
||||
)
|
||||
|
||||
def invoke_model(self, user_id: str, prompt_messages: list[PromptMessage], stop: list[str]) -> LLMResult:
|
||||
"""
|
||||
invoke model
|
||||
@ -32,14 +49,11 @@ class BuiltinTool(Tool):
|
||||
:return: the model result
|
||||
"""
|
||||
# invoke model
|
||||
if self.runtime is None or self.identity is None:
|
||||
raise ValueError("runtime and identity are required")
|
||||
|
||||
return ModelInvocationUtils.invoke(
|
||||
user_id=user_id,
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
tool_type="builtin",
|
||||
tool_name=self.identity.name,
|
||||
tool_name=self.entity.identity.name,
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
||||
@ -92,7 +106,8 @@ class BuiltinTool(Tool):
|
||||
stop=[],
|
||||
)
|
||||
|
||||
return cast(str, summary.message.content)
|
||||
assert isinstance(summary.message.content, str)
|
||||
return summary.message.content
|
||||
|
||||
lines = content.split("\n")
|
||||
new_lines = []
|
||||
@ -136,9 +151,3 @@ class BuiltinTool(Tool):
|
||||
return self.summary(user_id=user_id, content=result)
|
||||
|
||||
return result
|
||||
|
||||
def get_url(self, url: str, user_agent: Optional[str] = None) -> str:
|
||||
"""
|
||||
get url
|
||||
"""
|
||||
return get_url(url, user_agent=user_agent)
|
||||
@ -1,86 +1,99 @@
|
||||
from typing import Optional
|
||||
from pydantic import Field
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.custom_tool.tool import ApiTool
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolCredentialsOption,
|
||||
ToolDescription,
|
||||
ToolEntity,
|
||||
ToolIdentity,
|
||||
ToolProviderCredentials,
|
||||
ToolProviderEntity,
|
||||
ToolProviderIdentity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool.api_tool import ApiTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider
|
||||
|
||||
|
||||
class ApiToolProviderController(ToolProviderController):
|
||||
provider_id: str
|
||||
tenant_id: str
|
||||
tools: list[ApiTool] = Field(default_factory=list)
|
||||
|
||||
@staticmethod
|
||||
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
|
||||
credentials_schema = {
|
||||
"auth_type": ToolProviderCredentials(
|
||||
def __init__(self, entity: ToolProviderEntity, provider_id: str, tenant_id: str) -> None:
|
||||
super().__init__(entity)
|
||||
self.provider_id = provider_id
|
||||
self.tenant_id = tenant_id
|
||||
self.tools = []
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
|
||||
credentials_schema = [
|
||||
ProviderConfig(
|
||||
name="auth_type",
|
||||
required=True,
|
||||
type=ToolProviderCredentials.CredentialsType.SELECT,
|
||||
type=ProviderConfig.Type.SELECT,
|
||||
options=[
|
||||
ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="无")),
|
||||
ToolCredentialsOption(value="api_key", label=I18nObject(en_US="api_key", zh_Hans="api_key")),
|
||||
ProviderConfig.Option(value="none", label=I18nObject(en_US="None", zh_Hans="无")),
|
||||
ProviderConfig.Option(value="api_key", label=I18nObject(en_US="api_key", zh_Hans="api_key")),
|
||||
],
|
||||
default="none",
|
||||
help=I18nObject(en_US="The auth type of the api provider", zh_Hans="api provider 的认证类型"),
|
||||
)
|
||||
}
|
||||
]
|
||||
if auth_type == ApiProviderAuthType.API_KEY:
|
||||
credentials_schema = {
|
||||
**credentials_schema,
|
||||
"api_key_header": ToolProviderCredentials(
|
||||
credentials_schema = [
|
||||
*credentials_schema,
|
||||
ProviderConfig(
|
||||
name="api_key_header",
|
||||
required=False,
|
||||
default="api_key",
|
||||
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
|
||||
type=ProviderConfig.Type.TEXT_INPUT,
|
||||
help=I18nObject(en_US="The header name of the api key", zh_Hans="携带 api key 的 header 名称"),
|
||||
),
|
||||
"api_key_value": ToolProviderCredentials(
|
||||
ProviderConfig(
|
||||
name="api_key_value",
|
||||
required=True,
|
||||
type=ToolProviderCredentials.CredentialsType.SECRET_INPUT,
|
||||
type=ProviderConfig.Type.SECRET_INPUT,
|
||||
help=I18nObject(en_US="The api key", zh_Hans="api key的值"),
|
||||
),
|
||||
"api_key_header_prefix": ToolProviderCredentials(
|
||||
ProviderConfig(
|
||||
name="api_key_header_prefix",
|
||||
required=False,
|
||||
default="basic",
|
||||
type=ToolProviderCredentials.CredentialsType.SELECT,
|
||||
type=ProviderConfig.Type.SELECT,
|
||||
help=I18nObject(en_US="The prefix of the api key header", zh_Hans="api key header 的前缀"),
|
||||
options=[
|
||||
ToolCredentialsOption(value="basic", label=I18nObject(en_US="Basic", zh_Hans="Basic")),
|
||||
ToolCredentialsOption(value="bearer", label=I18nObject(en_US="Bearer", zh_Hans="Bearer")),
|
||||
ToolCredentialsOption(value="custom", label=I18nObject(en_US="Custom", zh_Hans="Custom")),
|
||||
ProviderConfig.Option(value="basic", label=I18nObject(en_US="Basic", zh_Hans="Basic")),
|
||||
ProviderConfig.Option(value="bearer", label=I18nObject(en_US="Bearer", zh_Hans="Bearer")),
|
||||
ProviderConfig.Option(value="custom", label=I18nObject(en_US="Custom", zh_Hans="Custom")),
|
||||
],
|
||||
),
|
||||
}
|
||||
]
|
||||
elif auth_type == ApiProviderAuthType.NONE:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"invalid auth type {auth_type}")
|
||||
user_name = db_provider.user.name if db_provider.user_id and db_provider.user is not None else ""
|
||||
|
||||
user = db_provider.user
|
||||
user_name = user.name if user else ""
|
||||
|
||||
return ApiToolProviderController(
|
||||
identity=ToolProviderIdentity(
|
||||
author=user_name,
|
||||
name=db_provider.name,
|
||||
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
|
||||
description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
|
||||
icon=db_provider.icon,
|
||||
entity=ToolProviderEntity(
|
||||
identity=ToolProviderIdentity(
|
||||
author=user_name,
|
||||
name=db_provider.name,
|
||||
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
|
||||
description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
|
||||
icon=db_provider.icon,
|
||||
),
|
||||
credentials_schema=credentials_schema,
|
||||
plugin_id=None,
|
||||
),
|
||||
credentials_schema=credentials_schema,
|
||||
provider_id=db_provider.id or "",
|
||||
tools=None,
|
||||
tenant_id=db_provider.tenant_id or "",
|
||||
)
|
||||
|
||||
@property
|
||||
@ -96,21 +109,28 @@ class ApiToolProviderController(ToolProviderController):
|
||||
"""
|
||||
return ApiTool(
|
||||
api_bundle=tool_bundle,
|
||||
identity=ToolIdentity(
|
||||
author=tool_bundle.author,
|
||||
name=tool_bundle.operation_id or "",
|
||||
label=I18nObject(en_US=tool_bundle.operation_id, zh_Hans=tool_bundle.operation_id),
|
||||
icon=self.identity.icon if self.identity else None,
|
||||
provider=self.provider_id,
|
||||
provider_id=self.provider_id,
|
||||
entity=ToolEntity(
|
||||
identity=ToolIdentity(
|
||||
author=tool_bundle.author,
|
||||
name=tool_bundle.operation_id or "default_tool",
|
||||
label=I18nObject(
|
||||
en_US=tool_bundle.operation_id or "default_tool",
|
||||
zh_Hans=tool_bundle.operation_id or "default_tool",
|
||||
),
|
||||
icon=self.entity.identity.icon,
|
||||
provider=self.provider_id,
|
||||
),
|
||||
description=ToolDescription(
|
||||
human=I18nObject(en_US=tool_bundle.summary or "", zh_Hans=tool_bundle.summary or ""),
|
||||
llm=tool_bundle.summary or "",
|
||||
),
|
||||
parameters=tool_bundle.parameters or [],
|
||||
),
|
||||
description=ToolDescription(
|
||||
human=I18nObject(en_US=tool_bundle.summary or "", zh_Hans=tool_bundle.summary or ""),
|
||||
llm=tool_bundle.summary or "",
|
||||
),
|
||||
parameters=tool_bundle.parameters or [],
|
||||
runtime=ToolRuntime(tenant_id=self.tenant_id),
|
||||
)
|
||||
|
||||
def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[Tool]:
|
||||
def load_bundled_tools(self, tools: list[ApiToolBundle]):
|
||||
"""
|
||||
load bundled tools
|
||||
|
||||
@ -121,7 +141,7 @@ class ApiToolProviderController(ToolProviderController):
|
||||
|
||||
return self.tools
|
||||
|
||||
def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]:
|
||||
def get_tools(self, tenant_id: str) -> list[ApiTool]:
|
||||
"""
|
||||
fetch tools from database
|
||||
|
||||
@ -129,17 +149,15 @@ class ApiToolProviderController(ToolProviderController):
|
||||
:param tenant_id: the tenant id
|
||||
:return: the tools
|
||||
"""
|
||||
if self.tools is not None:
|
||||
if len(self.tools) > 0:
|
||||
return self.tools
|
||||
if self.identity is None:
|
||||
return None
|
||||
|
||||
tools: list[Tool] = []
|
||||
tools: list[ApiTool] = []
|
||||
|
||||
# get tenant api providers
|
||||
db_providers: list[ApiToolProvider] = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.identity.name)
|
||||
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name)
|
||||
.all()
|
||||
)
|
||||
|
||||
@ -147,13 +165,12 @@ class ApiToolProviderController(ToolProviderController):
|
||||
for db_provider in db_providers:
|
||||
for tool in db_provider.tools:
|
||||
assistant_tool = self._parse_tool_bundle(tool)
|
||||
assistant_tool.is_team_authorization = True
|
||||
tools.append(assistant_tool)
|
||||
|
||||
self.tools = tools
|
||||
return tools
|
||||
|
||||
def get_tool(self, tool_name: str) -> Tool:
|
||||
def get_tool(self, tool_name: str):
|
||||
"""
|
||||
get tool by name
|
||||
|
||||
@ -161,12 +178,10 @@ class ApiToolProviderController(ToolProviderController):
|
||||
:return: the tool
|
||||
"""
|
||||
if self.tools is None:
|
||||
self.get_tools()
|
||||
self.get_tools(self.tenant_id)
|
||||
|
||||
for tool in self.tools or []:
|
||||
if tool.identity is None:
|
||||
continue
|
||||
if tool.identity.name == tool_name:
|
||||
for tool in self.tools:
|
||||
if tool.entity.identity.name == tool_name:
|
||||
return tool
|
||||
|
||||
raise ValueError(f"tool {tool_name} not found")
|
||||
@ -1,16 +1,18 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from os import getenv
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
|
||||
from core.file.file_manager import download
|
||||
from core.helper import ssrf_proxy
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
|
||||
from core.tools.tool.tool import Tool
|
||||
|
||||
API_TOOL_DEFAULT_TIMEOUT = (
|
||||
int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")),
|
||||
@ -20,12 +22,18 @@ API_TOOL_DEFAULT_TIMEOUT = (
|
||||
|
||||
class ApiTool(Tool):
|
||||
api_bundle: ApiToolBundle
|
||||
provider_id: str
|
||||
|
||||
"""
|
||||
Api tool
|
||||
"""
|
||||
|
||||
def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool":
|
||||
def __init__(self, entity: ToolEntity, api_bundle: ApiToolBundle, runtime: ToolRuntime, provider_id: str):
|
||||
super().__init__(entity, runtime)
|
||||
self.api_bundle = api_bundle
|
||||
self.provider_id = provider_id
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime):
|
||||
"""
|
||||
fork a new tool with meta data
|
||||
|
||||
@ -35,11 +43,10 @@ class ApiTool(Tool):
|
||||
if self.api_bundle is None:
|
||||
raise ValueError("api_bundle is required")
|
||||
return self.__class__(
|
||||
identity=self.identity.model_copy() if self.identity else None,
|
||||
parameters=self.parameters.copy() if self.parameters else None,
|
||||
description=self.description.model_copy() if self.description else None,
|
||||
entity=self.entity,
|
||||
api_bundle=self.api_bundle.model_copy(),
|
||||
runtime=Tool.Runtime(**runtime),
|
||||
runtime=runtime,
|
||||
provider_id=self.provider_id,
|
||||
)
|
||||
|
||||
def validate_credentials(
|
||||
@ -62,6 +69,9 @@ class ApiTool(Tool):
|
||||
return ToolProviderType.API
|
||||
|
||||
def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
if self.runtime is None:
|
||||
raise ToolProviderCredentialValidationError("runtime not initialized")
|
||||
|
||||
headers = {}
|
||||
if self.runtime is None:
|
||||
raise ValueError("runtime is required")
|
||||
@ -115,9 +125,9 @@ class ApiTool(Tool):
|
||||
response = response.json()
|
||||
try:
|
||||
return json.dumps(response, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return json.dumps(response)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return response.text
|
||||
else:
|
||||
raise ValueError(f"Invalid response type {type(response)}")
|
||||
@ -301,10 +311,17 @@ class ApiTool(Tool):
|
||||
raise ValueError(f"Invalid type {property['type']} for property {property}")
|
||||
elif "anyOf" in property and isinstance(property["anyOf"], list):
|
||||
return self._convert_body_property_any_of(property, value, property["anyOf"])
|
||||
except ValueError as e:
|
||||
except ValueError:
|
||||
return value
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
invoke http request
|
||||
"""
|
||||
@ -319,4 +336,4 @@ class ApiTool(Tool):
|
||||
response = self.validate_and_parse_response(response)
|
||||
|
||||
# assemble invoke message
|
||||
return self.create_text_message(response)
|
||||
yield self.create_text_message(response)
|
||||
@ -3,37 +3,40 @@ from typing import Literal, Optional
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool import ToolParameter
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderCredentials, ToolProviderType
|
||||
from core.tools.tool.tool import ToolParameter
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
|
||||
|
||||
class UserTool(BaseModel):
|
||||
class ToolApiEntity(BaseModel):
|
||||
author: str
|
||||
name: str # identifier
|
||||
label: I18nObject # label
|
||||
description: I18nObject
|
||||
parameters: Optional[list[ToolParameter]] = None
|
||||
labels: list[str] | None = None
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
output_schema: Optional[dict] = None
|
||||
|
||||
|
||||
UserToolProviderTypeLiteral = Optional[Literal["builtin", "api", "workflow"]]
|
||||
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
|
||||
|
||||
|
||||
class UserToolProvider(BaseModel):
|
||||
class ToolProviderApiEntity(BaseModel):
|
||||
id: str
|
||||
author: str
|
||||
name: str # identifier
|
||||
description: I18nObject
|
||||
icon: str
|
||||
icon: str | dict
|
||||
label: I18nObject # label
|
||||
type: ToolProviderType
|
||||
masked_credentials: Optional[dict] = None
|
||||
original_credentials: Optional[dict] = None
|
||||
is_team_authorization: bool = False
|
||||
allow_delete: bool = True
|
||||
tools: list[UserTool] = Field(default_factory=list)
|
||||
labels: list[str] | None = None
|
||||
plugin_id: Optional[str] = Field(default="", description="The plugin id of the tool")
|
||||
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool")
|
||||
tools: list[ToolApiEntity] = Field(default_factory=list)
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
|
||||
@field_validator("tools", mode="before")
|
||||
@classmethod
|
||||
@ -55,6 +58,8 @@ class UserToolProvider(BaseModel):
|
||||
"id": self.id,
|
||||
"author": self.author,
|
||||
"name": self.name,
|
||||
"plugin_id": self.plugin_id,
|
||||
"plugin_unique_identifier": self.plugin_unique_identifier,
|
||||
"description": self.description.to_dict(),
|
||||
"icon": self.icon,
|
||||
"label": self.label.to_dict(),
|
||||
@ -65,7 +70,3 @@ class UserToolProvider(BaseModel):
|
||||
"tools": tools,
|
||||
"labels": self.labels,
|
||||
}
|
||||
|
||||
|
||||
class UserToolProviderCredentials(BaseModel):
|
||||
credentials: dict[str, ToolProviderCredentials]
|
||||
|
||||
1
api/core/tools/entities/constants.py
Normal file
@ -0,0 +1 @@
|
||||
TOOL_SELECTOR_MODEL_IDENTITY = "__dify__tool_selector__"
|
||||
1
api/core/tools/entities/file_entities.py
Normal file
@ -0,0 +1 @@
|
||||
|
||||
@ -18,7 +18,7 @@ class ApiToolBundle(BaseModel):
|
||||
# summary
|
||||
summary: Optional[str] = None
|
||||
# operation_id
|
||||
operation_id: str | None = None
|
||||
operation_id: Optional[str] = None
|
||||
# parameters
|
||||
parameters: Optional[list[ToolParameter]] = None
|
||||
# author
|
||||
|
||||
@ -1,9 +1,22 @@
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, Optional, Union, cast
|
||||
import base64
|
||||
import enum
|
||||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.entities.parameters import (
|
||||
PluginParameter,
|
||||
PluginParameterOption,
|
||||
PluginParameterType,
|
||||
as_normal_type,
|
||||
cast_parameter_value,
|
||||
init_frontend_parameter,
|
||||
)
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY
|
||||
|
||||
|
||||
class ToolLabelEnum(Enum):
|
||||
@ -25,11 +38,12 @@ class ToolLabelEnum(Enum):
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
class ToolProviderType(Enum):
|
||||
class ToolProviderType(enum.StrEnum):
|
||||
"""
|
||||
Enum class for tool provider
|
||||
"""
|
||||
|
||||
PLUGIN = "plugin"
|
||||
BUILT_IN = "builtin"
|
||||
WORKFLOW = "workflow"
|
||||
API = "api"
|
||||
@ -97,6 +111,64 @@ class ApiProviderAuthType(Enum):
|
||||
|
||||
|
||||
class ToolInvokeMessage(BaseModel):
|
||||
class TextMessage(BaseModel):
|
||||
text: str
|
||||
|
||||
class JsonMessage(BaseModel):
|
||||
json_object: dict
|
||||
|
||||
class BlobMessage(BaseModel):
|
||||
blob: bytes
|
||||
|
||||
class FileMessage(BaseModel):
|
||||
pass
|
||||
|
||||
class VariableMessage(BaseModel):
|
||||
variable_name: str = Field(..., description="The name of the variable")
|
||||
variable_value: Any = Field(..., description="The value of the variable")
|
||||
stream: bool = Field(default=False, description="Whether the variable is streamed")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def transform_variable_value(cls, values) -> Any:
|
||||
"""
|
||||
Only basic types and lists are allowed.
|
||||
"""
|
||||
value = values.get("variable_value")
|
||||
if not isinstance(value, dict | list | str | int | float | bool):
|
||||
raise ValueError("Only basic types and lists are allowed.")
|
||||
|
||||
# if stream is true, the value must be a string
|
||||
if values.get("stream"):
|
||||
if not isinstance(value, str):
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
|
||||
return values
|
||||
|
||||
@field_validator("variable_name", mode="before")
|
||||
@classmethod
|
||||
def transform_variable_name(cls, value: str) -> str:
|
||||
"""
|
||||
The variable name must be a string.
|
||||
"""
|
||||
if value in {"json", "text", "files"}:
|
||||
raise ValueError(f"The variable name '{value}' is reserved.")
|
||||
return value
|
||||
|
||||
class LogMessage(BaseModel):
|
||||
class LogStatus(Enum):
|
||||
START = "start"
|
||||
ERROR = "error"
|
||||
SUCCESS = "success"
|
||||
|
||||
id: str
|
||||
label: str = Field(..., description="The label of the log")
|
||||
parent_id: Optional[str] = Field(default=None, description="Leave empty for root log")
|
||||
error: Optional[str] = Field(default=None, description="The error message")
|
||||
status: LogStatus = Field(..., description="The status of the log")
|
||||
data: Mapping[str, Any] = Field(..., description="Detailed log data")
|
||||
metadata: Optional[Mapping[str, Any]] = Field(default=None, description="The metadata of the log")
|
||||
|
||||
class MessageType(Enum):
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
@ -104,132 +176,86 @@ class ToolInvokeMessage(BaseModel):
|
||||
BLOB = "blob"
|
||||
JSON = "json"
|
||||
IMAGE_LINK = "image_link"
|
||||
BINARY_LINK = "binary_link"
|
||||
VARIABLE = "variable"
|
||||
FILE = "file"
|
||||
LOG = "log"
|
||||
|
||||
type: MessageType = MessageType.TEXT
|
||||
"""
|
||||
plain text, image url or link url
|
||||
"""
|
||||
message: str | bytes | dict | None = None
|
||||
# TODO: Use a BaseModel for meta
|
||||
meta: dict[str, Any] = Field(default_factory=dict)
|
||||
save_as: str = ""
|
||||
message: JsonMessage | TextMessage | BlobMessage | VariableMessage | FileMessage | LogMessage | None
|
||||
meta: dict[str, Any] | None = None
|
||||
|
||||
@field_validator("message", mode="before")
|
||||
@classmethod
|
||||
def decode_blob_message(cls, v):
|
||||
if isinstance(v, dict) and "blob" in v:
|
||||
try:
|
||||
v["blob"] = base64.b64decode(v["blob"])
|
||||
except Exception:
|
||||
pass
|
||||
return v
|
||||
|
||||
@field_serializer("message")
|
||||
def serialize_message(self, v):
|
||||
if isinstance(v, self.BlobMessage):
|
||||
return {"blob": base64.b64encode(v.blob).decode("utf-8")}
|
||||
return v
|
||||
|
||||
|
||||
class ToolInvokeMessageBinary(BaseModel):
|
||||
mimetype: str = Field(..., description="The mimetype of the binary")
|
||||
url: str = Field(..., description="The url of the binary")
|
||||
save_as: str = ""
|
||||
file_var: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class ToolParameterOption(BaseModel):
|
||||
value: str = Field(..., description="The value of the option")
|
||||
label: I18nObject = Field(..., description="The label of the option")
|
||||
class ToolParameter(PluginParameter):
|
||||
"""
|
||||
Overrides type
|
||||
"""
|
||||
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
def transform_id_to_str(cls, value) -> str:
|
||||
if not isinstance(value, str):
|
||||
return str(value)
|
||||
else:
|
||||
return value
|
||||
class ToolParameterType(enum.StrEnum):
|
||||
"""
|
||||
removes TOOLS_SELECTOR from PluginParameterType
|
||||
"""
|
||||
|
||||
|
||||
class ToolParameter(BaseModel):
|
||||
class ToolParameterType(StrEnum):
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
BOOLEAN = "boolean"
|
||||
SELECT = "select"
|
||||
SECRET_INPUT = "secret-input"
|
||||
FILE = "file"
|
||||
FILES = "files"
|
||||
STRING = PluginParameterType.STRING.value
|
||||
NUMBER = PluginParameterType.NUMBER.value
|
||||
BOOLEAN = PluginParameterType.BOOLEAN.value
|
||||
SELECT = PluginParameterType.SELECT.value
|
||||
SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
|
||||
FILE = PluginParameterType.FILE.value
|
||||
FILES = PluginParameterType.FILES.value
|
||||
APP_SELECTOR = PluginParameterType.APP_SELECTOR.value
|
||||
MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = "systme-files"
|
||||
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
|
||||
|
||||
def as_normal_type(self):
|
||||
if self in {
|
||||
ToolParameter.ToolParameterType.SECRET_INPUT,
|
||||
ToolParameter.ToolParameterType.SELECT,
|
||||
}:
|
||||
return "string"
|
||||
return self.value
|
||||
return as_normal_type(self)
|
||||
|
||||
def cast_value(self, value: Any, /):
|
||||
try:
|
||||
match self:
|
||||
case (
|
||||
ToolParameter.ToolParameterType.STRING
|
||||
| ToolParameter.ToolParameterType.SECRET_INPUT
|
||||
| ToolParameter.ToolParameterType.SELECT
|
||||
):
|
||||
if value is None:
|
||||
return ""
|
||||
else:
|
||||
return value if isinstance(value, str) else str(value)
|
||||
|
||||
case ToolParameter.ToolParameterType.BOOLEAN:
|
||||
if value is None:
|
||||
return False
|
||||
elif isinstance(value, str):
|
||||
# Allowed YAML boolean value strings: https://yaml.org/type/bool.html
|
||||
# and also '0' for False and '1' for True
|
||||
match value.lower():
|
||||
case "true" | "yes" | "y" | "1":
|
||||
return True
|
||||
case "false" | "no" | "n" | "0":
|
||||
return False
|
||||
case _:
|
||||
return bool(value)
|
||||
else:
|
||||
return value if isinstance(value, bool) else bool(value)
|
||||
|
||||
case ToolParameter.ToolParameterType.NUMBER:
|
||||
if isinstance(value, int | float):
|
||||
return value
|
||||
elif isinstance(value, str) and value:
|
||||
if "." in value:
|
||||
return float(value)
|
||||
else:
|
||||
return int(value)
|
||||
case (
|
||||
ToolParameter.ToolParameterType.SYSTEM_FILES
|
||||
| ToolParameter.ToolParameterType.FILE
|
||||
| ToolParameter.ToolParameterType.FILES
|
||||
):
|
||||
return value
|
||||
case _:
|
||||
return str(value)
|
||||
|
||||
except Exception:
|
||||
raise ValueError(f"The tool parameter value {value} is not in correct type.")
|
||||
def cast_value(self, value: Any):
|
||||
return cast_parameter_value(self, value)
|
||||
|
||||
class ToolParameterForm(Enum):
|
||||
SCHEMA = "schema" # should be set while adding tool
|
||||
FORM = "form" # should be set before invoking tool
|
||||
LLM = "llm" # will be set by LLM
|
||||
|
||||
name: str = Field(..., description="The name of the parameter")
|
||||
label: I18nObject = Field(..., description="The label presented to the user")
|
||||
human_description: Optional[I18nObject] = Field(None, description="The description presented to the user")
|
||||
placeholder: Optional[I18nObject] = Field(None, description="The placeholder presented to the user")
|
||||
type: ToolParameterType = Field(..., description="The type of the parameter")
|
||||
human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user")
|
||||
form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
|
||||
llm_description: Optional[str] = None
|
||||
required: Optional[bool] = False
|
||||
default: Optional[Union[float, int, str]] = None
|
||||
min: Optional[Union[float, int]] = None
|
||||
max: Optional[Union[float, int]] = None
|
||||
options: Optional[list[ToolParameterOption]] = None
|
||||
|
||||
@classmethod
|
||||
def get_simple_instance(
|
||||
cls,
|
||||
name: str,
|
||||
llm_description: str,
|
||||
type: ToolParameterType,
|
||||
typ: ToolParameterType,
|
||||
required: bool,
|
||||
options: Optional[list[str]] = None,
|
||||
) -> "ToolParameter":
|
||||
@ -245,22 +271,28 @@ class ToolParameter(BaseModel):
|
||||
# convert options to ToolParameterOption
|
||||
# FIXME fix the type error
|
||||
if options:
|
||||
options = [
|
||||
ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) # type: ignore
|
||||
for option in options # type: ignore
|
||||
option_objs = [
|
||||
PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
|
||||
for option in options
|
||||
]
|
||||
else:
|
||||
option_objs = []
|
||||
|
||||
return cls(
|
||||
name=name,
|
||||
label=I18nObject(en_US="", zh_Hans=""),
|
||||
human_description=I18nObject(en_US="", zh_Hans=""),
|
||||
placeholder=None,
|
||||
type=type,
|
||||
human_description=I18nObject(en_US="", zh_Hans=""),
|
||||
type=typ,
|
||||
form=cls.ToolParameterForm.LLM,
|
||||
llm_description=llm_description,
|
||||
required=required,
|
||||
options=options, # type: ignore
|
||||
options=option_objs,
|
||||
)
|
||||
|
||||
def init_frontend_parameter(self, value: Any):
|
||||
return init_frontend_parameter(self, self.type, value)
|
||||
|
||||
|
||||
class ToolProviderIdentity(BaseModel):
|
||||
author: str = Field(..., description="The author of the tool")
|
||||
@ -274,11 +306,6 @@ class ToolProviderIdentity(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class ToolDescription(BaseModel):
|
||||
human: I18nObject = Field(..., description="The description presented to the user")
|
||||
llm: str = Field(..., description="The description presented to the LLM")
|
||||
|
||||
|
||||
class ToolIdentity(BaseModel):
|
||||
author: str = Field(..., description="The author of the tool")
|
||||
name: str = Field(..., description="The name of the tool")
|
||||
@ -287,185 +314,35 @@ class ToolIdentity(BaseModel):
|
||||
icon: Optional[str] = None
|
||||
|
||||
|
||||
class ToolCredentialsOption(BaseModel):
|
||||
value: str = Field(..., description="The value of the option")
|
||||
label: I18nObject = Field(..., description="The label of the option")
|
||||
class ToolDescription(BaseModel):
|
||||
human: I18nObject = Field(..., description="The description presented to the user")
|
||||
llm: str = Field(..., description="The description presented to the LLM")
|
||||
|
||||
|
||||
class ToolProviderCredentials(BaseModel):
|
||||
class CredentialsType(Enum):
|
||||
SECRET_INPUT = "secret-input"
|
||||
TEXT_INPUT = "text-input"
|
||||
SELECT = "select"
|
||||
BOOLEAN = "boolean"
|
||||
class ToolEntity(BaseModel):
|
||||
identity: ToolIdentity
|
||||
parameters: list[ToolParameter] = Field(default_factory=list)
|
||||
description: Optional[ToolDescription] = None
|
||||
output_schema: Optional[dict] = None
|
||||
has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters")
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ToolProviderCredentials.CredentialsType":
|
||||
"""
|
||||
Get value of given mode.
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
@staticmethod
|
||||
def default(value: str) -> str:
|
||||
return ""
|
||||
|
||||
name: str = Field(..., description="The name of the credentials")
|
||||
type: CredentialsType = Field(..., description="The type of the credentials")
|
||||
required: bool = False
|
||||
default: Optional[Union[int, str]] = None
|
||||
options: Optional[list[ToolCredentialsOption]] = None
|
||||
label: Optional[I18nObject] = None
|
||||
help: Optional[I18nObject] = None
|
||||
url: Optional[str] = None
|
||||
placeholder: Optional[I18nObject] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"name": self.name,
|
||||
"type": self.type.value,
|
||||
"required": self.required,
|
||||
"default": self.default,
|
||||
"options": self.options,
|
||||
"help": self.help.to_dict() if self.help else None,
|
||||
"label": self.label.to_dict() if self.label else None,
|
||||
"url": self.url,
|
||||
"placeholder": self.placeholder.to_dict() if self.placeholder else None,
|
||||
}
|
||||
@field_validator("parameters", mode="before")
|
||||
@classmethod
|
||||
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]:
|
||||
return v or []
|
||||
|
||||
|
||||
class ToolRuntimeVariableType(Enum):
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
class ToolProviderEntity(BaseModel):
|
||||
identity: ToolProviderIdentity
|
||||
plugin_id: Optional[str] = None
|
||||
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ToolRuntimeVariable(BaseModel):
|
||||
type: ToolRuntimeVariableType = Field(..., description="The type of the variable")
|
||||
name: str = Field(..., description="The name of the variable")
|
||||
position: int = Field(..., description="The position of the variable")
|
||||
tool_name: str = Field(..., description="The name of the tool")
|
||||
|
||||
|
||||
class ToolRuntimeTextVariable(ToolRuntimeVariable):
|
||||
value: str = Field(..., description="The value of the variable")
|
||||
|
||||
|
||||
class ToolRuntimeImageVariable(ToolRuntimeVariable):
|
||||
value: str = Field(..., description="The path of the image")
|
||||
|
||||
|
||||
class ToolRuntimeVariablePool(BaseModel):
|
||||
conversation_id: str = Field(..., description="The conversation id")
|
||||
user_id: str = Field(..., description="The user id")
|
||||
tenant_id: str = Field(..., description="The tenant id of assistant")
|
||||
|
||||
pool: list[ToolRuntimeVariable] = Field(..., description="The pool of variables")
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
pool = data.get("pool", [])
|
||||
# convert pool into correct type
|
||||
for index, variable in enumerate(pool):
|
||||
if variable["type"] == ToolRuntimeVariableType.TEXT.value:
|
||||
pool[index] = ToolRuntimeTextVariable(**variable)
|
||||
elif variable["type"] == ToolRuntimeVariableType.IMAGE.value:
|
||||
pool[index] = ToolRuntimeImageVariable(**variable)
|
||||
super().__init__(**data)
|
||||
|
||||
def dict(self) -> dict: # type: ignore
|
||||
"""
|
||||
FIXME: just ignore the type check for now
|
||||
"""
|
||||
return {
|
||||
"conversation_id": self.conversation_id,
|
||||
"user_id": self.user_id,
|
||||
"tenant_id": self.tenant_id,
|
||||
"pool": [variable.model_dump() for variable in self.pool],
|
||||
}
|
||||
|
||||
def set_text(self, tool_name: str, name: str, value: str) -> None:
|
||||
"""
|
||||
set a text variable
|
||||
"""
|
||||
for variable in self.pool:
|
||||
if variable.name == name:
|
||||
if variable.type == ToolRuntimeVariableType.TEXT:
|
||||
variable = cast(ToolRuntimeTextVariable, variable)
|
||||
variable.value = value
|
||||
return
|
||||
|
||||
variable = ToolRuntimeTextVariable(
|
||||
type=ToolRuntimeVariableType.TEXT,
|
||||
name=name,
|
||||
position=len(self.pool),
|
||||
tool_name=tool_name,
|
||||
value=value,
|
||||
)
|
||||
|
||||
self.pool.append(variable)
|
||||
|
||||
def set_file(self, tool_name: str, value: str, name: Optional[str] = None) -> None:
|
||||
"""
|
||||
set an image variable
|
||||
|
||||
:param tool_name: the name of the tool
|
||||
:param value: the id of the file
|
||||
"""
|
||||
# check how many image variables are there
|
||||
image_variable_count = 0
|
||||
for variable in self.pool:
|
||||
if variable.type == ToolRuntimeVariableType.IMAGE:
|
||||
image_variable_count += 1
|
||||
|
||||
if name is None:
|
||||
name = f"file_{image_variable_count}"
|
||||
|
||||
for variable in self.pool:
|
||||
if variable.name == name:
|
||||
if variable.type == ToolRuntimeVariableType.IMAGE:
|
||||
variable = cast(ToolRuntimeImageVariable, variable)
|
||||
variable.value = value
|
||||
return
|
||||
|
||||
variable = ToolRuntimeImageVariable(
|
||||
type=ToolRuntimeVariableType.IMAGE,
|
||||
name=name,
|
||||
position=len(self.pool),
|
||||
tool_name=tool_name,
|
||||
value=value,
|
||||
)
|
||||
|
||||
self.pool.append(variable)
|
||||
|
||||
|
||||
class ModelToolPropertyKey(Enum):
|
||||
IMAGE_PARAMETER_NAME = "image_parameter_name"
|
||||
|
||||
|
||||
class ModelToolConfiguration(BaseModel):
|
||||
"""
|
||||
Model tool configuration
|
||||
"""
|
||||
|
||||
type: str = Field(..., description="The type of the model tool")
|
||||
model: str = Field(..., description="The model")
|
||||
label: I18nObject = Field(..., description="The label of the model tool")
|
||||
properties: dict[ModelToolPropertyKey, Any] = Field(..., description="The properties of the model tool")
|
||||
|
||||
|
||||
class ModelToolProviderConfiguration(BaseModel):
|
||||
"""
|
||||
Model tool provider configuration
|
||||
"""
|
||||
|
||||
provider: str = Field(..., description="The provider of the model tool")
|
||||
models: list[ModelToolConfiguration] = Field(..., description="The models of the model tool")
|
||||
label: I18nObject = Field(..., description="The label of the model tool")
|
||||
class ToolProviderEntityWithPlugin(ToolProviderEntity):
|
||||
tools: list[ToolEntity] = Field(default_factory=list)
|
||||
|
||||
|
||||
class WorkflowToolParameterConfiguration(BaseModel):
|
||||
@ -526,3 +403,25 @@ class ToolInvokeFrom(Enum):
|
||||
|
||||
WORKFLOW = "workflow"
|
||||
AGENT = "agent"
|
||||
PLUGIN = "plugin"
|
||||
|
||||
|
||||
class ToolSelector(BaseModel):
|
||||
dify_model_identity: str = TOOL_SELECTOR_MODEL_IDENTITY
|
||||
|
||||
class Parameter(BaseModel):
|
||||
name: str = Field(..., description="The name of the parameter")
|
||||
type: ToolParameter.ToolParameterType = Field(..., description="The type of the parameter")
|
||||
required: bool = Field(..., description="Whether the parameter is required")
|
||||
description: str = Field(..., description="The description of the parameter")
|
||||
default: Optional[Union[int, float, str]] = None
|
||||
options: Optional[list[PluginParameterOption]] = None
|
||||
|
||||
provider_id: str = Field(..., description="The id of the provider")
|
||||
tool_name: str = Field(..., description="The name of the tool")
|
||||
tool_description: str = Field(..., description="The description of the tool")
|
||||
tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form")
|
||||
tool_parameters: Mapping[str, Parameter] = Field(..., description="Parameters, type llm")
|
||||
|
||||
def to_plugin_parameter(self) -> dict[str, Any]:
|
||||
return self.model_dump()
|
||||
|
||||
79
api/core/tools/plugin_tool/provider.py
Normal file
@ -0,0 +1,79 @@
|
||||
from typing import Any
|
||||
|
||||
from core.plugin.manager.tool import PluginToolManager
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin, ToolProviderType
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.plugin_tool.tool import PluginTool
|
||||
|
||||
|
||||
class PluginToolProviderController(BuiltinToolProviderController):
|
||||
entity: ToolProviderEntityWithPlugin
|
||||
tenant_id: str
|
||||
plugin_id: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self, entity: ToolProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||
) -> None:
|
||||
self.entity = entity
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
:return: type of the provider
|
||||
"""
|
||||
return ToolProviderType.PLUGIN
|
||||
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
"""
|
||||
manager = PluginToolManager()
|
||||
if not manager.validate_provider_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
provider=self.entity.identity.name,
|
||||
credentials=credentials,
|
||||
):
|
||||
raise ToolProviderCredentialValidationError("Invalid credentials")
|
||||
|
||||
def get_tool(self, tool_name: str) -> PluginTool: # type: ignore
|
||||
"""
|
||||
return tool with given name
|
||||
"""
|
||||
tool_entity = next(
|
||||
(tool_entity for tool_entity in self.entity.tools if tool_entity.identity.name == tool_name), None
|
||||
)
|
||||
|
||||
if not tool_entity:
|
||||
raise ValueError(f"Tool with name {tool_name} not found")
|
||||
|
||||
return PluginTool(
|
||||
entity=tool_entity,
|
||||
runtime=ToolRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
|
||||
def get_tools(self) -> list[PluginTool]: # type: ignore
|
||||
"""
|
||||
get all tools
|
||||
"""
|
||||
return [
|
||||
PluginTool(
|
||||
entity=tool_entity,
|
||||
runtime=ToolRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
for tool_entity in self.entity.tools
|
||||
]
|
||||
89
api/core/tools/plugin_tool/tool.py
Normal file
@ -0,0 +1,89 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.plugin.manager.tool import PluginToolManager
|
||||
from core.plugin.utils.converter import convert_parameters_to_plugin_format
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||
|
||||
|
||||
class PluginTool(Tool):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
runtime_parameters: Optional[list[ToolParameter]]
|
||||
|
||||
def __init__(
|
||||
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
self.runtime_parameters = None
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.PLUGIN
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
manager = PluginToolManager()
|
||||
|
||||
tool_parameters = convert_parameters_to_plugin_format(tool_parameters)
|
||||
|
||||
yield from manager.invoke(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
tool_provider=self.entity.identity.provider,
|
||||
tool_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
tool_parameters=tool_parameters,
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> "PluginTool":
|
||||
return PluginTool(
|
||||
entity=self.entity,
|
||||
runtime=runtime,
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> list[ToolParameter]:
|
||||
"""
|
||||
get the runtime parameters
|
||||
"""
|
||||
if not self.entity.has_runtime_parameters:
|
||||
return self.entity.parameters
|
||||
|
||||
if self.runtime_parameters is not None:
|
||||
return self.runtime_parameters
|
||||
|
||||
manager = PluginToolManager()
|
||||
self.runtime_parameters = manager.get_runtime_parameters(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="",
|
||||
provider=self.entity.identity.provider,
|
||||
tool=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
return self.runtime_parameters
|
||||
@ -1,81 +0,0 @@
|
||||
- google
|
||||
- bing
|
||||
- perplexity
|
||||
- duckduckgo
|
||||
- searchapi
|
||||
- serper
|
||||
- searxng
|
||||
- websearch
|
||||
- tavily
|
||||
- stackexchange
|
||||
- pubmed
|
||||
- arxiv
|
||||
- aws
|
||||
- nominatim
|
||||
- devdocs
|
||||
- spider
|
||||
- firecrawl
|
||||
- brave
|
||||
- crossref
|
||||
- jina
|
||||
- webscraper
|
||||
- dalle
|
||||
- azuredalle
|
||||
- stability
|
||||
- stablediffusion
|
||||
- cogview
|
||||
- comfyui
|
||||
- getimgai
|
||||
- siliconflow
|
||||
- spark
|
||||
- stepfun
|
||||
- xinference
|
||||
- alphavantage
|
||||
- yahoo
|
||||
- openweather
|
||||
- gaode
|
||||
- aippt
|
||||
- chart
|
||||
- youtube
|
||||
- did
|
||||
- dingtalk
|
||||
- discord
|
||||
- feishu
|
||||
- feishu_base
|
||||
- feishu_document
|
||||
- feishu_message
|
||||
- feishu_wiki
|
||||
- feishu_task
|
||||
- feishu_calendar
|
||||
- feishu_spreadsheet
|
||||
- lark_base
|
||||
- lark_document
|
||||
- lark_message_and_group
|
||||
- lark_wiki
|
||||
- lark_task
|
||||
- lark_calendar
|
||||
- lark_spreadsheet
|
||||
- slack
|
||||
- twilio
|
||||
- wecom
|
||||
- wikipedia
|
||||
- code
|
||||
- wolframalpha
|
||||
- maths
|
||||
- github
|
||||
- gitlab
|
||||
- time
|
||||
- vectorizer
|
||||
- qrcode
|
||||
- tianditu
|
||||
- aliyuque
|
||||
- google_translate
|
||||
- hap
|
||||
- json_process
|
||||
- judge0ce
|
||||
- novitaai
|
||||
- onebot
|
||||
- regex
|
||||
- trello
|
||||
- vanna
|
||||
- fal
|
||||
@ -1,106 +0,0 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolParameter, ToolParameterOption, ToolProviderType
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool.api_tool import ApiTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppModelConfig
|
||||
from models.tools import PublishedAppTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppToolProviderEntity(ToolProviderController):
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.APP
|
||||
|
||||
def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def get_tools(self, user_id: str = "", tenant_id: str = "") -> list[Tool]:
|
||||
db_tools: list[PublishedAppTool] = (
|
||||
db.session.query(PublishedAppTool)
|
||||
.filter(
|
||||
PublishedAppTool.user_id == user_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not db_tools or len(db_tools) == 0:
|
||||
return []
|
||||
|
||||
tools: list[Tool] = []
|
||||
|
||||
for db_tool in db_tools:
|
||||
tool: dict[str, Any] = {
|
||||
"identity": {
|
||||
"author": db_tool.author,
|
||||
"name": db_tool.tool_name,
|
||||
"label": {"en_US": db_tool.tool_name, "zh_Hans": db_tool.tool_name},
|
||||
"icon": "",
|
||||
},
|
||||
"description": {
|
||||
"human": {"en_US": db_tool.description_i18n.en_US, "zh_Hans": db_tool.description_i18n.zh_Hans},
|
||||
"llm": db_tool.llm_description,
|
||||
},
|
||||
"parameters": [],
|
||||
}
|
||||
# get app from db
|
||||
app: Optional[App] = db_tool.app
|
||||
|
||||
if not app:
|
||||
logger.error(f"app {db_tool.app_id} not found")
|
||||
continue
|
||||
|
||||
app_model_config: AppModelConfig = app.app_model_config
|
||||
user_input_form_list = app_model_config.user_input_form_list
|
||||
for input_form in user_input_form_list:
|
||||
# get type
|
||||
form_type = list(input_form.keys())[0]
|
||||
default = input_form[form_type]["default"]
|
||||
required = input_form[form_type]["required"]
|
||||
label = input_form[form_type]["label"]
|
||||
variable_name = input_form[form_type]["variable_name"]
|
||||
options = input_form[form_type].get("options", [])
|
||||
if form_type in {"paragraph", "text-input"}:
|
||||
tool["parameters"].append(
|
||||
ToolParameter(
|
||||
name=variable_name,
|
||||
label=I18nObject(en_US=label, zh_Hans=label),
|
||||
human_description=I18nObject(en_US=label, zh_Hans=label),
|
||||
llm_description=label,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=required,
|
||||
default=default,
|
||||
placeholder=I18nObject(en_US="", zh_Hans=""),
|
||||
)
|
||||
)
|
||||
elif form_type == "select":
|
||||
tool["parameters"].append(
|
||||
ToolParameter(
|
||||
name=variable_name,
|
||||
label=I18nObject(en_US=label, zh_Hans=label),
|
||||
human_description=I18nObject(en_US=label, zh_Hans=label),
|
||||
llm_description=label,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
required=required,
|
||||
default=default,
|
||||
placeholder=I18nObject(en_US="", zh_Hans=""),
|
||||
options=[
|
||||
ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
|
||||
for option in options
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
tools.append(ApiTool(**tool))
|
||||
return tools
|
||||
|
Before Width: | Height: | Size: 1.9 KiB |
@ -1,11 +0,0 @@
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.aippt.tools.aippt import AIPPTGenerateTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class AIPPTProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
AIPPTGenerateTool._get_api_token(credentials, user_id="__dify_system__")
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
@ -1,45 +0,0 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: aippt
|
||||
label:
|
||||
en_US: AIPPT
|
||||
zh_Hans: AIPPT
|
||||
description:
|
||||
en_US: AI-generated PPT with one click, input your content topic, and let AI serve you one-stop
|
||||
zh_Hans: AI一键生成PPT,输入你的内容主题,让AI为你一站式服务到底
|
||||
icon: icon.png
|
||||
tags:
|
||||
- productivity
|
||||
- design
|
||||
credentials_for_provider:
|
||||
aippt_access_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: AIPPT API key
|
||||
zh_Hans: AIPPT API key
|
||||
pt_BR: AIPPT API key
|
||||
help:
|
||||
en_US: Please input your AIPPT API key
|
||||
zh_Hans: 请输入你的 AIPPT API key
|
||||
pt_BR: Please input your AIPPT API key
|
||||
placeholder:
|
||||
en_US: Please input your AIPPT API key
|
||||
zh_Hans: 请输入你的 AIPPT API key
|
||||
pt_BR: Please input your AIPPT API key
|
||||
url: https://www.aippt.cn
|
||||
aippt_secret_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: AIPPT Secret key
|
||||
zh_Hans: AIPPT Secret key
|
||||
pt_BR: AIPPT Secret key
|
||||
help:
|
||||
en_US: Please input your AIPPT Secret key
|
||||
zh_Hans: 请输入你的 AIPPT Secret key
|
||||
pt_BR: Please input your AIPPT Secret key
|
||||
placeholder:
|
||||
en_US: Please input your AIPPT Secret key
|
||||
zh_Hans: 请输入你的 AIPPT Secret key
|
||||
pt_BR: Please input your AIPPT Secret key
|
||||
@ -1,524 +0,0 @@
|
||||
from base64 import b64encode
|
||||
from hashlib import sha1
|
||||
from hmac import new as hmac_new
|
||||
from json import loads as json_loads
|
||||
from threading import Lock
|
||||
from time import sleep, time
|
||||
from typing import Any, Union
|
||||
|
||||
from httpx import get, post
|
||||
from requests import get as requests_get
|
||||
from yarl import URL
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class AIPPTGenerateToolAdapter:
|
||||
"""
|
||||
A tool for generating a ppt
|
||||
"""
|
||||
|
||||
_api_base_url = URL("https://co.aippt.cn/api")
|
||||
_api_token_cache: dict[str, dict[str, Union[str, float]]] = {}
|
||||
_style_cache: dict[str, dict[str, Union[list[dict[str, Any]], float]]] = {}
|
||||
|
||||
_api_token_cache_lock: Lock = Lock()
|
||||
_style_cache_lock: Lock = Lock()
|
||||
|
||||
_task: dict[str, Any] = {}
|
||||
_task_type_map = {
|
||||
"auto": 1,
|
||||
"markdown": 7,
|
||||
}
|
||||
_tool: BuiltinTool | None
|
||||
|
||||
def __init__(self, tool: BuiltinTool | None = None):
|
||||
self._tool = tool
|
||||
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
Invokes the AIPPT generate tool with the given user ID and tool parameters.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user invoking the tool.
|
||||
tool_parameters (dict[str, Any]): The parameters for the tool
|
||||
|
||||
Returns:
|
||||
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation,
|
||||
which can be a single message or a list of messages.
|
||||
"""
|
||||
title = tool_parameters.get("title", "")
|
||||
if not title:
|
||||
return self._tool.create_text_message("Please provide a title for the ppt")
|
||||
|
||||
model = tool_parameters.get("model", "aippt")
|
||||
if not model:
|
||||
return self._tool.create_text_message("Please provide a model for the ppt")
|
||||
|
||||
outline = tool_parameters.get("outline", "")
|
||||
|
||||
# create task
|
||||
task_id = self._create_task(
|
||||
type=self._task_type_map["auto" if not outline else "markdown"],
|
||||
title=title,
|
||||
content=outline,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# get suit
|
||||
color: str = tool_parameters.get("color", "")
|
||||
style: str = tool_parameters.get("style", "")
|
||||
|
||||
if color == "__default__":
|
||||
color_id = ""
|
||||
else:
|
||||
color_id = int(color.split("-")[1])
|
||||
|
||||
if style == "__default__":
|
||||
style_id = ""
|
||||
else:
|
||||
style_id = int(style.split("-")[1])
|
||||
|
||||
suit_id = self._get_suit(style_id=style_id, colour_id=color_id)
|
||||
|
||||
# generate outline
|
||||
if not outline:
|
||||
self._generate_outline(task_id=task_id, model=model, user_id=user_id)
|
||||
|
||||
# generate content
|
||||
self._generate_content(task_id=task_id, model=model, user_id=user_id)
|
||||
|
||||
# generate ppt
|
||||
_, ppt_url = self._generate_ppt(task_id=task_id, suit_id=suit_id, user_id=user_id)
|
||||
|
||||
return self._tool.create_text_message(
|
||||
"""the ppt has been created successfully,"""
|
||||
f"""the ppt url is {ppt_url} ."""
|
||||
"""please give the ppt url to user and direct user to download it."""
|
||||
)
|
||||
|
||||
def _create_task(self, type: int, title: str, content: str, user_id: str) -> str:
|
||||
"""
|
||||
Create a task
|
||||
|
||||
:param type: the task type
|
||||
:param title: the task title
|
||||
:param content: the task content
|
||||
|
||||
:return: the task ID
|
||||
"""
|
||||
headers = {
|
||||
"x-channel": "",
|
||||
"x-api-key": self._tool.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id),
|
||||
}
|
||||
response = post(
|
||||
str(self._api_base_url / "ai" / "chat" / "v2" / "task"),
|
||||
headers=headers,
|
||||
files={"type": ("", str(type)), "title": ("", title), "content": ("", content)},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f"Failed to create task: {response.get('msg')}")
|
||||
|
||||
return response.get("data", {}).get("id")
|
||||
|
||||
def _generate_outline(self, task_id: str, model: str, user_id: str) -> str:
|
||||
api_url = (
|
||||
self._api_base_url / "ai" / "chat" / "outline"
|
||||
if model == "aippt"
|
||||
else self._api_base_url / "ai" / "chat" / "wx" / "outline"
|
||||
)
|
||||
api_url %= {"task_id": task_id}
|
||||
|
||||
headers = {
|
||||
"x-channel": "",
|
||||
"x-api-key": self._tool.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id),
|
||||
}
|
||||
|
||||
response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60))
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
|
||||
outline = ""
|
||||
for chunk in response.iter_lines(delimiter=b"\n\n"):
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
event = ""
|
||||
lines = chunk.decode("utf-8").split("\n")
|
||||
for line in lines:
|
||||
if line.startswith("event:"):
|
||||
event = line[6:]
|
||||
elif line.startswith("data:"):
|
||||
data = line[5:]
|
||||
if event == "message":
|
||||
try:
|
||||
data = json_loads(data)
|
||||
outline += data.get("content", "")
|
||||
except Exception as e:
|
||||
pass
|
||||
elif event == "close":
|
||||
break
|
||||
elif event in {"error", "filter"}:
|
||||
raise Exception(f"Failed to generate outline: {data}")
|
||||
|
||||
return outline
|
||||
|
||||
def _generate_content(self, task_id: str, model: str, user_id: str) -> str:
|
||||
api_url = (
|
||||
self._api_base_url / "ai" / "chat" / "content"
|
||||
if model == "aippt"
|
||||
else self._api_base_url / "ai" / "chat" / "wx" / "content"
|
||||
)
|
||||
api_url %= {"task_id": task_id}
|
||||
|
||||
headers = {
|
||||
"x-channel": "",
|
||||
"x-api-key": self._tool.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id),
|
||||
}
|
||||
|
||||
response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60))
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
|
||||
if model == "aippt":
|
||||
content = ""
|
||||
for chunk in response.iter_lines(delimiter=b"\n\n"):
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
event = ""
|
||||
lines = chunk.decode("utf-8").split("\n")
|
||||
for line in lines:
|
||||
if line.startswith("event:"):
|
||||
event = line[6:]
|
||||
elif line.startswith("data:"):
|
||||
data = line[5:]
|
||||
if event == "message":
|
||||
try:
|
||||
data = json_loads(data)
|
||||
content += data.get("content", "")
|
||||
except Exception as e:
|
||||
pass
|
||||
elif event == "close":
|
||||
break
|
||||
elif event in {"error", "filter"}:
|
||||
raise Exception(f"Failed to generate content: {data}")
|
||||
|
||||
return content
|
||||
elif model == "wenxin":
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f"Failed to generate content: {response.get('msg')}")
|
||||
|
||||
return response.get("data", "")
|
||||
|
||||
return ""
|
||||
|
||||
def _generate_ppt(self, task_id: str, suit_id: int, user_id: str) -> tuple[str, str]:
|
||||
"""
|
||||
Generate a ppt
|
||||
|
||||
:param task_id: the task ID
|
||||
:param suit_id: the suit ID
|
||||
:return: the cover url of the ppt and the ppt url
|
||||
"""
|
||||
headers = {
|
||||
"x-channel": "",
|
||||
"x-api-key": self._tool.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id),
|
||||
}
|
||||
|
||||
response = post(
|
||||
str(self._api_base_url / "design" / "v2" / "save"),
|
||||
headers=headers,
|
||||
data={"task_id": task_id, "template_id": suit_id},
|
||||
timeout=(10, 60),
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f"Failed to generate ppt: {response.get('msg')}")
|
||||
|
||||
id = response.get("data", {}).get("id")
|
||||
cover_url = response.get("data", {}).get("cover_url")
|
||||
|
||||
response = post(
|
||||
str(self._api_base_url / "download" / "export" / "file"),
|
||||
headers=headers,
|
||||
data={"id": id, "format": "ppt", "files_to_zip": False, "edit": True},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f"Failed to generate ppt: {response.get('msg')}")
|
||||
|
||||
export_code = response.get("data")
|
||||
if not export_code:
|
||||
raise Exception("Failed to generate ppt, the export code is empty")
|
||||
|
||||
current_iteration = 0
|
||||
while current_iteration < 50:
|
||||
# get ppt url
|
||||
response = post(
|
||||
str(self._api_base_url / "download" / "export" / "file" / "result"),
|
||||
headers=headers,
|
||||
data={"task_key": export_code},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f"Failed to generate ppt: {response.get('msg')}")
|
||||
|
||||
if response.get("msg") == "导出中":
|
||||
current_iteration += 1
|
||||
sleep(2)
|
||||
continue
|
||||
|
||||
ppt_url = response.get("data", [])
|
||||
if len(ppt_url) == 0:
|
||||
raise Exception("Failed to generate ppt, the ppt url is empty")
|
||||
|
||||
return cover_url, ppt_url[0]
|
||||
|
||||
raise Exception("Failed to generate ppt, the export is timeout")
|
||||
|
||||
@classmethod
|
||||
def _get_api_token(cls, credentials: dict[str, str], user_id: str) -> str:
|
||||
"""
|
||||
Get API token
|
||||
|
||||
:param credentials: the credentials
|
||||
:return: the API token
|
||||
"""
|
||||
access_key = credentials["aippt_access_key"]
|
||||
secret_key = credentials["aippt_secret_key"]
|
||||
|
||||
cache_key = f"{access_key}#@#{user_id}"
|
||||
|
||||
with cls._api_token_cache_lock:
|
||||
# clear expired tokens
|
||||
now = time()
|
||||
for key in list(cls._api_token_cache.keys()):
|
||||
if cls._api_token_cache[key]["expire"] < now:
|
||||
del cls._api_token_cache[key]
|
||||
|
||||
if cache_key in cls._api_token_cache:
|
||||
return cls._api_token_cache[cache_key]["token"]
|
||||
|
||||
# get token
|
||||
headers = {
|
||||
"x-api-key": access_key,
|
||||
"x-timestamp": str(int(now)),
|
||||
"x-signature": cls._calculate_sign(access_key, secret_key, int(now)),
|
||||
}
|
||||
|
||||
param = {"uid": user_id, "channel": ""}
|
||||
|
||||
response = get(str(cls._api_base_url / "grant" / "token"), params=param, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f"Failed to connect to aippt: {response.get('msg')}")
|
||||
|
||||
token = response.get("data", {}).get("token")
|
||||
expire = response.get("data", {}).get("time_expire")
|
||||
|
||||
with cls._api_token_cache_lock:
|
||||
cls._api_token_cache[cache_key] = {"token": token, "expire": now + expire}
|
||||
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
def _calculate_sign(access_key: str, secret_key: str, timestamp: int) -> str:
|
||||
return b64encode(
|
||||
hmac_new(
|
||||
key=secret_key.encode("utf-8"),
|
||||
msg=f"GET@/api/grant/token/@{timestamp}".encode(),
|
||||
digestmod=sha1,
|
||||
).digest()
|
||||
).decode("utf-8")
|
||||
|
||||
@classmethod
|
||||
def _get_styles(
|
||||
cls, credentials: dict[str, str], user_id: str
|
||||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||
"""
|
||||
Get styles
|
||||
"""
|
||||
|
||||
# check cache
|
||||
with cls._style_cache_lock:
|
||||
# clear expired styles
|
||||
now = time()
|
||||
for key in list(cls._style_cache.keys()):
|
||||
if cls._style_cache[key]["expire"] < now:
|
||||
del cls._style_cache[key]
|
||||
|
||||
key = f"{credentials['aippt_access_key']}#@#{user_id}"
|
||||
if key in cls._style_cache:
|
||||
return cls._style_cache[key]["colors"], cls._style_cache[key]["styles"]
|
||||
|
||||
headers = {
|
||||
"x-channel": "",
|
||||
"x-api-key": credentials["aippt_access_key"],
|
||||
"x-token": cls._get_api_token(credentials=credentials, user_id=user_id),
|
||||
}
|
||||
response = get(str(cls._api_base_url / "template_component" / "suit" / "select"), headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
|
||||
response = response.json()
|
||||
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f"Failed to connect to aippt: {response.get('msg')}")
|
||||
|
||||
colors = [
|
||||
{
|
||||
"id": f"id-{item.get('id')}",
|
||||
"name": item.get("name"),
|
||||
"en_name": item.get("en_name", item.get("name")),
|
||||
}
|
||||
for item in response.get("data", {}).get("colour") or []
|
||||
]
|
||||
styles = [
|
||||
{
|
||||
"id": f"id-{item.get('id')}",
|
||||
"name": item.get("title"),
|
||||
}
|
||||
for item in response.get("data", {}).get("suit_style") or []
|
||||
]
|
||||
|
||||
with cls._style_cache_lock:
|
||||
cls._style_cache[key] = {"colors": colors, "styles": styles, "expire": now + 60 * 60}
|
||||
|
||||
return colors, styles
|
||||
|
||||
def get_styles(self, user_id: str) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||
"""
|
||||
Get styles
|
||||
|
||||
:param credentials: the credentials
|
||||
:return: Tuple[list[dict[id, color]], list[dict[id, style]]
|
||||
"""
|
||||
if not self._tool.runtime.credentials.get("aippt_access_key") or not self._tool.runtime.credentials.get(
|
||||
"aippt_secret_key"
|
||||
):
|
||||
raise Exception("Please provide aippt credentials")
|
||||
|
||||
return self._get_styles(credentials=self._tool.runtime.credentials, user_id=user_id)
|
||||
|
||||
def _get_suit(self, style_id: int, colour_id: int) -> int:
|
||||
"""
|
||||
Get suit
|
||||
"""
|
||||
headers = {
|
||||
"x-channel": "",
|
||||
"x-api-key": self._tool.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id="__dify_system__"),
|
||||
}
|
||||
response = get(
|
||||
str(self._api_base_url / "template_component" / "suit" / "search"),
|
||||
headers=headers,
|
||||
params={"style_id": style_id, "colour_id": colour_id, "page": 1, "page_size": 1},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
|
||||
response = response.json()
|
||||
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f"Failed to connect to aippt: {response.get('msg')}")
|
||||
|
||||
if len(response.get("data", {}).get("list") or []) > 0:
|
||||
return response.get("data", {}).get("list")[0].get("id")
|
||||
|
||||
raise Exception("Failed to get suit, the suit does not exist, please check the style and color")
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
"""
|
||||
Get runtime parameters
|
||||
|
||||
Override this method to add runtime parameters to the tool.
|
||||
"""
|
||||
try:
|
||||
colors, styles = self.get_styles(user_id="__dify_system__")
|
||||
except Exception as e:
|
||||
colors, styles = (
|
||||
[{"id": "-1", "name": "__default__", "en_name": "__default__"}],
|
||||
[{"id": "-1", "name": "__default__", "en_name": "__default__"}],
|
||||
)
|
||||
|
||||
return [
|
||||
ToolParameter(
|
||||
name="color",
|
||||
label=I18nObject(zh_Hans="颜色", en_US="Color"),
|
||||
human_description=I18nObject(zh_Hans="颜色", en_US="Color"),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
required=False,
|
||||
default=colors[0]["id"],
|
||||
options=[
|
||||
ToolParameterOption(
|
||||
value=color["id"], label=I18nObject(zh_Hans=color["name"], en_US=color["en_name"])
|
||||
)
|
||||
for color in colors
|
||||
],
|
||||
),
|
||||
ToolParameter(
|
||||
name="style",
|
||||
label=I18nObject(zh_Hans="风格", en_US="Style"),
|
||||
human_description=I18nObject(zh_Hans="风格", en_US="Style"),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
required=False,
|
||||
default=styles[0]["id"],
|
||||
options=[
|
||||
ToolParameterOption(value=style["id"], label=I18nObject(zh_Hans=style["name"], en_US=style["name"]))
|
||||
for style in styles
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class AIPPTGenerateTool(BuiltinTool):
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
return AIPPTGenerateToolAdapter(self)._invoke(user_id, tool_parameters)
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
return AIPPTGenerateToolAdapter(self).get_runtime_parameters()
|
||||
|
||||
@classmethod
|
||||
def _get_api_token(cls, credentials: dict[str, str], user_id: str) -> str:
|
||||
return AIPPTGenerateToolAdapter()._get_api_token(credentials, user_id)
|
||||
@ -1,54 +0,0 @@
|
||||
identity:
|
||||
name: aippt
|
||||
author: Dify
|
||||
label:
|
||||
en_US: AIPPT
|
||||
zh_Hans: AIPPT
|
||||
description:
|
||||
human:
|
||||
en_US: AI-generated PPT with one click, input your content topic, and let AI serve you one-stop
|
||||
zh_Hans: AI一键生成PPT,输入你的内容主题,让AI为你一站式服务到底
|
||||
llm: A tool used to generate PPT with AI, input your content topic, and let AI generate PPT for you.
|
||||
parameters:
|
||||
- name: title
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Title
|
||||
zh_Hans: 标题
|
||||
human_description:
|
||||
en_US: The title of the PPT.
|
||||
zh_Hans: PPT的标题。
|
||||
llm_description: The title of the PPT, which will be used to generate the PPT outline.
|
||||
form: llm
|
||||
- name: outline
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Outline
|
||||
zh_Hans: 大纲
|
||||
human_description:
|
||||
en_US: The outline of the PPT
|
||||
zh_Hans: PPT的大纲
|
||||
llm_description: The outline of the PPT, which will be used to generate the PPT content. provide it if you have.
|
||||
form: llm
|
||||
- name: llm
|
||||
type: select
|
||||
required: true
|
||||
label:
|
||||
en_US: LLM model
|
||||
zh_Hans: 生成大纲的LLM
|
||||
options:
|
||||
- value: aippt
|
||||
label:
|
||||
en_US: AIPPT default model
|
||||
zh_Hans: AIPPT默认模型
|
||||
- value: wenxin
|
||||
label:
|
||||
en_US: Wenxin ErnieBot
|
||||
zh_Hans: 文心一言
|
||||
default: aippt
|
||||
human_description:
|
||||
en_US: The LLM model used for generating PPT outline.
|
||||
zh_Hans: 用于生成PPT大纲的LLM模型。
|
||||
form: form
|
||||
|
Before Width: | Height: | Size: 7.1 KiB |
@ -1,19 +0,0 @@
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class AliYuqueProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
token = credentials.get("token")
|
||||
if not token:
|
||||
raise ToolProviderCredentialValidationError("token is required")
|
||||
|
||||
try:
|
||||
resp = AliYuqueTool.auth(token)
|
||||
if resp and resp.get("data", {}).get("id"):
|
||||
return
|
||||
|
||||
raise ToolProviderCredentialValidationError(resp)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
@ -1,29 +0,0 @@
|
||||
identity:
|
||||
author: 佐井
|
||||
name: aliyuque
|
||||
label:
|
||||
en_US: yuque
|
||||
zh_Hans: 语雀
|
||||
pt_BR: yuque
|
||||
description:
|
||||
en_US: Yuque, https://www.yuque.com.
|
||||
zh_Hans: 语雀,https://www.yuque.com。
|
||||
pt_BR: Yuque, https://www.yuque.com.
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- productivity
|
||||
- search
|
||||
credentials_for_provider:
|
||||
token:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Yuque Team Token
|
||||
zh_Hans: 语雀团队Token
|
||||
placeholder:
|
||||
en_US: Please input your Yuque team token
|
||||
zh_Hans: 请输入你的语雀团队Token
|
||||
help:
|
||||
en_US: Get Alibaba Yuque team token
|
||||
zh_Hans: 先获取语雀团队Token
|
||||
url: https://www.yuque.com/settings/tokens
|
||||
@ -1,42 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class AliYuqueTool:
|
||||
# yuque service url
|
||||
server_url = "https://www.yuque.com"
|
||||
|
||||
@staticmethod
|
||||
def auth(token):
|
||||
session = requests.Session()
|
||||
session.headers.update({"Accept": "application/json", "X-Auth-Token": token})
|
||||
login = session.request("GET", AliYuqueTool.server_url + "/api/v2/user")
|
||||
login.raise_for_status()
|
||||
resp = login.json()
|
||||
return resp
|
||||
|
||||
def request(self, method: str, token, tool_parameters: dict[str, Any], path: str) -> str:
|
||||
if not token:
|
||||
raise Exception("token is required")
|
||||
session = requests.Session()
|
||||
session.headers.update({"accept": "application/json", "X-Auth-Token": token})
|
||||
new_params = {**tool_parameters}
|
||||
|
||||
replacements = {k: v for k, v in new_params.items() if f"{{{k}}}" in path}
|
||||
|
||||
for key, value in replacements.items():
|
||||
path = path.replace(f"{{{key}}}", str(value))
|
||||
del new_params[key]
|
||||
|
||||
if method.upper() in {"POST", "PUT"}:
|
||||
session.headers.update(
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
)
|
||||
response = session.request(method.upper(), self.server_url + path, json=new_params)
|
||||
else:
|
||||
response = session.request(method, self.server_url + path, params=new_params)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
@ -1,15 +0,0 @@
|
||||
from typing import Any, Union
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class AliYuqueCreateDocumentTool(AliYuqueTool, BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
token = self.runtime.credentials.get("token", None)
|
||||
if not token:
|
||||
raise Exception("token is required")
|
||||
return self.create_text_message(self.request("POST", token, tool_parameters, "/api/v2/repos/{book_id}/docs"))
|
||||
@ -1,99 +0,0 @@
|
||||
identity:
|
||||
name: aliyuque_create_document
|
||||
author: 佐井
|
||||
label:
|
||||
en_US: Create Document
|
||||
zh_Hans: 创建文档
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: Creates a new document within a knowledge base without automatic addition to the table of contents. Requires a subsequent call to the "knowledge base directory update API". Supports setting visibility, format, and content. # 接口英文描述
|
||||
zh_Hans: 在知识库中创建新文档,但不会自动加入目录,需额外调用“知识库目录更新接口”。允许设置公开性、格式及正文内容。
|
||||
llm: Creates docs in a KB.
|
||||
|
||||
parameters:
|
||||
- name: book_id
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
en_US: Knowledge Base ID
|
||||
zh_Hans: 知识库ID
|
||||
human_description:
|
||||
en_US: The unique identifier of the knowledge base where the document will be created.
|
||||
zh_Hans: 文档将被创建的知识库的唯一标识。
|
||||
llm_description: ID of the target knowledge base.
|
||||
|
||||
- name: title
|
||||
type: string
|
||||
required: false
|
||||
form: llm
|
||||
label:
|
||||
en_US: Title
|
||||
zh_Hans: 标题
|
||||
human_description:
|
||||
en_US: The title of the document, defaults to 'Untitled' if not provided.
|
||||
zh_Hans: 文档标题,默认为'无标题'如未提供。
|
||||
llm_description: Title of the document, defaults to 'Untitled'.
|
||||
|
||||
- name: public
|
||||
type: select
|
||||
required: false
|
||||
form: llm
|
||||
options:
|
||||
- value: 0
|
||||
label:
|
||||
en_US: Private
|
||||
zh_Hans: 私密
|
||||
- value: 1
|
||||
label:
|
||||
en_US: Public
|
||||
zh_Hans: 公开
|
||||
- value: 2
|
||||
label:
|
||||
en_US: Enterprise-only
|
||||
zh_Hans: 企业内公开
|
||||
label:
|
||||
en_US: Visibility
|
||||
zh_Hans: 公开性
|
||||
human_description:
|
||||
en_US: Document visibility (0 Private, 1 Public, 2 Enterprise-only).
|
||||
zh_Hans: 文档可见性(0 私密, 1 公开, 2 企业内公开)。
|
||||
llm_description: Doc visibility options, 0-private, 1-public, 2-enterprise.
|
||||
|
||||
- name: format
|
||||
type: select
|
||||
required: false
|
||||
form: llm
|
||||
options:
|
||||
- value: markdown
|
||||
label:
|
||||
en_US: markdown
|
||||
zh_Hans: markdown
|
||||
- value: html
|
||||
label:
|
||||
en_US: html
|
||||
zh_Hans: html
|
||||
- value: lake
|
||||
label:
|
||||
en_US: lake
|
||||
zh_Hans: lake
|
||||
label:
|
||||
en_US: Content Format
|
||||
zh_Hans: 内容格式
|
||||
human_description:
|
||||
en_US: Format of the document content (markdown, HTML, Lake).
|
||||
zh_Hans: 文档内容格式(markdown, HTML, Lake)。
|
||||
llm_description: Content format choices, markdown, HTML, Lake.
|
||||
|
||||
- name: body
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
en_US: Body Content
|
||||
zh_Hans: 正文内容
|
||||
human_description:
|
||||
en_US: The actual content of the document.
|
||||
zh_Hans: 文档的实际内容。
|
||||
llm_description: Content of the document.
|
||||
@ -1,17 +0,0 @@
|
||||
from typing import Any, Union
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class AliYuqueDeleteDocumentTool(AliYuqueTool, BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
token = self.runtime.credentials.get("token", None)
|
||||
if not token:
|
||||
raise Exception("token is required")
|
||||
return self.create_text_message(
|
||||
self.request("DELETE", token, tool_parameters, "/api/v2/repos/{book_id}/docs/{id}")
|
||||
)
|
||||
@ -1,37 +0,0 @@
|
||||
identity:
|
||||
name: aliyuque_delete_document
|
||||
author: 佐井
|
||||
label:
|
||||
en_US: Delete Document
|
||||
zh_Hans: 删除文档
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: Delete Document
|
||||
zh_Hans: 根据id删除文档
|
||||
llm: Delete document.
|
||||
|
||||
parameters:
|
||||
- name: book_id
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
en_US: Knowledge Base ID
|
||||
zh_Hans: 知识库ID
|
||||
human_description:
|
||||
en_US: The unique identifier of the knowledge base where the document will be created.
|
||||
zh_Hans: 文档将被创建的知识库的唯一标识。
|
||||
llm_description: ID of the target knowledge base.
|
||||
|
||||
- name: id
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
en_US: Document ID or Path
|
||||
zh_Hans: 文档 ID or 路径
|
||||
human_description:
|
||||
en_US: Document ID or path.
|
||||
zh_Hans: 文档 ID or 路径。
|
||||
llm_description: Document ID or path.
|
||||
@ -1,17 +0,0 @@
|
||||
from typing import Any, Union
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class AliYuqueDescribeBookIndexPageTool(AliYuqueTool, BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
token = self.runtime.credentials.get("token", None)
|
||||
if not token:
|
||||
raise Exception("token is required")
|
||||
return self.create_text_message(
|
||||
self.request("GET", token, tool_parameters, "/api/v2/repos/{group_login}/{book_slug}/index_page")
|
||||
)
|
||||
@ -1,38 +0,0 @@
|
||||
identity:
|
||||
name: aliyuque_describe_book_index_page
|
||||
author: 佐井
|
||||
label:
|
||||
en_US: Get Repo Index Page
|
||||
zh_Hans: 获取知识库首页
|
||||
icon: icon.svg
|
||||
|
||||
description:
|
||||
human:
|
||||
en_US: Retrieves the homepage of a knowledge base within a group, supporting both book ID and group login with book slug access.
|
||||
zh_Hans: 获取团队中知识库的首页信息,可通过书籍ID或团队登录名与书籍路径访问。
|
||||
llm: Fetches the knowledge base homepage using group and book identifiers with support for alternate access paths.
|
||||
|
||||
parameters:
|
||||
- name: group_login
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
en_US: Group Login
|
||||
zh_Hans: 团队登录名
|
||||
human_description:
|
||||
en_US: The login name of the group that owns the knowledge base.
|
||||
zh_Hans: 拥有该知识库的团队登录名。
|
||||
llm_description: Team login identifier for the knowledge base owner.
|
||||
|
||||
- name: book_slug
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
en_US: Book Slug
|
||||
zh_Hans: 知识库路径
|
||||
human_description:
|
||||
en_US: The unique slug representing the path of the knowledge base.
|
||||
zh_Hans: 知识库的唯一路径标识。
|
||||
llm_description: Unique path identifier for the knowledge base.
|
||||
@ -1,15 +0,0 @@
|
||||
from typing import Any, Union
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class YuqueDescribeBookTableOfContentsTool(AliYuqueTool, BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> (Union)[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
token = self.runtime.credentials.get("token", None)
|
||||
if not token:
|
||||
raise Exception("token is required")
|
||||
return self.create_text_message(self.request("GET", token, tool_parameters, "/api/v2/repos/{book_id}/toc"))
|
||||
@ -1,25 +0,0 @@
|
||||
identity:
|
||||
name: aliyuque_describe_book_table_of_contents
|
||||
author: 佐井
|
||||
label:
|
||||
en_US: Get Book's Table of Contents
|
||||
zh_Hans: 获取知识库的目录
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: Get Book's Table of Contents.
|
||||
zh_Hans: 获取知识库的目录。
|
||||
llm: Get Book's Table of Contents.
|
||||
|
||||
parameters:
|
||||
- name: book_id
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
en_US: Book ID
|
||||
zh_Hans: 知识库 ID
|
||||
human_description:
|
||||
en_US: Book ID.
|
||||
zh_Hans: 知识库 ID。
|
||||
llm_description: Book ID.
|
||||
@ -1,53 +0,0 @@
|
||||
import json
|
||||
from typing import Any, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class AliYuqueDescribeDocumentContentTool(AliYuqueTool, BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
new_params = {**tool_parameters}
|
||||
token = new_params.pop("token")
|
||||
if not token or token.lower() == "none":
|
||||
token = self.runtime.credentials.get("token", None)
|
||||
if not token:
|
||||
raise Exception("token is required")
|
||||
new_params = {**tool_parameters}
|
||||
url = new_params.pop("url")
|
||||
if not url or not url.startswith("http"):
|
||||
raise Exception("url is not valid")
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
path_parts = parsed_url.path.strip("/").split("/")
|
||||
if len(path_parts) < 3:
|
||||
raise Exception("url is not correct")
|
||||
doc_id = path_parts[-1]
|
||||
book_slug = path_parts[-2]
|
||||
group_id = path_parts[-3]
|
||||
|
||||
new_params["group_login"] = group_id
|
||||
new_params["book_slug"] = book_slug
|
||||
index_page = json.loads(
|
||||
self.request("GET", token, new_params, "/api/v2/repos/{group_login}/{book_slug}/index_page")
|
||||
)
|
||||
book_id = index_page.get("data", {}).get("book", {}).get("id")
|
||||
if not book_id:
|
||||
raise Exception(f"can not parse book_id from {index_page}")
|
||||
|
||||
new_params["book_id"] = book_id
|
||||
new_params["id"] = doc_id
|
||||
data = self.request("GET", token, new_params, "/api/v2/repos/{book_id}/docs/{id}")
|
||||
data = json.loads(data)
|
||||
body_only = tool_parameters.get("body_only") or ""
|
||||
if body_only.lower() == "true":
|
||||
return self.create_text_message(data.get("data").get("body"))
|
||||
else:
|
||||
raw = data.get("data")
|
||||
del raw["body_lake"]
|
||||
del raw["body_html"]
|
||||
return self.create_text_message(json.dumps(data))
|
||||
@ -1,50 +0,0 @@
|
||||
identity:
|
||||
name: aliyuque_describe_document_content
|
||||
author: 佐井
|
||||
label:
|
||||
en_US: Fetch Document Content
|
||||
zh_Hans: 获取文档内容
|
||||
icon: icon.svg
|
||||
|
||||
description:
|
||||
human:
|
||||
en_US: Retrieves document content from Yuque based on the provided document URL, which can be a normal or shared link.
|
||||
zh_Hans: 根据提供的语雀文档地址(支持正常链接或分享链接)获取文档内容。
|
||||
llm: Fetches Yuque document content given a URL.
|
||||
|
||||
parameters:
|
||||
- name: url
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
en_US: Document URL
|
||||
zh_Hans: 文档地址
|
||||
human_description:
|
||||
en_US: The URL of the document to retrieve content from, can be normal or shared.
|
||||
zh_Hans: 需要获取内容的文档地址,可以是正常链接或分享链接。
|
||||
llm_description: URL of the Yuque document to fetch content.
|
||||
|
||||
- name: body_only
|
||||
type: string
|
||||
required: false
|
||||
form: llm
|
||||
label:
|
||||
en_US: return body content only
|
||||
zh_Hans: 仅返回body内容
|
||||
human_description:
|
||||
en_US: true:Body content only, false:Full response with metadata.
|
||||
zh_Hans: true:仅返回body内容,不返回其他元数据,false:返回所有元数据。
|
||||
llm_description: true:Body content only, false:Full response with metadata.
|
||||
|
||||
- name: token
|
||||
type: secret-input
|
||||
required: false
|
||||
form: llm
|
||||
label:
|
||||
en_US: Yuque API Token
|
||||
zh_Hans: 语雀接口Token
|
||||
human_description:
|
||||
en_US: The token for calling the Yuque API defaults to the Yuque token bound to the current tool if not provided.
|
||||
zh_Hans: 调用语雀接口的token,如果不传则默认为当前工具绑定的语雀Token。
|
||||
llm_description: If the token for calling the Yuque API is not provided, it will default to the Yuque token bound to the current tool.
|
||||
@ -1,17 +0,0 @@
|
||||
from typing import Any, Union
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class AliYuqueDescribeDocumentsTool(AliYuqueTool, BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
token = self.runtime.credentials.get("token", None)
|
||||
if not token:
|
||||
raise Exception("token is required")
|
||||
return self.create_text_message(
|
||||
self.request("GET", token, tool_parameters, "/api/v2/repos/{book_id}/docs/{id}")
|
||||
)
|
||||
@ -1,38 +0,0 @@
|
||||
identity:
|
||||
name: aliyuque_describe_documents
|
||||
author: 佐井
|
||||
label:
|
||||
en_US: Get Doc Detail
|
||||
zh_Hans: 获取文档详情
|
||||
icon: icon.svg
|
||||
|
||||
description:
|
||||
human:
|
||||
en_US: Retrieves detailed information of a specific document identified by its ID or path within a knowledge base.
|
||||
zh_Hans: 根据知识库ID和文档ID或路径获取文档详细信息。
|
||||
llm: Fetches detailed doc info using ID/path from a knowledge base; supports doc lookup in Yuque.
|
||||
|
||||
parameters:
|
||||
- name: book_id
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
en_US: Knowledge Base ID
|
||||
zh_Hans: 知识库 ID
|
||||
human_description:
|
||||
en_US: Identifier for the knowledge base where the document resides.
|
||||
zh_Hans: 文档所属知识库的唯一标识。
|
||||
llm_description: ID of the knowledge base holding the document.
|
||||
|
||||
- name: id
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
en_US: Document ID or Path
|
||||
zh_Hans: 文档 ID 或路径
|
||||
human_description:
|
||||
en_US: The unique identifier or path of the document to retrieve.
|
||||
zh_Hans: 需要获取的文档的ID或其在知识库中的路径。
|
||||
llm_description: Unique doc ID or its path for retrieval.
|
||||
@ -1,21 +0,0 @@
|
||||
from typing import Any, Union
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class YuqueDescribeBookTableOfContentsTool(AliYuqueTool, BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> (Union)[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
token = self.runtime.credentials.get("token", None)
|
||||
if not token:
|
||||
raise Exception("token is required")
|
||||
|
||||
doc_ids = tool_parameters.get("doc_ids")
|
||||
if doc_ids:
|
||||
doc_ids = [int(doc_id.strip()) for doc_id in doc_ids.split(",")]
|
||||
tool_parameters["doc_ids"] = doc_ids
|
||||
|
||||
return self.create_text_message(self.request("PUT", token, tool_parameters, "/api/v2/repos/{book_id}/toc"))
|
||||
@ -1,222 +0,0 @@
|
||||
identity:
|
||||
name: aliyuque_update_book_table_of_contents
|
||||
author: 佐井
|
||||
label:
|
||||
en_US: Update Book's Table of Contents
|
||||
zh_Hans: 更新知识库目录
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: Update Book's Table of Contents.
|
||||
zh_Hans: 更新知识库目录。
|
||||
llm: Update Book's Table of Contents.
|
||||
|
||||
parameters:
|
||||
- name: book_id
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
en_US: Book ID
|
||||
zh_Hans: 知识库 ID
|
||||
human_description:
|
||||
en_US: Book ID.
|
||||
zh_Hans: 知识库 ID。
|
||||
llm_description: Book ID.
|
||||
|
||||
- name: action
|
||||
type: select
|
||||
required: true
|
||||
form: llm
|
||||
options:
|
||||
- value: appendNode
|
||||
label:
|
||||
en_US: appendNode
|
||||
zh_Hans: appendNode
|
||||
pt_BR: appendNode
|
||||
- value: prependNode
|
||||
label:
|
||||
en_US: prependNode
|
||||
zh_Hans: prependNode
|
||||
pt_BR: prependNode
|
||||
- value: editNode
|
||||
label:
|
||||
en_US: editNode
|
||||
zh_Hans: editNode
|
||||
pt_BR: editNode
|
||||
- value: editNode
|
||||
label:
|
||||
en_US: removeNode
|
||||
zh_Hans: removeNode
|
||||
pt_BR: removeNode
|
||||
label:
|
||||
en_US: Action Type
|
||||
zh_Hans: 操作
|
||||
human_description:
|
||||
en_US: In the operation scenario, sibling node prepending is not supported, deleting a node doesn't remove associated documents, and node deletion has two modes, 'sibling' (delete current node) and 'child' (delete current node and its children).
|
||||
zh_Hans: 操作,创建场景下不支持同级头插 prependNode,删除节点不会删除关联文档,删除节点时action_mode=sibling (删除当前节点), action_mode=child (删除当前节点及子节点)
|
||||
llm_description: In the operation scenario, sibling node prepending is not supported, deleting a node doesn't remove associated documents, and node deletion has two modes, 'sibling' (delete current node) and 'child' (delete current node and its children).
|
||||
|
||||
|
||||
- name: action_mode
|
||||
type: select
|
||||
required: false
|
||||
form: llm
|
||||
options:
|
||||
- value: sibling
|
||||
label:
|
||||
en_US: sibling
|
||||
zh_Hans: 同级
|
||||
pt_BR: sibling
|
||||
- value: child
|
||||
label:
|
||||
en_US: child
|
||||
zh_Hans: 子集
|
||||
pt_BR: child
|
||||
label:
|
||||
en_US: Action Type
|
||||
zh_Hans: 操作
|
||||
human_description:
|
||||
en_US: Operation mode (sibling:same level, child:child level).
|
||||
zh_Hans: 操作模式 (sibling:同级, child:子级)。
|
||||
llm_description: Operation mode (sibling:same level, child:child level).
|
||||
|
||||
- name: target_uuid
|
||||
type: string
|
||||
required: false
|
||||
form: llm
|
||||
label:
|
||||
en_US: Target node UUID
|
||||
zh_Hans: 目标节点 UUID
|
||||
human_description:
|
||||
en_US: Target node UUID, defaults to root node if left empty.
|
||||
zh_Hans: 目标节点 UUID, 不填默认为根节点。
|
||||
llm_description: Target node UUID, defaults to root node if left empty.
|
||||
|
||||
- name: node_uuid
|
||||
type: string
|
||||
required: false
|
||||
form: llm
|
||||
label:
|
||||
en_US: Node UUID
|
||||
zh_Hans: 操作节点 UUID
|
||||
human_description:
|
||||
en_US: Operation node UUID [required for move/update/delete].
|
||||
zh_Hans: 操作节点 UUID [移动/更新/删除必填]。
|
||||
llm_description: Operation node UUID [required for move/update/delete].
|
||||
|
||||
- name: doc_ids
|
||||
type: string
|
||||
required: false
|
||||
form: llm
|
||||
label:
|
||||
en_US: Document IDs
|
||||
zh_Hans: 文档id列表
|
||||
human_description:
|
||||
en_US: Document IDs [required for creating documents], separate multiple IDs with ','.
|
||||
zh_Hans: 文档 IDs [创建文档必填],多个用','分隔。
|
||||
llm_description: Document IDs [required for creating documents], separate multiple IDs with ','.
|
||||
|
||||
|
||||
- name: type
|
||||
type: select
|
||||
required: false
|
||||
form: llm
|
||||
default: DOC
|
||||
options:
|
||||
- value: DOC
|
||||
label:
|
||||
en_US: DOC
|
||||
zh_Hans: 文档
|
||||
pt_BR: DOC
|
||||
- value: LINK
|
||||
label:
|
||||
en_US: LINK
|
||||
zh_Hans: 链接
|
||||
pt_BR: LINK
|
||||
- value: TITLE
|
||||
label:
|
||||
en_US: TITLE
|
||||
zh_Hans: 分组
|
||||
pt_BR: TITLE
|
||||
label:
|
||||
en_US: Node type
|
||||
zh_Hans: 操节点类型
|
||||
human_description:
|
||||
en_US: Node type [required for creation] (DOC:document, LINK:external link, TITLE:group).
|
||||
zh_Hans: 操节点类型 [创建必填] (DOC:文档, LINK:外链, TITLE:分组)。
|
||||
llm_description: Node type [required for creation] (DOC:document, LINK:external link, TITLE:group).
|
||||
|
||||
- name: title
|
||||
type: string
|
||||
required: false
|
||||
form: llm
|
||||
label:
|
||||
en_US: Node Name
|
||||
zh_Hans: 节点名称
|
||||
human_description:
|
||||
en_US: Node name [required for creating groups/external links].
|
||||
zh_Hans: 节点名称 [创建分组/外链必填]。
|
||||
llm_description: Node name [required for creating groups/external links].
|
||||
|
||||
- name: url
|
||||
type: string
|
||||
required: false
|
||||
form: llm
|
||||
label:
|
||||
en_US: Node URL
|
||||
zh_Hans: 节点URL
|
||||
human_description:
|
||||
en_US: Node URL [required for creating external links].
|
||||
zh_Hans: 节点 URL [创建外链必填]。
|
||||
llm_description: Node URL [required for creating external links].
|
||||
|
||||
|
||||
- name: open_window
|
||||
type: select
|
||||
required: false
|
||||
form: llm
|
||||
default: 0
|
||||
options:
|
||||
- value: 0
|
||||
label:
|
||||
en_US: DOC
|
||||
zh_Hans: Current Page
|
||||
pt_BR: DOC
|
||||
- value: 1
|
||||
label:
|
||||
en_US: LINK
|
||||
zh_Hans: New Page
|
||||
pt_BR: LINK
|
||||
label:
|
||||
en_US: Open in new window
|
||||
zh_Hans: 是否新窗口打开
|
||||
human_description:
|
||||
en_US: Open in new window [optional for external links] (0:open in current page, 1:open in new window).
|
||||
zh_Hans: 是否新窗口打开 [外链选填] (0:当前页打开, 1:新窗口打开)。
|
||||
llm_description: Open in new window [optional for external links] (0:open in current page, 1:open in new window).
|
||||
|
||||
|
||||
- name: visible
|
||||
type: select
|
||||
required: false
|
||||
form: llm
|
||||
default: 1
|
||||
options:
|
||||
- value: 0
|
||||
label:
|
||||
en_US: Invisible
|
||||
zh_Hans: 隐藏
|
||||
pt_BR: Invisible
|
||||
- value: 1
|
||||
label:
|
||||
en_US: Visible
|
||||
zh_Hans: 可见
|
||||
pt_BR: Visible
|
||||
label:
|
||||
en_US: Visibility
|
||||
zh_Hans: 是否可见
|
||||
human_description:
|
||||
en_US: Visibility (0:invisible, 1:visible).
|
||||
zh_Hans: 是否可见 (0:不可见, 1:可见)。
|
||||
llm_description: Visibility (0:invisible, 1:visible).
|
||||
@ -1,17 +0,0 @@
|
||||
from typing import Any, Union
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class AliYuqueUpdateDocumentTool(AliYuqueTool, BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
token = self.runtime.credentials.get("token", None)
|
||||
if not token:
|
||||
raise Exception("token is required")
|
||||
return self.create_text_message(
|
||||
self.request("PUT", token, tool_parameters, "/api/v2/repos/{book_id}/docs/{id}")
|
||||
)
|
||||
@ -1,87 +0,0 @@
|
||||
identity:
|
||||
name: aliyuque_update_document
|
||||
author: 佐井
|
||||
label:
|
||||
en_US: Update Document
|
||||
zh_Hans: 更新文档
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: Update an existing document within a specified knowledge base by providing the document ID or path.
|
||||
zh_Hans: 通过提供文档ID或路径,更新指定知识库中的现有文档。
|
||||
llm: Update doc in a knowledge base via ID/path.
|
||||
parameters:
|
||||
- name: book_id
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
en_US: Knowledge Base ID
|
||||
zh_Hans: 知识库 ID
|
||||
human_description:
|
||||
en_US: The unique identifier of the knowledge base where the document resides.
|
||||
zh_Hans: 文档所属知识库的ID。
|
||||
llm_description: ID of the knowledge base holding the doc.
|
||||
- name: id
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
en_US: Document ID or Path
|
||||
zh_Hans: 文档 ID 或 路径
|
||||
human_description:
|
||||
en_US: The unique identifier or the path of the document to be updated.
|
||||
zh_Hans: 要更新的文档的唯一ID或路径。
|
||||
llm_description: Doc's ID or path for update.
|
||||
|
||||
- name: title
|
||||
type: string
|
||||
required: false
|
||||
form: llm
|
||||
label:
|
||||
en_US: Title
|
||||
zh_Hans: 标题
|
||||
human_description:
|
||||
en_US: The title of the document, defaults to 'Untitled' if not provided.
|
||||
zh_Hans: 文档标题,默认为'无标题'如未提供。
|
||||
llm_description: Title of the document, defaults to 'Untitled'.
|
||||
|
||||
- name: format
|
||||
type: select
|
||||
required: false
|
||||
form: llm
|
||||
options:
|
||||
- value: markdown
|
||||
label:
|
||||
en_US: markdown
|
||||
zh_Hans: markdown
|
||||
pt_BR: markdown
|
||||
- value: html
|
||||
label:
|
||||
en_US: html
|
||||
zh_Hans: html
|
||||
pt_BR: html
|
||||
- value: lake
|
||||
label:
|
||||
en_US: lake
|
||||
zh_Hans: lake
|
||||
pt_BR: lake
|
||||
label:
|
||||
en_US: Content Format
|
||||
zh_Hans: 内容格式
|
||||
human_description:
|
||||
en_US: Format of the document content (markdown, HTML, Lake).
|
||||
zh_Hans: 文档内容格式(markdown, HTML, Lake)。
|
||||
llm_description: Content format choices, markdown, HTML, Lake.
|
||||
|
||||
- name: body
|
||||
type: string
|
||||
required: true
|
||||
form: llm
|
||||
label:
|
||||
en_US: Body Content
|
||||
zh_Hans: 正文内容
|
||||
human_description:
|
||||
en_US: The actual content of the document.
|
||||
zh_Hans: 文档的实际内容。
|
||||
llm_description: Content of the document.
|
||||
@ -1,7 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg width="56px" height="56px" viewBox="0 0 56 56" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||
<title>形状结合</title>
|
||||
<g id="设计规范" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||
<path d="M56,0 L56,56 L0,56 L0,0 L56,0 Z M31.6063018,12 L24.3936982,12 L24.1061064,12.7425499 L12.6071308,42.4324141 L12,44 L19.7849972,44 L20.0648488,43.2391815 L22.5196173,36.5567427 L33.4780427,36.5567427 L35.9351512,43.2391815 L36.2150028,44 L44,44 L43.3928692,42.4324141 L31.8938936,12.7425499 L31.6063018,12 Z M28.0163803,21.5755126 L31.1613993,30.2523823 L24.8432808,30.2523823 L28.0163803,21.5755126 Z" id="形状结合" fill="#2F4F4F"></path>
|
||||
</g>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 780 B |
@ -1,22 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.alphavantage.tools.query_stock import QueryStockTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class AlphaVantageProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
QueryStockTool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"code": "AAPL", # Apple Inc.
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
@ -1,31 +0,0 @@
|
||||
identity:
|
||||
author: zhuhao
|
||||
name: alphavantage
|
||||
label:
|
||||
en_US: AlphaVantage
|
||||
zh_Hans: AlphaVantage
|
||||
pt_BR: AlphaVantage
|
||||
description:
|
||||
en_US: AlphaVantage is an online platform that provides financial market data and APIs, making it convenient for individual investors and developers to access stock quotes, technical indicators, and stock analysis.
|
||||
zh_Hans: AlphaVantage是一个在线平台,它提供金融市场数据和API,便于个人投资者和开发者获取股票报价、技术指标和股票分析。
|
||||
pt_BR: AlphaVantage is an online platform that provides financial market data and APIs, making it convenient for individual investors and developers to access stock quotes, technical indicators, and stock analysis.
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- finance
|
||||
credentials_for_provider:
|
||||
api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: AlphaVantage API key
|
||||
zh_Hans: AlphaVantage API key
|
||||
pt_BR: AlphaVantage API key
|
||||
placeholder:
|
||||
en_US: Please input your AlphaVantage API key
|
||||
zh_Hans: 请输入你的 AlphaVantage API key
|
||||
pt_BR: Please input your AlphaVantage API key
|
||||
help:
|
||||
en_US: Get your AlphaVantage API key from AlphaVantage
|
||||
zh_Hans: 从 AlphaVantage 获取您的 AlphaVantage API key
|
||||
pt_BR: Get your AlphaVantage API key from AlphaVantage
|
||||
url: https://www.alphavantage.co/support/#api-key
|
||||
@ -1,48 +0,0 @@
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
ALPHAVANTAGE_API_URL = "https://www.alphavantage.co/query"
|
||||
|
||||
|
||||
class QueryStockTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
stock_code = tool_parameters.get("code", "")
|
||||
if not stock_code:
|
||||
return self.create_text_message("Please tell me your stock code")
|
||||
|
||||
if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"):
|
||||
return self.create_text_message("Alpha Vantage API key is required.")
|
||||
|
||||
params = {
|
||||
"function": "TIME_SERIES_DAILY",
|
||||
"symbol": stock_code,
|
||||
"outputsize": "compact",
|
||||
"datatype": "json",
|
||||
"apikey": self.runtime.credentials["api_key"],
|
||||
}
|
||||
response = requests.get(url=ALPHAVANTAGE_API_URL, params=params)
|
||||
response.raise_for_status()
|
||||
result = self._handle_response(response.json())
|
||||
return self.create_json_message(result)
|
||||
|
||||
def _handle_response(self, response: dict[str, Any]) -> dict[str, Any]:
|
||||
result = response.get("Time Series (Daily)", {})
|
||||
if not result:
|
||||
return {}
|
||||
stock_result = {}
|
||||
for k, v in result.items():
|
||||
stock_result[k] = {}
|
||||
stock_result[k]["open"] = v.get("1. open")
|
||||
stock_result[k]["high"] = v.get("2. high")
|
||||
stock_result[k]["low"] = v.get("3. low")
|
||||
stock_result[k]["close"] = v.get("4. close")
|
||||
stock_result[k]["volume"] = v.get("5. volume")
|
||||
return stock_result
|
||||
@ -1,27 +0,0 @@
|
||||
identity:
|
||||
name: query_stock
|
||||
author: zhuhao
|
||||
label:
|
||||
en_US: query_stock
|
||||
zh_Hans: query_stock
|
||||
pt_BR: query_stock
|
||||
description:
|
||||
human:
|
||||
en_US: Retrieve information such as daily opening price, daily highest price, daily lowest price, daily closing price, and daily trading volume for a specified stock symbol.
|
||||
zh_Hans: 获取指定股票代码的每日开盘价、每日最高价、每日最低价、每日收盘价和每日交易量等信息。
|
||||
pt_BR: Retrieve information such as daily opening price, daily highest price, daily lowest price, daily closing price, and daily trading volume for a specified stock symbol
|
||||
llm: Retrieve information such as daily opening price, daily highest price, daily lowest price, daily closing price, and daily trading volume for a specified stock symbol
|
||||
parameters:
|
||||
- name: code
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: stock code
|
||||
zh_Hans: 股票代码
|
||||
pt_BR: stock code
|
||||
human_description:
|
||||
en_US: stock code
|
||||
zh_Hans: 股票代码
|
||||
pt_BR: stock code
|
||||
llm_description: stock code for query from alphavantage
|
||||
form: llm
|
||||
@ -1 +0,0 @@
|
||||
<svg id="logomark" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 17.732 24.269"><g id="tiny"><path d="M573.549,280.916l2.266,2.738,6.674-7.84c.353-.47.52-.717.353-1.117a1.218,1.218,0,0,0-1.061-.748h0a.953.953,0,0,0-.712.262Z" transform="translate(-566.984 -271.548)" fill="#bdb9b4"/><path d="M579.525,282.225l-10.606-10.174a1.413,1.413,0,0,0-.834-.5,1.09,1.09,0,0,0-1.027.66c-.167.4-.047.681.319,1.206l8.44,10.242h0l-6.282,7.716a1.336,1.336,0,0,0-.323,1.3,1.114,1.114,0,0,0,1.04.69A.992.992,0,0,0,571,293l8.519-7.92A1.924,1.924,0,0,0,579.525,282.225Z" transform="translate(-566.984 -271.548)" fill="#b31b1b"/><path d="M584.32,293.912l-8.525-10.275,0,0L573.53,280.9l-1.389,1.254a2.063,2.063,0,0,0,0,2.965l10.812,10.419a.925.925,0,0,0,.742.282,1.039,1.039,0,0,0,.953-.667A1.261,1.261,0,0,0,584.32,293.912Z" transform="translate(-566.984 -271.548)" fill="#bdb9b4"/></g></svg>
|
||||
|
Before Width: | Height: | Size: 874 B |
@ -1,20 +0,0 @@
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.arxiv.tools.arxiv_search import ArxivSearchTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class ArxivProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
ArxivSearchTool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"query": "John Doe",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
@ -1,14 +0,0 @@
|
||||
identity:
|
||||
author: Yash Parmar
|
||||
name: arxiv
|
||||
label:
|
||||
en_US: ArXiv
|
||||
zh_Hans: ArXiv
|
||||
ja_JP: ArXiv
|
||||
description:
|
||||
en_US: Access to a vast repository of scientific papers and articles in various fields of research.
|
||||
zh_Hans: 访问各个研究领域大量科学论文和文章的存储库。
|
||||
ja_JP: 多様な研究分野の科学論文や記事の膨大なリポジトリへのアクセス。
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- search
|
||||
@ -1,119 +0,0 @@
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
import arxiv # type: ignore
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ArxivAPIWrapper(BaseModel):
|
||||
"""Wrapper around ArxivAPI.
|
||||
|
||||
To use, you should have the ``arxiv`` python package installed.
|
||||
https://lukasschwab.me/arxiv.py/index.html
|
||||
This wrapper will use the Arxiv API to conduct searches and
|
||||
fetch document summaries. By default, it will return the document summaries
|
||||
of the top-k results.
|
||||
It limits the Document content by doc_content_chars_max.
|
||||
Set doc_content_chars_max=None if you don't want to limit the content size.
|
||||
|
||||
Args:
|
||||
top_k_results: number of the top-scored document used for the arxiv tool
|
||||
ARXIV_MAX_QUERY_LENGTH: the cut limit on the query used for the arxiv tool.
|
||||
load_max_docs: a limit to the number of loaded documents
|
||||
load_all_available_meta:
|
||||
if True: the `metadata` of the loaded Documents contains all available
|
||||
meta info (see https://lukasschwab.me/arxiv.py/index.html#Result),
|
||||
if False: the `metadata` contains only the published date, title,
|
||||
authors and summary.
|
||||
doc_content_chars_max: an optional cut limit for the length of a document's
|
||||
content
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
arxiv = ArxivAPIWrapper(
|
||||
top_k_results = 3,
|
||||
ARXIV_MAX_QUERY_LENGTH = 300,
|
||||
load_max_docs = 3,
|
||||
load_all_available_meta = False,
|
||||
doc_content_chars_max = 40000
|
||||
)
|
||||
arxiv.run("tree of thought llm)
|
||||
"""
|
||||
|
||||
arxiv_search: type[arxiv.Search] = arxiv.Search #: :meta private:
|
||||
arxiv_http_error: tuple[type[Exception]] = (arxiv.ArxivError, arxiv.UnexpectedEmptyPageError, arxiv.HTTPError)
|
||||
top_k_results: int = 3
|
||||
ARXIV_MAX_QUERY_LENGTH: int = 300
|
||||
load_max_docs: int = 100
|
||||
load_all_available_meta: bool = False
|
||||
doc_content_chars_max: Optional[int] = 4000
|
||||
|
||||
def run(self, query: str) -> str:
|
||||
"""
|
||||
Performs an arxiv search and A single string
|
||||
with the publish date, title, authors, and summary
|
||||
for each article separated by two newlines.
|
||||
|
||||
If an error occurs or no documents found, error text
|
||||
is returned instead. Wrapper for
|
||||
https://lukasschwab.me/arxiv.py/index.html#Search
|
||||
|
||||
Args:
|
||||
query: a plaintext search query
|
||||
"""
|
||||
try:
|
||||
results = self.arxiv_search( # type: ignore
|
||||
query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.top_k_results
|
||||
).results()
|
||||
except arxiv_http_error as ex:
|
||||
return f"Arxiv exception: {ex}"
|
||||
docs = [
|
||||
f"Published: {result.updated.date()}\n"
|
||||
f"Title: {result.title}\n"
|
||||
f"Authors: {', '.join(a.name for a in result.authors)}\n"
|
||||
f"Summary: {result.summary}"
|
||||
for result in results
|
||||
]
|
||||
if docs:
|
||||
return "\n\n".join(docs)[: self.doc_content_chars_max]
|
||||
else:
|
||||
return "No good Arxiv Result was found"
|
||||
|
||||
|
||||
class ArxivSearchInput(BaseModel):
|
||||
query: str = Field(..., description="Search query.")
|
||||
|
||||
|
||||
class ArxivSearchTool(BuiltinTool):
|
||||
"""
|
||||
A tool for searching articles on Arxiv.
|
||||
"""
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
"""
|
||||
Invokes the Arxiv search tool with the given user ID and tool parameters.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user invoking the tool.
|
||||
tool_parameters (dict[str, Any]): The parameters for the tool, including the 'query' parameter.
|
||||
|
||||
Returns:
|
||||
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation,
|
||||
which can be a single message or a list of messages.
|
||||
"""
|
||||
query = tool_parameters.get("query", "")
|
||||
|
||||
if not query:
|
||||
return self.create_text_message("Please input query")
|
||||
|
||||
arxiv = ArxivAPIWrapper()
|
||||
|
||||
response = arxiv.run(query)
|
||||
|
||||
return self.create_text_message(self.summary(user_id=user_id, content=response))
|
||||
@ -1,27 +0,0 @@
|
||||
identity:
|
||||
name: arxiv_search
|
||||
author: Yash Parmar
|
||||
label:
|
||||
en_US: Arxiv Search
|
||||
zh_Hans: Arxiv 搜索
|
||||
ja_JP: Arxiv 検索
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for searching scientific papers and articles from the Arxiv repository. Input can be an Arxiv ID or an author's name.
|
||||
zh_Hans: 一个用于从Arxiv存储库搜索科学论文和文章的工具。 输入可以是Arxiv ID或作者姓名。
|
||||
ja_JP: Arxivリポジトリから科学論文や記事を検索するためのツールです。入力はArxiv IDまたは著者名にすることができます。
|
||||
llm: A tool for searching scientific papers and articles from the Arxiv repository. Input can be an Arxiv ID or an author's name.
|
||||
parameters:
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Query string
|
||||
zh_Hans: 查询字符串
|
||||
ja_JP: クエリ文字列
|
||||
human_description:
|
||||
en_US: The Arxiv ID or author's name used for searching.
|
||||
zh_Hans: 用于搜索的Arxiv ID或作者姓名。
|
||||
ja_JP: 検索に使用されるArxiv IDまたは著者名。
|
||||
llm_description: The Arxiv ID or author's name used for searching.
|
||||
form: llm
|
||||
@ -1,6 +0,0 @@
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class AudioToolProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
pass
|
||||
@ -1,9 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
|
||||
<!-- Uploaded to: SVG Repo, www.svgrepo.com, Generator: SVG Repo Mixer Tools -->
|
||||
<svg width="800px" height="800px" viewBox="0 0 16 16" xmlns="http://www.w3.org/2000/svg" fill="none">
|
||||
|
||||
<path fill="#252F3E" d="M4.51 7.687c0 .197.02.357.058.475.042.117.096.245.17.384a.233.233 0 01.037.123c0 .053-.032.107-.1.16l-.336.224a.255.255 0 01-.138.048c-.054 0-.107-.026-.16-.074a1.652 1.652 0 01-.192-.251 4.137 4.137 0 01-.165-.315c-.415.491-.936.737-1.564.737-.447 0-.804-.129-1.064-.385-.261-.256-.394-.598-.394-1.025 0-.454.16-.822.484-1.1.325-.278.756-.416 1.304-.416.18 0 .367.016.564.042.197.027.4.07.612.118v-.39c0-.406-.085-.689-.25-.854-.17-.166-.458-.246-.868-.246-.186 0-.377.022-.574.07a4.23 4.23 0 00-.575.181 1.525 1.525 0 01-.186.07.326.326 0 01-.085.016c-.075 0-.112-.054-.112-.166v-.262c0-.085.01-.15.037-.186a.399.399 0 01.15-.113c.185-.096.409-.176.67-.24.26-.07.537-.101.83-.101.633 0 1.096.144 1.394.432.293.288.442.726.442 1.314v1.73h.01zm-2.161.811c.175 0 .356-.032.548-.096.191-.064.362-.182.505-.342a.848.848 0 00.181-.341c.032-.129.054-.283.054-.465V7.03a4.43 4.43 0 00-.49-.09 3.996 3.996 0 00-.5-.033c-.357 0-.618.07-.793.214-.176.144-.26.347-.26.614 0 .25.063.437.196.566.128.133.314.197.559.197zm4.273.577c-.096 0-.16-.016-.202-.054-.043-.032-.08-.106-.112-.208l-1.25-4.127a.938.938 0 01-.049-.214c0-.085.043-.133.128-.133h.522c.1 0 .17.016.207.053.043.032.075.107.107.208l.894 3.535.83-3.535c.026-.106.058-.176.1-.208a.365.365 0 01.214-.053h.425c.102 0 .17.016.213.053.043.032.08.107.101.208l.841 3.578.92-3.578a.458.458 0 01.107-.208.346.346 0 01.208-.053h.495c.085 0 .133.043.133.133 0 .027-.006.054-.01.086a.76.76 0 01-.038.133l-1.283 4.127c-.032.107-.069.177-.111.209a.34.34 0 01-.203.053h-.457c-.101 0-.17-.016-.213-.053-.043-.038-.08-.107-.101-.214L8.213 5.37l-.82 3.439c-.026.107-.058.176-.1.213-.043.038-.118.054-.213.054h-.458zm6.838.144a3.51 3.51 0 01-.82-.096c-.266-.064-.473-.134-.612-.214-.085-.048-.143-.101-.165-.15a.378.378 0 01-.031-.149v-.272c0-.112.042-.166.122-.166a.3.3 0 01.096.016c.032.011.08.032.133.054.18.08.378.144.585.187.213.042.42.064.633.064.336 0 .596-.059.777-.176a.575.575 0 00.277-.508.52.52 0 00-.144-.373c-.095-.102-.276-.193-.537-.278l-.772-.24c-.388-.123-.676-.305-.851-.545a1.275 1.275 0 01-.266-.774c0-.224.048-.422.143-.593.096-.17.224-.32.384-.438.16-.122.34-.213.553-.277.213-.064.436-.091.67-.091.118 0 .24.005.357.021.122.016.234.038.346.06.106.026.208.052.303.085.096.032.17.064.224.096a.46.46 0 01.16.133.289.289 0 01.047.176v.251c0 .112-.042.171-.122.171a.552.552 0 01-.202-.064 2.427 2.427 0 00-1.022-.208c-.303 0-.543.048-.708.15-.165.1-.25.256-.25.475 0 .149.053.277.16.379.106.101.303.202.585.293l.756.24c.383.123.66.294.825.513.165.219.244.47.244.748 0 .23-.047.437-.138.619a1.436 1.436 0 01-.388.47c-.165.133-.362.23-.591.299-.24.075-.49.112-.761.112z"/>
|
||||
|
||||
<g fill="#F90" fill-rule="evenodd" clip-rule="evenodd">
|
||||
|
||||
|
Before Width: | Height: | Size: 3.3 KiB |
@ -1,24 +0,0 @@
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.aws.tools.sagemaker_text_rerank import SageMakerReRankTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class SageMakerProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
SageMakerReRankTool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"sagemaker_endpoint": "",
|
||||
"query": "misaka mikoto",
|
||||
"candidate_texts": "hello$$$hello world",
|
||||
"topk": 5,
|
||||
"aws_region": "",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
@ -1,15 +0,0 @@
|
||||
identity:
|
||||
author: AWS
|
||||
name: aws
|
||||
label:
|
||||
en_US: AWS
|
||||
zh_Hans: 亚马逊云科技
|
||||
pt_BR: AWS
|
||||
description:
|
||||
en_US: Services on AWS.
|
||||
zh_Hans: 亚马逊云科技的各类服务
|
||||
pt_BR: Services on AWS.
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- search
|
||||
credentials_for_provider:
|
||||
@ -1,90 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
|
||||
import boto3 # type: ignore
|
||||
from botocore.exceptions import BotoCoreError # type: ignore
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GuardrailParameters(BaseModel):
|
||||
guardrail_id: str = Field(..., description="The identifier of the guardrail")
|
||||
guardrail_version: str = Field(..., description="The version of the guardrail")
|
||||
source: str = Field(..., description="The source of the content")
|
||||
text: str = Field(..., description="The text to apply the guardrail to")
|
||||
aws_region: str = Field(..., description="AWS region for the Bedrock client")
|
||||
|
||||
|
||||
class ApplyGuardrailTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
Invoke the ApplyGuardrail tool
|
||||
"""
|
||||
try:
|
||||
# Validate and parse input parameters
|
||||
params = GuardrailParameters(**tool_parameters)
|
||||
|
||||
# Initialize AWS client
|
||||
bedrock_client = boto3.client("bedrock-runtime", region_name=params.aws_region)
|
||||
|
||||
# Apply guardrail
|
||||
response = bedrock_client.apply_guardrail(
|
||||
guardrailIdentifier=params.guardrail_id,
|
||||
guardrailVersion=params.guardrail_version,
|
||||
source=params.source,
|
||||
content=[{"text": {"text": params.text}}],
|
||||
)
|
||||
|
||||
logger.info(f"Raw response from AWS: {json.dumps(response, indent=2)}")
|
||||
|
||||
# Check for empty response
|
||||
if not response:
|
||||
return self.create_text_message(text="Received empty response from AWS Bedrock.")
|
||||
|
||||
# Process the result
|
||||
action = response.get("action", "No action specified")
|
||||
outputs = response.get("outputs", [])
|
||||
output = outputs[0].get("text", "No output received") if outputs else "No output received"
|
||||
assessments = response.get("assessments", [])
|
||||
|
||||
# Format assessments
|
||||
formatted_assessments = []
|
||||
for assessment in assessments:
|
||||
for policy_type, policy_data in assessment.items():
|
||||
if isinstance(policy_data, dict) and "topics" in policy_data:
|
||||
for topic in policy_data["topics"]:
|
||||
formatted_assessments.append(
|
||||
f"Policy: {policy_type}, Topic: {topic['name']}, Type: {topic['type']},"
|
||||
f" Action: {topic['action']}"
|
||||
)
|
||||
else:
|
||||
formatted_assessments.append(f"Policy: {policy_type}, Data: {policy_data}")
|
||||
|
||||
result = f"Action: {action}\n "
|
||||
result += f"Output: {output}\n "
|
||||
if formatted_assessments:
|
||||
result += "Assessments:\n " + "\n ".join(formatted_assessments) + "\n "
|
||||
# result += f"Full response: {json.dumps(response, indent=2, ensure_ascii=False)}"
|
||||
|
||||
return self.create_text_message(text=result)
|
||||
|
||||
except BotoCoreError as e:
|
||||
error_message = f"AWS service error: {str(e)}"
|
||||
logger.error(error_message, exc_info=True)
|
||||
return self.create_text_message(text=error_message)
|
||||
except json.JSONDecodeError as e:
|
||||
error_message = f"JSON parsing error: {str(e)}"
|
||||
logger.error(error_message, exc_info=True)
|
||||
return self.create_text_message(text=error_message)
|
||||
except Exception as e:
|
||||
error_message = f"An unexpected error occurred: {str(e)}"
|
||||
logger.error(error_message, exc_info=True)
|
||||
return self.create_text_message(text=error_message)
|
||||
@ -1,67 +0,0 @@
|
||||
identity:
|
||||
name: apply_guardrail
|
||||
author: AWS
|
||||
label:
|
||||
en_US: Content Moderation Guardrails
|
||||
zh_Hans: 内容审查护栏
|
||||
description:
|
||||
human:
|
||||
en_US: Content Moderation Guardrails utilizes the ApplyGuardrail API, a feature of Guardrails for Amazon Bedrock. This API is capable of evaluating input prompts and model responses for all Foundation Models (FMs), including those on Amazon Bedrock, custom FMs, and third-party FMs. By implementing this functionality, organizations can achieve centralized governance across all their generative AI applications, thereby enhancing control and consistency in content moderation.
|
||||
zh_Hans: 内容审查护栏采用 Guardrails for Amazon Bedrock 功能中的 ApplyGuardrail API 。ApplyGuardrail 可以评估所有基础模型(FMs)的输入提示和模型响应,包括 Amazon Bedrock 上的 FMs、自定义 FMs 和第三方 FMs。通过实施这一功能, 组织可以在所有生成式 AI 应用程序中实现集中化的治理,从而增强内容审核的控制力和一致性。
|
||||
llm: Content Moderation Guardrails utilizes the ApplyGuardrail API, a feature of Guardrails for Amazon Bedrock. This API is capable of evaluating input prompts and model responses for all Foundation Models (FMs), including those on Amazon Bedrock, custom FMs, and third-party FMs. By implementing this functionality, organizations can achieve centralized governance across all their generative AI applications, thereby enhancing control and consistency in content moderation.
|
||||
parameters:
|
||||
- name: guardrail_id
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Guardrail ID
|
||||
zh_Hans: Guardrail ID
|
||||
human_description:
|
||||
en_US: Please enter the ID of the Guardrail that has already been created on Amazon Bedrock, for example 'qk5nk0e4b77b'.
|
||||
zh_Hans: 请输入已经在 Amazon Bedrock 上创建好的 Guardrail ID, 例如 'qk5nk0e4b77b'.
|
||||
llm_description: Please enter the ID of the Guardrail that has already been created on Amazon Bedrock, for example 'qk5nk0e4b77b'.
|
||||
form: form
|
||||
- name: guardrail_version
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Guardrail Version Number
|
||||
zh_Hans: Guardrail 版本号码
|
||||
human_description:
|
||||
en_US: Please enter the published version of the Guardrail ID that has already been created on Amazon Bedrock. This is typically a version number, such as 2.
|
||||
zh_Hans: 请输入已经在Amazon Bedrock 上创建好的Guardrail ID发布的版本, 通常使用版本号, 例如2.
|
||||
llm_description: Please enter the published version of the Guardrail ID that has already been created on Amazon Bedrock. This is typically a version number, such as 2.
|
||||
form: form
|
||||
- name: source
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Content Source (INPUT or OUTPUT)
|
||||
zh_Hans: 内容来源 (INPUT or OUTPUT)
|
||||
human_description:
|
||||
en_US: The source of data used in the request to apply the guardrail. Valid Values "INPUT | OUTPUT"
|
||||
zh_Hans: 用于应用护栏的请求中所使用的数据来源。有效值为 "INPUT | OUTPUT"
|
||||
llm_description: The source of data used in the request to apply the guardrail. Valid Values "INPUT | OUTPUT"
|
||||
form: form
|
||||
- name: text
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Content to be reviewed
|
||||
zh_Hans: 待审查内容
|
||||
human_description:
|
||||
en_US: The content used for requesting guardrail review, which can be either user input or LLM output.
|
||||
zh_Hans: 用于请求护栏审查的内容,可以是用户输入或 LLM 输出。
|
||||
llm_description: The content used for requesting guardrail review, which can be either user input or LLM output.
|
||||
form: llm
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: AWS Region
|
||||
zh_Hans: AWS 区域
|
||||
human_description:
|
||||
en_US: Please enter the AWS region for the Bedrock client, for example 'us-east-1'.
|
||||
zh_Hans: 请输入 Bedrock 客户端的 AWS 区域,例如 'us-east-1'。
|
||||
llm_description: Please enter the AWS region for the Bedrock client, for example 'us-east-1'.
|
||||
form: form
|
||||
@ -1,162 +0,0 @@
|
||||
import json
|
||||
import operator
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import boto3
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class BedrockRetrieveTool(BuiltinTool):
|
||||
bedrock_client: Any = None
|
||||
knowledge_base_id: str = None
|
||||
topk: int = None
|
||||
|
||||
def _bedrock_retrieve(
|
||||
self,
|
||||
query_input: str,
|
||||
knowledge_base_id: str,
|
||||
num_results: int,
|
||||
search_type: str,
|
||||
rerank_model_id: str,
|
||||
metadata_filter: Optional[dict] = None,
|
||||
):
|
||||
try:
|
||||
retrieval_query = {"text": query_input}
|
||||
|
||||
if search_type not in ["HYBRID", "SEMANTIC"]:
|
||||
raise RuntimeException("search_type should be HYBRID or SEMANTIC")
|
||||
|
||||
retrieval_configuration = {
|
||||
"vectorSearchConfiguration": {"numberOfResults": num_results, "overrideSearchType": search_type}
|
||||
}
|
||||
|
||||
if rerank_model_id != "default":
|
||||
model_for_rerank_arn = f"arn:aws:bedrock:us-west-2::foundation-model/{rerank_model_id}"
|
||||
rerankingConfiguration = {
|
||||
"bedrockRerankingConfiguration": {
|
||||
"numberOfRerankedResults": num_results,
|
||||
"modelConfiguration": {"modelArn": model_for_rerank_arn},
|
||||
},
|
||||
"type": "BEDROCK_RERANKING_MODEL",
|
||||
}
|
||||
|
||||
retrieval_configuration["vectorSearchConfiguration"]["rerankingConfiguration"] = rerankingConfiguration
|
||||
retrieval_configuration["vectorSearchConfiguration"]["numberOfResults"] = num_results * 5
|
||||
|
||||
# 如果有元数据过滤条件,则添加到检索配置中
|
||||
if metadata_filter:
|
||||
retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter
|
||||
|
||||
response = self.bedrock_client.retrieve(
|
||||
knowledgeBaseId=knowledge_base_id,
|
||||
retrievalQuery=retrieval_query,
|
||||
retrievalConfiguration=retrieval_configuration,
|
||||
)
|
||||
|
||||
results = []
|
||||
for result in response.get("retrievalResults", []):
|
||||
results.append(
|
||||
{
|
||||
"content": result.get("content", {}).get("text", ""),
|
||||
"score": result.get("score", 0.0),
|
||||
"metadata": result.get("metadata", {}),
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
raise Exception(f"Error retrieving from knowledge base: {str(e)}")
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
try:
|
||||
line = 0
|
||||
# Initialize Bedrock client if not already initialized
|
||||
if not self.bedrock_client:
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
aws_access_key_id = tool_parameters.get("aws_access_key_id")
|
||||
aws_secret_access_key = tool_parameters.get("aws_secret_access_key")
|
||||
|
||||
client_kwargs = {"service_name": "bedrock-agent-runtime", "region_name": aws_region or None}
|
||||
|
||||
# Only add credentials if both access key and secret key are provided
|
||||
if aws_access_key_id and aws_secret_access_key:
|
||||
client_kwargs.update(
|
||||
{"aws_access_key_id": aws_access_key_id, "aws_secret_access_key": aws_secret_access_key}
|
||||
)
|
||||
|
||||
self.bedrock_client = boto3.client(**client_kwargs)
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Failed to initialize Bedrock client: {str(e)}")
|
||||
|
||||
try:
|
||||
line = 1
|
||||
if not self.knowledge_base_id:
|
||||
self.knowledge_base_id = tool_parameters.get("knowledge_base_id")
|
||||
if not self.knowledge_base_id:
|
||||
return self.create_text_message("Please provide knowledge_base_id")
|
||||
|
||||
line = 2
|
||||
if not self.topk:
|
||||
self.topk = tool_parameters.get("topk", 5)
|
||||
|
||||
line = 3
|
||||
query = tool_parameters.get("query", "")
|
||||
if not query:
|
||||
return self.create_text_message("Please input query")
|
||||
|
||||
# 获取元数据过滤条件(如果存在)
|
||||
metadata_filter_str = tool_parameters.get("metadata_filter")
|
||||
metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None
|
||||
|
||||
search_type = tool_parameters.get("search_type")
|
||||
rerank_model_id = tool_parameters.get("rerank_model_id")
|
||||
|
||||
line = 4
|
||||
retrieved_docs = self._bedrock_retrieve(
|
||||
query_input=query,
|
||||
knowledge_base_id=self.knowledge_base_id,
|
||||
num_results=self.topk,
|
||||
search_type=search_type,
|
||||
rerank_model_id=rerank_model_id,
|
||||
metadata_filter=metadata_filter,
|
||||
)
|
||||
|
||||
line = 5
|
||||
# Sort results by score in descending order
|
||||
sorted_docs = sorted(retrieved_docs, key=operator.itemgetter("score"), reverse=True)
|
||||
|
||||
line = 6
|
||||
result_type = tool_parameters.get("result_type")
|
||||
if result_type == "json":
|
||||
return [self.create_json_message(res) for res in sorted_docs]
|
||||
else:
|
||||
text = ""
|
||||
for i, res in enumerate(sorted_docs):
|
||||
text += f"{i + 1}: {res['content']}\n"
|
||||
return self.create_text_message(text)
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Exception {str(e)}, line : {line}")
|
||||
|
||||
def validate_parameters(self, parameters: dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the parameters
|
||||
"""
|
||||
if not parameters.get("knowledge_base_id"):
|
||||
raise ValueError("knowledge_base_id is required")
|
||||
|
||||
if not parameters.get("query"):
|
||||
raise ValueError("query is required")
|
||||
|
||||
metadata_filter_str = parameters.get("metadata_filter")
|
||||
if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict):
|
||||
raise ValueError("metadata_filter must be a valid JSON object")
|
||||
@ -1,179 +0,0 @@
|
||||
identity:
|
||||
name: bedrock_retrieve
|
||||
author: AWS
|
||||
label:
|
||||
en_US: Bedrock Retrieve
|
||||
zh_Hans: Bedrock检索
|
||||
pt_BR: Bedrock Retrieve
|
||||
icon: icon.svg
|
||||
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. You can find deploy instructions on Github Repo - https://github.com/aws-samples/dify-aws-tool
|
||||
zh_Hans: Amazon Bedrock知识库检索工具, 请参考 Github Repo - https://github.com/aws-samples/dify-aws-tool上的部署说明
|
||||
pt_BR: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base.
|
||||
llm: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. You can find deploy instructions on Github Repo - https://github.com/aws-samples/dify-aws-tool
|
||||
|
||||
parameters:
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: AWS Region
|
||||
zh_Hans: AWS区域
|
||||
human_description:
|
||||
en_US: AWS region for the Bedrock service
|
||||
zh_Hans: Bedrock服务的AWS区域
|
||||
form: form
|
||||
|
||||
- name: aws_access_key_id
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: AWS Access Key ID
|
||||
zh_Hans: AWS访问密钥ID
|
||||
human_description:
|
||||
en_US: AWS access key ID for authentication (optional)
|
||||
zh_Hans: 用于身份验证的AWS访问密钥ID(可选)
|
||||
form: form
|
||||
|
||||
- name: aws_secret_access_key
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: AWS Secret Access Key
|
||||
zh_Hans: AWS秘密访问密钥
|
||||
human_description:
|
||||
en_US: AWS secret access key for authentication (optional)
|
||||
zh_Hans: 用于身份验证的AWS秘密访问密钥(可选)
|
||||
form: form
|
||||
|
||||
- name: result_type
|
||||
type: select
|
||||
required: true
|
||||
label:
|
||||
en_US: result type
|
||||
zh_Hans: 结果类型
|
||||
human_description:
|
||||
en_US: return a list of json or texts
|
||||
zh_Hans: 返回一个列表,内容是json还是纯文本
|
||||
default: text
|
||||
options:
|
||||
- value: json
|
||||
label:
|
||||
en_US: JSON
|
||||
zh_Hans: JSON
|
||||
- value: text
|
||||
label:
|
||||
en_US: Text
|
||||
zh_Hans: 文本
|
||||
form: form
|
||||
|
||||
- name: knowledge_base_id
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Bedrock Knowledge Base ID
|
||||
zh_Hans: Bedrock知识库ID
|
||||
pt_BR: Bedrock Knowledge Base ID
|
||||
human_description:
|
||||
en_US: ID of the Bedrock Knowledge Base to retrieve from
|
||||
zh_Hans: 用于检索的Bedrock知识库ID
|
||||
pt_BR: ID of the Bedrock Knowledge Base to retrieve from
|
||||
llm_description: ID of the Bedrock Knowledge Base to retrieve from
|
||||
form: form
|
||||
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Query string
|
||||
zh_Hans: 查询语句
|
||||
pt_BR: Query string
|
||||
human_description:
|
||||
en_US: The search query to retrieve relevant information
|
||||
zh_Hans: 用于检索相关信息的查询语句
|
||||
pt_BR: The search query to retrieve relevant information
|
||||
llm_description: The search query to retrieve relevant information
|
||||
form: llm
|
||||
|
||||
- name: topk
|
||||
type: number
|
||||
required: false
|
||||
form: form
|
||||
label:
|
||||
en_US: Limit for results count
|
||||
zh_Hans: 返回结果数量限制
|
||||
pt_BR: Limit for results count
|
||||
human_description:
|
||||
en_US: Maximum number of results to return
|
||||
zh_Hans: 最大返回结果数量
|
||||
pt_BR: Maximum number of results to return
|
||||
min: 1
|
||||
max: 10
|
||||
default: 5
|
||||
|
||||
- name: search_type
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: search type
|
||||
zh_Hans: 搜索类型
|
||||
pt_BR: search type
|
||||
human_description:
|
||||
en_US: search type
|
||||
zh_Hans: 搜索类型
|
||||
pt_BR: search type
|
||||
llm_description: search type
|
||||
default: SEMANTIC
|
||||
options:
|
||||
- value: SEMANTIC
|
||||
label:
|
||||
en_US: SEMANTIC
|
||||
zh_Hans: 语义搜索
|
||||
- value: HYBRID
|
||||
label:
|
||||
en_US: HYBRID
|
||||
zh_Hans: 混合搜索
|
||||
form: form
|
||||
|
||||
- name: rerank_model_id
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: rerank model id
|
||||
zh_Hans: 重拍模型ID
|
||||
pt_BR: rerank model id
|
||||
human_description:
|
||||
en_US: rerank model id
|
||||
zh_Hans: 重拍模型ID
|
||||
pt_BR: rerank model id
|
||||
llm_description: rerank model id
|
||||
default: default
|
||||
options:
|
||||
- value: default
|
||||
label:
|
||||
en_US: default
|
||||
zh_Hans: 默认
|
||||
- value: cohere.rerank-v3-5:0
|
||||
label:
|
||||
en_US: cohere.rerank-v3-5:0
|
||||
zh_Hans: cohere.rerank-v3-5:0
|
||||
- value: amazon.rerank-v1:0
|
||||
label:
|
||||
en_US: amazon.rerank-v1:0
|
||||
zh_Hans: amazon.rerank-v1:0
|
||||
form: form
|
||||
|
||||
- name: metadata_filter # Additional parameter for metadata filtering
|
||||
type: string # String type, expects JSON-formatted filter conditions
|
||||
required: false # Optional field - can be omitted
|
||||
label:
|
||||
en_US: Metadata Filter
|
||||
zh_Hans: 元数据过滤器
|
||||
pt_BR: Metadata Filter
|
||||
human_description:
|
||||
en_US: 'JSON formatted filter conditions for metadata (e.g., {"greaterThan": {"key: "aaa", "value": 10}})'
|
||||
zh_Hans: '元数据的JSON格式过滤条件(例如,{{"greaterThan": {"key: "aaa", "value": 10}})'
|
||||
pt_BR: 'JSON formatted filter conditions for metadata (e.g., {"greaterThan": {"key: "aaa", "value": 10}})'
|
||||
form: form
|
||||
@ -1,137 +0,0 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import boto3
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class BedrockRetrieveAndGenerateTool(BuiltinTool):
|
||||
bedrock_client: Any = None
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> ToolInvokeMessage:
|
||||
try:
|
||||
# Initialize Bedrock client if not already initialized
|
||||
if not self.bedrock_client:
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
aws_access_key_id = tool_parameters.get("aws_access_key_id")
|
||||
aws_secret_access_key = tool_parameters.get("aws_secret_access_key")
|
||||
|
||||
client_kwargs = {"service_name": "bedrock-agent-runtime", "region_name": aws_region or None}
|
||||
|
||||
# Only add credentials if both access key and secret key are provided
|
||||
if aws_access_key_id and aws_secret_access_key:
|
||||
client_kwargs.update(
|
||||
{"aws_access_key_id": aws_access_key_id, "aws_secret_access_key": aws_secret_access_key}
|
||||
)
|
||||
|
||||
self.bedrock_client = boto3.client(**client_kwargs)
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Failed to initialize Bedrock client: {str(e)}")
|
||||
|
||||
try:
|
||||
request_config = {}
|
||||
|
||||
# Set input configuration
|
||||
input_text = tool_parameters.get("input")
|
||||
if input_text:
|
||||
request_config["input"] = {"text": input_text}
|
||||
|
||||
# Build retrieve and generate configuration
|
||||
config_type = tool_parameters.get("type")
|
||||
retrieve_generate_config = {"type": config_type}
|
||||
|
||||
# Add configuration based on type
|
||||
if config_type == "KNOWLEDGE_BASE":
|
||||
kb_config_str = tool_parameters.get("knowledge_base_configuration")
|
||||
kb_config = json.loads(kb_config_str) if kb_config_str else None
|
||||
retrieve_generate_config["knowledgeBaseConfiguration"] = kb_config
|
||||
else: # EXTERNAL_SOURCES
|
||||
es_config_str = tool_parameters.get("external_sources_configuration")
|
||||
es_config = json.loads(kb_config_str) if es_config_str else None
|
||||
retrieve_generate_config["externalSourcesConfiguration"] = es_config
|
||||
|
||||
request_config["retrieveAndGenerateConfiguration"] = retrieve_generate_config
|
||||
|
||||
# Parse session configuration
|
||||
session_config_str = tool_parameters.get("session_configuration")
|
||||
session_config = json.loads(session_config_str) if session_config_str else None
|
||||
if session_config:
|
||||
request_config["sessionConfiguration"] = session_config
|
||||
|
||||
# Add session ID if provided
|
||||
session_id = tool_parameters.get("session_id")
|
||||
if session_id:
|
||||
request_config["sessionId"] = session_id
|
||||
|
||||
# Send request
|
||||
response = self.bedrock_client.retrieve_and_generate(**request_config)
|
||||
|
||||
# Process response
|
||||
result = {"output": response.get("output", {}).get("text", ""), "citations": []}
|
||||
|
||||
# Process citations
|
||||
for citation in response.get("citations", []):
|
||||
citation_info = {
|
||||
"text": citation.get("generatedResponsePart", {}).get("textResponsePart", {}).get("text", ""),
|
||||
"references": [],
|
||||
}
|
||||
|
||||
for ref in citation.get("retrievedReferences", []):
|
||||
reference = {
|
||||
"content": ref.get("content", {}).get("text", ""),
|
||||
"metadata": ref.get("metadata", {}),
|
||||
"location": None,
|
||||
}
|
||||
|
||||
location = ref.get("location", {})
|
||||
if location.get("type") == "S3":
|
||||
reference["location"] = location.get("s3Location", {}).get("uri")
|
||||
|
||||
citation_info["references"].append(reference)
|
||||
|
||||
result["citations"].append(citation_info)
|
||||
result_type = tool_parameters.get("result_type")
|
||||
if result_type == "json":
|
||||
return self.create_json_message(result)
|
||||
elif result_type == "text-with-citations":
|
||||
return self.create_text_message(result)
|
||||
else:
|
||||
return self.create_text_message(result.get("output"))
|
||||
except json.JSONDecodeError as e:
|
||||
return self.create_text_message(f"Invalid JSON format: {str(e)}")
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Tool invocation error: {str(e)}")
|
||||
|
||||
def validate_parameters(self, parameters: dict[str, Any]) -> None:
|
||||
"""Validate the parameters"""
|
||||
# Validate required parameters
|
||||
if not parameters.get("input"):
|
||||
raise ValueError("input is required")
|
||||
if not parameters.get("type"):
|
||||
raise ValueError("type is required")
|
||||
|
||||
# Validate JSON configurations
|
||||
json_configs = ["knowledge_base_configuration", "external_sources_configuration", "session_configuration"]
|
||||
for config in json_configs:
|
||||
if config_value := parameters.get(config):
|
||||
try:
|
||||
json.loads(config_value)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"{config} must be a valid JSON string")
|
||||
|
||||
# Validate configuration type
|
||||
config_type = parameters.get("type")
|
||||
if config_type not in ["KNOWLEDGE_BASE", "EXTERNAL_SOURCES"]:
|
||||
raise ValueError("type must be either KNOWLEDGE_BASE or EXTERNAL_SOURCES")
|
||||
|
||||
# Validate type-specific configuration
|
||||
if config_type == "KNOWLEDGE_BASE" and not parameters.get("knowledge_base_configuration"):
|
||||
raise ValueError("knowledge_base_configuration is required when type is KNOWLEDGE_BASE")
|
||||
elif config_type == "EXTERNAL_SOURCES" and not parameters.get("external_sources_configuration"):
|
||||
raise ValueError("external_sources_configuration is required when type is EXTERNAL_SOURCES")
|
||||
@ -1,148 +0,0 @@
|
||||
identity:
|
||||
name: bedrock_retrieve_and_generate
|
||||
author: AWS
|
||||
label:
|
||||
en_US: Bedrock Retrieve and Generate
|
||||
zh_Hans: Bedrock检索和生成
|
||||
icon: icon.svg
|
||||
|
||||
description:
|
||||
human:
|
||||
en_US: "This is an advanced usage of Bedrock Retrieve. Please refer to the API documentation for detailed parameters and paste them into the corresponding Knowledge Base Configuration or External Sources Configuration"
|
||||
zh_Hans: "这个工具为Bedrock Retrieve的高级用法,请参考API设置详细的参数,并粘贴到对应的知识库配置或者外部源配置"
|
||||
llm: A tool for retrieving and generating information using Amazon Bedrock Knowledge Base
|
||||
|
||||
parameters:
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: AWS Region
|
||||
zh_Hans: AWS区域
|
||||
human_description:
|
||||
en_US: AWS region for the Bedrock service
|
||||
zh_Hans: Bedrock服务的AWS区域
|
||||
form: form
|
||||
|
||||
- name: aws_access_key_id
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: AWS Access Key ID
|
||||
zh_Hans: AWS访问密钥ID
|
||||
human_description:
|
||||
en_US: AWS access key ID for authentication (optional)
|
||||
zh_Hans: 用于身份验证的AWS访问密钥ID(可选)
|
||||
form: form
|
||||
|
||||
- name: aws_secret_access_key
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: AWS Secret Access Key
|
||||
zh_Hans: AWS秘密访问密钥
|
||||
human_description:
|
||||
en_US: AWS secret access key for authentication (optional)
|
||||
zh_Hans: 用于身份验证的AWS秘密访问密钥(可选)
|
||||
form: form
|
||||
|
||||
- name: result_type
|
||||
type: select
|
||||
required: true
|
||||
label:
|
||||
en_US: result type
|
||||
zh_Hans: 结果类型
|
||||
human_description:
|
||||
en_US: return a list of json or texts
|
||||
zh_Hans: 返回一个列表,内容是json还是纯文本
|
||||
default: text
|
||||
options:
|
||||
- value: json
|
||||
label:
|
||||
en_US: JSON
|
||||
zh_Hans: JSON
|
||||
- value: text
|
||||
label:
|
||||
en_US: Text
|
||||
zh_Hans: 文本
|
||||
- value: text-with-citations
|
||||
label:
|
||||
en_US: Text With Citations
|
||||
zh_Hans: 文本(包含引用)
|
||||
form: form
|
||||
|
||||
- name: input
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Input Text
|
||||
zh_Hans: 输入文本
|
||||
human_description:
|
||||
en_US: The text query to retrieve information
|
||||
zh_Hans: 用于检索信息的文本查询
|
||||
form: llm
|
||||
|
||||
- name: type
|
||||
type: select
|
||||
required: true
|
||||
label:
|
||||
en_US: Configuration Type
|
||||
zh_Hans: 配置类型
|
||||
human_description:
|
||||
en_US: Type of retrieve and generate configuration
|
||||
zh_Hans: 检索和生成配置的类型
|
||||
options:
|
||||
- value: KNOWLEDGE_BASE
|
||||
label:
|
||||
en_US: Knowledge Base
|
||||
zh_Hans: 知识库
|
||||
- value: EXTERNAL_SOURCES
|
||||
label:
|
||||
en_US: External Sources
|
||||
zh_Hans: 外部源
|
||||
form: form
|
||||
|
||||
- name: knowledge_base_configuration
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Knowledge Base Configuration
|
||||
zh_Hans: 知识库配置
|
||||
human_description:
|
||||
en_US: Please refer to @https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-agent-runtime/client/retrieve_and_generate.html#retrieve-and-generate for complete parameters and paste them here
|
||||
zh_Hans: 请参考 https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-agent-runtime/client/retrieve_and_generate.html#retrieve-and-generate 配置完整的参数并粘贴到这里
|
||||
form: form
|
||||
|
||||
- name: external_sources_configuration
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: External Sources Configuration
|
||||
zh_Hans: 外部源配置
|
||||
human_description:
|
||||
en_US: Please refer to https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-agent-runtime/client/retrieve_and_generate.html#retrieve-and-generate for complete parameters and paste them here
|
||||
zh_Hans: 请参考 https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-agent-runtime/client/retrieve_and_generate.html#retrieve-and-generate 配置完整的参数并粘贴到这里
|
||||
form: form
|
||||
|
||||
- name: session_configuration
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Session Configuration
|
||||
zh_Hans: 会话配置
|
||||
human_description:
|
||||
en_US: JSON formatted session configuration
|
||||
zh_Hans: JSON格式的会话配置
|
||||
default: ""
|
||||
form: form
|
||||
|
||||
- name: session_id
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Session ID
|
||||
zh_Hans: 会话ID
|
||||
human_description:
|
||||
en_US: Session ID for continuous conversations
|
||||
zh_Hans: 用于连续对话的会话ID
|
||||
form: form
|
||||
@ -1,91 +0,0 @@
|
||||
import json
|
||||
from typing import Any, Union
|
||||
|
||||
import boto3 # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class LambdaTranslateUtilsTool(BuiltinTool):
|
||||
lambda_client: Any = None
|
||||
|
||||
def _invoke_lambda(self, text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name):
|
||||
msg = {
|
||||
"src_contents": [text_content],
|
||||
"src_lang": src_lang,
|
||||
"dest_lang": dest_lang,
|
||||
"dictionary_id": dictionary_name,
|
||||
"request_type": request_type,
|
||||
"model_id": model_id,
|
||||
}
|
||||
|
||||
invoke_response = self.lambda_client.invoke(
|
||||
FunctionName=lambda_name, InvocationType="RequestResponse", Payload=json.dumps(msg)
|
||||
)
|
||||
response_body = invoke_response["Payload"]
|
||||
|
||||
response_str = response_body.read().decode("unicode_escape")
|
||||
|
||||
return response_str
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
line = 0
|
||||
try:
|
||||
if not self.lambda_client:
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
if aws_region:
|
||||
self.lambda_client = boto3.client("lambda", region_name=aws_region)
|
||||
else:
|
||||
self.lambda_client = boto3.client("lambda")
|
||||
|
||||
line = 1
|
||||
text_content = tool_parameters.get("text_content", "")
|
||||
if not text_content:
|
||||
return self.create_text_message("Please input text_content")
|
||||
|
||||
line = 2
|
||||
src_lang = tool_parameters.get("src_lang", "")
|
||||
if not src_lang:
|
||||
return self.create_text_message("Please input src_lang")
|
||||
|
||||
line = 3
|
||||
dest_lang = tool_parameters.get("dest_lang", "")
|
||||
if not dest_lang:
|
||||
return self.create_text_message("Please input dest_lang")
|
||||
|
||||
line = 4
|
||||
lambda_name = tool_parameters.get("lambda_name", "")
|
||||
if not lambda_name:
|
||||
return self.create_text_message("Please input lambda_name")
|
||||
|
||||
line = 5
|
||||
request_type = tool_parameters.get("request_type", "")
|
||||
if not request_type:
|
||||
return self.create_text_message("Please input request_type")
|
||||
|
||||
line = 6
|
||||
model_id = tool_parameters.get("model_id", "")
|
||||
if not model_id:
|
||||
return self.create_text_message("Please input model_id")
|
||||
|
||||
line = 7
|
||||
dictionary_name = tool_parameters.get("dictionary_name", "")
|
||||
if not dictionary_name:
|
||||
return self.create_text_message("Please input dictionary_name")
|
||||
|
||||
result = self._invoke_lambda(
|
||||
text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name
|
||||
)
|
||||
|
||||
return self.create_text_message(text=result)
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Exception {str(e)}, line : {line}")
|
||||
@ -1,134 +0,0 @@
|
||||
identity:
|
||||
name: lambda_translate_utils
|
||||
author: AWS
|
||||
label:
|
||||
en_US: TranslateTool
|
||||
zh_Hans: 翻译工具
|
||||
pt_BR: TranslateTool
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: A util tools for LLM translation, extra deployment is needed on AWS. Please refer Github Repo - https://github.com/aws-samples/rag-based-translation-with-dynamodb-and-bedrock
|
||||
zh_Hans: 大语言模型翻译工具(专词映射获取),需要在AWS上进行额外部署,可参考Github Repo - https://github.com/aws-samples/rag-based-translation-with-dynamodb-and-bedrock
|
||||
pt_BR: A util tools for LLM translation, specific Lambda Function deployment is needed on AWS. Please refer Github Repo - https://github.com/aws-samples/rag-based-translation-with-dynamodb-and-bedrock
|
||||
llm: A util tools for translation.
|
||||
parameters:
|
||||
- name: text_content
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: source content for translation
|
||||
zh_Hans: 待翻译原文
|
||||
pt_BR: source content for translation
|
||||
human_description:
|
||||
en_US: source content for translation
|
||||
zh_Hans: 待翻译原文
|
||||
pt_BR: source content for translation
|
||||
llm_description: source content for translation
|
||||
form: llm
|
||||
- name: src_lang
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: source language code
|
||||
zh_Hans: 原文语言代号
|
||||
pt_BR: source language code
|
||||
human_description:
|
||||
en_US: source language code
|
||||
zh_Hans: 原文语言代号
|
||||
pt_BR: source language code
|
||||
llm_description: source language code
|
||||
form: llm
|
||||
- name: dest_lang
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: target language code
|
||||
zh_Hans: 目标语言代号
|
||||
pt_BR: target language code
|
||||
human_description:
|
||||
en_US: target language code
|
||||
zh_Hans: 目标语言代号
|
||||
pt_BR: target language code
|
||||
llm_description: target language code
|
||||
form: llm
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: region of Lambda
|
||||
zh_Hans: Lambda 所在的region
|
||||
pt_BR: region of Lambda
|
||||
human_description:
|
||||
en_US: region of Lambda
|
||||
zh_Hans: Lambda 所在的region
|
||||
pt_BR: region of Lambda
|
||||
llm_description: region of Lambda
|
||||
form: form
|
||||
- name: model_id
|
||||
type: string
|
||||
required: false
|
||||
default: anthropic.claude-3-sonnet-20240229-v1:0
|
||||
label:
|
||||
en_US: LLM model_id in bedrock
|
||||
zh_Hans: bedrock上的大语言模型model_id
|
||||
pt_BR: LLM model_id in bedrock
|
||||
human_description:
|
||||
en_US: LLM model_id in bedrock
|
||||
zh_Hans: bedrock上的大语言模型model_id
|
||||
pt_BR: LLM model_id in bedrock
|
||||
llm_description: LLM model_id in bedrock
|
||||
form: form
|
||||
- name: dictionary_name
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: dictionary name for term mapping
|
||||
zh_Hans: 专词映射表名称
|
||||
pt_BR: dictionary name for term mapping
|
||||
human_description:
|
||||
en_US: dictionary name for term mapping
|
||||
zh_Hans: 专词映射表名称
|
||||
pt_BR: dictionary name for term mapping
|
||||
llm_description: dictionary name for term mapping
|
||||
form: form
|
||||
- name: request_type
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: request type
|
||||
zh_Hans: 请求类型
|
||||
pt_BR: request type
|
||||
human_description:
|
||||
en_US: request type
|
||||
zh_Hans: 请求类型
|
||||
pt_BR: request type
|
||||
default: term_mapping
|
||||
options:
|
||||
- value: term_mapping
|
||||
label:
|
||||
en_US: term_mapping
|
||||
zh_Hans: 专词映射
|
||||
- value: segment_only
|
||||
label:
|
||||
en_US: segment_only
|
||||
zh_Hans: 仅切词
|
||||
- value: translate
|
||||
label:
|
||||
en_US: translate
|
||||
zh_Hans: 翻译内容
|
||||
form: form
|
||||
- name: lambda_name
|
||||
type: string
|
||||
default: "translate_tool"
|
||||
required: true
|
||||
label:
|
||||
en_US: AWS Lambda for term mapping retrieval
|
||||
zh_Hans: 专词召回映射 - AWS Lambda
|
||||
pt_BR: lambda name for term mapping retrieval
|
||||
human_description:
|
||||
en_US: AWS Lambda for term mapping retrieval
|
||||
zh_Hans: 专词召回映射 - AWS Lambda
|
||||
pt_BR: AWS Lambda for term mapping retrieval
|
||||
llm_description: AWS Lambda for term mapping retrieval
|
||||
form: form
|
||||
@ -1,70 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
|
||||
import boto3 # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
console_handler = logging.StreamHandler()
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
|
||||
class LambdaYamlToJsonTool(BuiltinTool):
|
||||
lambda_client: Any = None
|
||||
|
||||
def _invoke_lambda(self, lambda_name: str, yaml_content: str) -> str:
|
||||
msg = {"body": yaml_content}
|
||||
logger.info(json.dumps(msg))
|
||||
|
||||
invoke_response = self.lambda_client.invoke(
|
||||
FunctionName=lambda_name, InvocationType="RequestResponse", Payload=json.dumps(msg)
|
||||
)
|
||||
response_body = invoke_response["Payload"]
|
||||
|
||||
response_str = response_body.read().decode("utf-8")
|
||||
resp_json = json.loads(response_str)
|
||||
|
||||
logger.info(resp_json)
|
||||
if resp_json["statusCode"] != 200:
|
||||
raise Exception(f"Invalid status code: {response_str}")
|
||||
|
||||
return resp_json["body"]
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
try:
|
||||
if not self.lambda_client:
|
||||
aws_region = tool_parameters.get("aws_region") # todo: move aws_region out, and update client region
|
||||
if aws_region:
|
||||
self.lambda_client = boto3.client("lambda", region_name=aws_region)
|
||||
else:
|
||||
self.lambda_client = boto3.client("lambda")
|
||||
|
||||
yaml_content = tool_parameters.get("yaml_content", "")
|
||||
if not yaml_content:
|
||||
return self.create_text_message("Please input yaml_content")
|
||||
|
||||
lambda_name = tool_parameters.get("lambda_name", "")
|
||||
if not lambda_name:
|
||||
return self.create_text_message("Please input lambda_name")
|
||||
logger.debug(f"{json.dumps(tool_parameters, indent=2, ensure_ascii=False)}")
|
||||
|
||||
result = self._invoke_lambda(lambda_name, yaml_content)
|
||||
logger.debug(result)
|
||||
|
||||
return self.create_text_message(result)
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Exception: {str(e)}")
|
||||
|
||||
console_handler.flush()
|
||||
@ -1,53 +0,0 @@
|
||||
identity:
|
||||
name: lambda_yaml_to_json
|
||||
author: AWS
|
||||
label:
|
||||
en_US: LambdaYamlToJson
|
||||
zh_Hans: LambdaYamlToJson
|
||||
pt_BR: LambdaYamlToJson
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: A tool to convert yaml to json using AWS Lambda.
|
||||
zh_Hans: 将 YAML 转为 JSON 的工具(通过AWS Lambda)。
|
||||
pt_BR: A tool to convert yaml to json using AWS Lambda.
|
||||
llm: A tool to convert yaml to json.
|
||||
parameters:
|
||||
- name: yaml_content
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: YAML content to convert for
|
||||
zh_Hans: YAML 内容
|
||||
pt_BR: YAML content to convert for
|
||||
human_description:
|
||||
en_US: YAML content to convert for
|
||||
zh_Hans: YAML 内容
|
||||
pt_BR: YAML content to convert for
|
||||
llm_description: YAML content to convert for
|
||||
form: llm
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: region of lambda
|
||||
zh_Hans: Lambda 所在的region
|
||||
pt_BR: region of lambda
|
||||
human_description:
|
||||
en_US: region of lambda
|
||||
zh_Hans: Lambda 所在的region
|
||||
pt_BR: region of lambda
|
||||
llm_description: region of lambda
|
||||
form: form
|
||||
- name: lambda_name
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: name of lambda
|
||||
zh_Hans: Lambda 名称
|
||||
pt_BR: name of lambda
|
||||
human_description:
|
||||
en_US: name of lambda
|
||||
zh_Hans: Lambda 名称
|
||||
pt_BR: name of lambda
|
||||
form: form
|
||||
@ -1,357 +0,0 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import boto3
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NovaCanvasTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
Invoke AWS Bedrock Nova Canvas model for image generation
|
||||
"""
|
||||
# Get common parameters
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
image_output_s3uri = tool_parameters.get("image_output_s3uri", "").strip()
|
||||
if not prompt:
|
||||
return self.create_text_message("Please provide a text prompt for image generation.")
|
||||
if not image_output_s3uri or urlparse(image_output_s3uri).scheme != "s3":
|
||||
return self.create_text_message("Please provide an valid S3 URI for image output.")
|
||||
|
||||
task_type = tool_parameters.get("task_type", "TEXT_IMAGE")
|
||||
aws_region = tool_parameters.get("aws_region", "us-east-1")
|
||||
|
||||
# Get common image generation config parameters
|
||||
width = tool_parameters.get("width", 1024)
|
||||
height = tool_parameters.get("height", 1024)
|
||||
cfg_scale = tool_parameters.get("cfg_scale", 8.0)
|
||||
negative_prompt = tool_parameters.get("negative_prompt", "")
|
||||
seed = tool_parameters.get("seed", 0)
|
||||
quality = tool_parameters.get("quality", "standard")
|
||||
|
||||
# Handle S3 image if provided
|
||||
image_input_s3uri = tool_parameters.get("image_input_s3uri", "")
|
||||
if task_type != "TEXT_IMAGE":
|
||||
if not image_input_s3uri or urlparse(image_input_s3uri).scheme != "s3":
|
||||
return self.create_text_message("Please provide a valid S3 URI for image to image generation.")
|
||||
|
||||
# Parse S3 URI
|
||||
parsed_uri = urlparse(image_input_s3uri)
|
||||
bucket = parsed_uri.netloc
|
||||
key = parsed_uri.path.lstrip("/")
|
||||
|
||||
# Initialize S3 client and download image
|
||||
s3_client = boto3.client("s3")
|
||||
response = s3_client.get_object(Bucket=bucket, Key=key)
|
||||
image_data = response["Body"].read()
|
||||
|
||||
# Base64 encode the image
|
||||
input_image = base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
try:
|
||||
# Initialize Bedrock client
|
||||
bedrock = boto3.client(service_name="bedrock-runtime", region_name=aws_region)
|
||||
|
||||
# Base image generation config
|
||||
image_generation_config = {
|
||||
"width": width,
|
||||
"height": height,
|
||||
"cfgScale": cfg_scale,
|
||||
"seed": seed,
|
||||
"numberOfImages": 1,
|
||||
"quality": quality,
|
||||
}
|
||||
|
||||
# Prepare request body based on task type
|
||||
body = {"imageGenerationConfig": image_generation_config}
|
||||
|
||||
if task_type == "TEXT_IMAGE":
|
||||
body["taskType"] = "TEXT_IMAGE"
|
||||
body["textToImageParams"] = {"text": prompt}
|
||||
if negative_prompt:
|
||||
body["textToImageParams"]["negativeText"] = negative_prompt
|
||||
|
||||
elif task_type == "COLOR_GUIDED_GENERATION":
|
||||
colors = tool_parameters.get("colors", "#ff8080-#ffb280-#ffe680-#ffe680")
|
||||
if not self._validate_color_string(colors):
|
||||
return self.create_text_message("Please provide valid colors in hexadecimal format.")
|
||||
|
||||
body["taskType"] = "COLOR_GUIDED_GENERATION"
|
||||
body["colorGuidedGenerationParams"] = {
|
||||
"colors": colors.split("-"),
|
||||
"referenceImage": input_image,
|
||||
"text": prompt,
|
||||
}
|
||||
if negative_prompt:
|
||||
body["colorGuidedGenerationParams"]["negativeText"] = negative_prompt
|
||||
|
||||
elif task_type == "IMAGE_VARIATION":
|
||||
similarity_strength = tool_parameters.get("similarity_strength", 0.5)
|
||||
|
||||
body["taskType"] = "IMAGE_VARIATION"
|
||||
body["imageVariationParams"] = {
|
||||
"images": [input_image],
|
||||
"similarityStrength": similarity_strength,
|
||||
"text": prompt,
|
||||
}
|
||||
if negative_prompt:
|
||||
body["imageVariationParams"]["negativeText"] = negative_prompt
|
||||
|
||||
elif task_type == "INPAINTING":
|
||||
mask_prompt = tool_parameters.get("mask_prompt")
|
||||
if not mask_prompt:
|
||||
return self.create_text_message("Please provide a mask prompt for image inpainting.")
|
||||
|
||||
body["taskType"] = "INPAINTING"
|
||||
body["inPaintingParams"] = {"image": input_image, "maskPrompt": mask_prompt, "text": prompt}
|
||||
if negative_prompt:
|
||||
body["inPaintingParams"]["negativeText"] = negative_prompt
|
||||
|
||||
elif task_type == "OUTPAINTING":
|
||||
mask_prompt = tool_parameters.get("mask_prompt")
|
||||
if not mask_prompt:
|
||||
return self.create_text_message("Please provide a mask prompt for image outpainting.")
|
||||
outpainting_mode = tool_parameters.get("outpainting_mode", "DEFAULT")
|
||||
|
||||
body["taskType"] = "OUTPAINTING"
|
||||
body["outPaintingParams"] = {
|
||||
"image": input_image,
|
||||
"maskPrompt": mask_prompt,
|
||||
"outPaintingMode": outpainting_mode,
|
||||
"text": prompt,
|
||||
}
|
||||
if negative_prompt:
|
||||
body["outPaintingParams"]["negativeText"] = negative_prompt
|
||||
|
||||
elif task_type == "BACKGROUND_REMOVAL":
|
||||
body["taskType"] = "BACKGROUND_REMOVAL"
|
||||
body["backgroundRemovalParams"] = {"image": input_image}
|
||||
|
||||
else:
|
||||
return self.create_text_message(f"Unsupported task type: {task_type}")
|
||||
|
||||
# Call Nova Canvas model
|
||||
response = bedrock.invoke_model(
|
||||
body=json.dumps(body),
|
||||
modelId="amazon.nova-canvas-v1:0",
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
|
||||
# Process response
|
||||
response_body = json.loads(response.get("body").read())
|
||||
if response_body.get("error"):
|
||||
raise Exception(f"Error in model response: {response_body.get('error')}")
|
||||
base64_image = response_body.get("images")[0]
|
||||
|
||||
# Upload to S3 if image_output_s3uri is provided
|
||||
try:
|
||||
# Parse S3 URI for output
|
||||
parsed_uri = urlparse(image_output_s3uri)
|
||||
output_bucket = parsed_uri.netloc
|
||||
output_base_path = parsed_uri.path.lstrip("/")
|
||||
# Generate filename with timestamp
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_key = f"{output_base_path}/canvas-output-{timestamp}.png"
|
||||
|
||||
# Initialize S3 client if not already done
|
||||
s3_client = boto3.client("s3", region_name=aws_region)
|
||||
|
||||
# Decode base64 image and upload to S3
|
||||
image_data = base64.b64decode(base64_image)
|
||||
s3_client.put_object(Bucket=output_bucket, Key=output_key, Body=image_data, ContentType="image/png")
|
||||
logger.info(f"Image uploaded to s3://{output_bucket}/{output_key}")
|
||||
except Exception as e:
|
||||
logger.exception("Failed to upload image to S3")
|
||||
# Return image
|
||||
return [
|
||||
self.create_text_message(f"Image is available at: s3://{output_bucket}/{output_key}"),
|
||||
self.create_blob_message(
|
||||
blob=base64.b64decode(base64_image),
|
||||
meta={"mime_type": "image/png"},
|
||||
save_as=self.VariableKey.IMAGE.value,
|
||||
),
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Failed to generate image: {str(e)}")
|
||||
|
||||
def _validate_color_string(self, color_string) -> bool:
|
||||
color_pattern = r"^#[0-9a-fA-F]{6}(?:-#[0-9a-fA-F]{6})*$"
|
||||
|
||||
if re.match(color_pattern, color_string):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
parameters = [
|
||||
ToolParameter(
|
||||
name="prompt",
|
||||
label=I18nObject(en_US="Prompt", zh_Hans="提示词"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=True,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(
|
||||
en_US="Text description of the image you want to generate or modify",
|
||||
zh_Hans="您想要生成或修改的图像的文本描述",
|
||||
),
|
||||
llm_description="Describe the image you want to generate or how you want to modify the input image",
|
||||
),
|
||||
ToolParameter(
|
||||
name="image_input_s3uri",
|
||||
label=I18nObject(en_US="Input image s3 uri", zh_Hans="输入图片的s3 uri"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(en_US="Image to be modified", zh_Hans="想要修改的图片"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="image_output_s3uri",
|
||||
label=I18nObject(en_US="Output Image S3 URI", zh_Hans="输出图片的S3 URI目录"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=True,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="S3 URI where the generated image should be uploaded", zh_Hans="生成的图像应该上传到的S3 URI"
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="width",
|
||||
label=I18nObject(en_US="Width", zh_Hans="宽度"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=1024,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="Width of the generated image", zh_Hans="生成图像的宽度"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="height",
|
||||
label=I18nObject(en_US="Height", zh_Hans="高度"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=1024,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="Height of the generated image", zh_Hans="生成图像的高度"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="cfg_scale",
|
||||
label=I18nObject(en_US="CFG Scale", zh_Hans="CFG比例"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=8.0,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="How strongly the image should conform to the prompt", zh_Hans="图像应该多大程度上符合提示词"
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="negative_prompt",
|
||||
label=I18nObject(en_US="Negative Prompt", zh_Hans="负面提示词"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default="",
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(
|
||||
en_US="Things you don't want in the generated image", zh_Hans="您不想在生成的图像中出现的内容"
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="seed",
|
||||
label=I18nObject(en_US="Seed", zh_Hans="种子值"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=0,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="Random seed for image generation", zh_Hans="图像生成的随机种子"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="aws_region",
|
||||
label=I18nObject(en_US="AWS Region", zh_Hans="AWS 区域"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default="us-east-1",
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="AWS region for Bedrock service", zh_Hans="Bedrock 服务的 AWS 区域"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="task_type",
|
||||
label=I18nObject(en_US="Task Type", zh_Hans="任务类型"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default="TEXT_IMAGE",
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(en_US="Type of image generation task", zh_Hans="图像生成任务的类型"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="quality",
|
||||
label=I18nObject(en_US="Quality", zh_Hans="质量"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default="standard",
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="Quality of the generated image (standard or premium)", zh_Hans="生成图像的质量(标准或高级)"
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="colors",
|
||||
label=I18nObject(en_US="Colors", zh_Hans="颜色"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="List of colors for color-guided generation, example: #ff8080-#ffb280-#ffe680-#ffe680",
|
||||
zh_Hans="颜色引导生成的颜色列表, 例子: #ff8080-#ffb280-#ffe680-#ffe680",
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="similarity_strength",
|
||||
label=I18nObject(en_US="Similarity Strength", zh_Hans="相似度强度"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=0.5,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="How similar the generated image should be to the input image (0.0 to 1.0)",
|
||||
zh_Hans="生成的图像应该与输入图像的相似程度(0.0到1.0)",
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="mask_prompt",
|
||||
label=I18nObject(en_US="Mask Prompt", zh_Hans="蒙版提示词"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(
|
||||
en_US="Text description to generate mask for inpainting/outpainting",
|
||||
zh_Hans="用于生成内补绘制/外补绘制蒙版的文本描述",
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="outpainting_mode",
|
||||
label=I18nObject(en_US="Outpainting Mode", zh_Hans="外补绘制模式"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default="DEFAULT",
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="Mode for outpainting (DEFAULT or other supported modes)",
|
||||
zh_Hans="外补绘制的模式(DEFAULT或其他支持的模式)",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
return parameters
|
||||