mirror of
https://github.com/langgenius/dify.git
synced 2026-02-04 18:57:51 +08:00
Compare commits
346 Commits
fix/extern
...
0.10.0-bet
| Author | SHA1 | Date | |
|---|---|---|---|
| 24d06eb741 | |||
| a36ef8430e | |||
| 6e9129a333 | |||
| 86427468fd | |||
| 638d9250e4 | |||
| dba67cd87a | |||
| c3811c4c23 | |||
| ff7a467a1a | |||
| c73b5e2410 | |||
| 421fde0e85 | |||
| b3fdd618a1 | |||
| 81e57bcc73 | |||
| 7156753a99 | |||
| 0d310b503b | |||
| 03018823d8 | |||
| 885192db38 | |||
| ea18dd1571 | |||
| 7838f9f3a3 | |||
| 8fe5028f74 | |||
| de3c5751db | |||
| 5ee7e03c1b | |||
| 7a405b86c9 | |||
| ffc3f33670 | |||
| 857055b797 | |||
| d15ba3939d | |||
| 1ec83e4969 | |||
| 9275760599 | |||
| d97d3ff5fc | |||
| ea6734f550 | |||
| f73751843f | |||
| dbfbc56de7 | |||
| 70c5b23089 | |||
| 2ec6ffe478 | |||
| 793205afc5 | |||
| c6b74daa0a | |||
| 23ce1fb1ba | |||
| 29188e0562 | |||
| d9773c963f | |||
| 1206b1eb96 | |||
| 93af87a9e0 | |||
| 7a6970e570 | |||
| ea584e94bd | |||
| 42b02b3a5f | |||
| 44f6a536d2 | |||
| d7b8e071dd | |||
| 3f1aa1f9e2 | |||
| 82024a65cd | |||
| 9f492d527a | |||
| faec1c50f9 | |||
| 2888c068c1 | |||
| d77521e65f | |||
| b463468a14 | |||
| 820076beec | |||
| fd6476b941 | |||
| 7abe76f07e | |||
| dba24766e8 | |||
| 12492c0d5d | |||
| 2740d68cd1 | |||
| 7218f718ad | |||
| 2ae0aceef8 | |||
| a5c1132087 | |||
| 691a7f727a | |||
| 7e2984b6e2 | |||
| 44e81dbbc8 | |||
| 944cfd2b68 | |||
| 6d2682c751 | |||
| d2971e84bb | |||
| b67b81bf8f | |||
| c05902404d | |||
| 1e6d5f2c48 | |||
| e2b1464db2 | |||
| f0285a53d2 | |||
| 00f91b5dc4 | |||
| dc3e86b82a | |||
| d239c5b54d | |||
| 23abccd3a6 | |||
| 2520e40059 | |||
| 1d4ed3d9e7 | |||
| eed8ab9348 | |||
| 112aaf6e1b | |||
| 094a1a1458 | |||
| 955fa4345a | |||
| ac5e381a1a | |||
| ae9b9f867a | |||
| 8fd04e5313 | |||
| 3904782647 | |||
| 288be3fbd8 | |||
| f7f836d6f1 | |||
| 5dedcb74a5 | |||
| b95d0fa9a9 | |||
| 543503c398 | |||
| 3f16caf244 | |||
| 54133dfbde | |||
| b491c93b1c | |||
| 2a6d9c3211 | |||
| c6691bd297 | |||
| 2a0b30de5c | |||
| a7d53abba9 | |||
| 296253a365 | |||
| c89cefe526 | |||
| 1d027fa065 | |||
| 9ce9a52a86 | |||
| c74424ed85 | |||
| 719ef9cef9 | |||
| 0ab525a691 | |||
| 6fdcf6ee21 | |||
| d01e97c1fc | |||
| 87e560de8a | |||
| f8d26e46ac | |||
| 195ac19774 | |||
| 0281eb796d | |||
| 9fe2f321ae | |||
| 5f76e665a1 | |||
| 81568752c0 | |||
| ceb1dde714 | |||
| 3209fdca53 | |||
| dc5010d833 | |||
| 8b26ae6532 | |||
| 66953d57a2 | |||
| afc9630cd0 | |||
| 7e8bafe186 | |||
| 6c5fcd1ffc | |||
| 7602d22133 | |||
| 5ec91e8507 | |||
| 466966f027 | |||
| 212d04ea27 | |||
| 0cb50dd4a5 | |||
| ab19fccf3d | |||
| 4ed46e3fed | |||
| 9fd2f798ff | |||
| 146be41b1d | |||
| ce6ae5732a | |||
| edf462c640 | |||
| d580fc1e9d | |||
| 5544791031 | |||
| 099746dd59 | |||
| c6f53c9030 | |||
| 8236f8fed8 | |||
| 2b0c39ed3f | |||
| 396a240e68 | |||
| 8bd9d8f6ba | |||
| aa7ae4c5f1 | |||
| 49b7acf52e | |||
| 466ac987f5 | |||
| 49972939a9 | |||
| 80f167ca02 | |||
| f652ae0d98 | |||
| 4dbf56675a | |||
| f5d1f5a20a | |||
| fd9b71c4d7 | |||
| 1df41cef4c | |||
| 602d2486bd | |||
| 403fede432 | |||
| 9f66e6e357 | |||
| affb2e38a1 | |||
| 31d87f85b8 | |||
| 54105e85ff | |||
| 5ec604500c | |||
| 96d2582d89 | |||
| a10b0db102 | |||
| 5dd556b4c8 | |||
| a4c6d0b94b | |||
| 323a835de9 | |||
| 0076577764 | |||
| 9a3b7345c4 | |||
| 2ebf5f5ffa | |||
| 02f494c0de | |||
| f0e81e3918 | |||
| aa8499efac | |||
| ea40b1dcb2 | |||
| a689cd6fd4 | |||
| 32b6c7063a | |||
| 97056dad30 | |||
| 264f7c2139 | |||
| 007a6fd14a | |||
| c159b7a781 | |||
| 6c9c3faf78 | |||
| d933ebb845 | |||
| b60c7a5826 | |||
| 0b94218378 | |||
| 97cc9a5615 | |||
| f6d0fd9848 | |||
| b863dd7de2 | |||
| b0e7a22a27 | |||
| 565a835947 | |||
| fe94c876fb | |||
| 67a34bdd7a | |||
| 8c785e268b | |||
| 65a6265ff6 | |||
| 08d3cb1912 | |||
| 48d8b01d81 | |||
| 38edb06897 | |||
| dc919c2a6c | |||
| e7a6a0ab01 | |||
| 61d989f413 | |||
| 976efd93a1 | |||
| 0e2f78b3a6 | |||
| b3529d3ccc | |||
| d69b453729 | |||
| 2f658de155 | |||
| a691700b48 | |||
| c5317d8f58 | |||
| 822f03f3cd | |||
| 101e56baaa | |||
| 3a8f516dfc | |||
| 912030c9a1 | |||
| 687661eef7 | |||
| 8efc63a705 | |||
| dca4f9fe9c | |||
| 51597629b1 | |||
| 76a07513ba | |||
| dae62bef78 | |||
| 2a6629d435 | |||
| 41f0ce1012 | |||
| e90b055c47 | |||
| 94e40d4ed9 | |||
| c34fc071e0 | |||
| c014ae43e1 | |||
| 9851153d38 | |||
| cfbabb8383 | |||
| b78e90679d | |||
| ec1bfdc723 | |||
| e20019f6e9 | |||
| 2122cfb152 | |||
| c2b8beffac | |||
| 985651454a | |||
| f9c1d06e91 | |||
| 657f1d2de8 | |||
| 6e2192c1e0 | |||
| e05b20eb91 | |||
| 5117e08def | |||
| 34691ca6c9 | |||
| aa40047b08 | |||
| eca17767fe | |||
| 51cec1b9ba | |||
| 651547c3ef | |||
| 8fbdaa604c | |||
| 1bcb30647f | |||
| bc245a25bf | |||
| 85b25ebe1b | |||
| b50e94d681 | |||
| 91c0657cf6 | |||
| 0da06128e3 | |||
| 0c4af3a1d2 | |||
| 5628b293f8 | |||
| fff40aae58 | |||
| b3b87b3e4c | |||
| 9a23cd08d8 | |||
| cf61ca24e3 | |||
| 58a56add9c | |||
| b362031baf | |||
| 7ad409b3d9 | |||
| 876ea90fe9 | |||
| 0eb442f954 | |||
| 4554ac3ef8 | |||
| eaa7d114dc | |||
| 581228be74 | |||
| 02da0219ff | |||
| d0bbe43dab | |||
| 16acdc9be4 | |||
| a6999b5d02 | |||
| 33bfa4758e | |||
| db63c2c219 | |||
| bea4ec5998 | |||
| 74333db4c8 | |||
| 0019fb9f8b | |||
| 47615ac8fb | |||
| d7c8bced9b | |||
| 57f178902f | |||
| 4586de48d6 | |||
| 6549519fa5 | |||
| ae098ad121 | |||
| 20922fde1c | |||
| 079c802b5c | |||
| efcd462a69 | |||
| 843c8ad306 | |||
| 594bf96922 | |||
| ade385c9c1 | |||
| baed068231 | |||
| 42f5334ae4 | |||
| 3c4ab0632d | |||
| bc5f109308 | |||
| 97b2a42cc3 | |||
| 939df16655 | |||
| 9362ae045c | |||
| 257c515178 | |||
| 6b7520ccc2 | |||
| 85eeaee95a | |||
| 99bf3ff565 | |||
| 36ae154ca2 | |||
| ef93d60534 | |||
| 6c9a6b99e0 | |||
| b73f05fdf0 | |||
| 26bca75884 | |||
| e2962da1b8 | |||
| 1b9ebb8037 | |||
| a945a45b06 | |||
| be829a8103 | |||
| 9432d41e60 | |||
| 0beeb4ab3e | |||
| d7e057be44 | |||
| 81b11c08d0 | |||
| 83a5cdfff9 | |||
| c837218bc9 | |||
| 68552893ef | |||
| 5ba93ed064 | |||
| 959107f553 | |||
| 443d929137 | |||
| 1e04418023 | |||
| aeda8869bc | |||
| 10eed02ec4 | |||
| 2472c4f890 | |||
| 0455e4e1a5 | |||
| 251ab5418f | |||
| 38e6e40900 | |||
| b3a3672857 | |||
| 53a3c199ec | |||
| fca5af5073 | |||
| 77d0aac1d3 | |||
| fd0f8f33b5 | |||
| 0be99ad01c | |||
| a05d16375e | |||
| 0480bb03c3 | |||
| 19dfc6d9a8 | |||
| d361675159 | |||
| 23ae150298 | |||
| 81383d7c74 | |||
| 573f653789 | |||
| f1b61861b6 | |||
| 8ecee8abce | |||
| e9ce9c1f47 | |||
| 944fea4cc9 | |||
| 25c029877a | |||
| 9c31c56115 | |||
| 56507c9f7a | |||
| b322dda3f6 | |||
| 52d69dd55b | |||
| 0451c5590c | |||
| 2498c238b2 | |||
| 6e15d7f777 | |||
| f6caf0915b | |||
| 09aa14ca82 | |||
| 394f06a27a | |||
| 6fafd410d2 | |||
| 1668df104f | |||
| d376b8540e |
4
.github/workflows/build-push.yml
vendored
4
.github/workflows/build-push.yml
vendored
@ -5,7 +5,7 @@ on:
|
||||
branches:
|
||||
- "main"
|
||||
- "deploy/dev"
|
||||
- "fix/external-knowledge-retrieval-issues"
|
||||
- "release/0.10.0-beta"
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
@ -126,7 +126,7 @@ jobs:
|
||||
with:
|
||||
images: ${{ env[matrix.image_name_env] }}
|
||||
tags: |
|
||||
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') && !contains(github.ref, '-beta') }}
|
||||
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') && !contains(github.ref, '-') }}
|
||||
type=ref,event=branch
|
||||
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
|
||||
type=raw,value=${{ github.ref_name }},enable=${{ startsWith(github.ref, 'refs/tags/') }}
|
||||
|
||||
@ -20,6 +20,9 @@ FILES_URL=http://127.0.0.1:5001
|
||||
# The time in seconds after the signature is rejected
|
||||
FILES_ACCESS_TIMEOUT=300
|
||||
|
||||
# Access token expiration time in minutes
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
||||
|
||||
# celery configuration
|
||||
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
|
||||
|
||||
@ -39,7 +42,7 @@ DB_DATABASE=dify
|
||||
|
||||
# Storage configuration
|
||||
# use for store upload files, private keys...
|
||||
# storage type: local, s3, azure-blob, google-storage, tencent-cos, huawei-obs, volcengine-tos, baidu-obs
|
||||
# storage type: local, s3, azure-blob, google-storage, tencent-cos, huawei-obs, volcengine-tos, baidu-obs, supabase
|
||||
STORAGE_TYPE=local
|
||||
STORAGE_LOCAL_PATH=storage
|
||||
S3_USE_AWS_MANAGED_IAM=false
|
||||
@ -99,11 +102,16 @@ VOLCENGINE_TOS_ACCESS_KEY=your-access-key
|
||||
VOLCENGINE_TOS_SECRET_KEY=your-secret-key
|
||||
VOLCENGINE_TOS_REGION=your-region
|
||||
|
||||
# Supabase Storage Configuration
|
||||
SUPABASE_BUCKET_NAME=your-bucket-name
|
||||
SUPABASE_API_KEY=your-access-key
|
||||
SUPABASE_URL=your-server-url
|
||||
|
||||
# CORS configuration
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, vikingdb
|
||||
VECTOR_STORE=weaviate
|
||||
|
||||
# Weaviate configuration
|
||||
@ -203,10 +211,30 @@ OPENSEARCH_USER=admin
|
||||
OPENSEARCH_PASSWORD=admin
|
||||
OPENSEARCH_SECURE=true
|
||||
|
||||
# Baidu configuration
|
||||
BAIDU_VECTOR_DB_ENDPOINT=http://127.0.0.1:5287
|
||||
BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS=30000
|
||||
BAIDU_VECTOR_DB_ACCOUNT=root
|
||||
BAIDU_VECTOR_DB_API_KEY=dify
|
||||
BAIDU_VECTOR_DB_DATABASE=dify
|
||||
BAIDU_VECTOR_DB_SHARD=1
|
||||
BAIDU_VECTOR_DB_REPLICAS=3
|
||||
|
||||
# ViKingDB configuration
|
||||
VIKINGDB_ACCESS_KEY=your-ak
|
||||
VIKINGDB_SECRET_KEY=your-sk
|
||||
VIKINGDB_REGION=cn-shanghai
|
||||
VIKINGDB_HOST=api-vikingdb.xxx.volces.com
|
||||
VIKINGDB_SCHEMA=http
|
||||
VIKINGDB_CONNECTION_TIMEOUT=30
|
||||
VIKINGDB_SOCKET_TIMEOUT=30
|
||||
|
||||
# Upload configuration
|
||||
UPLOAD_FILE_SIZE_LIMIT=15
|
||||
UPLOAD_FILE_BATCH_LIMIT=5
|
||||
UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
|
||||
UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
|
||||
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
|
||||
|
||||
# Model Configuration
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT=base64
|
||||
@ -284,6 +312,7 @@ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=1000
|
||||
WORKFLOW_MAX_EXECUTION_STEPS=500
|
||||
WORKFLOW_MAX_EXECUTION_TIME=1200
|
||||
WORKFLOW_CALL_MAX_DEPTH=5
|
||||
MAX_VARIABLE_SIZE=204800
|
||||
|
||||
# App configuration
|
||||
APP_MAX_EXECUTION_TIME=1200
|
||||
|
||||
15
api/.vscode/launch.json.example
vendored
15
api/.vscode/launch.json.example
vendored
@ -1,8 +1,15 @@
|
||||
{
|
||||
"version": "0.2.0",
|
||||
"compounds": [
|
||||
{
|
||||
"name": "Launch Flask and Celery",
|
||||
"configurations": ["Python: Flask", "Python: Celery"]
|
||||
}
|
||||
],
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python: Flask",
|
||||
"consoleName": "Flask",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"python": "${workspaceFolder}/.venv/bin/python",
|
||||
@ -17,12 +24,12 @@
|
||||
},
|
||||
"args": [
|
||||
"run",
|
||||
"--host=0.0.0.0",
|
||||
"--port=5001"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Python: Celery",
|
||||
"consoleName": "Celery",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"python": "${workspaceFolder}/.venv/bin/python",
|
||||
@ -45,10 +52,10 @@
|
||||
"-c",
|
||||
"1",
|
||||
"--loglevel",
|
||||
"info",
|
||||
"DEBUG",
|
||||
"-Q",
|
||||
"dataset,generation,mail,ops_trace,app_deletion"
|
||||
]
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@ -118,7 +118,7 @@ def create_app() -> Flask:
|
||||
|
||||
logging.basicConfig(
|
||||
level=app.config.get("LOG_LEVEL"),
|
||||
format=app.config.get("LOG_FORMAT"),
|
||||
format=app.config["LOG_FORMAT"],
|
||||
datefmt=app.config.get("LOG_DATEFORMAT"),
|
||||
handlers=log_handlers,
|
||||
force=True,
|
||||
@ -135,6 +135,7 @@ def create_app() -> Flask:
|
||||
return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple()
|
||||
|
||||
for handler in logging.root.handlers:
|
||||
assert handler.formatter
|
||||
handler.formatter.converter = time_converter
|
||||
initialize_extensions(app)
|
||||
register_blueprints(app)
|
||||
@ -183,7 +184,7 @@ def load_user_from_request(request_from_flask_login):
|
||||
decoded = PassportService().verify(auth_token)
|
||||
user_id = decoded.get("user_id")
|
||||
|
||||
logged_in_account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token)
|
||||
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
|
||||
if logged_in_account:
|
||||
contexts.tenant_id.set(logged_in_account.current_tenant_id)
|
||||
return logged_in_account
|
||||
|
||||
@ -19,7 +19,7 @@ from extensions.ext_redis import redis_client
|
||||
from libs.helper import email as email_validate
|
||||
from libs.password import hash_password, password_pattern, valid_password
|
||||
from libs.rsa import generate_key_pair
|
||||
from models.account import Tenant
|
||||
from models import Tenant
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
|
||||
@ -347,6 +347,14 @@ def migrate_knowledge_vector_database():
|
||||
index_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == VectorType.BAIDU:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": VectorType.BAIDU,
|
||||
"vector_store": {"class_prefix": collection_name},
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
else:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
@ -449,14 +457,14 @@ def convert_to_agent_apps():
|
||||
# fetch first 1000 apps
|
||||
sql_query = """SELECT a.id AS id FROM apps a
|
||||
INNER JOIN app_model_configs am ON a.app_model_config_id=am.id
|
||||
WHERE a.mode = 'chat'
|
||||
AND am.agent_mode is not null
|
||||
WHERE a.mode = 'chat'
|
||||
AND am.agent_mode is not null
|
||||
AND (
|
||||
am.agent_mode like '%"strategy": "function_call"%'
|
||||
am.agent_mode like '%"strategy": "function_call"%'
|
||||
OR am.agent_mode like '%"strategy": "react"%'
|
||||
)
|
||||
)
|
||||
AND (
|
||||
am.agent_mode like '{"enabled": true%'
|
||||
am.agent_mode like '{"enabled": true%'
|
||||
OR am.agent_mode like '{"max_iteration": %'
|
||||
) ORDER BY a.created_at DESC LIMIT 1000
|
||||
"""
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Annotated, Optional
|
||||
from typing import Annotated, Literal, Optional
|
||||
|
||||
from pydantic import AliasChoices, Field, HttpUrl, NegativeInt, NonNegativeInt, PositiveInt, computed_field
|
||||
from pydantic_settings import BaseSettings
|
||||
@ -11,11 +11,11 @@ class SecurityConfig(BaseSettings):
|
||||
Security-related configurations for the application
|
||||
"""
|
||||
|
||||
SECRET_KEY: Optional[str] = Field(
|
||||
SECRET_KEY: str = Field(
|
||||
description="Secret key for secure session cookie signing."
|
||||
"Make sure you are changing this key for your deployment with a strong key."
|
||||
"Generate a strong key using `openssl rand -base64 42` or set via the `SECRET_KEY` environment variable.",
|
||||
default=None,
|
||||
default="",
|
||||
)
|
||||
|
||||
RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field(
|
||||
@ -177,6 +177,16 @@ class FileUploadConfig(BaseSettings):
|
||||
default=10,
|
||||
)
|
||||
|
||||
UPLOAD_VIDEO_FILE_SIZE_LIMIT: NonNegativeInt = Field(
|
||||
description="video file size limit in Megabytes for uploading files",
|
||||
default=100,
|
||||
)
|
||||
|
||||
UPLOAD_AUDIO_FILE_SIZE_LIMIT: NonNegativeInt = Field(
|
||||
description="audio file size limit in Megabytes for uploading files",
|
||||
default=50,
|
||||
)
|
||||
|
||||
BATCH_UPLOAD_LIMIT: NonNegativeInt = Field(
|
||||
description="Maximum number of files allowed in a batch upload operation",
|
||||
default=20,
|
||||
@ -355,14 +365,14 @@ class WorkflowConfig(BaseSettings):
|
||||
)
|
||||
|
||||
MAX_VARIABLE_SIZE: PositiveInt = Field(
|
||||
description="Maximum size in bytes for a single variable in workflows. Default to 5KB.",
|
||||
default=5 * 1024,
|
||||
description="Maximum size in bytes for a single variable in workflows. Default to 200 KB.",
|
||||
default=200 * 1024,
|
||||
)
|
||||
|
||||
|
||||
class OAuthConfig(BaseSettings):
|
||||
class AuthConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for OAuth authentication
|
||||
Configuration for authentication and OAuth
|
||||
"""
|
||||
|
||||
OAUTH_REDIRECT_PATH: str = Field(
|
||||
@ -371,7 +381,7 @@ class OAuthConfig(BaseSettings):
|
||||
)
|
||||
|
||||
GITHUB_CLIENT_ID: Optional[str] = Field(
|
||||
description="GitHub OAuth client secret",
|
||||
description="GitHub OAuth client ID",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@ -390,6 +400,11 @@ class OAuthConfig(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: PositiveInt = Field(
|
||||
description="Expiration time for access tokens in minutes",
|
||||
default=60,
|
||||
)
|
||||
|
||||
|
||||
class ModerationConfig(BaseSettings):
|
||||
"""
|
||||
@ -474,6 +489,7 @@ class RagEtlConfig(BaseSettings):
|
||||
Configuration for RAG ETL processes
|
||||
"""
|
||||
|
||||
# TODO: This config is not only for rag etl, it is also for file upload, we should move it to file upload config
|
||||
ETL_TYPE: str = Field(
|
||||
description="RAG ETL type ('dify' or 'Unstructured'), default to 'dify'",
|
||||
default="dify",
|
||||
@ -535,7 +551,7 @@ class IndexingConfig(BaseSettings):
|
||||
|
||||
|
||||
class ImageFormatConfig(BaseSettings):
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT: str = Field(
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field(
|
||||
description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64",
|
||||
default="base64",
|
||||
)
|
||||
@ -607,6 +623,7 @@ class PositionConfig(BaseSettings):
|
||||
class FeatureConfig(
|
||||
# place the configs in alphabet order
|
||||
AppExecutionConfig,
|
||||
AuthConfig, # Changed from OAuthConfig to AuthConfig
|
||||
BillingConfig,
|
||||
CodeExecutionSandboxConfig,
|
||||
DataSetConfig,
|
||||
@ -621,14 +638,13 @@ class FeatureConfig(
|
||||
MailConfig,
|
||||
ModelLoadBalanceConfig,
|
||||
ModerationConfig,
|
||||
OAuthConfig,
|
||||
PositionConfig,
|
||||
RagEtlConfig,
|
||||
SecurityConfig,
|
||||
ToolConfig,
|
||||
UpdateConfig,
|
||||
WorkflowConfig,
|
||||
WorkspaceConfig,
|
||||
PositionConfig,
|
||||
# hosted services config
|
||||
HostedServiceConfig,
|
||||
CeleryBeatConfig,
|
||||
|
||||
@ -12,6 +12,7 @@ from configs.middleware.storage.baidu_obs_storage_config import BaiduOBSStorageC
|
||||
from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig
|
||||
from configs.middleware.storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
|
||||
from configs.middleware.storage.oci_storage_config import OCIStorageConfig
|
||||
from configs.middleware.storage.supabase_storage_config import SupabaseStorageConfig
|
||||
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
||||
from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
|
||||
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
|
||||
@ -27,6 +28,7 @@ from configs.middleware.vdb.qdrant_config import QdrantConfig
|
||||
from configs.middleware.vdb.relyt_config import RelytConfig
|
||||
from configs.middleware.vdb.tencent_vector_config import TencentVectorDBConfig
|
||||
from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig
|
||||
from configs.middleware.vdb.vikingdb_config import VikingDBConfig
|
||||
from configs.middleware.vdb.weaviate_config import WeaviateConfig
|
||||
|
||||
|
||||
@ -222,6 +224,7 @@ class MiddlewareConfig(
|
||||
HuaweiCloudOBSStorageConfig,
|
||||
OCIStorageConfig,
|
||||
S3StorageConfig,
|
||||
SupabaseStorageConfig,
|
||||
TencentCloudCOSStorageConfig,
|
||||
VolcengineTOSStorageConfig,
|
||||
# configs of vdb and vdb providers
|
||||
@ -241,5 +244,6 @@ class MiddlewareConfig(
|
||||
WeaviateConfig,
|
||||
ElasticsearchConfig,
|
||||
InternalTestConfig,
|
||||
VikingDBConfig,
|
||||
):
|
||||
pass
|
||||
|
||||
24
api/configs/middleware/storage/supabase_storage_config.py
Normal file
24
api/configs/middleware/storage/supabase_storage_config.py
Normal file
@ -0,0 +1,24 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SupabaseStorageConfig(BaseModel):
|
||||
"""
|
||||
Configuration settings for Supabase Object Storage Service
|
||||
"""
|
||||
|
||||
SUPABASE_BUCKET_NAME: Optional[str] = Field(
|
||||
description="Name of the Supabase bucket to store and retrieve objects (e.g., 'dify-bucket')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
SUPABASE_API_KEY: Optional[str] = Field(
|
||||
description="API KEY for authenticating with Supabase",
|
||||
default=None,
|
||||
)
|
||||
|
||||
SUPABASE_URL: Optional[str] = Field(
|
||||
description="URL of the Supabase",
|
||||
default=None,
|
||||
)
|
||||
45
api/configs/middleware/vdb/baidu_vector_config.py
Normal file
45
api/configs/middleware/vdb/baidu_vector_config.py
Normal file
@ -0,0 +1,45 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, NonNegativeInt, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class BaiduVectorDBConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for Baidu Vector Database
|
||||
"""
|
||||
|
||||
BAIDU_VECTOR_DB_ENDPOINT: Optional[str] = Field(
|
||||
description="URL of the Baidu Vector Database service (e.g., 'http://vdb.bj.baidubce.com')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: PositiveInt = Field(
|
||||
description="Timeout in milliseconds for Baidu Vector Database operations (default is 30000 milliseconds)",
|
||||
default=30000,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_ACCOUNT: Optional[str] = Field(
|
||||
description="Account for authenticating with the Baidu Vector Database",
|
||||
default=None,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_API_KEY: Optional[str] = Field(
|
||||
description="API key for authenticating with the Baidu Vector Database service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_DATABASE: Optional[str] = Field(
|
||||
description="Name of the specific Baidu Vector Database to connect to",
|
||||
default=None,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_SHARD: PositiveInt = Field(
|
||||
description="Number of shards for the Baidu Vector Database (default is 1)",
|
||||
default=1,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_REPLICAS: NonNegativeInt = Field(
|
||||
description="Number of replicas for the Baidu Vector Database (default is 3)",
|
||||
default=3,
|
||||
)
|
||||
37
api/configs/middleware/vdb/vikingdb_config.py
Normal file
37
api/configs/middleware/vdb/vikingdb_config.py
Normal file
@ -0,0 +1,37 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VikingDBConfig(BaseModel):
|
||||
"""
|
||||
Configuration for connecting to Volcengine VikingDB.
|
||||
Refer to the following documentation for details on obtaining credentials:
|
||||
https://www.volcengine.com/docs/6291/65568
|
||||
"""
|
||||
|
||||
VIKINGDB_ACCESS_KEY: Optional[str] = Field(
|
||||
default=None, description="The Access Key provided by Volcengine VikingDB for API authentication."
|
||||
)
|
||||
VIKINGDB_SECRET_KEY: Optional[str] = Field(
|
||||
default=None, description="The Secret Key provided by Volcengine VikingDB for API authentication."
|
||||
)
|
||||
VIKINGDB_REGION: Optional[str] = Field(
|
||||
default="cn-shanghai",
|
||||
description="The region of the Volcengine VikingDB service.(e.g., 'cn-shanghai', 'cn-beijing').",
|
||||
)
|
||||
VIKINGDB_HOST: Optional[str] = Field(
|
||||
default="api-vikingdb.mlp.cn-shanghai.volces.com",
|
||||
description="The host of the Volcengine VikingDB service.(e.g., 'api-vikingdb.volces.com', \
|
||||
'api-vikingdb.mlp.cn-shanghai.volces.com')",
|
||||
)
|
||||
VIKINGDB_SCHEME: Optional[str] = Field(
|
||||
default="http",
|
||||
description="The scheme of the Volcengine VikingDB service.(e.g., 'http', 'https').",
|
||||
)
|
||||
VIKINGDB_CONNECTION_TIMEOUT: Optional[int] = Field(
|
||||
default=30, description="The connection timeout of the Volcengine VikingDB service."
|
||||
)
|
||||
VIKINGDB_SOCKET_TIMEOUT: Optional[int] = Field(
|
||||
default=30, description="The socket timeout of the Volcengine VikingDB service."
|
||||
)
|
||||
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.9.1-fix1",
|
||||
default="0.10.0-beta3",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
@ -1,2 +1,21 @@
|
||||
from configs import dify_config
|
||||
|
||||
HIDDEN_VALUE = "[__HIDDEN__]"
|
||||
UUID_NIL = "00000000-0000-0000-0000-000000000000"
|
||||
|
||||
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
|
||||
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
|
||||
|
||||
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "mpga"]
|
||||
VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS])
|
||||
|
||||
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "webm", "amr"]
|
||||
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
|
||||
|
||||
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"]
|
||||
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
|
||||
|
||||
if dify_config.ETL_TYPE == "Unstructured":
|
||||
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls"]
|
||||
DOCUMENT_EXTENSIONS.extend(("docx", "csv", "eml", "msg", "pptx", "ppt", "xml", "epub"))
|
||||
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
from contextvars import ContextVar
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
tenant_id: ContextVar[str] = ContextVar("tenant_id")
|
||||
|
||||
workflow_variable_pool: ContextVar[VariablePool] = ContextVar("workflow_variable_pool")
|
||||
workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
|
||||
|
||||
@ -22,7 +22,8 @@ from fields.conversation_fields import (
|
||||
)
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation
|
||||
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class CompletionConversationApi(Resource):
|
||||
|
||||
@ -12,7 +12,7 @@ from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_site_fields
|
||||
from libs.login import login_required
|
||||
from models.model import Site
|
||||
from models import Site
|
||||
|
||||
|
||||
def parse_app_site_args():
|
||||
|
||||
@ -13,14 +13,14 @@ from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.segments import factory
|
||||
from core.errors.error import AppInvokeQuotaExceededError
|
||||
from factories import variable_factory
|
||||
from fields.workflow_fields import workflow_fields
|
||||
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from models.model import App, AppMode
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
@ -101,9 +101,13 @@ class DraftWorkflowApi(Resource):
|
||||
|
||||
try:
|
||||
environment_variables_list = args.get("environment_variables") or []
|
||||
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
|
||||
environment_variables = [
|
||||
variable_factory.build_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
]
|
||||
conversation_variables_list = args.get("conversation_variables") or []
|
||||
conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list]
|
||||
conversation_variables = [
|
||||
variable_factory.build_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
workflow = workflow_service.sync_draft_workflow(
|
||||
app_model=app_model,
|
||||
graph=args["graph"],
|
||||
@ -273,17 +277,15 @@ class DraftWorkflowRunApi(Resource):
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True
|
||||
)
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
return helper.compact_generate_response(response)
|
||||
|
||||
|
||||
class WorkflowTaskStopApi(Resource):
|
||||
|
||||
@ -7,7 +7,8 @@ from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppMode
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
|
||||
|
||||
@ -13,7 +13,8 @@ from fields.workflow_run_fields import (
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppMode
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from services.workflow_run_service import WorkflowRunService
|
||||
|
||||
|
||||
|
||||
@ -10,11 +10,11 @@ from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from enums import WorkflowRunTriggeredFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRunTriggeredFrom
|
||||
|
||||
|
||||
class WorkflowDailyRunsStatistic(Resource):
|
||||
|
||||
@ -5,7 +5,8 @@ from typing import Optional, Union
|
||||
from controllers.console.app.error import AppNotFoundError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models.model import App, AppMode
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None):
|
||||
|
||||
@ -15,7 +15,7 @@ from controllers.console.setup import setup_required
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email as email_validate
|
||||
from libs.password import hash_password, valid_password
|
||||
from models.account import Account
|
||||
from models import Account
|
||||
from services.account_service import AccountService
|
||||
from services.errors.account import RateLimitExceededError
|
||||
|
||||
|
||||
@ -7,9 +7,9 @@ from flask_restful import Resource, reqparse
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from libs.helper import email, get_remote_ip
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models.account import Account
|
||||
from models import Account
|
||||
from services.account_service import AccountService, TenantService
|
||||
|
||||
|
||||
@ -40,17 +40,16 @@ class LoginApi(Resource):
|
||||
"data": "workspace not found, please contact system admin to invite you to join in a workspace",
|
||||
}
|
||||
|
||||
token = AccountService.login(account, ip_address=get_remote_ip(request))
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
|
||||
return {"result": "success", "data": token}
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
|
||||
class LogoutApi(Resource):
|
||||
@setup_required
|
||||
def get(self):
|
||||
account = cast(Account, flask_login.current_user)
|
||||
token = request.headers.get("Authorization", "").split(" ")[1]
|
||||
AccountService.logout(account=account, token=token)
|
||||
AccountService.logout(account=account)
|
||||
flask_login.logout_user()
|
||||
return {"result": "success"}
|
||||
|
||||
@ -106,5 +105,19 @@ class ResetPasswordApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class RefreshTokenApi(Resource):
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("refresh_token", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
new_token_pair = AccountService.refresh_token(args["refresh_token"])
|
||||
return {"result": "success", "data": new_token_pair.model_dump()}
|
||||
except Exception as e:
|
||||
return {"result": "fail", "data": str(e)}, 401
|
||||
|
||||
|
||||
api.add_resource(LoginApi, "/login")
|
||||
api.add_resource(LogoutApi, "/logout")
|
||||
api.add_resource(RefreshTokenApi, "/refresh-token")
|
||||
|
||||
@ -9,9 +9,10 @@ from flask_restful import Resource
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import get_remote_ip
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||
from models.account import Account, AccountStatus
|
||||
from models import Account
|
||||
from models.account import AccountStatus
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
|
||||
from .. import api
|
||||
@ -81,9 +82,14 @@ class OAuthCallback(Resource):
|
||||
|
||||
TenantService.create_owner_tenant_if_not_exist(account)
|
||||
|
||||
token = AccountService.login(account, ip_address=get_remote_ip(request))
|
||||
token_pair = AccountService.login(
|
||||
account=account,
|
||||
ip_address=extract_remote_ip(request),
|
||||
)
|
||||
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}?console_token={token}")
|
||||
return redirect(
|
||||
f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
|
||||
)
|
||||
|
||||
|
||||
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
|
||||
|
||||
@ -15,8 +15,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
from extensions.ext_database import db
|
||||
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
|
||||
from libs.login import login_required
|
||||
from models.dataset import Document
|
||||
from models.source import DataSourceOauthBinding
|
||||
from models import DataSourceOauthBinding, Document
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
|
||||
@ -24,8 +24,8 @@ from fields.app_fields import related_app_list
|
||||
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
||||
from fields.document_fields import document_status_fields
|
||||
from libs.login import login_required
|
||||
from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment
|
||||
from models.model import ApiToken, UploadFile
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
|
||||
|
||||
@ -617,6 +617,8 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.CHROMA
|
||||
| VectorType.TENCENT
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.BAIDU
|
||||
| VectorType.VIKINGDB
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
@ -653,6 +655,8 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
| VectorType.CHROMA
|
||||
| VectorType.TENCENT
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.BAIDU
|
||||
| VectorType.VIKINGDB
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
|
||||
@ -46,8 +46,7 @@ from fields.document_fields import (
|
||||
document_with_segments_fields,
|
||||
)
|
||||
from libs.login import login_required
|
||||
from models.dataset import Dataset, DatasetProcessRule, Document, DocumentSegment
|
||||
from models.model import UploadFile
|
||||
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from tasks.add_document_to_index_task import add_document_to_index_task
|
||||
from tasks.remove_document_from_index_task import remove_document_from_index_task
|
||||
|
||||
@ -24,7 +24,7 @@ from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.segment_fields import segment_fields
|
||||
from libs.login import login_required
|
||||
from models.dataset import DocumentSegment
|
||||
from models import DocumentSegment
|
||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
||||
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
|
||||
|
||||
@ -1,9 +1,12 @@
|
||||
import urllib.parse
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from constants import DOCUMENT_EXTENSIONS
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.error import (
|
||||
FileTooLargeError,
|
||||
@ -13,9 +16,10 @@ from controllers.console.datasets.error import (
|
||||
)
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from fields.file_fields import file_fields, upload_config_fields
|
||||
from core.helper import ssrf_proxy
|
||||
from fields.file_fields import file_fields, remote_file_info_fields, upload_config_fields
|
||||
from libs.login import login_required
|
||||
from services.file_service import ALLOWED_EXTENSIONS, UNSTRUCTURED_ALLOWED_EXTENSIONS, FileService
|
||||
from services.file_service import FileService
|
||||
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
|
||||
@ -51,7 +55,7 @@ class FileApi(Resource):
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
try:
|
||||
upload_file = FileService.upload_file(file, current_user)
|
||||
upload_file = FileService.upload_file(file=file, user=current_user)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
@ -75,11 +79,24 @@ class FileSupportTypeApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
etl_type = dify_config.ETL_TYPE
|
||||
allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS
|
||||
return {"allowed_extensions": allowed_extensions}
|
||||
return {"allowed_extensions": DOCUMENT_EXTENSIONS}
|
||||
|
||||
|
||||
class RemoteFileInfoApi(Resource):
|
||||
@marshal_with(remote_file_info_fields)
|
||||
def get(self, url):
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
try:
|
||||
response = ssrf_proxy.head(decoded_url)
|
||||
return {
|
||||
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
|
||||
"file_length": int(response.headers.get("Content-Length", 0)),
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 400
|
||||
|
||||
|
||||
api.add_resource(FileApi, "/files/upload")
|
||||
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")
|
||||
api.add_resource(FileSupportTypeApi, "/files/support-type")
|
||||
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
|
||||
|
||||
@ -11,7 +11,7 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
|
||||
from extensions.ext_database import db
|
||||
from fields.installed_app_fields import installed_app_list_fields
|
||||
from libs.login import login_required
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
from models import App, InstalledApp, RecommendedApp
|
||||
from services.account_service import TenantService
|
||||
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ message_fields = {
|
||||
"inputs": fields.Raw,
|
||||
"query": fields.String,
|
||||
"answer": fields.String,
|
||||
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
@ -7,7 +7,7 @@ from werkzeug.exceptions import NotFound
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from models.model import InstalledApp
|
||||
from models import InstalledApp
|
||||
|
||||
|
||||
def installed_app_required(view=None):
|
||||
|
||||
@ -4,7 +4,7 @@ from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from configs import dify_config
|
||||
from libs.helper import StrLen, email, get_remote_ip
|
||||
from libs.helper import StrLen, email, extract_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models.model import DifySetup
|
||||
from services.account_service import RegisterService, TenantService
|
||||
@ -46,7 +46,7 @@ class SetupApi(Resource):
|
||||
|
||||
# setup
|
||||
RegisterService.setup(
|
||||
email=args["email"], name=args["name"], password=args["password"], ip_address=get_remote_ip(request)
|
||||
email=args["email"], name=args["name"], password=args["password"], ip_address=extract_remote_ip(request)
|
||||
)
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
@ -20,7 +20,7 @@ from extensions.ext_database import db
|
||||
from fields.member_fields import account_fields
|
||||
from libs.helper import TimestampField, timezone
|
||||
from libs.login import login_required
|
||||
from models.account import AccountIntegrate, InvitationCode
|
||||
from models import AccountIntegrate, InvitationCode
|
||||
from services.account_service import AccountService
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
|
||||
|
||||
@ -360,16 +360,15 @@ class ToolWorkflowProviderCreateApi(Resource):
|
||||
args = reqparser.parse_args()
|
||||
|
||||
return WorkflowToolManageService.create_workflow_tool(
|
||||
user_id,
|
||||
tenant_id,
|
||||
args["workflow_app_id"],
|
||||
args["name"],
|
||||
args["label"],
|
||||
args["icon"],
|
||||
args["description"],
|
||||
args["parameters"],
|
||||
args["privacy_policy"],
|
||||
args.get("labels", []),
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
workflow_app_id=args["workflow_app_id"],
|
||||
name=args["name"],
|
||||
label=args["label"],
|
||||
icon=args["icon"],
|
||||
description=args["description"],
|
||||
parameters=args["parameters"],
|
||||
privacy_policy=args["privacy_policy"],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -198,7 +198,7 @@ class WebappLogoWorkspaceApi(Resource):
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(file, current_user, True)
|
||||
upload_file = FileService.upload_file(file=file, user=current_user)
|
||||
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
|
||||
@ -21,7 +21,36 @@ class ImagePreviewApi(Resource):
|
||||
return {"content": "Invalid request."}, 400
|
||||
|
||||
try:
|
||||
generator, mimetype = FileService.get_image_preview(file_id, timestamp, nonce, sign)
|
||||
generator, mimetype = FileService.get_image_preview(
|
||||
file_id=file_id,
|
||||
timestamp=timestamp,
|
||||
nonce=nonce,
|
||||
sign=sign,
|
||||
)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return Response(generator, mimetype=mimetype)
|
||||
|
||||
|
||||
class FilePreviewApi(Resource):
|
||||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
|
||||
timestamp = request.args.get("timestamp")
|
||||
nonce = request.args.get("nonce")
|
||||
sign = request.args.get("sign")
|
||||
|
||||
if not timestamp or not nonce or not sign:
|
||||
return {"content": "Invalid request."}, 400
|
||||
|
||||
try:
|
||||
generator, mimetype = FileService.get_signed_file_preview(
|
||||
file_id=file_id,
|
||||
timestamp=timestamp,
|
||||
nonce=nonce,
|
||||
sign=sign,
|
||||
)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
@ -49,4 +78,5 @@ class WorkspaceWebappLogoApi(Resource):
|
||||
|
||||
|
||||
api.add_resource(ImagePreviewApi, "/files/<uuid:file_id>/image-preview")
|
||||
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/file-preview")
|
||||
api.add_resource(WorkspaceWebappLogoApi, "/files/workspaces/<uuid:workspace_id>/webapp-logo")
|
||||
|
||||
@ -48,7 +48,7 @@ class MessageListApi(Resource):
|
||||
"tool_input": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"observation": fields.String,
|
||||
"message_files": fields.List(fields.String, attribute="files"),
|
||||
"message_files": fields.List(fields.String),
|
||||
}
|
||||
|
||||
message_fields = {
|
||||
@ -58,7 +58,7 @@ class MessageListApi(Resource):
|
||||
"inputs": fields.Raw,
|
||||
"query": fields.String,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
|
||||
"created_at": TimestampField,
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import urllib.parse
|
||||
|
||||
from flask import request
|
||||
from flask_restful import marshal_with
|
||||
|
||||
@ -5,7 +7,8 @@ import services
|
||||
from controllers.web import api
|
||||
from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from fields.file_fields import file_fields
|
||||
from core.helper import ssrf_proxy
|
||||
from fields.file_fields import file_fields, remote_file_info_fields
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@ -31,4 +34,19 @@ class FileApi(WebApiResource):
|
||||
return upload_file, 201
|
||||
|
||||
|
||||
class RemoteFileInfoApi(WebApiResource):
|
||||
@marshal_with(remote_file_info_fields)
|
||||
def get(self, url):
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
try:
|
||||
response = ssrf_proxy.head(decoded_url)
|
||||
return {
|
||||
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
|
||||
"file_length": int(response.headers.get("Content-Length", 0)),
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 400
|
||||
|
||||
|
||||
api.add_resource(FileApi, "/files/upload")
|
||||
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
|
||||
|
||||
@ -22,6 +22,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from fields.message_fields import agent_thought_fields
|
||||
from fields.raws import FilesContainedField
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from models.model import AppMode
|
||||
@ -58,10 +59,10 @@ class MessageListApi(WebApiResource):
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
"inputs": fields.Raw,
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
|
||||
"created_at": TimestampField,
|
||||
|
||||
@ -17,7 +17,7 @@ message_fields = {
|
||||
"inputs": fields.Raw,
|
||||
"query": fields.String,
|
||||
"answer": fields.String,
|
||||
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
@ -16,13 +16,14 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.file import file_manager
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
@ -40,9 +41,9 @@ from core.tools.entities.tool_entities import (
|
||||
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation, Message, MessageAgentThought
|
||||
from factories import file_factory
|
||||
from models.model import Conversation, Message, MessageAgentThought, MessageFile
|
||||
from models.tools import ToolConversationVariables
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -66,23 +67,6 @@ class BaseAgentRunner(AppRunner):
|
||||
db_variables: Optional[ToolConversationVariables] = None,
|
||||
model_instance: ModelInstance = None,
|
||||
) -> None:
|
||||
"""
|
||||
Agent runner
|
||||
:param tenant_id: tenant id
|
||||
:param application_generate_entity: application generate entity
|
||||
:param conversation: conversation
|
||||
:param app_config: app generate entity
|
||||
:param model_config: model config
|
||||
:param config: dataset config
|
||||
:param queue_manager: queue manager
|
||||
:param message: message
|
||||
:param user_id: user id
|
||||
:param memory: memory
|
||||
:param prompt_messages: prompt messages
|
||||
:param variables_pool: variables pool
|
||||
:param db_variables: db variables
|
||||
:param model_instance: model instance
|
||||
"""
|
||||
self.tenant_id = tenant_id
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.conversation = conversation
|
||||
@ -180,7 +164,7 @@ class BaseAgentRunner(AppRunner):
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
continue
|
||||
|
||||
parameter_type = ToolParameterConverter.get_parameter_type(parameter.type)
|
||||
parameter_type = parameter.type.as_normal_type()
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options]
|
||||
@ -265,7 +249,7 @@ class BaseAgentRunner(AppRunner):
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
continue
|
||||
|
||||
parameter_type = ToolParameterConverter.get_parameter_type(parameter.type)
|
||||
parameter_type = parameter.type.as_normal_type()
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options]
|
||||
@ -511,26 +495,24 @@ class BaseAgentRunner(AppRunner):
|
||||
return result
|
||||
|
||||
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
|
||||
message_file_parser = MessageFileParser(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_config.app_id,
|
||||
)
|
||||
|
||||
files = message.message_files
|
||||
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
|
||||
if files:
|
||||
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
|
||||
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.transform_message_files(files, file_extra_config)
|
||||
file_objs = file_factory.build_from_message_files(
|
||||
message_files=files, tenant_id=self.tenant_id, config=file_extra_config
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
if not file_objs:
|
||||
return UserPromptMessage(content=message.query)
|
||||
else:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=message.query)]
|
||||
prompt_message_contents: list[PromptMessageContent] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
||||
for file_obj in file_objs:
|
||||
prompt_message_contents.append(file_obj.prompt_message_content)
|
||||
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
|
||||
|
||||
return UserPromptMessage(content=prompt_message_contents)
|
||||
else:
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
import json
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
from core.file import file_manager
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
@ -32,9 +34,10 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
Organize user query
|
||||
"""
|
||||
if self.files:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=query)]
|
||||
prompt_message_contents: list[PromptMessageContent] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
for file_obj in self.files:
|
||||
prompt_message_contents.append(file_obj.prompt_message_content)
|
||||
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
|
||||
@ -7,10 +7,15 @@ from typing import Any, Optional, Union
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
from core.file import file_manager
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentType,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
@ -390,9 +395,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
Organize user query
|
||||
"""
|
||||
if self.files:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=query)]
|
||||
prompt_message_contents: list[PromptMessageContent] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
for file_obj in self.files:
|
||||
prompt_message_contents.append(file_obj.prompt_message_content)
|
||||
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
|
||||
@ -53,12 +53,11 @@ class BasicVariablesConfigManager:
|
||||
VariableEntity(
|
||||
type=variable_type,
|
||||
variable=variable.get("variable"),
|
||||
description=variable.get("description"),
|
||||
description=variable.get("description", ""),
|
||||
label=variable.get("label"),
|
||||
required=variable.get("required", False),
|
||||
max_length=variable.get("max_length"),
|
||||
options=variable.get("options"),
|
||||
default=variable.get("default"),
|
||||
options=variable.get("options", []),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.file.file_obj import FileExtraConfig
|
||||
from core.file import FileExtraConfig, FileTransferMethod, FileType
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from models import AppMode
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class ModelConfigEntity(BaseModel):
|
||||
@ -69,7 +70,7 @@ class PromptTemplateEntity(BaseModel):
|
||||
ADVANCED = "advanced"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "PromptType":
|
||||
def value_of(cls, value: str):
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
@ -93,6 +94,8 @@ class VariableEntityType(str, Enum):
|
||||
PARAGRAPH = "paragraph"
|
||||
NUMBER = "number"
|
||||
EXTERNAL_DATA_TOOL = "external_data_tool"
|
||||
FILE = "file"
|
||||
FILE_LIST = "file-list"
|
||||
|
||||
|
||||
class VariableEntity(BaseModel):
|
||||
@ -102,13 +105,14 @@ class VariableEntity(BaseModel):
|
||||
|
||||
variable: str
|
||||
label: str
|
||||
description: Optional[str] = None
|
||||
description: str = ""
|
||||
type: VariableEntityType
|
||||
required: bool = False
|
||||
max_length: Optional[int] = None
|
||||
options: Optional[list[str]] = None
|
||||
default: Optional[str] = None
|
||||
hint: Optional[str] = None
|
||||
options: Sequence[str] = Field(default_factory=list)
|
||||
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
|
||||
allowed_file_extensions: Sequence[str] = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ExternalDataVariableEntity(BaseModel):
|
||||
@ -136,7 +140,7 @@ class DatasetRetrieveConfigEntity(BaseModel):
|
||||
MULTIPLE = "multiple"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "RetrieveStrategy":
|
||||
def value_of(cls, value: str):
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.file.file_obj import FileExtraConfig
|
||||
from core.file.models import FileExtraConfig
|
||||
from models import FileUploadConfig
|
||||
|
||||
|
||||
class FileUploadConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: Mapping[str, Any], is_vision: bool = True) -> Optional[FileExtraConfig]:
|
||||
def convert(cls, config: Mapping[str, Any], is_vision: bool = True):
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@ -15,19 +16,18 @@ class FileUploadConfigManager:
|
||||
"""
|
||||
file_upload_dict = config.get("file_upload")
|
||||
if file_upload_dict:
|
||||
if file_upload_dict.get("image"):
|
||||
if "enabled" in file_upload_dict["image"] and file_upload_dict["image"]["enabled"]:
|
||||
image_config = {
|
||||
"number_limits": file_upload_dict["image"]["number_limits"],
|
||||
"transfer_methods": file_upload_dict["image"]["transfer_methods"],
|
||||
if file_upload_dict.get("enabled"):
|
||||
data = {
|
||||
"image_config": {
|
||||
"number_limits": file_upload_dict["number_limits"],
|
||||
"transfer_methods": file_upload_dict["allowed_file_upload_methods"],
|
||||
}
|
||||
}
|
||||
|
||||
if is_vision:
|
||||
image_config["detail"] = file_upload_dict["image"]["detail"]
|
||||
if is_vision:
|
||||
data["image_config"]["detail"] = file_upload_dict.get("image", {}).get("detail", "low")
|
||||
|
||||
return FileExtraConfig(image_config=image_config)
|
||||
|
||||
return None
|
||||
return FileExtraConfig.model_validate(data)
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict, is_vision: bool = True) -> tuple[dict, list[str]]:
|
||||
@ -39,29 +39,7 @@ class FileUploadConfigManager:
|
||||
"""
|
||||
if not config.get("file_upload"):
|
||||
config["file_upload"] = {}
|
||||
|
||||
if not isinstance(config["file_upload"], dict):
|
||||
raise ValueError("file_upload must be of dict type")
|
||||
|
||||
# check image config
|
||||
if not config["file_upload"].get("image"):
|
||||
config["file_upload"]["image"] = {"enabled": False}
|
||||
|
||||
if config["file_upload"]["image"]["enabled"]:
|
||||
number_limits = config["file_upload"]["image"]["number_limits"]
|
||||
if number_limits < 1 or number_limits > 6:
|
||||
raise ValueError("number_limits must be in [1, 6]")
|
||||
|
||||
if is_vision:
|
||||
detail = config["file_upload"]["image"]["detail"]
|
||||
if detail not in {"high", "low"}:
|
||||
raise ValueError("detail must be in ['high', 'low']")
|
||||
|
||||
transfer_methods = config["file_upload"]["image"]["transfer_methods"]
|
||||
if not isinstance(transfer_methods, list):
|
||||
raise ValueError("transfer_methods must be of list type")
|
||||
for method in transfer_methods:
|
||||
if method not in {"remote_url", "local_file"}:
|
||||
raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
|
||||
else:
|
||||
FileUploadConfig.model_validate(config["file_upload"])
|
||||
|
||||
return config, ["file_upload"]
|
||||
|
||||
@ -17,6 +17,6 @@ class WorkflowVariablesConfigManager:
|
||||
|
||||
# variables
|
||||
for variable in user_input_form:
|
||||
variables.append(VariableEntity(**variable))
|
||||
variables.append(VariableEntity.model_validate(variable))
|
||||
|
||||
return variables
|
||||
|
||||
@ -20,10 +20,11 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from enums import CreatedByRole
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.account import Account
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from models.workflow import Workflow
|
||||
@ -95,10 +96,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
# parse files
|
||||
files = args["files"] if args.get("files") else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=user.id,
|
||||
role=role,
|
||||
config=file_extra_config,
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
@ -106,8 +113,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
|
||||
# get tracing instance
|
||||
user_id = user.id if isinstance(user, Account) else user.session_id
|
||||
trace_manager = TraceQueueManager(app_model.id, user_id)
|
||||
trace_manager = TraceQueueManager(
|
||||
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
|
||||
)
|
||||
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
# always enable retriever resource in debugger mode
|
||||
@ -119,7 +127,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
|
||||
inputs=conversation.inputs
|
||||
if conversation
|
||||
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
parent_message_id=args.get("parent_message_id"),
|
||||
|
||||
@ -1,30 +1,26 @@
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAnnotationReplyEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
)
|
||||
from core.moderation.base import ModerationError
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from enums import UserFrom
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from models.workflow import ConversationVariable, WorkflowType
|
||||
@ -44,12 +40,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
) -> None:
|
||||
"""
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:param conversation: conversation
|
||||
:param message: message
|
||||
"""
|
||||
super().__init__(queue_manager)
|
||||
|
||||
self.application_generate_entity = application_generate_entity
|
||||
@ -57,10 +47,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
self.message = message
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Run application
|
||||
:return:
|
||||
"""
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(AdvancedChatAppConfig, app_config)
|
||||
|
||||
@ -81,7 +67,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
user_id = self.application_generate_entity.user_id
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
|
||||
if dify_config.DEBUG:
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
@ -201,15 +187,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
query: str,
|
||||
message_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Handle input moderation
|
||||
:param app_record: app record
|
||||
:param app_generate_entity: application generate entity
|
||||
:param inputs: inputs
|
||||
:param query: query
|
||||
:param message_id: message id
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# process sensitive_word_avoidance
|
||||
_, inputs, query = self.moderation_for_inputs(
|
||||
@ -229,14 +206,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
def handle_annotation_reply(
|
||||
self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity
|
||||
) -> bool:
|
||||
"""
|
||||
Handle annotation reply
|
||||
:param app_record: app record
|
||||
:param message: message
|
||||
:param query: query
|
||||
:param app_generate_entity: application generate entity
|
||||
"""
|
||||
# annotation reply
|
||||
annotation_reply = self.query_app_annotations_to_reply(
|
||||
app_record=app_record,
|
||||
message=message,
|
||||
@ -258,8 +227,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None:
|
||||
"""
|
||||
Direct output
|
||||
:param text: text
|
||||
:return:
|
||||
"""
|
||||
self._publish_event(QueueTextChunkEvent(text=text))
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
@ -9,6 +9,7 @@ from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAdvancedChatMessageEndEvent,
|
||||
@ -50,12 +51,15 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from enums import CreatedByRole
|
||||
from enums.workflow_nodes import NodeType
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models import Conversation, EndUser, Message, MessageFile
|
||||
from models.account import Account
|
||||
from models.model import Conversation, EndUser, Message
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
@ -72,6 +76,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -115,8 +120,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
}
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._wip_workflow_node_executions = {}
|
||||
|
||||
self._conversation_name_generate_thread = None
|
||||
self._recorded_files: list[Mapping[str, Any]] = []
|
||||
|
||||
def process(self):
|
||||
"""
|
||||
@ -295,6 +302,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
elif isinstance(event, QueueNodeSucceededEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_success(event)
|
||||
|
||||
# Record files if it's an answer node or end node
|
||||
if event.node_type in [NodeType.ANSWER, NodeType.END]:
|
||||
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
|
||||
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@ -361,7 +372,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
outputs=json.dumps(event.outputs) if event.outputs else None,
|
||||
outputs=event.outputs,
|
||||
conversation_id=self._conversation.id,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
@ -487,10 +498,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
self._conversation_name_generate_thread.join()
|
||||
|
||||
def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
|
||||
"""
|
||||
Save message.
|
||||
:return:
|
||||
"""
|
||||
self._refetch_message()
|
||||
|
||||
self._message.answer = self._task_state.answer
|
||||
@ -498,6 +505,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
self._message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
message_files = [
|
||||
MessageFile(
|
||||
message_id=self._message.id,
|
||||
type=file["type"],
|
||||
transfer_method=file["transfer_method"],
|
||||
url=file["remote_url"],
|
||||
belongs_to="assistant",
|
||||
upload_file_id=file["related_id"],
|
||||
created_by_role=CreatedByRole.ACCOUNT
|
||||
if self._message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||
else CreatedByRole.END_USER,
|
||||
created_by=self._message.from_account_id or self._message.from_end_user_id or "",
|
||||
)
|
||||
for file in self._recorded_files
|
||||
]
|
||||
db.session.add_all(message_files)
|
||||
|
||||
if graph_runtime_state and graph_runtime_state.llm_usage:
|
||||
usage = graph_runtime_state.llm_usage
|
||||
@ -537,7 +560,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
del extras["metadata"]["annotation_reply"]
|
||||
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
|
||||
task_id=self._application_generate_entity.task_id, id=self._message.id, files=self._recorded_files, **extras
|
||||
)
|
||||
|
||||
def _handle_output_moderation_chunk(self, text: str) -> bool:
|
||||
|
||||
@ -17,12 +17,12 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskSt
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from enums import CreatedByRole
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, EndUser
|
||||
from factories import file_factory
|
||||
from models import Account, App, EndUser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -49,7 +49,12 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
) -> dict: ...
|
||||
|
||||
def generate(
|
||||
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> Union[dict, Generator[dict, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
@ -97,12 +102,19 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
# always enable retriever resource in debugger mode
|
||||
override_model_config_dict["retriever_resource"] = {"enabled": True}
|
||||
|
||||
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
||||
|
||||
# parse files
|
||||
files = args["files"] if args.get("files") else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
files = args.get("files") or []
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=user.id,
|
||||
role=role,
|
||||
config=file_extra_config,
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
@ -115,8 +127,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
user_id = user.id if isinstance(user, Account) else user.session_id
|
||||
trace_manager = TraceQueueManager(app_model.id, user_id)
|
||||
trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = AgentChatAppGenerateEntity(
|
||||
@ -124,7 +135,9 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
app_config=app_config,
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
|
||||
inputs=conversation.inputs
|
||||
if conversation
|
||||
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
parent_message_id=args.get("parent_message_id"),
|
||||
|
||||
@ -1,35 +1,92 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from core.app.app_config.entities import AppConfig, VariableEntity, VariableEntityType
|
||||
from core.app.app_config.entities import VariableEntityType
|
||||
from core.file import File, FileExtraConfig
|
||||
from factories import file_factory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.app.app_config.entities import AppConfig, VariableEntity
|
||||
from enums import CreatedByRole
|
||||
|
||||
|
||||
class BaseAppGenerator:
|
||||
def _get_cleaned_inputs(self, user_inputs: Optional[Mapping[str, Any]], app_config: AppConfig) -> Mapping[str, Any]:
|
||||
def _prepare_user_inputs(
|
||||
self,
|
||||
*,
|
||||
user_inputs: Optional[Mapping[str, Any]],
|
||||
app_config: "AppConfig",
|
||||
user_id: str,
|
||||
role: "CreatedByRole",
|
||||
) -> Mapping[str, Any]:
|
||||
user_inputs = user_inputs or {}
|
||||
# Filter input variables from form configuration, handle required fields, default values, and option values
|
||||
variables = app_config.variables
|
||||
filtered_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
|
||||
filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()}
|
||||
return filtered_inputs
|
||||
user_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
|
||||
user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()}
|
||||
# Convert files in inputs to File
|
||||
entity_dictionary = {item.variable: item for item in app_config.variables}
|
||||
# Convert single file to File
|
||||
files_inputs = {
|
||||
k: file_factory.build_from_mapping(
|
||||
mapping=v,
|
||||
tenant_id=app_config.tenant_id,
|
||||
user_id=user_id,
|
||||
role=role,
|
||||
config=FileExtraConfig(
|
||||
allowed_file_types=entity_dictionary[k].allowed_file_types,
|
||||
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
||||
),
|
||||
)
|
||||
for k, v in user_inputs.items()
|
||||
if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE
|
||||
}
|
||||
# Convert list of files to File
|
||||
file_list_inputs = {
|
||||
k: file_factory.build_from_mappings(
|
||||
mappings=v,
|
||||
tenant_id=app_config.tenant_id,
|
||||
user_id=user_id,
|
||||
role=role,
|
||||
config=FileExtraConfig(
|
||||
allowed_file_types=entity_dictionary[k].allowed_file_types,
|
||||
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
||||
),
|
||||
)
|
||||
for k, v in user_inputs.items()
|
||||
if isinstance(v, list)
|
||||
# Ensure skip List<File>
|
||||
and all(isinstance(item, dict) for item in v)
|
||||
and entity_dictionary[k].type == VariableEntityType.FILE_LIST
|
||||
}
|
||||
# Merge all inputs
|
||||
user_inputs = {**user_inputs, **files_inputs, **file_list_inputs}
|
||||
|
||||
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
|
||||
user_input_value = inputs.get(var.variable)
|
||||
if var.required and not user_input_value:
|
||||
raise ValueError(f"{var.variable} is required in input form")
|
||||
if not var.required and not user_input_value:
|
||||
# TODO: should we return None here if the default value is None?
|
||||
return var.default or ""
|
||||
if (
|
||||
var.type
|
||||
in {
|
||||
VariableEntityType.TEXT_INPUT,
|
||||
VariableEntityType.SELECT,
|
||||
VariableEntityType.PARAGRAPH,
|
||||
}
|
||||
and user_input_value
|
||||
and not isinstance(user_input_value, str)
|
||||
# Check if all files are converted to File
|
||||
if any(filter(lambda v: isinstance(v, dict), user_inputs.values())):
|
||||
raise ValueError("Invalid input type")
|
||||
if any(
|
||||
filter(lambda v: isinstance(v, dict), filter(lambda item: isinstance(item, list), user_inputs.values()))
|
||||
):
|
||||
raise ValueError("Invalid input type")
|
||||
|
||||
return user_inputs
|
||||
|
||||
def _validate_input(self, *, inputs: Mapping[str, Any], var: "VariableEntity"):
|
||||
user_input_value = inputs.get(var.variable)
|
||||
if not user_input_value:
|
||||
if var.required:
|
||||
raise ValueError(f"{var.variable} is required in input form")
|
||||
else:
|
||||
return None
|
||||
|
||||
if var.type in {
|
||||
VariableEntityType.TEXT_INPUT,
|
||||
VariableEntityType.SELECT,
|
||||
VariableEntityType.PARAGRAPH,
|
||||
} and not isinstance(user_input_value, str):
|
||||
raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string")
|
||||
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
|
||||
# may raise ValueError if user_input_value is not a valid number
|
||||
@ -41,12 +98,24 @@ class BaseAppGenerator:
|
||||
except ValueError:
|
||||
raise ValueError(f"{var.variable} in input form must be a valid number")
|
||||
if var.type == VariableEntityType.SELECT:
|
||||
options = var.options or []
|
||||
options = var.options
|
||||
if user_input_value not in options:
|
||||
raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
|
||||
elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}:
|
||||
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
|
||||
if var.max_length and len(user_input_value) > var.max_length:
|
||||
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
|
||||
elif var.type == VariableEntityType.FILE:
|
||||
if not isinstance(user_input_value, dict) and not isinstance(user_input_value, File):
|
||||
raise ValueError(f"{var.variable} in input form must be a file")
|
||||
elif var.type == VariableEntityType.FILE_LIST:
|
||||
if not (
|
||||
isinstance(user_input_value, list)
|
||||
and (
|
||||
all(isinstance(item, dict) for item in user_input_value)
|
||||
or all(isinstance(item, File) for item in user_input_value)
|
||||
)
|
||||
):
|
||||
raise ValueError(f"{var.variable} in input form must be a list of files")
|
||||
|
||||
return user_input_value
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
|
||||
from models.model import App, AppMode, Message, MessageAnnotation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.file_obj import FileVar
|
||||
from core.file.models import File
|
||||
|
||||
|
||||
class AppRunner:
|
||||
@ -37,7 +37,7 @@ class AppRunner:
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list["FileVar"],
|
||||
files: list["File"],
|
||||
query: Optional[str] = None,
|
||||
) -> int:
|
||||
"""
|
||||
@ -137,7 +137,7 @@ class AppRunner:
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list["FileVar"],
|
||||
files: list["File"],
|
||||
query: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
|
||||
@ -17,10 +17,11 @@ from core.app.apps.chat.generate_response_converter import ChatAppGenerateRespon
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from enums import CreatedByRole
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.account import Account
|
||||
from models.model import App, EndUser
|
||||
|
||||
@ -99,12 +100,19 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
# always enable retriever resource in debugger mode
|
||||
override_model_config_dict["retriever_resource"] = {"enabled": True}
|
||||
|
||||
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
||||
|
||||
# parse files
|
||||
files = args["files"] if args.get("files") else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=user.id,
|
||||
role=role,
|
||||
config=file_extra_config,
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
@ -117,7 +125,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(app_model.id)
|
||||
trace_manager = TraceQueueManager(app_id=app_model.id)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = ChatAppGenerateEntity(
|
||||
@ -125,15 +133,17 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
app_config=app_config,
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
|
||||
inputs=conversation.inputs
|
||||
if conversation
|
||||
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
parent_message_id=args.get("parent_message_id"),
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
trace_manager=trace_manager,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
# init generate records
|
||||
|
||||
@ -17,12 +17,12 @@ from core.app.apps.completion.generate_response_converter import CompletionAppGe
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from enums import CreatedByRole
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, EndUser, Message
|
||||
from factories import file_factory
|
||||
from models import Account, App, EndUser, Message
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
@ -88,12 +88,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
tenant_id=app_model.tenant_id, config=args.get("model_config")
|
||||
)
|
||||
|
||||
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
||||
|
||||
# parse files
|
||||
files = args["files"] if args.get("files") else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=user.id,
|
||||
role=role,
|
||||
config=file_extra_config,
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
@ -103,6 +110,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
user_id = user.id if isinstance(user, Account) else user.session_id
|
||||
trace_manager = TraceQueueManager(app_model.id)
|
||||
|
||||
# init application generate entity
|
||||
@ -110,7 +118,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
inputs=self._get_cleaned_inputs(inputs, app_config),
|
||||
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
user_id=user.id,
|
||||
@ -251,10 +259,16 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
override_model_config_dict["model"] = model_dict
|
||||
|
||||
# parse files
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict)
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(message.files, file_extra_config, user)
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=message.message_files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=user.id,
|
||||
role=role,
|
||||
config=file_extra_config,
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
|
||||
@ -26,7 +26,7 @@ from core.app.entities.task_entities import (
|
||||
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models import Account
|
||||
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
|
||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||
from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError
|
||||
@ -235,13 +235,13 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
for file in application_generate_entity.files:
|
||||
message_file = MessageFile(
|
||||
message_id=message.id,
|
||||
type=file.type.value,
|
||||
transfer_method=file.transfer_method.value,
|
||||
type=file.type,
|
||||
transfer_method=file.transfer_method,
|
||||
belongs_to="user",
|
||||
url=file.url,
|
||||
url=file.remote_url,
|
||||
upload_file_id=file.related_id,
|
||||
created_by_role=("account" if account_id else "end_user"),
|
||||
created_by=account_id or end_user_id,
|
||||
created_by=account_id or end_user_id or "",
|
||||
)
|
||||
db.session.add(message_file)
|
||||
db.session.commit()
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
import os
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
@ -20,13 +20,12 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from enums import CreatedByRole
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, EndUser
|
||||
from models.workflow import Workflow
|
||||
from factories import file_factory
|
||||
from models import Account, App, EndUser, Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -63,49 +62,46 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Generate App response.
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
:param app_model: App
|
||||
:param workflow: Workflow
|
||||
:param user: account or end user
|
||||
:param args: request args
|
||||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
:param call_depth: call depth
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
inputs = args["inputs"]
|
||||
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
||||
|
||||
# parse files
|
||||
files = args["files"] if args.get("files") else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
||||
else:
|
||||
file_objs = []
|
||||
system_files = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=user.id,
|
||||
role=role,
|
||||
config=file_extra_config,
|
||||
)
|
||||
|
||||
# convert to app config
|
||||
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
app_config = WorkflowAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
user_id = user.id if isinstance(user, Account) else user.session_id
|
||||
trace_manager = TraceQueueManager(app_model.id, user_id)
|
||||
trace_manager = TraceQueueManager(
|
||||
app_id=app_model.id,
|
||||
user_id=user.id if isinstance(user, Account) else user.session_id,
|
||||
)
|
||||
|
||||
inputs: Mapping[str, Any] = args["inputs"]
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
# init application generate entity
|
||||
application_generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
inputs=self._get_cleaned_inputs(inputs, app_config),
|
||||
files=file_objs,
|
||||
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
||||
files=system_files,
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
|
||||
@ -1,20 +1,19 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from enums import UserFrom
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, EndUser
|
||||
from models.workflow import WorkflowType
|
||||
@ -71,7 +70,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
db.session.close()
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
|
||||
if dify_config.DEBUG:
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# if only single iteration run is requested
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
@ -52,6 +51,7 @@ from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowAppLog,
|
||||
WorkflowAppLogCreatedFrom,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
@ -69,6 +69,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
_task_state: WorkflowTaskState
|
||||
_application_generate_entity: WorkflowAppGenerateEntity
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -103,6 +104,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
}
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._wip_workflow_node_executions = {}
|
||||
|
||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
"""
|
||||
@ -331,9 +333,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
outputs=json.dumps(event.outputs)
|
||||
if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs
|
||||
else None,
|
||||
outputs=event.outputs,
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
@ -20,7 +20,6 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
@ -45,6 +44,7 @@ from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.iteration.entities import IterationNodeData
|
||||
from core.workflow.nodes.node_mapping import node_classes
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models.model import App
|
||||
from models.workflow import Workflow
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
@ -6,7 +6,7 @@ from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
||||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
from core.file.file_obj import FileVar
|
||||
from core.file.models import File
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
|
||||
@ -22,7 +22,7 @@ class InvokeFrom(Enum):
|
||||
DEBUGGER = "debugger"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "InvokeFrom":
|
||||
def value_of(cls, value: str):
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
@ -81,7 +81,7 @@ class AppGenerateEntity(BaseModel):
|
||||
app_config: AppConfig
|
||||
|
||||
inputs: Mapping[str, Any]
|
||||
files: list[FileVar] = []
|
||||
files: Sequence[File]
|
||||
user_id: str
|
||||
|
||||
# extras
|
||||
|
||||
@ -6,8 +6,9 @@ from pydantic import BaseModel, field_validator
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from enums import NodeType
|
||||
|
||||
|
||||
class QueueEvent(str, Enum):
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
@ -119,6 +120,7 @@ class MessageEndStreamResponse(StreamResponse):
|
||||
event: StreamEvent = StreamEvent.MESSAGE_END
|
||||
id: str
|
||||
metadata: dict = {}
|
||||
files: Optional[Sequence[Mapping[str, Any]]] = None
|
||||
|
||||
|
||||
class MessageFileStreamResponse(StreamResponse):
|
||||
@ -211,7 +213,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
||||
created_by: Optional[dict] = None
|
||||
created_at: int
|
||||
finished_at: int
|
||||
files: Optional[list[dict]] = []
|
||||
files: Optional[Sequence[Mapping[str, Any]]] = []
|
||||
|
||||
event: StreamEvent = StreamEvent.WORKFLOW_FINISHED
|
||||
workflow_run_id: str
|
||||
@ -296,7 +298,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
execution_metadata: Optional[dict] = None
|
||||
created_at: int
|
||||
finished_at: int
|
||||
files: Optional[list[dict]] = []
|
||||
files: Optional[Sequence[Mapping[str, Any]]] = []
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parent_parallel_id: Optional[str] = None
|
||||
|
||||
@ -1,18 +0,0 @@
|
||||
import re
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
from . import SegmentGroup, factory
|
||||
|
||||
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
|
||||
|
||||
|
||||
def convert_template(*, template: str, variable_pool: VariablePool):
|
||||
parts = re.split(VARIABLE_PATTERN, template)
|
||||
segments = []
|
||||
for part in filter(lambda x: x, parts):
|
||||
if "." in part and (value := variable_pool.get(part.split("."))):
|
||||
segments.append(value)
|
||||
else:
|
||||
segments.append(factory.build_segment(part))
|
||||
return SegmentGroup(value=segments)
|
||||
@ -1,8 +1,10 @@
|
||||
import logging
|
||||
from threading import Thread
|
||||
from typing import Optional, Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
@ -83,7 +85,9 @@ class MessageCycleManage:
|
||||
name = LLMGenerator.generate_conversation_name(app_model.tenant_id, query)
|
||||
conversation.name = name
|
||||
except Exception as e:
|
||||
logging.exception(f"generate conversation name failed: {e}")
|
||||
if dify_config.DEBUG:
|
||||
logging.exception(f"generate conversation name failed: {e}")
|
||||
pass
|
||||
|
||||
db.session.merge(conversation)
|
||||
db.session.commit()
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
@ -27,27 +28,25 @@ from core.app.entities.task_entities import (
|
||||
WorkflowStartStreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.file.file_obj import FileVar
|
||||
from core.file import FILE_MODEL_IDENTITY, File
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from enums import CreatedByRole, NodeType, WorkflowRunTriggeredFrom
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
from models.workflow import (
|
||||
CreatedByRole,
|
||||
Workflow,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
WorkflowRunTriggeredFrom,
|
||||
)
|
||||
|
||||
|
||||
@ -57,6 +56,7 @@ class WorkflowCycleManage:
|
||||
_user: Union[Account, EndUser]
|
||||
_task_state: WorkflowTaskState
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
||||
|
||||
def _handle_workflow_run_start(self) -> WorkflowRun:
|
||||
max_sequence = (
|
||||
@ -116,7 +116,7 @@ class WorkflowCycleManage:
|
||||
start_at: float,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
outputs: Optional[str] = None,
|
||||
outputs: Mapping[str, Any] | None = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> WorkflowRun:
|
||||
@ -132,8 +132,10 @@ class WorkflowCycleManage:
|
||||
"""
|
||||
workflow_run = self._refetch_workflow_run(workflow_run.id)
|
||||
|
||||
outputs = WorkflowEntry.handle_special_values(outputs)
|
||||
|
||||
workflow_run.status = WorkflowRunStatus.SUCCEEDED.value
|
||||
workflow_run.outputs = outputs
|
||||
workflow_run.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_run.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_run.total_tokens = total_tokens
|
||||
workflow_run.total_steps = total_steps
|
||||
@ -251,6 +253,8 @@ class WorkflowCycleManage:
|
||||
db.session.refresh(workflow_node_execution)
|
||||
db.session.close()
|
||||
|
||||
self._wip_workflow_node_executions[workflow_node_execution.node_execution_id] = workflow_node_execution
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
|
||||
@ -262,21 +266,39 @@ class WorkflowCycleManage:
|
||||
workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
|
||||
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
execution_metadata = (
|
||||
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
|
||||
)
|
||||
finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
|
||||
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
|
||||
{
|
||||
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.SUCCEEDED.value,
|
||||
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
|
||||
WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None,
|
||||
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
|
||||
WorkflowNodeExecution.execution_metadata: execution_metadata,
|
||||
WorkflowNodeExecution.finished_at: finished_at,
|
||||
WorkflowNodeExecution.elapsed_time: elapsed_time,
|
||||
}
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
|
||||
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.execution_metadata = (
|
||||
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
|
||||
)
|
||||
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
|
||||
workflow_node_execution.execution_metadata = execution_metadata
|
||||
workflow_node_execution.finished_at = finished_at
|
||||
workflow_node_execution.elapsed_time = elapsed_time
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
db.session.close()
|
||||
self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id)
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
@ -289,19 +311,36 @@ class WorkflowCycleManage:
|
||||
workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
|
||||
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
|
||||
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
|
||||
{
|
||||
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
|
||||
WorkflowNodeExecution.error: event.error,
|
||||
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
|
||||
WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None,
|
||||
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
|
||||
WorkflowNodeExecution.finished_at: finished_at,
|
||||
WorkflowNodeExecution.elapsed_time: elapsed_time,
|
||||
}
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = event.error
|
||||
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
|
||||
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
|
||||
workflow_node_execution.finished_at = finished_at
|
||||
workflow_node_execution.elapsed_time = elapsed_time
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
db.session.close()
|
||||
self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id)
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
@ -603,7 +642,7 @@ class WorkflowCycleManage:
|
||||
),
|
||||
)
|
||||
|
||||
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]:
|
||||
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> Sequence[Mapping[str, Any]]:
|
||||
"""
|
||||
Fetch files from node outputs
|
||||
:param outputs_dict: node outputs dict
|
||||
@ -612,15 +651,15 @@ class WorkflowCycleManage:
|
||||
if not outputs_dict:
|
||||
return []
|
||||
|
||||
files = []
|
||||
for output_var, output_value in outputs_dict.items():
|
||||
file_vars = self._fetch_files_from_variable_value(output_value)
|
||||
if file_vars:
|
||||
files.extend(file_vars)
|
||||
files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()]
|
||||
# Remove None
|
||||
files = [file for file in files if file]
|
||||
# Flatten list
|
||||
files = [file for sublist in files for file in sublist]
|
||||
|
||||
return files
|
||||
|
||||
def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> list[dict]:
|
||||
def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]:
|
||||
"""
|
||||
Fetch files from variable value
|
||||
:param value: variable value
|
||||
@ -632,17 +671,17 @@ class WorkflowCycleManage:
|
||||
files = []
|
||||
if isinstance(value, list):
|
||||
for item in value:
|
||||
file_var = self._get_file_var_from_value(item)
|
||||
if file_var:
|
||||
files.append(file_var)
|
||||
file = self._get_file_var_from_value(item)
|
||||
if file:
|
||||
files.append(file)
|
||||
elif isinstance(value, dict):
|
||||
file_var = self._get_file_var_from_value(value)
|
||||
if file_var:
|
||||
files.append(file_var)
|
||||
file = self._get_file_var_from_value(value)
|
||||
if file:
|
||||
files.append(file)
|
||||
|
||||
return files
|
||||
|
||||
def _get_file_var_from_value(self, value: Union[dict, list]) -> Optional[dict]:
|
||||
def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any] | None:
|
||||
"""
|
||||
Get file var from value
|
||||
:param value: variable value
|
||||
@ -651,14 +690,11 @@ class WorkflowCycleManage:
|
||||
if not value:
|
||||
return None
|
||||
|
||||
if isinstance(value, dict):
|
||||
if "__variant" in value and value["__variant"] == FileVar.__name__:
|
||||
return value
|
||||
elif isinstance(value, FileVar):
|
||||
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
return value
|
||||
elif isinstance(value, File):
|
||||
return value.to_dict()
|
||||
|
||||
return None
|
||||
|
||||
def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
|
||||
"""
|
||||
Refetch workflow run
|
||||
@ -678,17 +714,7 @@ class WorkflowCycleManage:
|
||||
:param node_execution_id: workflow node execution id
|
||||
:return:
|
||||
"""
|
||||
workflow_node_execution = (
|
||||
db.session.query(WorkflowNodeExecution)
|
||||
.filter(
|
||||
WorkflowNodeExecution.tenant_id == self._application_generate_entity.app_config.tenant_id,
|
||||
WorkflowNodeExecution.app_id == self._application_generate_entity.app_config.app_id,
|
||||
WorkflowNodeExecution.workflow_id == self._workflow.id,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
WorkflowNodeExecution.node_execution_id == node_execution_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
workflow_node_execution = self._wip_workflow_node_executions.get(node_execution_id)
|
||||
|
||||
if not workflow_node_execution:
|
||||
raise Exception(f"Workflow node execution not found: {node_execution_id}")
|
||||
|
||||
@ -5,6 +5,7 @@ from typing import Optional, cast
|
||||
import numpy as np
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from configs import dify_config
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
@ -110,6 +111,8 @@ class CacheEmbedding(Embeddings):
|
||||
embedding_results = embedding_result.embeddings[0]
|
||||
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
|
||||
except Exception as ex:
|
||||
if dify_config.DEBUG:
|
||||
logging.exception(f"Failed to embed query text: {ex}")
|
||||
raise ex
|
||||
|
||||
try:
|
||||
@ -122,6 +125,8 @@ class CacheEmbedding(Embeddings):
|
||||
encoded_str = encoded_vector.decode("utf-8")
|
||||
redis_client.setex(embedding_cache_key, 600, encoded_str)
|
||||
except Exception as ex:
|
||||
logging.exception("Failed to add embedding to redis %s", ex)
|
||||
if dify_config.DEBUG:
|
||||
logging.exception("Failed to add embedding to redis %s", ex)
|
||||
raise ex
|
||||
|
||||
return embedding_results
|
||||
|
||||
@ -1,29 +0,0 @@
|
||||
import enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class PromptMessageFileType(enum.Enum):
|
||||
IMAGE = "image"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in PromptMessageFileType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class PromptMessageFile(BaseModel):
|
||||
type: PromptMessageFileType
|
||||
data: Any = None
|
||||
|
||||
|
||||
class ImagePromptMessageFile(PromptMessageFile):
|
||||
class DETAIL(enum.Enum):
|
||||
LOW = "low"
|
||||
HIGH = "high"
|
||||
|
||||
type: PromptMessageFileType = PromptMessageFileType.IMAGE
|
||||
detail: DETAIL = DETAIL.LOW
|
||||
@ -0,0 +1,19 @@
|
||||
from .constants import FILE_MODEL_IDENTITY
|
||||
from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType
|
||||
from .models import (
|
||||
File,
|
||||
FileExtraConfig,
|
||||
ImageConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"FileType",
|
||||
"FileExtraConfig",
|
||||
"FileTransferMethod",
|
||||
"FileBelongsTo",
|
||||
"File",
|
||||
"ImageConfig",
|
||||
"FileAttribute",
|
||||
"ArrayFileAttribute",
|
||||
"FILE_MODEL_IDENTITY",
|
||||
]
|
||||
|
||||
1
api/core/file/constants.py
Normal file
1
api/core/file/constants.py
Normal file
@ -0,0 +1 @@
|
||||
FILE_MODEL_IDENTITY = "__dify__file__"
|
||||
55
api/core/file/enums.py
Normal file
55
api/core/file/enums.py
Normal file
@ -0,0 +1,55 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class FileType(str, Enum):
|
||||
IMAGE = "image"
|
||||
DOCUMENT = "document"
|
||||
AUDIO = "audio"
|
||||
VIDEO = "video"
|
||||
CUSTOM = "custom"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in FileType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileTransferMethod(str, Enum):
|
||||
REMOTE_URL = "remote_url"
|
||||
LOCAL_FILE = "local_file"
|
||||
TOOL_FILE = "tool_file"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in FileTransferMethod:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileBelongsTo(str, Enum):
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in FileBelongsTo:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileAttribute(str, Enum):
|
||||
TYPE = "type"
|
||||
SIZE = "size"
|
||||
NAME = "name"
|
||||
MIME_TYPE = "mime_type"
|
||||
TRANSFER_METHOD = "transfer_method"
|
||||
URL = "url"
|
||||
EXTENSION = "extension"
|
||||
|
||||
|
||||
class ArrayFileAttribute(str, Enum):
|
||||
LENGTH = "length"
|
||||
136
api/core/file/file_manager.py
Normal file
136
api/core/file/file_manager.py
Normal file
@ -0,0 +1,136 @@
|
||||
import base64
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models import UploadFile
|
||||
|
||||
from . import helpers
|
||||
from .enums import FileAttribute
|
||||
from .models import File, FileTransferMethod, FileType
|
||||
from .tool_file_parser import ToolFileParser
|
||||
|
||||
|
||||
def get_attr(*, file: "File", attr: "FileAttribute"):
|
||||
match attr:
|
||||
case FileAttribute.TYPE:
|
||||
return file.type.value
|
||||
case FileAttribute.SIZE:
|
||||
return file.size
|
||||
case FileAttribute.NAME:
|
||||
return file.filename
|
||||
case FileAttribute.MIME_TYPE:
|
||||
return file.mime_type
|
||||
case FileAttribute.TRANSFER_METHOD:
|
||||
return file.transfer_method.value
|
||||
case FileAttribute.URL:
|
||||
return file.remote_url
|
||||
case FileAttribute.EXTENSION:
|
||||
return file.extension
|
||||
case _:
|
||||
raise ValueError(f"Invalid file attribute: {attr}")
|
||||
|
||||
|
||||
def to_prompt_message_content(file: "File", /):
|
||||
"""
|
||||
Convert a File object to an ImagePromptMessageContent object.
|
||||
|
||||
This function takes a File object and converts it to an ImagePromptMessageContent
|
||||
object, which can be used as a prompt for image-based AI models.
|
||||
|
||||
Args:
|
||||
file (File): The File object to convert. Must be of type FileType.IMAGE.
|
||||
|
||||
Returns:
|
||||
ImagePromptMessageContent: An object containing the image data and detail level.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file is not an image or if the file data is missing.
|
||||
|
||||
Note:
|
||||
The detail level of the image prompt is determined by the file's extra_config.
|
||||
If not specified, it defaults to ImagePromptMessageContent.DETAIL.LOW.
|
||||
"""
|
||||
if file.type != FileType.IMAGE:
|
||||
raise ValueError("Only image file can convert to prompt message content")
|
||||
|
||||
url_or_b64_data = _get_url_or_b64_data(file=file)
|
||||
if url_or_b64_data is None:
|
||||
raise ValueError("Missing file data")
|
||||
|
||||
# decide the detail of image prompt message content
|
||||
if file._extra_config and file._extra_config.image_config and file._extra_config.image_config.detail:
|
||||
detail = file._extra_config.image_config.detail
|
||||
else:
|
||||
detail = ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
return ImagePromptMessageContent(data=url_or_b64_data, detail=detail)
|
||||
|
||||
|
||||
def download(*, upload_file_id: str, tenant_id: str):
|
||||
upload_file = (
|
||||
db.session.query(UploadFile).filter(UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id).first()
|
||||
)
|
||||
|
||||
if not upload_file:
|
||||
raise ValueError("upload file not found")
|
||||
|
||||
return _download(upload_file.key)
|
||||
|
||||
|
||||
def _download(path: str, /):
|
||||
"""
|
||||
Download and return the contents of a file as bytes.
|
||||
|
||||
This function loads the file from storage and ensures it's in bytes format.
|
||||
|
||||
Args:
|
||||
path (str): The path to the file in storage.
|
||||
|
||||
Returns:
|
||||
bytes: The contents of the file as a bytes object.
|
||||
|
||||
Raises:
|
||||
ValueError: If the loaded file is not a bytes object.
|
||||
"""
|
||||
data = storage.load(path, stream=False)
|
||||
if not isinstance(data, bytes):
|
||||
raise ValueError(f"file {path} is not a bytes object")
|
||||
return data
|
||||
|
||||
|
||||
def _get_base64(*, upload_file_id: str, tenant_id: str) -> str | None:
|
||||
upload_file = (
|
||||
db.session.query(UploadFile).filter(UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id).first()
|
||||
)
|
||||
|
||||
if not upload_file:
|
||||
return None
|
||||
|
||||
data = _download(upload_file.key)
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return f"data:{upload_file.mime_type};base64,{encoded_string}"
|
||||
|
||||
|
||||
def _get_url_or_b64_data(file: "File"):
|
||||
if file.type == FileType.IMAGE:
|
||||
if file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return file.remote_url
|
||||
elif file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if file.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
|
||||
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
|
||||
return helpers.get_signed_image_url(upload_file_id=file.related_id)
|
||||
return _get_base64(upload_file_id=file.related_id, tenant_id=file.tenant_id)
|
||||
elif file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
# add sign url
|
||||
if file.related_id is None or file.extension is None:
|
||||
raise ValueError("Missing file related_id or extension")
|
||||
return ToolFileParser.get_tool_file_manager().sign_file(
|
||||
tool_file_id=file.related_id, extension=file.extension
|
||||
)
|
||||
@ -1,145 +0,0 @@
|
||||
import enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.file.tool_file_parser import ToolFileParser
|
||||
from core.file.upload_file_parser import UploadFileParser
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
||||
class FileExtraConfig(BaseModel):
|
||||
"""
|
||||
File Upload Entity.
|
||||
"""
|
||||
|
||||
image_config: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class FileType(enum.Enum):
|
||||
IMAGE = "image"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in FileType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileTransferMethod(enum.Enum):
|
||||
REMOTE_URL = "remote_url"
|
||||
LOCAL_FILE = "local_file"
|
||||
TOOL_FILE = "tool_file"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in FileTransferMethod:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileBelongsTo(enum.Enum):
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in FileBelongsTo:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileVar(BaseModel):
|
||||
id: Optional[str] = None # message file id
|
||||
tenant_id: str
|
||||
type: FileType
|
||||
transfer_method: FileTransferMethod
|
||||
url: Optional[str] = None # remote url
|
||||
related_id: Optional[str] = None
|
||||
extra_config: Optional[FileExtraConfig] = None
|
||||
filename: Optional[str] = None
|
||||
extension: Optional[str] = None
|
||||
mime_type: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"__variant": self.__class__.__name__,
|
||||
"tenant_id": self.tenant_id,
|
||||
"type": self.type.value,
|
||||
"transfer_method": self.transfer_method.value,
|
||||
"url": self.preview_url,
|
||||
"remote_url": self.url,
|
||||
"related_id": self.related_id,
|
||||
"filename": self.filename,
|
||||
"extension": self.extension,
|
||||
"mime_type": self.mime_type,
|
||||
}
|
||||
|
||||
def to_markdown(self) -> str:
|
||||
"""
|
||||
Convert file to markdown
|
||||
:return:
|
||||
"""
|
||||
preview_url = self.preview_url
|
||||
if self.type == FileType.IMAGE:
|
||||
text = f''
|
||||
else:
|
||||
text = f"[{self.filename or preview_url}]({preview_url})"
|
||||
|
||||
return text
|
||||
|
||||
@property
|
||||
def data(self) -> Optional[str]:
|
||||
"""
|
||||
Get image data, file signed url or base64 data
|
||||
depending on config MULTIMODAL_SEND_IMAGE_FORMAT
|
||||
:return:
|
||||
"""
|
||||
return self._get_data()
|
||||
|
||||
@property
|
||||
def preview_url(self) -> Optional[str]:
|
||||
"""
|
||||
Get signed preview url
|
||||
:return:
|
||||
"""
|
||||
return self._get_data(force_url=True)
|
||||
|
||||
@property
|
||||
def prompt_message_content(self) -> ImagePromptMessageContent:
|
||||
if self.type == FileType.IMAGE:
|
||||
image_config = self.extra_config.image_config
|
||||
|
||||
return ImagePromptMessageContent(
|
||||
data=self.data,
|
||||
detail=ImagePromptMessageContent.DETAIL.HIGH
|
||||
if image_config.get("detail") == "high"
|
||||
else ImagePromptMessageContent.DETAIL.LOW,
|
||||
)
|
||||
|
||||
def _get_data(self, force_url: bool = False) -> Optional[str]:
|
||||
from models.model import UploadFile
|
||||
|
||||
if self.type == FileType.IMAGE:
|
||||
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return self.url
|
||||
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
upload_file = (
|
||||
db.session.query(UploadFile)
|
||||
.filter(UploadFile.id == self.related_id, UploadFile.tenant_id == self.tenant_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
return UploadFileParser.get_image_data(upload_file=upload_file, force_url=force_url)
|
||||
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
extension = self.extension
|
||||
# add sign url
|
||||
return ToolFileParser.get_tool_file_manager().sign_file(
|
||||
tool_file_id=self.related_id, extension=extension
|
||||
)
|
||||
|
||||
return None
|
||||
61
api/core/file/helpers.py
Normal file
61
api/core/file/helpers.py
Normal file
@ -0,0 +1,61 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import time
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
def get_signed_image_url(upload_file_id: str) -> str:
|
||||
url = f"{dify_config.FILES_URL}/files/{upload_file_id}/image-preview"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
key = dify_config.SECRET_KEY.encode()
|
||||
msg = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
sign = hmac.new(key, msg.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
|
||||
|
||||
def get_signed_file_url(upload_file_id: str) -> str:
|
||||
url = f"{dify_config.FILES_URL}/files/{upload_file_id}/file-preview"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
key = dify_config.SECRET_KEY.encode()
|
||||
msg = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
sign = hmac.new(key, msg.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
|
||||
|
||||
def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode()
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
if sign != recalculated_encoded_sign:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
|
||||
|
||||
def verify_file_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode()
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
if sign != recalculated_encoded_sign:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
@ -1,243 +0,0 @@
|
||||
import re
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Union
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import requests
|
||||
|
||||
from core.file.file_obj import FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType, FileVar
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import EndUser, MessageFile, UploadFile
|
||||
from services.file_service import IMAGE_EXTENSIONS
|
||||
|
||||
|
||||
class MessageFileParser:
|
||||
def __init__(self, tenant_id: str, app_id: str) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.app_id = app_id
|
||||
|
||||
def validate_and_transform_files_arg(
|
||||
self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig, user: Union[Account, EndUser]
|
||||
) -> list[FileVar]:
|
||||
"""
|
||||
validate and transform files arg
|
||||
|
||||
:param files:
|
||||
:param file_extra_config:
|
||||
:param user:
|
||||
:return:
|
||||
"""
|
||||
for file in files:
|
||||
if not isinstance(file, dict):
|
||||
raise ValueError("Invalid file format, must be dict")
|
||||
if not file.get("type"):
|
||||
raise ValueError("Missing file type")
|
||||
FileType.value_of(file.get("type"))
|
||||
if not file.get("transfer_method"):
|
||||
raise ValueError("Missing file transfer method")
|
||||
FileTransferMethod.value_of(file.get("transfer_method"))
|
||||
if file.get("transfer_method") == FileTransferMethod.REMOTE_URL.value:
|
||||
if not file.get("url"):
|
||||
raise ValueError("Missing file url")
|
||||
if not file.get("url").startswith("http"):
|
||||
raise ValueError("Invalid file url")
|
||||
if file.get("transfer_method") == FileTransferMethod.LOCAL_FILE.value and not file.get("upload_file_id"):
|
||||
raise ValueError("Missing file upload_file_id")
|
||||
if file.get("transform_method") == FileTransferMethod.TOOL_FILE.value and not file.get("tool_file_id"):
|
||||
raise ValueError("Missing file tool_file_id")
|
||||
|
||||
# transform files to file objs
|
||||
type_file_objs = self._to_file_objs(files, file_extra_config)
|
||||
|
||||
# validate files
|
||||
new_files = []
|
||||
for file_type, file_objs in type_file_objs.items():
|
||||
if file_type == FileType.IMAGE:
|
||||
# parse and validate files
|
||||
image_config = file_extra_config.image_config
|
||||
|
||||
# check if image file feature is enabled
|
||||
if not image_config:
|
||||
continue
|
||||
|
||||
# Validate number of files
|
||||
if len(files) > image_config["number_limits"]:
|
||||
raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}")
|
||||
|
||||
for file_obj in file_objs:
|
||||
# Validate transfer method
|
||||
if file_obj.transfer_method.value not in image_config["transfer_methods"]:
|
||||
raise ValueError(f"Invalid transfer method: {file_obj.transfer_method.value}")
|
||||
|
||||
# Validate file type
|
||||
if file_obj.type != FileType.IMAGE:
|
||||
raise ValueError(f"Invalid file type: {file_obj.type}")
|
||||
|
||||
if file_obj.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
# check remote url valid and is image
|
||||
result, error = self._check_image_remote_url(file_obj.url)
|
||||
if result is False:
|
||||
raise ValueError(error)
|
||||
elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
# get upload file from upload_file_id
|
||||
upload_file = (
|
||||
db.session.query(UploadFile)
|
||||
.filter(
|
||||
UploadFile.id == file_obj.related_id,
|
||||
UploadFile.tenant_id == self.tenant_id,
|
||||
UploadFile.created_by == user.id,
|
||||
UploadFile.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
|
||||
UploadFile.extension.in_(IMAGE_EXTENSIONS),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# check upload file is belong to tenant and user
|
||||
if not upload_file:
|
||||
raise ValueError("Invalid upload file")
|
||||
|
||||
new_files.append(file_obj)
|
||||
|
||||
# return all file objs
|
||||
return new_files
|
||||
|
||||
def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig):
|
||||
"""
|
||||
transform message files
|
||||
|
||||
:param files:
|
||||
:param file_extra_config:
|
||||
:return:
|
||||
"""
|
||||
# transform files to file objs
|
||||
type_file_objs = self._to_file_objs(files, file_extra_config)
|
||||
|
||||
# return all file objs
|
||||
return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
|
||||
|
||||
def _to_file_objs(
|
||||
self, files: list[Union[dict, MessageFile]], file_extra_config: FileExtraConfig
|
||||
) -> dict[FileType, list[FileVar]]:
|
||||
"""
|
||||
transform files to file objs
|
||||
|
||||
:param files:
|
||||
:param file_extra_config:
|
||||
:return:
|
||||
"""
|
||||
type_file_objs: dict[FileType, list[FileVar]] = {
|
||||
# Currently only support image
|
||||
FileType.IMAGE: []
|
||||
}
|
||||
|
||||
if not files:
|
||||
return type_file_objs
|
||||
|
||||
# group by file type and convert file args or message files to FileObj
|
||||
for file in files:
|
||||
if isinstance(file, MessageFile):
|
||||
if file.belongs_to == FileBelongsTo.ASSISTANT.value:
|
||||
continue
|
||||
|
||||
file_obj = self._to_file_obj(file, file_extra_config)
|
||||
if file_obj.type not in type_file_objs:
|
||||
continue
|
||||
|
||||
type_file_objs[file_obj.type].append(file_obj)
|
||||
|
||||
return type_file_objs
|
||||
|
||||
def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig):
|
||||
"""
|
||||
transform file to file obj
|
||||
|
||||
:param file:
|
||||
:return:
|
||||
"""
|
||||
if isinstance(file, dict):
|
||||
transfer_method = FileTransferMethod.value_of(file.get("transfer_method"))
|
||||
if transfer_method != FileTransferMethod.TOOL_FILE:
|
||||
return FileVar(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.value_of(file.get("type")),
|
||||
transfer_method=transfer_method,
|
||||
url=file.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None,
|
||||
related_id=file.get("upload_file_id") if transfer_method == FileTransferMethod.LOCAL_FILE else None,
|
||||
extra_config=file_extra_config,
|
||||
)
|
||||
return FileVar(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.value_of(file.get("type")),
|
||||
transfer_method=transfer_method,
|
||||
url=None,
|
||||
related_id=file.get("tool_file_id"),
|
||||
extra_config=file_extra_config,
|
||||
)
|
||||
else:
|
||||
return FileVar(
|
||||
id=file.id,
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.value_of(file.type),
|
||||
transfer_method=FileTransferMethod.value_of(file.transfer_method),
|
||||
url=file.url,
|
||||
related_id=file.upload_file_id or None,
|
||||
extra_config=file_extra_config,
|
||||
)
|
||||
|
||||
def _check_image_remote_url(self, url):
|
||||
try:
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
|
||||
" Chrome/91.0.4472.124 Safari/537.36"
|
||||
}
|
||||
|
||||
def is_s3_presigned_url(url):
|
||||
try:
|
||||
parsed_url = urlparse(url)
|
||||
if "amazonaws.com" not in parsed_url.netloc:
|
||||
return False
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
|
||||
def check_presign_v2(query_params):
|
||||
required_params = ["Signature", "Expires"]
|
||||
for param in required_params:
|
||||
if param not in query_params:
|
||||
return False
|
||||
if not query_params["Expires"][0].isdigit():
|
||||
return False
|
||||
signature = query_params["Signature"][0]
|
||||
if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def check_presign_v4(query_params):
|
||||
required_params = ["X-Amz-Signature", "X-Amz-Expires"]
|
||||
for param in required_params:
|
||||
if param not in query_params:
|
||||
return False
|
||||
if not query_params["X-Amz-Expires"][0].isdigit():
|
||||
return False
|
||||
signature = query_params["X-Amz-Signature"][0]
|
||||
if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
return check_presign_v4(query_params) or check_presign_v2(query_params)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
if is_s3_presigned_url(url):
|
||||
response = requests.get(url, headers=headers, allow_redirects=True)
|
||||
if response.status_code in {200, 304}:
|
||||
return True, ""
|
||||
|
||||
response = requests.head(url, headers=headers, allow_redirects=True)
|
||||
if response.status_code in {200, 304}:
|
||||
return True, ""
|
||||
else:
|
||||
return False, "URL does not exist."
|
||||
except requests.RequestException as e:
|
||||
return False, f"Error checking URL: {e}"
|
||||
140
api/core/file/models.py
Normal file
140
api/core/file/models.py
Normal file
@ -0,0 +1,140 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
|
||||
from . import helpers
|
||||
from .constants import FILE_MODEL_IDENTITY
|
||||
from .enums import FileTransferMethod, FileType
|
||||
from .tool_file_parser import ToolFileParser
|
||||
|
||||
|
||||
class ImageConfig(BaseModel):
|
||||
"""
|
||||
NOTE: This part of validation is deprecated, but still used in app features "Image Upload".
|
||||
"""
|
||||
|
||||
number_limits: int = 0
|
||||
transfer_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||
detail: ImagePromptMessageContent.DETAIL | None = None
|
||||
|
||||
|
||||
class FileExtraConfig(BaseModel):
|
||||
"""
|
||||
File Upload Entity.
|
||||
"""
|
||||
|
||||
image_config: Optional[ImageConfig] = None
|
||||
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
|
||||
allowed_extensions: Sequence[str] = Field(default_factory=list)
|
||||
allowed_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||
number_limits: int = 0
|
||||
|
||||
|
||||
class File(BaseModel):
|
||||
dify_model_identity: str = FILE_MODEL_IDENTITY
|
||||
|
||||
id: Optional[str] = None # message file id
|
||||
tenant_id: str
|
||||
type: FileType
|
||||
transfer_method: FileTransferMethod
|
||||
remote_url: Optional[str] = None # remote url
|
||||
related_id: Optional[str] = None
|
||||
filename: Optional[str] = None
|
||||
extension: Optional[str] = Field(default=None, description="File extension, should contains dot")
|
||||
mime_type: Optional[str] = None
|
||||
size: int = -1
|
||||
_extra_config: FileExtraConfig | None = None
|
||||
|
||||
def to_dict(self) -> Mapping[str, str | int | None]:
|
||||
data = self.model_dump(mode="json")
|
||||
return {
|
||||
**data,
|
||||
"url": self.generate_url(),
|
||||
}
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
url = self.generate_url()
|
||||
if self.type == FileType.IMAGE:
|
||||
text = f''
|
||||
else:
|
||||
text = f"[{self.filename or url}]({url})"
|
||||
|
||||
return text
|
||||
|
||||
def generate_url(self) -> Optional[str]:
|
||||
if self.type == FileType.IMAGE:
|
||||
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return self.remote_url
|
||||
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if self.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
return helpers.get_signed_image_url(upload_file_id=self.related_id)
|
||||
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
assert self.related_id is not None
|
||||
assert self.extension is not None
|
||||
return ToolFileParser.get_tool_file_manager().sign_file(
|
||||
tool_file_id=self.related_id, extension=self.extension
|
||||
)
|
||||
else:
|
||||
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return self.remote_url
|
||||
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if self.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
return helpers.get_signed_file_url(upload_file_id=self.related_id)
|
||||
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
assert self.related_id is not None
|
||||
assert self.extension is not None
|
||||
return ToolFileParser.get_tool_file_manager().sign_file(
|
||||
tool_file_id=self.related_id, extension=self.extension
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_after(self):
|
||||
match self.transfer_method:
|
||||
case FileTransferMethod.REMOTE_URL:
|
||||
if not self.remote_url:
|
||||
raise ValueError("Missing file url")
|
||||
if not isinstance(self.remote_url, str) or not self.remote_url.startswith("http"):
|
||||
raise ValueError("Invalid file url")
|
||||
case FileTransferMethod.LOCAL_FILE:
|
||||
if not self.related_id:
|
||||
raise ValueError("Missing file related_id")
|
||||
case FileTransferMethod.TOOL_FILE:
|
||||
if not self.related_id:
|
||||
raise ValueError("Missing file related_id")
|
||||
|
||||
# Validate the extra config.
|
||||
if not self._extra_config:
|
||||
return self
|
||||
|
||||
if self._extra_config.allowed_file_types:
|
||||
if self.type not in self._extra_config.allowed_file_types and self.type != FileType.CUSTOM:
|
||||
raise ValueError(f"Invalid file type: {self.type}")
|
||||
|
||||
if self._extra_config.allowed_extensions and self.extension not in self._extra_config.allowed_extensions:
|
||||
raise ValueError(f"Invalid file extension: {self.extension}")
|
||||
|
||||
if (
|
||||
self._extra_config.allowed_upload_methods
|
||||
and self.transfer_method not in self._extra_config.allowed_upload_methods
|
||||
):
|
||||
raise ValueError(f"Invalid transfer method: {self.transfer_method}")
|
||||
|
||||
match self.type:
|
||||
case FileType.IMAGE:
|
||||
# NOTE: This part of validation is deprecated, but still used in app features "Image Upload".
|
||||
if not self._extra_config.image_config:
|
||||
return self
|
||||
# TODO: skip check if transfer_methods is empty, because many test cases are not setting this field
|
||||
if (
|
||||
self._extra_config.image_config.transfer_methods
|
||||
and self.transfer_method not in self._extra_config.image_config.transfer_methods
|
||||
):
|
||||
raise ValueError(f"Invalid transfer method: {self.transfer_method}")
|
||||
|
||||
return self
|
||||
@ -1,4 +1,9 @@
|
||||
tool_file_manager = {"manager": None}
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
|
||||
tool_file_manager: dict[str, Any] = {"manager": None}
|
||||
|
||||
|
||||
class ToolFileParser:
|
||||
|
||||
@ -1,79 +0,0 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
|
||||
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
|
||||
|
||||
|
||||
class UploadFileParser:
|
||||
@classmethod
|
||||
def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]:
|
||||
if not upload_file:
|
||||
return None
|
||||
|
||||
if upload_file.extension not in IMAGE_EXTENSIONS:
|
||||
return None
|
||||
|
||||
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url" or force_url:
|
||||
return cls.get_signed_temp_image_url(upload_file.id)
|
||||
else:
|
||||
# get image file base64
|
||||
try:
|
||||
data = storage.load(upload_file.key)
|
||||
except FileNotFoundError:
|
||||
logging.error(f"File not found: {upload_file.key}")
|
||||
return None
|
||||
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return f"data:{upload_file.mime_type};base64,{encoded_string}"
|
||||
|
||||
@classmethod
|
||||
def get_signed_temp_image_url(cls, upload_file_id) -> str:
|
||||
"""
|
||||
get signed url from upload file
|
||||
|
||||
:param upload_file: UploadFile object
|
||||
:return:
|
||||
"""
|
||||
base_url = dify_config.FILES_URL
|
||||
image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode()
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return f"{image_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
|
||||
@classmethod
|
||||
def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
"""
|
||||
verify signature
|
||||
|
||||
:param upload_file_id: file id
|
||||
:param timestamp: timestamp
|
||||
:param nonce: nonce
|
||||
:param sign: signature
|
||||
:return:
|
||||
"""
|
||||
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode()
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
if sign != recalculated_encoded_sign:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
@ -13,8 +13,11 @@ SSRF_PROXY_HTTP_URL = os.getenv("SSRF_PROXY_HTTP_URL", "")
|
||||
SSRF_PROXY_HTTPS_URL = os.getenv("SSRF_PROXY_HTTPS_URL", "")
|
||||
SSRF_DEFAULT_MAX_RETRIES = int(os.getenv("SSRF_DEFAULT_MAX_RETRIES", "3"))
|
||||
|
||||
proxies = (
|
||||
{"http://": SSRF_PROXY_HTTP_URL, "https://": SSRF_PROXY_HTTPS_URL}
|
||||
proxy_mounts = (
|
||||
{
|
||||
"http://": httpx.HTTPTransport(proxy=SSRF_PROXY_HTTP_URL),
|
||||
"https://": httpx.HTTPTransport(proxy=SSRF_PROXY_HTTPS_URL),
|
||||
}
|
||||
if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL
|
||||
else None
|
||||
)
|
||||
@ -33,11 +36,14 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
if SSRF_PROXY_ALL_URL:
|
||||
response = httpx.request(method=method, url=url, proxy=SSRF_PROXY_ALL_URL, **kwargs)
|
||||
elif proxies:
|
||||
response = httpx.request(method=method, url=url, proxies=proxies, **kwargs)
|
||||
with httpx.Client(proxy=SSRF_PROXY_ALL_URL) as client:
|
||||
response = client.request(method=method, url=url, **kwargs)
|
||||
elif proxy_mounts:
|
||||
with httpx.Client(mounts=proxy_mounts) as client:
|
||||
response = client.request(method=method, url=url, **kwargs)
|
||||
else:
|
||||
response = httpx.request(method=method, url=url, **kwargs)
|
||||
with httpx.Client() as client:
|
||||
response = client.request(method=method, url=url, **kwargs)
|
||||
|
||||
if response.status_code not in STATUS_FORCELIST:
|
||||
return response
|
||||
|
||||
@ -1,18 +1,20 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.file import FileType, file_manager
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.model import AppMode, Conversation, Message, MessageFile
|
||||
from models.workflow import WorkflowRun
|
||||
|
||||
@ -60,18 +62,17 @@ class TokenBufferMemory:
|
||||
thread_messages = extract_thread_messages(messages)
|
||||
|
||||
# for newly created message, its answer is temporarily empty, we don't need to add it to memory
|
||||
if thread_messages and not thread_messages[-1].answer:
|
||||
thread_messages.pop()
|
||||
if thread_messages and not thread_messages[0].answer:
|
||||
thread_messages.pop(0)
|
||||
|
||||
messages = list(reversed(thread_messages))
|
||||
|
||||
message_file_parser = MessageFileParser(tenant_id=app_record.tenant_id, app_id=app_record.id)
|
||||
prompt_messages = []
|
||||
for message in messages:
|
||||
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
|
||||
if files:
|
||||
file_extra_config = None
|
||||
if self.conversation.mode not in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
||||
else:
|
||||
if message.workflow_run_id:
|
||||
@ -84,17 +85,23 @@ class TokenBufferMemory:
|
||||
workflow_run.workflow.features_dict, is_vision=False
|
||||
)
|
||||
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.transform_message_files(files, file_extra_config)
|
||||
if file_extra_config and app_record:
|
||||
file_objs = file_factory.build_from_message_files(
|
||||
message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
if not file_objs:
|
||||
prompt_messages.append(UserPromptMessage(content=message.query))
|
||||
else:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=message.query)]
|
||||
prompt_message_contents: list[PromptMessageContent] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
||||
for file_obj in file_objs:
|
||||
prompt_message_contents.append(file_obj.prompt_message_content)
|
||||
if file_obj.type != FileType.IMAGE:
|
||||
continue
|
||||
prompt_message = file_manager.to_prompt_message_content(file_obj)
|
||||
prompt_message_contents.append(prompt_message)
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Callable, Generator, Sequence
|
||||
from typing import IO, Optional, Union, cast
|
||||
from collections.abc import Callable, Generator, Iterable, Sequence
|
||||
from typing import IO, Any, Optional, Union, cast
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
@ -274,7 +274,7 @@ class ModelInstance:
|
||||
user=user,
|
||||
)
|
||||
|
||||
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> str:
|
||||
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> Iterable[bytes]:
|
||||
"""
|
||||
Invoke large language tts model
|
||||
|
||||
@ -298,7 +298,7 @@ class ModelInstance:
|
||||
voice=voice,
|
||||
)
|
||||
|
||||
def _round_robin_invoke(self, function: Callable, *args, **kwargs):
|
||||
def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs):
|
||||
"""
|
||||
Round-robin invoke
|
||||
:param function: function to invoke
|
||||
|
||||
@ -0,0 +1,36 @@
|
||||
from .llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from .message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentType,
|
||||
PromptMessageRole,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from .model_entities import ModelPropertyKey
|
||||
|
||||
__all__ = [
|
||||
"ImagePromptMessageContent",
|
||||
"PromptMessage",
|
||||
"PromptMessageRole",
|
||||
"LLMUsage",
|
||||
"ModelPropertyKey",
|
||||
"AssistantPromptMessage",
|
||||
"PromptMessage",
|
||||
"PromptMessageContent",
|
||||
"PromptMessageRole",
|
||||
"SystemPromptMessage",
|
||||
"TextPromptMessageContent",
|
||||
"UserPromptMessage",
|
||||
"PromptMessageTool",
|
||||
"ToolPromptMessage",
|
||||
"PromptMessageContentType",
|
||||
"LLMResult",
|
||||
"LLMResultChunk",
|
||||
"LLMResultChunkDelta",
|
||||
]
|
||||
|
||||
@ -79,7 +79,7 @@ class ImagePromptMessageContent(PromptMessageContent):
|
||||
Model class for image prompt message content.
|
||||
"""
|
||||
|
||||
class DETAIL(Enum):
|
||||
class DETAIL(str, Enum):
|
||||
LOW = "low"
|
||||
HIGH = "high"
|
||||
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
@ -8,6 +7,7 @@ from typing import Optional, Union
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.callbacks.logging_callback import LoggingCallback
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
@ -77,7 +77,7 @@ class LargeLanguageModel(AIModel):
|
||||
|
||||
callbacks = callbacks or []
|
||||
|
||||
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
|
||||
if dify_config.DEBUG:
|
||||
callbacks.append(LoggingCallback())
|
||||
|
||||
# trigger before invoke callbacks
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import ConfigDict
|
||||
@ -22,8 +23,14 @@ class TTSModel(AIModel):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(
|
||||
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
|
||||
):
|
||||
self,
|
||||
model: str,
|
||||
tenant_id: str,
|
||||
credentials: dict,
|
||||
content_text: str,
|
||||
voice: str,
|
||||
user: Optional[str] = None,
|
||||
) -> Iterable[bytes]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
@ -50,8 +57,14 @@ class TTSModel(AIModel):
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(
|
||||
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
|
||||
):
|
||||
self,
|
||||
model: str,
|
||||
tenant_id: str,
|
||||
credentials: dict,
|
||||
content_text: str,
|
||||
voice: str,
|
||||
user: Optional[str] = None,
|
||||
) -> Iterable[bytes]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
@ -68,25 +81,25 @@ class TTSModel(AIModel):
|
||||
|
||||
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
|
||||
"""
|
||||
Get voice for given tts model voices
|
||||
Retrieves the list of voices supported by a given text-to-speech (TTS) model.
|
||||
|
||||
:param language: tts language
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: voices lists
|
||||
:param language: The language for which the voices are requested.
|
||||
:param model: The name of the TTS model.
|
||||
:param credentials: The credentials required to access the TTS model.
|
||||
:return: A list of voices supported by the TTS model.
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.VOICES in model_schema.model_properties:
|
||||
voices = model_schema.model_properties[ModelPropertyKey.VOICES]
|
||||
if language:
|
||||
return [
|
||||
{"name": d["name"], "value": d["mode"]}
|
||||
for d in voices
|
||||
if language and language in d.get("language")
|
||||
]
|
||||
else:
|
||||
return [{"name": d["name"], "value": d["mode"]} for d in voices]
|
||||
if not model_schema or ModelPropertyKey.VOICES not in model_schema.model_properties:
|
||||
raise ValueError("this model does not support voice")
|
||||
|
||||
voices = model_schema.model_properties[ModelPropertyKey.VOICES]
|
||||
if language:
|
||||
return [
|
||||
{"name": d["name"], "value": d["mode"]} for d in voices if language and language in d.get("language")
|
||||
]
|
||||
else:
|
||||
return [{"name": d["name"], "value": d["mode"]} for d in voices]
|
||||
|
||||
def _get_model_default_voice(self, model: str, credentials: dict) -> Any:
|
||||
"""
|
||||
@ -111,8 +124,10 @@ class TTSModel(AIModel):
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.AUDIO_TYPE in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE]
|
||||
if not model_schema or ModelPropertyKey.AUDIO_TYPE not in model_schema.model_properties:
|
||||
raise ValueError("this model does not support audio type")
|
||||
|
||||
return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE]
|
||||
|
||||
def _get_model_word_limit(self, model: str, credentials: dict) -> int:
|
||||
"""
|
||||
@ -121,8 +136,10 @@ class TTSModel(AIModel):
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.WORD_LIMIT in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]
|
||||
if not model_schema or ModelPropertyKey.WORD_LIMIT not in model_schema.model_properties:
|
||||
raise ValueError("this model does not support word limit")
|
||||
|
||||
return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]
|
||||
|
||||
def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
|
||||
"""
|
||||
@ -131,8 +148,10 @@ class TTSModel(AIModel):
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.MAX_WORKERS in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
|
||||
if not model_schema or ModelPropertyKey.MAX_WORKERS not in model_schema.model_properties:
|
||||
raise ValueError("this model does not support max workers")
|
||||
|
||||
return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
|
||||
|
||||
@staticmethod
|
||||
def _split_text_into_sentences(org_text, max_length=2000, pattern=r"[。.!?]"):
|
||||
|
||||
@ -1,6 +1,10 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from zhipuai import ZhipuAI
|
||||
from zhipuai.types.chat.chat_completion import Completion
|
||||
from zhipuai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
@ -16,9 +20,6 @@ from core.model_runtime.entities.message_entities import (
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
|
||||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
|
||||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion import Completion
|
||||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from core.model_runtime.utils import helper
|
||||
|
||||
GLM_JSON_MODE_PROMPT = """You should always follow the instructions and output a valid JSON object.
|
||||
|
||||
@ -1,13 +1,14 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from zhipuai import ZhipuAI
|
||||
|
||||
from core.embedding.embedding_constant import EmbeddingInputType
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
|
||||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
|
||||
|
||||
|
||||
class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
|
||||
|
||||
@ -1,15 +0,0 @@
|
||||
from .__version__ import __version__
|
||||
from ._client import ZhipuAI
|
||||
from .core import (
|
||||
APIAuthenticationError,
|
||||
APIConnectionError,
|
||||
APIInternalError,
|
||||
APIReachLimitError,
|
||||
APIRequestFailedError,
|
||||
APIResponseError,
|
||||
APIResponseValidationError,
|
||||
APIServerFlowExceedError,
|
||||
APIStatusError,
|
||||
APITimeoutError,
|
||||
ZhipuAIError,
|
||||
)
|
||||
@ -1 +0,0 @@
|
||||
__version__ = "v2.1.0"
|
||||
@ -1,82 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from typing import Union
|
||||
|
||||
import httpx
|
||||
from httpx import Timeout
|
||||
from typing_extensions import override
|
||||
|
||||
from . import api_resource
|
||||
from .core import NOT_GIVEN, ZHIPUAI_DEFAULT_MAX_RETRIES, HttpClient, NotGiven, ZhipuAIError, _jwt_token
|
||||
|
||||
|
||||
class ZhipuAI(HttpClient):
|
||||
chat: api_resource.chat.Chat
|
||||
api_key: str
|
||||
_disable_token_cache: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
base_url: str | httpx.URL | None = None,
|
||||
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
|
||||
max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES,
|
||||
http_client: httpx.Client | None = None,
|
||||
custom_headers: Mapping[str, str] | None = None,
|
||||
disable_token_cache: bool = True,
|
||||
_strict_response_validation: bool = False,
|
||||
) -> None:
|
||||
if api_key is None:
|
||||
api_key = os.environ.get("ZHIPUAI_API_KEY")
|
||||
if api_key is None:
|
||||
raise ZhipuAIError("未提供api_key,请通过参数或环境变量提供")
|
||||
self.api_key = api_key
|
||||
self._disable_token_cache = disable_token_cache
|
||||
|
||||
if base_url is None:
|
||||
base_url = os.environ.get("ZHIPUAI_BASE_URL")
|
||||
if base_url is None:
|
||||
base_url = "https://open.bigmodel.cn/api/paas/v4"
|
||||
from .__version__ import __version__
|
||||
|
||||
super().__init__(
|
||||
version=__version__,
|
||||
base_url=base_url,
|
||||
max_retries=max_retries,
|
||||
timeout=timeout,
|
||||
custom_httpx_client=http_client,
|
||||
custom_headers=custom_headers,
|
||||
_strict_response_validation=_strict_response_validation,
|
||||
)
|
||||
self.chat = api_resource.chat.Chat(self)
|
||||
self.images = api_resource.images.Images(self)
|
||||
self.embeddings = api_resource.embeddings.Embeddings(self)
|
||||
self.files = api_resource.files.Files(self)
|
||||
self.fine_tuning = api_resource.fine_tuning.FineTuning(self)
|
||||
self.batches = api_resource.Batches(self)
|
||||
self.knowledge = api_resource.Knowledge(self)
|
||||
self.tools = api_resource.Tools(self)
|
||||
self.videos = api_resource.Videos(self)
|
||||
self.assistant = api_resource.Assistant(self)
|
||||
|
||||
@property
|
||||
@override
|
||||
def auth_headers(self) -> dict[str, str]:
|
||||
api_key = self.api_key
|
||||
if self._disable_token_cache:
|
||||
return {"Authorization": f"Bearer {api_key}"}
|
||||
else:
|
||||
return {"Authorization": f"Bearer {_jwt_token.generate_token(api_key)}"}
|
||||
|
||||
def __del__(self) -> None:
|
||||
if not hasattr(self, "_has_custom_http_client") or not hasattr(self, "close") or not hasattr(self, "_client"):
|
||||
# if the '__init__' method raised an error, self would not have client attr
|
||||
return
|
||||
|
||||
if self._has_custom_http_client:
|
||||
return
|
||||
|
||||
self.close()
|
||||
@ -1,34 +0,0 @@
|
||||
from .assistant import (
|
||||
Assistant,
|
||||
)
|
||||
from .batches import Batches
|
||||
from .chat import (
|
||||
AsyncCompletions,
|
||||
Chat,
|
||||
Completions,
|
||||
)
|
||||
from .embeddings import Embeddings
|
||||
from .files import Files, FilesWithRawResponse
|
||||
from .fine_tuning import FineTuning
|
||||
from .images import Images
|
||||
from .knowledge import Knowledge
|
||||
from .tools import Tools
|
||||
from .videos import (
|
||||
Videos,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Videos",
|
||||
"AsyncCompletions",
|
||||
"Chat",
|
||||
"Completions",
|
||||
"Images",
|
||||
"Embeddings",
|
||||
"Files",
|
||||
"FilesWithRawResponse",
|
||||
"FineTuning",
|
||||
"Batches",
|
||||
"Knowledge",
|
||||
"Tools",
|
||||
"Assistant",
|
||||
]
|
||||
@ -1,3 +0,0 @@
|
||||
from .assistant import Assistant
|
||||
|
||||
__all__ = ["Assistant"]
|
||||
@ -1,122 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from ...core import (
|
||||
NOT_GIVEN,
|
||||
BaseAPI,
|
||||
Body,
|
||||
Headers,
|
||||
NotGiven,
|
||||
StreamResponse,
|
||||
deepcopy_minimal,
|
||||
make_request_options,
|
||||
maybe_transform,
|
||||
)
|
||||
from ...types.assistant import AssistantCompletion
|
||||
from ...types.assistant.assistant_conversation_resp import ConversationUsageListResp
|
||||
from ...types.assistant.assistant_support_resp import AssistantSupportResp
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..._client import ZhipuAI
|
||||
|
||||
from ...types.assistant import assistant_conversation_params, assistant_create_params
|
||||
|
||||
__all__ = ["Assistant"]
|
||||
|
||||
|
||||
class Assistant(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def conversation(
|
||||
self,
|
||||
assistant_id: str,
|
||||
model: str,
|
||||
messages: list[assistant_create_params.ConversationMessage],
|
||||
*,
|
||||
stream: bool = True,
|
||||
conversation_id: Optional[str] = None,
|
||||
attachments: Optional[list[assistant_create_params.AssistantAttachments]] = None,
|
||||
metadata: dict | None = None,
|
||||
request_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> StreamResponse[AssistantCompletion]:
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"assistant_id": assistant_id,
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
"conversation_id": conversation_id,
|
||||
"attachments": attachments,
|
||||
"metadata": metadata,
|
||||
"request_id": request_id,
|
||||
"user_id": user_id,
|
||||
}
|
||||
)
|
||||
return self._post(
|
||||
"/assistant",
|
||||
body=maybe_transform(body, assistant_create_params.AssistantParameters),
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=AssistantCompletion,
|
||||
stream=stream or True,
|
||||
stream_cls=StreamResponse[AssistantCompletion],
|
||||
)
|
||||
|
||||
def query_support(
|
||||
self,
|
||||
*,
|
||||
assistant_id_list: Optional[list[str]] = None,
|
||||
request_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> AssistantSupportResp:
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"assistant_id_list": assistant_id_list,
|
||||
"request_id": request_id,
|
||||
"user_id": user_id,
|
||||
}
|
||||
)
|
||||
return self._post(
|
||||
"/assistant/list",
|
||||
body=body,
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=AssistantSupportResp,
|
||||
)
|
||||
|
||||
def query_conversation_usage(
|
||||
self,
|
||||
assistant_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
*,
|
||||
request_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> ConversationUsageListResp:
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"assistant_id": assistant_id,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"request_id": request_id,
|
||||
"user_id": user_id,
|
||||
}
|
||||
)
|
||||
return self._post(
|
||||
"/assistant/conversation/list",
|
||||
body=maybe_transform(body, assistant_conversation_params.ConversationParameters),
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=ConversationUsageListResp,
|
||||
)
|
||||
@ -1,146 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Literal, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from ..core import NOT_GIVEN, BaseAPI, Body, Headers, NotGiven, make_request_options, maybe_transform
|
||||
from ..core.pagination import SyncCursorPage
|
||||
from ..types import batch_create_params, batch_list_params
|
||||
from ..types.batch import Batch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .._client import ZhipuAI
|
||||
|
||||
|
||||
class Batches(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
completion_window: str | None = None,
|
||||
endpoint: Literal["/v1/chat/completions", "/v1/embeddings"],
|
||||
input_file_id: str,
|
||||
metadata: Optional[dict[str, str]] | NotGiven = NOT_GIVEN,
|
||||
auto_delete_input_file: bool = True,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> Batch:
|
||||
return self._post(
|
||||
"/batches",
|
||||
body=maybe_transform(
|
||||
{
|
||||
"completion_window": completion_window,
|
||||
"endpoint": endpoint,
|
||||
"input_file_id": input_file_id,
|
||||
"metadata": metadata,
|
||||
"auto_delete_input_file": auto_delete_input_file,
|
||||
},
|
||||
batch_create_params.BatchCreateParams,
|
||||
),
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=Batch,
|
||||
)
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
batch_id: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> Batch:
|
||||
"""
|
||||
Retrieves a batch.
|
||||
|
||||
Args:
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
"""
|
||||
if not batch_id:
|
||||
raise ValueError(f"Expected a non-empty value for `batch_id` but received {batch_id!r}")
|
||||
return self._get(
|
||||
f"/batches/{batch_id}",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=Batch,
|
||||
)
|
||||
|
||||
def list(
|
||||
self,
|
||||
*,
|
||||
after: str | NotGiven = NOT_GIVEN,
|
||||
limit: int | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> SyncCursorPage[Batch]:
|
||||
"""List your organization's batches.
|
||||
|
||||
Args:
|
||||
after: A cursor for use in pagination.
|
||||
|
||||
`after` is an object ID that defines your place
|
||||
in the list. For instance, if you make a list request and receive 100 objects,
|
||||
ending with obj_foo, your subsequent call can include after=obj_foo in order to
|
||||
fetch the next page of the list.
|
||||
|
||||
limit: A limit on the number of objects to be returned. Limit can range between 1 and
|
||||
100, and the default is 20.
|
||||
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
"""
|
||||
return self._get_api_list(
|
||||
"/batches",
|
||||
page=SyncCursorPage[Batch],
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
query=maybe_transform(
|
||||
{
|
||||
"after": after,
|
||||
"limit": limit,
|
||||
},
|
||||
batch_list_params.BatchListParams,
|
||||
),
|
||||
),
|
||||
model=Batch,
|
||||
)
|
||||
|
||||
def cancel(
|
||||
self,
|
||||
batch_id: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> Batch:
|
||||
"""
|
||||
Cancels an in-progress batch.
|
||||
|
||||
Args:
|
||||
batch_id: The ID of the batch to cancel.
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
|
||||
"""
|
||||
if not batch_id:
|
||||
raise ValueError(f"Expected a non-empty value for `batch_id` but received {batch_id!r}")
|
||||
return self._post(
|
||||
f"/batches/{batch_id}/cancel",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=Batch,
|
||||
)
|
||||
@ -1,5 +0,0 @@
|
||||
from .async_completions import AsyncCompletions
|
||||
from .chat import Chat
|
||||
from .completions import Completions
|
||||
|
||||
__all__ = ["AsyncCompletions", "Chat", "Completions"]
|
||||
@ -1,115 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from ...core import (
|
||||
NOT_GIVEN,
|
||||
BaseAPI,
|
||||
Body,
|
||||
Headers,
|
||||
NotGiven,
|
||||
drop_prefix_image_data,
|
||||
make_request_options,
|
||||
maybe_transform,
|
||||
)
|
||||
from ...types.chat.async_chat_completion import AsyncCompletion, AsyncTaskStatus
|
||||
from ...types.chat.code_geex import code_geex_params
|
||||
from ...types.sensitive_word_check import SensitiveWordCheckRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..._client import ZhipuAI
|
||||
|
||||
|
||||
class AsyncCompletions(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
user_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
|
||||
temperature: Optional[float] | NotGiven = NOT_GIVEN,
|
||||
top_p: Optional[float] | NotGiven = NOT_GIVEN,
|
||||
max_tokens: int | NotGiven = NOT_GIVEN,
|
||||
seed: int | NotGiven = NOT_GIVEN,
|
||||
messages: Union[str, list[str], list[int], list[list[int]], None],
|
||||
stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN,
|
||||
sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
|
||||
tools: Optional[object] | NotGiven = NOT_GIVEN,
|
||||
tool_choice: str | NotGiven = NOT_GIVEN,
|
||||
meta: Optional[dict[str, str]] | NotGiven = NOT_GIVEN,
|
||||
extra: Optional[code_geex_params.CodeGeexExtra] | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> AsyncTaskStatus:
|
||||
_cast_type = AsyncTaskStatus
|
||||
logger.debug(f"temperature:{temperature}, top_p:{top_p}")
|
||||
if temperature is not None and temperature != NOT_GIVEN:
|
||||
if temperature <= 0:
|
||||
do_sample = False
|
||||
temperature = 0.01
|
||||
# logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间,do_sample重写为:false(参数top_p temperature不生效)") # noqa: E501
|
||||
if temperature >= 1:
|
||||
temperature = 0.99
|
||||
# logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间")
|
||||
if top_p is not None and top_p != NOT_GIVEN:
|
||||
if top_p >= 1:
|
||||
top_p = 0.99
|
||||
# logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
|
||||
if top_p <= 0:
|
||||
top_p = 0.01
|
||||
# logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
|
||||
|
||||
logger.debug(f"temperature:{temperature}, top_p:{top_p}")
|
||||
if isinstance(messages, list):
|
||||
for item in messages:
|
||||
if item.get("content"):
|
||||
item["content"] = drop_prefix_image_data(item["content"])
|
||||
|
||||
body = {
|
||||
"model": model,
|
||||
"request_id": request_id,
|
||||
"user_id": user_id,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"do_sample": do_sample,
|
||||
"max_tokens": max_tokens,
|
||||
"seed": seed,
|
||||
"messages": messages,
|
||||
"stop": stop,
|
||||
"sensitive_word_check": sensitive_word_check,
|
||||
"tools": tools,
|
||||
"tool_choice": tool_choice,
|
||||
"meta": meta,
|
||||
"extra": maybe_transform(extra, code_geex_params.CodeGeexExtra),
|
||||
}
|
||||
return self._post(
|
||||
"/async/chat/completions",
|
||||
body=body,
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=_cast_type,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
def retrieve_completion_result(
|
||||
self,
|
||||
id: str,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> Union[AsyncCompletion, AsyncTaskStatus]:
|
||||
_cast_type = Union[AsyncCompletion, AsyncTaskStatus]
|
||||
return self._get(
|
||||
path=f"/async-result/{id}",
|
||||
cast_type=_cast_type,
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
)
|
||||
@ -1,18 +0,0 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...core import BaseAPI, cached_property
|
||||
from .async_completions import AsyncCompletions
|
||||
from .completions import Completions
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class Chat(BaseAPI):
|
||||
@cached_property
|
||||
def completions(self) -> Completions:
|
||||
return Completions(self._client)
|
||||
|
||||
@cached_property
|
||||
def asyncCompletions(self) -> AsyncCompletions: # noqa: N802
|
||||
return AsyncCompletions(self._client)
|
||||
@ -1,108 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from ...core import (
|
||||
NOT_GIVEN,
|
||||
BaseAPI,
|
||||
Body,
|
||||
Headers,
|
||||
NotGiven,
|
||||
StreamResponse,
|
||||
deepcopy_minimal,
|
||||
drop_prefix_image_data,
|
||||
make_request_options,
|
||||
maybe_transform,
|
||||
)
|
||||
from ...types.chat.chat_completion import Completion
|
||||
from ...types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from ...types.chat.code_geex import code_geex_params
|
||||
from ...types.sensitive_word_check import SensitiveWordCheckRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..._client import ZhipuAI
|
||||
|
||||
|
||||
class Completions(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
user_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
|
||||
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
|
||||
temperature: Optional[float] | NotGiven = NOT_GIVEN,
|
||||
top_p: Optional[float] | NotGiven = NOT_GIVEN,
|
||||
max_tokens: int | NotGiven = NOT_GIVEN,
|
||||
seed: int | NotGiven = NOT_GIVEN,
|
||||
messages: Union[str, list[str], list[int], object, None],
|
||||
stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN,
|
||||
sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
|
||||
tools: Optional[object] | NotGiven = NOT_GIVEN,
|
||||
tool_choice: str | NotGiven = NOT_GIVEN,
|
||||
meta: Optional[dict[str, str]] | NotGiven = NOT_GIVEN,
|
||||
extra: Optional[code_geex_params.CodeGeexExtra] | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> Completion | StreamResponse[ChatCompletionChunk]:
|
||||
logger.debug(f"temperature:{temperature}, top_p:{top_p}")
|
||||
if temperature is not None and temperature != NOT_GIVEN:
|
||||
if temperature <= 0:
|
||||
do_sample = False
|
||||
temperature = 0.01
|
||||
# logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间,do_sample重写为:false(参数top_p temperature不生效)") # noqa: E501
|
||||
if temperature >= 1:
|
||||
temperature = 0.99
|
||||
# logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间")
|
||||
if top_p is not None and top_p != NOT_GIVEN:
|
||||
if top_p >= 1:
|
||||
top_p = 0.99
|
||||
# logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
|
||||
if top_p <= 0:
|
||||
top_p = 0.01
|
||||
# logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1")
|
||||
|
||||
logger.debug(f"temperature:{temperature}, top_p:{top_p}")
|
||||
if isinstance(messages, list):
|
||||
for item in messages:
|
||||
if item.get("content"):
|
||||
item["content"] = drop_prefix_image_data(item["content"])
|
||||
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"model": model,
|
||||
"request_id": request_id,
|
||||
"user_id": user_id,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"do_sample": do_sample,
|
||||
"max_tokens": max_tokens,
|
||||
"seed": seed,
|
||||
"messages": messages,
|
||||
"stop": stop,
|
||||
"sensitive_word_check": sensitive_word_check,
|
||||
"stream": stream,
|
||||
"tools": tools,
|
||||
"tool_choice": tool_choice,
|
||||
"meta": meta,
|
||||
"extra": maybe_transform(extra, code_geex_params.CodeGeexExtra),
|
||||
}
|
||||
)
|
||||
return self._post(
|
||||
"/chat/completions",
|
||||
body=body,
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=Completion,
|
||||
stream=stream or False,
|
||||
stream_cls=StreamResponse[ChatCompletionChunk],
|
||||
)
|
||||
@ -1,50 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from ..core import NOT_GIVEN, BaseAPI, Body, Headers, NotGiven, make_request_options
|
||||
from ..types.embeddings import EmbeddingsResponded
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .._client import ZhipuAI
|
||||
|
||||
|
||||
class Embeddings(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
input: Union[str, list[str], list[int], list[list[int]]],
|
||||
model: Union[str],
|
||||
dimensions: Union[int] | NotGiven = NOT_GIVEN,
|
||||
encoding_format: str | NotGiven = NOT_GIVEN,
|
||||
user: str | NotGiven = NOT_GIVEN,
|
||||
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
disable_strict_validation: Optional[bool] | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> EmbeddingsResponded:
|
||||
_cast_type = EmbeddingsResponded
|
||||
if disable_strict_validation:
|
||||
_cast_type = object
|
||||
return self._post(
|
||||
"/embeddings",
|
||||
body={
|
||||
"input": input,
|
||||
"model": model,
|
||||
"dimensions": dimensions,
|
||||
"encoding_format": encoding_format,
|
||||
"user": user,
|
||||
"request_id": request_id,
|
||||
"sensitive_word_check": sensitive_word_check,
|
||||
},
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=_cast_type,
|
||||
stream=False,
|
||||
)
|
||||
@ -1,194 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Literal, Optional, cast
|
||||
|
||||
import httpx
|
||||
|
||||
from ..core import (
|
||||
NOT_GIVEN,
|
||||
BaseAPI,
|
||||
Body,
|
||||
FileTypes,
|
||||
Headers,
|
||||
NotGiven,
|
||||
_legacy_binary_response,
|
||||
_legacy_response,
|
||||
deepcopy_minimal,
|
||||
extract_files,
|
||||
make_request_options,
|
||||
maybe_transform,
|
||||
)
|
||||
from ..types.files import FileDeleted, FileObject, ListOfFileObject, UploadDetail, file_create_params
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .._client import ZhipuAI
|
||||
|
||||
__all__ = ["Files", "FilesWithRawResponse"]
|
||||
|
||||
|
||||
class Files(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
file: Optional[FileTypes] = None,
|
||||
upload_detail: Optional[list[UploadDetail]] = None,
|
||||
purpose: Literal["fine-tune", "retrieval", "batch"],
|
||||
knowledge_id: Optional[str] = None,
|
||||
sentence_size: Optional[int] = None,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> FileObject:
|
||||
if not file and not upload_detail:
|
||||
raise ValueError("At least one of `file` and `upload_detail` must be provided.")
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"file": file,
|
||||
"upload_detail": upload_detail,
|
||||
"purpose": purpose,
|
||||
"knowledge_id": knowledge_id,
|
||||
"sentence_size": sentence_size,
|
||||
}
|
||||
)
|
||||
files = extract_files(cast(Mapping[str, object], body), paths=[["file"]])
|
||||
if files:
|
||||
# It should be noted that the actual Content-Type header that will be
|
||||
# sent to the server will contain a `boundary` parameter, e.g.
|
||||
# multipart/form-data; boundary=---abc--
|
||||
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
|
||||
return self._post(
|
||||
"/files",
|
||||
body=maybe_transform(body, file_create_params.FileCreateParams),
|
||||
files=files,
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=FileObject,
|
||||
)
|
||||
|
||||
# def retrieve(
|
||||
# self,
|
||||
# file_id: str,
|
||||
# *,
|
||||
# extra_headers: Headers | None = None,
|
||||
# extra_body: Body | None = None,
|
||||
# timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
# ) -> FileObject:
|
||||
# """
|
||||
# Returns information about a specific file.
|
||||
#
|
||||
# Args:
|
||||
# file_id: The ID of the file to retrieve information about
|
||||
# extra_headers: Send extra headers
|
||||
#
|
||||
# extra_body: Add additional JSON properties to the request
|
||||
#
|
||||
# timeout: Override the client-level default timeout for this request, in seconds
|
||||
# """
|
||||
# if not file_id:
|
||||
# raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}")
|
||||
# return self._get(
|
||||
# f"/files/{file_id}",
|
||||
# options=make_request_options(
|
||||
# extra_headers=extra_headers, extra_body=extra_body, timeout=timeout
|
||||
# ),
|
||||
# cast_type=FileObject,
|
||||
# )
|
||||
|
||||
def list(
|
||||
self,
|
||||
*,
|
||||
purpose: str | NotGiven = NOT_GIVEN,
|
||||
limit: int | NotGiven = NOT_GIVEN,
|
||||
after: str | NotGiven = NOT_GIVEN,
|
||||
order: str | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> ListOfFileObject:
|
||||
return self._get(
|
||||
"/files",
|
||||
cast_type=ListOfFileObject,
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
query={
|
||||
"purpose": purpose,
|
||||
"limit": limit,
|
||||
"after": after,
|
||||
"order": order,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
file_id: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> FileDeleted:
|
||||
"""
|
||||
Delete a file.
|
||||
|
||||
Args:
|
||||
file_id: The ID of the file to delete
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
"""
|
||||
if not file_id:
|
||||
raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}")
|
||||
return self._delete(
|
||||
f"/files/{file_id}",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=FileDeleted,
|
||||
)
|
||||
|
||||
def content(
|
||||
self,
|
||||
file_id: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> _legacy_response.HttpxBinaryResponseContent:
|
||||
"""
|
||||
Returns the contents of the specified file.
|
||||
|
||||
Args:
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
"""
|
||||
if not file_id:
|
||||
raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}")
|
||||
extra_headers = {"Accept": "application/binary", **(extra_headers or {})}
|
||||
return self._get(
|
||||
f"/files/{file_id}/content",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=_legacy_binary_response.HttpxBinaryResponseContent,
|
||||
)
|
||||
|
||||
|
||||
class FilesWithRawResponse:
|
||||
def __init__(self, files: Files) -> None:
|
||||
self._files = files
|
||||
|
||||
self.create = _legacy_response.to_raw_response_wrapper(
|
||||
files.create,
|
||||
)
|
||||
self.list = _legacy_response.to_raw_response_wrapper(
|
||||
files.list,
|
||||
)
|
||||
self.content = _legacy_response.to_raw_response_wrapper(
|
||||
files.content,
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user