mirror of
https://github.com/langgenius/dify.git
synced 2026-02-13 06:45:21 +08:00
Compare commits
52 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| eabfd84ceb | |||
| d320d1468d | |||
| b47fa27a35 | |||
| 68ad9a91b2 | |||
| c17a4165c1 | |||
| 96c171805a | |||
| 9a536979ab | |||
| 46a5294d94 | |||
| ec181649ae | |||
| 4fdcb30ff8 | |||
| 07add06c59 | |||
| a7b33b55e8 | |||
| 0cbbaf3f68 | |||
| c564f32ab6 | |||
| 7c2c949f01 | |||
| 066168da52 | |||
| 1df71ec64d | |||
| a9ee52f2d7 | |||
| 7b225a5ab0 | |||
| d7a6f25c63 | |||
| f46792334c | |||
| ee3936916f | |||
| 109de52fe2 | |||
| 10dd0f3fa0 | |||
| 2f064c68bc | |||
| 079583eaa4 | |||
| 0e82072323 | |||
| 678ad6b7eb | |||
| 63e34e5227 | |||
| c606295ea6 | |||
| 27d72e30ad | |||
| 5660878f7b | |||
| 12e55b2cac | |||
| 97e094dfd8 | |||
| 9622fbb62f | |||
| cc8dc6d35e | |||
| 215661ef91 | |||
| 5a3e09518c | |||
| ebba124c5c | |||
| a62325ac87 | |||
| 1d2ab2126c | |||
| b07dea836c | |||
| f9d00e0498 | |||
| 757ceda063 | |||
| d27e3ab99d | |||
| ce930f19b9 | |||
| 3b14939d66 | |||
| 279caf033c | |||
| eff280f3e7 | |||
| 7c70eb87bc | |||
| 6ef401a9f0 | |||
| b29a36f461 |
3
.github/workflows/api-tests.yml
vendored
3
.github/workflows/api-tests.yml
vendored
@ -75,7 +75,7 @@ jobs:
|
||||
- name: Run Workflow
|
||||
run: poetry run -C api bash dev/pytest/pytest_workflow.sh
|
||||
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma)
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale)
|
||||
uses: hoverkraft-tech/compose-action@v2.0.0
|
||||
with:
|
||||
compose-file: |
|
||||
@ -89,5 +89,6 @@ jobs:
|
||||
pgvecto-rs
|
||||
pgvector
|
||||
chroma
|
||||
myscale
|
||||
- name: Test Vector Stores
|
||||
run: poetry run -C api bash dev/pytest/pytest_vdb.sh
|
||||
|
||||
@ -83,7 +83,7 @@ OCI_REGION=your-region
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector
|
||||
VECTOR_STORE=weaviate
|
||||
|
||||
# Weaviate configuration
|
||||
@ -106,6 +106,14 @@ MILVUS_USER=root
|
||||
MILVUS_PASSWORD=Milvus
|
||||
MILVUS_SECURE=false
|
||||
|
||||
# MyScale configuration
|
||||
MYSCALE_HOST=127.0.0.1
|
||||
MYSCALE_PORT=8123
|
||||
MYSCALE_USER=default
|
||||
MYSCALE_PASSWORD=
|
||||
MYSCALE_DATABASE=default
|
||||
MYSCALE_FTS_PARAMS=
|
||||
|
||||
# Relyt configuration
|
||||
RELYT_HOST=127.0.0.1
|
||||
RELYT_PORT=5432
|
||||
@ -151,6 +159,16 @@ CHROMA_DATABASE=default_database
|
||||
CHROMA_AUTH_PROVIDER=chromadb.auth.token_authn.TokenAuthenticationServerProvider
|
||||
CHROMA_AUTH_CREDENTIALS=difyai123456
|
||||
|
||||
# AnalyticDB configuration
|
||||
ANALYTICDB_KEY_ID=your-ak
|
||||
ANALYTICDB_KEY_SECRET=your-sk
|
||||
ANALYTICDB_REGION_ID=cn-hangzhou
|
||||
ANALYTICDB_INSTANCE_ID=gp-ab123456
|
||||
ANALYTICDB_ACCOUNT=testaccount
|
||||
ANALYTICDB_PASSWORD=testpassword
|
||||
ANALYTICDB_NAMESPACE=dify
|
||||
ANALYTICDB_NAMESPACE_PASSWORD=difypassword
|
||||
|
||||
# OpenSearch configuration
|
||||
OPENSEARCH_HOST=127.0.0.1
|
||||
OPENSEARCH_PORT=9200
|
||||
@ -237,4 +255,4 @@ WORKFLOW_CALL_MAX_DEPTH=5
|
||||
|
||||
# App configuration
|
||||
APP_MAX_EXECUTION_TIME=1200
|
||||
|
||||
APP_MAX_ACTIVE_REQUESTS=0
|
||||
|
||||
@ -337,6 +337,14 @@ def migrate_knowledge_vector_database():
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == VectorType.ANALYTICDB:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": VectorType.ANALYTICDB,
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
else:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
|
||||
@ -31,6 +31,10 @@ class AppExecutionConfig(BaseSettings):
|
||||
description='execution timeout in seconds for app execution',
|
||||
default=1200,
|
||||
)
|
||||
APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field(
|
||||
description='max active request per app, 0 means unlimited',
|
||||
default=0,
|
||||
)
|
||||
|
||||
|
||||
class CodeExecutionSandboxConfig(BaseSettings):
|
||||
@ -396,6 +400,11 @@ class DataSetConfig(BaseSettings):
|
||||
default=30,
|
||||
)
|
||||
|
||||
DATASET_OPERATOR_ENABLED: bool = Field(
|
||||
description='whether to enable dataset operator',
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class WorkspaceConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
@ -10,8 +10,10 @@ from configs.middleware.storage.azure_blob_storage_config import AzureBlobStorag
|
||||
from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig
|
||||
from configs.middleware.storage.oci_storage_config import OCIStorageConfig
|
||||
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
||||
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
|
||||
from configs.middleware.vdb.chroma_config import ChromaConfig
|
||||
from configs.middleware.vdb.milvus_config import MilvusConfig
|
||||
from configs.middleware.vdb.myscale_config import MyScaleConfig
|
||||
from configs.middleware.vdb.opensearch_config import OpenSearchConfig
|
||||
from configs.middleware.vdb.oracle_config import OracleConfig
|
||||
from configs.middleware.vdb.pgvector_config import PGVectorConfig
|
||||
@ -183,8 +185,10 @@ class MiddlewareConfig(
|
||||
|
||||
# configs of vdb and vdb providers
|
||||
VectorStoreConfig,
|
||||
AnalyticdbConfig,
|
||||
ChromaConfig,
|
||||
MilvusConfig,
|
||||
MyScaleConfig,
|
||||
OpenSearchConfig,
|
||||
OracleConfig,
|
||||
PGVectorConfig,
|
||||
|
||||
44
api/configs/middleware/vdb/analyticdb_config.py
Normal file
44
api/configs/middleware/vdb/analyticdb_config.py
Normal file
@ -0,0 +1,44 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AnalyticdbConfig(BaseModel):
|
||||
"""
|
||||
Configuration for connecting to AnalyticDB.
|
||||
Refer to the following documentation for details on obtaining credentials:
|
||||
https://www.alibabacloud.com/help/en/analyticdb-for-postgresql/getting-started/create-an-instance-instances-with-vector-engine-optimization-enabled
|
||||
"""
|
||||
|
||||
ANALYTICDB_KEY_ID : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Access Key ID provided by Alibaba Cloud for authentication."
|
||||
)
|
||||
ANALYTICDB_KEY_SECRET : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Secret Access Key corresponding to the Access Key ID for secure access."
|
||||
)
|
||||
ANALYTICDB_REGION_ID : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')."
|
||||
)
|
||||
ANALYTICDB_INSTANCE_ID : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456').."
|
||||
)
|
||||
ANALYTICDB_ACCOUNT : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The account name used to log in to the AnalyticDB instance."
|
||||
)
|
||||
ANALYTICDB_PASSWORD : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The password associated with the AnalyticDB account for authentication."
|
||||
)
|
||||
ANALYTICDB_NAMESPACE : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The namespace within AnalyticDB for schema isolation."
|
||||
)
|
||||
ANALYTICDB_NAMESPACE_PASSWORD : Optional[str] = Field(
|
||||
default=None,
|
||||
description="The password for accessing the specified namespace within the AnalyticDB instance."
|
||||
)
|
||||
39
api/configs/middleware/vdb/myscale_config.py
Normal file
39
api/configs/middleware/vdb/myscale_config.py
Normal file
@ -0,0 +1,39 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
|
||||
|
||||
class MyScaleConfig(BaseModel):
|
||||
"""
|
||||
MyScale configs
|
||||
"""
|
||||
|
||||
MYSCALE_HOST: Optional[str] = Field(
|
||||
description='MyScale host',
|
||||
default=None,
|
||||
)
|
||||
|
||||
MYSCALE_PORT: Optional[PositiveInt] = Field(
|
||||
description='MyScale port',
|
||||
default=8123,
|
||||
)
|
||||
|
||||
MYSCALE_USER: Optional[str] = Field(
|
||||
description='MyScale user',
|
||||
default=None,
|
||||
)
|
||||
|
||||
MYSCALE_PASSWORD: Optional[str] = Field(
|
||||
description='MyScale password',
|
||||
default=None,
|
||||
)
|
||||
|
||||
MYSCALE_DATABASE: Optional[str] = Field(
|
||||
description='MyScale database name',
|
||||
default=None,
|
||||
)
|
||||
|
||||
MYSCALE_FTS_PARAMS: Optional[str] = Field(
|
||||
description='MyScale fts index parameters',
|
||||
default=None,
|
||||
)
|
||||
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description='Dify version',
|
||||
default='0.6.13',
|
||||
default='0.6.14',
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
4
api/constants/tts_auto_play_timeout.py
Normal file
4
api/constants/tts_auto_play_timeout.py
Normal file
@ -0,0 +1,4 @@
|
||||
TTS_AUTO_PLAY_TIMEOUT = 5
|
||||
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
TTS_AUTO_PLAY_YIELD_CPU_TIME = 0.02
|
||||
@ -15,6 +15,7 @@ from fields.app_fields import (
|
||||
app_pagination_fields,
|
||||
)
|
||||
from libs.login import login_required
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_service import AppService
|
||||
|
||||
ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion']
|
||||
@ -97,8 +98,42 @@ class AppImportApi(Resource):
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.import_app(current_user.current_tenant_id, args['data'], args, current_user)
|
||||
app = AppDslService.import_and_create_new_app(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data=args['data'],
|
||||
args=args,
|
||||
account=current_user
|
||||
)
|
||||
|
||||
return app, 201
|
||||
|
||||
|
||||
class AppImportFromUrlApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
@cloud_edition_billing_resource_check('apps')
|
||||
def post(self):
|
||||
"""Import app from url"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('url', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('name', type=str, location='json')
|
||||
parser.add_argument('description', type=str, location='json')
|
||||
parser.add_argument('icon', type=str, location='json')
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
app = AppDslService.import_and_create_new_app_from_url(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
url=args['url'],
|
||||
args=args,
|
||||
account=current_user
|
||||
)
|
||||
|
||||
return app, 201
|
||||
|
||||
@ -134,6 +169,7 @@ class AppApi(Resource):
|
||||
parser.add_argument('description', type=str, location='json')
|
||||
parser.add_argument('icon', type=str, location='json')
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
parser.add_argument('max_active_requests', type=int, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
@ -176,9 +212,13 @@ class AppCopyApi(Resource):
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
data = app_service.export_app(app_model)
|
||||
app = app_service.import_app(current_user.current_tenant_id, data, args, current_user)
|
||||
data = AppDslService.export_dsl(app_model=app_model)
|
||||
app = AppDslService.import_and_create_new_app(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data=data,
|
||||
args=args,
|
||||
account=current_user
|
||||
)
|
||||
|
||||
return app, 201
|
||||
|
||||
@ -194,10 +234,8 @@ class AppExportApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
return {
|
||||
"data": app_service.export_app(app_model)
|
||||
"data": AppDslService.export_dsl(app_model=app_model)
|
||||
}
|
||||
|
||||
|
||||
@ -321,6 +359,7 @@ class AppTraceApi(Resource):
|
||||
|
||||
api.add_resource(AppListApi, '/apps')
|
||||
api.add_resource(AppImportApi, '/apps/import')
|
||||
api.add_resource(AppImportFromUrlApi, '/apps/import/url')
|
||||
api.add_resource(AppApi, '/apps/<uuid:app_id>')
|
||||
api.add_resource(AppCopyApi, '/apps/<uuid:app_id>/copy')
|
||||
api.add_resource(AppExportApi, '/apps/<uuid:app_id>/export')
|
||||
|
||||
@ -81,15 +81,36 @@ class ChatMessageTextApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def post(self, app_model):
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('message_id', type=str, location='json')
|
||||
parser.add_argument('text', type=str, location='json')
|
||||
parser.add_argument('voice', type=str, location='json')
|
||||
parser.add_argument('streaming', type=bool, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get('message_id', None)
|
||||
text = args.get('text', None)
|
||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict):
|
||||
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
|
||||
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
|
||||
else:
|
||||
try:
|
||||
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get(
|
||||
'voice')
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
text=request.form['text'],
|
||||
voice=request.form['voice'],
|
||||
streaming=False
|
||||
text=text,
|
||||
message_id=message_id,
|
||||
voice=voice
|
||||
)
|
||||
|
||||
return {'data': response.data.decode('latin1')}
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
|
||||
@ -19,7 +19,12 @@ from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.errors.error import (
|
||||
AppInvokeQuotaExceededError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
@ -75,7 +80,7 @@ class CompletionMessageApi(Resource):
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
@ -141,7 +146,7 @@ class ChatMessageApi(Resource):
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
|
||||
@ -13,12 +13,14 @@ from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import AppInvokeQuotaExceededError
|
||||
from fields.workflow_fields import workflow_fields
|
||||
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from models.model import App, AppMode
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.workflow_service import WorkflowService
|
||||
@ -127,8 +129,7 @@ class DraftWorkflowImportApi(Resource):
|
||||
parser.add_argument('data', type=str, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.import_draft_workflow(
|
||||
workflow = AppDslService.import_and_overwrite_workflow(
|
||||
app_model=app_model,
|
||||
data=args['data'],
|
||||
account=current_user
|
||||
@ -279,7 +280,7 @@ class DraftWorkflowRunApi(Resource):
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except ValueError as e:
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
|
||||
@ -25,7 +25,7 @@ from fields.document_fields import document_status_fields
|
||||
from libs.login import login_required
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.model import ApiToken, UploadFile
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
@ -85,6 +85,12 @@ class DatasetListApi(Resource):
|
||||
else:
|
||||
item['embedding_available'] = True
|
||||
|
||||
if item.get('permission') == 'partial_members':
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item['id'])
|
||||
item.update({'partial_member_list': part_users_list})
|
||||
else:
|
||||
item.update({'partial_member_list': []})
|
||||
|
||||
response = {
|
||||
'data': data,
|
||||
'has_more': len(datasets) == limit,
|
||||
@ -108,8 +114,8 @@ class DatasetListApi(Resource):
|
||||
help='Invalid indexing technique.')
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
@ -140,6 +146,10 @@ class DatasetApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
data = marshal(dataset, dataset_detail_fields)
|
||||
if data.get('permission') == 'partial_members':
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
data.update({'partial_member_list': part_users_list})
|
||||
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
configurations = provider_manager.get_configurations(
|
||||
@ -163,6 +173,11 @@ class DatasetApi(Resource):
|
||||
data['embedding_available'] = False
|
||||
else:
|
||||
data['embedding_available'] = True
|
||||
|
||||
if data.get('permission') == 'partial_members':
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
data.update({'partial_member_list': part_users_list})
|
||||
|
||||
return data, 200
|
||||
|
||||
@setup_required
|
||||
@ -188,17 +203,21 @@ class DatasetApi(Resource):
|
||||
nullable=True,
|
||||
help='Invalid indexing technique.')
|
||||
parser.add_argument('permission', type=str, location='json', choices=(
|
||||
'only_me', 'all_team_members'), help='Invalid permission.')
|
||||
'only_me', 'all_team_members', 'partial_members'), help='Invalid permission.'
|
||||
)
|
||||
parser.add_argument('embedding_model', type=str,
|
||||
location='json', help='Invalid embedding model.')
|
||||
parser.add_argument('embedding_model_provider', type=str,
|
||||
location='json', help='Invalid embedding model provider.')
|
||||
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
|
||||
parser.add_argument('partial_member_list', type=list, location='json', help='Invalid parent user list.')
|
||||
args = parser.parse_args()
|
||||
data = request.get_json()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
DatasetPermissionService.check_permission(
|
||||
current_user, dataset, data.get('permission'), data.get('partial_member_list')
|
||||
)
|
||||
|
||||
dataset = DatasetService.update_dataset(
|
||||
dataset_id_str, args, current_user)
|
||||
@ -206,7 +225,20 @@ class DatasetApi(Resource):
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
return marshal(dataset, dataset_detail_fields), 200
|
||||
result_data = marshal(dataset, dataset_detail_fields)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
if data.get('partial_member_list') and data.get('permission') == 'partial_members':
|
||||
DatasetPermissionService.update_partial_member_list(
|
||||
tenant_id, dataset_id_str, data.get('partial_member_list')
|
||||
)
|
||||
else:
|
||||
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
||||
|
||||
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
result_data.update({'partial_member_list': partial_member_list})
|
||||
|
||||
return result_data, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -215,11 +247,12 @@ class DatasetApi(Resource):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.is_editor or current_user.is_dataset_operator:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
if DatasetService.delete_dataset(dataset_id_str, current_user):
|
||||
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
||||
return {'result': 'success'}, 204
|
||||
else:
|
||||
raise NotFound("Dataset not found.")
|
||||
@ -515,7 +548,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
RetrievalMethod.SEMANTIC_SEARCH
|
||||
]
|
||||
}
|
||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH:
|
||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
RetrievalMethod.SEMANTIC_SEARCH,
|
||||
@ -539,7 +572,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
RetrievalMethod.SEMANTIC_SEARCH
|
||||
]
|
||||
}
|
||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH:
|
||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
RetrievalMethod.SEMANTIC_SEARCH,
|
||||
@ -569,6 +602,27 @@ class DatasetErrorDocs(Resource):
|
||||
}, 200
|
||||
|
||||
|
||||
class DatasetPermissionUserListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
|
||||
return {
|
||||
'data': partial_members_list,
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(DatasetListApi, '/datasets')
|
||||
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
|
||||
api.add_resource(DatasetUseCheckApi, '/datasets/<uuid:dataset_id>/use-check')
|
||||
@ -582,3 +636,4 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
|
||||
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
|
||||
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
|
||||
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
|
||||
api.add_resource(DatasetPermissionUserListApi, '/datasets/<uuid:dataset_id>/permission-part-users')
|
||||
|
||||
@ -228,7 +228,7 @@ class DatasetDocumentListApi(Resource):
|
||||
raise NotFound('Dataset not found.')
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
@ -294,6 +294,11 @@ class DatasetInitApi(Resource):
|
||||
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
|
||||
location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if args['indexing_technique'] == 'high_quality':
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
@ -757,14 +762,18 @@ class DocumentStatusApi(DocumentResource):
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
# check user's permission
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
indexing_cache_key = 'document_{}_indexing'.format(document.id)
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
@ -955,10 +964,11 @@ class DocumentRenameApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
@marshal_with(document_fields)
|
||||
def post(self, dataset_id, document_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
DatasetService.check_dataset_operator_permission(current_user, dataset)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -19,6 +19,7 @@ from controllers.console.app.error import (
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import AppMode
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
@ -70,16 +71,33 @@ class ChatAudioApi(InstalledAppResource):
|
||||
|
||||
class ChatTextApi(InstalledAppResource):
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
from flask_restful import reqparse
|
||||
|
||||
app_model = installed_app.app
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('message_id', type=str, required=False, location='json')
|
||||
parser.add_argument('voice', type=str, location='json')
|
||||
parser.add_argument('streaming', type=bool, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get('message_id')
|
||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict):
|
||||
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
|
||||
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
|
||||
else:
|
||||
try:
|
||||
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice')
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
text=request.form['text'],
|
||||
voice=request.form['voice'] if request.form.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
streaming=False
|
||||
message_id=message_id,
|
||||
voice=voice
|
||||
)
|
||||
return {'data': response.data.decode('latin1')}
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
@ -108,3 +126,5 @@ class ChatTextApi(InstalledAppResource):
|
||||
|
||||
api.add_resource(ChatAudioApi, '/installed-apps/<uuid:installed_app_id>/audio-to-text', endpoint='installed_app_audio')
|
||||
api.add_resource(ChatTextApi, '/installed-apps/<uuid:installed_app_id>/text-to-audio', endpoint='installed_app_text')
|
||||
# api.add_resource(ChatTextApiWithMessageId, '/installed-apps/<uuid:installed_app_id>/text-to-audio/message-id',
|
||||
# endpoint='installed_app_text_with_message_id')
|
||||
|
||||
@ -36,7 +36,7 @@ class TagListApi(Resource):
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@ -68,7 +68,7 @@ class TagUpdateDeleteApi(Resource):
|
||||
def patch(self, tag_id):
|
||||
tag_id = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@ -109,8 +109,8 @@ class TagBindingCreateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@ -134,8 +134,8 @@ class TagBindingDeleteApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
@ -131,7 +131,20 @@ class MemberUpdateRoleApi(Resource):
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
class DatasetOperatorMemberListApi(Resource):
|
||||
"""List all members of current tenant."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_with_role_list_fields)
|
||||
def get(self):
|
||||
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
|
||||
return {'result': 'success', 'accounts': members}, 200
|
||||
|
||||
|
||||
api.add_resource(MemberListApi, '/workspaces/current/members')
|
||||
api.add_resource(MemberInviteEmailApi, '/workspaces/current/members/invite-email')
|
||||
api.add_resource(MemberCancelInviteApi, '/workspaces/current/members/<uuid:member_id>')
|
||||
api.add_resource(MemberUpdateRoleApi, '/workspaces/current/members/<uuid:member_id>/update-role')
|
||||
api.add_resource(DatasetOperatorMemberListApi, '/workspaces/current/dataset-operators')
|
||||
|
||||
@ -3,8 +3,9 @@ from functools import wraps
|
||||
from hashlib import sha1
|
||||
from hmac import new as hmac_new
|
||||
|
||||
from flask import abort, current_app, request
|
||||
from flask import abort, request
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser
|
||||
|
||||
@ -12,12 +13,12 @@ from models.model import EndUser
|
||||
def inner_api_only(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not current_app.config['INNER_API']:
|
||||
if not dify_config.INNER_API:
|
||||
abort(404)
|
||||
|
||||
# get header 'X-Inner-Api-Key'
|
||||
inner_api_key = request.headers.get('X-Inner-Api-Key')
|
||||
if not inner_api_key or inner_api_key != current_app.config['INNER_API_KEY']:
|
||||
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY:
|
||||
abort(404)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
@ -28,7 +29,7 @@ def inner_api_only(view):
|
||||
def inner_api_user_auth(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not current_app.config['INNER_API']:
|
||||
if not dify_config.INNER_API:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
# get header 'X-Inner-Api-Key'
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
|
||||
from flask import current_app
|
||||
from flask_restful import Resource, fields, marshal_with
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
@ -78,7 +78,7 @@ class AppParameterApi(Resource):
|
||||
"transfer_methods": ["remote_url", "local_file"]
|
||||
}}),
|
||||
'system_parameters': {
|
||||
'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
|
||||
'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -20,7 +20,7 @@ from controllers.service_api.app.error import (
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import App, EndUser
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
@ -72,19 +72,30 @@ class AudioApi(Resource):
|
||||
class TextApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('text', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('voice', type=str, location='json')
|
||||
parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('message_id', type=str, required=False, location='json')
|
||||
parser.add_argument('voice', type=str, location='json')
|
||||
parser.add_argument('streaming', type=bool, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get('message_id')
|
||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict):
|
||||
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
|
||||
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
|
||||
else:
|
||||
try:
|
||||
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get(
|
||||
'voice')
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
text=args['text'],
|
||||
end_user=end_user,
|
||||
voice=args.get('voice'),
|
||||
streaming=args['streaming']
|
||||
message_id=message_id,
|
||||
end_user=end_user.external_user_id,
|
||||
voice=voice
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@ -17,7 +17,12 @@ from controllers.service_api.app.error import (
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.errors.error import (
|
||||
AppInvokeQuotaExceededError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
@ -69,7 +74,7 @@ class CompletionApi(Resource):
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
@ -132,7 +137,7 @@ class ChatApi(Resource):
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
|
||||
@ -14,7 +14,12 @@ from controllers.service_api.app.error import (
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.errors.error import (
|
||||
AppInvokeQuotaExceededError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from models.model import App, AppMode, EndUser
|
||||
@ -59,7 +64,7 @@ class WorkflowRunApi(Resource):
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from flask import current_app
|
||||
from flask_restful import Resource
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.service_api import api
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ class IndexApi(Resource):
|
||||
return {
|
||||
"welcome": "Dify OpenAPI",
|
||||
"api_version": "v1",
|
||||
"server_version": current_app.config['CURRENT_VERSION']
|
||||
"server_version": dify_config.CURRENT_VERSION,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from flask import current_app
|
||||
from flask_restful import fields, marshal_with
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.web import api
|
||||
from controllers.web.error import AppUnavailableError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
@ -75,7 +75,7 @@ class AppParameterApi(WebApiResource):
|
||||
"transfer_methods": ["remote_url", "local_file"]
|
||||
}}),
|
||||
'system_parameters': {
|
||||
'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
|
||||
'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ from controllers.web.error import (
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import App
|
||||
from models.model import App, AppMode
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
@ -69,16 +69,35 @@ class AudioApi(WebApiResource):
|
||||
|
||||
class TextApi(WebApiResource):
|
||||
def post(self, app_model: App, end_user):
|
||||
from flask_restful import reqparse
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('message_id', type=str, required=False, location='json')
|
||||
parser.add_argument('voice', type=str, location='json')
|
||||
parser.add_argument('streaming', type=bool, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get('message_id')
|
||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict):
|
||||
text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
|
||||
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
|
||||
else:
|
||||
try:
|
||||
voice = args.get('voice') if args.get(
|
||||
'voice') else app_model.app_model_config.text_to_speech_dict.get('voice')
|
||||
except Exception:
|
||||
voice = None
|
||||
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
text=request.form['text'],
|
||||
message_id=message_id,
|
||||
end_user=end_user.external_user_id,
|
||||
voice=request.form['voice'] if request.form.get('voice') else None,
|
||||
streaming=False
|
||||
voice=voice
|
||||
)
|
||||
|
||||
return {'data': response.data.decode('latin1')}
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
|
||||
from flask import current_app
|
||||
from flask_restful import fields, marshal_with
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.web import api
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from extensions.ext_database import db
|
||||
@ -84,7 +84,7 @@ class AppSiteInfo:
|
||||
self.can_replace_logo = can_replace_logo
|
||||
|
||||
if can_replace_logo:
|
||||
base_url = current_app.config.get('FILES_URL')
|
||||
base_url = dify_config.FILES_URL
|
||||
remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False)
|
||||
replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None
|
||||
self.custom_config = {
|
||||
|
||||
135
api/core/app/apps/advanced_chat/app_generator_tts_publisher.py
Normal file
135
api/core/app/apps/advanced_chat/app_generator_tts_publisher.py
Normal file
@ -0,0 +1,135 @@
|
||||
import base64
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import queue
|
||||
import re
|
||||
import threading
|
||||
|
||||
from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueTextChunkEvent
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
|
||||
class AudioTrunk:
|
||||
def __init__(self, status: str, audio):
|
||||
self.audio = audio
|
||||
self.status = status
|
||||
|
||||
|
||||
def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str):
|
||||
if not text_content or text_content.isspace():
|
||||
return
|
||||
return model_instance.invoke_tts(
|
||||
content_text=text_content.strip(),
|
||||
user="responding_tts",
|
||||
tenant_id=tenant_id,
|
||||
voice=voice
|
||||
)
|
||||
|
||||
|
||||
def _process_future(future_queue, audio_queue):
|
||||
while True:
|
||||
try:
|
||||
future = future_queue.get()
|
||||
if future is None:
|
||||
break
|
||||
for audio in future.result():
|
||||
audio_base64 = base64.b64encode(bytes(audio))
|
||||
audio_queue.put(AudioTrunk("responding", audio=audio_base64))
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(e)
|
||||
break
|
||||
audio_queue.put(AudioTrunk("finish", b''))
|
||||
|
||||
|
||||
class AppGeneratorTTSPublisher:
|
||||
|
||||
def __init__(self, tenant_id: str, voice: str):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.tenant_id = tenant_id
|
||||
self.msg_text = ''
|
||||
self._audio_queue = queue.Queue()
|
||||
self._msg_queue = queue.Queue()
|
||||
self.match = re.compile(r'[。.!?]')
|
||||
self.model_manager = ModelManager()
|
||||
self.model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
model_type=ModelType.TTS
|
||||
)
|
||||
self.voices = self.model_instance.get_tts_voices()
|
||||
values = [voice.get('value') for voice in self.voices]
|
||||
self.voice = voice
|
||||
if not voice or voice not in values:
|
||||
self.voice = self.voices[0].get('value')
|
||||
self.MAX_SENTENCE = 2
|
||||
self._last_audio_event = None
|
||||
self._runtime_thread = threading.Thread(target=self._runtime).start()
|
||||
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
|
||||
|
||||
def publish(self, message):
|
||||
try:
|
||||
self._msg_queue.put(message)
|
||||
except Exception as e:
|
||||
self.logger.warning(e)
|
||||
|
||||
def _runtime(self):
|
||||
future_queue = queue.Queue()
|
||||
threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start()
|
||||
while True:
|
||||
try:
|
||||
message = self._msg_queue.get()
|
||||
if message is None:
|
||||
if self.msg_text and len(self.msg_text.strip()) > 0:
|
||||
futures_result = self.executor.submit(_invoiceTTS, self.msg_text,
|
||||
self.model_instance, self.tenant_id, self.voice)
|
||||
future_queue.put(futures_result)
|
||||
break
|
||||
elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent):
|
||||
self.msg_text += message.event.chunk.delta.message.content
|
||||
elif isinstance(message.event, QueueTextChunkEvent):
|
||||
self.msg_text += message.event.text
|
||||
self.last_message = message
|
||||
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
|
||||
if len(sentence_arr) >= min(self.MAX_SENTENCE, 7):
|
||||
self.MAX_SENTENCE += 1
|
||||
text_content = ''.join(sentence_arr)
|
||||
futures_result = self.executor.submit(_invoiceTTS, text_content,
|
||||
self.model_instance,
|
||||
self.tenant_id,
|
||||
self.voice)
|
||||
future_queue.put(futures_result)
|
||||
if text_tmp:
|
||||
self.msg_text = text_tmp
|
||||
else:
|
||||
self.msg_text = ''
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(e)
|
||||
break
|
||||
future_queue.put(None)
|
||||
|
||||
def checkAndGetAudio(self) -> AudioTrunk | None:
|
||||
try:
|
||||
if self._last_audio_event and self._last_audio_event.status == "finish":
|
||||
if self.executor:
|
||||
self.executor.shutdown(wait=False)
|
||||
return self.last_message
|
||||
audio = self._audio_queue.get_nowait()
|
||||
if audio and audio.status == "finish":
|
||||
self.executor.shutdown(wait=False)
|
||||
self._runtime_thread = None
|
||||
if audio:
|
||||
self._last_audio_event = audio
|
||||
return audio
|
||||
except queue.Empty:
|
||||
return None
|
||||
|
||||
def _extract_sentence(self, org_text):
|
||||
tx = self.match.finditer(org_text)
|
||||
start = 0
|
||||
result = []
|
||||
for i in tx:
|
||||
end = i.regs[0][1]
|
||||
result.append(org_text[start:end])
|
||||
start = end
|
||||
return result, org_text[start:]
|
||||
@ -255,6 +255,12 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
)
|
||||
index += 1
|
||||
time.sleep(0.01)
|
||||
else:
|
||||
queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=text
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
queue_manager.publish(
|
||||
QueueStopEvent(stopped_by=stopped_by),
|
||||
|
||||
@ -4,6 +4,8 @@ import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
@ -33,6 +35,8 @@ from core.app.entities.task_entities import (
|
||||
ChatbotAppStreamResponse,
|
||||
ChatflowStreamGenerateRoute,
|
||||
ErrorStreamResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
StreamResponse,
|
||||
)
|
||||
@ -71,13 +75,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
_iteration_nested_relations: dict[str, list[str]]
|
||||
|
||||
def __init__(
|
||||
self, application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool
|
||||
self, application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool
|
||||
) -> None:
|
||||
"""
|
||||
Initialize AdvancedChatAppGenerateTaskPipeline.
|
||||
@ -129,7 +133,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
self._application_generate_entity.query
|
||||
)
|
||||
|
||||
generator = self._process_stream_response(
|
||||
generator = self._wrapper_process_stream_response(
|
||||
trace_manager=self._application_generate_entity.trace_manager
|
||||
)
|
||||
if self._stream:
|
||||
@ -138,7 +142,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
return self._to_blocking_response(generator)
|
||||
|
||||
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) \
|
||||
-> ChatbotAppBlockingResponse:
|
||||
-> ChatbotAppBlockingResponse:
|
||||
"""
|
||||
Process blocking response.
|
||||
:return:
|
||||
@ -169,7 +173,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
raise Exception('Queue listening stopped unexpectedly.')
|
||||
|
||||
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \
|
||||
-> Generator[ChatbotAppStreamResponse, None, None]:
|
||||
-> Generator[ChatbotAppStreamResponse, None, None]:
|
||||
"""
|
||||
To stream response.
|
||||
:return:
|
||||
@ -182,14 +186,68 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
stream_response=stream_response
|
||||
)
|
||||
|
||||
def _listenAudioMsg(self, publisher, task_id: str):
|
||||
if not publisher:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
|
||||
Generator[StreamResponse, None, None]:
|
||||
|
||||
publisher = None
|
||||
task_id = self._application_generate_entity.task_id
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
features_dict = self._workflow.features_dict
|
||||
|
||||
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
|
||||
'text_to_speech'].get('autoPlay') == 'enabled':
|
||||
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
break
|
||||
yield response
|
||||
|
||||
start_listener_time = time.time()
|
||||
# timeout
|
||||
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
|
||||
try:
|
||||
if not publisher:
|
||||
break
|
||||
audio_trunk = publisher.checkAndGetAudio()
|
||||
if audio_trunk is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
|
||||
continue
|
||||
if audio_trunk.status == "finish":
|
||||
break
|
||||
else:
|
||||
start_listener_time = time.time()
|
||||
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
break
|
||||
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
|
||||
|
||||
def _process_stream_response(
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
self,
|
||||
publisher: AppGeneratorTTSPublisher,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
:return:
|
||||
"""
|
||||
for message in self._queue_manager.listen():
|
||||
if publisher:
|
||||
publisher.publish(message=message)
|
||||
event = message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
@ -301,7 +359,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
continue
|
||||
|
||||
if not self._is_stream_out_support(
|
||||
event=event
|
||||
event=event
|
||||
):
|
||||
continue
|
||||
|
||||
@ -318,7 +376,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
yield self._ping_stream_response()
|
||||
else:
|
||||
continue
|
||||
|
||||
if publisher:
|
||||
publisher.publish(None)
|
||||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
|
||||
@ -402,7 +461,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
return stream_generate_routes
|
||||
|
||||
def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \
|
||||
-> list[str]:
|
||||
-> list[str]:
|
||||
"""
|
||||
Get answer start at node id.
|
||||
:param graph: graph
|
||||
@ -457,7 +516,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
start_node_id = target_node_id
|
||||
start_node_ids.append(start_node_id)
|
||||
elif node_type == NodeType.START.value or \
|
||||
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
|
||||
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
|
||||
start_node_id = source_node_id
|
||||
start_node_ids.append(start_node_id)
|
||||
else:
|
||||
@ -515,7 +574,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
# all route chunks are generated
|
||||
if self._task_state.current_stream_generate_state.current_route_position == len(
|
||||
self._task_state.current_stream_generate_state.generate_route
|
||||
self._task_state.current_stream_generate_state.generate_route
|
||||
):
|
||||
self._task_state.current_stream_generate_state = None
|
||||
|
||||
@ -525,7 +584,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
:return:
|
||||
"""
|
||||
if not self._task_state.current_stream_generate_state:
|
||||
return None
|
||||
return
|
||||
|
||||
route_chunks = self._task_state.current_stream_generate_state.generate_route[
|
||||
self._task_state.current_stream_generate_state.current_route_position:]
|
||||
@ -573,7 +632,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
# get route chunk node execution info
|
||||
route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id]
|
||||
if (route_chunk_node_execution_info.node_type == NodeType.LLM
|
||||
and latest_node_execution_info.node_type == NodeType.LLM):
|
||||
and latest_node_execution_info.node_type == NodeType.LLM):
|
||||
# only LLM support chunk stream output
|
||||
self._task_state.current_stream_generate_state.current_route_position += 1
|
||||
continue
|
||||
@ -643,7 +702,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
# all route chunks are generated
|
||||
if self._task_state.current_stream_generate_state.current_route_position == len(
|
||||
self._task_state.current_stream_generate_state.generate_route
|
||||
self._task_state.current_stream_generate_state.generate_route
|
||||
):
|
||||
self._task_state.current_stream_generate_state = None
|
||||
|
||||
|
||||
@ -51,7 +51,6 @@ class AppQueueManager:
|
||||
listen_timeout = current_app.config.get("APP_MAX_EXECUTION_TIME")
|
||||
start_time = time.time()
|
||||
last_ping_time = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
message = self._q.get(timeout=1)
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
@ -25,6 +28,8 @@ from core.app.entities.queue_entities import (
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
ErrorStreamResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
StreamResponse,
|
||||
TextChunkStreamResponse,
|
||||
TextReplaceStreamResponse,
|
||||
@ -105,7 +110,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
db.session.refresh(self._user)
|
||||
db.session.close()
|
||||
|
||||
generator = self._process_stream_response(
|
||||
generator = self._wrapper_process_stream_response(
|
||||
trace_manager=self._application_generate_entity.trace_manager
|
||||
)
|
||||
if self._stream:
|
||||
@ -161,8 +166,58 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
stream_response=stream_response
|
||||
)
|
||||
|
||||
def _listenAudioMsg(self, publisher, task_id: str):
|
||||
if not publisher:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
|
||||
Generator[StreamResponse, None, None]:
|
||||
|
||||
publisher = None
|
||||
task_id = self._application_generate_entity.task_id
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
features_dict = self._workflow.features_dict
|
||||
|
||||
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
|
||||
'text_to_speech'].get('autoPlay') == 'enabled':
|
||||
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
break
|
||||
yield response
|
||||
|
||||
start_listener_time = time.time()
|
||||
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
|
||||
try:
|
||||
if not publisher:
|
||||
break
|
||||
audio_trunk = publisher.checkAndGetAudio()
|
||||
if audio_trunk is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
|
||||
continue
|
||||
if audio_trunk.status == "finish":
|
||||
break
|
||||
else:
|
||||
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
break
|
||||
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
|
||||
|
||||
|
||||
def _process_stream_response(
|
||||
self,
|
||||
publisher: AppGeneratorTTSPublisher,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
@ -170,6 +225,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
:return:
|
||||
"""
|
||||
for message in self._queue_manager.listen():
|
||||
if publisher:
|
||||
publisher.publish(message=message)
|
||||
event = message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
@ -251,6 +308,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
else:
|
||||
continue
|
||||
|
||||
if publisher:
|
||||
publisher.publish(None)
|
||||
|
||||
|
||||
def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
|
||||
"""
|
||||
Save workflow app log.
|
||||
|
||||
@ -69,6 +69,7 @@ class WorkflowTaskState(TaskState):
|
||||
|
||||
iteration_nested_node_ids: list[str] = None
|
||||
|
||||
|
||||
class AdvancedChatTaskState(WorkflowTaskState):
|
||||
"""
|
||||
AdvancedChatTaskState entity
|
||||
@ -86,6 +87,8 @@ class StreamEvent(Enum):
|
||||
ERROR = "error"
|
||||
MESSAGE = "message"
|
||||
MESSAGE_END = "message_end"
|
||||
TTS_MESSAGE = "tts_message"
|
||||
TTS_MESSAGE_END = "tts_message_end"
|
||||
MESSAGE_FILE = "message_file"
|
||||
MESSAGE_REPLACE = "message_replace"
|
||||
AGENT_THOUGHT = "agent_thought"
|
||||
@ -130,6 +133,22 @@ class MessageStreamResponse(StreamResponse):
|
||||
answer: str
|
||||
|
||||
|
||||
class MessageAudioStreamResponse(StreamResponse):
|
||||
"""
|
||||
MessageStreamResponse entity
|
||||
"""
|
||||
event: StreamEvent = StreamEvent.TTS_MESSAGE
|
||||
audio: str
|
||||
|
||||
|
||||
class MessageAudioEndStreamResponse(StreamResponse):
|
||||
"""
|
||||
MessageStreamResponse entity
|
||||
"""
|
||||
event: StreamEvent = StreamEvent.TTS_MESSAGE_END
|
||||
audio: str
|
||||
|
||||
|
||||
class MessageEndStreamResponse(StreamResponse):
|
||||
"""
|
||||
MessageEndStreamResponse entity
|
||||
@ -186,6 +205,7 @@ class WorkflowStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
WorkflowStartStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@ -205,6 +225,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
||||
"""
|
||||
WorkflowFinishStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@ -232,6 +253,7 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeStartStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@ -273,6 +295,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeFinishStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@ -323,10 +346,12 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class IterationNodeStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeStartStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@ -344,10 +369,12 @@ class IterationNodeStartStreamResponse(StreamResponse):
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class IterationNodeNextStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeStartStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@ -365,10 +392,12 @@ class IterationNodeNextStreamResponse(StreamResponse):
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class IterationNodeCompletedStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeCompletedStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@ -393,10 +422,12 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class TextChunkStreamResponse(StreamResponse):
|
||||
"""
|
||||
TextChunkStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@ -411,6 +442,7 @@ class TextReplaceStreamResponse(StreamResponse):
|
||||
"""
|
||||
TextReplaceStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@ -473,6 +505,7 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
|
||||
"""
|
||||
ChatbotAppBlockingResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@ -492,6 +525,7 @@ class CompletionAppBlockingResponse(AppBlockingResponse):
|
||||
"""
|
||||
CompletionAppBlockingResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@ -510,6 +544,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
|
||||
"""
|
||||
WorkflowAppBlockingResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
@ -528,10 +563,12 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class WorkflowIterationState(BaseModel):
|
||||
"""
|
||||
WorkflowIterationState entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
|
||||
1
api/core/app/features/rate_limiting/__init__.py
Normal file
1
api/core/app/features/rate_limiting/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .rate_limit import RateLimit
|
||||
120
api/core/app/features/rate_limiting/rate_limit.py
Normal file
120
api/core/app/features/rate_limiting/rate_limit.py
Normal file
@ -0,0 +1,120 @@
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from datetime import timedelta
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.errors.error import AppInvokeQuotaExceededError
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RateLimit:
|
||||
_MAX_ACTIVE_REQUESTS_KEY = "dify:rate_limit:{}:max_active_requests"
|
||||
_ACTIVE_REQUESTS_KEY = "dify:rate_limit:{}:active_requests"
|
||||
_UNLIMITED_REQUEST_ID = "unlimited_request_id"
|
||||
_REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes
|
||||
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
|
||||
_instance_dict = {}
|
||||
|
||||
def __new__(cls: type['RateLimit'], client_id: str, max_active_requests: int):
|
||||
if client_id not in cls._instance_dict:
|
||||
instance = super().__new__(cls)
|
||||
cls._instance_dict[client_id] = instance
|
||||
return cls._instance_dict[client_id]
|
||||
|
||||
def __init__(self, client_id: str, max_active_requests: int):
|
||||
self.max_active_requests = max_active_requests
|
||||
if hasattr(self, 'initialized'):
|
||||
return
|
||||
self.initialized = True
|
||||
self.client_id = client_id
|
||||
self.active_requests_key = self._ACTIVE_REQUESTS_KEY.format(client_id)
|
||||
self.max_active_requests_key = self._MAX_ACTIVE_REQUESTS_KEY.format(client_id)
|
||||
self.last_recalculate_time = float('-inf')
|
||||
self.flush_cache(use_local_value=True)
|
||||
|
||||
def flush_cache(self, use_local_value=False):
|
||||
self.last_recalculate_time = time.time()
|
||||
# flush max active requests
|
||||
if use_local_value or not redis_client.exists(self.max_active_requests_key):
|
||||
with redis_client.pipeline() as pipe:
|
||||
pipe.set(self.max_active_requests_key, self.max_active_requests)
|
||||
pipe.expire(self.max_active_requests_key, timedelta(days=1))
|
||||
pipe.execute()
|
||||
else:
|
||||
with redis_client.pipeline() as pipe:
|
||||
self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode('utf-8'))
|
||||
redis_client.expire(self.max_active_requests_key, timedelta(days=1))
|
||||
|
||||
# flush max active requests (in-transit request list)
|
||||
if not redis_client.exists(self.active_requests_key):
|
||||
return
|
||||
request_details = redis_client.hgetall(self.active_requests_key)
|
||||
redis_client.expire(self.active_requests_key, timedelta(days=1))
|
||||
timeout_requests = [k for k, v in request_details.items() if
|
||||
time.time() - float(v.decode('utf-8')) > RateLimit._REQUEST_MAX_ALIVE_TIME]
|
||||
if timeout_requests:
|
||||
redis_client.hdel(self.active_requests_key, *timeout_requests)
|
||||
|
||||
def enter(self, request_id: Optional[str] = None) -> str:
|
||||
if time.time() - self.last_recalculate_time > RateLimit._ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL:
|
||||
self.flush_cache()
|
||||
if self.max_active_requests <= 0:
|
||||
return RateLimit._UNLIMITED_REQUEST_ID
|
||||
if not request_id:
|
||||
request_id = RateLimit.gen_request_key()
|
||||
|
||||
active_requests_count = redis_client.hlen(self.active_requests_key)
|
||||
if active_requests_count >= self.max_active_requests:
|
||||
raise AppInvokeQuotaExceededError("Too many requests. Please try again later. The current maximum "
|
||||
"concurrent requests allowed is {}.".format(self.max_active_requests))
|
||||
redis_client.hset(self.active_requests_key, request_id, str(time.time()))
|
||||
return request_id
|
||||
|
||||
def exit(self, request_id: str):
|
||||
if request_id == RateLimit._UNLIMITED_REQUEST_ID:
|
||||
return
|
||||
redis_client.hdel(self.active_requests_key, request_id)
|
||||
|
||||
@staticmethod
|
||||
def gen_request_key() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def generate(self, generator: Union[Generator, callable, dict], request_id: str):
|
||||
if isinstance(generator, dict):
|
||||
return generator
|
||||
else:
|
||||
return RateLimitGenerator(self, generator, request_id)
|
||||
|
||||
|
||||
class RateLimitGenerator:
|
||||
def __init__(self, rate_limit: RateLimit, generator: Union[Generator, callable], request_id: str):
|
||||
self.rate_limit = rate_limit
|
||||
if callable(generator):
|
||||
self.generator = generator()
|
||||
else:
|
||||
self.generator = generator
|
||||
self.request_id = request_id
|
||||
self.closed = False
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.closed:
|
||||
raise StopIteration
|
||||
try:
|
||||
return next(self.generator)
|
||||
except StopIteration:
|
||||
self.close()
|
||||
raise
|
||||
|
||||
def close(self):
|
||||
if not self.closed:
|
||||
self.closed = True
|
||||
self.rate_limit.exit(self.request_id)
|
||||
if self.generator is not None and hasattr(self.generator, 'close'):
|
||||
self.generator.close()
|
||||
@ -4,6 +4,8 @@ import time
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AgentChatAppGenerateEntity,
|
||||
@ -32,6 +34,8 @@ from core.app.entities.task_entities import (
|
||||
CompletionAppStreamResponse,
|
||||
EasyUITaskState,
|
||||
ErrorStreamResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
StreamResponse,
|
||||
)
|
||||
@ -87,6 +91,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
"""
|
||||
super().__init__(application_generate_entity, queue_manager, user, stream)
|
||||
self._model_config = application_generate_entity.model_conf
|
||||
self._app_config = application_generate_entity.app_config
|
||||
self._conversation = conversation
|
||||
self._message = message
|
||||
|
||||
@ -102,7 +107,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
self._conversation_name_generate_thread = None
|
||||
|
||||
def process(
|
||||
self,
|
||||
self,
|
||||
) -> Union[
|
||||
ChatbotAppBlockingResponse,
|
||||
CompletionAppBlockingResponse,
|
||||
@ -123,7 +128,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
self._application_generate_entity.query
|
||||
)
|
||||
|
||||
generator = self._process_stream_response(
|
||||
generator = self._wrapper_process_stream_response(
|
||||
trace_manager=self._application_generate_entity.trace_manager
|
||||
)
|
||||
if self._stream:
|
||||
@ -202,14 +207,64 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
stream_response=stream_response
|
||||
)
|
||||
|
||||
def _listenAudioMsg(self, publisher, task_id: str):
|
||||
if publisher is None:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
# audio_str = audio_msg.audio.decode('utf-8', errors='ignore')
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
|
||||
Generator[StreamResponse, None, None]:
|
||||
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
task_id = self._application_generate_entity.task_id
|
||||
publisher = None
|
||||
text_to_speech_dict = self._app_config.app_model_config_dict.get('text_to_speech')
|
||||
if text_to_speech_dict and text_to_speech_dict.get('autoPlay') == 'enabled' and text_to_speech_dict.get('enabled'):
|
||||
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get('voice', None))
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(publisher, task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
break
|
||||
yield response
|
||||
|
||||
start_listener_time = time.time()
|
||||
# timeout
|
||||
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
|
||||
if publisher is None:
|
||||
break
|
||||
audio = publisher.checkAndGetAudio()
|
||||
if audio is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
|
||||
continue
|
||||
if audio.status == "finish":
|
||||
break
|
||||
else:
|
||||
start_listener_time = time.time()
|
||||
yield MessageAudioStreamResponse(audio=audio.audio,
|
||||
task_id=task_id)
|
||||
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
|
||||
|
||||
def _process_stream_response(
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
self,
|
||||
publisher: AppGeneratorTTSPublisher,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
:return:
|
||||
"""
|
||||
for message in self._queue_manager.listen():
|
||||
if publisher:
|
||||
publisher.publish(message)
|
||||
event = message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
@ -272,12 +327,13 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
yield self._ping_stream_response()
|
||||
else:
|
||||
continue
|
||||
|
||||
if publisher:
|
||||
publisher.publish(None)
|
||||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
|
||||
def _save_message(
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> None:
|
||||
"""
|
||||
Save message.
|
||||
|
||||
@ -31,6 +31,13 @@ class QuotaExceededError(Exception):
|
||||
description = "Quota Exceeded"
|
||||
|
||||
|
||||
class AppInvokeQuotaExceededError(Exception):
|
||||
"""
|
||||
Custom exception raised when the quota for an app has been exceeded.
|
||||
"""
|
||||
description = "App Invoke Quota Exceeded"
|
||||
|
||||
|
||||
class ModelCurrentlyNotSupportError(Exception):
|
||||
"""
|
||||
Custom exception raised when the model not support
|
||||
|
||||
@ -730,7 +730,7 @@ class IndexingRunner:
|
||||
self._check_document_paused_status(dataset_document.id)
|
||||
|
||||
tokens = 0
|
||||
if dataset.indexing_technique == 'high_quality' or embedding_model_type_instance:
|
||||
if embedding_model_instance:
|
||||
tokens += sum(
|
||||
embedding_model_instance.get_text_embedding_num_tokens(
|
||||
[document.page_content]
|
||||
|
||||
@ -264,7 +264,7 @@ class ModelInstance:
|
||||
user=user
|
||||
)
|
||||
|
||||
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, streaming: bool, user: Optional[str] = None) \
|
||||
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) \
|
||||
-> str:
|
||||
"""
|
||||
Invoke large language tts model
|
||||
@ -287,8 +287,7 @@ class ModelInstance:
|
||||
content_text=content_text,
|
||||
user=user,
|
||||
tenant_id=tenant_id,
|
||||
voice=voice,
|
||||
streaming=streaming
|
||||
voice=voice
|
||||
)
|
||||
|
||||
def _round_robin_invoke(self, function: Callable, *args, **kwargs):
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import re
|
||||
import subprocess
|
||||
import uuid
|
||||
from abc import abstractmethod
|
||||
@ -10,7 +12,7 @@ from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelTy
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class TTSModel(AIModel):
|
||||
"""
|
||||
Model class for ttstext model.
|
||||
@ -20,7 +22,7 @@ class TTSModel(AIModel):
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool,
|
||||
def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str,
|
||||
user: Optional[str] = None):
|
||||
"""
|
||||
Invoke large language model
|
||||
@ -35,14 +37,15 @@ class TTSModel(AIModel):
|
||||
:return: translated audio file
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Invoke TTS model: {model} , invoke content : {content_text}")
|
||||
self._is_ffmpeg_installed()
|
||||
return self._invoke(model=model, credentials=credentials, user=user, streaming=streaming,
|
||||
return self._invoke(model=model, credentials=credentials, user=user,
|
||||
content_text=content_text, voice=voice, tenant_id=tenant_id)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool,
|
||||
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str,
|
||||
user: Optional[str] = None):
|
||||
"""
|
||||
Invoke large language model
|
||||
@ -123,26 +126,26 @@ class TTSModel(AIModel):
|
||||
return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
|
||||
|
||||
@staticmethod
|
||||
def _split_text_into_sentences(text: str, limit: int, delimiters=None):
|
||||
if delimiters is None:
|
||||
delimiters = set('。!?;\n')
|
||||
|
||||
buf = []
|
||||
word_count = 0
|
||||
for char in text:
|
||||
buf.append(char)
|
||||
if char in delimiters:
|
||||
if word_count >= limit:
|
||||
yield ''.join(buf)
|
||||
buf = []
|
||||
word_count = 0
|
||||
else:
|
||||
word_count += 1
|
||||
else:
|
||||
word_count += 1
|
||||
|
||||
if buf:
|
||||
yield ''.join(buf)
|
||||
def _split_text_into_sentences(org_text, max_length=2000, pattern=r'[。.!?]'):
|
||||
match = re.compile(pattern)
|
||||
tx = match.finditer(org_text)
|
||||
start = 0
|
||||
result = []
|
||||
one_sentence = ''
|
||||
for i in tx:
|
||||
end = i.regs[0][1]
|
||||
tmp = org_text[start:end]
|
||||
if len(one_sentence + tmp) > max_length:
|
||||
result.append(one_sentence)
|
||||
one_sentence = ''
|
||||
one_sentence += tmp
|
||||
start = end
|
||||
last_sens = org_text[start:]
|
||||
if last_sens:
|
||||
one_sentence += last_sens
|
||||
if one_sentence != '':
|
||||
result.append(one_sentence)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _is_ffmpeg_installed():
|
||||
|
||||
@ -33,3 +33,4 @@
|
||||
- deepseek
|
||||
- hunyuan
|
||||
- siliconflow
|
||||
- perfxcloud
|
||||
|
||||
@ -4,7 +4,7 @@ from functools import reduce
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask import Response
|
||||
from openai import AzureOpenAI
|
||||
from pydub import AudioSegment
|
||||
|
||||
@ -14,7 +14,6 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
|
||||
from core.model_runtime.model_providers.azure_openai._constant import TTS_BASE_MODELS, AzureBaseModel
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
||||
class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
@ -23,7 +22,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, tenant_id: str, credentials: dict,
|
||||
content_text: str, voice: str, streaming: bool, user: Optional[str] = None) -> any:
|
||||
content_text: str, voice: str, user: Optional[str] = None) -> any:
|
||||
"""
|
||||
_invoke text2speech model
|
||||
|
||||
@ -32,30 +31,23 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model timbre
|
||||
:param streaming: output is streaming
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]:
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
if streaming:
|
||||
return Response(stream_with_context(self._tts_invoke_streaming(model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
tenant_id=tenant_id,
|
||||
voice=voice)),
|
||||
status=200, mimetype=f'audio/{audio_type}')
|
||||
else:
|
||||
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
|
||||
return self._tts_invoke_streaming(model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
voice=voice)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
validate credentials text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
try:
|
||||
@ -82,7 +74,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
max_workers = self._get_model_workers_limit(model, credentials)
|
||||
try:
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
sentences = list(self._split_text_into_sentences(org_text=content_text, max_length=word_limit))
|
||||
audio_bytes_list = []
|
||||
|
||||
# Create a thread pool and map the function to the list of sentences
|
||||
@ -107,34 +99,37 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
# Todo: To improve the streaming function
|
||||
def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str,
|
||||
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str,
|
||||
voice: str) -> any:
|
||||
"""
|
||||
_tts_invoke_streaming text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model timbre
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials):
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
tts_file_id = self._get_file_name(content_text)
|
||||
file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}'
|
||||
try:
|
||||
# doc: https://platform.openai.com/docs/guides/text-to-speech
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = AzureOpenAI(**credentials_kwargs)
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
for sentence in sentences:
|
||||
response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip())
|
||||
# response.stream_to_file(file_path)
|
||||
storage.save(file_path, response.read())
|
||||
# max font is 4096,there is 3500 limit for each request
|
||||
max_length = 3500
|
||||
if len(content_text) > max_length:
|
||||
sentences = self._split_text_into_sentences(content_text, max_length=max_length)
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences)))
|
||||
futures = [executor.submit(client.audio.speech.with_streaming_response.create, model=model,
|
||||
response_format="mp3",
|
||||
input=sentences[i], voice=voice) for i in range(len(sentences))]
|
||||
for index, future in enumerate(futures):
|
||||
yield from future.result().__enter__().iter_bytes(1024)
|
||||
|
||||
else:
|
||||
response = client.audio.speech.with_streaming_response.create(model=model, voice=voice,
|
||||
response_format="mp3",
|
||||
input=content_text.strip())
|
||||
|
||||
yield from response.__enter__().iter_bytes(1024)
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
@ -162,7 +157,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
|
||||
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel | None:
|
||||
for ai_model_entity in TTS_BASE_MODELS:
|
||||
if ai_model_entity.base_model_name == base_model_name:
|
||||
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
|
||||
@ -170,5 +165,4 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
|
||||
ai_model_entity_copy.entity.label.en_US = model
|
||||
ai_model_entity_copy.entity.label.zh_Hans = model
|
||||
return ai_model_entity_copy
|
||||
|
||||
return None
|
||||
|
||||
@ -66,6 +66,10 @@ provider_credential_schema:
|
||||
label:
|
||||
en_US: Europe (Frankfurt)
|
||||
zh_Hans: 欧洲 (法兰克福)
|
||||
- value: eu-west-2
|
||||
label:
|
||||
en_US: Eu west London (London)
|
||||
zh_Hans: 欧洲西部 (伦敦)
|
||||
- value: us-gov-west-1
|
||||
label:
|
||||
en_US: AWS GovCloud (US-West)
|
||||
|
||||
@ -7,7 +7,7 @@ features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
|
||||
@ -7,7 +7,7 @@ features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
|
||||
@ -21,7 +21,7 @@ model_properties:
|
||||
- mode: 'shimmer'
|
||||
name: 'Shimmer'
|
||||
language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ]
|
||||
word_limit: 120
|
||||
word_limit: 3500
|
||||
audio_type: 'mp3'
|
||||
max_workers: 5
|
||||
pricing:
|
||||
|
||||
@ -21,7 +21,7 @@ model_properties:
|
||||
- mode: 'shimmer'
|
||||
name: 'Shimmer'
|
||||
language: ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
|
||||
word_limit: 120
|
||||
word_limit: 3500
|
||||
audio_type: 'mp3'
|
||||
max_workers: 5
|
||||
pricing:
|
||||
|
||||
@ -3,7 +3,7 @@ from functools import reduce
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask import Response
|
||||
from openai import OpenAI
|
||||
from pydub import AudioSegment
|
||||
|
||||
@ -11,7 +11,6 @@ from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.model_runtime.model_providers.openai._common import _CommonOpenAI
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
||||
class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||
@ -20,7 +19,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, tenant_id: str, credentials: dict,
|
||||
content_text: str, voice: str, streaming: bool, user: Optional[str] = None) -> any:
|
||||
content_text: str, voice: str, user: Optional[str] = None) -> any:
|
||||
"""
|
||||
_invoke text2speech model
|
||||
|
||||
@ -29,22 +28,17 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model timbre
|
||||
:param streaming: output is streaming
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
|
||||
if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]:
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
if streaming:
|
||||
return Response(stream_with_context(self._tts_invoke_streaming(model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
tenant_id=tenant_id,
|
||||
voice=voice)),
|
||||
status=200, mimetype=f'audio/{audio_type}')
|
||||
else:
|
||||
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice)
|
||||
# if streaming:
|
||||
return self._tts_invoke_streaming(model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
voice=voice)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
@ -79,7 +73,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
max_workers = self._get_model_workers_limit(model, credentials)
|
||||
try:
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
sentences = list(self._split_text_into_sentences(org_text=content_text, max_length=word_limit))
|
||||
audio_bytes_list = []
|
||||
|
||||
# Create a thread pool and map the function to the list of sentences
|
||||
@ -104,34 +98,40 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
# Todo: To improve the streaming function
|
||||
def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str,
|
||||
|
||||
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str,
|
||||
voice: str) -> any:
|
||||
"""
|
||||
_tts_invoke_streaming text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model timbre
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials):
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
tts_file_id = self._get_file_name(content_text)
|
||||
file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}'
|
||||
try:
|
||||
# doc: https://platform.openai.com/docs/guides/text-to-speech
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = OpenAI(**credentials_kwargs)
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
for sentence in sentences:
|
||||
response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip())
|
||||
# response.stream_to_file(file_path)
|
||||
storage.save(file_path, response.read())
|
||||
if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials):
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
if len(content_text) > word_limit:
|
||||
sentences = self._split_text_into_sentences(content_text, max_length=word_limit)
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences)))
|
||||
futures = [executor.submit(client.audio.speech.with_streaming_response.create, model=model,
|
||||
response_format="mp3",
|
||||
input=sentences[i], voice=voice) for i in range(len(sentences))]
|
||||
for index, future in enumerate(futures):
|
||||
yield from future.result().__enter__().iter_bytes(1024)
|
||||
|
||||
else:
|
||||
response = client.audio.speech.with_streaming_response.create(model=model, voice=voice,
|
||||
response_format="mp3",
|
||||
input=content_text.strip())
|
||||
|
||||
yield from response.__enter__().iter_bytes(1024)
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
|
||||
@ -616,30 +616,34 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls:
|
||||
# message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call
|
||||
# in
|
||||
# message.tool_calls]
|
||||
|
||||
function_call = message.tool_calls[0]
|
||||
message_dict["function_call"] = {
|
||||
"name": function_call.function.name,
|
||||
"arguments": function_call.function.arguments,
|
||||
}
|
||||
function_calling_type = credentials.get('function_calling_type', 'no_call')
|
||||
if function_calling_type == 'tool_call':
|
||||
message_dict["tool_calls"] = [tool_call.dict() for tool_call in
|
||||
message.tool_calls]
|
||||
elif function_calling_type == 'function_call':
|
||||
function_call = message.tool_calls[0]
|
||||
message_dict["function_call"] = {
|
||||
"name": function_call.function.name,
|
||||
"arguments": function_call.function.arguments,
|
||||
}
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
# message_dict = {
|
||||
# "role": "tool",
|
||||
# "content": message.content,
|
||||
# "tool_call_id": message.tool_call_id
|
||||
# }
|
||||
message_dict = {
|
||||
"role": "tool" if credentials and credentials.get('function_calling_type', 'no_call') == 'tool_call' else "function",
|
||||
"content": message.content,
|
||||
"name": message.tool_call_id
|
||||
}
|
||||
function_calling_type = credentials.get('function_calling_type', 'no_call')
|
||||
if function_calling_type == 'tool_call':
|
||||
message_dict = {
|
||||
"role": "tool",
|
||||
"content": message.content,
|
||||
"tool_call_id": message.tool_call_id
|
||||
}
|
||||
elif function_calling_type == 'function_call':
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"content": message.content,
|
||||
"name": message.tool_call_id
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 21 KiB |
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 48 KiB |
@ -0,0 +1,61 @@
|
||||
model: Qwen-14B-Chat-Int4
|
||||
label:
|
||||
en_US: Qwen-14B-Chat-Int4
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 4096
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 600
|
||||
min: 1
|
||||
max: 1248
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
pricing:
|
||||
input: '0.000'
|
||||
output: '0.000'
|
||||
unit: '0.000'
|
||||
currency: RMB
|
||||
@ -0,0 +1,61 @@
|
||||
model: Qwen1.5-110B-Chat-GPTQ-Int4
|
||||
label:
|
||||
en_US: Qwen1.5-110B-Chat-GPTQ-Int4
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 128
|
||||
min: 1
|
||||
max: 256
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
pricing:
|
||||
input: '0.000'
|
||||
output: '0.000'
|
||||
unit: '0.000'
|
||||
currency: RMB
|
||||
@ -0,0 +1,61 @@
|
||||
model: Qwen1.5-72B-Chat-GPTQ-Int4
|
||||
label:
|
||||
en_US: Qwen1.5-72B-Chat-GPTQ-Int4
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 600
|
||||
min: 1
|
||||
max: 2000
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
pricing:
|
||||
input: '0.000'
|
||||
output: '0.000'
|
||||
unit: '0.000'
|
||||
currency: RMB
|
||||
@ -0,0 +1,61 @@
|
||||
model: Qwen1.5-7B
|
||||
label:
|
||||
en_US: Qwen1.5-7B
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 600
|
||||
min: 1
|
||||
max: 2000
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
pricing:
|
||||
input: '0.000'
|
||||
output: '0.000'
|
||||
unit: '0.000'
|
||||
currency: RMB
|
||||
@ -0,0 +1,63 @@
|
||||
model: Qwen2-72B-Instruct-GPTQ-Int4
|
||||
label:
|
||||
en_US: Qwen2-72B-Instruct-GPTQ-Int4
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 600
|
||||
min: 1
|
||||
max: 2000
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
pricing:
|
||||
input: '0.000'
|
||||
output: '0.000'
|
||||
unit: '0.000'
|
||||
currency: RMB
|
||||
@ -0,0 +1,63 @@
|
||||
model: Qwen2-7B
|
||||
label:
|
||||
en_US: Qwen2-7B
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 0.3
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。
|
||||
en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 600
|
||||
min: 1
|
||||
max: 2000
|
||||
help:
|
||||
zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。
|
||||
en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 0.8
|
||||
min: 0.1
|
||||
max: 0.9
|
||||
help:
|
||||
zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。
|
||||
en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated.
|
||||
- name: top_k
|
||||
type: int
|
||||
min: 0
|
||||
max: 99
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
help:
|
||||
zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。
|
||||
en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated.
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
default: 1.1
|
||||
label:
|
||||
en_US: Repetition penalty
|
||||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment.
|
||||
pricing:
|
||||
input: '0.000'
|
||||
output: '0.000'
|
||||
unit: '0.000'
|
||||
currency: RMB
|
||||
@ -0,0 +1,6 @@
|
||||
- Qwen2-72B-Instruct-GPTQ-Int4
|
||||
- Qwen2-7B
|
||||
- Qwen1.5-110B-Chat-GPTQ-Int4
|
||||
- Qwen1.5-72B-Chat-GPTQ-Int4
|
||||
- Qwen1.5-7B
|
||||
- Qwen-14B-Chat-Int4
|
||||
110
api/core/model_runtime/model_providers/perfxcloud/llm/llm.py
Normal file
110
api/core/model_runtime/model_providers/perfxcloud/llm/llm.py
Normal file
@ -0,0 +1,110 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import tiktoken
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguageModel
|
||||
|
||||
|
||||
class PerfXCloudLargeLanguageModel(OpenAILargeLanguageModel):
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
# refactored from openai model runtime, use cl100k_base for calculate token number
|
||||
def _num_tokens_from_string(self, model: str, text: str,
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Calculate num tokens for text completion model with tiktoken package.
|
||||
|
||||
:param model: model name
|
||||
:param text: prompt text
|
||||
:param tools: tools for tool calling
|
||||
:return: number of tokens
|
||||
"""
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
num_tokens = len(encoding.encode(text))
|
||||
|
||||
if tools:
|
||||
num_tokens += self._num_tokens_for_tools(encoding, tools)
|
||||
|
||||
return num_tokens
|
||||
|
||||
# refactored from openai model runtime, use cl100k_base for calculate token number
|
||||
def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
|
||||
num_tokens = 0
|
||||
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
# Cast str(value) in case the message value is not a string
|
||||
# This occurs with function messages
|
||||
# TODO: The current token calculation method for the image type is not implemented,
|
||||
# which need to download the image and then get the resolution for calculation,
|
||||
# and will increase the request delay
|
||||
if isinstance(value, list):
|
||||
text = ''
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item['type'] == 'text':
|
||||
text += item['text']
|
||||
|
||||
value = text
|
||||
|
||||
if key == "tool_calls":
|
||||
for tool_call in value:
|
||||
for t_key, t_value in tool_call.items():
|
||||
num_tokens += len(encoding.encode(t_key))
|
||||
if t_key == "function":
|
||||
for f_key, f_value in t_value.items():
|
||||
num_tokens += len(encoding.encode(f_key))
|
||||
num_tokens += len(encoding.encode(f_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(t_key))
|
||||
num_tokens += len(encoding.encode(t_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(str(value)))
|
||||
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
|
||||
if tools:
|
||||
num_tokens += self._num_tokens_for_tools(encoding, tools)
|
||||
|
||||
return num_tokens
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials: dict) -> None:
|
||||
credentials['mode'] = 'chat'
|
||||
credentials['openai_api_key']=credentials['api_key']
|
||||
if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "":
|
||||
credentials['openai_api_base']='https://cloud.perfxlab.cn'
|
||||
else:
|
||||
parsed_url = urlparse(credentials['endpoint_url'])
|
||||
credentials['openai_api_base']=f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
@ -0,0 +1,32 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PerfXCloudProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Use `Qwen2_72B_Chat_GPTQ_Int4` model for validate,
|
||||
# no matter what model you pass in, text completion model or chat model
|
||||
model_instance.validate_credentials(
|
||||
model='Qwen2-72B-Instruct-GPTQ-Int4',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
raise ex
|
||||
@ -0,0 +1,42 @@
|
||||
provider: perfxcloud
|
||||
label:
|
||||
en_US: PerfXCloud
|
||||
zh_Hans: PerfXCloud
|
||||
description:
|
||||
en_US: PerfXCloud (Pengfeng Technology) is an AI development and deployment platform tailored for developers and enterprises, providing reasoning capabilities for multiple models.
|
||||
zh_Hans: PerfXCloud(澎峰科技)为开发者和企业量身打造的AI开发和部署平台,提供多种模型的的推理能力。
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
en_US: icon_l_en.svg
|
||||
background: "#e3f0ff"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API Key from PerfXCloud
|
||||
zh_Hans: 从 PerfXCloud 获取 API Key
|
||||
url:
|
||||
en_US: https://cloud.perfxlab.cn/panel/token
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: endpoint_url
|
||||
label:
|
||||
zh_Hans: 自定义 API endpoint 地址
|
||||
en_US: Custom API endpoint URL
|
||||
type: text-input
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: Base URL, e.g. https://cloud.perfxlab.cn/v1
|
||||
en_US: Base URL, e.g. https://cloud.perfxlab.cn/v1
|
||||
@ -0,0 +1,4 @@
|
||||
model: BAAI/bge-m3
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 32768
|
||||
@ -0,0 +1,250 @@
|
||||
import json
|
||||
import time
|
||||
from decimal import Decimal
|
||||
from typing import Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
PriceConfig,
|
||||
PriceType,
|
||||
)
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat
|
||||
|
||||
|
||||
class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
|
||||
"""
|
||||
Model class for an OpenAI API-compatible text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
|
||||
# Prepare headers and payload for the request
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
api_key = credentials.get('api_key')
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "":
|
||||
endpoint_url='https://cloud.perfxlab.cn/v1/'
|
||||
else:
|
||||
endpoint_url = credentials.get('endpoint_url')
|
||||
if not endpoint_url.endswith('/'):
|
||||
endpoint_url += '/'
|
||||
|
||||
endpoint_url = urljoin(endpoint_url, 'embeddings')
|
||||
|
||||
extra_model_kwargs = {}
|
||||
if user:
|
||||
extra_model_kwargs['user'] = user
|
||||
|
||||
extra_model_kwargs['encoding_format'] = 'float'
|
||||
|
||||
# get model properties
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
max_chunks = self._get_max_chunks(model, credentials)
|
||||
|
||||
inputs = []
|
||||
indices = []
|
||||
used_tokens = 0
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
|
||||
# Here token count is only an approximation based on the GPT2 tokenizer
|
||||
# TODO: Optimize for better token estimation and chunking
|
||||
num_tokens = self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
if num_tokens >= context_size:
|
||||
cutoff = int(len(text) * (np.floor(context_size / num_tokens)))
|
||||
# if num tokens is larger than context length, only use the start
|
||||
inputs.append(text[0: cutoff])
|
||||
else:
|
||||
inputs.append(text)
|
||||
indices += [i]
|
||||
|
||||
batched_embeddings = []
|
||||
_iter = range(0, len(inputs), max_chunks)
|
||||
|
||||
for i in _iter:
|
||||
# Prepare the payload for the request
|
||||
payload = {
|
||||
'input': inputs[i: i + max_chunks],
|
||||
'model': model,
|
||||
**extra_model_kwargs
|
||||
}
|
||||
|
||||
# Make the request to the OpenAI API
|
||||
response = requests.post(
|
||||
endpoint_url,
|
||||
headers=headers,
|
||||
data=json.dumps(payload),
|
||||
timeout=(10, 300)
|
||||
)
|
||||
|
||||
response.raise_for_status() # Raise an exception for HTTP errors
|
||||
response_data = response.json()
|
||||
|
||||
# Extract embeddings and used tokens from the response
|
||||
embeddings_batch = [data['embedding'] for data in response_data['data']]
|
||||
embedding_used_tokens = response_data['usage']['total_tokens']
|
||||
|
||||
used_tokens += embedding_used_tokens
|
||||
batched_embeddings += embeddings_batch
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
tokens=used_tokens
|
||||
)
|
||||
|
||||
return TextEmbeddingResult(
|
||||
embeddings=batched_embeddings,
|
||||
usage=usage,
|
||||
model=model
|
||||
)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Approximate number of tokens for given messages using GPT2 tokenizer
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
api_key = credentials.get('api_key')
|
||||
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "":
|
||||
endpoint_url='https://cloud.perfxlab.cn/v1/'
|
||||
else:
|
||||
endpoint_url = credentials.get('endpoint_url')
|
||||
if not endpoint_url.endswith('/'):
|
||||
endpoint_url += '/'
|
||||
|
||||
endpoint_url = urljoin(endpoint_url, 'embeddings')
|
||||
|
||||
payload = {
|
||||
'input': 'ping',
|
||||
'model': model
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
url=endpoint_url,
|
||||
headers=headers,
|
||||
data=json.dumps(payload),
|
||||
timeout=(10, 300)
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise CredentialsValidateFailedError(
|
||||
f'Credentials validation failed with status code {response.status_code}')
|
||||
|
||||
try:
|
||||
json_result = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error')
|
||||
|
||||
if 'model' not in json_result:
|
||||
raise CredentialsValidateFailedError(
|
||||
'Credentials validation failed: invalid response')
|
||||
except CredentialsValidateFailedError:
|
||||
raise
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
|
||||
ModelPropertyKey.MAX_CHUNKS: 1,
|
||||
},
|
||||
parameter_rules=[],
|
||||
pricing=PriceConfig(
|
||||
input=Decimal(credentials.get('input_price', 0)),
|
||||
unit=Decimal(credentials.get('unit', 0)),
|
||||
currency=credentials.get('currency', "USD")
|
||||
)
|
||||
)
|
||||
|
||||
return entity
|
||||
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
)
|
||||
|
||||
return usage
|
||||
@ -129,7 +129,7 @@ model_properties:
|
||||
- mode: "sambert-waan-v1"
|
||||
name: "Waan(泰语女声)"
|
||||
language: [ "th-TH" ]
|
||||
word_limit: 120
|
||||
word_limit: 7000
|
||||
audio_type: 'mp3'
|
||||
max_workers: 5
|
||||
pricing:
|
||||
|
||||
@ -1,17 +1,21 @@
|
||||
import concurrent.futures
|
||||
import threading
|
||||
from functools import reduce
|
||||
from io import BytesIO
|
||||
from queue import Queue
|
||||
from typing import Optional
|
||||
|
||||
import dashscope
|
||||
from flask import Response, stream_with_context
|
||||
from dashscope import SpeechSynthesizer
|
||||
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
|
||||
from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult
|
||||
from flask import Response
|
||||
from pydub import AudioSegment
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.model_runtime.model_providers.tongyi._common import _CommonTongyi
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
||||
class TongyiText2SpeechModel(_CommonTongyi, TTSModel):
|
||||
@ -19,7 +23,7 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel):
|
||||
Model class for Tongyi Speech to text model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool,
|
||||
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str,
|
||||
user: Optional[str] = None) -> any:
|
||||
"""
|
||||
_invoke text2speech model
|
||||
@ -29,22 +33,17 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel):
|
||||
:param credentials: model credentials
|
||||
:param voice: model timbre
|
||||
:param content_text: text content to be translated
|
||||
:param streaming: output is streaming
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]:
|
||||
if not voice or voice not in [d['value'] for d in
|
||||
self.get_tts_model_voices(model=model, credentials=credentials)]:
|
||||
voice = self._get_model_default_voice(model, credentials)
|
||||
if streaming:
|
||||
return Response(stream_with_context(self._tts_invoke_streaming(model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
tenant_id=tenant_id)),
|
||||
status=200, mimetype=f'audio/{audio_type}')
|
||||
else:
|
||||
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice)
|
||||
|
||||
return self._tts_invoke_streaming(model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
voice=voice)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
@ -79,7 +78,7 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel):
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
max_workers = self._get_model_workers_limit(model, credentials)
|
||||
try:
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
sentences = list(self._split_text_into_sentences(org_text=content_text, max_length=word_limit))
|
||||
audio_bytes_list = []
|
||||
|
||||
# Create a thread pool and map the function to the list of sentences
|
||||
@ -105,14 +104,12 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel):
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
# Todo: To improve the streaming function
|
||||
def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str,
|
||||
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str,
|
||||
voice: str) -> any:
|
||||
"""
|
||||
_tts_invoke_streaming text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param voice: model timbre
|
||||
:param content_text: text content to be translated
|
||||
@ -120,18 +117,32 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel):
|
||||
"""
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
tts_file_id = self._get_file_name(content_text)
|
||||
file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}'
|
||||
try:
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
for sentence in sentences:
|
||||
response = dashscope.audio.tts.SpeechSynthesizer.call(model=voice, sample_rate=48000,
|
||||
api_key=credentials.get('dashscope_api_key'),
|
||||
text=sentence.strip(),
|
||||
format=audio_type, word_timestamp_enabled=True,
|
||||
phoneme_timestamp_enabled=True)
|
||||
if isinstance(response.get_audio_data(), bytes):
|
||||
storage.save(file_path, response.get_audio_data())
|
||||
audio_queue: Queue = Queue()
|
||||
callback = Callback(queue=audio_queue)
|
||||
|
||||
def invoke_remote(content, v, api_key, cb, at, wl):
|
||||
if len(content) < word_limit:
|
||||
sentences = [content]
|
||||
else:
|
||||
sentences = list(self._split_text_into_sentences(org_text=content, max_length=wl))
|
||||
for sentence in sentences:
|
||||
SpeechSynthesizer.call(model=v, sample_rate=16000,
|
||||
api_key=api_key,
|
||||
text=sentence.strip(),
|
||||
callback=cb,
|
||||
format=at, word_timestamp_enabled=True,
|
||||
phoneme_timestamp_enabled=True)
|
||||
|
||||
threading.Thread(target=invoke_remote, args=(
|
||||
content_text, voice, credentials.get('dashscope_api_key'), callback, audio_type, word_limit)).start()
|
||||
|
||||
while True:
|
||||
audio = audio_queue.get()
|
||||
if audio is None:
|
||||
break
|
||||
yield audio
|
||||
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
@ -152,3 +163,29 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel):
|
||||
format=audio_type)
|
||||
if isinstance(response.get_audio_data(), bytes):
|
||||
return response.get_audio_data()
|
||||
|
||||
|
||||
class Callback(ResultCallback):
|
||||
|
||||
def __init__(self, queue: Queue):
|
||||
self._queue = queue
|
||||
|
||||
def on_open(self):
|
||||
pass
|
||||
|
||||
def on_complete(self):
|
||||
self._queue.put(None)
|
||||
self._queue.task_done()
|
||||
|
||||
def on_error(self, response: SpeechSynthesisResponse):
|
||||
self._queue.put(None)
|
||||
self._queue.task_done()
|
||||
|
||||
def on_close(self):
|
||||
self._queue.put(None)
|
||||
self._queue.task_done()
|
||||
|
||||
def on_event(self, result: SpeechSynthesisResult):
|
||||
ad = result.get_audio_frame()
|
||||
if ad:
|
||||
self._queue.put(ad)
|
||||
|
||||
@ -29,7 +29,7 @@ model_credential_schema:
|
||||
label:
|
||||
zh_Hans: 服务器URL
|
||||
en_US: Server url
|
||||
type: secret-input
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入 Triton Inference Server 的服务器地址,如 http://192.168.1.100:8000
|
||||
|
||||
@ -0,0 +1,40 @@
|
||||
model: ernie-4.0-turbo-8k-preview
|
||||
label:
|
||||
en_US: Ernie-4.0-turbo-8k-preview
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0.1
|
||||
max: 1.0
|
||||
default: 0.8
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 2
|
||||
max: 2048
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
default: 1.0
|
||||
min: 1.0
|
||||
max: 2.0
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: disable_search
|
||||
label:
|
||||
zh_Hans: 禁用搜索
|
||||
en_US: Disable Search
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 禁用模型自行进行外部搜索。
|
||||
en_US: Disable the model to perform external search.
|
||||
required: false
|
||||
@ -138,6 +138,7 @@ class ErnieBotModel:
|
||||
'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
|
||||
'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
|
||||
'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
|
||||
'ernie-4.0-tutbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
|
||||
}
|
||||
|
||||
function_calling_supports = [
|
||||
@ -149,6 +150,7 @@ class ErnieBotModel:
|
||||
'ernie-3.5-4k-0205',
|
||||
'ernie-3.5-128k',
|
||||
'ernie-4.0-8k'
|
||||
'ernie-4.0-turbo-8k-preview'
|
||||
]
|
||||
|
||||
api_key: str = ''
|
||||
|
||||
@ -32,7 +32,7 @@ model_credential_schema:
|
||||
label:
|
||||
zh_Hans: 服务器URL
|
||||
en_US: Server url
|
||||
type: secret-input
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入Xinference的服务器地址,如 http://192.168.1.100:9997
|
||||
|
||||
@ -20,7 +20,7 @@ class ZhipuaiProvider(ModelProvider):
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
model_instance.validate_credentials(
|
||||
model='chatglm_turbo',
|
||||
model='glm-4',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
|
||||
@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
import httpx
|
||||
|
||||
from ..core._base_api import BaseAPI
|
||||
from ..core._base_type import NOT_GIVEN, Headers, NotGiven
|
||||
from ..core._base_type import NOT_GIVEN, Body, Headers, NotGiven
|
||||
from ..core._http_client import make_user_request_input
|
||||
from ..types.image import ImagesResponded
|
||||
|
||||
@ -28,7 +28,9 @@ class Images(BaseAPI):
|
||||
size: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
style: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
user: str | NotGiven = NOT_GIVEN,
|
||||
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
disable_strict_validation: Optional[bool] | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> ImagesResponded:
|
||||
@ -46,9 +48,12 @@ class Images(BaseAPI):
|
||||
"size": size,
|
||||
"style": style,
|
||||
"user": user,
|
||||
"request_id": request_id,
|
||||
},
|
||||
options=make_user_request_input(
|
||||
extra_headers=extra_headers, timeout=timeout
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout
|
||||
),
|
||||
cast_type=_cast_type,
|
||||
enable_stream=False,
|
||||
|
||||
@ -11,7 +11,7 @@ from tenacity import retry
|
||||
from tenacity.stop import stop_after_attempt
|
||||
|
||||
from . import _errors
|
||||
from ._base_type import NOT_GIVEN, Body, Data, Headers, NotGiven, Query, RequestFiles, ResponseT
|
||||
from ._base_type import NOT_GIVEN, AnyMapping, Body, Data, Headers, NotGiven, Query, RequestFiles, ResponseT
|
||||
from ._errors import APIResponseValidationError, APIStatusError, APITimeoutError
|
||||
from ._files import make_httpx_files
|
||||
from ._request_opt import ClientRequestParam, UserRequestInput
|
||||
@ -358,6 +358,7 @@ def make_user_request_input(
|
||||
max_retries: int | None = None,
|
||||
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers = None,
|
||||
extra_body: Body | None = None,
|
||||
query: Query | None = None,
|
||||
) -> UserRequestInput:
|
||||
options: UserRequestInput = {}
|
||||
@ -370,5 +371,7 @@ def make_user_request_input(
|
||||
options['timeout'] = timeout
|
||||
if query is not None:
|
||||
options["params"] = query
|
||||
if extra_body is not None:
|
||||
options["extra_json"] = cast(AnyMapping, extra_body)
|
||||
|
||||
return options
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.keyword.jieba.jieba import Jieba
|
||||
from core.rag.datasource.keyword.keyword_base import BaseKeyword
|
||||
from core.rag.models.document import Document
|
||||
@ -14,8 +13,8 @@ class Keyword:
|
||||
self._keyword_processor = self._init_keyword()
|
||||
|
||||
def _init_keyword(self) -> BaseKeyword:
|
||||
config = current_app.config
|
||||
keyword_type = config.get('KEYWORD_STORE')
|
||||
config = dify_config
|
||||
keyword_type = config.KEYWORD_STORE
|
||||
|
||||
if not keyword_type:
|
||||
raise ValueError("Keyword store must be specified.")
|
||||
|
||||
0
api/core/rag/datasource/vdb/analyticdb/__init__.py
Normal file
0
api/core/rag/datasource/vdb/analyticdb/__init__.py
Normal file
332
api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py
Normal file
332
api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py
Normal file
@ -0,0 +1,332 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
_import_err_msg = (
|
||||
"`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
|
||||
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
|
||||
)
|
||||
from flask import current_app
|
||||
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class AnalyticdbConfig(BaseModel):
|
||||
access_key_id: str
|
||||
access_key_secret: str
|
||||
region_id: str
|
||||
instance_id: str
|
||||
account: str
|
||||
account_password: str
|
||||
namespace: str = ("dify",)
|
||||
namespace_password: str = (None,)
|
||||
metrics: str = ("cosine",)
|
||||
read_timeout: int = 60000
|
||||
def to_analyticdb_client_params(self):
|
||||
return {
|
||||
"access_key_id": self.access_key_id,
|
||||
"access_key_secret": self.access_key_secret,
|
||||
"region_id": self.region_id,
|
||||
"read_timeout": self.read_timeout,
|
||||
}
|
||||
|
||||
class AnalyticdbVector(BaseVector):
|
||||
_instance = None
|
||||
_init = False
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, collection_name: str, config: AnalyticdbConfig):
|
||||
# collection_name must be updated every time
|
||||
self._collection_name = collection_name.lower()
|
||||
if AnalyticdbVector._init:
|
||||
return
|
||||
try:
|
||||
from alibabacloud_gpdb20160503.client import Client
|
||||
from alibabacloud_tea_openapi import models as open_api_models
|
||||
except:
|
||||
raise ImportError(_import_err_msg)
|
||||
self.config = config
|
||||
self._client_config = open_api_models.Config(
|
||||
user_agent="dify", **config.to_analyticdb_client_params()
|
||||
)
|
||||
self._client = Client(self._client_config)
|
||||
self._initialize()
|
||||
AnalyticdbVector._init = True
|
||||
|
||||
def _initialize(self) -> None:
|
||||
self._initialize_vector_database()
|
||||
self._create_namespace_if_not_exists()
|
||||
|
||||
def _initialize_vector_database(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
request = gpdb_20160503_models.InitVectorDatabaseRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
)
|
||||
self._client.init_vector_database(request)
|
||||
|
||||
def _create_namespace_if_not_exists(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
try:
|
||||
request = gpdb_20160503_models.DescribeNamespaceRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
)
|
||||
self._client.describe_namespace(request)
|
||||
except TeaException as e:
|
||||
if e.statusCode == 404:
|
||||
request = gpdb_20160503_models.CreateNamespaceRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
)
|
||||
self._client.create_namespace(request)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"failed to create namespace {self.config.namespace}: {e}"
|
||||
)
|
||||
|
||||
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
try:
|
||||
request = gpdb_20160503_models.DescribeCollectionRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
)
|
||||
self._client.describe_collection(request)
|
||||
except TeaException as e:
|
||||
if e.statusCode == 404:
|
||||
metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}'
|
||||
full_text_retrieval_fields = "page_content"
|
||||
request = gpdb_20160503_models.CreateCollectionRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
namespace=self.config.namespace,
|
||||
collection=self._collection_name,
|
||||
dimension=embedding_dimension,
|
||||
metrics=self.config.metrics,
|
||||
metadata=metadata,
|
||||
full_text_retrieval_fields=full_text_retrieval_fields,
|
||||
)
|
||||
self._client.create_collection(request)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"failed to create collection {self._collection_name}: {e}"
|
||||
)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.ANALYTICDB
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection_if_not_exists(dimension)
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(
|
||||
self, documents: list[Document], embeddings: list[list[float]], **kwargs
|
||||
):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
|
||||
for doc, embedding in zip(documents, embeddings, strict=True):
|
||||
metadata = {
|
||||
"ref_doc_id": doc.metadata["doc_id"],
|
||||
"page_content": doc.page_content,
|
||||
"metadata_": json.dumps(doc.metadata),
|
||||
}
|
||||
rows.append(
|
||||
gpdb_20160503_models.UpsertCollectionDataRequestRows(
|
||||
vector=embedding,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
request = gpdb_20160503_models.UpsertCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
rows=rows,
|
||||
)
|
||||
self._client.upsert_collection_data(request)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
metrics=self.config.metrics,
|
||||
include_values=True,
|
||||
vector=None,
|
||||
content=None,
|
||||
top_k=1,
|
||||
filter=f"ref_doc_id='{id}'"
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
return len(response.body.matches.match) > 0
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
ids_str = ",".join(f"'{id}'" for id in ids)
|
||||
ids_str = f"({ids_str})"
|
||||
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None,
|
||||
collection_data_filter=f"ref_doc_id IN {ids_str}",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None,
|
||||
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
|
||||
def search_by_vector(
|
||||
self, query_vector: list[float], **kwargs: Any
|
||||
) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
score_threshold = (
|
||||
kwargs.get("score_threshold", 0.0)
|
||||
if kwargs.get("score_threshold", 0.0)
|
||||
else 0.0
|
||||
)
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=query_vector,
|
||||
content=None,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=None,
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score > score_threshold:
|
||||
doc = Document(
|
||||
page_content=match.metadata.get("page_content"),
|
||||
metadata=json.loads(match.metadata.get("metadata_")),
|
||||
)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
score_threshold = (
|
||||
kwargs.get("score_threshold", 0.0)
|
||||
if kwargs.get("score_threshold", 0.0)
|
||||
else 0.0
|
||||
)
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=None,
|
||||
content=query,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=None,
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score > score_threshold:
|
||||
doc = Document(
|
||||
page_content=match.metadata.get("page_content"),
|
||||
metadata=json.loads(match.metadata.get("metadata_")),
|
||||
)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
def delete(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
request = gpdb_20160503_models.DeleteCollectionRequest(
|
||||
collection=self._collection_name,
|
||||
dbinstance_id=self.config.instance_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
region_id=self.config.region_id,
|
||||
)
|
||||
self._client.delete_collection(request)
|
||||
|
||||
class AnalyticdbVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings):
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"][
|
||||
"class_prefix"
|
||||
]
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name)
|
||||
)
|
||||
config = current_app.config
|
||||
return AnalyticdbVector(
|
||||
collection_name,
|
||||
AnalyticdbConfig(
|
||||
access_key_id=config.get("ANALYTICDB_KEY_ID"),
|
||||
access_key_secret=config.get("ANALYTICDB_KEY_SECRET"),
|
||||
region_id=config.get("ANALYTICDB_REGION_ID"),
|
||||
instance_id=config.get("ANALYTICDB_INSTANCE_ID"),
|
||||
account=config.get("ANALYTICDB_ACCOUNT"),
|
||||
account_password=config.get("ANALYTICDB_PASSWORD"),
|
||||
namespace=config.get("ANALYTICDB_NAMESPACE"),
|
||||
namespace_password=config.get("ANALYTICDB_NAMESPACE_PASSWORD"),
|
||||
),
|
||||
)
|
||||
0
api/core/rag/datasource/vdb/myscale/__init__.py
Normal file
0
api/core/rag/datasource/vdb/myscale/__init__.py
Normal file
170
api/core/rag/datasource/vdb/myscale/myscale_vector.py
Normal file
170
api/core/rag/datasource/vdb/myscale/myscale_vector.py
Normal file
@ -0,0 +1,170 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from clickhouse_connect import get_client
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.models.document import Document
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class MyScaleConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
user: str
|
||||
password: str
|
||||
database: str
|
||||
fts_params: str
|
||||
|
||||
|
||||
class SortOrder(Enum):
|
||||
ASC = "ASC"
|
||||
DESC = "DESC"
|
||||
|
||||
|
||||
class MyScaleVector(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, config: MyScaleConfig, metric: str = "Cosine"):
|
||||
super().__init__(collection_name)
|
||||
self._config = config
|
||||
self._metric = metric
|
||||
self._vec_order = SortOrder.ASC if metric.upper() in ["COSINE", "L2"] else SortOrder.DESC
|
||||
self._client = get_client(
|
||||
host=config.host,
|
||||
port=config.port,
|
||||
username=config.user,
|
||||
password=config.password,
|
||||
)
|
||||
self._client.command("SET allow_experimental_object_type=1")
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.MYSCALE
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection(dimension)
|
||||
return self.add_texts(documents=texts, embeddings=embeddings, **kwargs)
|
||||
|
||||
def _create_collection(self, dimension: int):
|
||||
logging.info(f"create MyScale collection {self._collection_name} with dimension {dimension}")
|
||||
self._client.command(f"CREATE DATABASE IF NOT EXISTS {self._config.database}")
|
||||
fts_params = f"('{self._config.fts_params}')" if self._config.fts_params else ""
|
||||
sql = f"""
|
||||
CREATE TABLE IF NOT EXISTS {self._config.database}.{self._collection_name}(
|
||||
id String,
|
||||
text String,
|
||||
vector Array(Float32),
|
||||
metadata JSON,
|
||||
CONSTRAINT cons_vec_len CHECK length(vector) = {dimension},
|
||||
VECTOR INDEX vidx vector TYPE DEFAULT('metric_type = {self._metric}'),
|
||||
INDEX text_idx text TYPE fts{fts_params}
|
||||
) ENGINE = MergeTree ORDER BY id
|
||||
"""
|
||||
self._client.command(sql)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
ids = []
|
||||
columns = ["id", "text", "vector", "metadata"]
|
||||
values = []
|
||||
for i, doc in enumerate(documents):
|
||||
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
|
||||
row = (
|
||||
doc_id,
|
||||
self.escape_str(doc.page_content),
|
||||
embeddings[i],
|
||||
json.dumps(doc.metadata) if doc.metadata else {}
|
||||
)
|
||||
values.append(str(row))
|
||||
ids.append(doc_id)
|
||||
sql = f"""
|
||||
INSERT INTO {self._config.database}.{self._collection_name}
|
||||
({",".join(columns)}) VALUES {",".join(values)}
|
||||
"""
|
||||
self._client.command(sql)
|
||||
return ids
|
||||
|
||||
@staticmethod
|
||||
def escape_str(value: Any) -> str:
|
||||
return "".join(f"\\{c}" if c in ("\\", "'") else c for c in str(value))
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'")
|
||||
return results.row_count > 0
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
self._client.command(
|
||||
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}")
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
rows = self._client.query(
|
||||
f"SELECT DISTINCT id FROM {self._config.database}.{self._collection_name} WHERE metadata.{key}='{value}'"
|
||||
).result_rows
|
||||
return [row[0] for row in rows]
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
self._client.command(
|
||||
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE metadata.{key}='{value}'"
|
||||
)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
return self._search(f"distance(vector, {str(query_vector)})", self._vec_order, **kwargs)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return self._search(f"TextSearch(text, '{query}')", SortOrder.DESC, **kwargs)
|
||||
|
||||
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
where_str = f"WHERE dist < {1 - score_threshold}" if \
|
||||
self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 else ""
|
||||
sql = f"""
|
||||
SELECT text, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
|
||||
{where_str} ORDER BY dist {order.value} LIMIT {top_k}
|
||||
"""
|
||||
try:
|
||||
return [
|
||||
Document(
|
||||
page_content=r["text"],
|
||||
metadata=r["metadata"],
|
||||
)
|
||||
for r in self._client.query(sql).named_results()
|
||||
]
|
||||
except Exception as e:
|
||||
logging.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
|
||||
return []
|
||||
|
||||
def delete(self) -> None:
|
||||
self._client.command(f"DROP TABLE IF EXISTS {self._config.database}.{self._collection_name}")
|
||||
|
||||
|
||||
class MyScaleVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MyScaleVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.MYSCALE, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
return MyScaleVector(
|
||||
collection_name=collection_name,
|
||||
config=MyScaleConfig(
|
||||
host=config.get("MYSCALE_HOST", "localhost"),
|
||||
port=int(config.get("MYSCALE_PORT", 8123)),
|
||||
user=config.get("MYSCALE_USER", "default"),
|
||||
password=config.get("MYSCALE_PASSWORD", ""),
|
||||
database=config.get("MYSCALE_DATABASE", "default"),
|
||||
fts_params=config.get("MYSCALE_FTS_PARAMS", ""),
|
||||
),
|
||||
)
|
||||
@ -57,6 +57,9 @@ class Vector:
|
||||
case VectorType.MILVUS:
|
||||
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory
|
||||
return MilvusVectorFactory
|
||||
case VectorType.MYSCALE:
|
||||
from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory
|
||||
return MyScaleVectorFactory
|
||||
case VectorType.PGVECTOR:
|
||||
from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory
|
||||
return PGVectorFactory
|
||||
@ -84,6 +87,9 @@ class Vector:
|
||||
case VectorType.OPENSEARCH:
|
||||
from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchVectorFactory
|
||||
return OpenSearchVectorFactory
|
||||
case VectorType.ANALYTICDB:
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory
|
||||
return AnalyticdbVectorFactory
|
||||
case _:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
|
||||
@ -2,8 +2,10 @@ from enum import Enum
|
||||
|
||||
|
||||
class VectorType(str, Enum):
|
||||
ANALYTICDB = 'analyticdb'
|
||||
CHROMA = 'chroma'
|
||||
MILVUS = 'milvus'
|
||||
MYSCALE = 'myscale'
|
||||
PGVECTOR = 'pgvector'
|
||||
PGVECTO_RS = 'pgvecto-rs'
|
||||
QDRANT = 'qdrant'
|
||||
|
||||
@ -46,7 +46,6 @@ class FirecrawlApp:
|
||||
raise Exception(f'Failed to scrape URL. Status code: {response.status_code}')
|
||||
|
||||
def crawl_url(self, url, params=None) -> str:
|
||||
start_time = time.time()
|
||||
headers = self._prepare_headers()
|
||||
json_data = {'url': url}
|
||||
if params:
|
||||
|
||||
@ -18,8 +18,8 @@ class MarkdownExtractor(BaseExtractor):
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
remove_hyperlinks: bool = True,
|
||||
remove_images: bool = True,
|
||||
remove_hyperlinks: bool = False,
|
||||
remove_images: bool = False,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = True,
|
||||
):
|
||||
|
||||
@ -8,7 +8,7 @@ We have defined a series of helper methods in the `Tool` class to help developer
|
||||
|
||||
### Message Return
|
||||
|
||||
Dify supports various message types such as `text`, `link`, `image`, and `file BLOB`. You can return different types of messages to the LLM and users through the following interfaces.
|
||||
Dify supports various message types such as `text`, `link`, `json`, `image`, and `file BLOB`. You can return different types of messages to the LLM and users through the following interfaces.
|
||||
|
||||
Please note, some parameters in the following interfaces will be introduced in later sections.
|
||||
|
||||
@ -67,6 +67,18 @@ If you need to return the raw data of a file, such as images, audio, video, PPT,
|
||||
"""
|
||||
```
|
||||
|
||||
#### JSON
|
||||
If you need to return a formatted JSON, you can use the following interface. This is commonly used for data transmission between nodes in a workflow, of course, in agent mode, most LLM are also able to read and understand JSON.
|
||||
|
||||
- `object` A Python dictionary object will be automatically serialized into JSON
|
||||
|
||||
```python
|
||||
def create_json_message(self, object: dict) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a json message
|
||||
"""
|
||||
```
|
||||
|
||||
### Shortcut Tools
|
||||
|
||||
In large model applications, we have two common needs:
|
||||
|
||||
@ -145,19 +145,25 @@ parameters: # Parameter list
|
||||
|
||||
- The `identity` field is mandatory, it contains the basic information of the tool, including name, author, label, description, etc.
|
||||
- `parameters` Parameter list
|
||||
- `name` Parameter name, unique, no duplication with other parameters
|
||||
- `type` Parameter type, currently supports `string`, `number`, `boolean`, `select`, `secret-input` four types, corresponding to string, number, boolean, drop-down box, and encrypted input box, respectively. For sensitive information, we recommend using `secret-input` type
|
||||
- `required` Required or not
|
||||
- `name` (Mandatory) Parameter name, must be unique and not duplicate with other parameters.
|
||||
- `type` (Mandatory) Parameter type, currently supports `string`, `number`, `boolean`, `select`, `secret-input` five types, corresponding to string, number, boolean, drop-down box, and encrypted input box, respectively. For sensitive information, we recommend using the `secret-input` type
|
||||
- `label` (Mandatory) Parameter label, for frontend display
|
||||
- `form` (Mandatory) Form type, currently supports `llm`, `form` two types.
|
||||
- In an agent app, `llm` indicates that the parameter is inferred by the LLM itself, while `form` indicates that the parameter can be pre-set for the tool.
|
||||
- In a workflow app, both `llm` and `form` need to be filled out by the front end, but the parameters of `llm` will be used as input variables for the tool node.
|
||||
- `required` Indicates whether the parameter is required or not
|
||||
- In `llm` mode, if the parameter is required, the Agent is required to infer this parameter
|
||||
- In `form` mode, if the parameter is required, the user is required to fill in this parameter on the frontend before the conversation starts
|
||||
- `options` Parameter options
|
||||
- In `llm` mode, Dify will pass all options to LLM, LLM can infer based on these options
|
||||
- In `form` mode, when `type` is `select`, the frontend will display these options
|
||||
- `default` Default value
|
||||
- `label` Parameter label, for frontend display
|
||||
- `min` Minimum value, can be set when the parameter type is `number`.
|
||||
- `max` Maximum value, can be set when the parameter type is `number`.
|
||||
- `placeholder` The prompt text for input boxes. It can be set when the form type is `form`, and the parameter type is `string`, `number`, or `secret-input`. It supports multiple languages.
|
||||
- `human_description` Introduction for frontend display, supports multiple languages
|
||||
- `llm_description` Introduction passed to LLM, in order to make LLM better understand this parameter, we suggest to write as detailed information about this parameter as possible here, so that LLM can understand this parameter
|
||||
- `form` Form type, currently supports `llm`, `form` two types, corresponding to Agent self-inference and frontend filling
|
||||
|
||||
|
||||
## 4. Add Tool Logic
|
||||
|
||||
@ -196,7 +202,7 @@ The overall logic of the tool is in the `_invoke` method, this method accepts tw
|
||||
|
||||
### Return Data
|
||||
|
||||
When the tool returns, you can choose to return one message or multiple messages, here we return one message, using `create_text_message` and `create_link_message` can create a text message or a link message.
|
||||
When the tool returns, you can choose to return one message or multiple messages, here we return one message, using `create_text_message` and `create_link_message` can create a text message or a link message. If you want to return multiple messages, you can use `[self.create_text_message('msg1'), self.create_text_message('msg2')]` to create a list of messages.
|
||||
|
||||
## 5. Add Provider Code
|
||||
|
||||
@ -205,8 +211,6 @@ Finally, we need to create a provider class under the provider module to impleme
|
||||
Create `google.py` under the `google` module, the content is as follows.
|
||||
|
||||
```python
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@
|
||||
|
||||
### 消息返回
|
||||
|
||||
Dify支持`文本` `链接` `图片` `文件BLOB` 等多种消息类型,你可以通过以下几个接口返回不同类型的消息给LLM和用户。
|
||||
Dify支持`文本` `链接` `图片` `文件BLOB` `JSON` 等多种消息类型,你可以通过以下几个接口返回不同类型的消息给LLM和用户。
|
||||
|
||||
注意,在下面的接口中的部分参数将在后面的章节中介绍。
|
||||
|
||||
@ -67,6 +67,18 @@ Dify支持`文本` `链接` `图片` `文件BLOB` 等多种消息类型,你可
|
||||
"""
|
||||
```
|
||||
|
||||
#### JSON
|
||||
如果你需要返回一个格式化的JSON,可以使用以下接口。这通常用于workflow中的节点间的数据传递,当然agent模式中,大部分大模型也都能够阅读和理解JSON。
|
||||
|
||||
- `object` 一个Python的字典对象,会被自动序列化为JSON
|
||||
|
||||
```python
|
||||
def create_json_message(self, object: dict) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a json message
|
||||
"""
|
||||
```
|
||||
|
||||
### 快捷工具
|
||||
|
||||
在大模型应用中,我们有两种常见的需求:
|
||||
@ -97,8 +109,8 @@ Dify支持`文本` `链接` `图片` `文件BLOB` 等多种消息类型,你可
|
||||
```python
|
||||
def get_url(self, url: str, user_agent: str = None) -> str:
|
||||
"""
|
||||
get url
|
||||
""" the crawled result
|
||||
get url from the crawled result
|
||||
"""
|
||||
```
|
||||
|
||||
### 变量池
|
||||
|
||||
@ -140,8 +140,12 @@ parameters: # 参数列表
|
||||
|
||||
- `identity` 字段是必须的,它包含了工具的基本信息,包括名称、作者、标签、描述等
|
||||
- `parameters` 参数列表
|
||||
- `name` 参数名称,唯一,不允许和其他参数重名
|
||||
- `type` 参数类型,目前支持`string`、`number`、`boolean`、`select`、`secret-input` 五种类型,分别对应字符串、数字、布尔值、下拉框、加密输入框,对于敏感信息,我们建议使用`secret-input`类型
|
||||
- `name` (必填)参数名称,唯一,不允许和其他参数重名
|
||||
- `type` (必填)参数类型,目前支持`string`、`number`、`boolean`、`select`、`secret-input` 五种类型,分别对应字符串、数字、布尔值、下拉框、加密输入框,对于敏感信息,我们建议使用`secret-input`类型
|
||||
- `label`(必填)参数标签,用于前端展示
|
||||
- `form` (必填)表单类型,目前支持`llm`、`form`两种类型
|
||||
- 在Agent应用中,`llm`表示该参数LLM自行推理,`form`表示要使用该工具可提前设定的参数
|
||||
- 在workflow应用中,`llm`和`form`均需要前端填写,但`llm`的参数会做为工具节点的输入变量
|
||||
- `required` 是否必填
|
||||
- 在`llm`模式下,如果参数为必填,则会要求Agent必须要推理出这个参数
|
||||
- 在`form`模式下,如果参数为必填,则会要求用户在对话开始前在前端填写这个参数
|
||||
@ -149,10 +153,12 @@ parameters: # 参数列表
|
||||
- 在`llm`模式下,Dify会将所有选项传递给LLM,LLM可以根据这些选项进行推理
|
||||
- 在`form`模式下,`type`为`select`时,前端会展示这些选项
|
||||
- `default` 默认值
|
||||
- `label` 参数标签,用于前端展示
|
||||
- `min` 最小值,当参数类型为`number`时可以设定
|
||||
- `max` 最大值,当参数类型为`number`时可以设定
|
||||
- `human_description` 用于前端展示的介绍,支持多语言
|
||||
- `placeholder` 字段输入框的提示文字,在表单类型为`form`,参数类型为`string`、`number`、`secret-input`时,可以设定,支持多语言
|
||||
- `llm_description` 传递给LLM的介绍,为了使得LLM更好理解这个参数,我们建议在这里写上关于这个参数尽可能详细的信息,让LLM能够理解这个参数
|
||||
- `form` 表单类型,目前支持`llm`、`form`两种类型,分别对应Agent自行推理和前端填写
|
||||
|
||||
|
||||
## 4. 准备工具代码
|
||||
当完成工具的配置以后,我们就可以开始编写工具代码了,主要用于实现工具的逻辑。
|
||||
@ -176,7 +182,6 @@ class GoogleSearchTool(BuiltinTool):
|
||||
query = tool_parameters['query']
|
||||
result_type = tool_parameters['result_type']
|
||||
api_key = self.runtime.credentials['serpapi_api_key']
|
||||
# TODO: search with serpapi
|
||||
result = SerpAPI(api_key).run(query, result_type=result_type)
|
||||
|
||||
if result_type == 'text':
|
||||
@ -188,7 +193,7 @@ class GoogleSearchTool(BuiltinTool):
|
||||
工具的整体逻辑都在`_invoke`方法中,这个方法接收两个参数:`user_id`和`tool_parameters`,分别表示用户ID和工具参数
|
||||
|
||||
### 返回数据
|
||||
在工具返回时,你可以选择返回一个消息或者多个消息,这里我们返回一个消息,使用`create_text_message`和`create_link_message`可以创建一个文本消息或者一个链接消息。
|
||||
在工具返回时,你可以选择返回一条消息或者多个消息,这里我们返回一条消息,使用`create_text_message`和`create_link_message`可以创建一条文本消息或者一条链接消息。如需返回多条消息,可以使用列表构建,例如`[self.create_text_message('msg1'), self.create_text_message('msg2')]`
|
||||
|
||||
## 5. 准备供应商代码
|
||||
最后,我们需要在供应商模块下创建一个供应商类,用于实现供应商的凭据验证逻辑,如果凭据验证失败,将会抛出`ToolProviderCredentialValidationError`异常。
|
||||
@ -196,8 +201,6 @@ class GoogleSearchTool(BuiltinTool):
|
||||
在`google`模块下创建`google.py`,内容如下。
|
||||
|
||||
```python
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
|
||||
@ -142,7 +142,8 @@ class ToolParameter(BaseModel):
|
||||
|
||||
name: str = Field(..., description="The name of the parameter")
|
||||
label: I18nObject = Field(..., description="The label presented to the user")
|
||||
human_description: I18nObject = Field(..., description="The description presented to the user")
|
||||
human_description: Optional[I18nObject] = Field(None, description="The description presented to the user")
|
||||
placeholder: Optional[I18nObject] = Field(None, description="The placeholder presented to the user")
|
||||
type: ToolParameterType = Field(..., description="The type of the parameter")
|
||||
form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
|
||||
llm_description: Optional[str] = None
|
||||
|
||||
0
api/core/tools/provider/builtin/cogview/__init__.py
Normal file
0
api/core/tools/provider/builtin/cogview/__init__.py
Normal file
BIN
api/core/tools/provider/builtin/cogview/_assets/icon.png
Normal file
BIN
api/core/tools/provider/builtin/cogview/_assets/icon.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 22 KiB |
27
api/core/tools/provider/builtin/cogview/cogview.py
Normal file
27
api/core/tools/provider/builtin/cogview/cogview.py
Normal file
@ -0,0 +1,27 @@
|
||||
""" Provide the input parameters type for the cogview provider class """
|
||||
from typing import Any
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.cogview.tools.cogview3 import CogView3Tool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class COGVIEWProvider(BuiltinToolProviderController):
|
||||
""" cogview provider """
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
CogView3Tool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_parameters={
|
||||
"prompt": "一个城市在水晶瓶中欢快生活的场景,水彩画风格,展现出微观与珠宝般的美丽。",
|
||||
"size": "square",
|
||||
"n": 1
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e)) from e
|
||||
|
||||
61
api/core/tools/provider/builtin/cogview/cogview.yaml
Normal file
61
api/core/tools/provider/builtin/cogview/cogview.yaml
Normal file
@ -0,0 +1,61 @@
|
||||
identity:
|
||||
author: Waffle
|
||||
name: cogview
|
||||
label:
|
||||
en_US: CogView
|
||||
zh_Hans: CogView 绘画
|
||||
pt_BR: CogView
|
||||
description:
|
||||
en_US: CogView art
|
||||
zh_Hans: CogView 绘画
|
||||
pt_BR: CogView art
|
||||
icon: icon.png
|
||||
tags:
|
||||
- image
|
||||
- productivity
|
||||
credentials_for_provider:
|
||||
zhipuai_api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: ZhipuAI API key
|
||||
zh_Hans: ZhipuAI API key
|
||||
pt_BR: ZhipuAI API key
|
||||
help:
|
||||
en_US: Please input your ZhipuAI API key
|
||||
zh_Hans: 请输入你的 ZhipuAI API key
|
||||
pt_BR: Please input your ZhipuAI API key
|
||||
placeholder:
|
||||
en_US: Please input your ZhipuAI API key
|
||||
zh_Hans: 请输入你的 ZhipuAI API key
|
||||
pt_BR: Please input your ZhipuAI API key
|
||||
zhipuai_organizaion_id:
|
||||
type: text-input
|
||||
required: false
|
||||
label:
|
||||
en_US: ZhipuAI organization ID
|
||||
zh_Hans: ZhipuAI organization ID
|
||||
pt_BR: ZhipuAI organization ID
|
||||
help:
|
||||
en_US: Please input your ZhipuAI organization ID
|
||||
zh_Hans: 请输入你的 ZhipuAI organization ID
|
||||
pt_BR: Please input your ZhipuAI organization ID
|
||||
placeholder:
|
||||
en_US: Please input your ZhipuAI organization ID
|
||||
zh_Hans: 请输入你的 ZhipuAI organization ID
|
||||
pt_BR: Please input your ZhipuAI organization ID
|
||||
zhipuai_base_url:
|
||||
type: text-input
|
||||
required: false
|
||||
label:
|
||||
en_US: ZhipuAI base URL
|
||||
zh_Hans: ZhipuAI base URL
|
||||
pt_BR: ZhipuAI base URL
|
||||
help:
|
||||
en_US: Please input your ZhipuAI base URL
|
||||
zh_Hans: 请输入你的 ZhipuAI base URL
|
||||
pt_BR: Please input your ZhipuAI base URL
|
||||
placeholder:
|
||||
en_US: Please input your ZhipuAI base URL
|
||||
zh_Hans: 请输入你的 ZhipuAI base URL
|
||||
pt_BR: Please input your ZhipuAI base URL
|
||||
69
api/core/tools/provider/builtin/cogview/tools/cogview3.py
Normal file
69
api/core/tools/provider/builtin/cogview/tools/cogview3.py
Normal file
@ -0,0 +1,69 @@
|
||||
import random
|
||||
from typing import Any, Union
|
||||
|
||||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class CogView3Tool(BuiltinTool):
|
||||
""" CogView3 Tool """
|
||||
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
Invoke CogView3 tool
|
||||
"""
|
||||
client = ZhipuAI(
|
||||
base_url=self.runtime.credentials['zhipuai_base_url'],
|
||||
api_key=self.runtime.credentials['zhipuai_api_key'],
|
||||
)
|
||||
size_mapping = {
|
||||
'square': '1024x1024',
|
||||
'vertical': '1024x1792',
|
||||
'horizontal': '1792x1024',
|
||||
}
|
||||
# prompt
|
||||
prompt = tool_parameters.get('prompt', '')
|
||||
if not prompt:
|
||||
return self.create_text_message('Please input prompt')
|
||||
# get size
|
||||
print(tool_parameters.get('prompt', 'square'))
|
||||
size = size_mapping[tool_parameters.get('size', 'square')]
|
||||
# get n
|
||||
n = tool_parameters.get('n', 1)
|
||||
# get quality
|
||||
quality = tool_parameters.get('quality', 'standard')
|
||||
if quality not in ['standard', 'hd']:
|
||||
return self.create_text_message('Invalid quality')
|
||||
# get style
|
||||
style = tool_parameters.get('style', 'vivid')
|
||||
if style not in ['natural', 'vivid']:
|
||||
return self.create_text_message('Invalid style')
|
||||
# set extra body
|
||||
seed_id = tool_parameters.get('seed_id', self._generate_random_id(8))
|
||||
extra_body = {'seed': seed_id}
|
||||
response = client.images.generations(
|
||||
prompt=prompt,
|
||||
model="cogview-3",
|
||||
size=size,
|
||||
n=n,
|
||||
extra_body=extra_body,
|
||||
style=style,
|
||||
quality=quality,
|
||||
response_format='b64_json'
|
||||
)
|
||||
result = []
|
||||
for image in response.data:
|
||||
result.append(self.create_image_message(image=image.url))
|
||||
result.append(self.create_text_message(
|
||||
f'\nGenerate image source to Seed ID: {seed_id}'))
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _generate_random_id(length=8):
|
||||
characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
|
||||
random_id = ''.join(random.choices(characters, k=length))
|
||||
return random_id
|
||||
123
api/core/tools/provider/builtin/cogview/tools/cogview3.yaml
Normal file
123
api/core/tools/provider/builtin/cogview/tools/cogview3.yaml
Normal file
@ -0,0 +1,123 @@
|
||||
identity:
|
||||
name: cogview3
|
||||
author: Waffle
|
||||
label:
|
||||
en_US: CogView 3
|
||||
zh_Hans: CogView 3 绘画
|
||||
pt_BR: CogView 3
|
||||
description:
|
||||
en_US: CogView 3 is a powerful drawing tool that can draw the image you want based on your prompt
|
||||
zh_Hans: CogView 3 是一个强大的绘画工具,它可以根据您的提示词绘制出您想要的图像
|
||||
pt_BR: CogView 3 is a powerful drawing tool that can draw the image you want based on your prompt
|
||||
description:
|
||||
human:
|
||||
en_US: CogView 3 is a text to image tool
|
||||
zh_Hans: CogView 3 是一个文本到图像的工具
|
||||
pt_BR: CogView 3 is a text to image tool
|
||||
llm: CogView 3 is a tool used to generate images from text
|
||||
parameters:
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Prompt
|
||||
zh_Hans: 提示词
|
||||
pt_BR: Prompt
|
||||
human_description:
|
||||
en_US: Image prompt, you can check the official documentation of CogView 3
|
||||
zh_Hans: 图像提示词,您可以查看CogView 3 的官方文档
|
||||
pt_BR: Image prompt, you can check the official documentation of CogView 3
|
||||
llm_description: Image prompt of CogView 3, you should describe the image you want to generate as a list of words as possible as detailed
|
||||
form: llm
|
||||
- name: size
|
||||
type: select
|
||||
required: true
|
||||
human_description:
|
||||
en_US: selecting the image size
|
||||
zh_Hans: 选择图像大小
|
||||
pt_BR: selecting the image size
|
||||
label:
|
||||
en_US: Image size
|
||||
zh_Hans: 图像大小
|
||||
pt_BR: Image size
|
||||
form: form
|
||||
options:
|
||||
- value: square
|
||||
label:
|
||||
en_US: Squre(1024x1024)
|
||||
zh_Hans: 方(1024x1024)
|
||||
pt_BR: Squre(1024x1024)
|
||||
- value: vertical
|
||||
label:
|
||||
en_US: Vertical(1024x1792)
|
||||
zh_Hans: 竖屏(1024x1792)
|
||||
pt_BR: Vertical(1024x1792)
|
||||
- value: horizontal
|
||||
label:
|
||||
en_US: Horizontal(1792x1024)
|
||||
zh_Hans: 横屏(1792x1024)
|
||||
pt_BR: Horizontal(1792x1024)
|
||||
default: square
|
||||
- name: n
|
||||
type: number
|
||||
required: true
|
||||
human_description:
|
||||
en_US: selecting the number of images
|
||||
zh_Hans: 选择图像数量
|
||||
pt_BR: selecting the number of images
|
||||
label:
|
||||
en_US: Number of images
|
||||
zh_Hans: 图像数量
|
||||
pt_BR: Number of images
|
||||
form: form
|
||||
min: 1
|
||||
max: 1
|
||||
default: 1
|
||||
- name: quality
|
||||
type: select
|
||||
required: true
|
||||
human_description:
|
||||
en_US: selecting the image quality
|
||||
zh_Hans: 选择图像质量
|
||||
pt_BR: selecting the image quality
|
||||
label:
|
||||
en_US: Image quality
|
||||
zh_Hans: 图像质量
|
||||
pt_BR: Image quality
|
||||
form: form
|
||||
options:
|
||||
- value: standard
|
||||
label:
|
||||
en_US: Standard
|
||||
zh_Hans: 标准
|
||||
pt_BR: Standard
|
||||
- value: hd
|
||||
label:
|
||||
en_US: HD
|
||||
zh_Hans: 高清
|
||||
pt_BR: HD
|
||||
default: standard
|
||||
- name: style
|
||||
type: select
|
||||
required: true
|
||||
human_description:
|
||||
en_US: selecting the image style
|
||||
zh_Hans: 选择图像风格
|
||||
pt_BR: selecting the image style
|
||||
label:
|
||||
en_US: Image style
|
||||
zh_Hans: 图像风格
|
||||
pt_BR: Image style
|
||||
form: form
|
||||
options:
|
||||
- value: vivid
|
||||
label:
|
||||
en_US: Vivid
|
||||
zh_Hans: 生动
|
||||
pt_BR: Vivid
|
||||
- value: natural
|
||||
label:
|
||||
en_US: Natural
|
||||
zh_Hans: 自然
|
||||
pt_BR: Natural
|
||||
default: vivid
|
||||
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
@ -5,6 +6,7 @@ from typing import Any
|
||||
import requests
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class FirecrawlApp:
|
||||
def __init__(self, api_key: str | None = None, base_url: str | None = None):
|
||||
@ -48,6 +50,7 @@ class FirecrawlApp:
|
||||
headers = self._prepare_headers()
|
||||
data = {'url': url, **kwargs}
|
||||
response = self._request('POST', endpoint, data, headers)
|
||||
logger.debug(f"Sent request to {endpoint=} body={data}")
|
||||
if response is None:
|
||||
raise HTTPError("Failed to scrape URL after multiple retries")
|
||||
return response
|
||||
@ -57,6 +60,7 @@ class FirecrawlApp:
|
||||
headers = self._prepare_headers()
|
||||
data = {'query': query, **kwargs}
|
||||
response = self._request('POST', endpoint, data, headers)
|
||||
logger.debug(f"Sent request to {endpoint=} body={data}")
|
||||
if response is None:
|
||||
raise HTTPError("Failed to perform search after multiple retries")
|
||||
return response
|
||||
@ -66,8 +70,9 @@ class FirecrawlApp:
|
||||
):
|
||||
endpoint = f'{self.base_url}/v0/crawl'
|
||||
headers = self._prepare_headers(idempotency_key)
|
||||
data = {'url': url, **kwargs}
|
||||
data = {'url': url, **kwargs['params']}
|
||||
response = self._request('POST', endpoint, data, headers)
|
||||
logger.debug(f"Sent request to {endpoint=} body={data}")
|
||||
if response is None:
|
||||
raise HTTPError("Failed to initiate crawl after multiple retries")
|
||||
job_id: str = response['jobId']
|
||||
|
||||
@ -238,7 +238,7 @@ class ApiTool(Tool):
|
||||
return int(value)
|
||||
elif property['type'] == 'number':
|
||||
# check if it is a float
|
||||
if '.' in value:
|
||||
if '.' in str(value):
|
||||
return float(value)
|
||||
else:
|
||||
return int(value)
|
||||
|
||||
@ -5,22 +5,34 @@ from pydantic import BaseModel
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
|
||||
class Condition(BaseModel):
|
||||
"""
|
||||
Condition entity
|
||||
"""
|
||||
variable_selector: list[str]
|
||||
comparison_operator: Literal[
|
||||
# for string or array
|
||||
"contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty",
|
||||
# for number
|
||||
"=", "≠", ">", "<", "≥", "≤", "null", "not null"
|
||||
]
|
||||
value: Optional[str] = None
|
||||
|
||||
|
||||
class IfElseNodeData(BaseNodeData):
|
||||
"""
|
||||
Answer Node Data.
|
||||
"""
|
||||
class Condition(BaseModel):
|
||||
"""
|
||||
Condition entity
|
||||
"""
|
||||
variable_selector: list[str]
|
||||
comparison_operator: Literal[
|
||||
# for string or array
|
||||
"contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty",
|
||||
# for number
|
||||
"=", "≠", ">", "<", "≥", "≤", "null", "not null"
|
||||
]
|
||||
value: Optional[str] = None
|
||||
|
||||
logical_operator: Literal["and", "or"] = "and"
|
||||
conditions: list[Condition]
|
||||
class Case(BaseModel):
|
||||
"""
|
||||
Case entity representing a single logical condition group
|
||||
"""
|
||||
case_id: str
|
||||
logical_operator: Literal["and", "or"]
|
||||
conditions: list[Condition]
|
||||
|
||||
logical_operator: Optional[Literal["and", "or"]] = "and"
|
||||
conditions: Optional[list[Condition]] = None
|
||||
|
||||
cases: Optional[list[Case]] = None
|
||||
|
||||
@ -4,7 +4,8 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.if_else.entities import IfElseNodeData
|
||||
from core.workflow.nodes.if_else.entities import Condition, IfElseNodeData
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
@ -29,68 +30,48 @@ class IfElseNode(BaseNode):
|
||||
"condition_results": []
|
||||
}
|
||||
|
||||
input_conditions = []
|
||||
final_result = False
|
||||
selected_case_id = None
|
||||
try:
|
||||
logical_operator = node_data.logical_operator
|
||||
input_conditions = []
|
||||
for condition in node_data.conditions:
|
||||
actual_value = variable_pool.get_variable_value(
|
||||
variable_selector=condition.variable_selector
|
||||
# Check if the new cases structure is used
|
||||
if node_data.cases:
|
||||
for case in node_data.cases:
|
||||
input_conditions, group_result = self.process_conditions(variable_pool, case.conditions)
|
||||
# Apply the logical operator for the current case
|
||||
final_result = all(group_result) if case.logical_operator == "and" else any(group_result)
|
||||
|
||||
process_datas["condition_results"].append(
|
||||
{
|
||||
"group": case.model_dump(),
|
||||
"results": group_result,
|
||||
"final_result": final_result,
|
||||
}
|
||||
)
|
||||
|
||||
# Break if a case passes (logical short-circuit)
|
||||
if final_result:
|
||||
selected_case_id = case.case_id # Capture the ID of the passing case
|
||||
break
|
||||
|
||||
else:
|
||||
# Fallback to old structure if cases are not defined
|
||||
input_conditions, group_result = self.process_conditions(variable_pool, node_data.conditions)
|
||||
|
||||
final_result = all(group_result) if node_data.logical_operator == "and" else any(group_result)
|
||||
|
||||
selected_case_id = "true" if final_result else "false"
|
||||
|
||||
process_datas["condition_results"].append(
|
||||
{
|
||||
"group": "default",
|
||||
"results": group_result,
|
||||
"final_result": final_result
|
||||
}
|
||||
)
|
||||
|
||||
expected_value = condition.value
|
||||
|
||||
input_conditions.append({
|
||||
"actual_value": actual_value,
|
||||
"expected_value": expected_value,
|
||||
"comparison_operator": condition.comparison_operator
|
||||
})
|
||||
|
||||
node_inputs["conditions"] = input_conditions
|
||||
|
||||
for input_condition in input_conditions:
|
||||
actual_value = input_condition["actual_value"]
|
||||
expected_value = input_condition["expected_value"]
|
||||
comparison_operator = input_condition["comparison_operator"]
|
||||
|
||||
if comparison_operator == "contains":
|
||||
compare_result = self._assert_contains(actual_value, expected_value)
|
||||
elif comparison_operator == "not contains":
|
||||
compare_result = self._assert_not_contains(actual_value, expected_value)
|
||||
elif comparison_operator == "start with":
|
||||
compare_result = self._assert_start_with(actual_value, expected_value)
|
||||
elif comparison_operator == "end with":
|
||||
compare_result = self._assert_end_with(actual_value, expected_value)
|
||||
elif comparison_operator == "is":
|
||||
compare_result = self._assert_is(actual_value, expected_value)
|
||||
elif comparison_operator == "is not":
|
||||
compare_result = self._assert_is_not(actual_value, expected_value)
|
||||
elif comparison_operator == "empty":
|
||||
compare_result = self._assert_empty(actual_value)
|
||||
elif comparison_operator == "not empty":
|
||||
compare_result = self._assert_not_empty(actual_value)
|
||||
elif comparison_operator == "=":
|
||||
compare_result = self._assert_equal(actual_value, expected_value)
|
||||
elif comparison_operator == "≠":
|
||||
compare_result = self._assert_not_equal(actual_value, expected_value)
|
||||
elif comparison_operator == ">":
|
||||
compare_result = self._assert_greater_than(actual_value, expected_value)
|
||||
elif comparison_operator == "<":
|
||||
compare_result = self._assert_less_than(actual_value, expected_value)
|
||||
elif comparison_operator == "≥":
|
||||
compare_result = self._assert_greater_than_or_equal(actual_value, expected_value)
|
||||
elif comparison_operator == "≤":
|
||||
compare_result = self._assert_less_than_or_equal(actual_value, expected_value)
|
||||
elif comparison_operator == "null":
|
||||
compare_result = self._assert_null(actual_value)
|
||||
elif comparison_operator == "not null":
|
||||
compare_result = self._assert_not_null(actual_value)
|
||||
else:
|
||||
continue
|
||||
|
||||
process_datas["condition_results"].append({
|
||||
**input_condition,
|
||||
"result": compare_result
|
||||
})
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
@ -99,21 +80,102 @@ class IfElseNode(BaseNode):
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
if logical_operator == "and":
|
||||
compare_result = False not in [condition["result"] for condition in process_datas["condition_results"]]
|
||||
else:
|
||||
compare_result = True in [condition["result"] for condition in process_datas["condition_results"]]
|
||||
outputs = {"result": final_result, "selected_case_id": selected_case_id}
|
||||
|
||||
return NodeRunResult(
|
||||
data = NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=node_inputs,
|
||||
process_data=process_datas,
|
||||
edge_source_handle="false" if not compare_result else "true",
|
||||
outputs={
|
||||
"result": compare_result
|
||||
}
|
||||
edge_source_handle=selected_case_id if selected_case_id else "false", # Use case ID or 'default'
|
||||
outputs=outputs
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
def evaluate_condition(
|
||||
self, actual_value: Optional[str | list], expected_value: str, comparison_operator: str
|
||||
) -> bool:
|
||||
"""
|
||||
Evaluate condition
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:param comparison_operator: comparison operator
|
||||
|
||||
:return: bool
|
||||
"""
|
||||
if comparison_operator == "contains":
|
||||
return self._assert_contains(actual_value, expected_value)
|
||||
elif comparison_operator == "not contains":
|
||||
return self._assert_not_contains(actual_value, expected_value)
|
||||
elif comparison_operator == "start with":
|
||||
return self._assert_start_with(actual_value, expected_value)
|
||||
elif comparison_operator == "end with":
|
||||
return self._assert_end_with(actual_value, expected_value)
|
||||
elif comparison_operator == "is":
|
||||
return self._assert_is(actual_value, expected_value)
|
||||
elif comparison_operator == "is not":
|
||||
return self._assert_is_not(actual_value, expected_value)
|
||||
elif comparison_operator == "empty":
|
||||
return self._assert_empty(actual_value)
|
||||
elif comparison_operator == "not empty":
|
||||
return self._assert_not_empty(actual_value)
|
||||
elif comparison_operator == "=":
|
||||
return self._assert_equal(actual_value, expected_value)
|
||||
elif comparison_operator == "≠":
|
||||
return self._assert_not_equal(actual_value, expected_value)
|
||||
elif comparison_operator == ">":
|
||||
return self._assert_greater_than(actual_value, expected_value)
|
||||
elif comparison_operator == "<":
|
||||
return self._assert_less_than(actual_value, expected_value)
|
||||
elif comparison_operator == "≥":
|
||||
return self._assert_greater_than_or_equal(actual_value, expected_value)
|
||||
elif comparison_operator == "≤":
|
||||
return self._assert_less_than_or_equal(actual_value, expected_value)
|
||||
elif comparison_operator == "null":
|
||||
return self._assert_null(actual_value)
|
||||
elif comparison_operator == "not null":
|
||||
return self._assert_not_null(actual_value)
|
||||
else:
|
||||
raise ValueError(f"Invalid comparison operator: {comparison_operator}")
|
||||
|
||||
def process_conditions(self, variable_pool: VariablePool, conditions: list[Condition]):
|
||||
input_conditions = []
|
||||
group_result = []
|
||||
|
||||
for condition in conditions:
|
||||
actual_value = variable_pool.get_variable_value(
|
||||
variable_selector=condition.variable_selector
|
||||
)
|
||||
|
||||
if condition.value is not None:
|
||||
variable_template_parser = VariableTemplateParser(template=condition.value)
|
||||
expected_value = variable_template_parser.extract_variable_selectors()
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
if variable_selectors:
|
||||
for variable_selector in variable_selectors:
|
||||
value = variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector
|
||||
)
|
||||
expected_value = variable_template_parser.format({variable_selector.variable: value})
|
||||
else:
|
||||
expected_value = condition.value
|
||||
else:
|
||||
expected_value = None
|
||||
|
||||
comparison_operator = condition.comparison_operator
|
||||
input_conditions.append(
|
||||
{
|
||||
"actual_value": actual_value,
|
||||
"expected_value": expected_value,
|
||||
"comparison_operator": comparison_operator
|
||||
}
|
||||
)
|
||||
|
||||
result = self.evaluate_condition(actual_value, expected_value, comparison_operator)
|
||||
group_result.append(result)
|
||||
|
||||
return input_conditions, group_result
|
||||
|
||||
def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert contains
|
||||
|
||||
@ -7,4 +7,5 @@ def handle(sender, **kwargs):
|
||||
document_id = sender
|
||||
dataset_id = kwargs.get('dataset_id')
|
||||
doc_form = kwargs.get('doc_form')
|
||||
clean_document_task.delay(document_id, dataset_id, doc_form)
|
||||
file_id = kwargs.get('file_id')
|
||||
clean_document_task.delay(document_id, dataset_id, doc_form, file_id)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user