mirror of
https://github.com/langgenius/dify.git
synced 2026-04-27 14:08:18 +08:00
Merge main
This commit is contained in:
@ -1,5 +1,6 @@
|
||||
- google
|
||||
- bing
|
||||
- perplexity
|
||||
- duckduckgo
|
||||
- searchapi
|
||||
- serper
|
||||
|
||||
@ -3,6 +3,7 @@ import logging
|
||||
from typing import Any, Union
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import BotoCoreError
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
@ -16,7 +17,7 @@ class GuardrailParameters(BaseModel):
|
||||
guardrail_version: str = Field(..., description="The version of the guardrail")
|
||||
source: str = Field(..., description="The source of the content")
|
||||
text: str = Field(..., description="The text to apply the guardrail to")
|
||||
aws_region: str = Field(default="us-east-1", description="AWS region for the Bedrock client")
|
||||
aws_region: str = Field(..., description="AWS region for the Bedrock client")
|
||||
|
||||
class ApplyGuardrailTool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
@ -40,6 +41,8 @@ class ApplyGuardrailTool(BuiltinTool):
|
||||
source=params.source,
|
||||
content=[{"text": {"text": params.text}}]
|
||||
)
|
||||
|
||||
logger.info(f"Raw response from AWS: {json.dumps(response, indent=2)}")
|
||||
|
||||
# Check for empty response
|
||||
if not response:
|
||||
@ -69,7 +72,7 @@ class ApplyGuardrailTool(BuiltinTool):
|
||||
|
||||
return self.create_text_message(text=result)
|
||||
|
||||
except boto3.exceptions.BotoCoreError as e:
|
||||
except BotoCoreError as e:
|
||||
error_message = f'AWS service error: {str(e)}'
|
||||
logger.error(error_message, exc_info=True)
|
||||
return self.create_text_message(text=error_message)
|
||||
@ -80,4 +83,4 @@ class ApplyGuardrailTool(BuiltinTool):
|
||||
except Exception as e:
|
||||
error_message = f'An unexpected error occurred: {str(e)}'
|
||||
logger.error(error_message, exc_info=True)
|
||||
return self.create_text_message(text=error_message)
|
||||
return self.create_text_message(text=error_message)
|
||||
@ -54,3 +54,14 @@ parameters:
|
||||
zh_Hans: 用于请求护栏审查的内容,可以是用户输入或 LLM 输出。
|
||||
llm_description: The content used for requesting guardrail review, which can be either user input or LLM output.
|
||||
form: llm
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: AWS Region
|
||||
zh_Hans: AWS 区域
|
||||
human_description:
|
||||
en_US: Please enter the AWS region for the Bedrock client, for example 'us-east-1'.
|
||||
zh_Hans: 请输入 Bedrock 客户端的 AWS 区域,例如 'us-east-1'。
|
||||
llm_description: Please enter the AWS region for the Bedrock client, for example 'us-east-1'.
|
||||
form: form
|
||||
|
||||
@ -0,0 +1,71 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
|
||||
import boto3
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
console_handler = logging.StreamHandler()
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
|
||||
class LambdaYamlToJsonTool(BuiltinTool):
|
||||
lambda_client: Any = None
|
||||
|
||||
def _invoke_lambda(self, lambda_name: str, yaml_content: str) -> str:
|
||||
msg = {
|
||||
"body": yaml_content
|
||||
}
|
||||
logger.info(json.dumps(msg))
|
||||
|
||||
invoke_response = self.lambda_client.invoke(FunctionName=lambda_name,
|
||||
InvocationType='RequestResponse',
|
||||
Payload=json.dumps(msg))
|
||||
response_body = invoke_response['Payload']
|
||||
|
||||
response_str = response_body.read().decode("utf-8")
|
||||
resp_json = json.loads(response_str)
|
||||
|
||||
logger.info(resp_json)
|
||||
if resp_json['statusCode'] != 200:
|
||||
raise Exception(f"Invalid status code: {response_str}")
|
||||
|
||||
return resp_json['body']
|
||||
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
try:
|
||||
if not self.lambda_client:
|
||||
aws_region = tool_parameters.get('aws_region') # todo: move aws_region out, and update client region
|
||||
if aws_region:
|
||||
self.lambda_client = boto3.client("lambda", region_name=aws_region)
|
||||
else:
|
||||
self.lambda_client = boto3.client("lambda")
|
||||
|
||||
yaml_content = tool_parameters.get('yaml_content', '')
|
||||
if not yaml_content:
|
||||
return self.create_text_message('Please input yaml_content')
|
||||
|
||||
lambda_name = tool_parameters.get('lambda_name', '')
|
||||
if not lambda_name:
|
||||
return self.create_text_message('Please input lambda_name')
|
||||
logger.debug(f'{json.dumps(tool_parameters, indent=2, ensure_ascii=False)}')
|
||||
|
||||
result = self._invoke_lambda(lambda_name, yaml_content)
|
||||
logger.debug(result)
|
||||
|
||||
return self.create_text_message(result)
|
||||
except Exception as e:
|
||||
return self.create_text_message(f'Exception: {str(e)}')
|
||||
|
||||
console_handler.flush()
|
||||
@ -0,0 +1,53 @@
|
||||
identity:
|
||||
name: lambda_yaml_to_json
|
||||
author: AWS
|
||||
label:
|
||||
en_US: LambdaYamlToJson
|
||||
zh_Hans: LambdaYamlToJson
|
||||
pt_BR: LambdaYamlToJson
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: A tool to convert yaml to json using AWS Lambda.
|
||||
zh_Hans: 将 YAML 转为 JSON 的工具(通过AWS Lambda)。
|
||||
pt_BR: A tool to convert yaml to json using AWS Lambda.
|
||||
llm: A tool to convert yaml to json.
|
||||
parameters:
|
||||
- name: yaml_content
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: YAML content to convert for
|
||||
zh_Hans: YAML 内容
|
||||
pt_BR: YAML content to convert for
|
||||
human_description:
|
||||
en_US: YAML content to convert for
|
||||
zh_Hans: YAML 内容
|
||||
pt_BR: YAML content to convert for
|
||||
llm_description: YAML content to convert for
|
||||
form: llm
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: region of lambda
|
||||
zh_Hans: Lambda 所在的region
|
||||
pt_BR: region of lambda
|
||||
human_description:
|
||||
en_US: region of lambda
|
||||
zh_Hans: Lambda 所在的region
|
||||
pt_BR: region of lambda
|
||||
llm_description: region of lambda
|
||||
form: form
|
||||
- name: lambda_name
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: name of lambda
|
||||
zh_Hans: Lambda 名称
|
||||
pt_BR: name of lambda
|
||||
human_description:
|
||||
en_US: name of lambda
|
||||
zh_Hans: Lambda 名称
|
||||
pt_BR: name of lambda
|
||||
form: form
|
||||
@ -78,9 +78,7 @@ class SageMakerReRankTool(BuiltinTool):
|
||||
sorted_candidate_docs = sorted(candidate_docs, key=lambda x: x['score'], reverse=True)
|
||||
|
||||
line = 9
|
||||
results_str = json.dumps(sorted_candidate_docs[:self.topk], ensure_ascii=False)
|
||||
return self.create_text_message(text=results_str)
|
||||
return [ self.create_json_message(res) for res in sorted_candidate_docs[:self.topk] ]
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f'Exception {str(e)}, line : {line}')
|
||||
|
||||
return self.create_text_message(f'Exception {str(e)}, line : {line}')
|
||||
95
api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py
Normal file
95
api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py
Normal file
@ -0,0 +1,95 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Any, Union
|
||||
|
||||
import boto3
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class TTSModelType(Enum):
|
||||
PresetVoice = "PresetVoice"
|
||||
CloneVoice = "CloneVoice"
|
||||
CloneVoice_CrossLingual = "CloneVoice_CrossLingual"
|
||||
InstructVoice = "InstructVoice"
|
||||
|
||||
class SageMakerTTSTool(BuiltinTool):
|
||||
sagemaker_client: Any = None
|
||||
sagemaker_endpoint:str = None
|
||||
s3_client : Any = None
|
||||
comprehend_client : Any = None
|
||||
|
||||
def _detect_lang_code(self, content:str, map_dict:dict=None):
|
||||
map_dict = {
|
||||
"zh" : "<|zh|>",
|
||||
"en" : "<|en|>",
|
||||
"ja" : "<|jp|>",
|
||||
"zh-TW" : "<|yue|>",
|
||||
"ko" : "<|ko|>"
|
||||
}
|
||||
|
||||
response = self.comprehend_client.detect_dominant_language(Text=content)
|
||||
language_code = response['Languages'][0]['LanguageCode']
|
||||
return map_dict.get(language_code, '<|zh|>')
|
||||
|
||||
def _build_tts_payload(self, model_type:str, content_text:str, model_role:str, prompt_text:str, prompt_audio:str, instruct_text:str):
|
||||
if model_type == TTSModelType.PresetVoice.value and model_role:
|
||||
return { "tts_text" : content_text, "role" : model_role }
|
||||
if model_type == TTSModelType.CloneVoice.value and prompt_text and prompt_audio:
|
||||
return { "tts_text" : content_text, "prompt_text": prompt_text, "prompt_audio" : prompt_audio }
|
||||
if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio:
|
||||
lang_tag = self._detect_lang_code(content_text)
|
||||
return { "tts_text" : f"{content_text}", "prompt_audio" : prompt_audio, "lang_tag" : lang_tag }
|
||||
if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role:
|
||||
return { "tts_text" : content_text, "role" : model_role, "instruct_text" : instruct_text }
|
||||
|
||||
raise RuntimeError(f"Invalid params for {model_type}")
|
||||
|
||||
def _invoke_sagemaker(self, payload:dict, endpoint:str):
|
||||
response_model = self.sagemaker_client.invoke_endpoint(
|
||||
EndpointName=endpoint,
|
||||
Body=json.dumps(payload),
|
||||
ContentType="application/json",
|
||||
)
|
||||
json_str = response_model['Body'].read().decode('utf8')
|
||||
json_obj = json.loads(json_str)
|
||||
return json_obj
|
||||
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
try:
|
||||
if not self.sagemaker_client:
|
||||
aws_region = tool_parameters.get('aws_region')
|
||||
if aws_region:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||
self.s3_client = boto3.client("s3", region_name=aws_region)
|
||||
self.comprehend_client = boto3.client('comprehend', region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||
self.s3_client = boto3.client("s3")
|
||||
self.comprehend_client = boto3.client('comprehend')
|
||||
|
||||
if not self.sagemaker_endpoint:
|
||||
self.sagemaker_endpoint = tool_parameters.get('sagemaker_endpoint')
|
||||
|
||||
tts_text = tool_parameters.get('tts_text')
|
||||
tts_infer_type = tool_parameters.get('tts_infer_type')
|
||||
|
||||
voice = tool_parameters.get('voice')
|
||||
mock_voice_audio = tool_parameters.get('mock_voice_audio')
|
||||
mock_voice_text = tool_parameters.get('mock_voice_text')
|
||||
voice_instruct_prompt = tool_parameters.get('voice_instruct_prompt')
|
||||
payload = self._build_tts_payload(tts_infer_type, tts_text, voice, mock_voice_text, mock_voice_audio, voice_instruct_prompt)
|
||||
|
||||
result = self._invoke_sagemaker(payload, self.sagemaker_endpoint)
|
||||
|
||||
return self.create_text_message(text=result['s3_presign_url'])
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f'Exception {str(e)}')
|
||||
149
api/core/tools/provider/builtin/aws/tools/sagemaker_tts.yaml
Normal file
149
api/core/tools/provider/builtin/aws/tools/sagemaker_tts.yaml
Normal file
@ -0,0 +1,149 @@
|
||||
identity:
|
||||
name: sagemaker_tts
|
||||
author: AWS
|
||||
label:
|
||||
en_US: SagemakerTTS
|
||||
zh_Hans: Sagemaker语音合成
|
||||
pt_BR: SagemakerTTS
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for Speech synthesis - https://github.com/aws-samples/dify-aws-tool
|
||||
zh_Hans: Sagemaker语音合成工具, 请参考 Github Repo - https://github.com/aws-samples/dify-aws-tool上的部署脚本
|
||||
pt_BR: A tool for Speech synthesis.
|
||||
llm: A tool for Speech synthesis. You can find deploy notebook on Github Repo - https://github.com/aws-samples/dify-aws-tool
|
||||
parameters:
|
||||
- name: sagemaker_endpoint
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: sagemaker endpoint for tts
|
||||
zh_Hans: 语音生成的SageMaker端点
|
||||
pt_BR: sagemaker endpoint for tts
|
||||
human_description:
|
||||
en_US: sagemaker endpoint for tts
|
||||
zh_Hans: 语音生成的SageMaker端点
|
||||
pt_BR: sagemaker endpoint for tts
|
||||
llm_description: sagemaker endpoint for tts
|
||||
form: form
|
||||
- name: tts_text
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: tts text
|
||||
zh_Hans: 语音合成原文
|
||||
pt_BR: tts text
|
||||
human_description:
|
||||
en_US: tts text
|
||||
zh_Hans: 语音合成原文
|
||||
pt_BR: tts text
|
||||
llm_description: tts text
|
||||
form: llm
|
||||
- name: tts_infer_type
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: tts infer type
|
||||
zh_Hans: 合成方式
|
||||
pt_BR: tts infer type
|
||||
human_description:
|
||||
en_US: tts infer type
|
||||
zh_Hans: 合成方式
|
||||
pt_BR: tts infer type
|
||||
llm_description: tts infer type
|
||||
options:
|
||||
- value: PresetVoice
|
||||
label:
|
||||
en_US: preset voice
|
||||
zh_Hans: 预置音色
|
||||
- value: CloneVoice
|
||||
label:
|
||||
en_US: clone voice
|
||||
zh_Hans: 克隆音色
|
||||
- value: CloneVoice_CrossLingual
|
||||
label:
|
||||
en_US: clone crossLingual voice
|
||||
zh_Hans: 克隆音色(跨语言)
|
||||
- value: InstructVoice
|
||||
label:
|
||||
en_US: instruct voice
|
||||
zh_Hans: 指令音色
|
||||
form: form
|
||||
- name: voice
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: preset voice
|
||||
zh_Hans: 预置音色
|
||||
pt_BR: preset voice
|
||||
human_description:
|
||||
en_US: preset voice
|
||||
zh_Hans: 预置音色
|
||||
pt_BR: preset voice
|
||||
llm_description: preset voice
|
||||
options:
|
||||
- value: 中文男
|
||||
label:
|
||||
en_US: zh-cn male
|
||||
zh_Hans: 中文男
|
||||
- value: 中文女
|
||||
label:
|
||||
en_US: zh-cn female
|
||||
zh_Hans: 中文女
|
||||
- value: 粤语女
|
||||
label:
|
||||
en_US: zh-TW female
|
||||
zh_Hans: 粤语女
|
||||
form: form
|
||||
- name: mock_voice_audio
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: clone voice link
|
||||
zh_Hans: 克隆音频链接
|
||||
pt_BR: clone voice link
|
||||
human_description:
|
||||
en_US: clone voice link
|
||||
zh_Hans: 克隆音频链接
|
||||
pt_BR: clone voice link
|
||||
llm_description: clone voice link
|
||||
form: llm
|
||||
- name: mock_voice_text
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: text of clone voice
|
||||
zh_Hans: 克隆音频对应文本
|
||||
pt_BR: text of clone voice
|
||||
human_description:
|
||||
en_US: text of clone voice
|
||||
zh_Hans: 克隆音频对应文本
|
||||
pt_BR: text of clone voice
|
||||
llm_description: text of clone voice
|
||||
form: llm
|
||||
- name: voice_instruct_prompt
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: instruct prompt for voice
|
||||
zh_Hans: 音色指令文本
|
||||
pt_BR: instruct prompt for voice
|
||||
human_description:
|
||||
en_US: instruct prompt for voice
|
||||
zh_Hans: 音色指令文本
|
||||
pt_BR: instruct prompt for voice
|
||||
llm_description: instruct prompt for voice
|
||||
form: llm
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: region of sagemaker endpoint
|
||||
zh_Hans: SageMaker 端点所在的region
|
||||
pt_BR: region of sagemaker endpoint
|
||||
human_description:
|
||||
en_US: region of sagemaker endpoint
|
||||
zh_Hans: SageMaker 端点所在的region
|
||||
pt_BR: region of sagemaker endpoint
|
||||
llm_description: region of sagemaker endpoint
|
||||
form: form
|
||||
@ -29,7 +29,7 @@ credentials_for_provider:
|
||||
en_US: Please input your OpenAI API key
|
||||
zh_Hans: 请输入你的 OpenAI API key
|
||||
pt_BR: Please input your OpenAI API key
|
||||
openai_organizaion_id:
|
||||
openai_organization_id:
|
||||
type: text-input
|
||||
required: false
|
||||
label:
|
||||
|
||||
@ -16,7 +16,7 @@ class DallE2Tool(BuiltinTool):
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
openai_organization = self.runtime.credentials.get('openai_organizaion_id', None)
|
||||
openai_organization = self.runtime.credentials.get('openai_organization_id', None)
|
||||
if not openai_organization:
|
||||
openai_organization = None
|
||||
openai_base_url = self.runtime.credentials.get('openai_base_url', None)
|
||||
|
||||
@ -17,7 +17,7 @@ class DallE3Tool(BuiltinTool):
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
openai_organization = self.runtime.credentials.get('openai_organizaion_id', None)
|
||||
openai_organization = self.runtime.credentials.get('openai_organization_id', None)
|
||||
if not openai_organization:
|
||||
openai_organization = None
|
||||
openai_base_url = self.runtime.credentials.get('openai_base_url', None)
|
||||
|
||||
@ -4,7 +4,7 @@ from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class GihubProvider(BuiltinToolProviderController):
|
||||
class GithubProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
if 'access_tokens' not in credentials or not credentials.get('access_tokens'):
|
||||
|
||||
@ -9,7 +9,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class GihubRepositoriesTool(BuiltinTool):
|
||||
class GithubRepositoriesTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
|
||||
@ -60,7 +60,7 @@ class GitlabCommitsTool(BuiltinTool):
|
||||
project_name = project['name']
|
||||
print(f"Project: {project_name}")
|
||||
|
||||
# Get all of proejct commits
|
||||
# Get all of project commits
|
||||
commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits"
|
||||
params = {
|
||||
'since': start_time,
|
||||
@ -83,7 +83,7 @@ class GitlabCommitsTool(BuiltinTool):
|
||||
diffs = diff_response.json()
|
||||
|
||||
for diff in diffs:
|
||||
# Caculate code lines of changed
|
||||
# Calculate code lines of changed
|
||||
added_lines = diff['diff'].count('\n+')
|
||||
removed_lines = diff['diff'].count('\n-')
|
||||
total_changes = added_lines + removed_lines
|
||||
|
||||
@ -6,7 +6,7 @@ identity:
|
||||
zh_Hans: GitLab 提交内容查询
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for query GitLab commits, Input should be a exists username or projec.
|
||||
en_US: A tool for query GitLab commits, Input should be a exists username or project.
|
||||
zh_Hans: 一个用于查询 GitLab 代码提交内容的工具,输入的内容应该是一个已存在的用户名或者项目名。
|
||||
llm: A tool for query GitLab commits, Input should be a exists username or project.
|
||||
parameters:
|
||||
|
||||
@ -29,7 +29,7 @@ class OpenweatherTool(BuiltinTool):
|
||||
# request URL
|
||||
url = "https://api.openweathermap.org/data/2.5/weather"
|
||||
|
||||
# request parmas
|
||||
# request params
|
||||
params = {
|
||||
"q": city,
|
||||
"appid": self.runtime.credentials.get("api_key"),
|
||||
|
||||
@ -0,0 +1,3 @@
|
||||
<svg width="400" height="400" viewBox="0 0 400 400" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M101.008 42L190.99 124.905L190.99 124.886L190.99 42.1913H208.506L208.506 125.276L298.891 42V136.524L336 136.524V272.866H299.005V357.035L208.506 277.525L208.506 357.948H190.99L190.99 278.836L101.11 358V272.866H64V136.524H101.008V42ZM177.785 153.826H81.5159V255.564H101.088V223.472L177.785 153.826ZM118.625 231.149V319.392L190.99 255.655L190.99 165.421L118.625 231.149ZM209.01 254.812V165.336L281.396 231.068V272.866H281.489V318.491L209.01 254.812ZM299.005 255.564H318.484V153.826L222.932 153.826L299.005 222.751V255.564ZM281.375 136.524V81.7983L221.977 136.524L281.375 136.524ZM177.921 136.524H118.524V81.7983L177.921 136.524Z" fill="black"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 798 B |
46
api/core/tools/provider/builtin/perplexity/perplexity.py
Normal file
46
api/core/tools/provider/builtin/perplexity/perplexity.py
Normal file
@ -0,0 +1,46 @@
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.perplexity.tools.perplexity_search import PERPLEXITY_API_URL
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class PerplexityProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {credentials.get('perplexity_api_key')}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": "llama-3.1-sonar-small-128k-online",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello"
|
||||
}
|
||||
],
|
||||
"max_tokens": 5,
|
||||
"temperature": 0.1,
|
||||
"top_p": 0.9,
|
||||
"stream": False
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(PERPLEXITY_API_URL, json=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
except requests.RequestException as e:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
f"Failed to validate Perplexity API key: {str(e)}"
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
f"Perplexity API key is invalid. Status code: {response.status_code}"
|
||||
)
|
||||
26
api/core/tools/provider/builtin/perplexity/perplexity.yaml
Normal file
26
api/core/tools/provider/builtin/perplexity/perplexity.yaml
Normal file
@ -0,0 +1,26 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: perplexity
|
||||
label:
|
||||
en_US: Perplexity
|
||||
zh_Hans: Perplexity
|
||||
description:
|
||||
en_US: Perplexity.AI
|
||||
zh_Hans: Perplexity.AI
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- search
|
||||
credentials_for_provider:
|
||||
perplexity_api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Perplexity API key
|
||||
zh_Hans: Perplexity API key
|
||||
placeholder:
|
||||
en_US: Please input your Perplexity API key
|
||||
zh_Hans: 请输入你的 Perplexity API key
|
||||
help:
|
||||
en_US: Get your Perplexity API key from Perplexity
|
||||
zh_Hans: 从 Perplexity 获取您的 Perplexity API key
|
||||
url: https://www.perplexity.ai/settings/api
|
||||
@ -0,0 +1,72 @@
|
||||
import json
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
PERPLEXITY_API_URL = "https://api.perplexity.ai/chat/completions"
|
||||
|
||||
class PerplexityAITool(BuiltinTool):
|
||||
def _parse_response(self, response: dict) -> dict:
|
||||
"""Parse the response from Perplexity AI API"""
|
||||
if 'choices' in response and len(response['choices']) > 0:
|
||||
message = response['choices'][0]['message']
|
||||
return {
|
||||
'content': message.get('content', ''),
|
||||
'role': message.get('role', ''),
|
||||
'citations': response.get('citations', [])
|
||||
}
|
||||
else:
|
||||
return {'content': 'Unable to get a valid response', 'role': 'assistant', 'citations': []}
|
||||
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.runtime.credentials['perplexity_api_key']}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": tool_parameters.get('model', 'llama-3.1-sonar-small-128k-online'),
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Be precise and concise."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": tool_parameters['query']
|
||||
}
|
||||
],
|
||||
"max_tokens": tool_parameters.get('max_tokens', 4096),
|
||||
"temperature": tool_parameters.get('temperature', 0.7),
|
||||
"top_p": tool_parameters.get('top_p', 1),
|
||||
"top_k": tool_parameters.get('top_k', 5),
|
||||
"presence_penalty": tool_parameters.get('presence_penalty', 0),
|
||||
"frequency_penalty": tool_parameters.get('frequency_penalty', 1),
|
||||
"stream": False
|
||||
}
|
||||
|
||||
if 'search_recency_filter' in tool_parameters:
|
||||
payload['search_recency_filter'] = tool_parameters['search_recency_filter']
|
||||
if 'return_citations' in tool_parameters:
|
||||
payload['return_citations'] = tool_parameters['return_citations']
|
||||
if 'search_domain_filter' in tool_parameters:
|
||||
if isinstance(tool_parameters['search_domain_filter'], str):
|
||||
payload['search_domain_filter'] = [tool_parameters['search_domain_filter']]
|
||||
elif isinstance(tool_parameters['search_domain_filter'], list):
|
||||
payload['search_domain_filter'] = tool_parameters['search_domain_filter']
|
||||
|
||||
|
||||
response = requests.post(url=PERPLEXITY_API_URL, json=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
valuable_res = self._parse_response(response.json())
|
||||
|
||||
return [
|
||||
self.create_json_message(valuable_res),
|
||||
self.create_text_message(json.dumps(valuable_res, ensure_ascii=False, indent=2))
|
||||
]
|
||||
@ -0,0 +1,178 @@
|
||||
identity:
|
||||
name: perplexity
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Perplexity Search
|
||||
description:
|
||||
human:
|
||||
en_US: Search information using Perplexity AI's language models.
|
||||
llm: This tool is used to search information using Perplexity AI's language models.
|
||||
parameters:
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Query
|
||||
zh_Hans: 查询
|
||||
human_description:
|
||||
en_US: The text query to be processed by the AI model.
|
||||
zh_Hans: 要由 AI 模型处理的文本查询。
|
||||
form: llm
|
||||
- name: model
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: Model Name
|
||||
zh_Hans: 模型名称
|
||||
human_description:
|
||||
en_US: The Perplexity AI model to use for generating the response.
|
||||
zh_Hans: 用于生成响应的 Perplexity AI 模型。
|
||||
form: form
|
||||
default: "llama-3.1-sonar-small-128k-online"
|
||||
options:
|
||||
- value: llama-3.1-sonar-small-128k-online
|
||||
label:
|
||||
en_US: llama-3.1-sonar-small-128k-online
|
||||
zh_Hans: llama-3.1-sonar-small-128k-online
|
||||
- value: llama-3.1-sonar-large-128k-online
|
||||
label:
|
||||
en_US: llama-3.1-sonar-large-128k-online
|
||||
zh_Hans: llama-3.1-sonar-large-128k-online
|
||||
- value: llama-3.1-sonar-huge-128k-online
|
||||
label:
|
||||
en_US: llama-3.1-sonar-huge-128k-online
|
||||
zh_Hans: llama-3.1-sonar-huge-128k-online
|
||||
- name: max_tokens
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Max Tokens
|
||||
zh_Hans: 最大令牌数
|
||||
pt_BR: Máximo de Tokens
|
||||
human_description:
|
||||
en_US: The maximum number of tokens to generate in the response.
|
||||
zh_Hans: 在响应中生成的最大令牌数。
|
||||
pt_BR: O número máximo de tokens a serem gerados na resposta.
|
||||
form: form
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: temperature
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Temperature
|
||||
zh_Hans: 温度
|
||||
pt_BR: Temperatura
|
||||
human_description:
|
||||
en_US: Controls randomness in the output. Lower values make the output more focused and deterministic.
|
||||
zh_Hans: 控制输出的随机性。较低的值使输出更加集中和确定。
|
||||
form: form
|
||||
default: 0.7
|
||||
min: 0
|
||||
max: 1
|
||||
- name: top_k
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Top K
|
||||
zh_Hans: 取样数量
|
||||
human_description:
|
||||
en_US: The number of top results to consider for response generation.
|
||||
zh_Hans: 用于生成响应的顶部结果数量。
|
||||
form: form
|
||||
default: 5
|
||||
min: 1
|
||||
max: 100
|
||||
- name: top_p
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Top P
|
||||
zh_Hans: Top P
|
||||
human_description:
|
||||
en_US: Controls diversity via nucleus sampling.
|
||||
zh_Hans: 通过核心采样控制多样性。
|
||||
form: form
|
||||
default: 1
|
||||
min: 0.1
|
||||
max: 1
|
||||
step: 0.1
|
||||
- name: presence_penalty
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Presence Penalty
|
||||
zh_Hans: 存在惩罚
|
||||
human_description:
|
||||
en_US: Positive values penalize new tokens based on whether they appear in the text so far.
|
||||
zh_Hans: 正值会根据新词元是否已经出现在文本中来对其进行惩罚。
|
||||
form: form
|
||||
default: 0
|
||||
min: -1.0
|
||||
max: 1.0
|
||||
step: 0.1
|
||||
- name: frequency_penalty
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Frequency Penalty
|
||||
zh_Hans: 频率惩罚
|
||||
human_description:
|
||||
en_US: Positive values penalize new tokens based on their existing frequency in the text so far.
|
||||
zh_Hans: 正值会根据新词元在文本中已经出现的频率来对其进行惩罚。
|
||||
form: form
|
||||
default: 1
|
||||
min: 0.1
|
||||
max: 1.0
|
||||
step: 0.1
|
||||
- name: return_citations
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Return Citations
|
||||
zh_Hans: 返回引用
|
||||
human_description:
|
||||
en_US: Whether to return citations in the response.
|
||||
zh_Hans: 是否在响应中返回引用。
|
||||
form: form
|
||||
default: true
|
||||
- name: search_domain_filter
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Search Domain Filter
|
||||
zh_Hans: 搜索域过滤器
|
||||
human_description:
|
||||
en_US: Domain to filter the search results.
|
||||
zh_Hans: 用于过滤搜索结果的域名。
|
||||
form: form
|
||||
default: ""
|
||||
- name: search_recency_filter
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: Search Recency Filter
|
||||
zh_Hans: 搜索时间过滤器
|
||||
human_description:
|
||||
en_US: Filter for search results based on recency.
|
||||
zh_Hans: 基于时间筛选搜索结果。
|
||||
form: form
|
||||
default: "month"
|
||||
options:
|
||||
- value: day
|
||||
label:
|
||||
en_US: Day
|
||||
zh_Hans: 天
|
||||
- value: week
|
||||
label:
|
||||
en_US: Week
|
||||
zh_Hans: 周
|
||||
- value: month
|
||||
label:
|
||||
en_US: Month
|
||||
zh_Hans: 月
|
||||
- value: year
|
||||
label:
|
||||
en_US: Year
|
||||
zh_Hans: 年
|
||||
@ -35,20 +35,20 @@ def sha256base64(data):
|
||||
return digest
|
||||
|
||||
|
||||
def parse_url(requset_url):
|
||||
stidx = requset_url.index("://")
|
||||
host = requset_url[stidx + 3 :]
|
||||
schema = requset_url[: stidx + 3]
|
||||
def parse_url(request_url):
|
||||
stidx = request_url.index("://")
|
||||
host = request_url[stidx + 3 :]
|
||||
schema = request_url[: stidx + 3]
|
||||
edidx = host.index("/")
|
||||
if edidx <= 0:
|
||||
raise AssembleHeaderException("invalid request url:" + requset_url)
|
||||
raise AssembleHeaderException("invalid request url:" + request_url)
|
||||
path = host[edidx:]
|
||||
host = host[:edidx]
|
||||
u = Url(host, path, schema)
|
||||
return u
|
||||
|
||||
def assemble_ws_auth_url(requset_url, method="GET", api_key="", api_secret=""):
|
||||
u = parse_url(requset_url)
|
||||
def assemble_ws_auth_url(request_url, method="GET", api_key="", api_secret=""):
|
||||
u = parse_url(request_url)
|
||||
host = u.host
|
||||
path = u.path
|
||||
now = datetime.now()
|
||||
@ -69,7 +69,7 @@ def assemble_ws_auth_url(requset_url, method="GET", api_key="", api_secret=""):
|
||||
)
|
||||
values = {"host": host, "date": date, "authorization": authorization}
|
||||
|
||||
return requset_url + "?" + urlencode(values)
|
||||
return request_url + "?" + urlencode(values)
|
||||
|
||||
|
||||
def get_body(appid, text):
|
||||
|
||||
@ -42,6 +42,6 @@ class ScrapeTool(BuiltinTool):
|
||||
result += "URL: " + i.get('url', '') + "\n"
|
||||
result += "CONTENT: " + i.get('content', '') + "\n\n"
|
||||
except Exception as e:
|
||||
return self.create_text_message("An error occured", str(e))
|
||||
return self.create_text_message("An error occurred", str(e))
|
||||
|
||||
return self.create_text_message(result)
|
||||
|
||||
@ -17,11 +17,8 @@ class StepfunTool(BuiltinTool):
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
base_url = self.runtime.credentials.get('stepfun_base_url', None)
|
||||
if not base_url:
|
||||
base_url = None
|
||||
else:
|
||||
base_url = str(URL(base_url) / 'v1')
|
||||
base_url = self.runtime.credentials.get('stepfun_base_url', 'https://api.stepfun.com')
|
||||
base_url = str(URL(base_url) / 'v1')
|
||||
|
||||
client = OpenAI(
|
||||
api_key=self.runtime.credentials['stepfun_api_key'],
|
||||
|
||||
@ -8,7 +8,7 @@ from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.rerank.rerank_model import RerankModelRunner
|
||||
from core.rag.retrieval.retrival_methods import RetrievalMethod
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
@ -163,7 +163,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(retrival_method='keyword_search',
|
||||
documents = RetrievalService.retrieve(retrieval_method='keyword_search',
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=self.top_k
|
||||
@ -173,7 +173,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
else:
|
||||
if self.top_k > 0:
|
||||
# retrieval source
|
||||
documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
|
||||
documents = RetrievalService.retrieve(retrieval_method=retrieval_model['search_method'],
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=self.top_k,
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.retrieval.retrival_methods import RetrievalMethod
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
@ -63,7 +63,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(retrival_method='keyword_search',
|
||||
documents = RetrievalService.retrieve(retrieval_method='keyword_search',
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=self.top_k
|
||||
@ -72,7 +72,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
else:
|
||||
if self.top_k > 0:
|
||||
# retrieval source
|
||||
documents = RetrievalService.retrieve(retrival_method=retrieval_model.get('search_method', 'semantic_search'),
|
||||
documents = RetrievalService.retrieve(retrieval_method=retrieval_model.get('search_method', 'semantic_search'),
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=self.top_k,
|
||||
|
||||
@ -18,7 +18,7 @@ from core.tools.tool.tool import Tool
|
||||
|
||||
|
||||
class DatasetRetrieverTool(Tool):
|
||||
retrival_tool: DatasetRetrieverBaseTool
|
||||
retrieval_tool: DatasetRetrieverBaseTool
|
||||
|
||||
@staticmethod
|
||||
def get_dataset_tools(tenant_id: str,
|
||||
@ -43,7 +43,7 @@ class DatasetRetrieverTool(Tool):
|
||||
# Agent only support SINGLE mode
|
||||
original_retriever_mode = retrieve_config.retrieve_strategy
|
||||
retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
|
||||
retrival_tools = feature.to_dataset_retriever_tool(
|
||||
retrieval_tools = feature.to_dataset_retriever_tool(
|
||||
tenant_id=tenant_id,
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=retrieve_config,
|
||||
@ -51,20 +51,23 @@ class DatasetRetrieverTool(Tool):
|
||||
invoke_from=invoke_from,
|
||||
hit_callback=hit_callback
|
||||
)
|
||||
if retrieval_tools is None or len(retrieval_tools) == 0:
|
||||
return []
|
||||
|
||||
# restore retrieve strategy
|
||||
retrieve_config.retrieve_strategy = original_retriever_mode
|
||||
|
||||
# convert retrival tools to Tools
|
||||
# convert retrieval tools to Tools
|
||||
tools = []
|
||||
for retrival_tool in retrival_tools:
|
||||
for retrieval_tool in retrieval_tools:
|
||||
tool = DatasetRetrieverTool(
|
||||
retrival_tool=retrival_tool,
|
||||
identity=ToolIdentity(provider='', author='', name=retrival_tool.name, label=I18nObject(en_US='', zh_Hans='')),
|
||||
retrieval_tool=retrieval_tool,
|
||||
identity=ToolIdentity(provider='', author='', name=retrieval_tool.name, label=I18nObject(en_US='', zh_Hans='')),
|
||||
parameters=[],
|
||||
is_team_authorization=True,
|
||||
description=ToolDescription(
|
||||
human=I18nObject(en_US='', zh_Hans=''),
|
||||
llm=retrival_tool.description),
|
||||
llm=retrieval_tool.description),
|
||||
runtime=DatasetRetrieverTool.Runtime()
|
||||
)
|
||||
|
||||
@ -96,8 +99,7 @@ class DatasetRetrieverTool(Tool):
|
||||
yield self.create_text_message(text='please input query')
|
||||
else:
|
||||
# invoke dataset retriever tool
|
||||
result = self.retrival_tool._run(query=query)
|
||||
|
||||
result = self.retrieval_tool._run(query=query)
|
||||
yield self.create_text_message(text=result)
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
|
||||
|
||||
@ -189,8 +189,8 @@ def extract_text_blocks_as_plain_text(paragraph_html):
|
||||
|
||||
|
||||
def plain_text_leaf_node(element):
|
||||
# Extract all text, stripped of any child HTML elements and normalise it
|
||||
plain_text = normalise_text(element.get_text())
|
||||
# Extract all text, stripped of any child HTML elements and normalize it
|
||||
plain_text = normalize_text(element.get_text())
|
||||
if plain_text != "" and element.name == "li":
|
||||
plain_text = "* {}, ".format(plain_text)
|
||||
if plain_text == "":
|
||||
@ -231,8 +231,8 @@ def plain_element(element, content_digests, node_indexes):
|
||||
# For leaf node elements, extract the text content, discarding any HTML tags
|
||||
# 1. Get element contents as text
|
||||
plain_text = element.get_text()
|
||||
# 2. Normalise the extracted text string to a canonical representation
|
||||
plain_text = normalise_text(plain_text)
|
||||
# 2. Normalize the extracted text string to a canonical representation
|
||||
plain_text = normalize_text(plain_text)
|
||||
# 3. Update element content to be plain text
|
||||
element.string = plain_text
|
||||
elif is_text(element):
|
||||
@ -243,7 +243,7 @@ def plain_element(element, content_digests, node_indexes):
|
||||
element = type(element)("")
|
||||
else:
|
||||
plain_text = element.string
|
||||
plain_text = normalise_text(plain_text)
|
||||
plain_text = normalize_text(plain_text)
|
||||
element = type(element)(plain_text)
|
||||
else:
|
||||
# If not a leaf node or leaf type call recursively on child nodes, replacing
|
||||
@ -267,12 +267,12 @@ def add_node_indexes(element, node_index="0"):
|
||||
return element
|
||||
|
||||
|
||||
def normalise_text(text):
|
||||
"""Normalise unicode and whitespace."""
|
||||
# Normalise unicode first to try and standardise whitespace characters as much as possible before normalising them
|
||||
def normalize_text(text):
|
||||
"""Normalize unicode and whitespace."""
|
||||
# Normalize unicode first to try and standardize whitespace characters as much as possible before normalizing them
|
||||
text = strip_control_characters(text)
|
||||
text = normalise_unicode(text)
|
||||
text = normalise_whitespace(text)
|
||||
text = normalize_unicode(text)
|
||||
text = normalize_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
@ -291,14 +291,14 @@ def strip_control_characters(text):
|
||||
return "".join(["" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char for char in text])
|
||||
|
||||
|
||||
def normalise_unicode(text):
|
||||
"""Normalise unicode such that things that are visually equivalent map to the same unicode string where possible."""
|
||||
def normalize_unicode(text):
|
||||
"""Normalize unicode such that things that are visually equivalent map to the same unicode string where possible."""
|
||||
normal_form = "NFKC"
|
||||
text = unicodedata.normalize(normal_form, text)
|
||||
return text
|
||||
|
||||
|
||||
def normalise_whitespace(text):
|
||||
def normalize_whitespace(text):
|
||||
"""Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed."""
|
||||
text = regex.sub(r"\s+", " ", text)
|
||||
# Remove leading and trailing whitespace
|
||||
|
||||
Reference in New Issue
Block a user