mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 17:38:04 +08:00
Merge branch 'main' into feat/support-extractor-tools
This commit is contained in:
@ -202,6 +202,20 @@ TIDB_VECTOR_USER=xxx.root
|
||||
TIDB_VECTOR_PASSWORD=xxxxxx
|
||||
TIDB_VECTOR_DATABASE=dify
|
||||
|
||||
# Tidb on qdrant configuration
|
||||
TIDB_ON_QDRANT_URL=http://127.0.0.1
|
||||
TIDB_ON_QDRANT_API_KEY=dify
|
||||
TIDB_ON_QDRANT_CLIENT_TIMEOUT=20
|
||||
TIDB_ON_QDRANT_GRPC_ENABLED=false
|
||||
TIDB_ON_QDRANT_GRPC_PORT=6334
|
||||
TIDB_PUBLIC_KEY=dify
|
||||
TIDB_PRIVATE_KEY=dify
|
||||
TIDB_API_URL=http://127.0.0.1
|
||||
TIDB_IAM_API_URL=http://127.0.0.1
|
||||
TIDB_REGION=regions/aws-us-east-1
|
||||
TIDB_PROJECT_ID=dify
|
||||
TIDB_SPEND_LIMIT=100
|
||||
|
||||
# Chroma configuration
|
||||
CHROMA_HOST=127.0.0.1
|
||||
CHROMA_PORT=8000
|
||||
@ -249,6 +263,14 @@ VIKINGDB_SCHEMA=http
|
||||
VIKINGDB_CONNECTION_TIMEOUT=30
|
||||
VIKINGDB_SOCKET_TIMEOUT=30
|
||||
|
||||
# OceanBase Vector configuration
|
||||
OCEANBASE_VECTOR_HOST=127.0.0.1
|
||||
OCEANBASE_VECTOR_PORT=2881
|
||||
OCEANBASE_VECTOR_USER=root@test
|
||||
OCEANBASE_VECTOR_PASSWORD=
|
||||
OCEANBASE_VECTOR_DATABASE=test
|
||||
OCEANBASE_MEMORY_LIMIT=6G
|
||||
|
||||
# Upload configuration
|
||||
UPLOAD_FILE_SIZE_LIMIT=15
|
||||
UPLOAD_FILE_BATCH_LIMIT=5
|
||||
|
||||
@ -55,7 +55,12 @@ RUN apt-get update \
|
||||
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
|
||||
&& apt-get update \
|
||||
# For Security
|
||||
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.3-1 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-6 libsqlite3-0=3.46.1-1 \
|
||||
&& apt-get install -y --no-install-recommends expat=2.6.3-2 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-6 libsqlite3-0=3.46.1-1 \
|
||||
&& if [ "$(dpkg --print-architecture)" = "amd64" ]; then \
|
||||
apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1+b1; \
|
||||
else \
|
||||
apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1; \
|
||||
fi \
|
||||
# install a chinese font to support the use of tools like matplotlib
|
||||
&& apt-get install -y fonts-noto-cjk \
|
||||
&& apt-get autoremove -y \
|
||||
|
||||
@ -279,6 +279,7 @@ def migrate_knowledge_vector_database():
|
||||
VectorType.VIKINGDB,
|
||||
VectorType.UPSTASH,
|
||||
VectorType.COUCHBASE,
|
||||
VectorType.OCEANBASE,
|
||||
}
|
||||
page = 1
|
||||
while True:
|
||||
|
||||
@ -16,11 +16,13 @@ from configs.middleware.storage.supabase_storage_config import SupabaseStorageCo
|
||||
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
||||
from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
|
||||
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
|
||||
from configs.middleware.vdb.baidu_vector_config import BaiduVectorDBConfig
|
||||
from configs.middleware.vdb.chroma_config import ChromaConfig
|
||||
from configs.middleware.vdb.couchbase_config import CouchbaseConfig
|
||||
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
|
||||
from configs.middleware.vdb.milvus_config import MilvusConfig
|
||||
from configs.middleware.vdb.myscale_config import MyScaleConfig
|
||||
from configs.middleware.vdb.oceanbase_config import OceanBaseVectorConfig
|
||||
from configs.middleware.vdb.opensearch_config import OpenSearchConfig
|
||||
from configs.middleware.vdb.oracle_config import OracleConfig
|
||||
from configs.middleware.vdb.pgvector_config import PGVectorConfig
|
||||
@ -257,5 +259,7 @@ class MiddlewareConfig(
|
||||
VikingDBConfig,
|
||||
UpstashConfig,
|
||||
TidbOnQdrantConfig,
|
||||
OceanBaseVectorConfig,
|
||||
BaiduVectorDBConfig,
|
||||
):
|
||||
pass
|
||||
|
||||
35
api/configs/middleware/vdb/oceanbase_config.py
Normal file
35
api/configs/middleware/vdb/oceanbase_config.py
Normal file
@ -0,0 +1,35 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class OceanBaseVectorConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for OceanBase Vector database
|
||||
"""
|
||||
|
||||
OCEANBASE_VECTOR_HOST: Optional[str] = Field(
|
||||
description="Hostname or IP address of the OceanBase Vector server (e.g. 'localhost')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
OCEANBASE_VECTOR_PORT: Optional[PositiveInt] = Field(
|
||||
description="Port number on which the OceanBase Vector server is listening (default is 2881)",
|
||||
default=2881,
|
||||
)
|
||||
|
||||
OCEANBASE_VECTOR_USER: Optional[str] = Field(
|
||||
description="Username for authenticating with the OceanBase Vector database",
|
||||
default=None,
|
||||
)
|
||||
|
||||
OCEANBASE_VECTOR_PASSWORD: Optional[str] = Field(
|
||||
description="Password for authenticating with the OceanBase Vector database",
|
||||
default=None,
|
||||
)
|
||||
|
||||
OCEANBASE_VECTOR_DATABASE: Optional[str] = Field(
|
||||
description="Name of the OceanBase Vector database to connect to",
|
||||
default=None,
|
||||
)
|
||||
@ -63,3 +63,8 @@ class TidbOnQdrantConfig(BaseSettings):
|
||||
description="Tidb project id",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TIDB_SPEND_LIMIT: Optional[int] = Field(
|
||||
description="Tidb spend limit",
|
||||
default=100,
|
||||
)
|
||||
|
||||
@ -628,6 +628,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.BAIDU
|
||||
| VectorType.VIKINGDB
|
||||
| VectorType.UPSTASH
|
||||
| VectorType.OCEANBASE
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
@ -669,6 +670,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
| VectorType.BAIDU
|
||||
| VectorType.VIKINGDB
|
||||
| VectorType.UPSTASH
|
||||
| VectorType.OCEANBASE
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
|
||||
@ -3,6 +3,7 @@ import logging
|
||||
|
||||
import requests
|
||||
from flask_restful import Resource, reqparse
|
||||
from packaging import version
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
@ -47,43 +48,15 @@ class VersionApi(Resource):
|
||||
|
||||
|
||||
def _has_new_version(*, latest_version: str, current_version: str) -> bool:
|
||||
def parse_version(version: str) -> tuple:
|
||||
# Split version into parts and pre-release suffix if any
|
||||
parts = version.split("-")
|
||||
version_parts = parts[0].split(".")
|
||||
pre_release = parts[1] if len(parts) > 1 else None
|
||||
try:
|
||||
latest = version.parse(latest_version)
|
||||
current = version.parse(current_version)
|
||||
|
||||
# Validate version format
|
||||
if len(version_parts) != 3:
|
||||
raise ValueError(f"Invalid version format: {version}")
|
||||
|
||||
try:
|
||||
# Convert version parts to integers
|
||||
major, minor, patch = map(int, version_parts)
|
||||
return (major, minor, patch, pre_release)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid version format: {version}")
|
||||
|
||||
latest = parse_version(latest_version)
|
||||
current = parse_version(current_version)
|
||||
|
||||
# Compare major, minor, and patch versions
|
||||
for latest_part, current_part in zip(latest[:3], current[:3]):
|
||||
if latest_part > current_part:
|
||||
return True
|
||||
elif latest_part < current_part:
|
||||
return False
|
||||
|
||||
# If versions are equal, check pre-release suffixes
|
||||
if latest[3] is None and current[3] is not None:
|
||||
return True
|
||||
elif latest[3] is not None and current[3] is None:
|
||||
# Compare versions
|
||||
return latest > current
|
||||
except version.InvalidVersion:
|
||||
logging.warning(f"Invalid version format: latest={latest_version}, current={current_version}")
|
||||
return False
|
||||
elif latest[3] is not None and current[3] is not None:
|
||||
# Simple string comparison for pre-release versions
|
||||
return latest[3] > current[3]
|
||||
|
||||
return False
|
||||
|
||||
|
||||
api.add_resource(VersionApi, "/version")
|
||||
|
||||
@ -230,7 +230,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": document.batch}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
@ -331,10 +331,26 @@ class DocumentIndexingStatusApi(DatasetApiResource):
|
||||
return data
|
||||
|
||||
|
||||
api.add_resource(DocumentAddByTextApi, "/datasets/<uuid:dataset_id>/document/create_by_text")
|
||||
api.add_resource(DocumentAddByFileApi, "/datasets/<uuid:dataset_id>/document/create_by_file")
|
||||
api.add_resource(DocumentUpdateByTextApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text")
|
||||
api.add_resource(DocumentUpdateByFileApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file")
|
||||
api.add_resource(
|
||||
DocumentAddByTextApi,
|
||||
"/datasets/<uuid:dataset_id>/document/create_by_text",
|
||||
"/datasets/<uuid:dataset_id>/document/create-by-text",
|
||||
)
|
||||
api.add_resource(
|
||||
DocumentAddByFileApi,
|
||||
"/datasets/<uuid:dataset_id>/document/create_by_file",
|
||||
"/datasets/<uuid:dataset_id>/document/create-by-file",
|
||||
)
|
||||
api.add_resource(
|
||||
DocumentUpdateByTextApi,
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text",
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-text",
|
||||
)
|
||||
api.add_resource(
|
||||
DocumentUpdateByFileApi,
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file",
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-file",
|
||||
)
|
||||
api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
api.add_resource(DocumentListApi, "/datasets/<uuid:dataset_id>/documents")
|
||||
api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status")
|
||||
|
||||
@ -14,4 +14,4 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
|
||||
return self.perform_hit_testing(dataset, args)
|
||||
|
||||
|
||||
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")
|
||||
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing", "/datasets/<uuid:dataset_id>/retrieve")
|
||||
|
||||
@ -37,6 +37,17 @@ def _get_max_tokens(default: int, min_val: int, max_val: int) -> ParameterRule:
|
||||
return rule
|
||||
|
||||
|
||||
def _get_o1_max_tokens(default: int, min_val: int, max_val: int) -> ParameterRule:
|
||||
rule = ParameterRule(
|
||||
name="max_completion_tokens",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.MAX_TOKENS],
|
||||
)
|
||||
rule.default = default
|
||||
rule.min = min_val
|
||||
rule.max = max_val
|
||||
return rule
|
||||
|
||||
|
||||
class AzureBaseModel(BaseModel):
|
||||
base_model_name: str
|
||||
entity: AIModelEntity
|
||||
@ -1098,14 +1109,6 @@ LLM_BASE_MODELS = [
|
||||
ModelPropertyKey.CONTEXT_SIZE: 128000,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name="temperature",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||
),
|
||||
ParameterRule(
|
||||
name="top_p",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||
),
|
||||
ParameterRule(
|
||||
name="response_format",
|
||||
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
|
||||
@ -1116,7 +1119,7 @@ LLM_BASE_MODELS = [
|
||||
required=False,
|
||||
options=["text", "json_object"],
|
||||
),
|
||||
_get_max_tokens(default=512, min_val=1, max_val=32768),
|
||||
_get_o1_max_tokens(default=512, min_val=1, max_val=32768),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=15.00,
|
||||
@ -1143,14 +1146,6 @@ LLM_BASE_MODELS = [
|
||||
ModelPropertyKey.CONTEXT_SIZE: 128000,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name="temperature",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||
),
|
||||
ParameterRule(
|
||||
name="top_p",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||
),
|
||||
ParameterRule(
|
||||
name="response_format",
|
||||
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
|
||||
@ -1161,7 +1156,7 @@ LLM_BASE_MODELS = [
|
||||
required=False,
|
||||
options=["text", "json_object"],
|
||||
),
|
||||
_get_max_tokens(default=512, min_val=1, max_val=65536),
|
||||
_get_o1_max_tokens(default=512, min_val=1, max_val=65536),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=3.00,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
model: hunyuan-standard-256k
|
||||
model: hunyuan-standard-256K
|
||||
label:
|
||||
zh_Hans: hunyuan-standard-256k
|
||||
en_US: hunyuan-standard-256k
|
||||
zh_Hans: hunyuan-standard-256K
|
||||
en_US: hunyuan-standard-256K
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 11 KiB |
@ -0,0 +1,3 @@
|
||||
<svg width="1200" height="925" viewBox="0 0 1200 925" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M780.152 250.999L907.882 462.174C907.882 462.174 880.925 510.854 867.43 535.21C834.845 594.039 764.171 612.49 710.442 508.333L420.376 0H0L459.926 803.307C552.303 964.663 787.366 964.663 879.743 803.307C989.874 610.952 1089.87 441.97 1200 249.646L1052.28 0H639.519L780.152 250.999Z" fill="#3366FF"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 417 B |
83
api/core/model_runtime/model_providers/vessl_ai/llm/llm.py
Normal file
83
api/core/model_runtime/model_providers/vessl_ai/llm/llm.py
Normal file
@ -0,0 +1,83 @@
|
||||
from decimal import Decimal
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
DefaultParameterName,
|
||||
FetchFrom,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
ParameterRule,
|
||||
ParameterType,
|
||||
PriceConfig,
|
||||
)
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||
|
||||
|
||||
class VesslAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
features = []
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
features=features,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: credentials.get("mode"),
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.TEMPERATURE.value,
|
||||
label=I18nObject(en_US="Temperature"),
|
||||
type=ParameterType.FLOAT,
|
||||
default=float(credentials.get("temperature", 0.7)),
|
||||
min=0,
|
||||
max=2,
|
||||
precision=2,
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.TOP_P.value,
|
||||
label=I18nObject(en_US="Top P"),
|
||||
type=ParameterType.FLOAT,
|
||||
default=float(credentials.get("top_p", 1)),
|
||||
min=0,
|
||||
max=1,
|
||||
precision=2,
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.TOP_K.value,
|
||||
label=I18nObject(en_US="Top K"),
|
||||
type=ParameterType.INT,
|
||||
default=int(credentials.get("top_k", 50)),
|
||||
min=-2147483647,
|
||||
max=2147483647,
|
||||
precision=0,
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.MAX_TOKENS.value,
|
||||
label=I18nObject(en_US="Max Tokens"),
|
||||
type=ParameterType.INT,
|
||||
default=512,
|
||||
min=1,
|
||||
max=int(credentials.get("max_tokens_to_sample", 4096)),
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=Decimal(credentials.get("input_price", 0)),
|
||||
output=Decimal(credentials.get("output_price", 0)),
|
||||
unit=Decimal(credentials.get("unit", 0)),
|
||||
currency=credentials.get("currency", "USD"),
|
||||
),
|
||||
)
|
||||
|
||||
if credentials["mode"] == "chat":
|
||||
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value
|
||||
elif credentials["mode"] == "completion":
|
||||
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
|
||||
else:
|
||||
raise ValueError(f"Unknown completion type {credentials['completion_type']}")
|
||||
|
||||
return entity
|
||||
10
api/core/model_runtime/model_providers/vessl_ai/vessl_ai.py
Normal file
10
api/core/model_runtime/model_providers/vessl_ai/vessl_ai.py
Normal file
@ -0,0 +1,10 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VesslAIProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
pass
|
||||
@ -0,0 +1,56 @@
|
||||
provider: vessl_ai
|
||||
label:
|
||||
en_US: vessl_ai
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
en_US: icon_l_en.png
|
||||
background: "#F1EFED"
|
||||
help:
|
||||
title:
|
||||
en_US: How to deploy VESSL AI LLM Model Endpoint
|
||||
url:
|
||||
en_US: https://docs.vessl.ai/guides/get-started/llama3-deployment
|
||||
supported_model_types:
|
||||
- llm
|
||||
configurate_methods:
|
||||
- customizable-model
|
||||
model_credential_schema:
|
||||
model:
|
||||
label:
|
||||
en_US: Model Name
|
||||
placeholder:
|
||||
en_US: Enter your model name
|
||||
credential_form_schemas:
|
||||
- variable: endpoint_url
|
||||
label:
|
||||
en_US: endpoint url
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
en_US: Enter the url of your endpoint url
|
||||
- variable: api_key
|
||||
required: true
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
placeholder:
|
||||
en_US: Enter your VESSL AI secret key
|
||||
- variable: mode
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
label:
|
||||
en_US: Completion mode
|
||||
type: select
|
||||
required: false
|
||||
default: chat
|
||||
placeholder:
|
||||
en_US: Select completion mode
|
||||
options:
|
||||
- value: completion
|
||||
label:
|
||||
en_US: Completion
|
||||
- value: chat
|
||||
label:
|
||||
en_US: Chat
|
||||
@ -115,6 +115,7 @@ class _CommonWenxin:
|
||||
"ernie-character-8k-0321": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k",
|
||||
"ernie-4.0-turbo-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k",
|
||||
"ernie-4.0-turbo-8k-preview": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview",
|
||||
"ernie-4.0-turbo-128k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-128k",
|
||||
"yi_34b_chat": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat",
|
||||
"embedding-v1": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1",
|
||||
"bge-large-en": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_en",
|
||||
|
||||
@ -0,0 +1,40 @@
|
||||
model: ernie-4.0-turbo-128k
|
||||
label:
|
||||
en_US: Ernie-4.0-turbo-128K
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0.1
|
||||
max: 1.0
|
||||
default: 0.8
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 2
|
||||
max: 4096
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
default: 1.0
|
||||
min: 1.0
|
||||
max: 2.0
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: disable_search
|
||||
label:
|
||||
zh_Hans: 禁用搜索
|
||||
en_US: Disable Search
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 禁用模型自行进行外部搜索。
|
||||
en_US: Disable the model to perform external search.
|
||||
required: false
|
||||
@ -34,6 +34,8 @@ class RetrievalService:
|
||||
reranking_mode: Optional[str] = "reranking_model",
|
||||
weights: Optional[dict] = None,
|
||||
):
|
||||
if not query:
|
||||
return []
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
return []
|
||||
|
||||
@ -3,11 +3,13 @@ import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pymochow import MochowClient
|
||||
from pymochow.auth.bce_credentials import BceCredentials
|
||||
from pymochow.configuration import Configuration
|
||||
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, TableState
|
||||
from pymochow.exception import ServerError
|
||||
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState
|
||||
from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex
|
||||
from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row
|
||||
|
||||
@ -116,6 +118,7 @@ class BaiduVector(BaseVector):
|
||||
self._db.table(self._collection_name).delete(filter=f"{key} = '{value}'")
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector]
|
||||
anns = AnnSearch(
|
||||
vector_field=self.field_vector,
|
||||
vector_floats=query_vector,
|
||||
@ -149,7 +152,13 @@ class BaiduVector(BaseVector):
|
||||
return docs
|
||||
|
||||
def delete(self) -> None:
|
||||
self._db.drop_table(table_name=self._collection_name)
|
||||
try:
|
||||
self._db.drop_table(table_name=self._collection_name)
|
||||
except ServerError as e:
|
||||
if e.code == ServerErrCode.TABLE_NOT_EXIST:
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
|
||||
def _init_client(self, config) -> MochowClient:
|
||||
config = Configuration(credentials=BceCredentials(config.account, config.api_key), endpoint=config.endpoint)
|
||||
@ -166,7 +175,14 @@ class BaiduVector(BaseVector):
|
||||
if exists:
|
||||
return self._client.database(self._client_config.database)
|
||||
else:
|
||||
return self._client.create_database(database_name=self._client_config.database)
|
||||
try:
|
||||
self._client.create_database(database_name=self._client_config.database)
|
||||
except ServerError as e:
|
||||
if e.code == ServerErrCode.DB_ALREADY_EXIST:
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
return
|
||||
|
||||
def _table_existed(self) -> bool:
|
||||
tables = self._db.list_table()
|
||||
@ -175,7 +191,7 @@ class BaiduVector(BaseVector):
|
||||
def _create_table(self, dimension: int) -> None:
|
||||
# Try to grab distributed lock and create table
|
||||
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
with redis_client.lock(lock_name, timeout=60):
|
||||
table_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
|
||||
if redis_client.get(table_exist_cache_key):
|
||||
return
|
||||
@ -238,15 +254,14 @@ class BaiduVector(BaseVector):
|
||||
description="Table for Dify",
|
||||
)
|
||||
|
||||
# Wait for table created
|
||||
while True:
|
||||
time.sleep(1)
|
||||
table = self._db.describe_table(self._collection_name)
|
||||
if table.state == TableState.NORMAL:
|
||||
break
|
||||
redis_client.set(table_exist_cache_key, 1, ex=3600)
|
||||
|
||||
# Wait for table created
|
||||
while True:
|
||||
time.sleep(1)
|
||||
table = self._db.describe_table(self._collection_name)
|
||||
if table.state == TableState.NORMAL:
|
||||
break
|
||||
|
||||
|
||||
class BaiduVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaiduVector:
|
||||
|
||||
0
api/core/rag/datasource/vdb/oceanbase/__init__.py
Normal file
0
api/core/rag/datasource/vdb/oceanbase/__init__.py
Normal file
209
api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
Normal file
209
api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
Normal file
@ -0,0 +1,209 @@
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pyobvector import VECTOR, ObVecClient
|
||||
from sqlalchemy import JSON, Column, String, func
|
||||
from sqlalchemy.dialects.mysql import LONGTEXT
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_OCEANBASE_HNSW_BUILD_PARAM = {"M": 16, "efConstruction": 256}
|
||||
DEFAULT_OCEANBASE_HNSW_SEARCH_PARAM = {"efSearch": 64}
|
||||
OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPE = "HNSW"
|
||||
DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE = "l2"
|
||||
|
||||
|
||||
class OceanBaseVectorConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
user: str
|
||||
password: str
|
||||
database: str
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["host"]:
|
||||
raise ValueError("config OCEANBASE_VECTOR_HOST is required")
|
||||
if not values["port"]:
|
||||
raise ValueError("config OCEANBASE_VECTOR_PORT is required")
|
||||
if not values["user"]:
|
||||
raise ValueError("config OCEANBASE_VECTOR_USER is required")
|
||||
if not values["database"]:
|
||||
raise ValueError("config OCEANBASE_VECTOR_DATABASE is required")
|
||||
return values
|
||||
|
||||
|
||||
class OceanBaseVector(BaseVector):
|
||||
def __init__(self, collection_name: str, config: OceanBaseVectorConfig):
|
||||
super().__init__(collection_name)
|
||||
self._config = config
|
||||
self._hnsw_ef_search = -1
|
||||
self._client = ObVecClient(
|
||||
uri=f"{self._config.host}:{self._config.port}",
|
||||
user=self._config.user,
|
||||
password=self._config.password,
|
||||
db_name=self._config.database,
|
||||
)
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.OCEANBASE
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
self._vec_dim = len(embeddings[0])
|
||||
self._create_collection()
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def _create_collection(self) -> None:
|
||||
lock_name = "vector_indexing_lock_" + self._collection_name
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = "vector_indexing_" + self._collection_name
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
|
||||
if self._client.check_table_exists(self._collection_name):
|
||||
return
|
||||
|
||||
self.delete()
|
||||
|
||||
cols = [
|
||||
Column("id", String(36), primary_key=True, autoincrement=False),
|
||||
Column("vector", VECTOR(self._vec_dim)),
|
||||
Column("text", LONGTEXT),
|
||||
Column("metadata", JSON),
|
||||
]
|
||||
vidx_params = self._client.prepare_index_params()
|
||||
vidx_params.add_index(
|
||||
field_name="vector",
|
||||
index_type=OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPE,
|
||||
index_name="vector_index",
|
||||
metric_type=DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE,
|
||||
params=DEFAULT_OCEANBASE_HNSW_BUILD_PARAM,
|
||||
)
|
||||
|
||||
self._client.create_table_with_index_params(
|
||||
table_name=self._collection_name,
|
||||
columns=cols,
|
||||
vidxs=vidx_params,
|
||||
)
|
||||
vals = []
|
||||
params = self._client.perform_raw_text_sql("SHOW PARAMETERS LIKE '%ob_vector_memory_limit_percentage%'")
|
||||
for row in params:
|
||||
val = int(row[6])
|
||||
vals.append(val)
|
||||
if len(vals) == 0:
|
||||
print("ob_vector_memory_limit_percentage not found in parameters.")
|
||||
exit(1)
|
||||
if any(val == 0 for val in vals):
|
||||
try:
|
||||
self._client.perform_raw_text_sql("ALTER SYSTEM SET ob_vector_memory_limit_percentage = 30")
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
"Failed to set ob_vector_memory_limit_percentage. "
|
||||
+ "Maybe the database user has insufficient privilege.",
|
||||
e,
|
||||
)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
ids = self._get_uuids(documents)
|
||||
for id, doc, emb in zip(ids, documents, embeddings):
|
||||
self._client.insert(
|
||||
table_name=self._collection_name,
|
||||
data={
|
||||
"id": id,
|
||||
"vector": emb,
|
||||
"text": doc.page_content,
|
||||
"metadata": doc.metadata,
|
||||
},
|
||||
)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
cur = self._client.get(table_name=self._collection_name, id=id)
|
||||
return cur.rowcount != 0
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
self._client.delete(table_name=self._collection_name, ids=ids)
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
|
||||
cur = self._client.get(
|
||||
table_name=self._collection_name,
|
||||
where_clause=f"metadata->>'$.{key}' = '{value}'",
|
||||
output_column_name=["id"],
|
||||
)
|
||||
return [row[0] for row in cur]
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
self.delete_by_ids(ids)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return []
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
ef_search = kwargs.get("ef_search", self._hnsw_ef_search)
|
||||
if ef_search != self._hnsw_ef_search:
|
||||
self._client.set_ob_hnsw_ef_search(ef_search)
|
||||
self._hnsw_ef_search = ef_search
|
||||
topk = kwargs.get("top_k", 10)
|
||||
cur = self._client.ann_search(
|
||||
table_name=self._collection_name,
|
||||
vec_column_name="vector",
|
||||
vec_data=query_vector,
|
||||
topk=topk,
|
||||
distance_func=func.l2_distance,
|
||||
output_column_names=["text", "metadata"],
|
||||
with_dist=True,
|
||||
)
|
||||
docs = []
|
||||
for text, metadata, distance in cur:
|
||||
metadata = json.loads(metadata)
|
||||
metadata["score"] = 1 - distance / math.sqrt(2)
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=text,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
return docs
|
||||
|
||||
def delete(self) -> None:
|
||||
self._client.drop_table_if_exist(self._collection_name)
|
||||
|
||||
|
||||
class OceanBaseVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
attributes: list,
|
||||
embeddings: Embeddings,
|
||||
) -> BaseVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OCEANBASE, collection_name))
|
||||
return OceanBaseVector(
|
||||
collection_name,
|
||||
OceanBaseVectorConfig(
|
||||
host=dify_config.OCEANBASE_VECTOR_HOST,
|
||||
port=dify_config.OCEANBASE_VECTOR_PORT,
|
||||
user=dify_config.OCEANBASE_VECTOR_USER,
|
||||
password=(dify_config.OCEANBASE_VECTOR_PASSWORD or ""),
|
||||
database=dify_config.OCEANBASE_VECTOR_DATABASE,
|
||||
),
|
||||
)
|
||||
@ -4,6 +4,7 @@ import uuid
|
||||
import requests
|
||||
from requests.auth import HTTPDigestAuth
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import TidbAuthBinding
|
||||
@ -36,7 +37,7 @@ class TidbService:
|
||||
}
|
||||
|
||||
spending_limit = {
|
||||
"monthly": 100,
|
||||
"monthly": dify_config.TIDB_SPEND_LIMIT,
|
||||
}
|
||||
password = str(uuid.uuid4()).replace("-", "")[:16]
|
||||
display_name = str(uuid.uuid4()).replace("-", "")[:16]
|
||||
@ -208,7 +209,7 @@ class TidbService:
|
||||
}
|
||||
|
||||
spending_limit = {
|
||||
"monthly": 10,
|
||||
"monthly": dify_config.TIDB_SPEND_LIMIT,
|
||||
}
|
||||
password = str(uuid.uuid4()).replace("-", "")[:16]
|
||||
display_name = str(uuid.uuid4()).replace("-", "")
|
||||
|
||||
@ -134,6 +134,10 @@ class Vector:
|
||||
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import TidbOnQdrantVectorFactory
|
||||
|
||||
return TidbOnQdrantVectorFactory
|
||||
case VectorType.OCEANBASE:
|
||||
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory
|
||||
|
||||
return OceanBaseVectorFactory
|
||||
case _:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
|
||||
@ -21,3 +21,4 @@ class VectorType(str, Enum):
|
||||
VIKINGDB = "vikingdb"
|
||||
UPSTASH = "upstash"
|
||||
TIDB_ON_QDRANT = "tidb_on_qdrant"
|
||||
OCEANBASE = "oceanbase"
|
||||
|
||||
@ -27,18 +27,17 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
:return:
|
||||
"""
|
||||
docs = []
|
||||
doc_id = []
|
||||
doc_id = set()
|
||||
unique_documents = []
|
||||
dify_documents = [item for item in documents if item.provider == "dify"]
|
||||
external_documents = [item for item in documents if item.provider == "external"]
|
||||
for document in dify_documents:
|
||||
if document.metadata["doc_id"] not in doc_id:
|
||||
doc_id.append(document.metadata["doc_id"])
|
||||
for document in documents:
|
||||
if document.provider == "dify" and document.metadata["doc_id"] not in doc_id:
|
||||
doc_id.add(document.metadata["doc_id"])
|
||||
docs.append(document.page_content)
|
||||
unique_documents.append(document)
|
||||
for document in external_documents:
|
||||
docs.append(document.page_content)
|
||||
unique_documents.append(document)
|
||||
elif document.provider == "external":
|
||||
if document not in unique_documents:
|
||||
docs.append(document.page_content)
|
||||
unique_documents.append(document)
|
||||
|
||||
documents = unique_documents
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class RerankMode(Enum):
|
||||
class RerankMode(str, Enum):
|
||||
RERANKING_MODEL = "reranking_model"
|
||||
WEIGHTED_SCORE = "weighted_score"
|
||||
|
||||
@ -22,6 +22,7 @@ from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaK
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.rerank.rerank_type import RerankMode
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
|
||||
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
|
||||
@ -361,10 +362,39 @@ class DatasetRetrieval:
|
||||
reranking_enable: bool = True,
|
||||
message_id: Optional[str] = None,
|
||||
):
|
||||
if not available_datasets:
|
||||
return []
|
||||
threads = []
|
||||
all_documents = []
|
||||
dataset_ids = [dataset.id for dataset in available_datasets]
|
||||
index_type = None
|
||||
index_type_check = all(
|
||||
item.indexing_technique == available_datasets[0].indexing_technique for item in available_datasets
|
||||
)
|
||||
if not index_type_check and (not reranking_enable or reranking_mode != RerankMode.RERANKING_MODEL):
|
||||
raise ValueError(
|
||||
"The configured knowledge base list have different indexing technique, please set reranking model."
|
||||
)
|
||||
index_type = available_datasets[0].indexing_technique
|
||||
if index_type == "high_quality":
|
||||
embedding_model_check = all(
|
||||
item.embedding_model == available_datasets[0].embedding_model for item in available_datasets
|
||||
)
|
||||
embedding_model_provider_check = all(
|
||||
item.embedding_model_provider == available_datasets[0].embedding_model_provider
|
||||
for item in available_datasets
|
||||
)
|
||||
if (
|
||||
reranking_enable
|
||||
and reranking_mode == "weighted_score"
|
||||
and (not embedding_model_check or not embedding_model_provider_check)
|
||||
):
|
||||
raise ValueError(
|
||||
"The configured knowledge base list have different embedding model, please set reranking model."
|
||||
)
|
||||
if reranking_enable and reranking_mode == RerankMode.WEIGHTED_SCORE:
|
||||
weights["vector_setting"]["embedding_provider_name"] = available_datasets[0].embedding_model_provider
|
||||
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
|
||||
|
||||
for dataset in available_datasets:
|
||||
index_type = dataset.indexing_technique
|
||||
retrieval_thread = threading.Thread(
|
||||
|
||||
@ -33,7 +33,9 @@ class BarChartTool(BuiltinTool):
|
||||
if axis:
|
||||
axis = [label[:10] + "..." if len(label) > 10 else label for label in axis]
|
||||
ax.set_xticklabels(axis, rotation=45, ha="right")
|
||||
ax.bar(axis, data)
|
||||
# ensure all labels, including duplicates, are correctly displayed
|
||||
ax.bar(range(len(data)), data)
|
||||
ax.set_xticks(range(len(data)))
|
||||
else:
|
||||
ax.bar(range(len(data)), data)
|
||||
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import random
|
||||
import uuid
|
||||
@ -8,7 +6,7 @@ import httpx
|
||||
from websocket import WebSocket
|
||||
from yarl import URL
|
||||
|
||||
from core.file.file_manager import _get_encoded_string
|
||||
from core.file.file_manager import download
|
||||
from core.file.models import File
|
||||
|
||||
|
||||
@ -29,8 +27,7 @@ class ComfyUiClient:
|
||||
return response.content
|
||||
|
||||
def upload_image(self, image_file: File) -> dict:
|
||||
image_content = base64.b64decode(_get_encoded_string(image_file))
|
||||
file = io.BytesIO(image_content)
|
||||
file = download(image_file)
|
||||
files = {"image": (image_file.filename, file, image_file.mime_type), "overwrite": "true"}
|
||||
res = httpx.post(str(self.base_url / "upload/image"), files=files)
|
||||
return res.json()
|
||||
@ -47,12 +44,7 @@ class ComfyUiClient:
|
||||
ws.connect(ws_address)
|
||||
return ws, client_id
|
||||
|
||||
def set_prompt(
|
||||
self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = "", image_name: str = ""
|
||||
) -> dict:
|
||||
"""
|
||||
find the first KSampler, then can find the prompt node through it.
|
||||
"""
|
||||
def set_prompt_by_ksampler(self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = "") -> dict:
|
||||
prompt = origin_prompt.copy()
|
||||
id_to_class_type = {id: details["class_type"] for id, details in prompt.items()}
|
||||
k_sampler = [key for key, value in id_to_class_type.items() if value == "KSampler"][0]
|
||||
@ -64,9 +56,20 @@ class ComfyUiClient:
|
||||
negative_input_id = prompt.get(k_sampler)["inputs"]["negative"][0]
|
||||
prompt.get(negative_input_id)["inputs"]["text"] = negative_prompt
|
||||
|
||||
if image_name != "":
|
||||
image_loader = [key for key, value in id_to_class_type.items() if value == "LoadImage"][0]
|
||||
prompt.get(image_loader)["inputs"]["image"] = image_name
|
||||
return prompt
|
||||
|
||||
def set_prompt_images_by_ids(self, origin_prompt: dict, image_names: list[str], image_ids: list[str]) -> dict:
|
||||
prompt = origin_prompt.copy()
|
||||
for index, image_node_id in enumerate(image_ids):
|
||||
prompt[image_node_id]["inputs"]["image"] = image_names[index]
|
||||
return prompt
|
||||
|
||||
def set_prompt_images_by_default(self, origin_prompt: dict, image_names: list[str]) -> dict:
|
||||
prompt = origin_prompt.copy()
|
||||
id_to_class_type = {id: details["class_type"] for id, details in prompt.items()}
|
||||
load_image_nodes = [key for key, value in id_to_class_type.items() if value == "LoadImage"]
|
||||
for load_image, image_name in zip(load_image_nodes, image_names):
|
||||
prompt.get(load_image)["inputs"]["image"] = image_name
|
||||
return prompt
|
||||
|
||||
def track_progress(self, prompt: dict, ws: WebSocket, prompt_id: str):
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from core.file import FileType
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.errors import ToolParameterValidationError
|
||||
from core.tools.provider.builtin.comfyui.tools.comfyui_client import ComfyUiClient
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
@ -10,19 +12,46 @@ class ComfyUIWorkflowTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
comfyui = ComfyUiClient(self.runtime.credentials["base_url"])
|
||||
|
||||
positive_prompt = tool_parameters.get("positive_prompt")
|
||||
negative_prompt = tool_parameters.get("negative_prompt")
|
||||
positive_prompt = tool_parameters.get("positive_prompt", "")
|
||||
negative_prompt = tool_parameters.get("negative_prompt", "")
|
||||
images = tool_parameters.get("images") or []
|
||||
workflow = tool_parameters.get("workflow_json")
|
||||
image_name = ""
|
||||
if image := tool_parameters.get("image"):
|
||||
image_names = []
|
||||
for image in images:
|
||||
if image.type != FileType.IMAGE:
|
||||
continue
|
||||
image_name = comfyui.upload_image(image).get("name")
|
||||
image_names.append(image_name)
|
||||
|
||||
set_prompt_with_ksampler = True
|
||||
if "{{positive_prompt}}" in workflow:
|
||||
set_prompt_with_ksampler = False
|
||||
workflow = workflow.replace("{{positive_prompt}}", positive_prompt)
|
||||
workflow = workflow.replace("{{negative_prompt}}", negative_prompt)
|
||||
|
||||
try:
|
||||
origin_prompt = json.loads(workflow)
|
||||
prompt = json.loads(workflow)
|
||||
except:
|
||||
return self.create_text_message("the Workflow JSON is not correct")
|
||||
|
||||
prompt = comfyui.set_prompt(origin_prompt, positive_prompt, negative_prompt, image_name)
|
||||
if set_prompt_with_ksampler:
|
||||
try:
|
||||
prompt = comfyui.set_prompt_by_ksampler(prompt, positive_prompt, negative_prompt)
|
||||
except:
|
||||
raise ToolParameterValidationError(
|
||||
"Failed set prompt with KSampler, try replace prompt to {{positive_prompt}} in your workflow json"
|
||||
)
|
||||
|
||||
if image_names:
|
||||
if image_ids := tool_parameters.get("image_ids"):
|
||||
image_ids = image_ids.split(",")
|
||||
try:
|
||||
prompt = comfyui.set_prompt_images_by_ids(prompt, image_names, image_ids)
|
||||
except:
|
||||
raise ToolParameterValidationError("the Image Node ID List not match your upload image files.")
|
||||
else:
|
||||
prompt = comfyui.set_prompt_images_by_default(prompt, image_names)
|
||||
|
||||
images = comfyui.generate_image_by_prompt(prompt)
|
||||
result = []
|
||||
for img in images:
|
||||
|
||||
@ -24,12 +24,12 @@ parameters:
|
||||
zh_Hans: 负面提示词
|
||||
llm_description: Negative prompt, you should describe the image you don't want to generate as a list of words as possible as detailed, the prompt must be written in English.
|
||||
form: llm
|
||||
- name: image
|
||||
type: file
|
||||
- name: images
|
||||
type: files
|
||||
label:
|
||||
en_US: Input Image
|
||||
en_US: Input Images
|
||||
zh_Hans: 输入的图片
|
||||
llm_description: The input image, used to transfer to the comfyui workflow to generate another image.
|
||||
llm_description: The input images, used to transfer to the comfyui workflow to generate another image.
|
||||
form: llm
|
||||
- name: workflow_json
|
||||
type: string
|
||||
@ -40,3 +40,15 @@ parameters:
|
||||
en_US: exported from ComfyUI workflow
|
||||
zh_Hans: 从ComfyUI的工作流中导出
|
||||
form: form
|
||||
- name: image_ids
|
||||
type: string
|
||||
label:
|
||||
en_US: Image Node ID List
|
||||
zh_Hans: 图片节点ID列表
|
||||
placeholder:
|
||||
en_US: Use commas to separate multiple node ID
|
||||
zh_Hans: 多个节点ID时使用半角逗号分隔
|
||||
human_description:
|
||||
en_US: When the workflow has multiple image nodes, enter the ID list of these nodes, and the images will be passed to ComfyUI in the order of the list.
|
||||
zh_Hans: 当工作流有多个图片节点时,输入这些节点的ID列表,图片将按列表顺序传给ComfyUI
|
||||
form: form
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
import mimetypes
|
||||
from collections.abc import Generator
|
||||
from os import listdir, path
|
||||
from threading import Lock
|
||||
from threading import Lock, Thread
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from configs import dify_config
|
||||
@ -647,4 +647,5 @@ class ToolManager:
|
||||
raise ValueError(f"provider type {provider_type} not found")
|
||||
|
||||
|
||||
ToolManager.load_builtin_providers_cache()
|
||||
# preload builtin tool providers
|
||||
Thread(target=ToolManager.load_builtin_providers_cache, name="pre_load_builtin_providers_cache", daemon=True).start()
|
||||
|
||||
@ -153,6 +153,7 @@ class AnswerStreamGeneratorRouter:
|
||||
NodeType.IF_ELSE,
|
||||
NodeType.QUESTION_CLASSIFIER,
|
||||
NodeType.ITERATION,
|
||||
NodeType.CONVERSATION_VARIABLE_ASSIGNER,
|
||||
}:
|
||||
answer_dependencies[answer_node_id].append(source_node_id)
|
||||
else:
|
||||
|
||||
@ -5,6 +5,7 @@ import json
|
||||
import docx
|
||||
import pandas as pd
|
||||
import pypdfium2
|
||||
import yaml
|
||||
from unstructured.partition.email import partition_email
|
||||
from unstructured.partition.epub import partition_epub
|
||||
from unstructured.partition.msg import partition_msg
|
||||
@ -101,6 +102,8 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
|
||||
return _extract_text_from_msg(file_content)
|
||||
case "application/json":
|
||||
return _extract_text_from_json(file_content)
|
||||
case "application/x-yaml" | "text/yaml":
|
||||
return _extract_text_from_yaml(file_content)
|
||||
case _:
|
||||
raise UnsupportedFileTypeError(f"Unsupported MIME type: {mime_type}")
|
||||
|
||||
@ -112,6 +115,8 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str)
|
||||
return _extract_text_from_plain_text(file_content)
|
||||
case ".json":
|
||||
return _extract_text_from_json(file_content)
|
||||
case ".yaml" | ".yml":
|
||||
return _extract_text_from_yaml(file_content)
|
||||
case ".pdf":
|
||||
return _extract_text_from_pdf(file_content)
|
||||
case ".doc" | ".docx":
|
||||
@ -149,6 +154,15 @@ def _extract_text_from_json(file_content: bytes) -> str:
|
||||
raise TextExtractionError(f"Failed to decode or parse JSON file: {e}") from e
|
||||
|
||||
|
||||
def _extract_text_from_yaml(file_content: bytes) -> str:
|
||||
"""Extract the content from yaml file"""
|
||||
try:
|
||||
yaml_data = yaml.safe_load_all(file_content.decode("utf-8"))
|
||||
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
|
||||
except (UnicodeDecodeError, yaml.YAMLError) as e:
|
||||
raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e
|
||||
|
||||
|
||||
def _extract_text_from_pdf(file_content: bytes) -> str:
|
||||
try:
|
||||
pdf_file = io.BytesIO(file_content)
|
||||
|
||||
@ -94,7 +94,7 @@ class Response:
|
||||
@property
|
||||
def is_file(self):
|
||||
content_type = self.content_type
|
||||
content_disposition = self.response.headers.get("Content-Disposition", "")
|
||||
content_disposition = self.response.headers.get("content-disposition", "")
|
||||
|
||||
return "attachment" in content_disposition or (
|
||||
not any(non_file in content_type for non_file in NON_FILE_CONTENT_TYPES)
|
||||
@ -103,7 +103,7 @@ class Response:
|
||||
|
||||
@property
|
||||
def content_type(self) -> str:
|
||||
return self.headers.get("Content-Type", "")
|
||||
return self.headers.get("content-type", "")
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
|
||||
@ -142,10 +142,11 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
Extract files from response
|
||||
"""
|
||||
files = []
|
||||
is_file = response.is_file
|
||||
content_type = response.content_type
|
||||
content = response.content
|
||||
|
||||
if content_type:
|
||||
if is_file and content_type:
|
||||
# extract filename from url
|
||||
filename = path.basename(url)
|
||||
# extract extension if possible
|
||||
|
||||
@ -327,7 +327,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
if variable is None:
|
||||
raise ValueError(f"Variable {variable_selector.variable} not found")
|
||||
if isinstance(variable, NoneSegment):
|
||||
continue
|
||||
inputs[variable_selector.variable] = ""
|
||||
inputs[variable_selector.variable] = variable.to_object()
|
||||
|
||||
memory = node_data.memory
|
||||
@ -349,13 +349,11 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||
if variable is None:
|
||||
return []
|
||||
if isinstance(variable, FileSegment):
|
||||
elif isinstance(variable, FileSegment):
|
||||
return [variable.value]
|
||||
if isinstance(variable, ArrayFileSegment):
|
||||
elif isinstance(variable, ArrayFileSegment):
|
||||
return variable.value
|
||||
# FIXME: Temporary fix for empty array,
|
||||
# all variables added to variable pool should be a Segment instance.
|
||||
if isinstance(variable, ArrayAnySegment) and len(variable.value) == 0:
|
||||
elif isinstance(variable, NoneSegment | ArrayAnySegment):
|
||||
return []
|
||||
raise ValueError(f"Invalid variable type: {type(variable)}")
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ from core.variables import (
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectSegment,
|
||||
ArrayObjectVariable,
|
||||
ArraySegment,
|
||||
ArrayStringSegment,
|
||||
ArrayStringVariable,
|
||||
FileSegment,
|
||||
@ -79,7 +80,7 @@ def build_segment(value: Any, /) -> Segment:
|
||||
if isinstance(value, list):
|
||||
items = [build_segment(item) for item in value]
|
||||
types = {item.value_type for item in items}
|
||||
if len(types) != 1:
|
||||
if len(types) != 1 or all(isinstance(item, ArraySegment) for item in items):
|
||||
return ArrayAnySegment(value=value)
|
||||
match types.pop():
|
||||
case SegmentType.STRING:
|
||||
|
||||
@ -121,6 +121,7 @@ conversation_fields = {
|
||||
"from_account_name": fields.String,
|
||||
"read_at": TimestampField,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"annotation": fields.Nested(annotation_fields, allow_null=True),
|
||||
"model_config": fields.Nested(simple_model_config_fields),
|
||||
"user_feedback_stats": fields.Nested(feedback_stat_fields),
|
||||
@ -182,6 +183,7 @@ conversation_detail_fields = {
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"annotated": fields.Boolean,
|
||||
"introduction": fields.String,
|
||||
"model_config": fields.Nested(model_config_fields),
|
||||
@ -197,6 +199,7 @@ simple_conversation_fields = {
|
||||
"status": fields.String,
|
||||
"introduction": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
}
|
||||
|
||||
conversation_infinite_scroll_pagination_fields = {
|
||||
|
||||
@ -396,7 +396,7 @@ class AppModelConfig(db.Model):
|
||||
"file_upload": self.file_upload_dict,
|
||||
}
|
||||
|
||||
def from_model_config_dict(self, model_config: dict):
|
||||
def from_model_config_dict(self, model_config: Mapping[str, Any]):
|
||||
self.opening_statement = model_config.get("opening_statement")
|
||||
self.suggested_questions = (
|
||||
json.dumps(model_config["suggested_questions"]) if model_config.get("suggested_questions") else None
|
||||
|
||||
23
api/poetry.lock
generated
23
api/poetry.lock
generated
@ -7269,6 +7269,22 @@ files = [
|
||||
ed25519 = ["PyNaCl (>=1.4.0)"]
|
||||
rsa = ["cryptography"]
|
||||
|
||||
[[package]]
|
||||
name = "pyobvector"
|
||||
version = "0.1.6"
|
||||
description = "A python SDK for OceanBase Vector Store, based on SQLAlchemy, compatible with Milvus API."
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
files = [
|
||||
{file = "pyobvector-0.1.6-py3-none-any.whl", hash = "sha256:0d700e865a85b4716b9a03384189e49288cd9d5f3cef88aed4740bc82d5fd136"},
|
||||
{file = "pyobvector-0.1.6.tar.gz", hash = "sha256:05551addcac8c596992d5e38b480c83ca3481c6cfc6f56a1a1bddfb2e6ae037e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
numpy = ">=1.26.0,<2.0.0"
|
||||
pymysql = ">=1.1.1,<2.0.0"
|
||||
sqlalchemy = ">=2.0.32,<3.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "pyopenssl"
|
||||
version = "24.2.1"
|
||||
@ -8677,6 +8693,11 @@ files = [
|
||||
{file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"},
|
||||
{file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"},
|
||||
{file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"},
|
||||
{file = "scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5"},
|
||||
{file = "scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908"},
|
||||
{file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3"},
|
||||
{file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12"},
|
||||
{file = "scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f"},
|
||||
{file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"},
|
||||
{file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"},
|
||||
{file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"},
|
||||
@ -10919,4 +10940,4 @@ cffi = ["cffi (>=1.11)"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<3.13"
|
||||
content-hash = "52552faf5f4823056eb48afe05349ab2f0e9a5bc42105211ccbbb54b59e27b59"
|
||||
content-hash = "ef927b98c33d704d680e08db0e5c7d9a4e05454c66fcd6a5f656a65eb08e886b"
|
||||
|
||||
@ -247,6 +247,7 @@ pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] }
|
||||
pgvector = "0.2.5"
|
||||
pymilvus = "~2.4.4"
|
||||
pymochow = "1.3.1"
|
||||
pyobvector = "~0.1.6"
|
||||
qdrant-client = "1.7.3"
|
||||
tcvectordb = "1.3.2"
|
||||
tidb-vector = "0.0.9"
|
||||
|
||||
3
api/services/app_dsl_service/__init__.py
Normal file
3
api/services/app_dsl_service/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .service import AppDslService
|
||||
|
||||
__all__ = ["AppDslService"]
|
||||
34
api/services/app_dsl_service/exc.py
Normal file
34
api/services/app_dsl_service/exc.py
Normal file
@ -0,0 +1,34 @@
|
||||
class DSLVersionNotSupportedError(ValueError):
|
||||
"""Raised when the imported DSL version is not supported by the current Dify version."""
|
||||
|
||||
|
||||
class InvalidYAMLFormatError(ValueError):
|
||||
"""Raised when the provided YAML format is invalid."""
|
||||
|
||||
|
||||
class MissingAppDataError(ValueError):
|
||||
"""Raised when the app data is missing in the provided DSL."""
|
||||
|
||||
|
||||
class InvalidAppModeError(ValueError):
|
||||
"""Raised when the app mode is invalid."""
|
||||
|
||||
|
||||
class MissingWorkflowDataError(ValueError):
|
||||
"""Raised when the workflow data is missing in the provided DSL."""
|
||||
|
||||
|
||||
class MissingModelConfigError(ValueError):
|
||||
"""Raised when the model config data is missing in the provided DSL."""
|
||||
|
||||
|
||||
class FileSizeLimitExceededError(ValueError):
|
||||
"""Raised when the file size exceeds the allowed limit."""
|
||||
|
||||
|
||||
class EmptyContentError(ValueError):
|
||||
"""Raised when the content fetched from the URL is empty."""
|
||||
|
||||
|
||||
class ContentDecodingError(ValueError):
|
||||
"""Raised when there is an error decoding the content."""
|
||||
@ -1,8 +1,11 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import yaml # type: ignore
|
||||
import yaml
|
||||
from packaging import version
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from events.app_event import app_model_config_was_updated, app_was_created
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
@ -11,6 +14,18 @@ from models.model import App, AppMode, AppModelConfig
|
||||
from models.workflow import Workflow
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
from .exc import (
|
||||
ContentDecodingError,
|
||||
DSLVersionNotSupportedError,
|
||||
EmptyContentError,
|
||||
FileSizeLimitExceededError,
|
||||
InvalidAppModeError,
|
||||
InvalidYAMLFormatError,
|
||||
MissingAppDataError,
|
||||
MissingModelConfigError,
|
||||
MissingWorkflowDataError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
current_dsl_version = "0.1.2"
|
||||
@ -30,32 +45,21 @@ class AppDslService:
|
||||
:param args: request args
|
||||
:param account: Account instance
|
||||
"""
|
||||
try:
|
||||
max_size = 10 * 1024 * 1024 # 10MB
|
||||
timeout = httpx.Timeout(10.0)
|
||||
with httpx.stream("GET", url.strip(), follow_redirects=True, timeout=timeout) as response:
|
||||
response.raise_for_status()
|
||||
total_size = 0
|
||||
content = b""
|
||||
for chunk in response.iter_bytes():
|
||||
total_size += len(chunk)
|
||||
if total_size > max_size:
|
||||
raise ValueError("File size exceeds the limit of 10MB")
|
||||
content += chunk
|
||||
except httpx.HTTPStatusError as http_err:
|
||||
raise ValueError(f"HTTP error occurred: {http_err}")
|
||||
except httpx.RequestError as req_err:
|
||||
raise ValueError(f"Request error occurred: {req_err}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch DSL from URL: {e}")
|
||||
max_size = 10 * 1024 * 1024 # 10MB
|
||||
response = ssrf_proxy.get(url.strip(), follow_redirects=True, timeout=(10, 10))
|
||||
response.raise_for_status()
|
||||
content = response.content
|
||||
|
||||
if len(content) > max_size:
|
||||
raise FileSizeLimitExceededError("File size exceeds the limit of 10MB")
|
||||
|
||||
if not content:
|
||||
raise ValueError("Empty content from url")
|
||||
raise EmptyContentError("Empty content from url")
|
||||
|
||||
try:
|
||||
data = content.decode("utf-8")
|
||||
except UnicodeDecodeError as e:
|
||||
raise ValueError(f"Error decoding content: {e}")
|
||||
raise ContentDecodingError(f"Error decoding content: {e}")
|
||||
|
||||
return cls.import_and_create_new_app(tenant_id, data, args, account)
|
||||
|
||||
@ -71,14 +75,14 @@ class AppDslService:
|
||||
try:
|
||||
import_data = yaml.safe_load(data)
|
||||
except yaml.YAMLError:
|
||||
raise ValueError("Invalid YAML format in data argument.")
|
||||
raise InvalidYAMLFormatError("Invalid YAML format in data argument.")
|
||||
|
||||
# check or repair dsl version
|
||||
import_data = cls._check_or_fix_dsl(import_data)
|
||||
import_data = _check_or_fix_dsl(import_data)
|
||||
|
||||
app_data = import_data.get("app")
|
||||
if not app_data:
|
||||
raise ValueError("Missing app in data argument")
|
||||
raise MissingAppDataError("Missing app in data argument")
|
||||
|
||||
# get app basic info
|
||||
name = args.get("name") or app_data.get("name")
|
||||
@ -90,11 +94,18 @@ class AppDslService:
|
||||
|
||||
# import dsl and create app
|
||||
app_mode = AppMode.value_of(app_data.get("mode"))
|
||||
|
||||
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow_data = import_data.get("workflow")
|
||||
if not workflow_data or not isinstance(workflow_data, dict):
|
||||
raise MissingWorkflowDataError(
|
||||
"Missing workflow in data argument when app mode is advanced-chat or workflow"
|
||||
)
|
||||
|
||||
app = cls._import_and_create_new_workflow_based_app(
|
||||
tenant_id=tenant_id,
|
||||
app_mode=app_mode,
|
||||
workflow_data=import_data.get("workflow"),
|
||||
workflow_data=workflow_data,
|
||||
account=account,
|
||||
name=name,
|
||||
description=description,
|
||||
@ -104,10 +115,16 @@ class AppDslService:
|
||||
use_icon_as_answer_icon=use_icon_as_answer_icon,
|
||||
)
|
||||
elif app_mode in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION}:
|
||||
model_config = import_data.get("model_config")
|
||||
if not model_config or not isinstance(model_config, dict):
|
||||
raise MissingModelConfigError(
|
||||
"Missing model_config in data argument when app mode is chat, agent-chat or completion"
|
||||
)
|
||||
|
||||
app = cls._import_and_create_new_model_config_based_app(
|
||||
tenant_id=tenant_id,
|
||||
app_mode=app_mode,
|
||||
model_config_data=import_data.get("model_config"),
|
||||
model_config_data=model_config,
|
||||
account=account,
|
||||
name=name,
|
||||
description=description,
|
||||
@ -117,7 +134,7 @@ class AppDslService:
|
||||
use_icon_as_answer_icon=use_icon_as_answer_icon,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid app mode")
|
||||
raise InvalidAppModeError("Invalid app mode")
|
||||
|
||||
return app
|
||||
|
||||
@ -132,26 +149,32 @@ class AppDslService:
|
||||
try:
|
||||
import_data = yaml.safe_load(data)
|
||||
except yaml.YAMLError:
|
||||
raise ValueError("Invalid YAML format in data argument.")
|
||||
raise InvalidYAMLFormatError("Invalid YAML format in data argument.")
|
||||
|
||||
# check or repair dsl version
|
||||
import_data = cls._check_or_fix_dsl(import_data)
|
||||
import_data = _check_or_fix_dsl(import_data)
|
||||
|
||||
app_data = import_data.get("app")
|
||||
if not app_data:
|
||||
raise ValueError("Missing app in data argument")
|
||||
raise MissingAppDataError("Missing app in data argument")
|
||||
|
||||
# import dsl and overwrite app
|
||||
app_mode = AppMode.value_of(app_data.get("mode"))
|
||||
if app_mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
raise ValueError("Only support import workflow in advanced-chat or workflow app.")
|
||||
raise InvalidAppModeError("Only support import workflow in advanced-chat or workflow app.")
|
||||
|
||||
if app_data.get("mode") != app_model.mode:
|
||||
raise ValueError(f"App mode {app_data.get('mode')} is not matched with current app mode {app_mode.value}")
|
||||
|
||||
workflow_data = import_data.get("workflow")
|
||||
if not workflow_data or not isinstance(workflow_data, dict):
|
||||
raise MissingWorkflowDataError(
|
||||
"Missing workflow in data argument when app mode is advanced-chat or workflow"
|
||||
)
|
||||
|
||||
return cls._import_and_overwrite_workflow_based_app(
|
||||
app_model=app_model,
|
||||
workflow_data=import_data.get("workflow"),
|
||||
workflow_data=workflow_data,
|
||||
account=account,
|
||||
)
|
||||
|
||||
@ -186,35 +209,12 @@ class AppDslService:
|
||||
|
||||
return yaml.dump(export_data, allow_unicode=True)
|
||||
|
||||
@classmethod
|
||||
def _check_or_fix_dsl(cls, import_data: dict) -> dict:
|
||||
"""
|
||||
Check or fix dsl
|
||||
|
||||
:param import_data: import data
|
||||
"""
|
||||
if not import_data.get("version"):
|
||||
import_data["version"] = "0.1.0"
|
||||
|
||||
if not import_data.get("kind") or import_data.get("kind") != "app":
|
||||
import_data["kind"] = "app"
|
||||
|
||||
if import_data.get("version") != current_dsl_version:
|
||||
# Currently only one DSL version, so no difference checks or compatibility fixes will be performed.
|
||||
logger.warning(
|
||||
f"DSL version {import_data.get('version')} is not compatible "
|
||||
f"with current version {current_dsl_version}, related to "
|
||||
f"Dify version {dsl_to_dify_version_mapping.get(current_dsl_version)}."
|
||||
)
|
||||
|
||||
return import_data
|
||||
|
||||
@classmethod
|
||||
def _import_and_create_new_workflow_based_app(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
app_mode: AppMode,
|
||||
workflow_data: dict,
|
||||
workflow_data: Mapping[str, Any],
|
||||
account: Account,
|
||||
name: str,
|
||||
description: str,
|
||||
@ -238,7 +238,9 @@ class AppDslService:
|
||||
:param use_icon_as_answer_icon: use app icon as answer icon
|
||||
"""
|
||||
if not workflow_data:
|
||||
raise ValueError("Missing workflow in data argument when app mode is advanced-chat or workflow")
|
||||
raise MissingWorkflowDataError(
|
||||
"Missing workflow in data argument when app mode is advanced-chat or workflow"
|
||||
)
|
||||
|
||||
app = cls._create_app(
|
||||
tenant_id=tenant_id,
|
||||
@ -277,7 +279,7 @@ class AppDslService:
|
||||
|
||||
@classmethod
|
||||
def _import_and_overwrite_workflow_based_app(
|
||||
cls, app_model: App, workflow_data: dict, account: Account
|
||||
cls, app_model: App, workflow_data: Mapping[str, Any], account: Account
|
||||
) -> Workflow:
|
||||
"""
|
||||
Import app dsl and overwrite workflow based app
|
||||
@ -287,7 +289,9 @@ class AppDslService:
|
||||
:param account: Account instance
|
||||
"""
|
||||
if not workflow_data:
|
||||
raise ValueError("Missing workflow in data argument when app mode is advanced-chat or workflow")
|
||||
raise MissingWorkflowDataError(
|
||||
"Missing workflow in data argument when app mode is advanced-chat or workflow"
|
||||
)
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
workflow_service = WorkflowService()
|
||||
@ -323,7 +327,7 @@ class AppDslService:
|
||||
cls,
|
||||
tenant_id: str,
|
||||
app_mode: AppMode,
|
||||
model_config_data: dict,
|
||||
model_config_data: Mapping[str, Any],
|
||||
account: Account,
|
||||
name: str,
|
||||
description: str,
|
||||
@ -345,7 +349,9 @@ class AppDslService:
|
||||
:param icon_background: app icon background
|
||||
"""
|
||||
if not model_config_data:
|
||||
raise ValueError("Missing model_config in data argument when app mode is chat, agent-chat or completion")
|
||||
raise MissingModelConfigError(
|
||||
"Missing model_config in data argument when app mode is chat, agent-chat or completion"
|
||||
)
|
||||
|
||||
app = cls._create_app(
|
||||
tenant_id=tenant_id,
|
||||
@ -448,3 +454,34 @@ class AppDslService:
|
||||
raise ValueError("Missing app configuration, please check.")
|
||||
|
||||
export_data["model_config"] = app_model_config.to_dict()
|
||||
|
||||
|
||||
def _check_or_fix_dsl(import_data: dict[str, Any]) -> Mapping[str, Any]:
|
||||
"""
|
||||
Check or fix dsl
|
||||
|
||||
:param import_data: import data
|
||||
:raises DSLVersionNotSupportedError: if the imported DSL version is newer than the current version
|
||||
"""
|
||||
if not import_data.get("version"):
|
||||
import_data["version"] = "0.1.0"
|
||||
|
||||
if not import_data.get("kind") or import_data.get("kind") != "app":
|
||||
import_data["kind"] = "app"
|
||||
|
||||
imported_version = import_data.get("version")
|
||||
if imported_version != current_dsl_version:
|
||||
if imported_version and version.parse(imported_version) > version.parse(current_dsl_version):
|
||||
raise DSLVersionNotSupportedError(
|
||||
f"The imported DSL version {imported_version} is newer than "
|
||||
f"the current supported version {current_dsl_version}. "
|
||||
f"Please upgrade your Dify instance to import this configuration."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"DSL version {imported_version} is older than "
|
||||
f"the current version {current_dsl_version}. "
|
||||
f"This may cause compatibility issues."
|
||||
)
|
||||
|
||||
return import_data
|
||||
@ -736,11 +736,12 @@ class DocumentService:
|
||||
dataset.retrieval_model = document_data.get("retrieval_model") or default_retrieval_model
|
||||
|
||||
documents = []
|
||||
batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
|
||||
if document_data.get("original_document_id"):
|
||||
document = DocumentService.update_document_with_dataset_id(dataset, document_data, account)
|
||||
documents.append(document)
|
||||
batch = document.batch
|
||||
else:
|
||||
batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
|
||||
# save process rule
|
||||
if not dataset_process_rule:
|
||||
process_rule = document_data["process_rule"]
|
||||
@ -921,7 +922,7 @@ class DocumentService:
|
||||
if duplicate_document_ids:
|
||||
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
|
||||
|
||||
return documents, batch
|
||||
return documents, batch
|
||||
|
||||
@staticmethod
|
||||
def check_documents_upload_quota(count: int, features: FeatureModel):
|
||||
|
||||
@ -35,7 +35,7 @@ class FileService:
|
||||
filename = file.filename
|
||||
if not filename:
|
||||
raise FileNotExistsError
|
||||
extension = filename.split(".")[-1]
|
||||
extension = filename.split(".")[-1].lower()
|
||||
if len(filename) > 200:
|
||||
filename = filename.split(".")[0][:200] + "." + extension
|
||||
|
||||
|
||||
@ -84,5 +84,10 @@ VOLC_EMBEDDING_ENDPOINT_ID=
|
||||
# 360 AI Credentials
|
||||
ZHINAO_API_KEY=
|
||||
|
||||
# VESSL AI Credentials
|
||||
VESSL_AI_MODEL_NAME=
|
||||
VESSL_AI_API_KEY=
|
||||
VESSL_AI_ENDPOINT_URL=
|
||||
|
||||
# Gitee AI Credentials
|
||||
GITEE_AI_API_KEY=
|
||||
GITEE_AI_API_KEY=
|
||||
131
api/tests/integration_tests/model_runtime/vessl_ai/test_llm.py
Normal file
131
api/tests/integration_tests/model_runtime/vessl_ai/test_llm.py
Normal file
@ -0,0 +1,131 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.vessl_ai.llm.llm import VesslAILargeLanguageModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = VesslAILargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model=os.environ.get("VESSL_AI_MODEL_NAME"),
|
||||
credentials={
|
||||
"api_key": "invalid_key",
|
||||
"endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"),
|
||||
"mode": "chat",
|
||||
},
|
||||
)
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model=os.environ.get("VESSL_AI_MODEL_NAME"),
|
||||
credentials={
|
||||
"api_key": os.environ.get("VESSL_AI_API_KEY"),
|
||||
"endpoint_url": "http://invalid_url",
|
||||
"mode": "chat",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model=os.environ.get("VESSL_AI_MODEL_NAME"),
|
||||
credentials={
|
||||
"api_key": os.environ.get("VESSL_AI_API_KEY"),
|
||||
"endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"),
|
||||
"mode": "chat",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = VesslAILargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model=os.environ.get("VESSL_AI_MODEL_NAME"),
|
||||
credentials={
|
||||
"api_key": os.environ.get("VESSL_AI_API_KEY"),
|
||||
"endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"),
|
||||
"mode": "chat",
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Who are you?"),
|
||||
],
|
||||
model_parameters={
|
||||
"temperature": 1.0,
|
||||
"top_k": 2,
|
||||
"top_p": 0.5,
|
||||
},
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
|
||||
def test_invoke_stream_model():
|
||||
model = VesslAILargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model=os.environ.get("VESSL_AI_MODEL_NAME"),
|
||||
credentials={
|
||||
"api_key": os.environ.get("VESSL_AI_API_KEY"),
|
||||
"endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"),
|
||||
"mode": "chat",
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Who are you?"),
|
||||
],
|
||||
model_parameters={
|
||||
"temperature": 1.0,
|
||||
"top_k": 2,
|
||||
"top_p": 0.5,
|
||||
},
|
||||
stop=["How"],
|
||||
stream=True,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = VesslAILargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model=os.environ.get("VESSL_AI_MODEL_NAME"),
|
||||
credentials={
|
||||
"api_key": os.environ.get("VESSL_AI_API_KEY"),
|
||||
"endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"),
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(num_tokens, int)
|
||||
assert num_tokens == 21
|
||||
71
api/tests/integration_tests/vdb/oceanbase/test_oceanbase.py
Normal file
71
api/tests/integration_tests/vdb/oceanbase/test_oceanbase.py
Normal file
@ -0,0 +1,71 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import (
|
||||
OceanBaseVector,
|
||||
OceanBaseVectorConfig,
|
||||
)
|
||||
from tests.integration_tests.vdb.__mock.tcvectordb import setup_tcvectordb_mock
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
get_example_text,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oceanbase_vector():
|
||||
return OceanBaseVector(
|
||||
"dify_test_collection",
|
||||
config=OceanBaseVectorConfig(
|
||||
host="127.0.0.1",
|
||||
port="2881",
|
||||
user="root@test",
|
||||
database="test",
|
||||
password="test",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class OceanBaseVectorTest(AbstractVectorTest):
|
||||
def __init__(self, vector: OceanBaseVector):
|
||||
super().__init__()
|
||||
self.vector = vector
|
||||
|
||||
def search_by_vector(self):
|
||||
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
|
||||
assert len(hits_by_vector) == 0
|
||||
|
||||
def search_by_full_text(self):
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 0
|
||||
|
||||
def text_exists(self):
|
||||
exist = self.vector.text_exists(self.example_doc_id)
|
||||
assert exist == True
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||
assert len(ids) == 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_mock_oceanbase_client():
|
||||
with patch("core.rag.datasource.vdb.oceanbase.oceanbase_vector.ObVecClient", new_callable=MagicMock) as mock_client:
|
||||
yield mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_mock_oceanbase_vector(oceanbase_vector):
|
||||
with patch.object(oceanbase_vector, "_client"):
|
||||
yield oceanbase_vector
|
||||
|
||||
|
||||
def test_oceanbase_vector(
|
||||
setup_mock_redis,
|
||||
setup_mock_oceanbase_client,
|
||||
setup_mock_oceanbase_vector,
|
||||
oceanbase_vector,
|
||||
):
|
||||
OceanBaseVectorTest(oceanbase_vector).run_all_tests()
|
||||
@ -430,3 +430,37 @@ def test_multi_colons_parse(setup_http_mock):
|
||||
assert urlencode({"Redirect": "http://example2.com"}) in result.process_data.get("request", "")
|
||||
assert 'form-data; name="Redirect"\r\n\r\nhttp://example6.com' in result.process_data.get("request", "")
|
||||
# assert "http://example3.com" == resp.get("headers", {}).get("referer")
|
||||
|
||||
|
||||
def test_image_file(monkeypatch):
|
||||
from types import SimpleNamespace
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.tools.tool_file_manager.ToolFileManager.create_file_by_raw",
|
||||
lambda *args, **kwargs: SimpleNamespace(id="1"),
|
||||
)
|
||||
|
||||
node = init_http_node(
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
"url": "https://cloud.dify.ai/logo/logo-site.png",
|
||||
"authorization": {
|
||||
"type": "no-auth",
|
||||
"config": None,
|
||||
},
|
||||
"params": "",
|
||||
"headers": "",
|
||||
"body": None,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
assert result.outputs is not None
|
||||
resp = result.outputs
|
||||
assert len(resp.get("files", [])) == 1
|
||||
|
||||
@ -22,17 +22,3 @@ from controllers.console.version import _has_new_version
|
||||
)
|
||||
def test_has_new_version(latest_version, current_version, expected):
|
||||
assert _has_new_version(latest_version=latest_version, current_version=current_version) == expected
|
||||
|
||||
|
||||
def test_has_new_version_invalid_input():
|
||||
with pytest.raises(ValueError):
|
||||
_has_new_version(latest_version="1.0", current_version="1.0.0")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_has_new_version(latest_version="1.0.0", current_version="1.0")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_has_new_version(latest_version="invalid", current_version="1.0.0")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_has_new_version(latest_version="1.0.0", current_version="invalid")
|
||||
|
||||
125
api/tests/unit_tests/core/workflow/nodes/llm/test_node.py
Normal file
125
api/tests/unit_tests/core/workflow/nodes/llm/test_node.py
Normal file
@ -0,0 +1,125 @@
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
|
||||
from core.workflow.nodes.end import EndStreamParam
|
||||
from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig, VisionConfig, VisionConfigOptions
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
class TestLLMNode:
|
||||
@pytest.fixture
|
||||
def llm_node(self):
|
||||
data = LLMNodeData(
|
||||
title="Test LLM",
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
|
||||
prompt_template=[],
|
||||
memory=None,
|
||||
context=ContextConfig(enabled=False),
|
||||
vision=VisionConfig(
|
||||
enabled=True,
|
||||
configs=VisionConfigOptions(
|
||||
variable_selector=["sys", "files"],
|
||||
detail=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
),
|
||||
),
|
||||
)
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
node = LLMNode(
|
||||
id="1",
|
||||
config={
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config={},
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
),
|
||||
graph=Graph(
|
||||
root_node_id="1",
|
||||
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
||||
answer_dependencies={},
|
||||
answer_generate_route={},
|
||||
),
|
||||
end_stream_param=EndStreamParam(
|
||||
end_dependencies={},
|
||||
end_stream_variable_selector_mapping={},
|
||||
),
|
||||
),
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
),
|
||||
)
|
||||
return node
|
||||
|
||||
def test_fetch_files_with_file_segment(self, llm_node):
|
||||
file = File(
|
||||
id="1",
|
||||
tenant_id="test",
|
||||
type=FileType.IMAGE,
|
||||
filename="test.jpg",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="1",
|
||||
)
|
||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
|
||||
|
||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||
assert result == [file]
|
||||
|
||||
def test_fetch_files_with_array_file_segment(self, llm_node):
|
||||
files = [
|
||||
File(
|
||||
id="1",
|
||||
tenant_id="test",
|
||||
type=FileType.IMAGE,
|
||||
filename="test1.jpg",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="1",
|
||||
),
|
||||
File(
|
||||
id="2",
|
||||
tenant_id="test",
|
||||
type=FileType.IMAGE,
|
||||
filename="test2.jpg",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="2",
|
||||
),
|
||||
]
|
||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
|
||||
|
||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||
assert result == files
|
||||
|
||||
def test_fetch_files_with_none_segment(self, llm_node):
|
||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
|
||||
|
||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_with_array_any_segment(self, llm_node):
|
||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
|
||||
|
||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||
assert result == []
|
||||
|
||||
def test_fetch_files_with_non_existent_variable(self, llm_node):
|
||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||
assert result == []
|
||||
100
api/tests/unit_tests/oss/__mock/aliyun_oss.py
Normal file
100
api/tests/unit_tests/oss/__mock/aliyun_oss.py
Normal file
@ -0,0 +1,100 @@
|
||||
import os
|
||||
import posixpath
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from oss2 import Bucket
|
||||
from oss2.models import GetObjectResult, PutObjectResult
|
||||
|
||||
from tests.unit_tests.oss.__mock.base import (
|
||||
get_example_bucket,
|
||||
get_example_data,
|
||||
get_example_filename,
|
||||
get_example_filepath,
|
||||
get_example_folder,
|
||||
)
|
||||
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, status, headers, request_id):
|
||||
self.status = status
|
||||
self.headers = headers
|
||||
self.request_id = request_id
|
||||
|
||||
|
||||
class MockAliyunOssClass:
|
||||
def __init__(
|
||||
self,
|
||||
auth,
|
||||
endpoint,
|
||||
bucket_name,
|
||||
is_cname=False,
|
||||
session=None,
|
||||
connect_timeout=None,
|
||||
app_name="",
|
||||
enable_crc=True,
|
||||
proxies=None,
|
||||
region=None,
|
||||
cloudbox_id=None,
|
||||
is_path_style=False,
|
||||
is_verify_object_strict=True,
|
||||
):
|
||||
self.bucket_name = get_example_bucket()
|
||||
self.key = posixpath.join(get_example_folder(), get_example_filename())
|
||||
self.content = get_example_data()
|
||||
self.filepath = get_example_filepath()
|
||||
self.resp = MockResponse(
|
||||
200,
|
||||
{
|
||||
"etag": "ee8de918d05640145b18f70f4c3aa602",
|
||||
"x-oss-version-id": "CAEQNhiBgMDJgZCA0BYiIDc4MGZjZGI2OTBjOTRmNTE5NmU5NmFhZjhjYmY0****",
|
||||
},
|
||||
"request_id",
|
||||
)
|
||||
|
||||
def put_object(self, key, data, headers=None, progress_callback=None):
|
||||
assert key == self.key
|
||||
assert data == self.content
|
||||
return PutObjectResult(self.resp)
|
||||
|
||||
def get_object(self, key, byte_range=None, headers=None, progress_callback=None, process=None, params=None):
|
||||
assert key == self.key
|
||||
|
||||
get_object_output = MagicMock(GetObjectResult)
|
||||
get_object_output.read.return_value = self.content
|
||||
return get_object_output
|
||||
|
||||
def get_object_to_file(
|
||||
self, key, filename, byte_range=None, headers=None, progress_callback=None, process=None, params=None
|
||||
):
|
||||
assert key == self.key
|
||||
assert filename == self.filepath
|
||||
|
||||
def object_exists(self, key, headers=None):
|
||||
assert key == self.key
|
||||
return True
|
||||
|
||||
def delete_object(self, key, params=None, headers=None):
|
||||
assert key == self.key
|
||||
self.resp.headers["x-oss-delete-marker"] = True
|
||||
return self.resp
|
||||
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_aliyun_oss_mock(monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(Bucket, "__init__", MockAliyunOssClass.__init__)
|
||||
monkeypatch.setattr(Bucket, "put_object", MockAliyunOssClass.put_object)
|
||||
monkeypatch.setattr(Bucket, "get_object", MockAliyunOssClass.get_object)
|
||||
monkeypatch.setattr(Bucket, "get_object_to_file", MockAliyunOssClass.get_object_to_file)
|
||||
monkeypatch.setattr(Bucket, "object_exists", MockAliyunOssClass.object_exists)
|
||||
monkeypatch.setattr(Bucket, "delete_object", MockAliyunOssClass.delete_object)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
@ -0,0 +1,22 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from oss2 import Auth
|
||||
|
||||
from extensions.storage.aliyun_oss_storage import AliyunOssStorage
|
||||
from tests.unit_tests.oss.__mock.aliyun_oss import setup_aliyun_oss_mock
|
||||
from tests.unit_tests.oss.__mock.base import (
|
||||
BaseStorageTest,
|
||||
get_example_bucket,
|
||||
get_example_folder,
|
||||
)
|
||||
|
||||
|
||||
class TestAliyunOss(BaseStorageTest):
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self, setup_aliyun_oss_mock):
|
||||
"""Executed before each test method."""
|
||||
with patch.object(Auth, "__init__", return_value=None):
|
||||
self.storage = AliyunOssStorage()
|
||||
self.storage.bucket_name = get_example_bucket()
|
||||
self.storage.folder = get_example_folder()
|
||||
@ -0,0 +1,41 @@
|
||||
import pytest
|
||||
from packaging import version
|
||||
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_dsl_service.exc import DSLVersionNotSupportedError
|
||||
from services.app_dsl_service.service import _check_or_fix_dsl, current_dsl_version
|
||||
|
||||
|
||||
class TestAppDSLService:
|
||||
def test_check_or_fix_dsl_missing_version(self):
|
||||
import_data = {}
|
||||
result = _check_or_fix_dsl(import_data)
|
||||
assert result["version"] == "0.1.0"
|
||||
assert result["kind"] == "app"
|
||||
|
||||
def test_check_or_fix_dsl_missing_kind(self):
|
||||
import_data = {"version": "0.1.0"}
|
||||
result = _check_or_fix_dsl(import_data)
|
||||
assert result["kind"] == "app"
|
||||
|
||||
def test_check_or_fix_dsl_older_version(self):
|
||||
import_data = {"version": "0.0.9", "kind": "app"}
|
||||
result = _check_or_fix_dsl(import_data)
|
||||
assert result["version"] == "0.0.9"
|
||||
|
||||
def test_check_or_fix_dsl_current_version(self):
|
||||
import_data = {"version": current_dsl_version, "kind": "app"}
|
||||
result = _check_or_fix_dsl(import_data)
|
||||
assert result["version"] == current_dsl_version
|
||||
|
||||
def test_check_or_fix_dsl_newer_version(self):
|
||||
current_version = version.parse(current_dsl_version)
|
||||
newer_version = f"{current_version.major}.{current_version.minor + 1}.0"
|
||||
import_data = {"version": newer_version, "kind": "app"}
|
||||
with pytest.raises(DSLVersionNotSupportedError):
|
||||
_check_or_fix_dsl(import_data)
|
||||
|
||||
def test_check_or_fix_dsl_invalid_kind(self):
|
||||
import_data = {"version": current_dsl_version, "kind": "invalid"}
|
||||
result = _check_or_fix_dsl(import_data)
|
||||
assert result["kind"] == "app"
|
||||
Reference in New Issue
Block a user