Merge main

This commit is contained in:
Yeuoly
2024-09-14 02:47:01 +08:00
959 changed files with 25695 additions and 24057 deletions

View File

@ -11,23 +11,23 @@ from core.tools.tool.tool import ToolParameter
class UserTool(BaseModel):
author: str
name: str # identifier
label: I18nObject # label
name: str # identifier
label: I18nObject # label
description: I18nObject
parameters: Optional[list[ToolParameter]] = None
labels: list[str] = Field(default_factory=list)
UserToolProviderTypeLiteral = Optional[Literal[
'builtin', 'api', 'workflow'
]]
UserToolProviderTypeLiteral = Optional[Literal["builtin", "api", "workflow"]]
class UserToolProvider(BaseModel):
id: str
author: str
name: str # identifier
name: str # identifier
description: I18nObject
icon: str | dict
label: I18nObject # label
label: I18nObject # label
type: ToolProviderType
masked_credentials: Optional[dict] = None
original_credentials: Optional[dict] = None
@ -41,26 +41,27 @@ class UserToolProvider(BaseModel):
# overwrite tool parameter types for temp fix
tools = jsonable_encoder(self.tools)
for tool in tools:
if tool.get('parameters'):
for parameter in tool.get('parameters'):
if parameter.get('type') == ToolParameter.ToolParameterType.FILE.value:
parameter['type'] = 'files'
if tool.get("parameters"):
for parameter in tool.get("parameters"):
if parameter.get("type") == ToolParameter.ToolParameterType.FILE.value:
parameter["type"] = "files"
# -------------
return {
'id': self.id,
'author': self.author,
'name': self.name,
'description': self.description.to_dict(),
'icon': self.icon,
'label': self.label.to_dict(),
'type': self.type.value,
'team_credentials': self.masked_credentials,
'is_team_authorization': self.is_team_authorization,
'allow_delete': self.allow_delete,
'tools': tools,
'labels': self.labels,
"id": self.id,
"author": self.author,
"name": self.name,
"description": self.description.to_dict(),
"icon": self.icon,
"label": self.label.to_dict(),
"type": self.type.value,
"team_credentials": self.masked_credentials,
"is_team_authorization": self.is_team_authorization,
"allow_delete": self.allow_delete,
"tools": tools,
"labels": self.labels,
}
class UserToolProviderCredentials(BaseModel):
credentials: dict[str, ProviderConfig]
credentials: dict[str, ProviderConfig]

View File

@ -7,6 +7,7 @@ class I18nObject(BaseModel):
"""
Model class for i18n object.
"""
zh_Hans: Optional[str] = None
pt_BR: Optional[str] = None
en_US: str
@ -19,8 +20,4 @@ class I18nObject(BaseModel):
self.pt_BR = self.en_US
def to_dict(self) -> dict:
return {
'zh_Hans': self.zh_Hans,
'en_US': self.en_US,
'pt_BR': self.pt_BR
}
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR}

View File

@ -7,8 +7,10 @@ from core.tools.entities.tool_entities import ToolParameter
class ApiToolBundle(BaseModel):
"""
This class is used to store the schema information of an api based tool. such as the url, the method, the parameters, etc.
This class is used to store the schema information of an api based tool.
such as the url, the method, the parameters, etc.
"""
# server_url
server_url: str
# method

View File

@ -9,27 +9,29 @@ from core.tools.entities.common_entities import I18nObject
class ToolLabelEnum(Enum):
SEARCH = 'search'
IMAGE = 'image'
VIDEOS = 'videos'
WEATHER = 'weather'
FINANCE = 'finance'
DESIGN = 'design'
TRAVEL = 'travel'
SOCIAL = 'social'
NEWS = 'news'
MEDICAL = 'medical'
PRODUCTIVITY = 'productivity'
EDUCATION = 'education'
BUSINESS = 'business'
ENTERTAINMENT = 'entertainment'
UTILITIES = 'utilities'
OTHER = 'other'
SEARCH = "search"
IMAGE = "image"
VIDEOS = "videos"
WEATHER = "weather"
FINANCE = "finance"
DESIGN = "design"
TRAVEL = "travel"
SOCIAL = "social"
NEWS = "news"
MEDICAL = "medical"
PRODUCTIVITY = "productivity"
EDUCATION = "education"
BUSINESS = "business"
ENTERTAINMENT = "entertainment"
UTILITIES = "utilities"
OTHER = "other"
class ToolProviderType(str, Enum):
"""
Enum class for tool provider
Enum class for tool provider
"""
BUILT_IN = "builtin"
WORKFLOW = "workflow"
API = "api"
@ -37,7 +39,7 @@ class ToolProviderType(str, Enum):
DATASET_RETRIEVAL = "dataset-retrieval"
@classmethod
def value_of(cls, value: str) -> 'ToolProviderType':
def value_of(cls, value: str) -> "ToolProviderType":
"""
Get value of given mode.
@ -47,19 +49,21 @@ class ToolProviderType(str, Enum):
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid mode value {value}')
raise ValueError(f"invalid mode value {value}")
class ApiProviderSchemaType(Enum):
"""
Enum class for api provider schema type.
"""
OPENAPI = "openapi"
SWAGGER = "swagger"
OPENAI_PLUGIN = "openai_plugin"
OPENAI_ACTIONS = "openai_actions"
@classmethod
def value_of(cls, value: str) -> 'ApiProviderSchemaType':
def value_of(cls, value: str) -> "ApiProviderSchemaType":
"""
Get value of given mode.
@ -69,17 +73,19 @@ class ApiProviderSchemaType(Enum):
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid mode value {value}')
raise ValueError(f"invalid mode value {value}")
class ApiProviderAuthType(Enum):
"""
Enum class for api provider auth type.
"""
NONE = "none"
API_KEY = "api_key"
@classmethod
def value_of(cls, value: str) -> 'ApiProviderAuthType':
def value_of(cls, value: str) -> "ApiProviderAuthType":
"""
Get value of given mode.
@ -89,7 +95,8 @@ class ApiProviderAuthType(Enum):
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid mode value {value}')
raise ValueError(f"invalid mode value {value}")
class ToolInvokeMessage(BaseModel):
class TextMessage(BaseModel):
@ -107,7 +114,7 @@ class ToolInvokeMessage(BaseModel):
stream: bool = Field(default=False, description="Whether the variable is streamed")
@field_validator("variable_value", mode="before")
def transform_variable_value(cls, value, values) -> Any:
def transform_variable_value(self, value, values) -> Any:
"""
Only basic types and lists are allowed.
"""
@ -122,11 +129,11 @@ class ToolInvokeMessage(BaseModel):
return value
@field_validator("variable_name", mode="before")
def transform_variable_name(cls, value) -> str:
def transform_variable_name(self, value) -> str:
"""
The variable name must be a string.
"""
if value in ["json", "text", "files"]:
if value in {"json", "text", "files"}:
raise ValueError(f"The variable name '{value}' is reserved.")
return value
@ -146,7 +153,7 @@ class ToolInvokeMessage(BaseModel):
"""
message: JsonMessage | TextMessage | BlobMessage | VariableMessage | None
meta: dict[str, Any] | None = None
save_as: str = ''
save_as: str = ""
@field_validator('message', mode='before')
@classmethod
@ -166,17 +173,19 @@ class ToolInvokeMessage(BaseModel):
}
return v
class ToolInvokeMessageBinary(BaseModel):
mimetype: str = Field(..., description="The mimetype of the binary")
url: str = Field(..., description="The url of the binary")
save_as: str = ''
save_as: str = ""
file_var: Optional[dict[str, Any]] = None
class ToolParameterOption(BaseModel):
value: str = Field(..., description="The value of the option")
label: I18nObject = Field(..., description="The label of the option")
@field_validator('value', mode='before')
@field_validator("value", mode="before")
@classmethod
def transform_id_to_str(cls, value) -> str:
if not isinstance(value, str):
@ -195,9 +204,9 @@ class ToolParameter(BaseModel):
FILE = CommonParameterType.FILE.value
class ToolParameterForm(Enum):
SCHEMA = "schema" # should be set while adding tool
FORM = "form" # should be set before invoking tool
LLM = "llm" # will be set by LLM
SCHEMA = "schema" # should be set while adding tool
FORM = "form" # should be set before invoking tool
LLM = "llm" # will be set by LLM
name: str = Field(..., description="The name of the parameter")
label: I18nObject = Field(..., description="The label presented to the user")
@ -214,21 +223,28 @@ class ToolParameter(BaseModel):
options: list[ToolParameterOption] = Field(default_factory=list)
@classmethod
def get_simple_instance(cls,
name: str, llm_description: str, type: ToolParameterType,
required: bool, options: Optional[list[str]] = None) -> 'ToolParameter':
def get_simple_instance(
cls,
name: str,
llm_description: str,
type: ToolParameterType,
required: bool,
options: Optional[list[str]] = None,
) -> "ToolParameter":
"""
get a simple tool parameter
get a simple tool parameter
:param name: the name of the parameter
:param llm_description: the description presented to the LLM
:param type: the type of the parameter
:param required: if the parameter is required
:param options: the options of the parameter
:param name: the name of the parameter
:param llm_description: the description presented to the LLM
:param type: the type of the parameter
:param required: if the parameter is required
:param options: the options of the parameter
"""
# convert options to ToolParameterOption
if options:
option_objs = [ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options]
option_objs = [
ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options
]
else:
option_objs = []
return cls(
@ -243,18 +259,24 @@ class ToolParameter(BaseModel):
options=option_objs,
)
class ToolProviderIdentity(BaseModel):
author: str = Field(..., description="The author of the tool")
name: str = Field(..., description="The name of the tool")
description: I18nObject = Field(..., description="The description of the tool")
icon: str = Field(..., description="The icon of the tool")
label: I18nObject = Field(..., description="The label of the tool")
tags: Optional[list[ToolLabelEnum]] = Field(default=[], description="The tags of the tool", )
tags: Optional[list[ToolLabelEnum]] = Field(
default=[],
description="The tags of the tool",
)
class ToolDescription(BaseModel):
human: I18nObject = Field(..., description="The description presented to the user")
llm: str = Field(..., description="The description presented to the LLM")
class ToolIdentity(BaseModel):
author: str = Field(..., description="The author of the tool")
name: str = Field(..., description="The name of the tool")
@ -262,22 +284,27 @@ class ToolIdentity(BaseModel):
provider: str = Field(..., description="The provider of the tool")
icon: Optional[str] = None
class ToolRuntimeVariableType(Enum):
TEXT = "text"
IMAGE = "image"
class ToolRuntimeVariable(BaseModel):
type: ToolRuntimeVariableType = Field(..., description="The type of the variable")
name: str = Field(..., description="The name of the variable")
position: int = Field(..., description="The position of the variable")
tool_name: str = Field(..., description="The name of the tool")
class ToolRuntimeTextVariable(ToolRuntimeVariable):
value: str = Field(..., description="The value of the variable")
class ToolRuntimeImageVariable(ToolRuntimeVariable):
value: str = Field(..., description="The path of the image")
class ToolRuntimeVariablePool(BaseModel):
conversation_id: str = Field(..., description="The conversation id")
user_id: str = Field(..., description="The user id")
@ -286,26 +313,26 @@ class ToolRuntimeVariablePool(BaseModel):
pool: list[ToolRuntimeVariable] = Field(..., description="The pool of variables")
def __init__(self, **data: Any):
pool = data.get('pool', [])
pool = data.get("pool", [])
# convert pool into correct type
for index, variable in enumerate(pool):
if variable['type'] == ToolRuntimeVariableType.TEXT.value:
if variable["type"] == ToolRuntimeVariableType.TEXT.value:
pool[index] = ToolRuntimeTextVariable(**variable)
elif variable['type'] == ToolRuntimeVariableType.IMAGE.value:
elif variable["type"] == ToolRuntimeVariableType.IMAGE.value:
pool[index] = ToolRuntimeImageVariable(**variable)
super().__init__(**data)
def dict(self) -> dict:
return {
'conversation_id': self.conversation_id,
'user_id': self.user_id,
'tenant_id': self.tenant_id,
'pool': [variable.model_dump() for variable in self.pool],
"conversation_id": self.conversation_id,
"user_id": self.user_id,
"tenant_id": self.tenant_id,
"pool": [variable.model_dump() for variable in self.pool],
}
def set_text(self, tool_name: str, name: str, value: str) -> None:
"""
set a text variable
set a text variable
"""
for variable in self.pool:
if variable.name == name:
@ -326,10 +353,10 @@ class ToolRuntimeVariablePool(BaseModel):
def set_file(self, tool_name: str, value: str, name: Optional[str] = None) -> None:
"""
set an image variable
set an image variable
:param tool_name: the name of the tool
:param value: the id of the file
:param tool_name: the name of the tool
:param value: the id of the file
"""
# check how many image variables are there
image_variable_count = 0
@ -357,22 +384,27 @@ class ToolRuntimeVariablePool(BaseModel):
self.pool.append(variable)
class ModelToolPropertyKey(Enum):
IMAGE_PARAMETER_NAME = "image_parameter_name"
class ModelToolConfiguration(BaseModel):
"""
Model tool configuration
"""
type: str = Field(..., description="The type of the model tool")
model: str = Field(..., description="The model")
label: I18nObject = Field(..., description="The label of the model tool")
properties: dict[ModelToolPropertyKey, Any] = Field(..., description="The properties of the model tool")
class ModelToolProviderConfiguration(BaseModel):
"""
Model tool provider configuration
"""
provider: str = Field(..., description="The provider of the model tool")
models: list[ModelToolConfiguration] = Field(..., description="The models of the model tool")
label: I18nObject = Field(..., description="The label of the model tool")
@ -382,27 +414,30 @@ class WorkflowToolParameterConfiguration(BaseModel):
"""
Workflow tool configuration
"""
name: str = Field(..., description="The name of the parameter")
description: str = Field(..., description="The description of the parameter")
form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter")
class ToolInvokeMeta(BaseModel):
"""
Tool invoke meta
"""
time_cost: float = Field(..., description="The time cost of the tool invoke")
error: Optional[str] = None
tool_config: Optional[dict] = None
@classmethod
def empty(cls) -> 'ToolInvokeMeta':
def empty(cls) -> "ToolInvokeMeta":
"""
Get an empty instance of ToolInvokeMeta
"""
return cls(time_cost=0.0, error=None, tool_config={})
@classmethod
def error_instance(cls, error: str) -> 'ToolInvokeMeta':
def error_instance(cls, error: str) -> "ToolInvokeMeta":
"""
Get an instance of ToolInvokeMeta with error
"""
@ -410,22 +445,26 @@ class ToolInvokeMeta(BaseModel):
def to_dict(self) -> dict:
return {
'time_cost': self.time_cost,
'error': self.error,
'tool_config': self.tool_config,
"time_cost": self.time_cost,
"error": self.error,
"tool_config": self.tool_config,
}
class ToolLabel(BaseModel):
"""
Tool label
"""
name: str = Field(..., description="The name of the tool")
label: I18nObject = Field(..., description="The label of the tool")
icon: str = Field(..., description="The icon of the tool")
class ToolInvokeFrom(Enum):
"""
Enum class for tool invoke
"""
WORKFLOW = "workflow"
AGENT = "agent"

View File

@ -2,73 +2,109 @@ from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolLabel, ToolLabelEnum
ICONS = {
ToolLabelEnum.SEARCH: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
ToolLabelEnum.SEARCH: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M7.33398 1.3335C10.646 1.3335 13.334 4.0215 13.334 7.3335C13.334 10.6455 10.646 13.3335 7.33398 13.3335C4.02198 13.3335 1.33398 10.6455 1.33398 7.3335C1.33398 4.0215 4.02198 1.3335 7.33398 1.3335ZM7.33398 12.0002C9.91232 12.0002 12.0007 9.91183 12.0007 7.3335C12.0007 4.75516 9.91232 2.66683 7.33398 2.66683C4.75565 2.66683 2.66732 4.75516 2.66732 7.3335C2.66732 9.91183 4.75565 12.0002 7.33398 12.0002ZM12.9909 12.0476L14.8764 13.9332L13.9337 14.876L12.0481 12.9904L12.9909 12.0476Z" fill="#344054"/>
</svg>''',
ToolLabelEnum.IMAGE: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
</svg>""", # noqa: E501
ToolLabelEnum.IMAGE: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M13.0514 9.71752L10.4718 7.13792C10.2115 6.87752 9.78932 6.87752 9.52898 7.13792L4.57721 12.0897C3.4097 11.1113 2.66732 9.64232 2.66732 7.99992C2.66732 5.0544 5.05513 2.66659 8.00065 2.66659C10.9462 2.66659 13.334 5.0544 13.334 7.99992C13.334 8.60085 13.2346 9.17852 13.0514 9.71752ZM5.72683 12.8257L10.0004 8.55212L12.4259 10.9777C11.4668 12.4001 9.84152 13.3331 8.00038 13.3331C7.18632 13.3331 6.41628 13.1511 5.72683 12.8257ZM8.00065 14.6666C11.6825 14.6666 14.6673 11.6818 14.6673 7.99992C14.6673 4.31802 11.6825 1.33325 8.00065 1.33325C4.31875 1.33325 1.33398 4.31802 1.33398 7.99992C1.33398 11.6818 4.31875 14.6666 8.00065 14.6666ZM7.33398 6.66658C7.33398 7.40299 6.73705 7.99992 6.00065 7.99992C5.26427 7.99992 4.66732 7.40299 4.66732 6.66658C4.66732 5.9302 5.26427 5.33325 6.00065 5.33325C6.73705 5.33325 7.33398 5.9302 7.33398 6.66658Z" fill="#344054"/>
</svg>''',
ToolLabelEnum.VIDEOS: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
</svg>""", # noqa: E501
ToolLabelEnum.VIDEOS: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M8.00065 13.3333H13.334V14.6666H8.00065C4.31875 14.6666 1.33398 11.6818 1.33398 7.99992C1.33398 4.31802 4.31875 1.33325 8.00065 1.33325C11.6825 1.33325 14.6673 4.31802 14.6673 7.99992C14.6673 9.50072 14.1714 10.8857 13.3345 11.9999H11.5284C12.6356 11.0227 13.334 9.59285 13.334 7.99992C13.334 5.0544 10.9462 2.66659 8.00065 2.66659C5.05513 2.66659 2.66732 5.0544 2.66732 7.99992C2.66732 10.9455 5.05513 13.3333 8.00065 13.3333ZM8.00065 6.66658C7.26425 6.66658 6.66732 6.06963 6.66732 5.33325C6.66732 4.59687 7.26425 3.99992 8.00065 3.99992C8.73705 3.99992 9.33398 4.59687 9.33398 5.33325C9.33398 6.06963 8.73705 6.66658 8.00065 6.66658ZM5.33398 9.33325C4.5976 9.33325 4.00065 8.73632 4.00065 7.99992C4.00065 7.26352 4.5976 6.66658 5.33398 6.66658C6.07036 6.66658 6.66732 7.26352 6.66732 7.99992C6.66732 8.73632 6.07036 9.33325 5.33398 9.33325ZM10.6673 9.33325C9.93092 9.33325 9.33398 8.73632 9.33398 7.99992C9.33398 7.26352 9.93092 6.66658 10.6673 6.66658C11.4037 6.66658 12.0007 7.26352 12.0007 7.99992C12.0007 8.73632 11.4037 9.33325 10.6673 9.33325ZM8.00065 11.9999C7.26425 11.9999 6.66732 11.403 6.66732 10.6666C6.66732 9.93018 7.26425 9.33325 8.00065 9.33325C8.73705 9.33325 9.33398 9.93018 9.33398 10.6666C9.33398 11.403 8.73705 11.9999 8.00065 11.9999Z" fill="#344054"/>
</svg>''',
ToolLabelEnum.WEATHER: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
</svg>""", # noqa: E501
ToolLabelEnum.WEATHER: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M6.6553 3.37344C7.42088 2.1484 8.78162 1.3335 10.3327 1.3335C12.7259 1.3335 14.666 3.2736 14.666 5.66683C14.666 6.38704 14.4903 7.06623 14.1794 7.66383C14.8894 8.3325 15.3327 9.28123 15.3327 10.3335C15.3327 12.3586 13.6911 14.0002 11.666 14.0002H5.99935C3.05383 14.0002 0.666016 11.6124 0.666016 8.66683C0.666016 5.72131 3.05383 3.3335 5.99935 3.3335C6.22143 3.3335 6.44034 3.34707 6.6553 3.37344ZM8.03628 3.73629C9.37768 4.29108 10.4435 5.37735 10.9711 6.73256C11.1961 6.68943 11.4284 6.66683 11.666 6.66683C12.1561 6.66683 12.6237 6.76296 13.0511 6.93743C13.2317 6.55162 13.3327 6.12102 13.3327 5.66683C13.3327 4.00998 11.9895 2.66683 10.3327 2.66683C9.41115 2.66683 8.58662 3.08236 8.03628 3.73629ZM11.666 12.6668C12.9547 12.6668 13.9993 11.6222 13.9993 10.3335C13.9993 9.04483 12.9547 8.00016 11.666 8.00016C11.013 8.00016 10.4227 8.26836 9.99922 8.70063C9.99928 8.68936 9.99935 8.6781 9.99935 8.66683C9.99935 6.45769 8.20848 4.66683 5.99935 4.66683C3.79021 4.66683 1.99935 6.45769 1.99935 8.66683C1.99935 10.876 3.79021 12.6668 5.99935 12.6668H11.666Z" fill="#344054"/>
</svg>''',
ToolLabelEnum.FINANCE: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
</svg>""", # noqa: E501
ToolLabelEnum.FINANCE: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M8.00262 14.6685C4.32071 14.6685 1.33594 11.6838 1.33594 8.00184C1.33594 4.31997 4.32071 1.33521 8.00262 1.33521C11.6845 1.33521 14.6693 4.31997 14.6693 8.00184C14.6693 11.6838 11.6845 14.6685 8.00262 14.6685ZM8.00262 13.3352C10.9482 13.3352 13.336 10.9474 13.336 8.00184C13.336 5.05635 10.9482 2.66854 8.00262 2.66854C5.05708 2.66854 2.66927 5.05635 2.66927 8.00184C2.66927 10.9474 5.05708 13.3352 8.00262 13.3352ZM5.66927 9.33517H9.33595C9.52002 9.33517 9.66928 9.18597 9.66928 9.00184C9.66928 8.81777 9.52002 8.66851 9.33595 8.66851H6.66928C5.7488 8.66851 5.0026 7.92237 5.0026 7.00184C5.0026 6.08139 5.7488 5.33521 6.66928 5.33521H7.33595V4.00187H8.66928V5.33521H10.336V6.66851H6.66928C6.48518 6.66851 6.33594 6.81777 6.33594 7.00184C6.33594 7.18597 6.48518 7.33517 6.66928 7.33517H9.33595C10.2564 7.33517 11.0026 8.08137 11.0026 9.00184C11.0026 9.92237 10.2564 10.6685 9.33595 10.6685H8.66928V12.0018H7.33595V10.6685H5.66927V9.33517Z" fill="#344054"/>
</svg>''',
ToolLabelEnum.DESIGN: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
</svg>""", # noqa: E501
ToolLabelEnum.DESIGN: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M4.70152 9.41416L3.2873 10.8284L5.17292 12.714L12.7154 5.17154L10.8298 3.28592L9.41557 4.70013L10.3584 5.64295L9.41557 6.58575L8.47277 5.64295L7.52997 6.58575L8.47277 7.52856L7.52997 8.47136L6.58713 7.52856L5.64433 8.47136L6.58713 9.41416L5.64433 10.357L4.70152 9.41416ZM11.3012 1.87171L14.1296 4.70013C14.39 4.96049 14.39 5.38259 14.1296 5.64295L5.64433 14.1282C5.38397 14.3886 4.96187 14.3886 4.70152 14.1282L1.87309 11.2998C1.61274 11.0394 1.61274 10.6174 1.87309 10.357L10.3584 1.87171C10.6187 1.61136 11.0408 1.61136 11.3012 1.87171ZM9.41557 12.2423L10.3584 11.2995L11.8534 12.7945H12.7962V11.8517L11.3012 10.3567L12.244 9.41383L14.0011 11.171V13.9999H11.1732L9.41557 12.2423ZM3.75861 6.58533L1.87299 4.69971C1.61265 4.43937 1.61265 4.01725 1.87299 3.75691L3.75861 1.87129C4.01896 1.61094 4.44107 1.61094 4.70142 1.87129L6.58704 3.75691L5.64423 4.69971L4.23002 3.2855L3.28721 4.22831L4.70142 5.64253L3.75861 6.58533Z" fill="#344054"/>
</svg>''',
ToolLabelEnum.TRAVEL: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
</svg>""", # noqa: E501
ToolLabelEnum.TRAVEL: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M9.44839 2C9.80198 2 10.1411 2.14047 10.3912 2.39053L13.6101 5.60947C13.8602 5.85953 14.0007 6.19866 14.0007 6.55229V11.3333H15.334V12.6667L9.91652 12.6672C9.62032 13.8171 8.57638 14.6667 7.33398 14.6667C6.0916 14.6667 5.04766 13.8171 4.75146 12.6672L2.00065 12.6667C1.63246 12.6667 1.33398 12.3682 1.33398 12V3.33333C1.33398 2.59695 1.93094 2 2.66732 2H9.44839ZM7.33398 10.6667C6.5976 10.6667 6.00065 11.2636 6.00065 12C6.00065 12.7364 6.5976 13.3333 7.33398 13.3333C8.07038 13.3333 8.66732 12.7364 8.66732 12C8.66732 11.2636 8.07038 10.6667 7.33398 10.6667ZM9.44839 3.33333H2.66732V11.3333L4.75128 11.3335C5.04726 10.1833 6.09136 9.33333 7.33398 9.33333C8.57658 9.33333 9.62072 10.1833 9.91665 11.3335L12.6673 11.3333V6.55229L9.44839 3.33333ZM9.33398 4.66667V8.66667H4.00065V4.66667H9.33398ZM8.00065 6H5.33398V7.33333H8.00065V6Z" fill="#344054"/>
</svg>''',
ToolLabelEnum.SOCIAL: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
</svg>""", # noqa: E501
ToolLabelEnum.SOCIAL: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M13.334 7.99992C13.334 5.0544 10.9462 2.66659 8.00065 2.66659C5.05513 2.66659 2.66732 5.0544 2.66732 7.99992C2.66732 10.9455 5.05513 13.3333 8.00065 13.3333C9.09518 13.3333 10.1127 13.0035 10.9594 12.438L11.699 13.5475C10.6408 14.2545 9.36885 14.6666 8.00065 14.6666C4.31875 14.6666 1.33398 11.6818 1.33398 7.99992C1.33398 4.31802 4.31875 1.33325 8.00065 1.33325C11.6825 1.33325 14.6673 4.31802 14.6673 7.99992V8.99992C14.6673 10.2886 13.6227 11.3333 12.334 11.3333C11.5312 11.3333 10.8231 10.9278 10.4032 10.3105C9.79678 10.9409 8.94452 11.3333 8.00065 11.3333C6.1597 11.3333 4.66732 9.84085 4.66732 7.99992C4.66732 6.15897 6.1597 4.66658 8.00065 4.66658C8.75118 4.66658 9.44378 4.91464 10.001 5.33325H11.334V8.99992C11.334 9.55219 11.7817 9.99992 12.334 9.99992C12.8863 9.99992 13.334 9.55219 13.334 8.99992V7.99992ZM8.00065 5.99992C6.89605 5.99992 6.00065 6.89532 6.00065 7.99992C6.00065 9.10452 6.89605 9.99992 8.00065 9.99992C9.10525 9.99992 10.0007 9.10452 10.0007 7.99992C10.0007 6.89532 9.10525 5.99992 8.00065 5.99992Z" fill="#344054"/>
</svg>''',
ToolLabelEnum.NEWS: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
</svg>""", # noqa: E501
ToolLabelEnum.NEWS: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M10.6673 13.3335V2.66683H2.66732V12.6668C2.66732 13.035 2.9658 13.3335 3.33398 13.3335H10.6673ZM12.6673 14.6668H3.33398C2.22942 14.6668 1.33398 13.7714 1.33398 12.6668V2.00016C1.33398 1.63198 1.63246 1.3335 2.00065 1.3335H11.334C11.7022 1.3335 12.0007 1.63198 12.0007 2.00016V6.66683H14.6673V12.6668C14.6673 13.7714 13.7719 14.6668 12.6673 14.6668ZM12.0007 8.00016V12.6668C12.0007 13.035 12.2991 13.3335 12.6673 13.3335C13.0355 13.3335 13.334 13.035 13.334 12.6668V8.00016H12.0007ZM4.00065 4.00016H8.00065V8.00016H4.00065V4.00016ZM5.33398 5.3335V6.66683H6.66732V5.3335H5.33398ZM4.00065 8.66683H9.33398V10.0002H4.00065V8.66683ZM4.00065 10.6668H9.33398V12.0002H4.00065V10.6668Z" fill="#344054"/>
</svg>''',
ToolLabelEnum.MEDICAL: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
</svg>""", # noqa: E501
ToolLabelEnum.MEDICAL: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M8.79747 1.51186L10.9641 5.26464C11.1482 5.5835 11.0389 5.99122 10.7201 6.17532L9.85373 6.67474L10.5207 7.83001L9.366 8.49668L8.699 7.34141L7.83333 7.84201C7.51447 8.02608 7.10673 7.91681 6.92267 7.59794L5.69747 5.47632C4.32922 5.89145 3.33333 7.16268 3.33333 8.66654C3.33333 9.08348 3.40987 9.48248 3.54965 9.85034C4.06613 9.52254 4.67762 9.33321 5.33333 9.33321C6.45605 9.33321 7.44913 9.88828 8.05313 10.7389L13.1787 7.78014L13.8454 8.93488L8.5932 11.9672C8.64133 12.1927 8.66667 12.4267 8.66667 12.6665C8.66667 12.895 8.64367 13.1181 8.59993 13.3337L14 13.3332V14.6665L2.66703 14.6673C2.2482 14.1101 2 13.4173 2 12.6665C2 11.9951 2.19855 11.3699 2.54014 10.8467C2.19517 10.1964 2 9.45428 2 8.66654C2 6.66968 3.25421 4.96575 5.01785 4.29953L4.75598 3.84519C4.38779 3.20747 4.60629 2.39202 5.24402 2.02382L6.97607 1.02382C7.6138 0.655637 8.42927 0.874138 8.79747 1.51186ZM5.33333 10.6665C4.22877 10.6665 3.33333 11.562 3.33333 12.6665C3.33333 12.9003 3.37343 13.1247 3.44711 13.3331H7.21953C7.29327 13.1247 7.33333 12.9003 7.33333 12.6665C7.33333 11.562 6.4379 10.6665 5.33333 10.6665ZM7.64273 2.17852L5.91068 3.17852L7.744 6.35395L9.47607 5.35395L7.64273 2.17852Z" fill="#344054"/>
</svg>''',
ToolLabelEnum.PRODUCTIVITY: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
</svg>""", # noqa: E501
ToolLabelEnum.PRODUCTIVITY: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M6.64807 11.9999H9.35062C9.43862 11.1989 9.84742 10.5376 10.5111 9.81499C10.5858 9.73365 11.0652 9.23752 11.1221 9.16665C11.6872 8.46199 11.9993 7.58992 11.9993 6.66659C11.9993 4.45745 10.2085 2.66659 7.99935 2.66659C5.79021 2.66659 3.99935 4.45745 3.99935 6.66659C3.99935 7.58945 4.31118 8.46105 4.87576 9.16552C4.93271 9.23659 5.41322 9.73405 5.48704 9.81445C6.15112 10.5375 6.56004 11.1989 6.64807 11.9999ZM9.33268 13.3333H6.66602V13.9999H9.33268V13.3333ZM3.83532 9.99939C3.10365 9.08639 2.66602 7.92759 2.66602 6.66659C2.66602 3.72107 5.05383 1.33325 7.99935 1.33325C10.9449 1.33325 13.3327 3.72107 13.3327 6.66659C13.3327 7.92825 12.8945 9.08759 12.1622 10.0009C11.7487 10.5165 10.666 11.3333 10.666 12.3333V13.9999C10.666 14.7363 10.0691 15.3333 9.33268 15.3333H6.66602C5.92964 15.3333 5.33268 14.7363 5.33268 13.9999V12.3333C5.33268 11.3333 4.24907 10.5157 3.83532 9.99939ZM8.66602 6.66979H10.3327L7.33268 10.6698V8.00312H5.66602L8.66602 3.99992V6.66979Z" fill="#344054"/>
</svg>''',
ToolLabelEnum.EDUCATION: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
</svg>""", # noqa: E501
ToolLabelEnum.EDUCATION: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M14 2.66683H4.66667C3.93029 2.66683 3.33333 3.26378 3.33333 4.00016C3.33333 4.73654 3.93029 5.3335 4.66667 5.3335H14V14.0002C14 14.3684 13.7015 14.6668 13.3333 14.6668H4.66667C3.19391 14.6668 2 13.4729 2 12.0002V4.00016C2 2.5274 3.19391 1.3335 4.66667 1.3335H13.3333C13.7015 1.3335 14 1.63198 14 2.00016V2.66683ZM3.33333 12.0002C3.33333 12.7366 3.93029 13.3335 4.66667 13.3335H12.6667V6.66683H4.66667C4.18095 6.66683 3.72557 6.53697 3.33333 6.31008V12.0002ZM13.3333 4.66683H4.66667C4.29848 4.66683 4 4.36835 4 4.00016C4 3.63198 4.29848 3.3335 4.66667 3.3335H13.3333V4.66683Z" fill="#344054"/>
</svg>''',
ToolLabelEnum.BUSINESS: '''<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 14 14" fill="none">
</svg>""", # noqa: E501
ToolLabelEnum.BUSINESS: """<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 14 14" fill="none">
<path d="M3.66732 3.33341V1.33341C3.66732 0.965228 3.9658 0.666748 4.33398 0.666748H9.66732C10.0355 0.666748 10.334 0.965228 10.334 1.33341V3.33341H13.0007C13.3689 3.33341 13.6673 3.63189 13.6673 4.00008V13.3334C13.6673 13.7016 13.3689 14.0001 13.0007 14.0001H1.00065C0.632464 14.0001 0.333984 13.7016 0.333984 13.3334V4.00008C0.333984 3.63189 0.632464 3.33341 1.00065 3.33341H3.66732ZM12.334 8.66675H1.66732V12.6667H12.334V8.66675ZM12.334 4.66675H1.66732V7.33341H3.66732V6.00008H5.00065V7.33341H9.00065V6.00008H10.334V7.33341H12.334V4.66675ZM5.00065 2.00008V3.33341H9.00065V2.00008H5.00065Z" fill="#344054"/>
</svg>''',
ToolLabelEnum.ENTERTAINMENT: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
</svg>""", # noqa: E501
ToolLabelEnum.ENTERTAINMENT: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M11.3327 2.66675C13.5418 2.66675 15.3327 4.45761 15.3327 6.66675V9.33342C15.3327 11.5425 13.5418 13.3334 11.3327 13.3334H4.66602C2.45688 13.3334 0.666016 11.5425 0.666016 9.33342V6.66675C0.666016 4.45761 2.45688 2.66675 4.66602 2.66675H11.3327ZM11.3327 4.00008H4.66602C3.23788 4.00008 2.07196 5.12273 2.00262 6.53365L1.99935 6.66675V9.33342C1.99935 10.7615 3.122 11.9275 4.53292 11.9968L4.66602 12.0001H11.3327C12.7608 12.0001 13.9267 10.8774 13.9961 9.46648L13.9993 9.33342V6.66675C13.9993 5.23861 12.8767 4.07269 11.4657 4.00335L11.3327 4.00008ZM6.66602 6.00008V7.33342H7.99935V8.66675H6.66535L6.66602 10.0001H5.33268L5.33202 8.66675H3.99935V7.33342H5.33268V6.00008H6.66602ZM11.9993 8.66675V10.0001H10.666V8.66675H11.9993ZM10.666 6.00008V7.33342H9.33268V6.00008H10.666Z" fill="#344054"/>
</svg>''',
ToolLabelEnum.UTILITIES: '''<svg xmlns="http://www.w3.org/2000/svg" width="13" height="15" viewBox="0 0 13 15" fill="none">
</svg>""", # noqa: E501
ToolLabelEnum.UTILITIES: """<svg xmlns="http://www.w3.org/2000/svg" width="13" height="15" viewBox="0 0 13 15" fill="none">
<path d="M12.3346 0.333252C12.7028 0.333252 13.0013 0.631732 13.0013 0.999919V4.33325C13.0013 4.70144 12.7028 4.99992 12.3346 4.99992H9.0013V13.6666C9.0013 14.0348 8.70284 14.3333 8.33463 14.3333H5.66797C5.29978 14.3333 5.0013 14.0348 5.0013 13.6666V4.99992H1.33464C0.966449 4.99992 0.667969 4.70144 0.667969 4.33325V2.74527C0.667969 2.49276 0.810635 2.26192 1.0365 2.14899L4.66797 0.333252H12.3346ZM9.0013 1.66659H4.98273L2.0013 3.1573V3.66659H6.33464V12.9999H7.66797V3.66659H9.0013V1.66659ZM11.668 1.66659H10.3346V3.66659H11.668V1.66659Z" fill="#344054"/>
</svg>''',
ToolLabelEnum.OTHER: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
</svg>""", # noqa: E501
ToolLabelEnum.OTHER: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
<path d="M8.00052 0.666748L4.00065 7.33342H12.0007L8.00052 0.666748ZM8.00052 3.25828L9.64572 6.00008H6.35553L8.00052 3.25828ZM4.50065 13.3334C3.48813 13.3334 2.66732 12.5126 2.66732 11.5001C2.66732 10.4875 3.48813 9.66675 4.50065 9.66675C5.51317 9.66675 6.33398 10.4875 6.33398 11.5001C6.33398 12.5126 5.51317 13.3334 4.50065 13.3334ZM4.50065 14.6667C6.24955 14.6667 7.66732 13.249 7.66732 11.5001C7.66732 9.75115 6.24955 8.33342 4.50065 8.33342C2.75175 8.33342 1.33398 9.75115 1.33398 11.5001C1.33398 13.249 2.75175 14.6667 4.50065 14.6667ZM10.0007 10.3334V13.0001H12.6673V10.3334H10.0007ZM8.66732 14.3334V9.00008H14.0007V14.3334H8.66732Z" fill="#344054"/>
</svg>'''
</svg>""", # noqa: E501
}
default_tool_label_dict = {
ToolLabelEnum.SEARCH: ToolLabel(name='search', label=I18nObject(en_US='Search', zh_Hans='搜索'), icon=ICONS[ToolLabelEnum.SEARCH]),
ToolLabelEnum.IMAGE: ToolLabel(name='image', label=I18nObject(en_US='Image', zh_Hans='图片'), icon=ICONS[ToolLabelEnum.IMAGE]),
ToolLabelEnum.VIDEOS: ToolLabel(name='videos', label=I18nObject(en_US='Videos', zh_Hans='视频'), icon=ICONS[ToolLabelEnum.VIDEOS]),
ToolLabelEnum.WEATHER: ToolLabel(name='weather', label=I18nObject(en_US='Weather', zh_Hans='天气'), icon=ICONS[ToolLabelEnum.WEATHER]),
ToolLabelEnum.FINANCE: ToolLabel(name='finance', label=I18nObject(en_US='Finance', zh_Hans='金融'), icon=ICONS[ToolLabelEnum.FINANCE]),
ToolLabelEnum.DESIGN: ToolLabel(name='design', label=I18nObject(en_US='Design', zh_Hans='设计'), icon=ICONS[ToolLabelEnum.DESIGN]),
ToolLabelEnum.TRAVEL: ToolLabel(name='travel', label=I18nObject(en_US='Travel', zh_Hans='旅行'), icon=ICONS[ToolLabelEnum.TRAVEL]),
ToolLabelEnum.SOCIAL: ToolLabel(name='social', label=I18nObject(en_US='Social', zh_Hans='社交'), icon=ICONS[ToolLabelEnum.SOCIAL]),
ToolLabelEnum.NEWS: ToolLabel(name='news', label=I18nObject(en_US='News', zh_Hans='新闻'), icon=ICONS[ToolLabelEnum.NEWS]),
ToolLabelEnum.MEDICAL: ToolLabel(name='medical', label=I18nObject(en_US='Medical', zh_Hans='医疗'), icon=ICONS[ToolLabelEnum.MEDICAL]),
ToolLabelEnum.PRODUCTIVITY: ToolLabel(name='productivity', label=I18nObject(en_US='Productivity', zh_Hans='生产力'), icon=ICONS[ToolLabelEnum.PRODUCTIVITY]),
ToolLabelEnum.EDUCATION: ToolLabel(name='education', label=I18nObject(en_US='Education', zh_Hans='教育'), icon=ICONS[ToolLabelEnum.EDUCATION]),
ToolLabelEnum.BUSINESS: ToolLabel(name='business', label=I18nObject(en_US='Business', zh_Hans='商业'), icon=ICONS[ToolLabelEnum.BUSINESS]),
ToolLabelEnum.ENTERTAINMENT: ToolLabel(name='entertainment', label=I18nObject(en_US='Entertainment', zh_Hans='娱乐'), icon=ICONS[ToolLabelEnum.ENTERTAINMENT]),
ToolLabelEnum.UTILITIES: ToolLabel(name='utilities', label=I18nObject(en_US='Utilities', zh_Hans='工具'), icon=ICONS[ToolLabelEnum.UTILITIES]),
ToolLabelEnum.OTHER: ToolLabel(name='other', label=I18nObject(en_US='Other', zh_Hans='其他'), icon=ICONS[ToolLabelEnum.OTHER]),
ToolLabelEnum.SEARCH: ToolLabel(
name="search", label=I18nObject(en_US="Search", zh_Hans="搜索"), icon=ICONS[ToolLabelEnum.SEARCH]
),
ToolLabelEnum.IMAGE: ToolLabel(
name="image", label=I18nObject(en_US="Image", zh_Hans="图片"), icon=ICONS[ToolLabelEnum.IMAGE]
),
ToolLabelEnum.VIDEOS: ToolLabel(
name="videos", label=I18nObject(en_US="Videos", zh_Hans="视频"), icon=ICONS[ToolLabelEnum.VIDEOS]
),
ToolLabelEnum.WEATHER: ToolLabel(
name="weather", label=I18nObject(en_US="Weather", zh_Hans="天气"), icon=ICONS[ToolLabelEnum.WEATHER]
),
ToolLabelEnum.FINANCE: ToolLabel(
name="finance", label=I18nObject(en_US="Finance", zh_Hans="金融"), icon=ICONS[ToolLabelEnum.FINANCE]
),
ToolLabelEnum.DESIGN: ToolLabel(
name="design", label=I18nObject(en_US="Design", zh_Hans="设计"), icon=ICONS[ToolLabelEnum.DESIGN]
),
ToolLabelEnum.TRAVEL: ToolLabel(
name="travel", label=I18nObject(en_US="Travel", zh_Hans="旅行"), icon=ICONS[ToolLabelEnum.TRAVEL]
),
ToolLabelEnum.SOCIAL: ToolLabel(
name="social", label=I18nObject(en_US="Social", zh_Hans="社交"), icon=ICONS[ToolLabelEnum.SOCIAL]
),
ToolLabelEnum.NEWS: ToolLabel(
name="news", label=I18nObject(en_US="News", zh_Hans="新闻"), icon=ICONS[ToolLabelEnum.NEWS]
),
ToolLabelEnum.MEDICAL: ToolLabel(
name="medical", label=I18nObject(en_US="Medical", zh_Hans="医疗"), icon=ICONS[ToolLabelEnum.MEDICAL]
),
ToolLabelEnum.PRODUCTIVITY: ToolLabel(
name="productivity",
label=I18nObject(en_US="Productivity", zh_Hans="生产力"),
icon=ICONS[ToolLabelEnum.PRODUCTIVITY],
),
ToolLabelEnum.EDUCATION: ToolLabel(
name="education", label=I18nObject(en_US="Education", zh_Hans="教育"), icon=ICONS[ToolLabelEnum.EDUCATION]
),
ToolLabelEnum.BUSINESS: ToolLabel(
name="business", label=I18nObject(en_US="Business", zh_Hans="商业"), icon=ICONS[ToolLabelEnum.BUSINESS]
),
ToolLabelEnum.ENTERTAINMENT: ToolLabel(
name="entertainment",
label=I18nObject(en_US="Entertainment", zh_Hans="娱乐"),
icon=ICONS[ToolLabelEnum.ENTERTAINMENT],
),
ToolLabelEnum.UTILITIES: ToolLabel(
name="utilities", label=I18nObject(en_US="Utilities", zh_Hans="工具"), icon=ICONS[ToolLabelEnum.UTILITIES]
),
ToolLabelEnum.OTHER: ToolLabel(
name="other", label=I18nObject(en_US="Other", zh_Hans="其他"), icon=ICONS[ToolLabelEnum.OTHER]
),
}
default_tool_labels = [v for k, v in default_tool_label_dict.items()]

View File

@ -4,23 +4,30 @@ from core.tools.entities.tool_entities import ToolInvokeMeta
class ToolProviderNotFoundError(ValueError):
pass
class ToolNotFoundError(ValueError):
pass
class ToolParameterValidationError(ValueError):
pass
class ToolProviderCredentialValidationError(ValueError):
pass
class ToolNotSupportedError(ValueError):
pass
class ToolInvokeError(ValueError):
pass
class ToolApiSchemaError(ValueError):
pass
class ToolEngineInvokeError(Exception):
meta: ToolInvokeMeta
meta: ToolInvokeMeta

View File

@ -1,4 +1,3 @@
from pydantic import Field
from core.entities.provider_entities import ProviderConfig
@ -20,86 +19,70 @@ class ApiToolProviderController(ToolProviderController):
tools: list[ApiTool] = Field(default_factory=list)
@staticmethod
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiToolProviderController':
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
credentials_schema = {
'auth_type': ProviderConfig(
name='auth_type',
"auth_type": ProviderConfig(
name="auth_type",
required=True,
type=ProviderConfig.Type.SELECT,
options=[
ProviderConfig.Option(value='none', label=I18nObject(en_US='None', zh_Hans='')),
ProviderConfig.Option(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key'))
ProviderConfig.Option(value="none", label=I18nObject(en_US="None", zh_Hans="")),
ProviderConfig.Option(value="api_key", label=I18nObject(en_US="api_key", zh_Hans="api_key")),
],
default='none',
help=I18nObject(
en_US='The auth type of the api provider',
zh_Hans='api provider 的认证类型'
)
default="none",
help=I18nObject(en_US="The auth type of the api provider", zh_Hans="api provider 的认证类型"),
)
}
if auth_type == ApiProviderAuthType.API_KEY:
credentials_schema = {
**credentials_schema,
'api_key_header': ProviderConfig(
name='api_key_header',
"api_key_header": ProviderConfig(
name="api_key_header",
required=False,
default='api_key',
default="api_key",
type=ProviderConfig.Type.TEXT_INPUT,
help=I18nObject(
en_US='The header name of the api key',
zh_Hans='携带 api key 的 header 名称'
)
help=I18nObject(en_US="The header name of the api key", zh_Hans="携带 api key 的 header 名称"),
),
'api_key_value': ProviderConfig(
name='api_key_value',
"api_key_value": ProviderConfig(
name="api_key_value",
required=True,
type=ProviderConfig.Type.SECRET_INPUT,
help=I18nObject(
en_US='The api key',
zh_Hans='api key的值'
)
help=I18nObject(en_US="The api key", zh_Hans="api key的值"),
),
'api_key_header_prefix': ProviderConfig(
name='api_key_header_prefix',
"api_key_header_prefix": ProviderConfig(
name="api_key_header_prefix",
required=False,
default='basic',
default="basic",
type=ProviderConfig.Type.SELECT,
help=I18nObject(
en_US='The prefix of the api key header',
zh_Hans='api key header 的前缀'
),
help=I18nObject(en_US="The prefix of the api key header", zh_Hans="api key header 的前缀"),
options=[
ProviderConfig.Option(value='basic', label=I18nObject(en_US='Basic', zh_Hans='Basic')),
ProviderConfig.Option(value='bearer', label=I18nObject(en_US='Bearer', zh_Hans='Bearer')),
ProviderConfig.Option(value='custom', label=I18nObject(en_US='Custom', zh_Hans='Custom'))
]
)
ProviderConfig.Option(value="basic", label=I18nObject(en_US="Basic", zh_Hans="Basic")),
ProviderConfig.Option(value="bearer", label=I18nObject(en_US="Bearer", zh_Hans="Bearer")),
ProviderConfig.Option(value="custom", label=I18nObject(en_US="Custom", zh_Hans="Custom")),
],
),
}
elif auth_type == ApiProviderAuthType.NONE:
pass
else:
raise ValueError(f'invalid auth type {auth_type}')
raise ValueError(f"invalid auth type {auth_type}")
user_name = db_provider.user.name if db_provider.user_id else ''
user_name = db_provider.user.name if db_provider.user_id else ""
return ApiToolProviderController(**{
'identity': {
'author': user_name,
'name': db_provider.name,
'label': {
'en_US': db_provider.name,
'zh_Hans': db_provider.name
return ApiToolProviderController(
**{
"identity": {
"author": user_name,
"name": db_provider.name,
"label": {"en_US": db_provider.name, "zh_Hans": db_provider.name},
"description": {"en_US": db_provider.description, "zh_Hans": db_provider.description},
"icon": db_provider.icon,
},
'description': {
'en_US': db_provider.description,
'zh_Hans': db_provider.description
},
'icon': db_provider.icon,
"credentials_schema": credentials_schema,
"provider_id": db_provider.id or "",
"tenant_id": db_provider.tenant_id or "",
},
'credentials_schema': credentials_schema,
'provider_id': db_provider.id or '',
'tenant_id': db_provider.tenant_id or '',
})
)
@property
def provider_type(self) -> ToolProviderType:
@ -107,39 +90,35 @@ class ApiToolProviderController(ToolProviderController):
def _parse_tool_bundle(self, tool_bundle: ApiToolBundle) -> ApiTool:
"""
parse tool bundle to tool
parse tool bundle to tool
:param tool_bundle: the tool bundle
:return: the tool
:param tool_bundle: the tool bundle
:return: the tool
"""
return ApiTool(**{
'api_bundle': tool_bundle,
'identity' : {
'author': tool_bundle.author,
'name': tool_bundle.operation_id,
'label': {
'en_US': tool_bundle.operation_id,
'zh_Hans': tool_bundle.operation_id
return ApiTool(
**{
"api_bundle": tool_bundle,
"identity": {
"author": tool_bundle.author,
"name": tool_bundle.operation_id,
"label": {"en_US": tool_bundle.operation_id, "zh_Hans": tool_bundle.operation_id},
"icon": self.identity.icon,
"provider": self.provider_id,
},
'icon': self.identity.icon,
'provider': self.provider_id,
},
'description': {
'human': {
'en_US': tool_bundle.summary or '',
'zh_Hans': tool_bundle.summary or ''
"description": {
"human": {"en_US": tool_bundle.summary or "", "zh_Hans": tool_bundle.summary or ""},
"llm": tool_bundle.summary or "",
},
'llm': tool_bundle.summary or ''
},
'parameters' : tool_bundle.parameters if tool_bundle.parameters else [],
})
"parameters": tool_bundle.parameters or [],
}
)
def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]:
"""
load bundled tools
load bundled tools
:param tools: the bundled tools
:return: the tools
:param tools: the bundled tools
:return: the tools
"""
self.tools = [self._parse_tool_bundle(tool) for tool in tools]
@ -147,22 +126,23 @@ class ApiToolProviderController(ToolProviderController):
def get_tools(self, tenant_id: str) -> list[ApiTool]:
"""
fetch tools from database
fetch tools from database
:param user_id: the user id
:param tenant_id: the tenant id
:return: the tools
:param user_id: the user id
:param tenant_id: the tenant id
:return: the tools
"""
if self.tools is not None:
return self.tools
tools: list[ApiTool] = []
# get tenant api providers
db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == self.identity.name
).all()
db_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.identity.name)
.all()
)
if db_providers and len(db_providers) != 0:
for db_provider in db_providers:
@ -170,16 +150,16 @@ class ApiToolProviderController(ToolProviderController):
assistant_tool = self._parse_tool_bundle(tool)
assistant_tool.is_team_authorization = True
tools.append(assistant_tool)
self.tools = tools
return tools
def get_tool(self, tool_name: str) -> ApiTool:
"""
get tool by name
get tool by name
:param tool_name: the name of the tool
:return: the tool
:param tool_name: the name of the tool
:return: the tool
"""
if self.tools is None:
self.get_tools(self.tenant_id)
@ -188,4 +168,4 @@ class ApiToolProviderController(ToolProviderController):
if tool.identity.name == tool_name:
return tool
raise ValueError(f'tool {tool_name} not found')
raise ValueError(f"tool {tool_name} not found")

View File

@ -0,0 +1,103 @@
import logging
from typing import Any
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolParameter, ToolParameterOption, ToolProviderType
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.tool import Tool
from extensions.ext_database import db
from models.model import App, AppModelConfig
from models.tools import PublishedAppTool
logger = logging.getLogger(__name__)
class AppToolProviderEntity(ToolProviderController):
@property
def provider_type(self) -> ToolProviderType:
return ToolProviderType.APP
def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None:
pass
def validate_parameters(self, tool_name: str, tool_parameters: dict[str, Any]) -> None:
pass
def get_tools(self, user_id: str) -> list[Tool]:
db_tools: list[PublishedAppTool] = (
db.session.query(PublishedAppTool)
.filter(
PublishedAppTool.user_id == user_id,
)
.all()
)
if not db_tools or len(db_tools) == 0:
return []
tools: list[Tool] = []
for db_tool in db_tools:
tool = {
"identity": {
"author": db_tool.author,
"name": db_tool.tool_name,
"label": {"en_US": db_tool.tool_name, "zh_Hans": db_tool.tool_name},
"icon": "",
},
"description": {
"human": {"en_US": db_tool.description_i18n.en_US, "zh_Hans": db_tool.description_i18n.zh_Hans},
"llm": db_tool.llm_description,
},
"parameters": [],
}
# get app from db
app: App = db_tool.app
if not app:
logger.error(f"app {db_tool.app_id} not found")
continue
app_model_config: AppModelConfig = app.app_model_config
user_input_form_list = app_model_config.user_input_form_list
for input_form in user_input_form_list:
# get type
form_type = input_form.keys()[0]
default = input_form[form_type]["default"]
required = input_form[form_type]["required"]
label = input_form[form_type]["label"]
variable_name = input_form[form_type]["variable_name"]
options = input_form[form_type].get("options", [])
if form_type in {"paragraph", "text-input"}:
tool["parameters"].append(
ToolParameter(
name=variable_name,
label=I18nObject(en_US=label, zh_Hans=label),
human_description=I18nObject(en_US=label, zh_Hans=label),
llm_description=label,
form=ToolParameter.ToolParameterForm.FORM,
type=ToolParameter.ToolParameterType.STRING,
required=required,
default=default,
)
)
elif form_type == "select":
tool["parameters"].append(
ToolParameter(
name=variable_name,
label=I18nObject(en_US=label, zh_Hans=label),
human_description=I18nObject(en_US=label, zh_Hans=label),
llm_description=label,
form=ToolParameter.ToolParameterForm.FORM,
type=ToolParameter.ToolParameterType.SELECT,
required=required,
default=default,
options=[
ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
for option in options
],
)
)
tools.append(Tool(**tool))
return tools

View File

@ -10,7 +10,7 @@ class BuiltinToolProviderSort:
@classmethod
def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
if not cls._position:
cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), '..'))
cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), ".."))
def name_func(provider: UserToolProvider) -> str:
return provider.name

View File

@ -6,6 +6,6 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
class AIPPTProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
AIPPTGenerateTool._get_api_token(credentials, user_id='__dify_system__')
AIPPTGenerateTool._get_api_token(credentials, user_id="__dify_system__")
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -20,16 +20,16 @@ class AIPPTGenerateTool(BuiltinTool):
A tool for generating a ppt
"""
_api_base_url = URL('https://co.aippt.cn/api')
_api_base_url = URL("https://co.aippt.cn/api")
_api_token_cache = {}
_api_token_cache_lock:Optional[Lock] = None
_api_token_cache_lock: Optional[Lock] = None
_style_cache = {}
_style_cache_lock:Optional[Lock] = None
_style_cache_lock: Optional[Lock] = None
_task = {}
_task_type_map = {
'auto': 1,
'markdown': 7,
"auto": 1,
"markdown": 7,
}
def __init__(self, **kwargs: Any):
@ -46,67 +46,58 @@ class AIPPTGenerateTool(BuiltinTool):
tool_parameters (dict[str, Any]): The parameters for the tool
Returns:
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, which can be a single message or a list of messages.
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation,
which can be a single message or a list of messages.
"""
title = tool_parameters.get('title', '')
title = tool_parameters.get("title", "")
if not title:
return self.create_text_message('Please provide a title for the ppt')
model = tool_parameters.get('model', 'aippt')
return self.create_text_message("Please provide a title for the ppt")
model = tool_parameters.get("model", "aippt")
if not model:
return self.create_text_message('Please provide a model for the ppt')
outline = tool_parameters.get('outline', '')
return self.create_text_message("Please provide a model for the ppt")
outline = tool_parameters.get("outline", "")
# create task
task_id = self._create_task(
type=self._task_type_map['auto' if not outline else 'markdown'],
type=self._task_type_map["auto" if not outline else "markdown"],
title=title,
content=outline,
user_id=user_id
user_id=user_id,
)
# get suit
color = tool_parameters.get('color')
style = tool_parameters.get('style')
color = tool_parameters.get("color")
style = tool_parameters.get("style")
if color == '__default__':
color_id = ''
if color == "__default__":
color_id = ""
else:
color_id = int(color.split('-')[1])
color_id = int(color.split("-")[1])
if style == '__default__':
style_id = ''
if style == "__default__":
style_id = ""
else:
style_id = int(style.split('-')[1])
style_id = int(style.split("-")[1])
suit_id = self._get_suit(style_id=style_id, colour_id=color_id)
# generate outline
if not outline:
self._generate_outline(
task_id=task_id,
model=model,
user_id=user_id
)
self._generate_outline(task_id=task_id, model=model, user_id=user_id)
# generate content
self._generate_content(
task_id=task_id,
model=model,
user_id=user_id
)
self._generate_content(task_id=task_id, model=model, user_id=user_id)
# generate ppt
_, ppt_url = self._generate_ppt(
task_id=task_id,
suit_id=suit_id,
user_id=user_id
)
_, ppt_url = self._generate_ppt(task_id=task_id, suit_id=suit_id, user_id=user_id)
return self.create_text_message('''the ppt has been created successfully,'''
f'''the ppt url is {ppt_url}'''
'''please give the ppt url to user and direct user to download it.''')
return self.create_text_message(
"""the ppt has been created successfully,"""
f"""the ppt url is {ppt_url}"""
"""please give the ppt url to user and direct user to download it."""
)
def _create_task(self, type: int, title: str, content: str, user_id: str) -> str:
"""
@ -119,129 +110,121 @@ class AIPPTGenerateTool(BuiltinTool):
:return: the task ID
"""
headers = {
'x-channel': '',
'x-api-key': self.runtime.credentials['aippt_access_key'],
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
"x-channel": "",
"x-api-key": self.runtime.credentials["aippt_access_key"],
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
}
response = post(
str(self._api_base_url / 'ai' / 'chat' / 'v2' / 'task'),
str(self._api_base_url / "ai" / "chat" / "v2" / "task"),
headers=headers,
files={
'type': ('', str(type)),
'title': ('', title),
'content': ('', content)
}
files={"type": ("", str(type)), "title": ("", title), "content": ("", content)},
)
if response.status_code != 200:
raise Exception(f'Failed to connect to aippt: {response.text}')
raise Exception(f"Failed to connect to aippt: {response.text}")
response = response.json()
if response.get('code') != 0:
if response.get("code") != 0:
raise Exception(f'Failed to create task: {response.get("msg")}')
return response.get('data', {}).get('id')
return response.get("data", {}).get("id")
def _generate_outline(self, task_id: str, model: str, user_id: str) -> str:
api_url = self._api_base_url / 'ai' / 'chat' / 'outline' if model == 'aippt' else \
self._api_base_url / 'ai' / 'chat' / 'wx' / 'outline'
api_url %= {'task_id': task_id}
api_url = (
self._api_base_url / "ai" / "chat" / "outline"
if model == "aippt"
else self._api_base_url / "ai" / "chat" / "wx" / "outline"
)
api_url %= {"task_id": task_id}
headers = {
'x-channel': '',
'x-api-key': self.runtime.credentials['aippt_access_key'],
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
"x-channel": "",
"x-api-key": self.runtime.credentials["aippt_access_key"],
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
}
response = requests_get(
url=api_url,
headers=headers,
stream=True,
timeout=(10, 60)
)
response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60))
if response.status_code != 200:
raise Exception(f'Failed to connect to aippt: {response.text}')
outline = ''
for chunk in response.iter_lines(delimiter=b'\n\n'):
raise Exception(f"Failed to connect to aippt: {response.text}")
outline = ""
for chunk in response.iter_lines(delimiter=b"\n\n"):
if not chunk:
continue
event = ''
lines = chunk.decode('utf-8').split('\n')
event = ""
lines = chunk.decode("utf-8").split("\n")
for line in lines:
if line.startswith('event:'):
if line.startswith("event:"):
event = line[6:]
elif line.startswith('data:'):
elif line.startswith("data:"):
data = line[5:]
if event == 'message':
if event == "message":
try:
data = json_loads(data)
outline += data.get('content', '')
outline += data.get("content", "")
except Exception as e:
pass
elif event == 'close':
elif event == "close":
break
elif event == 'error' or event == 'filter':
raise Exception(f'Failed to generate outline: {data}')
elif event in {"error", "filter"}:
raise Exception(f"Failed to generate outline: {data}")
return outline
def _generate_content(self, task_id: str, model: str, user_id: str) -> str:
api_url = self._api_base_url / 'ai' / 'chat' / 'content' if model == 'aippt' else \
self._api_base_url / 'ai' / 'chat' / 'wx' / 'content'
api_url %= {'task_id': task_id}
api_url = (
self._api_base_url / "ai" / "chat" / "content"
if model == "aippt"
else self._api_base_url / "ai" / "chat" / "wx" / "content"
)
api_url %= {"task_id": task_id}
headers = {
'x-channel': '',
'x-api-key': self.runtime.credentials['aippt_access_key'],
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
"x-channel": "",
"x-api-key": self.runtime.credentials["aippt_access_key"],
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
}
response = requests_get(
url=api_url,
headers=headers,
stream=True,
timeout=(10, 60)
)
response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60))
if response.status_code != 200:
raise Exception(f'Failed to connect to aippt: {response.text}')
if model == 'aippt':
content = ''
for chunk in response.iter_lines(delimiter=b'\n\n'):
raise Exception(f"Failed to connect to aippt: {response.text}")
if model == "aippt":
content = ""
for chunk in response.iter_lines(delimiter=b"\n\n"):
if not chunk:
continue
event = ''
lines = chunk.decode('utf-8').split('\n')
event = ""
lines = chunk.decode("utf-8").split("\n")
for line in lines:
if line.startswith('event:'):
if line.startswith("event:"):
event = line[6:]
elif line.startswith('data:'):
elif line.startswith("data:"):
data = line[5:]
if event == 'message':
if event == "message":
try:
data = json_loads(data)
content += data.get('content', '')
content += data.get("content", "")
except Exception as e:
pass
elif event == 'close':
elif event == "close":
break
elif event == 'error' or event == 'filter':
raise Exception(f'Failed to generate content: {data}')
elif event in {"error", "filter"}:
raise Exception(f"Failed to generate content: {data}")
return content
elif model == 'wenxin':
elif model == "wenxin":
response = response.json()
if response.get('code') != 0:
if response.get("code") != 0:
raise Exception(f'Failed to generate content: {response.get("msg")}')
return response.get('data', '')
return ''
return response.get("data", "")
return ""
def _generate_ppt(self, task_id: str, suit_id: int, user_id) -> tuple[str, str]:
"""
@ -252,83 +235,73 @@ class AIPPTGenerateTool(BuiltinTool):
:return: the cover url of the ppt and the ppt url
"""
headers = {
'x-channel': '',
'x-api-key': self.runtime.credentials['aippt_access_key'],
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
"x-channel": "",
"x-api-key": self.runtime.credentials["aippt_access_key"],
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
}
response = post(
str(self._api_base_url / 'design' / 'v2' / 'save'),
str(self._api_base_url / "design" / "v2" / "save"),
headers=headers,
data={
'task_id': task_id,
'template_id': suit_id
}
data={"task_id": task_id, "template_id": suit_id},
)
if response.status_code != 200:
raise Exception(f'Failed to connect to aippt: {response.text}')
raise Exception(f"Failed to connect to aippt: {response.text}")
response = response.json()
if response.get('code') != 0:
if response.get("code") != 0:
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
id = response.get('data', {}).get('id')
cover_url = response.get('data', {}).get('cover_url')
id = response.get("data", {}).get("id")
cover_url = response.get("data", {}).get("cover_url")
response = post(
str(self._api_base_url / 'download' / 'export' / 'file'),
str(self._api_base_url / "download" / "export" / "file"),
headers=headers,
data={
'id': id,
'format': 'ppt',
'files_to_zip': False,
'edit': True
}
data={"id": id, "format": "ppt", "files_to_zip": False, "edit": True},
)
if response.status_code != 200:
raise Exception(f'Failed to connect to aippt: {response.text}')
raise Exception(f"Failed to connect to aippt: {response.text}")
response = response.json()
if response.get('code') != 0:
if response.get("code") != 0:
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
export_code = response.get('data')
export_code = response.get("data")
if not export_code:
raise Exception('Failed to generate ppt, the export code is empty')
raise Exception("Failed to generate ppt, the export code is empty")
current_iteration = 0
while current_iteration < 50:
# get ppt url
response = post(
str(self._api_base_url / 'download' / 'export' / 'file' / 'result'),
str(self._api_base_url / "download" / "export" / "file" / "result"),
headers=headers,
data={
'task_key': export_code
}
data={"task_key": export_code},
)
if response.status_code != 200:
raise Exception(f'Failed to connect to aippt: {response.text}')
raise Exception(f"Failed to connect to aippt: {response.text}")
response = response.json()
if response.get('code') != 0:
if response.get("code") != 0:
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
if response.get('msg') == '导出中':
if response.get("msg") == "导出中":
current_iteration += 1
sleep(2)
continue
ppt_url = response.get('data', [])
ppt_url = response.get("data", [])
if len(ppt_url) == 0:
raise Exception('Failed to generate ppt, the ppt url is empty')
raise Exception("Failed to generate ppt, the ppt url is empty")
return cover_url, ppt_url[0]
raise Exception('Failed to generate ppt, the export is timeout')
raise Exception("Failed to generate ppt, the export is timeout")
@classmethod
def _get_api_token(cls, credentials: dict[str, str], user_id: str) -> str:
"""
@ -337,53 +310,43 @@ class AIPPTGenerateTool(BuiltinTool):
:param credentials: the credentials
:return: the API token
"""
access_key = credentials['aippt_access_key']
secret_key = credentials['aippt_secret_key']
access_key = credentials["aippt_access_key"]
secret_key = credentials["aippt_secret_key"]
cache_key = f'{access_key}#@#{user_id}'
cache_key = f"{access_key}#@#{user_id}"
with cls._api_token_cache_lock:
# clear expired tokens
now = time()
for key in list(cls._api_token_cache.keys()):
if cls._api_token_cache[key]['expire'] < now:
if cls._api_token_cache[key]["expire"] < now:
del cls._api_token_cache[key]
if cache_key in cls._api_token_cache:
return cls._api_token_cache[cache_key]['token']
return cls._api_token_cache[cache_key]["token"]
# get token
headers = {
'x-api-key': access_key,
'x-timestamp': str(int(now)),
'x-signature': cls._calculate_sign(access_key, secret_key, int(now))
"x-api-key": access_key,
"x-timestamp": str(int(now)),
"x-signature": cls._calculate_sign(access_key, secret_key, int(now)),
}
param = {
'uid': user_id,
'channel': ''
}
param = {"uid": user_id, "channel": ""}
response = get(
str(cls._api_base_url / 'grant' / 'token'),
params=param,
headers=headers
)
response = get(str(cls._api_base_url / "grant" / "token"), params=param, headers=headers)
if response.status_code != 200:
raise Exception(f'Failed to connect to aippt: {response.text}')
raise Exception(f"Failed to connect to aippt: {response.text}")
response = response.json()
if response.get('code') != 0:
if response.get("code") != 0:
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
token = response.get('data', {}).get('token')
expire = response.get('data', {}).get('time_expire')
token = response.get("data", {}).get("token")
expire = response.get("data", {}).get("time_expire")
with cls._api_token_cache_lock:
cls._api_token_cache[cache_key] = {
'token': token,
'expire': now + expire
}
cls._api_token_cache[cache_key] = {"token": token, "expire": now + expire}
return token
@ -391,11 +354,9 @@ class AIPPTGenerateTool(BuiltinTool):
def _calculate_sign(cls, access_key: str, secret_key: str, timestamp: int) -> str:
return b64encode(
hmac_new(
key=secret_key.encode('utf-8'),
msg=f'GET@/api/grant/token/@{timestamp}'.encode(),
digestmod=sha1
key=secret_key.encode("utf-8"), msg=f"GET@/api/grant/token/@{timestamp}".encode(), digestmod=sha1
).digest()
).decode('utf-8')
).decode("utf-8")
@classmethod
def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[dict], list[dict]]:
@ -408,47 +369,46 @@ class AIPPTGenerateTool(BuiltinTool):
# clear expired styles
now = time()
for key in list(cls._style_cache.keys()):
if cls._style_cache[key]['expire'] < now:
if cls._style_cache[key]["expire"] < now:
del cls._style_cache[key]
key = f'{credentials["aippt_access_key"]}#@#{user_id}'
if key in cls._style_cache:
return cls._style_cache[key]['colors'], cls._style_cache[key]['styles']
return cls._style_cache[key]["colors"], cls._style_cache[key]["styles"]
headers = {
'x-channel': '',
'x-api-key': credentials['aippt_access_key'],
'x-token': cls._get_api_token(credentials=credentials, user_id=user_id)
"x-channel": "",
"x-api-key": credentials["aippt_access_key"],
"x-token": cls._get_api_token(credentials=credentials, user_id=user_id),
}
response = get(
str(cls._api_base_url / 'template_component' / 'suit' / 'select'),
headers=headers
)
response = get(str(cls._api_base_url / "template_component" / "suit" / "select"), headers=headers)
if response.status_code != 200:
raise Exception(f'Failed to connect to aippt: {response.text}')
raise Exception(f"Failed to connect to aippt: {response.text}")
response = response.json()
if response.get('code') != 0:
if response.get("code") != 0:
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
colors = [{
'id': f'id-{item.get("id")}',
'name': item.get('name'),
'en_name': item.get('en_name', item.get('name')),
} for item in response.get('data', {}).get('colour') or []]
styles = [{
'id': f'id-{item.get("id")}',
'name': item.get('title'),
} for item in response.get('data', {}).get('suit_style') or []]
colors = [
{
"id": f'id-{item.get("id")}',
"name": item.get("name"),
"en_name": item.get("en_name", item.get("name")),
}
for item in response.get("data", {}).get("colour") or []
]
styles = [
{
"id": f'id-{item.get("id")}',
"name": item.get("title"),
}
for item in response.get("data", {}).get("suit_style") or []
]
with cls._style_cache_lock:
cls._style_cache[key] = {
'colors': colors,
'styles': styles,
'expire': now + 60 * 60
}
cls._style_cache[key] = {"colors": colors, "styles": styles, "expire": now + 60 * 60}
return colors, styles
@ -459,44 +419,39 @@ class AIPPTGenerateTool(BuiltinTool):
:param credentials: the credentials
:return: Tuple[list[dict[id, color]], list[dict[id, style]]
"""
if not self.runtime.credentials.get('aippt_access_key') or not self.runtime.credentials.get('aippt_secret_key'):
raise Exception('Please provide aippt credentials')
if not self.runtime.credentials.get("aippt_access_key") or not self.runtime.credentials.get("aippt_secret_key"):
raise Exception("Please provide aippt credentials")
return self._get_styles(credentials=self.runtime.credentials, user_id=user_id)
def _get_suit(self, style_id: int, colour_id: int) -> int:
"""
Get suit
"""
headers = {
'x-channel': '',
'x-api-key': self.runtime.credentials['aippt_access_key'],
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id='__dify_system__')
"x-channel": "",
"x-api-key": self.runtime.credentials["aippt_access_key"],
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id="__dify_system__"),
}
response = get(
str(self._api_base_url / 'template_component' / 'suit' / 'search'),
str(self._api_base_url / "template_component" / "suit" / "search"),
headers=headers,
params={
'style_id': style_id,
'colour_id': colour_id,
'page': 1,
'page_size': 1
}
params={"style_id": style_id, "colour_id": colour_id, "page": 1, "page_size": 1},
)
if response.status_code != 200:
raise Exception(f'Failed to connect to aippt: {response.text}')
raise Exception(f"Failed to connect to aippt: {response.text}")
response = response.json()
if response.get('code') != 0:
if response.get("code") != 0:
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
if len(response.get('data', {}).get('list') or []) > 0:
return response.get('data', {}).get('list')[0].get('id')
raise Exception('Failed to get suit, the suit does not exist, please check the style and color')
if len(response.get("data", {}).get("list") or []) > 0:
return response.get("data", {}).get("list")[0].get("id")
raise Exception("Failed to get suit, the suit does not exist, please check the style and color")
def get_runtime_parameters(self) -> list[ToolParameter]:
"""
Get runtime parameters
@ -504,43 +459,40 @@ class AIPPTGenerateTool(BuiltinTool):
Override this method to add runtime parameters to the tool.
"""
try:
colors, styles = self.get_styles(user_id='__dify_system__')
colors, styles = self.get_styles(user_id="__dify_system__")
except Exception as e:
colors, styles = [
{'id': '-1', 'name': '__default__', 'en_name': '__default__'}
], [
{'id': '-1', 'name': '__default__', 'en_name': '__default__'}
]
colors, styles = (
[{"id": "-1", "name": "__default__", "en_name": "__default__"}],
[{"id": "-1", "name": "__default__", "en_name": "__default__"}],
)
return [
ToolParameter(
name='color',
label=I18nObject(zh_Hans='颜色', en_US='Color'),
human_description=I18nObject(zh_Hans='颜色', en_US='Color'),
name="color",
label=I18nObject(zh_Hans="颜色", en_US="Color"),
human_description=I18nObject(zh_Hans="颜色", en_US="Color"),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
required=False,
default=colors[0]['id'],
default=colors[0]["id"],
options=[
ToolParameterOption(
value=color['id'],
label=I18nObject(zh_Hans=color['name'], en_US=color['en_name'])
) for color in colors
]
value=color["id"], label=I18nObject(zh_Hans=color["name"], en_US=color["en_name"])
)
for color in colors
],
),
ToolParameter(
name='style',
label=I18nObject(zh_Hans='风格', en_US='Style'),
human_description=I18nObject(zh_Hans='风格', en_US='Style'),
name="style",
label=I18nObject(zh_Hans="风格", en_US="Style"),
human_description=I18nObject(zh_Hans="风格", en_US="Style"),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
required=False,
default=styles[0]['id'],
default=styles[0]["id"],
options=[
ToolParameterOption(
value=style['id'],
label=I18nObject(zh_Hans=style['name'], en_US=style['name'])
) for style in styles
]
ToolParameterOption(value=style["id"], label=I18nObject(zh_Hans=style["name"], en_US=style["name"]))
for style in styles
],
),
]
]

View File

@ -13,7 +13,7 @@ class AlphaVantageProvider(BuiltinToolProviderController):
"credentials": credentials,
}
).invoke(
user_id='',
user_id="",
tool_parameters={
"code": "AAPL", # Apple Inc.
},

View File

@ -9,17 +9,16 @@ ALPHAVANTAGE_API_URL = "https://www.alphavantage.co/query"
class QueryStockTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
stock_code = tool_parameters.get('code', '')
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
stock_code = tool_parameters.get("code", "")
if not stock_code:
return self.create_text_message('Please tell me your stock code')
return self.create_text_message("Please tell me your stock code")
if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'):
if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"):
return self.create_text_message("Alpha Vantage API key is required.")
params = {
@ -27,7 +26,7 @@ class QueryStockTool(BuiltinTool):
"symbol": stock_code,
"outputsize": "compact",
"datatype": "json",
"apikey": self.runtime.credentials['api_key']
"apikey": self.runtime.credentials["api_key"],
}
response = requests.get(url=ALPHAVANTAGE_API_URL, params=params)
response.raise_for_status()
@ -35,15 +34,15 @@ class QueryStockTool(BuiltinTool):
return self.create_json_message(result)
def _handle_response(self, response: dict[str, Any]) -> dict[str, Any]:
result = response.get('Time Series (Daily)', {})
result = response.get("Time Series (Daily)", {})
if not result:
return {}
stock_result = {}
for k, v in result.items():
stock_result[k] = {}
stock_result[k]['open'] = v.get('1. open')
stock_result[k]['high'] = v.get('2. high')
stock_result[k]['low'] = v.get('3. low')
stock_result[k]['close'] = v.get('4. close')
stock_result[k]['volume'] = v.get('5. volume')
stock_result[k]["open"] = v.get("1. open")
stock_result[k]["high"] = v.get("2. high")
stock_result[k]["low"] = v.get("3. low")
stock_result[k]["close"] = v.get("4. close")
stock_result[k]["volume"] = v.get("5. volume")
return stock_result

View File

@ -11,11 +11,10 @@ class ArxivProvider(BuiltinToolProviderController):
"credentials": credentials,
}
).invoke(
user_id='',
user_id="",
tool_parameters={
"query": "John Doe",
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -8,6 +8,8 @@ from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
logger = logging.getLogger(__name__)
class ArxivAPIWrapper(BaseModel):
"""Wrapper around ArxivAPI.
@ -86,11 +88,13 @@ class ArxivAPIWrapper(BaseModel):
class ArxivSearchInput(BaseModel):
query: str = Field(..., description="Search query.")
class ArxivSearchTool(BuiltinTool):
"""
A tool for searching articles on Arxiv.
"""
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
"""
Invokes the Arxiv search tool with the given user ID and tool parameters.
@ -100,15 +104,16 @@ class ArxivSearchTool(BuiltinTool):
tool_parameters (dict[str, Any]): The parameters for the tool, including the 'query' parameter.
Returns:
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, which can be a single message or a list of messages.
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation,
which can be a single message or a list of messages.
"""
query = tool_parameters.get('query', '')
query = tool_parameters.get("query", "")
if not query:
return self.create_text_message('Please input query')
return self.create_text_message("Please input query")
arxiv = ArxivAPIWrapper()
response = arxiv.run(query)
return self.create_text_message(self.summary(user_id=user_id, content=response))

View File

@ -11,15 +11,14 @@ class SageMakerProvider(BuiltinToolProviderController):
"credentials": credentials,
}
).invoke(
user_id='',
user_id="",
tool_parameters={
"sagemaker_endpoint" : "",
"sagemaker_endpoint": "",
"query": "misaka mikoto",
"candidate_texts" : "hello$$$hello world",
"topk" : 5,
"aws_region" : ""
"candidate_texts": "hello$$$hello world",
"topk": 5,
"aws_region": "",
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -12,6 +12,7 @@ from core.tools.tool.builtin_tool import BuiltinTool
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class GuardrailParameters(BaseModel):
guardrail_id: str = Field(..., description="The identifier of the guardrail")
guardrail_version: str = Field(..., description="The version of the guardrail")
@ -19,35 +20,35 @@ class GuardrailParameters(BaseModel):
text: str = Field(..., description="The text to apply the guardrail to")
aws_region: str = Field(..., description="AWS region for the Bedrock client")
class ApplyGuardrailTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
Invoke the ApplyGuardrail tool
"""
try:
# Validate and parse input parameters
params = GuardrailParameters(**tool_parameters)
# Initialize AWS client
bedrock_client = boto3.client('bedrock-runtime', region_name=params.aws_region)
bedrock_client = boto3.client("bedrock-runtime", region_name=params.aws_region)
# Apply guardrail
response = bedrock_client.apply_guardrail(
guardrailIdentifier=params.guardrail_id,
guardrailVersion=params.guardrail_version,
source=params.source,
content=[{"text": {"text": params.text}}]
content=[{"text": {"text": params.text}}],
)
logger.info(f"Raw response from AWS: {json.dumps(response, indent=2)}")
# Check for empty response
if not response:
return self.create_text_message(text="Received empty response from AWS Bedrock.")
# Process the result
action = response.get("action", "No action specified")
outputs = response.get("outputs", [])
@ -58,9 +59,12 @@ class ApplyGuardrailTool(BuiltinTool):
formatted_assessments = []
for assessment in assessments:
for policy_type, policy_data in assessment.items():
if isinstance(policy_data, dict) and 'topics' in policy_data:
for topic in policy_data['topics']:
formatted_assessments.append(f"Policy: {policy_type}, Topic: {topic['name']}, Type: {topic['type']}, Action: {topic['action']}")
if isinstance(policy_data, dict) and "topics" in policy_data:
for topic in policy_data["topics"]:
formatted_assessments.append(
f"Policy: {policy_type}, Topic: {topic['name']}, Type: {topic['type']},"
f" Action: {topic['action']}"
)
else:
formatted_assessments.append(f"Policy: {policy_type}, Data: {policy_data}")
@ -68,19 +72,19 @@ class ApplyGuardrailTool(BuiltinTool):
result += f"Output: {output}\n "
if formatted_assessments:
result += "Assessments:\n " + "\n ".join(formatted_assessments) + "\n "
# result += f"Full response: {json.dumps(response, indent=2, ensure_ascii=False)}"
# result += f"Full response: {json.dumps(response, indent=2, ensure_ascii=False)}"
return self.create_text_message(text=result)
except BotoCoreError as e:
error_message = f'AWS service error: {str(e)}'
error_message = f"AWS service error: {str(e)}"
logger.error(error_message, exc_info=True)
return self.create_text_message(text=error_message)
except json.JSONDecodeError as e:
error_message = f'JSON parsing error: {str(e)}'
error_message = f"JSON parsing error: {str(e)}"
logger.error(error_message, exc_info=True)
return self.create_text_message(text=error_message)
except Exception as e:
error_message = f'An unexpected error occurred: {str(e)}'
error_message = f"An unexpected error occurred: {str(e)}"
logger.error(error_message, exc_info=True)
return self.create_text_message(text=error_message)
return self.create_text_message(text=error_message)

View File

@ -11,78 +11,81 @@ class LambdaTranslateUtilsTool(BuiltinTool):
lambda_client: Any = None
def _invoke_lambda(self, text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name):
msg = {
"src_content":text_content,
"src_lang": src_lang,
"dest_lang":dest_lang,
msg = {
"src_content": text_content,
"src_lang": src_lang,
"dest_lang": dest_lang,
"dictionary_id": dictionary_name,
"request_type" : request_type,
"model_id" : model_id
"request_type": request_type,
"model_id": model_id,
}
invoke_response = self.lambda_client.invoke(FunctionName=lambda_name,
InvocationType='RequestResponse',
Payload=json.dumps(msg))
response_body = invoke_response['Payload']
invoke_response = self.lambda_client.invoke(
FunctionName=lambda_name, InvocationType="RequestResponse", Payload=json.dumps(msg)
)
response_body = invoke_response["Payload"]
response_str = response_body.read().decode("unicode_escape")
return response_str
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
invoke tools
"""
line = 0
try:
if not self.lambda_client:
aws_region = tool_parameters.get('aws_region')
aws_region = tool_parameters.get("aws_region")
if aws_region:
self.lambda_client = boto3.client("lambda", region_name=aws_region)
else:
self.lambda_client = boto3.client("lambda")
line = 1
text_content = tool_parameters.get('text_content', '')
text_content = tool_parameters.get("text_content", "")
if not text_content:
return self.create_text_message('Please input text_content')
return self.create_text_message("Please input text_content")
line = 2
src_lang = tool_parameters.get('src_lang', '')
src_lang = tool_parameters.get("src_lang", "")
if not src_lang:
return self.create_text_message('Please input src_lang')
return self.create_text_message("Please input src_lang")
line = 3
dest_lang = tool_parameters.get('dest_lang', '')
dest_lang = tool_parameters.get("dest_lang", "")
if not dest_lang:
return self.create_text_message('Please input dest_lang')
return self.create_text_message("Please input dest_lang")
line = 4
lambda_name = tool_parameters.get('lambda_name', '')
lambda_name = tool_parameters.get("lambda_name", "")
if not lambda_name:
return self.create_text_message('Please input lambda_name')
return self.create_text_message("Please input lambda_name")
line = 5
request_type = tool_parameters.get('request_type', '')
request_type = tool_parameters.get("request_type", "")
if not request_type:
return self.create_text_message('Please input request_type')
return self.create_text_message("Please input request_type")
line = 6
model_id = tool_parameters.get('model_id', '')
model_id = tool_parameters.get("model_id", "")
if not model_id:
return self.create_text_message('Please input model_id')
return self.create_text_message("Please input model_id")
line = 7
dictionary_name = tool_parameters.get('dictionary_name', '')
dictionary_name = tool_parameters.get("dictionary_name", "")
if not dictionary_name:
return self.create_text_message('Please input dictionary_name')
result = self._invoke_lambda(text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name)
return self.create_text_message("Please input dictionary_name")
result = self._invoke_lambda(
text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name
)
return self.create_text_message(text=result)
except Exception as e:
return self.create_text_message(f'Exception {str(e)}, line : {line}')
return self.create_text_message(f"Exception {str(e)}, line : {line}")

View File

@ -10,7 +10,7 @@ description:
human:
en_US: A util tools for LLM translation, extra deployment is needed on AWS. Please refer Github Repo - https://github.com/ybalbert001/dynamodb-rag
zh_Hans: 大语言模型翻译工具(专词映射获取)需要在AWS上进行额外部署可参考Github Repo - https://github.com/ybalbert001/dynamodb-rag
pt_BR: A util tools for LLM translation, specfic Lambda Function deployment is needed on AWS. Please refer Github Repo - https://github.com/ybalbert001/dynamodb-rag
pt_BR: A util tools for LLM translation, specific Lambda Function deployment is needed on AWS. Please refer Github Repo - https://github.com/ybalbert001/dynamodb-rag
llm: A util tools for translation.
parameters:
- name: text_content

View File

@ -18,54 +18,53 @@ class LambdaYamlToJsonTool(BuiltinTool):
lambda_client: Any = None
def _invoke_lambda(self, lambda_name: str, yaml_content: str) -> str:
msg = {
"body": yaml_content
}
msg = {"body": yaml_content}
logger.info(json.dumps(msg))
invoke_response = self.lambda_client.invoke(FunctionName=lambda_name,
InvocationType='RequestResponse',
Payload=json.dumps(msg))
response_body = invoke_response['Payload']
invoke_response = self.lambda_client.invoke(
FunctionName=lambda_name, InvocationType="RequestResponse", Payload=json.dumps(msg)
)
response_body = invoke_response["Payload"]
response_str = response_body.read().decode("utf-8")
resp_json = json.loads(response_str)
logger.info(resp_json)
if resp_json['statusCode'] != 200:
if resp_json["statusCode"] != 200:
raise Exception(f"Invalid status code: {response_str}")
return resp_json['body']
return resp_json["body"]
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
invoke tools
"""
try:
if not self.lambda_client:
aws_region = tool_parameters.get('aws_region') # todo: move aws_region out, and update client region
aws_region = tool_parameters.get("aws_region") # todo: move aws_region out, and update client region
if aws_region:
self.lambda_client = boto3.client("lambda", region_name=aws_region)
else:
self.lambda_client = boto3.client("lambda")
yaml_content = tool_parameters.get('yaml_content', '')
yaml_content = tool_parameters.get("yaml_content", "")
if not yaml_content:
return self.create_text_message('Please input yaml_content')
return self.create_text_message("Please input yaml_content")
lambda_name = tool_parameters.get('lambda_name', '')
lambda_name = tool_parameters.get("lambda_name", "")
if not lambda_name:
return self.create_text_message('Please input lambda_name')
logger.debug(f'{json.dumps(tool_parameters, indent=2, ensure_ascii=False)}')
return self.create_text_message("Please input lambda_name")
logger.debug(f"{json.dumps(tool_parameters, indent=2, ensure_ascii=False)}")
result = self._invoke_lambda(lambda_name, yaml_content)
logger.debug(result)
return self.create_text_message(result)
except Exception as e:
return self.create_text_message(f'Exception: {str(e)}')
return self.create_text_message(f"Exception: {str(e)}")
console_handler.flush()
console_handler.flush()

View File

@ -1,4 +1,5 @@
import json
import operator
from typing import Any, Union
import boto3
@ -9,37 +10,33 @@ from core.tools.tool.builtin_tool import BuiltinTool
class SageMakerReRankTool(BuiltinTool):
sagemaker_client: Any = None
sagemaker_endpoint:str = None
topk:int = None
sagemaker_endpoint: str = None
topk: int = None
def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint:str):
inputs = [query_input]*len(docs)
def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str):
inputs = [query_input] * len(docs)
response_model = self.sagemaker_client.invoke_endpoint(
EndpointName=rerank_endpoint,
Body=json.dumps(
{
"inputs": inputs,
"docs": docs
}
),
Body=json.dumps({"inputs": inputs, "docs": docs}),
ContentType="application/json",
)
json_str = response_model['Body'].read().decode('utf8')
json_str = response_model["Body"].read().decode("utf8")
json_obj = json.loads(json_str)
scores = json_obj['scores']
scores = json_obj["scores"]
return scores if isinstance(scores, list) else [scores]
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
invoke tools
"""
line = 0
try:
if not self.sagemaker_client:
aws_region = tool_parameters.get('aws_region')
aws_region = tool_parameters.get("aws_region")
if aws_region:
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
else:
@ -47,25 +44,25 @@ class SageMakerReRankTool(BuiltinTool):
line = 1
if not self.sagemaker_endpoint:
self.sagemaker_endpoint = tool_parameters.get('sagemaker_endpoint')
self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint")
line = 2
if not self.topk:
self.topk = tool_parameters.get('topk', 5)
self.topk = tool_parameters.get("topk", 5)
line = 3
query = tool_parameters.get('query', '')
query = tool_parameters.get("query", "")
if not query:
return self.create_text_message('Please input query')
return self.create_text_message("Please input query")
line = 4
candidate_texts = tool_parameters.get('candidate_texts')
candidate_texts = tool_parameters.get("candidate_texts")
if not candidate_texts:
return self.create_text_message('Please input candidate_texts')
return self.create_text_message("Please input candidate_texts")
line = 5
candidate_docs = json.loads(candidate_texts)
docs = [ item.get('content') for item in candidate_docs ]
docs = [item.get("content") for item in candidate_docs]
line = 6
scores = self._sagemaker_rerank(query_input=query, docs=docs, rerank_endpoint=self.sagemaker_endpoint)
@ -75,10 +72,10 @@ class SageMakerReRankTool(BuiltinTool):
candidate_docs[idx]["score"] = scores[idx]
line = 8
sorted_candidate_docs = sorted(candidate_docs, key=lambda x: x['score'], reverse=True)
sorted_candidate_docs = sorted(candidate_docs, key=operator.itemgetter("score"), reverse=True)
line = 9
return [ self.create_json_message(res) for res in sorted_candidate_docs[:self.topk] ]
return [self.create_json_message(res) for res in sorted_candidate_docs[: self.topk]]
except Exception as e:
return self.create_text_message(f'Exception {str(e)}, line : {line}')
return self.create_text_message(f"Exception {str(e)}, line : {line}")

View File

@ -14,82 +14,88 @@ class TTSModelType(Enum):
CloneVoice_CrossLingual = "CloneVoice_CrossLingual"
InstructVoice = "InstructVoice"
class SageMakerTTSTool(BuiltinTool):
sagemaker_client: Any = None
sagemaker_endpoint:str = None
s3_client : Any = None
comprehend_client : Any = None
sagemaker_endpoint: str = None
s3_client: Any = None
comprehend_client: Any = None
def _detect_lang_code(self, content:str, map_dict:dict=None):
map_dict = {
"zh" : "<|zh|>",
"en" : "<|en|>",
"ja" : "<|jp|>",
"zh-TW" : "<|yue|>",
"ko" : "<|ko|>"
}
def _detect_lang_code(self, content: str, map_dict: dict = None):
map_dict = {"zh": "<|zh|>", "en": "<|en|>", "ja": "<|jp|>", "zh-TW": "<|yue|>", "ko": "<|ko|>"}
response = self.comprehend_client.detect_dominant_language(Text=content)
language_code = response['Languages'][0]['LanguageCode']
return map_dict.get(language_code, '<|zh|>')
language_code = response["Languages"][0]["LanguageCode"]
return map_dict.get(language_code, "<|zh|>")
def _build_tts_payload(self, model_type:str, content_text:str, model_role:str, prompt_text:str, prompt_audio:str, instruct_text:str):
def _build_tts_payload(
self,
model_type: str,
content_text: str,
model_role: str,
prompt_text: str,
prompt_audio: str,
instruct_text: str,
):
if model_type == TTSModelType.PresetVoice.value and model_role:
return { "tts_text" : content_text, "role" : model_role }
return {"tts_text": content_text, "role": model_role}
if model_type == TTSModelType.CloneVoice.value and prompt_text and prompt_audio:
return { "tts_text" : content_text, "prompt_text": prompt_text, "prompt_audio" : prompt_audio }
if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio:
return {"tts_text": content_text, "prompt_text": prompt_text, "prompt_audio": prompt_audio}
if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio:
lang_tag = self._detect_lang_code(content_text)
return { "tts_text" : f"{content_text}", "prompt_audio" : prompt_audio, "lang_tag" : lang_tag }
if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role:
return { "tts_text" : content_text, "role" : model_role, "instruct_text" : instruct_text }
return {"tts_text": f"{content_text}", "prompt_audio": prompt_audio, "lang_tag": lang_tag}
if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role:
return {"tts_text": content_text, "role": model_role, "instruct_text": instruct_text}
raise RuntimeError(f"Invalid params for {model_type}")
def _invoke_sagemaker(self, payload:dict, endpoint:str):
def _invoke_sagemaker(self, payload: dict, endpoint: str):
response_model = self.sagemaker_client.invoke_endpoint(
EndpointName=endpoint,
Body=json.dumps(payload),
ContentType="application/json",
)
json_str = response_model['Body'].read().decode('utf8')
json_str = response_model["Body"].read().decode("utf8")
json_obj = json.loads(json_str)
return json_obj
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
invoke tools
"""
try:
if not self.sagemaker_client:
aws_region = tool_parameters.get('aws_region')
aws_region = tool_parameters.get("aws_region")
if aws_region:
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
self.s3_client = boto3.client("s3", region_name=aws_region)
self.comprehend_client = boto3.client('comprehend', region_name=aws_region)
self.comprehend_client = boto3.client("comprehend", region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime")
self.s3_client = boto3.client("s3")
self.comprehend_client = boto3.client('comprehend')
self.comprehend_client = boto3.client("comprehend")
if not self.sagemaker_endpoint:
self.sagemaker_endpoint = tool_parameters.get('sagemaker_endpoint')
self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint")
tts_text = tool_parameters.get('tts_text')
tts_infer_type = tool_parameters.get('tts_infer_type')
tts_text = tool_parameters.get("tts_text")
tts_infer_type = tool_parameters.get("tts_infer_type")
voice = tool_parameters.get('voice')
mock_voice_audio = tool_parameters.get('mock_voice_audio')
mock_voice_text = tool_parameters.get('mock_voice_text')
voice_instruct_prompt = tool_parameters.get('voice_instruct_prompt')
payload = self._build_tts_payload(tts_infer_type, tts_text, voice, mock_voice_text, mock_voice_audio, voice_instruct_prompt)
voice = tool_parameters.get("voice")
mock_voice_audio = tool_parameters.get("mock_voice_audio")
mock_voice_text = tool_parameters.get("mock_voice_text")
voice_instruct_prompt = tool_parameters.get("voice_instruct_prompt")
payload = self._build_tts_payload(
tts_infer_type, tts_text, voice, mock_voice_text, mock_voice_audio, voice_instruct_prompt
)
result = self._invoke_sagemaker(payload, self.sagemaker_endpoint)
return self.create_text_message(text=result['s3_presign_url'])
return self.create_text_message(text=result["s3_presign_url"])
except Exception as e:
return self.create_text_message(f'Exception {str(e)}')
return self.create_text_message(f"Exception {str(e)}")

View File

@ -13,12 +13,8 @@ class AzureDALLEProvider(BuiltinToolProviderController):
"credentials": credentials,
}
).invoke(
user_id='',
tool_parameters={
"prompt": "cute girl, blue eyes, white hair, anime style",
"size": "square",
"n": 1
},
user_id="",
tool_parameters={"prompt": "cute girl, blue eyes, white hair, anime style", "size": "square", "n": 1},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -9,47 +9,48 @@ from core.tools.tool.builtin_tool import BuiltinTool
class DallE3Tool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
invoke tools
"""
client = AzureOpenAI(
api_version=self.runtime.credentials['azure_openai_api_version'],
azure_endpoint=self.runtime.credentials['azure_openai_base_url'],
api_key=self.runtime.credentials['azure_openai_api_key'],
api_version=self.runtime.credentials["azure_openai_api_version"],
azure_endpoint=self.runtime.credentials["azure_openai_base_url"],
api_key=self.runtime.credentials["azure_openai_api_key"],
)
SIZE_MAPPING = {
'square': '1024x1024',
'vertical': '1024x1792',
'horizontal': '1792x1024',
"square": "1024x1024",
"vertical": "1024x1792",
"horizontal": "1792x1024",
}
# prompt
prompt = tool_parameters.get('prompt', '')
prompt = tool_parameters.get("prompt", "")
if not prompt:
return self.create_text_message('Please input prompt')
return self.create_text_message("Please input prompt")
# get size
size = SIZE_MAPPING[tool_parameters.get('size', 'square')]
size = SIZE_MAPPING[tool_parameters.get("size", "square")]
# get n
n = tool_parameters.get('n', 1)
n = tool_parameters.get("n", 1)
# get quality
quality = tool_parameters.get('quality', 'standard')
if quality not in ['standard', 'hd']:
return self.create_text_message('Invalid quality')
quality = tool_parameters.get("quality", "standard")
if quality not in {"standard", "hd"}:
return self.create_text_message("Invalid quality")
# get style
style = tool_parameters.get('style', 'vivid')
if style not in ['natural', 'vivid']:
return self.create_text_message('Invalid style')
style = tool_parameters.get("style", "vivid")
if style not in {"natural", "vivid"}:
return self.create_text_message("Invalid style")
# set extra body
seed_id = tool_parameters.get('seed_id', self._generate_random_id(8))
extra_body = {'seed': seed_id}
seed_id = tool_parameters.get("seed_id", self._generate_random_id(8))
extra_body = {"seed": seed_id}
# call openapi dalle3
model = self.runtime.credentials['azure_openai_api_model_name']
model = self.runtime.credentials["azure_openai_api_model_name"]
response = client.images.generate(
prompt=prompt,
model=model,
@ -58,21 +59,25 @@ class DallE3Tool(BuiltinTool):
extra_body=extra_body,
style=style,
quality=quality,
response_format='b64_json'
response_format="b64_json",
)
result = []
for image in response.data:
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
meta={'mime_type': 'image/png'},
save_as=self.VARIABLE_KEY.IMAGE.value))
result.append(self.create_text_message(f'\nGenerate image source to Seed ID: {seed_id}'))
result.append(
self.create_blob_message(
blob=b64decode(image.b64_json),
meta={"mime_type": "image/png"},
save_as=self.VariableKey.IMAGE.value,
)
)
result.append(self.create_text_message(f"\nGenerate image source to Seed ID: {seed_id}"))
return result
@staticmethod
def _generate_random_id(length=8):
characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
random_id = ''.join(random.choices(characters, k=length))
characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
random_id = "".join(random.choices(characters, k=length))
return random_id

View File

@ -8,142 +8,135 @@ from core.tools.tool.builtin_tool import BuiltinTool
class BingSearchTool(BuiltinTool):
url: str = 'https://api.bing.microsoft.com/v7.0/search'
url: str = "https://api.bing.microsoft.com/v7.0/search"
def _invoke_bing(self,
user_id: str,
server_url: str,
subscription_key: str, query: str, limit: int,
result_type: str, market: str, lang: str,
filters: list[str]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke_bing(
self,
user_id: str,
server_url: str,
subscription_key: str,
query: str,
limit: int,
result_type: str,
market: str,
lang: str,
filters: list[str],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke bing search
invoke bing search
"""
market_code = f'{lang}-{market}'
accept_language = f'{lang},{market_code};q=0.9'
headers = {
'Ocp-Apim-Subscription-Key': subscription_key,
'Accept-Language': accept_language
}
market_code = f"{lang}-{market}"
accept_language = f"{lang},{market_code};q=0.9"
headers = {"Ocp-Apim-Subscription-Key": subscription_key, "Accept-Language": accept_language}
query = quote(query)
server_url = f'{server_url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filters)}'
response = get(server_url, headers=headers)
if response.status_code != 200:
raise Exception(f'Error {response.status_code}: {response.text}')
response = response.json()
search_results = response['webPages']['value'][:limit] if 'webPages' in response else []
related_searches = response['relatedSearches']['value'] if 'relatedSearches' in response else []
entities = response['entities']['value'] if 'entities' in response else []
news = response['news']['value'] if 'news' in response else []
computation = response['computation']['value'] if 'computation' in response else None
raise Exception(f"Error {response.status_code}: {response.text}")
if result_type == 'link':
response = response.json()
search_results = response["webPages"]["value"][:limit] if "webPages" in response else []
related_searches = response["relatedSearches"]["value"] if "relatedSearches" in response else []
entities = response["entities"]["value"] if "entities" in response else []
news = response["news"]["value"] if "news" in response else []
computation = response["computation"]["value"] if "computation" in response else None
if result_type == "link":
results = []
if search_results:
for result in search_results:
url = f': {result["url"]}' if "url" in result else ""
results.append(self.create_text_message(
text=f'{result["name"]}{url}'
))
results.append(self.create_text_message(text=f'{result["name"]}{url}'))
if entities:
for entity in entities:
url = f': {entity["url"]}' if "url" in entity else ""
results.append(self.create_text_message(
text=f'{entity.get("name", "")}{url}'
))
results.append(self.create_text_message(text=f'{entity.get("name", "")}{url}'))
if news:
for news_item in news:
url = f': {news_item["url"]}' if "url" in news_item else ""
results.append(self.create_text_message(
text=f'{news_item.get("name", "")}{url}'
))
results.append(self.create_text_message(text=f'{news_item.get("name", "")}{url}'))
if related_searches:
for related in related_searches:
url = f': {related["displayText"]}' if "displayText" in related else ""
results.append(self.create_text_message(
text=f'{related.get("displayText", "")}{url}'
))
results.append(self.create_text_message(text=f'{related.get("displayText", "")}{url}'))
return results
else:
# construct text
text = ''
text = ""
if search_results:
for i, result in enumerate(search_results):
text += f'{i+1}: {result.get("name", "")} - {result.get("snippet", "")}\n'
text += f'{i + 1}: {result.get("name", "")} - {result.get("snippet", "")}\n'
if computation and 'expression' in computation and 'value' in computation:
text += '\nComputation:\n'
if computation and "expression" in computation and "value" in computation:
text += "\nComputation:\n"
text += f'{computation["expression"]} = {computation["value"]}\n'
if entities:
text += '\nEntities:\n'
text += "\nEntities:\n"
for entity in entities:
url = f'- {entity["url"]}' if "url" in entity else ""
text += f'{entity.get("name", "")}{url}\n'
if news:
text += '\nNews:\n'
text += "\nNews:\n"
for news_item in news:
url = f'- {news_item["url"]}' if "url" in news_item else ""
text += f'{news_item.get("name", "")}{url}\n'
if related_searches:
text += '\n\nRelated Searches:\n'
text += "\n\nRelated Searches:\n"
for related in related_searches:
url = f'- {related["webSearchUrl"]}' if "webSearchUrl" in related else ""
text += f'{related.get("displayText", "")}{url}\n'
return self.create_text_message(text=self.summary(user_id=user_id, content=text))
def validate_credentials(self, credentials: dict[str, Any], tool_parameters: dict[str, Any]) -> None:
key = credentials.get('subscription_key')
key = credentials.get("subscription_key")
if not key:
raise Exception('subscription_key is required')
server_url = credentials.get('server_url')
raise Exception("subscription_key is required")
server_url = credentials.get("server_url")
if not server_url:
server_url = self.url
query = tool_parameters.get('query')
query = tool_parameters.get("query")
if not query:
raise Exception('query is required')
limit = min(tool_parameters.get('limit', 5), 10)
result_type = tool_parameters.get('result_type', 'text') or 'text'
raise Exception("query is required")
market = tool_parameters.get('market', 'US')
lang = tool_parameters.get('language', 'en')
limit = min(tool_parameters.get("limit", 5), 10)
result_type = tool_parameters.get("result_type", "text") or "text"
market = tool_parameters.get("market", "US")
lang = tool_parameters.get("language", "en")
filter = []
if credentials.get('allow_entities', False):
filter.append('Entities')
if credentials.get("allow_entities", False):
filter.append("Entities")
if credentials.get('allow_computation', False):
filter.append('Computation')
if credentials.get("allow_computation", False):
filter.append("Computation")
if credentials.get('allow_news', False):
filter.append('News')
if credentials.get("allow_news", False):
filter.append("News")
if credentials.get('allow_related_searches', False):
filter.append('RelatedSearches')
if credentials.get("allow_related_searches", False):
filter.append("RelatedSearches")
if credentials.get('allow_web_pages', False):
filter.append('WebPages')
if credentials.get("allow_web_pages", False):
filter.append("WebPages")
if not filter:
raise Exception('At least one filter is required')
raise Exception("At least one filter is required")
self._invoke_bing(
user_id='test',
user_id="test",
server_url=server_url,
subscription_key=key,
query=query,
@ -151,50 +144,51 @@ class BingSearchTool(BuiltinTool):
result_type=result_type,
market=market,
lang=lang,
filters=filter
filters=filter,
)
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
invoke tools
"""
key = self.runtime.credentials.get('subscription_key', None)
key = self.runtime.credentials.get("subscription_key", None)
if not key:
raise Exception('subscription_key is required')
server_url = self.runtime.credentials.get('server_url', None)
raise Exception("subscription_key is required")
server_url = self.runtime.credentials.get("server_url", None)
if not server_url:
server_url = self.url
query = tool_parameters.get('query')
query = tool_parameters.get("query")
if not query:
raise Exception('query is required')
limit = min(tool_parameters.get('limit', 5), 10)
result_type = tool_parameters.get('result_type', 'text') or 'text'
market = tool_parameters.get('market', 'US')
lang = tool_parameters.get('language', 'en')
raise Exception("query is required")
limit = min(tool_parameters.get("limit", 5), 10)
result_type = tool_parameters.get("result_type", "text") or "text"
market = tool_parameters.get("market", "US")
lang = tool_parameters.get("language", "en")
filter = []
if tool_parameters.get('enable_computation', False):
filter.append('Computation')
if tool_parameters.get('enable_entities', False):
filter.append('Entities')
if tool_parameters.get('enable_news', False):
filter.append('News')
if tool_parameters.get('enable_related_search', False):
filter.append('RelatedSearches')
if tool_parameters.get('enable_webpages', False):
filter.append('WebPages')
if tool_parameters.get("enable_computation", False):
filter.append("Computation")
if tool_parameters.get("enable_entities", False):
filter.append("Entities")
if tool_parameters.get("enable_news", False):
filter.append("News")
if tool_parameters.get("enable_related_search", False):
filter.append("RelatedSearches")
if tool_parameters.get("enable_webpages", False):
filter.append("WebPages")
if not filter:
raise Exception('At least one filter is required')
raise Exception("At least one filter is required")
return self._invoke_bing(
user_id=user_id,
server_url=server_url,
@ -204,5 +198,5 @@ class BingSearchTool(BuiltinTool):
result_type=result_type,
market=market,
lang=lang,
filters=filter
)
filters=filter,
)

View File

@ -13,11 +13,10 @@ class BraveProvider(BuiltinToolProviderController):
"credentials": credentials,
}
).invoke(
user_id='',
user_id="",
tool_parameters={
"query": "Sachin Tendulkar",
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -37,7 +37,7 @@ class BraveSearchWrapper(BaseModel):
for item in web_search_results
]
return json.dumps(final_results)
def _search_request(self, query: str) -> list[dict]:
headers = {
"X-Subscription-Token": self.api_key,
@ -55,6 +55,7 @@ class BraveSearchWrapper(BaseModel):
return response.json().get("web", {}).get("results", [])
class BraveSearch(BaseModel):
"""Tool that queries the BraveSearch."""
@ -67,9 +68,7 @@ class BraveSearch(BaseModel):
search_wrapper: BraveSearchWrapper
@classmethod
def from_api_key(
cls, api_key: str, search_kwargs: Optional[dict] = None, **kwargs: Any
) -> "BraveSearch":
def from_api_key(cls, api_key: str, search_kwargs: Optional[dict] = None, **kwargs: Any) -> "BraveSearch":
"""Create a tool from an api key.
Args:
@ -90,6 +89,7 @@ class BraveSearch(BaseModel):
"""Use the tool."""
return self.search_wrapper.run(query)
class BraveSearchTool(BuiltinTool):
"""
Tool for performing a search using Brave search engine.
@ -106,12 +106,12 @@ class BraveSearchTool(BuiltinTool):
Returns:
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation.
"""
query = tool_parameters.get('query', '')
count = tool_parameters.get('count', 3)
api_key = self.runtime.credentials['brave_search_api_key']
query = tool_parameters.get("query", "")
count = tool_parameters.get("count", 3)
api_key = self.runtime.credentials["brave_search_api_key"]
if not query:
return self.create_text_message('Please input query')
return self.create_text_message("Please input query")
tool = BraveSearch.from_api_key(api_key=api_key, search_kwargs={"count": count})
@ -121,4 +121,3 @@ class BraveSearchTool(BuiltinTool):
return self.create_text_message(f"No results found for '{query}' in Tavily")
else:
return self.create_text_message(text=results)

View File

@ -7,16 +7,34 @@ from core.tools.provider.builtin.chart.tools.line import LinearChartTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
# use a business theme
plt.style.use('seaborn-v0_8-darkgrid')
plt.rcParams['axes.unicode_minus'] = False
plt.style.use("seaborn-v0_8-darkgrid")
plt.rcParams["axes.unicode_minus"] = False
def init_fonts():
fonts = findSystemFonts()
popular_unicode_fonts = [
'Arial Unicode MS', 'DejaVu Sans', 'DejaVu Sans Mono', 'DejaVu Serif', 'FreeMono', 'FreeSans', 'FreeSerif',
'Liberation Mono', 'Liberation Sans', 'Liberation Serif', 'Noto Mono', 'Noto Sans', 'Noto Serif', 'Open Sans',
'Roboto', 'Source Code Pro', 'Source Sans Pro', 'Source Serif Pro', 'Ubuntu', 'Ubuntu Mono'
"Arial Unicode MS",
"DejaVu Sans",
"DejaVu Sans Mono",
"DejaVu Serif",
"FreeMono",
"FreeSans",
"FreeSerif",
"Liberation Mono",
"Liberation Sans",
"Liberation Serif",
"Noto Mono",
"Noto Sans",
"Noto Serif",
"Open Sans",
"Roboto",
"Source Code Pro",
"Source Sans Pro",
"Source Serif Pro",
"Ubuntu",
"Ubuntu Mono",
]
supported_fonts = []
@ -25,21 +43,23 @@ def init_fonts():
try:
font = TTFont(font_path)
# get family name
family_name = font['name'].getName(1, 3, 1).toUnicode()
family_name = font["name"].getName(1, 3, 1).toUnicode()
if family_name in popular_unicode_fonts:
supported_fonts.append(family_name)
except:
pass
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams["font.family"] = "sans-serif"
# sort by order of popular_unicode_fonts
for font in popular_unicode_fonts:
if font in supported_fonts:
plt.rcParams['font.sans-serif'] = font
plt.rcParams["font.sans-serif"] = font
break
init_fonts()
class ChartProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
@ -48,11 +68,10 @@ class ChartProvider(BuiltinToolProviderController):
"credentials": credentials,
}
).invoke(
user_id='',
user_id="",
tool_parameters={
"data": "1,3,5,7,9,2,4,6,8,10",
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -8,12 +8,13 @@ from core.tools.tool.builtin_tool import BuiltinTool
class BarChartTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \
-> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
data = tool_parameters.get('data', '')
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
data = tool_parameters.get("data", "")
if not data:
return self.create_text_message('Please input data')
data = data.split(';')
return self.create_text_message("Please input data")
data = data.split(";")
# if all data is int, convert to int
if all(i.isdigit() for i in data):
@ -21,29 +22,27 @@ class BarChartTool(BuiltinTool):
else:
data = [float(i) for i in data]
axis = tool_parameters.get('x_axis') or None
axis = tool_parameters.get("x_axis") or None
if axis:
axis = axis.split(';')
axis = axis.split(";")
if len(axis) != len(data):
axis = None
flg, ax = plt.subplots(figsize=(10, 8))
if axis:
axis = [label[:10] + '...' if len(label) > 10 else label for label in axis]
ax.set_xticklabels(axis, rotation=45, ha='right')
axis = [label[:10] + "..." if len(label) > 10 else label for label in axis]
ax.set_xticklabels(axis, rotation=45, ha="right")
ax.bar(axis, data)
else:
ax.bar(range(len(data)), data)
buf = io.BytesIO()
flg.savefig(buf, format='png')
flg.savefig(buf, format="png")
buf.seek(0)
plt.close(flg)
return [
self.create_text_message('the bar chart is saved as an image.'),
self.create_blob_message(blob=buf.read(),
meta={'mime_type': 'image/png'})
self.create_text_message("the bar chart is saved as an image."),
self.create_blob_message(blob=buf.read(), meta={"mime_type": "image/png"}),
]

View File

@ -8,18 +8,19 @@ from core.tools.tool.builtin_tool import BuiltinTool
class LinearChartTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
data = tool_parameters.get('data', '')
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
data = tool_parameters.get("data", "")
if not data:
return self.create_text_message('Please input data')
data = data.split(';')
return self.create_text_message("Please input data")
data = data.split(";")
axis = tool_parameters.get('x_axis') or None
axis = tool_parameters.get("x_axis") or None
if axis:
axis = axis.split(';')
axis = axis.split(";")
if len(axis) != len(data):
axis = None
@ -32,20 +33,18 @@ class LinearChartTool(BuiltinTool):
flg, ax = plt.subplots(figsize=(10, 8))
if axis:
axis = [label[:10] + '...' if len(label) > 10 else label for label in axis]
ax.set_xticklabels(axis, rotation=45, ha='right')
axis = [label[:10] + "..." if len(label) > 10 else label for label in axis]
ax.set_xticklabels(axis, rotation=45, ha="right")
ax.plot(axis, data)
else:
ax.plot(data)
buf = io.BytesIO()
flg.savefig(buf, format='png')
flg.savefig(buf, format="png")
buf.seek(0)
plt.close(flg)
return [
self.create_text_message('the linear chart is saved as an image.'),
self.create_blob_message(blob=buf.read(),
meta={'mime_type': 'image/png'})
self.create_text_message("the linear chart is saved as an image."),
self.create_blob_message(blob=buf.read(), meta={"mime_type": "image/png"}),
]

View File

@ -8,15 +8,16 @@ from core.tools.tool.builtin_tool import BuiltinTool
class PieChartTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
data = tool_parameters.get('data', '')
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
data = tool_parameters.get("data", "")
if not data:
return self.create_text_message('Please input data')
data = data.split(';')
categories = tool_parameters.get('categories') or None
return self.create_text_message("Please input data")
data = data.split(";")
categories = tool_parameters.get("categories") or None
# if all data is int, convert to int
if all(i.isdigit() for i in data):
@ -27,7 +28,7 @@ class PieChartTool(BuiltinTool):
flg, ax = plt.subplots()
if categories:
categories = categories.split(';')
categories = categories.split(";")
if len(categories) != len(data):
categories = None
@ -37,12 +38,11 @@ class PieChartTool(BuiltinTool):
ax.pie(data)
buf = io.BytesIO()
flg.savefig(buf, format='png')
flg.savefig(buf, format="png")
buf.seek(0)
plt.close(flg)
return [
self.create_text_message('the pie chart is saved as an image.'),
self.create_blob_message(blob=buf.read(),
meta={'mime_type': 'image/png'})
]
self.create_text_message("the pie chart is saved as an image."),
self.create_blob_message(blob=buf.read(), meta={"mime_type": "image/png"}),
]

View File

@ -8,15 +8,15 @@ from core.tools.tool.builtin_tool import BuiltinTool
class SimpleCode(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
"""
invoke simple code
invoke simple code
"""
language = tool_parameters.get('language', CodeLanguage.PYTHON3)
code = tool_parameters.get('code', '')
language = tool_parameters.get("language", CodeLanguage.PYTHON3)
code = tool_parameters.get("code", "")
if language not in [CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]:
raise ValueError(f'Only python3 and javascript are supported, not {language}')
result = CodeExecutor.execute_code(language, '', code)
if language not in {CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT}:
raise ValueError(f"Only python3 and javascript are supported, not {language}")
return self.create_text_message(result)
result = CodeExecutor.execute_code(language, "", code)
return self.create_text_message(result)

View File

@ -1,4 +1,5 @@
""" Provide the input parameters type for the cogview provider class """
"""Provide the input parameters type for the cogview provider class"""
from typing import Any
from core.tools.errors import ToolProviderCredentialValidationError
@ -7,7 +8,8 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
class COGVIEWProvider(BuiltinToolProviderController):
""" cogview provider """
"""cogview provider"""
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
CogView3Tool().fork_tool_runtime(
@ -15,13 +17,12 @@ class COGVIEWProvider(BuiltinToolProviderController):
"credentials": credentials,
}
).invoke(
user_id='',
user_id="",
tool_parameters={
"prompt": "一个城市在水晶瓶中欢快生活的场景,水彩画风格,展现出微观与珠宝般的美丽。",
"size": "square",
"n": 1
"n": 1,
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e)) from e

View File

@ -7,43 +7,42 @@ from core.tools.tool.builtin_tool import BuiltinTool
class CogView3Tool(BuiltinTool):
""" CogView3 Tool """
"""CogView3 Tool"""
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
Invoke CogView3 tool
"""
client = ZhipuAI(
base_url=self.runtime.credentials['zhipuai_base_url'],
api_key=self.runtime.credentials['zhipuai_api_key'],
base_url=self.runtime.credentials["zhipuai_base_url"],
api_key=self.runtime.credentials["zhipuai_api_key"],
)
size_mapping = {
'square': '1024x1024',
'vertical': '1024x1792',
'horizontal': '1792x1024',
"square": "1024x1024",
"vertical": "1024x1792",
"horizontal": "1792x1024",
}
# prompt
prompt = tool_parameters.get('prompt', '')
prompt = tool_parameters.get("prompt", "")
if not prompt:
return self.create_text_message('Please input prompt')
return self.create_text_message("Please input prompt")
# get size
size = size_mapping[tool_parameters.get('size', 'square')]
size = size_mapping[tool_parameters.get("size", "square")]
# get n
n = tool_parameters.get('n', 1)
n = tool_parameters.get("n", 1)
# get quality
quality = tool_parameters.get('quality', 'standard')
if quality not in ['standard', 'hd']:
return self.create_text_message('Invalid quality')
quality = tool_parameters.get("quality", "standard")
if quality not in {"standard", "hd"}:
return self.create_text_message("Invalid quality")
# get style
style = tool_parameters.get('style', 'vivid')
if style not in ['natural', 'vivid']:
return self.create_text_message('Invalid style')
style = tool_parameters.get("style", "vivid")
if style not in {"natural", "vivid"}:
return self.create_text_message("Invalid style")
# set extra body
seed_id = tool_parameters.get('seed_id', self._generate_random_id(8))
extra_body = {'seed': seed_id}
seed_id = tool_parameters.get("seed_id", self._generate_random_id(8))
extra_body = {"seed": seed_id}
response = client.images.generations(
prompt=prompt,
model="cogview-3",
@ -52,18 +51,22 @@ class CogView3Tool(BuiltinTool):
extra_body=extra_body,
style=style,
quality=quality,
response_format='b64_json'
response_format="b64_json",
)
result = []
for image in response.data:
result.append(self.create_image_message(image=image.url))
result.append(self.create_json_message({
"url": image.url,
}))
result.append(
self.create_json_message(
{
"url": image.url,
}
)
)
return result
@staticmethod
def _generate_random_id(length=8):
characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
random_id = ''.join(random.choices(characters, k=length))
characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
random_id = "".join(random.choices(characters, k=length))
return random_id

View File

@ -11,9 +11,9 @@ class CrossRefProvider(BuiltinToolProviderController):
"credentials": credentials,
}
).invoke(
user_id='',
user_id="",
tool_parameters={
"doi": '10.1007/s00894-022-05373-8',
"doi": "10.1007/s00894-022-05373-8",
},
)
except Exception as e:

View File

@ -11,15 +11,18 @@ class CrossRefQueryDOITool(BuiltinTool):
"""
Tool for querying the metadata of a publication using its DOI.
"""
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
doi = tool_parameters.get('doi')
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
doi = tool_parameters.get("doi")
if not doi:
raise ToolParameterValidationError('doi is required.')
raise ToolParameterValidationError("doi is required.")
# doc: https://github.com/CrossRef/rest-api-doc
url = f"https://api.crossref.org/works/{doi}"
response = requests.get(url)
response.raise_for_status()
response = response.json()
message = response.get('message', {})
message = response.get("message", {})
return self.create_json_message(message)

View File

@ -12,16 +12,16 @@ def convert_time_str_to_seconds(time_str: str) -> int:
Convert a time string to seconds.
example: 1s -> 1, 1m30s -> 90, 1h30m -> 5400, 1h30m30s -> 5430
"""
time_str = time_str.lower().strip().replace(' ', '')
time_str = time_str.lower().strip().replace(" ", "")
seconds = 0
if 'h' in time_str:
hours, time_str = time_str.split('h')
if "h" in time_str:
hours, time_str = time_str.split("h")
seconds += int(hours) * 3600
if 'm' in time_str:
minutes, time_str = time_str.split('m')
if "m" in time_str:
minutes, time_str = time_str.split("m")
seconds += int(minutes) * 60
if 's' in time_str:
seconds += int(time_str.replace('s', ''))
if "s" in time_str:
seconds += int(time_str.replace("s", ""))
return seconds
@ -30,6 +30,7 @@ class CrossRefQueryTitleAPI:
Tool for querying the metadata of a publication using its title.
Crossref API doc: https://github.com/CrossRef/rest-api-doc
"""
query_url_template: str = "https://api.crossref.org/works?query.bibliographic={query}&rows={rows}&offset={offset}&sort={sort}&order={order}&mailto={mailto}"
rate_limit: int = 50
rate_interval: float = 1
@ -38,7 +39,15 @@ class CrossRefQueryTitleAPI:
def __init__(self, mailto: str):
self.mailto = mailto
def _query(self, query: str, rows: int = 5, offset: int = 0, sort: str = 'relevance', order: str = 'desc', fuzzy_query: bool = False) -> list[dict]:
def _query(
self,
query: str,
rows: int = 5,
offset: int = 0,
sort: str = "relevance",
order: str = "desc",
fuzzy_query: bool = False,
) -> list[dict]:
"""
Query the metadata of a publication using its title.
:param query: the title of the publication
@ -47,33 +56,37 @@ class CrossRefQueryTitleAPI:
:param order: the sort order
:param fuzzy_query: whether to return all items that match the query
"""
url = self.query_url_template.format(query=query, rows=rows, offset=offset, sort=sort, order=order, mailto=self.mailto)
url = self.query_url_template.format(
query=query, rows=rows, offset=offset, sort=sort, order=order, mailto=self.mailto
)
response = requests.get(url)
response.raise_for_status()
rate_limit = int(response.headers['x-ratelimit-limit'])
rate_limit = int(response.headers["x-ratelimit-limit"])
# convert time string to seconds
rate_interval = convert_time_str_to_seconds(response.headers['x-ratelimit-interval'])
rate_interval = convert_time_str_to_seconds(response.headers["x-ratelimit-interval"])
self.rate_limit = rate_limit
self.rate_interval = rate_interval
response = response.json()
if response['status'] != 'ok':
if response["status"] != "ok":
return []
message = response['message']
message = response["message"]
if fuzzy_query:
# fuzzy query return all items
return message['items']
return message["items"]
else:
for paper in message['items']:
title = paper['title'][0]
for paper in message["items"]:
title = paper["title"][0]
if title.lower() != query.lower():
continue
return [paper]
return []
def query(self, query: str, rows: int = 5, sort: str = 'relevance', order: str = 'desc', fuzzy_query: bool = False) -> list[dict]:
def query(
self, query: str, rows: int = 5, sort: str = "relevance", order: str = "desc", fuzzy_query: bool = False
) -> list[dict]:
"""
Query the metadata of a publication using its title.
:param query: the title of the publication
@ -89,7 +102,14 @@ class CrossRefQueryTitleAPI:
results = []
for i in range(query_times):
result = self._query(query, rows=self.rate_limit, offset=i * self.rate_limit, sort=sort, order=order, fuzzy_query=fuzzy_query)
result = self._query(
query,
rows=self.rate_limit,
offset=i * self.rate_limit,
sort=sort,
order=order,
fuzzy_query=fuzzy_query,
)
if fuzzy_query:
results.extend(result)
else:
@ -107,13 +127,16 @@ class CrossRefQueryTitleTool(BuiltinTool):
"""
Tool for querying the metadata of a publication using its title.
"""
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
query = tool_parameters.get('query')
fuzzy_query = tool_parameters.get('fuzzy_query', False)
rows = tool_parameters.get('rows', 3)
sort = tool_parameters.get('sort', 'relevance')
order = tool_parameters.get('order', 'desc')
mailto = self.runtime.credentials['mailto']
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
query = tool_parameters.get("query")
fuzzy_query = tool_parameters.get("fuzzy_query", False)
rows = tool_parameters.get("rows", 3)
sort = tool_parameters.get("sort", "relevance")
order = tool_parameters.get("order", "desc")
mailto = self.runtime.credentials["mailto"]
result = CrossRefQueryTitleAPI(mailto).query(query, rows, sort, order, fuzzy_query)

View File

@ -13,13 +13,8 @@ class DALLEProvider(BuiltinToolProviderController):
"credentials": credentials,
}
).invoke(
user_id='',
tool_parameters={
"prompt": "cute girl, blue eyes, white hair, anime style",
"size": "small",
"n": 1
},
user_id="",
tool_parameters={"prompt": "cute girl, blue eyes, white hair, anime style", "size": "small", "n": 1},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -9,59 +9,58 @@ from core.tools.tool.builtin_tool import BuiltinTool
class DallE2Tool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
invoke tools
"""
openai_organization = self.runtime.credentials.get('openai_organization_id', None)
openai_organization = self.runtime.credentials.get("openai_organization_id", None)
if not openai_organization:
openai_organization = None
openai_base_url = self.runtime.credentials.get('openai_base_url', None)
openai_base_url = self.runtime.credentials.get("openai_base_url", None)
if not openai_base_url:
openai_base_url = None
else:
openai_base_url = str(URL(openai_base_url) / 'v1')
openai_base_url = str(URL(openai_base_url) / "v1")
client = OpenAI(
api_key=self.runtime.credentials['openai_api_key'],
api_key=self.runtime.credentials["openai_api_key"],
base_url=openai_base_url,
organization=openai_organization
organization=openai_organization,
)
SIZE_MAPPING = {
'small': '256x256',
'medium': '512x512',
'large': '1024x1024',
"small": "256x256",
"medium": "512x512",
"large": "1024x1024",
}
# prompt
prompt = tool_parameters.get('prompt', '')
prompt = tool_parameters.get("prompt", "")
if not prompt:
return self.create_text_message('Please input prompt')
return self.create_text_message("Please input prompt")
# get size
size = SIZE_MAPPING[tool_parameters.get('size', 'large')]
size = SIZE_MAPPING[tool_parameters.get("size", "large")]
# get n
n = tool_parameters.get('n', 1)
n = tool_parameters.get("n", 1)
# call openapi dalle2
response = client.images.generate(
prompt=prompt,
model='dall-e-2',
size=size,
n=n,
response_format='b64_json'
)
response = client.images.generate(prompt=prompt, model="dall-e-2", size=size, n=n, response_format="b64_json")
result = []
for image in response.data:
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
meta={ 'mime_type': 'image/png' },
save_as=self.VARIABLE_KEY.IMAGE.value))
result.append(
self.create_blob_message(
blob=b64decode(image.b64_json),
meta={"mime_type": "image/png"},
save_as=self.VariableKey.IMAGE.value,
)
)
return result

View File

@ -10,69 +10,64 @@ from core.tools.tool.builtin_tool import BuiltinTool
class DallE3Tool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
invoke tools
"""
openai_organization = self.runtime.credentials.get('openai_organization_id', None)
openai_organization = self.runtime.credentials.get("openai_organization_id", None)
if not openai_organization:
openai_organization = None
openai_base_url = self.runtime.credentials.get('openai_base_url', None)
openai_base_url = self.runtime.credentials.get("openai_base_url", None)
if not openai_base_url:
openai_base_url = None
else:
openai_base_url = str(URL(openai_base_url) / 'v1')
openai_base_url = str(URL(openai_base_url) / "v1")
client = OpenAI(
api_key=self.runtime.credentials['openai_api_key'],
api_key=self.runtime.credentials["openai_api_key"],
base_url=openai_base_url,
organization=openai_organization
organization=openai_organization,
)
SIZE_MAPPING = {
'square': '1024x1024',
'vertical': '1024x1792',
'horizontal': '1792x1024',
"square": "1024x1024",
"vertical": "1024x1792",
"horizontal": "1792x1024",
}
# prompt
prompt = tool_parameters.get('prompt', '')
prompt = tool_parameters.get("prompt", "")
if not prompt:
return self.create_text_message('Please input prompt')
return self.create_text_message("Please input prompt")
# get size
size = SIZE_MAPPING[tool_parameters.get('size', 'square')]
size = SIZE_MAPPING[tool_parameters.get("size", "square")]
# get n
n = tool_parameters.get('n', 1)
n = tool_parameters.get("n", 1)
# get quality
quality = tool_parameters.get('quality', 'standard')
if quality not in ['standard', 'hd']:
return self.create_text_message('Invalid quality')
quality = tool_parameters.get("quality", "standard")
if quality not in {"standard", "hd"}:
return self.create_text_message("Invalid quality")
# get style
style = tool_parameters.get('style', 'vivid')
if style not in ['natural', 'vivid']:
return self.create_text_message('Invalid style')
style = tool_parameters.get("style", "vivid")
if style not in {"natural", "vivid"}:
return self.create_text_message("Invalid style")
# call openapi dalle3
response = client.images.generate(
prompt=prompt,
model='dall-e-3',
size=size,
n=n,
style=style,
quality=quality,
response_format='b64_json'
prompt=prompt, model="dall-e-3", size=size, n=n, style=style, quality=quality, response_format="b64_json"
)
result = []
for image in response.data:
mime_type, blob_image = DallE3Tool._decode_image(image.b64_json)
blob_message = self.create_blob_message(blob=blob_image,
meta={'mime_type': mime_type},
save_as=self.VARIABLE_KEY.IMAGE.value)
blob_message = self.create_blob_message(
blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VariableKey.IMAGE.value
)
result.append(blob_message)
return result
@ -86,7 +81,7 @@ class DallE3Tool(BuiltinTool):
:return: A tuple containing the MIME type and the decoded image bytes
"""
if DallE3Tool._is_plain_base64(base64_image):
return 'image/png', base64.b64decode(base64_image)
return "image/png", base64.b64decode(base64_image)
else:
return DallE3Tool._extract_mime_and_data(base64_image)
@ -98,7 +93,7 @@ class DallE3Tool(BuiltinTool):
:param encoded_str: Base64 encoded image string
:return: True if the string is plain base64, False otherwise
"""
return not encoded_str.startswith('data:image')
return not encoded_str.startswith("data:image")
@staticmethod
def _extract_mime_and_data(encoded_str: str) -> tuple[str, bytes]:
@ -108,13 +103,13 @@ class DallE3Tool(BuiltinTool):
:param encoded_str: Base64 encoded image string with MIME type prefix
:return: A tuple containing the MIME type and the decoded image bytes
"""
mime_type = encoded_str.split(';')[0].split(':')[1]
image_data_base64 = encoded_str.split(',')[1]
mime_type = encoded_str.split(";")[0].split(":")[1]
image_data_base64 = encoded_str.split(",")[1]
decoded_data = base64.b64decode(image_data_base64)
return mime_type, decoded_data
@staticmethod
def _generate_random_id(length=8):
characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
random_id = ''.join(random.choices(characters, k=length))
characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
random_id = "".join(random.choices(characters, k=length))
return random_id

View File

@ -11,7 +11,7 @@ class DevDocsProvider(BuiltinToolProviderController):
"credentials": credentials,
}
).invoke(
user_id='',
user_id="",
tool_parameters={
"doc": "python~3.12",
"topic": "library/code",
@ -19,4 +19,3 @@ class DevDocsProvider(BuiltinToolProviderController):
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -13,7 +13,9 @@ class SearchDevDocsInput(BaseModel):
class SearchDevDocsTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
Invokes the DevDocs search tool with the given user ID and tool parameters.
@ -22,15 +24,16 @@ class SearchDevDocsTool(BuiltinTool):
tool_parameters (dict[str, Any]): The parameters for the tool, including 'doc' and 'topic'.
Returns:
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, which can be a single message or a list of messages.
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation,
which can be a single message or a list of messages.
"""
doc = tool_parameters.get('doc', '')
topic = tool_parameters.get('topic', '')
doc = tool_parameters.get("doc", "")
topic = tool_parameters.get("topic", "")
if not doc:
return self.create_text_message('Please provide the documentation name.')
return self.create_text_message("Please provide the documentation name.")
if not topic:
return self.create_text_message('Please provide the topic path.')
return self.create_text_message("Please provide the topic path.")
url = f"https://documents.devdocs.io/{doc}/{topic}.html"
response = requests.get(url)
@ -39,4 +42,6 @@ class SearchDevDocsTool(BuiltinTool):
content = response.text
return self.create_text_message(self.summary(user_id=user_id, content=content))
else:
return self.create_text_message(f"Failed to retrieve the documentation. Status code: {response.status_code}")
return self.create_text_message(
f"Failed to retrieve the documentation. Status code: {response.status_code}"
)

View File

@ -7,15 +7,12 @@ class DIDProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
# Example validation using the D-ID talks tool
TalksTool().fork_tool_runtime(
runtime={"credentials": credentials}
).invoke(
user_id='',
TalksTool().fork_tool_runtime(runtime={"credentials": credentials}).invoke(
user_id="",
tool_parameters={
"source_url": "https://www.d-id.com/wp-content/uploads/2023/11/Hero-image-1.png",
"text_input": "Hello, welcome to use D-ID tool in Dify",
}
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -12,14 +12,14 @@ logger = logging.getLogger(__name__)
class DIDApp:
def __init__(self, api_key: str | None = None, base_url: str | None = None):
self.api_key = api_key
self.base_url = base_url or 'https://api.d-id.com'
self.base_url = base_url or "https://api.d-id.com"
if not self.api_key:
raise ValueError('API key is required')
raise ValueError("API key is required")
def _prepare_headers(self, idempotency_key: str | None = None):
headers = {'Content-Type': 'application/json', 'Authorization': f'Basic {self.api_key}'}
headers = {"Content-Type": "application/json", "Authorization": f"Basic {self.api_key}"}
if idempotency_key:
headers['Idempotency-Key'] = idempotency_key
headers["Idempotency-Key"] = idempotency_key
return headers
def _request(
@ -44,44 +44,44 @@ class DIDApp:
return None
def talks(self, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs):
endpoint = f'{self.base_url}/talks'
endpoint = f"{self.base_url}/talks"
headers = self._prepare_headers(idempotency_key)
data = kwargs['params']
logger.debug(f'Send request to {endpoint=} body={data}')
response = self._request('POST', endpoint, data, headers)
data = kwargs["params"]
logger.debug(f"Send request to {endpoint=} body={data}")
response = self._request("POST", endpoint, data, headers)
if response is None:
raise HTTPError('Failed to initiate D-ID talks after multiple retries')
id: str = response['id']
raise HTTPError("Failed to initiate D-ID talks after multiple retries")
id: str = response["id"]
if wait:
return self._monitor_job_status(id=id, target='talks', poll_interval=poll_interval)
return self._monitor_job_status(id=id, target="talks", poll_interval=poll_interval)
return id
def animations(self, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs):
endpoint = f'{self.base_url}/animations'
endpoint = f"{self.base_url}/animations"
headers = self._prepare_headers(idempotency_key)
data = kwargs['params']
logger.debug(f'Send request to {endpoint=} body={data}')
response = self._request('POST', endpoint, data, headers)
data = kwargs["params"]
logger.debug(f"Send request to {endpoint=} body={data}")
response = self._request("POST", endpoint, data, headers)
if response is None:
raise HTTPError('Failed to initiate D-ID talks after multiple retries')
id: str = response['id']
raise HTTPError("Failed to initiate D-ID talks after multiple retries")
id: str = response["id"]
if wait:
return self._monitor_job_status(target='animations', id=id, poll_interval=poll_interval)
return self._monitor_job_status(target="animations", id=id, poll_interval=poll_interval)
return id
def check_did_status(self, target: str, id: str):
endpoint = f'{self.base_url}/{target}/{id}'
endpoint = f"{self.base_url}/{target}/{id}"
headers = self._prepare_headers()
response = self._request('GET', endpoint, headers=headers)
response = self._request("GET", endpoint, headers=headers)
if response is None:
raise HTTPError(f'Failed to check status for talks {id} after multiple retries')
raise HTTPError(f"Failed to check status for talks {id} after multiple retries")
return response
def _monitor_job_status(self, target: str, id: str, poll_interval: int):
while True:
status = self.check_did_status(target=target, id=id)
if status['status'] == 'done':
if status["status"] == "done":
return status
elif status['status'] == 'error' or status['status'] == 'rejected':
raise HTTPError(f'Talks {id} failed: {status["status"]} {status.get("error",{}).get("description")}')
elif status["status"] == "error" or status["status"] == "rejected":
raise HTTPError(f'Talks {id} failed: {status["status"]} {status.get("error", {}).get("description")}')
time.sleep(poll_interval)

View File

@ -10,33 +10,33 @@ class AnimationsTool(BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
app = DIDApp(api_key=self.runtime.credentials['did_api_key'], base_url=self.runtime.credentials['base_url'])
app = DIDApp(api_key=self.runtime.credentials["did_api_key"], base_url=self.runtime.credentials["base_url"])
driver_expressions_str = tool_parameters.get('driver_expressions')
driver_expressions_str = tool_parameters.get("driver_expressions")
driver_expressions = json.loads(driver_expressions_str) if driver_expressions_str else None
config = {
'stitch': tool_parameters.get('stitch', True),
'mute': tool_parameters.get('mute'),
'result_format': tool_parameters.get('result_format') or 'mp4',
"stitch": tool_parameters.get("stitch", True),
"mute": tool_parameters.get("mute"),
"result_format": tool_parameters.get("result_format") or "mp4",
}
config = {k: v for k, v in config.items() if v is not None and v != ''}
config = {k: v for k, v in config.items() if v is not None and v != ""}
options = {
'source_url': tool_parameters['source_url'],
'driver_url': tool_parameters.get('driver_url'),
'config': config,
"source_url": tool_parameters["source_url"],
"driver_url": tool_parameters.get("driver_url"),
"config": config,
}
options = {k: v for k, v in options.items() if v is not None and v != ''}
options = {k: v for k, v in options.items() if v is not None and v != ""}
if not options.get('source_url'):
raise ValueError('Source URL is required')
if not options.get("source_url"):
raise ValueError("Source URL is required")
if config.get('logo_url'):
if not config.get('logo_x'):
raise ValueError('Logo X position is required when logo URL is provided')
if not config.get('logo_y'):
raise ValueError('Logo Y position is required when logo URL is provided')
if config.get("logo_url"):
if not config.get("logo_x"):
raise ValueError("Logo X position is required when logo URL is provided")
if not config.get("logo_y"):
raise ValueError("Logo Y position is required when logo URL is provided")
animations_result = app.animations(params=options, wait=True)
@ -44,6 +44,6 @@ class AnimationsTool(BuiltinTool):
animations_result = json.dumps(animations_result, ensure_ascii=False, indent=4)
if not animations_result:
return self.create_text_message('D-ID animations request failed.')
return self.create_text_message("D-ID animations request failed.")
return self.create_text_message(animations_result)

View File

@ -10,49 +10,49 @@ class TalksTool(BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
app = DIDApp(api_key=self.runtime.credentials['did_api_key'], base_url=self.runtime.credentials['base_url'])
app = DIDApp(api_key=self.runtime.credentials["did_api_key"], base_url=self.runtime.credentials["base_url"])
driver_expressions_str = tool_parameters.get('driver_expressions')
driver_expressions_str = tool_parameters.get("driver_expressions")
driver_expressions = json.loads(driver_expressions_str) if driver_expressions_str else None
script = {
'type': tool_parameters.get('script_type') or 'text',
'input': tool_parameters.get('text_input'),
'audio_url': tool_parameters.get('audio_url'),
'reduce_noise': tool_parameters.get('audio_reduce_noise', False),
"type": tool_parameters.get("script_type") or "text",
"input": tool_parameters.get("text_input"),
"audio_url": tool_parameters.get("audio_url"),
"reduce_noise": tool_parameters.get("audio_reduce_noise", False),
}
script = {k: v for k, v in script.items() if v is not None and v != ''}
script = {k: v for k, v in script.items() if v is not None and v != ""}
config = {
'stitch': tool_parameters.get('stitch', True),
'sharpen': tool_parameters.get('sharpen'),
'fluent': tool_parameters.get('fluent'),
'result_format': tool_parameters.get('result_format') or 'mp4',
'pad_audio': tool_parameters.get('pad_audio'),
'driver_expressions': driver_expressions,
"stitch": tool_parameters.get("stitch", True),
"sharpen": tool_parameters.get("sharpen"),
"fluent": tool_parameters.get("fluent"),
"result_format": tool_parameters.get("result_format") or "mp4",
"pad_audio": tool_parameters.get("pad_audio"),
"driver_expressions": driver_expressions,
}
config = {k: v for k, v in config.items() if v is not None and v != ''}
config = {k: v for k, v in config.items() if v is not None and v != ""}
options = {
'source_url': tool_parameters['source_url'],
'driver_url': tool_parameters.get('driver_url'),
'script': script,
'config': config,
"source_url": tool_parameters["source_url"],
"driver_url": tool_parameters.get("driver_url"),
"script": script,
"config": config,
}
options = {k: v for k, v in options.items() if v is not None and v != ''}
options = {k: v for k, v in options.items() if v is not None and v != ""}
if not options.get('source_url'):
raise ValueError('Source URL is required')
if not options.get("source_url"):
raise ValueError("Source URL is required")
if script.get('type') == 'audio':
script.pop('input', None)
if not script.get('audio_url'):
raise ValueError('Audio URL is required for audio script type')
if script.get("type") == "audio":
script.pop("input", None)
if not script.get("audio_url"):
raise ValueError("Audio URL is required for audio script type")
if script.get('type') == 'text':
script.pop('audio_url', None)
script.pop('reduce_noise', None)
if not script.get('input'):
raise ValueError('Text input is required for text script type')
if script.get("type") == "text":
script.pop("audio_url", None)
script.pop("reduce_noise", None)
if not script.get("input"):
raise ValueError("Text input is required for text script type")
talks_result = app.talks(params=options, wait=True)
@ -60,6 +60,6 @@ class TalksTool(BuiltinTool):
talks_result = json.dumps(talks_result, ensure_ascii=False, indent=4)
if not talks_result:
return self.create_text_message('D-ID talks request failed.')
return self.create_text_message("D-ID talks request failed.")
return self.create_text_message(talks_result)

View File

@ -13,38 +13,43 @@ from core.tools.tool.builtin_tool import BuiltinTool
class DingTalkGroupBotTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
Dingtalk custom group robot API docs:
https://open.dingtalk.com/document/orgapp/custom-robot-access
invoke tools
Dingtalk custom group robot API docs:
https://open.dingtalk.com/document/orgapp/custom-robot-access
"""
content = tool_parameters.get('content')
content = tool_parameters.get("content")
if not content:
return self.create_text_message('Invalid parameter content')
return self.create_text_message("Invalid parameter content")
access_token = tool_parameters.get('access_token')
access_token = tool_parameters.get("access_token")
if not access_token:
return self.create_text_message('Invalid parameter access_token. '
'Regarding information about security details,'
'please refer to the DingTalk docs:'
'https://open.dingtalk.com/document/robots/customize-robot-security-settings')
return self.create_text_message(
"Invalid parameter access_token. "
"Regarding information about security details,"
"please refer to the DingTalk docs:"
"https://open.dingtalk.com/document/robots/customize-robot-security-settings"
)
sign_secret = tool_parameters.get('sign_secret')
sign_secret = tool_parameters.get("sign_secret")
if not sign_secret:
return self.create_text_message('Invalid parameter sign_secret. '
'Regarding information about security details,'
'please refer to the DingTalk docs:'
'https://open.dingtalk.com/document/robots/customize-robot-security-settings')
return self.create_text_message(
"Invalid parameter sign_secret. "
"Regarding information about security details,"
"please refer to the DingTalk docs:"
"https://open.dingtalk.com/document/robots/customize-robot-security-settings"
)
msgtype = 'text'
api_url = 'https://oapi.dingtalk.com/robot/send'
msgtype = "text"
api_url = "https://oapi.dingtalk.com/robot/send"
headers = {
'Content-Type': 'application/json',
"Content-Type": "application/json",
}
params = {
'access_token': access_token,
"access_token": access_token,
}
self._apply_security_mechanism(params, sign_secret)
@ -53,7 +58,7 @@ class DingTalkGroupBotTool(BuiltinTool):
"msgtype": msgtype,
"text": {
"content": content,
}
},
}
try:
@ -62,7 +67,8 @@ class DingTalkGroupBotTool(BuiltinTool):
return self.create_text_message("Text message sent successfully")
else:
return self.create_text_message(
f"Failed to send the text message, status code: {res.status_code}, response: {res.text}")
f"Failed to send the text message, status code: {res.status_code}, response: {res.text}"
)
except Exception as e:
return self.create_text_message("Failed to send message to group chat bot. {}".format(e))
@ -70,14 +76,14 @@ class DingTalkGroupBotTool(BuiltinTool):
def _apply_security_mechanism(params: dict[str, Any], sign_secret: str):
try:
timestamp = str(round(time.time() * 1000))
secret_enc = sign_secret.encode('utf-8')
string_to_sign = f'{timestamp}\n{sign_secret}'
string_to_sign_enc = string_to_sign.encode('utf-8')
secret_enc = sign_secret.encode("utf-8")
string_to_sign = f"{timestamp}\n{sign_secret}"
string_to_sign_enc = string_to_sign.encode("utf-8")
hmac_code = hmac.new(secret_enc, string_to_sign_enc, digestmod=hashlib.sha256).digest()
sign = urllib.parse.quote_plus(base64.b64encode(hmac_code))
params['timestamp'] = timestamp
params['sign'] = sign
params["timestamp"] = timestamp
params["sign"] = sign
except Exception:
msg = "Failed to apply security mechanism to the request."
logging.exception(msg)

View File

@ -11,11 +11,10 @@ class DuckDuckGoProvider(BuiltinToolProviderController):
"credentials": credentials,
}
).invoke(
user_id='',
user_id="",
tool_parameters={
"query": "John Doe",
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -13,8 +13,8 @@ class DuckDuckGoAITool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
query_dict = {
"keywords": tool_parameters.get('query'),
"model": tool_parameters.get('model'),
"keywords": tool_parameters.get("query"),
"model": tool_parameters.get("model"),
}
response = DDGS().chat(**query_dict)
return self.create_text_message(text=response)

View File

@ -14,18 +14,17 @@ class DuckDuckGoImageSearchTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
query_dict = {
"keywords": tool_parameters.get('query'),
"timelimit": tool_parameters.get('timelimit'),
"size": tool_parameters.get('size'),
"max_results": tool_parameters.get('max_results'),
"keywords": tool_parameters.get("query"),
"timelimit": tool_parameters.get("timelimit"),
"size": tool_parameters.get("size"),
"max_results": tool_parameters.get("max_results"),
}
response = DDGS().images(**query_dict)
result = []
for res in response:
res['transfer_method'] = FileTransferMethod.REMOTE_URL
msg = ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=res.get('image'),
save_as='',
meta=res)
res["transfer_method"] = FileTransferMethod.REMOTE_URL
msg = ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=res.get("image"), save_as="", meta=res
)
result.append(msg)
return result

View File

@ -21,10 +21,11 @@ class DuckDuckGoSearchTool(BuiltinTool):
"""
Tool for performing a search using DuckDuckGo search engine.
"""
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
query = tool_parameters.get('query')
max_results = tool_parameters.get('max_results', 5)
require_summary = tool_parameters.get('require_summary', False)
query = tool_parameters.get("query")
max_results = tool_parameters.get("max_results", 5)
require_summary = tool_parameters.get("require_summary", False)
response = DDGS().text(query, max_results=max_results)
if require_summary:
results = "\n".join([res.get("body") for res in response])
@ -34,7 +35,11 @@ class DuckDuckGoSearchTool(BuiltinTool):
def summary_results(self, user_id: str, content: str, query: str) -> str:
prompt = SUMMARY_PROMPT.format(query=query, content=content)
summary = self.invoke_model(user_id=user_id, prompt_messages=[
SystemPromptMessage(content=prompt),
], stop=[])
summary = self.invoke_model(
user_id=user_id,
prompt_messages=[
SystemPromptMessage(content=prompt),
],
stop=[],
)
return summary.message.content

View File

@ -13,8 +13,8 @@ class DuckDuckGoTranslateTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
query_dict = {
"keywords": tool_parameters.get('query'),
"to": tool_parameters.get('translate_to'),
"keywords": tool_parameters.get("query"),
"to": tool_parameters.get("translate_to"),
}
response = DDGS().translate(**query_dict)[0].get('translated', 'Unable to translate!')
response = DDGS().translate(**query_dict)[0].get("translated", "Unable to translate!")
return self.create_text_message(text=response)

View File

@ -8,35 +8,35 @@ from core.tools.utils.uuid_utils import is_valid_uuid
class FeishuGroupBotTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
API document: https://open.feishu.cn/document/client-docs/bot-v3/add-custom-bot
invoke tools
API document: https://open.feishu.cn/document/client-docs/bot-v3/add-custom-bot
"""
url = "https://open.feishu.cn/open-apis/bot/v2/hook"
content = tool_parameters.get('content', '')
content = tool_parameters.get("content", "")
if not content:
return self.create_text_message('Invalid parameter content')
return self.create_text_message("Invalid parameter content")
hook_key = tool_parameters.get('hook_key', '')
hook_key = tool_parameters.get("hook_key", "")
if not is_valid_uuid(hook_key):
return self.create_text_message(
f'Invalid parameter hook_key ${hook_key}, not a valid UUID')
return self.create_text_message(f"Invalid parameter hook_key ${hook_key}, not a valid UUID")
msg_type = 'text'
api_url = f'{url}/{hook_key}'
msg_type = "text"
api_url = f"{url}/{hook_key}"
headers = {
'Content-Type': 'application/json',
"Content-Type": "application/json",
}
params = {}
payload = {
"msg_type": msg_type,
"content": {
"text": content,
}
},
}
try:
@ -45,6 +45,7 @@ class FeishuGroupBotTool(BuiltinTool):
return self.create_text_message("Text message sent successfully")
else:
return self.create_text_message(
f"Failed to send the text message, status code: {res.status_code}, response: {res.text}")
f"Failed to send the text message, status code: {res.status_code}, response: {res.text}"
)
except Exception as e:
return self.create_text_message("Failed to send message to group chat bot. {}".format(e))
return self.create_text_message("Failed to send message to group chat bot. {}".format(e))

View File

@ -5,4 +5,4 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
class FeishuBaseProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
GetTenantAccessTokenTool()
pass
pass

View File

@ -8,45 +8,49 @@ from core.tools.tool.builtin_tool import BuiltinTool
class AddBaseRecordTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records"
access_token = tool_parameters.get('Authorization', '')
access_token = tool_parameters.get("Authorization", "")
if not access_token:
return self.create_text_message('Invalid parameter access_token')
return self.create_text_message("Invalid parameter access_token")
app_token = tool_parameters.get('app_token', '')
app_token = tool_parameters.get("app_token", "")
if not app_token:
return self.create_text_message('Invalid parameter app_token')
return self.create_text_message("Invalid parameter app_token")
table_id = tool_parameters.get('table_id', '')
table_id = tool_parameters.get("table_id", "")
if not table_id:
return self.create_text_message('Invalid parameter table_id')
return self.create_text_message("Invalid parameter table_id")
fields = tool_parameters.get('fields', '')
fields = tool_parameters.get("fields", "")
if not fields:
return self.create_text_message('Invalid parameter fields')
return self.create_text_message("Invalid parameter fields")
headers = {
'Content-Type': 'application/json',
'Authorization': f"Bearer {access_token}",
"Content-Type": "application/json",
"Authorization": f"Bearer {access_token}",
}
params = {}
payload = {
"fields": json.loads(fields)
}
payload = {"fields": json.loads(fields)}
try:
res = httpx.post(url.format(app_token=app_token, table_id=table_id), headers=headers, params=params,
json=payload, timeout=30)
res = httpx.post(
url.format(app_token=app_token, table_id=table_id),
headers=headers,
params=params,
json=payload,
timeout=30,
)
res_json = res.json()
if res.is_success:
return self.create_text_message(text=json.dumps(res_json))
else:
return self.create_text_message(
f"Failed to add base record, status code: {res.status_code}, response: {res.text}")
f"Failed to add base record, status code: {res.status_code}, response: {res.text}"
)
except Exception as e:
return self.create_text_message("Failed to add base record. {}".format(e))

View File

@ -8,28 +8,25 @@ from core.tools.tool.builtin_tool import BuiltinTool
class CreateBaseTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
url = "https://open.feishu.cn/open-apis/bitable/v1/apps"
access_token = tool_parameters.get('Authorization', '')
access_token = tool_parameters.get("Authorization", "")
if not access_token:
return self.create_text_message('Invalid parameter access_token')
return self.create_text_message("Invalid parameter access_token")
name = tool_parameters.get('name', '')
folder_token = tool_parameters.get('folder_token', '')
name = tool_parameters.get("name", "")
folder_token = tool_parameters.get("folder_token", "")
headers = {
'Content-Type': 'application/json',
'Authorization': f"Bearer {access_token}",
"Content-Type": "application/json",
"Authorization": f"Bearer {access_token}",
}
params = {}
payload = {
"name": name,
"folder_token": folder_token
}
payload = {"name": name, "folder_token": folder_token}
try:
res = httpx.post(url, headers=headers, params=params, json=payload, timeout=30)
@ -38,6 +35,7 @@ class CreateBaseTool(BuiltinTool):
return self.create_text_message(text=json.dumps(res_json))
else:
return self.create_text_message(
f"Failed to create base, status code: {res.status_code}, response: {res.text}")
f"Failed to create base, status code: {res.status_code}, response: {res.text}"
)
except Exception as e:
return self.create_text_message("Failed to create base. {}".format(e))

View File

@ -8,37 +8,32 @@ from core.tools.tool.builtin_tool import BuiltinTool
class CreateBaseTableTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables"
access_token = tool_parameters.get('Authorization', '')
access_token = tool_parameters.get("Authorization", "")
if not access_token:
return self.create_text_message('Invalid parameter access_token')
return self.create_text_message("Invalid parameter access_token")
app_token = tool_parameters.get('app_token', '')
app_token = tool_parameters.get("app_token", "")
if not app_token:
return self.create_text_message('Invalid parameter app_token')
return self.create_text_message("Invalid parameter app_token")
name = tool_parameters.get('name', '')
name = tool_parameters.get("name", "")
fields = tool_parameters.get('fields', '')
fields = tool_parameters.get("fields", "")
if not fields:
return self.create_text_message('Invalid parameter fields')
return self.create_text_message("Invalid parameter fields")
headers = {
'Content-Type': 'application/json',
'Authorization': f"Bearer {access_token}",
"Content-Type": "application/json",
"Authorization": f"Bearer {access_token}",
}
params = {}
payload = {
"table": {
"name": name,
"fields": json.loads(fields)
}
}
payload = {"table": {"name": name, "fields": json.loads(fields)}}
try:
res = httpx.post(url.format(app_token=app_token), headers=headers, params=params, json=payload, timeout=30)
@ -47,6 +42,7 @@ class CreateBaseTableTool(BuiltinTool):
return self.create_text_message(text=json.dumps(res_json))
else:
return self.create_text_message(
f"Failed to create base table, status code: {res.status_code}, response: {res.text}")
f"Failed to create base table, status code: {res.status_code}, response: {res.text}"
)
except Exception as e:
return self.create_text_message("Failed to create base table. {}".format(e))

View File

@ -8,45 +8,49 @@ from core.tools.tool.builtin_tool import BuiltinTool
class DeleteBaseRecordsTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/batch_delete"
access_token = tool_parameters.get('Authorization', '')
access_token = tool_parameters.get("Authorization", "")
if not access_token:
return self.create_text_message('Invalid parameter access_token')
return self.create_text_message("Invalid parameter access_token")
app_token = tool_parameters.get('app_token', '')
app_token = tool_parameters.get("app_token", "")
if not app_token:
return self.create_text_message('Invalid parameter app_token')
return self.create_text_message("Invalid parameter app_token")
table_id = tool_parameters.get('table_id', '')
table_id = tool_parameters.get("table_id", "")
if not table_id:
return self.create_text_message('Invalid parameter table_id')
return self.create_text_message("Invalid parameter table_id")
record_ids = tool_parameters.get('record_ids', '')
record_ids = tool_parameters.get("record_ids", "")
if not record_ids:
return self.create_text_message('Invalid parameter record_ids')
return self.create_text_message("Invalid parameter record_ids")
headers = {
'Content-Type': 'application/json',
'Authorization': f"Bearer {access_token}",
"Content-Type": "application/json",
"Authorization": f"Bearer {access_token}",
}
params = {}
payload = {
"records": json.loads(record_ids)
}
payload = {"records": json.loads(record_ids)}
try:
res = httpx.post(url.format(app_token=app_token, table_id=table_id), headers=headers, params=params,
json=payload, timeout=30)
res = httpx.post(
url.format(app_token=app_token, table_id=table_id),
headers=headers,
params=params,
json=payload,
timeout=30,
)
res_json = res.json()
if res.is_success:
return self.create_text_message(text=json.dumps(res_json))
else:
return self.create_text_message(
f"Failed to delete base records, status code: {res.status_code}, response: {res.text}")
f"Failed to delete base records, status code: {res.status_code}, response: {res.text}"
)
except Exception as e:
return self.create_text_message("Failed to delete base records. {}".format(e))

View File

@ -8,32 +8,30 @@ from core.tools.tool.builtin_tool import BuiltinTool
class DeleteBaseTablesTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/batch_delete"
access_token = tool_parameters.get('Authorization', '')
access_token = tool_parameters.get("Authorization", "")
if not access_token:
return self.create_text_message('Invalid parameter access_token')
return self.create_text_message("Invalid parameter access_token")
app_token = tool_parameters.get('app_token', '')
app_token = tool_parameters.get("app_token", "")
if not app_token:
return self.create_text_message('Invalid parameter app_token')
return self.create_text_message("Invalid parameter app_token")
table_ids = tool_parameters.get('table_ids', '')
table_ids = tool_parameters.get("table_ids", "")
if not table_ids:
return self.create_text_message('Invalid parameter table_ids')
return self.create_text_message("Invalid parameter table_ids")
headers = {
'Content-Type': 'application/json',
'Authorization': f"Bearer {access_token}",
"Content-Type": "application/json",
"Authorization": f"Bearer {access_token}",
}
params = {}
payload = {
"table_ids": json.loads(table_ids)
}
payload = {"table_ids": json.loads(table_ids)}
try:
res = httpx.post(url.format(app_token=app_token), headers=headers, params=params, json=payload, timeout=30)
@ -42,6 +40,7 @@ class DeleteBaseTablesTool(BuiltinTool):
return self.create_text_message(text=json.dumps(res_json))
else:
return self.create_text_message(
f"Failed to delete base tables, status code: {res.status_code}, response: {res.text}")
f"Failed to delete base tables, status code: {res.status_code}, response: {res.text}"
)
except Exception as e:
return self.create_text_message("Failed to delete base tables. {}".format(e))

View File

@ -8,22 +8,22 @@ from core.tools.tool.builtin_tool import BuiltinTool
class GetBaseInfoTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}"
access_token = tool_parameters.get('Authorization', '')
access_token = tool_parameters.get("Authorization", "")
if not access_token:
return self.create_text_message('Invalid parameter access_token')
return self.create_text_message("Invalid parameter access_token")
app_token = tool_parameters.get('app_token', '')
app_token = tool_parameters.get("app_token", "")
if not app_token:
return self.create_text_message('Invalid parameter app_token')
return self.create_text_message("Invalid parameter app_token")
headers = {
'Content-Type': 'application/json',
'Authorization': f"Bearer {access_token}",
"Content-Type": "application/json",
"Authorization": f"Bearer {access_token}",
}
try:
@ -33,6 +33,7 @@ class GetBaseInfoTool(BuiltinTool):
return self.create_text_message(text=json.dumps(res_json))
else:
return self.create_text_message(
f"Failed to get base info, status code: {res.status_code}, response: {res.text}")
f"Failed to get base info, status code: {res.status_code}, response: {res.text}"
)
except Exception as e:
return self.create_text_message("Failed to get base info. {}".format(e))

View File

@ -8,27 +8,24 @@ from core.tools.tool.builtin_tool import BuiltinTool
class GetTenantAccessTokenTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
url = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal"
app_id = tool_parameters.get('app_id', '')
app_id = tool_parameters.get("app_id", "")
if not app_id:
return self.create_text_message('Invalid parameter app_id')
return self.create_text_message("Invalid parameter app_id")
app_secret = tool_parameters.get('app_secret', '')
app_secret = tool_parameters.get("app_secret", "")
if not app_secret:
return self.create_text_message('Invalid parameter app_secret')
return self.create_text_message("Invalid parameter app_secret")
headers = {
'Content-Type': 'application/json',
"Content-Type": "application/json",
}
params = {}
payload = {
"app_id": app_id,
"app_secret": app_secret
}
payload = {"app_id": app_id, "app_secret": app_secret}
"""
{
@ -45,6 +42,7 @@ class GetTenantAccessTokenTool(BuiltinTool):
return self.create_text_message(text=json.dumps(res_json))
else:
return self.create_text_message(
f"Failed to get tenant access token, status code: {res.status_code}, response: {res.text}")
f"Failed to get tenant access token, status code: {res.status_code}, response: {res.text}"
)
except Exception as e:
return self.create_text_message("Failed to get tenant access token. {}".format(e))

View File

@ -8,31 +8,31 @@ from core.tools.tool.builtin_tool import BuiltinTool
class ListBaseRecordsTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/search"
access_token = tool_parameters.get('Authorization', '')
access_token = tool_parameters.get("Authorization", "")
if not access_token:
return self.create_text_message('Invalid parameter access_token')
return self.create_text_message("Invalid parameter access_token")
app_token = tool_parameters.get('app_token', '')
app_token = tool_parameters.get("app_token", "")
if not app_token:
return self.create_text_message('Invalid parameter app_token')
return self.create_text_message("Invalid parameter app_token")
table_id = tool_parameters.get('table_id', '')
table_id = tool_parameters.get("table_id", "")
if not table_id:
return self.create_text_message('Invalid parameter table_id')
return self.create_text_message("Invalid parameter table_id")
page_token = tool_parameters.get('page_token', '')
page_size = tool_parameters.get('page_size', '')
sort_condition = tool_parameters.get('sort_condition', '')
filter_condition = tool_parameters.get('filter_condition', '')
page_token = tool_parameters.get("page_token", "")
page_size = tool_parameters.get("page_size", "")
sort_condition = tool_parameters.get("sort_condition", "")
filter_condition = tool_parameters.get("filter_condition", "")
headers = {
'Content-Type': 'application/json',
'Authorization': f"Bearer {access_token}",
"Content-Type": "application/json",
"Authorization": f"Bearer {access_token}",
}
params = {
@ -40,22 +40,26 @@ class ListBaseRecordsTool(BuiltinTool):
"page_size": page_size,
}
payload = {
"automatic_fields": True
}
payload = {"automatic_fields": True}
if sort_condition:
payload["sort"] = json.loads(sort_condition)
if filter_condition:
payload["filter"] = json.loads(filter_condition)
try:
res = httpx.post(url.format(app_token=app_token, table_id=table_id), headers=headers, params=params,
json=payload, timeout=30)
res = httpx.post(
url.format(app_token=app_token, table_id=table_id),
headers=headers,
params=params,
json=payload,
timeout=30,
)
res_json = res.json()
if res.is_success:
return self.create_text_message(text=json.dumps(res_json))
else:
return self.create_text_message(
f"Failed to list base records, status code: {res.status_code}, response: {res.text}")
f"Failed to list base records, status code: {res.status_code}, response: {res.text}"
)
except Exception as e:
return self.create_text_message("Failed to list base records. {}".format(e))

View File

@ -8,25 +8,25 @@ from core.tools.tool.builtin_tool import BuiltinTool
class ListBaseTablesTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables"
access_token = tool_parameters.get('Authorization', '')
access_token = tool_parameters.get("Authorization", "")
if not access_token:
return self.create_text_message('Invalid parameter access_token')
return self.create_text_message("Invalid parameter access_token")
app_token = tool_parameters.get('app_token', '')
app_token = tool_parameters.get("app_token", "")
if not app_token:
return self.create_text_message('Invalid parameter app_token')
return self.create_text_message("Invalid parameter app_token")
page_token = tool_parameters.get('page_token', '')
page_size = tool_parameters.get('page_size', '')
page_token = tool_parameters.get("page_token", "")
page_size = tool_parameters.get("page_size", "")
headers = {
'Content-Type': 'application/json',
'Authorization': f"Bearer {access_token}",
"Content-Type": "application/json",
"Authorization": f"Bearer {access_token}",
}
params = {
@ -41,6 +41,7 @@ class ListBaseTablesTool(BuiltinTool):
return self.create_text_message(text=json.dumps(res_json))
else:
return self.create_text_message(
f"Failed to list base tables, status code: {res.status_code}, response: {res.text}")
f"Failed to list base tables, status code: {res.status_code}, response: {res.text}"
)
except Exception as e:
return self.create_text_message("Failed to list base tables. {}".format(e))

View File

@ -8,40 +8,42 @@ from core.tools.tool.builtin_tool import BuiltinTool
class ReadBaseRecordTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/{record_id}"
access_token = tool_parameters.get('Authorization', '')
access_token = tool_parameters.get("Authorization", "")
if not access_token:
return self.create_text_message('Invalid parameter access_token')
return self.create_text_message("Invalid parameter access_token")
app_token = tool_parameters.get('app_token', '')
app_token = tool_parameters.get("app_token", "")
if not app_token:
return self.create_text_message('Invalid parameter app_token')
return self.create_text_message("Invalid parameter app_token")
table_id = tool_parameters.get('table_id', '')
table_id = tool_parameters.get("table_id", "")
if not table_id:
return self.create_text_message('Invalid parameter table_id')
return self.create_text_message("Invalid parameter table_id")
record_id = tool_parameters.get('record_id', '')
record_id = tool_parameters.get("record_id", "")
if not record_id:
return self.create_text_message('Invalid parameter record_id')
return self.create_text_message("Invalid parameter record_id")
headers = {
'Content-Type': 'application/json',
'Authorization': f"Bearer {access_token}",
"Content-Type": "application/json",
"Authorization": f"Bearer {access_token}",
}
try:
res = httpx.get(url.format(app_token=app_token, table_id=table_id, record_id=record_id), headers=headers,
timeout=30)
res = httpx.get(
url.format(app_token=app_token, table_id=table_id, record_id=record_id), headers=headers, timeout=30
)
res_json = res.json()
if res.is_success:
return self.create_text_message(text=json.dumps(res_json))
else:
return self.create_text_message(
f"Failed to read base record, status code: {res.status_code}, response: {res.text}")
f"Failed to read base record, status code: {res.status_code}, response: {res.text}"
)
except Exception as e:
return self.create_text_message("Failed to read base record. {}".format(e))

View File

@ -8,49 +8,53 @@ from core.tools.tool.builtin_tool import BuiltinTool
class UpdateBaseRecordTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/{record_id}"
access_token = tool_parameters.get('Authorization', '')
access_token = tool_parameters.get("Authorization", "")
if not access_token:
return self.create_text_message('Invalid parameter access_token')
return self.create_text_message("Invalid parameter access_token")
app_token = tool_parameters.get('app_token', '')
app_token = tool_parameters.get("app_token", "")
if not app_token:
return self.create_text_message('Invalid parameter app_token')
return self.create_text_message("Invalid parameter app_token")
table_id = tool_parameters.get('table_id', '')
table_id = tool_parameters.get("table_id", "")
if not table_id:
return self.create_text_message('Invalid parameter table_id')
return self.create_text_message("Invalid parameter table_id")
record_id = tool_parameters.get('record_id', '')
record_id = tool_parameters.get("record_id", "")
if not record_id:
return self.create_text_message('Invalid parameter record_id')
return self.create_text_message("Invalid parameter record_id")
fields = tool_parameters.get('fields', '')
fields = tool_parameters.get("fields", "")
if not fields:
return self.create_text_message('Invalid parameter fields')
return self.create_text_message("Invalid parameter fields")
headers = {
'Content-Type': 'application/json',
'Authorization': f"Bearer {access_token}",
"Content-Type": "application/json",
"Authorization": f"Bearer {access_token}",
}
params = {}
payload = {
"fields": json.loads(fields)
}
payload = {"fields": json.loads(fields)}
try:
res = httpx.put(url.format(app_token=app_token, table_id=table_id, record_id=record_id), headers=headers,
params=params, json=payload, timeout=30)
res = httpx.put(
url.format(app_token=app_token, table_id=table_id, record_id=record_id),
headers=headers,
params=params,
json=payload,
timeout=30,
)
res_json = res.json()
if res.is_success:
return self.create_text_message(text=json.dumps(res_json))
else:
return self.create_text_message(
f"Failed to update base record, status code: {res.status_code}, response: {res.text}")
f"Failed to update base record, status code: {res.status_code}, response: {res.text}"
)
except Exception as e:
return self.create_text_message("Failed to update base record. {}".format(e))

View File

@ -5,11 +5,11 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class FeishuDocumentProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
app_id = credentials.get('app_id')
app_secret = credentials.get('app_secret')
app_id = credentials.get("app_id")
app_secret = credentials.get("app_secret")
if not app_id or not app_secret:
raise ToolProviderCredentialValidationError("app_id and app_secret is required")
try:
assert FeishuRequest(app_id, app_secret).tenant_access_token is not None
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
raise ToolProviderCredentialValidationError(str(e))

View File

@ -7,13 +7,13 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class CreateDocumentTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
app_id = self.runtime.credentials.get('app_id')
app_secret = self.runtime.credentials.get('app_secret')
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
client = FeishuRequest(app_id, app_secret)
title = tool_parameters.get('title')
content = tool_parameters.get('content')
folder_token = tool_parameters.get('folder_token')
title = tool_parameters.get("title")
content = tool_parameters.get("content")
folder_token = tool_parameters.get("folder_token")
res = client.create_document(title, content, folder_token)
return self.create_json_message(res)

View File

@ -7,11 +7,11 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class GetDocumentRawContentTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
app_id = self.runtime.credentials.get('app_id')
app_secret = self.runtime.credentials.get('app_secret')
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
client = FeishuRequest(app_id, app_secret)
document_id = tool_parameters.get('document_id')
document_id = tool_parameters.get("document_id")
res = client.get_document_raw_content(document_id)
return self.create_json_message(res)
return self.create_json_message(res)

View File

@ -7,13 +7,13 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class ListDocumentBlockTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
app_id = self.runtime.credentials.get('app_id')
app_secret = self.runtime.credentials.get('app_secret')
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
client = FeishuRequest(app_id, app_secret)
document_id = tool_parameters.get('document_id')
page_size = tool_parameters.get('page_size', 500)
page_token = tool_parameters.get('page_token', '')
document_id = tool_parameters.get("document_id")
page_size = tool_parameters.get("page_size", 500)
page_token = tool_parameters.get("page_token", "")
res = client.list_document_block(document_id, page_token, page_size)
return self.create_json_message(res)

View File

@ -7,13 +7,13 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class CreateDocumentTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
app_id = self.runtime.credentials.get('app_id')
app_secret = self.runtime.credentials.get('app_secret')
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
client = FeishuRequest(app_id, app_secret)
document_id = tool_parameters.get('document_id')
content = tool_parameters.get('content')
position = tool_parameters.get('position')
document_id = tool_parameters.get("document_id")
content = tool_parameters.get("content")
position = tool_parameters.get("position")
res = client.write_document(document_id, content, position)
return self.create_json_message(res)

View File

@ -5,11 +5,11 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class FeishuMessageProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
app_id = credentials.get('app_id')
app_secret = credentials.get('app_secret')
app_id = credentials.get("app_id")
app_secret = credentials.get("app_secret")
if not app_id or not app_secret:
raise ToolProviderCredentialValidationError("app_id and app_secret is required")
try:
assert FeishuRequest(app_id, app_secret).tenant_access_token is not None
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
raise ToolProviderCredentialValidationError(str(e))

View File

@ -7,14 +7,14 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class SendBotMessageTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
app_id = self.runtime.credentials.get('app_id')
app_secret = self.runtime.credentials.get('app_secret')
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
client = FeishuRequest(app_id, app_secret)
receive_id_type = tool_parameters.get('receive_id_type')
receive_id = tool_parameters.get('receive_id')
msg_type = tool_parameters.get('msg_type')
content = tool_parameters.get('content')
receive_id_type = tool_parameters.get("receive_id_type")
receive_id = tool_parameters.get("receive_id")
msg_type = tool_parameters.get("msg_type")
content = tool_parameters.get("content")
res = client.send_bot_message(receive_id_type, receive_id, msg_type, content)
return self.create_json_message(res)

View File

@ -6,14 +6,14 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
class SendWebhookMessageTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) ->ToolInvokeMessage:
app_id = self.runtime.credentials.get('app_id')
app_secret = self.runtime.credentials.get('app_secret')
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
app_id = self.runtime.credentials.get("app_id")
app_secret = self.runtime.credentials.get("app_secret")
client = FeishuRequest(app_id, app_secret)
webhook = tool_parameters.get('webhook')
msg_type = tool_parameters.get('msg_type')
content = tool_parameters.get('content')
webhook = tool_parameters.get("webhook")
msg_type = tool_parameters.get("msg_type")
content = tool_parameters.get("content")
res = client.send_webhook_message(webhook, msg_type, content)
return self.create_json_message(res)

View File

@ -7,15 +7,8 @@ class FirecrawlProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
# Example validation using the ScrapeTool, only scraping title for minimize content
ScrapeTool().fork_tool_runtime(
runtime={"credentials": credentials}
).invoke(
user_id='',
tool_parameters={
"url": "https://google.com",
"onlyIncludeTags": 'title'
}
ScrapeTool().fork_tool_runtime(runtime={"credentials": credentials}).invoke(
user_id="", tool_parameters={"url": "https://google.com", "onlyIncludeTags": "title"}
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -13,85 +13,83 @@ logger = logging.getLogger(__name__)
class FirecrawlApp:
def __init__(self, api_key: str | None = None, base_url: str | None = None):
self.api_key = api_key
self.base_url = base_url or 'https://api.firecrawl.dev'
self.base_url = base_url or "https://api.firecrawl.dev"
if not self.api_key:
raise ValueError("API key is required")
def _prepare_headers(self, idempotency_key: str | None = None):
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.api_key}'
}
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
if idempotency_key:
headers['Idempotency-Key'] = idempotency_key
headers["Idempotency-Key"] = idempotency_key
return headers
def _request(
self,
method: str,
url: str,
data: Mapping[str, Any] | None = None,
headers: Mapping[str, str] | None = None,
retries: int = 3,
backoff_factor: float = 0.3,
self,
method: str,
url: str,
data: Mapping[str, Any] | None = None,
headers: Mapping[str, str] | None = None,
retries: int = 3,
backoff_factor: float = 0.3,
) -> Mapping[str, Any] | None:
if not headers:
headers = self._prepare_headers()
for i in range(retries):
try:
response = requests.request(method, url, json=data, headers=headers)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
except requests.exceptions.RequestException:
if i < retries - 1:
time.sleep(backoff_factor * (2 ** i))
time.sleep(backoff_factor * (2**i))
else:
raise
return None
def scrape_url(self, url: str, **kwargs):
endpoint = f'{self.base_url}/v0/scrape'
data = {'url': url, **kwargs}
endpoint = f"{self.base_url}/v1/scrape"
data = {"url": url, **kwargs}
logger.debug(f"Sent request to {endpoint=} body={data}")
response = self._request('POST', endpoint, data)
response = self._request("POST", endpoint, data)
if response is None:
raise HTTPError("Failed to scrape URL after multiple retries")
return response
def search(self, query: str, **kwargs):
endpoint = f'{self.base_url}/v0/search'
data = {'query': query, **kwargs}
def map(self, url: str, **kwargs):
endpoint = f"{self.base_url}/v1/map"
data = {"url": url, **kwargs}
logger.debug(f"Sent request to {endpoint=} body={data}")
response = self._request('POST', endpoint, data)
response = self._request("POST", endpoint, data)
if response is None:
raise HTTPError("Failed to perform search after multiple retries")
raise HTTPError("Failed to perform map after multiple retries")
return response
def crawl_url(
self, url: str, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs
self, url: str, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs
):
endpoint = f'{self.base_url}/v0/crawl'
endpoint = f"{self.base_url}/v1/crawl"
headers = self._prepare_headers(idempotency_key)
data = {'url': url, **kwargs}
data = {"url": url, **kwargs}
logger.debug(f"Sent request to {endpoint=} body={data}")
response = self._request('POST', endpoint, data, headers)
response = self._request("POST", endpoint, data, headers)
if response is None:
raise HTTPError("Failed to initiate crawl after multiple retries")
job_id: str = response['jobId']
elif response.get("success") == False:
raise HTTPError(f'Failed to crawl: {response.get("error")}')
job_id: str = response["id"]
if wait:
return self._monitor_job_status(job_id=job_id, poll_interval=poll_interval)
return response
def check_crawl_status(self, job_id: str):
endpoint = f'{self.base_url}/v0/crawl/status/{job_id}'
response = self._request('GET', endpoint)
endpoint = f"{self.base_url}/v1/crawl/{job_id}"
response = self._request("GET", endpoint)
if response is None:
raise HTTPError(f"Failed to check status for job {job_id} after multiple retries")
return response
def cancel_crawl_job(self, job_id: str):
endpoint = f'{self.base_url}/v0/crawl/cancel/{job_id}'
response = self._request('DELETE', endpoint)
endpoint = f"{self.base_url}/v1/crawl/{job_id}"
response = self._request("DELETE", endpoint)
if response is None:
raise HTTPError(f"Failed to cancel job {job_id} after multiple retries")
return response
@ -99,9 +97,9 @@ class FirecrawlApp:
def _monitor_job_status(self, job_id: str, poll_interval: int):
while True:
status = self.check_crawl_status(job_id)
if status['status'] == 'completed':
if status["status"] == "completed":
return status
elif status['status'] == 'failed':
elif status["status"] == "failed":
raise HTTPError(f'Job {job_id} failed: {status["error"]}')
time.sleep(poll_interval)
@ -109,7 +107,7 @@ class FirecrawlApp:
def get_array_params(tool_parameters: dict[str, Any], key):
param = tool_parameters.get(key)
if param:
return param.split(',')
return param.split(",")
def get_json_params(tool_parameters: dict[str, Any], key):
@ -119,6 +117,6 @@ def get_json_params(tool_parameters: dict[str, Any], key):
# support both single quotes and double quotes
param = param.replace("'", '"')
param = json.loads(param)
except:
except Exception:
raise ValueError(f"Invalid {key} format.")
return param

View File

@ -8,41 +8,38 @@ from core.tools.tool.builtin_tool import BuiltinTool
class CrawlTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
"""
the crawlerOptions and pageOptions comes from doc here:
the api doc:
https://docs.firecrawl.dev/api-reference/endpoint/crawl
"""
app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'],
base_url=self.runtime.credentials['base_url'])
crawlerOptions = {}
pageOptions = {}
wait_for_results = tool_parameters.get('wait_for_results', True)
crawlerOptions['excludes'] = get_array_params(tool_parameters, 'excludes')
crawlerOptions['includes'] = get_array_params(tool_parameters, 'includes')
crawlerOptions['returnOnlyUrls'] = tool_parameters.get('returnOnlyUrls', False)
crawlerOptions['maxDepth'] = tool_parameters.get('maxDepth')
crawlerOptions['mode'] = tool_parameters.get('mode')
crawlerOptions['ignoreSitemap'] = tool_parameters.get('ignoreSitemap', False)
crawlerOptions['limit'] = tool_parameters.get('limit', 5)
crawlerOptions['allowBackwardCrawling'] = tool_parameters.get('allowBackwardCrawling', False)
crawlerOptions['allowExternalContentLinks'] = tool_parameters.get('allowExternalContentLinks', False)
pageOptions['headers'] = get_json_params(tool_parameters, 'headers')
pageOptions['includeHtml'] = tool_parameters.get('includeHtml', False)
pageOptions['includeRawHtml'] = tool_parameters.get('includeRawHtml', False)
pageOptions['onlyIncludeTags'] = get_array_params(tool_parameters, 'onlyIncludeTags')
pageOptions['removeTags'] = get_array_params(tool_parameters, 'removeTags')
pageOptions['onlyMainContent'] = tool_parameters.get('onlyMainContent', False)
pageOptions['replaceAllPathsWithAbsolutePaths'] = tool_parameters.get('replaceAllPathsWithAbsolutePaths', False)
pageOptions['screenshot'] = tool_parameters.get('screenshot', False)
pageOptions['waitFor'] = tool_parameters.get('waitFor', 0)
crawl_result = app.crawl_url(
url=tool_parameters['url'],
wait=wait_for_results,
crawlerOptions=crawlerOptions,
pageOptions=pageOptions
app = FirecrawlApp(
api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"]
)
scrapeOptions = {}
payload = {}
wait_for_results = tool_parameters.get("wait_for_results", True)
payload["excludePaths"] = get_array_params(tool_parameters, "excludePaths")
payload["includePaths"] = get_array_params(tool_parameters, "includePaths")
payload["maxDepth"] = tool_parameters.get("maxDepth")
payload["ignoreSitemap"] = tool_parameters.get("ignoreSitemap", False)
payload["limit"] = tool_parameters.get("limit", 5)
payload["allowBackwardLinks"] = tool_parameters.get("allowBackwardLinks", False)
payload["allowExternalLinks"] = tool_parameters.get("allowExternalLinks", False)
payload["webhook"] = tool_parameters.get("webhook")
scrapeOptions["formats"] = get_array_params(tool_parameters, "formats")
scrapeOptions["headers"] = get_json_params(tool_parameters, "headers")
scrapeOptions["includeTags"] = get_array_params(tool_parameters, "includeTags")
scrapeOptions["excludeTags"] = get_array_params(tool_parameters, "excludeTags")
scrapeOptions["onlyMainContent"] = tool_parameters.get("onlyMainContent", False)
scrapeOptions["waitFor"] = tool_parameters.get("waitFor", 0)
scrapeOptions = {k: v for k, v in scrapeOptions.items() if v not in {None, ""}}
payload["scrapeOptions"] = scrapeOptions or None
payload = {k: v for k, v in payload.items() if v not in {None, ""}}
crawl_result = app.crawl_url(url=tool_parameters["url"], wait=wait_for_results, **payload)
return self.create_json_message(crawl_result)

View File

@ -31,8 +31,21 @@ parameters:
en_US: If you choose not to wait, it will directly return a job ID. You can use this job ID to check the crawling results or cancel the crawling task, which is usually very useful for a large-scale crawling task.
zh_Hans: 如果选择不等待则会直接返回一个job_id可以通过job_id查询爬取结果或取消爬取任务这通常对于一个大型爬取任务来说非常有用。
form: form
############## Crawl Options #######################
- name: includes
############## Payload #######################
- name: excludePaths
type: string
label:
en_US: URL patterns to exclude
zh_Hans: 要排除的URL模式
placeholder:
en_US: Use commas to separate multiple tags
zh_Hans: 多个标签时使用半角逗号分隔
human_description:
en_US: |
Pages matching these patterns will be skipped. Example: blog/*, about/*
zh_Hans: 匹配这些模式的页面将被跳过。示例blog/*, about/*
form: form
- name: includePaths
type: string
required: false
label:
@ -46,30 +59,6 @@ parameters:
Only pages matching these patterns will be crawled. Example: blog/*, about/*
zh_Hans: 只有与这些模式匹配的页面才会被爬取。示例blog/*, about/*
form: form
- name: excludes
type: string
label:
en_US: URL patterns to exclude
zh_Hans: 要排除的URL模式
placeholder:
en_US: Use commas to separate multiple tags
zh_Hans: 多个标签时使用半角逗号分隔
human_description:
en_US: |
Pages matching these patterns will be skipped. Example: blog/*, about/*
zh_Hans: 匹配这些模式的页面将被跳过。示例blog/*, about/*
form: form
- name: returnOnlyUrls
type: boolean
default: false
label:
en_US: return Only Urls
zh_Hans: 仅返回URL
human_description:
en_US: |
If true, returns only the URLs as a list on the crawl status. Attention: the return response will be a list of URLs inside the data, not a list of documents.
zh_Hans: 只返回爬取到的网页链接,而不是网页内容本身。
form: form
- name: maxDepth
type: number
label:
@ -80,27 +69,10 @@ parameters:
zh_Hans: 相对于输入的URL爬取的最大深度。maxDepth为0时仅抓取输入的URL。maxDepth为1时抓取输入的URL以及所有一级深层页面。maxDepth为2时抓取输入的URL以及所有两级深层页面。更高值遵循相同模式。
form: form
min: 0
- name: mode
type: select
required: false
form: form
options:
- value: default
label:
en_US: default
- value: fast
label:
en_US: fast
default: default
label:
en_US: Crawl Mode
zh_Hans: 爬取模式
human_description:
en_US: The crawling mode to use. Fast mode crawls 4x faster websites without sitemap, but may not be as accurate and shouldn't be used in heavy js-rendered websites.
zh_Hans: 使用fast模式将不会使用其站点地图比普通模式快4倍但是可能不够准确也不适用于大量js渲染的网站。
default: 2
- name: ignoreSitemap
type: boolean
default: false
default: true
label:
en_US: ignore Sitemap
zh_Hans: 忽略站点地图
@ -120,7 +92,7 @@ parameters:
form: form
min: 1
default: 5
- name: allowBackwardCrawling
- name: allowBackwardLinks
type: boolean
default: false
label:
@ -130,7 +102,7 @@ parameters:
en_US: Enables the crawler to navigate from a specific URL to previously linked pages. For instance, from 'example.com/product/123' back to 'example.com/product'
zh_Hans: 使爬虫能够从特定URL导航到之前链接的页面。例如从'example.com/product/123'返回到'example.com/product'
form: form
- name: allowExternalContentLinks
- name: allowExternalLinks
type: boolean
default: false
label:
@ -140,7 +112,30 @@ parameters:
en_US: Allows the crawler to follow links to external websites.
zh_Hans:
form: form
############## Page Options #######################
- name: webhook
type: string
label:
en_US: webhook
human_description:
en_US: |
The URL to send the webhook to. This will trigger for crawl started (crawl.started) ,every page crawled (crawl.page) and when the crawl is completed (crawl.completed or crawl.failed). The response will be the same as the /scrape endpoint.
zh_Hans: 发送Webhook的URL。这将在开始爬取crawl.started、每爬取一个页面crawl.page以及爬取完成crawl.completed或crawl.failed时触发。响应将与/scrape端点相同。
form: form
############## Scrape Options #######################
- name: formats
type: string
label:
en_US: Formats
zh_Hans: 结果的格式
placeholder:
en_US: Use commas to separate multiple tags
zh_Hans: 多个标签时使用半角逗号分隔
human_description:
en_US: |
Formats to include in the output. Available options: markdown, html, rawHtml, links, screenshot
zh_Hans: |
输出中应包含的格式。可以填入: markdown, html, rawHtml, links, screenshot
form: form
- name: headers
type: string
label:
@ -155,30 +150,10 @@ parameters:
en_US: Please enter an object that can be serialized in JSON
zh_Hans: 请输入可以json序列化的对象
form: form
- name: includeHtml
type: boolean
default: false
label:
en_US: include Html
zh_Hans: 包含HTML
human_description:
en_US: Include the HTML version of the content on page. Will output a html key in the response.
zh_Hans: 返回中包含一个HTML版本的内容将以html键返回。
form: form
- name: includeRawHtml
type: boolean
default: false
label:
en_US: include Raw Html
zh_Hans: 包含原始HTML
human_description:
en_US: Include the raw HTML content of the page. Will output a rawHtml key in the response.
zh_Hans: 返回中包含一个原始HTML版本的内容将以rawHtml键返回。
form: form
- name: onlyIncludeTags
- name: includeTags
type: string
label:
en_US: only Include Tags
en_US: Include Tags
zh_Hans: 仅抓取这些标签
placeholder:
en_US: Use commas to separate multiple tags
@ -189,6 +164,20 @@ parameters:
zh_Hans: |
仅在最终输出中包含HTML页面的这些标签可以通过标签名、类或ID来设定使用逗号分隔值。示例script, .ad, #footer
form: form
- name: excludeTags
type: string
label:
en_US: Exclude Tags
zh_Hans: 要移除这些标签
human_description:
en_US: |
Tags, classes and ids to remove from the page. Use comma separated values. Example: script, .ad, #footer
zh_Hans: |
要在最终输出中移除HTML页面的这些标签可以通过标签名、类或ID来设定使用逗号分隔值。示例script, .ad, #footer
placeholder:
en_US: Use commas to separate multiple tags
zh_Hans: 多个标签时使用半角逗号分隔
form: form
- name: onlyMainContent
type: boolean
default: false
@ -199,40 +188,6 @@ parameters:
en_US: Only return the main content of the page excluding headers, navs, footers, etc.
zh_Hans: 只返回页面的主要内容,不包括头部、导航栏、尾部等。
form: form
- name: removeTags
type: string
label:
en_US: remove Tags
zh_Hans: 要移除这些标签
human_description:
en_US: |
Tags, classes and ids to remove from the page. Use comma separated values. Example: script, .ad, #footer
zh_Hans: |
要在最终输出中移除HTML页面的这些标签可以通过标签名、类或ID来设定使用逗号分隔值。示例script, .ad, #footer
placeholder:
en_US: Use commas to separate multiple tags
zh_Hans: 多个标签时使用半角逗号分隔
form: form
- name: replaceAllPathsWithAbsolutePaths
type: boolean
default: false
label:
en_US: All AbsolutePaths
zh_Hans: 使用绝对路径
human_description:
en_US: Replace all relative paths with absolute paths for images and links.
zh_Hans: 将所有图片和链接的相对路径替换为绝对路径。
form: form
- name: screenshot
type: boolean
default: false
label:
en_US: screenshot
zh_Hans: 截图
human_description:
en_US: Include a screenshot of the top of the page that you are scraping.
zh_Hans: 提供正在抓取的页面的顶部的截图。
form: form
- name: waitFor
type: number
min: 0

View File

@ -7,14 +7,15 @@ from core.tools.tool.builtin_tool import BuiltinTool
class CrawlJobTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'],
base_url=self.runtime.credentials['base_url'])
operation = tool_parameters.get('operation', 'get')
if operation == 'get':
result = app.check_crawl_status(job_id=tool_parameters['job_id'])
elif operation == 'cancel':
result = app.cancel_crawl_job(job_id=tool_parameters['job_id'])
app = FirecrawlApp(
api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"]
)
operation = tool_parameters.get("operation", "get")
if operation == "get":
result = app.check_crawl_status(job_id=tool_parameters["job_id"])
elif operation == "cancel":
result = app.cancel_crawl_job(job_id=tool_parameters["job_id"])
else:
raise ValueError(f'Invalid operation: {operation}')
raise ValueError(f"Invalid operation: {operation}")
return self.create_json_message(result)

View File

@ -0,0 +1,25 @@
from typing import Any
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.provider.builtin.firecrawl.firecrawl_appx import FirecrawlApp
from core.tools.tool.builtin_tool import BuiltinTool
class MapTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
"""
the api doc:
https://docs.firecrawl.dev/api-reference/endpoint/map
"""
app = FirecrawlApp(
api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"]
)
payload = {}
payload["search"] = tool_parameters.get("search")
payload["ignoreSitemap"] = tool_parameters.get("ignoreSitemap", True)
payload["includeSubdomains"] = tool_parameters.get("includeSubdomains", False)
payload["limit"] = tool_parameters.get("limit", 5000)
map_result = app.map(url=tool_parameters["url"], **payload)
return self.create_json_message(map_result)

View File

@ -0,0 +1,59 @@
identity:
name: map
author: hjlarry
label:
en_US: Map
zh_Hans: 地图式快爬
description:
human:
en_US: Input a website and get all the urls on the website - extremly fast
zh_Hans: 输入一个网站,快速获取网站上的所有网址。
llm: Input a website and get all the urls on the website - extremly fast
parameters:
- name: url
type: string
required: true
label:
en_US: Start URL
zh_Hans: 起始URL
human_description:
en_US: The base URL to start crawling from.
zh_Hans: 要爬取网站的起始URL。
llm_description: The URL of the website that needs to be crawled. This is a required parameter.
form: llm
- name: search
type: string
label:
en_US: search
zh_Hans: 搜索查询
human_description:
en_US: Search query to use for mapping. During the Alpha phase, the 'smart' part of the search functionality is limited to 100 search results. However, if map finds more results, there is no limit applied.
zh_Hans: 用于映射的搜索查询。在Alpha阶段搜索功能的“智能”部分限制为最多100个搜索结果。然而如果地图找到了更多结果则不施加任何限制。
llm_description: Search query to use for mapping. During the Alpha phase, the 'smart' part of the search functionality is limited to 100 search results. However, if map finds more results, there is no limit applied.
form: llm
############## Page Options #######################
- name: ignoreSitemap
type: boolean
default: true
label:
en_US: ignore Sitemap
zh_Hans: 忽略站点地图
human_description:
en_US: Ignore the website sitemap when crawling.
zh_Hans: 爬取时忽略网站站点地图。
form: form
- name: includeSubdomains
type: boolean
default: false
label:
en_US: include Subdomains
zh_Hans: 包含子域名
form: form
- name: limit
type: number
min: 0
default: 5000
label:
en_US: Maximum results
zh_Hans: 最大结果数量
form: form

View File

@ -6,34 +6,34 @@ from core.tools.tool.builtin_tool import BuiltinTool
class ScrapeTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
"""
the pageOptions and extractorOptions comes from doc here:
the api doc:
https://docs.firecrawl.dev/api-reference/endpoint/scrape
"""
app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'],
base_url=self.runtime.credentials['base_url'])
app = FirecrawlApp(
api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"]
)
pageOptions = {}
extractorOptions = {}
payload = {}
extract = {}
pageOptions['headers'] = get_json_params(tool_parameters, 'headers')
pageOptions['includeHtml'] = tool_parameters.get('includeHtml', False)
pageOptions['includeRawHtml'] = tool_parameters.get('includeRawHtml', False)
pageOptions['onlyIncludeTags'] = get_array_params(tool_parameters, 'onlyIncludeTags')
pageOptions['removeTags'] = get_array_params(tool_parameters, 'removeTags')
pageOptions['onlyMainContent'] = tool_parameters.get('onlyMainContent', False)
pageOptions['replaceAllPathsWithAbsolutePaths'] = tool_parameters.get('replaceAllPathsWithAbsolutePaths', False)
pageOptions['screenshot'] = tool_parameters.get('screenshot', False)
pageOptions['waitFor'] = tool_parameters.get('waitFor', 0)
payload["formats"] = get_array_params(tool_parameters, "formats")
payload["onlyMainContent"] = tool_parameters.get("onlyMainContent", True)
payload["includeTags"] = get_array_params(tool_parameters, "includeTags")
payload["excludeTags"] = get_array_params(tool_parameters, "excludeTags")
payload["headers"] = get_json_params(tool_parameters, "headers")
payload["waitFor"] = tool_parameters.get("waitFor", 0)
payload["timeout"] = tool_parameters.get("timeout", 30000)
extractorOptions['mode'] = tool_parameters.get('mode', '')
extractorOptions['extractionPrompt'] = tool_parameters.get('extractionPrompt', '')
extractorOptions['extractionSchema'] = get_json_params(tool_parameters, 'extractionSchema')
extract["schema"] = get_json_params(tool_parameters, "schema")
extract["systemPrompt"] = tool_parameters.get("systemPrompt")
extract["prompt"] = tool_parameters.get("prompt")
extract = {k: v for k, v in extract.items() if v not in {None, ""}}
payload["extract"] = extract or None
crawl_result = app.scrape_url(url=tool_parameters['url'],
pageOptions=pageOptions,
extractorOptions=extractorOptions)
payload = {k: v for k, v in payload.items() if v not in {None, ""}}
return self.create_json_message(crawl_result)
crawl_result = app.scrape_url(url=tool_parameters["url"], **payload)
markdown_result = crawl_result.get("data", {}).get("markdown", "")
return [self.create_text_message(markdown_result), self.create_json_message(crawl_result)]

View File

@ -6,8 +6,8 @@ identity:
zh_Hans: 单页面抓取
description:
human:
en_US: Extract data from a single URL.
zh_Hans: 从单个URL抓取数据。
en_US: Turn any url into clean data.
zh_Hans: 将任何网址转换为干净的数据。
llm: This tool is designed to scrape URL and output the content in Markdown format.
parameters:
- name: url
@ -21,7 +21,59 @@ parameters:
zh_Hans: 要抓取并提取数据的网站URL。
llm_description: The URL of the website that needs to be crawled. This is a required parameter.
form: llm
############## Page Options #######################
############## Payload #######################
- name: formats
type: string
label:
en_US: Formats
zh_Hans: 结果的格式
placeholder:
en_US: Use commas to separate multiple tags
zh_Hans: 多个标签时使用半角逗号分隔
human_description:
en_US: |
Formats to include in the output. Available options: markdown, html, rawHtml, links, screenshot, extract, screenshot@fullPage
zh_Hans: |
输出中应包含的格式。可以填入: markdown, html, rawHtml, links, screenshot, extract, screenshot@fullPage
form: form
- name: onlyMainContent
type: boolean
default: false
label:
en_US: only Main Content
zh_Hans: 仅抓取主要内容
human_description:
en_US: Only return the main content of the page excluding headers, navs, footers, etc.
zh_Hans: 只返回页面的主要内容,不包括头部、导航栏、尾部等。
form: form
- name: includeTags
type: string
label:
en_US: Include Tags
zh_Hans: 仅抓取这些标签
placeholder:
en_US: Use commas to separate multiple tags
zh_Hans: 多个标签时使用半角逗号分隔
human_description:
en_US: |
Only include tags, classes and ids from the page in the final output. Use comma separated values. Example: script, .ad, #footer
zh_Hans: |
仅在最终输出中包含HTML页面的这些标签可以通过标签名、类或ID来设定使用逗号分隔值。示例script, .ad, #footer
form: form
- name: excludeTags
type: string
label:
en_US: Exclude Tags
zh_Hans: 要移除这些标签
human_description:
en_US: |
Tags, classes and ids to remove from the page. Use comma separated values. Example: script, .ad, #footer
zh_Hans: |
要在最终输出中移除HTML页面的这些标签可以通过标签名、类或ID来设定使用逗号分隔值。示例script, .ad, #footer
placeholder:
en_US: Use commas to separate multiple tags
zh_Hans: 多个标签时使用半角逗号分隔
form: form
- name: headers
type: string
label:
@ -36,87 +88,10 @@ parameters:
en_US: Please enter an object that can be serialized in JSON
zh_Hans: 请输入可以json序列化的对象
form: form
- name: includeHtml
type: boolean
default: false
label:
en_US: include Html
zh_Hans: 包含HTML
human_description:
en_US: Include the HTML version of the content on page. Will output a html key in the response.
zh_Hans: 返回中包含一个HTML版本的内容将以html键返回。
form: form
- name: includeRawHtml
type: boolean
default: false
label:
en_US: include Raw Html
zh_Hans: 包含原始HTML
human_description:
en_US: Include the raw HTML content of the page. Will output a rawHtml key in the response.
zh_Hans: 返回中包含一个原始HTML版本的内容将以rawHtml键返回。
form: form
- name: onlyIncludeTags
type: string
label:
en_US: only Include Tags
zh_Hans: 仅抓取这些标签
placeholder:
en_US: Use commas to separate multiple tags
zh_Hans: 多个标签时使用半角逗号分隔
human_description:
en_US: |
Only include tags, classes and ids from the page in the final output. Use comma separated values. Example: script, .ad, #footer
zh_Hans: |
仅在最终输出中包含HTML页面的这些标签可以通过标签名、类或ID来设定使用逗号分隔值。示例script, .ad, #footer
form: form
- name: onlyMainContent
type: boolean
default: false
label:
en_US: only Main Content
zh_Hans: 仅抓取主要内容
human_description:
en_US: Only return the main content of the page excluding headers, navs, footers, etc.
zh_Hans: 只返回页面的主要内容,不包括头部、导航栏、尾部等。
form: form
- name: removeTags
type: string
label:
en_US: remove Tags
zh_Hans: 要移除这些标签
human_description:
en_US: |
Tags, classes and ids to remove from the page. Use comma separated values. Example: script, .ad, #footer
zh_Hans: |
要在最终输出中移除HTML页面的这些标签可以通过标签名、类或ID来设定使用逗号分隔值。示例script, .ad, #footer
placeholder:
en_US: Use commas to separate multiple tags
zh_Hans: 多个标签时使用半角逗号分隔
form: form
- name: replaceAllPathsWithAbsolutePaths
type: boolean
default: false
label:
en_US: All AbsolutePaths
zh_Hans: 使用绝对路径
human_description:
en_US: Replace all relative paths with absolute paths for images and links.
zh_Hans: 将所有图片和链接的相对路径替换为绝对路径。
form: form
- name: screenshot
type: boolean
default: false
label:
en_US: screenshot
zh_Hans: 截图
human_description:
en_US: Include a screenshot of the top of the page that you are scraping.
zh_Hans: 提供正在抓取的页面的顶部的截图。
form: form
- name: waitFor
type: number
min: 0
default: 0
label:
en_US: wait For
zh_Hans: 等待时间
@ -124,57 +99,54 @@ parameters:
en_US: Wait x amount of milliseconds for the page to load to fetch content.
zh_Hans: 等待x毫秒以使页面加载并获取内容。
form: form
- name: timeout
type: number
min: 0
default: 30000
label:
en_US: Timeout
human_description:
en_US: Timeout in milliseconds for the request.
zh_Hans: 请求的超时时间(以毫秒为单位)。
form: form
############## Extractor Options #######################
- name: mode
type: select
options:
- value: markdown
label:
en_US: markdown
- value: llm-extraction
label:
en_US: llm-extraction
- value: llm-extraction-from-raw-html
label:
en_US: llm-extraction-from-raw-html
- value: llm-extraction-from-markdown
label:
en_US: llm-extraction-from-markdown
label:
en_US: Extractor Mode
zh_Hans: 提取模式
human_description:
en_US: |
The extraction mode to use. 'markdown': Returns the scraped markdown content, does not perform LLM extraction. 'llm-extraction': Extracts information from the cleaned and parsed content using LLM.
zh_Hans: 使用的提取模式。“markdown”返回抓取的markdown内容不执行LLM提取。“llm-extractioin”使用LLM按Extractor Schema从内容中提取信息。
form: form
- name: extractionPrompt
type: string
label:
en_US: Extractor Prompt
zh_Hans: 提取时的提示词
human_description:
en_US: A prompt describing what information to extract from the page, applicable for LLM extraction modes.
zh_Hans: 当使用LLM提取模式时用于给LLM描述提取规则。
form: form
- name: extractionSchema
- name: schema
type: string
label:
en_US: Extractor Schema
zh_Hans: 提取时的结构
placeholder:
en_US: Please enter an object that can be serialized in JSON
zh_Hans: 请输入可以json序列化的对象
human_description:
en_US: |
The schema for the data to be extracted, required only for LLM extraction modes. Example: {
The schema for the data to be extracted. Example: {
"type": "object",
"properties": {"company_mission": {"type": "string"}},
"required": ["company_mission"]
}
zh_Hans: |
当使用LLM提取模式时使用该结构去提取,示例:{
使用该结构去提取,示例:{
"type": "object",
"properties": {"company_mission": {"type": "string"}},
"required": ["company_mission"]
}
form: form
- name: systemPrompt
type: string
label:
en_US: Extractor System Prompt
zh_Hans: 提取时的系统提示词
human_description:
en_US: The system prompt to use for the extraction.
zh_Hans: 用于提取的系统提示。
form: form
- name: prompt
type: string
label:
en_US: Extractor Prompt
zh_Hans: 提取时的提示词
human_description:
en_US: The prompt to use for the extraction without a schema.
zh_Hans: 用于无schema时提取的提示词
form: form

View File

@ -1,28 +0,0 @@
from typing import Any
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.provider.builtin.firecrawl.firecrawl_appx import FirecrawlApp
from core.tools.tool.builtin_tool import BuiltinTool
class SearchTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
"""
the pageOptions and searchOptions comes from doc here:
https://docs.firecrawl.dev/api-reference/endpoint/search
"""
app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'],
base_url=self.runtime.credentials['base_url'])
pageOptions = {}
pageOptions['onlyMainContent'] = tool_parameters.get('onlyMainContent', False)
pageOptions['fetchPageContent'] = tool_parameters.get('fetchPageContent', True)
pageOptions['includeHtml'] = tool_parameters.get('includeHtml', False)
pageOptions['includeRawHtml'] = tool_parameters.get('includeRawHtml', False)
searchOptions = {'limit': tool_parameters.get('limit')}
search_result = app.search(
query=tool_parameters['keyword'],
pageOptions=pageOptions,
searchOptions=searchOptions
)
return self.create_json_message(search_result)

View File

@ -1,75 +0,0 @@
identity:
name: search
author: ahasasjeb
label:
en_US: Search
zh_Hans: 搜索
description:
human:
en_US: Search, and output in Markdown format
zh_Hans: 搜索并且以Markdown格式输出
llm: This tool can perform online searches and convert the results to Markdown format.
parameters:
- name: keyword
type: string
required: true
label:
en_US: keyword
zh_Hans: 关键词
human_description:
en_US: Input keywords to use Firecrawl API for search.
zh_Hans: 输入关键词即可使用Firecrawl API进行搜索。
llm_description: Efficiently extract keywords from user text.
form: llm
############## Page Options #######################
- name: onlyMainContent
type: boolean
default: false
label:
en_US: only Main Content
zh_Hans: 仅抓取主要内容
human_description:
en_US: Only return the main content of the page excluding headers, navs, footers, etc.
zh_Hans: 只返回页面的主要内容,不包括头部、导航栏、尾部等。
form: form
- name: fetchPageContent
type: boolean
default: true
label:
en_US: fetch Page Content
zh_Hans: 抓取页面内容
human_description:
en_US: Fetch the content of each page. If false, defaults to a basic fast serp API.
zh_Hans: 获取每个页面的内容。如果为否则使用基本的快速搜索结果页面API。
form: form
- name: includeHtml
type: boolean
default: false
label:
en_US: include Html
zh_Hans: 包含HTML
human_description:
en_US: Include the HTML version of the content on page. Will output a html key in the response.
zh_Hans: 返回中包含一个HTML版本的内容将以html键返回。
form: form
- name: includeRawHtml
type: boolean
default: false
label:
en_US: include Raw Html
zh_Hans: 包含原始HTML
human_description:
en_US: Include the raw HTML content of the page. Will output a rawHtml key in the response.
zh_Hans: 返回中包含一个原始HTML版本的内容将以rawHtml键返回。
form: form
############## Search Options #######################
- name: limit
type: number
min: 0
label:
en_US: Maximum results
zh_Hans: 最大结果数量
human_description:
en_US: Maximum number of results. Max is 20 during beta.
zh_Hans: 最大结果数量。在测试阶段最大为20。
form: form

View File

@ -9,17 +9,19 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
class GaodeProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
if 'api_key' not in credentials or not credentials.get('api_key'):
if "api_key" not in credentials or not credentials.get("api_key"):
raise ToolProviderCredentialValidationError("Gaode API key is required.")
try:
response = requests.get(url="https://restapi.amap.com/v3/geocode/geo?address={address}&key={apikey}"
"".format(address=urllib.parse.quote('广东省广州市天河区广州塔'),
apikey=credentials.get('api_key')))
if response.status_code == 200 and (response.json()).get('info') == 'OK':
response = requests.get(
url="https://restapi.amap.com/v3/geocode/geo?address={address}&key={apikey}".format(
address=urllib.parse.quote("广东省广州市天河区广州塔"), apikey=credentials.get("api_key")
)
)
if response.status_code == 200 and (response.json()).get("info") == "OK":
pass
else:
raise ToolProviderCredentialValidationError((response.json()).get('info'))
raise ToolProviderCredentialValidationError((response.json()).get("info"))
except Exception as e:
raise ToolProviderCredentialValidationError("Gaode API Key is invalid. {}".format(e))
except Exception as e:

View File

@ -8,50 +8,57 @@ from core.tools.tool.builtin_tool import BuiltinTool
class GaodeRepositoriesTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
invoke tools
"""
city = tool_parameters.get('city', '')
city = tool_parameters.get("city", "")
if not city:
return self.create_text_message('Please tell me your city')
return self.create_text_message("Please tell me your city")
if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'):
if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"):
return self.create_text_message("Gaode API key is required.")
try:
s = requests.session()
api_domain = 'https://restapi.amap.com/v3'
city_response = s.request(method='GET', headers={"Content-Type": "application/json; charset=utf-8"},
url="{url}/config/district?keywords={keywords}"
"&subdistrict=0&extensions=base&key={apikey}"
"".format(url=api_domain, keywords=city,
apikey=self.runtime.credentials.get('api_key')))
api_domain = "https://restapi.amap.com/v3"
city_response = s.request(
method="GET",
headers={"Content-Type": "application/json; charset=utf-8"},
url="{url}/config/district?keywords={keywords}&subdistrict=0&extensions=base&key={apikey}".format(
url=api_domain, keywords=city, apikey=self.runtime.credentials.get("api_key")
),
)
City_data = city_response.json()
if city_response.status_code == 200 and City_data.get('info') == 'OK':
if len(City_data.get('districts')) > 0:
CityCode = City_data['districts'][0]['adcode']
weatherInfo_response = s.request(method='GET',
url="{url}/weather/weatherInfo?city={citycode}&extensions=all&key={apikey}&output=json"
"".format(url=api_domain, citycode=CityCode,
apikey=self.runtime.credentials.get('api_key')))
if city_response.status_code == 200 and City_data.get("info") == "OK":
if len(City_data.get("districts")) > 0:
CityCode = City_data["districts"][0]["adcode"]
weatherInfo_response = s.request(
method="GET",
url="{url}/weather/weatherInfo?city={citycode}&extensions=all&key={apikey}&output=json"
"".format(url=api_domain, citycode=CityCode, apikey=self.runtime.credentials.get("api_key")),
)
weatherInfo_data = weatherInfo_response.json()
if weatherInfo_response.status_code == 200 and weatherInfo_data.get('info') == 'OK':
if weatherInfo_response.status_code == 200 and weatherInfo_data.get("info") == "OK":
contents = []
if len(weatherInfo_data.get('forecasts')) > 0:
for item in weatherInfo_data['forecasts'][0]['casts']:
if len(weatherInfo_data.get("forecasts")) > 0:
for item in weatherInfo_data["forecasts"][0]["casts"]:
content = {}
content['date'] = item.get('date')
content['week'] = item.get('week')
content['dayweather'] = item.get('dayweather')
content['daytemp_float'] = item.get('daytemp_float')
content['daywind'] = item.get('daywind')
content['nightweather'] = item.get('nightweather')
content['nighttemp_float'] = item.get('nighttemp_float')
content["date"] = item.get("date")
content["week"] = item.get("week")
content["dayweather"] = item.get("dayweather")
content["daytemp_float"] = item.get("daytemp_float")
content["daywind"] = item.get("daywind")
content["nightweather"] = item.get("nightweather")
content["nighttemp_float"] = item.get("nighttemp_float")
contents.append(content)
s.close()
return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(contents, ensure_ascii=False)))
return self.create_text_message(
self.summary(user_id=user_id, content=json.dumps(contents, ensure_ascii=False))
)
s.close()
return self.create_text_message(f'No weather information for {city} was found.')
return self.create_text_message(f"No weather information for {city} was found.")
except Exception as e:
return self.create_text_message("Gaode API Key and Api Version is invalid. {}".format(e))

View File

@ -7,16 +7,13 @@ class GetImgAIProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
# Example validation using the text2image tool
Text2ImageTool().fork_tool_runtime(
runtime={"credentials": credentials}
).invoke(
user_id='',
Text2ImageTool().fork_tool_runtime(runtime={"credentials": credentials}).invoke(
user_id="",
tool_parameters={
"prompt": "A fire egg",
"response_format": "url",
"style": "photorealism",
}
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -8,18 +8,16 @@ from requests.exceptions import HTTPError
logger = logging.getLogger(__name__)
class GetImgAIApp:
def __init__(self, api_key: str | None = None, base_url: str | None = None):
self.api_key = api_key
self.base_url = base_url or 'https://api.getimg.ai/v1'
self.base_url = base_url or "https://api.getimg.ai/v1"
if not self.api_key:
raise ValueError("API key is required")
def _prepare_headers(self):
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.api_key}'
}
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
return headers
def _request(
@ -38,22 +36,20 @@ class GetImgAIApp:
return response.json()
except requests.exceptions.RequestException as e:
if i < retries - 1 and isinstance(e, HTTPError) and e.response.status_code >= 500:
time.sleep(backoff_factor * (2 ** i))
time.sleep(backoff_factor * (2**i))
else:
raise
return None
def text2image(
self, mode: str, **kwargs
):
data = kwargs['params']
if not data.get('prompt'):
def text2image(self, mode: str, **kwargs):
data = kwargs["params"]
if not data.get("prompt"):
raise ValueError("Prompt is required")
endpoint = f'{self.base_url}/{mode}/text-to-image'
endpoint = f"{self.base_url}/{mode}/text-to-image"
headers = self._prepare_headers()
logger.debug(f"Send request to {endpoint=} body={data}")
response = self._request('POST', endpoint, data, headers)
response = self._request("POST", endpoint, data, headers)
if response is None:
raise HTTPError("Failed to initiate getimg.ai after multiple retries")
return response

View File

@ -7,28 +7,28 @@ from core.tools.tool.builtin_tool import BuiltinTool
class Text2ImageTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
app = GetImgAIApp(api_key=self.runtime.credentials['getimg_api_key'], base_url=self.runtime.credentials['base_url'])
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
app = GetImgAIApp(
api_key=self.runtime.credentials["getimg_api_key"], base_url=self.runtime.credentials["base_url"]
)
options = {
'style': tool_parameters.get('style'),
'prompt': tool_parameters.get('prompt'),
'aspect_ratio': tool_parameters.get('aspect_ratio'),
'output_format': tool_parameters.get('output_format', 'jpeg'),
'response_format': tool_parameters.get('response_format', 'url'),
'width': tool_parameters.get('width'),
'height': tool_parameters.get('height'),
'steps': tool_parameters.get('steps'),
'negative_prompt': tool_parameters.get('negative_prompt'),
'prompt_2': tool_parameters.get('prompt_2'),
"style": tool_parameters.get("style"),
"prompt": tool_parameters.get("prompt"),
"aspect_ratio": tool_parameters.get("aspect_ratio"),
"output_format": tool_parameters.get("output_format", "jpeg"),
"response_format": tool_parameters.get("response_format", "url"),
"width": tool_parameters.get("width"),
"height": tool_parameters.get("height"),
"steps": tool_parameters.get("steps"),
"negative_prompt": tool_parameters.get("negative_prompt"),
"prompt_2": tool_parameters.get("prompt_2"),
}
options = {k: v for k, v in options.items() if v}
text2image_result = app.text2image(
mode=tool_parameters.get('mode', 'essential-v2'),
params=options,
wait=True
)
text2image_result = app.text2image(mode=tool_parameters.get("mode", "essential-v2"), params=options, wait=True)
if not isinstance(text2image_result, str):
text2image_result = json.dumps(text2image_result, ensure_ascii=False, indent=4)

View File

@ -7,25 +7,25 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
class GithubProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
if 'access_tokens' not in credentials or not credentials.get('access_tokens'):
if "access_tokens" not in credentials or not credentials.get("access_tokens"):
raise ToolProviderCredentialValidationError("Github API Access Tokens is required.")
if 'api_version' not in credentials or not credentials.get('api_version'):
api_version = '2022-11-28'
if "api_version" not in credentials or not credentials.get("api_version"):
api_version = "2022-11-28"
else:
api_version = credentials.get('api_version')
api_version = credentials.get("api_version")
try:
headers = {
"Content-Type": "application/vnd.github+json",
"Authorization": f"Bearer {credentials.get('access_tokens')}",
"X-GitHub-Api-Version": api_version
"X-GitHub-Api-Version": api_version,
}
response = requests.get(
url="https://api.github.com/search/users?q={account}".format(account='charli117'),
headers=headers)
url="https://api.github.com/search/users?q={account}".format(account="charli117"), headers=headers
)
if response.status_code != 200:
raise ToolProviderCredentialValidationError((response.json()).get('message'))
raise ToolProviderCredentialValidationError((response.json()).get("message"))
except Exception as e:
raise ToolProviderCredentialValidationError("Github API Key and Api Version is invalid. {}".format(e))
except Exception as e:

View File

@ -10,53 +10,61 @@ from core.tools.tool.builtin_tool import BuiltinTool
class GithubRepositoriesTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
invoke tools
"""
top_n = tool_parameters.get('top_n', 5)
query = tool_parameters.get('query', '')
top_n = tool_parameters.get("top_n", 5)
query = tool_parameters.get("query", "")
if not query:
return self.create_text_message('Please input symbol')
return self.create_text_message("Please input symbol")
if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'):
if "access_tokens" not in self.runtime.credentials or not self.runtime.credentials.get("access_tokens"):
return self.create_text_message("Github API Access Tokens is required.")
if 'api_version' not in self.runtime.credentials or not self.runtime.credentials.get('api_version'):
api_version = '2022-11-28'
if "api_version" not in self.runtime.credentials or not self.runtime.credentials.get("api_version"):
api_version = "2022-11-28"
else:
api_version = self.runtime.credentials.get('api_version')
api_version = self.runtime.credentials.get("api_version")
try:
headers = {
"Content-Type": "application/vnd.github+json",
"Authorization": f"Bearer {self.runtime.credentials.get('access_tokens')}",
"X-GitHub-Api-Version": api_version
"X-GitHub-Api-Version": api_version,
}
s = requests.session()
api_domain = 'https://api.github.com'
response = s.request(method='GET', headers=headers,
url=f"{api_domain}/search/repositories?"
f"q={quote(query)}&sort=stars&per_page={top_n}&order=desc")
api_domain = "https://api.github.com"
response = s.request(
method="GET",
headers=headers,
url=f"{api_domain}/search/repositories?q={quote(query)}&sort=stars&per_page={top_n}&order=desc",
)
response_data = response.json()
if response.status_code == 200 and isinstance(response_data.get('items'), list):
if response.status_code == 200 and isinstance(response_data.get("items"), list):
contents = []
if len(response_data.get('items')) > 0:
for item in response_data.get('items'):
if len(response_data.get("items")) > 0:
for item in response_data.get("items"):
content = {}
updated_at_object = datetime.strptime(item['updated_at'], "%Y-%m-%dT%H:%M:%SZ")
content['owner'] = item['owner']['login']
content['name'] = item['name']
content['description'] = item['description'][:100] + '...' if len(item['description']) > 100 else item['description']
content['url'] = item['html_url']
content['star'] = item['watchers']
content['forks'] = item['forks']
content['updated'] = updated_at_object.strftime("%Y-%m-%d")
updated_at_object = datetime.strptime(item["updated_at"], "%Y-%m-%dT%H:%M:%SZ")
content["owner"] = item["owner"]["login"]
content["name"] = item["name"]
content["description"] = (
item["description"][:100] + "..." if len(item["description"]) > 100 else item["description"]
)
content["url"] = item["html_url"]
content["star"] = item["watchers"]
content["forks"] = item["forks"]
content["updated"] = updated_at_object.strftime("%Y-%m-%d")
contents.append(content)
s.close()
return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(contents, ensure_ascii=False)))
return self.create_text_message(
self.summary(user_id=user_id, content=json.dumps(contents, ensure_ascii=False))
)
else:
return self.create_text_message(f'No items related to {query} were found.')
return self.create_text_message(f"No items related to {query} were found.")
else:
return self.create_text_message((response.json()).get('message'))
return self.create_text_message((response.json()).get("message"))
except Exception as e:
return self.create_text_message("Github API Key and Api Version is invalid. {}".format(e))

View File

@ -9,13 +9,13 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
class GitlabProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
if 'access_tokens' not in credentials or not credentials.get('access_tokens'):
if "access_tokens" not in credentials or not credentials.get("access_tokens"):
raise ToolProviderCredentialValidationError("Gitlab Access Tokens is required.")
if 'site_url' not in credentials or not credentials.get('site_url'):
site_url = 'https://gitlab.com'
if "site_url" not in credentials or not credentials.get("site_url"):
site_url = "https://gitlab.com"
else:
site_url = credentials.get('site_url')
site_url = credentials.get("site_url")
try:
headers = {
@ -23,12 +23,10 @@ class GitlabProvider(BuiltinToolProviderController):
"Authorization": f"Bearer {credentials.get('access_tokens')}",
}
response = requests.get(
url= f"{site_url}/api/v4/user",
headers=headers)
response = requests.get(url=f"{site_url}/api/v4/user", headers=headers)
if response.status_code != 200:
raise ToolProviderCredentialValidationError((response.json()).get('message'))
raise ToolProviderCredentialValidationError((response.json()).get("message"))
except Exception as e:
raise ToolProviderCredentialValidationError("Gitlab Access Tokens is invalid. {}".format(e))
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
raise ToolProviderCredentialValidationError(str(e))

View File

@ -1,4 +1,5 @@
import json
import urllib.parse
from datetime import datetime, timedelta
from typing import Any, Union
@ -9,103 +10,133 @@ from core.tools.tool.builtin_tool import BuiltinTool
class GitlabCommitsTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
project = tool_parameters.get("project", "")
repository = tool_parameters.get("repository", "")
employee = tool_parameters.get("employee", "")
start_time = tool_parameters.get("start_time", "")
end_time = tool_parameters.get("end_time", "")
change_type = tool_parameters.get("change_type", "all")
project = tool_parameters.get('project', '')
employee = tool_parameters.get('employee', '')
start_time = tool_parameters.get('start_time', '')
end_time = tool_parameters.get('end_time', '')
change_type = tool_parameters.get('change_type', 'all')
if not project:
return self.create_text_message('Project is required')
if not project and not repository:
return self.create_text_message("Either project or repository is required")
if not start_time:
start_time = (datetime.utcnow() - timedelta(days=1)).isoformat()
if not end_time:
end_time = datetime.utcnow().isoformat()
access_token = self.runtime.credentials.get('access_tokens')
site_url = self.runtime.credentials.get('site_url')
access_token = self.runtime.credentials.get("access_tokens")
site_url = self.runtime.credentials.get("site_url")
if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'):
if "access_tokens" not in self.runtime.credentials or not self.runtime.credentials.get("access_tokens"):
return self.create_text_message("Gitlab API Access Tokens is required.")
if 'site_url' not in self.runtime.credentials or not self.runtime.credentials.get('site_url'):
site_url = 'https://gitlab.com'
if "site_url" not in self.runtime.credentials or not self.runtime.credentials.get("site_url"):
site_url = "https://gitlab.com"
# Get commit content
result = self.fetch(user_id, site_url, access_token, project, employee, start_time, end_time, change_type)
if repository:
result = self.fetch_commits(
site_url, access_token, repository, employee, start_time, end_time, change_type, is_repository=True
)
else:
result = self.fetch_commits(
site_url, access_token, project, employee, start_time, end_time, change_type, is_repository=False
)
return [self.create_json_message(item) for item in result]
def fetch(self,user_id: str, site_url: str, access_token: str, project: str, employee: str = None, start_time: str = '', end_time: str = '', change_type: str = '') -> list[dict[str, Any]]:
def fetch_commits(
self,
site_url: str,
access_token: str,
identifier: str,
employee: str,
start_time: str,
end_time: str,
change_type: str,
is_repository: bool,
) -> list[dict[str, Any]]:
domain = site_url
headers = {"PRIVATE-TOKEN": access_token}
results = []
try:
# Get all of projects
url = f"{domain}/api/v4/projects"
response = requests.get(url, headers=headers)
response.raise_for_status()
projects = response.json()
if is_repository:
# URL encode the repository path
encoded_identifier = urllib.parse.quote(identifier, safe="")
commits_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/commits"
else:
# Get all projects
url = f"{domain}/api/v4/projects"
response = requests.get(url, headers=headers)
response.raise_for_status()
projects = response.json()
filtered_projects = [p for p in projects if project == "*" or p['name'] == project]
filtered_projects = [p for p in projects if identifier == "*" or p["name"] == identifier]
for project in filtered_projects:
project_id = project['id']
project_name = project['name']
print(f"Project: {project_name}")
for project in filtered_projects:
project_id = project["id"]
project_name = project["name"]
print(f"Project: {project_name}")
# Get all of project commits
commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits"
params = {
'since': start_time,
'until': end_time
}
if employee:
params['author'] = employee
commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits"
commits_response = requests.get(commits_url, headers=headers, params=params)
commits_response.raise_for_status()
commits = commits_response.json()
params = {"since": start_time, "until": end_time}
if employee:
params["author"] = employee
for commit in commits:
commit_sha = commit['id']
author_name = commit['author_name']
commits_response = requests.get(commits_url, headers=headers, params=params)
commits_response.raise_for_status()
commits = commits_response.json()
for commit in commits:
commit_sha = commit["id"]
author_name = commit["author_name"]
if is_repository:
diff_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/commits/{commit_sha}/diff"
else:
diff_url = f"{domain}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/diff"
diff_response = requests.get(diff_url, headers=headers)
diff_response.raise_for_status()
diffs = diff_response.json()
for diff in diffs:
# Calculate code lines of changed
added_lines = diff['diff'].count('\n+')
removed_lines = diff['diff'].count('\n-')
total_changes = added_lines + removed_lines
if change_type == "new":
if added_lines > 1:
final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if line.startswith('+') and not line.startswith('+++')])
results.append({
"commit_sha": commit_sha,
"author_name": author_name,
"diff": final_code
})
else:
if total_changes > 1:
final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if (line.startswith('+') or line.startswith('-')) and not line.startswith('+++') and not line.startswith('---')])
final_code_escaped = json.dumps(final_code)[1:-1] # Escape the final code
results.append({
"commit_sha": commit_sha,
"author_name": author_name,
"diff": final_code_escaped
})
diff_response = requests.get(diff_url, headers=headers)
diff_response.raise_for_status()
diffs = diff_response.json()
for diff in diffs:
# Calculate code lines of changes
added_lines = diff["diff"].count("\n+")
removed_lines = diff["diff"].count("\n-")
total_changes = added_lines + removed_lines
if change_type == "new":
if added_lines > 1:
final_code = "".join(
[
line[1:]
for line in diff["diff"].split("\n")
if line.startswith("+") and not line.startswith("+++")
]
)
results.append({"commit_sha": commit_sha, "author_name": author_name, "diff": final_code})
else:
if total_changes > 1:
final_code = "".join(
[
line[1:]
for line in diff["diff"].split("\n")
if (line.startswith("+") or line.startswith("-"))
and not line.startswith("+++")
and not line.startswith("---")
]
)
final_code_escaped = json.dumps(final_code)[1:-1] # Escape the final code
results.append(
{"commit_sha": commit_sha, "author_name": author_name, "diff": final_code_escaped}
)
except requests.RequestException as e:
print(f"Error fetching data from GitLab: {e}")
return results
return results

View File

@ -21,9 +21,20 @@ parameters:
zh_Hans: 员工用户名
llm_description: User name for GitLab
form: llm
- name: repository
type: string
required: false
label:
en_US: repository
zh_Hans: 仓库路径
human_description:
en_US: repository
zh_Hans: 仓库路径以namespace/project_name的形式。
llm_description: Repository path for GitLab, like namespace/project_name.
form: llm
- name: project
type: string
required: true
required: false
label:
en_US: project
zh_Hans: 项目名

View File

@ -1,3 +1,4 @@
import urllib.parse
from typing import Any, Union
import requests
@ -7,47 +8,85 @@ from core.tools.tool.builtin_tool import BuiltinTool
class GitlabFilesTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
project = tool_parameters.get('project', '')
branch = tool_parameters.get('branch', '')
path = tool_parameters.get('path', '')
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
project = tool_parameters.get("project", "")
repository = tool_parameters.get("repository", "")
branch = tool_parameters.get("branch", "")
path = tool_parameters.get("path", "")
if not project:
return self.create_text_message('Project is required')
if not project and not repository:
return self.create_text_message("Either project or repository is required")
if not branch:
return self.create_text_message('Branch is required')
return self.create_text_message("Branch is required")
if not path:
return self.create_text_message('Path is required')
return self.create_text_message("Path is required")
access_token = self.runtime.credentials.get('access_tokens')
site_url = self.runtime.credentials.get('site_url')
access_token = self.runtime.credentials.get("access_tokens")
site_url = self.runtime.credentials.get("site_url")
if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'):
if "access_tokens" not in self.runtime.credentials or not self.runtime.credentials.get("access_tokens"):
return self.create_text_message("Gitlab API Access Tokens is required.")
if 'site_url' not in self.runtime.credentials or not self.runtime.credentials.get('site_url'):
site_url = 'https://gitlab.com'
# Get project ID from project name
project_id = self.get_project_id(site_url, access_token, project)
if not project_id:
return self.create_text_message(f"Project '{project}' not found.")
if "site_url" not in self.runtime.credentials or not self.runtime.credentials.get("site_url"):
site_url = "https://gitlab.com"
# Get commit content
result = self.fetch(user_id, project_id, site_url, access_token, branch, path)
# Get file content
if repository:
result = self.fetch_files(site_url, access_token, repository, branch, path, is_repository=True)
else:
result = self.fetch_files(site_url, access_token, project, branch, path, is_repository=False)
return [self.create_json_message(item) for item in result]
def extract_project_name_and_path(self, path: str) -> tuple[str, str]:
parts = path.split('/', 1)
if len(parts) < 2:
return None, None
return parts[0], parts[1]
def fetch_files(
self, site_url: str, access_token: str, identifier: str, branch: str, path: str, is_repository: bool
) -> list[dict[str, Any]]:
domain = site_url
headers = {"PRIVATE-TOKEN": access_token}
results = []
try:
if is_repository:
# URL encode the repository path
encoded_identifier = urllib.parse.quote(identifier, safe="")
tree_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/tree?path={path}&ref={branch}"
else:
# Get project ID from project name
project_id = self.get_project_id(site_url, access_token, identifier)
if not project_id:
return self.create_text_message(f"Project '{identifier}' not found.")
tree_url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}"
response = requests.get(tree_url, headers=headers)
response.raise_for_status()
items = response.json()
for item in items:
item_path = item["path"]
if item["type"] == "tree": # It's a directory
results.extend(
self.fetch_files(site_url, access_token, identifier, branch, item_path, is_repository)
)
else: # It's a file
if is_repository:
file_url = (
f"{domain}/api/v4/projects/{encoded_identifier}/repository/files"
f"/{item_path}/raw?ref={branch}"
)
else:
file_url = (
f"{domain}/api/v4/projects/{project_id}/repository/files/{item_path}/raw?ref={branch}"
)
file_response = requests.get(file_url, headers=headers)
file_response.raise_for_status()
file_content = file_response.text
results.append({"path": item_path, "branch": branch, "content": file_content})
except requests.RequestException as e:
print(f"Error fetching data from GitLab: {e}")
return results
def get_project_id(self, site_url: str, access_token: str, project_name: str) -> Union[str, None]:
headers = {"PRIVATE-TOKEN": access_token}
@ -57,39 +96,8 @@ class GitlabFilesTool(BuiltinTool):
response.raise_for_status()
projects = response.json()
for project in projects:
if project['name'] == project_name:
return project['id']
if project["name"] == project_name:
return project["id"]
except requests.RequestException as e:
print(f"Error fetching project ID from GitLab: {e}")
return None
def fetch(self,user_id: str, project_id: str, site_url: str, access_token: str, branch: str, path: str = None) -> list[dict[str, Any]]:
domain = site_url
headers = {"PRIVATE-TOKEN": access_token}
results = []
try:
# List files and directories in the given path
url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}"
response = requests.get(url, headers=headers)
response.raise_for_status()
items = response.json()
for item in items:
item_path = item['path']
if item['type'] == 'tree': # It's a directory
results.extend(self.fetch(project_id, site_url, access_token, branch, item_path))
else: # It's a file
file_url = f"{domain}/api/v4/projects/{project_id}/repository/files/{item_path}/raw?ref={branch}"
file_response = requests.get(file_url, headers=headers)
file_response.raise_for_status()
file_content = file_response.text
results.append({
"path": item_path,
"branch": branch,
"content": file_content
})
except requests.RequestException as e:
print(f"Error fetching data from GitLab: {e}")
return results

View File

@ -10,9 +10,20 @@ description:
zh_Hans: 一个用于查询 GitLab 文件的工具,输入的内容应该是分支和一个已存在文件或者文件夹路径。
llm: A tool for query GitLab files, Input should be a exists file or directory path.
parameters:
- name: repository
type: string
required: false
label:
en_US: repository
zh_Hans: 仓库路径
human_description:
en_US: repository
zh_Hans: 仓库路径以namespace/project_name的形式。
llm_description: Repository path for GitLab, like namespace/project_name.
form: llm
- name: project
type: string
required: true
required: false
label:
en_US: project
zh_Hans: 项目

View File

@ -13,12 +13,8 @@ class GoogleProvider(BuiltinToolProviderController):
"credentials": credentials,
}
).invoke(
user_id='',
tool_parameters={
"query": "test",
"result_type": "link"
},
user_id="",
tool_parameters={"query": "test", "result_type": "link"},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -9,7 +9,6 @@ SERP_API_URL = "https://serpapi.com/search"
class GoogleSearchTool(BuiltinTool):
def _parse_response(self, response: dict) -> dict:
result = {}
if "knowledge_graph" in response:
@ -17,25 +16,23 @@ class GoogleSearchTool(BuiltinTool):
result["description"] = response["knowledge_graph"].get("description", "")
if "organic_results" in response:
result["organic_results"] = [
{
"title": item.get("title", ""),
"link": item.get("link", ""),
"snippet": item.get("snippet", "")
}
{"title": item.get("title", ""), "link": item.get("link", ""), "snippet": item.get("snippet", "")}
for item in response["organic_results"]
]
return result
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
params = {
"api_key": self.runtime.credentials['serpapi_api_key'],
"q": tool_parameters['query'],
"api_key": self.runtime.credentials["serpapi_api_key"],
"q": tool_parameters["query"],
"engine": "google",
"google_domain": "google.com",
"gl": "us",
"hl": "en"
"hl": "en",
}
response = requests.get(url=SERP_API_URL, params=params)
response.raise_for_status()

View File

@ -8,10 +8,6 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
class JsonExtractProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
GoogleTranslate().invoke(user_id='',
tool_parameters={
"content": "这是一段测试文本",
"dest": "en"
})
GoogleTranslate().invoke(user_id="", tool_parameters={"content": "这是一段测试文本", "dest": "en"})
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -7,46 +7,41 @@ from core.tools.tool.builtin_tool import BuiltinTool
class GoogleTranslate(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
invoke tools
"""
content = tool_parameters.get('content', '')
content = tool_parameters.get("content", "")
if not content:
return self.create_text_message('Invalid parameter content')
return self.create_text_message("Invalid parameter content")
dest = tool_parameters.get('dest', '')
dest = tool_parameters.get("dest", "")
if not dest:
return self.create_text_message('Invalid parameter destination language')
return self.create_text_message("Invalid parameter destination language")
try:
result = self._translate(content, dest)
return self.create_text_message(str(result))
except Exception:
return self.create_text_message('Translation service error, please check the network')
return self.create_text_message("Translation service error, please check the network")
def _translate(self, content: str, dest: str) -> str:
try:
url = "https://translate.googleapis.com/translate_a/single"
params = {
"client": "gtx",
"sl": "auto",
"tl": dest,
"dt": "t",
"q": content
}
params = {"client": "gtx", "sl": "auto", "tl": dest, "dt": "t", "q": content}
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
" Chrome/91.0.4472.124 Safari/537.36"
}
response_json = requests.get(
url, params=params, headers=headers).json()
response_json = requests.get(url, params=params, headers=headers).json()
result = response_json[0]
translated_text = ''.join([item[0] for item in result if item[0]])
translated_text = "".join([item[0] for item in result if item[0]])
return str(translated_text)
except Exception as e:
return str(e)

Some files were not shown because too many files have changed in this diff Show More