Merge branch 'main' into feat/mcp-06-18

This commit is contained in:
Novice
2025-09-16 17:09:46 +08:00
705 changed files with 18417 additions and 4880 deletions

View File

@ -5,7 +5,7 @@ import secrets
import uuid
from datetime import UTC, datetime, timedelta
from hashlib import sha256
from typing import Any, Optional, cast
from typing import Any, cast
from pydantic import BaseModel
from sqlalchemy import func
@ -37,7 +37,6 @@ from services.billing_service import BillingService
from services.errors.account import (
AccountAlreadyInTenantError,
AccountLoginError,
AccountNotFoundError,
AccountNotLinkTenantError,
AccountPasswordError,
AccountRegisterError,
@ -65,7 +64,11 @@ from tasks.mail_owner_transfer_task import (
send_old_owner_transfer_notify_email_task,
send_owner_transfer_confirm_task,
)
from tasks.mail_reset_password_task import send_reset_password_mail_task
from tasks.mail_register_task import send_email_register_mail_task, send_email_register_mail_task_when_account_exist
from tasks.mail_reset_password_task import (
send_reset_password_mail_task,
send_reset_password_mail_task_when_account_not_exist,
)
logger = logging.getLogger(__name__)
@ -82,8 +85,9 @@ REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
class AccountService:
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1)
email_register_rate_limiter = RateLimiter(prefix="email_register_rate_limit", max_attempts=1, time_window=60 * 1)
email_code_login_rate_limiter = RateLimiter(
prefix="email_code_login_rate_limit", max_attempts=1, time_window=60 * 1
prefix="email_code_login_rate_limit", max_attempts=3, time_window=300 * 1
)
email_code_account_deletion_rate_limiter = RateLimiter(
prefix="email_code_account_deletion_rate_limit", max_attempts=1, time_window=60 * 1
@ -95,6 +99,7 @@ class AccountService:
FORGOT_PASSWORD_MAX_ERROR_LIMITS = 5
CHANGE_EMAIL_MAX_ERROR_LIMITS = 5
OWNER_TRANSFER_MAX_ERROR_LIMITS = 5
EMAIL_REGISTER_MAX_ERROR_LIMITS = 5
@staticmethod
def _get_refresh_token_key(refresh_token: str) -> str:
@ -166,12 +171,12 @@ class AccountService:
return token
@staticmethod
def authenticate(email: str, password: str, invite_token: Optional[str] = None) -> Account:
def authenticate(email: str, password: str, invite_token: str | None = None) -> Account:
"""authenticate account with email and password"""
account = db.session.query(Account).filter_by(email=email).first()
if not account:
raise AccountNotFoundError()
raise AccountPasswordError("Invalid email or password.")
if account.status == AccountStatus.BANNED.value:
raise AccountLoginError("Account is banned.")
@ -223,9 +228,9 @@ class AccountService:
email: str,
name: str,
interface_language: str,
password: Optional[str] = None,
password: str | None = None,
interface_theme: str = "light",
is_setup: Optional[bool] = False,
is_setup: bool | None = False,
) -> Account:
"""create account"""
if not FeatureService.get_system_features().is_allow_register and not is_setup:
@ -271,7 +276,7 @@ class AccountService:
@staticmethod
def create_account_and_tenant(
email: str, name: str, interface_language: str, password: Optional[str] = None
email: str, name: str, interface_language: str, password: str | None = None
) -> Account:
"""create account"""
account = AccountService.create_account(
@ -296,7 +301,9 @@ class AccountService:
if cls.email_code_account_deletion_rate_limiter.is_rate_limited(email):
from controllers.console.auth.error import EmailCodeAccountDeletionRateLimitExceededError
raise EmailCodeAccountDeletionRateLimitExceededError()
raise EmailCodeAccountDeletionRateLimitExceededError(
int(cls.email_code_account_deletion_rate_limiter.time_window / 60)
)
send_account_deletion_verification_code.delay(to=email, code=code)
@ -323,7 +330,7 @@ class AccountService:
"""Link account integrate"""
try:
# Query whether there is an existing binding record for the same provider
account_integrate: Optional[AccountIntegrate] = (
account_integrate: AccountIntegrate | None = (
db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first()
)
@ -384,7 +391,7 @@ class AccountService:
db.session.commit()
@staticmethod
def login(account: Account, *, ip_address: Optional[str] = None) -> TokenPair:
def login(account: Account, *, ip_address: str | None = None) -> TokenPair:
if ip_address:
AccountService.update_login_info(account=account, ip_address=ip_address)
@ -432,9 +439,10 @@ class AccountService:
@classmethod
def send_reset_password_email(
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
account: Account | None = None,
email: str | None = None,
language: str = "en-US",
is_allow_register: bool = False,
):
account_email = account.email if account else email
if account_email is None:
@ -443,26 +451,67 @@ class AccountService:
if cls.reset_password_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import PasswordResetRateLimitExceededError
raise PasswordResetRateLimitExceededError()
raise PasswordResetRateLimitExceededError(int(cls.reset_password_rate_limiter.time_window / 60))
code, token = cls.generate_reset_password_token(account_email, account)
send_reset_password_mail_task.delay(
language=language,
to=account_email,
code=code,
)
if account:
send_reset_password_mail_task.delay(
language=language,
to=account_email,
code=code,
)
else:
send_reset_password_mail_task_when_account_not_exist.delay(
language=language,
to=account_email,
is_allow_register=is_allow_register,
)
cls.reset_password_rate_limiter.increment_rate_limit(account_email)
return token
@classmethod
def send_email_register_email(
cls,
account: Account | None = None,
email: str | None = None,
language: str = "en-US",
):
account_email = account.email if account else email
if account_email is None:
raise ValueError("Email must be provided.")
if cls.email_register_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import EmailRegisterRateLimitExceededError
raise EmailRegisterRateLimitExceededError(int(cls.email_register_rate_limiter.time_window / 60))
code, token = cls.generate_email_register_token(account_email)
if account:
send_email_register_mail_task_when_account_exist.delay(
language=language,
to=account_email,
account_name=account.name,
)
else:
send_email_register_mail_task.delay(
language=language,
to=account_email,
code=code,
)
cls.email_register_rate_limiter.increment_rate_limit(account_email)
return token
@classmethod
def send_change_email_email(
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
old_email: Optional[str] = None,
account: Account | None = None,
email: str | None = None,
old_email: str | None = None,
language: str = "en-US",
phase: Optional[str] = None,
phase: str | None = None,
):
account_email = account.email if account else email
if account_email is None:
@ -473,7 +522,7 @@ class AccountService:
if cls.change_email_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import EmailChangeRateLimitExceededError
raise EmailChangeRateLimitExceededError()
raise EmailChangeRateLimitExceededError(int(cls.change_email_rate_limiter.time_window / 60))
code, token = cls.generate_change_email_token(account_email, account, old_email=old_email)
@ -489,8 +538,8 @@ class AccountService:
@classmethod
def send_change_email_completed_notify_email(
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
account: Account | None = None,
email: str | None = None,
language: str = "en-US",
):
account_email = account.email if account else email
@ -505,10 +554,10 @@ class AccountService:
@classmethod
def send_owner_transfer_email(
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
account: Account | None = None,
email: str | None = None,
language: str = "en-US",
workspace_name: Optional[str] = "",
workspace_name: str | None = "",
):
account_email = account.email if account else email
if account_email is None:
@ -517,7 +566,7 @@ class AccountService:
if cls.owner_transfer_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import OwnerTransferRateLimitExceededError
raise OwnerTransferRateLimitExceededError()
raise OwnerTransferRateLimitExceededError(int(cls.owner_transfer_rate_limiter.time_window / 60))
code, token = cls.generate_owner_transfer_token(account_email, account)
workspace_name = workspace_name or ""
@ -534,10 +583,10 @@ class AccountService:
@classmethod
def send_old_owner_transfer_notify_email(
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
account: Account | None = None,
email: str | None = None,
language: str = "en-US",
workspace_name: Optional[str] = "",
workspace_name: str | None = "",
new_owner_email: str = "",
):
account_email = account.email if account else email
@ -555,10 +604,10 @@ class AccountService:
@classmethod
def send_new_owner_transfer_notify_email(
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
account: Account | None = None,
email: str | None = None,
language: str = "en-US",
workspace_name: Optional[str] = "",
workspace_name: str | None = "",
):
account_email = account.email if account else email
if account_email is None:
@ -575,8 +624,8 @@ class AccountService:
def generate_reset_password_token(
cls,
email: str,
account: Optional[Account] = None,
code: Optional[str] = None,
account: Account | None = None,
code: str | None = None,
additional_data: dict[str, Any] = {},
):
if not code:
@ -587,13 +636,26 @@ class AccountService:
)
return code, token
@classmethod
def generate_email_register_token(
cls,
email: str,
code: str | None = None,
additional_data: dict[str, Any] = {},
):
if not code:
code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)])
additional_data["code"] = code
token = TokenManager.generate_token(email=email, token_type="email_register", additional_data=additional_data)
return code, token
@classmethod
def generate_change_email_token(
cls,
email: str,
account: Optional[Account] = None,
code: Optional[str] = None,
old_email: Optional[str] = None,
account: Account | None = None,
code: str | None = None,
old_email: str | None = None,
additional_data: dict[str, Any] = {},
):
if not code:
@ -609,8 +671,8 @@ class AccountService:
def generate_owner_transfer_token(
cls,
email: str,
account: Optional[Account] = None,
code: Optional[str] = None,
account: Account | None = None,
code: str | None = None,
additional_data: dict[str, Any] = {},
):
if not code:
@ -625,6 +687,10 @@ class AccountService:
def revoke_reset_password_token(cls, token: str):
TokenManager.revoke_token(token, "reset_password")
@classmethod
def revoke_email_register_token(cls, token: str):
TokenManager.revoke_token(token, "email_register")
@classmethod
def revoke_change_email_token(cls, token: str):
TokenManager.revoke_token(token, "change_email")
@ -634,22 +700,26 @@ class AccountService:
TokenManager.revoke_token(token, "owner_transfer")
@classmethod
def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]:
def get_reset_password_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "reset_password")
@classmethod
def get_change_email_data(cls, token: str) -> Optional[dict[str, Any]]:
def get_email_register_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "email_register")
@classmethod
def get_change_email_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "change_email")
@classmethod
def get_owner_transfer_data(cls, token: str) -> Optional[dict[str, Any]]:
def get_owner_transfer_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "owner_transfer")
@classmethod
def send_email_code_login_email(
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
account: Account | None = None,
email: str | None = None,
language: str = "en-US",
):
email = account.email if account else email
@ -658,7 +728,7 @@ class AccountService:
if cls.email_code_login_rate_limiter.is_rate_limited(email):
from controllers.console.auth.error import EmailCodeLoginRateLimitExceededError
raise EmailCodeLoginRateLimitExceededError()
raise EmailCodeLoginRateLimitExceededError(int(cls.email_code_login_rate_limiter.time_window / 60))
code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)])
token = TokenManager.generate_token(
@ -673,7 +743,7 @@ class AccountService:
return token
@classmethod
def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]:
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "email_code_login")
@classmethod
@ -744,6 +814,16 @@ class AccountService:
count = int(count) + 1
redis_client.setex(key, dify_config.FORGOT_PASSWORD_LOCKOUT_DURATION, count)
@staticmethod
@redis_fallback(default_return=None)
def add_email_register_error_rate_limit(email: str) -> None:
key = f"email_register_error_rate_limit:{email}"
count = redis_client.get(key)
if count is None:
count = 0
count = int(count) + 1
redis_client.setex(key, dify_config.EMAIL_REGISTER_LOCKOUT_DURATION, count)
@staticmethod
@redis_fallback(default_return=False)
def is_forgot_password_error_rate_limit(email: str) -> bool:
@ -763,6 +843,24 @@ class AccountService:
key = f"forgot_password_error_rate_limit:{email}"
redis_client.delete(key)
@staticmethod
@redis_fallback(default_return=False)
def is_email_register_error_rate_limit(email: str) -> bool:
key = f"email_register_error_rate_limit:{email}"
count = redis_client.get(key)
if count is None:
return False
count = int(count)
if count > AccountService.EMAIL_REGISTER_MAX_ERROR_LIMITS:
return True
return False
@staticmethod
@redis_fallback(default_return=None)
def reset_email_register_error_rate_limit(email: str):
key = f"email_register_error_rate_limit:{email}"
redis_client.delete(key)
@staticmethod
@redis_fallback(default_return=None)
def add_change_email_error_rate_limit(email: str):
@ -867,7 +965,7 @@ class AccountService:
class TenantService:
@staticmethod
def create_tenant(name: str, is_setup: Optional[bool] = False, is_from_dashboard: Optional[bool] = False) -> Tenant:
def create_tenant(name: str, is_setup: bool | None = False, is_from_dashboard: bool | None = False) -> Tenant:
"""Create tenant"""
if (
not FeatureService.get_system_features().is_allow_create_workspace
@ -898,9 +996,7 @@ class TenantService:
return tenant
@staticmethod
def create_owner_tenant_if_not_exist(
account: Account, name: Optional[str] = None, is_setup: Optional[bool] = False
):
def create_owner_tenant_if_not_exist(account: Account, name: str | None = None, is_setup: bool | None = False):
"""Check if user have a workspace or not"""
available_ta = (
db.session.query(TenantAccountJoin)
@ -972,7 +1068,7 @@ class TenantService:
return tenant
@staticmethod
def switch_tenant(account: Account, tenant_id: Optional[str] = None):
def switch_tenant(account: Account, tenant_id: str | None = None):
"""Switch the current workspace for the account"""
# Ensure tenant_id is provided
@ -1054,7 +1150,7 @@ class TenantService:
)
@staticmethod
def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountRole]:
def get_user_role(account: Account, tenant: Tenant) -> TenantAccountRole | None:
"""Get the role of the current account for a given tenant"""
join = (
db.session.query(TenantAccountJoin)
@ -1194,13 +1290,13 @@ class RegisterService:
cls,
email,
name,
password: Optional[str] = None,
open_id: Optional[str] = None,
provider: Optional[str] = None,
language: Optional[str] = None,
status: Optional[AccountStatus] = None,
is_setup: Optional[bool] = False,
create_workspace_required: Optional[bool] = True,
password: str | None = None,
open_id: str | None = None,
provider: str | None = None,
language: str | None = None,
status: AccountStatus | None = None,
is_setup: bool | None = False,
create_workspace_required: bool | None = True,
) -> Account:
db.session.begin_nested()
"""Register account"""
@ -1317,9 +1413,7 @@ class RegisterService:
redis_client.delete(cls._get_invitation_token_key(token))
@classmethod
def get_invitation_if_token_valid(
cls, workspace_id: Optional[str], email: str, token: str
) -> Optional[dict[str, Any]]:
def get_invitation_if_token_valid(cls, workspace_id: str | None, email: str, token: str) -> dict[str, Any] | None:
invitation_data = cls.get_invitation_by_token(token, workspace_id, email)
if not invitation_data:
return None
@ -1358,8 +1452,8 @@ class RegisterService:
@classmethod
def get_invitation_by_token(
cls, token: str, workspace_id: Optional[str] = None, email: Optional[str] = None
) -> Optional[dict[str, str]]:
cls, token: str, workspace_id: str | None = None, email: str | None = None
) -> dict[str, str] | None:
if workspace_id is not None and email is not None:
email_hash = sha256(email.encode()).hexdigest()
cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}"

View File

@ -32,14 +32,14 @@ class AdvancedPromptTemplateService:
def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str):
context_prompt = copy.deepcopy(CONTEXT)
if app_mode == AppMode.CHAT.value:
if app_mode == AppMode.CHAT:
if model_mode == "completion":
return cls.get_completion_prompt(
copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
)
elif model_mode == "chat":
return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
elif app_mode == AppMode.COMPLETION.value:
elif app_mode == AppMode.COMPLETION:
if model_mode == "completion":
return cls.get_completion_prompt(
copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
@ -73,7 +73,7 @@ class AdvancedPromptTemplateService:
def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str):
baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)
if app_mode == AppMode.CHAT.value:
if app_mode == AppMode.CHAT:
if model_mode == "completion":
return cls.get_completion_prompt(
copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt
@ -82,7 +82,7 @@ class AdvancedPromptTemplateService:
return cls.get_chat_prompt(
copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt
)
elif app_mode == AppMode.COMPLETION.value:
elif app_mode == AppMode.COMPLETION:
if model_mode == "completion":
return cls.get_completion_prompt(
copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG),

View File

@ -1,5 +1,5 @@
import threading
from typing import Any, Optional
from typing import Any
import pytz
@ -35,7 +35,7 @@ class AgentService:
if not conversation:
raise ValueError(f"Conversation not found: {conversation_id}")
message: Optional[Message] = (
message: Message | None = (
db.session.query(Message)
.where(
Message.id == message_id,

View File

@ -1,5 +1,4 @@
import uuid
from typing import Optional
import pandas as pd
from sqlalchemy import or_, select
@ -42,7 +41,7 @@ class AppAnnotationService:
if not message:
raise NotFound("Message Not Exists.")
annotation: Optional[MessageAnnotation] = message.annotation
annotation: MessageAnnotation | None = message.annotation
# save the message annotation
if annotation:
annotation.content = args["answer"]

View File

@ -4,7 +4,6 @@ import logging
import uuid
from collections.abc import Mapping
from enum import StrEnum
from typing import Optional
from urllib.parse import urlparse
from uuid import uuid4
@ -61,8 +60,8 @@ class ImportStatus(StrEnum):
class Import(BaseModel):
id: str
status: ImportStatus
app_id: Optional[str] = None
app_mode: Optional[str] = None
app_id: str | None = None
app_mode: str | None = None
current_dsl_version: str = CURRENT_DSL_VERSION
imported_dsl_version: str = ""
error: str = ""
@ -121,14 +120,14 @@ class AppDslService:
*,
account: Account,
import_mode: str,
yaml_content: Optional[str] = None,
yaml_url: Optional[str] = None,
name: Optional[str] = None,
description: Optional[str] = None,
icon_type: Optional[str] = None,
icon: Optional[str] = None,
icon_background: Optional[str] = None,
app_id: Optional[str] = None,
yaml_content: str | None = None,
yaml_url: str | None = None,
name: str | None = None,
description: str | None = None,
icon_type: str | None = None,
icon: str | None = None,
icon_background: str | None = None,
app_id: str | None = None,
) -> Import:
"""Import an app from YAML content or URL."""
import_id = str(uuid.uuid4())
@ -407,15 +406,15 @@ class AppDslService:
def _create_or_update_app(
self,
*,
app: Optional[App],
app: App | None,
data: dict,
account: Account,
name: Optional[str] = None,
description: Optional[str] = None,
icon_type: Optional[str] = None,
icon: Optional[str] = None,
icon_background: Optional[str] = None,
dependencies: Optional[list[PluginDependency]] = None,
name: str | None = None,
description: str | None = None,
icon_type: str | None = None,
icon: str | None = None,
icon_background: str | None = None,
dependencies: list[PluginDependency] | None = None,
) -> App:
"""Create a new app or update an existing one."""
app_data = data.get("app", {})
@ -533,7 +532,7 @@ class AppDslService:
return app
@classmethod
def export_dsl(cls, app_model: App, include_secret: bool = False, workflow_id: Optional[str] = None) -> str:
def export_dsl(cls, app_model: App, include_secret: bool = False, workflow_id: str | None = None) -> str:
"""
Export app
:param app_model: App instance
@ -566,7 +565,7 @@ class AppDslService:
@classmethod
def _append_workflow_export_data(
cls, *, export_data: dict, app_model: App, include_secret: bool, workflow_id: Optional[str] = None
cls, *, export_data: dict, app_model: App, include_secret: bool, workflow_id: str | None = None
):
"""
Append workflow export data

View File

@ -1,6 +1,6 @@
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Optional, Union
from typing import Any, Union
from openai._exceptions import RateLimitError
@ -60,7 +60,7 @@ class AppGenerateService:
request_id = RateLimit.gen_request_key()
try:
request_id = rate_limit.enter(request_id)
if app_model.mode == AppMode.COMPLETION.value:
if app_model.mode == AppMode.COMPLETION:
return rate_limit.generate(
CompletionAppGenerator.convert_to_event_stream(
CompletionAppGenerator().generate(
@ -69,7 +69,7 @@ class AppGenerateService:
),
request_id=request_id,
)
elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
elif app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
return rate_limit.generate(
AgentChatAppGenerator.convert_to_event_stream(
AgentChatAppGenerator().generate(
@ -78,7 +78,7 @@ class AppGenerateService:
),
request_id,
)
elif app_model.mode == AppMode.CHAT.value:
elif app_model.mode == AppMode.CHAT:
return rate_limit.generate(
ChatAppGenerator.convert_to_event_stream(
ChatAppGenerator().generate(
@ -87,7 +87,7 @@ class AppGenerateService:
),
request_id=request_id,
)
elif app_model.mode == AppMode.ADVANCED_CHAT.value:
elif app_model.mode == AppMode.ADVANCED_CHAT:
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
return rate_limit.generate(
@ -103,7 +103,7 @@ class AppGenerateService:
),
request_id=request_id,
)
elif app_model.mode == AppMode.WORKFLOW.value:
elif app_model.mode == AppMode.WORKFLOW:
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
return rate_limit.generate(
@ -155,14 +155,14 @@ class AppGenerateService:
@classmethod
def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
if app_model.mode == AppMode.ADVANCED_CHAT.value:
if app_model.mode == AppMode.ADVANCED_CHAT:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().single_iteration_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
)
elif app_model.mode == AppMode.WORKFLOW.value:
elif app_model.mode == AppMode.WORKFLOW:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_iteration_generate(
@ -174,14 +174,14 @@ class AppGenerateService:
@classmethod
def generate_single_loop(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
if app_model.mode == AppMode.ADVANCED_CHAT.value:
if app_model.mode == AppMode.ADVANCED_CHAT:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().single_loop_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
)
elif app_model.mode == AppMode.WORKFLOW.value:
elif app_model.mode == AppMode.WORKFLOW:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_loop_generate(
@ -214,7 +214,7 @@ class AppGenerateService:
)
@classmethod
def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom, workflow_id: Optional[str] = None) -> Workflow:
def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom, workflow_id: str | None = None) -> Workflow:
"""
Get workflow
:param app_model: app model

View File

@ -1,6 +1,6 @@
import json
import logging
from typing import Optional, TypedDict, cast
from typing import TypedDict, cast
from flask_sqlalchemy.pagination import Pagination
@ -40,15 +40,15 @@ class AppService:
filters = [App.tenant_id == tenant_id, App.is_universal == False]
if args["mode"] == "workflow":
filters.append(App.mode == AppMode.WORKFLOW.value)
filters.append(App.mode == AppMode.WORKFLOW)
elif args["mode"] == "completion":
filters.append(App.mode == AppMode.COMPLETION.value)
filters.append(App.mode == AppMode.COMPLETION)
elif args["mode"] == "chat":
filters.append(App.mode == AppMode.CHAT.value)
filters.append(App.mode == AppMode.CHAT)
elif args["mode"] == "advanced-chat":
filters.append(App.mode == AppMode.ADVANCED_CHAT.value)
filters.append(App.mode == AppMode.ADVANCED_CHAT)
elif args["mode"] == "agent-chat":
filters.append(App.mode == AppMode.AGENT_CHAT.value)
filters.append(App.mode == AppMode.AGENT_CHAT)
if args.get("is_created_by_me", False):
filters.append(App.created_by == user_id)
@ -171,7 +171,7 @@ class AppService:
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# get original app model config
if app.mode == AppMode.AGENT_CHAT.value or app.is_agent:
if app.mode == AppMode.AGENT_CHAT or app.is_agent:
model_config = app.app_model_config
if not model_config:
return app
@ -370,7 +370,7 @@ class AppService:
}
)
else:
app_model_config: Optional[AppModelConfig] = app_model.app_model_config
app_model_config: AppModelConfig | None = app_model.app_model_config
if not app_model_config:
return meta
@ -393,7 +393,7 @@ class AppService:
meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon"
elif provider_type == "api":
try:
provider: Optional[ApiToolProvider] = (
provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider).where(ApiToolProvider.id == provider_id).first()
)
if provider is None:

View File

@ -2,7 +2,6 @@ import io
import logging
import uuid
from collections.abc import Generator
from typing import Optional
from flask import Response, stream_with_context
from werkzeug.datastructures import FileStorage
@ -30,8 +29,8 @@ logger = logging.getLogger(__name__)
class AudioService:
@classmethod
def transcript_asr(cls, app_model: App, file: FileStorage, end_user: Optional[str] = None):
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
def transcript_asr(cls, app_model: App, file: FileStorage, end_user: str | None = None):
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app_model.workflow
if workflow is None:
raise ValueError("Speech to text is not enabled")
@ -77,18 +76,18 @@ class AudioService:
def transcript_tts(
cls,
app_model: App,
text: Optional[str] = None,
voice: Optional[str] = None,
end_user: Optional[str] = None,
message_id: Optional[str] = None,
text: str | None = None,
voice: str | None = None,
end_user: str | None = None,
message_id: str | None = None,
is_draft: bool = False,
):
from app import app
def invoke_tts(text_content: str, app_model: App, voice: Optional[str] = None, is_draft: bool = False):
def invoke_tts(text_content: str, app_model: App, voice: str | None = None, is_draft: bool = False):
with app.app_context():
if voice is None:
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
if is_draft:
workflow = WorkflowService().get_draft_workflow(app_model=app_model)
else:

View File

@ -1,5 +1,5 @@
import os
from typing import Literal, Optional
from typing import Literal
import httpx
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
@ -73,7 +73,7 @@ class BillingService:
def is_tenant_owner_or_admin(current_user: Account):
tenant_id = current_user.current_tenant_id
join: Optional[TenantAccountJoin] = (
join: TenantAccountJoin | None = (
db.session.query(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id)
.first()

View File

@ -1,7 +1,7 @@
import contextlib
import logging
from collections.abc import Callable, Sequence
from typing import Any, Optional, Union
from typing import Any, Union
from sqlalchemy import asc, desc, func, or_, select
from sqlalchemy.orm import Session
@ -36,12 +36,12 @@ class ConversationService:
*,
session: Session,
app_model: App,
user: Optional[Union[Account, EndUser]],
last_id: Optional[str],
user: Union[Account, EndUser] | None,
last_id: str | None,
limit: int,
invoke_from: InvokeFrom,
include_ids: Optional[Sequence[str]] = None,
exclude_ids: Optional[Sequence[str]] = None,
include_ids: Sequence[str] | None = None,
exclude_ids: Sequence[str] | None = None,
sort_by: str = "-updated_at",
) -> InfiniteScrollPagination:
if not user:
@ -118,7 +118,7 @@ class ConversationService:
cls,
app_model: App,
conversation_id: str,
user: Optional[Union[Account, EndUser]],
user: Union[Account, EndUser] | None,
name: str,
auto_generate: bool,
):
@ -158,7 +158,7 @@ class ConversationService:
return conversation
@classmethod
def get_conversation(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
def get_conversation(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
conversation = (
db.session.query(Conversation)
.where(
@ -178,7 +178,7 @@ class ConversationService:
return conversation
@classmethod
def delete(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
def delete(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
try:
logger.info(
"Initiating conversation deletion for app_name %s, conversation_id: %s",
@ -200,9 +200,9 @@ class ConversationService:
cls,
app_model: App,
conversation_id: str,
user: Optional[Union[Account, EndUser]],
user: Union[Account, EndUser] | None,
limit: int,
last_id: Optional[str],
last_id: str | None,
) -> InfiniteScrollPagination:
conversation = cls.get_conversation(app_model, conversation_id, user)
@ -222,8 +222,8 @@ class ConversationService:
# Filter for variables created after the last_id
stmt = stmt.where(ConversationVariable.created_at > last_variable.created_at)
# Apply limit to query
query_stmt = stmt.limit(limit) # Get one extra to check if there are more
# Apply limit to query: fetch one extra row to determine has_more
query_stmt = stmt.limit(limit + 1)
rows = session.scalars(query_stmt).all()
has_more = False
@ -248,7 +248,7 @@ class ConversationService:
app_model: App,
conversation_id: str,
variable_id: str,
user: Optional[Union[Account, EndUser]],
user: Union[Account, EndUser] | None,
new_value: Any,
):
"""

View File

@ -7,7 +7,7 @@ import time
import uuid
from collections import Counter
from collections.abc import Sequence
from typing import Any, Literal, Optional
from typing import Any, Literal
import sqlalchemy as sa
from sqlalchemy import exists, func, select
@ -185,16 +185,16 @@ class DatasetService:
def create_empty_dataset(
tenant_id: str,
name: str,
description: Optional[str],
indexing_technique: Optional[str],
description: str | None,
indexing_technique: str | None,
account: Account,
permission: Optional[str] = None,
permission: str | None = None,
provider: str = "vendor",
external_knowledge_api_id: Optional[str] = None,
external_knowledge_id: Optional[str] = None,
embedding_model_provider: Optional[str] = None,
embedding_model_name: Optional[str] = None,
retrieval_model: Optional[RetrievalModel] = None,
external_knowledge_api_id: str | None = None,
external_knowledge_id: str | None = None,
embedding_model_provider: str | None = None,
embedding_model_name: str | None = None,
retrieval_model: RetrievalModel | None = None,
):
# check if dataset name already exists
if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
@ -257,8 +257,8 @@ class DatasetService:
return dataset
@staticmethod
def get_dataset(dataset_id) -> Optional[Dataset]:
dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first()
def get_dataset(dataset_id) -> Dataset | None:
dataset: Dataset | None = db.session.query(Dataset).filter_by(id=dataset_id).first()
return dataset
@staticmethod
@ -694,7 +694,7 @@ class DatasetService:
raise NoPermissionError("You do not have permission to access this dataset.")
@staticmethod
def check_dataset_operator_permission(user: Optional[Account] = None, dataset: Optional[Dataset] = None):
def check_dataset_operator_permission(user: Account | None = None, dataset: Dataset | None = None):
if not dataset:
raise ValueError("Dataset not found")
@ -868,7 +868,7 @@ class DocumentService:
}
@staticmethod
def get_document(dataset_id: str, document_id: Optional[str] = None) -> Optional[Document]:
def get_document(dataset_id: str, document_id: str | None = None) -> Document | None:
if document_id:
document = (
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
@ -878,7 +878,7 @@ class DocumentService:
return None
@staticmethod
def get_document_by_id(document_id: str) -> Optional[Document]:
def get_document_by_id(document_id: str) -> Document | None:
document = db.session.query(Document).where(Document.id == document_id).first()
return document
@ -1004,7 +1004,7 @@ class DocumentService:
if dataset.built_in_field_enabled:
if document.doc_metadata:
doc_metadata = copy.deepcopy(document.doc_metadata)
doc_metadata[BuiltInField.document_name.value] = name
doc_metadata[BuiltInField.document_name] = name
document.doc_metadata = doc_metadata
document.name = name
@ -1099,7 +1099,7 @@ class DocumentService:
dataset: Dataset,
knowledge_config: KnowledgeConfig,
account: Account | Any,
dataset_process_rule: Optional[DatasetProcessRule] = None,
dataset_process_rule: DatasetProcessRule | None = None,
created_from: str = "web",
) -> tuple[list[Document], str]:
# check doc_form
@ -1463,7 +1463,7 @@ class DocumentService:
dataset: Dataset,
document_data: KnowledgeConfig,
account: Account,
dataset_process_rule: Optional[DatasetProcessRule] = None,
dataset_process_rule: DatasetProcessRule | None = None,
created_from: str = "web",
):
assert isinstance(current_user, Account)
@ -2365,7 +2365,22 @@ class SegmentService:
if segment.enabled:
# send delete segment index task
redis_client.setex(indexing_cache_key, 600, 1)
delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id)
# Get child chunk IDs before parent segment is deleted
child_node_ids = []
if segment.index_node_id:
child_chunks = (
db.session.query(ChildChunk.index_node_id)
.where(
ChildChunk.segment_id == segment.id,
ChildChunk.dataset_id == dataset.id,
)
.all()
)
child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]]
delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id, child_node_ids)
db.session.delete(segment)
# update document word count
assert document.word_count is not None
@ -2375,9 +2390,13 @@ class SegmentService:
@classmethod
def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset):
assert isinstance(current_user, Account)
segments = (
db.session.query(DocumentSegment.index_node_id, DocumentSegment.word_count)
assert current_user is not None
# Check if segment_ids is not empty to avoid WHERE false condition
if not segment_ids or len(segment_ids) == 0:
return
segments_info = (
db.session.query(DocumentSegment)
.with_entities(DocumentSegment.index_node_id, DocumentSegment.id, DocumentSegment.word_count)
.where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
@ -2387,18 +2406,36 @@ class SegmentService:
.all()
)
if not segments:
if not segments_info:
return
index_node_ids = [seg.index_node_id for seg in segments]
total_words = sum(seg.word_count for seg in segments)
index_node_ids = [info[0] for info in segments_info]
segment_db_ids = [info[1] for info in segments_info]
total_words = sum(info[2] for info in segments_info if info[2] is not None)
# Get child chunk IDs before parent segments are deleted
child_node_ids = []
if index_node_ids:
child_chunks = (
db.session.query(ChildChunk.index_node_id)
.where(
ChildChunk.segment_id.in_(segment_db_ids),
ChildChunk.dataset_id == dataset.id,
)
.all()
)
child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]]
# Start async cleanup with both parent and child node IDs
if index_node_ids or child_node_ids:
delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id, child_node_ids)
document.word_count = (
document.word_count - total_words if document.word_count and document.word_count > total_words else 0
)
db.session.add(document)
delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id)
# Delete database records
db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).delete()
db.session.commit()
@ -2618,7 +2655,7 @@ class SegmentService:
@classmethod
def get_child_chunks(
cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None
cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: str | None = None
):
assert isinstance(current_user, Account)
@ -2637,7 +2674,7 @@ class SegmentService:
return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
@classmethod
def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> Optional[ChildChunk]:
def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> ChildChunk | None:
"""Get a child chunk by its ID."""
result = (
db.session.query(ChildChunk)
@ -2674,7 +2711,7 @@ class SegmentService:
return paginated_segments.items, paginated_segments.total
@classmethod
def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]:
def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> DocumentSegment | None:
"""Get a segment by its ID."""
result = (
db.session.query(DocumentSegment)
@ -2738,11 +2775,7 @@ class DatasetPermissionService:
).where(DatasetPermission.dataset_id == dataset_id)
).all()
user_list = []
for user in user_list_query:
user_list.append(user.account_id)
return user_list
return user_list_query
@classmethod
def update_partial_member_list(cls, tenant_id, dataset_id, user_list):

View File

@ -9,9 +9,9 @@ from services.errors.base import BaseServiceError
logger = logging.getLogger(__name__)
class PluginCredentialType(enum.IntEnum):
MODEL = enum.auto()
TOOL = enum.auto()
class PluginCredentialType(enum.Enum):
MODEL = 0 # must be 0 for API contract compatibility
TOOL = 1 # must be 1 for API contract compatibility
def to_number(self):
return self.value

View File

@ -1,4 +1,4 @@
from typing import Literal, Optional, Union
from typing import Literal, Union
from pydantic import BaseModel
@ -11,7 +11,7 @@ class AuthorizationConfig(BaseModel):
class Authorization(BaseModel):
type: Literal["no-auth", "api-key"]
config: Optional[AuthorizationConfig] = None
config: AuthorizationConfig | None = None
class ProcessStatusSetting(BaseModel):
@ -22,5 +22,5 @@ class ProcessStatusSetting(BaseModel):
class ExternalKnowledgeApiSetting(BaseModel):
url: str
request_method: str
headers: Optional[dict] = None
params: Optional[dict] = None
headers: dict | None = None
params: dict | None = None

View File

@ -1,5 +1,5 @@
from enum import StrEnum
from typing import Literal, Optional
from typing import Literal
from pydantic import BaseModel
@ -11,14 +11,14 @@ class ParentMode(StrEnum):
class NotionIcon(BaseModel):
type: str
url: Optional[str] = None
emoji: Optional[str] = None
url: str | None = None
emoji: str | None = None
class NotionPage(BaseModel):
page_id: str
page_name: str
page_icon: Optional[NotionIcon] = None
page_icon: NotionIcon | None = None
type: str
@ -40,9 +40,9 @@ class FileInfo(BaseModel):
class InfoList(BaseModel):
data_source_type: Literal["upload_file", "notion_import", "website_crawl"]
notion_info_list: Optional[list[NotionInfo]] = None
file_info_list: Optional[FileInfo] = None
website_info_list: Optional[WebsiteInfo] = None
notion_info_list: list[NotionInfo] | None = None
file_info_list: FileInfo | None = None
website_info_list: WebsiteInfo | None = None
class DataSource(BaseModel):
@ -61,20 +61,20 @@ class Segmentation(BaseModel):
class Rule(BaseModel):
pre_processing_rules: Optional[list[PreProcessingRule]] = None
segmentation: Optional[Segmentation] = None
parent_mode: Optional[Literal["full-doc", "paragraph"]] = None
subchunk_segmentation: Optional[Segmentation] = None
pre_processing_rules: list[PreProcessingRule] | None = None
segmentation: Segmentation | None = None
parent_mode: Literal["full-doc", "paragraph"] | None = None
subchunk_segmentation: Segmentation | None = None
class ProcessRule(BaseModel):
mode: Literal["automatic", "custom", "hierarchical"]
rules: Optional[Rule] = None
rules: Rule | None = None
class RerankingModel(BaseModel):
reranking_provider_name: Optional[str] = None
reranking_model_name: Optional[str] = None
reranking_provider_name: str | None = None
reranking_model_name: str | None = None
class WeightVectorSetting(BaseModel):
@ -88,20 +88,20 @@ class WeightKeywordSetting(BaseModel):
class WeightModel(BaseModel):
weight_type: Optional[Literal["semantic_first", "keyword_first", "customized"]] = None
vector_setting: Optional[WeightVectorSetting] = None
keyword_setting: Optional[WeightKeywordSetting] = None
weight_type: Literal["semantic_first", "keyword_first", "customized"] | None = None
vector_setting: WeightVectorSetting | None = None
keyword_setting: WeightKeywordSetting | None = None
class RetrievalModel(BaseModel):
search_method: Literal["hybrid_search", "semantic_search", "full_text_search", "keyword_search"]
reranking_enable: bool
reranking_model: Optional[RerankingModel] = None
reranking_mode: Optional[str] = None
reranking_model: RerankingModel | None = None
reranking_mode: str | None = None
top_k: int
score_threshold_enabled: bool
score_threshold: Optional[float] = None
weights: Optional[WeightModel] = None
score_threshold: float | None = None
weights: WeightModel | None = None
class MetaDataConfig(BaseModel):
@ -110,29 +110,29 @@ class MetaDataConfig(BaseModel):
class KnowledgeConfig(BaseModel):
original_document_id: Optional[str] = None
original_document_id: str | None = None
duplicate: bool = True
indexing_technique: Literal["high_quality", "economy"]
data_source: Optional[DataSource] = None
process_rule: Optional[ProcessRule] = None
retrieval_model: Optional[RetrievalModel] = None
data_source: DataSource | None = None
process_rule: ProcessRule | None = None
retrieval_model: RetrievalModel | None = None
doc_form: str = "text_model"
doc_language: str = "English"
embedding_model: Optional[str] = None
embedding_model_provider: Optional[str] = None
name: Optional[str] = None
embedding_model: str | None = None
embedding_model_provider: str | None = None
name: str | None = None
class SegmentUpdateArgs(BaseModel):
content: Optional[str] = None
answer: Optional[str] = None
keywords: Optional[list[str]] = None
content: str | None = None
answer: str | None = None
keywords: list[str] | None = None
regenerate_child_chunks: bool = False
enabled: Optional[bool] = None
enabled: bool | None = None
class ChildChunkUpdateArgs(BaseModel):
id: Optional[str] = None
id: str | None = None
content: str
@ -143,13 +143,13 @@ class MetadataArgs(BaseModel):
class MetadataUpdateArgs(BaseModel):
name: str
value: Optional[str | int | float] = None
value: str | int | float | None = None
class MetadataDetail(BaseModel):
id: str
name: str
value: Optional[str | int | float] = None
value: str | int | float | None = None
class DocumentMetadataOperation(BaseModel):

View File

@ -1,5 +1,4 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel, ConfigDict
@ -42,11 +41,11 @@ class CustomConfigurationResponse(BaseModel):
"""
status: CustomConfigurationStatus
current_credential_id: Optional[str] = None
current_credential_name: Optional[str] = None
available_credentials: Optional[list[CredentialConfiguration]] = None
custom_models: Optional[list[CustomModelConfiguration]] = None
can_added_models: Optional[list[UnaddedModelConfiguration]] = None
current_credential_id: str | None = None
current_credential_name: str | None = None
available_credentials: list[CredentialConfiguration] | None = None
custom_models: list[CustomModelConfiguration] | None = None
can_added_models: list[UnaddedModelConfiguration] | None = None
class SystemConfigurationResponse(BaseModel):
@ -55,7 +54,7 @@ class SystemConfigurationResponse(BaseModel):
"""
enabled: bool
current_quota_type: Optional[ProviderQuotaType] = None
current_quota_type: ProviderQuotaType | None = None
quota_configurations: list[QuotaConfiguration] = []
@ -67,15 +66,15 @@ class ProviderResponse(BaseModel):
tenant_id: str
provider: str
label: I18nObject
description: Optional[I18nObject] = None
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
background: Optional[str] = None
help: Optional[ProviderHelpEntity] = None
description: I18nObject | None = None
icon_small: I18nObject | None = None
icon_large: I18nObject | None = None
background: str | None = None
help: ProviderHelpEntity | None = None
supported_model_types: list[ModelType]
configurate_methods: list[ConfigurateMethod]
provider_credential_schema: Optional[ProviderCredentialSchema] = None
model_credential_schema: Optional[ModelCredentialSchema] = None
provider_credential_schema: ProviderCredentialSchema | None = None
model_credential_schema: ModelCredentialSchema | None = None
preferred_provider_type: ProviderType
custom_configuration: CustomConfigurationResponse
system_configuration: SystemConfigurationResponse
@ -108,8 +107,8 @@ class ProviderWithModelsResponse(BaseModel):
tenant_id: str
provider: str
label: I18nObject
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
icon_small: I18nObject | None = None
icon_large: I18nObject | None = None
status: CustomConfigurationStatus
models: list[ProviderModelWithStatusEntity]

View File

@ -1,6 +1,3 @@
from typing import Optional
class BaseServiceError(ValueError):
def __init__(self, description: Optional[str] = None):
def __init__(self, description: str | None = None):
self.description = description

View File

@ -1,12 +1,9 @@
from typing import Optional
class InvokeError(Exception):
"""Base class for all LLM exceptions."""
description: Optional[str] = None
description: str | None = None
def __init__(self, description: Optional[str] = None):
def __init__(self, description: str | None = None):
self.description = description
def __str__(self):

View File

@ -1,6 +1,6 @@
import json
from copy import deepcopy
from typing import Any, Optional, Union, cast
from typing import Any, Union, cast
from urllib.parse import urlparse
import httpx
@ -100,7 +100,7 @@ class ExternalDatasetService:
@staticmethod
def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis:
external_knowledge_api: Optional[ExternalKnowledgeApis] = (
external_knowledge_api: ExternalKnowledgeApis | None = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id).first()
)
if external_knowledge_api is None:
@ -109,7 +109,7 @@ class ExternalDatasetService:
@staticmethod
def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis:
external_knowledge_api: Optional[ExternalKnowledgeApis] = (
external_knowledge_api: ExternalKnowledgeApis | None = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
)
if external_knowledge_api is None:
@ -151,7 +151,7 @@ class ExternalDatasetService:
@staticmethod
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
external_knowledge_binding: Optional[ExternalKnowledgeBindings] = (
external_knowledge_binding: ExternalKnowledgeBindings | None = (
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
)
if not external_knowledge_binding:
@ -203,7 +203,7 @@ class ExternalDatasetService:
return response
@staticmethod
def assembling_headers(authorization: Authorization, headers: Optional[dict] = None) -> dict[str, Any]:
def assembling_headers(authorization: Authorization, headers: dict | None = None) -> dict[str, Any]:
authorization = deepcopy(authorization)
if headers:
headers = deepcopy(headers)
@ -277,7 +277,7 @@ class ExternalDatasetService:
dataset_id: str,
query: str,
external_retrieval_parameters: dict,
metadata_condition: Optional[MetadataCondition] = None,
metadata_condition: MetadataCondition | None = None,
):
external_knowledge_binding = (
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()

View File

@ -1,5 +1,5 @@
import json
from typing import Optional, Union
from typing import Union
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.entities.app_invoke_entities import InvokeFrom
@ -29,9 +29,9 @@ class MessageService:
def pagination_by_first_id(
cls,
app_model: App,
user: Optional[Union[Account, EndUser]],
user: Union[Account, EndUser] | None,
conversation_id: str,
first_id: Optional[str],
first_id: str | None,
limit: int,
order: str = "asc",
) -> InfiniteScrollPagination:
@ -91,11 +91,11 @@ class MessageService:
def pagination_by_last_id(
cls,
app_model: App,
user: Optional[Union[Account, EndUser]],
last_id: Optional[str],
user: Union[Account, EndUser] | None,
last_id: str | None,
limit: int,
conversation_id: Optional[str] = None,
include_ids: Optional[list] = None,
conversation_id: str | None = None,
include_ids: list | None = None,
) -> InfiniteScrollPagination:
if not user:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
@ -145,9 +145,9 @@ class MessageService:
*,
app_model: App,
message_id: str,
user: Optional[Union[Account, EndUser]],
rating: Optional[str],
content: Optional[str],
user: Union[Account, EndUser] | None,
rating: str | None,
content: str | None,
):
if not user:
raise ValueError("user cannot be None")
@ -196,7 +196,7 @@ class MessageService:
return [record.to_dict() for record in feedbacks]
@classmethod
def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
def get_message(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str):
message = (
db.session.query(Message)
.where(
@ -216,7 +216,7 @@ class MessageService:
@classmethod
def get_suggested_questions_after_answer(
cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str, invoke_from: InvokeFrom
cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str, invoke_from: InvokeFrom
) -> list[Message]:
if not user:
raise ValueError("user cannot be None")
@ -229,7 +229,7 @@ class MessageService:
model_manager = ModelManager()
if app_model.mode == AppMode.ADVANCED_CHAT.value:
if app_model.mode == AppMode.ADVANCED_CHAT:
workflow_service = WorkflowService()
if invoke_from == InvokeFrom.DEBUGGER:
workflow = workflow_service.get_draft_workflow(app_model=app_model)

View File

@ -1,6 +1,5 @@
import copy
import logging
from typing import Optional
from flask_login import current_user
@ -131,11 +130,11 @@ class MetadataService:
@staticmethod
def get_built_in_fields():
return [
{"name": BuiltInField.document_name.value, "type": "string"},
{"name": BuiltInField.uploader.value, "type": "string"},
{"name": BuiltInField.upload_date.value, "type": "time"},
{"name": BuiltInField.last_update_date.value, "type": "time"},
{"name": BuiltInField.source.value, "type": "string"},
{"name": BuiltInField.document_name, "type": "string"},
{"name": BuiltInField.uploader, "type": "string"},
{"name": BuiltInField.upload_date, "type": "time"},
{"name": BuiltInField.last_update_date, "type": "time"},
{"name": BuiltInField.source, "type": "string"},
]
@staticmethod
@ -153,11 +152,11 @@ class MetadataService:
doc_metadata = {}
else:
doc_metadata = copy.deepcopy(document.doc_metadata)
doc_metadata[BuiltInField.document_name.value] = document.name
doc_metadata[BuiltInField.uploader.value] = document.uploader
doc_metadata[BuiltInField.upload_date.value] = document.upload_date.timestamp()
doc_metadata[BuiltInField.last_update_date.value] = document.last_update_date.timestamp()
doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value
doc_metadata[BuiltInField.document_name] = document.name
doc_metadata[BuiltInField.uploader] = document.uploader
doc_metadata[BuiltInField.upload_date] = document.upload_date.timestamp()
doc_metadata[BuiltInField.last_update_date] = document.last_update_date.timestamp()
doc_metadata[BuiltInField.source] = MetadataDataSource[document.data_source_type]
document.doc_metadata = doc_metadata
db.session.add(document)
dataset.built_in_field_enabled = True
@ -183,11 +182,11 @@ class MetadataService:
doc_metadata = {}
else:
doc_metadata = copy.deepcopy(document.doc_metadata)
doc_metadata.pop(BuiltInField.document_name.value, None)
doc_metadata.pop(BuiltInField.uploader.value, None)
doc_metadata.pop(BuiltInField.upload_date.value, None)
doc_metadata.pop(BuiltInField.last_update_date.value, None)
doc_metadata.pop(BuiltInField.source.value, None)
doc_metadata.pop(BuiltInField.document_name, None)
doc_metadata.pop(BuiltInField.uploader, None)
doc_metadata.pop(BuiltInField.upload_date, None)
doc_metadata.pop(BuiltInField.last_update_date, None)
doc_metadata.pop(BuiltInField.source, None)
document.doc_metadata = doc_metadata
db.session.add(document)
document_ids.append(document.id)
@ -211,11 +210,11 @@ class MetadataService:
for metadata_value in operation.metadata_list:
doc_metadata[metadata_value.name] = metadata_value.value
if dataset.built_in_field_enabled:
doc_metadata[BuiltInField.document_name.value] = document.name
doc_metadata[BuiltInField.uploader.value] = document.uploader
doc_metadata[BuiltInField.upload_date.value] = document.upload_date.timestamp()
doc_metadata[BuiltInField.last_update_date.value] = document.last_update_date.timestamp()
doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value
doc_metadata[BuiltInField.document_name] = document.name
doc_metadata[BuiltInField.uploader] = document.uploader
doc_metadata[BuiltInField.upload_date] = document.upload_date.timestamp()
doc_metadata[BuiltInField.last_update_date] = document.last_update_date.timestamp()
doc_metadata[BuiltInField.source] = MetadataDataSource[document.data_source_type]
document.doc_metadata = doc_metadata
db.session.add(document)
db.session.commit()
@ -237,7 +236,7 @@ class MetadataService:
redis_client.delete(lock_key)
@staticmethod
def knowledge_base_metadata_lock_check(dataset_id: Optional[str], document_id: Optional[str]):
def knowledge_base_metadata_lock_check(dataset_id: str | None, document_id: str | None):
if dataset_id:
lock_key = f"dataset_metadata_lock_{dataset_id}"
if redis_client.get(lock_key):

View File

@ -1,7 +1,7 @@
import json
import logging
from json import JSONDecodeError
from typing import Optional, Union
from typing import Union
from sqlalchemy import or_, select
@ -211,7 +211,7 @@ class ModelLoadBalancingService:
def get_load_balancing_config(
self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str
) -> Optional[dict]:
) -> dict | None:
"""
Get load balancing configuration.
:param tenant_id: workspace id
@ -478,7 +478,7 @@ class ModelLoadBalancingService:
model: str,
model_type: str,
credentials: dict,
config_id: Optional[str] = None,
config_id: str | None = None,
):
"""
Validate load balancing credentials.
@ -536,7 +536,7 @@ class ModelLoadBalancingService:
model_type: ModelType,
model: str,
credentials: dict,
load_balancing_model_config: Optional[LoadBalancingModelConfig] = None,
load_balancing_model_config: LoadBalancingModelConfig | None = None,
validate: bool = True,
):
"""

View File

@ -1,5 +1,4 @@
import logging
from typing import Optional
from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity
from core.model_runtime.entities.model_entities import ModelType, ParameterRule
@ -52,7 +51,7 @@ class ModelProviderService:
return provider_configuration
def get_provider_list(self, tenant_id: str, model_type: Optional[str] = None) -> list[ProviderResponse]:
def get_provider_list(self, tenant_id: str, model_type: str | None = None) -> list[ProviderResponse]:
"""
get provider list.
@ -128,9 +127,7 @@ class ModelProviderService:
for model in provider_configurations.get_models(provider=provider)
]
def get_provider_credential(
self, tenant_id: str, provider: str, credential_id: Optional[str] = None
) -> Optional[dict]:
def get_provider_credential(self, tenant_id: str, provider: str, credential_id: str | None = None) -> dict | None:
"""
get provider credentials.
@ -216,7 +213,7 @@ class ModelProviderService:
def get_model_credential(
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | None
) -> Optional[dict]:
) -> dict | None:
"""
Retrieve model-specific credentials.
@ -449,7 +446,7 @@ class ModelProviderService:
return model_schema.parameter_rules if model_schema else []
def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]:
def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> DefaultModelResponse | None:
"""
get default model of model type.
@ -498,7 +495,7 @@ class ModelProviderService:
def get_model_provider_icon(
self, tenant_id: str, provider: str, icon_type: str, lang: str
) -> tuple[Optional[bytes], Optional[str]]:
) -> tuple[bytes | None, str | None]:
"""
get model provider icon.

View File

@ -1,4 +1,4 @@
from typing import Any, Optional
from typing import Any
from core.ops.entities.config_entity import BaseTracingConfig
from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map
@ -15,7 +15,7 @@ class OpsService:
:param tracing_provider: tracing provider
:return:
"""
trace_config_data: Optional[TraceAppConfig] = (
trace_config_data: TraceAppConfig | None = (
db.session.query(TraceAppConfig)
.where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
@ -153,7 +153,7 @@ class OpsService:
project_url = None
# check if trace config already exists
trace_config_data: Optional[TraceAppConfig] = (
trace_config_data: TraceAppConfig | None = (
db.session.query(TraceAppConfig)
.where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()

View File

@ -5,7 +5,7 @@ import time
from collections.abc import Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any, Optional
from typing import Any
from uuid import uuid4
import click
@ -256,7 +256,7 @@ class PluginMigration:
return []
agent_app_model_config_ids = [
app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT.value
app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT
]
rs = session.query(AppModelConfig).where(AppModelConfig.id.in_(agent_app_model_config_ids)).all()
@ -281,7 +281,7 @@ class PluginMigration:
return result
@classmethod
def _fetch_plugin_unique_identifier(cls, plugin_id: str) -> Optional[str]:
def _fetch_plugin_unique_identifier(cls, plugin_id: str) -> str | None:
"""
Fetch plugin unique identifier using plugin id.
"""

View File

@ -1,7 +1,6 @@
import logging
from collections.abc import Mapping, Sequence
from mimetypes import guess_type
from typing import Optional
from pydantic import BaseModel
@ -46,11 +45,11 @@ class PluginService:
REDIS_TTL = 60 * 5 # 5 minutes
@staticmethod
def fetch_latest_plugin_version(plugin_ids: Sequence[str]) -> Mapping[str, Optional[LatestPluginCache]]:
def fetch_latest_plugin_version(plugin_ids: Sequence[str]) -> Mapping[str, LatestPluginCache | None]:
"""
Fetch the latest plugin version
"""
result: dict[str, Optional[PluginService.LatestPluginCache]] = {}
result: dict[str, PluginService.LatestPluginCache | None] = {}
try:
cache_not_exists = []
@ -109,7 +108,7 @@ class PluginService:
raise PluginInstallationForbiddenError("Plugin installation is restricted to marketplace only")
@staticmethod
def _check_plugin_installation_scope(plugin_verification: Optional[PluginVerification]):
def _check_plugin_installation_scope(plugin_verification: PluginVerification | None):
"""
Check the plugin installation scope
"""
@ -144,7 +143,7 @@ class PluginService:
return manager.get_debugging_key(tenant_id)
@staticmethod
def list_latest_versions(plugin_ids: Sequence[str]) -> Mapping[str, Optional[LatestPluginCache]]:
def list_latest_versions(plugin_ids: Sequence[str]) -> Mapping[str, LatestPluginCache | None]:
"""
List the latest versions of the plugins
"""

View File

@ -1,7 +1,6 @@
import json
from os import path
from pathlib import Path
from typing import Optional
from flask import current_app
@ -14,7 +13,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
Retrieval recommended app from buildin, the location is constants/recommended_apps.json
"""
builtin_data: Optional[dict] = None
builtin_data: dict | None = None
def get_type(self) -> str:
return RecommendAppType.BUILDIN
@ -54,7 +53,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
return builtin_data.get("recommended_apps", {}).get(language, {})
@classmethod
def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> Optional[dict]:
def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> dict | None:
"""
Fetch recommended app detail from builtin.
:param app_id: App ID

View File

@ -1,5 +1,3 @@
from typing import Optional
from sqlalchemy import select
from constants.languages import languages
@ -72,7 +70,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)}
@classmethod
def fetch_recommended_app_detail_from_db(cls, app_id: str) -> Optional[dict]:
def fetch_recommended_app_detail_from_db(cls, app_id: str) -> dict | None:
"""
Fetch recommended app detail from db.
:param app_id: App ID

View File

@ -1,5 +1,4 @@
import logging
from typing import Optional
import requests
@ -36,7 +35,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
return RecommendAppType.REMOTE
@classmethod
def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> Optional[dict]:
def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> dict | None:
"""
Fetch recommended app detail from dify official.
:param app_id: App ID

View File

@ -1,5 +1,3 @@
from typing import Optional
from configs import dify_config
from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory
@ -25,7 +23,7 @@ class RecommendedAppService:
return result
@classmethod
def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]:
def get_recommend_app_detail(cls, app_id: str) -> dict | None:
"""
Get recommend app detail.
:param app_id: app id

View File

@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Union
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
@ -11,7 +11,7 @@ from services.message_service import MessageService
class SavedMessageService:
@classmethod
def pagination_by_last_id(
cls, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int
cls, app_model: App, user: Union[Account, EndUser] | None, last_id: str | None, limit: int
) -> InfiniteScrollPagination:
if not user:
raise ValueError("User is required")
@ -32,7 +32,7 @@ class SavedMessageService:
)
@classmethod
def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
def save(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str):
if not user:
return
saved_message = (
@ -62,7 +62,7 @@ class SavedMessageService:
db.session.commit()
@classmethod
def delete(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
def delete(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str):
if not user:
return
saved_message = (

View File

@ -1,5 +1,4 @@
import uuid
from typing import Optional
from flask_login import current_user
from sqlalchemy import func, select
@ -12,7 +11,7 @@ from models.model import App, Tag, TagBinding
class TagService:
@staticmethod
def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = None):
def get_tags(tag_type: str, current_tenant_id: str, keyword: str | None = None):
query = (
db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
.outerjoin(TagBinding, Tag.id == TagBinding.tag_id)

View File

@ -3,7 +3,7 @@ import logging
import re
from collections.abc import Mapping
from pathlib import Path
from typing import Any, Optional
from typing import Any
from sqlalchemy import exists, select
from sqlalchemy.orm import Session
@ -604,7 +604,7 @@ class BuiltinToolManageService:
return BuiltinToolProviderSort.sort(result)
@staticmethod
def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
def get_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
"""
This method is used to fetch the builtin provider from the database
1.if the default provider exists, return the default provider
@ -665,8 +665,8 @@ class BuiltinToolManageService:
def save_custom_oauth_client_params(
tenant_id: str,
provider: str,
client_params: Optional[dict] = None,
enable_oauth_custom_client: Optional[bool] = None,
client_params: dict | None = None,
enable_oauth_custom_client: bool | None = None,
):
"""
setup oauth custom client

View File

@ -173,12 +173,15 @@ class MCPToolManageService:
if sse_read_timeout is not None:
mcp_provider.sse_read_timeout = sse_read_timeout
if headers is not None:
mcp_provider.encrypted_headers = (
self._prepare_encrypted_headers(headers, tenant_id) if headers else None
)
if headers:
# Build headers preserving unchanged masked values
final_headers = self._merge_headers_with_masked(incoming_headers=headers, mcp_provider=mcp_provider)
encrypted_headers_dict = self._prepare_encrypted_headers(final_headers, tenant_id)
mcp_provider.encrypted_headers = encrypted_headers_dict
else:
# Clear headers if empty dict passed
mcp_provider.encrypted_headers = None
self._session.commit()
except IntegrityError as e:
self._session.rollback()
self._handle_integrity_error(e, name, server_url, server_identifier)
@ -357,7 +360,7 @@ class MCPToolManageService:
def _prepare_auth_headers(self, provider_entity: MCPProviderEntity) -> dict[str, str]:
"""Prepare headers with OAuth token if available."""
headers = provider_entity.headers.copy() if provider_entity.headers else {}
headers = provider_entity.decrypt_headers()
tokens = provider_entity.retrieve_tokens()
if tokens:
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
@ -436,3 +439,25 @@ class MCPToolManageService:
if "unique_mcp_provider_server_identifier" in error_msg:
raise ValueError(f"MCP tool {server_identifier} already exists")
raise
def _merge_headers_with_masked(
self, incoming_headers: dict[str, str], mcp_provider: MCPToolProvider
) -> dict[str, str]:
"""Merge incoming headers with existing ones, preserving unchanged masked values.
Args:
incoming_headers: Headers from frontend (may contain masked values)
mcp_provider: The MCP provider instance
Returns:
Final headers dict with proper values (original for unchanged masked, new for changed)
"""
mcp_provider_entity = mcp_provider.to_entity()
existing_decrypted = mcp_provider_entity.decrypt_headers()
existing_masked = mcp_provider_entity.masked_headers()
return {
key: (str(existing_decrypted[key]) if key in existing_masked and value == existing_masked[key] else value)
for key, value in incoming_headers.items()
if key in existing_decrypted or value != existing_masked.get(key)
}

View File

@ -1,5 +1,4 @@
import logging
from typing import Optional
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
from core.tools.tool_manager import ToolManager
@ -10,7 +9,7 @@ logger = logging.getLogger(__name__)
class ToolCommonService:
@staticmethod
def list_tool_providers(user_id: str, tenant_id: str, typ: Optional[ToolProviderTypeApiLiteral] = None):
def list_tool_providers(user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral | None = None):
"""
list tool providers

View File

@ -1,6 +1,6 @@
import json
import logging
from typing import Any, Optional, Union, cast
from typing import Any, Union, cast
from yarl import URL
@ -94,7 +94,7 @@ class ToolTransformService:
def builtin_provider_to_user_provider(
cls,
provider_controller: BuiltinToolProviderController | PluginToolProviderController,
db_provider: Optional[BuiltinToolProvider],
db_provider: BuiltinToolProvider | None,
decrypt_credentials: bool = True,
) -> ToolProviderApiEntity:
"""

View File

@ -1,5 +1,4 @@
import logging
from typing import Optional
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
@ -19,7 +18,7 @@ logger = logging.getLogger(__name__)
class VectorService:
@classmethod
def create_segments_vector(
cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset, doc_form: str
cls, keywords_list: list[list[str]] | None, segments: list[DocumentSegment], dataset: Dataset, doc_form: str
):
documents: list[Document] = []
@ -79,7 +78,7 @@ class VectorService:
index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list)
@classmethod
def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset):
def update_segment_vector(cls, keywords: list[str] | None, segment: DocumentSegment, dataset: Dataset):
# update segment index task
# format new index

View File

@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Union
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -19,11 +19,11 @@ class WebConversationService:
*,
session: Session,
app_model: App,
user: Optional[Union[Account, EndUser]],
last_id: Optional[str],
user: Union[Account, EndUser] | None,
last_id: str | None,
limit: int,
invoke_from: InvokeFrom,
pinned: Optional[bool] = None,
pinned: bool | None = None,
sort_by="-updated_at",
) -> InfiniteScrollPagination:
if not user:
@ -60,7 +60,7 @@ class WebConversationService:
)
@classmethod
def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
def pin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
if not user:
return
pinned_conversation = (
@ -92,7 +92,7 @@ class WebConversationService:
db.session.commit()
@classmethod
def unpin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
def unpin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
if not user:
return
pinned_conversation = (

View File

@ -1,7 +1,7 @@
import enum
import secrets
from datetime import UTC, datetime, timedelta
from typing import Any, Optional
from typing import Any
from werkzeug.exceptions import NotFound, Unauthorized
@ -63,7 +63,7 @@ class WebAppAuthService:
@classmethod
def send_email_code_login_email(
cls, account: Optional[Account] = None, email: Optional[str] = None, language: str = "en-US"
cls, account: Account | None = None, email: str | None = None, language: str = "en-US"
):
email = account.email if account else email
if email is None:
@ -82,7 +82,7 @@ class WebAppAuthService:
return token
@classmethod
def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]:
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "email_code_login")
@classmethod
@ -130,7 +130,7 @@ class WebAppAuthService:
@classmethod
def is_app_require_permission_check(
cls, app_code: Optional[str] = None, app_id: Optional[str] = None, access_mode: Optional[str] = None
cls, app_code: str | None = None, app_id: str | None = None, access_mode: str | None = None
) -> bool:
"""
Check if the app requires permission check based on its access mode.

View File

@ -1,7 +1,7 @@
import datetime
import json
from dataclasses import dataclass
from typing import Any, Optional
from typing import Any
import requests
from flask_login import current_user
@ -21,9 +21,9 @@ class CrawlOptions:
limit: int = 1
crawl_sub_pages: bool = False
only_main_content: bool = False
includes: Optional[str] = None
excludes: Optional[str] = None
max_depth: Optional[int] = None
includes: str | None = None
excludes: str | None = None
max_depth: int | None = None
use_sitemap: bool = True
def get_include_paths(self) -> list[str]:

View File

@ -1,5 +1,5 @@
import json
from typing import Any, Optional
from typing import Any
from core.app.app_config.entities import (
DatasetEntity,
@ -65,7 +65,7 @@ class WorkflowConverter:
new_app = App()
new_app.tenant_id = app_model.tenant_id
new_app.name = name or app_model.name + "(workflow)"
new_app.mode = AppMode.ADVANCED_CHAT.value if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value
new_app.mode = AppMode.ADVANCED_CHAT if app_model.mode == AppMode.CHAT else AppMode.WORKFLOW
new_app.icon_type = icon_type or app_model.icon_type
new_app.icon = icon or app_model.icon
new_app.icon_background = icon_background or app_model.icon_background
@ -203,7 +203,7 @@ class WorkflowConverter:
app_mode_enum = AppMode.value_of(app_model.mode)
app_config: EasyUIBasedAppConfig
if app_mode_enum == AppMode.AGENT_CHAT or app_model.is_agent:
app_model.mode = AppMode.AGENT_CHAT.value
app_model.mode = AppMode.AGENT_CHAT
app_config = AgentChatAppConfigManager.get_app_config(
app_model=app_model, app_model_config=app_model_config
)
@ -279,7 +279,7 @@ class WorkflowConverter:
"app_id": app_model.id,
"tool_variable": tool_variable,
"inputs": inputs,
"query": "{{#sys.query#}}" if app_model.mode == AppMode.CHAT.value else "",
"query": "{{#sys.query#}}" if app_model.mode == AppMode.CHAT else "",
},
}
@ -327,7 +327,7 @@ class WorkflowConverter:
def _convert_to_knowledge_retrieval_node(
self, new_app_mode: AppMode, dataset_config: DatasetEntity, model_config: ModelConfigEntity
) -> Optional[dict]:
) -> dict | None:
"""
Convert datasets to Knowledge Retrieval Node
:param new_app_mode: new app mode
@ -383,7 +383,7 @@ class WorkflowConverter:
graph: dict,
model_config: ModelConfigEntity,
prompt_template: PromptTemplateEntity,
file_upload: Optional[FileUploadConfig] = None,
file_upload: FileUploadConfig | None = None,
external_data_variable_node_mapping: dict[str, str] | None = None,
):
"""
@ -403,7 +403,7 @@ class WorkflowConverter:
)
role_prefix = None
prompts: Optional[Any] = None
prompts: Any | None = None
# Chat Model
if model_config.mode == LLMMode.CHAT.value:
@ -618,7 +618,7 @@ class WorkflowConverter:
:param app_model: App instance
:return: AppMode
"""
if app_model.mode == AppMode.COMPLETION.value:
if app_model.mode == AppMode.COMPLETION:
return AppMode.WORKFLOW
else:
return AppMode.ADVANCED_CHAT

View File

@ -1,6 +1,5 @@
import threading
from collections.abc import Sequence
from typing import Optional
from sqlalchemy.orm import sessionmaker
@ -80,7 +79,7 @@ class WorkflowRunService:
last_id=last_id,
)
def get_workflow_run(self, app_model: App, run_id: str) -> Optional[WorkflowRun]:
def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun | None:
"""
Get workflow run detail

View File

@ -2,7 +2,7 @@ import json
import time
import uuid
from collections.abc import Callable, Generator, Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any, cast
from uuid import uuid4
from sqlalchemy import exists, select
@ -88,7 +88,7 @@ class WorkflowService:
)
return db.session.execute(stmt).scalar_one()
def get_draft_workflow(self, app_model: App, workflow_id: Optional[str] = None) -> Optional[Workflow]:
def get_draft_workflow(self, app_model: App, workflow_id: str | None = None) -> Workflow | None:
"""
Get draft workflow
"""
@ -108,7 +108,7 @@ class WorkflowService:
# return draft workflow
return workflow
def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Workflow | None:
"""
fetch published workflow by workflow_id
"""
@ -130,7 +130,7 @@ class WorkflowService:
)
return workflow
def get_published_workflow(self, app_model: App) -> Optional[Workflow]:
def get_published_workflow(self, app_model: App) -> Workflow | None:
"""
Get published workflow
"""
@ -195,7 +195,7 @@ class WorkflowService:
app_model: App,
graph: dict,
features: dict,
unique_hash: Optional[str],
unique_hash: str | None,
account: Account,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
@ -375,13 +375,14 @@ class WorkflowService:
def _validate_llm_model_config(self, tenant_id: str, provider: str, model_name: str) -> None:
"""
Validate that an LLM model configuration can fetch valid credentials.
Validate that an LLM model configuration can fetch valid credentials and has active status.
This method attempts to get the model instance and validates that:
1. The provider exists and is configured
2. The model exists in the provider
3. Credentials can be fetched for the model
4. The credentials pass policy compliance checks
5. The model status is ACTIVE (not NO_CONFIGURE, DISABLED, etc.)
:param tenant_id: The tenant ID
:param provider: The provider name
@ -391,6 +392,7 @@ class WorkflowService:
try:
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
# Get model instance to validate provider+model combination
model_manager = ModelManager()
@ -402,6 +404,22 @@ class WorkflowService:
# via ProviderConfiguration.get_current_credentials() -> _check_credential_policy_compliance()
# If it fails, an exception will be raised
# Additionally, check the model status to ensure it's ACTIVE
provider_manager = ProviderManager()
provider_configurations = provider_manager.get_configurations(tenant_id)
models = provider_configurations.get_models(provider=provider, model_type=ModelType.LLM)
target_model = None
for model in models:
if model.model == model_name and model.provider.provider == provider:
target_model = model
break
if target_model:
target_model.raise_for_status()
else:
raise ValueError(f"Model {model_name} not found for provider {provider}")
except Exception as e:
raise ValueError(
f"Failed to validate LLM model configuration (provider: {provider}, model: {model_name}): {str(e)}"
@ -561,7 +579,7 @@ class WorkflowService:
return default_block_configs
def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]:
def get_default_block_config(self, node_type: str, filters: dict | None = None) -> dict | None:
"""
Get default config of node.
:param node_type: node type
@ -828,7 +846,7 @@ class WorkflowService:
# chatbot convert to workflow mode
workflow_converter = WorkflowConverter()
if app_model.mode not in {AppMode.CHAT.value, AppMode.COMPLETION.value}:
if app_model.mode not in {AppMode.CHAT, AppMode.COMPLETION}:
raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.")
# convert to workflow
@ -844,11 +862,11 @@ class WorkflowService:
return new_app
def validate_features_structure(self, app_model: App, features: dict):
if app_model.mode == AppMode.ADVANCED_CHAT.value:
if app_model.mode == AppMode.ADVANCED_CHAT:
return AdvancedChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
)
elif app_model.mode == AppMode.WORKFLOW.value:
elif app_model.mode == AppMode.WORKFLOW:
return WorkflowAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
)
@ -857,7 +875,7 @@ class WorkflowService:
def update_workflow(
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
) -> Optional[Workflow]:
) -> Workflow | None:
"""
Update workflow attributes