mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 09:28:04 +08:00
Merge main
This commit is contained in:
@ -11,23 +11,23 @@ from core.tools.tool.tool import ToolParameter
|
||||
|
||||
class UserTool(BaseModel):
|
||||
author: str
|
||||
name: str # identifier
|
||||
label: I18nObject # label
|
||||
name: str # identifier
|
||||
label: I18nObject # label
|
||||
description: I18nObject
|
||||
parameters: Optional[list[ToolParameter]] = None
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
|
||||
UserToolProviderTypeLiteral = Optional[Literal[
|
||||
'builtin', 'api', 'workflow'
|
||||
]]
|
||||
|
||||
UserToolProviderTypeLiteral = Optional[Literal["builtin", "api", "workflow"]]
|
||||
|
||||
|
||||
class UserToolProvider(BaseModel):
|
||||
id: str
|
||||
author: str
|
||||
name: str # identifier
|
||||
name: str # identifier
|
||||
description: I18nObject
|
||||
icon: str | dict
|
||||
label: I18nObject # label
|
||||
label: I18nObject # label
|
||||
type: ToolProviderType
|
||||
masked_credentials: Optional[dict] = None
|
||||
original_credentials: Optional[dict] = None
|
||||
@ -41,26 +41,27 @@ class UserToolProvider(BaseModel):
|
||||
# overwrite tool parameter types for temp fix
|
||||
tools = jsonable_encoder(self.tools)
|
||||
for tool in tools:
|
||||
if tool.get('parameters'):
|
||||
for parameter in tool.get('parameters'):
|
||||
if parameter.get('type') == ToolParameter.ToolParameterType.FILE.value:
|
||||
parameter['type'] = 'files'
|
||||
if tool.get("parameters"):
|
||||
for parameter in tool.get("parameters"):
|
||||
if parameter.get("type") == ToolParameter.ToolParameterType.FILE.value:
|
||||
parameter["type"] = "files"
|
||||
# -------------
|
||||
|
||||
return {
|
||||
'id': self.id,
|
||||
'author': self.author,
|
||||
'name': self.name,
|
||||
'description': self.description.to_dict(),
|
||||
'icon': self.icon,
|
||||
'label': self.label.to_dict(),
|
||||
'type': self.type.value,
|
||||
'team_credentials': self.masked_credentials,
|
||||
'is_team_authorization': self.is_team_authorization,
|
||||
'allow_delete': self.allow_delete,
|
||||
'tools': tools,
|
||||
'labels': self.labels,
|
||||
"id": self.id,
|
||||
"author": self.author,
|
||||
"name": self.name,
|
||||
"description": self.description.to_dict(),
|
||||
"icon": self.icon,
|
||||
"label": self.label.to_dict(),
|
||||
"type": self.type.value,
|
||||
"team_credentials": self.masked_credentials,
|
||||
"is_team_authorization": self.is_team_authorization,
|
||||
"allow_delete": self.allow_delete,
|
||||
"tools": tools,
|
||||
"labels": self.labels,
|
||||
}
|
||||
|
||||
|
||||
class UserToolProviderCredentials(BaseModel):
|
||||
credentials: dict[str, ProviderConfig]
|
||||
credentials: dict[str, ProviderConfig]
|
||||
|
||||
@ -7,6 +7,7 @@ class I18nObject(BaseModel):
|
||||
"""
|
||||
Model class for i18n object.
|
||||
"""
|
||||
|
||||
zh_Hans: Optional[str] = None
|
||||
pt_BR: Optional[str] = None
|
||||
en_US: str
|
||||
@ -19,8 +20,4 @@ class I18nObject(BaseModel):
|
||||
self.pt_BR = self.en_US
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
'zh_Hans': self.zh_Hans,
|
||||
'en_US': self.en_US,
|
||||
'pt_BR': self.pt_BR
|
||||
}
|
||||
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR}
|
||||
|
||||
@ -7,8 +7,10 @@ from core.tools.entities.tool_entities import ToolParameter
|
||||
|
||||
class ApiToolBundle(BaseModel):
|
||||
"""
|
||||
This class is used to store the schema information of an api based tool. such as the url, the method, the parameters, etc.
|
||||
This class is used to store the schema information of an api based tool.
|
||||
such as the url, the method, the parameters, etc.
|
||||
"""
|
||||
|
||||
# server_url
|
||||
server_url: str
|
||||
# method
|
||||
|
||||
@ -9,27 +9,29 @@ from core.tools.entities.common_entities import I18nObject
|
||||
|
||||
|
||||
class ToolLabelEnum(Enum):
|
||||
SEARCH = 'search'
|
||||
IMAGE = 'image'
|
||||
VIDEOS = 'videos'
|
||||
WEATHER = 'weather'
|
||||
FINANCE = 'finance'
|
||||
DESIGN = 'design'
|
||||
TRAVEL = 'travel'
|
||||
SOCIAL = 'social'
|
||||
NEWS = 'news'
|
||||
MEDICAL = 'medical'
|
||||
PRODUCTIVITY = 'productivity'
|
||||
EDUCATION = 'education'
|
||||
BUSINESS = 'business'
|
||||
ENTERTAINMENT = 'entertainment'
|
||||
UTILITIES = 'utilities'
|
||||
OTHER = 'other'
|
||||
SEARCH = "search"
|
||||
IMAGE = "image"
|
||||
VIDEOS = "videos"
|
||||
WEATHER = "weather"
|
||||
FINANCE = "finance"
|
||||
DESIGN = "design"
|
||||
TRAVEL = "travel"
|
||||
SOCIAL = "social"
|
||||
NEWS = "news"
|
||||
MEDICAL = "medical"
|
||||
PRODUCTIVITY = "productivity"
|
||||
EDUCATION = "education"
|
||||
BUSINESS = "business"
|
||||
ENTERTAINMENT = "entertainment"
|
||||
UTILITIES = "utilities"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
class ToolProviderType(str, Enum):
|
||||
"""
|
||||
Enum class for tool provider
|
||||
Enum class for tool provider
|
||||
"""
|
||||
|
||||
BUILT_IN = "builtin"
|
||||
WORKFLOW = "workflow"
|
||||
API = "api"
|
||||
@ -37,7 +39,7 @@ class ToolProviderType(str, Enum):
|
||||
DATASET_RETRIEVAL = "dataset-retrieval"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'ToolProviderType':
|
||||
def value_of(cls, value: str) -> "ToolProviderType":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
@ -47,19 +49,21 @@ class ToolProviderType(str, Enum):
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid mode value {value}')
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
class ApiProviderSchemaType(Enum):
|
||||
"""
|
||||
Enum class for api provider schema type.
|
||||
"""
|
||||
|
||||
OPENAPI = "openapi"
|
||||
SWAGGER = "swagger"
|
||||
OPENAI_PLUGIN = "openai_plugin"
|
||||
OPENAI_ACTIONS = "openai_actions"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'ApiProviderSchemaType':
|
||||
def value_of(cls, value: str) -> "ApiProviderSchemaType":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
@ -69,17 +73,19 @@ class ApiProviderSchemaType(Enum):
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid mode value {value}')
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
class ApiProviderAuthType(Enum):
|
||||
"""
|
||||
Enum class for api provider auth type.
|
||||
"""
|
||||
|
||||
NONE = "none"
|
||||
API_KEY = "api_key"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'ApiProviderAuthType':
|
||||
def value_of(cls, value: str) -> "ApiProviderAuthType":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
@ -89,7 +95,8 @@ class ApiProviderAuthType(Enum):
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid mode value {value}')
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
class ToolInvokeMessage(BaseModel):
|
||||
class TextMessage(BaseModel):
|
||||
@ -107,7 +114,7 @@ class ToolInvokeMessage(BaseModel):
|
||||
stream: bool = Field(default=False, description="Whether the variable is streamed")
|
||||
|
||||
@field_validator("variable_value", mode="before")
|
||||
def transform_variable_value(cls, value, values) -> Any:
|
||||
def transform_variable_value(self, value, values) -> Any:
|
||||
"""
|
||||
Only basic types and lists are allowed.
|
||||
"""
|
||||
@ -122,11 +129,11 @@ class ToolInvokeMessage(BaseModel):
|
||||
return value
|
||||
|
||||
@field_validator("variable_name", mode="before")
|
||||
def transform_variable_name(cls, value) -> str:
|
||||
def transform_variable_name(self, value) -> str:
|
||||
"""
|
||||
The variable name must be a string.
|
||||
"""
|
||||
if value in ["json", "text", "files"]:
|
||||
if value in {"json", "text", "files"}:
|
||||
raise ValueError(f"The variable name '{value}' is reserved.")
|
||||
return value
|
||||
|
||||
@ -146,7 +153,7 @@ class ToolInvokeMessage(BaseModel):
|
||||
"""
|
||||
message: JsonMessage | TextMessage | BlobMessage | VariableMessage | None
|
||||
meta: dict[str, Any] | None = None
|
||||
save_as: str = ''
|
||||
save_as: str = ""
|
||||
|
||||
@field_validator('message', mode='before')
|
||||
@classmethod
|
||||
@ -166,17 +173,19 @@ class ToolInvokeMessage(BaseModel):
|
||||
}
|
||||
return v
|
||||
|
||||
|
||||
class ToolInvokeMessageBinary(BaseModel):
|
||||
mimetype: str = Field(..., description="The mimetype of the binary")
|
||||
url: str = Field(..., description="The url of the binary")
|
||||
save_as: str = ''
|
||||
save_as: str = ""
|
||||
file_var: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class ToolParameterOption(BaseModel):
|
||||
value: str = Field(..., description="The value of the option")
|
||||
label: I18nObject = Field(..., description="The label of the option")
|
||||
|
||||
@field_validator('value', mode='before')
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
def transform_id_to_str(cls, value) -> str:
|
||||
if not isinstance(value, str):
|
||||
@ -195,9 +204,9 @@ class ToolParameter(BaseModel):
|
||||
FILE = CommonParameterType.FILE.value
|
||||
|
||||
class ToolParameterForm(Enum):
|
||||
SCHEMA = "schema" # should be set while adding tool
|
||||
FORM = "form" # should be set before invoking tool
|
||||
LLM = "llm" # will be set by LLM
|
||||
SCHEMA = "schema" # should be set while adding tool
|
||||
FORM = "form" # should be set before invoking tool
|
||||
LLM = "llm" # will be set by LLM
|
||||
|
||||
name: str = Field(..., description="The name of the parameter")
|
||||
label: I18nObject = Field(..., description="The label presented to the user")
|
||||
@ -214,21 +223,28 @@ class ToolParameter(BaseModel):
|
||||
options: list[ToolParameterOption] = Field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def get_simple_instance(cls,
|
||||
name: str, llm_description: str, type: ToolParameterType,
|
||||
required: bool, options: Optional[list[str]] = None) -> 'ToolParameter':
|
||||
def get_simple_instance(
|
||||
cls,
|
||||
name: str,
|
||||
llm_description: str,
|
||||
type: ToolParameterType,
|
||||
required: bool,
|
||||
options: Optional[list[str]] = None,
|
||||
) -> "ToolParameter":
|
||||
"""
|
||||
get a simple tool parameter
|
||||
get a simple tool parameter
|
||||
|
||||
:param name: the name of the parameter
|
||||
:param llm_description: the description presented to the LLM
|
||||
:param type: the type of the parameter
|
||||
:param required: if the parameter is required
|
||||
:param options: the options of the parameter
|
||||
:param name: the name of the parameter
|
||||
:param llm_description: the description presented to the LLM
|
||||
:param type: the type of the parameter
|
||||
:param required: if the parameter is required
|
||||
:param options: the options of the parameter
|
||||
"""
|
||||
# convert options to ToolParameterOption
|
||||
if options:
|
||||
option_objs = [ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options]
|
||||
option_objs = [
|
||||
ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options
|
||||
]
|
||||
else:
|
||||
option_objs = []
|
||||
return cls(
|
||||
@ -243,18 +259,24 @@ class ToolParameter(BaseModel):
|
||||
options=option_objs,
|
||||
)
|
||||
|
||||
|
||||
class ToolProviderIdentity(BaseModel):
|
||||
author: str = Field(..., description="The author of the tool")
|
||||
name: str = Field(..., description="The name of the tool")
|
||||
description: I18nObject = Field(..., description="The description of the tool")
|
||||
icon: str = Field(..., description="The icon of the tool")
|
||||
label: I18nObject = Field(..., description="The label of the tool")
|
||||
tags: Optional[list[ToolLabelEnum]] = Field(default=[], description="The tags of the tool", )
|
||||
tags: Optional[list[ToolLabelEnum]] = Field(
|
||||
default=[],
|
||||
description="The tags of the tool",
|
||||
)
|
||||
|
||||
|
||||
class ToolDescription(BaseModel):
|
||||
human: I18nObject = Field(..., description="The description presented to the user")
|
||||
llm: str = Field(..., description="The description presented to the LLM")
|
||||
|
||||
|
||||
class ToolIdentity(BaseModel):
|
||||
author: str = Field(..., description="The author of the tool")
|
||||
name: str = Field(..., description="The name of the tool")
|
||||
@ -262,22 +284,27 @@ class ToolIdentity(BaseModel):
|
||||
provider: str = Field(..., description="The provider of the tool")
|
||||
icon: Optional[str] = None
|
||||
|
||||
|
||||
class ToolRuntimeVariableType(Enum):
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
|
||||
|
||||
class ToolRuntimeVariable(BaseModel):
|
||||
type: ToolRuntimeVariableType = Field(..., description="The type of the variable")
|
||||
name: str = Field(..., description="The name of the variable")
|
||||
position: int = Field(..., description="The position of the variable")
|
||||
tool_name: str = Field(..., description="The name of the tool")
|
||||
|
||||
|
||||
class ToolRuntimeTextVariable(ToolRuntimeVariable):
|
||||
value: str = Field(..., description="The value of the variable")
|
||||
|
||||
|
||||
class ToolRuntimeImageVariable(ToolRuntimeVariable):
|
||||
value: str = Field(..., description="The path of the image")
|
||||
|
||||
|
||||
class ToolRuntimeVariablePool(BaseModel):
|
||||
conversation_id: str = Field(..., description="The conversation id")
|
||||
user_id: str = Field(..., description="The user id")
|
||||
@ -286,26 +313,26 @@ class ToolRuntimeVariablePool(BaseModel):
|
||||
pool: list[ToolRuntimeVariable] = Field(..., description="The pool of variables")
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
pool = data.get('pool', [])
|
||||
pool = data.get("pool", [])
|
||||
# convert pool into correct type
|
||||
for index, variable in enumerate(pool):
|
||||
if variable['type'] == ToolRuntimeVariableType.TEXT.value:
|
||||
if variable["type"] == ToolRuntimeVariableType.TEXT.value:
|
||||
pool[index] = ToolRuntimeTextVariable(**variable)
|
||||
elif variable['type'] == ToolRuntimeVariableType.IMAGE.value:
|
||||
elif variable["type"] == ToolRuntimeVariableType.IMAGE.value:
|
||||
pool[index] = ToolRuntimeImageVariable(**variable)
|
||||
super().__init__(**data)
|
||||
|
||||
def dict(self) -> dict:
|
||||
return {
|
||||
'conversation_id': self.conversation_id,
|
||||
'user_id': self.user_id,
|
||||
'tenant_id': self.tenant_id,
|
||||
'pool': [variable.model_dump() for variable in self.pool],
|
||||
"conversation_id": self.conversation_id,
|
||||
"user_id": self.user_id,
|
||||
"tenant_id": self.tenant_id,
|
||||
"pool": [variable.model_dump() for variable in self.pool],
|
||||
}
|
||||
|
||||
def set_text(self, tool_name: str, name: str, value: str) -> None:
|
||||
"""
|
||||
set a text variable
|
||||
set a text variable
|
||||
"""
|
||||
for variable in self.pool:
|
||||
if variable.name == name:
|
||||
@ -326,10 +353,10 @@ class ToolRuntimeVariablePool(BaseModel):
|
||||
|
||||
def set_file(self, tool_name: str, value: str, name: Optional[str] = None) -> None:
|
||||
"""
|
||||
set an image variable
|
||||
set an image variable
|
||||
|
||||
:param tool_name: the name of the tool
|
||||
:param value: the id of the file
|
||||
:param tool_name: the name of the tool
|
||||
:param value: the id of the file
|
||||
"""
|
||||
# check how many image variables are there
|
||||
image_variable_count = 0
|
||||
@ -357,22 +384,27 @@ class ToolRuntimeVariablePool(BaseModel):
|
||||
|
||||
self.pool.append(variable)
|
||||
|
||||
|
||||
class ModelToolPropertyKey(Enum):
|
||||
IMAGE_PARAMETER_NAME = "image_parameter_name"
|
||||
|
||||
|
||||
class ModelToolConfiguration(BaseModel):
|
||||
"""
|
||||
Model tool configuration
|
||||
"""
|
||||
|
||||
type: str = Field(..., description="The type of the model tool")
|
||||
model: str = Field(..., description="The model")
|
||||
label: I18nObject = Field(..., description="The label of the model tool")
|
||||
properties: dict[ModelToolPropertyKey, Any] = Field(..., description="The properties of the model tool")
|
||||
|
||||
|
||||
class ModelToolProviderConfiguration(BaseModel):
|
||||
"""
|
||||
Model tool provider configuration
|
||||
"""
|
||||
|
||||
provider: str = Field(..., description="The provider of the model tool")
|
||||
models: list[ModelToolConfiguration] = Field(..., description="The models of the model tool")
|
||||
label: I18nObject = Field(..., description="The label of the model tool")
|
||||
@ -382,27 +414,30 @@ class WorkflowToolParameterConfiguration(BaseModel):
|
||||
"""
|
||||
Workflow tool configuration
|
||||
"""
|
||||
|
||||
name: str = Field(..., description="The name of the parameter")
|
||||
description: str = Field(..., description="The description of the parameter")
|
||||
form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter")
|
||||
|
||||
|
||||
class ToolInvokeMeta(BaseModel):
|
||||
"""
|
||||
Tool invoke meta
|
||||
"""
|
||||
|
||||
time_cost: float = Field(..., description="The time cost of the tool invoke")
|
||||
error: Optional[str] = None
|
||||
tool_config: Optional[dict] = None
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> 'ToolInvokeMeta':
|
||||
def empty(cls) -> "ToolInvokeMeta":
|
||||
"""
|
||||
Get an empty instance of ToolInvokeMeta
|
||||
"""
|
||||
return cls(time_cost=0.0, error=None, tool_config={})
|
||||
|
||||
@classmethod
|
||||
def error_instance(cls, error: str) -> 'ToolInvokeMeta':
|
||||
def error_instance(cls, error: str) -> "ToolInvokeMeta":
|
||||
"""
|
||||
Get an instance of ToolInvokeMeta with error
|
||||
"""
|
||||
@ -410,22 +445,26 @@ class ToolInvokeMeta(BaseModel):
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
'time_cost': self.time_cost,
|
||||
'error': self.error,
|
||||
'tool_config': self.tool_config,
|
||||
"time_cost": self.time_cost,
|
||||
"error": self.error,
|
||||
"tool_config": self.tool_config,
|
||||
}
|
||||
|
||||
|
||||
class ToolLabel(BaseModel):
|
||||
"""
|
||||
Tool label
|
||||
"""
|
||||
|
||||
name: str = Field(..., description="The name of the tool")
|
||||
label: I18nObject = Field(..., description="The label of the tool")
|
||||
icon: str = Field(..., description="The icon of the tool")
|
||||
|
||||
|
||||
class ToolInvokeFrom(Enum):
|
||||
"""
|
||||
Enum class for tool invoke
|
||||
"""
|
||||
|
||||
WORKFLOW = "workflow"
|
||||
AGENT = "agent"
|
||||
|
||||
@ -2,73 +2,109 @@ from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolLabel, ToolLabelEnum
|
||||
|
||||
ICONS = {
|
||||
ToolLabelEnum.SEARCH: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
ToolLabelEnum.SEARCH: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M7.33398 1.3335C10.646 1.3335 13.334 4.0215 13.334 7.3335C13.334 10.6455 10.646 13.3335 7.33398 13.3335C4.02198 13.3335 1.33398 10.6455 1.33398 7.3335C1.33398 4.0215 4.02198 1.3335 7.33398 1.3335ZM7.33398 12.0002C9.91232 12.0002 12.0007 9.91183 12.0007 7.3335C12.0007 4.75516 9.91232 2.66683 7.33398 2.66683C4.75565 2.66683 2.66732 4.75516 2.66732 7.3335C2.66732 9.91183 4.75565 12.0002 7.33398 12.0002ZM12.9909 12.0476L14.8764 13.9332L13.9337 14.876L12.0481 12.9904L12.9909 12.0476Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.IMAGE: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.IMAGE: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M13.0514 9.71752L10.4718 7.13792C10.2115 6.87752 9.78932 6.87752 9.52898 7.13792L4.57721 12.0897C3.4097 11.1113 2.66732 9.64232 2.66732 7.99992C2.66732 5.0544 5.05513 2.66659 8.00065 2.66659C10.9462 2.66659 13.334 5.0544 13.334 7.99992C13.334 8.60085 13.2346 9.17852 13.0514 9.71752ZM5.72683 12.8257L10.0004 8.55212L12.4259 10.9777C11.4668 12.4001 9.84152 13.3331 8.00038 13.3331C7.18632 13.3331 6.41628 13.1511 5.72683 12.8257ZM8.00065 14.6666C11.6825 14.6666 14.6673 11.6818 14.6673 7.99992C14.6673 4.31802 11.6825 1.33325 8.00065 1.33325C4.31875 1.33325 1.33398 4.31802 1.33398 7.99992C1.33398 11.6818 4.31875 14.6666 8.00065 14.6666ZM7.33398 6.66658C7.33398 7.40299 6.73705 7.99992 6.00065 7.99992C5.26427 7.99992 4.66732 7.40299 4.66732 6.66658C4.66732 5.9302 5.26427 5.33325 6.00065 5.33325C6.73705 5.33325 7.33398 5.9302 7.33398 6.66658Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.VIDEOS: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.VIDEOS: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M8.00065 13.3333H13.334V14.6666H8.00065C4.31875 14.6666 1.33398 11.6818 1.33398 7.99992C1.33398 4.31802 4.31875 1.33325 8.00065 1.33325C11.6825 1.33325 14.6673 4.31802 14.6673 7.99992C14.6673 9.50072 14.1714 10.8857 13.3345 11.9999H11.5284C12.6356 11.0227 13.334 9.59285 13.334 7.99992C13.334 5.0544 10.9462 2.66659 8.00065 2.66659C5.05513 2.66659 2.66732 5.0544 2.66732 7.99992C2.66732 10.9455 5.05513 13.3333 8.00065 13.3333ZM8.00065 6.66658C7.26425 6.66658 6.66732 6.06963 6.66732 5.33325C6.66732 4.59687 7.26425 3.99992 8.00065 3.99992C8.73705 3.99992 9.33398 4.59687 9.33398 5.33325C9.33398 6.06963 8.73705 6.66658 8.00065 6.66658ZM5.33398 9.33325C4.5976 9.33325 4.00065 8.73632 4.00065 7.99992C4.00065 7.26352 4.5976 6.66658 5.33398 6.66658C6.07036 6.66658 6.66732 7.26352 6.66732 7.99992C6.66732 8.73632 6.07036 9.33325 5.33398 9.33325ZM10.6673 9.33325C9.93092 9.33325 9.33398 8.73632 9.33398 7.99992C9.33398 7.26352 9.93092 6.66658 10.6673 6.66658C11.4037 6.66658 12.0007 7.26352 12.0007 7.99992C12.0007 8.73632 11.4037 9.33325 10.6673 9.33325ZM8.00065 11.9999C7.26425 11.9999 6.66732 11.403 6.66732 10.6666C6.66732 9.93018 7.26425 9.33325 8.00065 9.33325C8.73705 9.33325 9.33398 9.93018 9.33398 10.6666C9.33398 11.403 8.73705 11.9999 8.00065 11.9999Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.WEATHER: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.WEATHER: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M6.6553 3.37344C7.42088 2.1484 8.78162 1.3335 10.3327 1.3335C12.7259 1.3335 14.666 3.2736 14.666 5.66683C14.666 6.38704 14.4903 7.06623 14.1794 7.66383C14.8894 8.3325 15.3327 9.28123 15.3327 10.3335C15.3327 12.3586 13.6911 14.0002 11.666 14.0002H5.99935C3.05383 14.0002 0.666016 11.6124 0.666016 8.66683C0.666016 5.72131 3.05383 3.3335 5.99935 3.3335C6.22143 3.3335 6.44034 3.34707 6.6553 3.37344ZM8.03628 3.73629C9.37768 4.29108 10.4435 5.37735 10.9711 6.73256C11.1961 6.68943 11.4284 6.66683 11.666 6.66683C12.1561 6.66683 12.6237 6.76296 13.0511 6.93743C13.2317 6.55162 13.3327 6.12102 13.3327 5.66683C13.3327 4.00998 11.9895 2.66683 10.3327 2.66683C9.41115 2.66683 8.58662 3.08236 8.03628 3.73629ZM11.666 12.6668C12.9547 12.6668 13.9993 11.6222 13.9993 10.3335C13.9993 9.04483 12.9547 8.00016 11.666 8.00016C11.013 8.00016 10.4227 8.26836 9.99922 8.70063C9.99928 8.68936 9.99935 8.6781 9.99935 8.66683C9.99935 6.45769 8.20848 4.66683 5.99935 4.66683C3.79021 4.66683 1.99935 6.45769 1.99935 8.66683C1.99935 10.876 3.79021 12.6668 5.99935 12.6668H11.666Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.FINANCE: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.FINANCE: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M8.00262 14.6685C4.32071 14.6685 1.33594 11.6838 1.33594 8.00184C1.33594 4.31997 4.32071 1.33521 8.00262 1.33521C11.6845 1.33521 14.6693 4.31997 14.6693 8.00184C14.6693 11.6838 11.6845 14.6685 8.00262 14.6685ZM8.00262 13.3352C10.9482 13.3352 13.336 10.9474 13.336 8.00184C13.336 5.05635 10.9482 2.66854 8.00262 2.66854C5.05708 2.66854 2.66927 5.05635 2.66927 8.00184C2.66927 10.9474 5.05708 13.3352 8.00262 13.3352ZM5.66927 9.33517H9.33595C9.52002 9.33517 9.66928 9.18597 9.66928 9.00184C9.66928 8.81777 9.52002 8.66851 9.33595 8.66851H6.66928C5.7488 8.66851 5.0026 7.92237 5.0026 7.00184C5.0026 6.08139 5.7488 5.33521 6.66928 5.33521H7.33595V4.00187H8.66928V5.33521H10.336V6.66851H6.66928C6.48518 6.66851 6.33594 6.81777 6.33594 7.00184C6.33594 7.18597 6.48518 7.33517 6.66928 7.33517H9.33595C10.2564 7.33517 11.0026 8.08137 11.0026 9.00184C11.0026 9.92237 10.2564 10.6685 9.33595 10.6685H8.66928V12.0018H7.33595V10.6685H5.66927V9.33517Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.DESIGN: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.DESIGN: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M4.70152 9.41416L3.2873 10.8284L5.17292 12.714L12.7154 5.17154L10.8298 3.28592L9.41557 4.70013L10.3584 5.64295L9.41557 6.58575L8.47277 5.64295L7.52997 6.58575L8.47277 7.52856L7.52997 8.47136L6.58713 7.52856L5.64433 8.47136L6.58713 9.41416L5.64433 10.357L4.70152 9.41416ZM11.3012 1.87171L14.1296 4.70013C14.39 4.96049 14.39 5.38259 14.1296 5.64295L5.64433 14.1282C5.38397 14.3886 4.96187 14.3886 4.70152 14.1282L1.87309 11.2998C1.61274 11.0394 1.61274 10.6174 1.87309 10.357L10.3584 1.87171C10.6187 1.61136 11.0408 1.61136 11.3012 1.87171ZM9.41557 12.2423L10.3584 11.2995L11.8534 12.7945H12.7962V11.8517L11.3012 10.3567L12.244 9.41383L14.0011 11.171V13.9999H11.1732L9.41557 12.2423ZM3.75861 6.58533L1.87299 4.69971C1.61265 4.43937 1.61265 4.01725 1.87299 3.75691L3.75861 1.87129C4.01896 1.61094 4.44107 1.61094 4.70142 1.87129L6.58704 3.75691L5.64423 4.69971L4.23002 3.2855L3.28721 4.22831L4.70142 5.64253L3.75861 6.58533Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.TRAVEL: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.TRAVEL: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M9.44839 2C9.80198 2 10.1411 2.14047 10.3912 2.39053L13.6101 5.60947C13.8602 5.85953 14.0007 6.19866 14.0007 6.55229V11.3333H15.334V12.6667L9.91652 12.6672C9.62032 13.8171 8.57638 14.6667 7.33398 14.6667C6.0916 14.6667 5.04766 13.8171 4.75146 12.6672L2.00065 12.6667C1.63246 12.6667 1.33398 12.3682 1.33398 12V3.33333C1.33398 2.59695 1.93094 2 2.66732 2H9.44839ZM7.33398 10.6667C6.5976 10.6667 6.00065 11.2636 6.00065 12C6.00065 12.7364 6.5976 13.3333 7.33398 13.3333C8.07038 13.3333 8.66732 12.7364 8.66732 12C8.66732 11.2636 8.07038 10.6667 7.33398 10.6667ZM9.44839 3.33333H2.66732V11.3333L4.75128 11.3335C5.04726 10.1833 6.09136 9.33333 7.33398 9.33333C8.57658 9.33333 9.62072 10.1833 9.91665 11.3335L12.6673 11.3333V6.55229L9.44839 3.33333ZM9.33398 4.66667V8.66667H4.00065V4.66667H9.33398ZM8.00065 6H5.33398V7.33333H8.00065V6Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.SOCIAL: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.SOCIAL: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M13.334 7.99992C13.334 5.0544 10.9462 2.66659 8.00065 2.66659C5.05513 2.66659 2.66732 5.0544 2.66732 7.99992C2.66732 10.9455 5.05513 13.3333 8.00065 13.3333C9.09518 13.3333 10.1127 13.0035 10.9594 12.438L11.699 13.5475C10.6408 14.2545 9.36885 14.6666 8.00065 14.6666C4.31875 14.6666 1.33398 11.6818 1.33398 7.99992C1.33398 4.31802 4.31875 1.33325 8.00065 1.33325C11.6825 1.33325 14.6673 4.31802 14.6673 7.99992V8.99992C14.6673 10.2886 13.6227 11.3333 12.334 11.3333C11.5312 11.3333 10.8231 10.9278 10.4032 10.3105C9.79678 10.9409 8.94452 11.3333 8.00065 11.3333C6.1597 11.3333 4.66732 9.84085 4.66732 7.99992C4.66732 6.15897 6.1597 4.66658 8.00065 4.66658C8.75118 4.66658 9.44378 4.91464 10.001 5.33325H11.334V8.99992C11.334 9.55219 11.7817 9.99992 12.334 9.99992C12.8863 9.99992 13.334 9.55219 13.334 8.99992V7.99992ZM8.00065 5.99992C6.89605 5.99992 6.00065 6.89532 6.00065 7.99992C6.00065 9.10452 6.89605 9.99992 8.00065 9.99992C9.10525 9.99992 10.0007 9.10452 10.0007 7.99992C10.0007 6.89532 9.10525 5.99992 8.00065 5.99992Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.NEWS: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.NEWS: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M10.6673 13.3335V2.66683H2.66732V12.6668C2.66732 13.035 2.9658 13.3335 3.33398 13.3335H10.6673ZM12.6673 14.6668H3.33398C2.22942 14.6668 1.33398 13.7714 1.33398 12.6668V2.00016C1.33398 1.63198 1.63246 1.3335 2.00065 1.3335H11.334C11.7022 1.3335 12.0007 1.63198 12.0007 2.00016V6.66683H14.6673V12.6668C14.6673 13.7714 13.7719 14.6668 12.6673 14.6668ZM12.0007 8.00016V12.6668C12.0007 13.035 12.2991 13.3335 12.6673 13.3335C13.0355 13.3335 13.334 13.035 13.334 12.6668V8.00016H12.0007ZM4.00065 4.00016H8.00065V8.00016H4.00065V4.00016ZM5.33398 5.3335V6.66683H6.66732V5.3335H5.33398ZM4.00065 8.66683H9.33398V10.0002H4.00065V8.66683ZM4.00065 10.6668H9.33398V12.0002H4.00065V10.6668Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.MEDICAL: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.MEDICAL: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M8.79747 1.51186L10.9641 5.26464C11.1482 5.5835 11.0389 5.99122 10.7201 6.17532L9.85373 6.67474L10.5207 7.83001L9.366 8.49668L8.699 7.34141L7.83333 7.84201C7.51447 8.02608 7.10673 7.91681 6.92267 7.59794L5.69747 5.47632C4.32922 5.89145 3.33333 7.16268 3.33333 8.66654C3.33333 9.08348 3.40987 9.48248 3.54965 9.85034C4.06613 9.52254 4.67762 9.33321 5.33333 9.33321C6.45605 9.33321 7.44913 9.88828 8.05313 10.7389L13.1787 7.78014L13.8454 8.93488L8.5932 11.9672C8.64133 12.1927 8.66667 12.4267 8.66667 12.6665C8.66667 12.895 8.64367 13.1181 8.59993 13.3337L14 13.3332V14.6665L2.66703 14.6673C2.2482 14.1101 2 13.4173 2 12.6665C2 11.9951 2.19855 11.3699 2.54014 10.8467C2.19517 10.1964 2 9.45428 2 8.66654C2 6.66968 3.25421 4.96575 5.01785 4.29953L4.75598 3.84519C4.38779 3.20747 4.60629 2.39202 5.24402 2.02382L6.97607 1.02382C7.6138 0.655637 8.42927 0.874138 8.79747 1.51186ZM5.33333 10.6665C4.22877 10.6665 3.33333 11.562 3.33333 12.6665C3.33333 12.9003 3.37343 13.1247 3.44711 13.3331H7.21953C7.29327 13.1247 7.33333 12.9003 7.33333 12.6665C7.33333 11.562 6.4379 10.6665 5.33333 10.6665ZM7.64273 2.17852L5.91068 3.17852L7.744 6.35395L9.47607 5.35395L7.64273 2.17852Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.PRODUCTIVITY: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.PRODUCTIVITY: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M6.64807 11.9999H9.35062C9.43862 11.1989 9.84742 10.5376 10.5111 9.81499C10.5858 9.73365 11.0652 9.23752 11.1221 9.16665C11.6872 8.46199 11.9993 7.58992 11.9993 6.66659C11.9993 4.45745 10.2085 2.66659 7.99935 2.66659C5.79021 2.66659 3.99935 4.45745 3.99935 6.66659C3.99935 7.58945 4.31118 8.46105 4.87576 9.16552C4.93271 9.23659 5.41322 9.73405 5.48704 9.81445C6.15112 10.5375 6.56004 11.1989 6.64807 11.9999ZM9.33268 13.3333H6.66602V13.9999H9.33268V13.3333ZM3.83532 9.99939C3.10365 9.08639 2.66602 7.92759 2.66602 6.66659C2.66602 3.72107 5.05383 1.33325 7.99935 1.33325C10.9449 1.33325 13.3327 3.72107 13.3327 6.66659C13.3327 7.92825 12.8945 9.08759 12.1622 10.0009C11.7487 10.5165 10.666 11.3333 10.666 12.3333V13.9999C10.666 14.7363 10.0691 15.3333 9.33268 15.3333H6.66602C5.92964 15.3333 5.33268 14.7363 5.33268 13.9999V12.3333C5.33268 11.3333 4.24907 10.5157 3.83532 9.99939ZM8.66602 6.66979H10.3327L7.33268 10.6698V8.00312H5.66602L8.66602 3.99992V6.66979Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.EDUCATION: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.EDUCATION: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M14 2.66683H4.66667C3.93029 2.66683 3.33333 3.26378 3.33333 4.00016C3.33333 4.73654 3.93029 5.3335 4.66667 5.3335H14V14.0002C14 14.3684 13.7015 14.6668 13.3333 14.6668H4.66667C3.19391 14.6668 2 13.4729 2 12.0002V4.00016C2 2.5274 3.19391 1.3335 4.66667 1.3335H13.3333C13.7015 1.3335 14 1.63198 14 2.00016V2.66683ZM3.33333 12.0002C3.33333 12.7366 3.93029 13.3335 4.66667 13.3335H12.6667V6.66683H4.66667C4.18095 6.66683 3.72557 6.53697 3.33333 6.31008V12.0002ZM13.3333 4.66683H4.66667C4.29848 4.66683 4 4.36835 4 4.00016C4 3.63198 4.29848 3.3335 4.66667 3.3335H13.3333V4.66683Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.BUSINESS: '''<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 14 14" fill="none">
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.BUSINESS: """<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 14 14" fill="none">
|
||||
<path d="M3.66732 3.33341V1.33341C3.66732 0.965228 3.9658 0.666748 4.33398 0.666748H9.66732C10.0355 0.666748 10.334 0.965228 10.334 1.33341V3.33341H13.0007C13.3689 3.33341 13.6673 3.63189 13.6673 4.00008V13.3334C13.6673 13.7016 13.3689 14.0001 13.0007 14.0001H1.00065C0.632464 14.0001 0.333984 13.7016 0.333984 13.3334V4.00008C0.333984 3.63189 0.632464 3.33341 1.00065 3.33341H3.66732ZM12.334 8.66675H1.66732V12.6667H12.334V8.66675ZM12.334 4.66675H1.66732V7.33341H3.66732V6.00008H5.00065V7.33341H9.00065V6.00008H10.334V7.33341H12.334V4.66675ZM5.00065 2.00008V3.33341H9.00065V2.00008H5.00065Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.ENTERTAINMENT: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.ENTERTAINMENT: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M11.3327 2.66675C13.5418 2.66675 15.3327 4.45761 15.3327 6.66675V9.33342C15.3327 11.5425 13.5418 13.3334 11.3327 13.3334H4.66602C2.45688 13.3334 0.666016 11.5425 0.666016 9.33342V6.66675C0.666016 4.45761 2.45688 2.66675 4.66602 2.66675H11.3327ZM11.3327 4.00008H4.66602C3.23788 4.00008 2.07196 5.12273 2.00262 6.53365L1.99935 6.66675V9.33342C1.99935 10.7615 3.122 11.9275 4.53292 11.9968L4.66602 12.0001H11.3327C12.7608 12.0001 13.9267 10.8774 13.9961 9.46648L13.9993 9.33342V6.66675C13.9993 5.23861 12.8767 4.07269 11.4657 4.00335L11.3327 4.00008ZM6.66602 6.00008V7.33342H7.99935V8.66675H6.66535L6.66602 10.0001H5.33268L5.33202 8.66675H3.99935V7.33342H5.33268V6.00008H6.66602ZM11.9993 8.66675V10.0001H10.666V8.66675H11.9993ZM10.666 6.00008V7.33342H9.33268V6.00008H10.666Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.UTILITIES: '''<svg xmlns="http://www.w3.org/2000/svg" width="13" height="15" viewBox="0 0 13 15" fill="none">
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.UTILITIES: """<svg xmlns="http://www.w3.org/2000/svg" width="13" height="15" viewBox="0 0 13 15" fill="none">
|
||||
<path d="M12.3346 0.333252C12.7028 0.333252 13.0013 0.631732 13.0013 0.999919V4.33325C13.0013 4.70144 12.7028 4.99992 12.3346 4.99992H9.0013V13.6666C9.0013 14.0348 8.70284 14.3333 8.33463 14.3333H5.66797C5.29978 14.3333 5.0013 14.0348 5.0013 13.6666V4.99992H1.33464C0.966449 4.99992 0.667969 4.70144 0.667969 4.33325V2.74527C0.667969 2.49276 0.810635 2.26192 1.0365 2.14899L4.66797 0.333252H12.3346ZM9.0013 1.66659H4.98273L2.0013 3.1573V3.66659H6.33464V12.9999H7.66797V3.66659H9.0013V1.66659ZM11.668 1.66659H10.3346V3.66659H11.668V1.66659Z" fill="#344054"/>
|
||||
</svg>''',
|
||||
ToolLabelEnum.OTHER: '''<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.OTHER: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M8.00052 0.666748L4.00065 7.33342H12.0007L8.00052 0.666748ZM8.00052 3.25828L9.64572 6.00008H6.35553L8.00052 3.25828ZM4.50065 13.3334C3.48813 13.3334 2.66732 12.5126 2.66732 11.5001C2.66732 10.4875 3.48813 9.66675 4.50065 9.66675C5.51317 9.66675 6.33398 10.4875 6.33398 11.5001C6.33398 12.5126 5.51317 13.3334 4.50065 13.3334ZM4.50065 14.6667C6.24955 14.6667 7.66732 13.249 7.66732 11.5001C7.66732 9.75115 6.24955 8.33342 4.50065 8.33342C2.75175 8.33342 1.33398 9.75115 1.33398 11.5001C1.33398 13.249 2.75175 14.6667 4.50065 14.6667ZM10.0007 10.3334V13.0001H12.6673V10.3334H10.0007ZM8.66732 14.3334V9.00008H14.0007V14.3334H8.66732Z" fill="#344054"/>
|
||||
</svg>'''
|
||||
</svg>""", # noqa: E501
|
||||
}
|
||||
|
||||
default_tool_label_dict = {
|
||||
ToolLabelEnum.SEARCH: ToolLabel(name='search', label=I18nObject(en_US='Search', zh_Hans='搜索'), icon=ICONS[ToolLabelEnum.SEARCH]),
|
||||
ToolLabelEnum.IMAGE: ToolLabel(name='image', label=I18nObject(en_US='Image', zh_Hans='图片'), icon=ICONS[ToolLabelEnum.IMAGE]),
|
||||
ToolLabelEnum.VIDEOS: ToolLabel(name='videos', label=I18nObject(en_US='Videos', zh_Hans='视频'), icon=ICONS[ToolLabelEnum.VIDEOS]),
|
||||
ToolLabelEnum.WEATHER: ToolLabel(name='weather', label=I18nObject(en_US='Weather', zh_Hans='天气'), icon=ICONS[ToolLabelEnum.WEATHER]),
|
||||
ToolLabelEnum.FINANCE: ToolLabel(name='finance', label=I18nObject(en_US='Finance', zh_Hans='金融'), icon=ICONS[ToolLabelEnum.FINANCE]),
|
||||
ToolLabelEnum.DESIGN: ToolLabel(name='design', label=I18nObject(en_US='Design', zh_Hans='设计'), icon=ICONS[ToolLabelEnum.DESIGN]),
|
||||
ToolLabelEnum.TRAVEL: ToolLabel(name='travel', label=I18nObject(en_US='Travel', zh_Hans='旅行'), icon=ICONS[ToolLabelEnum.TRAVEL]),
|
||||
ToolLabelEnum.SOCIAL: ToolLabel(name='social', label=I18nObject(en_US='Social', zh_Hans='社交'), icon=ICONS[ToolLabelEnum.SOCIAL]),
|
||||
ToolLabelEnum.NEWS: ToolLabel(name='news', label=I18nObject(en_US='News', zh_Hans='新闻'), icon=ICONS[ToolLabelEnum.NEWS]),
|
||||
ToolLabelEnum.MEDICAL: ToolLabel(name='medical', label=I18nObject(en_US='Medical', zh_Hans='医疗'), icon=ICONS[ToolLabelEnum.MEDICAL]),
|
||||
ToolLabelEnum.PRODUCTIVITY: ToolLabel(name='productivity', label=I18nObject(en_US='Productivity', zh_Hans='生产力'), icon=ICONS[ToolLabelEnum.PRODUCTIVITY]),
|
||||
ToolLabelEnum.EDUCATION: ToolLabel(name='education', label=I18nObject(en_US='Education', zh_Hans='教育'), icon=ICONS[ToolLabelEnum.EDUCATION]),
|
||||
ToolLabelEnum.BUSINESS: ToolLabel(name='business', label=I18nObject(en_US='Business', zh_Hans='商业'), icon=ICONS[ToolLabelEnum.BUSINESS]),
|
||||
ToolLabelEnum.ENTERTAINMENT: ToolLabel(name='entertainment', label=I18nObject(en_US='Entertainment', zh_Hans='娱乐'), icon=ICONS[ToolLabelEnum.ENTERTAINMENT]),
|
||||
ToolLabelEnum.UTILITIES: ToolLabel(name='utilities', label=I18nObject(en_US='Utilities', zh_Hans='工具'), icon=ICONS[ToolLabelEnum.UTILITIES]),
|
||||
ToolLabelEnum.OTHER: ToolLabel(name='other', label=I18nObject(en_US='Other', zh_Hans='其他'), icon=ICONS[ToolLabelEnum.OTHER]),
|
||||
ToolLabelEnum.SEARCH: ToolLabel(
|
||||
name="search", label=I18nObject(en_US="Search", zh_Hans="搜索"), icon=ICONS[ToolLabelEnum.SEARCH]
|
||||
),
|
||||
ToolLabelEnum.IMAGE: ToolLabel(
|
||||
name="image", label=I18nObject(en_US="Image", zh_Hans="图片"), icon=ICONS[ToolLabelEnum.IMAGE]
|
||||
),
|
||||
ToolLabelEnum.VIDEOS: ToolLabel(
|
||||
name="videos", label=I18nObject(en_US="Videos", zh_Hans="视频"), icon=ICONS[ToolLabelEnum.VIDEOS]
|
||||
),
|
||||
ToolLabelEnum.WEATHER: ToolLabel(
|
||||
name="weather", label=I18nObject(en_US="Weather", zh_Hans="天气"), icon=ICONS[ToolLabelEnum.WEATHER]
|
||||
),
|
||||
ToolLabelEnum.FINANCE: ToolLabel(
|
||||
name="finance", label=I18nObject(en_US="Finance", zh_Hans="金融"), icon=ICONS[ToolLabelEnum.FINANCE]
|
||||
),
|
||||
ToolLabelEnum.DESIGN: ToolLabel(
|
||||
name="design", label=I18nObject(en_US="Design", zh_Hans="设计"), icon=ICONS[ToolLabelEnum.DESIGN]
|
||||
),
|
||||
ToolLabelEnum.TRAVEL: ToolLabel(
|
||||
name="travel", label=I18nObject(en_US="Travel", zh_Hans="旅行"), icon=ICONS[ToolLabelEnum.TRAVEL]
|
||||
),
|
||||
ToolLabelEnum.SOCIAL: ToolLabel(
|
||||
name="social", label=I18nObject(en_US="Social", zh_Hans="社交"), icon=ICONS[ToolLabelEnum.SOCIAL]
|
||||
),
|
||||
ToolLabelEnum.NEWS: ToolLabel(
|
||||
name="news", label=I18nObject(en_US="News", zh_Hans="新闻"), icon=ICONS[ToolLabelEnum.NEWS]
|
||||
),
|
||||
ToolLabelEnum.MEDICAL: ToolLabel(
|
||||
name="medical", label=I18nObject(en_US="Medical", zh_Hans="医疗"), icon=ICONS[ToolLabelEnum.MEDICAL]
|
||||
),
|
||||
ToolLabelEnum.PRODUCTIVITY: ToolLabel(
|
||||
name="productivity",
|
||||
label=I18nObject(en_US="Productivity", zh_Hans="生产力"),
|
||||
icon=ICONS[ToolLabelEnum.PRODUCTIVITY],
|
||||
),
|
||||
ToolLabelEnum.EDUCATION: ToolLabel(
|
||||
name="education", label=I18nObject(en_US="Education", zh_Hans="教育"), icon=ICONS[ToolLabelEnum.EDUCATION]
|
||||
),
|
||||
ToolLabelEnum.BUSINESS: ToolLabel(
|
||||
name="business", label=I18nObject(en_US="Business", zh_Hans="商业"), icon=ICONS[ToolLabelEnum.BUSINESS]
|
||||
),
|
||||
ToolLabelEnum.ENTERTAINMENT: ToolLabel(
|
||||
name="entertainment",
|
||||
label=I18nObject(en_US="Entertainment", zh_Hans="娱乐"),
|
||||
icon=ICONS[ToolLabelEnum.ENTERTAINMENT],
|
||||
),
|
||||
ToolLabelEnum.UTILITIES: ToolLabel(
|
||||
name="utilities", label=I18nObject(en_US="Utilities", zh_Hans="工具"), icon=ICONS[ToolLabelEnum.UTILITIES]
|
||||
),
|
||||
ToolLabelEnum.OTHER: ToolLabel(
|
||||
name="other", label=I18nObject(en_US="Other", zh_Hans="其他"), icon=ICONS[ToolLabelEnum.OTHER]
|
||||
),
|
||||
}
|
||||
|
||||
default_tool_labels = [v for k, v in default_tool_label_dict.items()]
|
||||
|
||||
@ -4,23 +4,30 @@ from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
class ToolProviderNotFoundError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ToolNotFoundError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ToolParameterValidationError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ToolProviderCredentialValidationError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ToolNotSupportedError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ToolInvokeError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ToolApiSchemaError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ToolEngineInvokeError(Exception):
|
||||
meta: ToolInvokeMeta
|
||||
meta: ToolInvokeMeta
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
@ -20,86 +19,70 @@ class ApiToolProviderController(ToolProviderController):
|
||||
tools: list[ApiTool] = Field(default_factory=list)
|
||||
|
||||
@staticmethod
|
||||
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiToolProviderController':
|
||||
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
|
||||
credentials_schema = {
|
||||
'auth_type': ProviderConfig(
|
||||
name='auth_type',
|
||||
"auth_type": ProviderConfig(
|
||||
name="auth_type",
|
||||
required=True,
|
||||
type=ProviderConfig.Type.SELECT,
|
||||
options=[
|
||||
ProviderConfig.Option(value='none', label=I18nObject(en_US='None', zh_Hans='无')),
|
||||
ProviderConfig.Option(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key'))
|
||||
ProviderConfig.Option(value="none", label=I18nObject(en_US="None", zh_Hans="无")),
|
||||
ProviderConfig.Option(value="api_key", label=I18nObject(en_US="api_key", zh_Hans="api_key")),
|
||||
],
|
||||
default='none',
|
||||
help=I18nObject(
|
||||
en_US='The auth type of the api provider',
|
||||
zh_Hans='api provider 的认证类型'
|
||||
)
|
||||
default="none",
|
||||
help=I18nObject(en_US="The auth type of the api provider", zh_Hans="api provider 的认证类型"),
|
||||
)
|
||||
}
|
||||
if auth_type == ApiProviderAuthType.API_KEY:
|
||||
credentials_schema = {
|
||||
**credentials_schema,
|
||||
'api_key_header': ProviderConfig(
|
||||
name='api_key_header',
|
||||
"api_key_header": ProviderConfig(
|
||||
name="api_key_header",
|
||||
required=False,
|
||||
default='api_key',
|
||||
default="api_key",
|
||||
type=ProviderConfig.Type.TEXT_INPUT,
|
||||
help=I18nObject(
|
||||
en_US='The header name of the api key',
|
||||
zh_Hans='携带 api key 的 header 名称'
|
||||
)
|
||||
help=I18nObject(en_US="The header name of the api key", zh_Hans="携带 api key 的 header 名称"),
|
||||
),
|
||||
'api_key_value': ProviderConfig(
|
||||
name='api_key_value',
|
||||
"api_key_value": ProviderConfig(
|
||||
name="api_key_value",
|
||||
required=True,
|
||||
type=ProviderConfig.Type.SECRET_INPUT,
|
||||
help=I18nObject(
|
||||
en_US='The api key',
|
||||
zh_Hans='api key的值'
|
||||
)
|
||||
help=I18nObject(en_US="The api key", zh_Hans="api key的值"),
|
||||
),
|
||||
'api_key_header_prefix': ProviderConfig(
|
||||
name='api_key_header_prefix',
|
||||
"api_key_header_prefix": ProviderConfig(
|
||||
name="api_key_header_prefix",
|
||||
required=False,
|
||||
default='basic',
|
||||
default="basic",
|
||||
type=ProviderConfig.Type.SELECT,
|
||||
help=I18nObject(
|
||||
en_US='The prefix of the api key header',
|
||||
zh_Hans='api key header 的前缀'
|
||||
),
|
||||
help=I18nObject(en_US="The prefix of the api key header", zh_Hans="api key header 的前缀"),
|
||||
options=[
|
||||
ProviderConfig.Option(value='basic', label=I18nObject(en_US='Basic', zh_Hans='Basic')),
|
||||
ProviderConfig.Option(value='bearer', label=I18nObject(en_US='Bearer', zh_Hans='Bearer')),
|
||||
ProviderConfig.Option(value='custom', label=I18nObject(en_US='Custom', zh_Hans='Custom'))
|
||||
]
|
||||
)
|
||||
ProviderConfig.Option(value="basic", label=I18nObject(en_US="Basic", zh_Hans="Basic")),
|
||||
ProviderConfig.Option(value="bearer", label=I18nObject(en_US="Bearer", zh_Hans="Bearer")),
|
||||
ProviderConfig.Option(value="custom", label=I18nObject(en_US="Custom", zh_Hans="Custom")),
|
||||
],
|
||||
),
|
||||
}
|
||||
elif auth_type == ApiProviderAuthType.NONE:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f'invalid auth type {auth_type}')
|
||||
raise ValueError(f"invalid auth type {auth_type}")
|
||||
|
||||
user_name = db_provider.user.name if db_provider.user_id else ''
|
||||
user_name = db_provider.user.name if db_provider.user_id else ""
|
||||
|
||||
return ApiToolProviderController(**{
|
||||
'identity': {
|
||||
'author': user_name,
|
||||
'name': db_provider.name,
|
||||
'label': {
|
||||
'en_US': db_provider.name,
|
||||
'zh_Hans': db_provider.name
|
||||
return ApiToolProviderController(
|
||||
**{
|
||||
"identity": {
|
||||
"author": user_name,
|
||||
"name": db_provider.name,
|
||||
"label": {"en_US": db_provider.name, "zh_Hans": db_provider.name},
|
||||
"description": {"en_US": db_provider.description, "zh_Hans": db_provider.description},
|
||||
"icon": db_provider.icon,
|
||||
},
|
||||
'description': {
|
||||
'en_US': db_provider.description,
|
||||
'zh_Hans': db_provider.description
|
||||
},
|
||||
'icon': db_provider.icon,
|
||||
"credentials_schema": credentials_schema,
|
||||
"provider_id": db_provider.id or "",
|
||||
"tenant_id": db_provider.tenant_id or "",
|
||||
},
|
||||
'credentials_schema': credentials_schema,
|
||||
'provider_id': db_provider.id or '',
|
||||
'tenant_id': db_provider.tenant_id or '',
|
||||
})
|
||||
)
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
@ -107,39 +90,35 @@ class ApiToolProviderController(ToolProviderController):
|
||||
|
||||
def _parse_tool_bundle(self, tool_bundle: ApiToolBundle) -> ApiTool:
|
||||
"""
|
||||
parse tool bundle to tool
|
||||
parse tool bundle to tool
|
||||
|
||||
:param tool_bundle: the tool bundle
|
||||
:return: the tool
|
||||
:param tool_bundle: the tool bundle
|
||||
:return: the tool
|
||||
"""
|
||||
return ApiTool(**{
|
||||
'api_bundle': tool_bundle,
|
||||
'identity' : {
|
||||
'author': tool_bundle.author,
|
||||
'name': tool_bundle.operation_id,
|
||||
'label': {
|
||||
'en_US': tool_bundle.operation_id,
|
||||
'zh_Hans': tool_bundle.operation_id
|
||||
return ApiTool(
|
||||
**{
|
||||
"api_bundle": tool_bundle,
|
||||
"identity": {
|
||||
"author": tool_bundle.author,
|
||||
"name": tool_bundle.operation_id,
|
||||
"label": {"en_US": tool_bundle.operation_id, "zh_Hans": tool_bundle.operation_id},
|
||||
"icon": self.identity.icon,
|
||||
"provider": self.provider_id,
|
||||
},
|
||||
'icon': self.identity.icon,
|
||||
'provider': self.provider_id,
|
||||
},
|
||||
'description': {
|
||||
'human': {
|
||||
'en_US': tool_bundle.summary or '',
|
||||
'zh_Hans': tool_bundle.summary or ''
|
||||
"description": {
|
||||
"human": {"en_US": tool_bundle.summary or "", "zh_Hans": tool_bundle.summary or ""},
|
||||
"llm": tool_bundle.summary or "",
|
||||
},
|
||||
'llm': tool_bundle.summary or ''
|
||||
},
|
||||
'parameters' : tool_bundle.parameters if tool_bundle.parameters else [],
|
||||
})
|
||||
"parameters": tool_bundle.parameters or [],
|
||||
}
|
||||
)
|
||||
|
||||
def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]:
|
||||
"""
|
||||
load bundled tools
|
||||
load bundled tools
|
||||
|
||||
:param tools: the bundled tools
|
||||
:return: the tools
|
||||
:param tools: the bundled tools
|
||||
:return: the tools
|
||||
"""
|
||||
self.tools = [self._parse_tool_bundle(tool) for tool in tools]
|
||||
|
||||
@ -147,22 +126,23 @@ class ApiToolProviderController(ToolProviderController):
|
||||
|
||||
def get_tools(self, tenant_id: str) -> list[ApiTool]:
|
||||
"""
|
||||
fetch tools from database
|
||||
fetch tools from database
|
||||
|
||||
:param user_id: the user id
|
||||
:param tenant_id: the tenant id
|
||||
:return: the tools
|
||||
:param user_id: the user id
|
||||
:param tenant_id: the tenant id
|
||||
:return: the tools
|
||||
"""
|
||||
if self.tools is not None:
|
||||
return self.tools
|
||||
|
||||
|
||||
tools: list[ApiTool] = []
|
||||
|
||||
# get tenant api providers
|
||||
db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == self.identity.name
|
||||
).all()
|
||||
db_providers: list[ApiToolProvider] = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.identity.name)
|
||||
.all()
|
||||
)
|
||||
|
||||
if db_providers and len(db_providers) != 0:
|
||||
for db_provider in db_providers:
|
||||
@ -170,16 +150,16 @@ class ApiToolProviderController(ToolProviderController):
|
||||
assistant_tool = self._parse_tool_bundle(tool)
|
||||
assistant_tool.is_team_authorization = True
|
||||
tools.append(assistant_tool)
|
||||
|
||||
|
||||
self.tools = tools
|
||||
return tools
|
||||
|
||||
|
||||
def get_tool(self, tool_name: str) -> ApiTool:
|
||||
"""
|
||||
get tool by name
|
||||
get tool by name
|
||||
|
||||
:param tool_name: the name of the tool
|
||||
:return: the tool
|
||||
:param tool_name: the name of the tool
|
||||
:return: the tool
|
||||
"""
|
||||
if self.tools is None:
|
||||
self.get_tools(self.tenant_id)
|
||||
@ -188,4 +168,4 @@ class ApiToolProviderController(ToolProviderController):
|
||||
if tool.identity.name == tool_name:
|
||||
return tool
|
||||
|
||||
raise ValueError(f'tool {tool_name} not found')
|
||||
raise ValueError(f"tool {tool_name} not found")
|
||||
|
||||
103
api/core/tools/provider/app_tool_provider.py
Normal file
103
api/core/tools/provider/app_tool_provider.py
Normal file
@ -0,0 +1,103 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolParameter, ToolParameterOption, ToolProviderType
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool.tool import Tool
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppModelConfig
|
||||
from models.tools import PublishedAppTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppToolProviderEntity(ToolProviderController):
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.APP
|
||||
|
||||
def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def validate_parameters(self, tool_name: str, tool_parameters: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def get_tools(self, user_id: str) -> list[Tool]:
|
||||
db_tools: list[PublishedAppTool] = (
|
||||
db.session.query(PublishedAppTool)
|
||||
.filter(
|
||||
PublishedAppTool.user_id == user_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not db_tools or len(db_tools) == 0:
|
||||
return []
|
||||
|
||||
tools: list[Tool] = []
|
||||
|
||||
for db_tool in db_tools:
|
||||
tool = {
|
||||
"identity": {
|
||||
"author": db_tool.author,
|
||||
"name": db_tool.tool_name,
|
||||
"label": {"en_US": db_tool.tool_name, "zh_Hans": db_tool.tool_name},
|
||||
"icon": "",
|
||||
},
|
||||
"description": {
|
||||
"human": {"en_US": db_tool.description_i18n.en_US, "zh_Hans": db_tool.description_i18n.zh_Hans},
|
||||
"llm": db_tool.llm_description,
|
||||
},
|
||||
"parameters": [],
|
||||
}
|
||||
# get app from db
|
||||
app: App = db_tool.app
|
||||
|
||||
if not app:
|
||||
logger.error(f"app {db_tool.app_id} not found")
|
||||
continue
|
||||
|
||||
app_model_config: AppModelConfig = app.app_model_config
|
||||
user_input_form_list = app_model_config.user_input_form_list
|
||||
for input_form in user_input_form_list:
|
||||
# get type
|
||||
form_type = input_form.keys()[0]
|
||||
default = input_form[form_type]["default"]
|
||||
required = input_form[form_type]["required"]
|
||||
label = input_form[form_type]["label"]
|
||||
variable_name = input_form[form_type]["variable_name"]
|
||||
options = input_form[form_type].get("options", [])
|
||||
if form_type in {"paragraph", "text-input"}:
|
||||
tool["parameters"].append(
|
||||
ToolParameter(
|
||||
name=variable_name,
|
||||
label=I18nObject(en_US=label, zh_Hans=label),
|
||||
human_description=I18nObject(en_US=label, zh_Hans=label),
|
||||
llm_description=label,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=required,
|
||||
default=default,
|
||||
)
|
||||
)
|
||||
elif form_type == "select":
|
||||
tool["parameters"].append(
|
||||
ToolParameter(
|
||||
name=variable_name,
|
||||
label=I18nObject(en_US=label, zh_Hans=label),
|
||||
human_description=I18nObject(en_US=label, zh_Hans=label),
|
||||
llm_description=label,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
required=required,
|
||||
default=default,
|
||||
options=[
|
||||
ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
|
||||
for option in options
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
tools.append(Tool(**tool))
|
||||
return tools
|
||||
@ -10,7 +10,7 @@ class BuiltinToolProviderSort:
|
||||
@classmethod
|
||||
def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
|
||||
if not cls._position:
|
||||
cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), '..'))
|
||||
cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
def name_func(provider: UserToolProvider) -> str:
|
||||
return provider.name
|
||||
|
||||
@ -6,6 +6,6 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
|
||||
class AIPPTProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
AIPPTGenerateTool._get_api_token(credentials, user_id='__dify_system__')
|
||||
AIPPTGenerateTool._get_api_token(credentials, user_id="__dify_system__")
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
@ -20,16 +20,16 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
A tool for generating a ppt
|
||||
"""
|
||||
|
||||
_api_base_url = URL('https://co.aippt.cn/api')
|
||||
_api_base_url = URL("https://co.aippt.cn/api")
|
||||
_api_token_cache = {}
|
||||
_api_token_cache_lock:Optional[Lock] = None
|
||||
_api_token_cache_lock: Optional[Lock] = None
|
||||
_style_cache = {}
|
||||
_style_cache_lock:Optional[Lock] = None
|
||||
_style_cache_lock: Optional[Lock] = None
|
||||
|
||||
_task = {}
|
||||
_task_type_map = {
|
||||
'auto': 1,
|
||||
'markdown': 7,
|
||||
"auto": 1,
|
||||
"markdown": 7,
|
||||
}
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
@ -46,67 +46,58 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
tool_parameters (dict[str, Any]): The parameters for the tool
|
||||
|
||||
Returns:
|
||||
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, which can be a single message or a list of messages.
|
||||
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation,
|
||||
which can be a single message or a list of messages.
|
||||
"""
|
||||
title = tool_parameters.get('title', '')
|
||||
title = tool_parameters.get("title", "")
|
||||
if not title:
|
||||
return self.create_text_message('Please provide a title for the ppt')
|
||||
|
||||
model = tool_parameters.get('model', 'aippt')
|
||||
return self.create_text_message("Please provide a title for the ppt")
|
||||
|
||||
model = tool_parameters.get("model", "aippt")
|
||||
if not model:
|
||||
return self.create_text_message('Please provide a model for the ppt')
|
||||
|
||||
outline = tool_parameters.get('outline', '')
|
||||
return self.create_text_message("Please provide a model for the ppt")
|
||||
|
||||
outline = tool_parameters.get("outline", "")
|
||||
|
||||
# create task
|
||||
task_id = self._create_task(
|
||||
type=self._task_type_map['auto' if not outline else 'markdown'],
|
||||
type=self._task_type_map["auto" if not outline else "markdown"],
|
||||
title=title,
|
||||
content=outline,
|
||||
user_id=user_id
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# get suit
|
||||
color = tool_parameters.get('color')
|
||||
style = tool_parameters.get('style')
|
||||
color = tool_parameters.get("color")
|
||||
style = tool_parameters.get("style")
|
||||
|
||||
if color == '__default__':
|
||||
color_id = ''
|
||||
if color == "__default__":
|
||||
color_id = ""
|
||||
else:
|
||||
color_id = int(color.split('-')[1])
|
||||
color_id = int(color.split("-")[1])
|
||||
|
||||
if style == '__default__':
|
||||
style_id = ''
|
||||
if style == "__default__":
|
||||
style_id = ""
|
||||
else:
|
||||
style_id = int(style.split('-')[1])
|
||||
style_id = int(style.split("-")[1])
|
||||
|
||||
suit_id = self._get_suit(style_id=style_id, colour_id=color_id)
|
||||
|
||||
# generate outline
|
||||
if not outline:
|
||||
self._generate_outline(
|
||||
task_id=task_id,
|
||||
model=model,
|
||||
user_id=user_id
|
||||
)
|
||||
self._generate_outline(task_id=task_id, model=model, user_id=user_id)
|
||||
|
||||
# generate content
|
||||
self._generate_content(
|
||||
task_id=task_id,
|
||||
model=model,
|
||||
user_id=user_id
|
||||
)
|
||||
self._generate_content(task_id=task_id, model=model, user_id=user_id)
|
||||
|
||||
# generate ppt
|
||||
_, ppt_url = self._generate_ppt(
|
||||
task_id=task_id,
|
||||
suit_id=suit_id,
|
||||
user_id=user_id
|
||||
)
|
||||
_, ppt_url = self._generate_ppt(task_id=task_id, suit_id=suit_id, user_id=user_id)
|
||||
|
||||
return self.create_text_message('''the ppt has been created successfully,'''
|
||||
f'''the ppt url is {ppt_url}'''
|
||||
'''please give the ppt url to user and direct user to download it.''')
|
||||
return self.create_text_message(
|
||||
"""the ppt has been created successfully,"""
|
||||
f"""the ppt url is {ppt_url}"""
|
||||
"""please give the ppt url to user and direct user to download it."""
|
||||
)
|
||||
|
||||
def _create_task(self, type: int, title: str, content: str, user_id: str) -> str:
|
||||
"""
|
||||
@ -119,129 +110,121 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
:return: the task ID
|
||||
"""
|
||||
headers = {
|
||||
'x-channel': '',
|
||||
'x-api-key': self.runtime.credentials['aippt_access_key'],
|
||||
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
|
||||
"x-channel": "",
|
||||
"x-api-key": self.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
|
||||
}
|
||||
response = post(
|
||||
str(self._api_base_url / 'ai' / 'chat' / 'v2' / 'task'),
|
||||
str(self._api_base_url / "ai" / "chat" / "v2" / "task"),
|
||||
headers=headers,
|
||||
files={
|
||||
'type': ('', str(type)),
|
||||
'title': ('', title),
|
||||
'content': ('', content)
|
||||
}
|
||||
files={"type": ("", str(type)), "title": ("", title), "content": ("", content)},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f'Failed to connect to aippt: {response.text}')
|
||||
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
|
||||
response = response.json()
|
||||
if response.get('code') != 0:
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to create task: {response.get("msg")}')
|
||||
|
||||
return response.get('data', {}).get('id')
|
||||
|
||||
return response.get("data", {}).get("id")
|
||||
|
||||
def _generate_outline(self, task_id: str, model: str, user_id: str) -> str:
|
||||
api_url = self._api_base_url / 'ai' / 'chat' / 'outline' if model == 'aippt' else \
|
||||
self._api_base_url / 'ai' / 'chat' / 'wx' / 'outline'
|
||||
api_url %= {'task_id': task_id}
|
||||
api_url = (
|
||||
self._api_base_url / "ai" / "chat" / "outline"
|
||||
if model == "aippt"
|
||||
else self._api_base_url / "ai" / "chat" / "wx" / "outline"
|
||||
)
|
||||
api_url %= {"task_id": task_id}
|
||||
|
||||
headers = {
|
||||
'x-channel': '',
|
||||
'x-api-key': self.runtime.credentials['aippt_access_key'],
|
||||
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
|
||||
"x-channel": "",
|
||||
"x-api-key": self.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
|
||||
}
|
||||
|
||||
response = requests_get(
|
||||
url=api_url,
|
||||
headers=headers,
|
||||
stream=True,
|
||||
timeout=(10, 60)
|
||||
)
|
||||
response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60))
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f'Failed to connect to aippt: {response.text}')
|
||||
|
||||
outline = ''
|
||||
for chunk in response.iter_lines(delimiter=b'\n\n'):
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
|
||||
outline = ""
|
||||
for chunk in response.iter_lines(delimiter=b"\n\n"):
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
event = ''
|
||||
lines = chunk.decode('utf-8').split('\n')
|
||||
|
||||
event = ""
|
||||
lines = chunk.decode("utf-8").split("\n")
|
||||
for line in lines:
|
||||
if line.startswith('event:'):
|
||||
if line.startswith("event:"):
|
||||
event = line[6:]
|
||||
elif line.startswith('data:'):
|
||||
elif line.startswith("data:"):
|
||||
data = line[5:]
|
||||
if event == 'message':
|
||||
if event == "message":
|
||||
try:
|
||||
data = json_loads(data)
|
||||
outline += data.get('content', '')
|
||||
outline += data.get("content", "")
|
||||
except Exception as e:
|
||||
pass
|
||||
elif event == 'close':
|
||||
elif event == "close":
|
||||
break
|
||||
elif event == 'error' or event == 'filter':
|
||||
raise Exception(f'Failed to generate outline: {data}')
|
||||
|
||||
elif event in {"error", "filter"}:
|
||||
raise Exception(f"Failed to generate outline: {data}")
|
||||
|
||||
return outline
|
||||
|
||||
|
||||
def _generate_content(self, task_id: str, model: str, user_id: str) -> str:
|
||||
api_url = self._api_base_url / 'ai' / 'chat' / 'content' if model == 'aippt' else \
|
||||
self._api_base_url / 'ai' / 'chat' / 'wx' / 'content'
|
||||
api_url %= {'task_id': task_id}
|
||||
api_url = (
|
||||
self._api_base_url / "ai" / "chat" / "content"
|
||||
if model == "aippt"
|
||||
else self._api_base_url / "ai" / "chat" / "wx" / "content"
|
||||
)
|
||||
api_url %= {"task_id": task_id}
|
||||
|
||||
headers = {
|
||||
'x-channel': '',
|
||||
'x-api-key': self.runtime.credentials['aippt_access_key'],
|
||||
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
|
||||
"x-channel": "",
|
||||
"x-api-key": self.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
|
||||
}
|
||||
|
||||
response = requests_get(
|
||||
url=api_url,
|
||||
headers=headers,
|
||||
stream=True,
|
||||
timeout=(10, 60)
|
||||
)
|
||||
response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60))
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f'Failed to connect to aippt: {response.text}')
|
||||
|
||||
if model == 'aippt':
|
||||
content = ''
|
||||
for chunk in response.iter_lines(delimiter=b'\n\n'):
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
|
||||
if model == "aippt":
|
||||
content = ""
|
||||
for chunk in response.iter_lines(delimiter=b"\n\n"):
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
event = ''
|
||||
lines = chunk.decode('utf-8').split('\n')
|
||||
|
||||
event = ""
|
||||
lines = chunk.decode("utf-8").split("\n")
|
||||
for line in lines:
|
||||
if line.startswith('event:'):
|
||||
if line.startswith("event:"):
|
||||
event = line[6:]
|
||||
elif line.startswith('data:'):
|
||||
elif line.startswith("data:"):
|
||||
data = line[5:]
|
||||
if event == 'message':
|
||||
if event == "message":
|
||||
try:
|
||||
data = json_loads(data)
|
||||
content += data.get('content', '')
|
||||
content += data.get("content", "")
|
||||
except Exception as e:
|
||||
pass
|
||||
elif event == 'close':
|
||||
elif event == "close":
|
||||
break
|
||||
elif event == 'error' or event == 'filter':
|
||||
raise Exception(f'Failed to generate content: {data}')
|
||||
|
||||
elif event in {"error", "filter"}:
|
||||
raise Exception(f"Failed to generate content: {data}")
|
||||
|
||||
return content
|
||||
elif model == 'wenxin':
|
||||
elif model == "wenxin":
|
||||
response = response.json()
|
||||
if response.get('code') != 0:
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to generate content: {response.get("msg")}')
|
||||
|
||||
return response.get('data', '')
|
||||
|
||||
return ''
|
||||
|
||||
return response.get("data", "")
|
||||
|
||||
return ""
|
||||
|
||||
def _generate_ppt(self, task_id: str, suit_id: int, user_id) -> tuple[str, str]:
|
||||
"""
|
||||
@ -252,83 +235,73 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
:return: the cover url of the ppt and the ppt url
|
||||
"""
|
||||
headers = {
|
||||
'x-channel': '',
|
||||
'x-api-key': self.runtime.credentials['aippt_access_key'],
|
||||
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
|
||||
"x-channel": "",
|
||||
"x-api-key": self.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
|
||||
}
|
||||
|
||||
response = post(
|
||||
str(self._api_base_url / 'design' / 'v2' / 'save'),
|
||||
str(self._api_base_url / "design" / "v2" / "save"),
|
||||
headers=headers,
|
||||
data={
|
||||
'task_id': task_id,
|
||||
'template_id': suit_id
|
||||
}
|
||||
data={"task_id": task_id, "template_id": suit_id},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f'Failed to connect to aippt: {response.text}')
|
||||
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
|
||||
response = response.json()
|
||||
if response.get('code') != 0:
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
|
||||
|
||||
id = response.get('data', {}).get('id')
|
||||
cover_url = response.get('data', {}).get('cover_url')
|
||||
|
||||
id = response.get("data", {}).get("id")
|
||||
cover_url = response.get("data", {}).get("cover_url")
|
||||
|
||||
response = post(
|
||||
str(self._api_base_url / 'download' / 'export' / 'file'),
|
||||
str(self._api_base_url / "download" / "export" / "file"),
|
||||
headers=headers,
|
||||
data={
|
||||
'id': id,
|
||||
'format': 'ppt',
|
||||
'files_to_zip': False,
|
||||
'edit': True
|
||||
}
|
||||
data={"id": id, "format": "ppt", "files_to_zip": False, "edit": True},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f'Failed to connect to aippt: {response.text}')
|
||||
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
|
||||
response = response.json()
|
||||
if response.get('code') != 0:
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
|
||||
|
||||
export_code = response.get('data')
|
||||
|
||||
export_code = response.get("data")
|
||||
if not export_code:
|
||||
raise Exception('Failed to generate ppt, the export code is empty')
|
||||
|
||||
raise Exception("Failed to generate ppt, the export code is empty")
|
||||
|
||||
current_iteration = 0
|
||||
while current_iteration < 50:
|
||||
# get ppt url
|
||||
response = post(
|
||||
str(self._api_base_url / 'download' / 'export' / 'file' / 'result'),
|
||||
str(self._api_base_url / "download" / "export" / "file" / "result"),
|
||||
headers=headers,
|
||||
data={
|
||||
'task_key': export_code
|
||||
}
|
||||
data={"task_key": export_code},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f'Failed to connect to aippt: {response.text}')
|
||||
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
|
||||
response = response.json()
|
||||
if response.get('code') != 0:
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
|
||||
|
||||
if response.get('msg') == '导出中':
|
||||
|
||||
if response.get("msg") == "导出中":
|
||||
current_iteration += 1
|
||||
sleep(2)
|
||||
continue
|
||||
|
||||
ppt_url = response.get('data', [])
|
||||
|
||||
ppt_url = response.get("data", [])
|
||||
if len(ppt_url) == 0:
|
||||
raise Exception('Failed to generate ppt, the ppt url is empty')
|
||||
|
||||
raise Exception("Failed to generate ppt, the ppt url is empty")
|
||||
|
||||
return cover_url, ppt_url[0]
|
||||
|
||||
raise Exception('Failed to generate ppt, the export is timeout')
|
||||
|
||||
|
||||
raise Exception("Failed to generate ppt, the export is timeout")
|
||||
|
||||
@classmethod
|
||||
def _get_api_token(cls, credentials: dict[str, str], user_id: str) -> str:
|
||||
"""
|
||||
@ -337,53 +310,43 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
:param credentials: the credentials
|
||||
:return: the API token
|
||||
"""
|
||||
access_key = credentials['aippt_access_key']
|
||||
secret_key = credentials['aippt_secret_key']
|
||||
access_key = credentials["aippt_access_key"]
|
||||
secret_key = credentials["aippt_secret_key"]
|
||||
|
||||
cache_key = f'{access_key}#@#{user_id}'
|
||||
cache_key = f"{access_key}#@#{user_id}"
|
||||
|
||||
with cls._api_token_cache_lock:
|
||||
# clear expired tokens
|
||||
now = time()
|
||||
for key in list(cls._api_token_cache.keys()):
|
||||
if cls._api_token_cache[key]['expire'] < now:
|
||||
if cls._api_token_cache[key]["expire"] < now:
|
||||
del cls._api_token_cache[key]
|
||||
|
||||
if cache_key in cls._api_token_cache:
|
||||
return cls._api_token_cache[cache_key]['token']
|
||||
|
||||
return cls._api_token_cache[cache_key]["token"]
|
||||
|
||||
# get token
|
||||
headers = {
|
||||
'x-api-key': access_key,
|
||||
'x-timestamp': str(int(now)),
|
||||
'x-signature': cls._calculate_sign(access_key, secret_key, int(now))
|
||||
"x-api-key": access_key,
|
||||
"x-timestamp": str(int(now)),
|
||||
"x-signature": cls._calculate_sign(access_key, secret_key, int(now)),
|
||||
}
|
||||
|
||||
param = {
|
||||
'uid': user_id,
|
||||
'channel': ''
|
||||
}
|
||||
param = {"uid": user_id, "channel": ""}
|
||||
|
||||
response = get(
|
||||
str(cls._api_base_url / 'grant' / 'token'),
|
||||
params=param,
|
||||
headers=headers
|
||||
)
|
||||
response = get(str(cls._api_base_url / "grant" / "token"), params=param, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f'Failed to connect to aippt: {response.text}')
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
response = response.json()
|
||||
if response.get('code') != 0:
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
|
||||
|
||||
token = response.get('data', {}).get('token')
|
||||
expire = response.get('data', {}).get('time_expire')
|
||||
|
||||
token = response.get("data", {}).get("token")
|
||||
expire = response.get("data", {}).get("time_expire")
|
||||
|
||||
with cls._api_token_cache_lock:
|
||||
cls._api_token_cache[cache_key] = {
|
||||
'token': token,
|
||||
'expire': now + expire
|
||||
}
|
||||
cls._api_token_cache[cache_key] = {"token": token, "expire": now + expire}
|
||||
|
||||
return token
|
||||
|
||||
@ -391,11 +354,9 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
def _calculate_sign(cls, access_key: str, secret_key: str, timestamp: int) -> str:
|
||||
return b64encode(
|
||||
hmac_new(
|
||||
key=secret_key.encode('utf-8'),
|
||||
msg=f'GET@/api/grant/token/@{timestamp}'.encode(),
|
||||
digestmod=sha1
|
||||
key=secret_key.encode("utf-8"), msg=f"GET@/api/grant/token/@{timestamp}".encode(), digestmod=sha1
|
||||
).digest()
|
||||
).decode('utf-8')
|
||||
).decode("utf-8")
|
||||
|
||||
@classmethod
|
||||
def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[dict], list[dict]]:
|
||||
@ -408,47 +369,46 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
# clear expired styles
|
||||
now = time()
|
||||
for key in list(cls._style_cache.keys()):
|
||||
if cls._style_cache[key]['expire'] < now:
|
||||
if cls._style_cache[key]["expire"] < now:
|
||||
del cls._style_cache[key]
|
||||
|
||||
key = f'{credentials["aippt_access_key"]}#@#{user_id}'
|
||||
if key in cls._style_cache:
|
||||
return cls._style_cache[key]['colors'], cls._style_cache[key]['styles']
|
||||
return cls._style_cache[key]["colors"], cls._style_cache[key]["styles"]
|
||||
|
||||
headers = {
|
||||
'x-channel': '',
|
||||
'x-api-key': credentials['aippt_access_key'],
|
||||
'x-token': cls._get_api_token(credentials=credentials, user_id=user_id)
|
||||
"x-channel": "",
|
||||
"x-api-key": credentials["aippt_access_key"],
|
||||
"x-token": cls._get_api_token(credentials=credentials, user_id=user_id),
|
||||
}
|
||||
response = get(
|
||||
str(cls._api_base_url / 'template_component' / 'suit' / 'select'),
|
||||
headers=headers
|
||||
)
|
||||
response = get(str(cls._api_base_url / "template_component" / "suit" / "select"), headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f'Failed to connect to aippt: {response.text}')
|
||||
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
|
||||
response = response.json()
|
||||
|
||||
if response.get('code') != 0:
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
|
||||
|
||||
colors = [{
|
||||
'id': f'id-{item.get("id")}',
|
||||
'name': item.get('name'),
|
||||
'en_name': item.get('en_name', item.get('name')),
|
||||
} for item in response.get('data', {}).get('colour') or []]
|
||||
styles = [{
|
||||
'id': f'id-{item.get("id")}',
|
||||
'name': item.get('title'),
|
||||
} for item in response.get('data', {}).get('suit_style') or []]
|
||||
|
||||
colors = [
|
||||
{
|
||||
"id": f'id-{item.get("id")}',
|
||||
"name": item.get("name"),
|
||||
"en_name": item.get("en_name", item.get("name")),
|
||||
}
|
||||
for item in response.get("data", {}).get("colour") or []
|
||||
]
|
||||
styles = [
|
||||
{
|
||||
"id": f'id-{item.get("id")}',
|
||||
"name": item.get("title"),
|
||||
}
|
||||
for item in response.get("data", {}).get("suit_style") or []
|
||||
]
|
||||
|
||||
with cls._style_cache_lock:
|
||||
cls._style_cache[key] = {
|
||||
'colors': colors,
|
||||
'styles': styles,
|
||||
'expire': now + 60 * 60
|
||||
}
|
||||
cls._style_cache[key] = {"colors": colors, "styles": styles, "expire": now + 60 * 60}
|
||||
|
||||
return colors, styles
|
||||
|
||||
@ -459,44 +419,39 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
:param credentials: the credentials
|
||||
:return: Tuple[list[dict[id, color]], list[dict[id, style]]
|
||||
"""
|
||||
if not self.runtime.credentials.get('aippt_access_key') or not self.runtime.credentials.get('aippt_secret_key'):
|
||||
raise Exception('Please provide aippt credentials')
|
||||
if not self.runtime.credentials.get("aippt_access_key") or not self.runtime.credentials.get("aippt_secret_key"):
|
||||
raise Exception("Please provide aippt credentials")
|
||||
|
||||
return self._get_styles(credentials=self.runtime.credentials, user_id=user_id)
|
||||
|
||||
|
||||
def _get_suit(self, style_id: int, colour_id: int) -> int:
|
||||
"""
|
||||
Get suit
|
||||
"""
|
||||
headers = {
|
||||
'x-channel': '',
|
||||
'x-api-key': self.runtime.credentials['aippt_access_key'],
|
||||
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id='__dify_system__')
|
||||
"x-channel": "",
|
||||
"x-api-key": self.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id="__dify_system__"),
|
||||
}
|
||||
response = get(
|
||||
str(self._api_base_url / 'template_component' / 'suit' / 'search'),
|
||||
str(self._api_base_url / "template_component" / "suit" / "search"),
|
||||
headers=headers,
|
||||
params={
|
||||
'style_id': style_id,
|
||||
'colour_id': colour_id,
|
||||
'page': 1,
|
||||
'page_size': 1
|
||||
}
|
||||
params={"style_id": style_id, "colour_id": colour_id, "page": 1, "page_size": 1},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f'Failed to connect to aippt: {response.text}')
|
||||
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
|
||||
response = response.json()
|
||||
|
||||
if response.get('code') != 0:
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
|
||||
|
||||
if len(response.get('data', {}).get('list') or []) > 0:
|
||||
return response.get('data', {}).get('list')[0].get('id')
|
||||
|
||||
raise Exception('Failed to get suit, the suit does not exist, please check the style and color')
|
||||
|
||||
|
||||
if len(response.get("data", {}).get("list") or []) > 0:
|
||||
return response.get("data", {}).get("list")[0].get("id")
|
||||
|
||||
raise Exception("Failed to get suit, the suit does not exist, please check the style and color")
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
"""
|
||||
Get runtime parameters
|
||||
@ -504,43 +459,40 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
Override this method to add runtime parameters to the tool.
|
||||
"""
|
||||
try:
|
||||
colors, styles = self.get_styles(user_id='__dify_system__')
|
||||
colors, styles = self.get_styles(user_id="__dify_system__")
|
||||
except Exception as e:
|
||||
colors, styles = [
|
||||
{'id': '-1', 'name': '__default__', 'en_name': '__default__'}
|
||||
], [
|
||||
{'id': '-1', 'name': '__default__', 'en_name': '__default__'}
|
||||
]
|
||||
colors, styles = (
|
||||
[{"id": "-1", "name": "__default__", "en_name": "__default__"}],
|
||||
[{"id": "-1", "name": "__default__", "en_name": "__default__"}],
|
||||
)
|
||||
|
||||
return [
|
||||
ToolParameter(
|
||||
name='color',
|
||||
label=I18nObject(zh_Hans='颜色', en_US='Color'),
|
||||
human_description=I18nObject(zh_Hans='颜色', en_US='Color'),
|
||||
name="color",
|
||||
label=I18nObject(zh_Hans="颜色", en_US="Color"),
|
||||
human_description=I18nObject(zh_Hans="颜色", en_US="Color"),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
required=False,
|
||||
default=colors[0]['id'],
|
||||
default=colors[0]["id"],
|
||||
options=[
|
||||
ToolParameterOption(
|
||||
value=color['id'],
|
||||
label=I18nObject(zh_Hans=color['name'], en_US=color['en_name'])
|
||||
) for color in colors
|
||||
]
|
||||
value=color["id"], label=I18nObject(zh_Hans=color["name"], en_US=color["en_name"])
|
||||
)
|
||||
for color in colors
|
||||
],
|
||||
),
|
||||
ToolParameter(
|
||||
name='style',
|
||||
label=I18nObject(zh_Hans='风格', en_US='Style'),
|
||||
human_description=I18nObject(zh_Hans='风格', en_US='Style'),
|
||||
name="style",
|
||||
label=I18nObject(zh_Hans="风格", en_US="Style"),
|
||||
human_description=I18nObject(zh_Hans="风格", en_US="Style"),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
required=False,
|
||||
default=styles[0]['id'],
|
||||
default=styles[0]["id"],
|
||||
options=[
|
||||
ToolParameterOption(
|
||||
value=style['id'],
|
||||
label=I18nObject(zh_Hans=style['name'], en_US=style['name'])
|
||||
) for style in styles
|
||||
]
|
||||
ToolParameterOption(value=style["id"], label=I18nObject(zh_Hans=style["name"], en_US=style["name"]))
|
||||
for style in styles
|
||||
],
|
||||
),
|
||||
]
|
||||
]
|
||||
|
||||
@ -13,7 +13,7 @@ class AlphaVantageProvider(BuiltinToolProviderController):
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"code": "AAPL", # Apple Inc.
|
||||
},
|
||||
|
||||
@ -9,17 +9,16 @@ ALPHAVANTAGE_API_URL = "https://www.alphavantage.co/query"
|
||||
|
||||
|
||||
class QueryStockTool(BuiltinTool):
|
||||
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
|
||||
stock_code = tool_parameters.get('code', '')
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
stock_code = tool_parameters.get("code", "")
|
||||
if not stock_code:
|
||||
return self.create_text_message('Please tell me your stock code')
|
||||
return self.create_text_message("Please tell me your stock code")
|
||||
|
||||
if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'):
|
||||
if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"):
|
||||
return self.create_text_message("Alpha Vantage API key is required.")
|
||||
|
||||
params = {
|
||||
@ -27,7 +26,7 @@ class QueryStockTool(BuiltinTool):
|
||||
"symbol": stock_code,
|
||||
"outputsize": "compact",
|
||||
"datatype": "json",
|
||||
"apikey": self.runtime.credentials['api_key']
|
||||
"apikey": self.runtime.credentials["api_key"],
|
||||
}
|
||||
response = requests.get(url=ALPHAVANTAGE_API_URL, params=params)
|
||||
response.raise_for_status()
|
||||
@ -35,15 +34,15 @@ class QueryStockTool(BuiltinTool):
|
||||
return self.create_json_message(result)
|
||||
|
||||
def _handle_response(self, response: dict[str, Any]) -> dict[str, Any]:
|
||||
result = response.get('Time Series (Daily)', {})
|
||||
result = response.get("Time Series (Daily)", {})
|
||||
if not result:
|
||||
return {}
|
||||
stock_result = {}
|
||||
for k, v in result.items():
|
||||
stock_result[k] = {}
|
||||
stock_result[k]['open'] = v.get('1. open')
|
||||
stock_result[k]['high'] = v.get('2. high')
|
||||
stock_result[k]['low'] = v.get('3. low')
|
||||
stock_result[k]['close'] = v.get('4. close')
|
||||
stock_result[k]['volume'] = v.get('5. volume')
|
||||
stock_result[k]["open"] = v.get("1. open")
|
||||
stock_result[k]["high"] = v.get("2. high")
|
||||
stock_result[k]["low"] = v.get("3. low")
|
||||
stock_result[k]["close"] = v.get("4. close")
|
||||
stock_result[k]["volume"] = v.get("5. volume")
|
||||
return stock_result
|
||||
|
||||
@ -11,11 +11,10 @@ class ArxivProvider(BuiltinToolProviderController):
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"query": "John Doe",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
@ -8,6 +8,8 @@ from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ArxivAPIWrapper(BaseModel):
|
||||
"""Wrapper around ArxivAPI.
|
||||
|
||||
@ -86,11 +88,13 @@ class ArxivAPIWrapper(BaseModel):
|
||||
|
||||
class ArxivSearchInput(BaseModel):
|
||||
query: str = Field(..., description="Search query.")
|
||||
|
||||
|
||||
|
||||
class ArxivSearchTool(BuiltinTool):
|
||||
"""
|
||||
A tool for searching articles on Arxiv.
|
||||
"""
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
"""
|
||||
Invokes the Arxiv search tool with the given user ID and tool parameters.
|
||||
@ -100,15 +104,16 @@ class ArxivSearchTool(BuiltinTool):
|
||||
tool_parameters (dict[str, Any]): The parameters for the tool, including the 'query' parameter.
|
||||
|
||||
Returns:
|
||||
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, which can be a single message or a list of messages.
|
||||
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation,
|
||||
which can be a single message or a list of messages.
|
||||
"""
|
||||
query = tool_parameters.get('query', '')
|
||||
query = tool_parameters.get("query", "")
|
||||
|
||||
if not query:
|
||||
return self.create_text_message('Please input query')
|
||||
|
||||
return self.create_text_message("Please input query")
|
||||
|
||||
arxiv = ArxivAPIWrapper()
|
||||
|
||||
|
||||
response = arxiv.run(query)
|
||||
|
||||
|
||||
return self.create_text_message(self.summary(user_id=user_id, content=response))
|
||||
|
||||
@ -11,15 +11,14 @@ class SageMakerProvider(BuiltinToolProviderController):
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"sagemaker_endpoint" : "",
|
||||
"sagemaker_endpoint": "",
|
||||
"query": "misaka mikoto",
|
||||
"candidate_texts" : "hello$$$hello world",
|
||||
"topk" : 5,
|
||||
"aws_region" : ""
|
||||
"candidate_texts": "hello$$$hello world",
|
||||
"topk": 5,
|
||||
"aws_region": "",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
@ -12,6 +12,7 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GuardrailParameters(BaseModel):
|
||||
guardrail_id: str = Field(..., description="The identifier of the guardrail")
|
||||
guardrail_version: str = Field(..., description="The version of the guardrail")
|
||||
@ -19,35 +20,35 @@ class GuardrailParameters(BaseModel):
|
||||
text: str = Field(..., description="The text to apply the guardrail to")
|
||||
aws_region: str = Field(..., description="AWS region for the Bedrock client")
|
||||
|
||||
|
||||
class ApplyGuardrailTool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
Invoke the ApplyGuardrail tool
|
||||
"""
|
||||
try:
|
||||
# Validate and parse input parameters
|
||||
params = GuardrailParameters(**tool_parameters)
|
||||
|
||||
|
||||
# Initialize AWS client
|
||||
bedrock_client = boto3.client('bedrock-runtime', region_name=params.aws_region)
|
||||
bedrock_client = boto3.client("bedrock-runtime", region_name=params.aws_region)
|
||||
|
||||
# Apply guardrail
|
||||
response = bedrock_client.apply_guardrail(
|
||||
guardrailIdentifier=params.guardrail_id,
|
||||
guardrailVersion=params.guardrail_version,
|
||||
source=params.source,
|
||||
content=[{"text": {"text": params.text}}]
|
||||
content=[{"text": {"text": params.text}}],
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"Raw response from AWS: {json.dumps(response, indent=2)}")
|
||||
|
||||
# Check for empty response
|
||||
if not response:
|
||||
return self.create_text_message(text="Received empty response from AWS Bedrock.")
|
||||
|
||||
|
||||
# Process the result
|
||||
action = response.get("action", "No action specified")
|
||||
outputs = response.get("outputs", [])
|
||||
@ -58,9 +59,12 @@ class ApplyGuardrailTool(BuiltinTool):
|
||||
formatted_assessments = []
|
||||
for assessment in assessments:
|
||||
for policy_type, policy_data in assessment.items():
|
||||
if isinstance(policy_data, dict) and 'topics' in policy_data:
|
||||
for topic in policy_data['topics']:
|
||||
formatted_assessments.append(f"Policy: {policy_type}, Topic: {topic['name']}, Type: {topic['type']}, Action: {topic['action']}")
|
||||
if isinstance(policy_data, dict) and "topics" in policy_data:
|
||||
for topic in policy_data["topics"]:
|
||||
formatted_assessments.append(
|
||||
f"Policy: {policy_type}, Topic: {topic['name']}, Type: {topic['type']},"
|
||||
f" Action: {topic['action']}"
|
||||
)
|
||||
else:
|
||||
formatted_assessments.append(f"Policy: {policy_type}, Data: {policy_data}")
|
||||
|
||||
@ -68,19 +72,19 @@ class ApplyGuardrailTool(BuiltinTool):
|
||||
result += f"Output: {output}\n "
|
||||
if formatted_assessments:
|
||||
result += "Assessments:\n " + "\n ".join(formatted_assessments) + "\n "
|
||||
# result += f"Full response: {json.dumps(response, indent=2, ensure_ascii=False)}"
|
||||
# result += f"Full response: {json.dumps(response, indent=2, ensure_ascii=False)}"
|
||||
|
||||
return self.create_text_message(text=result)
|
||||
|
||||
except BotoCoreError as e:
|
||||
error_message = f'AWS service error: {str(e)}'
|
||||
error_message = f"AWS service error: {str(e)}"
|
||||
logger.error(error_message, exc_info=True)
|
||||
return self.create_text_message(text=error_message)
|
||||
except json.JSONDecodeError as e:
|
||||
error_message = f'JSON parsing error: {str(e)}'
|
||||
error_message = f"JSON parsing error: {str(e)}"
|
||||
logger.error(error_message, exc_info=True)
|
||||
return self.create_text_message(text=error_message)
|
||||
except Exception as e:
|
||||
error_message = f'An unexpected error occurred: {str(e)}'
|
||||
error_message = f"An unexpected error occurred: {str(e)}"
|
||||
logger.error(error_message, exc_info=True)
|
||||
return self.create_text_message(text=error_message)
|
||||
return self.create_text_message(text=error_message)
|
||||
|
||||
@ -11,78 +11,81 @@ class LambdaTranslateUtilsTool(BuiltinTool):
|
||||
lambda_client: Any = None
|
||||
|
||||
def _invoke_lambda(self, text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name):
|
||||
msg = {
|
||||
"src_content":text_content,
|
||||
"src_lang": src_lang,
|
||||
"dest_lang":dest_lang,
|
||||
msg = {
|
||||
"src_content": text_content,
|
||||
"src_lang": src_lang,
|
||||
"dest_lang": dest_lang,
|
||||
"dictionary_id": dictionary_name,
|
||||
"request_type" : request_type,
|
||||
"model_id" : model_id
|
||||
"request_type": request_type,
|
||||
"model_id": model_id,
|
||||
}
|
||||
|
||||
invoke_response = self.lambda_client.invoke(FunctionName=lambda_name,
|
||||
InvocationType='RequestResponse',
|
||||
Payload=json.dumps(msg))
|
||||
response_body = invoke_response['Payload']
|
||||
invoke_response = self.lambda_client.invoke(
|
||||
FunctionName=lambda_name, InvocationType="RequestResponse", Payload=json.dumps(msg)
|
||||
)
|
||||
response_body = invoke_response["Payload"]
|
||||
|
||||
response_str = response_body.read().decode("unicode_escape")
|
||||
|
||||
return response_str
|
||||
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
invoke tools
|
||||
"""
|
||||
line = 0
|
||||
try:
|
||||
if not self.lambda_client:
|
||||
aws_region = tool_parameters.get('aws_region')
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
if aws_region:
|
||||
self.lambda_client = boto3.client("lambda", region_name=aws_region)
|
||||
else:
|
||||
self.lambda_client = boto3.client("lambda")
|
||||
|
||||
line = 1
|
||||
text_content = tool_parameters.get('text_content', '')
|
||||
text_content = tool_parameters.get("text_content", "")
|
||||
if not text_content:
|
||||
return self.create_text_message('Please input text_content')
|
||||
|
||||
return self.create_text_message("Please input text_content")
|
||||
|
||||
line = 2
|
||||
src_lang = tool_parameters.get('src_lang', '')
|
||||
src_lang = tool_parameters.get("src_lang", "")
|
||||
if not src_lang:
|
||||
return self.create_text_message('Please input src_lang')
|
||||
|
||||
return self.create_text_message("Please input src_lang")
|
||||
|
||||
line = 3
|
||||
dest_lang = tool_parameters.get('dest_lang', '')
|
||||
dest_lang = tool_parameters.get("dest_lang", "")
|
||||
if not dest_lang:
|
||||
return self.create_text_message('Please input dest_lang')
|
||||
|
||||
return self.create_text_message("Please input dest_lang")
|
||||
|
||||
line = 4
|
||||
lambda_name = tool_parameters.get('lambda_name', '')
|
||||
lambda_name = tool_parameters.get("lambda_name", "")
|
||||
if not lambda_name:
|
||||
return self.create_text_message('Please input lambda_name')
|
||||
|
||||
return self.create_text_message("Please input lambda_name")
|
||||
|
||||
line = 5
|
||||
request_type = tool_parameters.get('request_type', '')
|
||||
request_type = tool_parameters.get("request_type", "")
|
||||
if not request_type:
|
||||
return self.create_text_message('Please input request_type')
|
||||
|
||||
return self.create_text_message("Please input request_type")
|
||||
|
||||
line = 6
|
||||
model_id = tool_parameters.get('model_id', '')
|
||||
model_id = tool_parameters.get("model_id", "")
|
||||
if not model_id:
|
||||
return self.create_text_message('Please input model_id')
|
||||
return self.create_text_message("Please input model_id")
|
||||
|
||||
line = 7
|
||||
dictionary_name = tool_parameters.get('dictionary_name', '')
|
||||
dictionary_name = tool_parameters.get("dictionary_name", "")
|
||||
if not dictionary_name:
|
||||
return self.create_text_message('Please input dictionary_name')
|
||||
|
||||
result = self._invoke_lambda(text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name)
|
||||
return self.create_text_message("Please input dictionary_name")
|
||||
|
||||
result = self._invoke_lambda(
|
||||
text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name
|
||||
)
|
||||
|
||||
return self.create_text_message(text=result)
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f'Exception {str(e)}, line : {line}')
|
||||
return self.create_text_message(f"Exception {str(e)}, line : {line}")
|
||||
|
||||
@ -10,7 +10,7 @@ description:
|
||||
human:
|
||||
en_US: A util tools for LLM translation, extra deployment is needed on AWS. Please refer Github Repo - https://github.com/ybalbert001/dynamodb-rag
|
||||
zh_Hans: 大语言模型翻译工具(专词映射获取),需要在AWS上进行额外部署,可参考Github Repo - https://github.com/ybalbert001/dynamodb-rag
|
||||
pt_BR: A util tools for LLM translation, specfic Lambda Function deployment is needed on AWS. Please refer Github Repo - https://github.com/ybalbert001/dynamodb-rag
|
||||
pt_BR: A util tools for LLM translation, specific Lambda Function deployment is needed on AWS. Please refer Github Repo - https://github.com/ybalbert001/dynamodb-rag
|
||||
llm: A util tools for translation.
|
||||
parameters:
|
||||
- name: text_content
|
||||
|
||||
@ -18,54 +18,53 @@ class LambdaYamlToJsonTool(BuiltinTool):
|
||||
lambda_client: Any = None
|
||||
|
||||
def _invoke_lambda(self, lambda_name: str, yaml_content: str) -> str:
|
||||
msg = {
|
||||
"body": yaml_content
|
||||
}
|
||||
msg = {"body": yaml_content}
|
||||
logger.info(json.dumps(msg))
|
||||
|
||||
invoke_response = self.lambda_client.invoke(FunctionName=lambda_name,
|
||||
InvocationType='RequestResponse',
|
||||
Payload=json.dumps(msg))
|
||||
response_body = invoke_response['Payload']
|
||||
invoke_response = self.lambda_client.invoke(
|
||||
FunctionName=lambda_name, InvocationType="RequestResponse", Payload=json.dumps(msg)
|
||||
)
|
||||
response_body = invoke_response["Payload"]
|
||||
|
||||
response_str = response_body.read().decode("utf-8")
|
||||
resp_json = json.loads(response_str)
|
||||
|
||||
logger.info(resp_json)
|
||||
if resp_json['statusCode'] != 200:
|
||||
if resp_json["statusCode"] != 200:
|
||||
raise Exception(f"Invalid status code: {response_str}")
|
||||
|
||||
return resp_json['body']
|
||||
return resp_json["body"]
|
||||
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
invoke tools
|
||||
"""
|
||||
try:
|
||||
if not self.lambda_client:
|
||||
aws_region = tool_parameters.get('aws_region') # todo: move aws_region out, and update client region
|
||||
aws_region = tool_parameters.get("aws_region") # todo: move aws_region out, and update client region
|
||||
if aws_region:
|
||||
self.lambda_client = boto3.client("lambda", region_name=aws_region)
|
||||
else:
|
||||
self.lambda_client = boto3.client("lambda")
|
||||
|
||||
yaml_content = tool_parameters.get('yaml_content', '')
|
||||
yaml_content = tool_parameters.get("yaml_content", "")
|
||||
if not yaml_content:
|
||||
return self.create_text_message('Please input yaml_content')
|
||||
return self.create_text_message("Please input yaml_content")
|
||||
|
||||
lambda_name = tool_parameters.get('lambda_name', '')
|
||||
lambda_name = tool_parameters.get("lambda_name", "")
|
||||
if not lambda_name:
|
||||
return self.create_text_message('Please input lambda_name')
|
||||
logger.debug(f'{json.dumps(tool_parameters, indent=2, ensure_ascii=False)}')
|
||||
|
||||
return self.create_text_message("Please input lambda_name")
|
||||
logger.debug(f"{json.dumps(tool_parameters, indent=2, ensure_ascii=False)}")
|
||||
|
||||
result = self._invoke_lambda(lambda_name, yaml_content)
|
||||
logger.debug(result)
|
||||
|
||||
|
||||
return self.create_text_message(result)
|
||||
except Exception as e:
|
||||
return self.create_text_message(f'Exception: {str(e)}')
|
||||
return self.create_text_message(f"Exception: {str(e)}")
|
||||
|
||||
console_handler.flush()
|
||||
console_handler.flush()
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import operator
|
||||
from typing import Any, Union
|
||||
|
||||
import boto3
|
||||
@ -9,37 +10,33 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
class SageMakerReRankTool(BuiltinTool):
|
||||
sagemaker_client: Any = None
|
||||
sagemaker_endpoint:str = None
|
||||
topk:int = None
|
||||
sagemaker_endpoint: str = None
|
||||
topk: int = None
|
||||
|
||||
def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint:str):
|
||||
inputs = [query_input]*len(docs)
|
||||
def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str):
|
||||
inputs = [query_input] * len(docs)
|
||||
response_model = self.sagemaker_client.invoke_endpoint(
|
||||
EndpointName=rerank_endpoint,
|
||||
Body=json.dumps(
|
||||
{
|
||||
"inputs": inputs,
|
||||
"docs": docs
|
||||
}
|
||||
),
|
||||
Body=json.dumps({"inputs": inputs, "docs": docs}),
|
||||
ContentType="application/json",
|
||||
)
|
||||
json_str = response_model['Body'].read().decode('utf8')
|
||||
json_str = response_model["Body"].read().decode("utf8")
|
||||
json_obj = json.loads(json_str)
|
||||
scores = json_obj['scores']
|
||||
scores = json_obj["scores"]
|
||||
return scores if isinstance(scores, list) else [scores]
|
||||
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
invoke tools
|
||||
"""
|
||||
line = 0
|
||||
try:
|
||||
if not self.sagemaker_client:
|
||||
aws_region = tool_parameters.get('aws_region')
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
if aws_region:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||
else:
|
||||
@ -47,25 +44,25 @@ class SageMakerReRankTool(BuiltinTool):
|
||||
|
||||
line = 1
|
||||
if not self.sagemaker_endpoint:
|
||||
self.sagemaker_endpoint = tool_parameters.get('sagemaker_endpoint')
|
||||
self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint")
|
||||
|
||||
line = 2
|
||||
if not self.topk:
|
||||
self.topk = tool_parameters.get('topk', 5)
|
||||
self.topk = tool_parameters.get("topk", 5)
|
||||
|
||||
line = 3
|
||||
query = tool_parameters.get('query', '')
|
||||
query = tool_parameters.get("query", "")
|
||||
if not query:
|
||||
return self.create_text_message('Please input query')
|
||||
|
||||
return self.create_text_message("Please input query")
|
||||
|
||||
line = 4
|
||||
candidate_texts = tool_parameters.get('candidate_texts')
|
||||
candidate_texts = tool_parameters.get("candidate_texts")
|
||||
if not candidate_texts:
|
||||
return self.create_text_message('Please input candidate_texts')
|
||||
|
||||
return self.create_text_message("Please input candidate_texts")
|
||||
|
||||
line = 5
|
||||
candidate_docs = json.loads(candidate_texts)
|
||||
docs = [ item.get('content') for item in candidate_docs ]
|
||||
docs = [item.get("content") for item in candidate_docs]
|
||||
|
||||
line = 6
|
||||
scores = self._sagemaker_rerank(query_input=query, docs=docs, rerank_endpoint=self.sagemaker_endpoint)
|
||||
@ -75,10 +72,10 @@ class SageMakerReRankTool(BuiltinTool):
|
||||
candidate_docs[idx]["score"] = scores[idx]
|
||||
|
||||
line = 8
|
||||
sorted_candidate_docs = sorted(candidate_docs, key=lambda x: x['score'], reverse=True)
|
||||
sorted_candidate_docs = sorted(candidate_docs, key=operator.itemgetter("score"), reverse=True)
|
||||
|
||||
line = 9
|
||||
return [ self.create_json_message(res) for res in sorted_candidate_docs[:self.topk] ]
|
||||
|
||||
return [self.create_json_message(res) for res in sorted_candidate_docs[: self.topk]]
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f'Exception {str(e)}, line : {line}')
|
||||
return self.create_text_message(f"Exception {str(e)}, line : {line}")
|
||||
|
||||
@ -14,82 +14,88 @@ class TTSModelType(Enum):
|
||||
CloneVoice_CrossLingual = "CloneVoice_CrossLingual"
|
||||
InstructVoice = "InstructVoice"
|
||||
|
||||
|
||||
class SageMakerTTSTool(BuiltinTool):
|
||||
sagemaker_client: Any = None
|
||||
sagemaker_endpoint:str = None
|
||||
s3_client : Any = None
|
||||
comprehend_client : Any = None
|
||||
sagemaker_endpoint: str = None
|
||||
s3_client: Any = None
|
||||
comprehend_client: Any = None
|
||||
|
||||
def _detect_lang_code(self, content:str, map_dict:dict=None):
|
||||
map_dict = {
|
||||
"zh" : "<|zh|>",
|
||||
"en" : "<|en|>",
|
||||
"ja" : "<|jp|>",
|
||||
"zh-TW" : "<|yue|>",
|
||||
"ko" : "<|ko|>"
|
||||
}
|
||||
def _detect_lang_code(self, content: str, map_dict: dict = None):
|
||||
map_dict = {"zh": "<|zh|>", "en": "<|en|>", "ja": "<|jp|>", "zh-TW": "<|yue|>", "ko": "<|ko|>"}
|
||||
|
||||
response = self.comprehend_client.detect_dominant_language(Text=content)
|
||||
language_code = response['Languages'][0]['LanguageCode']
|
||||
return map_dict.get(language_code, '<|zh|>')
|
||||
language_code = response["Languages"][0]["LanguageCode"]
|
||||
return map_dict.get(language_code, "<|zh|>")
|
||||
|
||||
def _build_tts_payload(self, model_type:str, content_text:str, model_role:str, prompt_text:str, prompt_audio:str, instruct_text:str):
|
||||
def _build_tts_payload(
|
||||
self,
|
||||
model_type: str,
|
||||
content_text: str,
|
||||
model_role: str,
|
||||
prompt_text: str,
|
||||
prompt_audio: str,
|
||||
instruct_text: str,
|
||||
):
|
||||
if model_type == TTSModelType.PresetVoice.value and model_role:
|
||||
return { "tts_text" : content_text, "role" : model_role }
|
||||
return {"tts_text": content_text, "role": model_role}
|
||||
if model_type == TTSModelType.CloneVoice.value and prompt_text and prompt_audio:
|
||||
return { "tts_text" : content_text, "prompt_text": prompt_text, "prompt_audio" : prompt_audio }
|
||||
if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio:
|
||||
return {"tts_text": content_text, "prompt_text": prompt_text, "prompt_audio": prompt_audio}
|
||||
if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio:
|
||||
lang_tag = self._detect_lang_code(content_text)
|
||||
return { "tts_text" : f"{content_text}", "prompt_audio" : prompt_audio, "lang_tag" : lang_tag }
|
||||
if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role:
|
||||
return { "tts_text" : content_text, "role" : model_role, "instruct_text" : instruct_text }
|
||||
return {"tts_text": f"{content_text}", "prompt_audio": prompt_audio, "lang_tag": lang_tag}
|
||||
if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role:
|
||||
return {"tts_text": content_text, "role": model_role, "instruct_text": instruct_text}
|
||||
|
||||
raise RuntimeError(f"Invalid params for {model_type}")
|
||||
|
||||
def _invoke_sagemaker(self, payload:dict, endpoint:str):
|
||||
def _invoke_sagemaker(self, payload: dict, endpoint: str):
|
||||
response_model = self.sagemaker_client.invoke_endpoint(
|
||||
EndpointName=endpoint,
|
||||
Body=json.dumps(payload),
|
||||
ContentType="application/json",
|
||||
)
|
||||
json_str = response_model['Body'].read().decode('utf8')
|
||||
json_str = response_model["Body"].read().decode("utf8")
|
||||
json_obj = json.loads(json_str)
|
||||
return json_obj
|
||||
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
invoke tools
|
||||
"""
|
||||
try:
|
||||
if not self.sagemaker_client:
|
||||
aws_region = tool_parameters.get('aws_region')
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
if aws_region:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||
self.s3_client = boto3.client("s3", region_name=aws_region)
|
||||
self.comprehend_client = boto3.client('comprehend', region_name=aws_region)
|
||||
self.comprehend_client = boto3.client("comprehend", region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||
self.s3_client = boto3.client("s3")
|
||||
self.comprehend_client = boto3.client('comprehend')
|
||||
self.comprehend_client = boto3.client("comprehend")
|
||||
|
||||
if not self.sagemaker_endpoint:
|
||||
self.sagemaker_endpoint = tool_parameters.get('sagemaker_endpoint')
|
||||
self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint")
|
||||
|
||||
tts_text = tool_parameters.get('tts_text')
|
||||
tts_infer_type = tool_parameters.get('tts_infer_type')
|
||||
tts_text = tool_parameters.get("tts_text")
|
||||
tts_infer_type = tool_parameters.get("tts_infer_type")
|
||||
|
||||
voice = tool_parameters.get('voice')
|
||||
mock_voice_audio = tool_parameters.get('mock_voice_audio')
|
||||
mock_voice_text = tool_parameters.get('mock_voice_text')
|
||||
voice_instruct_prompt = tool_parameters.get('voice_instruct_prompt')
|
||||
payload = self._build_tts_payload(tts_infer_type, tts_text, voice, mock_voice_text, mock_voice_audio, voice_instruct_prompt)
|
||||
voice = tool_parameters.get("voice")
|
||||
mock_voice_audio = tool_parameters.get("mock_voice_audio")
|
||||
mock_voice_text = tool_parameters.get("mock_voice_text")
|
||||
voice_instruct_prompt = tool_parameters.get("voice_instruct_prompt")
|
||||
payload = self._build_tts_payload(
|
||||
tts_infer_type, tts_text, voice, mock_voice_text, mock_voice_audio, voice_instruct_prompt
|
||||
)
|
||||
|
||||
result = self._invoke_sagemaker(payload, self.sagemaker_endpoint)
|
||||
|
||||
return self.create_text_message(text=result['s3_presign_url'])
|
||||
|
||||
return self.create_text_message(text=result["s3_presign_url"])
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f'Exception {str(e)}')
|
||||
return self.create_text_message(f"Exception {str(e)}")
|
||||
|
||||
@ -13,12 +13,8 @@ class AzureDALLEProvider(BuiltinToolProviderController):
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_parameters={
|
||||
"prompt": "cute girl, blue eyes, white hair, anime style",
|
||||
"size": "square",
|
||||
"n": 1
|
||||
},
|
||||
user_id="",
|
||||
tool_parameters={"prompt": "cute girl, blue eyes, white hair, anime style", "size": "square", "n": 1},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
@ -9,47 +9,48 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class DallE3Tool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
invoke tools
|
||||
"""
|
||||
client = AzureOpenAI(
|
||||
api_version=self.runtime.credentials['azure_openai_api_version'],
|
||||
azure_endpoint=self.runtime.credentials['azure_openai_base_url'],
|
||||
api_key=self.runtime.credentials['azure_openai_api_key'],
|
||||
api_version=self.runtime.credentials["azure_openai_api_version"],
|
||||
azure_endpoint=self.runtime.credentials["azure_openai_base_url"],
|
||||
api_key=self.runtime.credentials["azure_openai_api_key"],
|
||||
)
|
||||
|
||||
SIZE_MAPPING = {
|
||||
'square': '1024x1024',
|
||||
'vertical': '1024x1792',
|
||||
'horizontal': '1792x1024',
|
||||
"square": "1024x1024",
|
||||
"vertical": "1024x1792",
|
||||
"horizontal": "1792x1024",
|
||||
}
|
||||
|
||||
# prompt
|
||||
prompt = tool_parameters.get('prompt', '')
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
if not prompt:
|
||||
return self.create_text_message('Please input prompt')
|
||||
return self.create_text_message("Please input prompt")
|
||||
# get size
|
||||
size = SIZE_MAPPING[tool_parameters.get('size', 'square')]
|
||||
size = SIZE_MAPPING[tool_parameters.get("size", "square")]
|
||||
# get n
|
||||
n = tool_parameters.get('n', 1)
|
||||
n = tool_parameters.get("n", 1)
|
||||
# get quality
|
||||
quality = tool_parameters.get('quality', 'standard')
|
||||
if quality not in ['standard', 'hd']:
|
||||
return self.create_text_message('Invalid quality')
|
||||
quality = tool_parameters.get("quality", "standard")
|
||||
if quality not in {"standard", "hd"}:
|
||||
return self.create_text_message("Invalid quality")
|
||||
# get style
|
||||
style = tool_parameters.get('style', 'vivid')
|
||||
if style not in ['natural', 'vivid']:
|
||||
return self.create_text_message('Invalid style')
|
||||
style = tool_parameters.get("style", "vivid")
|
||||
if style not in {"natural", "vivid"}:
|
||||
return self.create_text_message("Invalid style")
|
||||
# set extra body
|
||||
seed_id = tool_parameters.get('seed_id', self._generate_random_id(8))
|
||||
extra_body = {'seed': seed_id}
|
||||
seed_id = tool_parameters.get("seed_id", self._generate_random_id(8))
|
||||
extra_body = {"seed": seed_id}
|
||||
|
||||
# call openapi dalle3
|
||||
model = self.runtime.credentials['azure_openai_api_model_name']
|
||||
model = self.runtime.credentials["azure_openai_api_model_name"]
|
||||
response = client.images.generate(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
@ -58,21 +59,25 @@ class DallE3Tool(BuiltinTool):
|
||||
extra_body=extra_body,
|
||||
style=style,
|
||||
quality=quality,
|
||||
response_format='b64_json'
|
||||
response_format="b64_json",
|
||||
)
|
||||
|
||||
result = []
|
||||
|
||||
for image in response.data:
|
||||
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
|
||||
meta={'mime_type': 'image/png'},
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value))
|
||||
result.append(self.create_text_message(f'\nGenerate image source to Seed ID: {seed_id}'))
|
||||
result.append(
|
||||
self.create_blob_message(
|
||||
blob=b64decode(image.b64_json),
|
||||
meta={"mime_type": "image/png"},
|
||||
save_as=self.VariableKey.IMAGE.value,
|
||||
)
|
||||
)
|
||||
result.append(self.create_text_message(f"\nGenerate image source to Seed ID: {seed_id}"))
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _generate_random_id(length=8):
|
||||
characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
|
||||
random_id = ''.join(random.choices(characters, k=length))
|
||||
characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
||||
random_id = "".join(random.choices(characters, k=length))
|
||||
return random_id
|
||||
|
||||
@ -8,142 +8,135 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class BingSearchTool(BuiltinTool):
|
||||
url: str = 'https://api.bing.microsoft.com/v7.0/search'
|
||||
url: str = "https://api.bing.microsoft.com/v7.0/search"
|
||||
|
||||
def _invoke_bing(self,
|
||||
user_id: str,
|
||||
server_url: str,
|
||||
subscription_key: str, query: str, limit: int,
|
||||
result_type: str, market: str, lang: str,
|
||||
filters: list[str]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
def _invoke_bing(
|
||||
self,
|
||||
user_id: str,
|
||||
server_url: str,
|
||||
subscription_key: str,
|
||||
query: str,
|
||||
limit: int,
|
||||
result_type: str,
|
||||
market: str,
|
||||
lang: str,
|
||||
filters: list[str],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke bing search
|
||||
invoke bing search
|
||||
"""
|
||||
market_code = f'{lang}-{market}'
|
||||
accept_language = f'{lang},{market_code};q=0.9'
|
||||
headers = {
|
||||
'Ocp-Apim-Subscription-Key': subscription_key,
|
||||
'Accept-Language': accept_language
|
||||
}
|
||||
market_code = f"{lang}-{market}"
|
||||
accept_language = f"{lang},{market_code};q=0.9"
|
||||
headers = {"Ocp-Apim-Subscription-Key": subscription_key, "Accept-Language": accept_language}
|
||||
|
||||
query = quote(query)
|
||||
server_url = f'{server_url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filters)}'
|
||||
response = get(server_url, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f'Error {response.status_code}: {response.text}')
|
||||
|
||||
response = response.json()
|
||||
search_results = response['webPages']['value'][:limit] if 'webPages' in response else []
|
||||
related_searches = response['relatedSearches']['value'] if 'relatedSearches' in response else []
|
||||
entities = response['entities']['value'] if 'entities' in response else []
|
||||
news = response['news']['value'] if 'news' in response else []
|
||||
computation = response['computation']['value'] if 'computation' in response else None
|
||||
raise Exception(f"Error {response.status_code}: {response.text}")
|
||||
|
||||
if result_type == 'link':
|
||||
response = response.json()
|
||||
search_results = response["webPages"]["value"][:limit] if "webPages" in response else []
|
||||
related_searches = response["relatedSearches"]["value"] if "relatedSearches" in response else []
|
||||
entities = response["entities"]["value"] if "entities" in response else []
|
||||
news = response["news"]["value"] if "news" in response else []
|
||||
computation = response["computation"]["value"] if "computation" in response else None
|
||||
|
||||
if result_type == "link":
|
||||
results = []
|
||||
if search_results:
|
||||
for result in search_results:
|
||||
url = f': {result["url"]}' if "url" in result else ""
|
||||
results.append(self.create_text_message(
|
||||
text=f'{result["name"]}{url}'
|
||||
))
|
||||
|
||||
results.append(self.create_text_message(text=f'{result["name"]}{url}'))
|
||||
|
||||
if entities:
|
||||
for entity in entities:
|
||||
url = f': {entity["url"]}' if "url" in entity else ""
|
||||
results.append(self.create_text_message(
|
||||
text=f'{entity.get("name", "")}{url}'
|
||||
))
|
||||
results.append(self.create_text_message(text=f'{entity.get("name", "")}{url}'))
|
||||
|
||||
if news:
|
||||
for news_item in news:
|
||||
url = f': {news_item["url"]}' if "url" in news_item else ""
|
||||
results.append(self.create_text_message(
|
||||
text=f'{news_item.get("name", "")}{url}'
|
||||
))
|
||||
results.append(self.create_text_message(text=f'{news_item.get("name", "")}{url}'))
|
||||
|
||||
if related_searches:
|
||||
for related in related_searches:
|
||||
url = f': {related["displayText"]}' if "displayText" in related else ""
|
||||
results.append(self.create_text_message(
|
||||
text=f'{related.get("displayText", "")}{url}'
|
||||
))
|
||||
|
||||
results.append(self.create_text_message(text=f'{related.get("displayText", "")}{url}'))
|
||||
|
||||
return results
|
||||
else:
|
||||
# construct text
|
||||
text = ''
|
||||
text = ""
|
||||
if search_results:
|
||||
for i, result in enumerate(search_results):
|
||||
text += f'{i+1}: {result.get("name", "")} - {result.get("snippet", "")}\n'
|
||||
text += f'{i + 1}: {result.get("name", "")} - {result.get("snippet", "")}\n'
|
||||
|
||||
if computation and 'expression' in computation and 'value' in computation:
|
||||
text += '\nComputation:\n'
|
||||
if computation and "expression" in computation and "value" in computation:
|
||||
text += "\nComputation:\n"
|
||||
text += f'{computation["expression"]} = {computation["value"]}\n'
|
||||
|
||||
if entities:
|
||||
text += '\nEntities:\n'
|
||||
text += "\nEntities:\n"
|
||||
for entity in entities:
|
||||
url = f'- {entity["url"]}' if "url" in entity else ""
|
||||
text += f'{entity.get("name", "")}{url}\n'
|
||||
|
||||
if news:
|
||||
text += '\nNews:\n'
|
||||
text += "\nNews:\n"
|
||||
for news_item in news:
|
||||
url = f'- {news_item["url"]}' if "url" in news_item else ""
|
||||
text += f'{news_item.get("name", "")}{url}\n'
|
||||
|
||||
if related_searches:
|
||||
text += '\n\nRelated Searches:\n'
|
||||
text += "\n\nRelated Searches:\n"
|
||||
for related in related_searches:
|
||||
url = f'- {related["webSearchUrl"]}' if "webSearchUrl" in related else ""
|
||||
text += f'{related.get("displayText", "")}{url}\n'
|
||||
|
||||
return self.create_text_message(text=self.summary(user_id=user_id, content=text))
|
||||
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any], tool_parameters: dict[str, Any]) -> None:
|
||||
key = credentials.get('subscription_key')
|
||||
key = credentials.get("subscription_key")
|
||||
if not key:
|
||||
raise Exception('subscription_key is required')
|
||||
|
||||
server_url = credentials.get('server_url')
|
||||
raise Exception("subscription_key is required")
|
||||
|
||||
server_url = credentials.get("server_url")
|
||||
if not server_url:
|
||||
server_url = self.url
|
||||
|
||||
query = tool_parameters.get('query')
|
||||
query = tool_parameters.get("query")
|
||||
if not query:
|
||||
raise Exception('query is required')
|
||||
|
||||
limit = min(tool_parameters.get('limit', 5), 10)
|
||||
result_type = tool_parameters.get('result_type', 'text') or 'text'
|
||||
raise Exception("query is required")
|
||||
|
||||
market = tool_parameters.get('market', 'US')
|
||||
lang = tool_parameters.get('language', 'en')
|
||||
limit = min(tool_parameters.get("limit", 5), 10)
|
||||
result_type = tool_parameters.get("result_type", "text") or "text"
|
||||
|
||||
market = tool_parameters.get("market", "US")
|
||||
lang = tool_parameters.get("language", "en")
|
||||
filter = []
|
||||
|
||||
if credentials.get('allow_entities', False):
|
||||
filter.append('Entities')
|
||||
if credentials.get("allow_entities", False):
|
||||
filter.append("Entities")
|
||||
|
||||
if credentials.get('allow_computation', False):
|
||||
filter.append('Computation')
|
||||
if credentials.get("allow_computation", False):
|
||||
filter.append("Computation")
|
||||
|
||||
if credentials.get('allow_news', False):
|
||||
filter.append('News')
|
||||
if credentials.get("allow_news", False):
|
||||
filter.append("News")
|
||||
|
||||
if credentials.get('allow_related_searches', False):
|
||||
filter.append('RelatedSearches')
|
||||
if credentials.get("allow_related_searches", False):
|
||||
filter.append("RelatedSearches")
|
||||
|
||||
if credentials.get('allow_web_pages', False):
|
||||
filter.append('WebPages')
|
||||
if credentials.get("allow_web_pages", False):
|
||||
filter.append("WebPages")
|
||||
|
||||
if not filter:
|
||||
raise Exception('At least one filter is required')
|
||||
|
||||
raise Exception("At least one filter is required")
|
||||
|
||||
self._invoke_bing(
|
||||
user_id='test',
|
||||
user_id="test",
|
||||
server_url=server_url,
|
||||
subscription_key=key,
|
||||
query=query,
|
||||
@ -151,50 +144,51 @@ class BingSearchTool(BuiltinTool):
|
||||
result_type=result_type,
|
||||
market=market,
|
||||
lang=lang,
|
||||
filters=filter
|
||||
filters=filter,
|
||||
)
|
||||
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
invoke tools
|
||||
"""
|
||||
|
||||
key = self.runtime.credentials.get('subscription_key', None)
|
||||
key = self.runtime.credentials.get("subscription_key", None)
|
||||
if not key:
|
||||
raise Exception('subscription_key is required')
|
||||
|
||||
server_url = self.runtime.credentials.get('server_url', None)
|
||||
raise Exception("subscription_key is required")
|
||||
|
||||
server_url = self.runtime.credentials.get("server_url", None)
|
||||
if not server_url:
|
||||
server_url = self.url
|
||||
|
||||
query = tool_parameters.get('query')
|
||||
|
||||
query = tool_parameters.get("query")
|
||||
if not query:
|
||||
raise Exception('query is required')
|
||||
|
||||
limit = min(tool_parameters.get('limit', 5), 10)
|
||||
result_type = tool_parameters.get('result_type', 'text') or 'text'
|
||||
|
||||
market = tool_parameters.get('market', 'US')
|
||||
lang = tool_parameters.get('language', 'en')
|
||||
raise Exception("query is required")
|
||||
|
||||
limit = min(tool_parameters.get("limit", 5), 10)
|
||||
result_type = tool_parameters.get("result_type", "text") or "text"
|
||||
|
||||
market = tool_parameters.get("market", "US")
|
||||
lang = tool_parameters.get("language", "en")
|
||||
filter = []
|
||||
|
||||
if tool_parameters.get('enable_computation', False):
|
||||
filter.append('Computation')
|
||||
if tool_parameters.get('enable_entities', False):
|
||||
filter.append('Entities')
|
||||
if tool_parameters.get('enable_news', False):
|
||||
filter.append('News')
|
||||
if tool_parameters.get('enable_related_search', False):
|
||||
filter.append('RelatedSearches')
|
||||
if tool_parameters.get('enable_webpages', False):
|
||||
filter.append('WebPages')
|
||||
if tool_parameters.get("enable_computation", False):
|
||||
filter.append("Computation")
|
||||
if tool_parameters.get("enable_entities", False):
|
||||
filter.append("Entities")
|
||||
if tool_parameters.get("enable_news", False):
|
||||
filter.append("News")
|
||||
if tool_parameters.get("enable_related_search", False):
|
||||
filter.append("RelatedSearches")
|
||||
if tool_parameters.get("enable_webpages", False):
|
||||
filter.append("WebPages")
|
||||
|
||||
if not filter:
|
||||
raise Exception('At least one filter is required')
|
||||
|
||||
raise Exception("At least one filter is required")
|
||||
|
||||
return self._invoke_bing(
|
||||
user_id=user_id,
|
||||
server_url=server_url,
|
||||
@ -204,5 +198,5 @@ class BingSearchTool(BuiltinTool):
|
||||
result_type=result_type,
|
||||
market=market,
|
||||
lang=lang,
|
||||
filters=filter
|
||||
)
|
||||
filters=filter,
|
||||
)
|
||||
|
||||
@ -13,11 +13,10 @@ class BraveProvider(BuiltinToolProviderController):
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"query": "Sachin Tendulkar",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
@ -37,7 +37,7 @@ class BraveSearchWrapper(BaseModel):
|
||||
for item in web_search_results
|
||||
]
|
||||
return json.dumps(final_results)
|
||||
|
||||
|
||||
def _search_request(self, query: str) -> list[dict]:
|
||||
headers = {
|
||||
"X-Subscription-Token": self.api_key,
|
||||
@ -55,6 +55,7 @@ class BraveSearchWrapper(BaseModel):
|
||||
|
||||
return response.json().get("web", {}).get("results", [])
|
||||
|
||||
|
||||
class BraveSearch(BaseModel):
|
||||
"""Tool that queries the BraveSearch."""
|
||||
|
||||
@ -67,9 +68,7 @@ class BraveSearch(BaseModel):
|
||||
search_wrapper: BraveSearchWrapper
|
||||
|
||||
@classmethod
|
||||
def from_api_key(
|
||||
cls, api_key: str, search_kwargs: Optional[dict] = None, **kwargs: Any
|
||||
) -> "BraveSearch":
|
||||
def from_api_key(cls, api_key: str, search_kwargs: Optional[dict] = None, **kwargs: Any) -> "BraveSearch":
|
||||
"""Create a tool from an api key.
|
||||
|
||||
Args:
|
||||
@ -90,6 +89,7 @@ class BraveSearch(BaseModel):
|
||||
"""Use the tool."""
|
||||
return self.search_wrapper.run(query)
|
||||
|
||||
|
||||
class BraveSearchTool(BuiltinTool):
|
||||
"""
|
||||
Tool for performing a search using Brave search engine.
|
||||
@ -106,12 +106,12 @@ class BraveSearchTool(BuiltinTool):
|
||||
Returns:
|
||||
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation.
|
||||
"""
|
||||
query = tool_parameters.get('query', '')
|
||||
count = tool_parameters.get('count', 3)
|
||||
api_key = self.runtime.credentials['brave_search_api_key']
|
||||
query = tool_parameters.get("query", "")
|
||||
count = tool_parameters.get("count", 3)
|
||||
api_key = self.runtime.credentials["brave_search_api_key"]
|
||||
|
||||
if not query:
|
||||
return self.create_text_message('Please input query')
|
||||
return self.create_text_message("Please input query")
|
||||
|
||||
tool = BraveSearch.from_api_key(api_key=api_key, search_kwargs={"count": count})
|
||||
|
||||
@ -121,4 +121,3 @@ class BraveSearchTool(BuiltinTool):
|
||||
return self.create_text_message(f"No results found for '{query}' in Tavily")
|
||||
else:
|
||||
return self.create_text_message(text=results)
|
||||
|
||||
|
||||
@ -7,16 +7,34 @@ from core.tools.provider.builtin.chart.tools.line import LinearChartTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
# use a business theme
|
||||
plt.style.use('seaborn-v0_8-darkgrid')
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
plt.style.use("seaborn-v0_8-darkgrid")
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
|
||||
|
||||
def init_fonts():
|
||||
fonts = findSystemFonts()
|
||||
|
||||
popular_unicode_fonts = [
|
||||
'Arial Unicode MS', 'DejaVu Sans', 'DejaVu Sans Mono', 'DejaVu Serif', 'FreeMono', 'FreeSans', 'FreeSerif',
|
||||
'Liberation Mono', 'Liberation Sans', 'Liberation Serif', 'Noto Mono', 'Noto Sans', 'Noto Serif', 'Open Sans',
|
||||
'Roboto', 'Source Code Pro', 'Source Sans Pro', 'Source Serif Pro', 'Ubuntu', 'Ubuntu Mono'
|
||||
"Arial Unicode MS",
|
||||
"DejaVu Sans",
|
||||
"DejaVu Sans Mono",
|
||||
"DejaVu Serif",
|
||||
"FreeMono",
|
||||
"FreeSans",
|
||||
"FreeSerif",
|
||||
"Liberation Mono",
|
||||
"Liberation Sans",
|
||||
"Liberation Serif",
|
||||
"Noto Mono",
|
||||
"Noto Sans",
|
||||
"Noto Serif",
|
||||
"Open Sans",
|
||||
"Roboto",
|
||||
"Source Code Pro",
|
||||
"Source Sans Pro",
|
||||
"Source Serif Pro",
|
||||
"Ubuntu",
|
||||
"Ubuntu Mono",
|
||||
]
|
||||
|
||||
supported_fonts = []
|
||||
@ -25,21 +43,23 @@ def init_fonts():
|
||||
try:
|
||||
font = TTFont(font_path)
|
||||
# get family name
|
||||
family_name = font['name'].getName(1, 3, 1).toUnicode()
|
||||
family_name = font["name"].getName(1, 3, 1).toUnicode()
|
||||
if family_name in popular_unicode_fonts:
|
||||
supported_fonts.append(family_name)
|
||||
except:
|
||||
pass
|
||||
|
||||
plt.rcParams['font.family'] = 'sans-serif'
|
||||
plt.rcParams["font.family"] = "sans-serif"
|
||||
# sort by order of popular_unicode_fonts
|
||||
for font in popular_unicode_fonts:
|
||||
if font in supported_fonts:
|
||||
plt.rcParams['font.sans-serif'] = font
|
||||
plt.rcParams["font.sans-serif"] = font
|
||||
break
|
||||
|
||||
|
||||
|
||||
init_fonts()
|
||||
|
||||
|
||||
class ChartProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
@ -48,11 +68,10 @@ class ChartProvider(BuiltinToolProviderController):
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"data": "1,3,5,7,9,2,4,6,8,10",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
@ -8,12 +8,13 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class BarChartTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \
|
||||
-> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
data = tool_parameters.get('data', '')
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
data = tool_parameters.get("data", "")
|
||||
if not data:
|
||||
return self.create_text_message('Please input data')
|
||||
data = data.split(';')
|
||||
return self.create_text_message("Please input data")
|
||||
data = data.split(";")
|
||||
|
||||
# if all data is int, convert to int
|
||||
if all(i.isdigit() for i in data):
|
||||
@ -21,29 +22,27 @@ class BarChartTool(BuiltinTool):
|
||||
else:
|
||||
data = [float(i) for i in data]
|
||||
|
||||
axis = tool_parameters.get('x_axis') or None
|
||||
axis = tool_parameters.get("x_axis") or None
|
||||
if axis:
|
||||
axis = axis.split(';')
|
||||
axis = axis.split(";")
|
||||
if len(axis) != len(data):
|
||||
axis = None
|
||||
|
||||
flg, ax = plt.subplots(figsize=(10, 8))
|
||||
|
||||
if axis:
|
||||
axis = [label[:10] + '...' if len(label) > 10 else label for label in axis]
|
||||
ax.set_xticklabels(axis, rotation=45, ha='right')
|
||||
axis = [label[:10] + "..." if len(label) > 10 else label for label in axis]
|
||||
ax.set_xticklabels(axis, rotation=45, ha="right")
|
||||
ax.bar(axis, data)
|
||||
else:
|
||||
ax.bar(range(len(data)), data)
|
||||
|
||||
buf = io.BytesIO()
|
||||
flg.savefig(buf, format='png')
|
||||
flg.savefig(buf, format="png")
|
||||
buf.seek(0)
|
||||
plt.close(flg)
|
||||
|
||||
return [
|
||||
self.create_text_message('the bar chart is saved as an image.'),
|
||||
self.create_blob_message(blob=buf.read(),
|
||||
meta={'mime_type': 'image/png'})
|
||||
self.create_text_message("the bar chart is saved as an image."),
|
||||
self.create_blob_message(blob=buf.read(), meta={"mime_type": "image/png"}),
|
||||
]
|
||||
|
||||
@ -8,18 +8,19 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class LinearChartTool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
data = tool_parameters.get('data', '')
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
data = tool_parameters.get("data", "")
|
||||
if not data:
|
||||
return self.create_text_message('Please input data')
|
||||
data = data.split(';')
|
||||
return self.create_text_message("Please input data")
|
||||
data = data.split(";")
|
||||
|
||||
axis = tool_parameters.get('x_axis') or None
|
||||
axis = tool_parameters.get("x_axis") or None
|
||||
if axis:
|
||||
axis = axis.split(';')
|
||||
axis = axis.split(";")
|
||||
if len(axis) != len(data):
|
||||
axis = None
|
||||
|
||||
@ -32,20 +33,18 @@ class LinearChartTool(BuiltinTool):
|
||||
flg, ax = plt.subplots(figsize=(10, 8))
|
||||
|
||||
if axis:
|
||||
axis = [label[:10] + '...' if len(label) > 10 else label for label in axis]
|
||||
ax.set_xticklabels(axis, rotation=45, ha='right')
|
||||
axis = [label[:10] + "..." if len(label) > 10 else label for label in axis]
|
||||
ax.set_xticklabels(axis, rotation=45, ha="right")
|
||||
ax.plot(axis, data)
|
||||
else:
|
||||
ax.plot(data)
|
||||
|
||||
buf = io.BytesIO()
|
||||
flg.savefig(buf, format='png')
|
||||
flg.savefig(buf, format="png")
|
||||
buf.seek(0)
|
||||
plt.close(flg)
|
||||
|
||||
return [
|
||||
self.create_text_message('the linear chart is saved as an image.'),
|
||||
self.create_blob_message(blob=buf.read(),
|
||||
meta={'mime_type': 'image/png'})
|
||||
self.create_text_message("the linear chart is saved as an image."),
|
||||
self.create_blob_message(blob=buf.read(), meta={"mime_type": "image/png"}),
|
||||
]
|
||||
|
||||
@ -8,15 +8,16 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class PieChartTool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
data = tool_parameters.get('data', '')
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
data = tool_parameters.get("data", "")
|
||||
if not data:
|
||||
return self.create_text_message('Please input data')
|
||||
data = data.split(';')
|
||||
categories = tool_parameters.get('categories') or None
|
||||
return self.create_text_message("Please input data")
|
||||
data = data.split(";")
|
||||
categories = tool_parameters.get("categories") or None
|
||||
|
||||
# if all data is int, convert to int
|
||||
if all(i.isdigit() for i in data):
|
||||
@ -27,7 +28,7 @@ class PieChartTool(BuiltinTool):
|
||||
flg, ax = plt.subplots()
|
||||
|
||||
if categories:
|
||||
categories = categories.split(';')
|
||||
categories = categories.split(";")
|
||||
if len(categories) != len(data):
|
||||
categories = None
|
||||
|
||||
@ -37,12 +38,11 @@ class PieChartTool(BuiltinTool):
|
||||
ax.pie(data)
|
||||
|
||||
buf = io.BytesIO()
|
||||
flg.savefig(buf, format='png')
|
||||
flg.savefig(buf, format="png")
|
||||
buf.seek(0)
|
||||
plt.close(flg)
|
||||
|
||||
return [
|
||||
self.create_text_message('the pie chart is saved as an image.'),
|
||||
self.create_blob_message(blob=buf.read(),
|
||||
meta={'mime_type': 'image/png'})
|
||||
]
|
||||
self.create_text_message("the pie chart is saved as an image."),
|
||||
self.create_blob_message(blob=buf.read(), meta={"mime_type": "image/png"}),
|
||||
]
|
||||
|
||||
@ -8,15 +8,15 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
class SimpleCode(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
"""
|
||||
invoke simple code
|
||||
invoke simple code
|
||||
"""
|
||||
|
||||
language = tool_parameters.get('language', CodeLanguage.PYTHON3)
|
||||
code = tool_parameters.get('code', '')
|
||||
language = tool_parameters.get("language", CodeLanguage.PYTHON3)
|
||||
code = tool_parameters.get("code", "")
|
||||
|
||||
if language not in [CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]:
|
||||
raise ValueError(f'Only python3 and javascript are supported, not {language}')
|
||||
|
||||
result = CodeExecutor.execute_code(language, '', code)
|
||||
if language not in {CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT}:
|
||||
raise ValueError(f"Only python3 and javascript are supported, not {language}")
|
||||
|
||||
return self.create_text_message(result)
|
||||
result = CodeExecutor.execute_code(language, "", code)
|
||||
|
||||
return self.create_text_message(result)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
""" Provide the input parameters type for the cogview provider class """
|
||||
"""Provide the input parameters type for the cogview provider class"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
@ -7,7 +8,8 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
|
||||
|
||||
|
||||
class COGVIEWProvider(BuiltinToolProviderController):
|
||||
""" cogview provider """
|
||||
"""cogview provider"""
|
||||
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
CogView3Tool().fork_tool_runtime(
|
||||
@ -15,13 +17,12 @@ class COGVIEWProvider(BuiltinToolProviderController):
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"prompt": "一个城市在水晶瓶中欢快生活的场景,水彩画风格,展现出微观与珠宝般的美丽。",
|
||||
"size": "square",
|
||||
"n": 1
|
||||
"n": 1,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e)) from e
|
||||
|
||||
@ -7,43 +7,42 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class CogView3Tool(BuiltinTool):
|
||||
""" CogView3 Tool """
|
||||
"""CogView3 Tool"""
|
||||
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
Invoke CogView3 tool
|
||||
"""
|
||||
client = ZhipuAI(
|
||||
base_url=self.runtime.credentials['zhipuai_base_url'],
|
||||
api_key=self.runtime.credentials['zhipuai_api_key'],
|
||||
base_url=self.runtime.credentials["zhipuai_base_url"],
|
||||
api_key=self.runtime.credentials["zhipuai_api_key"],
|
||||
)
|
||||
size_mapping = {
|
||||
'square': '1024x1024',
|
||||
'vertical': '1024x1792',
|
||||
'horizontal': '1792x1024',
|
||||
"square": "1024x1024",
|
||||
"vertical": "1024x1792",
|
||||
"horizontal": "1792x1024",
|
||||
}
|
||||
# prompt
|
||||
prompt = tool_parameters.get('prompt', '')
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
if not prompt:
|
||||
return self.create_text_message('Please input prompt')
|
||||
return self.create_text_message("Please input prompt")
|
||||
# get size
|
||||
size = size_mapping[tool_parameters.get('size', 'square')]
|
||||
size = size_mapping[tool_parameters.get("size", "square")]
|
||||
# get n
|
||||
n = tool_parameters.get('n', 1)
|
||||
n = tool_parameters.get("n", 1)
|
||||
# get quality
|
||||
quality = tool_parameters.get('quality', 'standard')
|
||||
if quality not in ['standard', 'hd']:
|
||||
return self.create_text_message('Invalid quality')
|
||||
quality = tool_parameters.get("quality", "standard")
|
||||
if quality not in {"standard", "hd"}:
|
||||
return self.create_text_message("Invalid quality")
|
||||
# get style
|
||||
style = tool_parameters.get('style', 'vivid')
|
||||
if style not in ['natural', 'vivid']:
|
||||
return self.create_text_message('Invalid style')
|
||||
style = tool_parameters.get("style", "vivid")
|
||||
if style not in {"natural", "vivid"}:
|
||||
return self.create_text_message("Invalid style")
|
||||
# set extra body
|
||||
seed_id = tool_parameters.get('seed_id', self._generate_random_id(8))
|
||||
extra_body = {'seed': seed_id}
|
||||
seed_id = tool_parameters.get("seed_id", self._generate_random_id(8))
|
||||
extra_body = {"seed": seed_id}
|
||||
response = client.images.generations(
|
||||
prompt=prompt,
|
||||
model="cogview-3",
|
||||
@ -52,18 +51,22 @@ class CogView3Tool(BuiltinTool):
|
||||
extra_body=extra_body,
|
||||
style=style,
|
||||
quality=quality,
|
||||
response_format='b64_json'
|
||||
response_format="b64_json",
|
||||
)
|
||||
result = []
|
||||
for image in response.data:
|
||||
result.append(self.create_image_message(image=image.url))
|
||||
result.append(self.create_json_message({
|
||||
"url": image.url,
|
||||
}))
|
||||
result.append(
|
||||
self.create_json_message(
|
||||
{
|
||||
"url": image.url,
|
||||
}
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _generate_random_id(length=8):
|
||||
characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
|
||||
random_id = ''.join(random.choices(characters, k=length))
|
||||
characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
||||
random_id = "".join(random.choices(characters, k=length))
|
||||
return random_id
|
||||
|
||||
@ -11,9 +11,9 @@ class CrossRefProvider(BuiltinToolProviderController):
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"doi": '10.1007/s00894-022-05373-8',
|
||||
"doi": "10.1007/s00894-022-05373-8",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@ -11,15 +11,18 @@ class CrossRefQueryDOITool(BuiltinTool):
|
||||
"""
|
||||
Tool for querying the metadata of a publication using its DOI.
|
||||
"""
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
doi = tool_parameters.get('doi')
|
||||
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
doi = tool_parameters.get("doi")
|
||||
if not doi:
|
||||
raise ToolParameterValidationError('doi is required.')
|
||||
raise ToolParameterValidationError("doi is required.")
|
||||
# doc: https://github.com/CrossRef/rest-api-doc
|
||||
url = f"https://api.crossref.org/works/{doi}"
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
response = response.json()
|
||||
message = response.get('message', {})
|
||||
message = response.get("message", {})
|
||||
|
||||
return self.create_json_message(message)
|
||||
|
||||
@ -12,16 +12,16 @@ def convert_time_str_to_seconds(time_str: str) -> int:
|
||||
Convert a time string to seconds.
|
||||
example: 1s -> 1, 1m30s -> 90, 1h30m -> 5400, 1h30m30s -> 5430
|
||||
"""
|
||||
time_str = time_str.lower().strip().replace(' ', '')
|
||||
time_str = time_str.lower().strip().replace(" ", "")
|
||||
seconds = 0
|
||||
if 'h' in time_str:
|
||||
hours, time_str = time_str.split('h')
|
||||
if "h" in time_str:
|
||||
hours, time_str = time_str.split("h")
|
||||
seconds += int(hours) * 3600
|
||||
if 'm' in time_str:
|
||||
minutes, time_str = time_str.split('m')
|
||||
if "m" in time_str:
|
||||
minutes, time_str = time_str.split("m")
|
||||
seconds += int(minutes) * 60
|
||||
if 's' in time_str:
|
||||
seconds += int(time_str.replace('s', ''))
|
||||
if "s" in time_str:
|
||||
seconds += int(time_str.replace("s", ""))
|
||||
return seconds
|
||||
|
||||
|
||||
@ -30,6 +30,7 @@ class CrossRefQueryTitleAPI:
|
||||
Tool for querying the metadata of a publication using its title.
|
||||
Crossref API doc: https://github.com/CrossRef/rest-api-doc
|
||||
"""
|
||||
|
||||
query_url_template: str = "https://api.crossref.org/works?query.bibliographic={query}&rows={rows}&offset={offset}&sort={sort}&order={order}&mailto={mailto}"
|
||||
rate_limit: int = 50
|
||||
rate_interval: float = 1
|
||||
@ -38,7 +39,15 @@ class CrossRefQueryTitleAPI:
|
||||
def __init__(self, mailto: str):
|
||||
self.mailto = mailto
|
||||
|
||||
def _query(self, query: str, rows: int = 5, offset: int = 0, sort: str = 'relevance', order: str = 'desc', fuzzy_query: bool = False) -> list[dict]:
|
||||
def _query(
|
||||
self,
|
||||
query: str,
|
||||
rows: int = 5,
|
||||
offset: int = 0,
|
||||
sort: str = "relevance",
|
||||
order: str = "desc",
|
||||
fuzzy_query: bool = False,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Query the metadata of a publication using its title.
|
||||
:param query: the title of the publication
|
||||
@ -47,33 +56,37 @@ class CrossRefQueryTitleAPI:
|
||||
:param order: the sort order
|
||||
:param fuzzy_query: whether to return all items that match the query
|
||||
"""
|
||||
url = self.query_url_template.format(query=query, rows=rows, offset=offset, sort=sort, order=order, mailto=self.mailto)
|
||||
url = self.query_url_template.format(
|
||||
query=query, rows=rows, offset=offset, sort=sort, order=order, mailto=self.mailto
|
||||
)
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
rate_limit = int(response.headers['x-ratelimit-limit'])
|
||||
rate_limit = int(response.headers["x-ratelimit-limit"])
|
||||
# convert time string to seconds
|
||||
rate_interval = convert_time_str_to_seconds(response.headers['x-ratelimit-interval'])
|
||||
rate_interval = convert_time_str_to_seconds(response.headers["x-ratelimit-interval"])
|
||||
|
||||
self.rate_limit = rate_limit
|
||||
self.rate_interval = rate_interval
|
||||
|
||||
response = response.json()
|
||||
if response['status'] != 'ok':
|
||||
if response["status"] != "ok":
|
||||
return []
|
||||
|
||||
message = response['message']
|
||||
message = response["message"]
|
||||
if fuzzy_query:
|
||||
# fuzzy query return all items
|
||||
return message['items']
|
||||
return message["items"]
|
||||
else:
|
||||
for paper in message['items']:
|
||||
title = paper['title'][0]
|
||||
for paper in message["items"]:
|
||||
title = paper["title"][0]
|
||||
if title.lower() != query.lower():
|
||||
continue
|
||||
return [paper]
|
||||
return []
|
||||
|
||||
def query(self, query: str, rows: int = 5, sort: str = 'relevance', order: str = 'desc', fuzzy_query: bool = False) -> list[dict]:
|
||||
def query(
|
||||
self, query: str, rows: int = 5, sort: str = "relevance", order: str = "desc", fuzzy_query: bool = False
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Query the metadata of a publication using its title.
|
||||
:param query: the title of the publication
|
||||
@ -89,7 +102,14 @@ class CrossRefQueryTitleAPI:
|
||||
results = []
|
||||
|
||||
for i in range(query_times):
|
||||
result = self._query(query, rows=self.rate_limit, offset=i * self.rate_limit, sort=sort, order=order, fuzzy_query=fuzzy_query)
|
||||
result = self._query(
|
||||
query,
|
||||
rows=self.rate_limit,
|
||||
offset=i * self.rate_limit,
|
||||
sort=sort,
|
||||
order=order,
|
||||
fuzzy_query=fuzzy_query,
|
||||
)
|
||||
if fuzzy_query:
|
||||
results.extend(result)
|
||||
else:
|
||||
@ -107,13 +127,16 @@ class CrossRefQueryTitleTool(BuiltinTool):
|
||||
"""
|
||||
Tool for querying the metadata of a publication using its title.
|
||||
"""
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
query = tool_parameters.get('query')
|
||||
fuzzy_query = tool_parameters.get('fuzzy_query', False)
|
||||
rows = tool_parameters.get('rows', 3)
|
||||
sort = tool_parameters.get('sort', 'relevance')
|
||||
order = tool_parameters.get('order', 'desc')
|
||||
mailto = self.runtime.credentials['mailto']
|
||||
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
query = tool_parameters.get("query")
|
||||
fuzzy_query = tool_parameters.get("fuzzy_query", False)
|
||||
rows = tool_parameters.get("rows", 3)
|
||||
sort = tool_parameters.get("sort", "relevance")
|
||||
order = tool_parameters.get("order", "desc")
|
||||
mailto = self.runtime.credentials["mailto"]
|
||||
|
||||
result = CrossRefQueryTitleAPI(mailto).query(query, rows, sort, order, fuzzy_query)
|
||||
|
||||
|
||||
@ -13,13 +13,8 @@ class DALLEProvider(BuiltinToolProviderController):
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_parameters={
|
||||
"prompt": "cute girl, blue eyes, white hair, anime style",
|
||||
"size": "small",
|
||||
"n": 1
|
||||
},
|
||||
user_id="",
|
||||
tool_parameters={"prompt": "cute girl, blue eyes, white hair, anime style", "size": "small", "n": 1},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
@ -9,59 +9,58 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class DallE2Tool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
invoke tools
|
||||
"""
|
||||
openai_organization = self.runtime.credentials.get('openai_organization_id', None)
|
||||
openai_organization = self.runtime.credentials.get("openai_organization_id", None)
|
||||
if not openai_organization:
|
||||
openai_organization = None
|
||||
openai_base_url = self.runtime.credentials.get('openai_base_url', None)
|
||||
openai_base_url = self.runtime.credentials.get("openai_base_url", None)
|
||||
if not openai_base_url:
|
||||
openai_base_url = None
|
||||
else:
|
||||
openai_base_url = str(URL(openai_base_url) / 'v1')
|
||||
openai_base_url = str(URL(openai_base_url) / "v1")
|
||||
|
||||
client = OpenAI(
|
||||
api_key=self.runtime.credentials['openai_api_key'],
|
||||
api_key=self.runtime.credentials["openai_api_key"],
|
||||
base_url=openai_base_url,
|
||||
organization=openai_organization
|
||||
organization=openai_organization,
|
||||
)
|
||||
|
||||
SIZE_MAPPING = {
|
||||
'small': '256x256',
|
||||
'medium': '512x512',
|
||||
'large': '1024x1024',
|
||||
"small": "256x256",
|
||||
"medium": "512x512",
|
||||
"large": "1024x1024",
|
||||
}
|
||||
|
||||
# prompt
|
||||
prompt = tool_parameters.get('prompt', '')
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
if not prompt:
|
||||
return self.create_text_message('Please input prompt')
|
||||
|
||||
return self.create_text_message("Please input prompt")
|
||||
|
||||
# get size
|
||||
size = SIZE_MAPPING[tool_parameters.get('size', 'large')]
|
||||
size = SIZE_MAPPING[tool_parameters.get("size", "large")]
|
||||
|
||||
# get n
|
||||
n = tool_parameters.get('n', 1)
|
||||
n = tool_parameters.get("n", 1)
|
||||
|
||||
# call openapi dalle2
|
||||
response = client.images.generate(
|
||||
prompt=prompt,
|
||||
model='dall-e-2',
|
||||
size=size,
|
||||
n=n,
|
||||
response_format='b64_json'
|
||||
)
|
||||
response = client.images.generate(prompt=prompt, model="dall-e-2", size=size, n=n, response_format="b64_json")
|
||||
|
||||
result = []
|
||||
|
||||
for image in response.data:
|
||||
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
|
||||
meta={ 'mime_type': 'image/png' },
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value))
|
||||
result.append(
|
||||
self.create_blob_message(
|
||||
blob=b64decode(image.b64_json),
|
||||
meta={"mime_type": "image/png"},
|
||||
save_as=self.VariableKey.IMAGE.value,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@ -10,69 +10,64 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class DallE3Tool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
invoke tools
|
||||
"""
|
||||
openai_organization = self.runtime.credentials.get('openai_organization_id', None)
|
||||
openai_organization = self.runtime.credentials.get("openai_organization_id", None)
|
||||
if not openai_organization:
|
||||
openai_organization = None
|
||||
openai_base_url = self.runtime.credentials.get('openai_base_url', None)
|
||||
openai_base_url = self.runtime.credentials.get("openai_base_url", None)
|
||||
if not openai_base_url:
|
||||
openai_base_url = None
|
||||
else:
|
||||
openai_base_url = str(URL(openai_base_url) / 'v1')
|
||||
openai_base_url = str(URL(openai_base_url) / "v1")
|
||||
|
||||
client = OpenAI(
|
||||
api_key=self.runtime.credentials['openai_api_key'],
|
||||
api_key=self.runtime.credentials["openai_api_key"],
|
||||
base_url=openai_base_url,
|
||||
organization=openai_organization
|
||||
organization=openai_organization,
|
||||
)
|
||||
|
||||
SIZE_MAPPING = {
|
||||
'square': '1024x1024',
|
||||
'vertical': '1024x1792',
|
||||
'horizontal': '1792x1024',
|
||||
"square": "1024x1024",
|
||||
"vertical": "1024x1792",
|
||||
"horizontal": "1792x1024",
|
||||
}
|
||||
|
||||
# prompt
|
||||
prompt = tool_parameters.get('prompt', '')
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
if not prompt:
|
||||
return self.create_text_message('Please input prompt')
|
||||
return self.create_text_message("Please input prompt")
|
||||
# get size
|
||||
size = SIZE_MAPPING[tool_parameters.get('size', 'square')]
|
||||
size = SIZE_MAPPING[tool_parameters.get("size", "square")]
|
||||
# get n
|
||||
n = tool_parameters.get('n', 1)
|
||||
n = tool_parameters.get("n", 1)
|
||||
# get quality
|
||||
quality = tool_parameters.get('quality', 'standard')
|
||||
if quality not in ['standard', 'hd']:
|
||||
return self.create_text_message('Invalid quality')
|
||||
quality = tool_parameters.get("quality", "standard")
|
||||
if quality not in {"standard", "hd"}:
|
||||
return self.create_text_message("Invalid quality")
|
||||
# get style
|
||||
style = tool_parameters.get('style', 'vivid')
|
||||
if style not in ['natural', 'vivid']:
|
||||
return self.create_text_message('Invalid style')
|
||||
style = tool_parameters.get("style", "vivid")
|
||||
if style not in {"natural", "vivid"}:
|
||||
return self.create_text_message("Invalid style")
|
||||
|
||||
# call openapi dalle3
|
||||
response = client.images.generate(
|
||||
prompt=prompt,
|
||||
model='dall-e-3',
|
||||
size=size,
|
||||
n=n,
|
||||
style=style,
|
||||
quality=quality,
|
||||
response_format='b64_json'
|
||||
prompt=prompt, model="dall-e-3", size=size, n=n, style=style, quality=quality, response_format="b64_json"
|
||||
)
|
||||
|
||||
result = []
|
||||
|
||||
for image in response.data:
|
||||
mime_type, blob_image = DallE3Tool._decode_image(image.b64_json)
|
||||
blob_message = self.create_blob_message(blob=blob_image,
|
||||
meta={'mime_type': mime_type},
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value)
|
||||
blob_message = self.create_blob_message(
|
||||
blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VariableKey.IMAGE.value
|
||||
)
|
||||
result.append(blob_message)
|
||||
return result
|
||||
|
||||
@ -86,7 +81,7 @@ class DallE3Tool(BuiltinTool):
|
||||
:return: A tuple containing the MIME type and the decoded image bytes
|
||||
"""
|
||||
if DallE3Tool._is_plain_base64(base64_image):
|
||||
return 'image/png', base64.b64decode(base64_image)
|
||||
return "image/png", base64.b64decode(base64_image)
|
||||
else:
|
||||
return DallE3Tool._extract_mime_and_data(base64_image)
|
||||
|
||||
@ -98,7 +93,7 @@ class DallE3Tool(BuiltinTool):
|
||||
:param encoded_str: Base64 encoded image string
|
||||
:return: True if the string is plain base64, False otherwise
|
||||
"""
|
||||
return not encoded_str.startswith('data:image')
|
||||
return not encoded_str.startswith("data:image")
|
||||
|
||||
@staticmethod
|
||||
def _extract_mime_and_data(encoded_str: str) -> tuple[str, bytes]:
|
||||
@ -108,13 +103,13 @@ class DallE3Tool(BuiltinTool):
|
||||
:param encoded_str: Base64 encoded image string with MIME type prefix
|
||||
:return: A tuple containing the MIME type and the decoded image bytes
|
||||
"""
|
||||
mime_type = encoded_str.split(';')[0].split(':')[1]
|
||||
image_data_base64 = encoded_str.split(',')[1]
|
||||
mime_type = encoded_str.split(";")[0].split(":")[1]
|
||||
image_data_base64 = encoded_str.split(",")[1]
|
||||
decoded_data = base64.b64decode(image_data_base64)
|
||||
return mime_type, decoded_data
|
||||
|
||||
@staticmethod
|
||||
def _generate_random_id(length=8):
|
||||
characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
|
||||
random_id = ''.join(random.choices(characters, k=length))
|
||||
characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
||||
random_id = "".join(random.choices(characters, k=length))
|
||||
return random_id
|
||||
|
||||
@ -11,7 +11,7 @@ class DevDocsProvider(BuiltinToolProviderController):
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"doc": "python~3.12",
|
||||
"topic": "library/code",
|
||||
@ -19,4 +19,3 @@ class DevDocsProvider(BuiltinToolProviderController):
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
@ -13,7 +13,9 @@ class SearchDevDocsInput(BaseModel):
|
||||
|
||||
|
||||
class SearchDevDocsTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
Invokes the DevDocs search tool with the given user ID and tool parameters.
|
||||
|
||||
@ -22,15 +24,16 @@ class SearchDevDocsTool(BuiltinTool):
|
||||
tool_parameters (dict[str, Any]): The parameters for the tool, including 'doc' and 'topic'.
|
||||
|
||||
Returns:
|
||||
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, which can be a single message or a list of messages.
|
||||
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation,
|
||||
which can be a single message or a list of messages.
|
||||
"""
|
||||
doc = tool_parameters.get('doc', '')
|
||||
topic = tool_parameters.get('topic', '')
|
||||
doc = tool_parameters.get("doc", "")
|
||||
topic = tool_parameters.get("topic", "")
|
||||
|
||||
if not doc:
|
||||
return self.create_text_message('Please provide the documentation name.')
|
||||
return self.create_text_message("Please provide the documentation name.")
|
||||
if not topic:
|
||||
return self.create_text_message('Please provide the topic path.')
|
||||
return self.create_text_message("Please provide the topic path.")
|
||||
|
||||
url = f"https://documents.devdocs.io/{doc}/{topic}.html"
|
||||
response = requests.get(url)
|
||||
@ -39,4 +42,6 @@ class SearchDevDocsTool(BuiltinTool):
|
||||
content = response.text
|
||||
return self.create_text_message(self.summary(user_id=user_id, content=content))
|
||||
else:
|
||||
return self.create_text_message(f"Failed to retrieve the documentation. Status code: {response.status_code}")
|
||||
return self.create_text_message(
|
||||
f"Failed to retrieve the documentation. Status code: {response.status_code}"
|
||||
)
|
||||
|
||||
@ -7,15 +7,12 @@ class DIDProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
# Example validation using the D-ID talks tool
|
||||
TalksTool().fork_tool_runtime(
|
||||
runtime={"credentials": credentials}
|
||||
).invoke(
|
||||
user_id='',
|
||||
TalksTool().fork_tool_runtime(runtime={"credentials": credentials}).invoke(
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"source_url": "https://www.d-id.com/wp-content/uploads/2023/11/Hero-image-1.png",
|
||||
"text_input": "Hello, welcome to use D-ID tool in Dify",
|
||||
}
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
@ -12,14 +12,14 @@ logger = logging.getLogger(__name__)
|
||||
class DIDApp:
|
||||
def __init__(self, api_key: str | None = None, base_url: str | None = None):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url or 'https://api.d-id.com'
|
||||
self.base_url = base_url or "https://api.d-id.com"
|
||||
if not self.api_key:
|
||||
raise ValueError('API key is required')
|
||||
raise ValueError("API key is required")
|
||||
|
||||
def _prepare_headers(self, idempotency_key: str | None = None):
|
||||
headers = {'Content-Type': 'application/json', 'Authorization': f'Basic {self.api_key}'}
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Basic {self.api_key}"}
|
||||
if idempotency_key:
|
||||
headers['Idempotency-Key'] = idempotency_key
|
||||
headers["Idempotency-Key"] = idempotency_key
|
||||
return headers
|
||||
|
||||
def _request(
|
||||
@ -44,44 +44,44 @@ class DIDApp:
|
||||
return None
|
||||
|
||||
def talks(self, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs):
|
||||
endpoint = f'{self.base_url}/talks'
|
||||
endpoint = f"{self.base_url}/talks"
|
||||
headers = self._prepare_headers(idempotency_key)
|
||||
data = kwargs['params']
|
||||
logger.debug(f'Send request to {endpoint=} body={data}')
|
||||
response = self._request('POST', endpoint, data, headers)
|
||||
data = kwargs["params"]
|
||||
logger.debug(f"Send request to {endpoint=} body={data}")
|
||||
response = self._request("POST", endpoint, data, headers)
|
||||
if response is None:
|
||||
raise HTTPError('Failed to initiate D-ID talks after multiple retries')
|
||||
id: str = response['id']
|
||||
raise HTTPError("Failed to initiate D-ID talks after multiple retries")
|
||||
id: str = response["id"]
|
||||
if wait:
|
||||
return self._monitor_job_status(id=id, target='talks', poll_interval=poll_interval)
|
||||
return self._monitor_job_status(id=id, target="talks", poll_interval=poll_interval)
|
||||
return id
|
||||
|
||||
def animations(self, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs):
|
||||
endpoint = f'{self.base_url}/animations'
|
||||
endpoint = f"{self.base_url}/animations"
|
||||
headers = self._prepare_headers(idempotency_key)
|
||||
data = kwargs['params']
|
||||
logger.debug(f'Send request to {endpoint=} body={data}')
|
||||
response = self._request('POST', endpoint, data, headers)
|
||||
data = kwargs["params"]
|
||||
logger.debug(f"Send request to {endpoint=} body={data}")
|
||||
response = self._request("POST", endpoint, data, headers)
|
||||
if response is None:
|
||||
raise HTTPError('Failed to initiate D-ID talks after multiple retries')
|
||||
id: str = response['id']
|
||||
raise HTTPError("Failed to initiate D-ID talks after multiple retries")
|
||||
id: str = response["id"]
|
||||
if wait:
|
||||
return self._monitor_job_status(target='animations', id=id, poll_interval=poll_interval)
|
||||
return self._monitor_job_status(target="animations", id=id, poll_interval=poll_interval)
|
||||
return id
|
||||
|
||||
def check_did_status(self, target: str, id: str):
|
||||
endpoint = f'{self.base_url}/{target}/{id}'
|
||||
endpoint = f"{self.base_url}/{target}/{id}"
|
||||
headers = self._prepare_headers()
|
||||
response = self._request('GET', endpoint, headers=headers)
|
||||
response = self._request("GET", endpoint, headers=headers)
|
||||
if response is None:
|
||||
raise HTTPError(f'Failed to check status for talks {id} after multiple retries')
|
||||
raise HTTPError(f"Failed to check status for talks {id} after multiple retries")
|
||||
return response
|
||||
|
||||
def _monitor_job_status(self, target: str, id: str, poll_interval: int):
|
||||
while True:
|
||||
status = self.check_did_status(target=target, id=id)
|
||||
if status['status'] == 'done':
|
||||
if status["status"] == "done":
|
||||
return status
|
||||
elif status['status'] == 'error' or status['status'] == 'rejected':
|
||||
raise HTTPError(f'Talks {id} failed: {status["status"]} {status.get("error",{}).get("description")}')
|
||||
elif status["status"] == "error" or status["status"] == "rejected":
|
||||
raise HTTPError(f'Talks {id} failed: {status["status"]} {status.get("error", {}).get("description")}')
|
||||
time.sleep(poll_interval)
|
||||
|
||||
@ -10,33 +10,33 @@ class AnimationsTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
app = DIDApp(api_key=self.runtime.credentials['did_api_key'], base_url=self.runtime.credentials['base_url'])
|
||||
app = DIDApp(api_key=self.runtime.credentials["did_api_key"], base_url=self.runtime.credentials["base_url"])
|
||||
|
||||
driver_expressions_str = tool_parameters.get('driver_expressions')
|
||||
driver_expressions_str = tool_parameters.get("driver_expressions")
|
||||
driver_expressions = json.loads(driver_expressions_str) if driver_expressions_str else None
|
||||
|
||||
config = {
|
||||
'stitch': tool_parameters.get('stitch', True),
|
||||
'mute': tool_parameters.get('mute'),
|
||||
'result_format': tool_parameters.get('result_format') or 'mp4',
|
||||
"stitch": tool_parameters.get("stitch", True),
|
||||
"mute": tool_parameters.get("mute"),
|
||||
"result_format": tool_parameters.get("result_format") or "mp4",
|
||||
}
|
||||
config = {k: v for k, v in config.items() if v is not None and v != ''}
|
||||
config = {k: v for k, v in config.items() if v is not None and v != ""}
|
||||
|
||||
options = {
|
||||
'source_url': tool_parameters['source_url'],
|
||||
'driver_url': tool_parameters.get('driver_url'),
|
||||
'config': config,
|
||||
"source_url": tool_parameters["source_url"],
|
||||
"driver_url": tool_parameters.get("driver_url"),
|
||||
"config": config,
|
||||
}
|
||||
options = {k: v for k, v in options.items() if v is not None and v != ''}
|
||||
options = {k: v for k, v in options.items() if v is not None and v != ""}
|
||||
|
||||
if not options.get('source_url'):
|
||||
raise ValueError('Source URL is required')
|
||||
if not options.get("source_url"):
|
||||
raise ValueError("Source URL is required")
|
||||
|
||||
if config.get('logo_url'):
|
||||
if not config.get('logo_x'):
|
||||
raise ValueError('Logo X position is required when logo URL is provided')
|
||||
if not config.get('logo_y'):
|
||||
raise ValueError('Logo Y position is required when logo URL is provided')
|
||||
if config.get("logo_url"):
|
||||
if not config.get("logo_x"):
|
||||
raise ValueError("Logo X position is required when logo URL is provided")
|
||||
if not config.get("logo_y"):
|
||||
raise ValueError("Logo Y position is required when logo URL is provided")
|
||||
|
||||
animations_result = app.animations(params=options, wait=True)
|
||||
|
||||
@ -44,6 +44,6 @@ class AnimationsTool(BuiltinTool):
|
||||
animations_result = json.dumps(animations_result, ensure_ascii=False, indent=4)
|
||||
|
||||
if not animations_result:
|
||||
return self.create_text_message('D-ID animations request failed.')
|
||||
return self.create_text_message("D-ID animations request failed.")
|
||||
|
||||
return self.create_text_message(animations_result)
|
||||
|
||||
@ -10,49 +10,49 @@ class TalksTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
app = DIDApp(api_key=self.runtime.credentials['did_api_key'], base_url=self.runtime.credentials['base_url'])
|
||||
app = DIDApp(api_key=self.runtime.credentials["did_api_key"], base_url=self.runtime.credentials["base_url"])
|
||||
|
||||
driver_expressions_str = tool_parameters.get('driver_expressions')
|
||||
driver_expressions_str = tool_parameters.get("driver_expressions")
|
||||
driver_expressions = json.loads(driver_expressions_str) if driver_expressions_str else None
|
||||
|
||||
script = {
|
||||
'type': tool_parameters.get('script_type') or 'text',
|
||||
'input': tool_parameters.get('text_input'),
|
||||
'audio_url': tool_parameters.get('audio_url'),
|
||||
'reduce_noise': tool_parameters.get('audio_reduce_noise', False),
|
||||
"type": tool_parameters.get("script_type") or "text",
|
||||
"input": tool_parameters.get("text_input"),
|
||||
"audio_url": tool_parameters.get("audio_url"),
|
||||
"reduce_noise": tool_parameters.get("audio_reduce_noise", False),
|
||||
}
|
||||
script = {k: v for k, v in script.items() if v is not None and v != ''}
|
||||
script = {k: v for k, v in script.items() if v is not None and v != ""}
|
||||
config = {
|
||||
'stitch': tool_parameters.get('stitch', True),
|
||||
'sharpen': tool_parameters.get('sharpen'),
|
||||
'fluent': tool_parameters.get('fluent'),
|
||||
'result_format': tool_parameters.get('result_format') or 'mp4',
|
||||
'pad_audio': tool_parameters.get('pad_audio'),
|
||||
'driver_expressions': driver_expressions,
|
||||
"stitch": tool_parameters.get("stitch", True),
|
||||
"sharpen": tool_parameters.get("sharpen"),
|
||||
"fluent": tool_parameters.get("fluent"),
|
||||
"result_format": tool_parameters.get("result_format") or "mp4",
|
||||
"pad_audio": tool_parameters.get("pad_audio"),
|
||||
"driver_expressions": driver_expressions,
|
||||
}
|
||||
config = {k: v for k, v in config.items() if v is not None and v != ''}
|
||||
config = {k: v for k, v in config.items() if v is not None and v != ""}
|
||||
|
||||
options = {
|
||||
'source_url': tool_parameters['source_url'],
|
||||
'driver_url': tool_parameters.get('driver_url'),
|
||||
'script': script,
|
||||
'config': config,
|
||||
"source_url": tool_parameters["source_url"],
|
||||
"driver_url": tool_parameters.get("driver_url"),
|
||||
"script": script,
|
||||
"config": config,
|
||||
}
|
||||
options = {k: v for k, v in options.items() if v is not None and v != ''}
|
||||
options = {k: v for k, v in options.items() if v is not None and v != ""}
|
||||
|
||||
if not options.get('source_url'):
|
||||
raise ValueError('Source URL is required')
|
||||
if not options.get("source_url"):
|
||||
raise ValueError("Source URL is required")
|
||||
|
||||
if script.get('type') == 'audio':
|
||||
script.pop('input', None)
|
||||
if not script.get('audio_url'):
|
||||
raise ValueError('Audio URL is required for audio script type')
|
||||
if script.get("type") == "audio":
|
||||
script.pop("input", None)
|
||||
if not script.get("audio_url"):
|
||||
raise ValueError("Audio URL is required for audio script type")
|
||||
|
||||
if script.get('type') == 'text':
|
||||
script.pop('audio_url', None)
|
||||
script.pop('reduce_noise', None)
|
||||
if not script.get('input'):
|
||||
raise ValueError('Text input is required for text script type')
|
||||
if script.get("type") == "text":
|
||||
script.pop("audio_url", None)
|
||||
script.pop("reduce_noise", None)
|
||||
if not script.get("input"):
|
||||
raise ValueError("Text input is required for text script type")
|
||||
|
||||
talks_result = app.talks(params=options, wait=True)
|
||||
|
||||
@ -60,6 +60,6 @@ class TalksTool(BuiltinTool):
|
||||
talks_result = json.dumps(talks_result, ensure_ascii=False, indent=4)
|
||||
|
||||
if not talks_result:
|
||||
return self.create_text_message('D-ID talks request failed.')
|
||||
return self.create_text_message("D-ID talks request failed.")
|
||||
|
||||
return self.create_text_message(talks_result)
|
||||
|
||||
@ -13,38 +13,43 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class DingTalkGroupBotTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
Dingtalk custom group robot API docs:
|
||||
https://open.dingtalk.com/document/orgapp/custom-robot-access
|
||||
invoke tools
|
||||
Dingtalk custom group robot API docs:
|
||||
https://open.dingtalk.com/document/orgapp/custom-robot-access
|
||||
"""
|
||||
content = tool_parameters.get('content')
|
||||
content = tool_parameters.get("content")
|
||||
if not content:
|
||||
return self.create_text_message('Invalid parameter content')
|
||||
return self.create_text_message("Invalid parameter content")
|
||||
|
||||
access_token = tool_parameters.get('access_token')
|
||||
access_token = tool_parameters.get("access_token")
|
||||
if not access_token:
|
||||
return self.create_text_message('Invalid parameter access_token. '
|
||||
'Regarding information about security details,'
|
||||
'please refer to the DingTalk docs:'
|
||||
'https://open.dingtalk.com/document/robots/customize-robot-security-settings')
|
||||
return self.create_text_message(
|
||||
"Invalid parameter access_token. "
|
||||
"Regarding information about security details,"
|
||||
"please refer to the DingTalk docs:"
|
||||
"https://open.dingtalk.com/document/robots/customize-robot-security-settings"
|
||||
)
|
||||
|
||||
sign_secret = tool_parameters.get('sign_secret')
|
||||
sign_secret = tool_parameters.get("sign_secret")
|
||||
if not sign_secret:
|
||||
return self.create_text_message('Invalid parameter sign_secret. '
|
||||
'Regarding information about security details,'
|
||||
'please refer to the DingTalk docs:'
|
||||
'https://open.dingtalk.com/document/robots/customize-robot-security-settings')
|
||||
return self.create_text_message(
|
||||
"Invalid parameter sign_secret. "
|
||||
"Regarding information about security details,"
|
||||
"please refer to the DingTalk docs:"
|
||||
"https://open.dingtalk.com/document/robots/customize-robot-security-settings"
|
||||
)
|
||||
|
||||
msgtype = 'text'
|
||||
api_url = 'https://oapi.dingtalk.com/robot/send'
|
||||
msgtype = "text"
|
||||
api_url = "https://oapi.dingtalk.com/robot/send"
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
params = {
|
||||
'access_token': access_token,
|
||||
"access_token": access_token,
|
||||
}
|
||||
|
||||
self._apply_security_mechanism(params, sign_secret)
|
||||
@ -53,7 +58,7 @@ class DingTalkGroupBotTool(BuiltinTool):
|
||||
"msgtype": msgtype,
|
||||
"text": {
|
||||
"content": content,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
try:
|
||||
@ -62,7 +67,8 @@ class DingTalkGroupBotTool(BuiltinTool):
|
||||
return self.create_text_message("Text message sent successfully")
|
||||
else:
|
||||
return self.create_text_message(
|
||||
f"Failed to send the text message, status code: {res.status_code}, response: {res.text}")
|
||||
f"Failed to send the text message, status code: {res.status_code}, response: {res.text}"
|
||||
)
|
||||
except Exception as e:
|
||||
return self.create_text_message("Failed to send message to group chat bot. {}".format(e))
|
||||
|
||||
@ -70,14 +76,14 @@ class DingTalkGroupBotTool(BuiltinTool):
|
||||
def _apply_security_mechanism(params: dict[str, Any], sign_secret: str):
|
||||
try:
|
||||
timestamp = str(round(time.time() * 1000))
|
||||
secret_enc = sign_secret.encode('utf-8')
|
||||
string_to_sign = f'{timestamp}\n{sign_secret}'
|
||||
string_to_sign_enc = string_to_sign.encode('utf-8')
|
||||
secret_enc = sign_secret.encode("utf-8")
|
||||
string_to_sign = f"{timestamp}\n{sign_secret}"
|
||||
string_to_sign_enc = string_to_sign.encode("utf-8")
|
||||
hmac_code = hmac.new(secret_enc, string_to_sign_enc, digestmod=hashlib.sha256).digest()
|
||||
sign = urllib.parse.quote_plus(base64.b64encode(hmac_code))
|
||||
|
||||
params['timestamp'] = timestamp
|
||||
params['sign'] = sign
|
||||
params["timestamp"] = timestamp
|
||||
params["sign"] = sign
|
||||
except Exception:
|
||||
msg = "Failed to apply security mechanism to the request."
|
||||
logging.exception(msg)
|
||||
|
||||
@ -11,11 +11,10 @@ class DuckDuckGoProvider(BuiltinToolProviderController):
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"query": "John Doe",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
@ -13,8 +13,8 @@ class DuckDuckGoAITool(BuiltinTool):
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
query_dict = {
|
||||
"keywords": tool_parameters.get('query'),
|
||||
"model": tool_parameters.get('model'),
|
||||
"keywords": tool_parameters.get("query"),
|
||||
"model": tool_parameters.get("model"),
|
||||
}
|
||||
response = DDGS().chat(**query_dict)
|
||||
return self.create_text_message(text=response)
|
||||
|
||||
@ -14,18 +14,17 @@ class DuckDuckGoImageSearchTool(BuiltinTool):
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
|
||||
query_dict = {
|
||||
"keywords": tool_parameters.get('query'),
|
||||
"timelimit": tool_parameters.get('timelimit'),
|
||||
"size": tool_parameters.get('size'),
|
||||
"max_results": tool_parameters.get('max_results'),
|
||||
"keywords": tool_parameters.get("query"),
|
||||
"timelimit": tool_parameters.get("timelimit"),
|
||||
"size": tool_parameters.get("size"),
|
||||
"max_results": tool_parameters.get("max_results"),
|
||||
}
|
||||
response = DDGS().images(**query_dict)
|
||||
result = []
|
||||
for res in response:
|
||||
res['transfer_method'] = FileTransferMethod.REMOTE_URL
|
||||
msg = ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=res.get('image'),
|
||||
save_as='',
|
||||
meta=res)
|
||||
res["transfer_method"] = FileTransferMethod.REMOTE_URL
|
||||
msg = ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=res.get("image"), save_as="", meta=res
|
||||
)
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
@ -21,10 +21,11 @@ class DuckDuckGoSearchTool(BuiltinTool):
|
||||
"""
|
||||
Tool for performing a search using DuckDuckGo search engine.
|
||||
"""
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
query = tool_parameters.get('query')
|
||||
max_results = tool_parameters.get('max_results', 5)
|
||||
require_summary = tool_parameters.get('require_summary', False)
|
||||
query = tool_parameters.get("query")
|
||||
max_results = tool_parameters.get("max_results", 5)
|
||||
require_summary = tool_parameters.get("require_summary", False)
|
||||
response = DDGS().text(query, max_results=max_results)
|
||||
if require_summary:
|
||||
results = "\n".join([res.get("body") for res in response])
|
||||
@ -34,7 +35,11 @@ class DuckDuckGoSearchTool(BuiltinTool):
|
||||
|
||||
def summary_results(self, user_id: str, content: str, query: str) -> str:
|
||||
prompt = SUMMARY_PROMPT.format(query=query, content=content)
|
||||
summary = self.invoke_model(user_id=user_id, prompt_messages=[
|
||||
SystemPromptMessage(content=prompt),
|
||||
], stop=[])
|
||||
summary = self.invoke_model(
|
||||
user_id=user_id,
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(content=prompt),
|
||||
],
|
||||
stop=[],
|
||||
)
|
||||
return summary.message.content
|
||||
|
||||
@ -13,8 +13,8 @@ class DuckDuckGoTranslateTool(BuiltinTool):
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
query_dict = {
|
||||
"keywords": tool_parameters.get('query'),
|
||||
"to": tool_parameters.get('translate_to'),
|
||||
"keywords": tool_parameters.get("query"),
|
||||
"to": tool_parameters.get("translate_to"),
|
||||
}
|
||||
response = DDGS().translate(**query_dict)[0].get('translated', 'Unable to translate!')
|
||||
response = DDGS().translate(**query_dict)[0].get("translated", "Unable to translate!")
|
||||
return self.create_text_message(text=response)
|
||||
|
||||
@ -8,35 +8,35 @@ from core.tools.utils.uuid_utils import is_valid_uuid
|
||||
|
||||
|
||||
class FeishuGroupBotTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
API document: https://open.feishu.cn/document/client-docs/bot-v3/add-custom-bot
|
||||
invoke tools
|
||||
API document: https://open.feishu.cn/document/client-docs/bot-v3/add-custom-bot
|
||||
"""
|
||||
|
||||
url = "https://open.feishu.cn/open-apis/bot/v2/hook"
|
||||
|
||||
content = tool_parameters.get('content', '')
|
||||
content = tool_parameters.get("content", "")
|
||||
if not content:
|
||||
return self.create_text_message('Invalid parameter content')
|
||||
return self.create_text_message("Invalid parameter content")
|
||||
|
||||
hook_key = tool_parameters.get('hook_key', '')
|
||||
hook_key = tool_parameters.get("hook_key", "")
|
||||
if not is_valid_uuid(hook_key):
|
||||
return self.create_text_message(
|
||||
f'Invalid parameter hook_key ${hook_key}, not a valid UUID')
|
||||
return self.create_text_message(f"Invalid parameter hook_key ${hook_key}, not a valid UUID")
|
||||
|
||||
msg_type = 'text'
|
||||
api_url = f'{url}/{hook_key}'
|
||||
msg_type = "text"
|
||||
api_url = f"{url}/{hook_key}"
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
params = {}
|
||||
payload = {
|
||||
"msg_type": msg_type,
|
||||
"content": {
|
||||
"text": content,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
try:
|
||||
@ -45,6 +45,7 @@ class FeishuGroupBotTool(BuiltinTool):
|
||||
return self.create_text_message("Text message sent successfully")
|
||||
else:
|
||||
return self.create_text_message(
|
||||
f"Failed to send the text message, status code: {res.status_code}, response: {res.text}")
|
||||
f"Failed to send the text message, status code: {res.status_code}, response: {res.text}"
|
||||
)
|
||||
except Exception as e:
|
||||
return self.create_text_message("Failed to send message to group chat bot. {}".format(e))
|
||||
return self.create_text_message("Failed to send message to group chat bot. {}".format(e))
|
||||
|
||||
@ -5,4 +5,4 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
|
||||
class FeishuBaseProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
GetTenantAccessTokenTool()
|
||||
pass
|
||||
pass
|
||||
|
||||
@ -8,45 +8,49 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class AddBaseRecordTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records"
|
||||
|
||||
access_token = tool_parameters.get('Authorization', '')
|
||||
access_token = tool_parameters.get("Authorization", "")
|
||||
if not access_token:
|
||||
return self.create_text_message('Invalid parameter access_token')
|
||||
return self.create_text_message("Invalid parameter access_token")
|
||||
|
||||
app_token = tool_parameters.get('app_token', '')
|
||||
app_token = tool_parameters.get("app_token", "")
|
||||
if not app_token:
|
||||
return self.create_text_message('Invalid parameter app_token')
|
||||
return self.create_text_message("Invalid parameter app_token")
|
||||
|
||||
table_id = tool_parameters.get('table_id', '')
|
||||
table_id = tool_parameters.get("table_id", "")
|
||||
if not table_id:
|
||||
return self.create_text_message('Invalid parameter table_id')
|
||||
return self.create_text_message("Invalid parameter table_id")
|
||||
|
||||
fields = tool_parameters.get('fields', '')
|
||||
fields = tool_parameters.get("fields", "")
|
||||
if not fields:
|
||||
return self.create_text_message('Invalid parameter fields')
|
||||
return self.create_text_message("Invalid parameter fields")
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
params = {}
|
||||
payload = {
|
||||
"fields": json.loads(fields)
|
||||
}
|
||||
payload = {"fields": json.loads(fields)}
|
||||
|
||||
try:
|
||||
res = httpx.post(url.format(app_token=app_token, table_id=table_id), headers=headers, params=params,
|
||||
json=payload, timeout=30)
|
||||
res = httpx.post(
|
||||
url.format(app_token=app_token, table_id=table_id),
|
||||
headers=headers,
|
||||
params=params,
|
||||
json=payload,
|
||||
timeout=30,
|
||||
)
|
||||
res_json = res.json()
|
||||
if res.is_success:
|
||||
return self.create_text_message(text=json.dumps(res_json))
|
||||
else:
|
||||
return self.create_text_message(
|
||||
f"Failed to add base record, status code: {res.status_code}, response: {res.text}")
|
||||
f"Failed to add base record, status code: {res.status_code}, response: {res.text}"
|
||||
)
|
||||
except Exception as e:
|
||||
return self.create_text_message("Failed to add base record. {}".format(e))
|
||||
|
||||
@ -8,28 +8,25 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class CreateBaseTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
url = "https://open.feishu.cn/open-apis/bitable/v1/apps"
|
||||
|
||||
access_token = tool_parameters.get('Authorization', '')
|
||||
access_token = tool_parameters.get("Authorization", "")
|
||||
if not access_token:
|
||||
return self.create_text_message('Invalid parameter access_token')
|
||||
return self.create_text_message("Invalid parameter access_token")
|
||||
|
||||
name = tool_parameters.get('name', '')
|
||||
folder_token = tool_parameters.get('folder_token', '')
|
||||
name = tool_parameters.get("name", "")
|
||||
folder_token = tool_parameters.get("folder_token", "")
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
params = {}
|
||||
payload = {
|
||||
"name": name,
|
||||
"folder_token": folder_token
|
||||
}
|
||||
payload = {"name": name, "folder_token": folder_token}
|
||||
|
||||
try:
|
||||
res = httpx.post(url, headers=headers, params=params, json=payload, timeout=30)
|
||||
@ -38,6 +35,7 @@ class CreateBaseTool(BuiltinTool):
|
||||
return self.create_text_message(text=json.dumps(res_json))
|
||||
else:
|
||||
return self.create_text_message(
|
||||
f"Failed to create base, status code: {res.status_code}, response: {res.text}")
|
||||
f"Failed to create base, status code: {res.status_code}, response: {res.text}"
|
||||
)
|
||||
except Exception as e:
|
||||
return self.create_text_message("Failed to create base. {}".format(e))
|
||||
|
||||
@ -8,37 +8,32 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class CreateBaseTableTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables"
|
||||
|
||||
access_token = tool_parameters.get('Authorization', '')
|
||||
access_token = tool_parameters.get("Authorization", "")
|
||||
if not access_token:
|
||||
return self.create_text_message('Invalid parameter access_token')
|
||||
return self.create_text_message("Invalid parameter access_token")
|
||||
|
||||
app_token = tool_parameters.get('app_token', '')
|
||||
app_token = tool_parameters.get("app_token", "")
|
||||
if not app_token:
|
||||
return self.create_text_message('Invalid parameter app_token')
|
||||
return self.create_text_message("Invalid parameter app_token")
|
||||
|
||||
name = tool_parameters.get('name', '')
|
||||
name = tool_parameters.get("name", "")
|
||||
|
||||
fields = tool_parameters.get('fields', '')
|
||||
fields = tool_parameters.get("fields", "")
|
||||
if not fields:
|
||||
return self.create_text_message('Invalid parameter fields')
|
||||
return self.create_text_message("Invalid parameter fields")
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
params = {}
|
||||
payload = {
|
||||
"table": {
|
||||
"name": name,
|
||||
"fields": json.loads(fields)
|
||||
}
|
||||
}
|
||||
payload = {"table": {"name": name, "fields": json.loads(fields)}}
|
||||
|
||||
try:
|
||||
res = httpx.post(url.format(app_token=app_token), headers=headers, params=params, json=payload, timeout=30)
|
||||
@ -47,6 +42,7 @@ class CreateBaseTableTool(BuiltinTool):
|
||||
return self.create_text_message(text=json.dumps(res_json))
|
||||
else:
|
||||
return self.create_text_message(
|
||||
f"Failed to create base table, status code: {res.status_code}, response: {res.text}")
|
||||
f"Failed to create base table, status code: {res.status_code}, response: {res.text}"
|
||||
)
|
||||
except Exception as e:
|
||||
return self.create_text_message("Failed to create base table. {}".format(e))
|
||||
|
||||
@ -8,45 +8,49 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class DeleteBaseRecordsTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/batch_delete"
|
||||
|
||||
access_token = tool_parameters.get('Authorization', '')
|
||||
access_token = tool_parameters.get("Authorization", "")
|
||||
if not access_token:
|
||||
return self.create_text_message('Invalid parameter access_token')
|
||||
return self.create_text_message("Invalid parameter access_token")
|
||||
|
||||
app_token = tool_parameters.get('app_token', '')
|
||||
app_token = tool_parameters.get("app_token", "")
|
||||
if not app_token:
|
||||
return self.create_text_message('Invalid parameter app_token')
|
||||
return self.create_text_message("Invalid parameter app_token")
|
||||
|
||||
table_id = tool_parameters.get('table_id', '')
|
||||
table_id = tool_parameters.get("table_id", "")
|
||||
if not table_id:
|
||||
return self.create_text_message('Invalid parameter table_id')
|
||||
return self.create_text_message("Invalid parameter table_id")
|
||||
|
||||
record_ids = tool_parameters.get('record_ids', '')
|
||||
record_ids = tool_parameters.get("record_ids", "")
|
||||
if not record_ids:
|
||||
return self.create_text_message('Invalid parameter record_ids')
|
||||
return self.create_text_message("Invalid parameter record_ids")
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
params = {}
|
||||
payload = {
|
||||
"records": json.loads(record_ids)
|
||||
}
|
||||
payload = {"records": json.loads(record_ids)}
|
||||
|
||||
try:
|
||||
res = httpx.post(url.format(app_token=app_token, table_id=table_id), headers=headers, params=params,
|
||||
json=payload, timeout=30)
|
||||
res = httpx.post(
|
||||
url.format(app_token=app_token, table_id=table_id),
|
||||
headers=headers,
|
||||
params=params,
|
||||
json=payload,
|
||||
timeout=30,
|
||||
)
|
||||
res_json = res.json()
|
||||
if res.is_success:
|
||||
return self.create_text_message(text=json.dumps(res_json))
|
||||
else:
|
||||
return self.create_text_message(
|
||||
f"Failed to delete base records, status code: {res.status_code}, response: {res.text}")
|
||||
f"Failed to delete base records, status code: {res.status_code}, response: {res.text}"
|
||||
)
|
||||
except Exception as e:
|
||||
return self.create_text_message("Failed to delete base records. {}".format(e))
|
||||
|
||||
@ -8,32 +8,30 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class DeleteBaseTablesTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/batch_delete"
|
||||
|
||||
access_token = tool_parameters.get('Authorization', '')
|
||||
access_token = tool_parameters.get("Authorization", "")
|
||||
if not access_token:
|
||||
return self.create_text_message('Invalid parameter access_token')
|
||||
return self.create_text_message("Invalid parameter access_token")
|
||||
|
||||
app_token = tool_parameters.get('app_token', '')
|
||||
app_token = tool_parameters.get("app_token", "")
|
||||
if not app_token:
|
||||
return self.create_text_message('Invalid parameter app_token')
|
||||
return self.create_text_message("Invalid parameter app_token")
|
||||
|
||||
table_ids = tool_parameters.get('table_ids', '')
|
||||
table_ids = tool_parameters.get("table_ids", "")
|
||||
if not table_ids:
|
||||
return self.create_text_message('Invalid parameter table_ids')
|
||||
return self.create_text_message("Invalid parameter table_ids")
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
params = {}
|
||||
payload = {
|
||||
"table_ids": json.loads(table_ids)
|
||||
}
|
||||
payload = {"table_ids": json.loads(table_ids)}
|
||||
|
||||
try:
|
||||
res = httpx.post(url.format(app_token=app_token), headers=headers, params=params, json=payload, timeout=30)
|
||||
@ -42,6 +40,7 @@ class DeleteBaseTablesTool(BuiltinTool):
|
||||
return self.create_text_message(text=json.dumps(res_json))
|
||||
else:
|
||||
return self.create_text_message(
|
||||
f"Failed to delete base tables, status code: {res.status_code}, response: {res.text}")
|
||||
f"Failed to delete base tables, status code: {res.status_code}, response: {res.text}"
|
||||
)
|
||||
except Exception as e:
|
||||
return self.create_text_message("Failed to delete base tables. {}".format(e))
|
||||
|
||||
@ -8,22 +8,22 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class GetBaseInfoTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}"
|
||||
|
||||
access_token = tool_parameters.get('Authorization', '')
|
||||
access_token = tool_parameters.get("Authorization", "")
|
||||
if not access_token:
|
||||
return self.create_text_message('Invalid parameter access_token')
|
||||
return self.create_text_message("Invalid parameter access_token")
|
||||
|
||||
app_token = tool_parameters.get('app_token', '')
|
||||
app_token = tool_parameters.get("app_token", "")
|
||||
if not app_token:
|
||||
return self.create_text_message('Invalid parameter app_token')
|
||||
return self.create_text_message("Invalid parameter app_token")
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
try:
|
||||
@ -33,6 +33,7 @@ class GetBaseInfoTool(BuiltinTool):
|
||||
return self.create_text_message(text=json.dumps(res_json))
|
||||
else:
|
||||
return self.create_text_message(
|
||||
f"Failed to get base info, status code: {res.status_code}, response: {res.text}")
|
||||
f"Failed to get base info, status code: {res.status_code}, response: {res.text}"
|
||||
)
|
||||
except Exception as e:
|
||||
return self.create_text_message("Failed to get base info. {}".format(e))
|
||||
|
||||
@ -8,27 +8,24 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class GetTenantAccessTokenTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
url = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal"
|
||||
|
||||
app_id = tool_parameters.get('app_id', '')
|
||||
app_id = tool_parameters.get("app_id", "")
|
||||
if not app_id:
|
||||
return self.create_text_message('Invalid parameter app_id')
|
||||
return self.create_text_message("Invalid parameter app_id")
|
||||
|
||||
app_secret = tool_parameters.get('app_secret', '')
|
||||
app_secret = tool_parameters.get("app_secret", "")
|
||||
if not app_secret:
|
||||
return self.create_text_message('Invalid parameter app_secret')
|
||||
return self.create_text_message("Invalid parameter app_secret")
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
params = {}
|
||||
payload = {
|
||||
"app_id": app_id,
|
||||
"app_secret": app_secret
|
||||
}
|
||||
payload = {"app_id": app_id, "app_secret": app_secret}
|
||||
|
||||
"""
|
||||
{
|
||||
@ -45,6 +42,7 @@ class GetTenantAccessTokenTool(BuiltinTool):
|
||||
return self.create_text_message(text=json.dumps(res_json))
|
||||
else:
|
||||
return self.create_text_message(
|
||||
f"Failed to get tenant access token, status code: {res.status_code}, response: {res.text}")
|
||||
f"Failed to get tenant access token, status code: {res.status_code}, response: {res.text}"
|
||||
)
|
||||
except Exception as e:
|
||||
return self.create_text_message("Failed to get tenant access token. {}".format(e))
|
||||
|
||||
@ -8,31 +8,31 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class ListBaseRecordsTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/search"
|
||||
|
||||
access_token = tool_parameters.get('Authorization', '')
|
||||
access_token = tool_parameters.get("Authorization", "")
|
||||
if not access_token:
|
||||
return self.create_text_message('Invalid parameter access_token')
|
||||
return self.create_text_message("Invalid parameter access_token")
|
||||
|
||||
app_token = tool_parameters.get('app_token', '')
|
||||
app_token = tool_parameters.get("app_token", "")
|
||||
if not app_token:
|
||||
return self.create_text_message('Invalid parameter app_token')
|
||||
return self.create_text_message("Invalid parameter app_token")
|
||||
|
||||
table_id = tool_parameters.get('table_id', '')
|
||||
table_id = tool_parameters.get("table_id", "")
|
||||
if not table_id:
|
||||
return self.create_text_message('Invalid parameter table_id')
|
||||
return self.create_text_message("Invalid parameter table_id")
|
||||
|
||||
page_token = tool_parameters.get('page_token', '')
|
||||
page_size = tool_parameters.get('page_size', '')
|
||||
sort_condition = tool_parameters.get('sort_condition', '')
|
||||
filter_condition = tool_parameters.get('filter_condition', '')
|
||||
page_token = tool_parameters.get("page_token", "")
|
||||
page_size = tool_parameters.get("page_size", "")
|
||||
sort_condition = tool_parameters.get("sort_condition", "")
|
||||
filter_condition = tool_parameters.get("filter_condition", "")
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
params = {
|
||||
@ -40,22 +40,26 @@ class ListBaseRecordsTool(BuiltinTool):
|
||||
"page_size": page_size,
|
||||
}
|
||||
|
||||
payload = {
|
||||
"automatic_fields": True
|
||||
}
|
||||
payload = {"automatic_fields": True}
|
||||
if sort_condition:
|
||||
payload["sort"] = json.loads(sort_condition)
|
||||
if filter_condition:
|
||||
payload["filter"] = json.loads(filter_condition)
|
||||
|
||||
try:
|
||||
res = httpx.post(url.format(app_token=app_token, table_id=table_id), headers=headers, params=params,
|
||||
json=payload, timeout=30)
|
||||
res = httpx.post(
|
||||
url.format(app_token=app_token, table_id=table_id),
|
||||
headers=headers,
|
||||
params=params,
|
||||
json=payload,
|
||||
timeout=30,
|
||||
)
|
||||
res_json = res.json()
|
||||
if res.is_success:
|
||||
return self.create_text_message(text=json.dumps(res_json))
|
||||
else:
|
||||
return self.create_text_message(
|
||||
f"Failed to list base records, status code: {res.status_code}, response: {res.text}")
|
||||
f"Failed to list base records, status code: {res.status_code}, response: {res.text}"
|
||||
)
|
||||
except Exception as e:
|
||||
return self.create_text_message("Failed to list base records. {}".format(e))
|
||||
|
||||
@ -8,25 +8,25 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class ListBaseTablesTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables"
|
||||
|
||||
access_token = tool_parameters.get('Authorization', '')
|
||||
access_token = tool_parameters.get("Authorization", "")
|
||||
if not access_token:
|
||||
return self.create_text_message('Invalid parameter access_token')
|
||||
return self.create_text_message("Invalid parameter access_token")
|
||||
|
||||
app_token = tool_parameters.get('app_token', '')
|
||||
app_token = tool_parameters.get("app_token", "")
|
||||
if not app_token:
|
||||
return self.create_text_message('Invalid parameter app_token')
|
||||
return self.create_text_message("Invalid parameter app_token")
|
||||
|
||||
page_token = tool_parameters.get('page_token', '')
|
||||
page_size = tool_parameters.get('page_size', '')
|
||||
page_token = tool_parameters.get("page_token", "")
|
||||
page_size = tool_parameters.get("page_size", "")
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
params = {
|
||||
@ -41,6 +41,7 @@ class ListBaseTablesTool(BuiltinTool):
|
||||
return self.create_text_message(text=json.dumps(res_json))
|
||||
else:
|
||||
return self.create_text_message(
|
||||
f"Failed to list base tables, status code: {res.status_code}, response: {res.text}")
|
||||
f"Failed to list base tables, status code: {res.status_code}, response: {res.text}"
|
||||
)
|
||||
except Exception as e:
|
||||
return self.create_text_message("Failed to list base tables. {}".format(e))
|
||||
|
||||
@ -8,40 +8,42 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class ReadBaseRecordTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/{record_id}"
|
||||
|
||||
access_token = tool_parameters.get('Authorization', '')
|
||||
access_token = tool_parameters.get("Authorization", "")
|
||||
if not access_token:
|
||||
return self.create_text_message('Invalid parameter access_token')
|
||||
return self.create_text_message("Invalid parameter access_token")
|
||||
|
||||
app_token = tool_parameters.get('app_token', '')
|
||||
app_token = tool_parameters.get("app_token", "")
|
||||
if not app_token:
|
||||
return self.create_text_message('Invalid parameter app_token')
|
||||
return self.create_text_message("Invalid parameter app_token")
|
||||
|
||||
table_id = tool_parameters.get('table_id', '')
|
||||
table_id = tool_parameters.get("table_id", "")
|
||||
if not table_id:
|
||||
return self.create_text_message('Invalid parameter table_id')
|
||||
return self.create_text_message("Invalid parameter table_id")
|
||||
|
||||
record_id = tool_parameters.get('record_id', '')
|
||||
record_id = tool_parameters.get("record_id", "")
|
||||
if not record_id:
|
||||
return self.create_text_message('Invalid parameter record_id')
|
||||
return self.create_text_message("Invalid parameter record_id")
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
try:
|
||||
res = httpx.get(url.format(app_token=app_token, table_id=table_id, record_id=record_id), headers=headers,
|
||||
timeout=30)
|
||||
res = httpx.get(
|
||||
url.format(app_token=app_token, table_id=table_id, record_id=record_id), headers=headers, timeout=30
|
||||
)
|
||||
res_json = res.json()
|
||||
if res.is_success:
|
||||
return self.create_text_message(text=json.dumps(res_json))
|
||||
else:
|
||||
return self.create_text_message(
|
||||
f"Failed to read base record, status code: {res.status_code}, response: {res.text}")
|
||||
f"Failed to read base record, status code: {res.status_code}, response: {res.text}"
|
||||
)
|
||||
except Exception as e:
|
||||
return self.create_text_message("Failed to read base record. {}".format(e))
|
||||
|
||||
@ -8,49 +8,53 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class UpdateBaseRecordTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/{record_id}"
|
||||
|
||||
access_token = tool_parameters.get('Authorization', '')
|
||||
access_token = tool_parameters.get("Authorization", "")
|
||||
if not access_token:
|
||||
return self.create_text_message('Invalid parameter access_token')
|
||||
return self.create_text_message("Invalid parameter access_token")
|
||||
|
||||
app_token = tool_parameters.get('app_token', '')
|
||||
app_token = tool_parameters.get("app_token", "")
|
||||
if not app_token:
|
||||
return self.create_text_message('Invalid parameter app_token')
|
||||
return self.create_text_message("Invalid parameter app_token")
|
||||
|
||||
table_id = tool_parameters.get('table_id', '')
|
||||
table_id = tool_parameters.get("table_id", "")
|
||||
if not table_id:
|
||||
return self.create_text_message('Invalid parameter table_id')
|
||||
return self.create_text_message("Invalid parameter table_id")
|
||||
|
||||
record_id = tool_parameters.get('record_id', '')
|
||||
record_id = tool_parameters.get("record_id", "")
|
||||
if not record_id:
|
||||
return self.create_text_message('Invalid parameter record_id')
|
||||
return self.create_text_message("Invalid parameter record_id")
|
||||
|
||||
fields = tool_parameters.get('fields', '')
|
||||
fields = tool_parameters.get("fields", "")
|
||||
if not fields:
|
||||
return self.create_text_message('Invalid parameter fields')
|
||||
return self.create_text_message("Invalid parameter fields")
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
params = {}
|
||||
payload = {
|
||||
"fields": json.loads(fields)
|
||||
}
|
||||
payload = {"fields": json.loads(fields)}
|
||||
|
||||
try:
|
||||
res = httpx.put(url.format(app_token=app_token, table_id=table_id, record_id=record_id), headers=headers,
|
||||
params=params, json=payload, timeout=30)
|
||||
res = httpx.put(
|
||||
url.format(app_token=app_token, table_id=table_id, record_id=record_id),
|
||||
headers=headers,
|
||||
params=params,
|
||||
json=payload,
|
||||
timeout=30,
|
||||
)
|
||||
res_json = res.json()
|
||||
if res.is_success:
|
||||
return self.create_text_message(text=json.dumps(res_json))
|
||||
else:
|
||||
return self.create_text_message(
|
||||
f"Failed to update base record, status code: {res.status_code}, response: {res.text}")
|
||||
f"Failed to update base record, status code: {res.status_code}, response: {res.text}"
|
||||
)
|
||||
except Exception as e:
|
||||
return self.create_text_message("Failed to update base record. {}".format(e))
|
||||
|
||||
@ -5,11 +5,11 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
class FeishuDocumentProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
app_id = credentials.get('app_id')
|
||||
app_secret = credentials.get('app_secret')
|
||||
app_id = credentials.get("app_id")
|
||||
app_secret = credentials.get("app_secret")
|
||||
if not app_id or not app_secret:
|
||||
raise ToolProviderCredentialValidationError("app_id and app_secret is required")
|
||||
try:
|
||||
assert FeishuRequest(app_id, app_secret).tenant_access_token is not None
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
@ -7,13 +7,13 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
class CreateDocumentTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
app_id = self.runtime.credentials.get('app_id')
|
||||
app_secret = self.runtime.credentials.get('app_secret')
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
title = tool_parameters.get('title')
|
||||
content = tool_parameters.get('content')
|
||||
folder_token = tool_parameters.get('folder_token')
|
||||
title = tool_parameters.get("title")
|
||||
content = tool_parameters.get("content")
|
||||
folder_token = tool_parameters.get("folder_token")
|
||||
|
||||
res = client.create_document(title, content, folder_token)
|
||||
return self.create_json_message(res)
|
||||
|
||||
@ -7,11 +7,11 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
class GetDocumentRawContentTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
app_id = self.runtime.credentials.get('app_id')
|
||||
app_secret = self.runtime.credentials.get('app_secret')
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
document_id = tool_parameters.get('document_id')
|
||||
document_id = tool_parameters.get("document_id")
|
||||
|
||||
res = client.get_document_raw_content(document_id)
|
||||
return self.create_json_message(res)
|
||||
return self.create_json_message(res)
|
||||
|
||||
@ -7,13 +7,13 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
class ListDocumentBlockTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
app_id = self.runtime.credentials.get('app_id')
|
||||
app_secret = self.runtime.credentials.get('app_secret')
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
document_id = tool_parameters.get('document_id')
|
||||
page_size = tool_parameters.get('page_size', 500)
|
||||
page_token = tool_parameters.get('page_token', '')
|
||||
document_id = tool_parameters.get("document_id")
|
||||
page_size = tool_parameters.get("page_size", 500)
|
||||
page_token = tool_parameters.get("page_token", "")
|
||||
|
||||
res = client.list_document_block(document_id, page_token, page_size)
|
||||
return self.create_json_message(res)
|
||||
|
||||
@ -7,13 +7,13 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
class CreateDocumentTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
app_id = self.runtime.credentials.get('app_id')
|
||||
app_secret = self.runtime.credentials.get('app_secret')
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
document_id = tool_parameters.get('document_id')
|
||||
content = tool_parameters.get('content')
|
||||
position = tool_parameters.get('position')
|
||||
document_id = tool_parameters.get("document_id")
|
||||
content = tool_parameters.get("content")
|
||||
position = tool_parameters.get("position")
|
||||
|
||||
res = client.write_document(document_id, content, position)
|
||||
return self.create_json_message(res)
|
||||
|
||||
@ -5,11 +5,11 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
class FeishuMessageProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
app_id = credentials.get('app_id')
|
||||
app_secret = credentials.get('app_secret')
|
||||
app_id = credentials.get("app_id")
|
||||
app_secret = credentials.get("app_secret")
|
||||
if not app_id or not app_secret:
|
||||
raise ToolProviderCredentialValidationError("app_id and app_secret is required")
|
||||
try:
|
||||
assert FeishuRequest(app_id, app_secret).tenant_access_token is not None
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
@ -7,14 +7,14 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
class SendBotMessageTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
app_id = self.runtime.credentials.get('app_id')
|
||||
app_secret = self.runtime.credentials.get('app_secret')
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
receive_id_type = tool_parameters.get('receive_id_type')
|
||||
receive_id = tool_parameters.get('receive_id')
|
||||
msg_type = tool_parameters.get('msg_type')
|
||||
content = tool_parameters.get('content')
|
||||
receive_id_type = tool_parameters.get("receive_id_type")
|
||||
receive_id = tool_parameters.get("receive_id")
|
||||
msg_type = tool_parameters.get("msg_type")
|
||||
content = tool_parameters.get("content")
|
||||
|
||||
res = client.send_bot_message(receive_id_type, receive_id, msg_type, content)
|
||||
return self.create_json_message(res)
|
||||
|
||||
@ -6,14 +6,14 @@ from core.tools.utils.feishu_api_utils import FeishuRequest
|
||||
|
||||
|
||||
class SendWebhookMessageTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) ->ToolInvokeMessage:
|
||||
app_id = self.runtime.credentials.get('app_id')
|
||||
app_secret = self.runtime.credentials.get('app_secret')
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
app_id = self.runtime.credentials.get("app_id")
|
||||
app_secret = self.runtime.credentials.get("app_secret")
|
||||
client = FeishuRequest(app_id, app_secret)
|
||||
|
||||
webhook = tool_parameters.get('webhook')
|
||||
msg_type = tool_parameters.get('msg_type')
|
||||
content = tool_parameters.get('content')
|
||||
webhook = tool_parameters.get("webhook")
|
||||
msg_type = tool_parameters.get("msg_type")
|
||||
content = tool_parameters.get("content")
|
||||
|
||||
res = client.send_webhook_message(webhook, msg_type, content)
|
||||
return self.create_json_message(res)
|
||||
|
||||
@ -7,15 +7,8 @@ class FirecrawlProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
# Example validation using the ScrapeTool, only scraping title for minimize content
|
||||
ScrapeTool().fork_tool_runtime(
|
||||
runtime={"credentials": credentials}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_parameters={
|
||||
"url": "https://google.com",
|
||||
"onlyIncludeTags": 'title'
|
||||
}
|
||||
ScrapeTool().fork_tool_runtime(runtime={"credentials": credentials}).invoke(
|
||||
user_id="", tool_parameters={"url": "https://google.com", "onlyIncludeTags": "title"}
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
@ -13,85 +13,83 @@ logger = logging.getLogger(__name__)
|
||||
class FirecrawlApp:
|
||||
def __init__(self, api_key: str | None = None, base_url: str | None = None):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url or 'https://api.firecrawl.dev'
|
||||
self.base_url = base_url or "https://api.firecrawl.dev"
|
||||
if not self.api_key:
|
||||
raise ValueError("API key is required")
|
||||
|
||||
def _prepare_headers(self, idempotency_key: str | None = None):
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {self.api_key}'
|
||||
}
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
if idempotency_key:
|
||||
headers['Idempotency-Key'] = idempotency_key
|
||||
headers["Idempotency-Key"] = idempotency_key
|
||||
return headers
|
||||
|
||||
def _request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
data: Mapping[str, Any] | None = None,
|
||||
headers: Mapping[str, str] | None = None,
|
||||
retries: int = 3,
|
||||
backoff_factor: float = 0.3,
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
data: Mapping[str, Any] | None = None,
|
||||
headers: Mapping[str, str] | None = None,
|
||||
retries: int = 3,
|
||||
backoff_factor: float = 0.3,
|
||||
) -> Mapping[str, Any] | None:
|
||||
if not headers:
|
||||
headers = self._prepare_headers()
|
||||
for i in range(retries):
|
||||
try:
|
||||
response = requests.request(method, url, json=data, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
except requests.exceptions.RequestException:
|
||||
if i < retries - 1:
|
||||
time.sleep(backoff_factor * (2 ** i))
|
||||
time.sleep(backoff_factor * (2**i))
|
||||
else:
|
||||
raise
|
||||
return None
|
||||
|
||||
def scrape_url(self, url: str, **kwargs):
|
||||
endpoint = f'{self.base_url}/v0/scrape'
|
||||
data = {'url': url, **kwargs}
|
||||
endpoint = f"{self.base_url}/v1/scrape"
|
||||
data = {"url": url, **kwargs}
|
||||
logger.debug(f"Sent request to {endpoint=} body={data}")
|
||||
response = self._request('POST', endpoint, data)
|
||||
response = self._request("POST", endpoint, data)
|
||||
if response is None:
|
||||
raise HTTPError("Failed to scrape URL after multiple retries")
|
||||
return response
|
||||
|
||||
def search(self, query: str, **kwargs):
|
||||
endpoint = f'{self.base_url}/v0/search'
|
||||
data = {'query': query, **kwargs}
|
||||
def map(self, url: str, **kwargs):
|
||||
endpoint = f"{self.base_url}/v1/map"
|
||||
data = {"url": url, **kwargs}
|
||||
logger.debug(f"Sent request to {endpoint=} body={data}")
|
||||
response = self._request('POST', endpoint, data)
|
||||
response = self._request("POST", endpoint, data)
|
||||
if response is None:
|
||||
raise HTTPError("Failed to perform search after multiple retries")
|
||||
raise HTTPError("Failed to perform map after multiple retries")
|
||||
return response
|
||||
|
||||
def crawl_url(
|
||||
self, url: str, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs
|
||||
self, url: str, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs
|
||||
):
|
||||
endpoint = f'{self.base_url}/v0/crawl'
|
||||
endpoint = f"{self.base_url}/v1/crawl"
|
||||
headers = self._prepare_headers(idempotency_key)
|
||||
data = {'url': url, **kwargs}
|
||||
data = {"url": url, **kwargs}
|
||||
logger.debug(f"Sent request to {endpoint=} body={data}")
|
||||
response = self._request('POST', endpoint, data, headers)
|
||||
response = self._request("POST", endpoint, data, headers)
|
||||
if response is None:
|
||||
raise HTTPError("Failed to initiate crawl after multiple retries")
|
||||
job_id: str = response['jobId']
|
||||
elif response.get("success") == False:
|
||||
raise HTTPError(f'Failed to crawl: {response.get("error")}')
|
||||
job_id: str = response["id"]
|
||||
if wait:
|
||||
return self._monitor_job_status(job_id=job_id, poll_interval=poll_interval)
|
||||
return response
|
||||
|
||||
def check_crawl_status(self, job_id: str):
|
||||
endpoint = f'{self.base_url}/v0/crawl/status/{job_id}'
|
||||
response = self._request('GET', endpoint)
|
||||
endpoint = f"{self.base_url}/v1/crawl/{job_id}"
|
||||
response = self._request("GET", endpoint)
|
||||
if response is None:
|
||||
raise HTTPError(f"Failed to check status for job {job_id} after multiple retries")
|
||||
return response
|
||||
|
||||
def cancel_crawl_job(self, job_id: str):
|
||||
endpoint = f'{self.base_url}/v0/crawl/cancel/{job_id}'
|
||||
response = self._request('DELETE', endpoint)
|
||||
endpoint = f"{self.base_url}/v1/crawl/{job_id}"
|
||||
response = self._request("DELETE", endpoint)
|
||||
if response is None:
|
||||
raise HTTPError(f"Failed to cancel job {job_id} after multiple retries")
|
||||
return response
|
||||
@ -99,9 +97,9 @@ class FirecrawlApp:
|
||||
def _monitor_job_status(self, job_id: str, poll_interval: int):
|
||||
while True:
|
||||
status = self.check_crawl_status(job_id)
|
||||
if status['status'] == 'completed':
|
||||
if status["status"] == "completed":
|
||||
return status
|
||||
elif status['status'] == 'failed':
|
||||
elif status["status"] == "failed":
|
||||
raise HTTPError(f'Job {job_id} failed: {status["error"]}')
|
||||
time.sleep(poll_interval)
|
||||
|
||||
@ -109,7 +107,7 @@ class FirecrawlApp:
|
||||
def get_array_params(tool_parameters: dict[str, Any], key):
|
||||
param = tool_parameters.get(key)
|
||||
if param:
|
||||
return param.split(',')
|
||||
return param.split(",")
|
||||
|
||||
|
||||
def get_json_params(tool_parameters: dict[str, Any], key):
|
||||
@ -119,6 +117,6 @@ def get_json_params(tool_parameters: dict[str, Any], key):
|
||||
# support both single quotes and double quotes
|
||||
param = param.replace("'", '"')
|
||||
param = json.loads(param)
|
||||
except:
|
||||
except Exception:
|
||||
raise ValueError(f"Invalid {key} format.")
|
||||
return param
|
||||
|
||||
@ -8,41 +8,38 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
class CrawlTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
"""
|
||||
the crawlerOptions and pageOptions comes from doc here:
|
||||
the api doc:
|
||||
https://docs.firecrawl.dev/api-reference/endpoint/crawl
|
||||
"""
|
||||
app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'],
|
||||
base_url=self.runtime.credentials['base_url'])
|
||||
crawlerOptions = {}
|
||||
pageOptions = {}
|
||||
|
||||
wait_for_results = tool_parameters.get('wait_for_results', True)
|
||||
|
||||
crawlerOptions['excludes'] = get_array_params(tool_parameters, 'excludes')
|
||||
crawlerOptions['includes'] = get_array_params(tool_parameters, 'includes')
|
||||
crawlerOptions['returnOnlyUrls'] = tool_parameters.get('returnOnlyUrls', False)
|
||||
crawlerOptions['maxDepth'] = tool_parameters.get('maxDepth')
|
||||
crawlerOptions['mode'] = tool_parameters.get('mode')
|
||||
crawlerOptions['ignoreSitemap'] = tool_parameters.get('ignoreSitemap', False)
|
||||
crawlerOptions['limit'] = tool_parameters.get('limit', 5)
|
||||
crawlerOptions['allowBackwardCrawling'] = tool_parameters.get('allowBackwardCrawling', False)
|
||||
crawlerOptions['allowExternalContentLinks'] = tool_parameters.get('allowExternalContentLinks', False)
|
||||
|
||||
pageOptions['headers'] = get_json_params(tool_parameters, 'headers')
|
||||
pageOptions['includeHtml'] = tool_parameters.get('includeHtml', False)
|
||||
pageOptions['includeRawHtml'] = tool_parameters.get('includeRawHtml', False)
|
||||
pageOptions['onlyIncludeTags'] = get_array_params(tool_parameters, 'onlyIncludeTags')
|
||||
pageOptions['removeTags'] = get_array_params(tool_parameters, 'removeTags')
|
||||
pageOptions['onlyMainContent'] = tool_parameters.get('onlyMainContent', False)
|
||||
pageOptions['replaceAllPathsWithAbsolutePaths'] = tool_parameters.get('replaceAllPathsWithAbsolutePaths', False)
|
||||
pageOptions['screenshot'] = tool_parameters.get('screenshot', False)
|
||||
pageOptions['waitFor'] = tool_parameters.get('waitFor', 0)
|
||||
|
||||
crawl_result = app.crawl_url(
|
||||
url=tool_parameters['url'],
|
||||
wait=wait_for_results,
|
||||
crawlerOptions=crawlerOptions,
|
||||
pageOptions=pageOptions
|
||||
app = FirecrawlApp(
|
||||
api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"]
|
||||
)
|
||||
|
||||
scrapeOptions = {}
|
||||
payload = {}
|
||||
|
||||
wait_for_results = tool_parameters.get("wait_for_results", True)
|
||||
|
||||
payload["excludePaths"] = get_array_params(tool_parameters, "excludePaths")
|
||||
payload["includePaths"] = get_array_params(tool_parameters, "includePaths")
|
||||
payload["maxDepth"] = tool_parameters.get("maxDepth")
|
||||
payload["ignoreSitemap"] = tool_parameters.get("ignoreSitemap", False)
|
||||
payload["limit"] = tool_parameters.get("limit", 5)
|
||||
payload["allowBackwardLinks"] = tool_parameters.get("allowBackwardLinks", False)
|
||||
payload["allowExternalLinks"] = tool_parameters.get("allowExternalLinks", False)
|
||||
payload["webhook"] = tool_parameters.get("webhook")
|
||||
|
||||
scrapeOptions["formats"] = get_array_params(tool_parameters, "formats")
|
||||
scrapeOptions["headers"] = get_json_params(tool_parameters, "headers")
|
||||
scrapeOptions["includeTags"] = get_array_params(tool_parameters, "includeTags")
|
||||
scrapeOptions["excludeTags"] = get_array_params(tool_parameters, "excludeTags")
|
||||
scrapeOptions["onlyMainContent"] = tool_parameters.get("onlyMainContent", False)
|
||||
scrapeOptions["waitFor"] = tool_parameters.get("waitFor", 0)
|
||||
scrapeOptions = {k: v for k, v in scrapeOptions.items() if v not in {None, ""}}
|
||||
payload["scrapeOptions"] = scrapeOptions or None
|
||||
|
||||
payload = {k: v for k, v in payload.items() if v not in {None, ""}}
|
||||
|
||||
crawl_result = app.crawl_url(url=tool_parameters["url"], wait=wait_for_results, **payload)
|
||||
|
||||
return self.create_json_message(crawl_result)
|
||||
|
||||
@ -31,8 +31,21 @@ parameters:
|
||||
en_US: If you choose not to wait, it will directly return a job ID. You can use this job ID to check the crawling results or cancel the crawling task, which is usually very useful for a large-scale crawling task.
|
||||
zh_Hans: 如果选择不等待,则会直接返回一个job_id,可以通过job_id查询爬取结果或取消爬取任务,这通常对于一个大型爬取任务来说非常有用。
|
||||
form: form
|
||||
############## Crawl Options #######################
|
||||
- name: includes
|
||||
############## Payload #######################
|
||||
- name: excludePaths
|
||||
type: string
|
||||
label:
|
||||
en_US: URL patterns to exclude
|
||||
zh_Hans: 要排除的URL模式
|
||||
placeholder:
|
||||
en_US: Use commas to separate multiple tags
|
||||
zh_Hans: 多个标签时使用半角逗号分隔
|
||||
human_description:
|
||||
en_US: |
|
||||
Pages matching these patterns will be skipped. Example: blog/*, about/*
|
||||
zh_Hans: 匹配这些模式的页面将被跳过。示例:blog/*, about/*
|
||||
form: form
|
||||
- name: includePaths
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
@ -46,30 +59,6 @@ parameters:
|
||||
Only pages matching these patterns will be crawled. Example: blog/*, about/*
|
||||
zh_Hans: 只有与这些模式匹配的页面才会被爬取。示例:blog/*, about/*
|
||||
form: form
|
||||
- name: excludes
|
||||
type: string
|
||||
label:
|
||||
en_US: URL patterns to exclude
|
||||
zh_Hans: 要排除的URL模式
|
||||
placeholder:
|
||||
en_US: Use commas to separate multiple tags
|
||||
zh_Hans: 多个标签时使用半角逗号分隔
|
||||
human_description:
|
||||
en_US: |
|
||||
Pages matching these patterns will be skipped. Example: blog/*, about/*
|
||||
zh_Hans: 匹配这些模式的页面将被跳过。示例:blog/*, about/*
|
||||
form: form
|
||||
- name: returnOnlyUrls
|
||||
type: boolean
|
||||
default: false
|
||||
label:
|
||||
en_US: return Only Urls
|
||||
zh_Hans: 仅返回URL
|
||||
human_description:
|
||||
en_US: |
|
||||
If true, returns only the URLs as a list on the crawl status. Attention: the return response will be a list of URLs inside the data, not a list of documents.
|
||||
zh_Hans: 只返回爬取到的网页链接,而不是网页内容本身。
|
||||
form: form
|
||||
- name: maxDepth
|
||||
type: number
|
||||
label:
|
||||
@ -80,27 +69,10 @@ parameters:
|
||||
zh_Hans: 相对于输入的URL,爬取的最大深度。maxDepth为0时,仅抓取输入的URL。maxDepth为1时,抓取输入的URL以及所有一级深层页面。maxDepth为2时,抓取输入的URL以及所有两级深层页面。更高值遵循相同模式。
|
||||
form: form
|
||||
min: 0
|
||||
- name: mode
|
||||
type: select
|
||||
required: false
|
||||
form: form
|
||||
options:
|
||||
- value: default
|
||||
label:
|
||||
en_US: default
|
||||
- value: fast
|
||||
label:
|
||||
en_US: fast
|
||||
default: default
|
||||
label:
|
||||
en_US: Crawl Mode
|
||||
zh_Hans: 爬取模式
|
||||
human_description:
|
||||
en_US: The crawling mode to use. Fast mode crawls 4x faster websites without sitemap, but may not be as accurate and shouldn't be used in heavy js-rendered websites.
|
||||
zh_Hans: 使用fast模式将不会使用其站点地图,比普通模式快4倍,但是可能不够准确,也不适用于大量js渲染的网站。
|
||||
default: 2
|
||||
- name: ignoreSitemap
|
||||
type: boolean
|
||||
default: false
|
||||
default: true
|
||||
label:
|
||||
en_US: ignore Sitemap
|
||||
zh_Hans: 忽略站点地图
|
||||
@ -120,7 +92,7 @@ parameters:
|
||||
form: form
|
||||
min: 1
|
||||
default: 5
|
||||
- name: allowBackwardCrawling
|
||||
- name: allowBackwardLinks
|
||||
type: boolean
|
||||
default: false
|
||||
label:
|
||||
@ -130,7 +102,7 @@ parameters:
|
||||
en_US: Enables the crawler to navigate from a specific URL to previously linked pages. For instance, from 'example.com/product/123' back to 'example.com/product'
|
||||
zh_Hans: 使爬虫能够从特定URL导航到之前链接的页面。例如,从'example.com/product/123'返回到'example.com/product'
|
||||
form: form
|
||||
- name: allowExternalContentLinks
|
||||
- name: allowExternalLinks
|
||||
type: boolean
|
||||
default: false
|
||||
label:
|
||||
@ -140,7 +112,30 @@ parameters:
|
||||
en_US: Allows the crawler to follow links to external websites.
|
||||
zh_Hans:
|
||||
form: form
|
||||
############## Page Options #######################
|
||||
- name: webhook
|
||||
type: string
|
||||
label:
|
||||
en_US: webhook
|
||||
human_description:
|
||||
en_US: |
|
||||
The URL to send the webhook to. This will trigger for crawl started (crawl.started) ,every page crawled (crawl.page) and when the crawl is completed (crawl.completed or crawl.failed). The response will be the same as the /scrape endpoint.
|
||||
zh_Hans: 发送Webhook的URL。这将在开始爬取(crawl.started)、每爬取一个页面(crawl.page)以及爬取完成(crawl.completed或crawl.failed)时触发。响应将与/scrape端点相同。
|
||||
form: form
|
||||
############## Scrape Options #######################
|
||||
- name: formats
|
||||
type: string
|
||||
label:
|
||||
en_US: Formats
|
||||
zh_Hans: 结果的格式
|
||||
placeholder:
|
||||
en_US: Use commas to separate multiple tags
|
||||
zh_Hans: 多个标签时使用半角逗号分隔
|
||||
human_description:
|
||||
en_US: |
|
||||
Formats to include in the output. Available options: markdown, html, rawHtml, links, screenshot
|
||||
zh_Hans: |
|
||||
输出中应包含的格式。可以填入: markdown, html, rawHtml, links, screenshot
|
||||
form: form
|
||||
- name: headers
|
||||
type: string
|
||||
label:
|
||||
@ -155,30 +150,10 @@ parameters:
|
||||
en_US: Please enter an object that can be serialized in JSON
|
||||
zh_Hans: 请输入可以json序列化的对象
|
||||
form: form
|
||||
- name: includeHtml
|
||||
type: boolean
|
||||
default: false
|
||||
label:
|
||||
en_US: include Html
|
||||
zh_Hans: 包含HTML
|
||||
human_description:
|
||||
en_US: Include the HTML version of the content on page. Will output a html key in the response.
|
||||
zh_Hans: 返回中包含一个HTML版本的内容,将以html键返回。
|
||||
form: form
|
||||
- name: includeRawHtml
|
||||
type: boolean
|
||||
default: false
|
||||
label:
|
||||
en_US: include Raw Html
|
||||
zh_Hans: 包含原始HTML
|
||||
human_description:
|
||||
en_US: Include the raw HTML content of the page. Will output a rawHtml key in the response.
|
||||
zh_Hans: 返回中包含一个原始HTML版本的内容,将以rawHtml键返回。
|
||||
form: form
|
||||
- name: onlyIncludeTags
|
||||
- name: includeTags
|
||||
type: string
|
||||
label:
|
||||
en_US: only Include Tags
|
||||
en_US: Include Tags
|
||||
zh_Hans: 仅抓取这些标签
|
||||
placeholder:
|
||||
en_US: Use commas to separate multiple tags
|
||||
@ -189,6 +164,20 @@ parameters:
|
||||
zh_Hans: |
|
||||
仅在最终输出中包含HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer
|
||||
form: form
|
||||
- name: excludeTags
|
||||
type: string
|
||||
label:
|
||||
en_US: Exclude Tags
|
||||
zh_Hans: 要移除这些标签
|
||||
human_description:
|
||||
en_US: |
|
||||
Tags, classes and ids to remove from the page. Use comma separated values. Example: script, .ad, #footer
|
||||
zh_Hans: |
|
||||
要在最终输出中移除HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer
|
||||
placeholder:
|
||||
en_US: Use commas to separate multiple tags
|
||||
zh_Hans: 多个标签时使用半角逗号分隔
|
||||
form: form
|
||||
- name: onlyMainContent
|
||||
type: boolean
|
||||
default: false
|
||||
@ -199,40 +188,6 @@ parameters:
|
||||
en_US: Only return the main content of the page excluding headers, navs, footers, etc.
|
||||
zh_Hans: 只返回页面的主要内容,不包括头部、导航栏、尾部等。
|
||||
form: form
|
||||
- name: removeTags
|
||||
type: string
|
||||
label:
|
||||
en_US: remove Tags
|
||||
zh_Hans: 要移除这些标签
|
||||
human_description:
|
||||
en_US: |
|
||||
Tags, classes and ids to remove from the page. Use comma separated values. Example: script, .ad, #footer
|
||||
zh_Hans: |
|
||||
要在最终输出中移除HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer
|
||||
placeholder:
|
||||
en_US: Use commas to separate multiple tags
|
||||
zh_Hans: 多个标签时使用半角逗号分隔
|
||||
form: form
|
||||
- name: replaceAllPathsWithAbsolutePaths
|
||||
type: boolean
|
||||
default: false
|
||||
label:
|
||||
en_US: All AbsolutePaths
|
||||
zh_Hans: 使用绝对路径
|
||||
human_description:
|
||||
en_US: Replace all relative paths with absolute paths for images and links.
|
||||
zh_Hans: 将所有图片和链接的相对路径替换为绝对路径。
|
||||
form: form
|
||||
- name: screenshot
|
||||
type: boolean
|
||||
default: false
|
||||
label:
|
||||
en_US: screenshot
|
||||
zh_Hans: 截图
|
||||
human_description:
|
||||
en_US: Include a screenshot of the top of the page that you are scraping.
|
||||
zh_Hans: 提供正在抓取的页面的顶部的截图。
|
||||
form: form
|
||||
- name: waitFor
|
||||
type: number
|
||||
min: 0
|
||||
|
||||
@ -7,14 +7,15 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
class CrawlJobTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'],
|
||||
base_url=self.runtime.credentials['base_url'])
|
||||
operation = tool_parameters.get('operation', 'get')
|
||||
if operation == 'get':
|
||||
result = app.check_crawl_status(job_id=tool_parameters['job_id'])
|
||||
elif operation == 'cancel':
|
||||
result = app.cancel_crawl_job(job_id=tool_parameters['job_id'])
|
||||
app = FirecrawlApp(
|
||||
api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"]
|
||||
)
|
||||
operation = tool_parameters.get("operation", "get")
|
||||
if operation == "get":
|
||||
result = app.check_crawl_status(job_id=tool_parameters["job_id"])
|
||||
elif operation == "cancel":
|
||||
result = app.cancel_crawl_job(job_id=tool_parameters["job_id"])
|
||||
else:
|
||||
raise ValueError(f'Invalid operation: {operation}')
|
||||
raise ValueError(f"Invalid operation: {operation}")
|
||||
|
||||
return self.create_json_message(result)
|
||||
|
||||
25
api/core/tools/provider/builtin/firecrawl/tools/map.py
Normal file
25
api/core/tools/provider/builtin/firecrawl/tools/map.py
Normal file
@ -0,0 +1,25 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.provider.builtin.firecrawl.firecrawl_appx import FirecrawlApp
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class MapTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
"""
|
||||
the api doc:
|
||||
https://docs.firecrawl.dev/api-reference/endpoint/map
|
||||
"""
|
||||
app = FirecrawlApp(
|
||||
api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"]
|
||||
)
|
||||
payload = {}
|
||||
payload["search"] = tool_parameters.get("search")
|
||||
payload["ignoreSitemap"] = tool_parameters.get("ignoreSitemap", True)
|
||||
payload["includeSubdomains"] = tool_parameters.get("includeSubdomains", False)
|
||||
payload["limit"] = tool_parameters.get("limit", 5000)
|
||||
|
||||
map_result = app.map(url=tool_parameters["url"], **payload)
|
||||
|
||||
return self.create_json_message(map_result)
|
||||
59
api/core/tools/provider/builtin/firecrawl/tools/map.yaml
Normal file
59
api/core/tools/provider/builtin/firecrawl/tools/map.yaml
Normal file
@ -0,0 +1,59 @@
|
||||
identity:
|
||||
name: map
|
||||
author: hjlarry
|
||||
label:
|
||||
en_US: Map
|
||||
zh_Hans: 地图式快爬
|
||||
description:
|
||||
human:
|
||||
en_US: Input a website and get all the urls on the website - extremly fast
|
||||
zh_Hans: 输入一个网站,快速获取网站上的所有网址。
|
||||
llm: Input a website and get all the urls on the website - extremly fast
|
||||
parameters:
|
||||
- name: url
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Start URL
|
||||
zh_Hans: 起始URL
|
||||
human_description:
|
||||
en_US: The base URL to start crawling from.
|
||||
zh_Hans: 要爬取网站的起始URL。
|
||||
llm_description: The URL of the website that needs to be crawled. This is a required parameter.
|
||||
form: llm
|
||||
- name: search
|
||||
type: string
|
||||
label:
|
||||
en_US: search
|
||||
zh_Hans: 搜索查询
|
||||
human_description:
|
||||
en_US: Search query to use for mapping. During the Alpha phase, the 'smart' part of the search functionality is limited to 100 search results. However, if map finds more results, there is no limit applied.
|
||||
zh_Hans: 用于映射的搜索查询。在Alpha阶段,搜索功能的“智能”部分限制为最多100个搜索结果。然而,如果地图找到了更多结果,则不施加任何限制。
|
||||
llm_description: Search query to use for mapping. During the Alpha phase, the 'smart' part of the search functionality is limited to 100 search results. However, if map finds more results, there is no limit applied.
|
||||
form: llm
|
||||
############## Page Options #######################
|
||||
- name: ignoreSitemap
|
||||
type: boolean
|
||||
default: true
|
||||
label:
|
||||
en_US: ignore Sitemap
|
||||
zh_Hans: 忽略站点地图
|
||||
human_description:
|
||||
en_US: Ignore the website sitemap when crawling.
|
||||
zh_Hans: 爬取时忽略网站站点地图。
|
||||
form: form
|
||||
- name: includeSubdomains
|
||||
type: boolean
|
||||
default: false
|
||||
label:
|
||||
en_US: include Subdomains
|
||||
zh_Hans: 包含子域名
|
||||
form: form
|
||||
- name: limit
|
||||
type: number
|
||||
min: 0
|
||||
default: 5000
|
||||
label:
|
||||
en_US: Maximum results
|
||||
zh_Hans: 最大结果数量
|
||||
form: form
|
||||
@ -6,34 +6,34 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class ScrapeTool(BuiltinTool):
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
|
||||
"""
|
||||
the pageOptions and extractorOptions comes from doc here:
|
||||
the api doc:
|
||||
https://docs.firecrawl.dev/api-reference/endpoint/scrape
|
||||
"""
|
||||
app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'],
|
||||
base_url=self.runtime.credentials['base_url'])
|
||||
app = FirecrawlApp(
|
||||
api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"]
|
||||
)
|
||||
|
||||
pageOptions = {}
|
||||
extractorOptions = {}
|
||||
payload = {}
|
||||
extract = {}
|
||||
|
||||
pageOptions['headers'] = get_json_params(tool_parameters, 'headers')
|
||||
pageOptions['includeHtml'] = tool_parameters.get('includeHtml', False)
|
||||
pageOptions['includeRawHtml'] = tool_parameters.get('includeRawHtml', False)
|
||||
pageOptions['onlyIncludeTags'] = get_array_params(tool_parameters, 'onlyIncludeTags')
|
||||
pageOptions['removeTags'] = get_array_params(tool_parameters, 'removeTags')
|
||||
pageOptions['onlyMainContent'] = tool_parameters.get('onlyMainContent', False)
|
||||
pageOptions['replaceAllPathsWithAbsolutePaths'] = tool_parameters.get('replaceAllPathsWithAbsolutePaths', False)
|
||||
pageOptions['screenshot'] = tool_parameters.get('screenshot', False)
|
||||
pageOptions['waitFor'] = tool_parameters.get('waitFor', 0)
|
||||
payload["formats"] = get_array_params(tool_parameters, "formats")
|
||||
payload["onlyMainContent"] = tool_parameters.get("onlyMainContent", True)
|
||||
payload["includeTags"] = get_array_params(tool_parameters, "includeTags")
|
||||
payload["excludeTags"] = get_array_params(tool_parameters, "excludeTags")
|
||||
payload["headers"] = get_json_params(tool_parameters, "headers")
|
||||
payload["waitFor"] = tool_parameters.get("waitFor", 0)
|
||||
payload["timeout"] = tool_parameters.get("timeout", 30000)
|
||||
|
||||
extractorOptions['mode'] = tool_parameters.get('mode', '')
|
||||
extractorOptions['extractionPrompt'] = tool_parameters.get('extractionPrompt', '')
|
||||
extractorOptions['extractionSchema'] = get_json_params(tool_parameters, 'extractionSchema')
|
||||
extract["schema"] = get_json_params(tool_parameters, "schema")
|
||||
extract["systemPrompt"] = tool_parameters.get("systemPrompt")
|
||||
extract["prompt"] = tool_parameters.get("prompt")
|
||||
extract = {k: v for k, v in extract.items() if v not in {None, ""}}
|
||||
payload["extract"] = extract or None
|
||||
|
||||
crawl_result = app.scrape_url(url=tool_parameters['url'],
|
||||
pageOptions=pageOptions,
|
||||
extractorOptions=extractorOptions)
|
||||
payload = {k: v for k, v in payload.items() if v not in {None, ""}}
|
||||
|
||||
return self.create_json_message(crawl_result)
|
||||
crawl_result = app.scrape_url(url=tool_parameters["url"], **payload)
|
||||
markdown_result = crawl_result.get("data", {}).get("markdown", "")
|
||||
return [self.create_text_message(markdown_result), self.create_json_message(crawl_result)]
|
||||
|
||||
@ -6,8 +6,8 @@ identity:
|
||||
zh_Hans: 单页面抓取
|
||||
description:
|
||||
human:
|
||||
en_US: Extract data from a single URL.
|
||||
zh_Hans: 从单个URL抓取数据。
|
||||
en_US: Turn any url into clean data.
|
||||
zh_Hans: 将任何网址转换为干净的数据。
|
||||
llm: This tool is designed to scrape URL and output the content in Markdown format.
|
||||
parameters:
|
||||
- name: url
|
||||
@ -21,7 +21,59 @@ parameters:
|
||||
zh_Hans: 要抓取并提取数据的网站URL。
|
||||
llm_description: The URL of the website that needs to be crawled. This is a required parameter.
|
||||
form: llm
|
||||
############## Page Options #######################
|
||||
############## Payload #######################
|
||||
- name: formats
|
||||
type: string
|
||||
label:
|
||||
en_US: Formats
|
||||
zh_Hans: 结果的格式
|
||||
placeholder:
|
||||
en_US: Use commas to separate multiple tags
|
||||
zh_Hans: 多个标签时使用半角逗号分隔
|
||||
human_description:
|
||||
en_US: |
|
||||
Formats to include in the output. Available options: markdown, html, rawHtml, links, screenshot, extract, screenshot@fullPage
|
||||
zh_Hans: |
|
||||
输出中应包含的格式。可以填入: markdown, html, rawHtml, links, screenshot, extract, screenshot@fullPage
|
||||
form: form
|
||||
- name: onlyMainContent
|
||||
type: boolean
|
||||
default: false
|
||||
label:
|
||||
en_US: only Main Content
|
||||
zh_Hans: 仅抓取主要内容
|
||||
human_description:
|
||||
en_US: Only return the main content of the page excluding headers, navs, footers, etc.
|
||||
zh_Hans: 只返回页面的主要内容,不包括头部、导航栏、尾部等。
|
||||
form: form
|
||||
- name: includeTags
|
||||
type: string
|
||||
label:
|
||||
en_US: Include Tags
|
||||
zh_Hans: 仅抓取这些标签
|
||||
placeholder:
|
||||
en_US: Use commas to separate multiple tags
|
||||
zh_Hans: 多个标签时使用半角逗号分隔
|
||||
human_description:
|
||||
en_US: |
|
||||
Only include tags, classes and ids from the page in the final output. Use comma separated values. Example: script, .ad, #footer
|
||||
zh_Hans: |
|
||||
仅在最终输出中包含HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer
|
||||
form: form
|
||||
- name: excludeTags
|
||||
type: string
|
||||
label:
|
||||
en_US: Exclude Tags
|
||||
zh_Hans: 要移除这些标签
|
||||
human_description:
|
||||
en_US: |
|
||||
Tags, classes and ids to remove from the page. Use comma separated values. Example: script, .ad, #footer
|
||||
zh_Hans: |
|
||||
要在最终输出中移除HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer
|
||||
placeholder:
|
||||
en_US: Use commas to separate multiple tags
|
||||
zh_Hans: 多个标签时使用半角逗号分隔
|
||||
form: form
|
||||
- name: headers
|
||||
type: string
|
||||
label:
|
||||
@ -36,87 +88,10 @@ parameters:
|
||||
en_US: Please enter an object that can be serialized in JSON
|
||||
zh_Hans: 请输入可以json序列化的对象
|
||||
form: form
|
||||
- name: includeHtml
|
||||
type: boolean
|
||||
default: false
|
||||
label:
|
||||
en_US: include Html
|
||||
zh_Hans: 包含HTML
|
||||
human_description:
|
||||
en_US: Include the HTML version of the content on page. Will output a html key in the response.
|
||||
zh_Hans: 返回中包含一个HTML版本的内容,将以html键返回。
|
||||
form: form
|
||||
- name: includeRawHtml
|
||||
type: boolean
|
||||
default: false
|
||||
label:
|
||||
en_US: include Raw Html
|
||||
zh_Hans: 包含原始HTML
|
||||
human_description:
|
||||
en_US: Include the raw HTML content of the page. Will output a rawHtml key in the response.
|
||||
zh_Hans: 返回中包含一个原始HTML版本的内容,将以rawHtml键返回。
|
||||
form: form
|
||||
- name: onlyIncludeTags
|
||||
type: string
|
||||
label:
|
||||
en_US: only Include Tags
|
||||
zh_Hans: 仅抓取这些标签
|
||||
placeholder:
|
||||
en_US: Use commas to separate multiple tags
|
||||
zh_Hans: 多个标签时使用半角逗号分隔
|
||||
human_description:
|
||||
en_US: |
|
||||
Only include tags, classes and ids from the page in the final output. Use comma separated values. Example: script, .ad, #footer
|
||||
zh_Hans: |
|
||||
仅在最终输出中包含HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer
|
||||
form: form
|
||||
- name: onlyMainContent
|
||||
type: boolean
|
||||
default: false
|
||||
label:
|
||||
en_US: only Main Content
|
||||
zh_Hans: 仅抓取主要内容
|
||||
human_description:
|
||||
en_US: Only return the main content of the page excluding headers, navs, footers, etc.
|
||||
zh_Hans: 只返回页面的主要内容,不包括头部、导航栏、尾部等。
|
||||
form: form
|
||||
- name: removeTags
|
||||
type: string
|
||||
label:
|
||||
en_US: remove Tags
|
||||
zh_Hans: 要移除这些标签
|
||||
human_description:
|
||||
en_US: |
|
||||
Tags, classes and ids to remove from the page. Use comma separated values. Example: script, .ad, #footer
|
||||
zh_Hans: |
|
||||
要在最终输出中移除HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer
|
||||
placeholder:
|
||||
en_US: Use commas to separate multiple tags
|
||||
zh_Hans: 多个标签时使用半角逗号分隔
|
||||
form: form
|
||||
- name: replaceAllPathsWithAbsolutePaths
|
||||
type: boolean
|
||||
default: false
|
||||
label:
|
||||
en_US: All AbsolutePaths
|
||||
zh_Hans: 使用绝对路径
|
||||
human_description:
|
||||
en_US: Replace all relative paths with absolute paths for images and links.
|
||||
zh_Hans: 将所有图片和链接的相对路径替换为绝对路径。
|
||||
form: form
|
||||
- name: screenshot
|
||||
type: boolean
|
||||
default: false
|
||||
label:
|
||||
en_US: screenshot
|
||||
zh_Hans: 截图
|
||||
human_description:
|
||||
en_US: Include a screenshot of the top of the page that you are scraping.
|
||||
zh_Hans: 提供正在抓取的页面的顶部的截图。
|
||||
form: form
|
||||
- name: waitFor
|
||||
type: number
|
||||
min: 0
|
||||
default: 0
|
||||
label:
|
||||
en_US: wait For
|
||||
zh_Hans: 等待时间
|
||||
@ -124,57 +99,54 @@ parameters:
|
||||
en_US: Wait x amount of milliseconds for the page to load to fetch content.
|
||||
zh_Hans: 等待x毫秒以使页面加载并获取内容。
|
||||
form: form
|
||||
- name: timeout
|
||||
type: number
|
||||
min: 0
|
||||
default: 30000
|
||||
label:
|
||||
en_US: Timeout
|
||||
human_description:
|
||||
en_US: Timeout in milliseconds for the request.
|
||||
zh_Hans: 请求的超时时间(以毫秒为单位)。
|
||||
form: form
|
||||
############## Extractor Options #######################
|
||||
- name: mode
|
||||
type: select
|
||||
options:
|
||||
- value: markdown
|
||||
label:
|
||||
en_US: markdown
|
||||
- value: llm-extraction
|
||||
label:
|
||||
en_US: llm-extraction
|
||||
- value: llm-extraction-from-raw-html
|
||||
label:
|
||||
en_US: llm-extraction-from-raw-html
|
||||
- value: llm-extraction-from-markdown
|
||||
label:
|
||||
en_US: llm-extraction-from-markdown
|
||||
label:
|
||||
en_US: Extractor Mode
|
||||
zh_Hans: 提取模式
|
||||
human_description:
|
||||
en_US: |
|
||||
The extraction mode to use. 'markdown': Returns the scraped markdown content, does not perform LLM extraction. 'llm-extraction': Extracts information from the cleaned and parsed content using LLM.
|
||||
zh_Hans: 使用的提取模式。“markdown”:返回抓取的markdown内容,不执行LLM提取。“llm-extractioin”:使用LLM按Extractor Schema从内容中提取信息。
|
||||
form: form
|
||||
- name: extractionPrompt
|
||||
type: string
|
||||
label:
|
||||
en_US: Extractor Prompt
|
||||
zh_Hans: 提取时的提示词
|
||||
human_description:
|
||||
en_US: A prompt describing what information to extract from the page, applicable for LLM extraction modes.
|
||||
zh_Hans: 当使用LLM提取模式时,用于给LLM描述提取规则。
|
||||
form: form
|
||||
- name: extractionSchema
|
||||
- name: schema
|
||||
type: string
|
||||
label:
|
||||
en_US: Extractor Schema
|
||||
zh_Hans: 提取时的结构
|
||||
placeholder:
|
||||
en_US: Please enter an object that can be serialized in JSON
|
||||
zh_Hans: 请输入可以json序列化的对象
|
||||
human_description:
|
||||
en_US: |
|
||||
The schema for the data to be extracted, required only for LLM extraction modes. Example: {
|
||||
The schema for the data to be extracted. Example: {
|
||||
"type": "object",
|
||||
"properties": {"company_mission": {"type": "string"}},
|
||||
"required": ["company_mission"]
|
||||
}
|
||||
zh_Hans: |
|
||||
当使用LLM提取模式时,使用该结构去提取,示例:{
|
||||
使用该结构去提取,示例:{
|
||||
"type": "object",
|
||||
"properties": {"company_mission": {"type": "string"}},
|
||||
"required": ["company_mission"]
|
||||
}
|
||||
form: form
|
||||
- name: systemPrompt
|
||||
type: string
|
||||
label:
|
||||
en_US: Extractor System Prompt
|
||||
zh_Hans: 提取时的系统提示词
|
||||
human_description:
|
||||
en_US: The system prompt to use for the extraction.
|
||||
zh_Hans: 用于提取的系统提示。
|
||||
form: form
|
||||
- name: prompt
|
||||
type: string
|
||||
label:
|
||||
en_US: Extractor Prompt
|
||||
zh_Hans: 提取时的提示词
|
||||
human_description:
|
||||
en_US: The prompt to use for the extraction without a schema.
|
||||
zh_Hans: 用于无schema时提取的提示词
|
||||
form: form
|
||||
|
||||
@ -1,28 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.provider.builtin.firecrawl.firecrawl_appx import FirecrawlApp
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class SearchTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
"""
|
||||
the pageOptions and searchOptions comes from doc here:
|
||||
https://docs.firecrawl.dev/api-reference/endpoint/search
|
||||
"""
|
||||
app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'],
|
||||
base_url=self.runtime.credentials['base_url'])
|
||||
pageOptions = {}
|
||||
pageOptions['onlyMainContent'] = tool_parameters.get('onlyMainContent', False)
|
||||
pageOptions['fetchPageContent'] = tool_parameters.get('fetchPageContent', True)
|
||||
pageOptions['includeHtml'] = tool_parameters.get('includeHtml', False)
|
||||
pageOptions['includeRawHtml'] = tool_parameters.get('includeRawHtml', False)
|
||||
searchOptions = {'limit': tool_parameters.get('limit')}
|
||||
search_result = app.search(
|
||||
query=tool_parameters['keyword'],
|
||||
pageOptions=pageOptions,
|
||||
searchOptions=searchOptions
|
||||
)
|
||||
|
||||
return self.create_json_message(search_result)
|
||||
@ -1,75 +0,0 @@
|
||||
identity:
|
||||
name: search
|
||||
author: ahasasjeb
|
||||
label:
|
||||
en_US: Search
|
||||
zh_Hans: 搜索
|
||||
description:
|
||||
human:
|
||||
en_US: Search, and output in Markdown format
|
||||
zh_Hans: 搜索,并且以Markdown格式输出
|
||||
llm: This tool can perform online searches and convert the results to Markdown format.
|
||||
parameters:
|
||||
- name: keyword
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: keyword
|
||||
zh_Hans: 关键词
|
||||
human_description:
|
||||
en_US: Input keywords to use Firecrawl API for search.
|
||||
zh_Hans: 输入关键词即可使用Firecrawl API进行搜索。
|
||||
llm_description: Efficiently extract keywords from user text.
|
||||
form: llm
|
||||
############## Page Options #######################
|
||||
- name: onlyMainContent
|
||||
type: boolean
|
||||
default: false
|
||||
label:
|
||||
en_US: only Main Content
|
||||
zh_Hans: 仅抓取主要内容
|
||||
human_description:
|
||||
en_US: Only return the main content of the page excluding headers, navs, footers, etc.
|
||||
zh_Hans: 只返回页面的主要内容,不包括头部、导航栏、尾部等。
|
||||
form: form
|
||||
- name: fetchPageContent
|
||||
type: boolean
|
||||
default: true
|
||||
label:
|
||||
en_US: fetch Page Content
|
||||
zh_Hans: 抓取页面内容
|
||||
human_description:
|
||||
en_US: Fetch the content of each page. If false, defaults to a basic fast serp API.
|
||||
zh_Hans: 获取每个页面的内容。如果为否,则使用基本的快速搜索结果页面API。
|
||||
form: form
|
||||
- name: includeHtml
|
||||
type: boolean
|
||||
default: false
|
||||
label:
|
||||
en_US: include Html
|
||||
zh_Hans: 包含HTML
|
||||
human_description:
|
||||
en_US: Include the HTML version of the content on page. Will output a html key in the response.
|
||||
zh_Hans: 返回中包含一个HTML版本的内容,将以html键返回。
|
||||
form: form
|
||||
- name: includeRawHtml
|
||||
type: boolean
|
||||
default: false
|
||||
label:
|
||||
en_US: include Raw Html
|
||||
zh_Hans: 包含原始HTML
|
||||
human_description:
|
||||
en_US: Include the raw HTML content of the page. Will output a rawHtml key in the response.
|
||||
zh_Hans: 返回中包含一个原始HTML版本的内容,将以rawHtml键返回。
|
||||
form: form
|
||||
############## Search Options #######################
|
||||
- name: limit
|
||||
type: number
|
||||
min: 0
|
||||
label:
|
||||
en_US: Maximum results
|
||||
zh_Hans: 最大结果数量
|
||||
human_description:
|
||||
en_US: Maximum number of results. Max is 20 during beta.
|
||||
zh_Hans: 最大结果数量。在测试阶段,最大为20。
|
||||
form: form
|
||||
@ -9,17 +9,19 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
|
||||
class GaodeProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
if 'api_key' not in credentials or not credentials.get('api_key'):
|
||||
if "api_key" not in credentials or not credentials.get("api_key"):
|
||||
raise ToolProviderCredentialValidationError("Gaode API key is required.")
|
||||
|
||||
try:
|
||||
response = requests.get(url="https://restapi.amap.com/v3/geocode/geo?address={address}&key={apikey}"
|
||||
"".format(address=urllib.parse.quote('广东省广州市天河区广州塔'),
|
||||
apikey=credentials.get('api_key')))
|
||||
if response.status_code == 200 and (response.json()).get('info') == 'OK':
|
||||
response = requests.get(
|
||||
url="https://restapi.amap.com/v3/geocode/geo?address={address}&key={apikey}".format(
|
||||
address=urllib.parse.quote("广东省广州市天河区广州塔"), apikey=credentials.get("api_key")
|
||||
)
|
||||
)
|
||||
if response.status_code == 200 and (response.json()).get("info") == "OK":
|
||||
pass
|
||||
else:
|
||||
raise ToolProviderCredentialValidationError((response.json()).get('info'))
|
||||
raise ToolProviderCredentialValidationError((response.json()).get("info"))
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError("Gaode API Key is invalid. {}".format(e))
|
||||
except Exception as e:
|
||||
|
||||
@ -8,50 +8,57 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class GaodeRepositoriesTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
invoke tools
|
||||
"""
|
||||
city = tool_parameters.get('city', '')
|
||||
city = tool_parameters.get("city", "")
|
||||
if not city:
|
||||
return self.create_text_message('Please tell me your city')
|
||||
return self.create_text_message("Please tell me your city")
|
||||
|
||||
if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'):
|
||||
if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"):
|
||||
return self.create_text_message("Gaode API key is required.")
|
||||
|
||||
try:
|
||||
s = requests.session()
|
||||
api_domain = 'https://restapi.amap.com/v3'
|
||||
city_response = s.request(method='GET', headers={"Content-Type": "application/json; charset=utf-8"},
|
||||
url="{url}/config/district?keywords={keywords}"
|
||||
"&subdistrict=0&extensions=base&key={apikey}"
|
||||
"".format(url=api_domain, keywords=city,
|
||||
apikey=self.runtime.credentials.get('api_key')))
|
||||
api_domain = "https://restapi.amap.com/v3"
|
||||
city_response = s.request(
|
||||
method="GET",
|
||||
headers={"Content-Type": "application/json; charset=utf-8"},
|
||||
url="{url}/config/district?keywords={keywords}&subdistrict=0&extensions=base&key={apikey}".format(
|
||||
url=api_domain, keywords=city, apikey=self.runtime.credentials.get("api_key")
|
||||
),
|
||||
)
|
||||
City_data = city_response.json()
|
||||
if city_response.status_code == 200 and City_data.get('info') == 'OK':
|
||||
if len(City_data.get('districts')) > 0:
|
||||
CityCode = City_data['districts'][0]['adcode']
|
||||
weatherInfo_response = s.request(method='GET',
|
||||
url="{url}/weather/weatherInfo?city={citycode}&extensions=all&key={apikey}&output=json"
|
||||
"".format(url=api_domain, citycode=CityCode,
|
||||
apikey=self.runtime.credentials.get('api_key')))
|
||||
if city_response.status_code == 200 and City_data.get("info") == "OK":
|
||||
if len(City_data.get("districts")) > 0:
|
||||
CityCode = City_data["districts"][0]["adcode"]
|
||||
weatherInfo_response = s.request(
|
||||
method="GET",
|
||||
url="{url}/weather/weatherInfo?city={citycode}&extensions=all&key={apikey}&output=json"
|
||||
"".format(url=api_domain, citycode=CityCode, apikey=self.runtime.credentials.get("api_key")),
|
||||
)
|
||||
weatherInfo_data = weatherInfo_response.json()
|
||||
if weatherInfo_response.status_code == 200 and weatherInfo_data.get('info') == 'OK':
|
||||
if weatherInfo_response.status_code == 200 and weatherInfo_data.get("info") == "OK":
|
||||
contents = []
|
||||
if len(weatherInfo_data.get('forecasts')) > 0:
|
||||
for item in weatherInfo_data['forecasts'][0]['casts']:
|
||||
if len(weatherInfo_data.get("forecasts")) > 0:
|
||||
for item in weatherInfo_data["forecasts"][0]["casts"]:
|
||||
content = {}
|
||||
content['date'] = item.get('date')
|
||||
content['week'] = item.get('week')
|
||||
content['dayweather'] = item.get('dayweather')
|
||||
content['daytemp_float'] = item.get('daytemp_float')
|
||||
content['daywind'] = item.get('daywind')
|
||||
content['nightweather'] = item.get('nightweather')
|
||||
content['nighttemp_float'] = item.get('nighttemp_float')
|
||||
content["date"] = item.get("date")
|
||||
content["week"] = item.get("week")
|
||||
content["dayweather"] = item.get("dayweather")
|
||||
content["daytemp_float"] = item.get("daytemp_float")
|
||||
content["daywind"] = item.get("daywind")
|
||||
content["nightweather"] = item.get("nightweather")
|
||||
content["nighttemp_float"] = item.get("nighttemp_float")
|
||||
contents.append(content)
|
||||
s.close()
|
||||
return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(contents, ensure_ascii=False)))
|
||||
return self.create_text_message(
|
||||
self.summary(user_id=user_id, content=json.dumps(contents, ensure_ascii=False))
|
||||
)
|
||||
s.close()
|
||||
return self.create_text_message(f'No weather information for {city} was found.')
|
||||
return self.create_text_message(f"No weather information for {city} was found.")
|
||||
except Exception as e:
|
||||
return self.create_text_message("Gaode API Key and Api Version is invalid. {}".format(e))
|
||||
|
||||
@ -7,16 +7,13 @@ class GetImgAIProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
# Example validation using the text2image tool
|
||||
Text2ImageTool().fork_tool_runtime(
|
||||
runtime={"credentials": credentials}
|
||||
).invoke(
|
||||
user_id='',
|
||||
Text2ImageTool().fork_tool_runtime(runtime={"credentials": credentials}).invoke(
|
||||
user_id="",
|
||||
tool_parameters={
|
||||
"prompt": "A fire egg",
|
||||
"response_format": "url",
|
||||
"style": "photorealism",
|
||||
}
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
@ -8,18 +8,16 @@ from requests.exceptions import HTTPError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GetImgAIApp:
|
||||
def __init__(self, api_key: str | None = None, base_url: str | None = None):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url or 'https://api.getimg.ai/v1'
|
||||
self.base_url = base_url or "https://api.getimg.ai/v1"
|
||||
if not self.api_key:
|
||||
raise ValueError("API key is required")
|
||||
|
||||
def _prepare_headers(self):
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {self.api_key}'
|
||||
}
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
return headers
|
||||
|
||||
def _request(
|
||||
@ -38,22 +36,20 @@ class GetImgAIApp:
|
||||
return response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
if i < retries - 1 and isinstance(e, HTTPError) and e.response.status_code >= 500:
|
||||
time.sleep(backoff_factor * (2 ** i))
|
||||
time.sleep(backoff_factor * (2**i))
|
||||
else:
|
||||
raise
|
||||
return None
|
||||
|
||||
def text2image(
|
||||
self, mode: str, **kwargs
|
||||
):
|
||||
data = kwargs['params']
|
||||
if not data.get('prompt'):
|
||||
def text2image(self, mode: str, **kwargs):
|
||||
data = kwargs["params"]
|
||||
if not data.get("prompt"):
|
||||
raise ValueError("Prompt is required")
|
||||
|
||||
endpoint = f'{self.base_url}/{mode}/text-to-image'
|
||||
endpoint = f"{self.base_url}/{mode}/text-to-image"
|
||||
headers = self._prepare_headers()
|
||||
logger.debug(f"Send request to {endpoint=} body={data}")
|
||||
response = self._request('POST', endpoint, data, headers)
|
||||
response = self._request("POST", endpoint, data, headers)
|
||||
if response is None:
|
||||
raise HTTPError("Failed to initiate getimg.ai after multiple retries")
|
||||
return response
|
||||
|
||||
@ -7,28 +7,28 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class Text2ImageTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
app = GetImgAIApp(api_key=self.runtime.credentials['getimg_api_key'], base_url=self.runtime.credentials['base_url'])
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
app = GetImgAIApp(
|
||||
api_key=self.runtime.credentials["getimg_api_key"], base_url=self.runtime.credentials["base_url"]
|
||||
)
|
||||
|
||||
options = {
|
||||
'style': tool_parameters.get('style'),
|
||||
'prompt': tool_parameters.get('prompt'),
|
||||
'aspect_ratio': tool_parameters.get('aspect_ratio'),
|
||||
'output_format': tool_parameters.get('output_format', 'jpeg'),
|
||||
'response_format': tool_parameters.get('response_format', 'url'),
|
||||
'width': tool_parameters.get('width'),
|
||||
'height': tool_parameters.get('height'),
|
||||
'steps': tool_parameters.get('steps'),
|
||||
'negative_prompt': tool_parameters.get('negative_prompt'),
|
||||
'prompt_2': tool_parameters.get('prompt_2'),
|
||||
"style": tool_parameters.get("style"),
|
||||
"prompt": tool_parameters.get("prompt"),
|
||||
"aspect_ratio": tool_parameters.get("aspect_ratio"),
|
||||
"output_format": tool_parameters.get("output_format", "jpeg"),
|
||||
"response_format": tool_parameters.get("response_format", "url"),
|
||||
"width": tool_parameters.get("width"),
|
||||
"height": tool_parameters.get("height"),
|
||||
"steps": tool_parameters.get("steps"),
|
||||
"negative_prompt": tool_parameters.get("negative_prompt"),
|
||||
"prompt_2": tool_parameters.get("prompt_2"),
|
||||
}
|
||||
options = {k: v for k, v in options.items() if v}
|
||||
|
||||
text2image_result = app.text2image(
|
||||
mode=tool_parameters.get('mode', 'essential-v2'),
|
||||
params=options,
|
||||
wait=True
|
||||
)
|
||||
text2image_result = app.text2image(mode=tool_parameters.get("mode", "essential-v2"), params=options, wait=True)
|
||||
|
||||
if not isinstance(text2image_result, str):
|
||||
text2image_result = json.dumps(text2image_result, ensure_ascii=False, indent=4)
|
||||
|
||||
@ -7,25 +7,25 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
|
||||
class GithubProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
if 'access_tokens' not in credentials or not credentials.get('access_tokens'):
|
||||
if "access_tokens" not in credentials or not credentials.get("access_tokens"):
|
||||
raise ToolProviderCredentialValidationError("Github API Access Tokens is required.")
|
||||
if 'api_version' not in credentials or not credentials.get('api_version'):
|
||||
api_version = '2022-11-28'
|
||||
if "api_version" not in credentials or not credentials.get("api_version"):
|
||||
api_version = "2022-11-28"
|
||||
else:
|
||||
api_version = credentials.get('api_version')
|
||||
api_version = credentials.get("api_version")
|
||||
|
||||
try:
|
||||
headers = {
|
||||
"Content-Type": "application/vnd.github+json",
|
||||
"Authorization": f"Bearer {credentials.get('access_tokens')}",
|
||||
"X-GitHub-Api-Version": api_version
|
||||
"X-GitHub-Api-Version": api_version,
|
||||
}
|
||||
|
||||
response = requests.get(
|
||||
url="https://api.github.com/search/users?q={account}".format(account='charli117'),
|
||||
headers=headers)
|
||||
url="https://api.github.com/search/users?q={account}".format(account="charli117"), headers=headers
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise ToolProviderCredentialValidationError((response.json()).get('message'))
|
||||
raise ToolProviderCredentialValidationError((response.json()).get("message"))
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError("Github API Key and Api Version is invalid. {}".format(e))
|
||||
except Exception as e:
|
||||
|
||||
@ -10,53 +10,61 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class GithubRepositoriesTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
invoke tools
|
||||
"""
|
||||
top_n = tool_parameters.get('top_n', 5)
|
||||
query = tool_parameters.get('query', '')
|
||||
top_n = tool_parameters.get("top_n", 5)
|
||||
query = tool_parameters.get("query", "")
|
||||
if not query:
|
||||
return self.create_text_message('Please input symbol')
|
||||
return self.create_text_message("Please input symbol")
|
||||
|
||||
if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'):
|
||||
if "access_tokens" not in self.runtime.credentials or not self.runtime.credentials.get("access_tokens"):
|
||||
return self.create_text_message("Github API Access Tokens is required.")
|
||||
if 'api_version' not in self.runtime.credentials or not self.runtime.credentials.get('api_version'):
|
||||
api_version = '2022-11-28'
|
||||
if "api_version" not in self.runtime.credentials or not self.runtime.credentials.get("api_version"):
|
||||
api_version = "2022-11-28"
|
||||
else:
|
||||
api_version = self.runtime.credentials.get('api_version')
|
||||
api_version = self.runtime.credentials.get("api_version")
|
||||
|
||||
try:
|
||||
headers = {
|
||||
"Content-Type": "application/vnd.github+json",
|
||||
"Authorization": f"Bearer {self.runtime.credentials.get('access_tokens')}",
|
||||
"X-GitHub-Api-Version": api_version
|
||||
"X-GitHub-Api-Version": api_version,
|
||||
}
|
||||
s = requests.session()
|
||||
api_domain = 'https://api.github.com'
|
||||
response = s.request(method='GET', headers=headers,
|
||||
url=f"{api_domain}/search/repositories?"
|
||||
f"q={quote(query)}&sort=stars&per_page={top_n}&order=desc")
|
||||
api_domain = "https://api.github.com"
|
||||
response = s.request(
|
||||
method="GET",
|
||||
headers=headers,
|
||||
url=f"{api_domain}/search/repositories?q={quote(query)}&sort=stars&per_page={top_n}&order=desc",
|
||||
)
|
||||
response_data = response.json()
|
||||
if response.status_code == 200 and isinstance(response_data.get('items'), list):
|
||||
if response.status_code == 200 and isinstance(response_data.get("items"), list):
|
||||
contents = []
|
||||
if len(response_data.get('items')) > 0:
|
||||
for item in response_data.get('items'):
|
||||
if len(response_data.get("items")) > 0:
|
||||
for item in response_data.get("items"):
|
||||
content = {}
|
||||
updated_at_object = datetime.strptime(item['updated_at'], "%Y-%m-%dT%H:%M:%SZ")
|
||||
content['owner'] = item['owner']['login']
|
||||
content['name'] = item['name']
|
||||
content['description'] = item['description'][:100] + '...' if len(item['description']) > 100 else item['description']
|
||||
content['url'] = item['html_url']
|
||||
content['star'] = item['watchers']
|
||||
content['forks'] = item['forks']
|
||||
content['updated'] = updated_at_object.strftime("%Y-%m-%d")
|
||||
updated_at_object = datetime.strptime(item["updated_at"], "%Y-%m-%dT%H:%M:%SZ")
|
||||
content["owner"] = item["owner"]["login"]
|
||||
content["name"] = item["name"]
|
||||
content["description"] = (
|
||||
item["description"][:100] + "..." if len(item["description"]) > 100 else item["description"]
|
||||
)
|
||||
content["url"] = item["html_url"]
|
||||
content["star"] = item["watchers"]
|
||||
content["forks"] = item["forks"]
|
||||
content["updated"] = updated_at_object.strftime("%Y-%m-%d")
|
||||
contents.append(content)
|
||||
s.close()
|
||||
return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(contents, ensure_ascii=False)))
|
||||
return self.create_text_message(
|
||||
self.summary(user_id=user_id, content=json.dumps(contents, ensure_ascii=False))
|
||||
)
|
||||
else:
|
||||
return self.create_text_message(f'No items related to {query} were found.')
|
||||
return self.create_text_message(f"No items related to {query} were found.")
|
||||
else:
|
||||
return self.create_text_message((response.json()).get('message'))
|
||||
return self.create_text_message((response.json()).get("message"))
|
||||
except Exception as e:
|
||||
return self.create_text_message("Github API Key and Api Version is invalid. {}".format(e))
|
||||
|
||||
@ -9,13 +9,13 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
|
||||
class GitlabProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
if 'access_tokens' not in credentials or not credentials.get('access_tokens'):
|
||||
if "access_tokens" not in credentials or not credentials.get("access_tokens"):
|
||||
raise ToolProviderCredentialValidationError("Gitlab Access Tokens is required.")
|
||||
|
||||
if 'site_url' not in credentials or not credentials.get('site_url'):
|
||||
site_url = 'https://gitlab.com'
|
||||
|
||||
if "site_url" not in credentials or not credentials.get("site_url"):
|
||||
site_url = "https://gitlab.com"
|
||||
else:
|
||||
site_url = credentials.get('site_url')
|
||||
site_url = credentials.get("site_url")
|
||||
|
||||
try:
|
||||
headers = {
|
||||
@ -23,12 +23,10 @@ class GitlabProvider(BuiltinToolProviderController):
|
||||
"Authorization": f"Bearer {credentials.get('access_tokens')}",
|
||||
}
|
||||
|
||||
response = requests.get(
|
||||
url= f"{site_url}/api/v4/user",
|
||||
headers=headers)
|
||||
response = requests.get(url=f"{site_url}/api/v4/user", headers=headers)
|
||||
if response.status_code != 200:
|
||||
raise ToolProviderCredentialValidationError((response.json()).get('message'))
|
||||
raise ToolProviderCredentialValidationError((response.json()).get("message"))
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError("Gitlab Access Tokens is invalid. {}".format(e))
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import urllib.parse
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Union
|
||||
|
||||
@ -9,103 +10,133 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class GitlabCommitsTool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
project = tool_parameters.get("project", "")
|
||||
repository = tool_parameters.get("repository", "")
|
||||
employee = tool_parameters.get("employee", "")
|
||||
start_time = tool_parameters.get("start_time", "")
|
||||
end_time = tool_parameters.get("end_time", "")
|
||||
change_type = tool_parameters.get("change_type", "all")
|
||||
|
||||
project = tool_parameters.get('project', '')
|
||||
employee = tool_parameters.get('employee', '')
|
||||
start_time = tool_parameters.get('start_time', '')
|
||||
end_time = tool_parameters.get('end_time', '')
|
||||
change_type = tool_parameters.get('change_type', 'all')
|
||||
|
||||
if not project:
|
||||
return self.create_text_message('Project is required')
|
||||
if not project and not repository:
|
||||
return self.create_text_message("Either project or repository is required")
|
||||
|
||||
if not start_time:
|
||||
start_time = (datetime.utcnow() - timedelta(days=1)).isoformat()
|
||||
if not end_time:
|
||||
end_time = datetime.utcnow().isoformat()
|
||||
|
||||
access_token = self.runtime.credentials.get('access_tokens')
|
||||
site_url = self.runtime.credentials.get('site_url')
|
||||
access_token = self.runtime.credentials.get("access_tokens")
|
||||
site_url = self.runtime.credentials.get("site_url")
|
||||
|
||||
if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'):
|
||||
if "access_tokens" not in self.runtime.credentials or not self.runtime.credentials.get("access_tokens"):
|
||||
return self.create_text_message("Gitlab API Access Tokens is required.")
|
||||
if 'site_url' not in self.runtime.credentials or not self.runtime.credentials.get('site_url'):
|
||||
site_url = 'https://gitlab.com'
|
||||
|
||||
if "site_url" not in self.runtime.credentials or not self.runtime.credentials.get("site_url"):
|
||||
site_url = "https://gitlab.com"
|
||||
|
||||
# Get commit content
|
||||
result = self.fetch(user_id, site_url, access_token, project, employee, start_time, end_time, change_type)
|
||||
if repository:
|
||||
result = self.fetch_commits(
|
||||
site_url, access_token, repository, employee, start_time, end_time, change_type, is_repository=True
|
||||
)
|
||||
else:
|
||||
result = self.fetch_commits(
|
||||
site_url, access_token, project, employee, start_time, end_time, change_type, is_repository=False
|
||||
)
|
||||
|
||||
return [self.create_json_message(item) for item in result]
|
||||
|
||||
def fetch(self,user_id: str, site_url: str, access_token: str, project: str, employee: str = None, start_time: str = '', end_time: str = '', change_type: str = '') -> list[dict[str, Any]]:
|
||||
|
||||
def fetch_commits(
|
||||
self,
|
||||
site_url: str,
|
||||
access_token: str,
|
||||
identifier: str,
|
||||
employee: str,
|
||||
start_time: str,
|
||||
end_time: str,
|
||||
change_type: str,
|
||||
is_repository: bool,
|
||||
) -> list[dict[str, Any]]:
|
||||
domain = site_url
|
||||
headers = {"PRIVATE-TOKEN": access_token}
|
||||
results = []
|
||||
|
||||
try:
|
||||
# Get all of projects
|
||||
url = f"{domain}/api/v4/projects"
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
projects = response.json()
|
||||
if is_repository:
|
||||
# URL encode the repository path
|
||||
encoded_identifier = urllib.parse.quote(identifier, safe="")
|
||||
commits_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/commits"
|
||||
else:
|
||||
# Get all projects
|
||||
url = f"{domain}/api/v4/projects"
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
projects = response.json()
|
||||
|
||||
filtered_projects = [p for p in projects if project == "*" or p['name'] == project]
|
||||
filtered_projects = [p for p in projects if identifier == "*" or p["name"] == identifier]
|
||||
|
||||
for project in filtered_projects:
|
||||
project_id = project['id']
|
||||
project_name = project['name']
|
||||
print(f"Project: {project_name}")
|
||||
for project in filtered_projects:
|
||||
project_id = project["id"]
|
||||
project_name = project["name"]
|
||||
print(f"Project: {project_name}")
|
||||
|
||||
# Get all of project commits
|
||||
commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits"
|
||||
params = {
|
||||
'since': start_time,
|
||||
'until': end_time
|
||||
}
|
||||
if employee:
|
||||
params['author'] = employee
|
||||
commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits"
|
||||
|
||||
commits_response = requests.get(commits_url, headers=headers, params=params)
|
||||
commits_response.raise_for_status()
|
||||
commits = commits_response.json()
|
||||
params = {"since": start_time, "until": end_time}
|
||||
if employee:
|
||||
params["author"] = employee
|
||||
|
||||
for commit in commits:
|
||||
commit_sha = commit['id']
|
||||
author_name = commit['author_name']
|
||||
commits_response = requests.get(commits_url, headers=headers, params=params)
|
||||
commits_response.raise_for_status()
|
||||
commits = commits_response.json()
|
||||
|
||||
for commit in commits:
|
||||
commit_sha = commit["id"]
|
||||
author_name = commit["author_name"]
|
||||
|
||||
if is_repository:
|
||||
diff_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/commits/{commit_sha}/diff"
|
||||
else:
|
||||
diff_url = f"{domain}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/diff"
|
||||
diff_response = requests.get(diff_url, headers=headers)
|
||||
diff_response.raise_for_status()
|
||||
diffs = diff_response.json()
|
||||
|
||||
for diff in diffs:
|
||||
# Calculate code lines of changed
|
||||
added_lines = diff['diff'].count('\n+')
|
||||
removed_lines = diff['diff'].count('\n-')
|
||||
total_changes = added_lines + removed_lines
|
||||
|
||||
if change_type == "new":
|
||||
if added_lines > 1:
|
||||
final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if line.startswith('+') and not line.startswith('+++')])
|
||||
results.append({
|
||||
"commit_sha": commit_sha,
|
||||
"author_name": author_name,
|
||||
"diff": final_code
|
||||
})
|
||||
else:
|
||||
if total_changes > 1:
|
||||
final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if (line.startswith('+') or line.startswith('-')) and not line.startswith('+++') and not line.startswith('---')])
|
||||
final_code_escaped = json.dumps(final_code)[1:-1] # Escape the final code
|
||||
results.append({
|
||||
"commit_sha": commit_sha,
|
||||
"author_name": author_name,
|
||||
"diff": final_code_escaped
|
||||
})
|
||||
diff_response = requests.get(diff_url, headers=headers)
|
||||
diff_response.raise_for_status()
|
||||
diffs = diff_response.json()
|
||||
|
||||
for diff in diffs:
|
||||
# Calculate code lines of changes
|
||||
added_lines = diff["diff"].count("\n+")
|
||||
removed_lines = diff["diff"].count("\n-")
|
||||
total_changes = added_lines + removed_lines
|
||||
|
||||
if change_type == "new":
|
||||
if added_lines > 1:
|
||||
final_code = "".join(
|
||||
[
|
||||
line[1:]
|
||||
for line in diff["diff"].split("\n")
|
||||
if line.startswith("+") and not line.startswith("+++")
|
||||
]
|
||||
)
|
||||
results.append({"commit_sha": commit_sha, "author_name": author_name, "diff": final_code})
|
||||
else:
|
||||
if total_changes > 1:
|
||||
final_code = "".join(
|
||||
[
|
||||
line[1:]
|
||||
for line in diff["diff"].split("\n")
|
||||
if (line.startswith("+") or line.startswith("-"))
|
||||
and not line.startswith("+++")
|
||||
and not line.startswith("---")
|
||||
]
|
||||
)
|
||||
final_code_escaped = json.dumps(final_code)[1:-1] # Escape the final code
|
||||
results.append(
|
||||
{"commit_sha": commit_sha, "author_name": author_name, "diff": final_code_escaped}
|
||||
)
|
||||
except requests.RequestException as e:
|
||||
print(f"Error fetching data from GitLab: {e}")
|
||||
|
||||
return results
|
||||
|
||||
return results
|
||||
|
||||
@ -21,9 +21,20 @@ parameters:
|
||||
zh_Hans: 员工用户名
|
||||
llm_description: User name for GitLab
|
||||
form: llm
|
||||
- name: repository
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: repository
|
||||
zh_Hans: 仓库路径
|
||||
human_description:
|
||||
en_US: repository
|
||||
zh_Hans: 仓库路径,以namespace/project_name的形式。
|
||||
llm_description: Repository path for GitLab, like namespace/project_name.
|
||||
form: llm
|
||||
- name: project
|
||||
type: string
|
||||
required: true
|
||||
required: false
|
||||
label:
|
||||
en_US: project
|
||||
zh_Hans: 项目名
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import urllib.parse
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
@ -7,47 +8,85 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class GitlabFilesTool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
|
||||
project = tool_parameters.get('project', '')
|
||||
branch = tool_parameters.get('branch', '')
|
||||
path = tool_parameters.get('path', '')
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
project = tool_parameters.get("project", "")
|
||||
repository = tool_parameters.get("repository", "")
|
||||
branch = tool_parameters.get("branch", "")
|
||||
path = tool_parameters.get("path", "")
|
||||
|
||||
|
||||
if not project:
|
||||
return self.create_text_message('Project is required')
|
||||
if not project and not repository:
|
||||
return self.create_text_message("Either project or repository is required")
|
||||
if not branch:
|
||||
return self.create_text_message('Branch is required')
|
||||
|
||||
return self.create_text_message("Branch is required")
|
||||
if not path:
|
||||
return self.create_text_message('Path is required')
|
||||
return self.create_text_message("Path is required")
|
||||
|
||||
access_token = self.runtime.credentials.get('access_tokens')
|
||||
site_url = self.runtime.credentials.get('site_url')
|
||||
access_token = self.runtime.credentials.get("access_tokens")
|
||||
site_url = self.runtime.credentials.get("site_url")
|
||||
|
||||
if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'):
|
||||
if "access_tokens" not in self.runtime.credentials or not self.runtime.credentials.get("access_tokens"):
|
||||
return self.create_text_message("Gitlab API Access Tokens is required.")
|
||||
if 'site_url' not in self.runtime.credentials or not self.runtime.credentials.get('site_url'):
|
||||
site_url = 'https://gitlab.com'
|
||||
|
||||
# Get project ID from project name
|
||||
project_id = self.get_project_id(site_url, access_token, project)
|
||||
if not project_id:
|
||||
return self.create_text_message(f"Project '{project}' not found.")
|
||||
if "site_url" not in self.runtime.credentials or not self.runtime.credentials.get("site_url"):
|
||||
site_url = "https://gitlab.com"
|
||||
|
||||
# Get commit content
|
||||
result = self.fetch(user_id, project_id, site_url, access_token, branch, path)
|
||||
# Get file content
|
||||
if repository:
|
||||
result = self.fetch_files(site_url, access_token, repository, branch, path, is_repository=True)
|
||||
else:
|
||||
result = self.fetch_files(site_url, access_token, project, branch, path, is_repository=False)
|
||||
|
||||
return [self.create_json_message(item) for item in result]
|
||||
|
||||
def extract_project_name_and_path(self, path: str) -> tuple[str, str]:
|
||||
parts = path.split('/', 1)
|
||||
if len(parts) < 2:
|
||||
return None, None
|
||||
return parts[0], parts[1]
|
||||
|
||||
def fetch_files(
|
||||
self, site_url: str, access_token: str, identifier: str, branch: str, path: str, is_repository: bool
|
||||
) -> list[dict[str, Any]]:
|
||||
domain = site_url
|
||||
headers = {"PRIVATE-TOKEN": access_token}
|
||||
results = []
|
||||
|
||||
try:
|
||||
if is_repository:
|
||||
# URL encode the repository path
|
||||
encoded_identifier = urllib.parse.quote(identifier, safe="")
|
||||
tree_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/tree?path={path}&ref={branch}"
|
||||
else:
|
||||
# Get project ID from project name
|
||||
project_id = self.get_project_id(site_url, access_token, identifier)
|
||||
if not project_id:
|
||||
return self.create_text_message(f"Project '{identifier}' not found.")
|
||||
tree_url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}"
|
||||
|
||||
response = requests.get(tree_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
items = response.json()
|
||||
|
||||
for item in items:
|
||||
item_path = item["path"]
|
||||
if item["type"] == "tree": # It's a directory
|
||||
results.extend(
|
||||
self.fetch_files(site_url, access_token, identifier, branch, item_path, is_repository)
|
||||
)
|
||||
else: # It's a file
|
||||
if is_repository:
|
||||
file_url = (
|
||||
f"{domain}/api/v4/projects/{encoded_identifier}/repository/files"
|
||||
f"/{item_path}/raw?ref={branch}"
|
||||
)
|
||||
else:
|
||||
file_url = (
|
||||
f"{domain}/api/v4/projects/{project_id}/repository/files/{item_path}/raw?ref={branch}"
|
||||
)
|
||||
|
||||
file_response = requests.get(file_url, headers=headers)
|
||||
file_response.raise_for_status()
|
||||
file_content = file_response.text
|
||||
results.append({"path": item_path, "branch": branch, "content": file_content})
|
||||
except requests.RequestException as e:
|
||||
print(f"Error fetching data from GitLab: {e}")
|
||||
|
||||
return results
|
||||
|
||||
def get_project_id(self, site_url: str, access_token: str, project_name: str) -> Union[str, None]:
|
||||
headers = {"PRIVATE-TOKEN": access_token}
|
||||
@ -57,39 +96,8 @@ class GitlabFilesTool(BuiltinTool):
|
||||
response.raise_for_status()
|
||||
projects = response.json()
|
||||
for project in projects:
|
||||
if project['name'] == project_name:
|
||||
return project['id']
|
||||
if project["name"] == project_name:
|
||||
return project["id"]
|
||||
except requests.RequestException as e:
|
||||
print(f"Error fetching project ID from GitLab: {e}")
|
||||
return None
|
||||
|
||||
def fetch(self,user_id: str, project_id: str, site_url: str, access_token: str, branch: str, path: str = None) -> list[dict[str, Any]]:
|
||||
domain = site_url
|
||||
headers = {"PRIVATE-TOKEN": access_token}
|
||||
results = []
|
||||
|
||||
try:
|
||||
# List files and directories in the given path
|
||||
url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}"
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
items = response.json()
|
||||
|
||||
for item in items:
|
||||
item_path = item['path']
|
||||
if item['type'] == 'tree': # It's a directory
|
||||
results.extend(self.fetch(project_id, site_url, access_token, branch, item_path))
|
||||
else: # It's a file
|
||||
file_url = f"{domain}/api/v4/projects/{project_id}/repository/files/{item_path}/raw?ref={branch}"
|
||||
file_response = requests.get(file_url, headers=headers)
|
||||
file_response.raise_for_status()
|
||||
file_content = file_response.text
|
||||
results.append({
|
||||
"path": item_path,
|
||||
"branch": branch,
|
||||
"content": file_content
|
||||
})
|
||||
except requests.RequestException as e:
|
||||
print(f"Error fetching data from GitLab: {e}")
|
||||
|
||||
return results
|
||||
@ -10,9 +10,20 @@ description:
|
||||
zh_Hans: 一个用于查询 GitLab 文件的工具,输入的内容应该是分支和一个已存在文件或者文件夹路径。
|
||||
llm: A tool for query GitLab files, Input should be a exists file or directory path.
|
||||
parameters:
|
||||
- name: repository
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: repository
|
||||
zh_Hans: 仓库路径
|
||||
human_description:
|
||||
en_US: repository
|
||||
zh_Hans: 仓库路径,以namespace/project_name的形式。
|
||||
llm_description: Repository path for GitLab, like namespace/project_name.
|
||||
form: llm
|
||||
- name: project
|
||||
type: string
|
||||
required: true
|
||||
required: false
|
||||
label:
|
||||
en_US: project
|
||||
zh_Hans: 项目
|
||||
|
||||
@ -13,12 +13,8 @@ class GoogleProvider(BuiltinToolProviderController):
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_parameters={
|
||||
"query": "test",
|
||||
"result_type": "link"
|
||||
},
|
||||
user_id="",
|
||||
tool_parameters={"query": "test", "result_type": "link"},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
@ -9,7 +9,6 @@ SERP_API_URL = "https://serpapi.com/search"
|
||||
|
||||
|
||||
class GoogleSearchTool(BuiltinTool):
|
||||
|
||||
def _parse_response(self, response: dict) -> dict:
|
||||
result = {}
|
||||
if "knowledge_graph" in response:
|
||||
@ -17,25 +16,23 @@ class GoogleSearchTool(BuiltinTool):
|
||||
result["description"] = response["knowledge_graph"].get("description", "")
|
||||
if "organic_results" in response:
|
||||
result["organic_results"] = [
|
||||
{
|
||||
"title": item.get("title", ""),
|
||||
"link": item.get("link", ""),
|
||||
"snippet": item.get("snippet", "")
|
||||
}
|
||||
{"title": item.get("title", ""), "link": item.get("link", ""), "snippet": item.get("snippet", "")}
|
||||
for item in response["organic_results"]
|
||||
]
|
||||
return result
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
params = {
|
||||
"api_key": self.runtime.credentials['serpapi_api_key'],
|
||||
"q": tool_parameters['query'],
|
||||
"api_key": self.runtime.credentials["serpapi_api_key"],
|
||||
"q": tool_parameters["query"],
|
||||
"engine": "google",
|
||||
"google_domain": "google.com",
|
||||
"gl": "us",
|
||||
"hl": "en"
|
||||
"hl": "en",
|
||||
}
|
||||
response = requests.get(url=SERP_API_URL, params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
@ -8,10 +8,6 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
|
||||
class JsonExtractProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
GoogleTranslate().invoke(user_id='',
|
||||
tool_parameters={
|
||||
"content": "这是一段测试文本",
|
||||
"dest": "en"
|
||||
})
|
||||
GoogleTranslate().invoke(user_id="", tool_parameters={"content": "这是一段测试文本", "dest": "en"})
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
@ -7,46 +7,41 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class GoogleTranslate(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
invoke tools
|
||||
"""
|
||||
content = tool_parameters.get('content', '')
|
||||
content = tool_parameters.get("content", "")
|
||||
if not content:
|
||||
return self.create_text_message('Invalid parameter content')
|
||||
return self.create_text_message("Invalid parameter content")
|
||||
|
||||
dest = tool_parameters.get('dest', '')
|
||||
dest = tool_parameters.get("dest", "")
|
||||
if not dest:
|
||||
return self.create_text_message('Invalid parameter destination language')
|
||||
return self.create_text_message("Invalid parameter destination language")
|
||||
|
||||
try:
|
||||
result = self._translate(content, dest)
|
||||
return self.create_text_message(str(result))
|
||||
except Exception:
|
||||
return self.create_text_message('Translation service error, please check the network')
|
||||
return self.create_text_message("Translation service error, please check the network")
|
||||
|
||||
def _translate(self, content: str, dest: str) -> str:
|
||||
try:
|
||||
url = "https://translate.googleapis.com/translate_a/single"
|
||||
params = {
|
||||
"client": "gtx",
|
||||
"sl": "auto",
|
||||
"tl": dest,
|
||||
"dt": "t",
|
||||
"q": content
|
||||
}
|
||||
params = {"client": "gtx", "sl": "auto", "tl": dest, "dt": "t", "q": content}
|
||||
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
|
||||
" Chrome/91.0.4472.124 Safari/537.36"
|
||||
}
|
||||
|
||||
response_json = requests.get(
|
||||
url, params=params, headers=headers).json()
|
||||
response_json = requests.get(url, params=params, headers=headers).json()
|
||||
result = response_json[0]
|
||||
translated_text = ''.join([item[0] for item in result if item[0]])
|
||||
translated_text = "".join([item[0] for item in result if item[0]])
|
||||
return str(translated_text)
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user