Merge branch 'feat/queue-based-graph-engine' into feat/rag-2

This commit is contained in:
-LAN-
2025-09-08 14:30:43 +08:00
828 changed files with 7240 additions and 2951 deletions

View File

@ -105,14 +105,14 @@ class AccountService:
return f"{ACCOUNT_REFRESH_TOKEN_PREFIX}{account_id}"
@staticmethod
def _store_refresh_token(refresh_token: str, account_id: str) -> None:
def _store_refresh_token(refresh_token: str, account_id: str):
redis_client.setex(AccountService._get_refresh_token_key(refresh_token), REFRESH_TOKEN_EXPIRY, account_id)
redis_client.setex(
AccountService._get_account_refresh_token_key(account_id), REFRESH_TOKEN_EXPIRY, refresh_token
)
@staticmethod
def _delete_refresh_token(refresh_token: str, account_id: str) -> None:
def _delete_refresh_token(refresh_token: str, account_id: str):
redis_client.delete(AccountService._get_refresh_token_key(refresh_token))
redis_client.delete(AccountService._get_account_refresh_token_key(account_id))
@ -214,6 +214,7 @@ class AccountService:
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.add(account)
db.session.commit()
return account
@ -311,12 +312,12 @@ class AccountService:
return True
@staticmethod
def delete_account(account: Account) -> None:
def delete_account(account: Account):
"""Delete account. This method only adds a task to the queue for deletion."""
delete_account_task.delay(account.id)
@staticmethod
def link_account_integrate(provider: str, open_id: str, account: Account) -> None:
def link_account_integrate(provider: str, open_id: str, account: Account):
"""Link account integrate"""
try:
# Query whether there is an existing binding record for the same provider
@ -343,7 +344,7 @@ class AccountService:
raise LinkAccountIntegrateError("Failed to link account.") from e
@staticmethod
def close_account(account: Account) -> None:
def close_account(account: Account):
"""Close account"""
account.status = AccountStatus.CLOSED.value
db.session.commit()
@ -351,6 +352,7 @@ class AccountService:
@staticmethod
def update_account(account, **kwargs):
"""Update account fields"""
account = db.session.merge(account)
for field, value in kwargs.items():
if hasattr(account, field):
setattr(account, field, value)
@ -372,7 +374,7 @@ class AccountService:
return account
@staticmethod
def update_login_info(account: Account, *, ip_address: str) -> None:
def update_login_info(account: Account, *, ip_address: str):
"""Update last login time and ip"""
account.last_login_at = naive_utc_now()
account.last_login_ip = ip_address
@ -396,7 +398,7 @@ class AccountService:
return TokenPair(access_token=access_token, refresh_token=refresh_token)
@staticmethod
def logout(*, account: Account) -> None:
def logout(*, account: Account):
refresh_token = redis_client.get(AccountService._get_account_refresh_token_key(account.id))
if refresh_token:
AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id)
@ -703,7 +705,7 @@ class AccountService:
@staticmethod
@redis_fallback(default_return=None)
def add_login_error_rate_limit(email: str) -> None:
def add_login_error_rate_limit(email: str):
key = f"login_error_rate_limit:{email}"
count = redis_client.get(key)
if count is None:
@ -732,7 +734,7 @@ class AccountService:
@staticmethod
@redis_fallback(default_return=None)
def add_forgot_password_error_rate_limit(email: str) -> None:
def add_forgot_password_error_rate_limit(email: str):
key = f"forgot_password_error_rate_limit:{email}"
count = redis_client.get(key)
if count is None:
@ -761,7 +763,7 @@ class AccountService:
@staticmethod
@redis_fallback(default_return=None)
def add_change_email_error_rate_limit(email: str) -> None:
def add_change_email_error_rate_limit(email: str):
key = f"change_email_error_rate_limit:{email}"
count = redis_client.get(key)
if count is None:
@ -789,7 +791,7 @@ class AccountService:
@staticmethod
@redis_fallback(default_return=None)
def add_owner_transfer_error_rate_limit(email: str) -> None:
def add_owner_transfer_error_rate_limit(email: str):
key = f"owner_transfer_error_rate_limit:{email}"
count = redis_client.get(key)
if count is None:
@ -968,7 +970,7 @@ class TenantService:
return tenant
@staticmethod
def switch_tenant(account: Account, tenant_id: Optional[str] = None) -> None:
def switch_tenant(account: Account, tenant_id: Optional[str] = None):
"""Switch the current workspace for the account"""
# Ensure tenant_id is provided
@ -1065,7 +1067,7 @@ class TenantService:
return cast(int, db.session.query(func.count(Tenant.id)).scalar())
@staticmethod
def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str) -> None:
def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str):
"""Check member permission"""
perms = {
"add": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN],
@ -1085,7 +1087,7 @@ class TenantService:
raise NoPermissionError(f"No permission to {action} member.")
@staticmethod
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None:
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account):
"""Remove member from tenant"""
if operator.id == account.id:
raise CannotOperateSelfError("Cannot operate self.")
@ -1100,7 +1102,7 @@ class TenantService:
db.session.commit()
@staticmethod
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account) -> None:
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account):
"""Update member role"""
TenantService.check_member_permission(tenant, operator, member, "update")
@ -1127,7 +1129,7 @@ class TenantService:
db.session.commit()
@staticmethod
def get_custom_config(tenant_id: str) -> dict:
def get_custom_config(tenant_id: str):
tenant = db.get_or_404(Tenant, tenant_id)
return tenant.custom_config_dict
@ -1148,7 +1150,7 @@ class RegisterService:
return f"member_invite:token:{token}"
@classmethod
def setup(cls, email: str, name: str, password: str, ip_address: str) -> None:
def setup(cls, email: str, name: str, password: str, ip_address: str):
"""
Setup dify

View File

@ -17,7 +17,7 @@ from models.model import AppMode
class AdvancedPromptTemplateService:
@classmethod
def get_prompt(cls, args: dict) -> dict:
def get_prompt(cls, args: dict):
app_mode = args["app_mode"]
model_mode = args["model_mode"]
model_name = args["model_name"]
@ -29,7 +29,7 @@ class AdvancedPromptTemplateService:
return cls.get_common_prompt(app_mode, model_mode, has_context)
@classmethod
def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict:
def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str):
context_prompt = copy.deepcopy(CONTEXT)
if app_mode == AppMode.CHAT.value:
@ -52,7 +52,7 @@ class AdvancedPromptTemplateService:
return {}
@classmethod
def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str):
if has_context == "true":
prompt_template["completion_prompt_config"]["prompt"]["text"] = (
context + prompt_template["completion_prompt_config"]["prompt"]["text"]
@ -61,7 +61,7 @@ class AdvancedPromptTemplateService:
return prompt_template
@classmethod
def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str):
if has_context == "true":
prompt_template["chat_prompt_config"]["prompt"][0]["text"] = (
context + prompt_template["chat_prompt_config"]["prompt"][0]["text"]
@ -70,7 +70,7 @@ class AdvancedPromptTemplateService:
return prompt_template
@classmethod
def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict:
def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str):
baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)
if app_mode == AppMode.CHAT.value:

View File

@ -1,8 +1,7 @@
import threading
from typing import Optional
from typing import Any, Optional
import pytz
from flask_login import current_user
import contexts
from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager
@ -10,13 +9,14 @@ from core.plugin.impl.agent import PluginAgentClient
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.tools.tool_manager import ToolManager
from extensions.ext_database import db
from libs.login import current_user
from models.account import Account
from models.model import App, Conversation, EndUser, Message, MessageAgentThought
class AgentService:
@classmethod
def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) -> dict:
def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str):
"""
Service to get agent logs
"""
@ -61,14 +61,15 @@ class AgentService:
executor = executor.name
else:
executor = "Unknown"
assert isinstance(current_user, Account)
assert current_user.timezone is not None
timezone = pytz.timezone(current_user.timezone)
app_model_config = app_model.app_model_config
if not app_model_config:
raise ValueError("App model config not found")
result = {
result: dict[str, Any] = {
"meta": {
"status": "success",
"executor": executor,

View File

@ -2,7 +2,6 @@ import uuid
from typing import Optional
import pandas as pd
from flask_login import current_user
from sqlalchemy import or_, select
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
@ -10,6 +9,8 @@ from werkzeug.exceptions import NotFound
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models.account import Account
from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation
from services.feature_service import FeatureService
from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task
@ -24,6 +25,7 @@ class AppAnnotationService:
@classmethod
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
# get app info
assert isinstance(current_user, Account)
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@ -62,6 +64,7 @@ class AppAnnotationService:
db.session.commit()
# if annotation reply is enabled , add annotation to index
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
assert current_user.current_tenant_id is not None
if annotation_setting:
add_annotation_to_index_task.delay(
annotation.id,
@ -73,7 +76,7 @@ class AppAnnotationService:
return annotation
@classmethod
def enable_app_annotation(cls, args: dict, app_id: str) -> dict:
def enable_app_annotation(cls, args: dict, app_id: str):
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
cache_result = redis_client.get(enable_app_annotation_key)
if cache_result is not None:
@ -84,6 +87,8 @@ class AppAnnotationService:
enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"
# send batch add segments task
redis_client.setnx(enable_app_annotation_job_key, "waiting")
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
enable_annotation_reply_task.delay(
str(job_id),
app_id,
@ -96,7 +101,9 @@ class AppAnnotationService:
return {"job_id": job_id, "job_status": "waiting"}
@classmethod
def disable_app_annotation(cls, app_id: str) -> dict:
def disable_app_annotation(cls, app_id: str):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
cache_result = redis_client.get(disable_app_annotation_key)
if cache_result is not None:
@ -113,6 +120,8 @@ class AppAnnotationService:
@classmethod
def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str):
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@ -145,6 +154,8 @@ class AppAnnotationService:
@classmethod
def export_annotation_list_by_app_id(cls, app_id: str):
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@ -164,6 +175,8 @@ class AppAnnotationService:
@classmethod
def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@ -193,6 +206,8 @@ class AppAnnotationService:
@classmethod
def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@ -230,6 +245,8 @@ class AppAnnotationService:
@classmethod
def delete_app_annotation(cls, app_id: str, annotation_id: str):
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@ -269,6 +286,8 @@ class AppAnnotationService:
@classmethod
def delete_app_annotations_in_batch(cls, app_id: str, annotation_ids: list[str]):
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@ -315,8 +334,10 @@ class AppAnnotationService:
return {"deleted_count": deleted_count}
@classmethod
def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict:
def batch_import_app_annotations(cls, app_id, file: FileStorage):
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@ -355,6 +376,8 @@ class AppAnnotationService:
@classmethod
def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# get app info
app = (
db.session.query(App)
@ -425,6 +448,8 @@ class AppAnnotationService:
@classmethod
def get_app_annotation_setting_by_app_id(cls, app_id: str):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# get app info
app = (
db.session.query(App)
@ -451,6 +476,8 @@ class AppAnnotationService:
@classmethod
def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# get app info
app = (
db.session.query(App)
@ -490,7 +517,9 @@ class AppAnnotationService:
}
@classmethod
def clear_all_annotations(cls, app_id: str) -> dict:
def clear_all_annotations(cls, app_id: str):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")

View File

@ -30,7 +30,7 @@ class APIBasedExtensionService:
return extension_data
@staticmethod
def delete(extension_data: APIBasedExtension) -> None:
def delete(extension_data: APIBasedExtension):
db.session.delete(extension_data)
db.session.commit()
@ -51,7 +51,7 @@ class APIBasedExtensionService:
return extension
@classmethod
def _validation(cls, extension_data: APIBasedExtension) -> None:
def _validation(cls, extension_data: APIBasedExtension):
# name
if not extension_data.name:
raise ValueError("name must not be empty")
@ -95,7 +95,7 @@ class APIBasedExtensionService:
cls._ping_connection(extension_data)
@staticmethod
def _ping_connection(extension_data: APIBasedExtension) -> None:
def _ping_connection(extension_data: APIBasedExtension):
try:
client = APIBasedExtensionRequestor(extension_data.api_endpoint, extension_data.api_key)
resp = client.request(point=APIBasedExtensionPoint.PING, params={})

View File

@ -566,7 +566,7 @@ class AppDslService:
@classmethod
def _append_workflow_export_data(
cls, *, export_data: dict, app_model: App, include_secret: bool, workflow_id: Optional[str] = None
) -> None:
):
"""
Append workflow export data
:param export_data: export data
@ -608,7 +608,7 @@ class AppDslService:
]
@classmethod
def _append_model_config_export_data(cls, export_data: dict, app_model: App) -> None:
def _append_model_config_export_data(cls, export_data: dict, app_model: App):
"""
Append model config export data
:param export_data: export data

View File

@ -6,7 +6,7 @@ from models.model import AppMode
class AppModelConfigService:
@classmethod
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> dict:
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode):
if app_mode == AppMode.CHAT:
return ChatAppConfigManager.config_validate(tenant_id, config)
elif app_mode == AppMode.AGENT_CHAT:

View File

@ -2,7 +2,6 @@ import json
import logging
from typing import Optional, TypedDict, cast
from flask_login import current_user
from flask_sqlalchemy.pagination import Pagination
from configs import dify_config
@ -17,6 +16,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_was_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models.account import Account
from models.model import App, AppMode, AppModelConfig, Site
from models.tools import ApiToolProvider
@ -168,9 +168,13 @@ class AppService:
"""
Get App
"""
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# get original app model config
if app.mode == AppMode.AGENT_CHAT.value or app.is_agent:
model_config = app.app_model_config
if not model_config:
return app
agent_mode = model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
for tool in agent_mode.get("tools") or []:
@ -205,7 +209,8 @@ class AppService:
pass
# override agent mode
model_config.agent_mode = json.dumps(agent_mode)
if model_config:
model_config.agent_mode = json.dumps(agent_mode)
class ModifiedApp(App):
"""
@ -239,6 +244,7 @@ class AppService:
:param args: request args
:return: App instance
"""
assert current_user is not None
app.name = args["name"]
app.description = args["description"]
app.icon_type = args["icon_type"]
@ -259,6 +265,7 @@ class AppService:
:param name: new name
:return: App instance
"""
assert current_user is not None
app.name = name
app.updated_by = current_user.id
app.updated_at = naive_utc_now()
@ -274,6 +281,7 @@ class AppService:
:param icon_background: new icon_background
:return: App instance
"""
assert current_user is not None
app.icon = icon
app.icon_background = icon_background
app.updated_by = current_user.id
@ -291,7 +299,7 @@ class AppService:
"""
if enable_site == app.enable_site:
return app
assert current_user is not None
app.enable_site = enable_site
app.updated_by = current_user.id
app.updated_at = naive_utc_now()
@ -308,6 +316,7 @@ class AppService:
"""
if enable_api == app.enable_api:
return app
assert current_user is not None
app.enable_api = enable_api
app.updated_by = current_user.id
@ -316,7 +325,7 @@ class AppService:
return app
def delete_app(self, app: App) -> None:
def delete_app(self, app: App):
"""
Delete app
:param app: App instance
@ -331,7 +340,7 @@ class AppService:
# Trigger asynchronous deletion of app and related data
remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id)
def get_app_meta(self, app_model: App) -> dict:
def get_app_meta(self, app_model: App):
"""
Get app meta info
:param app_model: app model

View File

@ -12,7 +12,7 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from models.enums import MessageStatus
from models.model import App, AppMode, AppModelConfig, Message
from models.model import App, AppMode, Message
from services.errors.audio import (
AudioTooLargeServiceError,
NoAudioUploadedServiceError,
@ -40,7 +40,9 @@ class AudioService:
if "speech_to_text" not in features_dict or not features_dict["speech_to_text"].get("enabled"):
raise ValueError("Speech to text is not enabled")
else:
app_model_config: AppModelConfig = app_model.app_model_config
app_model_config = app_model.app_model_config
if not app_model_config:
raise ValueError("Speech to text is not enabled")
if not app_model_config.speech_to_text_dict["enabled"]:
raise ValueError("Speech to text is not enabled")

View File

@ -8,7 +8,7 @@ from services.auth.api_key_auth_factory import ApiKeyAuthFactory
class ApiKeyAuthService:
@staticmethod
def get_provider_auth_list(tenant_id: str) -> list:
def get_provider_auth_list(tenant_id: str):
data_source_api_key_bindings = (
db.session.query(DataSourceApiKeyAuthBinding)
.where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False))

View File

@ -70,7 +70,7 @@ class BillingService:
return response.json()
@staticmethod
def is_tenant_owner_or_admin(current_user):
def is_tenant_owner_or_admin(current_user: Account):
tenant_id = current_user.current_tenant_id
join: Optional[TenantAccountJoin] = (

View File

@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
class ClearFreePlanTenantExpiredLogs:
@classmethod
def _clear_message_related_tables(cls, session: Session, tenant_id: str, batch_message_ids: list[str]) -> None:
def _clear_message_related_tables(cls, session: Session, tenant_id: str, batch_message_ids: list[str]):
"""
Clean up message-related tables to avoid data redundancy.
This method cleans up tables that have foreign key relationships with Message.
@ -353,7 +353,7 @@ class ClearFreePlanTenantExpiredLogs:
thread_pool = ThreadPoolExecutor(max_workers=10)
def process_tenant(flask_app: Flask, tenant_id: str) -> None:
def process_tenant(flask_app: Flask, tenant_id: str):
try:
if (
not dify_config.BILLING_ENABLED

View File

@ -3,7 +3,7 @@ from extensions.ext_code_based_extension import code_based_extension
class CodeBasedExtensionService:
@staticmethod
def get_code_based_extension(module: str) -> list[dict]:
def get_code_based_extension(module: str):
module_extensions = code_based_extension.module_extensions(module)
return [
{

View File

@ -250,7 +250,7 @@ class ConversationService:
variable_id: str,
user: Optional[Union[Account, EndUser]],
new_value: Any,
) -> dict:
):
"""
Update a conversation variable's value.

View File

@ -8,7 +8,7 @@ import uuid
from collections import Counter
from typing import Any, Literal, Optional
from flask_login import current_user
import sqlalchemy as sa
from sqlalchemy import exists, func, select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
@ -27,6 +27,7 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models.account import Account, TenantAccountRole
from models.dataset import (
AppDatasetJoin,
@ -566,8 +567,11 @@ class DatasetService:
data: Update data dictionary
filtered_data: Filtered update data to modify
"""
# assert isinstance(current_user, Account) and current_user.current_tenant_id is not None
try:
model_manager = ModelManager()
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=data["embedding_model_provider"],
@ -679,8 +683,12 @@ class DatasetService:
data: Update data dictionary
filtered_data: Filtered update data to modify
"""
# assert isinstance(current_user, Account) and current_user.current_tenant_id is not None
model_manager = ModelManager()
try:
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=data["embedding_model_provider"],
@ -909,7 +917,9 @@ class DatasetService:
)
@staticmethod
def get_dataset_auto_disable_logs(dataset_id: str) -> dict:
def get_dataset_auto_disable_logs(dataset_id: str):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
features = FeatureService.get_features(current_user.current_tenant_id)
if not features.billing.enabled or features.billing.subscription.plan == "sandbox":
return {
@ -1114,6 +1124,8 @@ class DocumentService:
@staticmethod
def get_batch_documents(dataset_id: str, batch: str) -> list[Document]:
assert isinstance(current_user, Account)
documents = (
db.session.query(Document)
.where(
@ -1163,7 +1175,7 @@ class DocumentService:
file_ids = [
document.data_source_info_dict.get("upload_file_id", "")
for document in documents
if document.data_source_type == "upload_file"
if document.data_source_type == "upload_file" and document.data_source_info_dict
]
batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids)
@ -1173,6 +1185,8 @@ class DocumentService:
@staticmethod
def rename_document(dataset_id: str, document_id: str, name: str) -> Document:
assert isinstance(current_user, Account)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise ValueError("Dataset not found.")
@ -1202,6 +1216,7 @@ class DocumentService:
if document.indexing_status not in {"waiting", "parsing", "cleaning", "splitting", "indexing"}:
raise DocumentIndexingError()
# update document to be paused
assert current_user is not None
document.is_paused = True
document.paused_by = current_user.id
document.paused_at = naive_utc_now()
@ -1257,8 +1272,9 @@ class DocumentService:
# sync document indexing
document.indexing_status = "waiting"
data_source_info = document.data_source_info_dict
data_source_info["mode"] = "scrape"
document.data_source_info = json.dumps(data_source_info, ensure_ascii=False)
if data_source_info:
data_source_info["mode"] = "scrape"
document.data_source_info = json.dumps(data_source_info, ensure_ascii=False)
db.session.add(document)
db.session.commit()
@ -1287,6 +1303,9 @@ class DocumentService:
# check doc_form
DatasetService.check_doc_form(dataset, knowledge_config.doc_form)
# check document limit
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled:
@ -1887,6 +1906,8 @@ class DocumentService:
@staticmethod
def get_tenant_documents_count():
assert isinstance(current_user, Account)
documents_count = (
db.session.query(Document)
.where(
@ -1907,6 +1928,8 @@ class DocumentService:
dataset_process_rule: Optional[DatasetProcessRule] = None,
created_from: str = "web",
):
assert isinstance(current_user, Account)
DatasetService.check_dataset_model_setting(dataset)
document = DocumentService.get_document(dataset.id, document_data.original_document_id)
if document is None:
@ -1963,6 +1986,20 @@ class DocumentService:
notion_info_list = document_data.data_source.info_list.notion_info_list
for notion_info in notion_info_list:
workspace_id = notion_info.workspace_id
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.where(
sa.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
)
.first()
)
if not data_source_binding:
raise ValueError("Data source binding not found.")
for page in notion_info.pages:
data_source_info = {
"credential_id": notion_info.credential_id,
@ -2014,6 +2051,9 @@ class DocumentService:
@staticmethod
def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled:
@ -2453,6 +2493,9 @@ class SegmentService:
@classmethod
def create_segment(cls, args: dict, document: Document, dataset: Dataset):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
content = args["content"]
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
@ -2515,6 +2558,9 @@ class SegmentService:
@classmethod
def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
lock_name = f"multi_add_segment_lock_document_id_{document.id}"
increment_word_count = 0
with redis_client.lock(lock_name, timeout=600):
@ -2597,9 +2643,10 @@ class SegmentService:
return segment_data_list
@classmethod
def update_segment(
cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset
) -> DocumentSegment:
def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
indexing_cache_key = f"segment_{segment.id}_indexing"
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
@ -2793,6 +2840,7 @@ class SegmentService:
@classmethod
def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset):
assert isinstance(current_user, Account)
segments = (
db.session.query(DocumentSegment.index_node_id, DocumentSegment.word_count)
.where(
@ -2825,6 +2873,8 @@ class SegmentService:
def update_segments_status(
cls, segment_ids: list, action: Literal["enable", "disable"], dataset: Dataset, document: Document
):
assert current_user is not None
# Check if segment_ids is not empty to avoid WHERE false condition
if not segment_ids or len(segment_ids) == 0:
return
@ -2887,6 +2937,8 @@ class SegmentService:
def create_child_chunk(
cls, content: str, segment: DocumentSegment, document: Document, dataset: Dataset
) -> ChildChunk:
assert isinstance(current_user, Account)
lock_name = f"add_child_lock_{segment.id}"
with redis_client.lock(lock_name, timeout=20):
index_node_id = str(uuid.uuid4())
@ -2934,6 +2986,8 @@ class SegmentService:
document: Document,
dataset: Dataset,
) -> list[ChildChunk]:
assert isinstance(current_user, Account)
child_chunks = (
db.session.query(ChildChunk)
.where(
@ -3008,6 +3062,8 @@ class SegmentService:
document: Document,
dataset: Dataset,
) -> ChildChunk:
assert current_user is not None
try:
child_chunk.content = content
child_chunk.word_count = len(content)
@ -3038,6 +3094,8 @@ class SegmentService:
def get_child_chunks(
cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None
):
assert isinstance(current_user, Account)
query = (
select(ChildChunk)
.filter_by(

View File

@ -83,7 +83,7 @@ class ProviderResponse(BaseModel):
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def __init__(self, **data) -> None:
def __init__(self, **data):
super().__init__(**data)
url_prefix = (
@ -113,7 +113,7 @@ class ProviderWithModelsResponse(BaseModel):
status: CustomConfigurationStatus
models: list[ProviderModelWithStatusEntity]
def __init__(self, **data) -> None:
def __init__(self, **data):
super().__init__(**data)
url_prefix = (
@ -137,7 +137,7 @@ class SimpleProviderEntityResponse(SimpleProviderEntity):
tenant_id: str
def __init__(self, **data) -> None:
def __init__(self, **data):
super().__init__(**data)
url_prefix = (
@ -174,7 +174,7 @@ class ModelWithProviderEntityResponse(ProviderModelWithStatusEntity):
provider: SimpleProviderEntityResponse
def __init__(self, tenant_id: str, model: ModelWithProviderEntity) -> None:
def __init__(self, tenant_id: str, model: ModelWithProviderEntity):
dump_model = model.model_dump()
dump_model["provider"]["tenant_id"] = tenant_id
super().__init__(**dump_model)

View File

@ -6,7 +6,7 @@ class InvokeError(Exception):
description: Optional[str] = None
def __init__(self, description: Optional[str] = None) -> None:
def __init__(self, description: Optional[str] = None):
self.description = description
def __str__(self):

View File

@ -114,8 +114,9 @@ class ExternalDatasetService:
)
if external_knowledge_api is None:
raise ValueError("api template not found")
if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE:
args.get("settings")["api_key"] = external_knowledge_api.settings_dict.get("api_key")
settings = args.get("settings")
if settings and settings.get("api_key") == HIDDEN_VALUE and external_knowledge_api.settings_dict:
settings["api_key"] = external_knowledge_api.settings_dict.get("api_key")
external_knowledge_api.name = args.get("name")
external_knowledge_api.description = args.get("description", "")
@ -277,7 +278,7 @@ class ExternalDatasetService:
query: str,
external_retrieval_parameters: dict,
metadata_condition: Optional[MetadataCondition] = None,
) -> list:
):
external_knowledge_binding = (
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
)

View File

@ -3,7 +3,6 @@ import os
import uuid
from typing import Any, Literal, Union
from flask_login import current_user
from sqlalchemy import Engine
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
@ -20,6 +19,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from libs.helper import extract_tenant_id
from libs.login import current_user
from models.account import Account
from models.enums import CreatorUserRole
from models.model import EndUser, UploadFile
@ -120,7 +120,11 @@ class FileService:
return file_size <= file_size_limit
def upload_text(self, text: str, text_name: str) -> UploadFile:
@staticmethod
def upload_text(text: str, text_name: str) -> UploadFile:
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
if len(text_name) > 200:
text_name = text_name[:200]
# user uuid as file name

View File

@ -33,7 +33,7 @@ class HitTestingService:
retrieval_model: Any, # FIXME drop this any
external_retrieval_model: dict,
limit: int = 10,
) -> dict:
):
start = time.perf_counter()
# get retrieval model , if the model is not setting , using default
@ -98,7 +98,7 @@ class HitTestingService:
account: Account,
external_retrieval_model: dict,
metadata_filtering_conditions: dict,
) -> dict:
):
if dataset.provider != "external":
return {
"query": {"content": query},

View File

@ -25,10 +25,10 @@ logger = logging.getLogger(__name__)
class ModelLoadBalancingService:
def __init__(self) -> None:
def __init__(self):
self.provider_manager = ProviderManager()
def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str):
"""
enable model load balancing.
@ -49,7 +49,7 @@ class ModelLoadBalancingService:
# Enable model load balancing
provider_configuration.enable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str):
"""
disable model load balancing.
@ -295,7 +295,7 @@ class ModelLoadBalancingService:
def update_load_balancing_configs(
self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict], config_from: str
) -> None:
):
"""
Update load balancing configurations.
:param tenant_id: workspace id
@ -478,7 +478,7 @@ class ModelLoadBalancingService:
model_type: str,
credentials: dict,
config_id: Optional[str] = None,
) -> None:
):
"""
Validate load balancing credentials.
:param tenant_id: workspace id
@ -537,7 +537,7 @@ class ModelLoadBalancingService:
credentials: dict,
load_balancing_model_config: Optional[LoadBalancingModelConfig] = None,
validate: bool = True,
) -> dict:
):
"""
Validate custom credentials.
:param tenant_id: workspace id
@ -605,7 +605,7 @@ class ModelLoadBalancingService:
else:
raise ValueError("No credential schema found")
def _clear_credentials_cache(self, tenant_id: str, config_id: str) -> None:
def _clear_credentials_cache(self, tenant_id: str, config_id: str):
"""
Clear credentials cache.
:param tenant_id: workspace id

View File

@ -26,7 +26,7 @@ class ModelProviderService:
Model Provider Service
"""
def __init__(self) -> None:
def __init__(self):
self.provider_manager = ProviderManager()
def _get_provider_configuration(self, tenant_id: str, provider: str):
@ -142,7 +142,7 @@ class ModelProviderService:
provider_configuration = self._get_provider_configuration(tenant_id, provider)
return provider_configuration.get_provider_credential(credential_id=credential_id) # type: ignore
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict) -> None:
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict):
"""
validate provider credentials before saving.
@ -193,7 +193,7 @@ class ModelProviderService:
credential_name=credential_name,
)
def remove_provider_credential(self, tenant_id: str, provider: str, credential_id: str) -> None:
def remove_provider_credential(self, tenant_id: str, provider: str, credential_id: str):
"""
remove a saved provider credential (by credential_id).
:param tenant_id: workspace id
@ -204,7 +204,7 @@ class ModelProviderService:
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.delete_provider_credential(credential_id=credential_id)
def switch_active_provider_credential(self, tenant_id: str, provider: str, credential_id: str) -> None:
def switch_active_provider_credential(self, tenant_id: str, provider: str, credential_id: str):
"""
:param tenant_id: workspace id
:param provider: provider name
@ -232,9 +232,7 @@ class ModelProviderService:
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
)
def validate_model_credentials(
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict
) -> None:
def validate_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict):
"""
validate model credentials.
@ -303,9 +301,7 @@ class ModelProviderService:
credential_name=credential_name,
)
def remove_model_credential(
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str
) -> None:
def remove_model_credential(self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str):
"""
remove model credentials.
@ -323,7 +319,7 @@ class ModelProviderService:
def switch_active_custom_model_credential(
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str
) -> None:
):
"""
switch model credentials.
@ -341,7 +337,7 @@ class ModelProviderService:
def add_model_credential_to_model_list(
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str
) -> None:
):
"""
add model credentials to model list.
@ -357,7 +353,7 @@ class ModelProviderService:
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
)
def remove_model(self, tenant_id: str, provider: str, model_type: str, model: str) -> None:
def remove_model(self, tenant_id: str, provider: str, model_type: str, model: str):
"""
remove model credentials.
@ -485,7 +481,7 @@ class ModelProviderService:
logger.debug("get_default_model_of_model_type error: %s", e)
return None
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None:
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str):
"""
update default model of model type.
@ -517,7 +513,7 @@ class ModelProviderService:
return byte_data, mime_type
def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str) -> None:
def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str):
"""
switch preferred provider.
@ -534,7 +530,7 @@ class ModelProviderService:
# Switch preferred provider type
provider_configuration.switch_preferred_provider_type(preferred_provider_type_enum)
def enable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
def enable_model(self, tenant_id: str, provider: str, model: str, model_type: str):
"""
enable model.
@ -547,7 +543,7 @@ class ModelProviderService:
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type))
def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str):
"""
disable model.

View File

@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
class PluginDataMigration:
@classmethod
def migrate(cls) -> None:
def migrate(cls):
cls.migrate_db_records("providers", "provider_name", ModelProviderID) # large table
cls.migrate_db_records("provider_models", "provider_name", ModelProviderID)
cls.migrate_db_records("provider_orders", "provider_name", ModelProviderID)
@ -26,7 +26,7 @@ class PluginDataMigration:
cls.migrate_db_records("tool_builtin_providers", "provider", ToolProviderID)
@classmethod
def migrate_datasets(cls) -> None:
def migrate_datasets(cls):
table_name = "datasets"
provider_column_name = "embedding_model_provider"
@ -126,9 +126,7 @@ limit 1000"""
)
@classmethod
def migrate_db_records(
cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID]
) -> None:
def migrate_db_records(cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID]):
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
processed_count = 0

View File

@ -35,7 +35,7 @@ excluded_providers = ["time", "audio", "code", "webscraper"]
class PluginMigration:
@classmethod
def extract_plugins(cls, filepath: str, workers: int) -> None:
def extract_plugins(cls, filepath: str, workers: int):
"""
Migrate plugin.
"""
@ -57,7 +57,7 @@ class PluginMigration:
thread_pool = ThreadPoolExecutor(max_workers=workers)
def process_tenant(flask_app: Flask, tenant_id: str) -> None:
def process_tenant(flask_app: Flask, tenant_id: str):
with flask_app.app_context():
nonlocal handled_tenant_count
try:
@ -293,7 +293,7 @@ class PluginMigration:
return plugin_manifest[0].latest_package_identifier
@classmethod
def extract_unique_plugins_to_file(cls, extracted_plugins: str, output_file: str) -> None:
def extract_unique_plugins_to_file(cls, extracted_plugins: str, output_file: str):
"""
Extract unique plugins.
"""
@ -330,7 +330,7 @@ class PluginMigration:
return {"plugins": plugins, "plugin_not_exist": plugin_not_exist}
@classmethod
def install_plugins(cls, extracted_plugins: str, output_file: str, workers: int = 100) -> None:
def install_plugins(cls, extracted_plugins: str, output_file: str, workers: int = 100):
"""
Install plugins.
"""
@ -350,7 +350,7 @@ class PluginMigration:
if response.get("failed"):
plugin_install_failed.extend(response.get("failed", []))
def install(tenant_id: str, plugin_ids: list[str]) -> None:
def install(tenant_id: str, plugin_ids: list[str]):
logger.info("Installing %s plugins for tenant %s", len(plugin_ids), tenant_id)
# fetch plugin already installed
installed_plugins = manager.list_plugins(tenant_id)

View File

@ -19,7 +19,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
def get_type(self) -> str:
return RecommendAppType.BUILDIN
def get_recommended_apps_and_categories(self, language: str) -> dict:
def get_recommended_apps_and_categories(self, language: str):
result = self.fetch_recommended_apps_from_builtin(language)
return result
@ -28,7 +28,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
return result
@classmethod
def _get_builtin_data(cls) -> dict:
def _get_builtin_data(cls):
"""
Get builtin data.
:return:
@ -44,7 +44,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
return cls.builtin_data or {}
@classmethod
def fetch_recommended_apps_from_builtin(cls, language: str) -> dict:
def fetch_recommended_apps_from_builtin(cls, language: str):
"""
Fetch recommended apps from builtin.
:param language: language

View File

@ -13,7 +13,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
Retrieval recommended app from database
"""
def get_recommended_apps_and_categories(self, language: str) -> dict:
def get_recommended_apps_and_categories(self, language: str):
result = self.fetch_recommended_apps_from_db(language)
return result
@ -25,7 +25,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
return RecommendAppType.DATABASE
@classmethod
def fetch_recommended_apps_from_db(cls, language: str) -> dict:
def fetch_recommended_apps_from_db(cls, language: str):
"""
Fetch recommended apps from db.
:param language: language

View File

@ -5,7 +5,7 @@ class RecommendAppRetrievalBase(ABC):
"""Interface for recommend app retrieval."""
@abstractmethod
def get_recommended_apps_and_categories(self, language: str) -> dict:
def get_recommended_apps_and_categories(self, language: str):
raise NotImplementedError
@abstractmethod

View File

@ -24,7 +24,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
result = BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin(app_id)
return result
def get_recommended_apps_and_categories(self, language: str) -> dict:
def get_recommended_apps_and_categories(self, language: str):
try:
result = self.fetch_recommended_apps_from_dify_official(language)
except Exception as e:
@ -51,7 +51,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
return data
@classmethod
def fetch_recommended_apps_from_dify_official(cls, language: str) -> dict:
def fetch_recommended_apps_from_dify_official(cls, language: str):
"""
Fetch recommended apps from dify official.
:param language: language

View File

@ -6,7 +6,7 @@ from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFa
class RecommendedAppService:
@classmethod
def get_recommended_apps_and_categories(cls, language: str) -> dict:
def get_recommended_apps_and_categories(cls, language: str):
"""
Get recommended apps and categories.
:param language: language

View File

@ -12,7 +12,7 @@ from models.model import App, Tag, TagBinding
class TagService:
@staticmethod
def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = None) -> list:
def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = None):
query = (
db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
.outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
@ -25,7 +25,7 @@ class TagService:
return results
@staticmethod
def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list:
def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list):
# Check if tag_ids is not empty to avoid WHERE false condition
if not tag_ids or len(tag_ids) == 0:
return []
@ -51,7 +51,7 @@ class TagService:
return results
@staticmethod
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str) -> list:
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str):
if not tag_type or not tag_name:
return []
tags = (
@ -64,7 +64,7 @@ class TagService:
return tags
@staticmethod
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list:
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str):
tags = (
db.session.query(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id)

View File

@ -551,7 +551,7 @@ class BuiltinToolManageService:
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
data=provider_controller,
name_func=lambda x: x.identity.name,
name_func=lambda x: x.entity.identity.name,
):
continue

View File

@ -226,7 +226,7 @@ class MCPToolManageService:
def update_mcp_provider_credentials(
cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False
):
provider_controller = MCPToolProviderController._from_db(mcp_provider)
provider_controller = MCPToolProviderController.from_db(mcp_provider)
tool_configuration = ProviderConfigEncrypter(
tenant_id=mcp_provider.tenant_id,
config=list(provider_controller.get_credentials_schema()), # ty: ignore [invalid-argument-type]

View File

@ -37,7 +37,7 @@ class WorkflowToolManageService:
parameters: list[Mapping[str, Any]],
privacy_policy: str = "",
labels: list[str] | None = None,
) -> dict:
):
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
# check if the name is unique
@ -103,7 +103,7 @@ class WorkflowToolManageService:
parameters: list[Mapping[str, Any]],
privacy_policy: str = "",
labels: list[str] | None = None,
) -> dict:
):
"""
Update a workflow tool.
:param user_id: the user id
@ -217,7 +217,7 @@ class WorkflowToolManageService:
return result
@classmethod
def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict:
def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str):
"""
Delete a workflow tool.
:param user_id: the user id
@ -233,7 +233,7 @@ class WorkflowToolManageService:
return {"result": "success"}
@classmethod
def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict:
def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str):
"""
Get a workflow tool.
:param user_id: the user id
@ -249,7 +249,7 @@ class WorkflowToolManageService:
return cls._get_workflow_tool(tenant_id, db_tool)
@classmethod
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict:
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str):
"""
Get a workflow tool.
:param user_id: the user id
@ -265,7 +265,7 @@ class WorkflowToolManageService:
return cls._get_workflow_tool(tenant_id, db_tool)
@classmethod
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None) -> dict:
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None):
"""
Get a workflow tool.
:db_tool: the database tool

View File

@ -145,7 +145,7 @@ class WebsiteService:
return encrypter.decrypt_token(tenant_id=tenant_id, token=api_key)
@classmethod
def document_create_args_validate(cls, args: dict) -> None:
def document_create_args_validate(cls, args: dict):
"""Validate arguments for document creation."""
try:
WebsiteCrawlApiRequest.from_args(args)

View File

@ -217,7 +217,7 @@ class WorkflowConverter:
return app_config
def _convert_to_start_node(self, variables: list[VariableEntity]) -> dict:
def _convert_to_start_node(self, variables: list[VariableEntity]):
"""
Convert to Start Node
:param variables: list of variables
@ -384,7 +384,7 @@ class WorkflowConverter:
prompt_template: PromptTemplateEntity,
file_upload: Optional[FileUploadConfig] = None,
external_data_variable_node_mapping: dict[str, str] | None = None,
) -> dict:
):
"""
Convert to LLM Node
:param original_app_mode: original app mode
@ -550,7 +550,7 @@ class WorkflowConverter:
return template
def _convert_to_end_node(self) -> dict:
def _convert_to_end_node(self):
"""
Convert to End Node
:return:
@ -566,7 +566,7 @@ class WorkflowConverter:
},
}
def _convert_to_answer_node(self) -> dict:
def _convert_to_answer_node(self):
"""
Convert to Answer Node
:return:
@ -578,7 +578,7 @@ class WorkflowConverter:
"data": {"title": "ANSWER", "type": NodeType.ANSWER.value, "answer": "{{#llm.text#}}"},
}
def _create_edge(self, source: str, target: str) -> dict:
def _create_edge(self, source: str, target: str):
"""
Create Edge
:param source: source node id
@ -587,7 +587,7 @@ class WorkflowConverter:
"""
return {"id": f"{source}-{target}", "source": source, "target": target}
def _append_node(self, graph: dict, node: dict) -> dict:
def _append_node(self, graph: dict, node: dict):
"""
Append Node to Graph

View File

@ -23,7 +23,7 @@ class WorkflowAppService:
limit: int = 20,
created_by_end_user_session_id: str | None = None,
created_by_account: str | None = None,
) -> dict:
):
"""
Get paginate workflow app logs using SQLAlchemy 2.0 style
:param session: SQLAlchemy session

View File

@ -85,7 +85,7 @@ class DraftVarLoader(VariableLoader):
app_id: str,
tenant_id: str,
fallback_variables: Sequence[Variable] | None = None,
) -> None:
):
self._engine = engine
self._app_id = app_id
self._tenant_id = tenant_id
@ -185,7 +185,7 @@ class DraftVarLoader(VariableLoader):
class WorkflowDraftVariableService:
_session: Session
def __init__(self, session: Session) -> None:
def __init__(self, session: Session):
"""
Initialize the WorkflowDraftVariableService with a SQLAlchemy session.
@ -602,7 +602,7 @@ def _batch_upsert_draft_variable(
session: Session,
draft_vars: Sequence[WorkflowDraftVariable],
policy: _UpsertPolicy = _UpsertPolicy.OVERWRITE,
) -> None:
):
if not draft_vars:
return None
# Although we could use SQLAlchemy ORM operations here, we choose not to for several reasons:

View File

@ -619,7 +619,7 @@ class WorkflowService:
return new_app
def validate_features_structure(self, app_model: App, features: dict) -> dict:
def validate_features_structure(self, app_model: App, features: dict):
if app_model.mode == AppMode.ADVANCED_CHAT.value:
return AdvancedChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True