Merge main

This commit is contained in:
Yeuoly
2024-09-10 14:05:20 +08:00
650 changed files with 15950 additions and 4747 deletions

View File

@ -1,5 +1,6 @@
- google
- bing
- perplexity
- duckduckgo
- searchapi
- serper

View File

@ -3,6 +3,7 @@ import logging
from typing import Any, Union
import boto3
from botocore.exceptions import BotoCoreError
from pydantic import BaseModel, Field
from core.tools.entities.tool_entities import ToolInvokeMessage
@ -16,7 +17,7 @@ class GuardrailParameters(BaseModel):
guardrail_version: str = Field(..., description="The version of the guardrail")
source: str = Field(..., description="The source of the content")
text: str = Field(..., description="The text to apply the guardrail to")
aws_region: str = Field(default="us-east-1", description="AWS region for the Bedrock client")
aws_region: str = Field(..., description="AWS region for the Bedrock client")
class ApplyGuardrailTool(BuiltinTool):
def _invoke(self,
@ -40,6 +41,8 @@ class ApplyGuardrailTool(BuiltinTool):
source=params.source,
content=[{"text": {"text": params.text}}]
)
logger.info(f"Raw response from AWS: {json.dumps(response, indent=2)}")
# Check for empty response
if not response:
@ -69,7 +72,7 @@ class ApplyGuardrailTool(BuiltinTool):
return self.create_text_message(text=result)
except boto3.exceptions.BotoCoreError as e:
except BotoCoreError as e:
error_message = f'AWS service error: {str(e)}'
logger.error(error_message, exc_info=True)
return self.create_text_message(text=error_message)
@ -80,4 +83,4 @@ class ApplyGuardrailTool(BuiltinTool):
except Exception as e:
error_message = f'An unexpected error occurred: {str(e)}'
logger.error(error_message, exc_info=True)
return self.create_text_message(text=error_message)
return self.create_text_message(text=error_message)

View File

@ -54,3 +54,14 @@ parameters:
zh_Hans: 用于请求护栏审查的内容,可以是用户输入或 LLM 输出。
llm_description: The content used for requesting guardrail review, which can be either user input or LLM output.
form: llm
- name: aws_region
type: string
required: true
label:
en_US: AWS Region
zh_Hans: AWS 区域
human_description:
en_US: Please enter the AWS region for the Bedrock client, for example 'us-east-1'.
zh_Hans: 请输入 Bedrock 客户端的 AWS 区域,例如 'us-east-1'。
llm_description: Please enter the AWS region for the Bedrock client, for example 'us-east-1'.
form: form

View File

@ -0,0 +1,71 @@
import json
import logging
from typing import Any, Union
import boto3
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
console_handler = logging.StreamHandler()
logger.addHandler(console_handler)
class LambdaYamlToJsonTool(BuiltinTool):
lambda_client: Any = None
def _invoke_lambda(self, lambda_name: str, yaml_content: str) -> str:
msg = {
"body": yaml_content
}
logger.info(json.dumps(msg))
invoke_response = self.lambda_client.invoke(FunctionName=lambda_name,
InvocationType='RequestResponse',
Payload=json.dumps(msg))
response_body = invoke_response['Payload']
response_str = response_body.read().decode("utf-8")
resp_json = json.loads(response_str)
logger.info(resp_json)
if resp_json['statusCode'] != 200:
raise Exception(f"Invalid status code: {response_str}")
return resp_json['body']
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
try:
if not self.lambda_client:
aws_region = tool_parameters.get('aws_region') # todo: move aws_region out, and update client region
if aws_region:
self.lambda_client = boto3.client("lambda", region_name=aws_region)
else:
self.lambda_client = boto3.client("lambda")
yaml_content = tool_parameters.get('yaml_content', '')
if not yaml_content:
return self.create_text_message('Please input yaml_content')
lambda_name = tool_parameters.get('lambda_name', '')
if not lambda_name:
return self.create_text_message('Please input lambda_name')
logger.debug(f'{json.dumps(tool_parameters, indent=2, ensure_ascii=False)}')
result = self._invoke_lambda(lambda_name, yaml_content)
logger.debug(result)
return self.create_text_message(result)
except Exception as e:
return self.create_text_message(f'Exception: {str(e)}')
console_handler.flush()

View File

@ -0,0 +1,53 @@
identity:
name: lambda_yaml_to_json
author: AWS
label:
en_US: LambdaYamlToJson
zh_Hans: LambdaYamlToJson
pt_BR: LambdaYamlToJson
icon: icon.svg
description:
human:
en_US: A tool to convert yaml to json using AWS Lambda.
zh_Hans: 将 YAML 转为 JSON 的工具通过AWS Lambda
pt_BR: A tool to convert yaml to json using AWS Lambda.
llm: A tool to convert yaml to json.
parameters:
- name: yaml_content
type: string
required: true
label:
en_US: YAML content to convert for
zh_Hans: YAML 内容
pt_BR: YAML content to convert for
human_description:
en_US: YAML content to convert for
zh_Hans: YAML 内容
pt_BR: YAML content to convert for
llm_description: YAML content to convert for
form: llm
- name: aws_region
type: string
required: false
label:
en_US: region of lambda
zh_Hans: Lambda 所在的region
pt_BR: region of lambda
human_description:
en_US: region of lambda
zh_Hans: Lambda 所在的region
pt_BR: region of lambda
llm_description: region of lambda
form: form
- name: lambda_name
type: string
required: false
label:
en_US: name of lambda
zh_Hans: Lambda 名称
pt_BR: name of lambda
human_description:
en_US: name of lambda
zh_Hans: Lambda 名称
pt_BR: name of lambda
form: form

View File

@ -78,9 +78,7 @@ class SageMakerReRankTool(BuiltinTool):
sorted_candidate_docs = sorted(candidate_docs, key=lambda x: x['score'], reverse=True)
line = 9
results_str = json.dumps(sorted_candidate_docs[:self.topk], ensure_ascii=False)
return self.create_text_message(text=results_str)
return [ self.create_json_message(res) for res in sorted_candidate_docs[:self.topk] ]
except Exception as e:
return self.create_text_message(f'Exception {str(e)}, line : {line}')
return self.create_text_message(f'Exception {str(e)}, line : {line}')

View File

@ -0,0 +1,95 @@
import json
from enum import Enum
from typing import Any, Union
import boto3
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class TTSModelType(Enum):
PresetVoice = "PresetVoice"
CloneVoice = "CloneVoice"
CloneVoice_CrossLingual = "CloneVoice_CrossLingual"
InstructVoice = "InstructVoice"
class SageMakerTTSTool(BuiltinTool):
sagemaker_client: Any = None
sagemaker_endpoint:str = None
s3_client : Any = None
comprehend_client : Any = None
def _detect_lang_code(self, content:str, map_dict:dict=None):
map_dict = {
"zh" : "<|zh|>",
"en" : "<|en|>",
"ja" : "<|jp|>",
"zh-TW" : "<|yue|>",
"ko" : "<|ko|>"
}
response = self.comprehend_client.detect_dominant_language(Text=content)
language_code = response['Languages'][0]['LanguageCode']
return map_dict.get(language_code, '<|zh|>')
def _build_tts_payload(self, model_type:str, content_text:str, model_role:str, prompt_text:str, prompt_audio:str, instruct_text:str):
if model_type == TTSModelType.PresetVoice.value and model_role:
return { "tts_text" : content_text, "role" : model_role }
if model_type == TTSModelType.CloneVoice.value and prompt_text and prompt_audio:
return { "tts_text" : content_text, "prompt_text": prompt_text, "prompt_audio" : prompt_audio }
if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio:
lang_tag = self._detect_lang_code(content_text)
return { "tts_text" : f"{content_text}", "prompt_audio" : prompt_audio, "lang_tag" : lang_tag }
if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role:
return { "tts_text" : content_text, "role" : model_role, "instruct_text" : instruct_text }
raise RuntimeError(f"Invalid params for {model_type}")
def _invoke_sagemaker(self, payload:dict, endpoint:str):
response_model = self.sagemaker_client.invoke_endpoint(
EndpointName=endpoint,
Body=json.dumps(payload),
ContentType="application/json",
)
json_str = response_model['Body'].read().decode('utf8')
json_obj = json.loads(json_str)
return json_obj
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
try:
if not self.sagemaker_client:
aws_region = tool_parameters.get('aws_region')
if aws_region:
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
self.s3_client = boto3.client("s3", region_name=aws_region)
self.comprehend_client = boto3.client('comprehend', region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime")
self.s3_client = boto3.client("s3")
self.comprehend_client = boto3.client('comprehend')
if not self.sagemaker_endpoint:
self.sagemaker_endpoint = tool_parameters.get('sagemaker_endpoint')
tts_text = tool_parameters.get('tts_text')
tts_infer_type = tool_parameters.get('tts_infer_type')
voice = tool_parameters.get('voice')
mock_voice_audio = tool_parameters.get('mock_voice_audio')
mock_voice_text = tool_parameters.get('mock_voice_text')
voice_instruct_prompt = tool_parameters.get('voice_instruct_prompt')
payload = self._build_tts_payload(tts_infer_type, tts_text, voice, mock_voice_text, mock_voice_audio, voice_instruct_prompt)
result = self._invoke_sagemaker(payload, self.sagemaker_endpoint)
return self.create_text_message(text=result['s3_presign_url'])
except Exception as e:
return self.create_text_message(f'Exception {str(e)}')

View File

@ -0,0 +1,149 @@
identity:
name: sagemaker_tts
author: AWS
label:
en_US: SagemakerTTS
zh_Hans: Sagemaker语音合成
pt_BR: SagemakerTTS
icon: icon.svg
description:
human:
en_US: A tool for Speech synthesis - https://github.com/aws-samples/dify-aws-tool
zh_Hans: Sagemaker语音合成工具, 请参考 Github Repo - https://github.com/aws-samples/dify-aws-tool上的部署脚本
pt_BR: A tool for Speech synthesis.
llm: A tool for Speech synthesis. You can find deploy notebook on Github Repo - https://github.com/aws-samples/dify-aws-tool
parameters:
- name: sagemaker_endpoint
type: string
required: true
label:
en_US: sagemaker endpoint for tts
zh_Hans: 语音生成的SageMaker端点
pt_BR: sagemaker endpoint for tts
human_description:
en_US: sagemaker endpoint for tts
zh_Hans: 语音生成的SageMaker端点
pt_BR: sagemaker endpoint for tts
llm_description: sagemaker endpoint for tts
form: form
- name: tts_text
type: string
required: true
label:
en_US: tts text
zh_Hans: 语音合成原文
pt_BR: tts text
human_description:
en_US: tts text
zh_Hans: 语音合成原文
pt_BR: tts text
llm_description: tts text
form: llm
- name: tts_infer_type
type: select
required: false
label:
en_US: tts infer type
zh_Hans: 合成方式
pt_BR: tts infer type
human_description:
en_US: tts infer type
zh_Hans: 合成方式
pt_BR: tts infer type
llm_description: tts infer type
options:
- value: PresetVoice
label:
en_US: preset voice
zh_Hans: 预置音色
- value: CloneVoice
label:
en_US: clone voice
zh_Hans: 克隆音色
- value: CloneVoice_CrossLingual
label:
en_US: clone crossLingual voice
zh_Hans: 克隆音色(跨语言)
- value: InstructVoice
label:
en_US: instruct voice
zh_Hans: 指令音色
form: form
- name: voice
type: select
required: false
label:
en_US: preset voice
zh_Hans: 预置音色
pt_BR: preset voice
human_description:
en_US: preset voice
zh_Hans: 预置音色
pt_BR: preset voice
llm_description: preset voice
options:
- value: 中文男
label:
en_US: zh-cn male
zh_Hans: 中文男
- value: 中文女
label:
en_US: zh-cn female
zh_Hans: 中文女
- value: 粤语女
label:
en_US: zh-TW female
zh_Hans: 粤语女
form: form
- name: mock_voice_audio
type: string
required: false
label:
en_US: clone voice link
zh_Hans: 克隆音频链接
pt_BR: clone voice link
human_description:
en_US: clone voice link
zh_Hans: 克隆音频链接
pt_BR: clone voice link
llm_description: clone voice link
form: llm
- name: mock_voice_text
type: string
required: false
label:
en_US: text of clone voice
zh_Hans: 克隆音频对应文本
pt_BR: text of clone voice
human_description:
en_US: text of clone voice
zh_Hans: 克隆音频对应文本
pt_BR: text of clone voice
llm_description: text of clone voice
form: llm
- name: voice_instruct_prompt
type: string
required: false
label:
en_US: instruct prompt for voice
zh_Hans: 音色指令文本
pt_BR: instruct prompt for voice
human_description:
en_US: instruct prompt for voice
zh_Hans: 音色指令文本
pt_BR: instruct prompt for voice
llm_description: instruct prompt for voice
form: llm
- name: aws_region
type: string
required: false
label:
en_US: region of sagemaker endpoint
zh_Hans: SageMaker 端点所在的region
pt_BR: region of sagemaker endpoint
human_description:
en_US: region of sagemaker endpoint
zh_Hans: SageMaker 端点所在的region
pt_BR: region of sagemaker endpoint
llm_description: region of sagemaker endpoint
form: form

View File

@ -29,7 +29,7 @@ credentials_for_provider:
en_US: Please input your OpenAI API key
zh_Hans: 请输入你的 OpenAI API key
pt_BR: Please input your OpenAI API key
openai_organizaion_id:
openai_organization_id:
type: text-input
required: false
label:

View File

@ -16,7 +16,7 @@ class DallE2Tool(BuiltinTool):
"""
invoke tools
"""
openai_organization = self.runtime.credentials.get('openai_organizaion_id', None)
openai_organization = self.runtime.credentials.get('openai_organization_id', None)
if not openai_organization:
openai_organization = None
openai_base_url = self.runtime.credentials.get('openai_base_url', None)

View File

@ -17,7 +17,7 @@ class DallE3Tool(BuiltinTool):
"""
invoke tools
"""
openai_organization = self.runtime.credentials.get('openai_organizaion_id', None)
openai_organization = self.runtime.credentials.get('openai_organization_id', None)
if not openai_organization:
openai_organization = None
openai_base_url = self.runtime.credentials.get('openai_base_url', None)

View File

@ -4,7 +4,7 @@ from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class GihubProvider(BuiltinToolProviderController):
class GithubProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
if 'access_tokens' not in credentials or not credentials.get('access_tokens'):

View File

@ -9,7 +9,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class GihubRepositoriesTool(BuiltinTool):
class GithubRepositoriesTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools

View File

@ -60,7 +60,7 @@ class GitlabCommitsTool(BuiltinTool):
project_name = project['name']
print(f"Project: {project_name}")
# Get all of proejct commits
# Get all of project commits
commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits"
params = {
'since': start_time,
@ -83,7 +83,7 @@ class GitlabCommitsTool(BuiltinTool):
diffs = diff_response.json()
for diff in diffs:
# Caculate code lines of changed
# Calculate code lines of changed
added_lines = diff['diff'].count('\n+')
removed_lines = diff['diff'].count('\n-')
total_changes = added_lines + removed_lines

View File

@ -6,7 +6,7 @@ identity:
zh_Hans: GitLab 提交内容查询
description:
human:
en_US: A tool for query GitLab commits, Input should be a exists username or projec.
en_US: A tool for query GitLab commits, Input should be a exists username or project.
zh_Hans: 一个用于查询 GitLab 代码提交内容的工具,输入的内容应该是一个已存在的用户名或者项目名。
llm: A tool for query GitLab commits, Input should be a exists username or project.
parameters:

View File

@ -29,7 +29,7 @@ class OpenweatherTool(BuiltinTool):
# request URL
url = "https://api.openweathermap.org/data/2.5/weather"
# request parmas
# request params
params = {
"q": city,
"appid": self.runtime.credentials.get("api_key"),

View File

@ -0,0 +1,3 @@
<svg width="400" height="400" viewBox="0 0 400 400" fill="none" xmlns="http://www.w3.org/2000/svg">
<path fill-rule="evenodd" clip-rule="evenodd" d="M101.008 42L190.99 124.905L190.99 124.886L190.99 42.1913H208.506L208.506 125.276L298.891 42V136.524L336 136.524V272.866H299.005V357.035L208.506 277.525L208.506 357.948H190.99L190.99 278.836L101.11 358V272.866H64V136.524H101.008V42ZM177.785 153.826H81.5159V255.564H101.088V223.472L177.785 153.826ZM118.625 231.149V319.392L190.99 255.655L190.99 165.421L118.625 231.149ZM209.01 254.812V165.336L281.396 231.068V272.866H281.489V318.491L209.01 254.812ZM299.005 255.564H318.484V153.826L222.932 153.826L299.005 222.751V255.564ZM281.375 136.524V81.7983L221.977 136.524L281.375 136.524ZM177.921 136.524H118.524V81.7983L177.921 136.524Z" fill="black"/>
</svg>

After

Width:  |  Height:  |  Size: 798 B

View File

@ -0,0 +1,46 @@
from typing import Any
import requests
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.perplexity.tools.perplexity_search import PERPLEXITY_API_URL
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class PerplexityProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
headers = {
"Authorization": f"Bearer {credentials.get('perplexity_api_key')}",
"Content-Type": "application/json"
}
payload = {
"model": "llama-3.1-sonar-small-128k-online",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "Hello"
}
],
"max_tokens": 5,
"temperature": 0.1,
"top_p": 0.9,
"stream": False
}
try:
response = requests.post(PERPLEXITY_API_URL, json=payload, headers=headers)
response.raise_for_status()
except requests.RequestException as e:
raise ToolProviderCredentialValidationError(
f"Failed to validate Perplexity API key: {str(e)}"
)
if response.status_code != 200:
raise ToolProviderCredentialValidationError(
f"Perplexity API key is invalid. Status code: {response.status_code}"
)

View File

@ -0,0 +1,26 @@
identity:
author: Dify
name: perplexity
label:
en_US: Perplexity
zh_Hans: Perplexity
description:
en_US: Perplexity.AI
zh_Hans: Perplexity.AI
icon: icon.svg
tags:
- search
credentials_for_provider:
perplexity_api_key:
type: secret-input
required: true
label:
en_US: Perplexity API key
zh_Hans: Perplexity API key
placeholder:
en_US: Please input your Perplexity API key
zh_Hans: 请输入你的 Perplexity API key
help:
en_US: Get your Perplexity API key from Perplexity
zh_Hans: 从 Perplexity 获取您的 Perplexity API key
url: https://www.perplexity.ai/settings/api

View File

@ -0,0 +1,72 @@
import json
from typing import Any, Union
import requests
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
PERPLEXITY_API_URL = "https://api.perplexity.ai/chat/completions"
class PerplexityAITool(BuiltinTool):
def _parse_response(self, response: dict) -> dict:
"""Parse the response from Perplexity AI API"""
if 'choices' in response and len(response['choices']) > 0:
message = response['choices'][0]['message']
return {
'content': message.get('content', ''),
'role': message.get('role', ''),
'citations': response.get('citations', [])
}
else:
return {'content': 'Unable to get a valid response', 'role': 'assistant', 'citations': []}
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
headers = {
"Authorization": f"Bearer {self.runtime.credentials['perplexity_api_key']}",
"Content-Type": "application/json"
}
payload = {
"model": tool_parameters.get('model', 'llama-3.1-sonar-small-128k-online'),
"messages": [
{
"role": "system",
"content": "Be precise and concise."
},
{
"role": "user",
"content": tool_parameters['query']
}
],
"max_tokens": tool_parameters.get('max_tokens', 4096),
"temperature": tool_parameters.get('temperature', 0.7),
"top_p": tool_parameters.get('top_p', 1),
"top_k": tool_parameters.get('top_k', 5),
"presence_penalty": tool_parameters.get('presence_penalty', 0),
"frequency_penalty": tool_parameters.get('frequency_penalty', 1),
"stream": False
}
if 'search_recency_filter' in tool_parameters:
payload['search_recency_filter'] = tool_parameters['search_recency_filter']
if 'return_citations' in tool_parameters:
payload['return_citations'] = tool_parameters['return_citations']
if 'search_domain_filter' in tool_parameters:
if isinstance(tool_parameters['search_domain_filter'], str):
payload['search_domain_filter'] = [tool_parameters['search_domain_filter']]
elif isinstance(tool_parameters['search_domain_filter'], list):
payload['search_domain_filter'] = tool_parameters['search_domain_filter']
response = requests.post(url=PERPLEXITY_API_URL, json=payload, headers=headers)
response.raise_for_status()
valuable_res = self._parse_response(response.json())
return [
self.create_json_message(valuable_res),
self.create_text_message(json.dumps(valuable_res, ensure_ascii=False, indent=2))
]

View File

@ -0,0 +1,178 @@
identity:
name: perplexity
author: Dify
label:
en_US: Perplexity Search
description:
human:
en_US: Search information using Perplexity AI's language models.
llm: This tool is used to search information using Perplexity AI's language models.
parameters:
- name: query
type: string
required: true
label:
en_US: Query
zh_Hans: 查询
human_description:
en_US: The text query to be processed by the AI model.
zh_Hans: 要由 AI 模型处理的文本查询。
form: llm
- name: model
type: select
required: false
label:
en_US: Model Name
zh_Hans: 模型名称
human_description:
en_US: The Perplexity AI model to use for generating the response.
zh_Hans: 用于生成响应的 Perplexity AI 模型。
form: form
default: "llama-3.1-sonar-small-128k-online"
options:
- value: llama-3.1-sonar-small-128k-online
label:
en_US: llama-3.1-sonar-small-128k-online
zh_Hans: llama-3.1-sonar-small-128k-online
- value: llama-3.1-sonar-large-128k-online
label:
en_US: llama-3.1-sonar-large-128k-online
zh_Hans: llama-3.1-sonar-large-128k-online
- value: llama-3.1-sonar-huge-128k-online
label:
en_US: llama-3.1-sonar-huge-128k-online
zh_Hans: llama-3.1-sonar-huge-128k-online
- name: max_tokens
type: number
required: false
label:
en_US: Max Tokens
zh_Hans: 最大令牌数
pt_BR: Máximo de Tokens
human_description:
en_US: The maximum number of tokens to generate in the response.
zh_Hans: 在响应中生成的最大令牌数。
pt_BR: O número máximo de tokens a serem gerados na resposta.
form: form
default: 4096
min: 1
max: 4096
- name: temperature
type: number
required: false
label:
en_US: Temperature
zh_Hans: 温度
pt_BR: Temperatura
human_description:
en_US: Controls randomness in the output. Lower values make the output more focused and deterministic.
zh_Hans: 控制输出的随机性。较低的值使输出更加集中和确定。
form: form
default: 0.7
min: 0
max: 1
- name: top_k
type: number
required: false
label:
en_US: Top K
zh_Hans: 取样数量
human_description:
en_US: The number of top results to consider for response generation.
zh_Hans: 用于生成响应的顶部结果数量。
form: form
default: 5
min: 1
max: 100
- name: top_p
type: number
required: false
label:
en_US: Top P
zh_Hans: Top P
human_description:
en_US: Controls diversity via nucleus sampling.
zh_Hans: 通过核心采样控制多样性。
form: form
default: 1
min: 0.1
max: 1
step: 0.1
- name: presence_penalty
type: number
required: false
label:
en_US: Presence Penalty
zh_Hans: 存在惩罚
human_description:
en_US: Positive values penalize new tokens based on whether they appear in the text so far.
zh_Hans: 正值会根据新词元是否已经出现在文本中来对其进行惩罚。
form: form
default: 0
min: -1.0
max: 1.0
step: 0.1
- name: frequency_penalty
type: number
required: false
label:
en_US: Frequency Penalty
zh_Hans: 频率惩罚
human_description:
en_US: Positive values penalize new tokens based on their existing frequency in the text so far.
zh_Hans: 正值会根据新词元在文本中已经出现的频率来对其进行惩罚。
form: form
default: 1
min: 0.1
max: 1.0
step: 0.1
- name: return_citations
type: boolean
required: false
label:
en_US: Return Citations
zh_Hans: 返回引用
human_description:
en_US: Whether to return citations in the response.
zh_Hans: 是否在响应中返回引用。
form: form
default: true
- name: search_domain_filter
type: string
required: false
label:
en_US: Search Domain Filter
zh_Hans: 搜索域过滤器
human_description:
en_US: Domain to filter the search results.
zh_Hans: 用于过滤搜索结果的域名。
form: form
default: ""
- name: search_recency_filter
type: select
required: false
label:
en_US: Search Recency Filter
zh_Hans: 搜索时间过滤器
human_description:
en_US: Filter for search results based on recency.
zh_Hans: 基于时间筛选搜索结果。
form: form
default: "month"
options:
- value: day
label:
en_US: Day
zh_Hans:
- value: week
label:
en_US: Week
zh_Hans:
- value: month
label:
en_US: Month
zh_Hans:
- value: year
label:
en_US: Year
zh_Hans:

View File

@ -35,20 +35,20 @@ def sha256base64(data):
return digest
def parse_url(requset_url):
stidx = requset_url.index("://")
host = requset_url[stidx + 3 :]
schema = requset_url[: stidx + 3]
def parse_url(request_url):
stidx = request_url.index("://")
host = request_url[stidx + 3 :]
schema = request_url[: stidx + 3]
edidx = host.index("/")
if edidx <= 0:
raise AssembleHeaderException("invalid request url:" + requset_url)
raise AssembleHeaderException("invalid request url:" + request_url)
path = host[edidx:]
host = host[:edidx]
u = Url(host, path, schema)
return u
def assemble_ws_auth_url(requset_url, method="GET", api_key="", api_secret=""):
u = parse_url(requset_url)
def assemble_ws_auth_url(request_url, method="GET", api_key="", api_secret=""):
u = parse_url(request_url)
host = u.host
path = u.path
now = datetime.now()
@ -69,7 +69,7 @@ def assemble_ws_auth_url(requset_url, method="GET", api_key="", api_secret=""):
)
values = {"host": host, "date": date, "authorization": authorization}
return requset_url + "?" + urlencode(values)
return request_url + "?" + urlencode(values)
def get_body(appid, text):

View File

@ -42,6 +42,6 @@ class ScrapeTool(BuiltinTool):
result += "URL: " + i.get('url', '') + "\n"
result += "CONTENT: " + i.get('content', '') + "\n\n"
except Exception as e:
return self.create_text_message("An error occured", str(e))
return self.create_text_message("An error occurred", str(e))
return self.create_text_message(result)

View File

@ -17,11 +17,8 @@ class StepfunTool(BuiltinTool):
"""
invoke tools
"""
base_url = self.runtime.credentials.get('stepfun_base_url', None)
if not base_url:
base_url = None
else:
base_url = str(URL(base_url) / 'v1')
base_url = self.runtime.credentials.get('stepfun_base_url', 'https://api.stepfun.com')
base_url = str(URL(base_url) / 'v1')
client = OpenAI(
api_key=self.runtime.credentials['stepfun_api_key'],

View File

@ -8,7 +8,7 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.rerank.rerank_model import RerankModelRunner
from core.rag.retrieval.retrival_methods import RetrievalMethod
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
@ -163,7 +163,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(retrival_method='keyword_search',
documents = RetrievalService.retrieve(retrieval_method='keyword_search',
dataset_id=dataset.id,
query=query,
top_k=self.top_k
@ -173,7 +173,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
else:
if self.top_k > 0:
# retrieval source
documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
documents = RetrievalService.retrieve(retrieval_method=retrieval_model['search_method'],
dataset_id=dataset.id,
query=query,
top_k=self.top_k,

View File

@ -2,7 +2,7 @@
from pydantic import BaseModel, Field
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.retrieval.retrival_methods import RetrievalMethod
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
@ -63,7 +63,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(retrival_method='keyword_search',
documents = RetrievalService.retrieve(retrieval_method='keyword_search',
dataset_id=dataset.id,
query=query,
top_k=self.top_k
@ -72,7 +72,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
else:
if self.top_k > 0:
# retrieval source
documents = RetrievalService.retrieve(retrival_method=retrieval_model.get('search_method', 'semantic_search'),
documents = RetrievalService.retrieve(retrieval_method=retrieval_model.get('search_method', 'semantic_search'),
dataset_id=dataset.id,
query=query,
top_k=self.top_k,

View File

@ -18,7 +18,7 @@ from core.tools.tool.tool import Tool
class DatasetRetrieverTool(Tool):
retrival_tool: DatasetRetrieverBaseTool
retrieval_tool: DatasetRetrieverBaseTool
@staticmethod
def get_dataset_tools(tenant_id: str,
@ -43,7 +43,7 @@ class DatasetRetrieverTool(Tool):
# Agent only support SINGLE mode
original_retriever_mode = retrieve_config.retrieve_strategy
retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
retrival_tools = feature.to_dataset_retriever_tool(
retrieval_tools = feature.to_dataset_retriever_tool(
tenant_id=tenant_id,
dataset_ids=dataset_ids,
retrieve_config=retrieve_config,
@ -51,20 +51,23 @@ class DatasetRetrieverTool(Tool):
invoke_from=invoke_from,
hit_callback=hit_callback
)
if retrieval_tools is None or len(retrieval_tools) == 0:
return []
# restore retrieve strategy
retrieve_config.retrieve_strategy = original_retriever_mode
# convert retrival tools to Tools
# convert retrieval tools to Tools
tools = []
for retrival_tool in retrival_tools:
for retrieval_tool in retrieval_tools:
tool = DatasetRetrieverTool(
retrival_tool=retrival_tool,
identity=ToolIdentity(provider='', author='', name=retrival_tool.name, label=I18nObject(en_US='', zh_Hans='')),
retrieval_tool=retrieval_tool,
identity=ToolIdentity(provider='', author='', name=retrieval_tool.name, label=I18nObject(en_US='', zh_Hans='')),
parameters=[],
is_team_authorization=True,
description=ToolDescription(
human=I18nObject(en_US='', zh_Hans=''),
llm=retrival_tool.description),
llm=retrieval_tool.description),
runtime=DatasetRetrieverTool.Runtime()
)
@ -96,8 +99,7 @@ class DatasetRetrieverTool(Tool):
yield self.create_text_message(text='please input query')
else:
# invoke dataset retriever tool
result = self.retrival_tool._run(query=query)
result = self.retrieval_tool._run(query=query)
yield self.create_text_message(text=result)
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:

View File

@ -189,8 +189,8 @@ def extract_text_blocks_as_plain_text(paragraph_html):
def plain_text_leaf_node(element):
# Extract all text, stripped of any child HTML elements and normalise it
plain_text = normalise_text(element.get_text())
# Extract all text, stripped of any child HTML elements and normalize it
plain_text = normalize_text(element.get_text())
if plain_text != "" and element.name == "li":
plain_text = "* {}, ".format(plain_text)
if plain_text == "":
@ -231,8 +231,8 @@ def plain_element(element, content_digests, node_indexes):
# For leaf node elements, extract the text content, discarding any HTML tags
# 1. Get element contents as text
plain_text = element.get_text()
# 2. Normalise the extracted text string to a canonical representation
plain_text = normalise_text(plain_text)
# 2. Normalize the extracted text string to a canonical representation
plain_text = normalize_text(plain_text)
# 3. Update element content to be plain text
element.string = plain_text
elif is_text(element):
@ -243,7 +243,7 @@ def plain_element(element, content_digests, node_indexes):
element = type(element)("")
else:
plain_text = element.string
plain_text = normalise_text(plain_text)
plain_text = normalize_text(plain_text)
element = type(element)(plain_text)
else:
# If not a leaf node or leaf type call recursively on child nodes, replacing
@ -267,12 +267,12 @@ def add_node_indexes(element, node_index="0"):
return element
def normalise_text(text):
"""Normalise unicode and whitespace."""
# Normalise unicode first to try and standardise whitespace characters as much as possible before normalising them
def normalize_text(text):
"""Normalize unicode and whitespace."""
# Normalize unicode first to try and standardize whitespace characters as much as possible before normalizing them
text = strip_control_characters(text)
text = normalise_unicode(text)
text = normalise_whitespace(text)
text = normalize_unicode(text)
text = normalize_whitespace(text)
return text
@ -291,14 +291,14 @@ def strip_control_characters(text):
return "".join(["" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char for char in text])
def normalise_unicode(text):
"""Normalise unicode such that things that are visually equivalent map to the same unicode string where possible."""
def normalize_unicode(text):
"""Normalize unicode such that things that are visually equivalent map to the same unicode string where possible."""
normal_form = "NFKC"
text = unicodedata.normalize(normal_form, text)
return text
def normalise_whitespace(text):
def normalize_whitespace(text):
"""Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed."""
text = regex.sub(r"\s+", " ", text)
# Remove leading and trailing whitespace