Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2bfcd227e8 |
3
.gitignore
vendored
@ -145,9 +145,6 @@ docker/volumes/db/data/*
|
||||
docker/volumes/redis/data/*
|
||||
docker/volumes/weaviate/*
|
||||
docker/volumes/qdrant/*
|
||||
docker/volumes/etcd/*
|
||||
docker/volumes/minio/*
|
||||
docker/volumes/milvus/*
|
||||
|
||||
sdks/python-client/build
|
||||
sdks/python-client/dist
|
||||
|
||||
11
README.md
@ -21,6 +21,17 @@
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://discord.com/events/1082486657678311454/1211724120996188220" target="_blank">
|
||||
Dify.AI Upcoming Meetup Event [👉 Click to Join the Event Here 👈]
|
||||
</a>
|
||||
<ul align="center" style="text-decoration: none; list-style: none;">
|
||||
<li> US EST: 09:00 (9:00 AM)</li>
|
||||
<li> CET: 15:00 (3:00 PM)</li>
|
||||
<li> CST: 22:00 (10:00 PM)</li>
|
||||
</ul>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://dify.ai/blog/dify-ai-unveils-ai-agent-creating-gpts-and-assistants-with-various-llms" target="_blank">
|
||||
Dify.AI Unveils AI Agent: Creating GPTs and Assistants with Various LLMs
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
1. Start the docker-compose stack
|
||||
|
||||
The backend require some middleware, including PostgreSQL, Redis, and Weaviate, which can be started together using `docker-compose`.
|
||||
|
||||
|
||||
```bash
|
||||
cd ../docker
|
||||
docker-compose -f docker-compose.middleware.yaml -p dify up -d
|
||||
@ -15,7 +15,7 @@
|
||||
3. Generate a `SECRET_KEY` in the `.env` file.
|
||||
|
||||
```bash
|
||||
sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env
|
||||
openssl rand -base64 42
|
||||
```
|
||||
3.5 If you use annaconda, create a new environment and activate it
|
||||
```bash
|
||||
@ -46,7 +46,7 @@
|
||||
```
|
||||
pip install -r requirements.txt --upgrade --force-reinstall
|
||||
```
|
||||
|
||||
|
||||
6. Start backend:
|
||||
```bash
|
||||
flask run --host 0.0.0.0 --port=5001 --debug
|
||||
|
||||
@ -26,7 +26,6 @@ from config import CloudEditionConfig, Config
|
||||
from extensions import (
|
||||
ext_celery,
|
||||
ext_code_based_extension,
|
||||
ext_compress,
|
||||
ext_database,
|
||||
ext_hosting_provider,
|
||||
ext_login,
|
||||
@ -97,7 +96,6 @@ def create_app(test_config=None) -> Flask:
|
||||
def initialize_extensions(app):
|
||||
# Since the application instance is now created, pass it to each Flask
|
||||
# extension instance to bind it to the Flask application instance (app)
|
||||
ext_compress.init_app(app)
|
||||
ext_code_based_extension.init()
|
||||
ext_database.init_app(app)
|
||||
ext_migrate.init(app, db)
|
||||
|
||||
@ -90,7 +90,7 @@ class Config:
|
||||
# ------------------------
|
||||
# General Configurations.
|
||||
# ------------------------
|
||||
self.CURRENT_VERSION = "0.5.9"
|
||||
self.CURRENT_VERSION = "0.5.8"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
@ -293,8 +293,6 @@ class Config:
|
||||
|
||||
self.BATCH_UPLOAD_LIMIT = get_env('BATCH_UPLOAD_LIMIT')
|
||||
|
||||
self.API_COMPRESSION_ENABLED = get_bool_env('API_COMPRESSION_ENABLED')
|
||||
|
||||
|
||||
class CloudEditionConfig(Config):
|
||||
|
||||
|
||||
@ -27,9 +27,7 @@ from fields.app_fields import (
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppModelConfig, Site
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.entities.application_entities import AgentToolEntity
|
||||
|
||||
|
||||
def _get_app(app_id, tenant_id):
|
||||
app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first()
|
||||
@ -238,42 +236,7 @@ class AppApi(Resource):
|
||||
def get(self, app_id):
|
||||
"""Get app detail"""
|
||||
app_id = str(app_id)
|
||||
app: App = _get_app(app_id, current_user.current_tenant_id)
|
||||
|
||||
# get original app model config
|
||||
model_config: AppModelConfig = app.app_model_config
|
||||
agent_mode = model_config.agent_mode_dict
|
||||
# decrypt agent tool parameters if it's secret-input
|
||||
for tool in agent_mode.get('tools') or []:
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
# get tool
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
agent_tool=agent_tool_entity,
|
||||
agent_callback=None
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
)
|
||||
|
||||
# get decrypted parameters
|
||||
if agent_tool_entity.tool_parameters:
|
||||
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
|
||||
masked_parameter = manager.mask_tool_parameters(parameters or {})
|
||||
else:
|
||||
masked_parameter = {}
|
||||
|
||||
# override tool parameters
|
||||
tool['tool_parameters'] = masked_parameter
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# override agent mode
|
||||
model_config.agent_mode = json.dumps(agent_mode)
|
||||
app = _get_app(app_id, current_user.current_tenant_id)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import json
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
@ -8,9 +7,6 @@ from controllers.console import api
|
||||
from controllers.console.app import _get_app
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.entities.application_entities import AgentToolEntity
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
@ -42,88 +38,6 @@ class ModelConfigResource(Resource):
|
||||
)
|
||||
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
|
||||
|
||||
# get original app model config
|
||||
original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter(
|
||||
AppModelConfig.id == app.app_model_config_id
|
||||
).first()
|
||||
agent_mode = original_app_model_config.agent_mode_dict
|
||||
# decrypt agent tool parameters if it's secret-input
|
||||
parameter_map = {}
|
||||
masked_parameter_map = {}
|
||||
tool_map = {}
|
||||
for tool in agent_mode.get('tools') or []:
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
# get tool
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
agent_tool=agent_tool_entity,
|
||||
agent_callback=None
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
)
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
# get decrypted parameters
|
||||
if agent_tool_entity.tool_parameters:
|
||||
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
|
||||
masked_parameter = manager.mask_tool_parameters(parameters or {})
|
||||
else:
|
||||
parameters = {}
|
||||
masked_parameter = {}
|
||||
|
||||
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
|
||||
masked_parameter_map[key] = masked_parameter
|
||||
parameter_map[key] = parameters
|
||||
tool_map[key] = tool_runtime
|
||||
|
||||
# encrypt agent tool parameters if it's secret-input
|
||||
agent_mode = new_app_model_config.agent_mode_dict
|
||||
for tool in agent_mode.get('tools') or []:
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
|
||||
# get tool
|
||||
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
|
||||
if key in tool_map:
|
||||
tool_runtime = tool_map[key]
|
||||
else:
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
agent_tool=agent_tool_entity,
|
||||
agent_callback=None
|
||||
)
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
)
|
||||
manager.delete_tool_parameters_cache()
|
||||
|
||||
# override parameters if it equals to masked parameters
|
||||
if agent_tool_entity.tool_parameters:
|
||||
if key not in masked_parameter_map:
|
||||
continue
|
||||
|
||||
if agent_tool_entity.tool_parameters == masked_parameter_map[key]:
|
||||
agent_tool_entity.tool_parameters = parameter_map[key]
|
||||
|
||||
# encrypt parameters
|
||||
if agent_tool_entity.tool_parameters:
|
||||
tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
|
||||
|
||||
# update app model config
|
||||
new_app_model_config.agent_mode = json.dumps(agent_mode)
|
||||
|
||||
db.session.add(new_app_model_config)
|
||||
db.session.flush()
|
||||
|
||||
|
||||
@ -82,30 +82,6 @@ class ToolBuiltinProviderIconApi(Resource):
|
||||
icon_bytes, minetype = ToolManageService.get_builtin_tool_provider_icon(provider)
|
||||
return send_file(io.BytesIO(icon_bytes), mimetype=minetype)
|
||||
|
||||
class ToolModelProviderIconApi(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
icon_bytes, mimetype = ToolManageService.get_model_tool_provider_icon(provider)
|
||||
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype)
|
||||
|
||||
class ToolModelProviderListToolsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('provider', type=str, required=True, nullable=False, location='args')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return ToolManageService.list_model_tool_provider_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
args['provider'],
|
||||
)
|
||||
|
||||
class ToolApiProviderAddApi(Resource):
|
||||
@setup_required
|
||||
@ -307,8 +283,6 @@ api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provide
|
||||
api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin/<provider>/update')
|
||||
api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials_schema')
|
||||
api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin/<provider>/icon')
|
||||
api.add_resource(ToolModelProviderIconApi, '/workspaces/current/tool-provider/model/<provider>/icon')
|
||||
api.add_resource(ToolModelProviderListToolsApi, '/workspaces/current/tool-provider/model/tools')
|
||||
api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add')
|
||||
api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote')
|
||||
api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools')
|
||||
|
||||
@ -200,8 +200,8 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
parser.add_argument('segments', type=dict, required=False, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.update_segment(args, segment, document, dataset)
|
||||
SegmentService.segment_create_args_validate(args['segments'], document)
|
||||
segment = SegmentService.update_segment(args['segments'], segment, document, dataset)
|
||||
return {
|
||||
'data': marshal(segment, segment_fields),
|
||||
'doc_form': document.doc_form
|
||||
|
||||
@ -195,10 +195,6 @@ class AssistantApplicationRunner(AppRunner):
|
||||
if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []):
|
||||
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
|
||||
db.session.refresh(conversation)
|
||||
db.session.refresh(message)
|
||||
db.session.close()
|
||||
|
||||
# start agent runner
|
||||
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||
assistant_cot_runner = AssistantCotApplicationRunner(
|
||||
|
||||
@ -192,8 +192,6 @@ class BasicApplicationRunner(AppRunner):
|
||||
model=app_orchestration_config.model_config.model
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_orchestration_config.model_config.parameters,
|
||||
|
||||
@ -89,10 +89,6 @@ class GenerateTaskPipeline:
|
||||
Process generate task pipeline.
|
||||
:return:
|
||||
"""
|
||||
db.session.refresh(self._conversation)
|
||||
db.session.refresh(self._message)
|
||||
db.session.close()
|
||||
|
||||
if stream:
|
||||
return self._process_stream_response()
|
||||
else:
|
||||
@ -307,7 +303,6 @@ class GenerateTaskPipeline:
|
||||
.first()
|
||||
)
|
||||
db.session.refresh(agent_thought)
|
||||
db.session.close()
|
||||
|
||||
if agent_thought:
|
||||
response = {
|
||||
@ -335,8 +330,6 @@ class GenerateTaskPipeline:
|
||||
.filter(MessageFile.id == event.message_file_id)
|
||||
.first()
|
||||
)
|
||||
db.session.close()
|
||||
|
||||
# get extension
|
||||
if '.' in message_file.url:
|
||||
extension = f'.{message_file.url.split(".")[-1]}'
|
||||
@ -420,7 +413,6 @@ class GenerateTaskPipeline:
|
||||
usage = llm_result.usage
|
||||
|
||||
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||
self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
|
||||
|
||||
self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages)
|
||||
self._message.message_tokens = usage.prompt_tokens
|
||||
|
||||
@ -201,7 +201,7 @@ class ApplicationManager:
|
||||
logger.exception("Unknown Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
finally:
|
||||
db.session.close()
|
||||
db.session.remove()
|
||||
|
||||
def _handle_response(self, application_generate_entity: ApplicationGenerateEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
@ -233,6 +233,8 @@ class ApplicationManager:
|
||||
else:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
finally:
|
||||
db.session.remove()
|
||||
|
||||
def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \
|
||||
-> AppOrchestrationConfigEntity:
|
||||
@ -649,7 +651,6 @@ class ApplicationManager:
|
||||
|
||||
db.session.add(conversation)
|
||||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
else:
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
@ -688,7 +689,6 @@ class ApplicationManager:
|
||||
|
||||
db.session.add(message)
|
||||
db.session.commit()
|
||||
db.session.refresh(message)
|
||||
|
||||
for file in application_generate_entity.files:
|
||||
message_file = MessageFile(
|
||||
|
||||
@ -114,7 +114,6 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
self.agent_thought_count = db.session.query(MessageAgentThought).filter(
|
||||
MessageAgentThought.message_id == self.message.id,
|
||||
).count()
|
||||
db.session.close()
|
||||
|
||||
# check if model supports stream tool call
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
@ -155,9 +154,9 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
"""
|
||||
convert tool to prompt message tool
|
||||
"""
|
||||
tool_entity = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=self.tenant_id,
|
||||
agent_tool=tool,
|
||||
tool_entity = ToolManager.get_tool_runtime(
|
||||
provider_type=tool.provider_type, provider_name=tool.provider_id, tool_name=tool.tool_name,
|
||||
tenant_id=self.application_generate_entity.tenant_id,
|
||||
agent_callback=self.agent_callback
|
||||
)
|
||||
tool_entity.load_variables(self.variables_pool)
|
||||
@ -172,11 +171,33 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
}
|
||||
)
|
||||
|
||||
parameters = tool_entity.get_all_runtime_parameters()
|
||||
for parameter in parameters:
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
continue
|
||||
runtime_parameters = {}
|
||||
|
||||
parameters = tool_entity.parameters or []
|
||||
user_parameters = tool_entity.get_runtime_parameters() or []
|
||||
|
||||
# override parameters
|
||||
for parameter in user_parameters:
|
||||
# check if parameter in tool parameters
|
||||
found = False
|
||||
for tool_parameter in parameters:
|
||||
if tool_parameter.name == parameter.name:
|
||||
found = True
|
||||
break
|
||||
|
||||
if found:
|
||||
# override parameter
|
||||
tool_parameter.type = parameter.type
|
||||
tool_parameter.form = parameter.form
|
||||
tool_parameter.required = parameter.required
|
||||
tool_parameter.default = parameter.default
|
||||
tool_parameter.options = parameter.options
|
||||
tool_parameter.llm_description = parameter.llm_description
|
||||
else:
|
||||
# add new parameter
|
||||
parameters.append(parameter)
|
||||
|
||||
for parameter in parameters:
|
||||
parameter_type = 'string'
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.STRING:
|
||||
@ -192,16 +213,59 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
else:
|
||||
raise ValueError(f"parameter type {parameter.type} is not supported")
|
||||
|
||||
message_tool.parameters['properties'][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
}
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
# get tool parameter from form
|
||||
tool_parameter_config = tool.tool_parameters.get(parameter.name)
|
||||
if not tool_parameter_config:
|
||||
# get default value
|
||||
tool_parameter_config = parameter.default
|
||||
if not tool_parameter_config and parameter.required:
|
||||
raise ValueError(f"tool parameter {parameter.name} not found in tool config")
|
||||
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
# check if tool_parameter_config in options
|
||||
options = list(map(lambda x: x.value, parameter.options))
|
||||
if tool_parameter_config not in options:
|
||||
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}")
|
||||
|
||||
# convert tool parameter config to correct type
|
||||
try:
|
||||
if parameter.type == ToolParameter.ToolParameterType.NUMBER:
|
||||
# check if tool parameter is integer
|
||||
if isinstance(tool_parameter_config, int):
|
||||
tool_parameter_config = tool_parameter_config
|
||||
elif isinstance(tool_parameter_config, float):
|
||||
tool_parameter_config = tool_parameter_config
|
||||
elif isinstance(tool_parameter_config, str):
|
||||
if '.' in tool_parameter_config:
|
||||
tool_parameter_config = float(tool_parameter_config)
|
||||
else:
|
||||
tool_parameter_config = int(tool_parameter_config)
|
||||
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
|
||||
tool_parameter_config = bool(tool_parameter_config)
|
||||
elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]:
|
||||
tool_parameter_config = str(tool_parameter_config)
|
||||
elif parameter.type == ToolParameter.ToolParameterType:
|
||||
tool_parameter_config = str(tool_parameter_config)
|
||||
except Exception as e:
|
||||
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type")
|
||||
|
||||
# save tool parameter to tool entity memory
|
||||
runtime_parameters[parameter.name] = tool_parameter_config
|
||||
|
||||
elif parameter.form == ToolParameter.ToolParameterForm.LLM:
|
||||
message_tool.parameters['properties'][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
}
|
||||
|
||||
if len(enum) > 0:
|
||||
message_tool.parameters['properties'][parameter.name]['enum'] = enum
|
||||
if len(enum) > 0:
|
||||
message_tool.parameters['properties'][parameter.name]['enum'] = enum
|
||||
|
||||
if parameter.required:
|
||||
message_tool.parameters['required'].append(parameter.name)
|
||||
if parameter.required:
|
||||
message_tool.parameters['required'].append(parameter.name)
|
||||
|
||||
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
|
||||
|
||||
return message_tool, tool_entity
|
||||
|
||||
@ -241,9 +305,6 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
tool_runtime_parameters = tool.get_runtime_parameters() or []
|
||||
|
||||
for parameter in tool_runtime_parameters:
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
continue
|
||||
|
||||
parameter_type = 'string'
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.STRING:
|
||||
@ -259,17 +320,18 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
else:
|
||||
raise ValueError(f"parameter type {parameter.type} is not supported")
|
||||
|
||||
prompt_tool.parameters['properties'][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
}
|
||||
if parameter.form == ToolParameter.ToolParameterForm.LLM:
|
||||
prompt_tool.parameters['properties'][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
}
|
||||
|
||||
if len(enum) > 0:
|
||||
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
|
||||
if len(enum) > 0:
|
||||
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
|
||||
|
||||
if parameter.required:
|
||||
if parameter.name not in prompt_tool.parameters['required']:
|
||||
prompt_tool.parameters['required'].append(parameter.name)
|
||||
if parameter.required:
|
||||
if parameter.name not in prompt_tool.parameters['required']:
|
||||
prompt_tool.parameters['required'].append(parameter.name)
|
||||
|
||||
return prompt_tool
|
||||
|
||||
@ -342,16 +404,13 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
created_by=self.user_id,
|
||||
)
|
||||
db.session.add(message_file)
|
||||
db.session.commit()
|
||||
db.session.refresh(message_file)
|
||||
|
||||
result.append((
|
||||
message_file,
|
||||
message.save_as
|
||||
))
|
||||
|
||||
db.session.close()
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return result
|
||||
|
||||
def create_agent_thought(self, message_id: str, message: str,
|
||||
@ -388,8 +447,6 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
|
||||
db.session.add(thought)
|
||||
db.session.commit()
|
||||
db.session.refresh(thought)
|
||||
db.session.close()
|
||||
|
||||
self.agent_thought_count += 1
|
||||
|
||||
@ -407,10 +464,6 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
"""
|
||||
Save agent thought
|
||||
"""
|
||||
agent_thought = db.session.query(MessageAgentThought).filter(
|
||||
MessageAgentThought.id == agent_thought.id
|
||||
).first()
|
||||
|
||||
if thought is not None:
|
||||
agent_thought.thought = thought
|
||||
|
||||
@ -461,7 +514,6 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
agent_thought.tool_labels_str = json.dumps(labels)
|
||||
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]:
|
||||
"""
|
||||
@ -534,14 +586,9 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
"""
|
||||
convert tool variables to db variables
|
||||
"""
|
||||
db_variables = db.session.query(ToolConversationVariables).filter(
|
||||
ToolConversationVariables.conversation_id == self.message.conversation_id,
|
||||
).first()
|
||||
|
||||
db_variables.updated_at = datetime.utcnow()
|
||||
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
@ -597,6 +644,4 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
if message.answer:
|
||||
result.append(AssistantPromptMessage(content=message.answer))
|
||||
|
||||
db.session.close()
|
||||
|
||||
return result
|
||||
@ -28,9 +28,6 @@ from models.model import Conversation, Message
|
||||
|
||||
|
||||
class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
_is_first_iteration = True
|
||||
_ignore_observation_providers = ['wenxin']
|
||||
|
||||
def run(self, conversation: Conversation,
|
||||
message: Message,
|
||||
query: str,
|
||||
@ -45,8 +42,10 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
agent_scratchpad: list[AgentScratchpadUnit] = []
|
||||
self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages)
|
||||
|
||||
if 'Observation' not in app_orchestration_config.model_config.stop:
|
||||
if app_orchestration_config.model_config.provider not in self._ignore_observation_providers:
|
||||
# check model mode
|
||||
if self.app_orchestration_config.model_config.mode == "completion":
|
||||
# TODO: stop words
|
||||
if 'Observation' not in app_orchestration_config.model_config.stop:
|
||||
app_orchestration_config.model_config.stop.append('Observation')
|
||||
|
||||
# override inputs
|
||||
@ -203,7 +202,6 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
)
|
||||
)
|
||||
|
||||
scratchpad.thought = scratchpad.thought.strip() or 'I am thinking about how to help you'
|
||||
agent_scratchpad.append(scratchpad)
|
||||
|
||||
# get llm usage
|
||||
@ -257,15 +255,9 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
# invoke tool
|
||||
error_response = None
|
||||
try:
|
||||
if isinstance(tool_call_args, str):
|
||||
try:
|
||||
tool_call_args = json.loads(tool_call_args)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
tool_response = tool_instance.invoke(
|
||||
user_id=self.user_id,
|
||||
tool_parameters=tool_call_args
|
||||
tool_parameters=tool_call_args if isinstance(tool_call_args, dict) else json.loads(tool_call_args)
|
||||
)
|
||||
# transform tool response to llm friendly response
|
||||
tool_response = self.transform_tool_invoke_messages(tool_response)
|
||||
@ -474,7 +466,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
if isinstance(message, AssistantPromptMessage):
|
||||
current_scratchpad = AgentScratchpadUnit(
|
||||
agent_response=message.content,
|
||||
thought=message.content or 'I am thinking about how to help you',
|
||||
thought=message.content,
|
||||
action_str='',
|
||||
action=None,
|
||||
observation=None,
|
||||
@ -554,8 +546,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
|
||||
result = ''
|
||||
for scratchpad in agent_scratchpad:
|
||||
result += (scratchpad.thought or '') + (scratchpad.action_str or '') + \
|
||||
next_iteration.replace("{{observation}}", scratchpad.observation or 'It seems that no response is available')
|
||||
result += scratchpad.thought + next_iteration.replace("{{observation}}", scratchpad.observation or '') + "\n"
|
||||
|
||||
return result
|
||||
|
||||
@ -630,24 +621,21 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
))
|
||||
|
||||
# add assistant message
|
||||
if len(agent_scratchpad) > 0 and not self._is_first_iteration:
|
||||
if len(agent_scratchpad) > 0:
|
||||
prompt_messages.append(AssistantPromptMessage(
|
||||
content=(agent_scratchpad[-1].thought or '') + (agent_scratchpad[-1].action_str or ''),
|
||||
content=(agent_scratchpad[-1].thought or '')
|
||||
))
|
||||
|
||||
# add user message
|
||||
if len(agent_scratchpad) > 0 and not self._is_first_iteration:
|
||||
if len(agent_scratchpad) > 0:
|
||||
prompt_messages.append(UserPromptMessage(
|
||||
content=(agent_scratchpad[-1].observation or 'It seems that no response is available'),
|
||||
content=(agent_scratchpad[-1].observation or ''),
|
||||
))
|
||||
|
||||
self._is_first_iteration = False
|
||||
|
||||
return prompt_messages
|
||||
elif mode == "completion":
|
||||
# parse agent scratchpad
|
||||
agent_scratchpad_str = self._convert_scratchpad_list_to_str(agent_scratchpad)
|
||||
self._is_first_iteration = False
|
||||
# parse prompt messages
|
||||
return [UserPromptMessage(
|
||||
content=first_prompt.replace("{{instruction}}", instruction)
|
||||
|
||||
@ -1,54 +0,0 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class ToolParameterCacheType(Enum):
|
||||
PARAMETER = "tool_parameter"
|
||||
|
||||
class ToolParameterCache:
|
||||
def __init__(self,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
tool_name: str,
|
||||
cache_type: ToolParameterCacheType
|
||||
):
|
||||
self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
"""
|
||||
Get cached model provider credentials.
|
||||
|
||||
:return:
|
||||
"""
|
||||
cached_tool_parameter = redis_client.get(self.cache_key)
|
||||
if cached_tool_parameter:
|
||||
try:
|
||||
cached_tool_parameter = cached_tool_parameter.decode('utf-8')
|
||||
cached_tool_parameter = json.loads(cached_tool_parameter)
|
||||
except JSONDecodeError:
|
||||
return None
|
||||
|
||||
return cached_tool_parameter
|
||||
else:
|
||||
return None
|
||||
|
||||
def set(self, parameters: dict) -> None:
|
||||
"""
|
||||
Cache model provider credentials.
|
||||
|
||||
:param credentials: provider credentials
|
||||
:return:
|
||||
"""
|
||||
redis_client.setex(self.cache_key, 86400, json.dumps(parameters))
|
||||
|
||||
def delete(self) -> None:
|
||||
"""
|
||||
Delete cached model provider credentials.
|
||||
|
||||
:return:
|
||||
"""
|
||||
redis_client.delete(self.cache_key)
|
||||
@ -82,8 +82,6 @@ class HostingConfiguration:
|
||||
RestrictModel(model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM),
|
||||
RestrictModel(model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM),
|
||||
RestrictModel(model="text-embedding-ada-002", base_model_name="text-embedding-ada-002", model_type=ModelType.TEXT_EMBEDDING),
|
||||
RestrictModel(model="text-embedding-3-small", base_model_name="text-embedding-3-small", model_type=ModelType.TEXT_EMBEDDING),
|
||||
RestrictModel(model="text-embedding-3-large", base_model_name="text-embedding-3-large", model_type=ModelType.TEXT_EMBEDDING),
|
||||
]
|
||||
)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
@ -62,8 +62,7 @@ class IndexingRunner:
|
||||
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
|
||||
|
||||
# transform
|
||||
documents = self._transform(index_processor, dataset, text_docs, dataset_document.doc_language,
|
||||
processing_rule.to_dict())
|
||||
documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict())
|
||||
# save segment
|
||||
self._load_segments(dataset, dataset_document, documents)
|
||||
|
||||
@ -121,8 +120,7 @@ class IndexingRunner:
|
||||
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
|
||||
|
||||
# transform
|
||||
documents = self._transform(index_processor, dataset, text_docs, dataset_document.doc_language,
|
||||
processing_rule.to_dict())
|
||||
documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict())
|
||||
# save segment
|
||||
self._load_segments(dataset, dataset_document, documents)
|
||||
|
||||
@ -188,7 +186,7 @@ class IndexingRunner:
|
||||
first()
|
||||
|
||||
index_type = dataset_document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
index_processor = IndexProcessorFactory(index_type, processing_rule.to_dict()).init_index_processor()
|
||||
self._load(
|
||||
index_processor=index_processor,
|
||||
dataset=dataset,
|
||||
@ -416,14 +414,9 @@ class IndexingRunner:
|
||||
if separator:
|
||||
separator = separator.replace('\\n', '\n')
|
||||
|
||||
if 'chunk_overlap' in segmentation and segmentation['chunk_overlap']:
|
||||
chunk_overlap = segmentation['chunk_overlap']
|
||||
else:
|
||||
chunk_overlap = 0
|
||||
|
||||
character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
|
||||
chunk_size=segmentation["max_tokens"],
|
||||
chunk_overlap=chunk_overlap,
|
||||
chunk_overlap=segmentation.get('chunk_overlap', 0),
|
||||
fixed_separator=separator,
|
||||
separators=["\n\n", "。", ".", " ", ""],
|
||||
embedding_model_instance=embedding_model_instance
|
||||
@ -757,7 +750,7 @@ class IndexingRunner:
|
||||
index_processor.load(dataset, documents)
|
||||
|
||||
def _transform(self, index_processor: BaseIndexProcessor, dataset: Dataset,
|
||||
text_docs: list[Document], doc_language: str, process_rule: dict) -> list[Document]:
|
||||
text_docs: list[Document], process_rule: dict) -> list[Document]:
|
||||
# get embedding model instance
|
||||
embedding_model_instance = None
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
@ -775,8 +768,7 @@ class IndexingRunner:
|
||||
)
|
||||
|
||||
documents = index_processor.transform(text_docs, embedding_model_instance=embedding_model_instance,
|
||||
process_rule=process_rule, tenant_id=dataset.tenant_id,
|
||||
doc_language=doc_language)
|
||||
process_rule=process_rule)
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
@ -47,14 +47,11 @@ class TokenBufferMemory:
|
||||
files, message.app_model_config
|
||||
)
|
||||
|
||||
if not file_objs:
|
||||
prompt_messages.append(UserPromptMessage(content=message.query))
|
||||
else:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=message.query)]
|
||||
for file_obj in file_objs:
|
||||
prompt_message_contents.append(file_obj.prompt_message_content)
|
||||
prompt_message_contents = [TextPromptMessageContent(data=message.query)]
|
||||
for file_obj in file_objs:
|
||||
prompt_message_contents.append(file_obj.prompt_message_content)
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=message.query))
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@ class ModelType(Enum):
|
||||
SPEECH2TEXT = "speech2text"
|
||||
MODERATION = "moderation"
|
||||
TTS = "tts"
|
||||
TEXT2IMG = "text2img"
|
||||
# TEXT2IMG = "text2img"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, origin_model_type: str) -> "ModelType":
|
||||
@ -36,8 +36,6 @@ class ModelType(Enum):
|
||||
return cls.SPEECH2TEXT
|
||||
elif origin_model_type == 'tts' or origin_model_type == cls.TTS.value:
|
||||
return cls.TTS
|
||||
elif origin_model_type == 'text2img' or origin_model_type == cls.TEXT2IMG.value:
|
||||
return cls.TEXT2IMG
|
||||
elif origin_model_type == cls.MODERATION.value:
|
||||
return cls.MODERATION
|
||||
else:
|
||||
@ -61,11 +59,10 @@ class ModelType(Enum):
|
||||
return 'tts'
|
||||
elif self == self.MODERATION:
|
||||
return 'moderation'
|
||||
elif self == self.TEXT2IMG:
|
||||
return 'text2img'
|
||||
else:
|
||||
raise ValueError(f'invalid model type {self}')
|
||||
|
||||
|
||||
class FetchFrom(Enum):
|
||||
"""
|
||||
Enum class for fetch from.
|
||||
|
||||
@ -1,48 +0,0 @@
|
||||
from abc import abstractmethod
|
||||
from typing import IO, Optional
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
|
||||
class Text2ImageModel(AIModel):
|
||||
"""
|
||||
Model class for text2img model.
|
||||
"""
|
||||
model_type: ModelType = ModelType.TEXT2IMG
|
||||
|
||||
def invoke(self, model: str, credentials: dict, prompt: str,
|
||||
model_parameters: dict, user: Optional[str] = None) \
|
||||
-> list[IO[bytes]]:
|
||||
"""
|
||||
Invoke Text2Image model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt: prompt for image generation
|
||||
:param model_parameters: model parameters
|
||||
:param user: unique user id
|
||||
|
||||
:return: image bytes
|
||||
"""
|
||||
try:
|
||||
return self._invoke(model, credentials, prompt, model_parameters, user)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(self, model: str, credentials: dict, prompt: str,
|
||||
model_parameters: dict, user: Optional[str] = None) \
|
||||
-> list[IO[bytes]]:
|
||||
"""
|
||||
Invoke Text2Image model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt: prompt for image generation
|
||||
:param model_parameters: model parameters
|
||||
:param user: unique user id
|
||||
|
||||
:return: image bytes
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@ -7,7 +7,6 @@
|
||||
- togetherai
|
||||
- ollama
|
||||
- mistralai
|
||||
- groq
|
||||
- replicate
|
||||
- huggingface_hub
|
||||
- zhipuai
|
||||
|
||||
@ -424,25 +424,8 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
if not isinstance(message.content, list):
|
||||
message_text = f"{ai_prompt} {content}"
|
||||
else:
|
||||
message_text = ""
|
||||
for sub_message in message.content:
|
||||
if sub_message.type == PromptMessageContentType.TEXT:
|
||||
message_text += f"{human_prompt} {sub_message.data}"
|
||||
elif sub_message.type == PromptMessageContentType.IMAGE:
|
||||
message_text += f"{human_prompt} [IMAGE]"
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
if not isinstance(message.content, list):
|
||||
message_text = f"{ai_prompt} {content}"
|
||||
else:
|
||||
message_text = ""
|
||||
for sub_message in message.content:
|
||||
if sub_message.type == PromptMessageContentType.TEXT:
|
||||
message_text += f"{ai_prompt} {sub_message.data}"
|
||||
elif sub_message.type == PromptMessageContentType.IMAGE:
|
||||
message_text += f"{ai_prompt} [IMAGE]"
|
||||
message_text = f"{ai_prompt} {content}"
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message_text = content
|
||||
else:
|
||||
|
||||
@ -524,172 +524,5 @@ EMBEDDING_BASE_MODELS = [
|
||||
currency='USD',
|
||||
)
|
||||
)
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name='text-embedding-3-small',
|
||||
entity=AIModelEntity(
|
||||
model='fake-deployment-name',
|
||||
label=I18nObject(
|
||||
en_US='fake-deployment-name-label'
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: 8191,
|
||||
ModelPropertyKey.MAX_CHUNKS: 32,
|
||||
},
|
||||
pricing=PriceConfig(
|
||||
input=0.00002,
|
||||
unit=0.001,
|
||||
currency='USD',
|
||||
)
|
||||
)
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name='text-embedding-3-large',
|
||||
entity=AIModelEntity(
|
||||
model='fake-deployment-name',
|
||||
label=I18nObject(
|
||||
en_US='fake-deployment-name-label'
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: 8191,
|
||||
ModelPropertyKey.MAX_CHUNKS: 32,
|
||||
},
|
||||
pricing=PriceConfig(
|
||||
input=0.00013,
|
||||
unit=0.001,
|
||||
currency='USD',
|
||||
)
|
||||
)
|
||||
)
|
||||
]
|
||||
SPEECH2TEXT_BASE_MODELS = [
|
||||
AzureBaseModel(
|
||||
base_model_name='whisper-1',
|
||||
entity=AIModelEntity(
|
||||
model='fake-deployment-name',
|
||||
label=I18nObject(
|
||||
en_US='fake-deployment-name-label'
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.SPEECH2TEXT,
|
||||
model_properties={
|
||||
ModelPropertyKey.FILE_UPLOAD_LIMIT: 25,
|
||||
ModelPropertyKey.SUPPORTED_FILE_EXTENSIONS: 'flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm'
|
||||
}
|
||||
)
|
||||
)
|
||||
]
|
||||
TTS_BASE_MODELS = [
|
||||
AzureBaseModel(
|
||||
base_model_name='tts-1',
|
||||
entity=AIModelEntity(
|
||||
model='fake-deployment-name',
|
||||
label=I18nObject(
|
||||
en_US='fake-deployment-name-label'
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.TTS,
|
||||
model_properties={
|
||||
ModelPropertyKey.DEFAULT_VOICE: 'alloy',
|
||||
ModelPropertyKey.VOICES: [
|
||||
{
|
||||
'mode': 'alloy',
|
||||
'name': 'Alloy',
|
||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
},
|
||||
{
|
||||
'mode': 'echo',
|
||||
'name': 'Echo',
|
||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
},
|
||||
{
|
||||
'mode': 'fable',
|
||||
'name': 'Fable',
|
||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
},
|
||||
{
|
||||
'mode': 'onyx',
|
||||
'name': 'Onyx',
|
||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
},
|
||||
{
|
||||
'mode': 'nova',
|
||||
'name': 'Nova',
|
||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
},
|
||||
{
|
||||
'mode': 'shimmer',
|
||||
'name': 'Shimmer',
|
||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
},
|
||||
],
|
||||
ModelPropertyKey.WORD_LIMIT: 120,
|
||||
ModelPropertyKey.AUDOI_TYPE: 'mp3',
|
||||
ModelPropertyKey.MAX_WORKERS: 5
|
||||
},
|
||||
pricing=PriceConfig(
|
||||
input=0.015,
|
||||
unit=0.001,
|
||||
currency='USD',
|
||||
)
|
||||
)
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name='tts-1-hd',
|
||||
entity=AIModelEntity(
|
||||
model='fake-deployment-name',
|
||||
label=I18nObject(
|
||||
en_US='fake-deployment-name-label'
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.TTS,
|
||||
model_properties={
|
||||
ModelPropertyKey.DEFAULT_VOICE: 'alloy',
|
||||
ModelPropertyKey.VOICES: [
|
||||
{
|
||||
'mode': 'alloy',
|
||||
'name': 'Alloy',
|
||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
},
|
||||
{
|
||||
'mode': 'echo',
|
||||
'name': 'Echo',
|
||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
},
|
||||
{
|
||||
'mode': 'fable',
|
||||
'name': 'Fable',
|
||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
},
|
||||
{
|
||||
'mode': 'onyx',
|
||||
'name': 'Onyx',
|
||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
},
|
||||
{
|
||||
'mode': 'nova',
|
||||
'name': 'Nova',
|
||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
},
|
||||
{
|
||||
'mode': 'shimmer',
|
||||
'name': 'Shimmer',
|
||||
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
},
|
||||
],
|
||||
ModelPropertyKey.WORD_LIMIT: 120,
|
||||
ModelPropertyKey.AUDOI_TYPE: 'mp3',
|
||||
ModelPropertyKey.MAX_WORKERS: 5
|
||||
},
|
||||
pricing=PriceConfig(
|
||||
input=0.03,
|
||||
unit=0.001,
|
||||
currency='USD',
|
||||
)
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
@ -15,8 +15,6 @@ help:
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
- speech2text
|
||||
- tts
|
||||
configurate_methods:
|
||||
- customizable-model
|
||||
model_credential_schema:
|
||||
@ -101,36 +99,6 @@ model_credential_schema:
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: text-embedding
|
||||
- label:
|
||||
en_US: text-embedding-3-small
|
||||
value: text-embedding-3-small
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: text-embedding
|
||||
- label:
|
||||
en_US: text-embedding-3-large
|
||||
value: text-embedding-3-large
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: text-embedding
|
||||
- label:
|
||||
en_US: whisper-1
|
||||
value: whisper-1
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: speech2text
|
||||
- label:
|
||||
en_US: tts-1
|
||||
value: tts-1
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: tts
|
||||
- label:
|
||||
en_US: tts-1-hd
|
||||
value: tts-1-hd
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: tts
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的模型版本
|
||||
en_US: Enter your model version
|
||||
|
||||
@ -1,82 +0,0 @@
|
||||
import copy
|
||||
from typing import IO, Optional
|
||||
|
||||
from openai import AzureOpenAI
|
||||
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
|
||||
from core.model_runtime.model_providers.azure_openai._constant import SPEECH2TEXT_BASE_MODELS, AzureBaseModel
|
||||
|
||||
|
||||
class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel):
|
||||
"""
|
||||
Model class for OpenAI Speech to text model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
file: IO[bytes], user: Optional[str] = None) \
|
||||
-> str:
|
||||
"""
|
||||
Invoke speech2text model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param file: audio file
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
return self._speech2text_invoke(model, credentials, file)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
audio_file_path = self._get_demo_file_path()
|
||||
|
||||
with open(audio_file_path, 'rb') as audio_file:
|
||||
self._speech2text_invoke(model, credentials, audio_file)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _speech2text_invoke(self, model: str, credentials: dict, file: IO[bytes]) -> str:
|
||||
"""
|
||||
Invoke speech2text model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param file: audio file
|
||||
:return: text for given audio file
|
||||
"""
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
|
||||
# init model client
|
||||
client = AzureOpenAI(**credentials_kwargs)
|
||||
|
||||
response = client.audio.transcriptions.create(model=model, file=file)
|
||||
|
||||
return response.text
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
|
||||
return ai_model_entity.entity
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
|
||||
for ai_model_entity in SPEECH2TEXT_BASE_MODELS:
|
||||
if ai_model_entity.base_model_name == base_model_name:
|
||||
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
|
||||
ai_model_entity_copy.entity.model = model
|
||||
ai_model_entity_copy.entity.label.en_US = model
|
||||
ai_model_entity_copy.entity.label.zh_Hans = model
|
||||
return ai_model_entity_copy
|
||||
|
||||
return None
|
||||
@ -1,174 +0,0 @@
|
||||
import concurrent.futures
|
||||
import copy
|
||||
from functools import reduce
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from openai import AzureOpenAI
|
||||
from pydub import AudioSegment
|
||||
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
|
||||
from core.model_runtime.model_providers.azure_openai._constant import TTS_BASE_MODELS, AzureBaseModel
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
||||
class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
"""
|
||||
Model class for OpenAI Speech to text model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, tenant_id: str, credentials: dict,
|
||||
content_text: str, voice: str, streaming: bool, user: Optional[str] = None) -> any:
|
||||
"""
|
||||
_invoke text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model timbre
|
||||
:param streaming: output is streaming
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]:
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
if streaming:
|
||||
return Response(stream_with_context(self._tts_invoke_streaming(model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
tenant_id=tenant_id,
|
||||
voice=voice)),
|
||||
status=200, mimetype=f'audio/{audio_type}')
|
||||
else:
|
||||
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
validate credentials text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
try:
|
||||
self._tts_invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
content_text='Hello Dify!',
|
||||
voice=self._get_model_default_voice(model, credentials),
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> Response:
|
||||
"""
|
||||
_tts_invoke text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model timbre
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
max_workers = self._get_model_workers_limit(model, credentials)
|
||||
try:
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
audio_bytes_list = list()
|
||||
|
||||
# Create a thread pool and map the function to the list of sentences
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = [executor.submit(self._process_sentence, sentence=sentence, model=model, voice=voice,
|
||||
credentials=credentials) for sentence in sentences]
|
||||
for future in futures:
|
||||
try:
|
||||
if future.result():
|
||||
audio_bytes_list.append(future.result())
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
if len(audio_bytes_list) > 0:
|
||||
audio_segments = [AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type) for audio_bytes in
|
||||
audio_bytes_list if audio_bytes]
|
||||
combined_segment = reduce(lambda x, y: x + y, audio_segments)
|
||||
buffer: BytesIO = BytesIO()
|
||||
combined_segment.export(buffer, format=audio_type)
|
||||
buffer.seek(0)
|
||||
return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}")
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
# Todo: To improve the streaming function
|
||||
def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str,
|
||||
voice: str) -> any:
|
||||
"""
|
||||
_tts_invoke_streaming text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model timbre
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials):
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
tts_file_id = self._get_file_name(content_text)
|
||||
file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}'
|
||||
try:
|
||||
client = AzureOpenAI(**credentials_kwargs)
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
for sentence in sentences:
|
||||
response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip())
|
||||
# response.stream_to_file(file_path)
|
||||
storage.save(file_path, response.read())
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
def _process_sentence(self, sentence: str, model: str,
|
||||
voice, credentials: dict):
|
||||
"""
|
||||
_tts_invoke openai text2speech model api
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param voice: model timbre
|
||||
:param sentence: text content to be translated
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = AzureOpenAI(**credentials_kwargs)
|
||||
response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip())
|
||||
if isinstance(response.read(), bytes):
|
||||
return response.read()
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
|
||||
return ai_model_entity.entity
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
|
||||
for ai_model_entity in TTS_BASE_MODELS:
|
||||
if ai_model_entity.base_model_name == base_model_name:
|
||||
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
|
||||
ai_model_entity_copy.entity.model = model
|
||||
ai_model_entity_copy.entity.label.en_US = model
|
||||
ai_model_entity_copy.entity.label.zh_Hans = model
|
||||
return ai_model_entity_copy
|
||||
|
||||
return None
|
||||
@ -108,7 +108,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
try:
|
||||
response = post(url, headers=headers, data=dumps(data))
|
||||
except Exception as e:
|
||||
raise InvokeConnectionError(str(e))
|
||||
raise InvokeConnectionError(e)
|
||||
|
||||
if response.status_code != 200:
|
||||
try:
|
||||
|
||||
@ -472,7 +472,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
if message.name:
|
||||
if message.name is not None:
|
||||
message_dict["user_name"] = message.name
|
||||
|
||||
return message_dict
|
||||
|
||||
@ -1,11 +0,0 @@
|
||||
<svg width="112" height="24" viewBox="0 0 112 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M57.4336 17.092C56.4746 16.5453 55.7216 15.7924 55.1749 14.8244C54.6283 13.8564 54.3594 12.763 54.3594 11.544C54.3594 10.3251 54.6283 9.2137 55.1749 8.24571C55.7216 7.27772 56.4746 6.52485 57.4336 5.98708C58.3926 5.4493 59.4861 5.18042 60.6961 5.18042C61.6999 5.18042 62.623 5.3776 63.4476 5.77197C64.2722 6.16633 64.9445 6.73995 65.4554 7.49284L64.568 8.13816C64.1199 7.51076 63.5642 7.04469 62.9009 6.731C62.2377 6.41729 61.5027 6.26492 60.705 6.26492C59.7281 6.26492 58.8498 6.48899 58.0789 6.92818C57.2992 7.36736 56.6986 7.98579 56.2505 8.79244C55.8113 9.59014 55.5872 10.5133 55.5872 11.553C55.5872 12.5926 55.8113 13.5159 56.2505 14.3136C56.6896 15.1112 57.2992 15.7297 58.0789 16.1778C58.8587 16.617 59.7281 16.8411 60.705 16.8411C61.5027 16.8411 62.2377 16.6888 62.9009 16.375C63.5642 16.0613 64.1199 15.5953 64.568 14.9678L65.4554 15.6132C64.9445 16.366 64.2722 16.9396 63.4476 17.334C62.623 17.7284 61.7089 17.9255 60.6961 17.9255C59.4771 17.9255 58.3926 17.6568 57.4336 17.11V17.092Z" fill="#F55036"/>
|
||||
<path d="M67.2754 0H68.4763V17.8181H67.2754V0Z" fill="#F55036"/>
|
||||
<path d="M73.6754 17.092C72.7254 16.5454 71.9725 15.7924 71.4347 14.8244C70.888 13.8564 70.6191 12.763 70.6191 11.544C70.6191 10.3251 70.888 9.23163 71.4347 8.26364C71.9814 7.29566 72.7254 6.54277 73.6754 5.99604C74.6255 5.4493 75.6921 5.18042 76.8841 5.18042C78.0762 5.18042 79.1338 5.4493 80.0928 5.99604C81.0429 6.54277 81.7957 7.29566 82.3335 8.26364C82.8803 9.23163 83.1492 10.3251 83.1492 11.544C83.1492 12.763 82.8803 13.8564 82.3335 14.8244C81.7868 15.7924 81.0429 16.5454 80.0928 17.092C79.1427 17.6387 78.0673 17.9076 76.8841 17.9076C75.7011 17.9076 74.6344 17.6387 73.6754 17.092ZM79.4655 16.1599C80.2273 15.7118 80.8277 15.0843 81.2669 14.2867C81.7062 13.489 81.9302 12.5747 81.9302 11.553C81.9302 10.5312 81.7062 9.61703 81.2669 8.81933C80.8277 8.02164 80.2273 7.39425 79.4655 6.9461C78.7036 6.49796 77.8431 6.27389 76.8841 6.27389C75.9251 6.27389 75.0646 6.49796 74.3028 6.9461C73.5409 7.39425 72.9405 8.02164 72.5013 8.81933C72.0621 9.61703 71.838 10.5312 71.838 11.553C71.838 12.5747 72.0621 13.489 72.5013 14.2867C72.9405 15.0843 73.5409 15.7118 74.3028 16.1599C75.0646 16.608 75.9251 16.8322 76.8841 16.8322C77.8431 16.8322 78.7036 16.608 79.4655 16.1599Z" fill="#F55036"/>
|
||||
<path d="M96.2799 5.27905V17.8091H95.1237V15.1203C94.7114 15.9986 94.0929 16.6887 93.2774 17.1728C92.4618 17.6567 91.5027 17.9077 90.4003 17.9077C88.769 17.9077 87.4873 17.4506 86.5553 16.5364C85.6231 15.6222 85.166 14.3136 85.166 12.6017V5.27905H86.367V12.5031C86.367 13.9102 86.7255 14.9858 87.4515 15.7207C88.1775 16.4557 89.1903 16.8232 90.4989 16.8232C91.9061 16.8232 93.0264 16.384 93.851 15.5057C94.6756 14.6272 95.0878 13.4442 95.0878 11.9563V5.27905H96.2889H96.2799Z" fill="#F55036"/>
|
||||
<path d="M110.952 0V17.8181H109.777V14.8604C109.284 15.8374 108.585 16.5902 107.689 17.119C106.793 17.6479 105.78 17.9077 104.642 17.9077C103.503 17.9077 102.419 17.6389 101.469 17.0922C100.528 16.5454 99.7838 15.7925 99.246 14.8336C98.7083 13.8745 98.4395 12.781 98.4395 11.5441C98.4395 10.3073 98.7083 9.2138 99.246 8.24582C99.7838 7.27783 100.519 6.52496 101.469 5.98718C102.41 5.44941 103.468 5.18053 104.642 5.18053C105.816 5.18053 106.766 5.44044 107.653 5.96925C108.541 6.49807 109.24 7.23301 109.75 8.17411V0H110.952ZM107.295 16.16C108.057 15.7119 108.657 15.0844 109.096 14.2868C109.535 13.4891 109.759 12.5749 109.759 11.5531C109.759 10.5313 109.535 9.61713 109.096 8.81944C108.657 8.02174 108.057 7.39434 107.295 6.9462C106.533 6.49807 105.672 6.27399 104.713 6.27399C103.754 6.27399 102.894 6.49807 102.132 6.9462C101.37 7.39434 100.77 8.02174 100.331 8.81944C99.8914 9.61713 99.6673 10.5313 99.6673 11.5531C99.6673 12.5749 99.8914 13.4891 100.331 14.2868C100.77 15.0844 101.37 15.7119 102.132 16.16C102.894 16.6081 103.754 16.8322 104.713 16.8322C105.672 16.8322 106.533 16.6081 107.295 16.16Z" fill="#F55036"/>
|
||||
<path d="M30.6085 5.27024C27.077 5.27024 24.209 8.13835 24.209 11.6697C24.209 15.201 27.077 18.0692 30.6085 18.0692C34.1399 18.0692 37.0079 15.201 37.0079 11.6697C37.0079 8.13835 34.1399 5.27921 30.6085 5.27024ZM30.6085 15.6672C28.4036 15.6672 26.611 13.8746 26.611 11.6697C26.611 9.46486 28.4036 7.67228 30.6085 7.67228C32.8133 7.67228 34.6059 9.46486 34.6059 11.6697C34.6059 13.8746 32.8133 15.6672 30.6085 15.6672Z" fill="black"/>
|
||||
<path d="M6.45358 5.23422C2.92222 5.19837 0.036187 8.0396 0.000335591 11.571C-0.0355158 15.1023 2.80571 17.9974 6.33706 18.0242C6.37292 18.0242 6.41773 18.0242 6.45358 18.0242H8.55986V15.6311H6.45358C4.24873 15.658 2.43823 13.8923 2.41134 11.6785C2.38445 9.47365 4.15014 7.66315 6.36395 7.63626C6.39084 7.63626 6.4267 7.63626 6.45358 7.63626C8.65844 7.63626 10.46 9.42884 10.46 11.6337V17.5222C10.46 19.7092 8.67637 21.4929 6.48943 21.5197C5.44078 21.5197 4.44591 21.0895 3.71095 20.3455L2.01698 22.0395C3.1911 23.2227 4.7865 23.8949 6.45358 23.9128H6.54321C10.0298 23.859 12.8351 21.0357 12.853 17.5491V11.4724C12.7635 8.00374 9.93116 5.23422 6.46254 5.23422H6.45358Z" fill="black"/>
|
||||
<path d="M51.2406 11.5082C51.151 8.03961 48.3187 5.27009 44.8501 5.27009C41.3187 5.23423 38.4237 8.07545 38.3968 11.6068C38.361 15.1382 41.2022 18.0331 44.7335 18.0601C44.7694 18.0601 44.8143 18.0601 44.8501 18.0601H46.9563V15.667H44.8501C42.6452 15.6939 40.8347 13.9282 40.8078 11.7144C40.7809 9.5095 42.5467 7.69902 44.7604 7.67213C44.7874 7.67213 44.8232 7.67213 44.8501 7.67213C47.055 7.67213 48.8565 9.46469 48.8565 11.6696V23.626L51.2406 23.6528V11.5082Z" fill="black"/>
|
||||
<path d="M14.6808 18.0602H17.0649V11.6607C17.0649 9.45589 18.8575 7.66332 21.0623 7.66332C21.7883 7.66332 22.4695 7.8605 23.0611 8.2011L24.2621 6.12172C23.3209 5.57498 22.2276 5.27024 21.0713 5.27024C17.5399 5.27024 14.6719 8.13835 14.6719 11.6697V18.0692L14.6808 18.0602Z" fill="black"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 5.8 KiB |
@ -1,4 +0,0 @@
|
||||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<rect width="24" height="24" rx="12" fill="#F55036"/>
|
||||
<path d="M12.146 6.00022C9.87734 5.97718 8.02325 7.80249 8.00022 10.0712C7.97718 12.3398 9.80249 14.1997 12.0712 14.217C12.0942 14.217 12.123 14.217 12.146 14.217H13.4992V12.6796H12.146C10.7295 12.6968 9.56641 11.5625 9.54913 10.1403C9.53186 8.72377 10.6662 7.56065 12.0884 7.54337C12.1057 7.54337 12.1287 7.54337 12.146 7.54337C13.5625 7.54337 14.7199 8.69498 14.7199 10.1115V13.8945C14.7199 15.2995 13.574 16.4453 12.169 16.4626C11.4953 16.4626 10.8562 16.1862 10.384 15.7083L9.29578 16.7965C10.0501 17.5566 11.075 17.9885 12.146 18H12.2036C14.4435 17.9654 16.2457 16.1516 16.2572 13.9117V10.0078C16.1997 7.77945 14.3801 6.00022 12.1518 6.00022H12.146Z" fill="white"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 828 B |
@ -1,29 +0,0 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class GroqProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
model_instance.validate_credentials(
|
||||
model='llama2-70b-4096',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
raise ex
|
||||
@ -1,32 +0,0 @@
|
||||
provider: groq
|
||||
label:
|
||||
zh_Hans: GroqCloud
|
||||
en_US: GroqCloud
|
||||
description:
|
||||
en_US: GroqCloud provides access to the Groq Cloud API, which hosts models like LLama2 and Mixtral.
|
||||
zh_Hans: GroqCloud 提供对 Groq Cloud API 的访问,其中托管了 LLama2 和 Mixtral 等模型。
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
en_US: icon_l_en.svg
|
||||
background: "#F5F5F4"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API Key from GroqCloud
|
||||
zh_Hans: 从 GroqCloud 获取 API Key
|
||||
url:
|
||||
en_US: https://console.groq.com/
|
||||
supported_model_types:
|
||||
- llm
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
@ -1,25 +0,0 @@
|
||||
model: llama2-70b-4096
|
||||
label:
|
||||
zh_Hans: Llama-2-70B-4096
|
||||
en_US: Llama-2-70B-4096
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 4096
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
pricing:
|
||||
input: '0.7'
|
||||
output: '0.8'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -1,26 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||
|
||||
|
||||
class GroqLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials: dict) -> None:
|
||||
credentials['mode'] = 'chat'
|
||||
credentials['endpoint_url'] = 'https://api.groq.com/openai/v1'
|
||||
|
||||
@ -1,25 +0,0 @@
|
||||
model: mixtral-8x7b-32768
|
||||
label:
|
||||
zh_Hans: Mixtral-8x7b-Instruct-v0.1
|
||||
en_US: Mixtral-8x7b-Instruct-v0.1
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 20480
|
||||
pricing:
|
||||
input: '0.27'
|
||||
output: '0.27'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -1,32 +1,20 @@
|
||||
from os.path import abspath, dirname, join
|
||||
from threading import Lock
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
class JinaTokenizer:
|
||||
_tokenizer = None
|
||||
_lock = Lock()
|
||||
|
||||
@classmethod
|
||||
def _get_tokenizer(cls):
|
||||
if cls._tokenizer is None:
|
||||
with cls._lock:
|
||||
if cls._tokenizer is None:
|
||||
base_path = abspath(__file__)
|
||||
gpt2_tokenizer_path = join(dirname(base_path), 'tokenizer')
|
||||
cls._tokenizer = AutoTokenizer.from_pretrained(gpt2_tokenizer_path)
|
||||
return cls._tokenizer
|
||||
|
||||
@classmethod
|
||||
def _get_num_tokens_by_jina_base(cls, text: str) -> int:
|
||||
@staticmethod
|
||||
def _get_num_tokens_by_jina_base(text: str) -> int:
|
||||
"""
|
||||
use jina tokenizer to get num tokens
|
||||
"""
|
||||
tokenizer = cls._get_tokenizer()
|
||||
base_path = abspath(__file__)
|
||||
gpt2_tokenizer_path = join(dirname(base_path), 'tokenizer')
|
||||
tokenizer = AutoTokenizer.from_pretrained(gpt2_tokenizer_path)
|
||||
tokens = tokenizer.encode(text)
|
||||
return len(tokens)
|
||||
|
||||
@classmethod
|
||||
def get_num_tokens(cls, text: str) -> int:
|
||||
return cls._get_num_tokens_by_jina_base(text)
|
||||
@staticmethod
|
||||
def get_num_tokens(text: str) -> int:
|
||||
return JinaTokenizer._get_num_tokens_by_jina_base(text)
|
||||
@ -57,7 +57,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
||||
try:
|
||||
response = post(url, headers=headers, data=dumps(data))
|
||||
except Exception as e:
|
||||
raise InvokeConnectionError(str(e))
|
||||
raise InvokeConnectionError(e)
|
||||
|
||||
if response.status_code != 200:
|
||||
try:
|
||||
|
||||
@ -59,7 +59,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
|
||||
try:
|
||||
response = post(join(url, 'embeddings'), headers=headers, data=dumps(data), timeout=10)
|
||||
except Exception as e:
|
||||
raise InvokeConnectionError(str(e))
|
||||
raise InvokeConnectionError(e)
|
||||
|
||||
if response.status_code != 200:
|
||||
try:
|
||||
|
||||
@ -65,7 +65,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
|
||||
try:
|
||||
response = post(url, headers=headers, data=dumps(data))
|
||||
except Exception as e:
|
||||
raise InvokeConnectionError(str(e))
|
||||
raise InvokeConnectionError(e)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise InvokeServerUnavailableError(response.text)
|
||||
|
||||
|
Before Width: | Height: | Size: 7.2 KiB After Width: | Height: | Size: 2.3 KiB |
@ -24,7 +24,7 @@ parameter_rules:
|
||||
min: 1
|
||||
max: 8000
|
||||
- name: safe_prompt
|
||||
default: false
|
||||
defulat: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
|
||||
@ -24,7 +24,7 @@ parameter_rules:
|
||||
min: 1
|
||||
max: 8000
|
||||
- name: safe_prompt
|
||||
default: false
|
||||
defulat: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
|
||||
@ -24,7 +24,7 @@ parameter_rules:
|
||||
min: 1
|
||||
max: 8000
|
||||
- name: safe_prompt
|
||||
default: false
|
||||
defulat: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
|
||||
@ -24,7 +24,7 @@ parameter_rules:
|
||||
min: 1
|
||||
max: 2048
|
||||
- name: safe_prompt
|
||||
default: false
|
||||
defulat: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
|
||||
@ -24,7 +24,7 @@ parameter_rules:
|
||||
min: 1
|
||||
max: 8000
|
||||
- name: safe_prompt
|
||||
default: false
|
||||
defulat: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
|
||||
@ -2,4 +2,4 @@ model: whisper-1
|
||||
model_type: speech2text
|
||||
model_properties:
|
||||
file_upload_limit: 25
|
||||
supported_file_extensions: flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm
|
||||
supported_file_extensions: mp3,mp4,mpeg,mpga,m4a,wav,webm
|
||||
|
||||
@ -25,7 +25,6 @@ from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
DefaultParameterName,
|
||||
FetchFrom,
|
||||
ModelFeature,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
ParameterRule,
|
||||
@ -167,23 +166,11 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
support_function_call = False
|
||||
features = []
|
||||
function_calling_type = credentials.get('function_calling_type', 'no_call')
|
||||
if function_calling_type == 'function_call':
|
||||
features = [ModelFeature.TOOL_CALL]
|
||||
support_function_call = True
|
||||
endpoint_url = credentials["endpoint_url"]
|
||||
# if not endpoint_url.endswith('/'):
|
||||
# endpoint_url += '/'
|
||||
# if 'https://api.openai.com/v1/' == endpoint_url:
|
||||
# features = [ModelFeature.STREAM_TOOL_CALL]
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
features=features if support_function_call else [],
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', "4096")),
|
||||
ModelPropertyKey.MODE: credentials.get('mode'),
|
||||
@ -207,6 +194,14 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
max=1,
|
||||
precision=2
|
||||
),
|
||||
ParameterRule(
|
||||
name="top_k",
|
||||
label=I18nObject(en_US="Top K"),
|
||||
type=ParameterType.INT,
|
||||
default=int(credentials.get('top_k', 1)),
|
||||
min=1,
|
||||
max=100
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.FREQUENCY_PENALTY.value,
|
||||
label=I18nObject(en_US="Frequency Penalty"),
|
||||
@ -237,7 +232,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
output=Decimal(credentials.get('output_price', 0)),
|
||||
unit=Decimal(credentials.get('unit', 0)),
|
||||
currency=credentials.get('currency', "USD")
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if credentials['mode'] == 'chat':
|
||||
@ -297,22 +292,14 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
raise ValueError("Unsupported completion type for model configuration.")
|
||||
|
||||
# annotate tools with names, descriptions, etc.
|
||||
function_calling_type = credentials.get('function_calling_type', 'no_call')
|
||||
formatted_tools = []
|
||||
if tools:
|
||||
if function_calling_type == 'function_call':
|
||||
data['functions'] = [{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters
|
||||
} for tool in tools]
|
||||
elif function_calling_type == 'tool_call':
|
||||
data["tool_choice"] = "auto"
|
||||
data["tool_choice"] = "auto"
|
||||
|
||||
for tool in tools:
|
||||
formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
|
||||
for tool in tools:
|
||||
formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
|
||||
|
||||
data["tools"] = formatted_tools
|
||||
data["tools"] = formatted_tools
|
||||
|
||||
if stop:
|
||||
data["stop"] = stop
|
||||
@ -380,9 +367,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
|
||||
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
|
||||
if chunk:
|
||||
# ignore sse comments
|
||||
#ignore sse comments
|
||||
if chunk.startswith(':'):
|
||||
continue
|
||||
continue
|
||||
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
|
||||
chunk_json = None
|
||||
try:
|
||||
@ -465,13 +452,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
|
||||
response_content = ''
|
||||
tool_calls = None
|
||||
function_calling_type = credentials.get('function_calling_type', 'no_call')
|
||||
|
||||
if completion_type is LLMMode.CHAT:
|
||||
response_content = output.get('message', {})['content']
|
||||
if function_calling_type == 'tool_call':
|
||||
tool_calls = output.get('message', {}).get('tool_calls')
|
||||
elif function_calling_type == 'function_call':
|
||||
tool_calls = output.get('message', {}).get('function_call')
|
||||
tool_calls = output.get('message', {}).get('tool_calls')
|
||||
|
||||
elif completion_type is LLMMode.COMPLETION:
|
||||
response_content = output['text']
|
||||
@ -479,10 +463,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
assistant_message = AssistantPromptMessage(content=response_content, tool_calls=[])
|
||||
|
||||
if tool_calls:
|
||||
if function_calling_type == 'tool_call':
|
||||
assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls)
|
||||
elif function_calling_type == 'function_call':
|
||||
assistant_message.tool_calls = [self._extract_response_function_call(tool_calls)]
|
||||
assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls)
|
||||
|
||||
usage = response_json.get("usage")
|
||||
if usage:
|
||||
@ -541,34 +522,33 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls:
|
||||
# message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call
|
||||
# in
|
||||
# message.tool_calls]
|
||||
|
||||
function_call = message.tool_calls[0]
|
||||
message_dict["function_call"] = {
|
||||
"name": function_call.function.name,
|
||||
"arguments": function_call.function.arguments,
|
||||
}
|
||||
message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call
|
||||
in
|
||||
message.tool_calls]
|
||||
# function_call = message.tool_calls[0]
|
||||
# message_dict["function_call"] = {
|
||||
# "name": function_call.function.name,
|
||||
# "arguments": function_call.function.arguments,
|
||||
# }
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
# message_dict = {
|
||||
# "role": "tool",
|
||||
# "content": message.content,
|
||||
# "tool_call_id": message.tool_call_id
|
||||
# }
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"role": "tool",
|
||||
"content": message.content,
|
||||
"name": message.tool_call_id
|
||||
"tool_call_id": message.tool_call_id
|
||||
}
|
||||
# message_dict = {
|
||||
# "role": "function",
|
||||
# "content": message.content,
|
||||
# "name": message.tool_call_id
|
||||
# }
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
if message.name:
|
||||
if message.name is not None:
|
||||
message_dict["name"] = message.name
|
||||
|
||||
return message_dict
|
||||
@ -713,26 +693,3 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
return tool_calls
|
||||
|
||||
def _extract_response_function_call(self, response_function_call) \
|
||||
-> AssistantPromptMessage.ToolCall:
|
||||
"""
|
||||
Extract function call from response
|
||||
|
||||
:param response_function_call: response function call
|
||||
:return: tool call
|
||||
"""
|
||||
tool_call = None
|
||||
if response_function_call:
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_function_call['name'],
|
||||
arguments=response_function_call['arguments']
|
||||
)
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_function_call['name'],
|
||||
type="function",
|
||||
function=function
|
||||
)
|
||||
|
||||
return tool_call
|
||||
|
||||
@ -75,28 +75,6 @@ model_credential_schema:
|
||||
value: llm
|
||||
default: '4096'
|
||||
type: text-input
|
||||
- variable: function_calling_type
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
label:
|
||||
en_US: Function calling
|
||||
type: select
|
||||
required: false
|
||||
default: no_call
|
||||
options:
|
||||
- value: function_call
|
||||
label:
|
||||
en_US: Support
|
||||
zh_Hans: 支持
|
||||
# - value: tool_call
|
||||
# label:
|
||||
# en_US: Tool Call
|
||||
# zh_Hans: Tool Call
|
||||
- value: no_call
|
||||
label:
|
||||
en_US: Not Support
|
||||
zh_Hans: 不支持
|
||||
- variable: stream_mode_delimiter
|
||||
label:
|
||||
zh_Hans: 流模式返回结果的分隔符
|
||||
|
||||
@ -53,7 +53,7 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel):
|
||||
# cloud not connect to the server
|
||||
raise InvokeAuthorizationError(f"Invalid server URL: {e}")
|
||||
except Exception as e:
|
||||
raise InvokeConnectionError(str(e))
|
||||
raise InvokeConnectionError(e)
|
||||
|
||||
if response.status_code != 200:
|
||||
if response.status_code == 400:
|
||||
|
||||
@ -308,7 +308,6 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
type=ParameterType.INT,
|
||||
use_template='max_tokens',
|
||||
min=1,
|
||||
max=credentials.get('context_length', 2048),
|
||||
default=512,
|
||||
label=I18nObject(
|
||||
zh_Hans='最大生成长度',
|
||||
|
||||
@ -44,9 +44,6 @@ class XinferenceRerankModel(RerankModel):
|
||||
docs=[]
|
||||
)
|
||||
|
||||
if credentials['server_url'].endswith('/'):
|
||||
credentials['server_url'] = credentials['server_url'][:-1]
|
||||
|
||||
# initialize client
|
||||
client = Client(
|
||||
base_url=credentials['server_url']
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
from os import path
|
||||
from threading import Lock
|
||||
from time import time
|
||||
|
||||
from requests.adapters import HTTPAdapter
|
||||
from requests.exceptions import ConnectionError, MissingSchema, Timeout
|
||||
from requests.sessions import Session
|
||||
from yarl import URL
|
||||
|
||||
|
||||
class XinferenceModelExtraParameter:
|
||||
@ -55,10 +55,7 @@ class XinferenceHelper:
|
||||
get xinference model extra parameter like model_format and model_handle_type
|
||||
"""
|
||||
|
||||
if not model_uid or not model_uid.strip() or not server_url or not server_url.strip():
|
||||
raise RuntimeError('model_uid is empty')
|
||||
|
||||
url = str(URL(server_url) / 'v1' / 'models' / model_uid)
|
||||
url = path.join(server_url, 'v1/models', model_uid)
|
||||
|
||||
# this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
|
||||
session = Session()
|
||||
@ -69,6 +66,7 @@ class XinferenceHelper:
|
||||
response = session.get(url, timeout=10)
|
||||
except (MissingSchema, ConnectionError, Timeout) as e:
|
||||
raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}')
|
||||
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(f'get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}')
|
||||
|
||||
|
||||
@ -3,7 +3,6 @@ import csv
|
||||
from typing import Optional
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.extractor.helpers import detect_file_encodings
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
@ -37,7 +36,7 @@ class CSVExtractor(BaseExtractor):
|
||||
docs = self._read_from_file(csvfile)
|
||||
except UnicodeDecodeError as e:
|
||||
if self._autodetect_encoding:
|
||||
detected_encodings = detect_file_encodings(self._file_path)
|
||||
detected_encodings = detect_filze_encodings(self._file_path)
|
||||
for encoding in detected_encodings:
|
||||
try:
|
||||
with open(self._file_path, newline="", encoding=encoding.encoding) as csvfile:
|
||||
|
||||
@ -10,7 +10,7 @@ from core.rag.models.document import Document
|
||||
|
||||
|
||||
class WordExtractor(BaseExtractor):
|
||||
"""Load docx files.
|
||||
"""Load pdf files.
|
||||
|
||||
|
||||
Args:
|
||||
@ -46,16 +46,14 @@ class WordExtractor(BaseExtractor):
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Load given path as single page."""
|
||||
from docx import Document as docx_Document
|
||||
import docx2txt
|
||||
|
||||
document = docx_Document(self.file_path)
|
||||
doc_texts = [paragraph.text for paragraph in document.paragraphs]
|
||||
content = '\n'.join(doc_texts)
|
||||
|
||||
return [Document(
|
||||
page_content=content,
|
||||
metadata={"source": self.file_path},
|
||||
)]
|
||||
return [
|
||||
Document(
|
||||
page_content=docx2txt.process(self.file_path),
|
||||
metadata={"source": self.file_path},
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _is_valid_url(url: str) -> bool:
|
||||
|
||||
@ -52,7 +52,7 @@ class BaseIndexProcessor(ABC):
|
||||
|
||||
character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
|
||||
chunk_size=segmentation["max_tokens"],
|
||||
chunk_overlap=segmentation.get('chunk_overlap', 0),
|
||||
chunk_overlap=0,
|
||||
fixed_separator=separator,
|
||||
separators=["\n\n", "。", ".", " ", ""],
|
||||
embedding_model_instance=embedding_model_instance
|
||||
@ -61,7 +61,7 @@ class BaseIndexProcessor(ABC):
|
||||
# Automatic segmentation
|
||||
character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
|
||||
chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
|
||||
chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'],
|
||||
chunk_overlap=0,
|
||||
separators=["\n\n", "。", ".", " ", ""],
|
||||
embedding_model_instance=embedding_model_instance
|
||||
)
|
||||
|
||||
@ -7,6 +7,7 @@ from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
from flask import Flask, current_app
|
||||
from flask_login import current_user
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from core.generator.llm_generator import LLMGenerator
|
||||
@ -30,7 +31,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
|
||||
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
||||
splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'),
|
||||
embedding_model_instance=kwargs.get('embedding_model_instance'))
|
||||
embedding_model_instance=None)
|
||||
|
||||
# Split the text documents into nodes.
|
||||
all_documents = []
|
||||
@ -65,10 +66,10 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
for doc in sub_documents:
|
||||
document_format_thread = threading.Thread(target=self._format_qa_document, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'tenant_id': kwargs.get('tenant_id'),
|
||||
'tenant_id': current_user.current_tenant.id,
|
||||
'document_node': doc,
|
||||
'all_qa_documents': all_qa_documents,
|
||||
'document_language': kwargs.get('doc_language', 'English')})
|
||||
'document_language': kwargs.get('document_language', 'English')})
|
||||
threads.append(document_format_thread)
|
||||
document_format_thread.start()
|
||||
for thread in threads:
|
||||
|
||||
@ -30,7 +30,7 @@ def _split_text_with_regex(
|
||||
if separator:
|
||||
if keep_separator:
|
||||
# The parentheses in the pattern keep the delimiters in the result.
|
||||
_splits = re.split(f"({re.escape(separator)})", text)
|
||||
_splits = re.split(f"({separator})", text)
|
||||
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
|
||||
if len(_splits) % 2 == 0:
|
||||
splits += _splits[-1:]
|
||||
@ -94,7 +94,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
documents.append(new_doc)
|
||||
return documents
|
||||
|
||||
def split_documents(self, documents: Iterable[Document] ) -> list[Document]:
|
||||
def split_documents(self, documents: Iterable[Document]) -> list[Document]:
|
||||
"""Split documents."""
|
||||
texts, metadatas = [], []
|
||||
for doc in documents:
|
||||
|
||||
@ -119,7 +119,7 @@ parameters: # Parameter list
|
||||
- The `identity` field is mandatory, it contains the basic information of the tool, including name, author, label, description, etc.
|
||||
- `parameters` Parameter list
|
||||
- `name` Parameter name, unique, no duplication with other parameters
|
||||
- `type` Parameter type, currently supports `string`, `number`, `boolean`, `select`, `secret-input` four types, corresponding to string, number, boolean, drop-down box, and encrypted input box, respectively. For sensitive information, we recommend using `secret-input` type
|
||||
- `type` Parameter type, currently supports `string`, `number`, `boolean`, `select` four types, corresponding to string, number, boolean, drop-down box
|
||||
- `required` Required or not
|
||||
- In `llm` mode, if the parameter is required, the Agent is required to infer this parameter
|
||||
- In `form` mode, if the parameter is required, the user is required to fill in this parameter on the frontend before the conversation starts
|
||||
|
||||
@ -119,7 +119,7 @@ parameters: # 参数列表
|
||||
- `identity` 字段是必须的,它包含了工具的基本信息,包括名称、作者、标签、描述等
|
||||
- `parameters` 参数列表
|
||||
- `name` 参数名称,唯一,不允许和其他参数重名
|
||||
- `type` 参数类型,目前支持`string`、`number`、`boolean`、`select`、`secret-input` 五种类型,分别对应字符串、数字、布尔值、下拉框、加密输入框,对于敏感信息,我们建议使用`secret-input`类型
|
||||
- `type` 参数类型,目前支持`string`、`number`、`boolean`、`select` 四种类型,分别对应字符串、数字、布尔值、下拉框
|
||||
- `required` 是否必填
|
||||
- 在`llm`模式下,如果参数为必填,则会要求Agent必须要推理出这个参数
|
||||
- 在`form`模式下,如果参数为必填,则会要求用户在对话开始前在前端填写这个参数
|
||||
|
||||
@ -8,19 +8,15 @@ class I18nObject(BaseModel):
|
||||
Model class for i18n object.
|
||||
"""
|
||||
zh_Hans: Optional[str] = None
|
||||
pt_BR: Optional[str] = None
|
||||
en_US: str
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
if not self.zh_Hans:
|
||||
self.zh_Hans = self.en_US
|
||||
if not self.pt_BR:
|
||||
self.pt_BR = self.en_US
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
'zh_Hans': self.zh_Hans,
|
||||
'en_US': self.en_US,
|
||||
'pt_BR': self.pt_BR
|
||||
}
|
||||
}
|
||||
@ -100,7 +100,6 @@ class ToolParameter(BaseModel):
|
||||
NUMBER = "number"
|
||||
BOOLEAN = "boolean"
|
||||
SELECT = "select"
|
||||
SECRET_INPUT = "secret-input"
|
||||
|
||||
class ToolParameterForm(Enum):
|
||||
SCHEMA = "schema" # should be set while adding tool
|
||||
@ -305,24 +304,4 @@ class ToolRuntimeVariablePool(BaseModel):
|
||||
value=value,
|
||||
)
|
||||
|
||||
self.pool.append(variable)
|
||||
|
||||
class ModelToolPropertyKey(Enum):
|
||||
IMAGE_PARAMETER_NAME = "image_parameter_name"
|
||||
|
||||
class ModelToolConfiguration(BaseModel):
|
||||
"""
|
||||
Model tool configuration
|
||||
"""
|
||||
type: str = Field(..., description="The type of the model tool")
|
||||
model: str = Field(..., description="The model")
|
||||
label: I18nObject = Field(..., description="The label of the model tool")
|
||||
properties: dict[ModelToolPropertyKey, Any] = Field(..., description="The properties of the model tool")
|
||||
|
||||
class ModelToolProviderConfiguration(BaseModel):
|
||||
"""
|
||||
Model tool provider configuration
|
||||
"""
|
||||
provider: str = Field(..., description="The provider of the model tool")
|
||||
models: list[ModelToolConfiguration] = Field(..., description="The models of the model tool")
|
||||
label: I18nObject = Field(..., description="The label of the model tool")
|
||||
self.pool.append(variable)
|
||||
@ -13,7 +13,6 @@ class UserToolProvider(BaseModel):
|
||||
BUILTIN = "builtin"
|
||||
APP = "app"
|
||||
API = "api"
|
||||
MODEL = "model"
|
||||
|
||||
id: str
|
||||
author: str
|
||||
|
||||
@ -1,20 +0,0 @@
|
||||
provider: anthropic
|
||||
label:
|
||||
en_US: Anthropic Model Tools
|
||||
zh_Hans: Anthropic 模型能力
|
||||
pt_BR: Anthropic Model Tools
|
||||
models:
|
||||
- type: llm
|
||||
model: claude-3-sonnet-20240229
|
||||
label:
|
||||
zh_Hans: Claude3 Sonnet 视觉
|
||||
en_US: Claude3 Sonnet Vision
|
||||
properties:
|
||||
image_parameter_name: image_id
|
||||
- type: llm
|
||||
model: claude-3-opus-20240229
|
||||
label:
|
||||
zh_Hans: Claude3 Opus 视觉
|
||||
en_US: Claude3 Opus Vision
|
||||
properties:
|
||||
image_parameter_name: image_id
|
||||
@ -1,13 +0,0 @@
|
||||
provider: google
|
||||
label:
|
||||
en_US: Google Model Tools
|
||||
zh_Hans: Google 模型能力
|
||||
pt_BR: Google Model Tools
|
||||
models:
|
||||
- type: llm
|
||||
model: gemini-pro-vision
|
||||
label:
|
||||
zh_Hans: Gemini Pro 视觉
|
||||
en_US: Gemini Pro Vision
|
||||
properties:
|
||||
image_parameter_name: image_id
|
||||
@ -1,13 +0,0 @@
|
||||
provider: openai
|
||||
label:
|
||||
en_US: OpenAI Model Tools
|
||||
zh_Hans: OpenAI 模型能力
|
||||
pt_BR: OpenAI Model Tools
|
||||
models:
|
||||
- type: llm
|
||||
model: gpt-4-vision-preview
|
||||
label:
|
||||
zh_Hans: GPT-4 视觉
|
||||
en_US: GPT-4 Vision
|
||||
properties:
|
||||
image_parameter_name: image_id
|
||||
@ -1,13 +0,0 @@
|
||||
provider: zhipuai
|
||||
label:
|
||||
en_US: ZhipuAI Model Tools
|
||||
zh_Hans: ZhipuAI 模型能力
|
||||
pt_BR: ZhipuAI Model Tools
|
||||
models:
|
||||
- type: llm
|
||||
model: glm-4v
|
||||
label:
|
||||
zh_Hans: GLM-4 视觉
|
||||
en_US: GLM-4 Vision
|
||||
properties:
|
||||
image_parameter_name: image_id
|
||||
@ -1,19 +1,14 @@
|
||||
- google
|
||||
- bing
|
||||
- duckduckgo
|
||||
- dalle
|
||||
- azuredalle
|
||||
- wikipedia
|
||||
- model.openai
|
||||
- model.google
|
||||
- model.anthropic
|
||||
- yahoo
|
||||
- wikipedia
|
||||
- arxiv
|
||||
- pubmed
|
||||
- dalle
|
||||
- azuredalle
|
||||
- stablediffusion
|
||||
- webscraper
|
||||
- model.zhipuai
|
||||
- aippt
|
||||
- youtube
|
||||
- wolframalpha
|
||||
- maths
|
||||
@ -23,5 +18,3 @@
|
||||
- vectorizer
|
||||
- gaode
|
||||
- wecom
|
||||
- qrcode
|
||||
- dingtalk
|
||||
|
||||
@ -4,24 +4,24 @@ from yaml import FullLoader, load
|
||||
|
||||
from core.tools.entities.user_entities import UserToolProvider
|
||||
|
||||
position = {}
|
||||
|
||||
class BuiltinToolProviderSort:
|
||||
_position = {}
|
||||
|
||||
@classmethod
|
||||
def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
|
||||
if not cls._position:
|
||||
@staticmethod
|
||||
def sort(providers: list[UserToolProvider]) -> list[UserToolProvider]:
|
||||
global position
|
||||
if not position:
|
||||
tmp_position = {}
|
||||
file_path = os.path.join(os.path.dirname(__file__), '..', '_position.yaml')
|
||||
with open(file_path) as f:
|
||||
for pos, val in enumerate(load(f, Loader=FullLoader)):
|
||||
tmp_position[val] = pos
|
||||
cls._position = tmp_position
|
||||
position = tmp_position
|
||||
|
||||
def sort_compare(provider: UserToolProvider) -> int:
|
||||
if provider.type == UserToolProvider.ProviderType.MODEL:
|
||||
return cls._position.get(f'model.{provider.name}', 10000)
|
||||
return cls._position.get(provider.name, 10000)
|
||||
# if provider.type == UserToolProvider.ProviderType.MODEL:
|
||||
# return position.get(f'model_provider.{provider.name}', 10000)
|
||||
return position.get(provider.name, 10000)
|
||||
|
||||
sorted_providers = sorted(providers, key=sort_compare)
|
||||
|
||||
|
||||
|
Before Width: | Height: | Size: 1.9 KiB |
@ -1,11 +0,0 @@
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.aippt.tools.aippt import AIPPTGenerateTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class AIPPTProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
AIPPTGenerateTool._get_api_token(credentials, user_id='__dify_system__')
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
@ -1,42 +0,0 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: aippt
|
||||
label:
|
||||
en_US: AIPPT
|
||||
zh_Hans: AIPPT
|
||||
description:
|
||||
en_US: AI-generated PPT with one click, input your content topic, and let AI serve you one-stop
|
||||
zh_Hans: AI一键生成PPT,输入你的内容主题,让AI为你一站式服务到底
|
||||
icon: icon.png
|
||||
credentials_for_provider:
|
||||
aippt_access_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: AIPPT API key
|
||||
zh_Hans: AIPPT API key
|
||||
pt_BR: AIPPT API key
|
||||
help:
|
||||
en_US: Please input your AIPPT API key
|
||||
zh_Hans: 请输入你的 AIPPT API key
|
||||
pt_BR: Please input your AIPPT API key
|
||||
placeholder:
|
||||
en_US: Please input your AIPPT API key
|
||||
zh_Hans: 请输入你的 AIPPT API key
|
||||
pt_BR: Please input your AIPPT API key
|
||||
url: https://www.aippt.cn
|
||||
aippt_secret_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: AIPPT Secret key
|
||||
zh_Hans: AIPPT Secret key
|
||||
pt_BR: AIPPT Secret key
|
||||
help:
|
||||
en_US: Please input your AIPPT Secret key
|
||||
zh_Hans: 请输入你的 AIPPT Secret key
|
||||
pt_BR: Please input your AIPPT Secret key
|
||||
placeholder:
|
||||
en_US: Please input your AIPPT Secret key
|
||||
zh_Hans: 请输入你的 AIPPT Secret key
|
||||
pt_BR: Please input your AIPPT Secret key
|
||||
@ -1,541 +0,0 @@
|
||||
from base64 import b64encode
|
||||
from hashlib import sha1
|
||||
from hmac import new as hmac_new
|
||||
from json import loads as json_loads
|
||||
from threading import Lock
|
||||
from time import sleep, time
|
||||
from typing import Any
|
||||
|
||||
from httpx import get, post
|
||||
from requests import get as requests_get
|
||||
from yarl import URL
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class AIPPTGenerateTool(BuiltinTool):
|
||||
"""
|
||||
A tool for generating a ppt
|
||||
"""
|
||||
|
||||
_api_base_url = URL('https://co.aippt.cn/api')
|
||||
_api_token_cache = {}
|
||||
_api_token_cache_lock = Lock()
|
||||
_style_cache = {}
|
||||
_style_cache_lock = Lock()
|
||||
|
||||
_task = {}
|
||||
_task_type_map = {
|
||||
'auto': 1,
|
||||
'markdown': 7,
|
||||
}
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
"""
|
||||
Invokes the AIPPT generate tool with the given user ID and tool parameters.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user invoking the tool.
|
||||
tool_parameters (dict[str, Any]): The parameters for the tool
|
||||
|
||||
Returns:
|
||||
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, which can be a single message or a list of messages.
|
||||
"""
|
||||
title = tool_parameters.get('title', '')
|
||||
if not title:
|
||||
return self.create_text_message('Please provide a title for the ppt')
|
||||
|
||||
model = tool_parameters.get('model', 'aippt')
|
||||
if not model:
|
||||
return self.create_text_message('Please provide a model for the ppt')
|
||||
|
||||
outline = tool_parameters.get('outline', '')
|
||||
|
||||
# create task
|
||||
task_id = self._create_task(
|
||||
type=self._task_type_map['auto' if not outline else 'markdown'],
|
||||
title=title,
|
||||
content=outline,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# get suit
|
||||
color = tool_parameters.get('color')
|
||||
style = tool_parameters.get('style')
|
||||
|
||||
if color == '__default__':
|
||||
color_id = ''
|
||||
else:
|
||||
color_id = int(color.split('-')[1])
|
||||
|
||||
if style == '__default__':
|
||||
style_id = ''
|
||||
else:
|
||||
style_id = int(style.split('-')[1])
|
||||
|
||||
suit_id = self._get_suit(style_id=style_id, colour_id=color_id)
|
||||
|
||||
# generate outline
|
||||
if not outline:
|
||||
self._generate_outline(
|
||||
task_id=task_id,
|
||||
model=model,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# generate content
|
||||
self._generate_content(
|
||||
task_id=task_id,
|
||||
model=model,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# generate ppt
|
||||
_, ppt_url = self._generate_ppt(
|
||||
task_id=task_id,
|
||||
suit_id=suit_id,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
return self.create_text_message('''the ppt has been created successfully,'''
|
||||
f'''the ppt url is {ppt_url}'''
|
||||
'''please give the ppt url to user and direct user to download it.''')
|
||||
|
||||
def _create_task(self, type: int, title: str, content: str, user_id: str) -> str:
|
||||
"""
|
||||
Create a task
|
||||
|
||||
:param type: the task type
|
||||
:param title: the task title
|
||||
:param content: the task content
|
||||
|
||||
:return: the task ID
|
||||
"""
|
||||
headers = {
|
||||
'x-channel': '',
|
||||
'x-api-key': self.runtime.credentials['aippt_access_key'],
|
||||
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
|
||||
}
|
||||
response = post(
|
||||
str(self._api_base_url / 'ai' / 'chat' / 'v2' / 'task'),
|
||||
headers=headers,
|
||||
files={
|
||||
'type': ('', str(type)),
|
||||
'title': ('', title),
|
||||
'content': ('', content)
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f'Failed to connect to aippt: {response.text}')
|
||||
|
||||
response = response.json()
|
||||
if response.get('code') != 0:
|
||||
raise Exception(f'Failed to create task: {response.get("msg")}')
|
||||
|
||||
return response.get('data', {}).get('id')
|
||||
|
||||
def _generate_outline(self, task_id: str, model: str, user_id: str) -> str:
|
||||
api_url = self._api_base_url / 'ai' / 'chat' / 'outline' if model == 'aippt' else \
|
||||
self._api_base_url / 'ai' / 'chat' / 'wx' / 'outline'
|
||||
api_url %= {'task_id': task_id}
|
||||
|
||||
headers = {
|
||||
'x-channel': '',
|
||||
'x-api-key': self.runtime.credentials['aippt_access_key'],
|
||||
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
|
||||
}
|
||||
|
||||
response = requests_get(
|
||||
url=api_url,
|
||||
headers=headers,
|
||||
stream=True,
|
||||
timeout=(10, 60)
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f'Failed to connect to aippt: {response.text}')
|
||||
|
||||
outline = ''
|
||||
for chunk in response.iter_lines(delimiter=b'\n\n'):
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
event = ''
|
||||
lines = chunk.decode('utf-8').split('\n')
|
||||
for line in lines:
|
||||
if line.startswith('event:'):
|
||||
event = line[6:]
|
||||
elif line.startswith('data:'):
|
||||
data = line[5:]
|
||||
if event == 'message':
|
||||
try:
|
||||
data = json_loads(data)
|
||||
outline += data.get('content', '')
|
||||
except Exception as e:
|
||||
pass
|
||||
elif event == 'close':
|
||||
break
|
||||
elif event == 'error' or event == 'filter':
|
||||
raise Exception(f'Failed to generate outline: {data}')
|
||||
|
||||
return outline
|
||||
|
||||
def _generate_content(self, task_id: str, model: str, user_id: str) -> str:
|
||||
api_url = self._api_base_url / 'ai' / 'chat' / 'content' if model == 'aippt' else \
|
||||
self._api_base_url / 'ai' / 'chat' / 'wx' / 'content'
|
||||
api_url %= {'task_id': task_id}
|
||||
|
||||
headers = {
|
||||
'x-channel': '',
|
||||
'x-api-key': self.runtime.credentials['aippt_access_key'],
|
||||
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
|
||||
}
|
||||
|
||||
response = requests_get(
|
||||
url=api_url,
|
||||
headers=headers,
|
||||
stream=True,
|
||||
timeout=(10, 60)
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f'Failed to connect to aippt: {response.text}')
|
||||
|
||||
if model == 'aippt':
|
||||
content = ''
|
||||
for chunk in response.iter_lines(delimiter=b'\n\n'):
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
event = ''
|
||||
lines = chunk.decode('utf-8').split('\n')
|
||||
for line in lines:
|
||||
if line.startswith('event:'):
|
||||
event = line[6:]
|
||||
elif line.startswith('data:'):
|
||||
data = line[5:]
|
||||
if event == 'message':
|
||||
try:
|
||||
data = json_loads(data)
|
||||
content += data.get('content', '')
|
||||
except Exception as e:
|
||||
pass
|
||||
elif event == 'close':
|
||||
break
|
||||
elif event == 'error' or event == 'filter':
|
||||
raise Exception(f'Failed to generate content: {data}')
|
||||
|
||||
return content
|
||||
elif model == 'wenxin':
|
||||
response = response.json()
|
||||
if response.get('code') != 0:
|
||||
raise Exception(f'Failed to generate content: {response.get("msg")}')
|
||||
|
||||
return response.get('data', '')
|
||||
|
||||
return ''
|
||||
|
||||
def _generate_ppt(self, task_id: str, suit_id: int, user_id) -> tuple[str, str]:
|
||||
"""
|
||||
Generate a ppt
|
||||
|
||||
:param task_id: the task ID
|
||||
:param suit_id: the suit ID
|
||||
:return: the cover url of the ppt and the ppt url
|
||||
"""
|
||||
headers = {
|
||||
'x-channel': '',
|
||||
'x-api-key': self.runtime.credentials['aippt_access_key'],
|
||||
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
|
||||
}
|
||||
|
||||
response = post(
|
||||
str(self._api_base_url / 'design' / 'v2' / 'save'),
|
||||
headers=headers,
|
||||
data={
|
||||
'task_id': task_id,
|
||||
'template_id': suit_id
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f'Failed to connect to aippt: {response.text}')
|
||||
|
||||
response = response.json()
|
||||
if response.get('code') != 0:
|
||||
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
|
||||
|
||||
id = response.get('data', {}).get('id')
|
||||
cover_url = response.get('data', {}).get('cover_url')
|
||||
|
||||
response = post(
|
||||
str(self._api_base_url / 'download' / 'export' / 'file'),
|
||||
headers=headers,
|
||||
data={
|
||||
'id': id,
|
||||
'format': 'ppt',
|
||||
'files_to_zip': False,
|
||||
'edit': True
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f'Failed to connect to aippt: {response.text}')
|
||||
|
||||
response = response.json()
|
||||
if response.get('code') != 0:
|
||||
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
|
||||
|
||||
export_code = response.get('data')
|
||||
if not export_code:
|
||||
raise Exception('Failed to generate ppt, the export code is empty')
|
||||
|
||||
current_iteration = 0
|
||||
while current_iteration < 50:
|
||||
# get ppt url
|
||||
response = post(
|
||||
str(self._api_base_url / 'download' / 'export' / 'file' / 'result'),
|
||||
headers=headers,
|
||||
data={
|
||||
'task_key': export_code
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f'Failed to connect to aippt: {response.text}')
|
||||
|
||||
response = response.json()
|
||||
if response.get('code') != 0:
|
||||
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
|
||||
|
||||
if response.get('msg') == '导出中':
|
||||
current_iteration += 1
|
||||
sleep(2)
|
||||
continue
|
||||
|
||||
ppt_url = response.get('data', [])
|
||||
if len(ppt_url) == 0:
|
||||
raise Exception('Failed to generate ppt, the ppt url is empty')
|
||||
|
||||
return cover_url, ppt_url[0]
|
||||
|
||||
raise Exception('Failed to generate ppt, the export is timeout')
|
||||
|
||||
@classmethod
|
||||
def _get_api_token(cls, credentials: dict[str, str], user_id: str) -> str:
|
||||
"""
|
||||
Get API token
|
||||
|
||||
:param credentials: the credentials
|
||||
:return: the API token
|
||||
"""
|
||||
access_key = credentials['aippt_access_key']
|
||||
secret_key = credentials['aippt_secret_key']
|
||||
|
||||
cache_key = f'{access_key}#@#{user_id}'
|
||||
|
||||
with cls._api_token_cache_lock:
|
||||
# clear expired tokens
|
||||
now = time()
|
||||
for key in list(cls._api_token_cache.keys()):
|
||||
if cls._api_token_cache[key]['expire'] < now:
|
||||
del cls._api_token_cache[key]
|
||||
|
||||
if cache_key in cls._api_token_cache:
|
||||
return cls._api_token_cache[cache_key]['token']
|
||||
|
||||
# get token
|
||||
headers = {
|
||||
'x-api-key': access_key,
|
||||
'x-timestamp': str(int(now)),
|
||||
'x-signature': cls._calculate_sign(access_key, secret_key, int(now))
|
||||
}
|
||||
|
||||
param = {
|
||||
'uid': user_id,
|
||||
'channel': ''
|
||||
}
|
||||
|
||||
response = get(
|
||||
str(cls._api_base_url / 'grant' / 'token'),
|
||||
params=param,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f'Failed to connect to aippt: {response.text}')
|
||||
response = response.json()
|
||||
if response.get('code') != 0:
|
||||
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
|
||||
|
||||
token = response.get('data', {}).get('token')
|
||||
expire = response.get('data', {}).get('time_expire')
|
||||
|
||||
with cls._api_token_cache_lock:
|
||||
cls._api_token_cache[cache_key] = {
|
||||
'token': token,
|
||||
'expire': now + expire
|
||||
}
|
||||
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def _calculate_sign(cls, access_key: str, secret_key: str, timestamp: int) -> str:
|
||||
return b64encode(
|
||||
hmac_new(
|
||||
key=secret_key.encode('utf-8'),
|
||||
msg=f'GET@/api/grant/token/@{timestamp}'.encode(),
|
||||
digestmod=sha1
|
||||
).digest()
|
||||
).decode('utf-8')
|
||||
|
||||
@classmethod
|
||||
def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[dict], list[dict]]:
|
||||
"""
|
||||
Get styles
|
||||
"""
|
||||
|
||||
# check cache
|
||||
with cls._style_cache_lock:
|
||||
# clear expired styles
|
||||
now = time()
|
||||
for key in list(cls._style_cache.keys()):
|
||||
if cls._style_cache[key]['expire'] < now:
|
||||
del cls._style_cache[key]
|
||||
|
||||
key = f'{credentials["aippt_access_key"]}#@#{user_id}'
|
||||
if key in cls._style_cache:
|
||||
return cls._style_cache[key]['colors'], cls._style_cache[key]['styles']
|
||||
|
||||
headers = {
|
||||
'x-channel': '',
|
||||
'x-api-key': credentials['aippt_access_key'],
|
||||
'x-token': cls._get_api_token(credentials=credentials, user_id=user_id)
|
||||
}
|
||||
response = get(
|
||||
str(cls._api_base_url / 'template_component' / 'suit' / 'select'),
|
||||
headers=headers
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f'Failed to connect to aippt: {response.text}')
|
||||
|
||||
response = response.json()
|
||||
|
||||
if response.get('code') != 0:
|
||||
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
|
||||
|
||||
colors = [{
|
||||
'id': f'id-{item.get("id")}',
|
||||
'name': item.get('name'),
|
||||
'en_name': item.get('en_name', item.get('name')),
|
||||
} for item in response.get('data', {}).get('colour') or []]
|
||||
styles = [{
|
||||
'id': f'id-{item.get("id")}',
|
||||
'name': item.get('title'),
|
||||
} for item in response.get('data', {}).get('suit_style') or []]
|
||||
|
||||
with cls._style_cache_lock:
|
||||
cls._style_cache[key] = {
|
||||
'colors': colors,
|
||||
'styles': styles,
|
||||
'expire': now + 60 * 60
|
||||
}
|
||||
|
||||
return colors, styles
|
||||
|
||||
def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]:
|
||||
"""
|
||||
Get styles
|
||||
|
||||
:param credentials: the credentials
|
||||
:return: Tuple[list[dict[id, color]], list[dict[id, style]]
|
||||
"""
|
||||
if not self.runtime.credentials.get('aippt_access_key') or not self.runtime.credentials.get('aippt_secret_key'):
|
||||
raise Exception('Please provide aippt credentials')
|
||||
|
||||
return self._get_styles(credentials=self.runtime.credentials, user_id=user_id)
|
||||
|
||||
def _get_suit(self, style_id: int, colour_id: int) -> int:
|
||||
"""
|
||||
Get suit
|
||||
"""
|
||||
headers = {
|
||||
'x-channel': '',
|
||||
'x-api-key': self.runtime.credentials['aippt_access_key'],
|
||||
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id='__dify_system__')
|
||||
}
|
||||
response = get(
|
||||
str(self._api_base_url / 'template_component' / 'suit' / 'search'),
|
||||
headers=headers,
|
||||
params={
|
||||
'style_id': style_id,
|
||||
'colour_id': colour_id,
|
||||
'page': 1,
|
||||
'page_size': 1
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f'Failed to connect to aippt: {response.text}')
|
||||
|
||||
response = response.json()
|
||||
|
||||
if response.get('code') != 0:
|
||||
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
|
||||
|
||||
if len(response.get('data', {}).get('list') or []) > 0:
|
||||
return response.get('data', {}).get('list')[0].get('id')
|
||||
|
||||
raise Exception('Failed to get suit, the suit does not exist, please check the style and color')
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
"""
|
||||
Get runtime parameters
|
||||
|
||||
Override this method to add runtime parameters to the tool.
|
||||
"""
|
||||
try:
|
||||
colors, styles = self.get_styles(user_id='__dify_system__')
|
||||
except Exception as e:
|
||||
colors, styles = [
|
||||
{'id': -1, 'name': '__default__', 'en_name': '__default__'}
|
||||
], [
|
||||
{'id': -1, 'name': '__default__', 'en_name': '__default__'}
|
||||
]
|
||||
|
||||
return [
|
||||
ToolParameter(
|
||||
name='color',
|
||||
label=I18nObject(zh_Hans='颜色', en_US='Color'),
|
||||
human_description=I18nObject(zh_Hans='颜色', en_US='Color'),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
required=False,
|
||||
default=colors[0]['id'],
|
||||
options=[
|
||||
ToolParameterOption(
|
||||
value=color['id'],
|
||||
label=I18nObject(zh_Hans=color['name'], en_US=color['en_name'])
|
||||
) for color in colors
|
||||
]
|
||||
),
|
||||
ToolParameter(
|
||||
name='style',
|
||||
label=I18nObject(zh_Hans='风格', en_US='Style'),
|
||||
human_description=I18nObject(zh_Hans='风格', en_US='Style'),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
required=False,
|
||||
default=styles[0]['id'],
|
||||
options=[
|
||||
ToolParameterOption(
|
||||
value=style['id'],
|
||||
label=I18nObject(zh_Hans=style['name'], en_US=style['name'])
|
||||
) for style in styles
|
||||
]
|
||||
),
|
||||
]
|
||||
@ -1,54 +0,0 @@
|
||||
identity:
|
||||
name: aippt
|
||||
author: Dify
|
||||
label:
|
||||
en_US: AIPPT
|
||||
zh_Hans: AIPPT
|
||||
description:
|
||||
human:
|
||||
en_US: AI-generated PPT with one click, input your content topic, and let AI serve you one-stop
|
||||
zh_Hans: AI一键生成PPT,输入你的内容主题,让AI为你一站式服务到底
|
||||
llm: A tool used to generate PPT with AI, input your content topic, and let AI generate PPT for you.
|
||||
parameters:
|
||||
- name: title
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Title
|
||||
zh_Hans: 标题
|
||||
human_description:
|
||||
en_US: The title of the PPT.
|
||||
zh_Hans: PPT的标题。
|
||||
llm_description: The title of the PPT, which will be used to generate the PPT outline.
|
||||
form: llm
|
||||
- name: outline
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Outline
|
||||
zh_Hans: 大纲
|
||||
human_description:
|
||||
en_US: The outline of the PPT
|
||||
zh_Hans: PPT的大纲
|
||||
llm_description: The outline of the PPT, which will be used to generate the PPT content. provide it if you have.
|
||||
form: llm
|
||||
- name: llm
|
||||
type: select
|
||||
required: true
|
||||
label:
|
||||
en_US: LLM model
|
||||
zh_Hans: 生成大纲的LLM
|
||||
options:
|
||||
- value: aippt
|
||||
label:
|
||||
en_US: AIPPT default model
|
||||
zh_Hans: AIPPT默认模型
|
||||
- value: wenxin
|
||||
label:
|
||||
en_US: Wenxin ErnieBot
|
||||
zh_Hans: 文心一言
|
||||
default: aippt
|
||||
human_description:
|
||||
en_US: The LLM model used for generating PPT outline.
|
||||
zh_Hans: 用于生成PPT大纲的LLM模型。
|
||||
form: form
|
||||
BIN
api/core/tools/provider/builtin/bing/_assets/icon.png
Normal file
|
After Width: | Height: | Size: 4.5 KiB |
@ -1,40 +0,0 @@
|
||||
<svg viewBox="-29.62167543756803 0.1 574.391675437568 799.8100000000002" xmlns="http://www.w3.org/2000/svg" width="1888"
|
||||
height="2500">
|
||||
<linearGradient id="a" gradientUnits="userSpaceOnUse" x1="286.383" x2="542.057" y1="284.169" y2="569.112">
|
||||
<stop offset="0" stop-color="#37bdff"/>
|
||||
<stop offset=".25" stop-color="#26c6f4"/>
|
||||
<stop offset=".5" stop-color="#15d0e9"/>
|
||||
<stop offset=".75" stop-color="#3bd6df"/>
|
||||
<stop offset="1" stop-color="#62dcd4"/>
|
||||
</linearGradient>
|
||||
<linearGradient id="b" gradientUnits="userSpaceOnUse" x1="108.979" x2="100.756" y1="675.98" y2="43.669">
|
||||
<stop offset="0" stop-color="#1b48ef"/>
|
||||
<stop offset=".5" stop-color="#2080f1"/>
|
||||
<stop offset="1" stop-color="#26b8f4"/>
|
||||
</linearGradient>
|
||||
<linearGradient id="c" gradientUnits="userSpaceOnUse" x1="256.823" x2="875.632" y1="649.719" y2="649.719">
|
||||
<stop offset="0" stop-color="#39d2ff"/>
|
||||
<stop offset=".5" stop-color="#248ffa"/>
|
||||
<stop offset="1" stop-color="#104cf5"/>
|
||||
</linearGradient>
|
||||
<linearGradient id="d" gradientUnits="userSpaceOnUse" x1="256.823" x2="875.632" y1="649.719" y2="649.719">
|
||||
<stop offset="0" stop-color="#fff"/>
|
||||
<stop offset="1"/>
|
||||
</linearGradient>
|
||||
<path d="M249.97 277.48c-.12.96-.12 2.05-.12 3.12 0 4.16.83 8.16 2.33 11.84l1.34 2.76 5.3 13.56 27.53 70.23 24.01 61.33c6.85 12.38 17.82 22.1 31.05 27.28l4.11 1.51c.16.05.43.05.65.11l65.81 22.63v.05l25.16 8.64 1.72.58c.06 0 .16.06.22.06 4.96 1.25 9.82 2.93 14.46 4.98 10.73 4.63 20.46 11.23 28.77 19.28 3.35 3.2 6.43 6.65 9.28 10.33a88.64 88.64 0 0 1 6.64 9.72c8.78 14.58 13.82 31.72 13.82 49.97 0 3.26-.16 6.41-.49 9.61-.11 1.41-.28 2.77-.49 4.12v.11c-.22 1.43-.49 2.91-.76 4.36-.28 1.41-.54 2.81-.86 4.21-.05.16-.11.33-.17.49-.3 1.42-.68 2.82-1.07 4.23-.35 1.33-.79 2.7-1.28 3.99a42.96 42.96 0 0 1-1.51 4.16c-.49 1.4-1.07 2.82-1.72 4.16-1.78 4.11-3.9 8.06-6.28 11.83a97.889 97.889 0 0 1-10.47 13.95c30.88-33.2 51.41-76.07 56.52-123.51.86-7.78 1.3-15.67 1.3-23.61 0-5.07-.22-10.09-.55-15.13-3.89-56.89-29.79-107.77-69.32-144.08-10.9-10.09-22.81-19.07-35.62-26.69l-24.2-12.37-122.63-62.93a30.15 30.15 0 0 0-11.93-2.44c-15.88 0-28.99 12.11-30.55 27.56z"
|
||||
fill="#7f7f7f"/>
|
||||
<path d="M249.97 277.48c-.12.96-.12 2.05-.12 3.12 0 4.16.83 8.16 2.33 11.84l1.34 2.76 5.3 13.56 27.53 70.23 24.01 61.33c6.85 12.38 17.82 22.1 31.05 27.28l4.11 1.51c.16.05.43.05.65.11l65.81 22.63v.05l25.16 8.64 1.72.58c.06 0 .16.06.22.06 4.96 1.25 9.82 2.93 14.46 4.98 10.73 4.63 20.46 11.23 28.77 19.28 3.35 3.2 6.43 6.65 9.28 10.33a88.64 88.64 0 0 1 6.64 9.72c8.78 14.58 13.82 31.72 13.82 49.97 0 3.26-.16 6.41-.49 9.61-.11 1.41-.28 2.77-.49 4.12v.11c-.22 1.43-.49 2.91-.76 4.36-.28 1.41-.54 2.81-.86 4.21-.05.16-.11.33-.17.49-.3 1.42-.68 2.82-1.07 4.23-.35 1.33-.79 2.7-1.28 3.99a42.96 42.96 0 0 1-1.51 4.16c-.49 1.4-1.07 2.82-1.72 4.16-1.78 4.11-3.9 8.06-6.28 11.83a97.889 97.889 0 0 1-10.47 13.95c30.88-33.2 51.41-76.07 56.52-123.51.86-7.78 1.3-15.67 1.3-23.61 0-5.07-.22-10.09-.55-15.13-3.89-56.89-29.79-107.77-69.32-144.08-10.9-10.09-22.81-19.07-35.62-26.69l-24.2-12.37-122.63-62.93a30.15 30.15 0 0 0-11.93-2.44c-15.88 0-28.99 12.11-30.55 27.56z"
|
||||
fill="url(#a)"/>
|
||||
<path d="M31.62.1C14.17.41.16 14.69.16 32.15v559.06c.07 3.9.29 7.75.57 11.66.25 2.06.52 4.2.9 6.28 7.97 44.87 47.01 78.92 94.15 78.92 16.53 0 32.03-4.21 45.59-11.53.08-.06.22-.14.29-.14l4.88-2.95 19.78-11.64 25.16-14.93.06-496.73c0-33.01-16.52-62.11-41.81-79.4-.6-.36-1.18-.74-1.71-1.17L50.12 5.56C45.16 2.28 39.18.22 32.77.1z"
|
||||
fill="#7f7f7f"/>
|
||||
<path d="M31.62.1C14.17.41.16 14.69.16 32.15v559.06c.07 3.9.29 7.75.57 11.66.25 2.06.52 4.2.9 6.28 7.97 44.87 47.01 78.92 94.15 78.92 16.53 0 32.03-4.21 45.59-11.53.08-.06.22-.14.29-.14l4.88-2.95 19.78-11.64 25.16-14.93.06-496.73c0-33.01-16.52-62.11-41.81-79.4-.6-.36-1.18-.74-1.71-1.17L50.12 5.56C45.16 2.28 39.18.22 32.77.1z"
|
||||
fill="url(#b)"/>
|
||||
<path d="M419.81 510.84L194.72 644.26l-3.24 1.95v.71l-25.16 14.9-19.77 11.67-4.85 2.93-.33.16c-13.53 7.35-29.04 11.51-45.56 11.51-47.13 0-86.22-34.03-94.16-78.92 3.77 32.84 14.96 63.41 31.84 90.04 34.76 54.87 93.54 93.04 161.54 99.67h41.58c36.78-3.84 67.49-18.57 99.77-38.46l49.64-30.36c22.36-14.33 83.05-49.58 100.93-69.36 3.89-4.33 7.4-8.97 10.47-13.94 2.38-3.78 4.5-7.73 6.28-11.84.6-1.4 1.17-2.76 1.72-4.15.52-1.38 1.01-2.77 1.51-4.18.93-2.7 1.67-5.41 2.38-8.2.36-1.59.69-3.16 1.02-4.72 1.08-5.89 1.67-11.94 1.67-18.21 0-18.25-5.04-35.39-13.77-49.95-2-3.4-4.2-6.65-6.64-9.72-2.85-3.7-5.93-7.13-9.28-10.33-8.31-8.05-18.01-14.65-28.77-19.29-4.64-2.05-9.48-3.74-14.46-4.97-.06 0-.16-.06-.22-.06l-1.72-.58z"
|
||||
fill="#7f7f7f"/>
|
||||
<path d="M419.81 510.84L194.72 644.26l-3.24 1.95v.71l-25.16 14.9-19.77 11.67-4.85 2.93-.33.16c-13.53 7.35-29.04 11.51-45.56 11.51-47.13 0-86.22-34.03-94.16-78.92 3.77 32.84 14.96 63.41 31.84 90.04 34.76 54.87 93.54 93.04 161.54 99.67h41.58c36.78-3.84 67.49-18.57 99.77-38.46l49.64-30.36c22.36-14.33 83.05-49.58 100.93-69.36 3.89-4.33 7.4-8.97 10.47-13.94 2.38-3.78 4.5-7.73 6.28-11.84.6-1.4 1.17-2.76 1.72-4.15.52-1.38 1.01-2.77 1.51-4.18.93-2.7 1.67-5.41 2.38-8.2.36-1.59.69-3.16 1.02-4.72 1.08-5.89 1.67-11.94 1.67-18.21 0-18.25-5.04-35.39-13.77-49.95-2-3.4-4.2-6.65-6.64-9.72-2.85-3.7-5.93-7.13-9.28-10.33-8.31-8.05-18.01-14.65-28.77-19.29-4.64-2.05-9.48-3.74-14.46-4.97-.06 0-.16-.06-.22-.06l-1.72-.58z"
|
||||
fill="url(#c)"/>
|
||||
<path d="M512 595.46c0 6.27-.59 12.33-1.68 18.22-.32 1.56-.65 3.12-1.02 4.7-.7 2.8-1.44 5.51-2.37 8.22-.49 1.4-.99 2.8-1.51 4.16-.54 1.4-1.12 2.76-1.73 4.16a87.873 87.873 0 0 1-6.26 11.83 96.567 96.567 0 0 1-10.48 13.94c-17.88 19.79-78.57 55.04-100.93 69.37l-49.64 30.36c-36.39 22.42-70.77 38.29-114.13 39.38-2.05.06-4.06.11-6.05.11-2.8 0-5.56-.05-8.33-.16-73.42-2.8-137.45-42.25-174.38-100.54a213.368 213.368 0 0 1-31.84-90.04c7.94 44.89 47.03 78.92 94.16 78.92 16.52 0 32.03-4.17 45.56-11.51l.33-.17 4.85-2.92 19.77-11.67 25.16-14.9v-.71l3.24-1.95 225.09-133.43 17.33-10.27 1.72.58c.05 0 .16.06.22.06 4.98 1.23 9.83 2.92 14.46 4.97 10.76 4.64 20.45 11.24 28.77 19.29a92.13 92.13 0 0 1 9.28 10.33c2.44 3.07 4.64 6.32 6.64 9.72 8.73 14.56 13.77 31.7 13.77 49.95z"
|
||||
fill="#7f7f7f" opacity=".15"/>
|
||||
<path d="M512 595.46c0 6.27-.59 12.33-1.68 18.22-.32 1.56-.65 3.12-1.02 4.7-.7 2.8-1.44 5.51-2.37 8.22-.49 1.4-.99 2.8-1.51 4.16-.54 1.4-1.12 2.76-1.73 4.16a87.873 87.873 0 0 1-6.26 11.83 96.567 96.567 0 0 1-10.48 13.94c-17.88 19.79-78.57 55.04-100.93 69.37l-49.64 30.36c-36.39 22.42-70.77 38.29-114.13 39.38-2.05.06-4.06.11-6.05.11-2.8 0-5.56-.05-8.33-.16-73.42-2.8-137.45-42.25-174.38-100.54a213.368 213.368 0 0 1-31.84-90.04c7.94 44.89 47.03 78.92 94.16 78.92 16.52 0 32.03-4.17 45.56-11.51l.33-.17 4.85-2.92 19.77-11.67 25.16-14.9v-.71l3.24-1.95 225.09-133.43 17.33-10.27 1.72.58c.05 0 .16.06.22.06 4.98 1.23 9.83 2.92 14.46 4.97 10.76 4.64 20.45 11.24 28.77 19.29a92.13 92.13 0 0 1 9.28 10.33c2.44 3.07 4.64 6.32 6.64 9.72 8.73 14.56 13.77 31.7 13.77 49.95z"
|
||||
fill="url(#d)" opacity=".15"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 6.9 KiB |
@ -9,7 +9,7 @@ identity:
|
||||
en_US: Bing Search
|
||||
zh_Hans: Bing 搜索
|
||||
pt_BR: Bing Search
|
||||
icon: icon.svg
|
||||
icon: icon.png
|
||||
credentials_for_provider:
|
||||
subscription_key:
|
||||
type: secret-input
|
||||
|
||||
@ -1,7 +0,0 @@
|
||||
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
|
||||
|
||||
<!-- Uploaded to: SVG Repo, www.svgrepo.com, Transformed by: SVG Repo Mixer Tools -->
|
||||
<svg fill="#4aa4f8" width="800px" height="800px" viewBox="0 0 1024 1024" xmlns="http://www.w3.org/2000/svg" class="icon" stroke="#4aa4f8">
|
||||
|
||||
<g id="SVGRepo_bgCarrier" stroke-width="0"/>
|
||||
|
||||
|
Before Width: | Height: | Size: 1.1 KiB |
@ -1,8 +0,0 @@
|
||||
from core.tools.provider.builtin.dingtalk.tools.dingtalk_group_bot import DingTalkGroupBotTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class DingTalkProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
DingTalkGroupBotTool()
|
||||
pass
|
||||
@ -1,13 +0,0 @@
|
||||
identity:
|
||||
author: Bowen Liang
|
||||
name: dingtalk
|
||||
label:
|
||||
en_US: DingTalk
|
||||
zh_Hans: 钉钉
|
||||
pt_BR: DingTalk
|
||||
description:
|
||||
en_US: DingTalk group robot
|
||||
zh_Hans: 钉钉群机器人
|
||||
pt_BR: DingTalk group robot
|
||||
icon: icon.svg
|
||||
credentials_for_provider:
|
||||
@ -1,83 +0,0 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import time
|
||||
import urllib.parse
|
||||
from typing import Any, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class DingTalkGroupBotTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
Dingtalk custom group robot API docs:
|
||||
https://open.dingtalk.com/document/orgapp/custom-robot-access
|
||||
"""
|
||||
content = tool_parameters.get('content')
|
||||
if not content:
|
||||
return self.create_text_message('Invalid parameter content')
|
||||
|
||||
access_token = tool_parameters.get('access_token')
|
||||
if not access_token:
|
||||
return self.create_text_message('Invalid parameter access_token. '
|
||||
'Regarding information about security details,'
|
||||
'please refer to the DingTalk docs:'
|
||||
'https://open.dingtalk.com/document/robots/customize-robot-security-settings')
|
||||
|
||||
sign_secret = tool_parameters.get('sign_secret')
|
||||
if not sign_secret:
|
||||
return self.create_text_message('Invalid parameter sign_secret. '
|
||||
'Regarding information about security details,'
|
||||
'please refer to the DingTalk docs:'
|
||||
'https://open.dingtalk.com/document/robots/customize-robot-security-settings')
|
||||
|
||||
msgtype = 'text'
|
||||
api_url = 'https://oapi.dingtalk.com/robot/send'
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
params = {
|
||||
'access_token': access_token,
|
||||
}
|
||||
|
||||
self._apply_security_mechanism(params, sign_secret)
|
||||
|
||||
payload = {
|
||||
"msgtype": msgtype,
|
||||
"text": {
|
||||
"content": content,
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
res = httpx.post(api_url, headers=headers, params=params, json=payload)
|
||||
if res.is_success:
|
||||
return self.create_text_message("Text message sent successfully")
|
||||
else:
|
||||
return self.create_text_message(
|
||||
f"Failed to send the text message, status code: {res.status_code}, response: {res.text}")
|
||||
except Exception as e:
|
||||
return self.create_text_message("Failed to send message to group chat bot. {}".format(e))
|
||||
|
||||
@staticmethod
|
||||
def _apply_security_mechanism(params: dict[str, Any], sign_secret: str):
|
||||
try:
|
||||
timestamp = str(round(time.time() * 1000))
|
||||
secret_enc = sign_secret.encode('utf-8')
|
||||
string_to_sign = f'{timestamp}\n{sign_secret}'
|
||||
string_to_sign_enc = string_to_sign.encode('utf-8')
|
||||
hmac_code = hmac.new(secret_enc, string_to_sign_enc, digestmod=hashlib.sha256).digest()
|
||||
sign = urllib.parse.quote_plus(base64.b64encode(hmac_code))
|
||||
|
||||
params['timestamp'] = timestamp
|
||||
params['sign'] = sign
|
||||
except Exception:
|
||||
msg = "Failed to apply security mechanism to the request."
|
||||
logging.exception(msg)
|
||||
@ -1,52 +0,0 @@
|
||||
identity:
|
||||
name: dingtalk_group_bot
|
||||
author: Bowen Liang
|
||||
label:
|
||||
en_US: Send Group Message
|
||||
zh_Hans: 发送群消息
|
||||
pt_BR: Send Group Message
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: Sending a group message on DingTalk via the webhook of group bot
|
||||
zh_Hans: 通过钉钉的群机器人webhook发送群消息
|
||||
pt_BR: Sending a group message on DingTalk via the webhook of group bot
|
||||
llm: A tool for sending messages to a chat group on DingTalk(钉钉) .
|
||||
parameters:
|
||||
- name: access_token
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: access token
|
||||
zh_Hans: access token
|
||||
pt_BR: access token
|
||||
human_description:
|
||||
en_US: access_token in the group robot webhook
|
||||
zh_Hans: 群自定义机器人webhook中access_token字段的值
|
||||
pt_BR: access_token in the group robot webhook
|
||||
form: form
|
||||
- name: sign_secret
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: secret key for signing
|
||||
zh_Hans: 加签秘钥
|
||||
pt_BR: secret key for signing
|
||||
human_description:
|
||||
en_US: secret key for signing
|
||||
zh_Hans: 加签秘钥
|
||||
pt_BR: secret key for signing
|
||||
form: form
|
||||
- name: content
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: content
|
||||
zh_Hans: 消息内容
|
||||
pt_BR: content
|
||||
human_description:
|
||||
en_US: Content to sent to the group.
|
||||
zh_Hans: 群消息文本
|
||||
pt_BR: Content to sent to the group.
|
||||
llm_description: Content of the message
|
||||
form: llm
|
||||
BIN
api/core/tools/provider/builtin/gaode/_assets/icon.png
Normal file
|
After Width: | Height: | Size: 1.7 KiB |
@ -1 +0,0 @@
|
||||
<svg version="1.1" xmlns="http://www.w3.org/2000/svg" height="1024" width="1024" viewBox="0 0 1024 1024"><path d="M699.052008 894.366428l-253.855434-159.289336-115.571175 114.768472 44.746184-157.983686L887.097839 163.880236 312.470884 651.791088l-205.584597-128.995835L887.097839 163.876212 699.056031 894.364417zM348.039293 321.886051h122.859882L348.039293 374.779976V321.886051z m675.960707 0v-75.373642C1024 109.706813 917.443646 0 782.927466 0H698.090373v224.076951l-80.471512 34.642986V0H242.63167C108.113477 0-0.002012 109.706813-0.002012 246.51442V321.886051h195.143419v80.471513H0v376.276746C0 915.439906 108.115489 1024 242.63167 1024h374.985179v-145.906923l80.471512 51.270412V1024h84.837093C917.445658 1024 1024 915.439906 1024 778.63431V402.357564h-172.255308l20.717391-80.471513H1024z" fill="#0093FD"></path></svg>
|
||||
|
Before Width: | Height: | Size: 828 B |
@ -9,7 +9,7 @@ identity:
|
||||
en_US: Autonavi Open Platform service toolkit.
|
||||
zh_Hans: 高德开放平台服务工具包。
|
||||
pt_BR: Kit de ferramentas de serviço Autonavi Open Platform.
|
||||
icon: icon.svg
|
||||
icon: icon.png
|
||||
credentials_for_provider:
|
||||
api_key:
|
||||
type: secret-input
|
||||
|
||||
BIN
api/core/tools/provider/builtin/github/_assets/icon.png
Normal file
|
After Width: | Height: | Size: 1.4 KiB |
@ -1,17 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
||||
<svg width="800px" height="800px" viewBox="0 0 20 20" version="1.1" xmlns="http://www.w3.org/2000/svg"
|
||||
xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||
<title>github [#142]</title>
|
||||
<desc>Created with Sketch.</desc>
|
||||
<defs>
|
||||
</defs>
|
||||
<g id="Page-1" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||
<g id="Dribbble-Light-Preview" transform="translate(-140.000000, -7559.000000)" fill="#000000">
|
||||
<g id="icons" transform="translate(56.000000, 160.000000)">
|
||||
<path d="M94,7399 C99.523,7399 104,7403.59 104,7409.253 C104,7413.782 101.138,7417.624 97.167,7418.981 C96.66,7419.082 96.48,7418.762 96.48,7418.489 C96.48,7418.151 96.492,7417.047 96.492,7415.675 C96.492,7414.719 96.172,7414.095 95.813,7413.777 C98.04,7413.523 100.38,7412.656 100.38,7408.718 C100.38,7407.598 99.992,7406.684 99.35,7405.966 C99.454,7405.707 99.797,7404.664 99.252,7403.252 C99.252,7403.252 98.414,7402.977 96.505,7404.303 C95.706,7404.076 94.85,7403.962 94,7403.958 C93.15,7403.962 92.295,7404.076 91.497,7404.303 C89.586,7402.977 88.746,7403.252 88.746,7403.252 C88.203,7404.664 88.546,7405.707 88.649,7405.966 C88.01,7406.684 87.619,7407.598 87.619,7408.718 C87.619,7412.646 89.954,7413.526 92.175,7413.785 C91.889,7414.041 91.63,7414.493 91.54,7415.156 C90.97,7415.418 89.522,7415.871 88.63,7414.304 C88.63,7414.304 88.101,7413.319 87.097,7413.247 C87.097,7413.247 86.122,7413.234 87.029,7413.87 C87.029,7413.87 87.684,7414.185 88.139,7415.37 C88.139,7415.37 88.726,7417.2 91.508,7416.58 C91.513,7417.437 91.522,7418.245 91.522,7418.489 C91.522,7418.76 91.338,7419.077 90.839,7418.982 C86.865,7417.627 84,7413.783 84,7409.253 C84,7403.59 88.478,7399 94,7399"
|
||||
id="github-[#142]">
|
||||
</path>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 1.8 KiB |
@ -9,7 +9,7 @@ identity:
|
||||
en_US: GitHub is an online software source code hosting service.
|
||||
zh_Hans: GitHub是一个在线软件源代码托管服务平台。
|
||||
pt_BR: GitHub é uma plataforma online para serviços de hospedagem de código fonte de software.
|
||||
icon: icon.svg
|
||||
icon: icon.png
|
||||
credentials_for_provider:
|
||||
access_tokens:
|
||||
type: secret-input
|
||||
|
||||
@ -1,7 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<svg width="800px" height="800px" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg">
|
||||
<g>
|
||||
<path fill="none" d="M0 0h24v24H0z"/>
|
||||
<path d="M16 17v-1h-3v-3h3v2h2v2h-1v2h-2v2h-2v-3h2v-1h1zm5 4h-4v-2h2v-2h2v4zM3 3h8v8H3V3zm2 2v4h4V5H5zm8-2h8v8h-8V3zm2 2v4h4V5h-4zM3 13h8v8H3v-8zm2 2v4h4v-4H5zm13-2h3v2h-3v-2zM6 6h2v2H6V6zm0 10h2v2H6v-2zM16 6h2v2h-2V6z"/>
|
||||
</g>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 428 B |
@ -1,16 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.qrcode.tools.qrcode_generator import QRCodeGeneratorTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class QRCodeProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
QRCodeGeneratorTool().invoke(user_id='',
|
||||
tool_parameters={
|
||||
'content': 'Dify 123 😊'
|
||||
})
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
@ -1,12 +0,0 @@
|
||||
identity:
|
||||
author: Bowen Liang
|
||||
name: qrcode
|
||||
label:
|
||||
en_US: QRCode
|
||||
zh_Hans: 二维码工具
|
||||
pt_BR: QRCode
|
||||
description:
|
||||
en_US: A tool for generating QR code (quick-response code) image.
|
||||
zh_Hans: 一个二维码工具
|
||||
pt_BR: A tool for generating QR code (quick-response code) image.
|
||||
icon: icon.svg
|
||||
@ -1,69 +0,0 @@
|
||||
import io
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
|
||||
from qrcode.constants import ERROR_CORRECT_H, ERROR_CORRECT_L, ERROR_CORRECT_M, ERROR_CORRECT_Q
|
||||
from qrcode.image.base import BaseImage
|
||||
from qrcode.image.pure import PyPNGImage
|
||||
from qrcode.main import QRCode
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class QRCodeGeneratorTool(BuiltinTool):
|
||||
error_correction_levels = {
|
||||
'L': ERROR_CORRECT_L, # <=7%
|
||||
'M': ERROR_CORRECT_M, # <=15%
|
||||
'Q': ERROR_CORRECT_Q, # <=25%
|
||||
'H': ERROR_CORRECT_H, # <=30%
|
||||
}
|
||||
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
# get text content
|
||||
content = tool_parameters.get('content', '')
|
||||
if not content:
|
||||
return self.create_text_message('Invalid parameter content')
|
||||
|
||||
# get border size
|
||||
border = tool_parameters.get('border', 0)
|
||||
if border < 0 or border > 100:
|
||||
return self.create_text_message('Invalid parameter border')
|
||||
|
||||
# get error_correction
|
||||
error_correction = tool_parameters.get('error_correction', '')
|
||||
if error_correction not in self.error_correction_levels.keys():
|
||||
return self.create_text_message('Invalid parameter error_correction')
|
||||
|
||||
try:
|
||||
image = self._generate_qrcode(content, border, error_correction)
|
||||
image_bytes = self._image_to_byte_array(image)
|
||||
return self.create_blob_message(blob=image_bytes,
|
||||
meta={'mime_type': 'image/png'},
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value)
|
||||
except Exception:
|
||||
logging.exception(f'Failed to generate QR code for content: {content}')
|
||||
return self.create_text_message('Failed to generate QR code')
|
||||
|
||||
def _generate_qrcode(self, content: str, border: int, error_correction: str) -> BaseImage:
|
||||
qr = QRCode(
|
||||
image_factory=PyPNGImage,
|
||||
error_correction=self.error_correction_levels.get(error_correction),
|
||||
border=border,
|
||||
)
|
||||
qr.add_data(data=content)
|
||||
qr.make(fit=True)
|
||||
img = qr.make_image()
|
||||
return img
|
||||
|
||||
@staticmethod
|
||||
def _image_to_byte_array(image: BaseImage) -> bytes:
|
||||
byte_stream = io.BytesIO()
|
||||
image.save(byte_stream)
|
||||
return byte_stream.getvalue()
|
||||
@ -1,76 +0,0 @@
|
||||
identity:
|
||||
name: qrcode_generator
|
||||
author: Bowen Liang
|
||||
label:
|
||||
en_US: Generate QR Code
|
||||
zh_Hans: 生成二维码
|
||||
pt_BR: Generate QR Code
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for generating QR code image
|
||||
zh_Hans: 一个用于生成二维码的工具
|
||||
pt_BR: A tool for generating QR code image
|
||||
llm: A tool for generating QR code image
|
||||
parameters:
|
||||
- name: content
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: content text for QR code
|
||||
zh_Hans: 二维码文本内容
|
||||
pt_BR: content text for QR code
|
||||
human_description:
|
||||
en_US: content text for QR code
|
||||
zh_Hans: 二维码文本内容
|
||||
pt_BR: 二维码文本内容
|
||||
form: llm
|
||||
- name: error_correction
|
||||
type: select
|
||||
required: true
|
||||
default: M
|
||||
label:
|
||||
en_US: Error Correction
|
||||
zh_Hans: 容错等级
|
||||
pt_BR: Error Correction
|
||||
human_description:
|
||||
en_US: Error Correction in L, M, Q or H, from low to high, the bigger size of generated QR code with the better error correction effect
|
||||
zh_Hans: 容错等级,可设置为低、中、偏高或高,从低到高,生成的二维码越大且容错效果越好
|
||||
pt_BR: Error Correction in L, M, Q or H, from low to high, the bigger size of generated QR code with the better error correction effect
|
||||
options:
|
||||
- value: L
|
||||
label:
|
||||
en_US: Low
|
||||
zh_Hans: 低
|
||||
pt_BR: Low
|
||||
- value: M
|
||||
label:
|
||||
en_US: Medium
|
||||
zh_Hans: 中
|
||||
pt_BR: Medium
|
||||
- value: Q
|
||||
label:
|
||||
en_US: Quartile
|
||||
zh_Hans: 偏高
|
||||
pt_BR: Quartile
|
||||
- value: H
|
||||
label:
|
||||
en_US: High
|
||||
zh_Hans: 高
|
||||
pt_BR: High
|
||||
form: form
|
||||
- name: border
|
||||
type: number
|
||||
required: true
|
||||
default: 2
|
||||
min: 0
|
||||
max: 100
|
||||
label:
|
||||
en_US: border size
|
||||
zh_Hans: 边框粗细
|
||||
pt_BR: border size
|
||||
human_description:
|
||||
en_US: border size(default to 2)
|
||||
zh_Hans: 边框粗细的格数(默认为2)
|
||||
pt_BR: border size(default to 2)
|
||||
llm: border size, default to 2
|
||||
form: form
|
||||
@ -2,11 +2,11 @@ import io
|
||||
import json
|
||||
from base64 import b64decode, b64encode
|
||||
from copy import deepcopy
|
||||
from os.path import join
|
||||
from typing import Any, Union
|
||||
|
||||
from httpx import get, post
|
||||
from PIL import Image
|
||||
from yarl import URL
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
|
||||
@ -79,7 +79,7 @@ class StableDiffusionTool(BuiltinTool):
|
||||
|
||||
# set model
|
||||
try:
|
||||
url = str(URL(base_url) / 'sdapi' / 'v1' / 'options')
|
||||
url = join(base_url, 'sdapi/v1/options')
|
||||
response = post(url, data=json.dumps({
|
||||
'sd_model_checkpoint': model
|
||||
}))
|
||||
@ -153,21 +153,8 @@ class StableDiffusionTool(BuiltinTool):
|
||||
if not model:
|
||||
raise ToolProviderCredentialValidationError('Please input model')
|
||||
|
||||
api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models')
|
||||
response = get(url=api_url, timeout=10)
|
||||
if response.status_code == 404:
|
||||
# try draw a picture
|
||||
self._invoke(
|
||||
user_id='test',
|
||||
tool_parameters={
|
||||
'prompt': 'a cat',
|
||||
'width': 1024,
|
||||
'height': 1024,
|
||||
'steps': 1,
|
||||
'lora': '',
|
||||
}
|
||||
)
|
||||
elif response.status_code != 200:
|
||||
response = get(url=f'{base_url}/sdapi/v1/sd-models', timeout=120)
|
||||
if response.status_code != 200:
|
||||
raise ToolProviderCredentialValidationError('Failed to get models')
|
||||
else:
|
||||
models = [d['model_name'] for d in response.json()]
|
||||
@ -178,23 +165,6 @@ class StableDiffusionTool(BuiltinTool):
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(f'Failed to get models, {e}')
|
||||
|
||||
def get_sd_models(self) -> list[str]:
|
||||
"""
|
||||
get sd models
|
||||
"""
|
||||
try:
|
||||
base_url = self.runtime.credentials.get('base_url', None)
|
||||
if not base_url:
|
||||
return []
|
||||
api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models')
|
||||
response = get(url=api_url, timeout=10)
|
||||
if response.status_code != 200:
|
||||
return []
|
||||
else:
|
||||
return [d['model_name'] for d in response.json()]
|
||||
except Exception as e:
|
||||
return []
|
||||
|
||||
def img2img(self, base_url: str, lora: str, image_binary: bytes,
|
||||
prompt: str, negative_prompt: str,
|
||||
width: int, height: int, steps: int) \
|
||||
@ -222,7 +192,7 @@ class StableDiffusionTool(BuiltinTool):
|
||||
draw_options['prompt'] = prompt
|
||||
|
||||
try:
|
||||
url = str(URL(base_url) / 'sdapi' / 'v1' / 'img2img')
|
||||
url = join(base_url, 'sdapi/v1/img2img')
|
||||
response = post(url, data=json.dumps(draw_options), timeout=120)
|
||||
if response.status_code != 200:
|
||||
return self.create_text_message('Failed to generate image')
|
||||
@ -255,7 +225,7 @@ class StableDiffusionTool(BuiltinTool):
|
||||
draw_options['negative_prompt'] = negative_prompt
|
||||
|
||||
try:
|
||||
url = str(URL(base_url) / 'sdapi' / 'v1' / 'txt2img')
|
||||
url = join(base_url, 'sdapi/v1/txt2img')
|
||||
response = post(url, data=json.dumps(draw_options), timeout=120)
|
||||
if response.status_code != 200:
|
||||
return self.create_text_message('Failed to generate image')
|
||||
@ -299,29 +269,5 @@ class StableDiffusionTool(BuiltinTool):
|
||||
label=I18nObject(en_US=i.name, zh_Hans=i.name)
|
||||
) for i in self.list_default_image_variables()])
|
||||
)
|
||||
|
||||
if self.runtime.credentials:
|
||||
try:
|
||||
models = self.get_sd_models()
|
||||
if len(models) != 0:
|
||||
parameters.append(
|
||||
ToolParameter(name='model',
|
||||
label=I18nObject(en_US='Model', zh_Hans='Model'),
|
||||
human_description=I18nObject(
|
||||
en_US='Model of Stable Diffusion, you can check the official documentation of Stable Diffusion',
|
||||
zh_Hans='Stable Diffusion 的模型,您可以查看 Stable Diffusion 的官方文档',
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
llm_description='Model of Stable Diffusion, you can check the official documentation of Stable Diffusion',
|
||||
required=True,
|
||||
default=models[0],
|
||||
options=[ToolParameterOption(
|
||||
value=i,
|
||||
label=I18nObject(en_US=i, zh_Hans=i)
|
||||
) for i in models])
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
return parameters
|
||||
|
||||
@ -1 +0,0 @@
|
||||
<svg width="2500" height="2500" viewBox="0 0 256 256" xmlns="http://www.w3.org/2000/svg" preserveAspectRatio="xMidYMid"><g fill="#CF272D"><path d="M127.86 222.304c-52.005 0-94.164-42.159-94.164-94.163 0-52.005 42.159-94.163 94.164-94.163 52.004 0 94.162 42.158 94.162 94.163 0 52.004-42.158 94.163-94.162 94.163zm0-222.023C57.245.281 0 57.527 0 128.141 0 198.756 57.245 256 127.86 256c70.614 0 127.859-57.244 127.859-127.859 0-70.614-57.245-127.86-127.86-127.86z"/><path d="M133.116 96.297c0-14.682 11.903-26.585 26.586-26.585 14.683 0 26.585 11.903 26.585 26.585 0 14.684-11.902 26.586-26.585 26.586-14.683 0-26.586-11.902-26.586-26.586M133.116 159.983c0-14.682 11.903-26.586 26.586-26.586 14.683 0 26.585 11.904 26.585 26.586 0 14.683-11.902 26.586-26.585 26.586-14.683 0-26.586-11.903-26.586-26.586M69.431 159.983c0-14.682 11.904-26.586 26.586-26.586 14.683 0 26.586 11.904 26.586 26.586 0 14.683-11.903 26.586-26.586 26.586-14.682 0-26.586-11.903-26.586-26.586M69.431 96.298c0-14.683 11.904-26.585 26.586-26.585 14.683 0 26.586 11.902 26.586 26.585 0 14.684-11.903 26.586-26.586 26.586-14.682 0-26.586-11.902-26.586-26.586"/></g></svg>
|
||||
|
Before Width: | Height: | Size: 1.1 KiB |
@ -1,41 +0,0 @@
|
||||
from typing import Any, Union
|
||||
|
||||
from langchain.utilities import TwilioAPIWrapper
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class SendMessageTool(BuiltinTool):
|
||||
"""
|
||||
A tool for sending messages using Twilio API.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user invoking the tool.
|
||||
tool_parameters (Dict[str, Any]): The parameters required for sending the message.
|
||||
|
||||
Returns:
|
||||
Union[ToolInvokeMessage, List[ToolInvokeMessage]]: The result of invoking the tool, which includes the status of the message sending operation.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
account_sid = self.runtime.credentials["account_sid"]
|
||||
auth_token = self.runtime.credentials["auth_token"]
|
||||
from_number = self.runtime.credentials["from_number"]
|
||||
|
||||
message = tool_parameters["message"]
|
||||
to_number = tool_parameters["to_number"]
|
||||
|
||||
if to_number.startswith("whatsapp:"):
|
||||
from_number = f"whatsapp: {from_number}"
|
||||
|
||||
twilio = TwilioAPIWrapper(
|
||||
account_sid=account_sid, auth_token=auth_token, from_number=from_number
|
||||
)
|
||||
|
||||
# Sending the message through Twilio
|
||||
result = twilio.run(message, to_number)
|
||||
|
||||
return self.create_text_message(text="Message sent successfully.")
|
||||
@ -1,40 +0,0 @@
|
||||
identity:
|
||||
name: send_message
|
||||
author: Yash Parmar
|
||||
label:
|
||||
en_US: SendMessage
|
||||
zh_Hans: 发送消息
|
||||
pt_BR: SendMessage
|
||||
description:
|
||||
human:
|
||||
en_US: Send SMS or Twilio Messaging Channels messages.
|
||||
zh_Hans: 发送SMS或Twilio消息通道消息。
|
||||
pt_BR: Send SMS or Twilio Messaging Channels messages.
|
||||
llm: Send SMS or Twilio Messaging Channels messages. Supports different channels including WhatsApp.
|
||||
parameters:
|
||||
- name: message
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Message
|
||||
zh_Hans: 消息内容
|
||||
pt_BR: Message
|
||||
human_description:
|
||||
en_US: The content of the message to be sent.
|
||||
zh_Hans: 要发送的消息内容。
|
||||
pt_BR: The content of the message to be sent.
|
||||
llm_description: The content of the message to be sent.
|
||||
form: llm
|
||||
- name: to_number
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: To Number
|
||||
zh_Hans: 收信号码
|
||||
pt_BR: Para Número
|
||||
human_description:
|
||||
en_US: The recipient's phone number. Prefix with 'whatsapp:' for WhatsApp messages, e.g., "whatsapp:+1234567890".
|
||||
zh_Hans: 收件人的电话号码。WhatsApp消息前缀为'whatsapp:',例如,"whatsapp:+1234567890"。
|
||||
pt_BR: The recipient's phone number. Prefix with 'whatsapp:' for WhatsApp messages, e.g., "whatsapp:+1234567890".
|
||||
llm_description: The recipient's phone number. Prefix with 'whatsapp:' for WhatsApp messages, e.g., "whatsapp:+1234567890".
|
||||
form: llm
|
||||