This commit is contained in:
takatost
2024-07-22 19:57:32 +08:00
372 changed files with 9779 additions and 1678 deletions

View File

@ -30,3 +30,4 @@
- feishu
- feishu_base
- slack
- tianditu

View File

@ -1,5 +1,5 @@
import base64
import random
from base64 import b64decode
from typing import Any, Union
from openai import OpenAI
@ -69,11 +69,50 @@ class DallE3Tool(BuiltinTool):
result = []
for image in response.data:
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
meta={'mime_type': 'image/png'},
save_as=self.VARIABLE_KEY.IMAGE.value))
mime_type, blob_image = DallE3Tool._decode_image(image.b64_json)
blob_message = self.create_blob_message(blob=blob_image,
meta={'mime_type': mime_type},
save_as=self.VARIABLE_KEY.IMAGE.value)
result.append(blob_message)
return result
@staticmethod
def _decode_image(base64_image: str) -> tuple[str, bytes]:
"""
Decode a base64 encoded image. If the image is not prefixed with a MIME type,
it assumes 'image/png' as the default.
:param base64_image: Base64 encoded image string
:return: A tuple containing the MIME type and the decoded image bytes
"""
if DallE3Tool._is_plain_base64(base64_image):
return 'image/png', base64.b64decode(base64_image)
else:
return DallE3Tool._extract_mime_and_data(base64_image)
@staticmethod
def _is_plain_base64(encoded_str: str) -> bool:
"""
Check if the given encoded string is plain base64 without a MIME type prefix.
:param encoded_str: Base64 encoded image string
:return: True if the string is plain base64, False otherwise
"""
return not encoded_str.startswith('data:image')
@staticmethod
def _extract_mime_and_data(encoded_str: str) -> tuple[str, bytes]:
"""
Extract MIME type and image data from a base64 encoded string with a MIME type prefix.
:param encoded_str: Base64 encoded image string with MIME type prefix
:return: A tuple containing the MIME type and the decoded image bytes
"""
mime_type = encoded_str.split(';')[0].split(':')[1]
image_data_base64 = encoded_str.split(',')[1]
decoded_data = base64.b64decode(image_data_base64)
return mime_type, decoded_data
@staticmethod
def _generate_random_id(length=8):
characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'

View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 256 256"><rect width="256" height="256" fill="none"/><rect x="32" y="48" width="192" height="160" rx="8" fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="16"/><circle cx="156" cy="100" r="12"/><path d="M147.31,164,173,138.34a8,8,0,0,1,11.31,0L224,178.06" fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="16"/><path d="M32,168.69l54.34-54.35a8,8,0,0,1,11.32,0L191.31,208" fill="none" stroke="#1553ed" stroke-linecap="round" stroke-linejoin="round" stroke-width="16"/></svg>

After

Width:  |  Height:  |  Size: 617 B

View File

@ -0,0 +1,22 @@
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.getimgai.tools.text2image import Text2ImageTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class GetImgAIProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
# Example validation using the text2image tool
Text2ImageTool().fork_tool_runtime(
runtime={"credentials": credentials}
).invoke(
user_id='',
tool_parameters={
"prompt": "A fire egg",
"response_format": "url",
"style": "photorealism",
}
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -0,0 +1,29 @@
identity:
author: Matri Qi
name: getimgai
label:
en_US: getimg.ai
zh_CN: getimg.ai
description:
en_US: GetImg API integration for image generation and scraping.
icon: icon.svg
tags:
- image
credentials_for_provider:
getimg_api_key:
type: secret-input
required: true
label:
en_US: getimg.ai API Key
placeholder:
en_US: Please input your getimg.ai API key
help:
en_US: Get your getimg.ai API key from your getimg.ai account settings. If you are using a self-hosted version, you may enter any key at your convenience.
url: https://dashboard.getimg.ai/api-keys
base_url:
type: text-input
required: false
label:
en_US: getimg.ai server's Base URL
placeholder:
en_US: https://api.getimg.ai/v1

View File

@ -0,0 +1,59 @@
import logging
import time
from collections.abc import Mapping
from typing import Any
import requests
from requests.exceptions import HTTPError
logger = logging.getLogger(__name__)
class GetImgAIApp:
def __init__(self, api_key: str | None = None, base_url: str | None = None):
self.api_key = api_key
self.base_url = base_url or 'https://api.getimg.ai/v1'
if not self.api_key:
raise ValueError("API key is required")
def _prepare_headers(self):
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.api_key}'
}
return headers
def _request(
self,
method: str,
url: str,
data: Mapping[str, Any] | None = None,
headers: Mapping[str, str] | None = None,
retries: int = 3,
backoff_factor: float = 0.3,
) -> Mapping[str, Any] | None:
for i in range(retries):
try:
response = requests.request(method, url, json=data, headers=headers)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
if i < retries - 1 and isinstance(e, HTTPError) and e.response.status_code >= 500:
time.sleep(backoff_factor * (2 ** i))
else:
raise
return None
def text2image(
self, mode: str, **kwargs
):
data = kwargs['params']
if not data.get('prompt'):
raise ValueError("Prompt is required")
endpoint = f'{self.base_url}/{mode}/text-to-image'
headers = self._prepare_headers()
logger.debug(f"Send request to {endpoint=} body={data}")
response = self._request('POST', endpoint, data, headers)
if response is None:
raise HTTPError("Failed to initiate getimg.ai after multiple retries")
return response

View File

@ -0,0 +1,39 @@
import json
from typing import Any, Union
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.provider.builtin.getimgai.getimgai_appx import GetImgAIApp
from core.tools.tool.builtin_tool import BuiltinTool
class Text2ImageTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
app = GetImgAIApp(api_key=self.runtime.credentials['getimg_api_key'], base_url=self.runtime.credentials['base_url'])
options = {
'style': tool_parameters.get('style'),
'prompt': tool_parameters.get('prompt'),
'aspect_ratio': tool_parameters.get('aspect_ratio'),
'output_format': tool_parameters.get('output_format', 'jpeg'),
'response_format': tool_parameters.get('response_format', 'url'),
'width': tool_parameters.get('width'),
'height': tool_parameters.get('height'),
'steps': tool_parameters.get('steps'),
'negative_prompt': tool_parameters.get('negative_prompt'),
'prompt_2': tool_parameters.get('prompt_2'),
}
options = {k: v for k, v in options.items() if v}
text2image_result = app.text2image(
mode=tool_parameters.get('mode', 'essential-v2'),
params=options,
wait=True
)
if not isinstance(text2image_result, str):
text2image_result = json.dumps(text2image_result, ensure_ascii=False, indent=4)
if not text2image_result:
return self.create_text_message("getimg.ai request failed.")
return self.create_text_message(text2image_result)

View File

@ -0,0 +1,167 @@
identity:
name: text2image
author: Matri Qi
label:
en_US: text2image
icon: icon.svg
description:
human:
en_US: Generate image via getimg.ai.
llm: This tool is used to generate image from prompt or image via https://getimg.ai.
parameters:
- name: prompt
type: string
required: true
label:
en_US: prompt
human_description:
en_US: The text prompt used to generate the image. The getimg.aier will generate an image based on this prompt.
llm_description: this prompt text will be used to generate image.
form: llm
- name: mode
type: select
required: false
label:
en_US: mode
human_description:
en_US: The getimg.ai mode to use. The mode determines the endpoint used to generate the image.
form: form
options:
- value: "essential-v2"
label:
en_US: essential-v2
- value: stable-diffusion-xl
label:
en_US: stable-diffusion-xl
- value: stable-diffusion
label:
en_US: stable-diffusion
- value: latent-consistency
label:
en_US: latent-consistency
- name: style
type: select
required: false
label:
en_US: style
human_description:
en_US: The style preset to use. The style preset guides the generation towards a particular style. It's just efficient for `Essential V2` mode.
form: form
options:
- value: photorealism
label:
en_US: photorealism
- value: anime
label:
en_US: anime
- value: art
label:
en_US: art
- name: aspect_ratio
type: select
required: false
label:
en_US: "aspect ratio"
human_description:
en_US: The aspect ratio of the generated image. It's just efficient for `Essential V2` mode.
form: form
options:
- value: "1:1"
label:
en_US: "1:1"
- value: "4:5"
label:
en_US: "4:5"
- value: "5:4"
label:
en_US: "5:4"
- value: "2:3"
label:
en_US: "2:3"
- value: "3:2"
label:
en_US: "3:2"
- value: "4:7"
label:
en_US: "4:7"
- value: "7:4"
label:
en_US: "7:4"
- name: output_format
type: select
required: false
label:
en_US: "output format"
human_description:
en_US: The file format of the generated image.
form: form
options:
- value: jpeg
label:
en_US: jpeg
- value: png
label:
en_US: png
- name: response_format
type: select
required: false
label:
en_US: "response format"
human_description:
en_US: The format in which the generated images are returned. Must be one of url or b64. URLs are only valid for 1 hour after the image has been generated.
form: form
options:
- value: url
label:
en_US: url
- value: b64
label:
en_US: b64
- name: model
type: string
required: false
label:
en_US: model
human_description:
en_US: Model ID supported by this pipeline and family. It's just efficient for `Stable Diffusion XL`, `Stable Diffusion`, `Latent Consistency` mode.
form: form
- name: negative_prompt
type: string
required: false
label:
en_US: negative prompt
human_description:
en_US: Text input that will not guide the image generation. It's just efficient for `Stable Diffusion XL`, `Stable Diffusion`, `Latent Consistency` mode.
form: form
- name: prompt_2
type: string
required: false
label:
en_US: prompt2
human_description:
en_US: Prompt sent to second tokenizer and text encoder. If not defined, prompt is used in both text-encoders. It's just efficient for `Stable Diffusion XL` mode.
form: form
- name: width
type: number
required: false
label:
en_US: width
human_description:
en_US: he width of the generated image in pixels. Width needs to be multiple of 64.
form: form
- name: height
type: number
required: false
label:
en_US: height
human_description:
en_US: he height of the generated image in pixels. Height needs to be multiple of 64.
form: form
- name: steps
type: number
required: false
label:
en_US: steps
human_description:
en_US: The number of denoising steps. More steps usually can produce higher quality images, but take more time to generate. It's just efficient for `Stable Diffusion XL`, `Stable Diffusion`, `Latent Consistency` mode.
form: form

View File

@ -19,28 +19,29 @@ class JSONDeleteTool(BuiltinTool):
content = tool_parameters.get('content', '')
if not content:
return self.create_text_message('Invalid parameter content')
# Get query
query = tool_parameters.get('query', '')
if not query:
return self.create_text_message('Invalid parameter query')
ensure_ascii = tool_parameters.get('ensure_ascii', True)
try:
result = self._delete(content, query)
result = self._delete(content, query, ensure_ascii)
return self.create_text_message(str(result))
except Exception as e:
return self.create_text_message(f'Failed to delete JSON content: {str(e)}')
def _delete(self, origin_json: str, query: str) -> str:
def _delete(self, origin_json: str, query: str, ensure_ascii: bool) -> str:
try:
input_data = json.loads(origin_json)
expr = parse('$.' + query.lstrip('$.')) # Ensure query path starts with $
matches = expr.find(input_data)
if not matches:
return json.dumps(input_data, ensure_ascii=True) # No changes if no matches found
return json.dumps(input_data, ensure_ascii=ensure_ascii) # No changes if no matches found
for match in matches:
if isinstance(match.context.value, dict):
# Delete key from dictionary
@ -53,7 +54,7 @@ class JSONDeleteTool(BuiltinTool):
parent = match.context.parent
if parent:
del parent.value[match.path.fields[-1]]
return json.dumps(input_data, ensure_ascii=True)
return json.dumps(input_data, ensure_ascii=ensure_ascii)
except Exception as e:
raise Exception(f"Delete operation failed: {str(e)}")
raise Exception(f"Delete operation failed: {str(e)}")

View File

@ -38,3 +38,15 @@ parameters:
pt_BR: JSONPath query to locate the element to delete
llm_description: JSONPath query to locate the element to delete
form: llm
- name: ensure_ascii
type: boolean
default: true
label:
en_US: Ensure ASCII
zh_Hans: 确保 ASCII
pt_BR: Ensure ASCII
human_description:
en_US: Ensure the JSON output is ASCII encoded
zh_Hans: 确保输出的 JSON 是 ASCII 编码
pt_BR: Ensure the JSON output is ASCII encoded
form: form

View File

@ -19,31 +19,31 @@ class JSONParseTool(BuiltinTool):
content = tool_parameters.get('content', '')
if not content:
return self.create_text_message('Invalid parameter content')
# get query
query = tool_parameters.get('query', '')
if not query:
return self.create_text_message('Invalid parameter query')
# get new value
new_value = tool_parameters.get('new_value', '')
if not new_value:
return self.create_text_message('Invalid parameter new_value')
# get insert position
index = tool_parameters.get('index')
# get create path
create_path = tool_parameters.get('create_path', False)
ensure_ascii = tool_parameters.get('ensure_ascii', True)
try:
result = self._insert(content, query, new_value, index, create_path)
result = self._insert(content, query, new_value, ensure_ascii, index, create_path)
return self.create_text_message(str(result))
except Exception:
return self.create_text_message('Failed to insert JSON content')
def _insert(self, origin_json, query, new_value, index=None, create_path=False):
def _insert(self, origin_json, query, new_value, ensure_ascii: bool, index=None, create_path=False):
try:
input_data = json.loads(origin_json)
expr = parse(query)
@ -51,9 +51,9 @@ class JSONParseTool(BuiltinTool):
new_value = json.loads(new_value)
except json.JSONDecodeError:
new_value = new_value
matches = expr.find(input_data)
if not matches and create_path:
# create new path
path_parts = query.strip('$').strip('.').split('.')
@ -91,7 +91,7 @@ class JSONParseTool(BuiltinTool):
else:
# replace old value with new value
match.full_path.update(input_data, new_value)
return json.dumps(input_data, ensure_ascii=True)
return json.dumps(input_data, ensure_ascii=ensure_ascii)
except Exception as e:
return str(e)
return str(e)

View File

@ -75,3 +75,15 @@ parameters:
zh_Hans:
pt_BR: "No"
form: form
- name: ensure_ascii
type: boolean
default: true
label:
en_US: Ensure ASCII
zh_Hans: 确保 ASCII
pt_BR: Ensure ASCII
human_description:
en_US: Ensure the JSON output is ASCII encoded
zh_Hans: 确保输出的 JSON 是 ASCII 编码
pt_BR: Ensure the JSON output is ASCII encoded
form: form

View File

@ -19,33 +19,34 @@ class JSONParseTool(BuiltinTool):
content = tool_parameters.get('content', '')
if not content:
return self.create_text_message('Invalid parameter content')
# get json filter
json_filter = tool_parameters.get('json_filter', '')
if not json_filter:
return self.create_text_message('Invalid parameter json_filter')
ensure_ascii = tool_parameters.get('ensure_ascii', True)
try:
result = self._extract(content, json_filter)
result = self._extract(content, json_filter, ensure_ascii)
return self.create_text_message(str(result))
except Exception:
return self.create_text_message('Failed to extract JSON content')
# Extract data from JSON content
def _extract(self, content: str, json_filter: str) -> str:
def _extract(self, content: str, json_filter: str, ensure_ascii: bool) -> str:
try:
input_data = json.loads(content)
expr = parse(json_filter)
result = [match.value for match in expr.find(input_data)]
if len(result) == 1:
result = result[0]
if isinstance(result, dict | list):
return json.dumps(result, ensure_ascii=True)
return json.dumps(result, ensure_ascii=ensure_ascii)
elif isinstance(result, str | int | float | bool) or result is None:
return str(result)
else:
return repr(result)
except Exception as e:
return str(e)
return str(e)

View File

@ -38,3 +38,15 @@ parameters:
pt_BR: JSON fields to be parsed
llm_description: JSON fields to be parsed
form: llm
- name: ensure_ascii
type: boolean
default: true
label:
en_US: Ensure ASCII
zh_Hans: 确保 ASCII
pt_BR: Ensure ASCII
human_description:
en_US: Ensure the JSON output is ASCII encoded
zh_Hans: 确保输出的 JSON 是 ASCII 编码
pt_BR: Ensure the JSON output is ASCII encoded
form: form

View File

@ -19,61 +19,62 @@ class JSONReplaceTool(BuiltinTool):
content = tool_parameters.get('content', '')
if not content:
return self.create_text_message('Invalid parameter content')
# get query
query = tool_parameters.get('query', '')
if not query:
return self.create_text_message('Invalid parameter query')
# get replace value
replace_value = tool_parameters.get('replace_value', '')
if not replace_value:
return self.create_text_message('Invalid parameter replace_value')
# get replace model
replace_model = tool_parameters.get('replace_model', '')
if not replace_model:
return self.create_text_message('Invalid parameter replace_model')
ensure_ascii = tool_parameters.get('ensure_ascii', True)
try:
if replace_model == 'pattern':
# get replace pattern
replace_pattern = tool_parameters.get('replace_pattern', '')
if not replace_pattern:
return self.create_text_message('Invalid parameter replace_pattern')
result = self._replace_pattern(content, query, replace_pattern, replace_value)
result = self._replace_pattern(content, query, replace_pattern, replace_value, ensure_ascii)
elif replace_model == 'key':
result = self._replace_key(content, query, replace_value)
result = self._replace_key(content, query, replace_value, ensure_ascii)
elif replace_model == 'value':
result = self._replace_value(content, query, replace_value)
result = self._replace_value(content, query, replace_value, ensure_ascii)
return self.create_text_message(str(result))
except Exception:
return self.create_text_message('Failed to replace JSON content')
# Replace pattern
def _replace_pattern(self, content: str, query: str, replace_pattern: str, replace_value: str) -> str:
def _replace_pattern(self, content: str, query: str, replace_pattern: str, replace_value: str, ensure_ascii: bool) -> str:
try:
input_data = json.loads(content)
expr = parse(query)
matches = expr.find(input_data)
for match in matches:
new_value = match.value.replace(replace_pattern, replace_value)
match.full_path.update(input_data, new_value)
return json.dumps(input_data, ensure_ascii=True)
return json.dumps(input_data, ensure_ascii=ensure_ascii)
except Exception as e:
return str(e)
# Replace key
def _replace_key(self, content: str, query: str, replace_value: str) -> str:
def _replace_key(self, content: str, query: str, replace_value: str, ensure_ascii: bool) -> str:
try:
input_data = json.loads(content)
expr = parse(query)
matches = expr.find(input_data)
for match in matches:
parent = match.context.value
if isinstance(parent, dict):
@ -86,21 +87,21 @@ class JSONReplaceTool(BuiltinTool):
if isinstance(item, dict) and old_key in item:
value = item.pop(old_key)
item[replace_value] = value
return json.dumps(input_data, ensure_ascii=True)
return json.dumps(input_data, ensure_ascii=ensure_ascii)
except Exception as e:
return str(e)
# Replace value
def _replace_value(self, content: str, query: str, replace_value: str) -> str:
def _replace_value(self, content: str, query: str, replace_value: str, ensure_ascii: bool) -> str:
try:
input_data = json.loads(content)
expr = parse(query)
matches = expr.find(input_data)
for match in matches:
match.full_path.update(input_data, replace_value)
return json.dumps(input_data, ensure_ascii=True)
return json.dumps(input_data, ensure_ascii=ensure_ascii)
except Exception as e:
return str(e)
return str(e)

View File

@ -93,3 +93,15 @@ parameters:
zh_Hans: 字符串替换
pt_BR: replace string
form: form
- name: ensure_ascii
type: boolean
default: true
label:
en_US: Ensure ASCII
zh_Hans: 确保 ASCII
pt_BR: Ensure ASCII
human_description:
en_US: Ensure the JSON output is ASCII encoded
zh_Hans: 确保输出的 JSON 是 ASCII 编码
pt_BR: Ensure the JSON output is ASCII encoded
form: form

View File

@ -0,0 +1 @@
<svg height="30" width="30" viewBox="0 0 36 34" xml:space="preserve" xmlns="http://www.w3.org/2000/svg" class="fill-accent-foreground transition-all group-hover:scale-110"><title>Spider v1 Logo</title><path fill-rule="evenodd" clip-rule="evenodd" d="M9.13883 7.06589V0.164429L13.0938 0.164429V6.175L14.5178 7.4346C15.577 6.68656 16.7337 6.27495 17.945 6.27495C19.1731 6.27495 20.3451 6.69807 21.4163 7.46593L22.8757 6.175V0.164429L26.8307 0.164429V7.06589V7.95679L26.1634 8.54706L24.0775 10.3922C24.3436 10.8108 24.5958 11.2563 24.8327 11.7262L26.0467 11.4215L28.6971 8.08749L31.793 10.5487L28.7257 14.407L28.3089 14.9313L27.6592 15.0944L26.2418 15.4502C26.3124 15.7082 26.3793 15.9701 26.4422 16.2355L28.653 16.6566L29.092 16.7402L29.4524 17.0045L35.3849 21.355L33.0461 24.5444L27.474 20.4581L27.0719 20.3816C27.1214 21.0613 27.147 21.7543 27.147 22.4577C27.147 22.5398 27.1466 22.6214 27.1459 22.7024L29.5889 23.7911L30.3219 24.1177L30.62 24.8629L33.6873 32.5312L30.0152 34L27.246 27.0769L26.7298 26.8469C25.5612 32.2432 22.0701 33.8808 17.945 33.8808C13.8382 33.8808 10.3598 32.2577 9.17593 26.9185L8.82034 27.0769L6.05109 34L2.37897 32.5312L5.44629 24.8629L5.74435 24.1177L6.47743 23.7911L8.74487 22.7806C8.74366 22.6739 8.74305 22.5663 8.74305 22.4577C8.74305 21.7616 8.76804 21.0758 8.81654 20.4028L8.52606 20.4581L2.95395 24.5444L0.615112 21.355L6.54761 17.0045L6.908 16.7402L7.34701 16.6566L9.44264 16.2575C9.50917 15.9756 9.5801 15.6978 9.65528 15.4242L8.34123 15.0944L7.69155 14.9313L7.27471 14.407L4.20739 10.5487L7.30328 8.08749L9.95376 11.4215L11.0697 11.7016C11.3115 11.2239 11.5692 10.7716 11.8412 10.3473L9.80612 8.54706L9.13883 7.95679V7.06589Z"></path></svg>

After

Width:  |  Height:  |  Size: 1.6 KiB

View File

@ -0,0 +1,14 @@
from typing import Any
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.spider.spiderApp import Spider
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class SpiderProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
app = Spider(api_key=credentials["spider_api_key"])
app.scrape_url(url="https://spider.cloud")
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -0,0 +1,27 @@
identity:
author: William Espegren
name: spider
label:
en_US: Spider
zh_CN: Spider
description:
en_US: Spider API integration, returning LLM-ready data by scraping & crawling websites.
zh_CN: Spider API 集成,通过爬取和抓取网站返回 LLM-ready 数据。
icon: icon.svg
tags:
- search
- utilities
credentials_for_provider:
spider_api_key:
type: secret-input
required: true
label:
en_US: Spider API Key
zh_CN: Spider API 密钥
placeholder:
en_US: Please input your Spider API key
zh_CN: 请输入您的 Spider API 密钥
help:
en_US: Get your Spider API key from your Spider dashboard
zh_CN: 从您的 Spider 仪表板中获取 Spider API 密钥。
url: https://spider.cloud/

View File

@ -0,0 +1,237 @@
import os
from typing import Literal, Optional, TypedDict
import requests
class RequestParamsDict(TypedDict, total=False):
url: Optional[str]
request: Optional[Literal["http", "chrome", "smart"]]
limit: Optional[int]
return_format: Optional[Literal["raw", "markdown", "html2text", "text", "bytes"]]
tld: Optional[bool]
depth: Optional[int]
cache: Optional[bool]
budget: Optional[dict[str, int]]
locale: Optional[str]
cookies: Optional[str]
stealth: Optional[bool]
headers: Optional[dict[str, str]]
anti_bot: Optional[bool]
metadata: Optional[bool]
viewport: Optional[dict[str, int]]
encoding: Optional[str]
subdomains: Optional[bool]
user_agent: Optional[str]
store_data: Optional[bool]
gpt_config: Optional[list[str]]
fingerprint: Optional[bool]
storageless: Optional[bool]
readability: Optional[bool]
proxy_enabled: Optional[bool]
respect_robots: Optional[bool]
query_selector: Optional[str]
full_resources: Optional[bool]
request_timeout: Optional[int]
run_in_background: Optional[bool]
skip_config_checks: Optional[bool]
class Spider:
def __init__(self, api_key: Optional[str] = None):
"""
Initialize the Spider with an API key.
:param api_key: A string of the API key for Spider. Defaults to the SPIDER_API_KEY environment variable.
:raises ValueError: If no API key is provided.
"""
self.api_key = api_key or os.getenv("SPIDER_API_KEY")
if self.api_key is None:
raise ValueError("No API key provided")
def api_post(
self,
endpoint: str,
data: dict,
stream: bool,
content_type: str = "application/json",
):
"""
Send a POST request to the specified API endpoint.
:param endpoint: The API endpoint to which the POST request is sent.
:param data: The data (dictionary) to be sent in the POST request.
:param stream: Boolean indicating if the response should be streamed.
:return: The JSON response or the raw response stream if stream is True.
"""
headers = self._prepare_headers(content_type)
response = self._post_request(
f"https://api.spider.cloud/v1/{endpoint}", data, headers, stream
)
if stream:
return response
elif response.status_code == 200:
return response.json()
else:
self._handle_error(response, f"post to {endpoint}")
def api_get(
self, endpoint: str, stream: bool, content_type: str = "application/json"
):
"""
Send a GET request to the specified endpoint.
:param endpoint: The API endpoint from which to retrieve data.
:return: The JSON decoded response.
"""
headers = self._prepare_headers(content_type)
response = self._get_request(
f"https://api.spider.cloud/v1/{endpoint}", headers, stream
)
if response.status_code == 200:
return response.json()
else:
self._handle_error(response, f"get from {endpoint}")
def get_credits(self):
"""
Retrieve the account's remaining credits.
:return: JSON response containing the number of credits left.
"""
return self.api_get("credits", stream=False)
def scrape_url(
self,
url: str,
params: Optional[RequestParamsDict] = None,
stream: bool = False,
content_type: str = "application/json",
):
"""
Scrape data from the specified URL.
:param url: The URL from which to scrape data.
:param params: Optional dictionary of additional parameters for the scrape request.
:return: JSON response containing the scraping results.
"""
# Add { "return_format": "markdown" } to the params if not already present
if "return_format" not in params:
params["return_format"] = "markdown"
# Set limit to 1
params["limit"] = 1
return self.api_post(
"crawl", {"url": url, **(params or {})}, stream, content_type
)
def crawl_url(
self,
url: str,
params: Optional[RequestParamsDict] = None,
stream: bool = False,
content_type: str = "application/json",
):
"""
Start crawling at the specified URL.
:param url: The URL to begin crawling.
:param params: Optional dictionary with additional parameters to customize the crawl.
:param stream: Boolean indicating if the response should be streamed. Defaults to False.
:return: JSON response or the raw response stream if streaming enabled.
"""
# Add { "return_format": "markdown" } to the params if not already present
if "return_format" not in params:
params["return_format"] = "markdown"
return self.api_post(
"crawl", {"url": url, **(params or {})}, stream, content_type
)
def links(
self,
url: str,
params: Optional[RequestParamsDict] = None,
stream: bool = False,
content_type: str = "application/json",
):
"""
Retrieve links from the specified URL.
:param url: The URL from which to extract links.
:param params: Optional parameters for the link retrieval request.
:return: JSON response containing the links.
"""
return self.api_post(
"links", {"url": url, **(params or {})}, stream, content_type
)
def extract_contacts(
self,
url: str,
params: Optional[RequestParamsDict] = None,
stream: bool = False,
content_type: str = "application/json",
):
"""
Extract contact information from the specified URL.
:param url: The URL from which to extract contact information.
:param params: Optional parameters for the contact extraction.
:return: JSON response containing extracted contact details.
"""
return self.api_post(
"pipeline/extract-contacts",
{"url": url, **(params or {})},
stream,
content_type,
)
def label(
self,
url: str,
params: Optional[RequestParamsDict] = None,
stream: bool = False,
content_type: str = "application/json",
):
"""
Apply labeling to data extracted from the specified URL.
:param url: The URL to label data from.
:param params: Optional parameters to guide the labeling process.
:return: JSON response with labeled data.
"""
return self.api_post(
"pipeline/label", {"url": url, **(params or {})}, stream, content_type
)
def _prepare_headers(self, content_type: str = "application/json"):
return {
"Content-Type": content_type,
"Authorization": f"Bearer {self.api_key}",
"User-Agent": "Spider-Client/0.0.27",
}
def _post_request(self, url: str, data, headers, stream=False):
return requests.post(url, headers=headers, json=data, stream=stream)
def _get_request(self, url: str, headers, stream=False):
return requests.get(url, headers=headers, stream=stream)
def _delete_request(self, url: str, headers, stream=False):
return requests.delete(url, headers=headers, stream=stream)
def _handle_error(self, response, action):
if response.status_code in [402, 409, 500]:
error_message = response.json().get("error", "Unknown error occurred")
raise Exception(
f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}"
)
else:
raise Exception(
f"Unexpected error occurred while trying to {action}. Status code: {response.status_code}"
)

View File

@ -0,0 +1,47 @@
from typing import Any, Union
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.provider.builtin.spider.spiderApp import Spider
from core.tools.tool.builtin_tool import BuiltinTool
class ScrapeTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
# initialize the app object with the api key
app = Spider(api_key=self.runtime.credentials['spider_api_key'])
url = tool_parameters['url']
mode = tool_parameters['mode']
options = {
'limit': tool_parameters.get('limit', 0),
'depth': tool_parameters.get('depth', 0),
'blacklist': tool_parameters.get('blacklist', '').split(',') if tool_parameters.get('blacklist') else [],
'whitelist': tool_parameters.get('whitelist', '').split(',') if tool_parameters.get('whitelist') else [],
'readability': tool_parameters.get('readability', False),
}
result = ""
try:
if mode == 'scrape':
scrape_result = app.scrape_url(
url=url,
params=options,
)
for i in scrape_result:
result += "URL: " + i.get('url', '') + "\n"
result += "CONTENT: " + i.get('content', '') + "\n\n"
elif mode == 'crawl':
crawl_result = app.crawl_url(
url=tool_parameters['url'],
params=options,
)
for i in crawl_result:
result += "URL: " + i.get('url', '') + "\n"
result += "CONTENT: " + i.get('content', '') + "\n\n"
except Exception as e:
return self.create_text_message("An error occured", str(e))
return self.create_text_message(result)

View File

@ -0,0 +1,102 @@
identity:
name: scraper_crawler
author: William Espegren
label:
en_US: Web Scraper & Crawler
zh_Hans: 网页抓取与爬虫
description:
human:
en_US: A tool for scraping & crawling webpages. Input should be a url.
zh_Hans: 用于抓取和爬取网页的工具。输入应该是一个网址。
llm: A tool for scraping & crawling webpages. Input should be a url.
parameters:
- name: url
type: string
required: true
label:
en_US: URL
zh_Hans: 网址
human_description:
en_US: url to be scraped or crawled
zh_Hans: 要抓取或爬取的网址
llm_description: url to either be scraped or crawled
form: llm
- name: mode
type: select
required: true
options:
- value: scrape
label:
en_US: scrape
zh_Hans: 抓取
- value: crawl
label:
en_US: crawl
zh_Hans: 爬取
default: crawl
label:
en_US: Mode
zh_Hans: 模式
human_description:
en_US: used for selecting to either scrape the website or crawl the entire website following subpages
zh_Hans: 用于选择抓取网站或爬取整个网站及其子页面
form: form
- name: limit
type: number
required: false
label:
en_US: maximum number of pages to crawl
zh_Hans: 最大爬取页面数
human_description:
en_US: specify the maximum number of pages to crawl per website. the crawler will stop after reaching this limit.
zh_Hans: 指定每个网站要爬取的最大页面数。爬虫将在达到此限制后停止。
form: form
min: 0
default: 0
- name: depth
type: number
required: false
label:
en_US: maximum depth of pages to crawl
zh_Hans: 最大爬取深度
human_description:
en_US: the crawl limit for maximum depth.
zh_Hans: 最大爬取深度的限制。
form: form
min: 0
default: 0
- name: blacklist
type: string
required: false
label:
en_US: url patterns to exclude
zh_Hans: 要排除的URL模式
human_description:
en_US: blacklist a set of paths that you do not want to crawl. you can use regex patterns to help with the list.
zh_Hans: 指定一组不想爬取的路径。您可以使用正则表达式模式来帮助定义列表。
placeholder:
en_US: /blog/*, /about
form: form
- name: whitelist
type: string
required: false
label:
en_US: URL patterns to include
zh_Hans: 要包含的URL模式
human_description:
en_US: Whitelist a set of paths that you want to crawl, ignoring all other routes that do not match the patterns. You can use regex patterns to help with the list.
zh_Hans: 指定一组要爬取的路径,忽略所有不匹配模式的其他路由。您可以使用正则表达式模式来帮助定义列表。
placeholder:
en_US: /blog/*, /about
form: form
- name: readability
type: boolean
required: false
label:
en_US: Pre-process the content for LLM usage
zh_Hans: 仅返回页面的主要内容
human_description:
en_US: Use Mozilla's readability to pre-process the content for reading. This may drastically improve the content for LLM usage.
zh_Hans: 如果启用,爬虫将仅返回页面的主要内容,不包括标题、导航、页脚等。
form: form
default: false

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 23 KiB

View File

@ -0,0 +1,21 @@
from typing import Any
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.tianditu.tools.poisearch import PoiSearchTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class TiandituProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
PoiSearchTool().fork_tool_runtime(
runtime={
"credentials": credentials,
}
).invoke(user_id='',
tool_parameters={
'content': '北京',
'specify': '156110000',
})
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@ -0,0 +1,32 @@
identity:
author: Listeng
name: tianditu
label:
en_US: Tianditu
zh_Hans: 天地图
pt_BR: Tianditu
description:
en_US: The Tianditu tool provided the functions of place name search, geocoding, static maps generation, etc. in China region.
zh_Hans: 天地图工具可以调用天地图的接口,实现中国区域内的地名搜索、地理编码、静态地图等功能。
pt_BR: The Tianditu tool provided the functions of place name search, geocoding, static maps generation, etc. in China region.
icon: icon.svg
tags:
- utilities
- travel
credentials_for_provider:
tianditu_api_key:
type: secret-input
required: true
label:
en_US: Tianditu API Key
zh_Hans: 天地图Key
pt_BR: Tianditu API key
placeholder:
en_US: Please input your Tianditu API key
zh_Hans: 请输入你的天地图Key
pt_BR: Please input your Tianditu API key
help:
en_US: Get your Tianditu API key from Tianditu
zh_Hans: 获取您的天地图Key
pt_BR: Get your Tianditu API key from Tianditu
url: http://lbs.tianditu.gov.cn/home.html

View File

@ -0,0 +1,33 @@
import json
from typing import Any, Union
import requests
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class GeocoderTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
base_url = 'http://api.tianditu.gov.cn/geocoder'
keyword = tool_parameters.get('keyword', '')
if not keyword:
return self.create_text_message('Invalid parameter keyword')
tk = self.runtime.credentials['tianditu_api_key']
params = {
'keyWord': keyword,
}
result = requests.get(base_url + '?ds=' + json.dumps(params, ensure_ascii=False) + '&tk=' + tk).json()
return self.create_json_message(result)

View File

@ -0,0 +1,26 @@
identity:
name: geocoder
author: Listeng
label:
en_US: Get coords converted from address name
zh_Hans: 地理编码
pt_BR: Get coords converted from address name
description:
human:
en_US: Geocoder
zh_Hans: 中国区域地理编码查询
pt_BR: Geocoder
llm: A tool for geocoder in China
parameters:
- name: keyword
type: string
required: true
label:
en_US: keyword
zh_Hans: 搜索的关键字
pt_BR: keyword
human_description:
en_US: keyword
zh_Hans: 搜索的关键字
pt_BR: keyword
form: llm

View File

@ -0,0 +1,45 @@
import json
from typing import Any, Union
import requests
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class PoiSearchTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
geocoder_base_url = 'http://api.tianditu.gov.cn/geocoder'
base_url = 'http://api.tianditu.gov.cn/v2/search'
keyword = tool_parameters.get('keyword', '')
if not keyword:
return self.create_text_message('Invalid parameter keyword')
baseAddress = tool_parameters.get('baseAddress', '')
if not baseAddress:
return self.create_text_message('Invalid parameter baseAddress')
tk = self.runtime.credentials['tianditu_api_key']
base_coords = requests.get(geocoder_base_url + '?ds=' + json.dumps({'keyWord': baseAddress,}, ensure_ascii=False) + '&tk=' + tk).json()
params = {
'keyWord': keyword,
'queryRadius': 5000,
'queryType': 3,
'pointLonlat': base_coords['location']['lon'] + ',' + base_coords['location']['lat'],
'start': 0,
'count': 100,
}
result = requests.get(base_url + '?postStr=' + json.dumps(params, ensure_ascii=False) + '&type=query&tk=' + tk).json()
return self.create_json_message(result)

View File

@ -0,0 +1,38 @@
identity:
name: point_of_interest_search
author: Listeng
label:
en_US: Point of Interest search
zh_Hans: 兴趣点搜索
pt_BR: Point of Interest search
description:
human:
en_US: Search for certain types of points of interest around a location
zh_Hans: 搜索某个位置周边的5公里内某种类型的兴趣点
pt_BR: Search for certain types of points of interest around a location
llm: A tool for searching for certain types of points of interest around a location
parameters:
- name: keyword
type: string
required: true
label:
en_US: poi keyword
zh_Hans: 兴趣点的关键字
pt_BR: poi keyword
human_description:
en_US: poi keyword
zh_Hans: 兴趣点的关键字
pt_BR: poi keyword
form: llm
- name: baseAddress
type: string
required: true
label:
en_US: base current point
zh_Hans: 当前位置的关键字
pt_BR: base current point
human_description:
en_US: base current point
zh_Hans: 当前位置的关键字
pt_BR: base current point
form: llm

View File

@ -0,0 +1,36 @@
import json
from typing import Any, Union
import requests
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class PoiSearchTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
geocoder_base_url = 'http://api.tianditu.gov.cn/geocoder'
base_url = 'http://api.tianditu.gov.cn/staticimage'
keyword = tool_parameters.get('keyword', '')
if not keyword:
return self.create_text_message('Invalid parameter keyword')
tk = self.runtime.credentials['tianditu_api_key']
keyword_coords = requests.get(geocoder_base_url + '?ds=' + json.dumps({'keyWord': keyword,}, ensure_ascii=False) + '&tk=' + tk).json()
coords = keyword_coords['location']['lon'] + ',' + keyword_coords['location']['lat']
result = requests.get(base_url + '?center=' + coords + '&markers=' + coords + '&width=400&height=300&zoom=14&tk=' + tk).content
return self.create_blob_message(blob=result,
meta={'mime_type': 'image/png'},
save_as=self.VARIABLE_KEY.IMAGE.value)

View File

@ -0,0 +1,26 @@
identity:
name: generate_static_map
author: Listeng
label:
en_US: Generate a static map
zh_Hans: 生成静态地图
pt_BR: Generate a static map
description:
human:
en_US: Generate a static map
zh_Hans: 生成静态地图
pt_BR: Generate a static map
llm: A tool for generate a static map
parameters:
- name: keyword
type: string
required: true
label:
en_US: keyword
zh_Hans: 搜索的关键字
pt_BR: keyword
human_description:
en_US: keyword
zh_Hans: 搜索的关键字
pt_BR: keyword
form: llm

View File

@ -14,7 +14,7 @@ from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
default_retrieval_model = {
'search_method': RetrievalMethod.SEMANTIC_SEARCH,
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',

View File

@ -8,7 +8,7 @@ from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
default_retrieval_model = {
'search_method': RetrievalMethod.SEMANTIC_SEARCH,
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',

View File

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping
from copy import deepcopy
from enum import Enum
from typing import Any, Optional, Union
@ -190,8 +191,9 @@ class Tool(BaseModel, ABC):
return result
def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
def invoke(self, user_id: str, tool_parameters: Mapping[str, Any]) -> list[ToolInvokeMessage]:
# update tool_parameters
# TODO: Fix type error.
if self.runtime.runtime_parameters:
tool_parameters.update(self.runtime.runtime_parameters)
@ -208,7 +210,7 @@ class Tool(BaseModel, ABC):
return result
def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
def _transform_tool_parameters_type(self, tool_parameters: Mapping[str, Any]) -> dict[str, Any]:
"""
Transform tool parameters type
"""
@ -241,7 +243,7 @@ class Tool(BaseModel, ABC):
:return: the runtime parameters
"""
return self.parameters
return self.parameters or []
def get_all_runtime_parameters(self) -> list[ToolParameter]:
"""

View File

@ -1,4 +1,5 @@
import json
from collections.abc import Mapping
from copy import deepcopy
from datetime import datetime, timezone
from mimetypes import guess_type
@ -46,7 +47,7 @@ class ToolEngine:
if isinstance(tool_parameters, str):
# check if this tool has only one parameter
parameters = [
parameter for parameter in tool.get_runtime_parameters()
parameter for parameter in tool.get_runtime_parameters() or []
if parameter.form == ToolParameter.ToolParameterForm.LLM
]
if parameters and len(parameters) == 1:
@ -123,8 +124,8 @@ class ToolEngine:
return error_response, [], ToolInvokeMeta.error_instance(error_response)
@staticmethod
def workflow_invoke(tool: Tool, tool_parameters: dict,
user_id: str, workflow_id: str,
def workflow_invoke(tool: Tool, tool_parameters: Mapping[str, Any],
user_id: str,
workflow_tool_callback: DifyWorkflowCallbackHandler,
workflow_call_depth: int,
) -> list[ToolInvokeMessage]:
@ -141,7 +142,9 @@ class ToolEngine:
if isinstance(tool, WorkflowTool):
tool.workflow_call_depth = workflow_call_depth + 1
response = tool.invoke(user_id, tool_parameters)
if tool.runtime and tool.runtime.runtime_parameters:
tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters}
response = tool.invoke(user_id=user_id, tool_parameters=tool_parameters)
# hit the callback handler
workflow_tool_callback.on_tool_end(

View File

@ -9,9 +9,9 @@ from mimetypes import guess_extension, guess_type
from typing import Optional, Union
from uuid import uuid4
from flask import current_app
from httpx import get
from configs import dify_config
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import MessageFile
@ -26,25 +26,25 @@ class ToolFileManager:
"""
sign file to get a temporary url
"""
base_url = current_app.config.get('FILES_URL')
base_url = dify_config.FILES_URL
file_preview_url = f'{base_url}/files/tools/{tool_file_id}{extension}'
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}"
secret_key = current_app.config['SECRET_KEY'].encode()
data_to_sign = f'file-preview|{tool_file_id}|{timestamp}|{nonce}'
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b''
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
return f'{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}'
@staticmethod
def verify_file(file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
"""
verify signature
"""
data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}"
secret_key = current_app.config['SECRET_KEY'].encode()
data_to_sign = f'file-preview|{file_id}|{timestamp}|{nonce}'
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b''
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
@ -53,23 +53,23 @@ class ToolFileManager:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= current_app.config.get('FILES_ACCESS_TIMEOUT')
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
@staticmethod
def create_file_by_raw(user_id: str, tenant_id: str,
conversation_id: Optional[str], file_binary: bytes,
mimetype: str
) -> ToolFile:
def create_file_by_raw(
user_id: str, tenant_id: str, conversation_id: Optional[str], file_binary: bytes, mimetype: str
) -> ToolFile:
"""
create file
"""
extension = guess_extension(mimetype) or '.bin'
unique_name = uuid4().hex
filename = f"tools/{tenant_id}/{unique_name}{extension}"
filename = f'tools/{tenant_id}/{unique_name}{extension}'
storage.save(filename, file_binary)
tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id,
conversation_id=conversation_id, file_key=filename, mimetype=mimetype)
tool_file = ToolFile(
user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_key=filename, mimetype=mimetype
)
db.session.add(tool_file)
db.session.commit()
@ -77,9 +77,12 @@ class ToolFileManager:
return tool_file
@staticmethod
def create_file_by_url(user_id: str, tenant_id: str,
conversation_id: str, file_url: str,
) -> ToolFile:
def create_file_by_url(
user_id: str,
tenant_id: str,
conversation_id: str,
file_url: str,
) -> ToolFile:
"""
create file
"""
@ -90,12 +93,17 @@ class ToolFileManager:
mimetype = guess_type(file_url)[0] or 'octet/stream'
extension = guess_extension(mimetype) or '.bin'
unique_name = uuid4().hex
filename = f"tools/{tenant_id}/{unique_name}{extension}"
filename = f'tools/{tenant_id}/{unique_name}{extension}'
storage.save(filename, blob)
tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id,
conversation_id=conversation_id, file_key=filename,
mimetype=mimetype, original_url=file_url)
tool_file = ToolFile(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=conversation_id,
file_key=filename,
mimetype=mimetype,
original_url=file_url,
)
db.session.add(tool_file)
db.session.commit()
@ -103,15 +111,15 @@ class ToolFileManager:
return tool_file
@staticmethod
def create_file_by_key(user_id: str, tenant_id: str,
conversation_id: str, file_key: str,
mimetype: str
) -> ToolFile:
def create_file_by_key(
user_id: str, tenant_id: str, conversation_id: str, file_key: str, mimetype: str
) -> ToolFile:
"""
create file
"""
tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id,
conversation_id=conversation_id, file_key=file_key, mimetype=mimetype)
tool_file = ToolFile(
user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_key=file_key, mimetype=mimetype
)
return tool_file
@staticmethod
@ -123,9 +131,13 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
tool_file: ToolFile = db.session.query(ToolFile).filter(
ToolFile.id == id,
).first()
tool_file: ToolFile = (
db.session.query(ToolFile)
.filter(
ToolFile.id == id,
)
.first()
)
if not tool_file:
return None
@ -143,18 +155,31 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
message_file: MessageFile = db.session.query(MessageFile).filter(
MessageFile.id == id,
).first()
message_file: MessageFile = (
db.session.query(MessageFile)
.filter(
MessageFile.id == id,
)
.first()
)
# get tool file id
tool_file_id = message_file.url.split('/')[-1]
# trim extension
tool_file_id = tool_file_id.split('.')[0]
# Check if message_file is not None
if message_file is not None:
# get tool file id
tool_file_id = message_file.url.split('/')[-1]
# trim extension
tool_file_id = tool_file_id.split('.')[0]
else:
tool_file_id = None
tool_file: ToolFile = db.session.query(ToolFile).filter(
ToolFile.id == tool_file_id,
).first()
tool_file: ToolFile = (
db.session.query(ToolFile)
.filter(
ToolFile.id == tool_file_id,
)
.first()
)
if not tool_file:
return None
@ -172,9 +197,13 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
tool_file: ToolFile = db.session.query(ToolFile).filter(
ToolFile.id == tool_file_id,
).first()
tool_file: ToolFile = (
db.session.query(ToolFile)
.filter(
ToolFile.id == tool_file_id,
)
.first()
)
if not tool_file:
return None

View File

@ -6,8 +6,7 @@ from os import listdir, path
from threading import Lock
from typing import Any, Union
from flask import current_app
from configs import dify_config
from core.agent.entities import AgentToolEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.helper.module_import_helper import load_single_subclass_from_source
@ -565,7 +564,7 @@ class ToolManager:
provider_type = provider_type
provider_id = provider_id
if provider_type == 'builtin':
return (current_app.config.get("CONSOLE_API_URL")
return (dify_config.CONSOLE_API_URL
+ "/console/api/workspaces/current/tool-provider/builtin/"
+ provider_id
+ "/icon")
@ -574,7 +573,7 @@ class ToolManager:
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.id == provider_id
)
).first()
return json.loads(provider.icon)
except:
return {
@ -593,4 +592,4 @@ class ToolManager:
else:
raise ValueError(f"provider type {provider_type} not found")
ToolManager.load_builtin_providers_cache()
ToolManager.load_builtin_providers_cache()

View File

@ -10,6 +10,7 @@ import unicodedata
from contextlib import contextmanager
from urllib.parse import unquote
import cloudscraper
import requests
from bs4 import BeautifulSoup, CData, Comment, NavigableString
from newspaper import Article
@ -46,29 +47,34 @@ def get_url(url: str, user_agent: str = None) -> str:
supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10))
if response.status_code == 200:
# check content-type
content_type = response.headers.get('Content-Type')
if content_type:
main_content_type = response.headers.get('Content-Type').split(';')[0].strip()
else:
content_disposition = response.headers.get('Content-Disposition', '')
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
if filename_match:
filename = unquote(filename_match.group(1))
extension = re.search(r'\.(\w+)$', filename)
if extension:
main_content_type = mimetypes.guess_type(filename)[0]
if main_content_type not in supported_content_types:
return "Unsupported content-type [{}] of URL.".format(main_content_type)
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
return ExtractProcessor.load_from_url(url, return_text=True)
response = requests.get(url, headers=headers, allow_redirects=True, timeout=(120, 300))
elif response.status_code == 403:
scraper = cloudscraper.create_scraper()
response = scraper.get(url, headers=headers, allow_redirects=True, timeout=(120, 300))
if response.status_code != 200:
return "URL returned status code {}.".format(response.status_code)
# check content-type
content_type = response.headers.get('Content-Type')
if content_type:
main_content_type = response.headers.get('Content-Type').split(';')[0].strip()
else:
content_disposition = response.headers.get('Content-Disposition', '')
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
if filename_match:
filename = unquote(filename_match.group(1))
extension = re.search(r'\.(\w+)$', filename)
if extension:
main_content_type = mimetypes.guess_type(filename)[0]
if main_content_type not in supported_content_types:
return "Unsupported content-type [{}] of URL.".format(main_content_type)
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
return ExtractProcessor.load_from_url(url, return_text=True)
response = requests.get(url, headers=headers, allow_redirects=True, timeout=(120, 300))
a = extract_using_readabilipy(response.text)
if not a['plain_text'] or not a['plain_text'].strip():