mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 09:58:04 +08:00
feat: support more model types and builtin tools on aws/sagemaker (#8061)
Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
This commit is contained in:
@ -3,6 +3,7 @@ import logging
|
||||
from typing import Any, Union
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import BotoCoreError
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
@ -16,7 +17,7 @@ class GuardrailParameters(BaseModel):
|
||||
guardrail_version: str = Field(..., description="The version of the guardrail")
|
||||
source: str = Field(..., description="The source of the content")
|
||||
text: str = Field(..., description="The text to apply the guardrail to")
|
||||
aws_region: str = Field(default="us-east-1", description="AWS region for the Bedrock client")
|
||||
aws_region: str = Field(..., description="AWS region for the Bedrock client")
|
||||
|
||||
class ApplyGuardrailTool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
@ -40,6 +41,8 @@ class ApplyGuardrailTool(BuiltinTool):
|
||||
source=params.source,
|
||||
content=[{"text": {"text": params.text}}]
|
||||
)
|
||||
|
||||
logger.info(f"Raw response from AWS: {json.dumps(response, indent=2)}")
|
||||
|
||||
# Check for empty response
|
||||
if not response:
|
||||
@ -69,7 +72,7 @@ class ApplyGuardrailTool(BuiltinTool):
|
||||
|
||||
return self.create_text_message(text=result)
|
||||
|
||||
except boto3.exceptions.BotoCoreError as e:
|
||||
except BotoCoreError as e:
|
||||
error_message = f'AWS service error: {str(e)}'
|
||||
logger.error(error_message, exc_info=True)
|
||||
return self.create_text_message(text=error_message)
|
||||
@ -80,4 +83,4 @@ class ApplyGuardrailTool(BuiltinTool):
|
||||
except Exception as e:
|
||||
error_message = f'An unexpected error occurred: {str(e)}'
|
||||
logger.error(error_message, exc_info=True)
|
||||
return self.create_text_message(text=error_message)
|
||||
return self.create_text_message(text=error_message)
|
||||
@ -54,3 +54,14 @@ parameters:
|
||||
zh_Hans: 用于请求护栏审查的内容,可以是用户输入或 LLM 输出。
|
||||
llm_description: The content used for requesting guardrail review, which can be either user input or LLM output.
|
||||
form: llm
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: AWS Region
|
||||
zh_Hans: AWS 区域
|
||||
human_description:
|
||||
en_US: Please enter the AWS region for the Bedrock client, for example 'us-east-1'.
|
||||
zh_Hans: 请输入 Bedrock 客户端的 AWS 区域,例如 'us-east-1'。
|
||||
llm_description: Please enter the AWS region for the Bedrock client, for example 'us-east-1'.
|
||||
form: form
|
||||
|
||||
@ -0,0 +1,71 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
|
||||
import boto3
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
console_handler = logging.StreamHandler()
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
|
||||
class LambdaYamlToJsonTool(BuiltinTool):
|
||||
lambda_client: Any = None
|
||||
|
||||
def _invoke_lambda(self, lambda_name: str, yaml_content: str) -> str:
|
||||
msg = {
|
||||
"body": yaml_content
|
||||
}
|
||||
logger.info(json.dumps(msg))
|
||||
|
||||
invoke_response = self.lambda_client.invoke(FunctionName=lambda_name,
|
||||
InvocationType='RequestResponse',
|
||||
Payload=json.dumps(msg))
|
||||
response_body = invoke_response['Payload']
|
||||
|
||||
response_str = response_body.read().decode("utf-8")
|
||||
resp_json = json.loads(response_str)
|
||||
|
||||
logger.info(resp_json)
|
||||
if resp_json['statusCode'] != 200:
|
||||
raise Exception(f"Invalid status code: {response_str}")
|
||||
|
||||
return resp_json['body']
|
||||
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
try:
|
||||
if not self.lambda_client:
|
||||
aws_region = tool_parameters.get('aws_region') # todo: move aws_region out, and update client region
|
||||
if aws_region:
|
||||
self.lambda_client = boto3.client("lambda", region_name=aws_region)
|
||||
else:
|
||||
self.lambda_client = boto3.client("lambda")
|
||||
|
||||
yaml_content = tool_parameters.get('yaml_content', '')
|
||||
if not yaml_content:
|
||||
return self.create_text_message('Please input yaml_content')
|
||||
|
||||
lambda_name = tool_parameters.get('lambda_name', '')
|
||||
if not lambda_name:
|
||||
return self.create_text_message('Please input lambda_name')
|
||||
logger.debug(f'{json.dumps(tool_parameters, indent=2, ensure_ascii=False)}')
|
||||
|
||||
result = self._invoke_lambda(lambda_name, yaml_content)
|
||||
logger.debug(result)
|
||||
|
||||
return self.create_text_message(result)
|
||||
except Exception as e:
|
||||
return self.create_text_message(f'Exception: {str(e)}')
|
||||
|
||||
console_handler.flush()
|
||||
@ -0,0 +1,53 @@
|
||||
identity:
|
||||
name: lambda_yaml_to_json
|
||||
author: AWS
|
||||
label:
|
||||
en_US: LambdaYamlToJson
|
||||
zh_Hans: LambdaYamlToJson
|
||||
pt_BR: LambdaYamlToJson
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: A tool to convert yaml to json using AWS Lambda.
|
||||
zh_Hans: 将 YAML 转为 JSON 的工具(通过AWS Lambda)。
|
||||
pt_BR: A tool to convert yaml to json using AWS Lambda.
|
||||
llm: A tool to convert yaml to json.
|
||||
parameters:
|
||||
- name: yaml_content
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: YAML content to convert for
|
||||
zh_Hans: YAML 内容
|
||||
pt_BR: YAML content to convert for
|
||||
human_description:
|
||||
en_US: YAML content to convert for
|
||||
zh_Hans: YAML 内容
|
||||
pt_BR: YAML content to convert for
|
||||
llm_description: YAML content to convert for
|
||||
form: llm
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: region of lambda
|
||||
zh_Hans: Lambda 所在的region
|
||||
pt_BR: region of lambda
|
||||
human_description:
|
||||
en_US: region of lambda
|
||||
zh_Hans: Lambda 所在的region
|
||||
pt_BR: region of lambda
|
||||
llm_description: region of lambda
|
||||
form: form
|
||||
- name: lambda_name
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: name of lambda
|
||||
zh_Hans: Lambda 名称
|
||||
pt_BR: name of lambda
|
||||
human_description:
|
||||
en_US: name of lambda
|
||||
zh_Hans: Lambda 名称
|
||||
pt_BR: name of lambda
|
||||
form: form
|
||||
@ -78,9 +78,7 @@ class SageMakerReRankTool(BuiltinTool):
|
||||
sorted_candidate_docs = sorted(candidate_docs, key=lambda x: x['score'], reverse=True)
|
||||
|
||||
line = 9
|
||||
results_str = json.dumps(sorted_candidate_docs[:self.topk], ensure_ascii=False)
|
||||
return self.create_text_message(text=results_str)
|
||||
return [ self.create_json_message(res) for res in sorted_candidate_docs[:self.topk] ]
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f'Exception {str(e)}, line : {line}')
|
||||
|
||||
return self.create_text_message(f'Exception {str(e)}, line : {line}')
|
||||
95
api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py
Normal file
95
api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py
Normal file
@ -0,0 +1,95 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Any, Union
|
||||
|
||||
import boto3
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class TTSModelType(Enum):
|
||||
PresetVoice = "PresetVoice"
|
||||
CloneVoice = "CloneVoice"
|
||||
CloneVoice_CrossLingual = "CloneVoice_CrossLingual"
|
||||
InstructVoice = "InstructVoice"
|
||||
|
||||
class SageMakerTTSTool(BuiltinTool):
|
||||
sagemaker_client: Any = None
|
||||
sagemaker_endpoint:str = None
|
||||
s3_client : Any = None
|
||||
comprehend_client : Any = None
|
||||
|
||||
def _detect_lang_code(self, content:str, map_dict:dict=None):
|
||||
map_dict = {
|
||||
"zh" : "<|zh|>",
|
||||
"en" : "<|en|>",
|
||||
"ja" : "<|jp|>",
|
||||
"zh-TW" : "<|yue|>",
|
||||
"ko" : "<|ko|>"
|
||||
}
|
||||
|
||||
response = self.comprehend_client.detect_dominant_language(Text=content)
|
||||
language_code = response['Languages'][0]['LanguageCode']
|
||||
return map_dict.get(language_code, '<|zh|>')
|
||||
|
||||
def _build_tts_payload(self, model_type:str, content_text:str, model_role:str, prompt_text:str, prompt_audio:str, instruct_text:str):
|
||||
if model_type == TTSModelType.PresetVoice.value and model_role:
|
||||
return { "tts_text" : content_text, "role" : model_role }
|
||||
if model_type == TTSModelType.CloneVoice.value and prompt_text and prompt_audio:
|
||||
return { "tts_text" : content_text, "prompt_text": prompt_text, "prompt_audio" : prompt_audio }
|
||||
if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio:
|
||||
lang_tag = self._detect_lang_code(content_text)
|
||||
return { "tts_text" : f"{content_text}", "prompt_audio" : prompt_audio, "lang_tag" : lang_tag }
|
||||
if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role:
|
||||
return { "tts_text" : content_text, "role" : model_role, "instruct_text" : instruct_text }
|
||||
|
||||
raise RuntimeError(f"Invalid params for {model_type}")
|
||||
|
||||
def _invoke_sagemaker(self, payload:dict, endpoint:str):
|
||||
response_model = self.sagemaker_client.invoke_endpoint(
|
||||
EndpointName=endpoint,
|
||||
Body=json.dumps(payload),
|
||||
ContentType="application/json",
|
||||
)
|
||||
json_str = response_model['Body'].read().decode('utf8')
|
||||
json_obj = json.loads(json_str)
|
||||
return json_obj
|
||||
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
try:
|
||||
if not self.sagemaker_client:
|
||||
aws_region = tool_parameters.get('aws_region')
|
||||
if aws_region:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||
self.s3_client = boto3.client("s3", region_name=aws_region)
|
||||
self.comprehend_client = boto3.client('comprehend', region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||
self.s3_client = boto3.client("s3")
|
||||
self.comprehend_client = boto3.client('comprehend')
|
||||
|
||||
if not self.sagemaker_endpoint:
|
||||
self.sagemaker_endpoint = tool_parameters.get('sagemaker_endpoint')
|
||||
|
||||
tts_text = tool_parameters.get('tts_text')
|
||||
tts_infer_type = tool_parameters.get('tts_infer_type')
|
||||
|
||||
voice = tool_parameters.get('voice')
|
||||
mock_voice_audio = tool_parameters.get('mock_voice_audio')
|
||||
mock_voice_text = tool_parameters.get('mock_voice_text')
|
||||
voice_instruct_prompt = tool_parameters.get('voice_instruct_prompt')
|
||||
payload = self._build_tts_payload(tts_infer_type, tts_text, voice, mock_voice_text, mock_voice_audio, voice_instruct_prompt)
|
||||
|
||||
result = self._invoke_sagemaker(payload, self.sagemaker_endpoint)
|
||||
|
||||
return self.create_text_message(text=result['s3_presign_url'])
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f'Exception {str(e)}')
|
||||
149
api/core/tools/provider/builtin/aws/tools/sagemaker_tts.yaml
Normal file
149
api/core/tools/provider/builtin/aws/tools/sagemaker_tts.yaml
Normal file
@ -0,0 +1,149 @@
|
||||
identity:
|
||||
name: sagemaker_tts
|
||||
author: AWS
|
||||
label:
|
||||
en_US: SagemakerTTS
|
||||
zh_Hans: Sagemaker语音合成
|
||||
pt_BR: SagemakerTTS
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for Speech synthesis - https://github.com/aws-samples/dify-aws-tool
|
||||
zh_Hans: Sagemaker语音合成工具, 请参考 Github Repo - https://github.com/aws-samples/dify-aws-tool上的部署脚本
|
||||
pt_BR: A tool for Speech synthesis.
|
||||
llm: A tool for Speech synthesis. You can find deploy notebook on Github Repo - https://github.com/aws-samples/dify-aws-tool
|
||||
parameters:
|
||||
- name: sagemaker_endpoint
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: sagemaker endpoint for tts
|
||||
zh_Hans: 语音生成的SageMaker端点
|
||||
pt_BR: sagemaker endpoint for tts
|
||||
human_description:
|
||||
en_US: sagemaker endpoint for tts
|
||||
zh_Hans: 语音生成的SageMaker端点
|
||||
pt_BR: sagemaker endpoint for tts
|
||||
llm_description: sagemaker endpoint for tts
|
||||
form: form
|
||||
- name: tts_text
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: tts text
|
||||
zh_Hans: 语音合成原文
|
||||
pt_BR: tts text
|
||||
human_description:
|
||||
en_US: tts text
|
||||
zh_Hans: 语音合成原文
|
||||
pt_BR: tts text
|
||||
llm_description: tts text
|
||||
form: llm
|
||||
- name: tts_infer_type
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: tts infer type
|
||||
zh_Hans: 合成方式
|
||||
pt_BR: tts infer type
|
||||
human_description:
|
||||
en_US: tts infer type
|
||||
zh_Hans: 合成方式
|
||||
pt_BR: tts infer type
|
||||
llm_description: tts infer type
|
||||
options:
|
||||
- value: PresetVoice
|
||||
label:
|
||||
en_US: preset voice
|
||||
zh_Hans: 预置音色
|
||||
- value: CloneVoice
|
||||
label:
|
||||
en_US: clone voice
|
||||
zh_Hans: 克隆音色
|
||||
- value: CloneVoice_CrossLingual
|
||||
label:
|
||||
en_US: clone crossLingual voice
|
||||
zh_Hans: 克隆音色(跨语言)
|
||||
- value: InstructVoice
|
||||
label:
|
||||
en_US: instruct voice
|
||||
zh_Hans: 指令音色
|
||||
form: form
|
||||
- name: voice
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: preset voice
|
||||
zh_Hans: 预置音色
|
||||
pt_BR: preset voice
|
||||
human_description:
|
||||
en_US: preset voice
|
||||
zh_Hans: 预置音色
|
||||
pt_BR: preset voice
|
||||
llm_description: preset voice
|
||||
options:
|
||||
- value: 中文男
|
||||
label:
|
||||
en_US: zh-cn male
|
||||
zh_Hans: 中文男
|
||||
- value: 中文女
|
||||
label:
|
||||
en_US: zh-cn female
|
||||
zh_Hans: 中文女
|
||||
- value: 粤语女
|
||||
label:
|
||||
en_US: zh-TW female
|
||||
zh_Hans: 粤语女
|
||||
form: form
|
||||
- name: mock_voice_audio
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: clone voice link
|
||||
zh_Hans: 克隆音频链接
|
||||
pt_BR: clone voice link
|
||||
human_description:
|
||||
en_US: clone voice link
|
||||
zh_Hans: 克隆音频链接
|
||||
pt_BR: clone voice link
|
||||
llm_description: clone voice link
|
||||
form: llm
|
||||
- name: mock_voice_text
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: text of clone voice
|
||||
zh_Hans: 克隆音频对应文本
|
||||
pt_BR: text of clone voice
|
||||
human_description:
|
||||
en_US: text of clone voice
|
||||
zh_Hans: 克隆音频对应文本
|
||||
pt_BR: text of clone voice
|
||||
llm_description: text of clone voice
|
||||
form: llm
|
||||
- name: voice_instruct_prompt
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: instruct prompt for voice
|
||||
zh_Hans: 音色指令文本
|
||||
pt_BR: instruct prompt for voice
|
||||
human_description:
|
||||
en_US: instruct prompt for voice
|
||||
zh_Hans: 音色指令文本
|
||||
pt_BR: instruct prompt for voice
|
||||
llm_description: instruct prompt for voice
|
||||
form: llm
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: region of sagemaker endpoint
|
||||
zh_Hans: SageMaker 端点所在的region
|
||||
pt_BR: region of sagemaker endpoint
|
||||
human_description:
|
||||
en_US: region of sagemaker endpoint
|
||||
zh_Hans: SageMaker 端点所在的region
|
||||
pt_BR: region of sagemaker endpoint
|
||||
llm_description: region of sagemaker endpoint
|
||||
form: form
|
||||
Reference in New Issue
Block a user