mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 17:38:04 +08:00
Merge branch 'feat/queue-based-graph-engine' into feat/rag-2
This commit is contained in:
@ -105,14 +105,14 @@ class AccountService:
|
||||
return f"{ACCOUNT_REFRESH_TOKEN_PREFIX}{account_id}"
|
||||
|
||||
@staticmethod
|
||||
def _store_refresh_token(refresh_token: str, account_id: str) -> None:
|
||||
def _store_refresh_token(refresh_token: str, account_id: str):
|
||||
redis_client.setex(AccountService._get_refresh_token_key(refresh_token), REFRESH_TOKEN_EXPIRY, account_id)
|
||||
redis_client.setex(
|
||||
AccountService._get_account_refresh_token_key(account_id), REFRESH_TOKEN_EXPIRY, refresh_token
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _delete_refresh_token(refresh_token: str, account_id: str) -> None:
|
||||
def _delete_refresh_token(refresh_token: str, account_id: str):
|
||||
redis_client.delete(AccountService._get_refresh_token_key(refresh_token))
|
||||
redis_client.delete(AccountService._get_account_refresh_token_key(account_id))
|
||||
|
||||
@ -214,6 +214,7 @@ class AccountService:
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
return account
|
||||
|
||||
@ -311,12 +312,12 @@ class AccountService:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def delete_account(account: Account) -> None:
|
||||
def delete_account(account: Account):
|
||||
"""Delete account. This method only adds a task to the queue for deletion."""
|
||||
delete_account_task.delay(account.id)
|
||||
|
||||
@staticmethod
|
||||
def link_account_integrate(provider: str, open_id: str, account: Account) -> None:
|
||||
def link_account_integrate(provider: str, open_id: str, account: Account):
|
||||
"""Link account integrate"""
|
||||
try:
|
||||
# Query whether there is an existing binding record for the same provider
|
||||
@ -343,7 +344,7 @@ class AccountService:
|
||||
raise LinkAccountIntegrateError("Failed to link account.") from e
|
||||
|
||||
@staticmethod
|
||||
def close_account(account: Account) -> None:
|
||||
def close_account(account: Account):
|
||||
"""Close account"""
|
||||
account.status = AccountStatus.CLOSED.value
|
||||
db.session.commit()
|
||||
@ -351,6 +352,7 @@ class AccountService:
|
||||
@staticmethod
|
||||
def update_account(account, **kwargs):
|
||||
"""Update account fields"""
|
||||
account = db.session.merge(account)
|
||||
for field, value in kwargs.items():
|
||||
if hasattr(account, field):
|
||||
setattr(account, field, value)
|
||||
@ -372,7 +374,7 @@ class AccountService:
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def update_login_info(account: Account, *, ip_address: str) -> None:
|
||||
def update_login_info(account: Account, *, ip_address: str):
|
||||
"""Update last login time and ip"""
|
||||
account.last_login_at = naive_utc_now()
|
||||
account.last_login_ip = ip_address
|
||||
@ -396,7 +398,7 @@ class AccountService:
|
||||
return TokenPair(access_token=access_token, refresh_token=refresh_token)
|
||||
|
||||
@staticmethod
|
||||
def logout(*, account: Account) -> None:
|
||||
def logout(*, account: Account):
|
||||
refresh_token = redis_client.get(AccountService._get_account_refresh_token_key(account.id))
|
||||
if refresh_token:
|
||||
AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id)
|
||||
@ -703,7 +705,7 @@ class AccountService:
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=None)
|
||||
def add_login_error_rate_limit(email: str) -> None:
|
||||
def add_login_error_rate_limit(email: str):
|
||||
key = f"login_error_rate_limit:{email}"
|
||||
count = redis_client.get(key)
|
||||
if count is None:
|
||||
@ -732,7 +734,7 @@ class AccountService:
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=None)
|
||||
def add_forgot_password_error_rate_limit(email: str) -> None:
|
||||
def add_forgot_password_error_rate_limit(email: str):
|
||||
key = f"forgot_password_error_rate_limit:{email}"
|
||||
count = redis_client.get(key)
|
||||
if count is None:
|
||||
@ -761,7 +763,7 @@ class AccountService:
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=None)
|
||||
def add_change_email_error_rate_limit(email: str) -> None:
|
||||
def add_change_email_error_rate_limit(email: str):
|
||||
key = f"change_email_error_rate_limit:{email}"
|
||||
count = redis_client.get(key)
|
||||
if count is None:
|
||||
@ -789,7 +791,7 @@ class AccountService:
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=None)
|
||||
def add_owner_transfer_error_rate_limit(email: str) -> None:
|
||||
def add_owner_transfer_error_rate_limit(email: str):
|
||||
key = f"owner_transfer_error_rate_limit:{email}"
|
||||
count = redis_client.get(key)
|
||||
if count is None:
|
||||
@ -968,7 +970,7 @@ class TenantService:
|
||||
return tenant
|
||||
|
||||
@staticmethod
|
||||
def switch_tenant(account: Account, tenant_id: Optional[str] = None) -> None:
|
||||
def switch_tenant(account: Account, tenant_id: Optional[str] = None):
|
||||
"""Switch the current workspace for the account"""
|
||||
|
||||
# Ensure tenant_id is provided
|
||||
@ -1065,7 +1067,7 @@ class TenantService:
|
||||
return cast(int, db.session.query(func.count(Tenant.id)).scalar())
|
||||
|
||||
@staticmethod
|
||||
def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str) -> None:
|
||||
def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str):
|
||||
"""Check member permission"""
|
||||
perms = {
|
||||
"add": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN],
|
||||
@ -1085,7 +1087,7 @@ class TenantService:
|
||||
raise NoPermissionError(f"No permission to {action} member.")
|
||||
|
||||
@staticmethod
|
||||
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None:
|
||||
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account):
|
||||
"""Remove member from tenant"""
|
||||
if operator.id == account.id:
|
||||
raise CannotOperateSelfError("Cannot operate self.")
|
||||
@ -1100,7 +1102,7 @@ class TenantService:
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account) -> None:
|
||||
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account):
|
||||
"""Update member role"""
|
||||
TenantService.check_member_permission(tenant, operator, member, "update")
|
||||
|
||||
@ -1127,7 +1129,7 @@ class TenantService:
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_custom_config(tenant_id: str) -> dict:
|
||||
def get_custom_config(tenant_id: str):
|
||||
tenant = db.get_or_404(Tenant, tenant_id)
|
||||
|
||||
return tenant.custom_config_dict
|
||||
@ -1148,7 +1150,7 @@ class RegisterService:
|
||||
return f"member_invite:token:{token}"
|
||||
|
||||
@classmethod
|
||||
def setup(cls, email: str, name: str, password: str, ip_address: str) -> None:
|
||||
def setup(cls, email: str, name: str, password: str, ip_address: str):
|
||||
"""
|
||||
Setup dify
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@ from models.model import AppMode
|
||||
|
||||
class AdvancedPromptTemplateService:
|
||||
@classmethod
|
||||
def get_prompt(cls, args: dict) -> dict:
|
||||
def get_prompt(cls, args: dict):
|
||||
app_mode = args["app_mode"]
|
||||
model_mode = args["model_mode"]
|
||||
model_name = args["model_name"]
|
||||
@ -29,7 +29,7 @@ class AdvancedPromptTemplateService:
|
||||
return cls.get_common_prompt(app_mode, model_mode, has_context)
|
||||
|
||||
@classmethod
|
||||
def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict:
|
||||
def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str):
|
||||
context_prompt = copy.deepcopy(CONTEXT)
|
||||
|
||||
if app_mode == AppMode.CHAT.value:
|
||||
@ -52,7 +52,7 @@ class AdvancedPromptTemplateService:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
|
||||
def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str):
|
||||
if has_context == "true":
|
||||
prompt_template["completion_prompt_config"]["prompt"]["text"] = (
|
||||
context + prompt_template["completion_prompt_config"]["prompt"]["text"]
|
||||
@ -61,7 +61,7 @@ class AdvancedPromptTemplateService:
|
||||
return prompt_template
|
||||
|
||||
@classmethod
|
||||
def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
|
||||
def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str):
|
||||
if has_context == "true":
|
||||
prompt_template["chat_prompt_config"]["prompt"][0]["text"] = (
|
||||
context + prompt_template["chat_prompt_config"]["prompt"][0]["text"]
|
||||
@ -70,7 +70,7 @@ class AdvancedPromptTemplateService:
|
||||
return prompt_template
|
||||
|
||||
@classmethod
|
||||
def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict:
|
||||
def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str):
|
||||
baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)
|
||||
|
||||
if app_mode == AppMode.CHAT.value:
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
import threading
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytz
|
||||
from flask_login import current_user
|
||||
|
||||
import contexts
|
||||
from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager
|
||||
@ -10,13 +9,14 @@ from core.plugin.impl.agent import PluginAgentClient
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.model import App, Conversation, EndUser, Message, MessageAgentThought
|
||||
|
||||
|
||||
class AgentService:
|
||||
@classmethod
|
||||
def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) -> dict:
|
||||
def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str):
|
||||
"""
|
||||
Service to get agent logs
|
||||
"""
|
||||
@ -61,14 +61,15 @@ class AgentService:
|
||||
executor = executor.name
|
||||
else:
|
||||
executor = "Unknown"
|
||||
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.timezone is not None
|
||||
timezone = pytz.timezone(current_user.timezone)
|
||||
|
||||
app_model_config = app_model.app_model_config
|
||||
if not app_model_config:
|
||||
raise ValueError("App model config not found")
|
||||
|
||||
result = {
|
||||
result: dict[str, Any] = {
|
||||
"meta": {
|
||||
"status": "success",
|
||||
"executor": executor,
|
||||
|
||||
@ -2,7 +2,6 @@ import uuid
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import or_, select
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import NotFound
|
||||
@ -10,6 +9,8 @@ from werkzeug.exceptions import NotFound
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation
|
||||
from services.feature_service import FeatureService
|
||||
from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task
|
||||
@ -24,6 +25,7 @@ class AppAnnotationService:
|
||||
@classmethod
|
||||
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
|
||||
# get app info
|
||||
assert isinstance(current_user, Account)
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
@ -62,6 +64,7 @@ class AppAnnotationService:
|
||||
db.session.commit()
|
||||
# if annotation reply is enabled , add annotation to index
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
|
||||
assert current_user.current_tenant_id is not None
|
||||
if annotation_setting:
|
||||
add_annotation_to_index_task.delay(
|
||||
annotation.id,
|
||||
@ -73,7 +76,7 @@ class AppAnnotationService:
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def enable_app_annotation(cls, args: dict, app_id: str) -> dict:
|
||||
def enable_app_annotation(cls, args: dict, app_id: str):
|
||||
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
|
||||
cache_result = redis_client.get(enable_app_annotation_key)
|
||||
if cache_result is not None:
|
||||
@ -84,6 +87,8 @@ class AppAnnotationService:
|
||||
enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"
|
||||
# send batch add segments task
|
||||
redis_client.setnx(enable_app_annotation_job_key, "waiting")
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
enable_annotation_reply_task.delay(
|
||||
str(job_id),
|
||||
app_id,
|
||||
@ -96,7 +101,9 @@ class AppAnnotationService:
|
||||
return {"job_id": job_id, "job_status": "waiting"}
|
||||
|
||||
@classmethod
|
||||
def disable_app_annotation(cls, app_id: str) -> dict:
|
||||
def disable_app_annotation(cls, app_id: str):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
|
||||
cache_result = redis_client.get(disable_app_annotation_key)
|
||||
if cache_result is not None:
|
||||
@ -113,6 +120,8 @@ class AppAnnotationService:
|
||||
@classmethod
|
||||
def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str):
|
||||
# get app info
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
@ -145,6 +154,8 @@ class AppAnnotationService:
|
||||
@classmethod
|
||||
def export_annotation_list_by_app_id(cls, app_id: str):
|
||||
# get app info
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
@ -164,6 +175,8 @@ class AppAnnotationService:
|
||||
@classmethod
|
||||
def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
|
||||
# get app info
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
@ -193,6 +206,8 @@ class AppAnnotationService:
|
||||
@classmethod
|
||||
def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
|
||||
# get app info
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
@ -230,6 +245,8 @@ class AppAnnotationService:
|
||||
@classmethod
|
||||
def delete_app_annotation(cls, app_id: str, annotation_id: str):
|
||||
# get app info
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
@ -269,6 +286,8 @@ class AppAnnotationService:
|
||||
@classmethod
|
||||
def delete_app_annotations_in_batch(cls, app_id: str, annotation_ids: list[str]):
|
||||
# get app info
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
@ -315,8 +334,10 @@ class AppAnnotationService:
|
||||
return {"deleted_count": deleted_count}
|
||||
|
||||
@classmethod
|
||||
def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict:
|
||||
def batch_import_app_annotations(cls, app_id, file: FileStorage):
|
||||
# get app info
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
@ -355,6 +376,8 @@ class AppAnnotationService:
|
||||
|
||||
@classmethod
|
||||
def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
# get app info
|
||||
app = (
|
||||
db.session.query(App)
|
||||
@ -425,6 +448,8 @@ class AppAnnotationService:
|
||||
|
||||
@classmethod
|
||||
def get_app_annotation_setting_by_app_id(cls, app_id: str):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
# get app info
|
||||
app = (
|
||||
db.session.query(App)
|
||||
@ -451,6 +476,8 @@ class AppAnnotationService:
|
||||
|
||||
@classmethod
|
||||
def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
# get app info
|
||||
app = (
|
||||
db.session.query(App)
|
||||
@ -490,7 +517,9 @@ class AppAnnotationService:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def clear_all_annotations(cls, app_id: str) -> dict:
|
||||
def clear_all_annotations(cls, app_id: str):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
|
||||
@ -30,7 +30,7 @@ class APIBasedExtensionService:
|
||||
return extension_data
|
||||
|
||||
@staticmethod
|
||||
def delete(extension_data: APIBasedExtension) -> None:
|
||||
def delete(extension_data: APIBasedExtension):
|
||||
db.session.delete(extension_data)
|
||||
db.session.commit()
|
||||
|
||||
@ -51,7 +51,7 @@ class APIBasedExtensionService:
|
||||
return extension
|
||||
|
||||
@classmethod
|
||||
def _validation(cls, extension_data: APIBasedExtension) -> None:
|
||||
def _validation(cls, extension_data: APIBasedExtension):
|
||||
# name
|
||||
if not extension_data.name:
|
||||
raise ValueError("name must not be empty")
|
||||
@ -95,7 +95,7 @@ class APIBasedExtensionService:
|
||||
cls._ping_connection(extension_data)
|
||||
|
||||
@staticmethod
|
||||
def _ping_connection(extension_data: APIBasedExtension) -> None:
|
||||
def _ping_connection(extension_data: APIBasedExtension):
|
||||
try:
|
||||
client = APIBasedExtensionRequestor(extension_data.api_endpoint, extension_data.api_key)
|
||||
resp = client.request(point=APIBasedExtensionPoint.PING, params={})
|
||||
|
||||
@ -566,7 +566,7 @@ class AppDslService:
|
||||
@classmethod
|
||||
def _append_workflow_export_data(
|
||||
cls, *, export_data: dict, app_model: App, include_secret: bool, workflow_id: Optional[str] = None
|
||||
) -> None:
|
||||
):
|
||||
"""
|
||||
Append workflow export data
|
||||
:param export_data: export data
|
||||
@ -608,7 +608,7 @@ class AppDslService:
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _append_model_config_export_data(cls, export_data: dict, app_model: App) -> None:
|
||||
def _append_model_config_export_data(cls, export_data: dict, app_model: App):
|
||||
"""
|
||||
Append model config export data
|
||||
:param export_data: export data
|
||||
|
||||
@ -6,7 +6,7 @@ from models.model import AppMode
|
||||
|
||||
class AppModelConfigService:
|
||||
@classmethod
|
||||
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> dict:
|
||||
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode):
|
||||
if app_mode == AppMode.CHAT:
|
||||
return ChatAppConfigManager.config_validate(tenant_id, config)
|
||||
elif app_mode == AppMode.AGENT_CHAT:
|
||||
|
||||
@ -2,7 +2,6 @@ import json
|
||||
import logging
|
||||
from typing import Optional, TypedDict, cast
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_sqlalchemy.pagination import Pagination
|
||||
|
||||
from configs import dify_config
|
||||
@ -17,6 +16,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from events.app_event import app_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.model import App, AppMode, AppModelConfig, Site
|
||||
from models.tools import ApiToolProvider
|
||||
@ -168,9 +168,13 @@ class AppService:
|
||||
"""
|
||||
Get App
|
||||
"""
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
# get original app model config
|
||||
if app.mode == AppMode.AGENT_CHAT.value or app.is_agent:
|
||||
model_config = app.app_model_config
|
||||
if not model_config:
|
||||
return app
|
||||
agent_mode = model_config.agent_mode_dict
|
||||
# decrypt agent tool parameters if it's secret-input
|
||||
for tool in agent_mode.get("tools") or []:
|
||||
@ -205,7 +209,8 @@ class AppService:
|
||||
pass
|
||||
|
||||
# override agent mode
|
||||
model_config.agent_mode = json.dumps(agent_mode)
|
||||
if model_config:
|
||||
model_config.agent_mode = json.dumps(agent_mode)
|
||||
|
||||
class ModifiedApp(App):
|
||||
"""
|
||||
@ -239,6 +244,7 @@ class AppService:
|
||||
:param args: request args
|
||||
:return: App instance
|
||||
"""
|
||||
assert current_user is not None
|
||||
app.name = args["name"]
|
||||
app.description = args["description"]
|
||||
app.icon_type = args["icon_type"]
|
||||
@ -259,6 +265,7 @@ class AppService:
|
||||
:param name: new name
|
||||
:return: App instance
|
||||
"""
|
||||
assert current_user is not None
|
||||
app.name = name
|
||||
app.updated_by = current_user.id
|
||||
app.updated_at = naive_utc_now()
|
||||
@ -274,6 +281,7 @@ class AppService:
|
||||
:param icon_background: new icon_background
|
||||
:return: App instance
|
||||
"""
|
||||
assert current_user is not None
|
||||
app.icon = icon
|
||||
app.icon_background = icon_background
|
||||
app.updated_by = current_user.id
|
||||
@ -291,7 +299,7 @@ class AppService:
|
||||
"""
|
||||
if enable_site == app.enable_site:
|
||||
return app
|
||||
|
||||
assert current_user is not None
|
||||
app.enable_site = enable_site
|
||||
app.updated_by = current_user.id
|
||||
app.updated_at = naive_utc_now()
|
||||
@ -308,6 +316,7 @@ class AppService:
|
||||
"""
|
||||
if enable_api == app.enable_api:
|
||||
return app
|
||||
assert current_user is not None
|
||||
|
||||
app.enable_api = enable_api
|
||||
app.updated_by = current_user.id
|
||||
@ -316,7 +325,7 @@ class AppService:
|
||||
|
||||
return app
|
||||
|
||||
def delete_app(self, app: App) -> None:
|
||||
def delete_app(self, app: App):
|
||||
"""
|
||||
Delete app
|
||||
:param app: App instance
|
||||
@ -331,7 +340,7 @@ class AppService:
|
||||
# Trigger asynchronous deletion of app and related data
|
||||
remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id)
|
||||
|
||||
def get_app_meta(self, app_model: App) -> dict:
|
||||
def get_app_meta(self, app_model: App):
|
||||
"""
|
||||
Get app meta info
|
||||
:param app_model: app model
|
||||
|
||||
@ -12,7 +12,7 @@ from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from models.enums import MessageStatus
|
||||
from models.model import App, AppMode, AppModelConfig, Message
|
||||
from models.model import App, AppMode, Message
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
NoAudioUploadedServiceError,
|
||||
@ -40,7 +40,9 @@ class AudioService:
|
||||
if "speech_to_text" not in features_dict or not features_dict["speech_to_text"].get("enabled"):
|
||||
raise ValueError("Speech to text is not enabled")
|
||||
else:
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
app_model_config = app_model.app_model_config
|
||||
if not app_model_config:
|
||||
raise ValueError("Speech to text is not enabled")
|
||||
|
||||
if not app_model_config.speech_to_text_dict["enabled"]:
|
||||
raise ValueError("Speech to text is not enabled")
|
||||
|
||||
@ -8,7 +8,7 @@ from services.auth.api_key_auth_factory import ApiKeyAuthFactory
|
||||
|
||||
class ApiKeyAuthService:
|
||||
@staticmethod
|
||||
def get_provider_auth_list(tenant_id: str) -> list:
|
||||
def get_provider_auth_list(tenant_id: str):
|
||||
data_source_api_key_bindings = (
|
||||
db.session.query(DataSourceApiKeyAuthBinding)
|
||||
.where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False))
|
||||
|
||||
@ -70,7 +70,7 @@ class BillingService:
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def is_tenant_owner_or_admin(current_user):
|
||||
def is_tenant_owner_or_admin(current_user: Account):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
join: Optional[TenantAccountJoin] = (
|
||||
|
||||
@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ClearFreePlanTenantExpiredLogs:
|
||||
@classmethod
|
||||
def _clear_message_related_tables(cls, session: Session, tenant_id: str, batch_message_ids: list[str]) -> None:
|
||||
def _clear_message_related_tables(cls, session: Session, tenant_id: str, batch_message_ids: list[str]):
|
||||
"""
|
||||
Clean up message-related tables to avoid data redundancy.
|
||||
This method cleans up tables that have foreign key relationships with Message.
|
||||
@ -353,7 +353,7 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
|
||||
thread_pool = ThreadPoolExecutor(max_workers=10)
|
||||
|
||||
def process_tenant(flask_app: Flask, tenant_id: str) -> None:
|
||||
def process_tenant(flask_app: Flask, tenant_id: str):
|
||||
try:
|
||||
if (
|
||||
not dify_config.BILLING_ENABLED
|
||||
|
||||
@ -3,7 +3,7 @@ from extensions.ext_code_based_extension import code_based_extension
|
||||
|
||||
class CodeBasedExtensionService:
|
||||
@staticmethod
|
||||
def get_code_based_extension(module: str) -> list[dict]:
|
||||
def get_code_based_extension(module: str):
|
||||
module_extensions = code_based_extension.module_extensions(module)
|
||||
return [
|
||||
{
|
||||
|
||||
@ -250,7 +250,7 @@ class ConversationService:
|
||||
variable_id: str,
|
||||
user: Optional[Union[Account, EndUser]],
|
||||
new_value: Any,
|
||||
) -> dict:
|
||||
):
|
||||
"""
|
||||
Update a conversation variable's value.
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ import uuid
|
||||
from collections import Counter
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from flask_login import current_user
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import exists, func, select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
@ -27,6 +27,7 @@ from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
from models.account import Account, TenantAccountRole
|
||||
from models.dataset import (
|
||||
AppDatasetJoin,
|
||||
@ -566,8 +567,11 @@ class DatasetService:
|
||||
data: Update data dictionary
|
||||
filtered_data: Filtered update data to modify
|
||||
"""
|
||||
# assert isinstance(current_user, Account) and current_user.current_tenant_id is not None
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=data["embedding_model_provider"],
|
||||
@ -679,8 +683,12 @@ class DatasetService:
|
||||
data: Update data dictionary
|
||||
filtered_data: Filtered update data to modify
|
||||
"""
|
||||
# assert isinstance(current_user, Account) and current_user.current_tenant_id is not None
|
||||
|
||||
model_manager = ModelManager()
|
||||
try:
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=data["embedding_model_provider"],
|
||||
@ -909,7 +917,9 @@ class DatasetService:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_dataset_auto_disable_logs(dataset_id: str) -> dict:
|
||||
def get_dataset_auto_disable_logs(dataset_id: str):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
if not features.billing.enabled or features.billing.subscription.plan == "sandbox":
|
||||
return {
|
||||
@ -1114,6 +1124,8 @@ class DocumentService:
|
||||
|
||||
@staticmethod
|
||||
def get_batch_documents(dataset_id: str, batch: str) -> list[Document]:
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.where(
|
||||
@ -1163,7 +1175,7 @@ class DocumentService:
|
||||
file_ids = [
|
||||
document.data_source_info_dict.get("upload_file_id", "")
|
||||
for document in documents
|
||||
if document.data_source_type == "upload_file"
|
||||
if document.data_source_type == "upload_file" and document.data_source_info_dict
|
||||
]
|
||||
batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids)
|
||||
|
||||
@ -1173,6 +1185,8 @@ class DocumentService:
|
||||
|
||||
@staticmethod
|
||||
def rename_document(dataset_id: str, document_id: str, name: str) -> Document:
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found.")
|
||||
@ -1202,6 +1216,7 @@ class DocumentService:
|
||||
if document.indexing_status not in {"waiting", "parsing", "cleaning", "splitting", "indexing"}:
|
||||
raise DocumentIndexingError()
|
||||
# update document to be paused
|
||||
assert current_user is not None
|
||||
document.is_paused = True
|
||||
document.paused_by = current_user.id
|
||||
document.paused_at = naive_utc_now()
|
||||
@ -1257,8 +1272,9 @@ class DocumentService:
|
||||
# sync document indexing
|
||||
document.indexing_status = "waiting"
|
||||
data_source_info = document.data_source_info_dict
|
||||
data_source_info["mode"] = "scrape"
|
||||
document.data_source_info = json.dumps(data_source_info, ensure_ascii=False)
|
||||
if data_source_info:
|
||||
data_source_info["mode"] = "scrape"
|
||||
document.data_source_info = json.dumps(data_source_info, ensure_ascii=False)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
@ -1287,6 +1303,9 @@ class DocumentService:
|
||||
# check doc_form
|
||||
DatasetService.check_doc_form(dataset, knowledge_config.doc_form)
|
||||
# check document limit
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
|
||||
if features.billing.enabled:
|
||||
@ -1887,6 +1906,8 @@ class DocumentService:
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_documents_count():
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
documents_count = (
|
||||
db.session.query(Document)
|
||||
.where(
|
||||
@ -1907,6 +1928,8 @@ class DocumentService:
|
||||
dataset_process_rule: Optional[DatasetProcessRule] = None,
|
||||
created_from: str = "web",
|
||||
):
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
document = DocumentService.get_document(dataset.id, document_data.original_document_id)
|
||||
if document is None:
|
||||
@ -1963,6 +1986,20 @@ class DocumentService:
|
||||
notion_info_list = document_data.data_source.info_list.notion_info_list
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info.workspace_id
|
||||
data_source_binding = (
|
||||
db.session.query(DataSourceOauthBinding)
|
||||
.where(
|
||||
sa.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not data_source_binding:
|
||||
raise ValueError("Data source binding not found.")
|
||||
for page in notion_info.pages:
|
||||
data_source_info = {
|
||||
"credential_id": notion_info.credential_id,
|
||||
@ -2014,6 +2051,9 @@ class DocumentService:
|
||||
|
||||
@staticmethod
|
||||
def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
|
||||
if features.billing.enabled:
|
||||
@ -2453,6 +2493,9 @@ class SegmentService:
|
||||
|
||||
@classmethod
|
||||
def create_segment(cls, args: dict, document: Document, dataset: Dataset):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
content = args["content"]
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
@ -2515,6 +2558,9 @@ class SegmentService:
|
||||
|
||||
@classmethod
|
||||
def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
lock_name = f"multi_add_segment_lock_document_id_{document.id}"
|
||||
increment_word_count = 0
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
@ -2597,9 +2643,10 @@ class SegmentService:
|
||||
return segment_data_list
|
||||
|
||||
@classmethod
|
||||
def update_segment(
|
||||
cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset
|
||||
) -> DocumentSegment:
|
||||
def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
@ -2793,6 +2840,7 @@ class SegmentService:
|
||||
|
||||
@classmethod
|
||||
def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset):
|
||||
assert isinstance(current_user, Account)
|
||||
segments = (
|
||||
db.session.query(DocumentSegment.index_node_id, DocumentSegment.word_count)
|
||||
.where(
|
||||
@ -2825,6 +2873,8 @@ class SegmentService:
|
||||
def update_segments_status(
|
||||
cls, segment_ids: list, action: Literal["enable", "disable"], dataset: Dataset, document: Document
|
||||
):
|
||||
assert current_user is not None
|
||||
|
||||
# Check if segment_ids is not empty to avoid WHERE false condition
|
||||
if not segment_ids or len(segment_ids) == 0:
|
||||
return
|
||||
@ -2887,6 +2937,8 @@ class SegmentService:
|
||||
def create_child_chunk(
|
||||
cls, content: str, segment: DocumentSegment, document: Document, dataset: Dataset
|
||||
) -> ChildChunk:
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
lock_name = f"add_child_lock_{segment.id}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
index_node_id = str(uuid.uuid4())
|
||||
@ -2934,6 +2986,8 @@ class SegmentService:
|
||||
document: Document,
|
||||
dataset: Dataset,
|
||||
) -> list[ChildChunk]:
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
child_chunks = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(
|
||||
@ -3008,6 +3062,8 @@ class SegmentService:
|
||||
document: Document,
|
||||
dataset: Dataset,
|
||||
) -> ChildChunk:
|
||||
assert current_user is not None
|
||||
|
||||
try:
|
||||
child_chunk.content = content
|
||||
child_chunk.word_count = len(content)
|
||||
@ -3038,6 +3094,8 @@ class SegmentService:
|
||||
def get_child_chunks(
|
||||
cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None
|
||||
):
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
query = (
|
||||
select(ChildChunk)
|
||||
.filter_by(
|
||||
|
||||
@ -83,7 +83,7 @@ class ProviderResponse(BaseModel):
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
url_prefix = (
|
||||
@ -113,7 +113,7 @@ class ProviderWithModelsResponse(BaseModel):
|
||||
status: CustomConfigurationStatus
|
||||
models: list[ProviderModelWithStatusEntity]
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
url_prefix = (
|
||||
@ -137,7 +137,7 @@ class SimpleProviderEntityResponse(SimpleProviderEntity):
|
||||
|
||||
tenant_id: str
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
url_prefix = (
|
||||
@ -174,7 +174,7 @@ class ModelWithProviderEntityResponse(ProviderModelWithStatusEntity):
|
||||
|
||||
provider: SimpleProviderEntityResponse
|
||||
|
||||
def __init__(self, tenant_id: str, model: ModelWithProviderEntity) -> None:
|
||||
def __init__(self, tenant_id: str, model: ModelWithProviderEntity):
|
||||
dump_model = model.model_dump()
|
||||
dump_model["provider"]["tenant_id"] = tenant_id
|
||||
super().__init__(**dump_model)
|
||||
|
||||
@ -6,7 +6,7 @@ class InvokeError(Exception):
|
||||
|
||||
description: Optional[str] = None
|
||||
|
||||
def __init__(self, description: Optional[str] = None) -> None:
|
||||
def __init__(self, description: Optional[str] = None):
|
||||
self.description = description
|
||||
|
||||
def __str__(self):
|
||||
|
||||
@ -114,8 +114,9 @@ class ExternalDatasetService:
|
||||
)
|
||||
if external_knowledge_api is None:
|
||||
raise ValueError("api template not found")
|
||||
if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE:
|
||||
args.get("settings")["api_key"] = external_knowledge_api.settings_dict.get("api_key")
|
||||
settings = args.get("settings")
|
||||
if settings and settings.get("api_key") == HIDDEN_VALUE and external_knowledge_api.settings_dict:
|
||||
settings["api_key"] = external_knowledge_api.settings_dict.get("api_key")
|
||||
|
||||
external_knowledge_api.name = args.get("name")
|
||||
external_knowledge_api.description = args.get("description", "")
|
||||
@ -277,7 +278,7 @@ class ExternalDatasetService:
|
||||
query: str,
|
||||
external_retrieval_parameters: dict,
|
||||
metadata_condition: Optional[MetadataCondition] = None,
|
||||
) -> list:
|
||||
):
|
||||
external_knowledge_binding = (
|
||||
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
|
||||
)
|
||||
|
||||
@ -3,7 +3,6 @@ import os
|
||||
import uuid
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
@ -20,6 +19,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import extract_tenant_id
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import EndUser, UploadFile
|
||||
@ -120,7 +120,11 @@ class FileService:
|
||||
|
||||
return file_size <= file_size_limit
|
||||
|
||||
def upload_text(self, text: str, text_name: str) -> UploadFile:
|
||||
@staticmethod
|
||||
def upload_text(text: str, text_name: str) -> UploadFile:
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
if len(text_name) > 200:
|
||||
text_name = text_name[:200]
|
||||
# user uuid as file name
|
||||
|
||||
@ -33,7 +33,7 @@ class HitTestingService:
|
||||
retrieval_model: Any, # FIXME drop this any
|
||||
external_retrieval_model: dict,
|
||||
limit: int = 10,
|
||||
) -> dict:
|
||||
):
|
||||
start = time.perf_counter()
|
||||
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
@ -98,7 +98,7 @@ class HitTestingService:
|
||||
account: Account,
|
||||
external_retrieval_model: dict,
|
||||
metadata_filtering_conditions: dict,
|
||||
) -> dict:
|
||||
):
|
||||
if dataset.provider != "external":
|
||||
return {
|
||||
"query": {"content": query},
|
||||
|
||||
@ -25,10 +25,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelLoadBalancingService:
|
||||
def __init__(self) -> None:
|
||||
def __init__(self):
|
||||
self.provider_manager = ProviderManager()
|
||||
|
||||
def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
|
||||
def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str):
|
||||
"""
|
||||
enable model load balancing.
|
||||
|
||||
@ -49,7 +49,7 @@ class ModelLoadBalancingService:
|
||||
# Enable model load balancing
|
||||
provider_configuration.enable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
|
||||
|
||||
def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
|
||||
def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str):
|
||||
"""
|
||||
disable model load balancing.
|
||||
|
||||
@ -295,7 +295,7 @@ class ModelLoadBalancingService:
|
||||
|
||||
def update_load_balancing_configs(
|
||||
self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict], config_from: str
|
||||
) -> None:
|
||||
):
|
||||
"""
|
||||
Update load balancing configurations.
|
||||
:param tenant_id: workspace id
|
||||
@ -478,7 +478,7 @@ class ModelLoadBalancingService:
|
||||
model_type: str,
|
||||
credentials: dict,
|
||||
config_id: Optional[str] = None,
|
||||
) -> None:
|
||||
):
|
||||
"""
|
||||
Validate load balancing credentials.
|
||||
:param tenant_id: workspace id
|
||||
@ -537,7 +537,7 @@ class ModelLoadBalancingService:
|
||||
credentials: dict,
|
||||
load_balancing_model_config: Optional[LoadBalancingModelConfig] = None,
|
||||
validate: bool = True,
|
||||
) -> dict:
|
||||
):
|
||||
"""
|
||||
Validate custom credentials.
|
||||
:param tenant_id: workspace id
|
||||
@ -605,7 +605,7 @@ class ModelLoadBalancingService:
|
||||
else:
|
||||
raise ValueError("No credential schema found")
|
||||
|
||||
def _clear_credentials_cache(self, tenant_id: str, config_id: str) -> None:
|
||||
def _clear_credentials_cache(self, tenant_id: str, config_id: str):
|
||||
"""
|
||||
Clear credentials cache.
|
||||
:param tenant_id: workspace id
|
||||
|
||||
@ -26,7 +26,7 @@ class ModelProviderService:
|
||||
Model Provider Service
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self):
|
||||
self.provider_manager = ProviderManager()
|
||||
|
||||
def _get_provider_configuration(self, tenant_id: str, provider: str):
|
||||
@ -142,7 +142,7 @@ class ModelProviderService:
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
return provider_configuration.get_provider_credential(credential_id=credential_id) # type: ignore
|
||||
|
||||
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict) -> None:
|
||||
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict):
|
||||
"""
|
||||
validate provider credentials before saving.
|
||||
|
||||
@ -193,7 +193,7 @@ class ModelProviderService:
|
||||
credential_name=credential_name,
|
||||
)
|
||||
|
||||
def remove_provider_credential(self, tenant_id: str, provider: str, credential_id: str) -> None:
|
||||
def remove_provider_credential(self, tenant_id: str, provider: str, credential_id: str):
|
||||
"""
|
||||
remove a saved provider credential (by credential_id).
|
||||
:param tenant_id: workspace id
|
||||
@ -204,7 +204,7 @@ class ModelProviderService:
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.delete_provider_credential(credential_id=credential_id)
|
||||
|
||||
def switch_active_provider_credential(self, tenant_id: str, provider: str, credential_id: str) -> None:
|
||||
def switch_active_provider_credential(self, tenant_id: str, provider: str, credential_id: str):
|
||||
"""
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
@ -232,9 +232,7 @@ class ModelProviderService:
|
||||
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
|
||||
)
|
||||
|
||||
def validate_model_credentials(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict
|
||||
) -> None:
|
||||
def validate_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict):
|
||||
"""
|
||||
validate model credentials.
|
||||
|
||||
@ -303,9 +301,7 @@ class ModelProviderService:
|
||||
credential_name=credential_name,
|
||||
)
|
||||
|
||||
def remove_model_credential(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str
|
||||
) -> None:
|
||||
def remove_model_credential(self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str):
|
||||
"""
|
||||
remove model credentials.
|
||||
|
||||
@ -323,7 +319,7 @@ class ModelProviderService:
|
||||
|
||||
def switch_active_custom_model_credential(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str
|
||||
) -> None:
|
||||
):
|
||||
"""
|
||||
switch model credentials.
|
||||
|
||||
@ -341,7 +337,7 @@ class ModelProviderService:
|
||||
|
||||
def add_model_credential_to_model_list(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str
|
||||
) -> None:
|
||||
):
|
||||
"""
|
||||
add model credentials to model list.
|
||||
|
||||
@ -357,7 +353,7 @@ class ModelProviderService:
|
||||
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
|
||||
)
|
||||
|
||||
def remove_model(self, tenant_id: str, provider: str, model_type: str, model: str) -> None:
|
||||
def remove_model(self, tenant_id: str, provider: str, model_type: str, model: str):
|
||||
"""
|
||||
remove model credentials.
|
||||
|
||||
@ -485,7 +481,7 @@ class ModelProviderService:
|
||||
logger.debug("get_default_model_of_model_type error: %s", e)
|
||||
return None
|
||||
|
||||
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None:
|
||||
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str):
|
||||
"""
|
||||
update default model of model type.
|
||||
|
||||
@ -517,7 +513,7 @@ class ModelProviderService:
|
||||
|
||||
return byte_data, mime_type
|
||||
|
||||
def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str) -> None:
|
||||
def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str):
|
||||
"""
|
||||
switch preferred provider.
|
||||
|
||||
@ -534,7 +530,7 @@ class ModelProviderService:
|
||||
# Switch preferred provider type
|
||||
provider_configuration.switch_preferred_provider_type(preferred_provider_type_enum)
|
||||
|
||||
def enable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
|
||||
def enable_model(self, tenant_id: str, provider: str, model: str, model_type: str):
|
||||
"""
|
||||
enable model.
|
||||
|
||||
@ -547,7 +543,7 @@ class ModelProviderService:
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type))
|
||||
|
||||
def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
|
||||
def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str):
|
||||
"""
|
||||
disable model.
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class PluginDataMigration:
|
||||
@classmethod
|
||||
def migrate(cls) -> None:
|
||||
def migrate(cls):
|
||||
cls.migrate_db_records("providers", "provider_name", ModelProviderID) # large table
|
||||
cls.migrate_db_records("provider_models", "provider_name", ModelProviderID)
|
||||
cls.migrate_db_records("provider_orders", "provider_name", ModelProviderID)
|
||||
@ -26,7 +26,7 @@ class PluginDataMigration:
|
||||
cls.migrate_db_records("tool_builtin_providers", "provider", ToolProviderID)
|
||||
|
||||
@classmethod
|
||||
def migrate_datasets(cls) -> None:
|
||||
def migrate_datasets(cls):
|
||||
table_name = "datasets"
|
||||
provider_column_name = "embedding_model_provider"
|
||||
|
||||
@ -126,9 +126,7 @@ limit 1000"""
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def migrate_db_records(
|
||||
cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID]
|
||||
) -> None:
|
||||
def migrate_db_records(cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID]):
|
||||
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
|
||||
|
||||
processed_count = 0
|
||||
|
||||
@ -35,7 +35,7 @@ excluded_providers = ["time", "audio", "code", "webscraper"]
|
||||
|
||||
class PluginMigration:
|
||||
@classmethod
|
||||
def extract_plugins(cls, filepath: str, workers: int) -> None:
|
||||
def extract_plugins(cls, filepath: str, workers: int):
|
||||
"""
|
||||
Migrate plugin.
|
||||
"""
|
||||
@ -57,7 +57,7 @@ class PluginMigration:
|
||||
|
||||
thread_pool = ThreadPoolExecutor(max_workers=workers)
|
||||
|
||||
def process_tenant(flask_app: Flask, tenant_id: str) -> None:
|
||||
def process_tenant(flask_app: Flask, tenant_id: str):
|
||||
with flask_app.app_context():
|
||||
nonlocal handled_tenant_count
|
||||
try:
|
||||
@ -293,7 +293,7 @@ class PluginMigration:
|
||||
return plugin_manifest[0].latest_package_identifier
|
||||
|
||||
@classmethod
|
||||
def extract_unique_plugins_to_file(cls, extracted_plugins: str, output_file: str) -> None:
|
||||
def extract_unique_plugins_to_file(cls, extracted_plugins: str, output_file: str):
|
||||
"""
|
||||
Extract unique plugins.
|
||||
"""
|
||||
@ -330,7 +330,7 @@ class PluginMigration:
|
||||
return {"plugins": plugins, "plugin_not_exist": plugin_not_exist}
|
||||
|
||||
@classmethod
|
||||
def install_plugins(cls, extracted_plugins: str, output_file: str, workers: int = 100) -> None:
|
||||
def install_plugins(cls, extracted_plugins: str, output_file: str, workers: int = 100):
|
||||
"""
|
||||
Install plugins.
|
||||
"""
|
||||
@ -350,7 +350,7 @@ class PluginMigration:
|
||||
if response.get("failed"):
|
||||
plugin_install_failed.extend(response.get("failed", []))
|
||||
|
||||
def install(tenant_id: str, plugin_ids: list[str]) -> None:
|
||||
def install(tenant_id: str, plugin_ids: list[str]):
|
||||
logger.info("Installing %s plugins for tenant %s", len(plugin_ids), tenant_id)
|
||||
# fetch plugin already installed
|
||||
installed_plugins = manager.list_plugins(tenant_id)
|
||||
|
||||
@ -19,7 +19,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
def get_type(self) -> str:
|
||||
return RecommendAppType.BUILDIN
|
||||
|
||||
def get_recommended_apps_and_categories(self, language: str) -> dict:
|
||||
def get_recommended_apps_and_categories(self, language: str):
|
||||
result = self.fetch_recommended_apps_from_builtin(language)
|
||||
return result
|
||||
|
||||
@ -28,7 +28,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _get_builtin_data(cls) -> dict:
|
||||
def _get_builtin_data(cls):
|
||||
"""
|
||||
Get builtin data.
|
||||
:return:
|
||||
@ -44,7 +44,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
return cls.builtin_data or {}
|
||||
|
||||
@classmethod
|
||||
def fetch_recommended_apps_from_builtin(cls, language: str) -> dict:
|
||||
def fetch_recommended_apps_from_builtin(cls, language: str):
|
||||
"""
|
||||
Fetch recommended apps from builtin.
|
||||
:param language: language
|
||||
|
||||
@ -13,7 +13,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
Retrieval recommended app from database
|
||||
"""
|
||||
|
||||
def get_recommended_apps_and_categories(self, language: str) -> dict:
|
||||
def get_recommended_apps_and_categories(self, language: str):
|
||||
result = self.fetch_recommended_apps_from_db(language)
|
||||
return result
|
||||
|
||||
@ -25,7 +25,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
return RecommendAppType.DATABASE
|
||||
|
||||
@classmethod
|
||||
def fetch_recommended_apps_from_db(cls, language: str) -> dict:
|
||||
def fetch_recommended_apps_from_db(cls, language: str):
|
||||
"""
|
||||
Fetch recommended apps from db.
|
||||
:param language: language
|
||||
|
||||
@ -5,7 +5,7 @@ class RecommendAppRetrievalBase(ABC):
|
||||
"""Interface for recommend app retrieval."""
|
||||
|
||||
@abstractmethod
|
||||
def get_recommended_apps_and_categories(self, language: str) -> dict:
|
||||
def get_recommended_apps_and_categories(self, language: str):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@ -24,7 +24,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
result = BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin(app_id)
|
||||
return result
|
||||
|
||||
def get_recommended_apps_and_categories(self, language: str) -> dict:
|
||||
def get_recommended_apps_and_categories(self, language: str):
|
||||
try:
|
||||
result = self.fetch_recommended_apps_from_dify_official(language)
|
||||
except Exception as e:
|
||||
@ -51,7 +51,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def fetch_recommended_apps_from_dify_official(cls, language: str) -> dict:
|
||||
def fetch_recommended_apps_from_dify_official(cls, language: str):
|
||||
"""
|
||||
Fetch recommended apps from dify official.
|
||||
:param language: language
|
||||
|
||||
@ -6,7 +6,7 @@ from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFa
|
||||
|
||||
class RecommendedAppService:
|
||||
@classmethod
|
||||
def get_recommended_apps_and_categories(cls, language: str) -> dict:
|
||||
def get_recommended_apps_and_categories(cls, language: str):
|
||||
"""
|
||||
Get recommended apps and categories.
|
||||
:param language: language
|
||||
|
||||
@ -12,7 +12,7 @@ from models.model import App, Tag, TagBinding
|
||||
|
||||
class TagService:
|
||||
@staticmethod
|
||||
def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = None) -> list:
|
||||
def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = None):
|
||||
query = (
|
||||
db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
|
||||
.outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
|
||||
@ -25,7 +25,7 @@ class TagService:
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list:
|
||||
def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list):
|
||||
# Check if tag_ids is not empty to avoid WHERE false condition
|
||||
if not tag_ids or len(tag_ids) == 0:
|
||||
return []
|
||||
@ -51,7 +51,7 @@ class TagService:
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str) -> list:
|
||||
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str):
|
||||
if not tag_type or not tag_name:
|
||||
return []
|
||||
tags = (
|
||||
@ -64,7 +64,7 @@ class TagService:
|
||||
return tags
|
||||
|
||||
@staticmethod
|
||||
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list:
|
||||
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str):
|
||||
tags = (
|
||||
db.session.query(Tag)
|
||||
.join(TagBinding, Tag.id == TagBinding.tag_id)
|
||||
|
||||
@ -551,7 +551,7 @@ class BuiltinToolManageService:
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
|
||||
data=provider_controller,
|
||||
name_func=lambda x: x.identity.name,
|
||||
name_func=lambda x: x.entity.identity.name,
|
||||
):
|
||||
continue
|
||||
|
||||
|
||||
@ -226,7 +226,7 @@ class MCPToolManageService:
|
||||
def update_mcp_provider_credentials(
|
||||
cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False
|
||||
):
|
||||
provider_controller = MCPToolProviderController._from_db(mcp_provider)
|
||||
provider_controller = MCPToolProviderController.from_db(mcp_provider)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=mcp_provider.tenant_id,
|
||||
config=list(provider_controller.get_credentials_schema()), # ty: ignore [invalid-argument-type]
|
||||
|
||||
@ -37,7 +37,7 @@ class WorkflowToolManageService:
|
||||
parameters: list[Mapping[str, Any]],
|
||||
privacy_policy: str = "",
|
||||
labels: list[str] | None = None,
|
||||
) -> dict:
|
||||
):
|
||||
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
|
||||
|
||||
# check if the name is unique
|
||||
@ -103,7 +103,7 @@ class WorkflowToolManageService:
|
||||
parameters: list[Mapping[str, Any]],
|
||||
privacy_policy: str = "",
|
||||
labels: list[str] | None = None,
|
||||
) -> dict:
|
||||
):
|
||||
"""
|
||||
Update a workflow tool.
|
||||
:param user_id: the user id
|
||||
@ -217,7 +217,7 @@ class WorkflowToolManageService:
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict:
|
||||
def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str):
|
||||
"""
|
||||
Delete a workflow tool.
|
||||
:param user_id: the user id
|
||||
@ -233,7 +233,7 @@ class WorkflowToolManageService:
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict:
|
||||
def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str):
|
||||
"""
|
||||
Get a workflow tool.
|
||||
:param user_id: the user id
|
||||
@ -249,7 +249,7 @@ class WorkflowToolManageService:
|
||||
return cls._get_workflow_tool(tenant_id, db_tool)
|
||||
|
||||
@classmethod
|
||||
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict:
|
||||
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str):
|
||||
"""
|
||||
Get a workflow tool.
|
||||
:param user_id: the user id
|
||||
@ -265,7 +265,7 @@ class WorkflowToolManageService:
|
||||
return cls._get_workflow_tool(tenant_id, db_tool)
|
||||
|
||||
@classmethod
|
||||
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None) -> dict:
|
||||
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None):
|
||||
"""
|
||||
Get a workflow tool.
|
||||
:db_tool: the database tool
|
||||
|
||||
@ -145,7 +145,7 @@ class WebsiteService:
|
||||
return encrypter.decrypt_token(tenant_id=tenant_id, token=api_key)
|
||||
|
||||
@classmethod
|
||||
def document_create_args_validate(cls, args: dict) -> None:
|
||||
def document_create_args_validate(cls, args: dict):
|
||||
"""Validate arguments for document creation."""
|
||||
try:
|
||||
WebsiteCrawlApiRequest.from_args(args)
|
||||
|
||||
@ -217,7 +217,7 @@ class WorkflowConverter:
|
||||
|
||||
return app_config
|
||||
|
||||
def _convert_to_start_node(self, variables: list[VariableEntity]) -> dict:
|
||||
def _convert_to_start_node(self, variables: list[VariableEntity]):
|
||||
"""
|
||||
Convert to Start Node
|
||||
:param variables: list of variables
|
||||
@ -384,7 +384,7 @@ class WorkflowConverter:
|
||||
prompt_template: PromptTemplateEntity,
|
||||
file_upload: Optional[FileUploadConfig] = None,
|
||||
external_data_variable_node_mapping: dict[str, str] | None = None,
|
||||
) -> dict:
|
||||
):
|
||||
"""
|
||||
Convert to LLM Node
|
||||
:param original_app_mode: original app mode
|
||||
@ -550,7 +550,7 @@ class WorkflowConverter:
|
||||
|
||||
return template
|
||||
|
||||
def _convert_to_end_node(self) -> dict:
|
||||
def _convert_to_end_node(self):
|
||||
"""
|
||||
Convert to End Node
|
||||
:return:
|
||||
@ -566,7 +566,7 @@ class WorkflowConverter:
|
||||
},
|
||||
}
|
||||
|
||||
def _convert_to_answer_node(self) -> dict:
|
||||
def _convert_to_answer_node(self):
|
||||
"""
|
||||
Convert to Answer Node
|
||||
:return:
|
||||
@ -578,7 +578,7 @@ class WorkflowConverter:
|
||||
"data": {"title": "ANSWER", "type": NodeType.ANSWER.value, "answer": "{{#llm.text#}}"},
|
||||
}
|
||||
|
||||
def _create_edge(self, source: str, target: str) -> dict:
|
||||
def _create_edge(self, source: str, target: str):
|
||||
"""
|
||||
Create Edge
|
||||
:param source: source node id
|
||||
@ -587,7 +587,7 @@ class WorkflowConverter:
|
||||
"""
|
||||
return {"id": f"{source}-{target}", "source": source, "target": target}
|
||||
|
||||
def _append_node(self, graph: dict, node: dict) -> dict:
|
||||
def _append_node(self, graph: dict, node: dict):
|
||||
"""
|
||||
Append Node to Graph
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ class WorkflowAppService:
|
||||
limit: int = 20,
|
||||
created_by_end_user_session_id: str | None = None,
|
||||
created_by_account: str | None = None,
|
||||
) -> dict:
|
||||
):
|
||||
"""
|
||||
Get paginate workflow app logs using SQLAlchemy 2.0 style
|
||||
:param session: SQLAlchemy session
|
||||
|
||||
@ -85,7 +85,7 @@ class DraftVarLoader(VariableLoader):
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
fallback_variables: Sequence[Variable] | None = None,
|
||||
) -> None:
|
||||
):
|
||||
self._engine = engine
|
||||
self._app_id = app_id
|
||||
self._tenant_id = tenant_id
|
||||
@ -185,7 +185,7 @@ class DraftVarLoader(VariableLoader):
|
||||
class WorkflowDraftVariableService:
|
||||
_session: Session
|
||||
|
||||
def __init__(self, session: Session) -> None:
|
||||
def __init__(self, session: Session):
|
||||
"""
|
||||
Initialize the WorkflowDraftVariableService with a SQLAlchemy session.
|
||||
|
||||
@ -602,7 +602,7 @@ def _batch_upsert_draft_variable(
|
||||
session: Session,
|
||||
draft_vars: Sequence[WorkflowDraftVariable],
|
||||
policy: _UpsertPolicy = _UpsertPolicy.OVERWRITE,
|
||||
) -> None:
|
||||
):
|
||||
if not draft_vars:
|
||||
return None
|
||||
# Although we could use SQLAlchemy ORM operations here, we choose not to for several reasons:
|
||||
|
||||
@ -619,7 +619,7 @@ class WorkflowService:
|
||||
|
||||
return new_app
|
||||
|
||||
def validate_features_structure(self, app_model: App, features: dict) -> dict:
|
||||
def validate_features_structure(self, app_model: App, features: dict):
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
return AdvancedChatAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
|
||||
|
||||
Reference in New Issue
Block a user