Compare commits

..

14 Commits

Author SHA1 Message Date
37f7d5732a external knowledge api 2024-09-18 15:29:30 +08:00
dcb033d221 Merge branch 'main' into feat/external-knowledge
# Conflicts:
#	api/core/rag/datasource/retrieval_service.py
#	api/models/dataset.py
#	api/services/dataset_service.py
2024-09-18 14:40:43 +08:00
9f894bb3b3 external knowledge api 2024-09-18 14:36:51 +08:00
89e81873c4 merge error 2024-09-13 09:49:24 +08:00
9ca0e56a8a external dataset binding 2024-09-11 16:59:19 +08:00
e7c77d961b Merge branch 'main' into feat/external-knowledge
# Conflicts:
#	api/controllers/console/auth/data_source_oauth.py
2024-09-09 15:54:43 +08:00
a63e15081f update nltk version 2024-08-23 16:43:47 +08:00
0724640bbb fix rerank mode is none 2024-08-22 15:36:47 +08:00
cb70e12827 fix rerank mode is none 2024-08-22 15:33:43 +08:00
067b956b2c merge migration 2024-08-21 16:25:18 +08:00
e7762b731c external knowledge 2024-08-20 16:18:35 +08:00
f6c8390b0b external knowledge 2024-08-20 12:47:51 +08:00
4fd57929df Merge branch 'main' into feat/external-knowledge 2024-08-20 12:46:37 +08:00
517cdb2ca4 add external knowledge 2024-08-20 11:13:29 +08:00
1469 changed files with 13704 additions and 45506 deletions

View File

@ -125,7 +125,7 @@ jobs:
with:
images: ${{ env[matrix.image_name_env] }}
tags: |
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') && !contains(github.ref, '-') }}
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
type=ref,event=branch
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
type=raw,value=${{ github.ref_name }},enable=${{ startsWith(github.ref, 'refs/tags/') }}

View File

@ -1,46 +0,0 @@
name: Web Tests
on:
pull_request:
branches:
- main
paths:
- web/**
concurrency:
group: web-tests-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
test:
name: Web Tests
runs-on: ubuntu-latest
defaults:
run:
working-directory: ./web
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v45
with:
files: web/**
- name: Setup Node.js
uses: actions/setup-node@v4
if: steps.changed-files.outputs.any_changed == 'true'
with:
node-version: 20
cache: yarn
cache-dependency-path: ./web/package.json
- name: Install dependencies
if: steps.changed-files.outputs.any_changed == 'true'
run: yarn install --frozen-lockfile
- name: Run tests
if: steps.changed-files.outputs.any_changed == 'true'
run: yarn test

View File

@ -36,7 +36,7 @@
| 被团队成员标记为高优先级的功能 | 高优先级 |
| 在 [community feedback board](https://github.com/langgenius/dify/discussions/categories/feedbacks) 内反馈的常见功能请求 | 中等优先级 |
| 非核心功能和小幅改进 | 低优先级 |
| 有价值不紧急 | 未来功能 |
| 有价值不紧急 | 未来功能 |
### 其他任何事情(例如 bug 报告、性能优化、拼写错误更正):
* 立即开始编码。
@ -138,7 +138,7 @@ Dify 的后端使用 Python 编写,使用 [Flask](https://flask.palletsproject
├── models // 描述数据模型和 API 响应的形状
├── public // 如 favicon 等元资源
├── service // 定义 API 操作的形状
├── test
├── test
├── types // 函数参数和返回值的描述
└── utils // 共享的实用函数
```

View File

@ -162,8 +162,6 @@ PGVECTOR_PORT=5433
PGVECTOR_USER=postgres
PGVECTOR_PASSWORD=postgres
PGVECTOR_DATABASE=postgres
PGVECTOR_MIN_CONNECTION=1
PGVECTOR_MAX_CONNECTION=5
# Tidb Vector configuration
TIDB_VECTOR_HOST=xxx.eu-central-1.xxx.aws.tidbcloud.com
@ -201,8 +199,6 @@ OPENSEARCH_SECURE=true
UPLOAD_FILE_SIZE_LIMIT=15
UPLOAD_FILE_BATCH_LIMIT=5
UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
# Model Configuration
MULTIMODAL_SEND_IMAGE_FORMAT=base64
@ -277,7 +273,6 @@ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=1000
WORKFLOW_MAX_EXECUTION_STEPS=500
WORKFLOW_MAX_EXECUTION_TIME=1200
WORKFLOW_CALL_MAX_DEPTH=5
MAX_VARIABLE_SIZE=204800
# App configuration
APP_MAX_EXECUTION_TIME=1200

View File

@ -1,15 +1,8 @@
{
"version": "0.2.0",
"compounds": [
{
"name": "Launch Flask and Celery",
"configurations": ["Python: Flask", "Python: Celery"]
}
],
"configurations": [
{
"name": "Python: Flask",
"consoleName": "Flask",
"type": "debugpy",
"request": "launch",
"python": "${workspaceFolder}/.venv/bin/python",
@ -24,12 +17,12 @@
},
"args": [
"run",
"--host=0.0.0.0",
"--port=5001"
]
},
{
"name": "Python: Celery",
"consoleName": "Celery",
"type": "debugpy",
"request": "launch",
"python": "${workspaceFolder}/.venv/bin/python",
@ -52,10 +45,10 @@
"-c",
"1",
"--loglevel",
"DEBUG",
"info",
"-Q",
"dataset,generation,mail,ops_trace,app_deletion"
]
}
},
]
}
}

View File

@ -65,12 +65,14 @@
8. Start Dify [web](../web) service.
9. Setup your application by visiting `http://localhost:3000`...
10. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
10. If you need to debug local async processing, please start the worker service.
```bash
poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion
```
The started celery app handles the async tasks, e.g. dataset importing and documents indexing.
## Testing
1. Install dependencies for both the backend and the test environment

View File

@ -53,9 +53,11 @@ from services.account_service import AccountService
warnings.simplefilter("ignore", ResourceWarning)
os.environ["TZ"] = "UTC"
# windows platform not support tzset
if hasattr(time, "tzset"):
# fix windows platform
if os.name == "nt":
os.system('tzutil /s "UTC"')
else:
os.environ["TZ"] = "UTC"
time.tzset()
@ -117,7 +119,7 @@ def create_app() -> Flask:
logging.basicConfig(
level=app.config.get("LOG_LEVEL"),
format=app.config["LOG_FORMAT"],
format=app.config.get("LOG_FORMAT"),
datefmt=app.config.get("LOG_DATEFORMAT"),
handlers=log_handlers,
force=True,
@ -134,7 +136,6 @@ def create_app() -> Flask:
return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple()
for handler in logging.root.handlers:
assert handler.formatter
handler.formatter.converter = time_converter
initialize_extensions(app)
register_blueprints(app)

View File

@ -19,7 +19,7 @@ from extensions.ext_redis import redis_client
from libs.helper import email as email_validate
from libs.password import hash_password, password_pattern, valid_password
from libs.rsa import generate_key_pair
from models import Tenant
from models.account import Tenant
from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
@ -28,28 +28,28 @@ from services.account_service import RegisterService, TenantService
@click.command("reset-password", help="Reset the account password.")
@click.option("--email", prompt=True, help="Account email to reset password for")
@click.option("--new-password", prompt=True, help="New password")
@click.option("--password-confirm", prompt=True, help="Confirm new password")
@click.option("--email", prompt=True, help="The email address of the account whose password you need to reset")
@click.option("--new-password", prompt=True, help="the new password.")
@click.option("--password-confirm", prompt=True, help="the new password confirm.")
def reset_password(email, new_password, password_confirm):
"""
Reset password of owner account
Only available in SELF_HOSTED mode
"""
if str(new_password).strip() != str(password_confirm).strip():
click.echo(click.style("Passwords do not match.", fg="red"))
click.echo(click.style("sorry. The two passwords do not match.", fg="red"))
return
account = db.session.query(Account).filter(Account.email == email).one_or_none()
if not account:
click.echo(click.style("Account not found for email: {}".format(email), fg="red"))
click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red"))
return
try:
valid_password(new_password)
except:
click.echo(click.style("Invalid password. Must match {}".format(password_pattern), fg="red"))
click.echo(click.style("sorry. The passwords must match {} ".format(password_pattern), fg="red"))
return
# generate password salt
@ -62,37 +62,37 @@ def reset_password(email, new_password, password_confirm):
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()
click.echo(click.style("Password reset successfully.", fg="green"))
click.echo(click.style("Congratulations! Password has been reset.", fg="green"))
@click.command("reset-email", help="Reset the account email.")
@click.option("--email", prompt=True, help="Current account email")
@click.option("--new-email", prompt=True, help="New email")
@click.option("--email-confirm", prompt=True, help="Confirm new email")
@click.option("--email", prompt=True, help="The old email address of the account whose email you need to reset")
@click.option("--new-email", prompt=True, help="the new email.")
@click.option("--email-confirm", prompt=True, help="the new email confirm.")
def reset_email(email, new_email, email_confirm):
"""
Replace account email
:return:
"""
if str(new_email).strip() != str(email_confirm).strip():
click.echo(click.style("New emails do not match.", fg="red"))
click.echo(click.style("Sorry, new email and confirm email do not match.", fg="red"))
return
account = db.session.query(Account).filter(Account.email == email).one_or_none()
if not account:
click.echo(click.style("Account not found for email: {}".format(email), fg="red"))
click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red"))
return
try:
email_validate(new_email)
except:
click.echo(click.style("Invalid email: {}".format(new_email), fg="red"))
click.echo(click.style("sorry. {} is not a valid email. ".format(email), fg="red"))
return
account.email = new_email
db.session.commit()
click.echo(click.style("Email updated successfully.", fg="green"))
click.echo(click.style("Congratulations!, email has been reset.", fg="green"))
@click.command(
@ -104,7 +104,7 @@ def reset_email(email, new_email, email_confirm):
)
@click.confirmation_option(
prompt=click.style(
"Are you sure you want to reset encrypt key pair? This operation cannot be rolled back!", fg="red"
"Are you sure you want to reset encrypt key pair? this operation cannot be rolled back!", fg="red"
)
)
def reset_encrypt_key_pair():
@ -114,13 +114,13 @@ def reset_encrypt_key_pair():
Only support SELF_HOSTED mode.
"""
if dify_config.EDITION != "SELF_HOSTED":
click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red"))
click.echo(click.style("Sorry, only support SELF_HOSTED mode.", fg="red"))
return
tenants = db.session.query(Tenant).all()
for tenant in tenants:
if not tenant:
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
click.echo(click.style("Sorry, no workspace found. Please enter /install to initialize.", fg="red"))
return
tenant.encrypt_public_key = generate_key_pair(tenant.id)
@ -137,7 +137,7 @@ def reset_encrypt_key_pair():
)
@click.command("vdb-migrate", help="Migrate vector db.")
@click.command("vdb-migrate", help="migrate vector db.")
@click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.")
def vdb_migrate(scope: str):
if scope in {"knowledge", "all"}:
@ -150,7 +150,7 @@ def migrate_annotation_vector_database():
"""
Migrate annotation datas to target vector database .
"""
click.echo(click.style("Starting annotation data migration.", fg="green"))
click.echo(click.style("Start migrate annotation data.", fg="green"))
create_count = 0
skipped_count = 0
total_count = 0
@ -174,14 +174,14 @@ def migrate_annotation_vector_database():
f"Processing the {total_count} app {app.id}. " + f"{create_count} created, {skipped_count} skipped."
)
try:
click.echo("Creating app annotation index: {}".format(app.id))
click.echo("Create app annotation index: {}".format(app.id))
app_annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app.id).first()
)
if not app_annotation_setting:
skipped_count = skipped_count + 1
click.echo("App annotation setting disabled: {}".format(app.id))
click.echo("App annotation setting is disabled: {}".format(app.id))
continue
# get dataset_collection_binding info
dataset_collection_binding = (
@ -190,7 +190,7 @@ def migrate_annotation_vector_database():
.first()
)
if not dataset_collection_binding:
click.echo("App annotation collection binding not found: {}".format(app.id))
click.echo("App annotation collection binding is not exist: {}".format(app.id))
continue
annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all()
dataset = Dataset(
@ -211,11 +211,11 @@ def migrate_annotation_vector_database():
documents.append(document)
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
click.echo(f"Migrating annotations for app: {app.id}.")
click.echo(f"Start to migrate annotation, app_id: {app.id}.")
try:
vector.delete()
click.echo(click.style(f"Deleted vector index for app {app.id}.", fg="green"))
click.echo(click.style(f"Successfully delete vector index for app: {app.id}.", fg="green"))
except Exception as e:
click.echo(click.style(f"Failed to delete vector index for app {app.id}.", fg="red"))
raise e
@ -223,12 +223,12 @@ def migrate_annotation_vector_database():
try:
click.echo(
click.style(
f"Creating vector index with {len(documents)} annotations for app {app.id}.",
f"Start to created vector index with {len(documents)} annotations for app {app.id}.",
fg="green",
)
)
vector.create(documents)
click.echo(click.style(f"Created vector index for app {app.id}.", fg="green"))
click.echo(click.style(f"Successfully created vector index for app {app.id}.", fg="green"))
except Exception as e:
click.echo(click.style(f"Failed to created vector index for app {app.id}.", fg="red"))
raise e
@ -237,14 +237,14 @@ def migrate_annotation_vector_database():
except Exception as e:
click.echo(
click.style(
"Error creating app annotation index: {} {}".format(e.__class__.__name__, str(e)), fg="red"
"Create app annotation index error: {} {}".format(e.__class__.__name__, str(e)), fg="red"
)
)
continue
click.echo(
click.style(
f"Migration complete. Created {create_count} app annotation indexes. Skipped {skipped_count} apps.",
f"Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.",
fg="green",
)
)
@ -254,7 +254,7 @@ def migrate_knowledge_vector_database():
"""
Migrate vector database datas to target vector database .
"""
click.echo(click.style("Starting vector database migration.", fg="green"))
click.echo(click.style("Start migrate vector db.", fg="green"))
create_count = 0
skipped_count = 0
total_count = 0
@ -278,7 +278,7 @@ def migrate_knowledge_vector_database():
f"Processing the {total_count} dataset {dataset.id}. {create_count} created, {skipped_count} skipped."
)
try:
click.echo("Creating dataset vector database index: {}".format(dataset.id))
click.echo("Create dataset vdb index: {}".format(dataset.id))
if dataset.index_struct_dict:
if dataset.index_struct_dict["type"] == vector_type:
skipped_count = skipped_count + 1
@ -299,7 +299,7 @@ def migrate_knowledge_vector_database():
if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name
else:
raise ValueError("Dataset Collection Binding not found")
raise ValueError("Dataset Collection Bindings is not exist!")
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
@ -351,12 +351,14 @@ def migrate_knowledge_vector_database():
raise ValueError(f"Vector store {vector_type} is not supported.")
vector = Vector(dataset)
click.echo(f"Migrating dataset {dataset.id}.")
click.echo(f"Start to migrate dataset {dataset.id}.")
try:
vector.delete()
click.echo(
click.style(f"Deleted vector index {collection_name} for dataset {dataset.id}.", fg="green")
click.style(
f"Successfully delete vector index {collection_name} for dataset {dataset.id}.", fg="green"
)
)
except Exception as e:
click.echo(
@ -408,13 +410,15 @@ def migrate_knowledge_vector_database():
try:
click.echo(
click.style(
f"Creating vector index with {len(documents)} documents of {segments_count}"
f"Start to created vector index with {len(documents)} documents of {segments_count}"
f" segments for dataset {dataset.id}.",
fg="green",
)
)
vector.create(documents)
click.echo(click.style(f"Created vector index for dataset {dataset.id}.", fg="green"))
click.echo(
click.style(f"Successfully created vector index for dataset {dataset.id}.", fg="green")
)
except Exception as e:
click.echo(click.style(f"Failed to created vector index for dataset {dataset.id}.", fg="red"))
raise e
@ -425,13 +429,13 @@ def migrate_knowledge_vector_database():
except Exception as e:
db.session.rollback()
click.echo(
click.style("Error creating dataset index: {} {}".format(e.__class__.__name__, str(e)), fg="red")
click.style("Create dataset index error: {} {}".format(e.__class__.__name__, str(e)), fg="red")
)
continue
click.echo(
click.style(
f"Migration complete. Created {create_count} dataset indexes. Skipped {skipped_count} datasets.", fg="green"
f"Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.", fg="green"
)
)
@ -441,7 +445,7 @@ def convert_to_agent_apps():
"""
Convert Agent Assistant to Agent App.
"""
click.echo(click.style("Starting convert to agent apps.", fg="green"))
click.echo(click.style("Start convert to agent apps.", fg="green"))
proceeded_app_ids = []
@ -449,14 +453,14 @@ def convert_to_agent_apps():
# fetch first 1000 apps
sql_query = """SELECT a.id AS id FROM apps a
INNER JOIN app_model_configs am ON a.app_model_config_id=am.id
WHERE a.mode = 'chat'
AND am.agent_mode is not null
WHERE a.mode = 'chat'
AND am.agent_mode is not null
AND (
am.agent_mode like '%"strategy": "function_call"%'
am.agent_mode like '%"strategy": "function_call"%'
OR am.agent_mode like '%"strategy": "react"%'
)
)
AND (
am.agent_mode like '{"enabled": true%'
am.agent_mode like '{"enabled": true%'
OR am.agent_mode like '{"max_iteration": %'
) ORDER BY a.created_at DESC LIMIT 1000
"""
@ -492,23 +496,23 @@ def convert_to_agent_apps():
except Exception as e:
click.echo(click.style("Convert app error: {} {}".format(e.__class__.__name__, str(e)), fg="red"))
click.echo(click.style("Conversion complete. Converted {} agent apps.".format(len(proceeded_app_ids)), fg="green"))
click.echo(click.style("Congratulations! Converted {} agent apps.".format(len(proceeded_app_ids)), fg="green"))
@click.command("add-qdrant-doc-id-index", help="Add Qdrant doc_id index.")
@click.option("--field", default="metadata.doc_id", prompt=False, help="Index field , default is metadata.doc_id.")
@click.command("add-qdrant-doc-id-index", help="add qdrant doc_id index.")
@click.option("--field", default="metadata.doc_id", prompt=False, help="index field , default is metadata.doc_id.")
def add_qdrant_doc_id_index(field: str):
click.echo(click.style("Starting Qdrant doc_id index creation.", fg="green"))
click.echo(click.style("Start add qdrant doc_id index.", fg="green"))
vector_type = dify_config.VECTOR_STORE
if vector_type != "qdrant":
click.echo(click.style("This command only supports Qdrant vector store.", fg="red"))
click.echo(click.style("Sorry, only support qdrant vector store.", fg="red"))
return
create_count = 0
try:
bindings = db.session.query(DatasetCollectionBinding).all()
if not bindings:
click.echo(click.style("No dataset collection bindings found.", fg="red"))
click.echo(click.style("Sorry, no dataset collection bindings found.", fg="red"))
return
import qdrant_client
from qdrant_client.http.exceptions import UnexpectedResponse
@ -518,7 +522,7 @@ def add_qdrant_doc_id_index(field: str):
for binding in bindings:
if dify_config.QDRANT_URL is None:
raise ValueError("Qdrant URL is required.")
raise ValueError("Qdrant url is required.")
qdrant_config = QdrantConfig(
endpoint=dify_config.QDRANT_URL,
api_key=dify_config.QDRANT_API_KEY,
@ -535,39 +539,41 @@ def add_qdrant_doc_id_index(field: str):
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
click.echo(click.style(f"Collection not found: {binding.collection_name}.", fg="red"))
click.echo(
click.style(f"Collection not found, collection_name:{binding.collection_name}.", fg="red")
)
continue
# Some other error occurred, so re-raise the exception
else:
click.echo(
click.style(
f"Failed to create Qdrant index for collection: {binding.collection_name}.", fg="red"
f"Failed to create qdrant index, collection_name:{binding.collection_name}.", fg="red"
)
)
except Exception as e:
click.echo(click.style("Failed to create Qdrant client.", fg="red"))
click.echo(click.style("Failed to create qdrant client.", fg="red"))
click.echo(click.style(f"Index creation complete. Created {create_count} collection indexes.", fg="green"))
click.echo(click.style(f"Congratulations! Create {create_count} collection indexes.", fg="green"))
@click.command("create-tenant", help="Create account and tenant.")
@click.option("--email", prompt=True, help="Tenant account email.")
@click.option("--name", prompt=True, help="Workspace name.")
@click.option("--email", prompt=True, help="The email address of the tenant account.")
@click.option("--name", prompt=True, help="The workspace name of the tenant account.")
@click.option("--language", prompt=True, help="Account language, default: en-US.")
def create_tenant(email: str, language: Optional[str] = None, name: Optional[str] = None):
"""
Create tenant account
"""
if not email:
click.echo(click.style("Email is required.", fg="red"))
click.echo(click.style("Sorry, email is required.", fg="red"))
return
# Create account
email = email.strip()
if "@" not in email:
click.echo(click.style("Invalid email address.", fg="red"))
click.echo(click.style("Sorry, invalid email address.", fg="red"))
return
account_name = email.split("@")[0]
@ -587,19 +593,19 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str
click.echo(
click.style(
"Account and tenant created.\nAccount: {}\nPassword: {}".format(email, new_password),
"Congratulations! Account and tenant created.\nAccount: {}\nPassword: {}".format(email, new_password),
fg="green",
)
)
@click.command("upgrade-db", help="Upgrade the database")
@click.command("upgrade-db", help="upgrade the database")
def upgrade_db():
click.echo("Preparing database migration...")
lock = redis_client.lock(name="db_upgrade_lock", timeout=60)
if lock.acquire(blocking=False):
try:
click.echo(click.style("Starting database migration.", fg="green"))
click.echo(click.style("Start database migration.", fg="green"))
# run db migration
import flask_migrate
@ -609,7 +615,7 @@ def upgrade_db():
click.echo(click.style("Database migration successful!", fg="green"))
except Exception as e:
logging.exception(f"Database migration failed: {e}")
logging.exception(f"Database migration failed, error: {e}")
finally:
lock.release()
else:
@ -621,7 +627,7 @@ def fix_app_site_missing():
"""
Fix app related site missing issue.
"""
click.echo(click.style("Starting fix for missing app-related sites.", fg="green"))
click.echo(click.style("Start fix app related site missing issue.", fg="green"))
failed_app_ids = []
while True:
@ -644,22 +650,22 @@ where sites.id is null limit 1000"""
if tenant:
accounts = tenant.get_accounts()
if not accounts:
print("Fix failed for app {}".format(app.id))
print("Fix app {} failed.".format(app.id))
continue
account = accounts[0]
print("Fixing missing site for app {}".format(app.id))
print("Fix app {} related site missing issue.".format(app.id))
app_was_created.send(app, account=account)
except Exception as e:
failed_app_ids.append(app_id)
click.echo(click.style("Failed to fix missing site for app {}".format(app_id), fg="red"))
click.echo(click.style("Fix app {} related site missing issue failed!".format(app_id), fg="red"))
logging.exception(f"Fix app related site missing issue failed, error: {e}")
continue
if not processed_count:
break
click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green"))
click.echo(click.style("Congratulations! Fix app related site missing issue successful!", fg="green"))
def register_commands(app):

View File

@ -4,30 +4,30 @@ from pydantic_settings import BaseSettings
class DeploymentConfig(BaseSettings):
"""
Configuration settings for application deployment
Deployment configs
"""
APPLICATION_NAME: str = Field(
description="Name of the application, used for identification and logging purposes",
description="application name",
default="langgenius/dify",
)
DEBUG: bool = Field(
description="Enable debug mode for additional logging and development features",
description="whether to enable debug mode.",
default=False,
)
TESTING: bool = Field(
description="Enable testing mode for running automated tests",
description="",
default=False,
)
EDITION: str = Field(
description="Deployment edition of the application (e.g., 'SELF_HOSTED', 'CLOUD')",
description="deployment edition",
default="SELF_HOSTED",
)
DEPLOY_ENV: str = Field(
description="Deployment environment (e.g., 'PRODUCTION', 'DEVELOPMENT'), default to PRODUCTION",
description="deployment environment, default to PRODUCTION.",
default="PRODUCTION",
)

View File

@ -4,17 +4,17 @@ from pydantic_settings import BaseSettings
class EnterpriseFeatureConfig(BaseSettings):
"""
Configuration for enterprise-level features.
Enterprise feature configs.
**Before using, please contact business@dify.ai by email to inquire about licensing matters.**
"""
ENTERPRISE_ENABLED: bool = Field(
description="Enable or disable enterprise-level features."
description="whether to enable enterprise features."
"Before using, please contact business@dify.ai by email to inquire about licensing matters.",
default=False,
)
CAN_REPLACE_LOGO: bool = Field(
description="Allow customization of the enterprise logo.",
description="whether to allow replacing enterprise logo.",
default=False,
)

View File

@ -6,31 +6,30 @@ from pydantic_settings import BaseSettings
class NotionConfig(BaseSettings):
"""
Configuration settings for Notion integration
Notion integration configs
"""
NOTION_CLIENT_ID: Optional[str] = Field(
description="Client ID for Notion API authentication. Required for OAuth 2.0 flow.",
description="Notion client ID",
default=None,
)
NOTION_CLIENT_SECRET: Optional[str] = Field(
description="Client secret for Notion API authentication. Required for OAuth 2.0 flow.",
description="Notion client secret key",
default=None,
)
NOTION_INTEGRATION_TYPE: Optional[str] = Field(
description="Type of Notion integration."
" Set to 'internal' for internal integrations, or None for public integrations.",
description="Notion integration type, default to None, available values: internal.",
default=None,
)
NOTION_INTERNAL_SECRET: Optional[str] = Field(
description="Secret key for internal Notion integrations. Required when NOTION_INTEGRATION_TYPE is 'internal'.",
description="Notion internal secret key",
default=None,
)
NOTION_INTEGRATION_TOKEN: Optional[str] = Field(
description="Integration token for Notion API access. Used for direct API calls without OAuth flow.",
description="Notion integration token",
default=None,
)

View File

@ -6,23 +6,20 @@ from pydantic_settings import BaseSettings
class SentryConfig(BaseSettings):
"""
Configuration settings for Sentry error tracking and performance monitoring
Sentry configs
"""
SENTRY_DSN: Optional[str] = Field(
description="Sentry Data Source Name (DSN)."
" This is the unique identifier of your Sentry project, used to send events to the correct project.",
description="Sentry DSN",
default=None,
)
SENTRY_TRACES_SAMPLE_RATE: NonNegativeFloat = Field(
description="Sample rate for Sentry performance monitoring traces."
" Value between 0.0 and 1.0, where 1.0 means 100% of traces are sent to Sentry.",
description="Sentry trace sample rate",
default=1.0,
)
SENTRY_PROFILES_SAMPLE_RATE: NonNegativeFloat = Field(
description="Sample rate for Sentry profiling."
" Value between 0.0 and 1.0, where 1.0 means 100% of profiles are sent to Sentry.",
description="Sentry profiles sample rate",
default=1.0,
)

View File

@ -1,4 +1,4 @@
from typing import Annotated, Literal, Optional
from typing import Annotated, Optional
from pydantic import AliasChoices, Field, HttpUrl, NegativeInt, NonNegativeInt, PositiveInt, computed_field
from pydantic_settings import BaseSettings
@ -8,143 +8,145 @@ from configs.feature.hosted_service import HostedServiceConfig
class SecurityConfig(BaseSettings):
"""
Security-related configurations for the application
Secret Key configs
"""
SECRET_KEY: str = Field(
description="Secret key for secure session cookie signing."
SECRET_KEY: Optional[str] = Field(
description="Your App secret key will be used for securely signing the session cookie"
"Make sure you are changing this key for your deployment with a strong key."
"Generate a strong key using `openssl rand -base64 42` or set via the `SECRET_KEY` environment variable.",
default="",
"You can generate a strong key using `openssl rand -base64 42`."
"Alternatively you can set it with `SECRET_KEY` environment variable.",
default=None,
)
RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field(
description="Duration in hours for which a password reset token remains valid",
description="Expiry time in hours for reset token",
default=24,
)
class AppExecutionConfig(BaseSettings):
"""
Configuration parameters for application execution
App Execution configs
"""
APP_MAX_EXECUTION_TIME: PositiveInt = Field(
description="Maximum allowed execution time for the application in seconds",
description="execution timeout in seconds for app execution",
default=1200,
)
APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field(
description="Maximum number of concurrent active requests per app (0 for unlimited)",
description="max active request per app, 0 means unlimited",
default=0,
)
class CodeExecutionSandboxConfig(BaseSettings):
"""
Configuration for the code execution sandbox environment
Code Execution Sandbox configs
"""
CODE_EXECUTION_ENDPOINT: HttpUrl = Field(
description="URL endpoint for the code execution service",
description="endpoint URL of code execution service",
default="http://sandbox:8194",
)
CODE_EXECUTION_API_KEY: str = Field(
description="API key for accessing the code execution service",
description="API key for code execution service",
default="dify-sandbox",
)
CODE_EXECUTION_CONNECT_TIMEOUT: Optional[float] = Field(
description="Connection timeout in seconds for code execution requests",
description="connect timeout in seconds for code execution request",
default=10.0,
)
CODE_EXECUTION_READ_TIMEOUT: Optional[float] = Field(
description="Read timeout in seconds for code execution requests",
description="read timeout in seconds for code execution request",
default=60.0,
)
CODE_EXECUTION_WRITE_TIMEOUT: Optional[float] = Field(
description="Write timeout in seconds for code execution request",
description="write timeout in seconds for code execution request",
default=10.0,
)
CODE_MAX_NUMBER: PositiveInt = Field(
description="Maximum allowed numeric value in code execution",
description="max depth for code execution",
default=9223372036854775807,
)
CODE_MIN_NUMBER: NegativeInt = Field(
description="Minimum allowed numeric value in code execution",
description="",
default=-9223372036854775807,
)
CODE_MAX_DEPTH: PositiveInt = Field(
description="Maximum allowed depth for nested structures in code execution",
description="max depth for code execution",
default=5,
)
CODE_MAX_PRECISION: PositiveInt = Field(
description="mMaximum number of decimal places for floating-point numbers in code execution",
description="max precision digits for float type in code execution",
default=20,
)
CODE_MAX_STRING_LENGTH: PositiveInt = Field(
description="Maximum allowed length for strings in code execution",
description="max string length for code execution",
default=80000,
)
CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field(
description="Maximum allowed length for string arrays in code execution",
description="",
default=30,
)
CODE_MAX_OBJECT_ARRAY_LENGTH: PositiveInt = Field(
description="Maximum allowed length for object arrays in code execution",
description="",
default=30,
)
CODE_MAX_NUMBER_ARRAY_LENGTH: PositiveInt = Field(
description="Maximum allowed length for numeric arrays in code execution",
description="",
default=1000,
)
class EndpointConfig(BaseSettings):
"""
Configuration for various application endpoints and URLs
Module URL configs
"""
CONSOLE_API_URL: str = Field(
description="Base URL for the console API,"
"used for login authentication callback or notion integration callbacks",
description="The backend URL prefix of the console API."
"used to concatenate the login authorization callback or notion integration callback.",
default="",
)
CONSOLE_WEB_URL: str = Field(
description="Base URL for the console web interface," "used for frontend references and CORS configuration",
description="The front-end URL prefix of the console web."
"used to concatenate some front-end addresses and for CORS configuration use.",
default="",
)
SERVICE_API_URL: str = Field(
description="Base URL for the service API, displayed to users for API access",
description="Service API Url prefix. used to display Service API Base Url to the front-end.",
default="",
)
APP_WEB_URL: str = Field(
description="Base URL for the web application, used for frontend references",
description="WebApp Url prefix. used to display WebAPP API Base Url to the front-end.",
default="",
)
class FileAccessConfig(BaseSettings):
"""
Configuration for file access and handling
File Access configs
"""
FILES_URL: str = Field(
description="Base URL for file preview or download,"
" used for frontend display and multi-model inputs"
description="File preview or download Url prefix."
" used to display File preview or download Url to the front-end or as Multi-model inputs;"
"Url is signed and has expiration time.",
validation_alias=AliasChoices("FILES_URL", "CONSOLE_API_URL"),
alias_priority=1,
@ -152,59 +154,49 @@ class FileAccessConfig(BaseSettings):
)
FILES_ACCESS_TIMEOUT: int = Field(
description="Expiration time in seconds for file access URLs",
description="timeout in seconds for file accessing",
default=300,
)
class FileUploadConfig(BaseSettings):
"""
Configuration for file upload limitations
File Uploading configs
"""
UPLOAD_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description="Maximum allowed file size for uploads in megabytes",
description="size limit in Megabytes for uploading files",
default=15,
)
UPLOAD_FILE_BATCH_LIMIT: NonNegativeInt = Field(
description="Maximum number of files allowed in a single upload batch",
description="batch size limit for uploading files",
default=5,
)
UPLOAD_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description="Maximum allowed image file size for uploads in megabytes",
description="image file size limit in Megabytes for uploading files",
default=10,
)
UPLOAD_VIDEO_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description="video file size limit in Megabytes for uploading files",
default=100,
)
UPLOAD_AUDIO_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description="audio file size limit in Megabytes for uploading files",
default=50,
)
BATCH_UPLOAD_LIMIT: NonNegativeInt = Field(
description="Maximum number of files allowed in a batch upload operation",
description="", # todo: to be clarified
default=20,
)
class HttpConfig(BaseSettings):
"""
HTTP-related configurations for the application
HTTP configs
"""
API_COMPRESSION_ENABLED: bool = Field(
description="Enable or disable gzip compression for HTTP responses",
description="whether to enable HTTP response compression of gzip",
default=False,
)
inner_CONSOLE_CORS_ALLOW_ORIGINS: str = Field(
description="Comma-separated list of allowed origins for CORS in the console",
description="",
validation_alias=AliasChoices("CONSOLE_CORS_ALLOW_ORIGINS", "CONSOLE_WEB_URL"),
default="",
)
@ -226,361 +218,359 @@ class HttpConfig(BaseSettings):
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: Annotated[
PositiveInt, Field(ge=10, description="Maximum connection timeout in seconds for HTTP requests")
PositiveInt, Field(ge=10, description="connect timeout in seconds for HTTP request")
] = 10
HTTP_REQUEST_MAX_READ_TIMEOUT: Annotated[
PositiveInt, Field(ge=60, description="Maximum read timeout in seconds for HTTP requests")
PositiveInt, Field(ge=60, description="read timeout in seconds for HTTP request")
] = 60
HTTP_REQUEST_MAX_WRITE_TIMEOUT: Annotated[
PositiveInt, Field(ge=10, description="Maximum write timeout in seconds for HTTP requests")
PositiveInt, Field(ge=10, description="read timeout in seconds for HTTP request")
] = 20
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
description="Maximum allowed size in bytes for binary data in HTTP requests",
description="",
default=10 * 1024 * 1024,
)
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: PositiveInt = Field(
description="Maximum allowed size in bytes for text data in HTTP requests",
description="",
default=1 * 1024 * 1024,
)
SSRF_PROXY_HTTP_URL: Optional[str] = Field(
description="Proxy URL for HTTP requests to prevent Server-Side Request Forgery (SSRF)",
description="HTTP URL for SSRF proxy",
default=None,
)
SSRF_PROXY_HTTPS_URL: Optional[str] = Field(
description="Proxy URL for HTTPS requests to prevent Server-Side Request Forgery (SSRF)",
description="HTTPS URL for SSRF proxy",
default=None,
)
class InnerAPIConfig(BaseSettings):
"""
Configuration for internal API functionality
Inner API configs
"""
INNER_API: bool = Field(
description="Enable or disable the internal API",
description="whether to enable the inner API",
default=False,
)
INNER_API_KEY: Optional[str] = Field(
description="API key for accessing the internal API",
description="The inner API key is used to authenticate the inner API",
default=None,
)
class LoggingConfig(BaseSettings):
"""
Configuration for application logging
Logging configs
"""
LOG_LEVEL: str = Field(
description="Logging level, default to INFO. Set to ERROR for production environments.",
description="Log output level, default to INFO. It is recommended to set it to ERROR for production.",
default="INFO",
)
LOG_FILE: Optional[str] = Field(
description="File path for log output.",
description="logging output file path",
default=None,
)
LOG_FORMAT: str = Field(
description="Format string for log messages",
description="log format",
default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s",
)
LOG_DATEFORMAT: Optional[str] = Field(
description="Date format string for log timestamps",
description="log date format",
default=None,
)
LOG_TZ: Optional[str] = Field(
description="Timezone for log timestamps (e.g., 'America/New_York')",
description="specify log timezone, eg: America/New_York",
default=None,
)
class ModelLoadBalanceConfig(BaseSettings):
"""
Configuration for model load balancing
Model load balance configs
"""
MODEL_LB_ENABLED: bool = Field(
description="Enable or disable load balancing for models",
description="whether to enable model load balancing",
default=False,
)
class BillingConfig(BaseSettings):
"""
Configuration for platform billing features
Platform Billing Configurations
"""
BILLING_ENABLED: bool = Field(
description="Enable or disable billing functionality",
description="whether to enable billing",
default=False,
)
class UpdateConfig(BaseSettings):
"""
Configuration for application update checks
Update configs
"""
CHECK_UPDATE_URL: str = Field(
description="URL to check for application updates",
description="url for checking updates",
default="https://updates.dify.ai",
)
class WorkflowConfig(BaseSettings):
"""
Configuration for workflow execution
Workflow feature configs
"""
WORKFLOW_MAX_EXECUTION_STEPS: PositiveInt = Field(
description="Maximum number of steps allowed in a single workflow execution",
description="max execution steps in single workflow execution",
default=500,
)
WORKFLOW_MAX_EXECUTION_TIME: PositiveInt = Field(
description="Maximum execution time in seconds for a single workflow",
description="max execution time in seconds in single workflow execution",
default=1200,
)
WORKFLOW_CALL_MAX_DEPTH: PositiveInt = Field(
description="Maximum allowed depth for nested workflow calls",
description="max depth of calling in single workflow execution",
default=5,
)
MAX_VARIABLE_SIZE: PositiveInt = Field(
description="Maximum size in bytes for a single variable in workflows. Default to 200 KB.",
default=200 * 1024,
description="The maximum size in bytes of a variable. default to 5KB.",
default=5 * 1024,
)
class OAuthConfig(BaseSettings):
"""
Configuration for OAuth authentication
oauth configs
"""
OAUTH_REDIRECT_PATH: str = Field(
description="Redirect path for OAuth authentication callbacks",
description="redirect path for OAuth",
default="/console/api/oauth/authorize",
)
GITHUB_CLIENT_ID: Optional[str] = Field(
description="GitHub OAuth client secret",
description="GitHub client id for OAuth",
default=None,
)
GITHUB_CLIENT_SECRET: Optional[str] = Field(
description="GitHub OAuth client secret",
description="GitHub client secret key for OAuth",
default=None,
)
GOOGLE_CLIENT_ID: Optional[str] = Field(
description="Google OAuth client ID",
description="Google client id for OAuth",
default=None,
)
GOOGLE_CLIENT_SECRET: Optional[str] = Field(
description="Google OAuth client secret",
description="Google client secret key for OAuth",
default=None,
)
class ModerationConfig(BaseSettings):
"""
Configuration for content moderation
Moderation in app configs.
"""
MODERATION_BUFFER_SIZE: PositiveInt = Field(
description="Size of the buffer for content moderation processing",
description="buffer size for moderation",
default=300,
)
class ToolConfig(BaseSettings):
"""
Configuration for tool management
Tool configs
"""
TOOL_ICON_CACHE_MAX_AGE: PositiveInt = Field(
description="Maximum age in seconds for caching tool icons",
description="max age in seconds for tool icon caching",
default=3600,
)
class MailConfig(BaseSettings):
"""
Configuration for email services
Mail Configurations
"""
MAIL_TYPE: Optional[str] = Field(
description="Email service provider type ('smtp' or 'resend'), default to None.",
description="Mail provider type name, default to None, available values are `smtp` and `resend`.",
default=None,
)
MAIL_DEFAULT_SEND_FROM: Optional[str] = Field(
description="Default email address to use as the sender",
description="default email address for sending from ",
default=None,
)
RESEND_API_KEY: Optional[str] = Field(
description="API key for Resend email service",
description="API key for Resend",
default=None,
)
RESEND_API_URL: Optional[str] = Field(
description="API URL for Resend email service",
description="API URL for Resend",
default=None,
)
SMTP_SERVER: Optional[str] = Field(
description="SMTP server hostname",
description="smtp server host",
default=None,
)
SMTP_PORT: Optional[int] = Field(
description="SMTP server port number",
description="smtp server port",
default=465,
)
SMTP_USERNAME: Optional[str] = Field(
description="Username for SMTP authentication",
description="smtp server username",
default=None,
)
SMTP_PASSWORD: Optional[str] = Field(
description="Password for SMTP authentication",
description="smtp server password",
default=None,
)
SMTP_USE_TLS: bool = Field(
description="Enable TLS encryption for SMTP connections",
description="whether to use TLS connection to smtp server",
default=False,
)
SMTP_OPPORTUNISTIC_TLS: bool = Field(
description="Enable opportunistic TLS for SMTP connections",
description="whether to use opportunistic TLS connection to smtp server",
default=False,
)
class RagEtlConfig(BaseSettings):
"""
Configuration for RAG ETL processes
RAG ETL Configurations.
"""
# TODO: This config is not only for rag etl, it is also for file upload, we should move it to file upload config
ETL_TYPE: str = Field(
description="RAG ETL type ('dify' or 'Unstructured'), default to 'dify'",
description="RAG ETL type name, default to `dify`, available values are `dify` and `Unstructured`. ",
default="dify",
)
KEYWORD_DATA_SOURCE_TYPE: str = Field(
description="Data source type for keyword extraction"
" ('database' or other supported types), default to 'database'",
description="source type for keyword data, default to `database`, available values are `database` .",
default="database",
)
UNSTRUCTURED_API_URL: Optional[str] = Field(
description="API URL for Unstructured.io service",
description="API URL for Unstructured",
default=None,
)
UNSTRUCTURED_API_KEY: Optional[str] = Field(
description="API key for Unstructured.io service",
description="API key for Unstructured",
default=None,
)
class DataSetConfig(BaseSettings):
"""
Configuration for dataset management
Dataset configs
"""
CLEAN_DAY_SETTING: PositiveInt = Field(
description="Interval in days for dataset cleanup operations",
description="interval in days for cleaning up dataset",
default=30,
)
DATASET_OPERATOR_ENABLED: bool = Field(
description="Enable or disable dataset operator functionality",
description="whether to enable dataset operator",
default=False,
)
class WorkspaceConfig(BaseSettings):
"""
Configuration for workspace management
Workspace configs
"""
INVITE_EXPIRY_HOURS: PositiveInt = Field(
description="Expiration time in hours for workspace invitation links",
description="workspaces invitation expiration in hours",
default=72,
)
class IndexingConfig(BaseSettings):
"""
Configuration for indexing operations
Indexing configs.
"""
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: PositiveInt = Field(
description="Maximum token length for text segmentation during indexing",
description="max segmentation token length for indexing",
default=1000,
)
class ImageFormatConfig(BaseSettings):
MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field(
description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64",
MULTIMODAL_SEND_IMAGE_FORMAT: str = Field(
description="multi model send image format, support base64, url, default is base64",
default="base64",
)
class CeleryBeatConfig(BaseSettings):
CELERY_BEAT_SCHEDULER_TIME: int = Field(
description="Interval in days for Celery Beat scheduler execution, default to 1 day",
description="the time of the celery scheduler, default to 1 day",
default=1,
)
class PositionConfig(BaseSettings):
POSITION_PROVIDER_PINS: str = Field(
description="Comma-separated list of pinned model providers",
description="The heads of model providers",
default="",
)
POSITION_PROVIDER_INCLUDES: str = Field(
description="Comma-separated list of included model providers",
description="The included model providers",
default="",
)
POSITION_PROVIDER_EXCLUDES: str = Field(
description="Comma-separated list of excluded model providers",
description="The excluded model providers",
default="",
)
POSITION_TOOL_PINS: str = Field(
description="Comma-separated list of pinned tools",
description="The heads of tools",
default="",
)
POSITION_TOOL_INCLUDES: str = Field(
description="Comma-separated list of included tools",
description="The included tools",
default="",
)
POSITION_TOOL_EXCLUDES: str = Field(
description="Comma-separated list of excluded tools",
description="The excluded tools",
default="",
)

View File

@ -6,31 +6,31 @@ from pydantic_settings import BaseSettings
class HostedOpenAiConfig(BaseSettings):
"""
Configuration for hosted OpenAI service
Hosted OpenAI service config
"""
HOSTED_OPENAI_API_KEY: Optional[str] = Field(
description="API key for hosted OpenAI service",
description="",
default=None,
)
HOSTED_OPENAI_API_BASE: Optional[str] = Field(
description="Base URL for hosted OpenAI API",
description="",
default=None,
)
HOSTED_OPENAI_API_ORGANIZATION: Optional[str] = Field(
description="Organization ID for hosted OpenAI service",
description="",
default=None,
)
HOSTED_OPENAI_TRIAL_ENABLED: bool = Field(
description="Enable trial access to hosted OpenAI service",
description="",
default=False,
)
HOSTED_OPENAI_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
description="",
default="gpt-3.5-turbo,"
"gpt-3.5-turbo-1106,"
"gpt-3.5-turbo-instruct,"
@ -42,17 +42,17 @@ class HostedOpenAiConfig(BaseSettings):
)
HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
description="Quota limit for hosted OpenAI service usage",
description="",
default=200,
)
HOSTED_OPENAI_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted OpenAI service",
description="",
default=False,
)
HOSTED_OPENAI_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
description="",
default="gpt-4,"
"gpt-4-turbo-preview,"
"gpt-4-turbo-2024-04-09,"
@ -71,122 +71,124 @@ class HostedOpenAiConfig(BaseSettings):
class HostedAzureOpenAiConfig(BaseSettings):
"""
Configuration for hosted Azure OpenAI service
Hosted OpenAI service config
"""
HOSTED_AZURE_OPENAI_ENABLED: bool = Field(
description="Enable hosted Azure OpenAI service",
description="",
default=False,
)
HOSTED_AZURE_OPENAI_API_KEY: Optional[str] = Field(
description="API key for hosted Azure OpenAI service",
description="",
default=None,
)
HOSTED_AZURE_OPENAI_API_BASE: Optional[str] = Field(
description="Base URL for hosted Azure OpenAI API",
description="",
default=None,
)
HOSTED_AZURE_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
description="Quota limit for hosted Azure OpenAI service usage",
description="",
default=200,
)
class HostedAnthropicConfig(BaseSettings):
"""
Configuration for hosted Anthropic service
Hosted Azure OpenAI service config
"""
HOSTED_ANTHROPIC_API_BASE: Optional[str] = Field(
description="Base URL for hosted Anthropic API",
description="",
default=None,
)
HOSTED_ANTHROPIC_API_KEY: Optional[str] = Field(
description="API key for hosted Anthropic service",
description="",
default=None,
)
HOSTED_ANTHROPIC_TRIAL_ENABLED: bool = Field(
description="Enable trial access to hosted Anthropic service",
description="",
default=False,
)
HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field(
description="Quota limit for hosted Anthropic service usage",
description="",
default=600000,
)
HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted Anthropic service",
description="",
default=False,
)
class HostedMinmaxConfig(BaseSettings):
"""
Configuration for hosted Minmax service
Hosted Minmax service config
"""
HOSTED_MINIMAX_ENABLED: bool = Field(
description="Enable hosted Minmax service",
description="",
default=False,
)
class HostedSparkConfig(BaseSettings):
"""
Configuration for hosted Spark service
Hosted Spark service config
"""
HOSTED_SPARK_ENABLED: bool = Field(
description="Enable hosted Spark service",
description="",
default=False,
)
class HostedZhipuAIConfig(BaseSettings):
"""
Configuration for hosted ZhipuAI service
Hosted Minmax service config
"""
HOSTED_ZHIPUAI_ENABLED: bool = Field(
description="Enable hosted ZhipuAI service",
description="",
default=False,
)
class HostedModerationConfig(BaseSettings):
"""
Configuration for hosted Moderation service
Hosted Moderation service config
"""
HOSTED_MODERATION_ENABLED: bool = Field(
description="Enable hosted Moderation service",
description="",
default=False,
)
HOSTED_MODERATION_PROVIDERS: str = Field(
description="Comma-separated list of moderation providers",
description="",
default="",
)
class HostedFetchAppTemplateConfig(BaseSettings):
"""
Configuration for fetching app templates
Hosted Moderation service config
"""
HOSTED_FETCH_APP_TEMPLATES_MODE: str = Field(
description="Mode for fetching app templates: remote, db, or builtin" " default to remote,",
description="the mode for fetching app templates,"
" default to remote,"
" available values: remote, db, builtin",
default="remote",
)
HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN: str = Field(
description="Domain for fetching remote app templates",
description="the domain for fetching remote app templates",
default="https://tmpl.dify.ai",
)

View File

@ -31,71 +31,70 @@ from configs.middleware.vdb.weaviate_config import WeaviateConfig
class StorageConfig(BaseSettings):
STORAGE_TYPE: str = Field(
description="Type of storage to use."
" Options: 'local', 's3', 'azure-blob', 'aliyun-oss', 'google-storage'. Default is 'local'.",
description="storage type,"
" default to `local`,"
" available values are `local`, `s3`, `azure-blob`, `aliyun-oss`, `google-storage`.",
default="local",
)
STORAGE_LOCAL_PATH: str = Field(
description="Path for local storage when STORAGE_TYPE is set to 'local'.",
description="local storage path",
default="storage",
)
class VectorStoreConfig(BaseSettings):
VECTOR_STORE: Optional[str] = Field(
description="Type of vector store to use for efficient similarity search."
" Set to None if not using a vector store.",
description="vector store type",
default=None,
)
class KeywordStoreConfig(BaseSettings):
KEYWORD_STORE: str = Field(
description="Method for keyword extraction and storage."
" Default is 'jieba', a Chinese text segmentation library.",
description="keyword store type",
default="jieba",
)
class DatabaseConfig:
DB_HOST: str = Field(
description="Hostname or IP address of the database server.",
description="db host",
default="localhost",
)
DB_PORT: PositiveInt = Field(
description="Port number for database connection.",
description="db port",
default=5432,
)
DB_USERNAME: str = Field(
description="Username for database authentication.",
description="db username",
default="postgres",
)
DB_PASSWORD: str = Field(
description="Password for database authentication.",
description="db password",
default="",
)
DB_DATABASE: str = Field(
description="Name of the database to connect to.",
description="db database",
default="dify",
)
DB_CHARSET: str = Field(
description="Character set for database connection.",
description="db charset",
default="",
)
DB_EXTRAS: str = Field(
description="Additional database connection parameters. Example: 'keepalives_idle=60&keepalives=1'",
description="db extras options. Example: keepalives_idle=60&keepalives=1",
default="",
)
SQLALCHEMY_DATABASE_URI_SCHEME: str = Field(
description="Database URI scheme for SQLAlchemy connection.",
description="db uri scheme",
default="postgresql",
)
@ -113,27 +112,27 @@ class DatabaseConfig:
)
SQLALCHEMY_POOL_SIZE: NonNegativeInt = Field(
description="Maximum number of database connections in the pool.",
description="pool size of SqlAlchemy",
default=30,
)
SQLALCHEMY_MAX_OVERFLOW: NonNegativeInt = Field(
description="Maximum number of connections that can be created beyond the pool_size.",
description="max overflows for SqlAlchemy",
default=10,
)
SQLALCHEMY_POOL_RECYCLE: NonNegativeInt = Field(
description="Number of seconds after which a connection is automatically recycled.",
description="SqlAlchemy pool recycle",
default=3600,
)
SQLALCHEMY_POOL_PRE_PING: bool = Field(
description="If True, enables connection pool pre-ping feature to check connections.",
description="whether to enable pool pre-ping in SqlAlchemy",
default=False,
)
SQLALCHEMY_ECHO: bool | str = Field(
description="If True, SQLAlchemy will log all SQL statements.",
description="whether to enable SqlAlchemy echo",
default=False,
)
@ -151,27 +150,27 @@ class DatabaseConfig:
class CeleryConfig(DatabaseConfig):
CELERY_BACKEND: str = Field(
description="Backend for Celery task results. Options: 'database', 'redis'.",
description="Celery backend, available values are `database`, `redis`",
default="database",
)
CELERY_BROKER_URL: Optional[str] = Field(
description="URL of the message broker for Celery tasks.",
description="CELERY_BROKER_URL",
default=None,
)
CELERY_USE_SENTINEL: Optional[bool] = Field(
description="Whether to use Redis Sentinel for high availability.",
description="Whether to use Redis Sentinel mode",
default=False,
)
CELERY_SENTINEL_MASTER_NAME: Optional[str] = Field(
description="Name of the Redis Sentinel master.",
description="Redis Sentinel master name",
default=None,
)
CELERY_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field(
description="Timeout for Redis Sentinel socket operations in seconds.",
description="Redis Sentinel socket timeout",
default=0.1,
)

View File

@ -6,65 +6,65 @@ from pydantic_settings import BaseSettings
class RedisConfig(BaseSettings):
"""
Configuration settings for Redis connection
Redis configs
"""
REDIS_HOST: str = Field(
description="Hostname or IP address of the Redis server",
description="Redis host",
default="localhost",
)
REDIS_PORT: PositiveInt = Field(
description="Port number on which the Redis server is listening",
description="Redis port",
default=6379,
)
REDIS_USERNAME: Optional[str] = Field(
description="Username for Redis authentication (if required)",
description="Redis username",
default=None,
)
REDIS_PASSWORD: Optional[str] = Field(
description="Password for Redis authentication (if required)",
description="Redis password",
default=None,
)
REDIS_DB: NonNegativeInt = Field(
description="Redis database number to use (0-15)",
description="Redis database id, default to 0",
default=0,
)
REDIS_USE_SSL: bool = Field(
description="Enable SSL/TLS for the Redis connection",
description="whether to use SSL for Redis connection",
default=False,
)
REDIS_USE_SENTINEL: Optional[bool] = Field(
description="Enable Redis Sentinel mode for high availability",
description="Whether to use Redis Sentinel mode",
default=False,
)
REDIS_SENTINELS: Optional[str] = Field(
description="Comma-separated list of Redis Sentinel nodes (host:port)",
description="Redis Sentinel nodes",
default=None,
)
REDIS_SENTINEL_SERVICE_NAME: Optional[str] = Field(
description="Name of the Redis Sentinel service to monitor",
description="Redis Sentinel service name",
default=None,
)
REDIS_SENTINEL_USERNAME: Optional[str] = Field(
description="Username for Redis Sentinel authentication (if required)",
description="Redis Sentinel username",
default=None,
)
REDIS_SENTINEL_PASSWORD: Optional[str] = Field(
description="Password for Redis Sentinel authentication (if required)",
description="Redis Sentinel password",
default=None,
)
REDIS_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field(
description="Socket timeout in seconds for Redis Sentinel connections",
description="Redis Sentinel socket timeout",
default=0.1,
)

View File

@ -6,40 +6,40 @@ from pydantic_settings import BaseSettings
class AliyunOSSStorageConfig(BaseSettings):
"""
Configuration settings for Aliyun Object Storage Service (OSS)
Aliyun storage configs
"""
ALIYUN_OSS_BUCKET_NAME: Optional[str] = Field(
description="Name of the Aliyun OSS bucket to store and retrieve objects",
description="Aliyun OSS bucket name",
default=None,
)
ALIYUN_OSS_ACCESS_KEY: Optional[str] = Field(
description="Access key ID for authenticating with Aliyun OSS",
description="Aliyun OSS access key",
default=None,
)
ALIYUN_OSS_SECRET_KEY: Optional[str] = Field(
description="Secret access key for authenticating with Aliyun OSS",
description="Aliyun OSS secret key",
default=None,
)
ALIYUN_OSS_ENDPOINT: Optional[str] = Field(
description="URL of the Aliyun OSS endpoint for your chosen region",
description="Aliyun OSS endpoint URL",
default=None,
)
ALIYUN_OSS_REGION: Optional[str] = Field(
description="Aliyun OSS region where your bucket is located (e.g., 'oss-cn-hangzhou')",
description="Aliyun OSS region",
default=None,
)
ALIYUN_OSS_AUTH_VERSION: Optional[str] = Field(
description="Version of the authentication protocol to use with Aliyun OSS (e.g., 'v4')",
description="Aliyun OSS authentication version",
default=None,
)
ALIYUN_OSS_PATH: Optional[str] = Field(
description="Base path within the bucket to store objects (e.g., 'my-app-data/')",
description="Aliyun OSS path",
default=None,
)

View File

@ -6,40 +6,40 @@ from pydantic_settings import BaseSettings
class S3StorageConfig(BaseSettings):
"""
Configuration settings for S3-compatible object storage
S3 storage configs
"""
S3_ENDPOINT: Optional[str] = Field(
description="URL of the S3-compatible storage endpoint (e.g., 'https://s3.amazonaws.com')",
description="S3 storage endpoint",
default=None,
)
S3_REGION: Optional[str] = Field(
description="Region where the S3 bucket is located (e.g., 'us-east-1')",
description="S3 storage region",
default=None,
)
S3_BUCKET_NAME: Optional[str] = Field(
description="Name of the S3 bucket to store and retrieve objects",
description="S3 storage bucket name",
default=None,
)
S3_ACCESS_KEY: Optional[str] = Field(
description="Access key ID for authenticating with the S3 service",
description="S3 storage access key",
default=None,
)
S3_SECRET_KEY: Optional[str] = Field(
description="Secret access key for authenticating with the S3 service",
description="S3 storage secret key",
default=None,
)
S3_ADDRESS_STYLE: str = Field(
description="S3 addressing style: 'auto', 'path', or 'virtual'",
description="S3 storage address style",
default="auto",
)
S3_USE_AWS_MANAGED_IAM: bool = Field(
description="Use AWS managed IAM roles for authentication instead of access/secret keys",
description="whether to use aws managed IAM for S3",
default=False,
)

View File

@ -6,25 +6,25 @@ from pydantic_settings import BaseSettings
class AzureBlobStorageConfig(BaseSettings):
"""
Configuration settings for Azure Blob Storage
Azure Blob storage configs
"""
AZURE_BLOB_ACCOUNT_NAME: Optional[str] = Field(
description="Name of the Azure Storage account (e.g., 'mystorageaccount')",
description="Azure Blob account name",
default=None,
)
AZURE_BLOB_ACCOUNT_KEY: Optional[str] = Field(
description="Access key for authenticating with the Azure Storage account",
description="Azure Blob account key",
default=None,
)
AZURE_BLOB_CONTAINER_NAME: Optional[str] = Field(
description="Name of the Azure Blob container to store and retrieve objects",
description="Azure Blob container name",
default=None,
)
AZURE_BLOB_ACCOUNT_URL: Optional[str] = Field(
description="URL of the Azure Blob storage endpoint (e.g., 'https://mystorageaccount.blob.core.windows.net')",
description="Azure Blob account URL",
default=None,
)

View File

@ -6,15 +6,15 @@ from pydantic_settings import BaseSettings
class GoogleCloudStorageConfig(BaseSettings):
"""
Configuration settings for Google Cloud Storage
Google Cloud storage configs
"""
GOOGLE_STORAGE_BUCKET_NAME: Optional[str] = Field(
description="Name of the Google Cloud Storage bucket to store and retrieve objects (e.g., 'my-gcs-bucket')",
description="Google Cloud storage bucket name",
default=None,
)
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: Optional[str] = Field(
description="Base64-encoded JSON key file for Google Cloud service account authentication",
description="Google Cloud storage service account json base64",
default=None,
)

View File

@ -5,25 +5,25 @@ from pydantic import BaseModel, Field
class HuaweiCloudOBSStorageConfig(BaseModel):
"""
Configuration settings for Huawei Cloud Object Storage Service (OBS)
Huawei Cloud OBS storage configs
"""
HUAWEI_OBS_BUCKET_NAME: Optional[str] = Field(
description="Name of the Huawei Cloud OBS bucket to store and retrieve objects (e.g., 'my-obs-bucket')",
description="Huawei Cloud OBS bucket name",
default=None,
)
HUAWEI_OBS_ACCESS_KEY: Optional[str] = Field(
description="Access Key ID for authenticating with Huawei Cloud OBS",
description="Huawei Cloud OBS Access key",
default=None,
)
HUAWEI_OBS_SECRET_KEY: Optional[str] = Field(
description="Secret Access Key for authenticating with Huawei Cloud OBS",
description="Huawei Cloud OBS Secret key",
default=None,
)
HUAWEI_OBS_SERVER: Optional[str] = Field(
description="Endpoint URL for Huawei Cloud OBS (e.g., 'https://obs.cn-north-4.myhuaweicloud.com')",
description="Huawei Cloud OBS server URL",
default=None,
)

View File

@ -6,30 +6,30 @@ from pydantic_settings import BaseSettings
class OCIStorageConfig(BaseSettings):
"""
Configuration settings for Oracle Cloud Infrastructure (OCI) Object Storage
OCI storage configs
"""
OCI_ENDPOINT: Optional[str] = Field(
description="URL of the OCI Object Storage endpoint (e.g., 'https://objectstorage.us-phoenix-1.oraclecloud.com')",
description="OCI storage endpoint",
default=None,
)
OCI_REGION: Optional[str] = Field(
description="OCI region where the bucket is located (e.g., 'us-phoenix-1')",
description="OCI storage region",
default=None,
)
OCI_BUCKET_NAME: Optional[str] = Field(
description="Name of the OCI Object Storage bucket to store and retrieve objects (e.g., 'my-oci-bucket')",
description="OCI storage bucket name",
default=None,
)
OCI_ACCESS_KEY: Optional[str] = Field(
description="Access key (also known as API key) for authenticating with OCI Object Storage",
description="OCI storage access key",
default=None,
)
OCI_SECRET_KEY: Optional[str] = Field(
description="Secret key associated with the access key for authenticating with OCI Object Storage",
description="OCI storage secret key",
default=None,
)

View File

@ -6,30 +6,30 @@ from pydantic_settings import BaseSettings
class TencentCloudCOSStorageConfig(BaseSettings):
"""
Configuration settings for Tencent Cloud Object Storage (COS)
Tencent Cloud COS storage configs
"""
TENCENT_COS_BUCKET_NAME: Optional[str] = Field(
description="Name of the Tencent Cloud COS bucket to store and retrieve objects",
description="Tencent Cloud COS bucket name",
default=None,
)
TENCENT_COS_REGION: Optional[str] = Field(
description="Tencent Cloud region where the COS bucket is located (e.g., 'ap-guangzhou')",
description="Tencent Cloud COS region",
default=None,
)
TENCENT_COS_SECRET_ID: Optional[str] = Field(
description="SecretId for authenticating with Tencent Cloud COS (part of API credentials)",
description="Tencent Cloud COS secret id",
default=None,
)
TENCENT_COS_SECRET_KEY: Optional[str] = Field(
description="SecretKey for authenticating with Tencent Cloud COS (part of API credentials)",
description="Tencent Cloud COS secret key",
default=None,
)
TENCENT_COS_SCHEME: Optional[str] = Field(
description="Protocol scheme for COS requests: 'https' (recommended) or 'http'",
description="Tencent Cloud COS scheme",
default=None,
)

View File

@ -5,30 +5,30 @@ from pydantic import BaseModel, Field
class VolcengineTOSStorageConfig(BaseModel):
"""
Configuration settings for Volcengine Tinder Object Storage (TOS)
Volcengine tos storage configs
"""
VOLCENGINE_TOS_BUCKET_NAME: Optional[str] = Field(
description="Name of the Volcengine TOS bucket to store and retrieve objects (e.g., 'my-tos-bucket')",
description="Volcengine TOS Bucket Name",
default=None,
)
VOLCENGINE_TOS_ACCESS_KEY: Optional[str] = Field(
description="Access Key ID for authenticating with Volcengine TOS",
description="Volcengine TOS Access Key",
default=None,
)
VOLCENGINE_TOS_SECRET_KEY: Optional[str] = Field(
description="Secret Access Key for authenticating with Volcengine TOS",
description="Volcengine TOS Secret Key",
default=None,
)
VOLCENGINE_TOS_ENDPOINT: Optional[str] = Field(
description="URL of the Volcengine TOS endpoint (e.g., 'https://tos-cn-beijing.volces.com')",
description="Volcengine TOS Endpoint URL",
default=None,
)
VOLCENGINE_TOS_REGION: Optional[str] = Field(
description="Volcengine region where the TOS bucket is located (e.g., 'cn-beijing')",
description="Volcengine TOS Region",
default=None,
)

View File

@ -5,38 +5,33 @@ from pydantic import BaseModel, Field
class AnalyticdbConfig(BaseModel):
"""
Configuration for connecting to Alibaba Cloud AnalyticDB for PostgreSQL.
Configuration for connecting to AnalyticDB.
Refer to the following documentation for details on obtaining credentials:
https://www.alibabacloud.com/help/en/analyticdb-for-postgresql/getting-started/create-an-instance-instances-with-vector-engine-optimization-enabled
"""
ANALYTICDB_KEY_ID: Optional[str] = Field(
default=None, description="The Access Key ID provided by Alibaba Cloud for API authentication."
default=None, description="The Access Key ID provided by Alibaba Cloud for authentication."
)
ANALYTICDB_KEY_SECRET: Optional[str] = Field(
default=None, description="The Secret Access Key corresponding to the Access Key ID for secure API access."
default=None, description="The Secret Access Key corresponding to the Access Key ID for secure access."
)
ANALYTICDB_REGION_ID: Optional[str] = Field(
default=None,
description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou', 'ap-southeast-1').",
default=None, description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')."
)
ANALYTICDB_INSTANCE_ID: Optional[str] = Field(
default=None,
description="The unique identifier of the AnalyticDB instance you want to connect to.",
description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456')..",
)
ANALYTICDB_ACCOUNT: Optional[str] = Field(
default=None,
description="The account name used to log in to the AnalyticDB instance"
" (usually the initial account created with the instance).",
default=None, description="The account name used to log in to the AnalyticDB instance."
)
ANALYTICDB_PASSWORD: Optional[str] = Field(
default=None, description="The password associated with the AnalyticDB account for database authentication."
default=None, description="The password associated with the AnalyticDB account for authentication."
)
ANALYTICDB_NAMESPACE: Optional[str] = Field(
default=None, description="The namespace within AnalyticDB for schema isolation (if using namespace feature)."
default=None, description="The namespace within AnalyticDB for schema isolation."
)
ANALYTICDB_NAMESPACE_PASSWORD: Optional[str] = Field(
default=None,
description="The password for accessing the specified namespace within the AnalyticDB instance"
" (if namespace feature is enabled).",
default=None, description="The password for accessing the specified namespace within the AnalyticDB instance."
)

View File

@ -6,35 +6,35 @@ from pydantic_settings import BaseSettings
class ChromaConfig(BaseSettings):
"""
Configuration settings for Chroma vector database
Chroma configs
"""
CHROMA_HOST: Optional[str] = Field(
description="Hostname or IP address of the Chroma server (e.g., 'localhost' or '192.168.1.100')",
description="Chroma host",
default=None,
)
CHROMA_PORT: PositiveInt = Field(
description="Port number on which the Chroma server is listening (default is 8000)",
description="Chroma port",
default=8000,
)
CHROMA_TENANT: Optional[str] = Field(
description="Tenant identifier for multi-tenancy support in Chroma",
description="Chroma database",
default=None,
)
CHROMA_DATABASE: Optional[str] = Field(
description="Name of the Chroma database to connect to",
description="Chroma database",
default=None,
)
CHROMA_AUTH_PROVIDER: Optional[str] = Field(
description="Authentication provider for Chroma (e.g., 'basic', 'token', or a custom provider)",
description="Chroma authentication provider",
default=None,
)
CHROMA_AUTH_CREDENTIALS: Optional[str] = Field(
description="Authentication credentials for Chroma (format depends on the auth provider)",
description="Chroma authentication credentials",
default=None,
)

View File

@ -6,25 +6,25 @@ from pydantic_settings import BaseSettings
class ElasticsearchConfig(BaseSettings):
"""
Configuration settings for Elasticsearch
Elasticsearch configs
"""
ELASTICSEARCH_HOST: Optional[str] = Field(
description="Hostname or IP address of the Elasticsearch server (e.g., 'localhost' or '192.168.1.100')",
description="Elasticsearch host",
default="127.0.0.1",
)
ELASTICSEARCH_PORT: PositiveInt = Field(
description="Port number on which the Elasticsearch server is listening (default is 9200)",
description="Elasticsearch port",
default=9200,
)
ELASTICSEARCH_USERNAME: Optional[str] = Field(
description="Username for authenticating with Elasticsearch (default is 'elastic')",
description="Elasticsearch username",
default="elastic",
)
ELASTICSEARCH_PASSWORD: Optional[str] = Field(
description="Password for authenticating with Elasticsearch (default is 'elastic')",
description="Elasticsearch password",
default="elastic",
)

View File

@ -6,30 +6,30 @@ from pydantic_settings import BaseSettings
class MilvusConfig(BaseSettings):
"""
Configuration settings for Milvus vector database
Milvus configs
"""
MILVUS_URI: Optional[str] = Field(
description="URI for connecting to the Milvus server (e.g., 'http://localhost:19530' or 'https://milvus-instance.example.com:19530')",
description="Milvus uri",
default="http://127.0.0.1:19530",
)
MILVUS_TOKEN: Optional[str] = Field(
description="Authentication token for Milvus, if token-based authentication is enabled",
description="Milvus token",
default=None,
)
MILVUS_USER: Optional[str] = Field(
description="Username for authenticating with Milvus, if username/password authentication is enabled",
description="Milvus user",
default=None,
)
MILVUS_PASSWORD: Optional[str] = Field(
description="Password for authenticating with Milvus, if username/password authentication is enabled",
description="Milvus password",
default=None,
)
MILVUS_DATABASE: str = Field(
description="Name of the Milvus database to connect to (default is 'default')",
description="Milvus database, default to `default`",
default="default",
)

View File

@ -3,35 +3,35 @@ from pydantic import BaseModel, Field, PositiveInt
class MyScaleConfig(BaseModel):
"""
Configuration settings for MyScale vector database
MyScale configs
"""
MYSCALE_HOST: str = Field(
description="Hostname or IP address of the MyScale server (e.g., 'localhost' or 'myscale.example.com')",
description="MyScale host",
default="localhost",
)
MYSCALE_PORT: PositiveInt = Field(
description="Port number on which the MyScale server is listening (default is 8123)",
description="MyScale port",
default=8123,
)
MYSCALE_USER: str = Field(
description="Username for authenticating with MyScale (default is 'default')",
description="MyScale user",
default="default",
)
MYSCALE_PASSWORD: str = Field(
description="Password for authenticating with MyScale (default is an empty string)",
description="MyScale password",
default="",
)
MYSCALE_DATABASE: str = Field(
description="Name of the MyScale database to connect to (default is 'default')",
description="MyScale database name",
default="default",
)
MYSCALE_FTS_PARAMS: str = Field(
description="Additional parameters for MyScale Full Text Search index)",
description="MyScale fts index parameters",
default="",
)

View File

@ -6,30 +6,30 @@ from pydantic_settings import BaseSettings
class OpenSearchConfig(BaseSettings):
"""
Configuration settings for OpenSearch
OpenSearch configs
"""
OPENSEARCH_HOST: Optional[str] = Field(
description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')",
description="OpenSearch host",
default=None,
)
OPENSEARCH_PORT: PositiveInt = Field(
description="Port number on which the OpenSearch server is listening (default is 9200)",
description="OpenSearch port",
default=9200,
)
OPENSEARCH_USER: Optional[str] = Field(
description="Username for authenticating with OpenSearch",
description="OpenSearch user",
default=None,
)
OPENSEARCH_PASSWORD: Optional[str] = Field(
description="Password for authenticating with OpenSearch",
description="OpenSearch password",
default=None,
)
OPENSEARCH_SECURE: bool = Field(
description="Whether to use SSL/TLS encrypted connection for OpenSearch (True for HTTPS, False for HTTP)",
description="whether to use SSL connection for OpenSearch",
default=False,
)

View File

@ -6,30 +6,30 @@ from pydantic_settings import BaseSettings
class OracleConfig(BaseSettings):
"""
Configuration settings for Oracle database
ORACLE configs
"""
ORACLE_HOST: Optional[str] = Field(
description="Hostname or IP address of the Oracle database server (e.g., 'localhost' or 'oracle.example.com')",
description="ORACLE host",
default=None,
)
ORACLE_PORT: Optional[PositiveInt] = Field(
description="Port number on which the Oracle database server is listening (default is 1521)",
description="ORACLE port",
default=1521,
)
ORACLE_USER: Optional[str] = Field(
description="Username for authenticating with the Oracle database",
description="ORACLE user",
default=None,
)
ORACLE_PASSWORD: Optional[str] = Field(
description="Password for authenticating with the Oracle database",
description="ORACLE password",
default=None,
)
ORACLE_DATABASE: Optional[str] = Field(
description="Name of the Oracle database or service to connect to (e.g., 'ORCL' or 'pdborcl')",
description="ORACLE database",
default=None,
)

View File

@ -6,40 +6,30 @@ from pydantic_settings import BaseSettings
class PGVectorConfig(BaseSettings):
"""
Configuration settings for PGVector (PostgreSQL with vector extension)
PGVector configs
"""
PGVECTOR_HOST: Optional[str] = Field(
description="Hostname or IP address of the PostgreSQL server with PGVector extension (e.g., 'localhost')",
description="PGVector host",
default=None,
)
PGVECTOR_PORT: Optional[PositiveInt] = Field(
description="Port number on which the PostgreSQL server is listening (default is 5433)",
description="PGVector port",
default=5433,
)
PGVECTOR_USER: Optional[str] = Field(
description="Username for authenticating with the PostgreSQL database",
description="PGVector user",
default=None,
)
PGVECTOR_PASSWORD: Optional[str] = Field(
description="Password for authenticating with the PostgreSQL database",
description="PGVector password",
default=None,
)
PGVECTOR_DATABASE: Optional[str] = Field(
description="Name of the PostgreSQL database to connect to",
description="PGVector database",
default=None,
)
PGVECTOR_MIN_CONNECTION: PositiveInt = Field(
description="Min connection of the PostgreSQL database",
default=1,
)
PGVECTOR_MAX_CONNECTION: PositiveInt = Field(
description="Max connection of the PostgreSQL database",
default=5,
)

View File

@ -6,30 +6,30 @@ from pydantic_settings import BaseSettings
class PGVectoRSConfig(BaseSettings):
"""
Configuration settings for PGVecto.RS (Rust-based vector extension for PostgreSQL)
PGVectoRS configs
"""
PGVECTO_RS_HOST: Optional[str] = Field(
description="Hostname or IP address of the PostgreSQL server with PGVecto.RS extension (e.g., 'localhost')",
description="PGVectoRS host",
default=None,
)
PGVECTO_RS_PORT: Optional[PositiveInt] = Field(
description="Port number on which the PostgreSQL server with PGVecto.RS is listening (default is 5431)",
description="PGVectoRS port",
default=5431,
)
PGVECTO_RS_USER: Optional[str] = Field(
description="Username for authenticating with the PostgreSQL database using PGVecto.RS",
description="PGVectoRS user",
default=None,
)
PGVECTO_RS_PASSWORD: Optional[str] = Field(
description="Password for authenticating with the PostgreSQL database using PGVecto.RS",
description="PGVectoRS password",
default=None,
)
PGVECTO_RS_DATABASE: Optional[str] = Field(
description="Name of the PostgreSQL database with PGVecto.RS extension to connect to",
description="PGVectoRS database",
default=None,
)

View File

@ -6,30 +6,30 @@ from pydantic_settings import BaseSettings
class QdrantConfig(BaseSettings):
"""
Configuration settings for Qdrant vector database
Qdrant configs
"""
QDRANT_URL: Optional[str] = Field(
description="URL of the Qdrant server (e.g., 'http://localhost:6333' or 'https://qdrant.example.com')",
description="Qdrant url",
default=None,
)
QDRANT_API_KEY: Optional[str] = Field(
description="API key for authenticating with the Qdrant server",
description="Qdrant api key",
default=None,
)
QDRANT_CLIENT_TIMEOUT: NonNegativeInt = Field(
description="Timeout in seconds for Qdrant client operations (default is 20 seconds)",
description="Qdrant client timeout in seconds",
default=20,
)
QDRANT_GRPC_ENABLED: bool = Field(
description="Whether to enable gRPC support for Qdrant connection (True for gRPC, False for HTTP)",
description="whether enable grpc support for Qdrant connection",
default=False,
)
QDRANT_GRPC_PORT: PositiveInt = Field(
description="Port number for gRPC connection to Qdrant server (default is 6334)",
description="Qdrant grpc port",
default=6334,
)

View File

@ -6,30 +6,30 @@ from pydantic_settings import BaseSettings
class RelytConfig(BaseSettings):
"""
Configuration settings for Relyt database
Relyt configs
"""
RELYT_HOST: Optional[str] = Field(
description="Hostname or IP address of the Relyt server (e.g., 'localhost' or 'relyt.example.com')",
description="Relyt host",
default=None,
)
RELYT_PORT: PositiveInt = Field(
description="Port number on which the Relyt server is listening (default is 9200)",
description="Relyt port",
default=9200,
)
RELYT_USER: Optional[str] = Field(
description="Username for authenticating with the Relyt database",
description="Relyt user",
default=None,
)
RELYT_PASSWORD: Optional[str] = Field(
description="Password for authenticating with the Relyt database",
description="Relyt password",
default=None,
)
RELYT_DATABASE: Optional[str] = Field(
description="Name of the Relyt database to connect to (default is 'default')",
description="Relyt database",
default="default",
)

View File

@ -6,45 +6,45 @@ from pydantic_settings import BaseSettings
class TencentVectorDBConfig(BaseSettings):
"""
Configuration settings for Tencent Vector Database
Tencent Vector configs
"""
TENCENT_VECTOR_DB_URL: Optional[str] = Field(
description="URL of the Tencent Vector Database service (e.g., 'https://vectordb.tencentcloudapi.com')",
description="Tencent Vector URL",
default=None,
)
TENCENT_VECTOR_DB_API_KEY: Optional[str] = Field(
description="API key for authenticating with the Tencent Vector Database service",
description="Tencent Vector API key",
default=None,
)
TENCENT_VECTOR_DB_TIMEOUT: PositiveInt = Field(
description="Timeout in seconds for Tencent Vector Database operations (default is 30 seconds)",
description="Tencent Vector timeout in seconds",
default=30,
)
TENCENT_VECTOR_DB_USERNAME: Optional[str] = Field(
description="Username for authenticating with the Tencent Vector Database (if required)",
description="Tencent Vector username",
default=None,
)
TENCENT_VECTOR_DB_PASSWORD: Optional[str] = Field(
description="Password for authenticating with the Tencent Vector Database (if required)",
description="Tencent Vector password",
default=None,
)
TENCENT_VECTOR_DB_SHARD: PositiveInt = Field(
description="Number of shards for the Tencent Vector Database (default is 1)",
description="Tencent Vector sharding number",
default=1,
)
TENCENT_VECTOR_DB_REPLICAS: NonNegativeInt = Field(
description="Number of replicas for the Tencent Vector Database (default is 2)",
description="Tencent Vector replicas",
default=2,
)
TENCENT_VECTOR_DB_DATABASE: Optional[str] = Field(
description="Name of the specific Tencent Vector Database to connect to",
description="Tencent Vector Database",
default=None,
)

View File

@ -6,30 +6,30 @@ from pydantic_settings import BaseSettings
class TiDBVectorConfig(BaseSettings):
"""
Configuration settings for TiDB Vector database
TiDB Vector configs
"""
TIDB_VECTOR_HOST: Optional[str] = Field(
description="Hostname or IP address of the TiDB Vector server (e.g., 'localhost' or 'tidb.example.com')",
description="TiDB Vector host",
default=None,
)
TIDB_VECTOR_PORT: Optional[PositiveInt] = Field(
description="Port number on which the TiDB Vector server is listening (default is 4000)",
description="TiDB Vector port",
default=4000,
)
TIDB_VECTOR_USER: Optional[str] = Field(
description="Username for authenticating with the TiDB Vector database",
description="TiDB Vector user",
default=None,
)
TIDB_VECTOR_PASSWORD: Optional[str] = Field(
description="Password for authenticating with the TiDB Vector database",
description="TiDB Vector password",
default=None,
)
TIDB_VECTOR_DATABASE: Optional[str] = Field(
description="Name of the TiDB Vector database to connect to",
description="TiDB Vector database",
default=None,
)

View File

@ -6,25 +6,25 @@ from pydantic_settings import BaseSettings
class WeaviateConfig(BaseSettings):
"""
Configuration settings for Weaviate vector database
Weaviate configs
"""
WEAVIATE_ENDPOINT: Optional[str] = Field(
description="URL of the Weaviate server (e.g., 'http://localhost:8080' or 'https://weaviate.example.com')",
description="Weaviate endpoint URL",
default=None,
)
WEAVIATE_API_KEY: Optional[str] = Field(
description="API key for authenticating with the Weaviate server",
description="Weaviate API key",
default=None,
)
WEAVIATE_GRPC_ENABLED: bool = Field(
description="Whether to enable gRPC for Weaviate connection (True for gRPC, False for HTTP)",
description="whether to enable gRPC for Weaviate connection",
default=True,
)
WEAVIATE_BATCH_SIZE: PositiveInt = Field(
description="Number of objects to be processed in a single batch operation (default is 100)",
description="Weaviate batch size",
default=100,
)

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description="Dify version",
default="0.10.0-beta1",
default="0.8.2",
)
COMMIT_SHA: str = Field(

View File

@ -1,21 +1 @@
from configs import dify_config
HIDDEN_VALUE = "[__HIDDEN__]"
UUID_NIL = "00000000-0000-0000-0000-000000000000"
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "mpga"]
VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS])
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "webm", "amr"]
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"]
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
if dify_config.ETL_TYPE == "Unstructured":
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls"]
DOCUMENT_EXTENSIONS.extend(("docx", "csv", "eml", "msg", "pptx", "ppt", "xml", "epub"))
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])

View File

@ -1,9 +1,7 @@
from contextvars import ContextVar
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.variable_pool import VariablePool
tenant_id: ContextVar[str] = ContextVar("tenant_id")
workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
workflow_variable_pool: ContextVar[VariablePool] = ContextVar("workflow_variable_pool")

View File

@ -37,16 +37,7 @@ from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_p
from .billing import billing
# Import datasets controllers
from .datasets import (
data_source,
datasets,
datasets_document,
datasets_segments,
external,
file,
hit_testing,
website,
)
from .datasets import data_source, datasets, datasets_document, datasets_segments, external, file, hit_testing, website
# Import explore controllers
from .explore import (

View File

@ -109,7 +109,6 @@ class ChatMessageApi(Resource):
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("model_config", type=dict, required=True, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
args = parser.parse_args()

View File

@ -22,8 +22,7 @@ from fields.conversation_fields import (
)
from libs.helper import DatetimeString
from libs.login import login_required
from models import Conversation, EndUser, Message, MessageAnnotation
from models.model import AppMode
from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation
class CompletionConversationApi(Resource):

View File

@ -105,6 +105,8 @@ class ChatMessageListApi(Resource):
if rest_count > 0:
has_more = True
history_messages = list(reversed(history_messages))
return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)

View File

@ -12,7 +12,7 @@ from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from fields.app_fields import app_site_fields
from libs.login import login_required
from models import Site
from models.model import Site
def parse_app_site_args():

View File

@ -13,14 +13,14 @@ from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from factories import variable_factory
from core.app.segments import factory
from core.errors.error import AppInvokeQuotaExceededError
from fields.workflow_fields import workflow_fields
from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper
from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required
from models import App
from models.model import AppMode
from models.model import App, AppMode
from services.app_dsl_service import AppDslService
from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowHashNotEqualError
@ -101,13 +101,9 @@ class DraftWorkflowApi(Resource):
try:
environment_variables_list = args.get("environment_variables") or []
environment_variables = [
variable_factory.build_variable_from_mapping(obj) for obj in environment_variables_list
]
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
conversation_variables_list = args.get("conversation_variables") or []
conversation_variables = [
variable_factory.build_variable_from_mapping(obj) for obj in conversation_variables_list
]
conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list]
workflow = workflow_service.sync_draft_workflow(
app_model=app_model,
graph=args["graph"],
@ -170,8 +166,6 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
parser.add_argument("query", type=str, required=True, location="json", default="")
parser.add_argument("files", type=list, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
args = parser.parse_args()
try:
@ -277,15 +271,17 @@ class DraftWorkflowRunApi(Resource):
parser.add_argument("files", type=list, required=False, location="json")
args = parser.parse_args()
response = AppGenerateService.generate(
app_model=app_model,
user=current_user,
args=args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=True,
)
try:
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True
)
return helper.compact_generate_response(response)
return helper.compact_generate_response(response)
except (ValueError, AppInvokeQuotaExceededError) as e:
raise e
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
class WorkflowTaskStopApi(Resource):

View File

@ -7,8 +7,7 @@ from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
from libs.login import login_required
from models import App
from models.model import AppMode
from models.model import App, AppMode
from services.workflow_app_service import WorkflowAppService

View File

@ -13,8 +13,7 @@ from fields.workflow_run_fields import (
)
from libs.helper import uuid_value
from libs.login import login_required
from models import App
from models.model import AppMode
from models.model import App, AppMode
from services.workflow_run_service import WorkflowRunService

View File

@ -10,11 +10,11 @@ from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from enums import WorkflowRunTriggeredFrom
from extensions.ext_database import db
from libs.helper import DatetimeString
from libs.login import login_required
from models.model import AppMode
from models.workflow import WorkflowRunTriggeredFrom
class WorkflowDailyRunsStatistic(Resource):

View File

@ -5,8 +5,7 @@ from typing import Optional, Union
from controllers.console.app.error import AppNotFoundError
from extensions.ext_database import db
from libs.login import current_user
from models import App
from models.model import AppMode
from models.model import App, AppMode
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None):

View File

@ -15,7 +15,7 @@ from controllers.console.setup import setup_required
from extensions.ext_database import db
from libs.helper import email as email_validate
from libs.password import hash_password, valid_password
from models import Account
from models.account import Account
from services.account_service import AccountService
from services.errors.account import RateLimitExceededError

View File

@ -9,7 +9,7 @@ from controllers.console import api
from controllers.console.setup import setup_required
from libs.helper import email, get_remote_ip
from libs.password import valid_password
from models import Account
from models.account import Account
from services.account_service import AccountService, TenantService

View File

@ -11,8 +11,7 @@ from constants.languages import languages
from extensions.ext_database import db
from libs.helper import get_remote_ip
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from models import Account
from models.account import AccountStatus
from models.account import Account, AccountStatus
from services.account_service import AccountService, RegisterService, TenantService
from .. import api

View File

@ -15,7 +15,8 @@ from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
from libs.login import login_required
from models import DataSourceOauthBinding, Document
from models.dataset import Document
from models.source import DataSourceOauthBinding
from services.dataset_service import DatasetService, DocumentService
from tasks.document_indexing_sync_task import document_indexing_sync_task

View File

@ -24,8 +24,8 @@ from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
from fields.document_fields import document_status_fields
from libs.login import login_required
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermissionEnum
from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment
from models.model import ApiToken, UploadFile
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
@ -49,7 +49,7 @@ class DatasetListApi(Resource):
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
ids = request.args.getlist("ids")
# provider = request.args.get("provider", default="vendor")
provider = request.args.get("provider", default="vendor")
search = request.args.get("keyword", default=None, type=str)
tag_ids = request.args.getlist("tag_ids")
@ -57,7 +57,7 @@ class DatasetListApi(Resource):
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
else:
datasets, total = DatasetService.get_datasets(
page, limit, current_user.current_tenant_id, current_user, search, tag_ids
page, limit, provider, current_user.current_tenant_id, current_user, search, tag_ids
)
# check embedding setting
@ -111,7 +111,7 @@ class DatasetListApi(Resource):
help="Invalid indexing technique.",
)
parser.add_argument(
"external_knowledge_api_id",
"external_api_template_id",
type=str,
nullable=True,
required=False,
@ -144,7 +144,7 @@ class DatasetListApi(Resource):
account=current_user,
permission=DatasetPermissionEnum.ONLY_ME,
provider=args["provider"],
external_knowledge_api_id=args["external_knowledge_api_id"],
external_api_template_id=args["external_api_template_id"],
external_knowledge_id=args["external_knowledge_id"],
)
except services.errors.dataset.DatasetNameDuplicateError:
@ -234,33 +234,6 @@ class DatasetApi(Resource):
)
parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
parser.add_argument(
"external_retrieval_model",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid external retrieval model.",
)
parser.add_argument(
"external_knowledge_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge id.",
)
parser.add_argument(
"external_knowledge_api_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge api id.",
)
args = parser.parse_args()
data = request.get_json()
@ -613,10 +586,10 @@ class DatasetRetrievalSettingApi(Resource):
case (
VectorType.MILVUS
| VectorType.RELYT
| VectorType.PGVECTOR
| VectorType.TIDB_VECTOR
| VectorType.CHROMA
| VectorType.TENCENT
| VectorType.PGVECTO_RS
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (
@ -627,7 +600,6 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.MYSCALE
| VectorType.ORACLE
| VectorType.ELASTICSEARCH
| VectorType.PGVECTOR
):
return {
"retrieval_method": [

View File

@ -46,7 +46,8 @@ from fields.document_fields import (
document_with_segments_fields,
)
from libs.login import login_required
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import Dataset, DatasetProcessRule, Document, DocumentSegment
from models.model import UploadFile
from services.dataset_service import DatasetService, DocumentService
from tasks.add_document_to_index_task import add_document_to_index_task
from tasks.remove_document_from_index_task import remove_document_from_index_task

View File

@ -24,7 +24,7 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.segment_fields import segment_fields
from libs.login import login_required
from models import DocumentSegment
from models.dataset import DocumentSegment
from services.dataset_service import DatasetService, DocumentService, SegmentService
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
from tasks.disable_segment_from_index_task import disable_segment_from_index_task

View File

@ -1,18 +1,17 @@
from flask import request
from flask_login import current_user
from flask_restful import Resource, marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required
from services.dataset_service import DatasetService
from services.external_knowledge_service import ExternalDatasetService
from services.hit_testing_service import HitTestingService
def _validate_name(name):
@ -22,7 +21,7 @@ def _validate_name(name):
def _validate_description_length(description):
if description and len(description) > 400:
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
@ -36,12 +35,12 @@ class ExternalApiTemplateListApi(Resource):
limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str)
external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis(
api_templates, total = ExternalDatasetService.get_external_api_templates(
page, limit, current_user.current_tenant_id, search
)
response = {
"data": [item.to_dict() for item in external_knowledge_apis],
"has_more": len(external_knowledge_apis) == limit,
"data": [item.to_dict() for item in api_templates],
"has_more": len(api_templates) == limit,
"limit": limit,
"total": total,
"page": page,
@ -60,6 +59,13 @@ class ExternalApiTemplateListApi(Resource):
help="Name is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
nullable=False,
required=True,
help="Description is required. Description must be between 1 to 400 characters.",
type=_validate_description_length,
)
parser.add_argument(
"settings",
type=dict,
@ -76,32 +82,32 @@ class ExternalApiTemplateListApi(Resource):
raise Forbidden()
try:
external_knowledge_api = ExternalDatasetService.create_external_knowledge_api(
api_template = ExternalDatasetService.create_api_template(
tenant_id=current_user.current_tenant_id, user_id=current_user.id, args=args
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
return external_knowledge_api.to_dict(), 201
return api_template.to_dict(), 201
class ExternalApiTemplateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, external_knowledge_api_id):
external_knowledge_api_id = str(external_knowledge_api_id)
external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(external_knowledge_api_id)
if external_knowledge_api is None:
def get(self, api_template_id):
api_template_id = str(api_template_id)
api_template = ExternalDatasetService.get_api_template(api_template_id)
if api_template is None:
raise NotFound("API template not found.")
return external_knowledge_api.to_dict(), 200
return api_template.to_dict(), 200
@setup_required
@login_required
@account_initialization_required
def patch(self, external_knowledge_api_id):
external_knowledge_api_id = str(external_knowledge_api_id)
def patch(self, api_template_id):
api_template_id = str(api_template_id)
parser = reqparse.RequestParser()
parser.add_argument(
@ -111,6 +117,13 @@ class ExternalApiTemplateApi(Resource):
help="type is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
nullable=False,
required=True,
help="description is required. Description must be between 1 to 400 characters.",
type=_validate_description_length,
)
parser.add_argument(
"settings",
type=dict,
@ -121,40 +134,80 @@ class ExternalApiTemplateApi(Resource):
args = parser.parse_args()
ExternalDatasetService.validate_api_list(args["settings"])
external_knowledge_api = ExternalDatasetService.update_external_knowledge_api(
api_template = ExternalDatasetService.update_api_template(
tenant_id=current_user.current_tenant_id,
user_id=current_user.id,
external_knowledge_api_id=external_knowledge_api_id,
api_template_id=api_template_id,
args=args,
)
return external_knowledge_api.to_dict(), 200
return api_template.to_dict(), 200
@setup_required
@login_required
@account_initialization_required
def delete(self, external_knowledge_api_id):
external_knowledge_api_id = str(external_knowledge_api_id)
def delete(self, api_template_id):
api_template_id = str(api_template_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor or current_user.is_dataset_operator:
raise Forbidden()
ExternalDatasetService.delete_external_knowledge_api(current_user.current_tenant_id, external_knowledge_api_id)
return {"result": "success"}, 200
ExternalDatasetService.delete_api_template(current_user.current_tenant_id, api_template_id)
return {"result": "success"}, 204
class ExternalApiUseCheckApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, external_knowledge_api_id):
external_knowledge_api_id = str(external_knowledge_api_id)
def get(self, api_template_id):
api_template_id = str(api_template_id)
external_knowledge_api_is_using, count = ExternalDatasetService.external_knowledge_api_use_check(
external_knowledge_api_id
external_api_template_is_using = ExternalDatasetService.external_api_template_use_check(api_template_id)
return {"is_using": external_api_template_is_using}, 200
class ExternalDatasetInitApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("api_template_id", type=str, required=True, nullable=True, location="json")
# parser.add_argument('name', nullable=False, required=True,
# help='name is required. Name must be between 1 to 100 characters.',
# type=_validate_name)
# parser.add_argument('description', type=str, required=True, nullable=True, location='json')
parser.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
parser.add_argument("process_parameter", type=dict, required=True, nullable=True, location="json")
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
# validate args
ExternalDatasetService.document_create_args_validate(
current_user.current_tenant_id, args["api_template_id"], args["process_parameter"]
)
return {"is_using": external_knowledge_api_is_using, "count": count}, 200
try:
dataset, documents, batch = ExternalDatasetService.init_external_dataset(
tenant_id=current_user.current_tenant_id,
user_id=current_user.id,
args=args,
)
except Exception as ex:
raise ProviderNotInitializeError(ex.description)
response = {"dataset": dataset, "documents": documents, "batch": batch}
return response
class ExternalDatasetCreateApi(Resource):
@ -167,7 +220,7 @@ class ExternalDatasetCreateApi(Resource):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("external_api_template_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json")
parser.add_argument(
"name",
@ -176,8 +229,7 @@ class ExternalDatasetCreateApi(Resource):
help="name is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
parser.add_argument("description", type=str, required=False, nullable=True, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
parser.add_argument("description", type=str, required=True, nullable=True, location="json")
args = parser.parse_args()
@ -197,43 +249,6 @@ class ExternalDatasetCreateApi(Resource):
return marshal(dataset, dataset_detail_fields), 201
class ExternalKnowledgeHitTestingApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
parser = reqparse.RequestParser()
parser.add_argument("query", type=str, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
args = parser.parse_args()
HitTestingService.hit_testing_args_check(args)
try:
response = HitTestingService.external_retrieve(
dataset=dataset,
query=args["query"],
account=current_user,
external_retrieval_model=args["external_retrieval_model"],
)
return response
except Exception as e:
raise InternalServerError(str(e))
api.add_resource(ExternalKnowledgeHitTestingApi, "/datasets/<uuid:dataset_id>/external-hit-testing")
api.add_resource(ExternalDatasetCreateApi, "/datasets/external")
api.add_resource(ExternalApiTemplateListApi, "/datasets/external-knowledge-api")
api.add_resource(ExternalApiTemplateApi, "/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>")
api.add_resource(ExternalApiUseCheckApi, "/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>/use-check")
api.add_resource(ExternalApiTemplateListApi, "/datasets/external-api-template")
api.add_resource(ExternalApiTemplateApi, "/datasets/external-api-template/<uuid:api_template_id>")
api.add_resource(ExternalApiUseCheckApi, "/datasets/external-api-template/<uuid:api_template_id>/use-check")

View File

@ -1,12 +1,9 @@
import urllib.parse
from flask import request
from flask_login import current_user
from flask_restful import Resource, marshal_with
import services
from configs import dify_config
from constants import DOCUMENT_EXTENSIONS
from controllers.console import api
from controllers.console.datasets.error import (
FileTooLargeError,
@ -16,10 +13,9 @@ from controllers.console.datasets.error import (
)
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from core.helper import ssrf_proxy
from fields.file_fields import file_fields, remote_file_info_fields, upload_config_fields
from fields.file_fields import file_fields, upload_config_fields
from libs.login import login_required
from services.file_service import FileService
from services.file_service import ALLOWED_EXTENSIONS, UNSTRUCTURED_ALLOWED_EXTENSIONS, FileService
PREVIEW_WORDS_LIMIT = 3000
@ -55,7 +51,7 @@ class FileApi(Resource):
if len(request.files) > 1:
raise TooManyFilesError()
try:
upload_file = FileService.upload_file(file=file, user=current_user)
upload_file = FileService.upload_file(file, current_user)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
@ -79,24 +75,11 @@ class FileSupportTypeApi(Resource):
@login_required
@account_initialization_required
def get(self):
return {"allowed_extensions": DOCUMENT_EXTENSIONS}
class RemoteFileInfoApi(Resource):
@marshal_with(remote_file_info_fields)
def get(self, url):
decoded_url = urllib.parse.unquote(url)
try:
response = ssrf_proxy.head(decoded_url)
return {
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
"file_length": int(response.headers.get("Content-Length", 0)),
}
except Exception as e:
return {"error": str(e)}, 400
etl_type = dify_config.ETL_TYPE
allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS
return {"allowed_extensions": allowed_extensions}
api.add_resource(FileApi, "/files/upload")
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")
api.add_resource(FileSupportTypeApi, "/files/support-type")
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")

View File

@ -47,7 +47,7 @@ class HitTestingApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("query", type=str, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
parser.add_argument("external_retrival_model", type=dict, required=False, location="json")
args = parser.parse_args()
HitTestingService.hit_testing_args_check(args)
@ -58,7 +58,7 @@ class HitTestingApi(Resource):
query=args["query"],
account=current_user,
retrieval_model=args["retrieval_model"],
external_retrieval_model=args["external_retrieval_model"],
external_retrieval_model=args["external_retrival_model"],
limit=10,
)

View File

@ -0,0 +1,49 @@
from flask import request
from flask_login import current_user
from flask_restful import Resource, marshal, reqparse
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required
from services.external_knowledge_service import ExternalDatasetService
class TestExternalApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument(
"top_k",
nullable=False,
required=True,
type=int,
)
parser.add_argument(
"score_threshold",
nullable=False,
required=True,
type=float,
)
args = parser.parse_args()
result = ExternalDatasetService.test_external_knowledge_retrival(
args["top_k"], args["score_threshold"]
)
response = {
"data": [item.to_dict() for item in api_templates],
"has_more": len(api_templates) == limit,
"limit": limit,
"total": total,
"page": page,
}
return response, 200
api.add_resource(TestExternalApi, "/dify/external-knowledge/retrival-documents")

View File

@ -14,9 +14,7 @@ class WebsiteCrawlApi(Resource):
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument(
"provider", type=str, choices=["firecrawl", "jinareader"], required=True, nullable=True, location="json"
)
parser.add_argument("provider", type=str, choices=["firecrawl"], required=True, nullable=True, location="json")
parser.add_argument("url", type=str, required=True, nullable=True, location="json")
parser.add_argument("options", type=dict, required=True, nullable=True, location="json")
args = parser.parse_args()
@ -35,7 +33,7 @@ class WebsiteCrawlStatusApi(Resource):
@account_initialization_required
def get(self, job_id: str):
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, choices=["firecrawl", "jinareader"], required=True, location="args")
parser.add_argument("provider", type=str, choices=["firecrawl"], required=True, location="args")
args = parser.parse_args()
# get crawl status
try:

View File

@ -100,7 +100,6 @@ class ChatApi(InstalledAppResource):
parser.add_argument("query", type=str, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
args = parser.parse_args()

View File

@ -11,7 +11,7 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
from extensions.ext_database import db
from fields.installed_app_fields import installed_app_list_fields
from libs.login import login_required
from models import App, InstalledApp, RecommendedApp
from models.model import App, InstalledApp, RecommendedApp
from services.account_service import TenantService

View File

@ -51,7 +51,7 @@ class MessageListApi(InstalledAppResource):
try:
return MessageService.pagination_by_first_id(
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"], "desc"
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")

View File

@ -7,7 +7,7 @@ from werkzeug.exceptions import NotFound
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from libs.login import login_required
from models import InstalledApp
from models.model import InstalledApp
def installed_app_required(view=None):

View File

@ -38,52 +38,11 @@ class VersionApi(Resource):
return result
content = json.loads(response.content)
if _has_new_version(latest_version=content["version"], current_version=f"{args.get('current_version')}"):
result["version"] = content["version"]
result["release_date"] = content["releaseDate"]
result["release_notes"] = content["releaseNotes"]
result["can_auto_update"] = content["canAutoUpdate"]
result["version"] = content["version"]
result["release_date"] = content["releaseDate"]
result["release_notes"] = content["releaseNotes"]
result["can_auto_update"] = content["canAutoUpdate"]
return result
def _has_new_version(*, latest_version: str, current_version: str) -> bool:
def parse_version(version: str) -> tuple:
# Split version into parts and pre-release suffix if any
parts = version.split("-")
version_parts = parts[0].split(".")
pre_release = parts[1] if len(parts) > 1 else None
# Validate version format
if len(version_parts) != 3:
raise ValueError(f"Invalid version format: {version}")
try:
# Convert version parts to integers
major, minor, patch = map(int, version_parts)
return (major, minor, patch, pre_release)
except ValueError:
raise ValueError(f"Invalid version format: {version}")
latest = parse_version(latest_version)
current = parse_version(current_version)
# Compare major, minor, and patch versions
for latest_part, current_part in zip(latest[:3], current[:3]):
if latest_part > current_part:
return True
elif latest_part < current_part:
return False
# If versions are equal, check pre-release suffixes
if latest[3] is None and current[3] is not None:
return True
elif latest[3] is not None and current[3] is None:
return False
elif latest[3] is not None and current[3] is not None:
# Simple string comparison for pre-release versions
return latest[3] > current[3]
return False
api.add_resource(VersionApi, "/version")

View File

@ -20,7 +20,7 @@ from extensions.ext_database import db
from fields.member_fields import account_fields
from libs.helper import TimestampField, timezone
from libs.login import login_required
from models import AccountIntegrate, InvitationCode
from models.account import AccountIntegrate, InvitationCode
from services.account_service import AccountService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError

View File

@ -72,9 +72,8 @@ class DefaultModelApi(Resource):
provider=model_setting["provider"],
model=model_setting["model"],
)
except Exception as ex:
logging.exception(f"{model_setting['model_type']} save error: {ex}")
raise ex
except Exception:
logging.warning(f"{model_setting['model_type']} save error")
return {"result": "success"}

View File

@ -360,15 +360,16 @@ class ToolWorkflowProviderCreateApi(Resource):
args = reqparser.parse_args()
return WorkflowToolManageService.create_workflow_tool(
user_id=user_id,
tenant_id=tenant_id,
workflow_app_id=args["workflow_app_id"],
name=args["name"],
label=args["label"],
icon=args["icon"],
description=args["description"],
parameters=args["parameters"],
privacy_policy=args["privacy_policy"],
user_id,
tenant_id,
args["workflow_app_id"],
args["name"],
args["label"],
args["icon"],
args["description"],
args["parameters"],
args["privacy_policy"],
args.get("labels", []),
)

View File

@ -198,7 +198,7 @@ class WebappLogoWorkspaceApi(Resource):
raise UnsupportedFileTypeError()
try:
upload_file = FileService.upload_file(file=file, user=current_user)
upload_file = FileService.upload_file(file, current_user, True)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)

View File

@ -21,36 +21,7 @@ class ImagePreviewApi(Resource):
return {"content": "Invalid request."}, 400
try:
generator, mimetype = FileService.get_image_preview(
file_id=file_id,
timestamp=timestamp,
nonce=nonce,
sign=sign,
)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return Response(generator, mimetype=mimetype)
class FilePreviewApi(Resource):
def get(self, file_id):
file_id = str(file_id)
timestamp = request.args.get("timestamp")
nonce = request.args.get("nonce")
sign = request.args.get("sign")
if not timestamp or not nonce or not sign:
return {"content": "Invalid request."}, 400
try:
generator, mimetype = FileService.get_signed_file_preview(
file_id=file_id,
timestamp=timestamp,
nonce=nonce,
sign=sign,
)
generator, mimetype = FileService.get_image_preview(file_id, timestamp, nonce, sign)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
@ -78,7 +49,6 @@ class WorkspaceWebappLogoApi(Resource):
api.add_resource(ImagePreviewApi, "/files/<uuid:file_id>/image-preview")
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/file-preview")
api.add_resource(WorkspaceWebappLogoApi, "/files/workspaces/<uuid:workspace_id>/webapp-logo")

View File

@ -54,7 +54,6 @@ class MessageListApi(Resource):
message_fields = {
"id": fields.String,
"conversation_id": fields.String,
"parent_message_id": fields.String,
"inputs": fields.Raw,
"query": fields.String,
"answer": fields.String(attribute="re_sign_file_url_answer"),

View File

@ -28,11 +28,11 @@ class DatasetListApi(DatasetApiResource):
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
# provider = request.args.get("provider", default="vendor")
provider = request.args.get("provider", default="vendor")
search = request.args.get("keyword", default=None, type=str)
tag_ids = request.args.getlist("tag_ids")
datasets, total = DatasetService.get_datasets(page, limit, tenant_id, current_user, search, tag_ids)
datasets, total = DatasetService.get_datasets(page, limit, provider, tenant_id, current_user, search, tag_ids)
# check embedding setting
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
@ -83,7 +83,7 @@ class DatasetListApi(DatasetApiResource):
nullable=False,
)
parser.add_argument(
"external_knowledge_api_id",
"external_api_template_id",
type=str,
nullable=True,
required=False,
@ -112,7 +112,7 @@ class DatasetListApi(DatasetApiResource):
account=current_user,
permission=args["permission"],
provider=args["provider"],
external_knowledge_api_id=args["external_knowledge_api_id"],
external_api_template_id=args["external_api_template_id"],
external_knowledge_id=args["external_knowledge_id"],
)
except services.errors.dataset.DatasetNameDuplicateError:

View File

@ -96,7 +96,6 @@ class ChatApi(WebApiResource):
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="web_app", location="json")
args = parser.parse_args()

View File

@ -1,5 +1,3 @@
import urllib.parse
from flask import request
from flask_restful import marshal_with
@ -7,8 +5,7 @@ import services
from controllers.web import api
from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError
from controllers.web.wraps import WebApiResource
from core.helper import ssrf_proxy
from fields.file_fields import file_fields, remote_file_info_fields
from fields.file_fields import file_fields
from services.file_service import FileService
@ -34,19 +31,4 @@ class FileApi(WebApiResource):
return upload_file, 201
class RemoteFileInfoApi(WebApiResource):
@marshal_with(remote_file_info_fields)
def get(self, url):
decoded_url = urllib.parse.unquote(url)
try:
response = ssrf_proxy.head(decoded_url)
return {
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
"file_length": int(response.headers.get("Content-Length", 0)),
}
except Exception as e:
return {"error": str(e)}, 400
api.add_resource(FileApi, "/files/upload")
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")

View File

@ -22,7 +22,6 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from core.model_runtime.errors.invoke import InvokeError
from fields.conversation_fields import message_file_fields
from fields.message_fields import agent_thought_fields
from fields.raws import FilesContainedField
from libs import helper
from libs.helper import TimestampField, uuid_value
from models.model import AppMode
@ -58,8 +57,7 @@ class MessageListApi(WebApiResource):
message_fields = {
"id": fields.String,
"conversation_id": fields.String,
"parent_message_id": fields.String,
"inputs": FilesContainedField,
"inputs": fields.Raw,
"query": fields.String,
"answer": fields.String(attribute="re_sign_file_url_answer"),
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
@ -91,7 +89,7 @@ class MessageListApi(WebApiResource):
try:
return MessageService.pagination_by_first_id(
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"], "desc"
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"]
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")

View File

@ -16,14 +16,13 @@ from core.app.entities.app_invoke_entities import (
)
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.file import file_manager
from core.file.message_file_parser import MessageFileParser
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
LLMUsage,
PromptMessage,
PromptMessageContent,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
@ -33,7 +32,6 @@ from core.model_runtime.entities import (
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from core.tools.entities.tool_entities import (
ToolParameter,
ToolRuntimeVariablePool,
@ -41,8 +39,8 @@ from core.tools.entities.tool_entities import (
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.tools.tool.tool import Tool
from core.tools.tool_manager import ToolManager
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
from extensions.ext_database import db
from factories import file_factory
from models.model import Conversation, Message, MessageAgentThought
from models.tools import ToolConversationVariables
@ -67,6 +65,23 @@ class BaseAgentRunner(AppRunner):
db_variables: Optional[ToolConversationVariables] = None,
model_instance: ModelInstance = None,
) -> None:
"""
Agent runner
:param tenant_id: tenant id
:param application_generate_entity: application generate entity
:param conversation: conversation
:param app_config: app generate entity
:param model_config: model config
:param config: dataset config
:param queue_manager: queue manager
:param message: message
:param user_id: user id
:param memory: memory
:param prompt_messages: prompt messages
:param variables_pool: variables pool
:param db_variables: db variables
:param model_instance: model instance
"""
self.tenant_id = tenant_id
self.application_generate_entity = application_generate_entity
self.conversation = conversation
@ -164,7 +179,7 @@ class BaseAgentRunner(AppRunner):
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
parameter_type = parameter.type.as_normal_type()
parameter_type = ToolParameterConverter.get_parameter_type(parameter.type)
enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options]
@ -249,7 +264,7 @@ class BaseAgentRunner(AppRunner):
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
parameter_type = parameter.type.as_normal_type()
parameter_type = ToolParameterConverter.get_parameter_type(parameter.type)
enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options]
@ -426,12 +441,10 @@ class BaseAgentRunner(AppRunner):
.filter(
Message.conversation_id == self.message.conversation_id,
)
.order_by(Message.created_at.desc())
.order_by(Message.created_at.asc())
.all()
)
messages = list(reversed(extract_thread_messages(messages)))
for message in messages:
if message.id == self.message.id:
continue
@ -495,24 +508,26 @@ class BaseAgentRunner(AppRunner):
return result
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
message_file_parser = MessageFileParser(
tenant_id=self.tenant_id,
app_id=self.app_config.app_id,
)
files = message.message_files
if files:
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
if file_extra_config:
file_objs = file_factory.build_from_message_files(
message_files=files, tenant_id=self.tenant_id, config=file_extra_config
)
file_objs = message_file_parser.transform_message_files(files, file_extra_config)
else:
file_objs = []
if not file_objs:
return UserPromptMessage(content=message.query)
else:
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
prompt_message_contents = [TextPromptMessageContent(data=message.query)]
for file_obj in file_objs:
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
prompt_message_contents.append(file_obj.prompt_message_content)
return UserPromptMessage(content=prompt_message_contents)
else:

View File

@ -1,11 +1,9 @@
import json
from core.agent.cot_agent_runner import CotAgentRunner
from core.file import file_manager
from core.model_runtime.entities import (
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageContent,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
@ -34,10 +32,9 @@ class CotChatAgentRunner(CotAgentRunner):
Organize user query
"""
if self.files:
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=query))
prompt_message_contents = [TextPromptMessageContent(data=query)]
for file_obj in self.files:
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
prompt_message_contents.append(file_obj.prompt_message_content)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:

View File

@ -7,15 +7,10 @@ from typing import Any, Union
from core.agent.base_agent_runner import BaseAgentRunner
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
from core.file import file_manager
from core.model_runtime.entities import (
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
LLMUsage,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
SystemPromptMessage,
TextPromptMessageContent,
@ -395,10 +390,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
Organize user query
"""
if self.files:
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=query))
prompt_message_contents = [TextPromptMessageContent(data=query)]
for file_obj in self.files:
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
prompt_message_contents.append(file_obj.prompt_message_content)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:

View File

@ -53,11 +53,12 @@ class BasicVariablesConfigManager:
VariableEntity(
type=variable_type,
variable=variable.get("variable"),
description=variable.get("description", ""),
description=variable.get("description"),
label=variable.get("label"),
required=variable.get("required", False),
max_length=variable.get("max_length"),
options=variable.get("options", []),
options=variable.get("options"),
default=variable.get("default"),
)
)

View File

@ -1,12 +1,11 @@
from collections.abc import Sequence
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel, Field
from pydantic import BaseModel
from core.file import FileExtraConfig, FileTransferMethod, FileType
from core.file.file_obj import FileExtraConfig
from core.model_runtime.entities.message_entities import PromptMessageRole
from models.model import AppMode
from models import AppMode
class ModelConfigEntity(BaseModel):
@ -70,7 +69,7 @@ class PromptTemplateEntity(BaseModel):
ADVANCED = "advanced"
@classmethod
def value_of(cls, value: str):
def value_of(cls, value: str) -> "PromptType":
"""
Get value of given mode.
@ -94,8 +93,6 @@ class VariableEntityType(str, Enum):
PARAGRAPH = "paragraph"
NUMBER = "number"
EXTERNAL_DATA_TOOL = "external_data_tool"
FILE = "file"
FILE_LIST = "file-list"
class VariableEntity(BaseModel):
@ -105,14 +102,13 @@ class VariableEntity(BaseModel):
variable: str
label: str
description: str = ""
description: Optional[str] = None
type: VariableEntityType
required: bool = False
max_length: Optional[int] = None
options: Sequence[str] = Field(default_factory=list)
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
allowed_file_extensions: Sequence[str] = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
options: Optional[list[str]] = None
default: Optional[str] = None
hint: Optional[str] = None
class ExternalDataVariableEntity(BaseModel):
@ -140,7 +136,7 @@ class DatasetRetrieveConfigEntity(BaseModel):
MULTIPLE = "multiple"
@classmethod
def value_of(cls, value: str):
def value_of(cls, value: str) -> "RetrieveStrategy":
"""
Get value of given mode.

View File

@ -1,13 +1,12 @@
from collections.abc import Mapping
from typing import Any
from typing import Any, Optional
from core.file.models import FileExtraConfig
from models import FileUploadConfig
from core.file.file_obj import FileExtraConfig
class FileUploadConfigManager:
@classmethod
def convert(cls, config: Mapping[str, Any], is_vision: bool = True):
def convert(cls, config: Mapping[str, Any], is_vision: bool = True) -> Optional[FileExtraConfig]:
"""
Convert model config to model config
@ -16,18 +15,19 @@ class FileUploadConfigManager:
"""
file_upload_dict = config.get("file_upload")
if file_upload_dict:
if file_upload_dict.get("enabled"):
data = {
"image_config": {
"number_limits": file_upload_dict["number_limits"],
"transfer_methods": file_upload_dict["allowed_file_upload_methods"],
if file_upload_dict.get("image"):
if "enabled" in file_upload_dict["image"] and file_upload_dict["image"]["enabled"]:
image_config = {
"number_limits": file_upload_dict["image"]["number_limits"],
"transfer_methods": file_upload_dict["image"]["transfer_methods"],
}
}
if is_vision:
data["image_config"]["detail"] = file_upload_dict.get("image", {}).get("detail", "low")
if is_vision:
image_config["detail"] = file_upload_dict["image"]["detail"]
return FileExtraConfig.model_validate(data)
return FileExtraConfig(image_config=image_config)
return None
@classmethod
def validate_and_set_defaults(cls, config: dict, is_vision: bool = True) -> tuple[dict, list[str]]:
@ -39,7 +39,29 @@ class FileUploadConfigManager:
"""
if not config.get("file_upload"):
config["file_upload"] = {}
else:
FileUploadConfig.model_validate(config["file_upload"])
if not isinstance(config["file_upload"], dict):
raise ValueError("file_upload must be of dict type")
# check image config
if not config["file_upload"].get("image"):
config["file_upload"]["image"] = {"enabled": False}
if config["file_upload"]["image"]["enabled"]:
number_limits = config["file_upload"]["image"]["number_limits"]
if number_limits < 1 or number_limits > 6:
raise ValueError("number_limits must be in [1, 6]")
if is_vision:
detail = config["file_upload"]["image"]["detail"]
if detail not in {"high", "low"}:
raise ValueError("detail must be in ['high', 'low']")
transfer_methods = config["file_upload"]["image"]["transfer_methods"]
if not isinstance(transfer_methods, list):
raise ValueError("transfer_methods must be of list type")
for method in transfer_methods:
if method not in {"remote_url", "local_file"}:
raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
return config, ["file_upload"]

View File

@ -17,6 +17,6 @@ class WorkflowVariablesConfigManager:
# variables
for variable in user_input_form:
variables.append(VariableEntity.model_validate(variable))
variables.append(VariableEntity(**variable))
return variables

View File

@ -20,11 +20,10 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from enums import CreatedByRole
from extensions.ext_database import db
from factories import file_factory
from models.account import Account
from models.model import App, Conversation, EndUser, Message
from models.workflow import Workflow
@ -96,16 +95,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# parse files
files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config,
)
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
else:
file_objs = []
@ -113,9 +106,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
# get tracing instance
trace_manager = TraceQueueManager(
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
)
user_id = user.id if isinstance(user, Account) else user.session_id
trace_manager = TraceQueueManager(app_model.id, user_id)
if invoke_from == InvokeFrom.DEBUGGER:
# always enable retriever resource in debugger mode
@ -126,12 +118,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
task_id=str(uuid.uuid4()),
app_config=app_config,
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs
if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
query=query,
files=file_objs,
parent_message_id=args.get("parent_message_id"),
user_id=user.id,
stream=stream,
invoke_from=invoke_from,

View File

@ -1,26 +1,30 @@
import logging
import os
from collections.abc import Mapping
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.queue_entities import (
QueueAnnotationReplyEvent,
QueueStopEvent,
QueueTextChunkEvent,
)
from core.moderation.base import ModerationError
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.workflow_entry import WorkflowEntry
from enums import UserFrom
from extensions.ext_database import db
from models.model import App, Conversation, EndUser, Message
from models.workflow import ConversationVariable, WorkflowType
@ -40,6 +44,12 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
conversation: Conversation,
message: Message,
) -> None:
"""
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param conversation: conversation
:param message: message
"""
super().__init__(queue_manager)
self.application_generate_entity = application_generate_entity
@ -47,6 +57,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self.message = message
def run(self) -> None:
"""
Run application
:return:
"""
app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)
@ -67,7 +81,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
user_id = self.application_generate_entity.user_id
workflow_callbacks: list[WorkflowCallback] = []
if dify_config.DEBUG:
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
workflow_callbacks.append(WorkflowLoggingCallback())
if self.application_generate_entity.single_iteration_run:
@ -184,6 +198,15 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
query: str,
message_id: str,
) -> bool:
"""
Handle input moderation
:param app_record: app record
:param app_generate_entity: application generate entity
:param inputs: inputs
:param query: query
:param message_id: message id
:return:
"""
try:
# process sensitive_word_avoidance
_, inputs, query = self.moderation_for_inputs(
@ -203,6 +226,14 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
def handle_annotation_reply(
self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity
) -> bool:
"""
Handle annotation reply
:param app_record: app record
:param message: message
:param query: query
:param app_generate_entity: application generate entity
"""
# annotation reply
annotation_reply = self.query_app_annotations_to_reply(
app_record=app_record,
message=message,
@ -224,6 +255,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None:
"""
Direct output
:param text: text
:return:
"""
self._publish_event(QueueTextChunkEvent(text=text))

View File

@ -1,7 +1,7 @@
import json
import logging
import time
from collections.abc import Generator, Mapping
from collections.abc import Generator
from typing import Any, Optional, Union
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
@ -49,7 +49,6 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from enums.workflow_nodes import NodeType
from events.message_event import message_was_created
from extensions.ext_database import db
from models.account import Account
@ -113,7 +112,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._task_state = WorkflowTaskState()
self._conversation_name_generate_thread = None
self._recorded_files: list[Mapping[str, Any]] = []
def process(self):
"""
@ -233,8 +231,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
except Exception as e:
logger.error(e)
break
if tts_publisher:
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response(
self,
@ -292,10 +289,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._handle_workflow_node_execution_success(event)
# Record files if it's an answer node or end node
if event.node_type in [NodeType.ANSWER, NodeType.END]:
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
@ -362,7 +355,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
outputs=json.dumps(event.outputs) if event.outputs else None,
conversation_id=self._conversation.id,
trace_manager=trace_manager,
)
@ -534,7 +527,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
del extras["metadata"]["annotation_reply"]
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id, id=self._message.id, files=self._recorded_files, **extras
task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
)
def _handle_output_moderation_chunk(self, text: str) -> bool:

View File

@ -17,12 +17,12 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskSt
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from enums import CreatedByRole
from extensions.ext_database import db
from factories import file_factory
from models import Account, App, EndUser
from models.account import Account
from models.model import App, EndUser
logger = logging.getLogger(__name__)
@ -49,12 +49,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
) -> dict: ...
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
stream: bool = True,
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True
) -> Union[dict, Generator[dict, None, None]]:
"""
Generate App response.
@ -102,19 +97,12 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
# always enable retriever resource in debugger mode
override_model_config_dict["retriever_resource"] = {"enabled": True}
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
# parse files
files = args.get("files") or []
files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config,
)
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
else:
file_objs = []
@ -127,7 +115,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
)
# get tracing instance
trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id)
user_id = user.id if isinstance(user, Account) else user.session_id
trace_manager = TraceQueueManager(app_model.id, user_id)
# init application generate entity
application_generate_entity = AgentChatAppGenerateEntity(
@ -135,12 +124,9 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs
if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
query=query,
files=file_objs,
parent_message_id=args.get("parent_message_id"),
user_id=user.id,
stream=stream,
invoke_from=invoke_from,

View File

@ -75,10 +75,10 @@ class AppGenerateResponseConverter(ABC):
:return:
"""
# show_retrieve_source
updated_resources = []
if "retriever_resources" in metadata:
metadata["retriever_resources"] = []
for resource in metadata["retriever_resources"]:
updated_resources.append(
metadata["retriever_resources"].append(
{
"segment_id": resource["segment_id"],
"position": resource["position"],
@ -87,7 +87,6 @@ class AppGenerateResponseConverter(ABC):
"content": resource["content"],
}
)
metadata["retriever_resources"] = updated_resources
# show annotation reply
if "annotation_reply" in metadata:

View File

@ -1,92 +1,35 @@
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Optional
from typing import Any, Optional
from core.app.app_config.entities import VariableEntityType
from core.file import File, FileExtraConfig
from factories import file_factory
if TYPE_CHECKING:
from core.app.app_config.entities import AppConfig, VariableEntity
from enums import CreatedByRole
from core.app.app_config.entities import AppConfig, VariableEntity, VariableEntityType
class BaseAppGenerator:
def _prepare_user_inputs(
self,
*,
user_inputs: Optional[Mapping[str, Any]],
app_config: "AppConfig",
user_id: str,
role: "CreatedByRole",
) -> Mapping[str, Any]:
def _get_cleaned_inputs(self, user_inputs: Optional[Mapping[str, Any]], app_config: AppConfig) -> Mapping[str, Any]:
user_inputs = user_inputs or {}
# Filter input variables from form configuration, handle required fields, default values, and option values
variables = app_config.variables
user_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()}
# Convert files in inputs to File
entity_dictionary = {item.variable: item for item in app_config.variables}
# Convert single file to File
files_inputs = {
k: file_factory.build_from_mapping(
mapping=v,
tenant_id=app_config.tenant_id,
user_id=user_id,
role=role,
config=FileExtraConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
),
)
for k, v in user_inputs.items()
if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE
}
# Convert list of files to File
file_list_inputs = {
k: file_factory.build_from_mappings(
mappings=v,
tenant_id=app_config.tenant_id,
user_id=user_id,
role=role,
config=FileExtraConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
),
)
for k, v in user_inputs.items()
if isinstance(v, list)
# Ensure skip List<File>
and all(isinstance(item, dict) for item in v)
and entity_dictionary[k].type == VariableEntityType.FILE_LIST
}
# Merge all inputs
user_inputs = {**user_inputs, **files_inputs, **file_list_inputs}
filtered_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()}
return filtered_inputs
# Check if all files are converted to File
if any(filter(lambda v: isinstance(v, dict), user_inputs.values())):
raise ValueError("Invalid input type")
if any(
filter(lambda v: isinstance(v, dict), filter(lambda item: isinstance(item, list), user_inputs.values()))
):
raise ValueError("Invalid input type")
return user_inputs
def _validate_input(self, *, inputs: Mapping[str, Any], var: "VariableEntity"):
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
user_input_value = inputs.get(var.variable)
if not user_input_value:
if var.required:
raise ValueError(f"{var.variable} is required in input form")
else:
return None
if var.type in {
VariableEntityType.TEXT_INPUT,
VariableEntityType.SELECT,
VariableEntityType.PARAGRAPH,
} and not isinstance(user_input_value, str):
if var.required and not user_input_value:
raise ValueError(f"{var.variable} is required in input form")
if not var.required and not user_input_value:
# TODO: should we return None here if the default value is None?
return var.default or ""
if (
var.type
in {
VariableEntityType.TEXT_INPUT,
VariableEntityType.SELECT,
VariableEntityType.PARAGRAPH,
}
and user_input_value
and not isinstance(user_input_value, str)
):
raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string")
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
# may raise ValueError if user_input_value is not a valid number
@ -98,24 +41,12 @@ class BaseAppGenerator:
except ValueError:
raise ValueError(f"{var.variable} in input form must be a valid number")
if var.type == VariableEntityType.SELECT:
options = var.options
options = var.options or []
if user_input_value not in options:
raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}:
if var.max_length and len(user_input_value) > var.max_length:
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
elif var.type == VariableEntityType.FILE:
if not isinstance(user_input_value, dict) and not isinstance(user_input_value, File):
raise ValueError(f"{var.variable} in input form must be a file")
elif var.type == VariableEntityType.FILE_LIST:
if not (
isinstance(user_input_value, list)
and (
all(isinstance(item, dict) for item in user_input_value)
or all(isinstance(item, File) for item in user_input_value)
)
):
raise ValueError(f"{var.variable} in input form must be a list of files")
return user_input_value

View File

@ -27,7 +27,7 @@ from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
from models.model import App, AppMode, Message, MessageAnnotation
if TYPE_CHECKING:
from core.file.models import File
from core.file.file_obj import FileVar
class AppRunner:
@ -37,7 +37,7 @@ class AppRunner:
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list["File"],
files: list["FileVar"],
query: Optional[str] = None,
) -> int:
"""
@ -137,7 +137,7 @@ class AppRunner:
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list["File"],
files: list["FileVar"],
query: Optional[str] = None,
context: Optional[str] = None,
memory: Optional[TokenBufferMemory] = None,
@ -309,7 +309,7 @@ class AppRunner:
if not prompt_messages:
prompt_messages = result.prompt_messages
if result.delta.usage:
if not usage and result.delta.usage:
usage = result.delta.usage
if not usage:

View File

@ -17,11 +17,10 @@ from core.app.apps.chat.generate_response_converter import ChatAppGenerateRespon
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from enums import CreatedByRole
from extensions.ext_database import db
from factories import file_factory
from models.account import Account
from models.model import App, EndUser
@ -100,19 +99,12 @@ class ChatAppGenerator(MessageBasedAppGenerator):
# always enable retriever resource in debugger mode
override_model_config_dict["retriever_resource"] = {"enabled": True}
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
# parse files
files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config,
)
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
else:
file_objs = []
@ -125,7 +117,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
)
# get tracing instance
trace_manager = TraceQueueManager(app_id=app_model.id)
trace_manager = TraceQueueManager(app_model.id)
# init application generate entity
application_generate_entity = ChatAppGenerateEntity(
@ -133,17 +125,14 @@ class ChatAppGenerator(MessageBasedAppGenerator):
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs
if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
query=query,
files=file_objs,
parent_message_id=args.get("parent_message_id"),
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager,
stream=stream,
)
# init generate records

View File

@ -17,12 +17,12 @@ from core.app.apps.completion.generate_response_converter import CompletionAppGe
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from enums import CreatedByRole
from extensions.ext_database import db
from factories import file_factory
from models import Account, App, EndUser, Message
from models.account import Account
from models.model import App, EndUser, Message
from services.errors.app import MoreLikeThisDisabledError
from services.errors.message import MessageNotExistsError
@ -88,19 +88,12 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
tenant_id=app_model.tenant_id, config=args.get("model_config")
)
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
# parse files
files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config,
)
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
else:
file_objs = []
@ -110,7 +103,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
)
# get tracing instance
user_id = user.id if isinstance(user, Account) else user.session_id
trace_manager = TraceQueueManager(app_model.id)
# init application generate entity
@ -118,7 +110,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
task_id=str(uuid.uuid4()),
app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config),
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
inputs=self._get_cleaned_inputs(inputs, app_config),
query=query,
files=file_objs,
user_id=user.id,
@ -259,16 +251,10 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
override_model_config_dict["model"] = model_dict
# parse files
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict)
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
file_objs = file_factory.build_from_mappings(
mappings=message.files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config,
)
file_objs = message_file_parser.validate_and_transform_files_arg(message.files, file_extra_config, user)
else:
file_objs = []

View File

@ -26,7 +26,7 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from extensions.ext_database import db
from models import Account
from models.account import Account
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError
@ -218,7 +218,6 @@ class MessageBasedAppGenerator(BaseAppGenerator):
answer_tokens=0,
answer_unit_price=0,
answer_price_unit=0,
parent_message_id=getattr(application_generate_entity, "parent_message_id", None),
provider_response_latency=0,
total_price=0,
currency="USD",
@ -238,7 +237,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
type=file.type.value,
transfer_method=file.transfer_method.value,
belongs_to="user",
url=file.remote_url,
url=file.url,
upload_file_id=file.related_id,
created_by_role=("account" if account_id else "end_user"),
created_by=account_id or end_user_id,

View File

@ -3,7 +3,7 @@ import logging
import os
import threading
import uuid
from collections.abc import Generator, Mapping, Sequence
from collections.abc import Generator
from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app
@ -20,12 +20,13 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from enums import CreatedByRole
from extensions.ext_database import db
from factories import file_factory
from models import Account, App, EndUser, Workflow
from models.account import Account
from models.model import App, EndUser
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@ -62,45 +63,48 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
args: dict,
invoke_from: InvokeFrom,
stream: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
):
files: Sequence[Mapping[str, Any]] = args.get("files") or []
"""
Generate App response.
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
:param app_model: App
:param workflow: Workflow
:param user: account or end user
:param args: request args
:param invoke_from: invoke from source
:param stream: is stream
:param call_depth: call depth
:param workflow_thread_pool_id: workflow thread pool id
"""
inputs = args["inputs"]
# parse files
files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
system_files = file_factory.build_from_mappings(
mappings=files,
tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config,
)
if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
else:
file_objs = []
# convert to app config
app_config = WorkflowAppConfigManager.get_app_config(
app_model=app_model,
workflow=workflow,
)
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
# get tracing instance
trace_manager = TraceQueueManager(
app_id=app_model.id,
user_id=user.id if isinstance(user, Account) else user.session_id,
)
user_id = user.id if isinstance(user, Account) else user.session_id
trace_manager = TraceQueueManager(app_model.id, user_id)
inputs: Mapping[str, Any] = args["inputs"]
# init application generate entity
application_generate_entity = WorkflowAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
files=system_files,
inputs=self._get_cleaned_inputs(inputs, app_config),
files=file_objs,
user_id=user.id,
stream=stream,
invoke_from=invoke_from,

View File

@ -1,19 +1,20 @@
import logging
import os
from typing import Optional, cast
from configs import dify_config
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
from core.app.entities.app_invoke_entities import (
InvokeFrom,
WorkflowAppGenerateEntity,
)
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.workflow_entry import WorkflowEntry
from enums import UserFrom
from extensions.ext_database import db
from models.model import App, EndUser
from models.workflow import WorkflowType
@ -70,7 +71,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
db.session.close()
workflow_callbacks: list[WorkflowCallback] = []
if dify_config.DEBUG:
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
workflow_callbacks.append(WorkflowLoggingCallback())
# if only single iteration run is requested

View File

@ -1,3 +1,4 @@
import json
import logging
import time
from collections.abc import Generator
@ -211,8 +212,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
except Exception as e:
logger.error(e)
break
if tts_publisher:
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response(
self,
@ -327,7 +327,9 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
outputs=json.dumps(event.outputs)
if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs
else None,
conversation_id=None,
trace_manager=trace_manager,
)

View File

@ -20,6 +20,7 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
@ -44,7 +45,6 @@ from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.iteration.entities import IterationNodeData
from core.workflow.nodes.node_mapping import node_classes
from core.workflow.workflow_entry import WorkflowEntry
from enums import NodeType
from extensions.ext_database import db
from models.model import App
from models.workflow import Workflow

View File

@ -1,6 +1,7 @@
from typing import Optional
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
GraphRunFailedEvent,
@ -19,8 +20,6 @@ from core.workflow.graph_engine.entities.event import (
ParallelBranchRunSucceededEvent,
)
from .base_workflow_callback import WorkflowCallback
_TEXT_COLOR_MAPPING = {
"blue": "36;1",
"yellow": "33;1",

Some files were not shown because too many files have changed in this diff Show More