mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 09:28:04 +08:00
fix bug
This commit is contained in:
@ -30,3 +30,4 @@
|
||||
- feishu
|
||||
- feishu_base
|
||||
- slack
|
||||
- tianditu
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import base64
|
||||
import random
|
||||
from base64 import b64decode
|
||||
from typing import Any, Union
|
||||
|
||||
from openai import OpenAI
|
||||
@ -69,11 +69,50 @@ class DallE3Tool(BuiltinTool):
|
||||
result = []
|
||||
|
||||
for image in response.data:
|
||||
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
|
||||
meta={'mime_type': 'image/png'},
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value))
|
||||
mime_type, blob_image = DallE3Tool._decode_image(image.b64_json)
|
||||
blob_message = self.create_blob_message(blob=blob_image,
|
||||
meta={'mime_type': mime_type},
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value)
|
||||
result.append(blob_message)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _decode_image(base64_image: str) -> tuple[str, bytes]:
|
||||
"""
|
||||
Decode a base64 encoded image. If the image is not prefixed with a MIME type,
|
||||
it assumes 'image/png' as the default.
|
||||
|
||||
:param base64_image: Base64 encoded image string
|
||||
:return: A tuple containing the MIME type and the decoded image bytes
|
||||
"""
|
||||
if DallE3Tool._is_plain_base64(base64_image):
|
||||
return 'image/png', base64.b64decode(base64_image)
|
||||
else:
|
||||
return DallE3Tool._extract_mime_and_data(base64_image)
|
||||
|
||||
@staticmethod
|
||||
def _is_plain_base64(encoded_str: str) -> bool:
|
||||
"""
|
||||
Check if the given encoded string is plain base64 without a MIME type prefix.
|
||||
|
||||
:param encoded_str: Base64 encoded image string
|
||||
:return: True if the string is plain base64, False otherwise
|
||||
"""
|
||||
return not encoded_str.startswith('data:image')
|
||||
|
||||
@staticmethod
|
||||
def _extract_mime_and_data(encoded_str: str) -> tuple[str, bytes]:
|
||||
"""
|
||||
Extract MIME type and image data from a base64 encoded string with a MIME type prefix.
|
||||
|
||||
:param encoded_str: Base64 encoded image string with MIME type prefix
|
||||
:return: A tuple containing the MIME type and the decoded image bytes
|
||||
"""
|
||||
mime_type = encoded_str.split(';')[0].split(':')[1]
|
||||
image_data_base64 = encoded_str.split(',')[1]
|
||||
decoded_data = base64.b64decode(image_data_base64)
|
||||
return mime_type, decoded_data
|
||||
|
||||
@staticmethod
|
||||
def _generate_random_id(length=8):
|
||||
characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
|
||||
|
||||
@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 256 256"><rect width="256" height="256" fill="none"/><rect x="32" y="48" width="192" height="160" rx="8" fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="16"/><circle cx="156" cy="100" r="12"/><path d="M147.31,164,173,138.34a8,8,0,0,1,11.31,0L224,178.06" fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="16"/><path d="M32,168.69l54.34-54.35a8,8,0,0,1,11.32,0L191.31,208" fill="none" stroke="#1553ed" stroke-linecap="round" stroke-linejoin="round" stroke-width="16"/></svg>
|
||||
|
After Width: | Height: | Size: 617 B |
22
api/core/tools/provider/builtin/getimgai/getimgai.py
Normal file
22
api/core/tools/provider/builtin/getimgai/getimgai.py
Normal file
@ -0,0 +1,22 @@
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.getimgai.tools.text2image import Text2ImageTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class GetImgAIProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
# Example validation using the text2image tool
|
||||
Text2ImageTool().fork_tool_runtime(
|
||||
runtime={"credentials": credentials}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_parameters={
|
||||
"prompt": "A fire egg",
|
||||
"response_format": "url",
|
||||
"style": "photorealism",
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
29
api/core/tools/provider/builtin/getimgai/getimgai.yaml
Normal file
29
api/core/tools/provider/builtin/getimgai/getimgai.yaml
Normal file
@ -0,0 +1,29 @@
|
||||
identity:
|
||||
author: Matri Qi
|
||||
name: getimgai
|
||||
label:
|
||||
en_US: getimg.ai
|
||||
zh_CN: getimg.ai
|
||||
description:
|
||||
en_US: GetImg API integration for image generation and scraping.
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- image
|
||||
credentials_for_provider:
|
||||
getimg_api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: getimg.ai API Key
|
||||
placeholder:
|
||||
en_US: Please input your getimg.ai API key
|
||||
help:
|
||||
en_US: Get your getimg.ai API key from your getimg.ai account settings. If you are using a self-hosted version, you may enter any key at your convenience.
|
||||
url: https://dashboard.getimg.ai/api-keys
|
||||
base_url:
|
||||
type: text-input
|
||||
required: false
|
||||
label:
|
||||
en_US: getimg.ai server's Base URL
|
||||
placeholder:
|
||||
en_US: https://api.getimg.ai/v1
|
||||
59
api/core/tools/provider/builtin/getimgai/getimgai_appx.py
Normal file
59
api/core/tools/provider/builtin/getimgai/getimgai_appx.py
Normal file
@ -0,0 +1,59 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class GetImgAIApp:
|
||||
def __init__(self, api_key: str | None = None, base_url: str | None = None):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url or 'https://api.getimg.ai/v1'
|
||||
if not self.api_key:
|
||||
raise ValueError("API key is required")
|
||||
|
||||
def _prepare_headers(self):
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {self.api_key}'
|
||||
}
|
||||
return headers
|
||||
|
||||
def _request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
data: Mapping[str, Any] | None = None,
|
||||
headers: Mapping[str, str] | None = None,
|
||||
retries: int = 3,
|
||||
backoff_factor: float = 0.3,
|
||||
) -> Mapping[str, Any] | None:
|
||||
for i in range(retries):
|
||||
try:
|
||||
response = requests.request(method, url, json=data, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
if i < retries - 1 and isinstance(e, HTTPError) and e.response.status_code >= 500:
|
||||
time.sleep(backoff_factor * (2 ** i))
|
||||
else:
|
||||
raise
|
||||
return None
|
||||
|
||||
def text2image(
|
||||
self, mode: str, **kwargs
|
||||
):
|
||||
data = kwargs['params']
|
||||
if not data.get('prompt'):
|
||||
raise ValueError("Prompt is required")
|
||||
|
||||
endpoint = f'{self.base_url}/{mode}/text-to-image'
|
||||
headers = self._prepare_headers()
|
||||
logger.debug(f"Send request to {endpoint=} body={data}")
|
||||
response = self._request('POST', endpoint, data, headers)
|
||||
if response is None:
|
||||
raise HTTPError("Failed to initiate getimg.ai after multiple retries")
|
||||
return response
|
||||
39
api/core/tools/provider/builtin/getimgai/tools/text2image.py
Normal file
39
api/core/tools/provider/builtin/getimgai/tools/text2image.py
Normal file
@ -0,0 +1,39 @@
|
||||
import json
|
||||
from typing import Any, Union
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.provider.builtin.getimgai.getimgai_appx import GetImgAIApp
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class Text2ImageTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
app = GetImgAIApp(api_key=self.runtime.credentials['getimg_api_key'], base_url=self.runtime.credentials['base_url'])
|
||||
|
||||
options = {
|
||||
'style': tool_parameters.get('style'),
|
||||
'prompt': tool_parameters.get('prompt'),
|
||||
'aspect_ratio': tool_parameters.get('aspect_ratio'),
|
||||
'output_format': tool_parameters.get('output_format', 'jpeg'),
|
||||
'response_format': tool_parameters.get('response_format', 'url'),
|
||||
'width': tool_parameters.get('width'),
|
||||
'height': tool_parameters.get('height'),
|
||||
'steps': tool_parameters.get('steps'),
|
||||
'negative_prompt': tool_parameters.get('negative_prompt'),
|
||||
'prompt_2': tool_parameters.get('prompt_2'),
|
||||
}
|
||||
options = {k: v for k, v in options.items() if v}
|
||||
|
||||
text2image_result = app.text2image(
|
||||
mode=tool_parameters.get('mode', 'essential-v2'),
|
||||
params=options,
|
||||
wait=True
|
||||
)
|
||||
|
||||
if not isinstance(text2image_result, str):
|
||||
text2image_result = json.dumps(text2image_result, ensure_ascii=False, indent=4)
|
||||
|
||||
if not text2image_result:
|
||||
return self.create_text_message("getimg.ai request failed.")
|
||||
|
||||
return self.create_text_message(text2image_result)
|
||||
167
api/core/tools/provider/builtin/getimgai/tools/text2image.yaml
Normal file
167
api/core/tools/provider/builtin/getimgai/tools/text2image.yaml
Normal file
@ -0,0 +1,167 @@
|
||||
identity:
|
||||
name: text2image
|
||||
author: Matri Qi
|
||||
label:
|
||||
en_US: text2image
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: Generate image via getimg.ai.
|
||||
llm: This tool is used to generate image from prompt or image via https://getimg.ai.
|
||||
parameters:
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: prompt
|
||||
human_description:
|
||||
en_US: The text prompt used to generate the image. The getimg.aier will generate an image based on this prompt.
|
||||
llm_description: this prompt text will be used to generate image.
|
||||
form: llm
|
||||
- name: mode
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: mode
|
||||
human_description:
|
||||
en_US: The getimg.ai mode to use. The mode determines the endpoint used to generate the image.
|
||||
form: form
|
||||
options:
|
||||
- value: "essential-v2"
|
||||
label:
|
||||
en_US: essential-v2
|
||||
- value: stable-diffusion-xl
|
||||
label:
|
||||
en_US: stable-diffusion-xl
|
||||
- value: stable-diffusion
|
||||
label:
|
||||
en_US: stable-diffusion
|
||||
- value: latent-consistency
|
||||
label:
|
||||
en_US: latent-consistency
|
||||
- name: style
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: style
|
||||
human_description:
|
||||
en_US: The style preset to use. The style preset guides the generation towards a particular style. It's just efficient for `Essential V2` mode.
|
||||
form: form
|
||||
options:
|
||||
- value: photorealism
|
||||
label:
|
||||
en_US: photorealism
|
||||
- value: anime
|
||||
label:
|
||||
en_US: anime
|
||||
- value: art
|
||||
label:
|
||||
en_US: art
|
||||
- name: aspect_ratio
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: "aspect ratio"
|
||||
human_description:
|
||||
en_US: The aspect ratio of the generated image. It's just efficient for `Essential V2` mode.
|
||||
form: form
|
||||
options:
|
||||
- value: "1:1"
|
||||
label:
|
||||
en_US: "1:1"
|
||||
- value: "4:5"
|
||||
label:
|
||||
en_US: "4:5"
|
||||
- value: "5:4"
|
||||
label:
|
||||
en_US: "5:4"
|
||||
- value: "2:3"
|
||||
label:
|
||||
en_US: "2:3"
|
||||
- value: "3:2"
|
||||
label:
|
||||
en_US: "3:2"
|
||||
- value: "4:7"
|
||||
label:
|
||||
en_US: "4:7"
|
||||
- value: "7:4"
|
||||
label:
|
||||
en_US: "7:4"
|
||||
- name: output_format
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: "output format"
|
||||
human_description:
|
||||
en_US: The file format of the generated image.
|
||||
form: form
|
||||
options:
|
||||
- value: jpeg
|
||||
label:
|
||||
en_US: jpeg
|
||||
- value: png
|
||||
label:
|
||||
en_US: png
|
||||
- name: response_format
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: "response format"
|
||||
human_description:
|
||||
en_US: The format in which the generated images are returned. Must be one of url or b64. URLs are only valid for 1 hour after the image has been generated.
|
||||
form: form
|
||||
options:
|
||||
- value: url
|
||||
label:
|
||||
en_US: url
|
||||
- value: b64
|
||||
label:
|
||||
en_US: b64
|
||||
- name: model
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: model
|
||||
human_description:
|
||||
en_US: Model ID supported by this pipeline and family. It's just efficient for `Stable Diffusion XL`, `Stable Diffusion`, `Latent Consistency` mode.
|
||||
form: form
|
||||
- name: negative_prompt
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: negative prompt
|
||||
human_description:
|
||||
en_US: Text input that will not guide the image generation. It's just efficient for `Stable Diffusion XL`, `Stable Diffusion`, `Latent Consistency` mode.
|
||||
form: form
|
||||
- name: prompt_2
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: prompt2
|
||||
human_description:
|
||||
en_US: Prompt sent to second tokenizer and text encoder. If not defined, prompt is used in both text-encoders. It's just efficient for `Stable Diffusion XL` mode.
|
||||
form: form
|
||||
- name: width
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: width
|
||||
human_description:
|
||||
en_US: he width of the generated image in pixels. Width needs to be multiple of 64.
|
||||
form: form
|
||||
- name: height
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: height
|
||||
human_description:
|
||||
en_US: he height of the generated image in pixels. Height needs to be multiple of 64.
|
||||
form: form
|
||||
- name: steps
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: steps
|
||||
human_description:
|
||||
en_US: The number of denoising steps. More steps usually can produce higher quality images, but take more time to generate. It's just efficient for `Stable Diffusion XL`, `Stable Diffusion`, `Latent Consistency` mode.
|
||||
form: form
|
||||
@ -19,28 +19,29 @@ class JSONDeleteTool(BuiltinTool):
|
||||
content = tool_parameters.get('content', '')
|
||||
if not content:
|
||||
return self.create_text_message('Invalid parameter content')
|
||||
|
||||
|
||||
# Get query
|
||||
query = tool_parameters.get('query', '')
|
||||
if not query:
|
||||
return self.create_text_message('Invalid parameter query')
|
||||
|
||||
|
||||
ensure_ascii = tool_parameters.get('ensure_ascii', True)
|
||||
try:
|
||||
result = self._delete(content, query)
|
||||
result = self._delete(content, query, ensure_ascii)
|
||||
return self.create_text_message(str(result))
|
||||
except Exception as e:
|
||||
return self.create_text_message(f'Failed to delete JSON content: {str(e)}')
|
||||
|
||||
def _delete(self, origin_json: str, query: str) -> str:
|
||||
def _delete(self, origin_json: str, query: str, ensure_ascii: bool) -> str:
|
||||
try:
|
||||
input_data = json.loads(origin_json)
|
||||
expr = parse('$.' + query.lstrip('$.')) # Ensure query path starts with $
|
||||
|
||||
|
||||
matches = expr.find(input_data)
|
||||
|
||||
|
||||
if not matches:
|
||||
return json.dumps(input_data, ensure_ascii=True) # No changes if no matches found
|
||||
|
||||
return json.dumps(input_data, ensure_ascii=ensure_ascii) # No changes if no matches found
|
||||
|
||||
for match in matches:
|
||||
if isinstance(match.context.value, dict):
|
||||
# Delete key from dictionary
|
||||
@ -53,7 +54,7 @@ class JSONDeleteTool(BuiltinTool):
|
||||
parent = match.context.parent
|
||||
if parent:
|
||||
del parent.value[match.path.fields[-1]]
|
||||
|
||||
return json.dumps(input_data, ensure_ascii=True)
|
||||
|
||||
return json.dumps(input_data, ensure_ascii=ensure_ascii)
|
||||
except Exception as e:
|
||||
raise Exception(f"Delete operation failed: {str(e)}")
|
||||
raise Exception(f"Delete operation failed: {str(e)}")
|
||||
|
||||
@ -38,3 +38,15 @@ parameters:
|
||||
pt_BR: JSONPath query to locate the element to delete
|
||||
llm_description: JSONPath query to locate the element to delete
|
||||
form: llm
|
||||
- name: ensure_ascii
|
||||
type: boolean
|
||||
default: true
|
||||
label:
|
||||
en_US: Ensure ASCII
|
||||
zh_Hans: 确保 ASCII
|
||||
pt_BR: Ensure ASCII
|
||||
human_description:
|
||||
en_US: Ensure the JSON output is ASCII encoded
|
||||
zh_Hans: 确保输出的 JSON 是 ASCII 编码
|
||||
pt_BR: Ensure the JSON output is ASCII encoded
|
||||
form: form
|
||||
|
||||
@ -19,31 +19,31 @@ class JSONParseTool(BuiltinTool):
|
||||
content = tool_parameters.get('content', '')
|
||||
if not content:
|
||||
return self.create_text_message('Invalid parameter content')
|
||||
|
||||
|
||||
# get query
|
||||
query = tool_parameters.get('query', '')
|
||||
if not query:
|
||||
return self.create_text_message('Invalid parameter query')
|
||||
|
||||
|
||||
# get new value
|
||||
new_value = tool_parameters.get('new_value', '')
|
||||
if not new_value:
|
||||
return self.create_text_message('Invalid parameter new_value')
|
||||
|
||||
|
||||
# get insert position
|
||||
index = tool_parameters.get('index')
|
||||
|
||||
|
||||
# get create path
|
||||
create_path = tool_parameters.get('create_path', False)
|
||||
|
||||
|
||||
ensure_ascii = tool_parameters.get('ensure_ascii', True)
|
||||
try:
|
||||
result = self._insert(content, query, new_value, index, create_path)
|
||||
result = self._insert(content, query, new_value, ensure_ascii, index, create_path)
|
||||
return self.create_text_message(str(result))
|
||||
except Exception:
|
||||
return self.create_text_message('Failed to insert JSON content')
|
||||
|
||||
|
||||
def _insert(self, origin_json, query, new_value, index=None, create_path=False):
|
||||
def _insert(self, origin_json, query, new_value, ensure_ascii: bool, index=None, create_path=False):
|
||||
try:
|
||||
input_data = json.loads(origin_json)
|
||||
expr = parse(query)
|
||||
@ -51,9 +51,9 @@ class JSONParseTool(BuiltinTool):
|
||||
new_value = json.loads(new_value)
|
||||
except json.JSONDecodeError:
|
||||
new_value = new_value
|
||||
|
||||
|
||||
matches = expr.find(input_data)
|
||||
|
||||
|
||||
if not matches and create_path:
|
||||
# create new path
|
||||
path_parts = query.strip('$').strip('.').split('.')
|
||||
@ -91,7 +91,7 @@ class JSONParseTool(BuiltinTool):
|
||||
else:
|
||||
# replace old value with new value
|
||||
match.full_path.update(input_data, new_value)
|
||||
|
||||
return json.dumps(input_data, ensure_ascii=True)
|
||||
|
||||
return json.dumps(input_data, ensure_ascii=ensure_ascii)
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
return str(e)
|
||||
|
||||
@ -75,3 +75,15 @@ parameters:
|
||||
zh_Hans: 否
|
||||
pt_BR: "No"
|
||||
form: form
|
||||
- name: ensure_ascii
|
||||
type: boolean
|
||||
default: true
|
||||
label:
|
||||
en_US: Ensure ASCII
|
||||
zh_Hans: 确保 ASCII
|
||||
pt_BR: Ensure ASCII
|
||||
human_description:
|
||||
en_US: Ensure the JSON output is ASCII encoded
|
||||
zh_Hans: 确保输出的 JSON 是 ASCII 编码
|
||||
pt_BR: Ensure the JSON output is ASCII encoded
|
||||
form: form
|
||||
|
||||
@ -19,33 +19,34 @@ class JSONParseTool(BuiltinTool):
|
||||
content = tool_parameters.get('content', '')
|
||||
if not content:
|
||||
return self.create_text_message('Invalid parameter content')
|
||||
|
||||
|
||||
# get json filter
|
||||
json_filter = tool_parameters.get('json_filter', '')
|
||||
if not json_filter:
|
||||
return self.create_text_message('Invalid parameter json_filter')
|
||||
|
||||
ensure_ascii = tool_parameters.get('ensure_ascii', True)
|
||||
try:
|
||||
result = self._extract(content, json_filter)
|
||||
result = self._extract(content, json_filter, ensure_ascii)
|
||||
return self.create_text_message(str(result))
|
||||
except Exception:
|
||||
return self.create_text_message('Failed to extract JSON content')
|
||||
|
||||
# Extract data from JSON content
|
||||
def _extract(self, content: str, json_filter: str) -> str:
|
||||
def _extract(self, content: str, json_filter: str, ensure_ascii: bool) -> str:
|
||||
try:
|
||||
input_data = json.loads(content)
|
||||
expr = parse(json_filter)
|
||||
result = [match.value for match in expr.find(input_data)]
|
||||
|
||||
|
||||
if len(result) == 1:
|
||||
result = result[0]
|
||||
|
||||
|
||||
if isinstance(result, dict | list):
|
||||
return json.dumps(result, ensure_ascii=True)
|
||||
return json.dumps(result, ensure_ascii=ensure_ascii)
|
||||
elif isinstance(result, str | int | float | bool) or result is None:
|
||||
return str(result)
|
||||
else:
|
||||
return repr(result)
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
return str(e)
|
||||
|
||||
@ -38,3 +38,15 @@ parameters:
|
||||
pt_BR: JSON fields to be parsed
|
||||
llm_description: JSON fields to be parsed
|
||||
form: llm
|
||||
- name: ensure_ascii
|
||||
type: boolean
|
||||
default: true
|
||||
label:
|
||||
en_US: Ensure ASCII
|
||||
zh_Hans: 确保 ASCII
|
||||
pt_BR: Ensure ASCII
|
||||
human_description:
|
||||
en_US: Ensure the JSON output is ASCII encoded
|
||||
zh_Hans: 确保输出的 JSON 是 ASCII 编码
|
||||
pt_BR: Ensure the JSON output is ASCII encoded
|
||||
form: form
|
||||
|
||||
@ -19,61 +19,62 @@ class JSONReplaceTool(BuiltinTool):
|
||||
content = tool_parameters.get('content', '')
|
||||
if not content:
|
||||
return self.create_text_message('Invalid parameter content')
|
||||
|
||||
|
||||
# get query
|
||||
query = tool_parameters.get('query', '')
|
||||
if not query:
|
||||
return self.create_text_message('Invalid parameter query')
|
||||
|
||||
|
||||
# get replace value
|
||||
replace_value = tool_parameters.get('replace_value', '')
|
||||
if not replace_value:
|
||||
return self.create_text_message('Invalid parameter replace_value')
|
||||
|
||||
|
||||
# get replace model
|
||||
replace_model = tool_parameters.get('replace_model', '')
|
||||
if not replace_model:
|
||||
return self.create_text_message('Invalid parameter replace_model')
|
||||
|
||||
ensure_ascii = tool_parameters.get('ensure_ascii', True)
|
||||
try:
|
||||
if replace_model == 'pattern':
|
||||
# get replace pattern
|
||||
replace_pattern = tool_parameters.get('replace_pattern', '')
|
||||
if not replace_pattern:
|
||||
return self.create_text_message('Invalid parameter replace_pattern')
|
||||
result = self._replace_pattern(content, query, replace_pattern, replace_value)
|
||||
result = self._replace_pattern(content, query, replace_pattern, replace_value, ensure_ascii)
|
||||
elif replace_model == 'key':
|
||||
result = self._replace_key(content, query, replace_value)
|
||||
result = self._replace_key(content, query, replace_value, ensure_ascii)
|
||||
elif replace_model == 'value':
|
||||
result = self._replace_value(content, query, replace_value)
|
||||
result = self._replace_value(content, query, replace_value, ensure_ascii)
|
||||
return self.create_text_message(str(result))
|
||||
except Exception:
|
||||
return self.create_text_message('Failed to replace JSON content')
|
||||
|
||||
# Replace pattern
|
||||
def _replace_pattern(self, content: str, query: str, replace_pattern: str, replace_value: str) -> str:
|
||||
def _replace_pattern(self, content: str, query: str, replace_pattern: str, replace_value: str, ensure_ascii: bool) -> str:
|
||||
try:
|
||||
input_data = json.loads(content)
|
||||
expr = parse(query)
|
||||
|
||||
|
||||
matches = expr.find(input_data)
|
||||
|
||||
|
||||
for match in matches:
|
||||
new_value = match.value.replace(replace_pattern, replace_value)
|
||||
match.full_path.update(input_data, new_value)
|
||||
|
||||
return json.dumps(input_data, ensure_ascii=True)
|
||||
|
||||
return json.dumps(input_data, ensure_ascii=ensure_ascii)
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
|
||||
|
||||
# Replace key
|
||||
def _replace_key(self, content: str, query: str, replace_value: str) -> str:
|
||||
def _replace_key(self, content: str, query: str, replace_value: str, ensure_ascii: bool) -> str:
|
||||
try:
|
||||
input_data = json.loads(content)
|
||||
expr = parse(query)
|
||||
|
||||
|
||||
matches = expr.find(input_data)
|
||||
|
||||
|
||||
for match in matches:
|
||||
parent = match.context.value
|
||||
if isinstance(parent, dict):
|
||||
@ -86,21 +87,21 @@ class JSONReplaceTool(BuiltinTool):
|
||||
if isinstance(item, dict) and old_key in item:
|
||||
value = item.pop(old_key)
|
||||
item[replace_value] = value
|
||||
return json.dumps(input_data, ensure_ascii=True)
|
||||
return json.dumps(input_data, ensure_ascii=ensure_ascii)
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
|
||||
|
||||
# Replace value
|
||||
def _replace_value(self, content: str, query: str, replace_value: str) -> str:
|
||||
def _replace_value(self, content: str, query: str, replace_value: str, ensure_ascii: bool) -> str:
|
||||
try:
|
||||
input_data = json.loads(content)
|
||||
expr = parse(query)
|
||||
|
||||
|
||||
matches = expr.find(input_data)
|
||||
|
||||
|
||||
for match in matches:
|
||||
match.full_path.update(input_data, replace_value)
|
||||
|
||||
return json.dumps(input_data, ensure_ascii=True)
|
||||
|
||||
return json.dumps(input_data, ensure_ascii=ensure_ascii)
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
return str(e)
|
||||
|
||||
@ -93,3 +93,15 @@ parameters:
|
||||
zh_Hans: 字符串替换
|
||||
pt_BR: replace string
|
||||
form: form
|
||||
- name: ensure_ascii
|
||||
type: boolean
|
||||
default: true
|
||||
label:
|
||||
en_US: Ensure ASCII
|
||||
zh_Hans: 确保 ASCII
|
||||
pt_BR: Ensure ASCII
|
||||
human_description:
|
||||
en_US: Ensure the JSON output is ASCII encoded
|
||||
zh_Hans: 确保输出的 JSON 是 ASCII 编码
|
||||
pt_BR: Ensure the JSON output is ASCII encoded
|
||||
form: form
|
||||
|
||||
1
api/core/tools/provider/builtin/spider/_assets/icon.svg
Normal file
1
api/core/tools/provider/builtin/spider/_assets/icon.svg
Normal file
@ -0,0 +1 @@
|
||||
<svg height="30" width="30" viewBox="0 0 36 34" xml:space="preserve" xmlns="http://www.w3.org/2000/svg" class="fill-accent-foreground transition-all group-hover:scale-110"><title>Spider v1 Logo</title><path fill-rule="evenodd" clip-rule="evenodd" d="M9.13883 7.06589V0.164429L13.0938 0.164429V6.175L14.5178 7.4346C15.577 6.68656 16.7337 6.27495 17.945 6.27495C19.1731 6.27495 20.3451 6.69807 21.4163 7.46593L22.8757 6.175V0.164429L26.8307 0.164429V7.06589V7.95679L26.1634 8.54706L24.0775 10.3922C24.3436 10.8108 24.5958 11.2563 24.8327 11.7262L26.0467 11.4215L28.6971 8.08749L31.793 10.5487L28.7257 14.407L28.3089 14.9313L27.6592 15.0944L26.2418 15.4502C26.3124 15.7082 26.3793 15.9701 26.4422 16.2355L28.653 16.6566L29.092 16.7402L29.4524 17.0045L35.3849 21.355L33.0461 24.5444L27.474 20.4581L27.0719 20.3816C27.1214 21.0613 27.147 21.7543 27.147 22.4577C27.147 22.5398 27.1466 22.6214 27.1459 22.7024L29.5889 23.7911L30.3219 24.1177L30.62 24.8629L33.6873 32.5312L30.0152 34L27.246 27.0769L26.7298 26.8469C25.5612 32.2432 22.0701 33.8808 17.945 33.8808C13.8382 33.8808 10.3598 32.2577 9.17593 26.9185L8.82034 27.0769L6.05109 34L2.37897 32.5312L5.44629 24.8629L5.74435 24.1177L6.47743 23.7911L8.74487 22.7806C8.74366 22.6739 8.74305 22.5663 8.74305 22.4577C8.74305 21.7616 8.76804 21.0758 8.81654 20.4028L8.52606 20.4581L2.95395 24.5444L0.615112 21.355L6.54761 17.0045L6.908 16.7402L7.34701 16.6566L9.44264 16.2575C9.50917 15.9756 9.5801 15.6978 9.65528 15.4242L8.34123 15.0944L7.69155 14.9313L7.27471 14.407L4.20739 10.5487L7.30328 8.08749L9.95376 11.4215L11.0697 11.7016C11.3115 11.2239 11.5692 10.7716 11.8412 10.3473L9.80612 8.54706L9.13883 7.95679V7.06589Z"></path></svg>
|
||||
|
After Width: | Height: | Size: 1.6 KiB |
14
api/core/tools/provider/builtin/spider/spider.py
Normal file
14
api/core/tools/provider/builtin/spider/spider.py
Normal file
@ -0,0 +1,14 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.spider.spiderApp import Spider
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class SpiderProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
app = Spider(api_key=credentials["spider_api_key"])
|
||||
app.scrape_url(url="https://spider.cloud")
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
27
api/core/tools/provider/builtin/spider/spider.yaml
Normal file
27
api/core/tools/provider/builtin/spider/spider.yaml
Normal file
@ -0,0 +1,27 @@
|
||||
identity:
|
||||
author: William Espegren
|
||||
name: spider
|
||||
label:
|
||||
en_US: Spider
|
||||
zh_CN: Spider
|
||||
description:
|
||||
en_US: Spider API integration, returning LLM-ready data by scraping & crawling websites.
|
||||
zh_CN: Spider API 集成,通过爬取和抓取网站返回 LLM-ready 数据。
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- search
|
||||
- utilities
|
||||
credentials_for_provider:
|
||||
spider_api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Spider API Key
|
||||
zh_CN: Spider API 密钥
|
||||
placeholder:
|
||||
en_US: Please input your Spider API key
|
||||
zh_CN: 请输入您的 Spider API 密钥
|
||||
help:
|
||||
en_US: Get your Spider API key from your Spider dashboard
|
||||
zh_CN: 从您的 Spider 仪表板中获取 Spider API 密钥。
|
||||
url: https://spider.cloud/
|
||||
237
api/core/tools/provider/builtin/spider/spiderApp.py
Normal file
237
api/core/tools/provider/builtin/spider/spiderApp.py
Normal file
@ -0,0 +1,237 @@
|
||||
import os
|
||||
from typing import Literal, Optional, TypedDict
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class RequestParamsDict(TypedDict, total=False):
|
||||
url: Optional[str]
|
||||
request: Optional[Literal["http", "chrome", "smart"]]
|
||||
limit: Optional[int]
|
||||
return_format: Optional[Literal["raw", "markdown", "html2text", "text", "bytes"]]
|
||||
tld: Optional[bool]
|
||||
depth: Optional[int]
|
||||
cache: Optional[bool]
|
||||
budget: Optional[dict[str, int]]
|
||||
locale: Optional[str]
|
||||
cookies: Optional[str]
|
||||
stealth: Optional[bool]
|
||||
headers: Optional[dict[str, str]]
|
||||
anti_bot: Optional[bool]
|
||||
metadata: Optional[bool]
|
||||
viewport: Optional[dict[str, int]]
|
||||
encoding: Optional[str]
|
||||
subdomains: Optional[bool]
|
||||
user_agent: Optional[str]
|
||||
store_data: Optional[bool]
|
||||
gpt_config: Optional[list[str]]
|
||||
fingerprint: Optional[bool]
|
||||
storageless: Optional[bool]
|
||||
readability: Optional[bool]
|
||||
proxy_enabled: Optional[bool]
|
||||
respect_robots: Optional[bool]
|
||||
query_selector: Optional[str]
|
||||
full_resources: Optional[bool]
|
||||
request_timeout: Optional[int]
|
||||
run_in_background: Optional[bool]
|
||||
skip_config_checks: Optional[bool]
|
||||
|
||||
|
||||
class Spider:
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
"""
|
||||
Initialize the Spider with an API key.
|
||||
|
||||
:param api_key: A string of the API key for Spider. Defaults to the SPIDER_API_KEY environment variable.
|
||||
:raises ValueError: If no API key is provided.
|
||||
"""
|
||||
self.api_key = api_key or os.getenv("SPIDER_API_KEY")
|
||||
if self.api_key is None:
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
def api_post(
|
||||
self,
|
||||
endpoint: str,
|
||||
data: dict,
|
||||
stream: bool,
|
||||
content_type: str = "application/json",
|
||||
):
|
||||
"""
|
||||
Send a POST request to the specified API endpoint.
|
||||
|
||||
:param endpoint: The API endpoint to which the POST request is sent.
|
||||
:param data: The data (dictionary) to be sent in the POST request.
|
||||
:param stream: Boolean indicating if the response should be streamed.
|
||||
:return: The JSON response or the raw response stream if stream is True.
|
||||
"""
|
||||
headers = self._prepare_headers(content_type)
|
||||
response = self._post_request(
|
||||
f"https://api.spider.cloud/v1/{endpoint}", data, headers, stream
|
||||
)
|
||||
|
||||
if stream:
|
||||
return response
|
||||
elif response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
self._handle_error(response, f"post to {endpoint}")
|
||||
|
||||
def api_get(
|
||||
self, endpoint: str, stream: bool, content_type: str = "application/json"
|
||||
):
|
||||
"""
|
||||
Send a GET request to the specified endpoint.
|
||||
|
||||
:param endpoint: The API endpoint from which to retrieve data.
|
||||
:return: The JSON decoded response.
|
||||
"""
|
||||
headers = self._prepare_headers(content_type)
|
||||
response = self._get_request(
|
||||
f"https://api.spider.cloud/v1/{endpoint}", headers, stream
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
self._handle_error(response, f"get from {endpoint}")
|
||||
|
||||
def get_credits(self):
|
||||
"""
|
||||
Retrieve the account's remaining credits.
|
||||
|
||||
:return: JSON response containing the number of credits left.
|
||||
"""
|
||||
return self.api_get("credits", stream=False)
|
||||
|
||||
def scrape_url(
|
||||
self,
|
||||
url: str,
|
||||
params: Optional[RequestParamsDict] = None,
|
||||
stream: bool = False,
|
||||
content_type: str = "application/json",
|
||||
):
|
||||
"""
|
||||
Scrape data from the specified URL.
|
||||
|
||||
:param url: The URL from which to scrape data.
|
||||
:param params: Optional dictionary of additional parameters for the scrape request.
|
||||
:return: JSON response containing the scraping results.
|
||||
"""
|
||||
|
||||
# Add { "return_format": "markdown" } to the params if not already present
|
||||
if "return_format" not in params:
|
||||
params["return_format"] = "markdown"
|
||||
|
||||
# Set limit to 1
|
||||
params["limit"] = 1
|
||||
|
||||
return self.api_post(
|
||||
"crawl", {"url": url, **(params or {})}, stream, content_type
|
||||
)
|
||||
|
||||
def crawl_url(
|
||||
self,
|
||||
url: str,
|
||||
params: Optional[RequestParamsDict] = None,
|
||||
stream: bool = False,
|
||||
content_type: str = "application/json",
|
||||
):
|
||||
"""
|
||||
Start crawling at the specified URL.
|
||||
|
||||
:param url: The URL to begin crawling.
|
||||
:param params: Optional dictionary with additional parameters to customize the crawl.
|
||||
:param stream: Boolean indicating if the response should be streamed. Defaults to False.
|
||||
:return: JSON response or the raw response stream if streaming enabled.
|
||||
"""
|
||||
|
||||
# Add { "return_format": "markdown" } to the params if not already present
|
||||
if "return_format" not in params:
|
||||
params["return_format"] = "markdown"
|
||||
|
||||
return self.api_post(
|
||||
"crawl", {"url": url, **(params or {})}, stream, content_type
|
||||
)
|
||||
|
||||
def links(
|
||||
self,
|
||||
url: str,
|
||||
params: Optional[RequestParamsDict] = None,
|
||||
stream: bool = False,
|
||||
content_type: str = "application/json",
|
||||
):
|
||||
"""
|
||||
Retrieve links from the specified URL.
|
||||
|
||||
:param url: The URL from which to extract links.
|
||||
:param params: Optional parameters for the link retrieval request.
|
||||
:return: JSON response containing the links.
|
||||
"""
|
||||
return self.api_post(
|
||||
"links", {"url": url, **(params or {})}, stream, content_type
|
||||
)
|
||||
|
||||
def extract_contacts(
|
||||
self,
|
||||
url: str,
|
||||
params: Optional[RequestParamsDict] = None,
|
||||
stream: bool = False,
|
||||
content_type: str = "application/json",
|
||||
):
|
||||
"""
|
||||
Extract contact information from the specified URL.
|
||||
|
||||
:param url: The URL from which to extract contact information.
|
||||
:param params: Optional parameters for the contact extraction.
|
||||
:return: JSON response containing extracted contact details.
|
||||
"""
|
||||
return self.api_post(
|
||||
"pipeline/extract-contacts",
|
||||
{"url": url, **(params or {})},
|
||||
stream,
|
||||
content_type,
|
||||
)
|
||||
|
||||
def label(
|
||||
self,
|
||||
url: str,
|
||||
params: Optional[RequestParamsDict] = None,
|
||||
stream: bool = False,
|
||||
content_type: str = "application/json",
|
||||
):
|
||||
"""
|
||||
Apply labeling to data extracted from the specified URL.
|
||||
|
||||
:param url: The URL to label data from.
|
||||
:param params: Optional parameters to guide the labeling process.
|
||||
:return: JSON response with labeled data.
|
||||
"""
|
||||
return self.api_post(
|
||||
"pipeline/label", {"url": url, **(params or {})}, stream, content_type
|
||||
)
|
||||
|
||||
def _prepare_headers(self, content_type: str = "application/json"):
|
||||
return {
|
||||
"Content-Type": content_type,
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"User-Agent": "Spider-Client/0.0.27",
|
||||
}
|
||||
|
||||
def _post_request(self, url: str, data, headers, stream=False):
|
||||
return requests.post(url, headers=headers, json=data, stream=stream)
|
||||
|
||||
def _get_request(self, url: str, headers, stream=False):
|
||||
return requests.get(url, headers=headers, stream=stream)
|
||||
|
||||
def _delete_request(self, url: str, headers, stream=False):
|
||||
return requests.delete(url, headers=headers, stream=stream)
|
||||
|
||||
def _handle_error(self, response, action):
|
||||
if response.status_code in [402, 409, 500]:
|
||||
error_message = response.json().get("error", "Unknown error occurred")
|
||||
raise Exception(
|
||||
f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}"
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Unexpected error occurred while trying to {action}. Status code: {response.status_code}"
|
||||
)
|
||||
@ -0,0 +1,47 @@
|
||||
from typing import Any, Union
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.provider.builtin.spider.spiderApp import Spider
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class ScrapeTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
# initialize the app object with the api key
|
||||
app = Spider(api_key=self.runtime.credentials['spider_api_key'])
|
||||
|
||||
url = tool_parameters['url']
|
||||
mode = tool_parameters['mode']
|
||||
|
||||
options = {
|
||||
'limit': tool_parameters.get('limit', 0),
|
||||
'depth': tool_parameters.get('depth', 0),
|
||||
'blacklist': tool_parameters.get('blacklist', '').split(',') if tool_parameters.get('blacklist') else [],
|
||||
'whitelist': tool_parameters.get('whitelist', '').split(',') if tool_parameters.get('whitelist') else [],
|
||||
'readability': tool_parameters.get('readability', False),
|
||||
}
|
||||
|
||||
result = ""
|
||||
|
||||
try:
|
||||
if mode == 'scrape':
|
||||
scrape_result = app.scrape_url(
|
||||
url=url,
|
||||
params=options,
|
||||
)
|
||||
|
||||
for i in scrape_result:
|
||||
result += "URL: " + i.get('url', '') + "\n"
|
||||
result += "CONTENT: " + i.get('content', '') + "\n\n"
|
||||
elif mode == 'crawl':
|
||||
crawl_result = app.crawl_url(
|
||||
url=tool_parameters['url'],
|
||||
params=options,
|
||||
)
|
||||
for i in crawl_result:
|
||||
result += "URL: " + i.get('url', '') + "\n"
|
||||
result += "CONTENT: " + i.get('content', '') + "\n\n"
|
||||
except Exception as e:
|
||||
return self.create_text_message("An error occured", str(e))
|
||||
|
||||
return self.create_text_message(result)
|
||||
@ -0,0 +1,102 @@
|
||||
identity:
|
||||
name: scraper_crawler
|
||||
author: William Espegren
|
||||
label:
|
||||
en_US: Web Scraper & Crawler
|
||||
zh_Hans: 网页抓取与爬虫
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for scraping & crawling webpages. Input should be a url.
|
||||
zh_Hans: 用于抓取和爬取网页的工具。输入应该是一个网址。
|
||||
llm: A tool for scraping & crawling webpages. Input should be a url.
|
||||
parameters:
|
||||
- name: url
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: URL
|
||||
zh_Hans: 网址
|
||||
human_description:
|
||||
en_US: url to be scraped or crawled
|
||||
zh_Hans: 要抓取或爬取的网址
|
||||
llm_description: url to either be scraped or crawled
|
||||
form: llm
|
||||
- name: mode
|
||||
type: select
|
||||
required: true
|
||||
options:
|
||||
- value: scrape
|
||||
label:
|
||||
en_US: scrape
|
||||
zh_Hans: 抓取
|
||||
- value: crawl
|
||||
label:
|
||||
en_US: crawl
|
||||
zh_Hans: 爬取
|
||||
default: crawl
|
||||
label:
|
||||
en_US: Mode
|
||||
zh_Hans: 模式
|
||||
human_description:
|
||||
en_US: used for selecting to either scrape the website or crawl the entire website following subpages
|
||||
zh_Hans: 用于选择抓取网站或爬取整个网站及其子页面
|
||||
form: form
|
||||
- name: limit
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: maximum number of pages to crawl
|
||||
zh_Hans: 最大爬取页面数
|
||||
human_description:
|
||||
en_US: specify the maximum number of pages to crawl per website. the crawler will stop after reaching this limit.
|
||||
zh_Hans: 指定每个网站要爬取的最大页面数。爬虫将在达到此限制后停止。
|
||||
form: form
|
||||
min: 0
|
||||
default: 0
|
||||
- name: depth
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: maximum depth of pages to crawl
|
||||
zh_Hans: 最大爬取深度
|
||||
human_description:
|
||||
en_US: the crawl limit for maximum depth.
|
||||
zh_Hans: 最大爬取深度的限制。
|
||||
form: form
|
||||
min: 0
|
||||
default: 0
|
||||
- name: blacklist
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: url patterns to exclude
|
||||
zh_Hans: 要排除的URL模式
|
||||
human_description:
|
||||
en_US: blacklist a set of paths that you do not want to crawl. you can use regex patterns to help with the list.
|
||||
zh_Hans: 指定一组不想爬取的路径。您可以使用正则表达式模式来帮助定义列表。
|
||||
placeholder:
|
||||
en_US: /blog/*, /about
|
||||
form: form
|
||||
- name: whitelist
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: URL patterns to include
|
||||
zh_Hans: 要包含的URL模式
|
||||
human_description:
|
||||
en_US: Whitelist a set of paths that you want to crawl, ignoring all other routes that do not match the patterns. You can use regex patterns to help with the list.
|
||||
zh_Hans: 指定一组要爬取的路径,忽略所有不匹配模式的其他路由。您可以使用正则表达式模式来帮助定义列表。
|
||||
placeholder:
|
||||
en_US: /blog/*, /about
|
||||
form: form
|
||||
- name: readability
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Pre-process the content for LLM usage
|
||||
zh_Hans: 仅返回页面的主要内容
|
||||
human_description:
|
||||
en_US: Use Mozilla's readability to pre-process the content for reading. This may drastically improve the content for LLM usage.
|
||||
zh_Hans: 如果启用,爬虫将仅返回页面的主要内容,不包括标题、导航、页脚等。
|
||||
form: form
|
||||
default: false
|
||||
21
api/core/tools/provider/builtin/tianditu/_assets/icon.svg
Normal file
21
api/core/tools/provider/builtin/tianditu/_assets/icon.svg
Normal file
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 23 KiB |
21
api/core/tools/provider/builtin/tianditu/tianditu.py
Normal file
21
api/core/tools/provider/builtin/tianditu/tianditu.py
Normal file
@ -0,0 +1,21 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.tianditu.tools.poisearch import PoiSearchTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class TiandituProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
PoiSearchTool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(user_id='',
|
||||
tool_parameters={
|
||||
'content': '北京',
|
||||
'specify': '156110000',
|
||||
})
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
32
api/core/tools/provider/builtin/tianditu/tianditu.yaml
Normal file
32
api/core/tools/provider/builtin/tianditu/tianditu.yaml
Normal file
@ -0,0 +1,32 @@
|
||||
identity:
|
||||
author: Listeng
|
||||
name: tianditu
|
||||
label:
|
||||
en_US: Tianditu
|
||||
zh_Hans: 天地图
|
||||
pt_BR: Tianditu
|
||||
description:
|
||||
en_US: The Tianditu tool provided the functions of place name search, geocoding, static maps generation, etc. in China region.
|
||||
zh_Hans: 天地图工具可以调用天地图的接口,实现中国区域内的地名搜索、地理编码、静态地图等功能。
|
||||
pt_BR: The Tianditu tool provided the functions of place name search, geocoding, static maps generation, etc. in China region.
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- utilities
|
||||
- travel
|
||||
credentials_for_provider:
|
||||
tianditu_api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Tianditu API Key
|
||||
zh_Hans: 天地图Key
|
||||
pt_BR: Tianditu API key
|
||||
placeholder:
|
||||
en_US: Please input your Tianditu API key
|
||||
zh_Hans: 请输入你的天地图Key
|
||||
pt_BR: Please input your Tianditu API key
|
||||
help:
|
||||
en_US: Get your Tianditu API key from Tianditu
|
||||
zh_Hans: 获取您的天地图Key
|
||||
pt_BR: Get your Tianditu API key from Tianditu
|
||||
url: http://lbs.tianditu.gov.cn/home.html
|
||||
33
api/core/tools/provider/builtin/tianditu/tools/geocoder.py
Normal file
33
api/core/tools/provider/builtin/tianditu/tools/geocoder.py
Normal file
@ -0,0 +1,33 @@
|
||||
import json
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class GeocoderTool(BuiltinTool):
|
||||
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
base_url = 'http://api.tianditu.gov.cn/geocoder'
|
||||
|
||||
keyword = tool_parameters.get('keyword', '')
|
||||
if not keyword:
|
||||
return self.create_text_message('Invalid parameter keyword')
|
||||
|
||||
tk = self.runtime.credentials['tianditu_api_key']
|
||||
|
||||
params = {
|
||||
'keyWord': keyword,
|
||||
}
|
||||
|
||||
result = requests.get(base_url + '?ds=' + json.dumps(params, ensure_ascii=False) + '&tk=' + tk).json()
|
||||
|
||||
return self.create_json_message(result)
|
||||
26
api/core/tools/provider/builtin/tianditu/tools/geocoder.yaml
Normal file
26
api/core/tools/provider/builtin/tianditu/tools/geocoder.yaml
Normal file
@ -0,0 +1,26 @@
|
||||
identity:
|
||||
name: geocoder
|
||||
author: Listeng
|
||||
label:
|
||||
en_US: Get coords converted from address name
|
||||
zh_Hans: 地理编码
|
||||
pt_BR: Get coords converted from address name
|
||||
description:
|
||||
human:
|
||||
en_US: Geocoder
|
||||
zh_Hans: 中国区域地理编码查询
|
||||
pt_BR: Geocoder
|
||||
llm: A tool for geocoder in China
|
||||
parameters:
|
||||
- name: keyword
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: keyword
|
||||
zh_Hans: 搜索的关键字
|
||||
pt_BR: keyword
|
||||
human_description:
|
||||
en_US: keyword
|
||||
zh_Hans: 搜索的关键字
|
||||
pt_BR: keyword
|
||||
form: llm
|
||||
45
api/core/tools/provider/builtin/tianditu/tools/poisearch.py
Normal file
45
api/core/tools/provider/builtin/tianditu/tools/poisearch.py
Normal file
@ -0,0 +1,45 @@
|
||||
import json
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class PoiSearchTool(BuiltinTool):
|
||||
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
geocoder_base_url = 'http://api.tianditu.gov.cn/geocoder'
|
||||
base_url = 'http://api.tianditu.gov.cn/v2/search'
|
||||
|
||||
keyword = tool_parameters.get('keyword', '')
|
||||
if not keyword:
|
||||
return self.create_text_message('Invalid parameter keyword')
|
||||
|
||||
baseAddress = tool_parameters.get('baseAddress', '')
|
||||
if not baseAddress:
|
||||
return self.create_text_message('Invalid parameter baseAddress')
|
||||
|
||||
tk = self.runtime.credentials['tianditu_api_key']
|
||||
|
||||
base_coords = requests.get(geocoder_base_url + '?ds=' + json.dumps({'keyWord': baseAddress,}, ensure_ascii=False) + '&tk=' + tk).json()
|
||||
|
||||
params = {
|
||||
'keyWord': keyword,
|
||||
'queryRadius': 5000,
|
||||
'queryType': 3,
|
||||
'pointLonlat': base_coords['location']['lon'] + ',' + base_coords['location']['lat'],
|
||||
'start': 0,
|
||||
'count': 100,
|
||||
}
|
||||
|
||||
result = requests.get(base_url + '?postStr=' + json.dumps(params, ensure_ascii=False) + '&type=query&tk=' + tk).json()
|
||||
|
||||
return self.create_json_message(result)
|
||||
@ -0,0 +1,38 @@
|
||||
identity:
|
||||
name: point_of_interest_search
|
||||
author: Listeng
|
||||
label:
|
||||
en_US: Point of Interest search
|
||||
zh_Hans: 兴趣点搜索
|
||||
pt_BR: Point of Interest search
|
||||
description:
|
||||
human:
|
||||
en_US: Search for certain types of points of interest around a location
|
||||
zh_Hans: 搜索某个位置周边的5公里内某种类型的兴趣点
|
||||
pt_BR: Search for certain types of points of interest around a location
|
||||
llm: A tool for searching for certain types of points of interest around a location
|
||||
parameters:
|
||||
- name: keyword
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: poi keyword
|
||||
zh_Hans: 兴趣点的关键字
|
||||
pt_BR: poi keyword
|
||||
human_description:
|
||||
en_US: poi keyword
|
||||
zh_Hans: 兴趣点的关键字
|
||||
pt_BR: poi keyword
|
||||
form: llm
|
||||
- name: baseAddress
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: base current point
|
||||
zh_Hans: 当前位置的关键字
|
||||
pt_BR: base current point
|
||||
human_description:
|
||||
en_US: base current point
|
||||
zh_Hans: 当前位置的关键字
|
||||
pt_BR: base current point
|
||||
form: llm
|
||||
36
api/core/tools/provider/builtin/tianditu/tools/staticmap.py
Normal file
36
api/core/tools/provider/builtin/tianditu/tools/staticmap.py
Normal file
@ -0,0 +1,36 @@
|
||||
import json
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class PoiSearchTool(BuiltinTool):
|
||||
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
|
||||
geocoder_base_url = 'http://api.tianditu.gov.cn/geocoder'
|
||||
base_url = 'http://api.tianditu.gov.cn/staticimage'
|
||||
|
||||
keyword = tool_parameters.get('keyword', '')
|
||||
if not keyword:
|
||||
return self.create_text_message('Invalid parameter keyword')
|
||||
|
||||
tk = self.runtime.credentials['tianditu_api_key']
|
||||
|
||||
keyword_coords = requests.get(geocoder_base_url + '?ds=' + json.dumps({'keyWord': keyword,}, ensure_ascii=False) + '&tk=' + tk).json()
|
||||
coords = keyword_coords['location']['lon'] + ',' + keyword_coords['location']['lat']
|
||||
|
||||
result = requests.get(base_url + '?center=' + coords + '&markers=' + coords + '&width=400&height=300&zoom=14&tk=' + tk).content
|
||||
|
||||
return self.create_blob_message(blob=result,
|
||||
meta={'mime_type': 'image/png'},
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value)
|
||||
@ -0,0 +1,26 @@
|
||||
identity:
|
||||
name: generate_static_map
|
||||
author: Listeng
|
||||
label:
|
||||
en_US: Generate a static map
|
||||
zh_Hans: 生成静态地图
|
||||
pt_BR: Generate a static map
|
||||
description:
|
||||
human:
|
||||
en_US: Generate a static map
|
||||
zh_Hans: 生成静态地图
|
||||
pt_BR: Generate a static map
|
||||
llm: A tool for generate a static map
|
||||
parameters:
|
||||
- name: keyword
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: keyword
|
||||
zh_Hans: 搜索的关键字
|
||||
pt_BR: keyword
|
||||
human_description:
|
||||
en_US: keyword
|
||||
zh_Hans: 搜索的关键字
|
||||
pt_BR: keyword
|
||||
form: llm
|
||||
@ -14,7 +14,7 @@ from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
default_retrieval_model = {
|
||||
'search_method': RetrievalMethod.SEMANTIC_SEARCH,
|
||||
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
|
||||
@ -8,7 +8,7 @@ from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
default_retrieval_model = {
|
||||
'search_method': RetrievalMethod.SEMANTIC_SEARCH,
|
||||
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
@ -190,8 +191,9 @@ class Tool(BaseModel, ABC):
|
||||
|
||||
return result
|
||||
|
||||
def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
|
||||
def invoke(self, user_id: str, tool_parameters: Mapping[str, Any]) -> list[ToolInvokeMessage]:
|
||||
# update tool_parameters
|
||||
# TODO: Fix type error.
|
||||
if self.runtime.runtime_parameters:
|
||||
tool_parameters.update(self.runtime.runtime_parameters)
|
||||
|
||||
@ -208,7 +210,7 @@ class Tool(BaseModel, ABC):
|
||||
|
||||
return result
|
||||
|
||||
def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
def _transform_tool_parameters_type(self, tool_parameters: Mapping[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Transform tool parameters type
|
||||
"""
|
||||
@ -241,7 +243,7 @@ class Tool(BaseModel, ABC):
|
||||
|
||||
:return: the runtime parameters
|
||||
"""
|
||||
return self.parameters
|
||||
return self.parameters or []
|
||||
|
||||
def get_all_runtime_parameters(self) -> list[ToolParameter]:
|
||||
"""
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from datetime import datetime, timezone
|
||||
from mimetypes import guess_type
|
||||
@ -46,7 +47,7 @@ class ToolEngine:
|
||||
if isinstance(tool_parameters, str):
|
||||
# check if this tool has only one parameter
|
||||
parameters = [
|
||||
parameter for parameter in tool.get_runtime_parameters()
|
||||
parameter for parameter in tool.get_runtime_parameters() or []
|
||||
if parameter.form == ToolParameter.ToolParameterForm.LLM
|
||||
]
|
||||
if parameters and len(parameters) == 1:
|
||||
@ -123,8 +124,8 @@ class ToolEngine:
|
||||
return error_response, [], ToolInvokeMeta.error_instance(error_response)
|
||||
|
||||
@staticmethod
|
||||
def workflow_invoke(tool: Tool, tool_parameters: dict,
|
||||
user_id: str, workflow_id: str,
|
||||
def workflow_invoke(tool: Tool, tool_parameters: Mapping[str, Any],
|
||||
user_id: str,
|
||||
workflow_tool_callback: DifyWorkflowCallbackHandler,
|
||||
workflow_call_depth: int,
|
||||
) -> list[ToolInvokeMessage]:
|
||||
@ -141,7 +142,9 @@ class ToolEngine:
|
||||
if isinstance(tool, WorkflowTool):
|
||||
tool.workflow_call_depth = workflow_call_depth + 1
|
||||
|
||||
response = tool.invoke(user_id, tool_parameters)
|
||||
if tool.runtime and tool.runtime.runtime_parameters:
|
||||
tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters}
|
||||
response = tool.invoke(user_id=user_id, tool_parameters=tool_parameters)
|
||||
|
||||
# hit the callback handler
|
||||
workflow_tool_callback.on_tool_end(
|
||||
|
||||
@ -9,9 +9,9 @@ from mimetypes import guess_extension, guess_type
|
||||
from typing import Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from flask import current_app
|
||||
from httpx import get
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import MessageFile
|
||||
@ -26,25 +26,25 @@ class ToolFileManager:
|
||||
"""
|
||||
sign file to get a temporary url
|
||||
"""
|
||||
base_url = current_app.config.get('FILES_URL')
|
||||
base_url = dify_config.FILES_URL
|
||||
file_preview_url = f'{base_url}/files/tools/{tool_file_id}{extension}'
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = current_app.config['SECRET_KEY'].encode()
|
||||
data_to_sign = f'file-preview|{tool_file_id}|{timestamp}|{nonce}'
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b''
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
return f'{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}'
|
||||
|
||||
@staticmethod
|
||||
def verify_file(file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
"""
|
||||
verify signature
|
||||
"""
|
||||
data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}"
|
||||
secret_key = current_app.config['SECRET_KEY'].encode()
|
||||
data_to_sign = f'file-preview|{file_id}|{timestamp}|{nonce}'
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b''
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
@ -53,23 +53,23 @@ class ToolFileManager:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= current_app.config.get('FILES_ACCESS_TIMEOUT')
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
|
||||
@staticmethod
|
||||
def create_file_by_raw(user_id: str, tenant_id: str,
|
||||
conversation_id: Optional[str], file_binary: bytes,
|
||||
mimetype: str
|
||||
) -> ToolFile:
|
||||
def create_file_by_raw(
|
||||
user_id: str, tenant_id: str, conversation_id: Optional[str], file_binary: bytes, mimetype: str
|
||||
) -> ToolFile:
|
||||
"""
|
||||
create file
|
||||
"""
|
||||
extension = guess_extension(mimetype) or '.bin'
|
||||
unique_name = uuid4().hex
|
||||
filename = f"tools/{tenant_id}/{unique_name}{extension}"
|
||||
filename = f'tools/{tenant_id}/{unique_name}{extension}'
|
||||
storage.save(filename, file_binary)
|
||||
|
||||
tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id,
|
||||
conversation_id=conversation_id, file_key=filename, mimetype=mimetype)
|
||||
tool_file = ToolFile(
|
||||
user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_key=filename, mimetype=mimetype
|
||||
)
|
||||
|
||||
db.session.add(tool_file)
|
||||
db.session.commit()
|
||||
@ -77,9 +77,12 @@ class ToolFileManager:
|
||||
return tool_file
|
||||
|
||||
@staticmethod
|
||||
def create_file_by_url(user_id: str, tenant_id: str,
|
||||
conversation_id: str, file_url: str,
|
||||
) -> ToolFile:
|
||||
def create_file_by_url(
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: str,
|
||||
file_url: str,
|
||||
) -> ToolFile:
|
||||
"""
|
||||
create file
|
||||
"""
|
||||
@ -90,12 +93,17 @@ class ToolFileManager:
|
||||
mimetype = guess_type(file_url)[0] or 'octet/stream'
|
||||
extension = guess_extension(mimetype) or '.bin'
|
||||
unique_name = uuid4().hex
|
||||
filename = f"tools/{tenant_id}/{unique_name}{extension}"
|
||||
filename = f'tools/{tenant_id}/{unique_name}{extension}'
|
||||
storage.save(filename, blob)
|
||||
|
||||
tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id,
|
||||
conversation_id=conversation_id, file_key=filename,
|
||||
mimetype=mimetype, original_url=file_url)
|
||||
tool_file = ToolFile(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_key=filename,
|
||||
mimetype=mimetype,
|
||||
original_url=file_url,
|
||||
)
|
||||
|
||||
db.session.add(tool_file)
|
||||
db.session.commit()
|
||||
@ -103,15 +111,15 @@ class ToolFileManager:
|
||||
return tool_file
|
||||
|
||||
@staticmethod
|
||||
def create_file_by_key(user_id: str, tenant_id: str,
|
||||
conversation_id: str, file_key: str,
|
||||
mimetype: str
|
||||
) -> ToolFile:
|
||||
def create_file_by_key(
|
||||
user_id: str, tenant_id: str, conversation_id: str, file_key: str, mimetype: str
|
||||
) -> ToolFile:
|
||||
"""
|
||||
create file
|
||||
"""
|
||||
tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id,
|
||||
conversation_id=conversation_id, file_key=file_key, mimetype=mimetype)
|
||||
tool_file = ToolFile(
|
||||
user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_key=file_key, mimetype=mimetype
|
||||
)
|
||||
return tool_file
|
||||
|
||||
@staticmethod
|
||||
@ -123,9 +131,13 @@ class ToolFileManager:
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
tool_file: ToolFile = db.session.query(ToolFile).filter(
|
||||
ToolFile.id == id,
|
||||
).first()
|
||||
tool_file: ToolFile = (
|
||||
db.session.query(ToolFile)
|
||||
.filter(
|
||||
ToolFile.id == id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tool_file:
|
||||
return None
|
||||
@ -143,18 +155,31 @@ class ToolFileManager:
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
message_file: MessageFile = db.session.query(MessageFile).filter(
|
||||
MessageFile.id == id,
|
||||
).first()
|
||||
message_file: MessageFile = (
|
||||
db.session.query(MessageFile)
|
||||
.filter(
|
||||
MessageFile.id == id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# get tool file id
|
||||
tool_file_id = message_file.url.split('/')[-1]
|
||||
# trim extension
|
||||
tool_file_id = tool_file_id.split('.')[0]
|
||||
# Check if message_file is not None
|
||||
if message_file is not None:
|
||||
# get tool file id
|
||||
tool_file_id = message_file.url.split('/')[-1]
|
||||
# trim extension
|
||||
tool_file_id = tool_file_id.split('.')[0]
|
||||
else:
|
||||
tool_file_id = None
|
||||
|
||||
tool_file: ToolFile = db.session.query(ToolFile).filter(
|
||||
ToolFile.id == tool_file_id,
|
||||
).first()
|
||||
|
||||
tool_file: ToolFile = (
|
||||
db.session.query(ToolFile)
|
||||
.filter(
|
||||
ToolFile.id == tool_file_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tool_file:
|
||||
return None
|
||||
@ -172,9 +197,13 @@ class ToolFileManager:
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
tool_file: ToolFile = db.session.query(ToolFile).filter(
|
||||
ToolFile.id == tool_file_id,
|
||||
).first()
|
||||
tool_file: ToolFile = (
|
||||
db.session.query(ToolFile)
|
||||
.filter(
|
||||
ToolFile.id == tool_file_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tool_file:
|
||||
return None
|
||||
|
||||
@ -6,8 +6,7 @@ from os import listdir, path
|
||||
from threading import Lock
|
||||
from typing import Any, Union
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from configs import dify_config
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
@ -565,7 +564,7 @@ class ToolManager:
|
||||
provider_type = provider_type
|
||||
provider_id = provider_id
|
||||
if provider_type == 'builtin':
|
||||
return (current_app.config.get("CONSOLE_API_URL")
|
||||
return (dify_config.CONSOLE_API_URL
|
||||
+ "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
+ provider_id
|
||||
+ "/icon")
|
||||
@ -574,7 +573,7 @@ class ToolManager:
|
||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.id == provider_id
|
||||
)
|
||||
).first()
|
||||
return json.loads(provider.icon)
|
||||
except:
|
||||
return {
|
||||
@ -593,4 +592,4 @@ class ToolManager:
|
||||
else:
|
||||
raise ValueError(f"provider type {provider_type} not found")
|
||||
|
||||
ToolManager.load_builtin_providers_cache()
|
||||
ToolManager.load_builtin_providers_cache()
|
||||
|
||||
@ -10,6 +10,7 @@ import unicodedata
|
||||
from contextlib import contextmanager
|
||||
from urllib.parse import unquote
|
||||
|
||||
import cloudscraper
|
||||
import requests
|
||||
from bs4 import BeautifulSoup, CData, Comment, NavigableString
|
||||
from newspaper import Article
|
||||
@ -46,29 +47,34 @@ def get_url(url: str, user_agent: str = None) -> str:
|
||||
supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"]
|
||||
response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10))
|
||||
|
||||
if response.status_code == 200:
|
||||
# check content-type
|
||||
content_type = response.headers.get('Content-Type')
|
||||
if content_type:
|
||||
main_content_type = response.headers.get('Content-Type').split(';')[0].strip()
|
||||
else:
|
||||
content_disposition = response.headers.get('Content-Disposition', '')
|
||||
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
|
||||
if filename_match:
|
||||
filename = unquote(filename_match.group(1))
|
||||
extension = re.search(r'\.(\w+)$', filename)
|
||||
if extension:
|
||||
main_content_type = mimetypes.guess_type(filename)[0]
|
||||
|
||||
if main_content_type not in supported_content_types:
|
||||
return "Unsupported content-type [{}] of URL.".format(main_content_type)
|
||||
|
||||
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
|
||||
return ExtractProcessor.load_from_url(url, return_text=True)
|
||||
|
||||
response = requests.get(url, headers=headers, allow_redirects=True, timeout=(120, 300))
|
||||
elif response.status_code == 403:
|
||||
scraper = cloudscraper.create_scraper()
|
||||
response = scraper.get(url, headers=headers, allow_redirects=True, timeout=(120, 300))
|
||||
|
||||
if response.status_code != 200:
|
||||
return "URL returned status code {}.".format(response.status_code)
|
||||
|
||||
# check content-type
|
||||
content_type = response.headers.get('Content-Type')
|
||||
if content_type:
|
||||
main_content_type = response.headers.get('Content-Type').split(';')[0].strip()
|
||||
else:
|
||||
content_disposition = response.headers.get('Content-Disposition', '')
|
||||
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
|
||||
if filename_match:
|
||||
filename = unquote(filename_match.group(1))
|
||||
extension = re.search(r'\.(\w+)$', filename)
|
||||
if extension:
|
||||
main_content_type = mimetypes.guess_type(filename)[0]
|
||||
|
||||
if main_content_type not in supported_content_types:
|
||||
return "Unsupported content-type [{}] of URL.".format(main_content_type)
|
||||
|
||||
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
|
||||
return ExtractProcessor.load_from_url(url, return_text=True)
|
||||
|
||||
response = requests.get(url, headers=headers, allow_redirects=True, timeout=(120, 300))
|
||||
a = extract_using_readabilipy(response.text)
|
||||
|
||||
if not a['plain_text'] or not a['plain_text'].strip():
|
||||
|
||||
Reference in New Issue
Block a user