Merge remote-tracking branch 'origin/main' into feat/trigger

This commit is contained in:
yessenia
2025-09-25 17:14:24 +08:00
3013 changed files with 148826 additions and 44294 deletions

View File

@ -5,7 +5,7 @@ import secrets
import uuid
from datetime import UTC, datetime, timedelta
from hashlib import sha256
from typing import Any, Optional, cast
from typing import Any, cast
from pydantic import BaseModel
from sqlalchemy import func
@ -37,7 +37,6 @@ from services.billing_service import BillingService
from services.errors.account import (
AccountAlreadyInTenantError,
AccountLoginError,
AccountNotFoundError,
AccountNotLinkTenantError,
AccountPasswordError,
AccountRegisterError,
@ -65,7 +64,11 @@ from tasks.mail_owner_transfer_task import (
send_old_owner_transfer_notify_email_task,
send_owner_transfer_confirm_task,
)
from tasks.mail_reset_password_task import send_reset_password_mail_task
from tasks.mail_register_task import send_email_register_mail_task, send_email_register_mail_task_when_account_exist
from tasks.mail_reset_password_task import (
send_reset_password_mail_task,
send_reset_password_mail_task_when_account_not_exist,
)
logger = logging.getLogger(__name__)
@ -82,8 +85,9 @@ REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
class AccountService:
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1)
email_register_rate_limiter = RateLimiter(prefix="email_register_rate_limit", max_attempts=1, time_window=60 * 1)
email_code_login_rate_limiter = RateLimiter(
prefix="email_code_login_rate_limit", max_attempts=1, time_window=60 * 1
prefix="email_code_login_rate_limit", max_attempts=3, time_window=300 * 1
)
email_code_account_deletion_rate_limiter = RateLimiter(
prefix="email_code_account_deletion_rate_limit", max_attempts=1, time_window=60 * 1
@ -95,6 +99,7 @@ class AccountService:
FORGOT_PASSWORD_MAX_ERROR_LIMITS = 5
CHANGE_EMAIL_MAX_ERROR_LIMITS = 5
OWNER_TRANSFER_MAX_ERROR_LIMITS = 5
EMAIL_REGISTER_MAX_ERROR_LIMITS = 5
@staticmethod
def _get_refresh_token_key(refresh_token: str) -> str:
@ -105,14 +110,14 @@ class AccountService:
return f"{ACCOUNT_REFRESH_TOKEN_PREFIX}{account_id}"
@staticmethod
def _store_refresh_token(refresh_token: str, account_id: str) -> None:
def _store_refresh_token(refresh_token: str, account_id: str):
redis_client.setex(AccountService._get_refresh_token_key(refresh_token), REFRESH_TOKEN_EXPIRY, account_id)
redis_client.setex(
AccountService._get_account_refresh_token_key(account_id), REFRESH_TOKEN_EXPIRY, refresh_token
)
@staticmethod
def _delete_refresh_token(refresh_token: str, account_id: str) -> None:
def _delete_refresh_token(refresh_token: str, account_id: str):
redis_client.delete(AccountService._get_refresh_token_key(refresh_token))
redis_client.delete(AccountService._get_account_refresh_token_key(account_id))
@ -145,7 +150,10 @@ class AccountService:
if naive_utc_now() - account.last_active_at > timedelta(minutes=10):
account.last_active_at = naive_utc_now()
db.session.commit()
# NOTE: make sure account is accessible outside of a db session
# This ensures that it will work correctly after upgrading to Flask version 3.1.2
db.session.refresh(account)
db.session.close()
return account
@staticmethod
@ -163,12 +171,12 @@ class AccountService:
return token
@staticmethod
def authenticate(email: str, password: str, invite_token: Optional[str] = None) -> Account:
def authenticate(email: str, password: str, invite_token: str | None = None) -> Account:
"""authenticate account with email and password"""
account = db.session.query(Account).filter_by(email=email).first()
if not account:
raise AccountNotFoundError()
raise AccountPasswordError("Invalid email or password.")
if account.status == AccountStatus.BANNED.value:
raise AccountLoginError("Account is banned.")
@ -211,6 +219,7 @@ class AccountService:
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.add(account)
db.session.commit()
return account
@ -219,9 +228,9 @@ class AccountService:
email: str,
name: str,
interface_language: str,
password: Optional[str] = None,
password: str | None = None,
interface_theme: str = "light",
is_setup: Optional[bool] = False,
is_setup: bool | None = False,
) -> Account:
"""create account"""
if not FeatureService.get_system_features().is_allow_register and not is_setup:
@ -242,6 +251,8 @@ class AccountService:
account.name = name
if password:
valid_password(password)
# generate password salt
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
@ -265,7 +276,7 @@ class AccountService:
@staticmethod
def create_account_and_tenant(
email: str, name: str, interface_language: str, password: Optional[str] = None
email: str, name: str, interface_language: str, password: str | None = None
) -> Account:
"""create account"""
account = AccountService.create_account(
@ -290,7 +301,9 @@ class AccountService:
if cls.email_code_account_deletion_rate_limiter.is_rate_limited(email):
from controllers.console.auth.error import EmailCodeAccountDeletionRateLimitExceededError
raise EmailCodeAccountDeletionRateLimitExceededError()
raise EmailCodeAccountDeletionRateLimitExceededError(
int(cls.email_code_account_deletion_rate_limiter.time_window / 60)
)
send_account_deletion_verification_code.delay(to=email, code=code)
@ -308,16 +321,16 @@ class AccountService:
return True
@staticmethod
def delete_account(account: Account) -> None:
def delete_account(account: Account):
"""Delete account. This method only adds a task to the queue for deletion."""
delete_account_task.delay(account.id)
@staticmethod
def link_account_integrate(provider: str, open_id: str, account: Account) -> None:
def link_account_integrate(provider: str, open_id: str, account: Account):
"""Link account integrate"""
try:
# Query whether there is an existing binding record for the same provider
account_integrate: Optional[AccountIntegrate] = (
account_integrate: AccountIntegrate | None = (
db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first()
)
@ -340,7 +353,7 @@ class AccountService:
raise LinkAccountIntegrateError("Failed to link account.") from e
@staticmethod
def close_account(account: Account) -> None:
def close_account(account: Account):
"""Close account"""
account.status = AccountStatus.CLOSED.value
db.session.commit()
@ -348,6 +361,7 @@ class AccountService:
@staticmethod
def update_account(account, **kwargs):
"""Update account fields"""
account = db.session.merge(account)
for field, value in kwargs.items():
if hasattr(account, field):
setattr(account, field, value)
@ -369,7 +383,7 @@ class AccountService:
return account
@staticmethod
def update_login_info(account: Account, *, ip_address: str) -> None:
def update_login_info(account: Account, *, ip_address: str):
"""Update last login time and ip"""
account.last_login_at = naive_utc_now()
account.last_login_ip = ip_address
@ -377,7 +391,7 @@ class AccountService:
db.session.commit()
@staticmethod
def login(account: Account, *, ip_address: Optional[str] = None) -> TokenPair:
def login(account: Account, *, ip_address: str | None = None) -> TokenPair:
if ip_address:
AccountService.update_login_info(account=account, ip_address=ip_address)
@ -393,7 +407,7 @@ class AccountService:
return TokenPair(access_token=access_token, refresh_token=refresh_token)
@staticmethod
def logout(*, account: Account) -> None:
def logout(*, account: Account):
refresh_token = redis_client.get(AccountService._get_account_refresh_token_key(account.id))
if refresh_token:
AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id)
@ -425,9 +439,10 @@ class AccountService:
@classmethod
def send_reset_password_email(
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
account: Account | None = None,
email: str | None = None,
language: str = "en-US",
is_allow_register: bool = False,
):
account_email = account.email if account else email
if account_email is None:
@ -436,26 +451,67 @@ class AccountService:
if cls.reset_password_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import PasswordResetRateLimitExceededError
raise PasswordResetRateLimitExceededError()
raise PasswordResetRateLimitExceededError(int(cls.reset_password_rate_limiter.time_window / 60))
code, token = cls.generate_reset_password_token(account_email, account)
send_reset_password_mail_task.delay(
language=language,
to=account_email,
code=code,
)
if account:
send_reset_password_mail_task.delay(
language=language,
to=account_email,
code=code,
)
else:
send_reset_password_mail_task_when_account_not_exist.delay(
language=language,
to=account_email,
is_allow_register=is_allow_register,
)
cls.reset_password_rate_limiter.increment_rate_limit(account_email)
return token
@classmethod
def send_email_register_email(
cls,
account: Account | None = None,
email: str | None = None,
language: str = "en-US",
):
account_email = account.email if account else email
if account_email is None:
raise ValueError("Email must be provided.")
if cls.email_register_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import EmailRegisterRateLimitExceededError
raise EmailRegisterRateLimitExceededError(int(cls.email_register_rate_limiter.time_window / 60))
code, token = cls.generate_email_register_token(account_email)
if account:
send_email_register_mail_task_when_account_exist.delay(
language=language,
to=account_email,
account_name=account.name,
)
else:
send_email_register_mail_task.delay(
language=language,
to=account_email,
code=code,
)
cls.email_register_rate_limiter.increment_rate_limit(account_email)
return token
@classmethod
def send_change_email_email(
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
old_email: Optional[str] = None,
account: Account | None = None,
email: str | None = None,
old_email: str | None = None,
language: str = "en-US",
phase: Optional[str] = None,
phase: str | None = None,
):
account_email = account.email if account else email
if account_email is None:
@ -466,7 +522,7 @@ class AccountService:
if cls.change_email_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import EmailChangeRateLimitExceededError
raise EmailChangeRateLimitExceededError()
raise EmailChangeRateLimitExceededError(int(cls.change_email_rate_limiter.time_window / 60))
code, token = cls.generate_change_email_token(account_email, account, old_email=old_email)
@ -482,8 +538,8 @@ class AccountService:
@classmethod
def send_change_email_completed_notify_email(
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
account: Account | None = None,
email: str | None = None,
language: str = "en-US",
):
account_email = account.email if account else email
@ -498,10 +554,10 @@ class AccountService:
@classmethod
def send_owner_transfer_email(
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
account: Account | None = None,
email: str | None = None,
language: str = "en-US",
workspace_name: Optional[str] = "",
workspace_name: str | None = "",
):
account_email = account.email if account else email
if account_email is None:
@ -510,7 +566,7 @@ class AccountService:
if cls.owner_transfer_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import OwnerTransferRateLimitExceededError
raise OwnerTransferRateLimitExceededError()
raise OwnerTransferRateLimitExceededError(int(cls.owner_transfer_rate_limiter.time_window / 60))
code, token = cls.generate_owner_transfer_token(account_email, account)
workspace_name = workspace_name or ""
@ -527,10 +583,10 @@ class AccountService:
@classmethod
def send_old_owner_transfer_notify_email(
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
account: Account | None = None,
email: str | None = None,
language: str = "en-US",
workspace_name: Optional[str] = "",
workspace_name: str | None = "",
new_owner_email: str = "",
):
account_email = account.email if account else email
@ -548,10 +604,10 @@ class AccountService:
@classmethod
def send_new_owner_transfer_notify_email(
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
account: Account | None = None,
email: str | None = None,
language: str = "en-US",
workspace_name: Optional[str] = "",
workspace_name: str | None = "",
):
account_email = account.email if account else email
if account_email is None:
@ -568,8 +624,8 @@ class AccountService:
def generate_reset_password_token(
cls,
email: str,
account: Optional[Account] = None,
code: Optional[str] = None,
account: Account | None = None,
code: str | None = None,
additional_data: dict[str, Any] = {},
):
if not code:
@ -580,13 +636,26 @@ class AccountService:
)
return code, token
@classmethod
def generate_email_register_token(
cls,
email: str,
code: str | None = None,
additional_data: dict[str, Any] = {},
):
if not code:
code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)])
additional_data["code"] = code
token = TokenManager.generate_token(email=email, token_type="email_register", additional_data=additional_data)
return code, token
@classmethod
def generate_change_email_token(
cls,
email: str,
account: Optional[Account] = None,
code: Optional[str] = None,
old_email: Optional[str] = None,
account: Account | None = None,
code: str | None = None,
old_email: str | None = None,
additional_data: dict[str, Any] = {},
):
if not code:
@ -602,8 +671,8 @@ class AccountService:
def generate_owner_transfer_token(
cls,
email: str,
account: Optional[Account] = None,
code: Optional[str] = None,
account: Account | None = None,
code: str | None = None,
additional_data: dict[str, Any] = {},
):
if not code:
@ -618,6 +687,10 @@ class AccountService:
def revoke_reset_password_token(cls, token: str):
TokenManager.revoke_token(token, "reset_password")
@classmethod
def revoke_email_register_token(cls, token: str):
TokenManager.revoke_token(token, "email_register")
@classmethod
def revoke_change_email_token(cls, token: str):
TokenManager.revoke_token(token, "change_email")
@ -627,22 +700,26 @@ class AccountService:
TokenManager.revoke_token(token, "owner_transfer")
@classmethod
def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]:
def get_reset_password_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "reset_password")
@classmethod
def get_change_email_data(cls, token: str) -> Optional[dict[str, Any]]:
def get_email_register_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "email_register")
@classmethod
def get_change_email_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "change_email")
@classmethod
def get_owner_transfer_data(cls, token: str) -> Optional[dict[str, Any]]:
def get_owner_transfer_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "owner_transfer")
@classmethod
def send_email_code_login_email(
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
account: Account | None = None,
email: str | None = None,
language: str = "en-US",
):
email = account.email if account else email
@ -651,7 +728,7 @@ class AccountService:
if cls.email_code_login_rate_limiter.is_rate_limited(email):
from controllers.console.auth.error import EmailCodeLoginRateLimitExceededError
raise EmailCodeLoginRateLimitExceededError()
raise EmailCodeLoginRateLimitExceededError(int(cls.email_code_login_rate_limiter.time_window / 60))
code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)])
token = TokenManager.generate_token(
@ -666,7 +743,7 @@ class AccountService:
return token
@classmethod
def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]:
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "email_code_login")
@classmethod
@ -700,7 +777,7 @@ class AccountService:
@staticmethod
@redis_fallback(default_return=None)
def add_login_error_rate_limit(email: str) -> None:
def add_login_error_rate_limit(email: str):
key = f"login_error_rate_limit:{email}"
count = redis_client.get(key)
if count is None:
@ -729,7 +806,7 @@ class AccountService:
@staticmethod
@redis_fallback(default_return=None)
def add_forgot_password_error_rate_limit(email: str) -> None:
def add_forgot_password_error_rate_limit(email: str):
key = f"forgot_password_error_rate_limit:{email}"
count = redis_client.get(key)
if count is None:
@ -737,6 +814,16 @@ class AccountService:
count = int(count) + 1
redis_client.setex(key, dify_config.FORGOT_PASSWORD_LOCKOUT_DURATION, count)
@staticmethod
@redis_fallback(default_return=None)
def add_email_register_error_rate_limit(email: str) -> None:
key = f"email_register_error_rate_limit:{email}"
count = redis_client.get(key)
if count is None:
count = 0
count = int(count) + 1
redis_client.setex(key, dify_config.EMAIL_REGISTER_LOCKOUT_DURATION, count)
@staticmethod
@redis_fallback(default_return=False)
def is_forgot_password_error_rate_limit(email: str) -> bool:
@ -756,9 +843,27 @@ class AccountService:
key = f"forgot_password_error_rate_limit:{email}"
redis_client.delete(key)
@staticmethod
@redis_fallback(default_return=False)
def is_email_register_error_rate_limit(email: str) -> bool:
key = f"email_register_error_rate_limit:{email}"
count = redis_client.get(key)
if count is None:
return False
count = int(count)
if count > AccountService.EMAIL_REGISTER_MAX_ERROR_LIMITS:
return True
return False
@staticmethod
@redis_fallback(default_return=None)
def add_change_email_error_rate_limit(email: str) -> None:
def reset_email_register_error_rate_limit(email: str):
key = f"email_register_error_rate_limit:{email}"
redis_client.delete(key)
@staticmethod
@redis_fallback(default_return=None)
def add_change_email_error_rate_limit(email: str):
key = f"change_email_error_rate_limit:{email}"
count = redis_client.get(key)
if count is None:
@ -786,7 +891,7 @@ class AccountService:
@staticmethod
@redis_fallback(default_return=None)
def add_owner_transfer_error_rate_limit(email: str) -> None:
def add_owner_transfer_error_rate_limit(email: str):
key = f"owner_transfer_error_rate_limit:{email}"
count = redis_client.get(key)
if count is None:
@ -860,7 +965,7 @@ class AccountService:
class TenantService:
@staticmethod
def create_tenant(name: str, is_setup: Optional[bool] = False, is_from_dashboard: Optional[bool] = False) -> Tenant:
def create_tenant(name: str, is_setup: bool | None = False, is_from_dashboard: bool | None = False) -> Tenant:
"""Create tenant"""
if (
not FeatureService.get_system_features().is_allow_create_workspace
@ -891,9 +996,7 @@ class TenantService:
return tenant
@staticmethod
def create_owner_tenant_if_not_exist(
account: Account, name: Optional[str] = None, is_setup: Optional[bool] = False
):
def create_owner_tenant_if_not_exist(account: Account, name: str | None = None, is_setup: bool | None = False):
"""Check if user have a workspace or not"""
available_ta = (
db.session.query(TenantAccountJoin)
@ -938,6 +1041,8 @@ class TenantService:
db.session.add(ta)
db.session.commit()
if dify_config.BILLING_ENABLED:
BillingService.clean_billing_info_cache(tenant.id)
return ta
@staticmethod
@ -965,7 +1070,7 @@ class TenantService:
return tenant
@staticmethod
def switch_tenant(account: Account, tenant_id: Optional[str] = None) -> None:
def switch_tenant(account: Account, tenant_id: str | None = None):
"""Switch the current workspace for the account"""
# Ensure tenant_id is provided
@ -1047,7 +1152,7 @@ class TenantService:
)
@staticmethod
def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountRole]:
def get_user_role(account: Account, tenant: Tenant) -> TenantAccountRole | None:
"""Get the role of the current account for a given tenant"""
join = (
db.session.query(TenantAccountJoin)
@ -1062,7 +1167,7 @@ class TenantService:
return cast(int, db.session.query(func.count(Tenant.id)).scalar())
@staticmethod
def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str) -> None:
def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str):
"""Check member permission"""
perms = {
"add": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN],
@ -1082,7 +1187,7 @@ class TenantService:
raise NoPermissionError(f"No permission to {action} member.")
@staticmethod
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None:
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account):
"""Remove member from tenant"""
if operator.id == account.id:
raise CannotOperateSelfError("Cannot operate self.")
@ -1096,8 +1201,11 @@ class TenantService:
db.session.delete(ta)
db.session.commit()
if dify_config.BILLING_ENABLED:
BillingService.clean_billing_info_cache(tenant.id)
@staticmethod
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account) -> None:
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account):
"""Update member role"""
TenantService.check_member_permission(tenant, operator, member, "update")
@ -1124,7 +1232,7 @@ class TenantService:
db.session.commit()
@staticmethod
def get_custom_config(tenant_id: str) -> dict:
def get_custom_config(tenant_id: str):
tenant = db.get_or_404(Tenant, tenant_id)
return tenant.custom_config_dict
@ -1145,7 +1253,7 @@ class RegisterService:
return f"member_invite:token:{token}"
@classmethod
def setup(cls, email: str, name: str, password: str, ip_address: str) -> None:
def setup(cls, email: str, name: str, password: str, ip_address: str):
"""
Setup dify
@ -1187,13 +1295,13 @@ class RegisterService:
cls,
email,
name,
password: Optional[str] = None,
open_id: Optional[str] = None,
provider: Optional[str] = None,
language: Optional[str] = None,
status: Optional[AccountStatus] = None,
is_setup: Optional[bool] = False,
create_workspace_required: Optional[bool] = True,
password: str | None = None,
open_id: str | None = None,
provider: str | None = None,
language: str | None = None,
status: AccountStatus | None = None,
is_setup: bool | None = False,
create_workspace_required: bool | None = True,
) -> Account:
db.session.begin_nested()
"""Register account"""
@ -1310,10 +1418,8 @@ class RegisterService:
redis_client.delete(cls._get_invitation_token_key(token))
@classmethod
def get_invitation_if_token_valid(
cls, workspace_id: Optional[str], email: str, token: str
) -> Optional[dict[str, Any]]:
invitation_data = cls._get_invitation_by_token(token, workspace_id, email)
def get_invitation_if_token_valid(cls, workspace_id: str | None, email: str, token: str) -> dict[str, Any] | None:
invitation_data = cls.get_invitation_by_token(token, workspace_id, email)
if not invitation_data:
return None
@ -1350,9 +1456,9 @@ class RegisterService:
}
@classmethod
def _get_invitation_by_token(
cls, token: str, workspace_id: Optional[str] = None, email: Optional[str] = None
) -> Optional[dict[str, str]]:
def get_invitation_by_token(
cls, token: str, workspace_id: str | None = None, email: str | None = None
) -> dict[str, str] | None:
if workspace_id is not None and email is not None:
email_hash = sha256(email.encode()).hexdigest()
cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}"

View File

@ -17,7 +17,7 @@ from models.model import AppMode
class AdvancedPromptTemplateService:
@classmethod
def get_prompt(cls, args: dict) -> dict:
def get_prompt(cls, args: dict):
app_mode = args["app_mode"]
model_mode = args["model_mode"]
model_name = args["model_name"]
@ -29,17 +29,17 @@ class AdvancedPromptTemplateService:
return cls.get_common_prompt(app_mode, model_mode, has_context)
@classmethod
def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict:
def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str):
context_prompt = copy.deepcopy(CONTEXT)
if app_mode == AppMode.CHAT.value:
if app_mode == AppMode.CHAT:
if model_mode == "completion":
return cls.get_completion_prompt(
copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
)
elif model_mode == "chat":
return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
elif app_mode == AppMode.COMPLETION.value:
elif app_mode == AppMode.COMPLETION:
if model_mode == "completion":
return cls.get_completion_prompt(
copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
@ -52,7 +52,7 @@ class AdvancedPromptTemplateService:
return {}
@classmethod
def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str):
if has_context == "true":
prompt_template["completion_prompt_config"]["prompt"]["text"] = (
context + prompt_template["completion_prompt_config"]["prompt"]["text"]
@ -61,7 +61,7 @@ class AdvancedPromptTemplateService:
return prompt_template
@classmethod
def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str):
if has_context == "true":
prompt_template["chat_prompt_config"]["prompt"][0]["text"] = (
context + prompt_template["chat_prompt_config"]["prompt"][0]["text"]
@ -70,10 +70,10 @@ class AdvancedPromptTemplateService:
return prompt_template
@classmethod
def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict:
def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str):
baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)
if app_mode == AppMode.CHAT.value:
if app_mode == AppMode.CHAT:
if model_mode == "completion":
return cls.get_completion_prompt(
copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt
@ -82,7 +82,7 @@ class AdvancedPromptTemplateService:
return cls.get_chat_prompt(
copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt
)
elif app_mode == AppMode.COMPLETION.value:
elif app_mode == AppMode.COMPLETION:
if model_mode == "completion":
return cls.get_completion_prompt(
copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG),

View File

@ -1,8 +1,7 @@
import threading
from typing import Optional
from typing import Any
import pytz
from flask_login import current_user
import contexts
from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager
@ -10,13 +9,14 @@ from core.plugin.impl.agent import PluginAgentClient
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.tools.tool_manager import ToolManager
from extensions.ext_database import db
from libs.login import current_user
from models.account import Account
from models.model import App, Conversation, EndUser, Message, MessageAgentThought
class AgentService:
@classmethod
def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) -> dict:
def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str):
"""
Service to get agent logs
"""
@ -35,7 +35,7 @@ class AgentService:
if not conversation:
raise ValueError(f"Conversation not found: {conversation_id}")
message: Optional[Message] = (
message: Message | None = (
db.session.query(Message)
.where(
Message.id == message_id,
@ -61,14 +61,15 @@ class AgentService:
executor = executor.name
else:
executor = "Unknown"
assert isinstance(current_user, Account)
assert current_user.timezone is not None
timezone = pytz.timezone(current_user.timezone)
app_model_config = app_model.app_model_config
if not app_model_config:
raise ValueError("App model config not found")
result = {
result: dict[str, Any] = {
"meta": {
"status": "success",
"executor": executor,

View File

@ -1,8 +1,6 @@
import uuid
from typing import Optional
import pandas as pd
from flask_login import current_user
from sqlalchemy import or_, select
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
@ -10,6 +8,8 @@ from werkzeug.exceptions import NotFound
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models.account import Account
from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation
from services.feature_service import FeatureService
from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task
@ -24,6 +24,7 @@ class AppAnnotationService:
@classmethod
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
# get app info
assert isinstance(current_user, Account)
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@ -40,7 +41,7 @@ class AppAnnotationService:
if not message:
raise NotFound("Message Not Exists.")
annotation: Optional[MessageAnnotation] = message.annotation
annotation: MessageAnnotation | None = message.annotation
# save the message annotation
if annotation:
annotation.content = args["answer"]
@ -62,6 +63,7 @@ class AppAnnotationService:
db.session.commit()
# if annotation reply is enabled , add annotation to index
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
assert current_user.current_tenant_id is not None
if annotation_setting:
add_annotation_to_index_task.delay(
annotation.id,
@ -73,7 +75,7 @@ class AppAnnotationService:
return annotation
@classmethod
def enable_app_annotation(cls, args: dict, app_id: str) -> dict:
def enable_app_annotation(cls, args: dict, app_id: str):
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
cache_result = redis_client.get(enable_app_annotation_key)
if cache_result is not None:
@ -84,6 +86,8 @@ class AppAnnotationService:
enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"
# send batch add segments task
redis_client.setnx(enable_app_annotation_job_key, "waiting")
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
enable_annotation_reply_task.delay(
str(job_id),
app_id,
@ -96,7 +100,9 @@ class AppAnnotationService:
return {"job_id": job_id, "job_status": "waiting"}
@classmethod
def disable_app_annotation(cls, app_id: str) -> dict:
def disable_app_annotation(cls, app_id: str):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
cache_result = redis_client.get(disable_app_annotation_key)
if cache_result is not None:
@ -113,6 +119,8 @@ class AppAnnotationService:
@classmethod
def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str):
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@ -145,6 +153,8 @@ class AppAnnotationService:
@classmethod
def export_annotation_list_by_app_id(cls, app_id: str):
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@ -164,6 +174,8 @@ class AppAnnotationService:
@classmethod
def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@ -193,6 +205,8 @@ class AppAnnotationService:
@classmethod
def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@ -230,6 +244,8 @@ class AppAnnotationService:
@classmethod
def delete_app_annotation(cls, app_id: str, annotation_id: str):
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@ -246,11 +262,9 @@ class AppAnnotationService:
db.session.delete(annotation)
annotation_hit_histories = (
db.session.query(AppAnnotationHitHistory)
.where(AppAnnotationHitHistory.annotation_id == annotation_id)
.all()
)
annotation_hit_histories = db.session.scalars(
select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.annotation_id == annotation_id)
).all()
if annotation_hit_histories:
for annotation_hit_history in annotation_hit_histories:
db.session.delete(annotation_hit_history)
@ -269,6 +283,8 @@ class AppAnnotationService:
@classmethod
def delete_app_annotations_in_batch(cls, app_id: str, annotation_ids: list[str]):
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@ -315,8 +331,10 @@ class AppAnnotationService:
return {"deleted_count": deleted_count}
@classmethod
def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict:
def batch_import_app_annotations(cls, app_id, file: FileStorage):
# get app info
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
@ -328,9 +346,9 @@ class AppAnnotationService:
try:
# Skip the first row
df = pd.read_csv(file, dtype=str)
df = pd.read_csv(file.stream, dtype=str)
result = []
for index, row in df.iterrows():
for _, row in df.iterrows():
content = {"question": row.iloc[0], "answer": row.iloc[1]}
result.append(content)
if len(result) == 0:
@ -355,6 +373,8 @@ class AppAnnotationService:
@classmethod
def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# get app info
app = (
db.session.query(App)
@ -425,6 +445,8 @@ class AppAnnotationService:
@classmethod
def get_app_annotation_setting_by_app_id(cls, app_id: str):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# get app info
app = (
db.session.query(App)
@ -438,19 +460,29 @@ class AppAnnotationService:
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
if annotation_setting:
collection_binding_detail = annotation_setting.collection_binding_detail
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name,
},
}
if collection_binding_detail:
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name,
},
}
else:
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {},
}
return {"enabled": False}
@classmethod
def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# get app info
app = (
db.session.query(App)
@ -479,18 +511,28 @@ class AppAnnotationService:
collection_binding_detail = annotation_setting.collection_binding_detail
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name,
},
}
if collection_binding_detail:
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {
"embedding_provider_name": collection_binding_detail.provider_name,
"embedding_model_name": collection_binding_detail.model_name,
},
}
else:
return {
"id": annotation_setting.id,
"enabled": True,
"score_threshold": annotation_setting.score_threshold,
"embedding_model": {},
}
@classmethod
def clear_all_annotations(cls, app_id: str) -> dict:
def clear_all_annotations(cls, app_id: str):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")

View File

@ -30,7 +30,7 @@ class APIBasedExtensionService:
return extension_data
@staticmethod
def delete(extension_data: APIBasedExtension) -> None:
def delete(extension_data: APIBasedExtension):
db.session.delete(extension_data)
db.session.commit()
@ -51,7 +51,7 @@ class APIBasedExtensionService:
return extension
@classmethod
def _validation(cls, extension_data: APIBasedExtension) -> None:
def _validation(cls, extension_data: APIBasedExtension):
# name
if not extension_data.name:
raise ValueError("name must not be empty")
@ -95,7 +95,7 @@ class APIBasedExtensionService:
cls._ping_connection(extension_data)
@staticmethod
def _ping_connection(extension_data: APIBasedExtension) -> None:
def _ping_connection(extension_data: APIBasedExtension):
try:
client = APIBasedExtensionRequestor(extension_data.api_endpoint, extension_data.api_key)
resp = client.request(point=APIBasedExtensionPoint.PING, params={})

View File

@ -4,7 +4,6 @@ import logging
import uuid
from collections.abc import Mapping
from enum import StrEnum
from typing import Optional
from urllib.parse import urlparse
from uuid import uuid4
@ -17,10 +16,11 @@ from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from core.helper import ssrf_proxy
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import PluginDependency
from core.workflow.nodes.enums import NodeType
from core.workflow.enums import NodeType
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
from core.workflow.nodes.llm.entities import LLMNodeData
from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData
@ -61,8 +61,8 @@ class ImportStatus(StrEnum):
class Import(BaseModel):
id: str
status: ImportStatus
app_id: Optional[str] = None
app_mode: Optional[str] = None
app_id: str | None = None
app_mode: str | None = None
current_dsl_version: str = CURRENT_DSL_VERSION
imported_dsl_version: str = ""
error: str = ""
@ -99,17 +99,17 @@ def _check_version_compatibility(imported_version: str) -> ImportStatus:
class PendingData(BaseModel):
import_mode: str
yaml_content: str
name: str | None
description: str | None
icon_type: str | None
icon: str | None
icon_background: str | None
app_id: str | None
name: str | None = None
description: str | None = None
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
app_id: str | None = None
class CheckDependenciesPendingData(BaseModel):
dependencies: list[PluginDependency]
app_id: str | None
app_id: str | None = None
class AppDslService:
@ -121,14 +121,14 @@ class AppDslService:
*,
account: Account,
import_mode: str,
yaml_content: Optional[str] = None,
yaml_url: Optional[str] = None,
name: Optional[str] = None,
description: Optional[str] = None,
icon_type: Optional[str] = None,
icon: Optional[str] = None,
icon_background: Optional[str] = None,
app_id: Optional[str] = None,
yaml_content: str | None = None,
yaml_url: str | None = None,
name: str | None = None,
description: str | None = None,
icon_type: str | None = None,
icon: str | None = None,
icon_background: str | None = None,
app_id: str | None = None,
) -> Import:
"""Import an app from YAML content or URL."""
import_id = str(uuid.uuid4())
@ -407,15 +407,15 @@ class AppDslService:
def _create_or_update_app(
self,
*,
app: Optional[App],
app: App | None,
data: dict,
account: Account,
name: Optional[str] = None,
description: Optional[str] = None,
icon_type: Optional[str] = None,
icon: Optional[str] = None,
icon_background: Optional[str] = None,
dependencies: Optional[list[PluginDependency]] = None,
name: str | None = None,
description: str | None = None,
icon_type: str | None = None,
icon: str | None = None,
icon_background: str | None = None,
dependencies: list[PluginDependency] | None = None,
) -> App:
"""Create a new app or update an existing one."""
app_data = data.get("app", {})
@ -533,7 +533,7 @@ class AppDslService:
return app
@classmethod
def export_dsl(cls, app_model: App, include_secret: bool = False) -> str:
def export_dsl(cls, app_model: App, include_secret: bool = False, workflow_id: str | None = None) -> str:
"""
Export app
:param app_model: App instance
@ -557,7 +557,7 @@ class AppDslService:
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
cls._append_workflow_export_data(
export_data=export_data, app_model=app_model, include_secret=include_secret
export_data=export_data, app_model=app_model, include_secret=include_secret, workflow_id=workflow_id
)
else:
cls._append_model_config_export_data(export_data, app_model)
@ -565,14 +565,16 @@ class AppDslService:
return yaml.dump(export_data, allow_unicode=True) # type: ignore
@classmethod
def _append_workflow_export_data(cls, *, export_data: dict, app_model: App, include_secret: bool) -> None:
def _append_workflow_export_data(
cls, *, export_data: dict, app_model: App, include_secret: bool, workflow_id: str | None = None
):
"""
Append workflow export data
:param export_data: export data
:param app_model: App instance
"""
workflow_service = WorkflowService()
workflow = workflow_service.get_draft_workflow(app_model)
workflow = workflow_service.get_draft_workflow(app_model, workflow_id)
if not workflow:
raise ValueError("Missing draft workflow configuration, please check.")
@ -614,7 +616,7 @@ class AppDslService:
]
@classmethod
def _append_model_config_export_data(cls, export_data: dict, app_model: App) -> None:
def _append_model_config_export_data(cls, export_data: dict, app_model: App):
"""
Append model config export data
:param export_data: export data
@ -792,7 +794,10 @@ class AppDslService:
@classmethod
def encrypt_dataset_id(cls, dataset_id: str, tenant_id: str) -> str:
"""Encrypt dataset_id using AES-CBC mode"""
"""Encrypt dataset_id using AES-CBC mode or return plain text based on configuration"""
if not dify_config.DSL_EXPORT_ENCRYPT_DATASET_ID:
return dataset_id
key = cls._generate_aes_key(tenant_id)
iv = key[:16]
cipher = AES.new(key, AES.MODE_CBC, iv)
@ -801,12 +806,34 @@ class AppDslService:
@classmethod
def decrypt_dataset_id(cls, encrypted_data: str, tenant_id: str) -> str | None:
"""AES decryption"""
"""AES decryption with fallback to plain text UUID"""
# First, check if it's already a plain UUID (not encrypted)
if cls._is_valid_uuid(encrypted_data):
return encrypted_data
# If it's not a UUID, try to decrypt it
try:
key = cls._generate_aes_key(tenant_id)
iv = key[:16]
cipher = AES.new(key, AES.MODE_CBC, iv)
pt = unpad(cipher.decrypt(base64.b64decode(encrypted_data)), AES.block_size)
return pt.decode()
decrypted_text = pt.decode()
# Validate that the decrypted result is a valid UUID
if cls._is_valid_uuid(decrypted_text):
return decrypted_text
else:
# If decrypted result is not a valid UUID, it's probably not our encrypted data
return None
except Exception:
# If decryption fails completely, return None
return None
@staticmethod
def _is_valid_uuid(value: str) -> bool:
"""Check if string is a valid UUID format"""
try:
uuid.UUID(value)
return True
except (ValueError, TypeError):
return False

View File

@ -1,6 +1,6 @@
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Optional, Union
from typing import Any, Union
from openai._exceptions import RateLimitError
@ -60,7 +60,7 @@ class AppGenerateService:
request_id = RateLimit.gen_request_key()
try:
request_id = rate_limit.enter(request_id)
if app_model.mode == AppMode.COMPLETION.value:
if app_model.mode == AppMode.COMPLETION:
return rate_limit.generate(
CompletionAppGenerator.convert_to_event_stream(
CompletionAppGenerator().generate(
@ -69,7 +69,7 @@ class AppGenerateService:
),
request_id=request_id,
)
elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
elif app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
return rate_limit.generate(
AgentChatAppGenerator.convert_to_event_stream(
AgentChatAppGenerator().generate(
@ -78,7 +78,7 @@ class AppGenerateService:
),
request_id,
)
elif app_model.mode == AppMode.CHAT.value:
elif app_model.mode == AppMode.CHAT:
return rate_limit.generate(
ChatAppGenerator.convert_to_event_stream(
ChatAppGenerator().generate(
@ -87,7 +87,7 @@ class AppGenerateService:
),
request_id=request_id,
)
elif app_model.mode == AppMode.ADVANCED_CHAT.value:
elif app_model.mode == AppMode.ADVANCED_CHAT:
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
return rate_limit.generate(
@ -103,7 +103,7 @@ class AppGenerateService:
),
request_id=request_id,
)
elif app_model.mode == AppMode.WORKFLOW.value:
elif app_model.mode == AppMode.WORKFLOW:
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
return rate_limit.generate(
@ -116,7 +116,6 @@ class AppGenerateService:
invoke_from=invoke_from,
streaming=streaming,
call_depth=0,
workflow_thread_pool_id=None,
),
),
request_id,
@ -155,14 +154,14 @@ class AppGenerateService:
@classmethod
def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
if app_model.mode == AppMode.ADVANCED_CHAT.value:
if app_model.mode == AppMode.ADVANCED_CHAT:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().single_iteration_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
)
elif app_model.mode == AppMode.WORKFLOW.value:
elif app_model.mode == AppMode.WORKFLOW:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_iteration_generate(
@ -174,14 +173,14 @@ class AppGenerateService:
@classmethod
def generate_single_loop(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
if app_model.mode == AppMode.ADVANCED_CHAT.value:
if app_model.mode == AppMode.ADVANCED_CHAT:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
AdvancedChatAppGenerator().single_loop_generate(
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
)
elif app_model.mode == AppMode.WORKFLOW.value:
elif app_model.mode == AppMode.WORKFLOW:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_loop_generate(
@ -214,7 +213,7 @@ class AppGenerateService:
)
@classmethod
def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom, workflow_id: Optional[str] = None) -> Workflow:
def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom, workflow_id: str | None = None) -> Workflow:
"""
Get workflow
:param app_model: app model
@ -227,7 +226,7 @@ class AppGenerateService:
# If workflow_id is specified, get the specific workflow version
if workflow_id:
try:
workflow_uuid = uuid.UUID(workflow_id)
_ = uuid.UUID(workflow_id)
except ValueError:
raise WorkflowIdFormatError(f"Invalid workflow_id format: '{workflow_id}'. ")
workflow = workflow_service.get_published_workflow_by_id(app_model=app_model, workflow_id=workflow_id)

View File

@ -6,7 +6,7 @@ from models.model import AppMode
class AppModelConfigService:
@classmethod
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> dict:
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode):
if app_mode == AppMode.CHAT:
return ChatAppConfigManager.config_validate(tenant_id, config)
elif app_mode == AppMode.AGENT_CHAT:

View File

@ -1,8 +1,7 @@
import json
import logging
from typing import Optional, TypedDict, cast
from typing import TypedDict, cast
from flask_login import current_user
from flask_sqlalchemy.pagination import Pagination
from configs import dify_config
@ -17,9 +16,11 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_was_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models.account import Account
from models.model import App, AppMode, AppModelConfig, Site
from models.tools import ApiToolProvider
from services.billing_service import BillingService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
from services.tag_service import TagService
@ -40,15 +41,15 @@ class AppService:
filters = [App.tenant_id == tenant_id, App.is_universal == False]
if args["mode"] == "workflow":
filters.append(App.mode == AppMode.WORKFLOW.value)
filters.append(App.mode == AppMode.WORKFLOW)
elif args["mode"] == "completion":
filters.append(App.mode == AppMode.COMPLETION.value)
filters.append(App.mode == AppMode.COMPLETION)
elif args["mode"] == "chat":
filters.append(App.mode == AppMode.CHAT.value)
filters.append(App.mode == AppMode.CHAT)
elif args["mode"] == "advanced-chat":
filters.append(App.mode == AppMode.ADVANCED_CHAT.value)
filters.append(App.mode == AppMode.ADVANCED_CHAT)
elif args["mode"] == "agent-chat":
filters.append(App.mode == AppMode.AGENT_CHAT.value)
filters.append(App.mode == AppMode.AGENT_CHAT)
if args.get("is_created_by_me", False):
filters.append(App.created_by == user_id)
@ -96,7 +97,7 @@ class AppService:
)
except (ProviderTokenNotInitError, LLMBadRequestError):
model_instance = None
except Exception as e:
except Exception:
logger.exception("Get default model instance failed, tenant_id: %s", tenant_id)
model_instance = None
@ -162,15 +163,22 @@ class AppService:
# update web app setting as private
EnterpriseService.WebAppAuth.update_app_access_mode(app.id, "private")
if dify_config.BILLING_ENABLED:
BillingService.clean_billing_info_cache(app.tenant_id)
return app
def get_app(self, app: App) -> App:
"""
Get App
"""
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# get original app model config
if app.mode == AppMode.AGENT_CHAT.value or app.is_agent:
if app.mode == AppMode.AGENT_CHAT or app.is_agent:
model_config = app.app_model_config
if not model_config:
return app
agent_mode = model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
for tool in agent_mode.get("tools") or []:
@ -201,11 +209,12 @@ class AppService:
# override tool parameters
tool["tool_parameters"] = masked_parameter
except Exception as e:
except Exception:
pass
# override agent mode
model_config.agent_mode = json.dumps(agent_mode)
if model_config:
model_config.agent_mode = json.dumps(agent_mode)
class ModifiedApp(App):
"""
@ -239,6 +248,7 @@ class AppService:
:param args: request args
:return: App instance
"""
assert current_user is not None
app.name = args["name"]
app.description = args["description"]
app.icon_type = args["icon_type"]
@ -259,6 +269,7 @@ class AppService:
:param name: new name
:return: App instance
"""
assert current_user is not None
app.name = name
app.updated_by = current_user.id
app.updated_at = naive_utc_now()
@ -274,6 +285,7 @@ class AppService:
:param icon_background: new icon_background
:return: App instance
"""
assert current_user is not None
app.icon = icon
app.icon_background = icon_background
app.updated_by = current_user.id
@ -291,7 +303,7 @@ class AppService:
"""
if enable_site == app.enable_site:
return app
assert current_user is not None
app.enable_site = enable_site
app.updated_by = current_user.id
app.updated_at = naive_utc_now()
@ -308,6 +320,7 @@ class AppService:
"""
if enable_api == app.enable_api:
return app
assert current_user is not None
app.enable_api = enable_api
app.updated_by = current_user.id
@ -316,7 +329,7 @@ class AppService:
return app
def delete_app(self, app: App) -> None:
def delete_app(self, app: App):
"""
Delete app
:param app: App instance
@ -328,10 +341,13 @@ class AppService:
if FeatureService.get_system_features().webapp_auth.enabled:
EnterpriseService.WebAppAuth.cleanup_webapp(app.id)
if dify_config.BILLING_ENABLED:
BillingService.clean_billing_info_cache(app.tenant_id)
# Trigger asynchronous deletion of app and related data
remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id)
def get_app_meta(self, app_model: App) -> dict:
def get_app_meta(self, app_model: App):
"""
Get app meta info
:param app_model: app model
@ -361,7 +377,7 @@ class AppService:
}
)
else:
app_model_config: Optional[AppModelConfig] = app_model.app_model_config
app_model_config: AppModelConfig | None = app_model.app_model_config
if not app_model_config:
return meta
@ -384,7 +400,7 @@ class AppService:
meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon"
elif provider_type == "api":
try:
provider: Optional[ApiToolProvider] = (
provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider).where(ApiToolProvider.id == provider_id).first()
)
if provider is None:

View File

@ -2,7 +2,6 @@ import io
import logging
import uuid
from collections.abc import Generator
from typing import Optional
from flask import Response, stream_with_context
from werkzeug.datastructures import FileStorage
@ -12,7 +11,7 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from models.enums import MessageStatus
from models.model import App, AppMode, AppModelConfig, Message
from models.model import App, AppMode, Message
from services.errors.audio import (
AudioTooLargeServiceError,
NoAudioUploadedServiceError,
@ -30,8 +29,8 @@ logger = logging.getLogger(__name__)
class AudioService:
@classmethod
def transcript_asr(cls, app_model: App, file: FileStorage, end_user: Optional[str] = None):
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
def transcript_asr(cls, app_model: App, file: FileStorage, end_user: str | None = None):
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app_model.workflow
if workflow is None:
raise ValueError("Speech to text is not enabled")
@ -40,7 +39,9 @@ class AudioService:
if "speech_to_text" not in features_dict or not features_dict["speech_to_text"].get("enabled"):
raise ValueError("Speech to text is not enabled")
else:
app_model_config: AppModelConfig = app_model.app_model_config
app_model_config = app_model.app_model_config
if not app_model_config:
raise ValueError("Speech to text is not enabled")
if not app_model_config.speech_to_text_dict["enabled"]:
raise ValueError("Speech to text is not enabled")
@ -75,18 +76,18 @@ class AudioService:
def transcript_tts(
cls,
app_model: App,
text: Optional[str] = None,
voice: Optional[str] = None,
end_user: Optional[str] = None,
message_id: Optional[str] = None,
text: str | None = None,
voice: str | None = None,
end_user: str | None = None,
message_id: str | None = None,
is_draft: bool = False,
):
from app import app
def invoke_tts(text_content: str, app_model: App, voice: Optional[str] = None, is_draft: bool = False):
def invoke_tts(text_content: str, app_model: App, voice: str | None = None, is_draft: bool = False):
with app.app_context():
if voice is None:
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
if is_draft:
workflow = WorkflowService().get_draft_workflow(app_model=app_model)
else:

View File

@ -1,5 +1,7 @@
import json
from sqlalchemy import select
from core.helper import encrypter
from extensions.ext_database import db
from models.source import DataSourceApiKeyAuthBinding
@ -8,12 +10,12 @@ from services.auth.api_key_auth_factory import ApiKeyAuthFactory
class ApiKeyAuthService:
@staticmethod
def get_provider_auth_list(tenant_id: str) -> list:
data_source_api_key_bindings = (
db.session.query(DataSourceApiKeyAuthBinding)
.where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False))
.all()
)
def get_provider_auth_list(tenant_id: str):
data_source_api_key_bindings = db.session.scalars(
select(DataSourceApiKeyAuthBinding).where(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)
)
).all()
return data_source_api_key_bindings
@staticmethod

View File

@ -1,6 +1,6 @@
import json
import requests
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
@ -36,7 +36,7 @@ class FirecrawlAuth(ApiKeyAuthBase):
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _post_request(self, url, data, headers):
return requests.post(url, headers=headers, json=data)
return httpx.post(url, headers=headers, json=data)
def _handle_error(self, response):
if response.status_code in {402, 409, 500}:

View File

@ -1,6 +1,6 @@
import json
import requests
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
@ -31,7 +31,7 @@ class JinaAuth(ApiKeyAuthBase):
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _post_request(self, url, data, headers):
return requests.post(url, headers=headers, json=data)
return httpx.post(url, headers=headers, json=data)
def _handle_error(self, response):
if response.status_code in {402, 409, 500}:

View File

@ -1,6 +1,6 @@
import json
import requests
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
@ -31,7 +31,7 @@ class JinaAuth(ApiKeyAuthBase):
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _post_request(self, url, data, headers):
return requests.post(url, headers=headers, json=data)
return httpx.post(url, headers=headers, json=data)
def _handle_error(self, response):
if response.status_code in {402, 409, 500}:

View File

@ -1,7 +1,7 @@
import json
from urllib.parse import urljoin
import requests
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
@ -31,7 +31,7 @@ class WatercrawlAuth(ApiKeyAuthBase):
return {"Content-Type": "application/json", "X-API-KEY": self.api_key}
def _get_request(self, url, headers):
return requests.get(url, headers=headers)
return httpx.get(url, headers=headers)
def _handle_error(self, response):
if response.status_code in {402, 409, 500}:

View File

@ -1,10 +1,11 @@
import os
from typing import Literal, Optional
from typing import Literal
import httpx
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.helper import RateLimiter
from models.account import Account, TenantAccountJoin, TenantAccountRole
@ -70,10 +71,10 @@ class BillingService:
return response.json()
@staticmethod
def is_tenant_owner_or_admin(current_user):
def is_tenant_owner_or_admin(current_user: Account):
tenant_id = current_user.current_tenant_id
join: Optional[TenantAccountJoin] = (
join: TenantAccountJoin | None = (
db.session.query(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id)
.first()
@ -173,3 +174,7 @@ class BillingService:
res = cls._send_request("POST", "/compliance/download", json=json)
cls.compliance_download_rate_limiter.increment_rate_limit(limiter_key)
return res
@classmethod
def clean_billing_info_cache(cls, tenant_id: str):
redis_client.delete(f"tenant:{tenant_id}:billing_info")

View File

@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor
import click
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
@ -34,7 +35,7 @@ logger = logging.getLogger(__name__)
class ClearFreePlanTenantExpiredLogs:
@classmethod
def _clear_message_related_tables(cls, session: Session, tenant_id: str, batch_message_ids: list[str]) -> None:
def _clear_message_related_tables(cls, session: Session, tenant_id: str, batch_message_ids: list[str]):
"""
Clean up message-related tables to avoid data redundancy.
This method cleans up tables that have foreign key relationships with Message.
@ -115,7 +116,7 @@ class ClearFreePlanTenantExpiredLogs:
@classmethod
def process_tenant(cls, flask_app: Flask, tenant_id: str, days: int, batch: int):
with flask_app.app_context():
apps = db.session.query(App).where(App.tenant_id == tenant_id).all()
apps = db.session.scalars(select(App).where(App.tenant_id == tenant_id)).all()
app_ids = [app.id for app in apps]
while True:
with Session(db.engine).no_autoflush as session:
@ -353,7 +354,7 @@ class ClearFreePlanTenantExpiredLogs:
thread_pool = ThreadPoolExecutor(max_workers=10)
def process_tenant(flask_app: Flask, tenant_id: str) -> None:
def process_tenant(flask_app: Flask, tenant_id: str):
try:
if (
not dify_config.BILLING_ENABLED
@ -407,6 +408,7 @@ class ClearFreePlanTenantExpiredLogs:
datetime.timedelta(hours=1),
]
tenant_count = 0
for test_interval in test_intervals:
tenant_count = (
session.query(Tenant.id)

View File

@ -3,7 +3,7 @@ from extensions.ext_code_based_extension import code_based_extension
class CodeBasedExtensionService:
@staticmethod
def get_code_based_extension(module: str) -> list[dict]:
def get_code_based_extension(module: str):
module_extensions = code_based_extension.module_extensions(module)
return [
{

View File

@ -1,7 +1,7 @@
import contextlib
import logging
from collections.abc import Callable, Sequence
from typing import Any, Optional, Union
from typing import Any, Union
from sqlalchemy import asc, desc, func, or_, select
from sqlalchemy.orm import Session
@ -36,12 +36,12 @@ class ConversationService:
*,
session: Session,
app_model: App,
user: Optional[Union[Account, EndUser]],
last_id: Optional[str],
user: Union[Account, EndUser] | None,
last_id: str | None,
limit: int,
invoke_from: InvokeFrom,
include_ids: Optional[Sequence[str]] = None,
exclude_ids: Optional[Sequence[str]] = None,
include_ids: Sequence[str] | None = None,
exclude_ids: Sequence[str] | None = None,
sort_by: str = "-updated_at",
) -> InfiniteScrollPagination:
if not user:
@ -118,7 +118,7 @@ class ConversationService:
cls,
app_model: App,
conversation_id: str,
user: Optional[Union[Account, EndUser]],
user: Union[Account, EndUser] | None,
name: str,
auto_generate: bool,
):
@ -158,7 +158,7 @@ class ConversationService:
return conversation
@classmethod
def get_conversation(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
def get_conversation(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
conversation = (
db.session.query(Conversation)
.where(
@ -178,7 +178,7 @@ class ConversationService:
return conversation
@classmethod
def delete(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
def delete(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
try:
logger.info(
"Initiating conversation deletion for app_name %s, conversation_id: %s",
@ -200,9 +200,9 @@ class ConversationService:
cls,
app_model: App,
conversation_id: str,
user: Optional[Union[Account, EndUser]],
user: Union[Account, EndUser] | None,
limit: int,
last_id: Optional[str],
last_id: str | None,
) -> InfiniteScrollPagination:
conversation = cls.get_conversation(app_model, conversation_id, user)
@ -222,8 +222,8 @@ class ConversationService:
# Filter for variables created after the last_id
stmt = stmt.where(ConversationVariable.created_at > last_variable.created_at)
# Apply limit to query
query_stmt = stmt.limit(limit) # Get one extra to check if there are more
# Apply limit to query: fetch one extra row to determine has_more
query_stmt = stmt.limit(limit + 1)
rows = session.scalars(query_stmt).all()
has_more = False
@ -248,9 +248,9 @@ class ConversationService:
app_model: App,
conversation_id: str,
variable_id: str,
user: Optional[Union[Account, EndUser]],
user: Union[Account, EndUser] | None,
new_value: Any,
) -> dict:
):
"""
Update a conversation variable's value.

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,975 @@
import logging
import time
from collections.abc import Mapping
from typing import Any
from flask_login import current_user
from sqlalchemy.orm import Session
from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.helper import encrypter
from core.helper.name_generator import generate_incremental_name
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.model_runtime.entities.provider_entities import FormType
from core.plugin.impl.datasource import PluginDatasourceManager
from core.plugin.impl.oauth import OAuthHandler
from core.tools.entities.tool_entities import CredentialType
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider
from models.provider_ids import DatasourceProviderID
from services.plugin.plugin_service import PluginService
logger = logging.getLogger(__name__)
class DatasourceProviderService:
"""
Model Provider Service
"""
def __init__(self) -> None:
self.provider_manager = PluginDatasourceManager()
def remove_oauth_custom_client_params(self, tenant_id: str, datasource_provider_id: DatasourceProviderID):
"""
remove oauth custom client params
"""
with Session(db.engine) as session:
session.query(DatasourceOauthTenantParamConfig).filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
).delete()
session.commit()
def decrypt_datasource_provider_credentials(
self,
tenant_id: str,
datasource_provider: DatasourceProvider,
plugin_id: str,
provider: str,
) -> dict[str, Any]:
encrypted_credentials = datasource_provider.encrypted_credentials
credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id,
provider_id=f"{plugin_id}/{provider}",
credential_type=CredentialType.of(datasource_provider.auth_type),
)
decrypted_credentials = encrypted_credentials.copy()
for key, value in decrypted_credentials.items():
if key in credential_secret_variables:
decrypted_credentials[key] = encrypter.decrypt_token(tenant_id, value)
return decrypted_credentials
def encrypt_datasource_provider_credentials(
self,
tenant_id: str,
provider: str,
plugin_id: str,
raw_credentials: Mapping[str, Any],
datasource_provider: DatasourceProvider,
) -> dict[str, Any]:
provider_credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id,
provider_id=f"{plugin_id}/{provider}",
credential_type=CredentialType.of(datasource_provider.auth_type),
)
encrypted_credentials = dict(raw_credentials)
for key, value in encrypted_credentials.items():
if key in provider_credential_secret_variables:
encrypted_credentials[key] = encrypter.encrypt_token(tenant_id, value)
return encrypted_credentials
def get_datasource_credentials(
self,
tenant_id: str,
provider: str,
plugin_id: str,
credential_id: str | None = None,
) -> dict[str, Any]:
"""
get credential by id
"""
with Session(db.engine) as session:
if credential_id:
datasource_provider = (
session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
)
else:
datasource_provider = (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
.first()
)
if not datasource_provider:
return {}
# refresh the credentials
if datasource_provider.expires_at != -1 and (datasource_provider.expires_at - 60) < int(time.time()):
decrypted_credentials = self.decrypt_datasource_provider_credentials(
tenant_id=tenant_id,
datasource_provider=datasource_provider,
plugin_id=plugin_id,
provider=provider,
)
datasource_provider_id = DatasourceProviderID(f"{plugin_id}/{provider}")
provider_name = datasource_provider_id.provider_name
redirect_uri = (
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/"
f"{datasource_provider_id}/datasource/callback"
)
system_credentials = self.get_oauth_client(tenant_id, datasource_provider_id)
refreshed_credentials = OAuthHandler().refresh_credentials(
tenant_id=tenant_id,
user_id=current_user.id,
plugin_id=datasource_provider_id.plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=system_credentials or {},
credentials=decrypted_credentials,
)
datasource_provider.encrypted_credentials = self.encrypt_datasource_provider_credentials(
tenant_id=tenant_id,
raw_credentials=refreshed_credentials.credentials,
provider=provider,
plugin_id=plugin_id,
datasource_provider=datasource_provider,
)
datasource_provider.expires_at = refreshed_credentials.expires_at
session.commit()
return self.decrypt_datasource_provider_credentials(
tenant_id=tenant_id,
datasource_provider=datasource_provider,
plugin_id=plugin_id,
provider=provider,
)
def get_all_datasource_credentials_by_provider(
self,
tenant_id: str,
provider: str,
plugin_id: str,
) -> list[dict[str, Any]]:
"""
get all datasource credentials by provider
"""
with Session(db.engine) as session:
datasource_providers = (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
.all()
)
if not datasource_providers:
return []
# refresh the credentials
real_credentials_list = []
for datasource_provider in datasource_providers:
decrypted_credentials = self.decrypt_datasource_provider_credentials(
tenant_id=tenant_id,
datasource_provider=datasource_provider,
plugin_id=plugin_id,
provider=provider,
)
datasource_provider_id = DatasourceProviderID(f"{plugin_id}/{provider}")
provider_name = datasource_provider_id.provider_name
redirect_uri = (
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/"
f"{datasource_provider_id}/datasource/callback"
)
system_credentials = self.get_oauth_client(tenant_id, datasource_provider_id)
refreshed_credentials = OAuthHandler().refresh_credentials(
tenant_id=tenant_id,
user_id=current_user.id,
plugin_id=datasource_provider_id.plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=system_credentials or {},
credentials=decrypted_credentials,
)
datasource_provider.encrypted_credentials = self.encrypt_datasource_provider_credentials(
tenant_id=tenant_id,
raw_credentials=refreshed_credentials.credentials,
provider=provider,
plugin_id=plugin_id,
datasource_provider=datasource_provider,
)
datasource_provider.expires_at = refreshed_credentials.expires_at
real_credentials = self.decrypt_datasource_provider_credentials(
tenant_id=tenant_id,
datasource_provider=datasource_provider,
plugin_id=plugin_id,
provider=provider,
)
real_credentials_list.append(real_credentials)
session.commit()
return real_credentials_list
def update_datasource_provider_name(
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, name: str, credential_id: str
):
"""
update datasource provider name
"""
with Session(db.engine) as session:
target_provider = (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
id=credential_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
)
.first()
)
if target_provider is None:
raise ValueError("provider not found")
if target_provider.name == name:
return
# check name is exist
if (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
name=name,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
)
.count()
> 0
):
raise ValueError("Authorization name is already exists")
target_provider.name = name
session.commit()
return
def set_default_datasource_provider(
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, credential_id: str
):
"""
set default datasource provider
"""
with Session(db.engine) as session:
# get provider
target_provider = (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
id=credential_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
)
.first()
)
if target_provider is None:
raise ValueError("provider not found")
# clear default provider
session.query(DatasourceProvider).filter_by(
tenant_id=tenant_id,
provider=target_provider.provider,
plugin_id=target_provider.plugin_id,
is_default=True,
).update({"is_default": False})
# set new default provider
target_provider.is_default = True
session.commit()
return {"result": "success"}
def setup_oauth_custom_client_params(
self,
tenant_id: str,
datasource_provider_id: DatasourceProviderID,
client_params: dict | None,
enabled: bool | None,
):
"""
setup oauth custom client params
"""
if client_params is None and enabled is None:
return
with Session(db.engine) as session:
tenant_oauth_client_params = (
session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
)
.first()
)
if not tenant_oauth_client_params:
tenant_oauth_client_params = DatasourceOauthTenantParamConfig(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
client_params={},
enabled=False,
)
session.add(tenant_oauth_client_params)
if client_params is not None:
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
original_params = (
encrypter.decrypt(tenant_oauth_client_params.client_params) if tenant_oauth_client_params else {}
)
new_params: dict = {
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
for key, value in client_params.items()
}
tenant_oauth_client_params.client_params = encrypter.encrypt(new_params)
if enabled is not None:
tenant_oauth_client_params.enabled = enabled
session.commit()
def is_system_oauth_params_exist(self, datasource_provider_id: DatasourceProviderID) -> bool:
"""
check if system oauth params exist
"""
with Session(db.engine).no_autoflush as session:
return (
session.query(DatasourceOauthParamConfig)
.filter_by(provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id)
.first()
is not None
)
def is_tenant_oauth_params_enabled(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> bool:
"""
check if tenant oauth params is enabled
"""
return (
db.session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
enabled=True,
)
.count()
> 0
)
def get_tenant_oauth_client(
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, mask: bool = False
) -> dict[str, Any] | None:
"""
get tenant oauth client
"""
tenant_oauth_client_params = (
db.session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
)
.first()
)
if tenant_oauth_client_params:
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
if mask:
return encrypter.mask_tool_credentials(encrypter.decrypt(tenant_oauth_client_params.client_params))
else:
return encrypter.decrypt(tenant_oauth_client_params.client_params)
return None
def get_oauth_encrypter(
self, tenant_id: str, datasource_provider_id: DatasourceProviderID
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
"""
get oauth encrypter
"""
datasource_provider = self.provider_manager.fetch_datasource_provider(
tenant_id=tenant_id, provider_id=str(datasource_provider_id)
)
if not datasource_provider.declaration.oauth_schema:
raise ValueError("Datasource provider oauth schema not found")
client_schema = datasource_provider.declaration.oauth_schema.client_schema
return create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in client_schema],
cache=NoOpProviderCredentialCache(),
)
def get_oauth_client(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> dict[str, Any] | None:
"""
get oauth client
"""
provider = datasource_provider_id.provider_name
plugin_id = datasource_provider_id.plugin_id
with Session(db.engine).no_autoflush as session:
# get tenant oauth client params
tenant_oauth_client_params = (
session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id,
enabled=True,
)
.first()
)
if tenant_oauth_client_params:
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
return encrypter.decrypt(tenant_oauth_client_params.client_params)
provider_controller = self.provider_manager.fetch_datasource_provider(
tenant_id=tenant_id, provider_id=str(datasource_provider_id)
)
is_verified = PluginService.is_plugin_verified(tenant_id, provider_controller.plugin_unique_identifier)
if is_verified:
# fallback to system oauth client params
oauth_client_params = (
session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
)
if oauth_client_params:
return oauth_client_params.system_credentials
raise ValueError(f"Please configure oauth client params(system/tenant) for {plugin_id}/{provider}")
@staticmethod
def generate_next_datasource_provider_name(
session: Session, tenant_id: str, provider_id: DatasourceProviderID, credential_type: CredentialType
) -> str:
db_providers = (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
)
.all()
)
return generate_incremental_name(
[provider.name for provider in db_providers],
f"{credential_type.get_name()}",
)
def reauthorize_datasource_oauth_provider(
self,
name: str | None,
tenant_id: str,
provider_id: DatasourceProviderID,
avatar_url: str | None,
expire_at: int,
credentials: dict,
credential_id: str,
) -> None:
"""
update datasource oauth provider
"""
with Session(db.engine) as session:
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.OAUTH2.value}"
with redis_client.lock(lock, timeout=20):
target_provider = (
session.query(DatasourceProvider).filter_by(id=credential_id, tenant_id=tenant_id).first()
)
if target_provider is None:
raise ValueError("provider not found")
db_provider_name = name
if not db_provider_name:
db_provider_name = target_provider.name
else:
name_conflict = (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
name=db_provider_name,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
auth_type=CredentialType.OAUTH2.value,
)
.count()
)
if name_conflict > 0:
db_provider_name = generate_incremental_name(
[
provider.name
for provider in session.query(DatasourceProvider).filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
)
],
db_provider_name,
)
provider_credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.OAUTH2
)
for key, value in credentials.items():
if key in provider_credential_secret_variables:
credentials[key] = encrypter.encrypt_token(tenant_id, value)
target_provider.expires_at = expire_at
target_provider.encrypted_credentials = credentials
target_provider.avatar_url = avatar_url or target_provider.avatar_url
session.commit()
def add_datasource_oauth_provider(
self,
name: str | None,
tenant_id: str,
provider_id: DatasourceProviderID,
avatar_url: str | None,
expire_at: int,
credentials: dict,
) -> None:
"""
add datasource oauth provider
"""
credential_type = CredentialType.OAUTH2
with Session(db.engine) as session:
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{credential_type.value}"
with redis_client.lock(lock, timeout=60):
db_provider_name = name
if not db_provider_name:
db_provider_name = self.generate_next_datasource_provider_name(
session=session,
tenant_id=tenant_id,
provider_id=provider_id,
credential_type=credential_type,
)
else:
if (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
name=db_provider_name,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
auth_type=credential_type.value,
)
.count()
> 0
):
db_provider_name = generate_incremental_name(
[
provider.name
for provider in session.query(DatasourceProvider).filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
)
],
db_provider_name,
)
provider_credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=credential_type
)
for key, value in credentials.items():
if key in provider_credential_secret_variables:
credentials[key] = encrypter.encrypt_token(tenant_id, value)
datasource_provider = DatasourceProvider(
tenant_id=tenant_id,
name=db_provider_name,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
auth_type=credential_type.value,
encrypted_credentials=credentials,
avatar_url=avatar_url or "default",
expires_at=expire_at,
)
session.add(datasource_provider)
session.commit()
def add_datasource_api_key_provider(
self,
name: str | None,
tenant_id: str,
provider_id: DatasourceProviderID,
credentials: dict,
) -> None:
"""
validate datasource provider credentials.
:param tenant_id:
:param provider:
:param credentials:
"""
provider_name = provider_id.provider_name
plugin_id = provider_id.plugin_id
with Session(db.engine) as session:
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}"
with redis_client.lock(lock, timeout=20):
db_provider_name = name or self.generate_next_datasource_provider_name(
session=session,
tenant_id=tenant_id,
provider_id=provider_id,
credential_type=CredentialType.API_KEY,
)
# check name is exist
if (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name, name=db_provider_name)
.count()
> 0
):
raise ValueError("Authorization name is already exists")
try:
self.provider_manager.validate_provider_credentials(
tenant_id=tenant_id,
user_id=current_user.id,
provider=provider_name,
plugin_id=plugin_id,
credentials=credentials,
)
except Exception as e:
raise ValueError(f"Failed to validate credentials: {str(e)}")
provider_credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.API_KEY
)
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
credentials[key] = encrypter.encrypt_token(tenant_id, value)
datasource_provider = DatasourceProvider(
tenant_id=tenant_id,
name=db_provider_name,
provider=provider_name,
plugin_id=plugin_id,
auth_type=CredentialType.API_KEY.value,
encrypted_credentials=credentials,
)
session.add(datasource_provider)
session.commit()
def extract_secret_variables(self, tenant_id: str, provider_id: str, credential_type: CredentialType) -> list[str]:
"""
Extract secret input form variables.
:param credential_form_schemas:
:return:
"""
datasource_provider = self.provider_manager.fetch_datasource_provider(
tenant_id=tenant_id, provider_id=provider_id
)
credential_form_schemas = []
if credential_type == CredentialType.API_KEY:
credential_form_schemas = list(datasource_provider.declaration.credentials_schema)
elif credential_type == CredentialType.OAUTH2:
if not datasource_provider.declaration.oauth_schema:
raise ValueError("Datasource provider oauth schema not found")
credential_form_schemas = list(datasource_provider.declaration.oauth_schema.credentials_schema)
else:
raise ValueError(f"Invalid credential type: {credential_type}")
secret_input_form_variables = []
for credential_form_schema in credential_form_schemas:
if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
secret_input_form_variables.append(credential_form_schema.name)
return secret_input_form_variables
def list_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
"""
list datasource credentials with obfuscated sensitive fields.
:param tenant_id: workspace id
:param provider_id: provider id
:return:
"""
# Get all provider configurations of the current workspace
datasource_providers: list[DatasourceProvider] = (
db.session.query(DatasourceProvider)
.where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
)
.all()
)
if not datasource_providers:
return []
copy_credentials_list = []
default_provider = (
db.session.query(DatasourceProvider.id)
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
.first()
)
default_provider_id = default_provider.id if default_provider else None
for datasource_provider in datasource_providers:
encrypted_credentials = datasource_provider.encrypted_credentials
# Get provider credential secret variables
credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id,
provider_id=f"{plugin_id}/{provider}",
credential_type=CredentialType.of(datasource_provider.auth_type),
)
# Obfuscate provider credentials
copy_credentials = encrypted_credentials.copy()
for key, value in copy_credentials.items():
if key in credential_secret_variables:
copy_credentials[key] = encrypter.obfuscated_token(value)
copy_credentials_list.append(
{
"credential": copy_credentials,
"type": datasource_provider.auth_type,
"name": datasource_provider.name,
"avatar_url": datasource_provider.avatar_url,
"id": datasource_provider.id,
"is_default": default_provider_id and datasource_provider.id == default_provider_id,
}
)
return copy_credentials_list
def get_all_datasource_credentials(self, tenant_id: str) -> list[dict]:
"""
get datasource credentials.
:return:
"""
# get all plugin providers
manager = PluginDatasourceManager()
datasources = manager.fetch_installed_datasource_providers(tenant_id)
datasource_credentials = []
for datasource in datasources:
datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}")
credentials = self.list_datasource_credentials(
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
)
redirect_uri = (
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback"
)
datasource_credentials.append(
{
"provider": datasource.provider,
"plugin_id": datasource.plugin_id,
"plugin_unique_identifier": datasource.plugin_unique_identifier,
"icon": datasource.declaration.identity.icon,
"name": datasource.declaration.identity.name.split("/")[-1],
"label": datasource.declaration.identity.label.model_dump(),
"description": datasource.declaration.identity.description.model_dump(),
"author": datasource.declaration.identity.author,
"credentials_list": credentials,
"credential_schema": [
credential.model_dump() for credential in datasource.declaration.credentials_schema
],
"oauth_schema": {
"client_schema": [
client_schema.model_dump()
for client_schema in datasource.declaration.oauth_schema.client_schema
],
"credentials_schema": [
credential_schema.model_dump()
for credential_schema in datasource.declaration.oauth_schema.credentials_schema
],
"oauth_custom_client_params": self.get_tenant_oauth_client(
tenant_id, datasource_provider_id, mask=True
),
"is_oauth_custom_client_enabled": self.is_tenant_oauth_params_enabled(
tenant_id, datasource_provider_id
),
"is_system_oauth_params_exists": self.is_system_oauth_params_exist(datasource_provider_id),
"redirect_uri": redirect_uri,
}
if datasource.declaration.oauth_schema
else None,
}
)
return datasource_credentials
def get_hard_code_datasource_credentials(self, tenant_id: str) -> list[dict]:
"""
get hard code datasource credentials.
:return:
"""
# get all plugin providers
manager = PluginDatasourceManager()
datasources = manager.fetch_installed_datasource_providers(tenant_id)
datasource_credentials = []
for datasource in datasources:
if datasource.plugin_id in [
"langgenius/firecrawl_datasource",
"langgenius/notion_datasource",
"langgenius/jina_datasource",
]:
datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}")
credentials = self.list_datasource_credentials(
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
)
redirect_uri = "{}/console/api/oauth/plugin/{}/datasource/callback".format(
dify_config.CONSOLE_API_URL, datasource_provider_id
)
datasource_credentials.append(
{
"provider": datasource.provider,
"plugin_id": datasource.plugin_id,
"plugin_unique_identifier": datasource.plugin_unique_identifier,
"icon": datasource.declaration.identity.icon,
"name": datasource.declaration.identity.name.split("/")[-1],
"label": datasource.declaration.identity.label.model_dump(),
"description": datasource.declaration.identity.description.model_dump(),
"author": datasource.declaration.identity.author,
"credentials_list": credentials,
"credential_schema": [
credential.model_dump() for credential in datasource.declaration.credentials_schema
],
"oauth_schema": {
"client_schema": [
client_schema.model_dump()
for client_schema in datasource.declaration.oauth_schema.client_schema
],
"credentials_schema": [
credential_schema.model_dump()
for credential_schema in datasource.declaration.oauth_schema.credentials_schema
],
"oauth_custom_client_params": self.get_tenant_oauth_client(
tenant_id, datasource_provider_id, mask=True
),
"is_oauth_custom_client_enabled": self.is_tenant_oauth_params_enabled(
tenant_id, datasource_provider_id
),
"is_system_oauth_params_exists": self.is_system_oauth_params_exist(datasource_provider_id),
"redirect_uri": redirect_uri,
}
if datasource.declaration.oauth_schema
else None,
}
)
return datasource_credentials
def get_real_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
"""
get datasource credentials.
:param tenant_id: workspace id
:param provider_id: provider id
:return:
"""
# Get all provider configurations of the current workspace
datasource_providers: list[DatasourceProvider] = (
db.session.query(DatasourceProvider)
.where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
)
.all()
)
if not datasource_providers:
return []
copy_credentials_list = []
for datasource_provider in datasource_providers:
encrypted_credentials = datasource_provider.encrypted_credentials
# Get provider credential secret variables
credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id,
provider_id=f"{plugin_id}/{provider}",
credential_type=CredentialType.of(datasource_provider.auth_type),
)
# Obfuscate provider credentials
copy_credentials = encrypted_credentials.copy()
for key, value in copy_credentials.items():
if key in credential_secret_variables:
copy_credentials[key] = encrypter.decrypt_token(tenant_id, value)
copy_credentials_list.append(
{
"credentials": copy_credentials,
"type": datasource_provider.auth_type,
}
)
return copy_credentials_list
def update_datasource_credentials(
self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict | None, name: str | None
) -> None:
"""
update datasource credentials.
"""
with Session(db.engine) as session:
datasource_provider = (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
.first()
)
if not datasource_provider:
raise ValueError("Datasource provider not found")
# update name
if name and name != datasource_provider.name:
if (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, name=name, provider=provider, plugin_id=plugin_id)
.count()
> 0
):
raise ValueError("Authorization name is already exists")
datasource_provider.name = name
# update credentials
if credentials:
secret_variables = self.extract_secret_variables(
tenant_id=tenant_id,
provider_id=f"{plugin_id}/{provider}",
credential_type=CredentialType.of(datasource_provider.auth_type),
)
original_credentials = {
key: value if key not in secret_variables else encrypter.decrypt_token(tenant_id, value)
for key, value in datasource_provider.encrypted_credentials.items()
}
new_credentials = {
key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE)
for key, value in credentials.items()
}
try:
self.provider_manager.validate_provider_credentials(
tenant_id=tenant_id,
user_id=current_user.id,
provider=provider,
plugin_id=plugin_id,
credentials=new_credentials,
)
except Exception as e:
raise ValueError(f"Failed to validate credentials: {str(e)}")
encrypted_credentials = {}
for key, value in new_credentials.items():
if key in secret_variables:
encrypted_credentials[key] = encrypter.encrypt_token(tenant_id, value)
else:
encrypted_credentials[key] = value
datasource_provider.encrypted_credentials = encrypted_credentials
session.commit()
def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None:
"""
remove datasource credentials.
:param tenant_id: workspace id
:param provider: provider name
:param plugin_id: plugin id
:return:
"""
datasource_provider = (
db.session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
.first()
)
if datasource_provider:
db.session.delete(datasource_provider)
db.session.commit()

View File

@ -3,18 +3,30 @@ import os
import requests
class EnterpriseRequest:
base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL")
secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY")
class BaseRequest:
proxies = {
"http": "",
"https": "",
}
base_url = ""
secret_key = ""
secret_key_header = ""
@classmethod
def send_request(cls, method, endpoint, json=None, params=None):
headers = {"Content-Type": "application/json", "Enterprise-Api-Secret-Key": cls.secret_key}
headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key}
url = f"{cls.base_url}{endpoint}"
response = requests.request(method, url, json=json, params=params, headers=headers, proxies=cls.proxies)
return response.json()
class EnterpriseRequest(BaseRequest):
base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL")
secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY")
secret_key_header = "Enterprise-Api-Secret-Key"
class EnterprisePluginManagerRequest(BaseRequest):
base_url = os.environ.get("ENTERPRISE_PLUGIN_MANAGER_API_URL", "ENTERPRISE_PLUGIN_MANAGER_API_URL")
secret_key = os.environ.get("ENTERPRISE_PLUGIN_MANAGER_API_SECRET_KEY", "ENTERPRISE_PLUGIN_MANAGER_API_SECRET_KEY")
secret_key_header = "Plugin-Manager-Inner-Api-Secret-Key"

View File

@ -0,0 +1,57 @@
import enum
import logging
from pydantic import BaseModel
from services.enterprise.base import EnterprisePluginManagerRequest
from services.errors.base import BaseServiceError
logger = logging.getLogger(__name__)
class PluginCredentialType(enum.Enum):
MODEL = 0 # must be 0 for API contract compatibility
TOOL = 1 # must be 1 for API contract compatibility
def to_number(self):
return self.value
class CheckCredentialPolicyComplianceRequest(BaseModel):
dify_credential_id: str
provider: str
credential_type: PluginCredentialType
def model_dump(self, **kwargs):
data = super().model_dump(**kwargs)
data["credential_type"] = self.credential_type.to_number()
return data
class CredentialPolicyViolationError(BaseServiceError):
pass
class PluginManagerService:
@classmethod
def check_credential_policy_compliance(cls, body: CheckCredentialPolicyComplianceRequest):
try:
ret = EnterprisePluginManagerRequest.send_request(
"POST", "/check-credential-policy-compliance", json=body.model_dump()
)
if not isinstance(ret, dict) or "result" not in ret:
raise ValueError("Invalid response format from plugin manager API")
except Exception as e:
raise CredentialPolicyViolationError(
f"error occurred while checking credential policy compliance: {e}"
) from e
if not ret.get("result", False):
raise CredentialPolicyViolationError("Credentials not available: Please use ENTERPRISE global credentials")
logging.debug(
"Credential policy compliance checked for %s with credential %s, result: %s",
body.provider,
body.dify_credential_id,
ret.get("result", False),
)

View File

@ -1,4 +1,4 @@
from typing import Literal, Optional, Union
from typing import Literal, Union
from pydantic import BaseModel
@ -11,7 +11,7 @@ class AuthorizationConfig(BaseModel):
class Authorization(BaseModel):
type: Literal["no-auth", "api-key"]
config: Optional[AuthorizationConfig] = None
config: AuthorizationConfig | None = None
class ProcessStatusSetting(BaseModel):
@ -22,5 +22,5 @@ class ProcessStatusSetting(BaseModel):
class ExternalKnowledgeApiSetting(BaseModel):
url: str
request_method: str
headers: Optional[dict] = None
params: Optional[dict] = None
headers: dict | None = None
params: dict | None = None

View File

@ -1,5 +1,5 @@
from enum import StrEnum
from typing import Literal, Optional
from typing import Literal
from pydantic import BaseModel
@ -11,18 +11,19 @@ class ParentMode(StrEnum):
class NotionIcon(BaseModel):
type: str
url: Optional[str] = None
emoji: Optional[str] = None
url: str | None = None
emoji: str | None = None
class NotionPage(BaseModel):
page_id: str
page_name: str
page_icon: Optional[NotionIcon] = None
page_icon: NotionIcon | None = None
type: str
class NotionInfo(BaseModel):
credential_id: str
workspace_id: str
pages: list[NotionPage]
@ -40,9 +41,9 @@ class FileInfo(BaseModel):
class InfoList(BaseModel):
data_source_type: Literal["upload_file", "notion_import", "website_crawl"]
notion_info_list: Optional[list[NotionInfo]] = None
file_info_list: Optional[FileInfo] = None
website_info_list: Optional[WebsiteInfo] = None
notion_info_list: list[NotionInfo] | None = None
file_info_list: FileInfo | None = None
website_info_list: WebsiteInfo | None = None
class DataSource(BaseModel):
@ -61,20 +62,20 @@ class Segmentation(BaseModel):
class Rule(BaseModel):
pre_processing_rules: Optional[list[PreProcessingRule]] = None
segmentation: Optional[Segmentation] = None
parent_mode: Optional[Literal["full-doc", "paragraph"]] = None
subchunk_segmentation: Optional[Segmentation] = None
pre_processing_rules: list[PreProcessingRule] | None = None
segmentation: Segmentation | None = None
parent_mode: Literal["full-doc", "paragraph"] | None = None
subchunk_segmentation: Segmentation | None = None
class ProcessRule(BaseModel):
mode: Literal["automatic", "custom", "hierarchical"]
rules: Optional[Rule] = None
rules: Rule | None = None
class RerankingModel(BaseModel):
reranking_provider_name: Optional[str] = None
reranking_model_name: Optional[str] = None
reranking_provider_name: str | None = None
reranking_model_name: str | None = None
class WeightVectorSetting(BaseModel):
@ -88,20 +89,20 @@ class WeightKeywordSetting(BaseModel):
class WeightModel(BaseModel):
weight_type: Optional[Literal["semantic_first", "keyword_first", "customized"]] = None
vector_setting: Optional[WeightVectorSetting] = None
keyword_setting: Optional[WeightKeywordSetting] = None
weight_type: Literal["semantic_first", "keyword_first", "customized"] | None = None
vector_setting: WeightVectorSetting | None = None
keyword_setting: WeightKeywordSetting | None = None
class RetrievalModel(BaseModel):
search_method: Literal["hybrid_search", "semantic_search", "full_text_search", "keyword_search"]
reranking_enable: bool
reranking_model: Optional[RerankingModel] = None
reranking_mode: Optional[str] = None
reranking_model: RerankingModel | None = None
reranking_mode: str | None = None
top_k: int
score_threshold_enabled: bool
score_threshold: Optional[float] = None
weights: Optional[WeightModel] = None
score_threshold: float | None = None
weights: WeightModel | None = None
class MetaDataConfig(BaseModel):
@ -110,29 +111,29 @@ class MetaDataConfig(BaseModel):
class KnowledgeConfig(BaseModel):
original_document_id: Optional[str] = None
original_document_id: str | None = None
duplicate: bool = True
indexing_technique: Literal["high_quality", "economy"]
data_source: Optional[DataSource] = None
process_rule: Optional[ProcessRule] = None
retrieval_model: Optional[RetrievalModel] = None
data_source: DataSource | None = None
process_rule: ProcessRule | None = None
retrieval_model: RetrievalModel | None = None
doc_form: str = "text_model"
doc_language: str = "English"
embedding_model: Optional[str] = None
embedding_model_provider: Optional[str] = None
name: Optional[str] = None
embedding_model: str | None = None
embedding_model_provider: str | None = None
name: str | None = None
class SegmentUpdateArgs(BaseModel):
content: Optional[str] = None
answer: Optional[str] = None
keywords: Optional[list[str]] = None
content: str | None = None
answer: str | None = None
keywords: list[str] | None = None
regenerate_child_chunks: bool = False
enabled: Optional[bool] = None
enabled: bool | None = None
class ChildChunkUpdateArgs(BaseModel):
id: Optional[str] = None
id: str | None = None
content: str
@ -143,13 +144,13 @@ class MetadataArgs(BaseModel):
class MetadataUpdateArgs(BaseModel):
name: str
value: Optional[str | int | float] = None
value: str | int | float | None = None
class MetadataDetail(BaseModel):
id: str
name: str
value: Optional[str | int | float] = None
value: str | int | float | None = None
class DocumentMetadataOperation(BaseModel):

View File

@ -0,0 +1,130 @@
from typing import Literal
from pydantic import BaseModel, field_validator
class IconInfo(BaseModel):
icon: str
icon_background: str | None = None
icon_type: str | None = None
icon_url: str | None = None
class PipelineTemplateInfoEntity(BaseModel):
name: str
description: str
icon_info: IconInfo
class RagPipelineDatasetCreateEntity(BaseModel):
name: str
description: str
icon_info: IconInfo
permission: str
partial_member_list: list[str] | None = None
yaml_content: str | None = None
class RerankingModelConfig(BaseModel):
"""
Reranking Model Config.
"""
reranking_provider_name: str | None = ""
reranking_model_name: str | None = ""
class VectorSetting(BaseModel):
"""
Vector Setting.
"""
vector_weight: float
embedding_provider_name: str
embedding_model_name: str
class KeywordSetting(BaseModel):
"""
Keyword Setting.
"""
keyword_weight: float
class WeightedScoreConfig(BaseModel):
"""
Weighted score Config.
"""
vector_setting: VectorSetting | None
keyword_setting: KeywordSetting | None
class EmbeddingSetting(BaseModel):
"""
Embedding Setting.
"""
embedding_provider_name: str
embedding_model_name: str
class EconomySetting(BaseModel):
"""
Economy Setting.
"""
keyword_number: int
class RetrievalSetting(BaseModel):
"""
Retrieval Setting.
"""
search_method: Literal["semantic_search", "fulltext_search", "keyword_search", "hybrid_search"]
top_k: int
score_threshold: float | None = 0.5
score_threshold_enabled: bool = False
reranking_mode: str | None = "reranking_model"
reranking_enable: bool | None = True
reranking_model: RerankingModelConfig | None = None
weights: WeightedScoreConfig | None = None
class IndexMethod(BaseModel):
"""
Knowledge Index Setting.
"""
indexing_technique: Literal["high_quality", "economy"]
embedding_setting: EmbeddingSetting
economy_setting: EconomySetting
class KnowledgeConfiguration(BaseModel):
"""
Knowledge Base Configuration.
"""
chunk_structure: str
indexing_technique: Literal["high_quality", "economy"]
embedding_model_provider: str = ""
embedding_model: str = ""
keyword_number: int | None = 10
retrieval_model: RetrievalSetting
@field_validator("embedding_model_provider", mode="before")
@classmethod
def validate_embedding_model_provider(cls, v):
if v is None:
return ""
return v
@field_validator("embedding_model", mode="before")
@classmethod
def validate_embedding_model(cls, v):
if v is None:
return ""
return v

View File

@ -1,5 +1,4 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel, ConfigDict
@ -13,6 +12,7 @@ from core.entities.provider_entities import (
CustomModelConfiguration,
ProviderQuotaType,
QuotaConfiguration,
UnaddedModelConfiguration,
)
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import ModelType
@ -41,10 +41,11 @@ class CustomConfigurationResponse(BaseModel):
"""
status: CustomConfigurationStatus
current_credential_id: Optional[str] = None
current_credential_name: Optional[str] = None
available_credentials: Optional[list[CredentialConfiguration]] = None
custom_models: Optional[list[CustomModelConfiguration]] = None
current_credential_id: str | None = None
current_credential_name: str | None = None
available_credentials: list[CredentialConfiguration] | None = None
custom_models: list[CustomModelConfiguration] | None = None
can_added_models: list[UnaddedModelConfiguration] | None = None
class SystemConfigurationResponse(BaseModel):
@ -53,7 +54,7 @@ class SystemConfigurationResponse(BaseModel):
"""
enabled: bool
current_quota_type: Optional[ProviderQuotaType] = None
current_quota_type: ProviderQuotaType | None = None
quota_configurations: list[QuotaConfiguration] = []
@ -65,15 +66,15 @@ class ProviderResponse(BaseModel):
tenant_id: str
provider: str
label: I18nObject
description: Optional[I18nObject] = None
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
background: Optional[str] = None
help: Optional[ProviderHelpEntity] = None
description: I18nObject | None = None
icon_small: I18nObject | None = None
icon_large: I18nObject | None = None
background: str | None = None
help: ProviderHelpEntity | None = None
supported_model_types: list[ModelType]
configurate_methods: list[ConfigurateMethod]
provider_credential_schema: Optional[ProviderCredentialSchema] = None
model_credential_schema: Optional[ModelCredentialSchema] = None
provider_credential_schema: ProviderCredentialSchema | None = None
model_credential_schema: ModelCredentialSchema | None = None
preferred_provider_type: ProviderType
custom_configuration: CustomConfigurationResponse
system_configuration: SystemConfigurationResponse
@ -81,7 +82,7 @@ class ProviderResponse(BaseModel):
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def __init__(self, **data) -> None:
def __init__(self, **data):
super().__init__(**data)
url_prefix = (
@ -106,12 +107,12 @@ class ProviderWithModelsResponse(BaseModel):
tenant_id: str
provider: str
label: I18nObject
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
icon_small: I18nObject | None = None
icon_large: I18nObject | None = None
status: CustomConfigurationStatus
models: list[ProviderModelWithStatusEntity]
def __init__(self, **data) -> None:
def __init__(self, **data):
super().__init__(**data)
url_prefix = (
@ -135,7 +136,7 @@ class SimpleProviderEntityResponse(SimpleProviderEntity):
tenant_id: str
def __init__(self, **data) -> None:
def __init__(self, **data):
super().__init__(**data)
url_prefix = (
@ -172,7 +173,7 @@ class ModelWithProviderEntityResponse(ProviderModelWithStatusEntity):
provider: SimpleProviderEntityResponse
def __init__(self, tenant_id: str, model: ModelWithProviderEntity) -> None:
def __init__(self, tenant_id: str, model: ModelWithProviderEntity):
dump_model = model.model_dump()
dump_model["provider"]["tenant_id"] = tenant_id
super().__init__(**dump_model)

View File

@ -1,6 +1,3 @@
from typing import Optional
class BaseServiceError(ValueError):
def __init__(self, description: Optional[str] = None):
def __init__(self, description: str | None = None):
self.description = description

View File

@ -1,12 +1,9 @@
from typing import Optional
class InvokeError(Exception):
"""Base class for all LLM exceptions."""
description: Optional[str] = None
description: str | None = None
def __init__(self, description: Optional[str] = None) -> None:
def __init__(self, description: str | None = None):
self.description = description
def __str__(self):

View File

@ -1,6 +1,6 @@
import json
from copy import deepcopy
from typing import Any, Optional, Union, cast
from typing import Any, Union, cast
from urllib.parse import urlparse
import httpx
@ -89,7 +89,7 @@ class ExternalDatasetService:
raise ValueError(f"invalid endpoint: {endpoint}")
try:
response = httpx.post(endpoint, headers={"Authorization": f"Bearer {api_key}"})
except Exception as e:
except Exception:
raise ValueError(f"failed to connect to the endpoint: {endpoint}")
if response.status_code == 502:
raise ValueError(f"Bad Gateway: failed to connect to the endpoint: {endpoint}")
@ -100,7 +100,7 @@ class ExternalDatasetService:
@staticmethod
def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis:
external_knowledge_api: Optional[ExternalKnowledgeApis] = (
external_knowledge_api: ExternalKnowledgeApis | None = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id).first()
)
if external_knowledge_api is None:
@ -109,13 +109,14 @@ class ExternalDatasetService:
@staticmethod
def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis:
external_knowledge_api: Optional[ExternalKnowledgeApis] = (
external_knowledge_api: ExternalKnowledgeApis | None = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
)
if external_knowledge_api is None:
raise ValueError("api template not found")
if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE:
args.get("settings")["api_key"] = external_knowledge_api.settings_dict.get("api_key")
settings = args.get("settings")
if settings and settings.get("api_key") == HIDDEN_VALUE and external_knowledge_api.settings_dict:
settings["api_key"] = external_knowledge_api.settings_dict.get("api_key")
external_knowledge_api.name = args.get("name")
external_knowledge_api.description = args.get("description", "")
@ -150,7 +151,7 @@ class ExternalDatasetService:
@staticmethod
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
external_knowledge_binding: Optional[ExternalKnowledgeBindings] = (
external_knowledge_binding: ExternalKnowledgeBindings | None = (
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
)
if not external_knowledge_binding:
@ -180,7 +181,7 @@ class ExternalDatasetService:
do http request depending on api bundle
"""
kwargs = {
kwargs: dict[str, Any] = {
"url": settings.url,
"headers": settings.headers,
"follow_redirects": True,
@ -202,7 +203,7 @@ class ExternalDatasetService:
return response
@staticmethod
def assembling_headers(authorization: Authorization, headers: Optional[dict] = None) -> dict[str, Any]:
def assembling_headers(authorization: Authorization, headers: dict | None = None) -> dict[str, Any]:
authorization = deepcopy(authorization)
if headers:
headers = deepcopy(headers)
@ -229,7 +230,7 @@ class ExternalDatasetService:
@staticmethod
def get_external_knowledge_api_settings(settings: dict) -> ExternalKnowledgeApiSetting:
return ExternalKnowledgeApiSetting.parse_obj(settings)
return ExternalKnowledgeApiSetting.model_validate(settings)
@staticmethod
def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset:
@ -276,8 +277,8 @@ class ExternalDatasetService:
dataset_id: str,
query: str,
external_retrieval_parameters: dict,
metadata_condition: Optional[MetadataCondition] = None,
) -> list:
metadata_condition: MetadataCondition | None = None,
):
external_knowledge_binding = (
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
)

View File

@ -88,6 +88,10 @@ class WebAppAuthModel(BaseModel):
allow_email_password_login: bool = False
class KnowledgePipeline(BaseModel):
publish_enabled: bool = False
class PluginInstallationScope(StrEnum):
NONE = "none"
OFFICIAL_ONLY = "official_only"
@ -126,6 +130,7 @@ class FeatureModel(BaseModel):
is_allow_transfer_workspace: bool = True
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
knowledge_pipeline: KnowledgePipeline = KnowledgePipeline()
class KnowledgeRateLimitModel(BaseModel):
@ -134,6 +139,10 @@ class KnowledgeRateLimitModel(BaseModel):
subscription_plan: str = ""
class PluginManagerModel(BaseModel):
enabled: bool = False
class SystemFeatureModel(BaseModel):
sso_enforced_for_signin: bool = False
sso_enforced_for_signin_protocol: str = ""
@ -150,6 +159,7 @@ class SystemFeatureModel(BaseModel):
webapp_auth: WebAppAuthModel = WebAppAuthModel()
plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel()
enable_change_email: bool = True
plugin_manager: PluginManagerModel = PluginManagerModel()
class FeatureService:
@ -188,6 +198,7 @@ class FeatureService:
system_features.branding.enabled = True
system_features.webapp_auth.enabled = True
system_features.enable_change_email = False
system_features.plugin_manager.enabled = True
cls._fulfill_params_from_enterprise(system_features)
if dify_config.MARKETPLACE_ENABLED:
@ -265,6 +276,9 @@ class FeatureService:
if "knowledge_rate_limit" in billing_info:
features.knowledge_rate_limit = billing_info["knowledge_rate_limit"]["limit"]
if "knowledge_pipeline_publish_enabled" in billing_info:
features.knowledge_pipeline.publish_enabled = billing_info["knowledge_pipeline_publish_enabled"]
@classmethod
def _fulfill_params_from_enterprise(cls, features: SystemFeatureModel):
enterprise_info = EnterpriseService.get_info()

View File

@ -1,9 +1,10 @@
import hashlib
import os
import uuid
from typing import Any, Literal, Union
from typing import Literal, Union
from flask_login import current_user
from sqlalchemy import Engine
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
from configs import dify_config
@ -15,7 +16,6 @@ from constants import (
)
from core.file import helpers as file_helpers
from core.rag.extractor.extract_processor import ExtractProcessor
from extensions.ext_database import db
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from libs.helper import extract_tenant_id
@ -29,13 +29,23 @@ PREVIEW_WORDS_LIMIT = 3000
class FileService:
@staticmethod
_session_maker: sessionmaker
def __init__(self, session_factory: sessionmaker | Engine | None = None):
if isinstance(session_factory, Engine):
self._session_maker = sessionmaker(bind=session_factory)
elif isinstance(session_factory, sessionmaker):
self._session_maker = session_factory
else:
raise AssertionError("must be a sessionmaker or an Engine.")
def upload_file(
self,
*,
filename: str,
content: bytes,
mimetype: str,
user: Union[Account, EndUser, Any],
user: Union[Account, EndUser],
source: Literal["datasets"] | None = None,
source_url: str = "",
) -> UploadFile:
@ -85,14 +95,14 @@ class FileService:
hash=hashlib.sha3_256(content).hexdigest(),
source_url=source_url,
)
db.session.add(upload_file)
db.session.commit()
# The `UploadFile` ID is generated within its constructor, so flushing to retrieve the ID is unnecessary.
# We can directly generate the `source_url` here before committing.
if not upload_file.source_url:
upload_file.source_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
db.session.add(upload_file)
db.session.commit()
with self._session_maker(expire_on_commit=False) as session:
session.add(upload_file)
session.commit()
return upload_file
@ -109,42 +119,42 @@ class FileService:
return file_size <= file_size_limit
@staticmethod
def upload_text(text: str, text_name: str) -> UploadFile:
def upload_text(self, text: str, text_name: str, user_id: str, tenant_id: str) -> UploadFile:
if len(text_name) > 200:
text_name = text_name[:200]
# user uuid as file name
file_uuid = str(uuid.uuid4())
file_key = "upload_files/" + current_user.current_tenant_id + "/" + file_uuid + ".txt"
file_key = "upload_files/" + tenant_id + "/" + file_uuid + ".txt"
# save file to storage
storage.save(file_key, text.encode("utf-8"))
# save file to db
upload_file = UploadFile(
tenant_id=current_user.current_tenant_id,
tenant_id=tenant_id,
storage_type=dify_config.STORAGE_TYPE,
key=file_key,
name=text_name,
size=len(text),
extension="txt",
mime_type="text/plain",
created_by=current_user.id,
created_by=user_id,
created_by_role=CreatorUserRole.ACCOUNT,
created_at=naive_utc_now(),
used=True,
used_by=current_user.id,
used_by=user_id,
used_at=naive_utc_now(),
)
db.session.add(upload_file)
db.session.commit()
with self._session_maker(expire_on_commit=False) as session:
session.add(upload_file)
session.commit()
return upload_file
@staticmethod
def get_file_preview(file_id: str):
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
def get_file_preview(self, file_id: str):
with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("File not found")
@ -159,15 +169,14 @@ class FileService:
return text
@staticmethod
def get_image_preview(file_id: str, timestamp: str, nonce: str, sign: str):
def get_image_preview(self, file_id: str, timestamp: str, nonce: str, sign: str):
result = file_helpers.verify_image_signature(
upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign
)
if not result:
raise NotFound("File not found or signature is invalid")
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("File not found or signature is invalid")
@ -181,13 +190,13 @@ class FileService:
return generator, upload_file.mime_type
@staticmethod
def get_file_generator_by_file_id(file_id: str, timestamp: str, nonce: str, sign: str):
def get_file_generator_by_file_id(self, file_id: str, timestamp: str, nonce: str, sign: str):
result = file_helpers.verify_file_signature(upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign)
if not result:
raise NotFound("File not found or signature is invalid")
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("File not found or signature is invalid")
@ -196,9 +205,9 @@ class FileService:
return generator, upload_file
@staticmethod
def get_public_image_preview(file_id: str):
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
def get_public_image_preview(self, file_id: str):
with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("File not found or signature is invalid")
@ -211,3 +220,23 @@ class FileService:
generator = storage.load(upload_file.key)
return generator, upload_file.mime_type
def get_file_content(self, file_id: str) -> str:
with self._session_maker(expire_on_commit=False) as session:
upload_file: UploadFile | None = session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("File not found")
content = storage.load(upload_file.key)
return content.decode("utf-8")
def delete_file(self, file_id: str):
with self._session_maker(expire_on_commit=False) as session:
upload_file: UploadFile | None = session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
return
storage.delete(upload_file.key)
session.delete(upload_file)
session.commit()

View File

@ -33,7 +33,7 @@ class HitTestingService:
retrieval_model: Any, # FIXME drop this any
external_retrieval_model: dict,
limit: int = 10,
) -> dict:
):
start = time.perf_counter()
# get retrieval model , if the model is not setting , using default
@ -98,7 +98,7 @@ class HitTestingService:
account: Account,
external_retrieval_model: dict,
metadata_filtering_conditions: dict,
) -> dict:
):
if dataset.provider != "external":
return {
"query": {"content": query},

View File

@ -1,5 +1,5 @@
import json
from typing import Optional, Union
from typing import Union
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.entities.app_invoke_entities import InvokeFrom
@ -29,9 +29,9 @@ class MessageService:
def pagination_by_first_id(
cls,
app_model: App,
user: Optional[Union[Account, EndUser]],
user: Union[Account, EndUser] | None,
conversation_id: str,
first_id: Optional[str],
first_id: str | None,
limit: int,
order: str = "asc",
) -> InfiniteScrollPagination:
@ -91,11 +91,11 @@ class MessageService:
def pagination_by_last_id(
cls,
app_model: App,
user: Optional[Union[Account, EndUser]],
last_id: Optional[str],
user: Union[Account, EndUser] | None,
last_id: str | None,
limit: int,
conversation_id: Optional[str] = None,
include_ids: Optional[list] = None,
conversation_id: str | None = None,
include_ids: list | None = None,
) -> InfiniteScrollPagination:
if not user:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
@ -145,9 +145,9 @@ class MessageService:
*,
app_model: App,
message_id: str,
user: Optional[Union[Account, EndUser]],
rating: Optional[str],
content: Optional[str],
user: Union[Account, EndUser] | None,
rating: str | None,
content: str | None,
):
if not user:
raise ValueError("user cannot be None")
@ -196,7 +196,7 @@ class MessageService:
return [record.to_dict() for record in feedbacks]
@classmethod
def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
def get_message(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str):
message = (
db.session.query(Message)
.where(
@ -216,8 +216,8 @@ class MessageService:
@classmethod
def get_suggested_questions_after_answer(
cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str, invoke_from: InvokeFrom
) -> list[Message]:
cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str, invoke_from: InvokeFrom
) -> list[str]:
if not user:
raise ValueError("user cannot be None")
@ -229,7 +229,7 @@ class MessageService:
model_manager = ModelManager()
if app_model.mode == AppMode.ADVANCED_CHAT.value:
if app_model.mode == AppMode.ADVANCED_CHAT:
workflow_service = WorkflowService()
if invoke_from == InvokeFrom.DEBUGGER:
workflow = workflow_service.get_draft_workflow(app_model=app_model)
@ -241,6 +241,9 @@ class MessageService:
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
if not app_config.additional_features:
raise ValueError("Additional features not found")
if not app_config.additional_features.suggested_questions_after_answer:
raise SuggestedQuestionsAfterAnswerDisabledError()
@ -285,7 +288,7 @@ class MessageService:
)
with measure_time() as timer:
questions: list[Message] = LLMGenerator.generate_suggested_questions_after_answer(
questions: list[str] = LLMGenerator.generate_suggested_questions_after_answer(
tenant_id=app_model.tenant_id, histories=histories
)

View File

@ -1,6 +1,5 @@
import copy
import logging
from typing import Optional
from flask_login import current_user
@ -131,11 +130,11 @@ class MetadataService:
@staticmethod
def get_built_in_fields():
return [
{"name": BuiltInField.document_name.value, "type": "string"},
{"name": BuiltInField.uploader.value, "type": "string"},
{"name": BuiltInField.upload_date.value, "type": "time"},
{"name": BuiltInField.last_update_date.value, "type": "time"},
{"name": BuiltInField.source.value, "type": "string"},
{"name": BuiltInField.document_name, "type": "string"},
{"name": BuiltInField.uploader, "type": "string"},
{"name": BuiltInField.upload_date, "type": "time"},
{"name": BuiltInField.last_update_date, "type": "time"},
{"name": BuiltInField.source, "type": "string"},
]
@staticmethod
@ -153,11 +152,11 @@ class MetadataService:
doc_metadata = {}
else:
doc_metadata = copy.deepcopy(document.doc_metadata)
doc_metadata[BuiltInField.document_name.value] = document.name
doc_metadata[BuiltInField.uploader.value] = document.uploader
doc_metadata[BuiltInField.upload_date.value] = document.upload_date.timestamp()
doc_metadata[BuiltInField.last_update_date.value] = document.last_update_date.timestamp()
doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value
doc_metadata[BuiltInField.document_name] = document.name
doc_metadata[BuiltInField.uploader] = document.uploader
doc_metadata[BuiltInField.upload_date] = document.upload_date.timestamp()
doc_metadata[BuiltInField.last_update_date] = document.last_update_date.timestamp()
doc_metadata[BuiltInField.source] = MetadataDataSource[document.data_source_type]
document.doc_metadata = doc_metadata
db.session.add(document)
dataset.built_in_field_enabled = True
@ -183,11 +182,11 @@ class MetadataService:
doc_metadata = {}
else:
doc_metadata = copy.deepcopy(document.doc_metadata)
doc_metadata.pop(BuiltInField.document_name.value, None)
doc_metadata.pop(BuiltInField.uploader.value, None)
doc_metadata.pop(BuiltInField.upload_date.value, None)
doc_metadata.pop(BuiltInField.last_update_date.value, None)
doc_metadata.pop(BuiltInField.source.value, None)
doc_metadata.pop(BuiltInField.document_name, None)
doc_metadata.pop(BuiltInField.uploader, None)
doc_metadata.pop(BuiltInField.upload_date, None)
doc_metadata.pop(BuiltInField.last_update_date, None)
doc_metadata.pop(BuiltInField.source, None)
document.doc_metadata = doc_metadata
db.session.add(document)
document_ids.append(document.id)
@ -211,11 +210,11 @@ class MetadataService:
for metadata_value in operation.metadata_list:
doc_metadata[metadata_value.name] = metadata_value.value
if dataset.built_in_field_enabled:
doc_metadata[BuiltInField.document_name.value] = document.name
doc_metadata[BuiltInField.uploader.value] = document.uploader
doc_metadata[BuiltInField.upload_date.value] = document.upload_date.timestamp()
doc_metadata[BuiltInField.last_update_date.value] = document.last_update_date.timestamp()
doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value
doc_metadata[BuiltInField.document_name] = document.name
doc_metadata[BuiltInField.uploader] = document.uploader
doc_metadata[BuiltInField.upload_date] = document.upload_date.timestamp()
doc_metadata[BuiltInField.last_update_date] = document.last_update_date.timestamp()
doc_metadata[BuiltInField.source] = MetadataDataSource[document.data_source_type]
document.doc_metadata = doc_metadata
db.session.add(document)
db.session.commit()
@ -237,7 +236,7 @@ class MetadataService:
redis_client.delete(lock_key)
@staticmethod
def knowledge_base_metadata_lock_check(dataset_id: Optional[str], document_id: Optional[str]):
def knowledge_base_metadata_lock_check(dataset_id: str | None, document_id: str | None):
if dataset_id:
lock_key = f"dataset_metadata_lock_{dataset_id}"
if redis_client.get(lock_key):

View File

@ -1,7 +1,9 @@
import json
import logging
from json import JSONDecodeError
from typing import Optional, Union
from typing import Union
from sqlalchemy import or_, select
from constants import HIDDEN_VALUE
from core.entities.provider_configuration import ProviderConfiguration
@ -23,10 +25,10 @@ logger = logging.getLogger(__name__)
class ModelLoadBalancingService:
def __init__(self) -> None:
def __init__(self):
self.provider_manager = ProviderManager()
def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str):
"""
enable model load balancing.
@ -47,7 +49,7 @@ class ModelLoadBalancingService:
# Enable model load balancing
provider_configuration.enable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str):
"""
disable model load balancing.
@ -69,7 +71,7 @@ class ModelLoadBalancingService:
provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
def get_load_balancing_configs(
self, tenant_id: str, provider: str, model: str, model_type: str
self, tenant_id: str, provider: str, model: str, model_type: str, config_from: str = ""
) -> tuple[bool, list[dict]]:
"""
Get load balancing configurations.
@ -100,6 +102,11 @@ class ModelLoadBalancingService:
if provider_model_setting and provider_model_setting.load_balancing_enabled:
is_load_balancing_enabled = True
if config_from == "predefined-model":
credential_source_type = "provider"
else:
credential_source_type = "custom_model"
# Get load balancing configurations
load_balancing_configs = (
db.session.query(LoadBalancingModelConfig)
@ -108,6 +115,10 @@ class ModelLoadBalancingService:
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
or_(
LoadBalancingModelConfig.credential_source_type == credential_source_type,
LoadBalancingModelConfig.credential_source_type.is_(None),
),
)
.order_by(LoadBalancingModelConfig.created_at)
.all()
@ -154,7 +165,7 @@ class ModelLoadBalancingService:
try:
if load_balancing_config.encrypted_config:
credentials = json.loads(load_balancing_config.encrypted_config)
credentials: dict[str, object] = json.loads(load_balancing_config.encrypted_config)
else:
credentials = {}
except JSONDecodeError:
@ -169,9 +180,13 @@ class ModelLoadBalancingService:
for variable in credential_secret_variables:
if variable in credentials:
try:
credentials[variable] = encrypter.decrypt_token_with_decoding(
credentials.get(variable), decoding_rsa_key, decoding_cipher_rsa
)
token_value = credentials.get(variable)
if isinstance(token_value, str):
credentials[variable] = encrypter.decrypt_token_with_decoding(
token_value,
decoding_rsa_key,
decoding_cipher_rsa,
)
except ValueError:
pass
@ -196,7 +211,7 @@ class ModelLoadBalancingService:
def get_load_balancing_config(
self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str
) -> Optional[dict]:
) -> dict | None:
"""
Get load balancing configuration.
:param tenant_id: workspace id
@ -282,7 +297,7 @@ class ModelLoadBalancingService:
def update_load_balancing_configs(
self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict], config_from: str
) -> None:
):
"""
Update load balancing configurations.
:param tenant_id: workspace id
@ -307,16 +322,14 @@ class ModelLoadBalancingService:
if not isinstance(configs, list):
raise ValueError("Invalid load balancing configs")
current_load_balancing_configs = (
db.session.query(LoadBalancingModelConfig)
.where(
current_load_balancing_configs = db.session.scalars(
select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
)
.all()
)
).all()
# id as key, config as value
current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs}
@ -332,8 +345,9 @@ class ModelLoadBalancingService:
credential_id = config.get("credential_id")
enabled = config.get("enabled")
credential_record: ProviderCredential | ProviderModelCredential | None = None
if credential_id:
credential_record: ProviderCredential | ProviderModelCredential | None = None
if config_from == "predefined-model":
credential_record = (
db.session.query(ProviderCredential)
@ -403,7 +417,7 @@ class ModelLoadBalancingService:
self._clear_credentials_cache(tenant_id, config_id)
else:
# create load balancing config
if name in {"__inherit__", "__delete__"}:
if name == "__inherit__":
raise ValueError("Invalid load balancing config name")
if credential_id:
@ -464,8 +478,8 @@ class ModelLoadBalancingService:
model: str,
model_type: str,
credentials: dict,
config_id: Optional[str] = None,
) -> None:
config_id: str | None = None,
):
"""
Validate load balancing credentials.
:param tenant_id: workspace id
@ -522,9 +536,9 @@ class ModelLoadBalancingService:
model_type: ModelType,
model: str,
credentials: dict,
load_balancing_model_config: Optional[LoadBalancingModelConfig] = None,
load_balancing_model_config: LoadBalancingModelConfig | None = None,
validate: bool = True,
) -> dict:
):
"""
Validate custom credentials.
:param tenant_id: workspace id
@ -592,7 +606,7 @@ class ModelLoadBalancingService:
else:
raise ValueError("No credential schema found")
def _clear_credentials_cache(self, tenant_id: str, config_id: str) -> None:
def _clear_credentials_cache(self, tenant_id: str, config_id: str):
"""
Clear credentials cache.
:param tenant_id: workspace id

View File

@ -1,5 +1,4 @@
import logging
from typing import Optional
from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity
from core.model_runtime.entities.model_entities import ModelType, ParameterRule
@ -26,7 +25,7 @@ class ModelProviderService:
Model Provider Service
"""
def __init__(self) -> None:
def __init__(self):
self.provider_manager = ProviderManager()
def _get_provider_configuration(self, tenant_id: str, provider: str):
@ -52,7 +51,7 @@ class ModelProviderService:
return provider_configuration
def get_provider_list(self, tenant_id: str, model_type: Optional[str] = None) -> list[ProviderResponse]:
def get_provider_list(self, tenant_id: str, model_type: str | None = None) -> list[ProviderResponse]:
"""
get provider list.
@ -72,6 +71,7 @@ class ModelProviderService:
provider_config = provider_configuration.custom_configuration.provider
model_config = provider_configuration.custom_configuration.models
can_added_models = provider_configuration.custom_configuration.can_added_models
provider_response = ProviderResponse(
tenant_id=tenant_id,
@ -95,6 +95,7 @@ class ModelProviderService:
current_credential_name=getattr(provider_config, "current_credential_name", None),
available_credentials=getattr(provider_config, "available_credentials", []),
custom_models=model_config,
can_added_models=can_added_models,
),
system_configuration=SystemConfigurationResponse(
enabled=provider_configuration.system_configuration.enabled,
@ -126,9 +127,7 @@ class ModelProviderService:
for model in provider_configurations.get_models(provider=provider)
]
def get_provider_credential(
self, tenant_id: str, provider: str, credential_id: Optional[str] = None
) -> Optional[dict]:
def get_provider_credential(self, tenant_id: str, provider: str, credential_id: str | None = None) -> dict | None:
"""
get provider credentials.
@ -140,7 +139,7 @@ class ModelProviderService:
provider_configuration = self._get_provider_configuration(tenant_id, provider)
return provider_configuration.get_provider_credential(credential_id=credential_id) # type: ignore
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict) -> None:
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict):
"""
validate provider credentials before saving.
@ -152,7 +151,7 @@ class ModelProviderService:
provider_configuration.validate_provider_credentials(credentials)
def create_provider_credential(
self, tenant_id: str, provider: str, credentials: dict, credential_name: str
self, tenant_id: str, provider: str, credentials: dict, credential_name: str | None
) -> None:
"""
Create and save new provider credentials.
@ -172,7 +171,7 @@ class ModelProviderService:
provider: str,
credentials: dict,
credential_id: str,
credential_name: str,
credential_name: str | None,
) -> None:
"""
update a saved provider credential (by credential_id).
@ -191,7 +190,7 @@ class ModelProviderService:
credential_name=credential_name,
)
def remove_provider_credential(self, tenant_id: str, provider: str, credential_id: str) -> None:
def remove_provider_credential(self, tenant_id: str, provider: str, credential_id: str):
"""
remove a saved provider credential (by credential_id).
:param tenant_id: workspace id
@ -202,7 +201,7 @@ class ModelProviderService:
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.delete_provider_credential(credential_id=credential_id)
def switch_active_provider_credential(self, tenant_id: str, provider: str, credential_id: str) -> None:
def switch_active_provider_credential(self, tenant_id: str, provider: str, credential_id: str):
"""
:param tenant_id: workspace id
:param provider: provider name
@ -214,7 +213,7 @@ class ModelProviderService:
def get_model_credential(
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | None
) -> Optional[dict]:
) -> dict | None:
"""
Retrieve model-specific credentials.
@ -230,9 +229,7 @@ class ModelProviderService:
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
)
def validate_model_credentials(
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict
) -> None:
def validate_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict):
"""
validate model credentials.
@ -249,7 +246,7 @@ class ModelProviderService:
)
def create_model_credential(
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str | None
) -> None:
"""
create and save model credentials.
@ -278,7 +275,7 @@ class ModelProviderService:
model: str,
credentials: dict,
credential_id: str,
credential_name: str,
credential_name: str | None,
) -> None:
"""
update model credentials.
@ -301,9 +298,7 @@ class ModelProviderService:
credential_name=credential_name,
)
def remove_model_credential(
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str
) -> None:
def remove_model_credential(self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str):
"""
remove model credentials.
@ -321,7 +316,7 @@ class ModelProviderService:
def switch_active_custom_model_credential(
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str
) -> None:
):
"""
switch model credentials.
@ -339,7 +334,7 @@ class ModelProviderService:
def add_model_credential_to_model_list(
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str
) -> None:
):
"""
add model credentials to model list.
@ -355,7 +350,7 @@ class ModelProviderService:
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
)
def remove_model(self, tenant_id: str, provider: str, model_type: str, model: str) -> None:
def remove_model(self, tenant_id: str, provider: str, model_type: str, model: str):
"""
remove model credentials.
@ -451,7 +446,7 @@ class ModelProviderService:
return model_schema.parameter_rules if model_schema else []
def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]:
def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> DefaultModelResponse | None:
"""
get default model of model type.
@ -483,7 +478,7 @@ class ModelProviderService:
logger.debug("get_default_model_of_model_type error: %s", e)
return None
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None:
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str):
"""
update default model of model type.
@ -500,7 +495,7 @@ class ModelProviderService:
def get_model_provider_icon(
self, tenant_id: str, provider: str, icon_type: str, lang: str
) -> tuple[Optional[bytes], Optional[str]]:
) -> tuple[bytes | None, str | None]:
"""
get model provider icon.
@ -515,7 +510,7 @@ class ModelProviderService:
return byte_data, mime_type
def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str) -> None:
def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str):
"""
switch preferred provider.
@ -532,7 +527,7 @@ class ModelProviderService:
# Switch preferred provider type
provider_configuration.switch_preferred_provider_type(preferred_provider_type_enum)
def enable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
def enable_model(self, tenant_id: str, provider: str, model: str, model_type: str):
"""
enable model.
@ -545,7 +540,7 @@ class ModelProviderService:
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type))
def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str):
"""
disable model.

View File

@ -1,6 +1,6 @@
import os
import requests
import httpx
class OperationService:
@ -12,7 +12,7 @@ class OperationService:
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
url = f"{cls.base_url}{endpoint}"
response = requests.request(method, url, json=json, params=params, headers=headers)
response = httpx.request(method, url, json=json, params=params, headers=headers)
return response.json()

View File

@ -1,4 +1,4 @@
from typing import Any, Optional
from typing import Any
from core.ops.entities.config_entity import BaseTracingConfig
from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map
@ -15,7 +15,7 @@ class OpsService:
:param tracing_provider: tracing provider
:return:
"""
trace_config_data: Optional[TraceAppConfig] = (
trace_config_data: TraceAppConfig | None = (
db.session.query(TraceAppConfig)
.where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
@ -153,7 +153,7 @@ class OpsService:
project_url = None
# check if trace config already exists
trace_config_data: Optional[TraceAppConfig] = (
trace_config_data: TraceAppConfig | None = (
db.session.query(TraceAppConfig)
.where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()

View File

@ -4,15 +4,15 @@ import logging
import click
import sqlalchemy as sa
from core.plugin.entities.plugin import GenericProviderID, ModelProviderID, ToolProviderID
from models.engine import db
from extensions.ext_database import db
from models.provider_ids import GenericProviderID, ModelProviderID, ToolProviderID
logger = logging.getLogger(__name__)
class PluginDataMigration:
@classmethod
def migrate(cls) -> None:
def migrate(cls):
cls.migrate_db_records("providers", "provider_name", ModelProviderID) # large table
cls.migrate_db_records("provider_models", "provider_name", ModelProviderID)
cls.migrate_db_records("provider_orders", "provider_name", ModelProviderID)
@ -26,7 +26,7 @@ class PluginDataMigration:
cls.migrate_db_records("tool_builtin_providers", "provider", ToolProviderID)
@classmethod
def migrate_datasets(cls) -> None:
def migrate_datasets(cls):
table_name = "datasets"
provider_column_name = "embedding_model_provider"
@ -46,7 +46,11 @@ limit 1000"""
record_id = str(i.id)
provider_name = str(i.provider_name)
retrieval_model = i.retrieval_model
print(type(retrieval_model))
logger.debug(
"Processing dataset %s with retrieval model of type %s",
record_id,
type(retrieval_model),
)
if record_id in failed_ids:
continue
@ -126,9 +130,7 @@ limit 1000"""
)
@classmethod
def migrate_db_records(
cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID]
) -> None:
def migrate_db_records(cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID]):
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
processed_count = 0
@ -175,7 +177,7 @@ limit 1000"""
# update jina to langgenius/jina_tool/jina etc.
updated_value = provider_cls(provider_name).to_string()
batch_updates.append((updated_value, record_id))
except Exception as e:
except Exception:
failed_ids.append(record_id)
click.echo(
click.style(

View File

@ -1,7 +1,13 @@
import re
from configs import dify_config
from core.helper import marketplace
from core.plugin.entities.plugin import ModelProviderID, PluginDependency, PluginInstallationSource, ToolProviderID
from core.plugin.entities.plugin import PluginDependency, PluginInstallationSource
from core.plugin.impl.plugin import PluginInstaller
from models.provider_ids import ModelProviderID, ToolProviderID
# Compile regex pattern for version extraction at module level for better performance
_VERSION_REGEX = re.compile(r":(?P<version>[0-9]+(?:\.[0-9]+){2}(?:[+-][0-9A-Za-z.-]+)?)(?:@|$)")
class DependenciesAnalysisService:
@ -48,6 +54,13 @@ class DependenciesAnalysisService:
for dependency in dependencies:
unique_identifier = dependency.value.plugin_unique_identifier
if unique_identifier in missing_plugin_unique_identifiers:
# Extract version for Marketplace dependencies
if dependency.type == PluginDependency.Type.Marketplace:
version_match = _VERSION_REGEX.search(unique_identifier)
if version_match:
dependency.value.version = version_match.group("version")
# Create and append the dependency (same for all types)
leaked_dependencies.append(
PluginDependency(
type=dependency.type,

View File

@ -11,7 +11,14 @@ class OAuthProxyService(BasePluginClient):
__KEY_PREFIX__ = "oauth_proxy_context:"
@staticmethod
def create_proxy_context(user_id: str, tenant_id: str, plugin_id: str, provider: str, extra_data: dict = {}):
def create_proxy_context(
user_id: str,
tenant_id: str,
plugin_id: str,
provider: str,
extra_data: dict = {},
credential_id: str | None = None,
):
"""
Create a proxy context for an OAuth 2.0 authorization request.
@ -32,6 +39,8 @@ class OAuthProxyService(BasePluginClient):
"tenant_id": tenant_id,
"provider": provider,
}
if credential_id:
data["credential_id"] = credential_id
redis_client.setex(
f"{OAuthProxyService.__KEY_PREFIX__}{context_id}",
OAuthProxyService.__MAX_AGE__,

View File

@ -5,7 +5,7 @@ import time
from collections.abc import Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any, Optional
from typing import Any
from uuid import uuid4
import click
@ -16,15 +16,17 @@ from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.helper import marketplace
from core.plugin.entities.plugin import ModelProviderID, PluginInstallationSource, ToolProviderID
from core.plugin.entities.plugin import PluginInstallationSource
from core.plugin.entities.plugin_daemon import PluginInstallTaskStatus
from core.plugin.impl.plugin import PluginInstaller
from core.tools.entities.tool_entities import ToolProviderType
from extensions.ext_database import db
from models.account import Tenant
from models.engine import db
from models.model import App, AppMode, AppModelConfig
from models.provider_ids import ModelProviderID, ToolProviderID
from models.tools import BuiltinToolProvider
from models.workflow import Workflow
from services.plugin.plugin_service import PluginService
logger = logging.getLogger(__name__)
@ -33,7 +35,7 @@ excluded_providers = ["time", "audio", "code", "webscraper"]
class PluginMigration:
@classmethod
def extract_plugins(cls, filepath: str, workers: int) -> None:
def extract_plugins(cls, filepath: str, workers: int):
"""
Migrate plugin.
"""
@ -55,7 +57,7 @@ class PluginMigration:
thread_pool = ThreadPoolExecutor(max_workers=workers)
def process_tenant(flask_app: Flask, tenant_id: str) -> None:
def process_tenant(flask_app: Flask, tenant_id: str):
with flask_app.app_context():
nonlocal handled_tenant_count
try:
@ -99,6 +101,7 @@ class PluginMigration:
datetime.timedelta(hours=1),
]
tenant_count = 0
for test_interval in test_intervals:
tenant_count = (
session.query(Tenant.id)
@ -255,7 +258,7 @@ class PluginMigration:
return []
agent_app_model_config_ids = [
app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT.value
app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT
]
rs = session.query(AppModelConfig).where(AppModelConfig.id.in_(agent_app_model_config_ids)).all()
@ -280,7 +283,7 @@ class PluginMigration:
return result
@classmethod
def _fetch_plugin_unique_identifier(cls, plugin_id: str) -> Optional[str]:
def _fetch_plugin_unique_identifier(cls, plugin_id: str) -> str | None:
"""
Fetch plugin unique identifier using plugin id.
"""
@ -291,7 +294,7 @@ class PluginMigration:
return plugin_manifest[0].latest_package_identifier
@classmethod
def extract_unique_plugins_to_file(cls, extracted_plugins: str, output_file: str) -> None:
def extract_unique_plugins_to_file(cls, extracted_plugins: str, output_file: str):
"""
Extract unique plugins.
"""
@ -328,7 +331,7 @@ class PluginMigration:
return {"plugins": plugins, "plugin_not_exist": plugin_not_exist}
@classmethod
def install_plugins(cls, extracted_plugins: str, output_file: str, workers: int = 100) -> None:
def install_plugins(cls, extracted_plugins: str, output_file: str, workers: int = 100):
"""
Install plugins.
"""
@ -348,7 +351,7 @@ class PluginMigration:
if response.get("failed"):
plugin_install_failed.extend(response.get("failed", []))
def install(tenant_id: str, plugin_ids: list[str]) -> None:
def install(tenant_id: str, plugin_ids: list[str]):
logger.info("Installing %s plugins for tenant %s", len(plugin_ids), tenant_id)
# fetch plugin already installed
installed_plugins = manager.list_plugins(tenant_id)
@ -420,6 +423,94 @@ class PluginMigration:
)
)
@classmethod
def install_rag_pipeline_plugins(cls, extracted_plugins: str, output_file: str, workers: int = 100) -> None:
"""
Install rag pipeline plugins.
"""
manager = PluginInstaller()
plugins = cls.extract_unique_plugins(extracted_plugins)
plugin_install_failed = []
# use a fake tenant id to install all the plugins
fake_tenant_id = uuid4().hex
logger.info("Installing %s plugin instances for fake tenant %s", len(plugins["plugins"]), fake_tenant_id)
thread_pool = ThreadPoolExecutor(max_workers=workers)
response = cls.handle_plugin_instance_install(fake_tenant_id, plugins["plugins"])
if response.get("failed"):
plugin_install_failed.extend(response.get("failed", []))
def install(
tenant_id: str, plugin_ids: dict[str, str], total_success_tenant: int, total_failed_tenant: int
) -> None:
logger.info("Installing %s plugins for tenant %s", len(plugin_ids), tenant_id)
try:
# fetch plugin already installed
installed_plugins = manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
# at most 64 plugins one batch
for i in range(0, len(plugin_ids), 64):
batch_plugin_ids = list(plugin_ids.keys())[i : i + 64]
batch_plugin_identifiers = [
plugin_ids[plugin_id]
for plugin_id in batch_plugin_ids
if plugin_id not in installed_plugins_ids and plugin_id in plugin_ids
]
PluginService.install_from_marketplace_pkg(tenant_id, batch_plugin_identifiers)
total_success_tenant += 1
except Exception:
logger.exception("Failed to install plugins for tenant %s", tenant_id)
total_failed_tenant += 1
page = 1
total_success_tenant = 0
total_failed_tenant = 0
while True:
# paginate
tenants = db.paginate(db.select(Tenant).order_by(Tenant.created_at.desc()), page=page, per_page=100)
if tenants.items is None or len(tenants.items) == 0:
break
for tenant in tenants:
tenant_id = tenant.id
# get plugin unique identifier
thread_pool.submit(
install,
tenant_id,
plugins.get("plugins", {}),
total_success_tenant,
total_failed_tenant,
)
page += 1
thread_pool.shutdown(wait=True)
# uninstall all the plugins for fake tenant
try:
installation = manager.list_plugins(fake_tenant_id)
while installation:
for plugin in installation:
manager.uninstall(fake_tenant_id, plugin.installation_id)
installation = manager.list_plugins(fake_tenant_id)
except Exception:
logger.exception("Failed to get installation for tenant %s", fake_tenant_id)
Path(output_file).write_text(
json.dumps(
{
"total_success_tenant": total_success_tenant,
"total_failed_tenant": total_failed_tenant,
"plugin_install_failed": plugin_install_failed,
}
)
)
@classmethod
def handle_plugin_instance_install(
cls, tenant_id: str, plugin_identifiers_map: Mapping[str, str]

View File

@ -1,7 +1,6 @@
import logging
from collections.abc import Mapping, Sequence
from mimetypes import guess_type
from typing import Optional
from pydantic import BaseModel
from yarl import URL
@ -12,7 +11,6 @@ from core.helper.download import download_with_size_limit
from core.helper.marketplace import download_plugin_pkg
from core.plugin.entities.bundle import PluginBundleDependency
from core.plugin.entities.plugin import (
GenericProviderID,
PluginDeclaration,
PluginEntity,
PluginInstallation,
@ -28,6 +26,7 @@ from core.plugin.impl.asset import PluginAssetManager
from core.plugin.impl.debugging import PluginDebuggingClient
from core.plugin.impl.plugin import PluginInstaller
from extensions.ext_redis import redis_client
from models.provider_ids import GenericProviderID
from services.errors.plugin import PluginInstallationForbiddenError
from services.feature_service import FeatureService, PluginInstallationScope
@ -47,11 +46,11 @@ class PluginService:
REDIS_TTL = 60 * 5 # 5 minutes
@staticmethod
def fetch_latest_plugin_version(plugin_ids: Sequence[str]) -> Mapping[str, Optional[LatestPluginCache]]:
def fetch_latest_plugin_version(plugin_ids: Sequence[str]) -> Mapping[str, LatestPluginCache | None]:
"""
Fetch the latest plugin version
"""
result: dict[str, Optional[PluginService.LatestPluginCache]] = {}
result: dict[str, PluginService.LatestPluginCache | None] = {}
try:
cache_not_exists = []
@ -110,7 +109,7 @@ class PluginService:
raise PluginInstallationForbiddenError("Plugin installation is restricted to marketplace only")
@staticmethod
def _check_plugin_installation_scope(plugin_verification: Optional[PluginVerification]):
def _check_plugin_installation_scope(plugin_verification: PluginVerification | None):
"""
Check the plugin installation scope
"""
@ -145,7 +144,7 @@ class PluginService:
return manager.get_debugging_key(tenant_id)
@staticmethod
def list_latest_versions(plugin_ids: Sequence[str]) -> Mapping[str, Optional[LatestPluginCache]]:
def list_latest_versions(plugin_ids: Sequence[str]) -> Mapping[str, LatestPluginCache | None]:
"""
List the latest versions of the plugins
"""

View File

@ -0,0 +1,22 @@
from collections.abc import Mapping
from typing import Any
from pydantic import BaseModel
class DatasourceNodeRunApiEntity(BaseModel):
pipeline_id: str
node_id: str
inputs: dict[str, Any]
datasource_type: str
credential_id: str | None = None
is_published: bool
class PipelineRunApiEntity(BaseModel):
inputs: Mapping[str, Any]
datasource_type: str
datasource_info_list: list[Mapping[str, Any]]
start_node_id: str
is_published: bool
response_mode: str

View File

@ -0,0 +1,115 @@
from collections.abc import Mapping
from typing import Any, Union
from configs import dify_config
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from models.dataset import Document, Pipeline
from models.model import Account, App, EndUser
from models.workflow import Workflow
from services.rag_pipeline.rag_pipeline import RagPipelineService
class PipelineGenerateService:
@classmethod
def generate(
cls,
pipeline: Pipeline,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
):
"""
Pipeline Content Generate
:param pipeline: pipeline
:param user: user
:param args: args
:param invoke_from: invoke from
:param streaming: streaming
:return:
"""
try:
workflow = cls._get_workflow(pipeline, invoke_from)
if original_document_id := args.get("original_document_id"):
# update document status to waiting
cls.update_document_status(original_document_id)
return PipelineGenerator.convert_to_event_stream(
PipelineGenerator().generate(
pipeline=pipeline,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=streaming,
call_depth=0,
workflow_thread_pool_id=None,
),
)
except Exception:
raise
@staticmethod
def _get_max_active_requests(app_model: App) -> int:
max_active_requests = app_model.max_active_requests
if max_active_requests is None:
max_active_requests = int(dify_config.APP_MAX_ACTIVE_REQUESTS)
return max_active_requests
@classmethod
def generate_single_iteration(
cls, pipeline: Pipeline, user: Account, node_id: str, args: Any, streaming: bool = True
):
workflow = cls._get_workflow(pipeline, InvokeFrom.DEBUGGER)
return PipelineGenerator.convert_to_event_stream(
PipelineGenerator().single_iteration_generate(
pipeline=pipeline, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
)
@classmethod
def generate_single_loop(cls, pipeline: Pipeline, user: Account, node_id: str, args: Any, streaming: bool = True):
workflow = cls._get_workflow(pipeline, InvokeFrom.DEBUGGER)
return PipelineGenerator.convert_to_event_stream(
PipelineGenerator().single_loop_generate(
pipeline=pipeline, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming
)
)
@classmethod
def _get_workflow(cls, pipeline: Pipeline, invoke_from: InvokeFrom) -> Workflow:
"""
Get workflow
:param pipeline: pipeline
:param invoke_from: invoke from
:return:
"""
rag_pipeline_service = RagPipelineService()
if invoke_from == InvokeFrom.DEBUGGER:
# fetch draft workflow by app_model
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
if not workflow:
raise ValueError("Workflow not initialized")
else:
# fetch published workflow by app_model
workflow = rag_pipeline_service.get_published_workflow(pipeline=pipeline)
if not workflow:
raise ValueError("Workflow not published")
return workflow
@classmethod
def update_document_status(cls, document_id: str):
"""
Update document status to waiting
:param document_id: document id
"""
document = db.session.query(Document).where(Document.id == document_id).first()
if document:
document.indexing_status = "waiting"
db.session.add(document)
db.session.commit()

View File

@ -0,0 +1,63 @@
import json
from os import path
from pathlib import Path
from flask import current_app
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"""
Retrieval pipeline template from built-in, the location is constants/pipeline_templates.json
"""
builtin_data: dict | None = None
def get_type(self) -> str:
return PipelineTemplateType.BUILTIN
def get_pipeline_templates(self, language: str) -> dict:
result = self.fetch_pipeline_templates_from_builtin(language)
return result
def get_pipeline_template_detail(self, template_id: str):
result = self.fetch_pipeline_template_detail_from_builtin(template_id)
return result
@classmethod
def _get_builtin_data(cls) -> dict:
"""
Get builtin data.
:return:
"""
if cls.builtin_data:
return cls.builtin_data
root_path = current_app.root_path
cls.builtin_data = json.loads(
Path(path.join(root_path, "constants", "pipeline_templates.json")).read_text(encoding="utf-8")
)
return cls.builtin_data or {}
@classmethod
def fetch_pipeline_templates_from_builtin(cls, language: str) -> dict:
"""
Fetch pipeline templates from builtin.
:param language: language
:return:
"""
builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data()
return builtin_data.get("pipeline_templates", {}).get(language, {})
@classmethod
def fetch_pipeline_template_detail_from_builtin(cls, template_id: str) -> dict | None:
"""
Fetch pipeline template detail from builtin.
:param template_id: Template ID
:return:
"""
builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data()
return builtin_data.get("pipeline_templates", {}).get(template_id)

View File

@ -0,0 +1,81 @@
import yaml
from flask_login import current_user
from extensions.ext_database import db
from models.dataset import PipelineCustomizedTemplate
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"""
Retrieval recommended app from database
"""
def get_pipeline_templates(self, language: str) -> dict:
result = self.fetch_pipeline_templates_from_customized(
tenant_id=current_user.current_tenant_id, language=language
)
return result
def get_pipeline_template_detail(self, template_id: str):
result = self.fetch_pipeline_template_detail_from_db(template_id)
return result
def get_type(self) -> str:
return PipelineTemplateType.CUSTOMIZED
@classmethod
def fetch_pipeline_templates_from_customized(cls, tenant_id: str, language: str) -> dict:
"""
Fetch pipeline templates from db.
:param tenant_id: tenant id
:param language: language
:return:
"""
pipeline_customized_templates = (
db.session.query(PipelineCustomizedTemplate)
.where(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language)
.order_by(PipelineCustomizedTemplate.position.asc(), PipelineCustomizedTemplate.created_at.desc())
.all()
)
recommended_pipelines_results = []
for pipeline_customized_template in pipeline_customized_templates:
recommended_pipeline_result = {
"id": pipeline_customized_template.id,
"name": pipeline_customized_template.name,
"description": pipeline_customized_template.description,
"icon": pipeline_customized_template.icon,
"position": pipeline_customized_template.position,
"chunk_structure": pipeline_customized_template.chunk_structure,
}
recommended_pipelines_results.append(recommended_pipeline_result)
return {"pipeline_templates": recommended_pipelines_results}
@classmethod
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> dict | None:
"""
Fetch pipeline template detail from db.
:param template_id: Template ID
:return:
"""
pipeline_template = (
db.session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first()
)
if not pipeline_template:
return None
dsl_data = yaml.safe_load(pipeline_template.yaml_content)
graph_data = dsl_data.get("workflow", {}).get("graph", {})
return {
"id": pipeline_template.id,
"name": pipeline_template.name,
"icon_info": pipeline_template.icon,
"description": pipeline_template.description,
"chunk_structure": pipeline_template.chunk_structure,
"export_data": pipeline_template.yaml_content,
"graph": graph_data,
"created_by": pipeline_template.created_user_name,
}

View File

@ -0,0 +1,78 @@
import yaml
from extensions.ext_database import db
from models.dataset import PipelineBuiltInTemplate
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"""
Retrieval pipeline template from database
"""
def get_pipeline_templates(self, language: str) -> dict:
result = self.fetch_pipeline_templates_from_db(language)
return result
def get_pipeline_template_detail(self, template_id: str):
result = self.fetch_pipeline_template_detail_from_db(template_id)
return result
def get_type(self) -> str:
return PipelineTemplateType.DATABASE
@classmethod
def fetch_pipeline_templates_from_db(cls, language: str) -> dict:
"""
Fetch pipeline templates from db.
:param language: language
:return:
"""
pipeline_built_in_templates: list[PipelineBuiltInTemplate] = (
db.session.query(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.language == language).all()
)
recommended_pipelines_results = []
for pipeline_built_in_template in pipeline_built_in_templates:
recommended_pipeline_result = {
"id": pipeline_built_in_template.id,
"name": pipeline_built_in_template.name,
"description": pipeline_built_in_template.description,
"icon": pipeline_built_in_template.icon,
"copyright": pipeline_built_in_template.copyright,
"privacy_policy": pipeline_built_in_template.privacy_policy,
"position": pipeline_built_in_template.position,
"chunk_structure": pipeline_built_in_template.chunk_structure,
}
recommended_pipelines_results.append(recommended_pipeline_result)
return {"pipeline_templates": recommended_pipelines_results}
@classmethod
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> dict | None:
"""
Fetch pipeline template detail from db.
:param pipeline_id: Pipeline ID
:return:
"""
# is in public recommended list
pipeline_template = (
db.session.query(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.id == template_id).first()
)
if not pipeline_template:
return None
dsl_data = yaml.safe_load(pipeline_template.yaml_content)
graph_data = dsl_data.get("workflow", {}).get("graph", {})
return {
"id": pipeline_template.id,
"name": pipeline_template.name,
"icon_info": pipeline_template.icon,
"description": pipeline_template.description,
"chunk_structure": pipeline_template.chunk_structure,
"export_data": pipeline_template.yaml_content,
"graph": graph_data,
"created_by": pipeline_template.created_user_name,
}

View File

@ -0,0 +1,17 @@
from abc import ABC, abstractmethod
class PipelineTemplateRetrievalBase(ABC):
"""Interface for pipeline template retrieval."""
@abstractmethod
def get_pipeline_templates(self, language: str) -> dict:
raise NotImplementedError
@abstractmethod
def get_pipeline_template_detail(self, template_id: str) -> dict | None:
raise NotImplementedError
@abstractmethod
def get_type(self) -> str:
raise NotImplementedError

View File

@ -0,0 +1,26 @@
from services.rag_pipeline.pipeline_template.built_in.built_in_retrieval import BuiltInPipelineTemplateRetrieval
from services.rag_pipeline.pipeline_template.customized.customized_retrieval import CustomizedPipelineTemplateRetrieval
from services.rag_pipeline.pipeline_template.database.database_retrieval import DatabasePipelineTemplateRetrieval
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
from services.rag_pipeline.pipeline_template.remote.remote_retrieval import RemotePipelineTemplateRetrieval
class PipelineTemplateRetrievalFactory:
@staticmethod
def get_pipeline_template_factory(mode: str) -> type[PipelineTemplateRetrievalBase]:
match mode:
case PipelineTemplateType.REMOTE:
return RemotePipelineTemplateRetrieval
case PipelineTemplateType.CUSTOMIZED:
return CustomizedPipelineTemplateRetrieval
case PipelineTemplateType.DATABASE:
return DatabasePipelineTemplateRetrieval
case PipelineTemplateType.BUILTIN:
return BuiltInPipelineTemplateRetrieval
case _:
raise ValueError(f"invalid fetch recommended apps mode: {mode}")
@staticmethod
def get_built_in_pipeline_template_retrieval():
return BuiltInPipelineTemplateRetrieval

View File

@ -0,0 +1,8 @@
from enum import StrEnum
class PipelineTemplateType(StrEnum):
REMOTE = "remote"
DATABASE = "database"
CUSTOMIZED = "customized"
BUILTIN = "builtin"

View File

@ -0,0 +1,67 @@
import logging
import requests
from configs import dify_config
from services.rag_pipeline.pipeline_template.database.database_retrieval import DatabasePipelineTemplateRetrieval
from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
logger = logging.getLogger(__name__)
class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"""
Retrieval recommended app from dify official
"""
def get_pipeline_template_detail(self, template_id: str):
try:
result = self.fetch_pipeline_template_detail_from_dify_official(template_id)
except Exception as e:
logger.warning("fetch recommended app detail from dify official failed: %r, switch to database.", e)
result = DatabasePipelineTemplateRetrieval.fetch_pipeline_template_detail_from_db(template_id)
return result
def get_pipeline_templates(self, language: str) -> dict:
try:
result = self.fetch_pipeline_templates_from_dify_official(language)
except Exception as e:
logger.warning("fetch pipeline templates from dify official failed: %r, switch to database.", e)
result = DatabasePipelineTemplateRetrieval.fetch_pipeline_templates_from_db(language)
return result
def get_type(self) -> str:
return PipelineTemplateType.REMOTE
@classmethod
def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> dict | None:
"""
Fetch pipeline template detail from dify official.
:param template_id: Pipeline ID
:return:
"""
domain = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN
url = f"{domain}/pipeline-templates/{template_id}"
response = requests.get(url, timeout=(3, 10))
if response.status_code != 200:
return None
data: dict = response.json()
return data
@classmethod
def fetch_pipeline_templates_from_dify_official(cls, language: str) -> dict:
"""
Fetch pipeline templates from dify official.
:param language: language
:return:
"""
domain = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN
url = f"{domain}/pipeline-templates?language={language}"
response = requests.get(url, timeout=(3, 10))
if response.status_code != 200:
raise ValueError(f"fetch pipeline templates failed, status code: {response.status_code}")
result: dict = response.json()
return result

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,944 @@
import base64
import hashlib
import json
import logging
import uuid
from collections.abc import Mapping
from datetime import UTC, datetime
from enum import StrEnum
from typing import cast
from urllib.parse import urlparse
from uuid import uuid4
import yaml # type: ignore
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
from flask_login import current_user
from packaging import version
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.helper import ssrf_proxy
from core.helper.name_generator import generate_incremental_name
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import PluginDependency
from core.workflow.enums import NodeType
from core.workflow.nodes.datasource.entities import DatasourceNodeData
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
from core.workflow.nodes.llm.entities import LLMNodeData
from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
from core.workflow.nodes.tool.entities import ToolNodeData
from extensions.ext_redis import redis_client
from factories import variable_factory
from models import Account
from models.dataset import Dataset, DatasetCollectionBinding, Pipeline
from models.workflow import Workflow, WorkflowType
from services.entities.knowledge_entities.rag_pipeline_entities import (
IconInfo,
KnowledgeConfiguration,
RagPipelineDatasetCreateEntity,
)
from services.plugin.dependencies_analysis import DependenciesAnalysisService
logger = logging.getLogger(__name__)
IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:"
CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:"
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
CURRENT_DSL_VERSION = "0.1.0"
class ImportMode(StrEnum):
YAML_CONTENT = "yaml-content"
YAML_URL = "yaml-url"
class ImportStatus(StrEnum):
COMPLETED = "completed"
COMPLETED_WITH_WARNINGS = "completed-with-warnings"
PENDING = "pending"
FAILED = "failed"
class RagPipelineImportInfo(BaseModel):
id: str
status: ImportStatus
pipeline_id: str | None = None
current_dsl_version: str = CURRENT_DSL_VERSION
imported_dsl_version: str = ""
error: str = ""
dataset_id: str | None = None
class CheckDependenciesResult(BaseModel):
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
def _check_version_compatibility(imported_version: str) -> ImportStatus:
"""Determine import status based on version comparison"""
try:
current_ver = version.parse(CURRENT_DSL_VERSION)
imported_ver = version.parse(imported_version)
except version.InvalidVersion:
return ImportStatus.FAILED
# If imported version is newer than current, always return PENDING
if imported_ver > current_ver:
return ImportStatus.PENDING
# If imported version is older than current's major, return PENDING
if imported_ver.major < current_ver.major:
return ImportStatus.PENDING
# If imported version is older than current's minor, return COMPLETED_WITH_WARNINGS
if imported_ver.minor < current_ver.minor:
return ImportStatus.COMPLETED_WITH_WARNINGS
# If imported version equals or is older than current's micro, return COMPLETED
return ImportStatus.COMPLETED
class RagPipelinePendingData(BaseModel):
import_mode: str
yaml_content: str
pipeline_id: str | None
class CheckDependenciesPendingData(BaseModel):
dependencies: list[PluginDependency]
pipeline_id: str | None
class RagPipelineDslService:
def __init__(self, session: Session):
self._session = session
def import_rag_pipeline(
self,
*,
account: Account,
import_mode: str,
yaml_content: str | None = None,
yaml_url: str | None = None,
pipeline_id: str | None = None,
dataset: Dataset | None = None,
dataset_name: str | None = None,
icon_info: IconInfo | None = None,
) -> RagPipelineImportInfo:
"""Import an app from YAML content or URL."""
import_id = str(uuid.uuid4())
# Validate import mode
try:
mode = ImportMode(import_mode)
except ValueError:
raise ValueError(f"Invalid import_mode: {import_mode}")
# Get YAML content
content: str = ""
if mode == ImportMode.YAML_URL:
if not yaml_url:
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="yaml_url is required when import_mode is yaml-url",
)
try:
parsed_url = urlparse(yaml_url)
if (
parsed_url.scheme == "https"
and parsed_url.netloc == "github.com"
and parsed_url.path.endswith((".yml", ".yaml"))
):
yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com")
yaml_url = yaml_url.replace("/blob/", "/")
response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10))
response.raise_for_status()
content = response.content.decode()
if len(content) > DSL_MAX_SIZE:
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="File size exceeds the limit of 10MB",
)
if not content:
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Empty content from url",
)
except Exception as e:
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error=f"Error fetching YAML from URL: {str(e)}",
)
elif mode == ImportMode.YAML_CONTENT:
if not yaml_content:
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="yaml_content is required when import_mode is yaml-content",
)
content = yaml_content
# Process YAML content
try:
# Parse YAML to validate format
data = yaml.safe_load(content)
if not isinstance(data, dict):
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Invalid YAML format: content must be a mapping",
)
# Validate and fix DSL version
if not data.get("version"):
data["version"] = "0.1.0"
if not data.get("kind") or data.get("kind") != "rag_pipeline":
data["kind"] = "rag_pipeline"
imported_version = data.get("version", "0.1.0")
# check if imported_version is a float-like string
if not isinstance(imported_version, str):
raise ValueError(f"Invalid version type, expected str, got {type(imported_version)}")
status = _check_version_compatibility(imported_version)
# Extract app data
pipeline_data = data.get("rag_pipeline")
if not pipeline_data:
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Missing rag_pipeline data in YAML content",
)
# If app_id is provided, check if it exists
pipeline = None
if pipeline_id:
stmt = select(Pipeline).where(
Pipeline.id == pipeline_id,
Pipeline.tenant_id == account.current_tenant_id,
)
pipeline = self._session.scalar(stmt)
if not pipeline:
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Pipeline not found",
)
dataset = pipeline.retrieve_dataset(session=self._session)
# If major version mismatch, store import info in Redis
if status == ImportStatus.PENDING:
pending_data = RagPipelinePendingData(
import_mode=import_mode,
yaml_content=content,
pipeline_id=pipeline_id,
)
redis_client.setex(
f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}",
IMPORT_INFO_REDIS_EXPIRY,
pending_data.model_dump_json(),
)
return RagPipelineImportInfo(
id=import_id,
status=status,
pipeline_id=pipeline_id,
imported_dsl_version=imported_version,
)
# Extract dependencies
dependencies = data.get("dependencies", [])
check_dependencies_pending_data = None
if dependencies:
check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies]
# Create or update pipeline
pipeline = self._create_or_update_pipeline(
pipeline=pipeline,
data=data,
account=account,
dependencies=check_dependencies_pending_data,
)
# create dataset
name = pipeline.name or "Untitled"
description = pipeline.description
if icon_info:
icon_type = icon_info.icon_type
icon = icon_info.icon
icon_background = icon_info.icon_background
icon_url = icon_info.icon_url
else:
icon_type = data.get("rag_pipeline", {}).get("icon_type")
icon = data.get("rag_pipeline", {}).get("icon")
icon_background = data.get("rag_pipeline", {}).get("icon_background")
icon_url = data.get("rag_pipeline", {}).get("icon_url")
workflow = data.get("workflow", {})
graph = workflow.get("graph", {})
nodes = graph.get("nodes", [])
dataset_id = None
for node in nodes:
if node.get("data", {}).get("type") == "knowledge-index":
knowledge_configuration = KnowledgeConfiguration(**node.get("data", {}))
if (
dataset
and pipeline.is_published
and dataset.chunk_structure != knowledge_configuration.chunk_structure
):
raise ValueError("Chunk structure is not compatible with the published pipeline")
if not dataset:
datasets = self._session.query(Dataset).filter_by(tenant_id=account.current_tenant_id).all()
names = [dataset.name for dataset in datasets]
generate_name = generate_incremental_name(names, name)
dataset = Dataset(
tenant_id=account.current_tenant_id,
name=generate_name,
description=description,
icon_info={
"icon_type": icon_type,
"icon": icon,
"icon_background": icon_background,
"icon_url": icon_url,
},
indexing_technique=knowledge_configuration.indexing_technique,
created_by=account.id,
retrieval_model=knowledge_configuration.retrieval_model.model_dump(),
runtime_mode="rag_pipeline",
chunk_structure=knowledge_configuration.chunk_structure,
)
if knowledge_configuration.indexing_technique == "high_quality":
dataset_collection_binding = (
self._session.query(DatasetCollectionBinding)
.where(
DatasetCollectionBinding.provider_name
== knowledge_configuration.embedding_model_provider,
DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
DatasetCollectionBinding.type == "dataset",
)
.order_by(DatasetCollectionBinding.created_at)
.first()
)
if not dataset_collection_binding:
dataset_collection_binding = DatasetCollectionBinding(
provider_name=knowledge_configuration.embedding_model_provider,
model_name=knowledge_configuration.embedding_model,
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
type="dataset",
)
self._session.add(dataset_collection_binding)
self._session.commit()
dataset_collection_binding_id = dataset_collection_binding.id
dataset.collection_binding_id = dataset_collection_binding_id
dataset.embedding_model = knowledge_configuration.embedding_model
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
elif knowledge_configuration.indexing_technique == "economy":
dataset.keyword_number = knowledge_configuration.keyword_number
dataset.pipeline_id = pipeline.id
self._session.add(dataset)
self._session.commit()
dataset_id = dataset.id
if not dataset_id:
raise ValueError("DSL is not valid, please check the Knowledge Index node.")
return RagPipelineImportInfo(
id=import_id,
status=status,
pipeline_id=pipeline.id,
dataset_id=dataset_id,
imported_dsl_version=imported_version,
)
except yaml.YAMLError as e:
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error=f"Invalid YAML format: {str(e)}",
)
except Exception as e:
logger.exception("Failed to import app")
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error=str(e),
)
def confirm_import(self, *, import_id: str, account: Account) -> RagPipelineImportInfo:
"""
Confirm an import that requires confirmation
"""
redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}"
pending_data = redis_client.get(redis_key)
if not pending_data:
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Import information expired or does not exist",
)
try:
if not isinstance(pending_data, str | bytes):
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error="Invalid import information",
)
pending_data = RagPipelinePendingData.model_validate_json(pending_data)
data = yaml.safe_load(pending_data.yaml_content)
pipeline = None
if pending_data.pipeline_id:
stmt = select(Pipeline).where(
Pipeline.id == pending_data.pipeline_id,
Pipeline.tenant_id == account.current_tenant_id,
)
pipeline = self._session.scalar(stmt)
# Create or update app
pipeline = self._create_or_update_pipeline(
pipeline=pipeline,
data=data,
account=account,
)
dataset = pipeline.retrieve_dataset(session=self._session)
# create dataset
name = pipeline.name
description = pipeline.description
icon_type = data.get("rag_pipeline", {}).get("icon_type")
icon = data.get("rag_pipeline", {}).get("icon")
icon_background = data.get("rag_pipeline", {}).get("icon_background")
icon_url = data.get("rag_pipeline", {}).get("icon_url")
workflow = data.get("workflow", {})
graph = workflow.get("graph", {})
nodes = graph.get("nodes", [])
dataset_id = None
for node in nodes:
if node.get("data", {}).get("type") == "knowledge-index":
knowledge_configuration = KnowledgeConfiguration(**node.get("data", {}))
if not dataset:
dataset = Dataset(
tenant_id=account.current_tenant_id,
name=name,
description=description,
icon_info={
"icon_type": icon_type,
"icon": icon,
"icon_background": icon_background,
"icon_url": icon_url,
},
indexing_technique=knowledge_configuration.indexing_technique,
created_by=account.id,
retrieval_model=knowledge_configuration.retrieval_model.model_dump(),
runtime_mode="rag_pipeline",
chunk_structure=knowledge_configuration.chunk_structure,
)
else:
dataset.indexing_technique = knowledge_configuration.indexing_technique
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
dataset.runtime_mode = "rag_pipeline"
dataset.chunk_structure = knowledge_configuration.chunk_structure
if knowledge_configuration.indexing_technique == "high_quality":
dataset_collection_binding = (
self._session.query(DatasetCollectionBinding)
.where(
DatasetCollectionBinding.provider_name
== knowledge_configuration.embedding_model_provider,
DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
DatasetCollectionBinding.type == "dataset",
)
.order_by(DatasetCollectionBinding.created_at)
.first()
)
if not dataset_collection_binding:
dataset_collection_binding = DatasetCollectionBinding(
provider_name=knowledge_configuration.embedding_model_provider,
model_name=knowledge_configuration.embedding_model,
collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())),
type="dataset",
)
self._session.add(dataset_collection_binding)
self._session.commit()
dataset_collection_binding_id = dataset_collection_binding.id
dataset.collection_binding_id = dataset_collection_binding_id
dataset.embedding_model = knowledge_configuration.embedding_model
dataset.embedding_model_provider = knowledge_configuration.embedding_model_provider
elif knowledge_configuration.indexing_technique == "economy":
dataset.keyword_number = knowledge_configuration.keyword_number
dataset.pipeline_id = pipeline.id
self._session.add(dataset)
self._session.commit()
dataset_id = dataset.id
if not dataset_id:
raise ValueError("DSL is not valid, please check the Knowledge Index node.")
# Delete import info from Redis
redis_client.delete(redis_key)
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.COMPLETED,
pipeline_id=pipeline.id,
dataset_id=dataset_id,
current_dsl_version=CURRENT_DSL_VERSION,
imported_dsl_version=data.get("version", "0.1.0"),
)
except Exception as e:
logger.exception("Error confirming import")
return RagPipelineImportInfo(
id=import_id,
status=ImportStatus.FAILED,
error=str(e),
)
def check_dependencies(
self,
*,
pipeline: Pipeline,
) -> CheckDependenciesResult:
"""Check dependencies"""
# Get dependencies from Redis
redis_key = f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{pipeline.id}"
dependencies = redis_client.get(redis_key)
if not dependencies:
return CheckDependenciesResult()
# Extract dependencies
dependencies = CheckDependenciesPendingData.model_validate_json(dependencies)
# Get leaked dependencies
leaked_dependencies = DependenciesAnalysisService.get_leaked_dependencies(
tenant_id=pipeline.tenant_id, dependencies=dependencies.dependencies
)
return CheckDependenciesResult(
leaked_dependencies=leaked_dependencies,
)
def _create_or_update_pipeline(
self,
*,
pipeline: Pipeline | None,
data: dict,
account: Account,
dependencies: list[PluginDependency] | None = None,
) -> Pipeline:
"""Create a new app or update an existing one."""
if not account.current_tenant_id:
raise ValueError("Tenant id is required")
pipeline_data = data.get("rag_pipeline", {})
# Initialize pipeline based on mode
workflow_data = data.get("workflow")
if not workflow_data or not isinstance(workflow_data, dict):
raise ValueError("Missing workflow data for rag pipeline")
environment_variables_list = workflow_data.get("environment_variables", [])
environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
conversation_variables_list = workflow_data.get("conversation_variables", [])
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", [])
graph = workflow_data.get("graph", {})
for node in graph.get("nodes", []):
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
dataset_ids = node["data"].get("dataset_ids", [])
node["data"]["dataset_ids"] = [
decrypted_id
for dataset_id in dataset_ids
if (
decrypted_id := self.decrypt_dataset_id(
encrypted_data=dataset_id,
tenant_id=account.current_tenant_id,
)
)
]
if pipeline:
# Update existing pipeline
pipeline.name = pipeline_data.get("name", pipeline.name)
pipeline.description = pipeline_data.get("description", pipeline.description)
pipeline.updated_by = account.id
else:
if account.current_tenant_id is None:
raise ValueError("Current tenant is not set")
# Create new app
pipeline = Pipeline()
pipeline.id = str(uuid4())
pipeline.tenant_id = account.current_tenant_id
pipeline.name = pipeline_data.get("name", "")
pipeline.description = pipeline_data.get("description", "")
pipeline.created_by = account.id
pipeline.updated_by = account.id
self._session.add(pipeline)
self._session.commit()
# save dependencies
if dependencies:
redis_client.setex(
f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{pipeline.id}",
IMPORT_INFO_REDIS_EXPIRY,
CheckDependenciesPendingData(pipeline_id=pipeline.id, dependencies=dependencies).model_dump_json(),
)
workflow = (
self._session.query(Workflow)
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == "draft",
)
.first()
)
# create draft workflow if not found
if not workflow:
workflow = Workflow(
tenant_id=pipeline.tenant_id,
app_id=pipeline.id,
features="{}",
type=WorkflowType.RAG_PIPELINE.value,
version="draft",
graph=json.dumps(graph),
created_by=account.id,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
rag_pipeline_variables=rag_pipeline_variables_list,
)
self._session.add(workflow)
self._session.flush()
pipeline.workflow_id = workflow.id
else:
workflow.graph = json.dumps(graph)
workflow.updated_by = account.id
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
workflow.environment_variables = environment_variables
workflow.conversation_variables = conversation_variables
workflow.rag_pipeline_variables = rag_pipeline_variables_list
# commit db session changes
self._session.commit()
return pipeline
def export_rag_pipeline_dsl(self, pipeline: Pipeline, include_secret: bool = False) -> str:
"""
Export pipeline
:param pipeline: Pipeline instance
:param include_secret: Whether include secret variable
:return:
"""
dataset = pipeline.retrieve_dataset(session=self._session)
if not dataset:
raise ValueError("Missing dataset for rag pipeline")
icon_info = dataset.icon_info
export_data = {
"version": CURRENT_DSL_VERSION,
"kind": "rag_pipeline",
"rag_pipeline": {
"name": dataset.name,
"icon": icon_info.get("icon", "📙") if icon_info else "📙",
"icon_type": icon_info.get("icon_type", "emoji") if icon_info else "emoji",
"icon_background": icon_info.get("icon_background", "#FFEAD5") if icon_info else "#FFEAD5",
"icon_url": icon_info.get("icon_url") if icon_info else None,
"description": pipeline.description,
},
}
self._append_workflow_export_data(export_data=export_data, pipeline=pipeline, include_secret=include_secret)
return yaml.dump(export_data, allow_unicode=True) # type: ignore
def _append_workflow_export_data(self, *, export_data: dict, pipeline: Pipeline, include_secret: bool) -> None:
"""
Append workflow export data
:param export_data: export data
:param pipeline: Pipeline instance
"""
workflow = (
self._session.query(Workflow)
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == "draft",
)
.first()
)
if not workflow:
raise ValueError("Missing draft workflow configuration, please check.")
workflow_dict = workflow.to_dict(include_secret=include_secret)
for node in workflow_dict.get("graph", {}).get("nodes", []):
node_data = node.get("data", {})
if not node_data:
continue
data_type = node_data.get("type", "")
if data_type == NodeType.KNOWLEDGE_RETRIEVAL.value:
dataset_ids = node_data.get("dataset_ids", [])
node["data"]["dataset_ids"] = [
self.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=pipeline.tenant_id)
for dataset_id in dataset_ids
]
# filter credential id from tool node
if not include_secret and data_type == NodeType.TOOL.value:
node_data.pop("credential_id", None)
# filter credential id from agent node
if not include_secret and data_type == NodeType.AGENT.value:
for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []):
tool.pop("credential_id", None)
export_data["workflow"] = workflow_dict
dependencies = self._extract_dependencies_from_workflow(workflow)
export_data["dependencies"] = [
jsonable_encoder(d.model_dump())
for d in DependenciesAnalysisService.generate_dependencies(
tenant_id=pipeline.tenant_id, dependencies=dependencies
)
]
def _extract_dependencies_from_workflow(self, workflow: Workflow) -> list[str]:
"""
Extract dependencies from workflow
:param workflow: Workflow instance
:return: dependencies list format like ["langgenius/google"]
"""
graph = workflow.graph_dict
dependencies = self._extract_dependencies_from_workflow_graph(graph)
return dependencies
def _extract_dependencies_from_workflow_graph(self, graph: Mapping) -> list[str]:
"""
Extract dependencies from workflow graph
:param graph: Workflow graph
:return: dependencies list format like ["langgenius/google"]
"""
dependencies = []
for node in graph.get("nodes", []):
try:
typ = node.get("data", {}).get("type")
match typ:
case NodeType.TOOL.value:
tool_entity = ToolNodeData(**node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id),
)
case NodeType.DATASOURCE.value:
datasource_entity = DatasourceNodeData(**node["data"])
if datasource_entity.provider_type != "local_file":
dependencies.append(datasource_entity.plugin_id)
case NodeType.LLM.value:
llm_entity = LLMNodeData(**node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider),
)
case NodeType.QUESTION_CLASSIFIER.value:
question_classifier_entity = QuestionClassifierNodeData(**node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
question_classifier_entity.model.provider
),
)
case NodeType.PARAMETER_EXTRACTOR.value:
parameter_extractor_entity = ParameterExtractorNodeData(**node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
parameter_extractor_entity.model.provider
),
)
case NodeType.KNOWLEDGE_INDEX.value:
knowledge_index_entity = KnowledgeConfiguration(**node["data"])
if knowledge_index_entity.indexing_technique == "high_quality":
if knowledge_index_entity.embedding_model_provider:
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
knowledge_index_entity.embedding_model_provider
),
)
if knowledge_index_entity.retrieval_model.reranking_mode == "reranking_model":
if knowledge_index_entity.retrieval_model.reranking_enable:
if (
knowledge_index_entity.retrieval_model.reranking_model
and knowledge_index_entity.retrieval_model.reranking_mode == "reranking_model"
):
if knowledge_index_entity.retrieval_model.reranking_model.reranking_provider_name:
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
knowledge_index_entity.retrieval_model.reranking_model.reranking_provider_name
),
)
case NodeType.KNOWLEDGE_RETRIEVAL.value:
knowledge_retrieval_entity = KnowledgeRetrievalNodeData(**node["data"])
if knowledge_retrieval_entity.retrieval_mode == "multiple":
if knowledge_retrieval_entity.multiple_retrieval_config:
if (
knowledge_retrieval_entity.multiple_retrieval_config.reranking_mode
== "reranking_model"
):
if knowledge_retrieval_entity.multiple_retrieval_config.reranking_model:
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
knowledge_retrieval_entity.multiple_retrieval_config.reranking_model.provider
),
)
elif (
knowledge_retrieval_entity.multiple_retrieval_config.reranking_mode
== "weighted_score"
):
if knowledge_retrieval_entity.multiple_retrieval_config.weights:
vector_setting = (
knowledge_retrieval_entity.multiple_retrieval_config.weights.vector_setting
)
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
vector_setting.embedding_provider_name
),
)
elif knowledge_retrieval_entity.retrieval_mode == "single":
model_config = knowledge_retrieval_entity.single_retrieval_config
if model_config:
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
model_config.model.provider
),
)
case _:
# TODO: Handle default case or unknown node types
pass
except Exception as e:
logger.exception("Error extracting node dependency", exc_info=e)
return dependencies
@classmethod
def _extract_dependencies_from_model_config(cls, model_config: Mapping) -> list[str]:
"""
Extract dependencies from model config
:param model_config: model config dict
:return: dependencies list format like ["langgenius/google"]
"""
dependencies = []
try:
# completion model
model_dict = model_config.get("model", {})
if model_dict:
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(model_dict.get("provider", ""))
)
# reranking model
dataset_configs = model_config.get("dataset_configs", {})
if dataset_configs:
for dataset_config in dataset_configs.get("datasets", {}).get("datasets", []):
if dataset_config.get("reranking_model"):
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
dataset_config.get("reranking_model", {})
.get("reranking_provider_name", {})
.get("provider")
)
)
# tools
agent_configs = model_config.get("agent_mode", {})
if agent_configs:
for agent_config in agent_configs.get("tools", []):
dependencies.append(
DependenciesAnalysisService.analyze_tool_dependency(agent_config.get("provider_id"))
)
except Exception as e:
logger.exception("Error extracting model config dependency", exc_info=e)
return dependencies
@classmethod
def get_leaked_dependencies(cls, tenant_id: str, dsl_dependencies: list[dict]) -> list[PluginDependency]:
"""
Returns the leaked dependencies in current workspace
"""
dependencies = [PluginDependency(**dep) for dep in dsl_dependencies]
if not dependencies:
return []
return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies)
def _generate_aes_key(self, tenant_id: str) -> bytes:
"""Generate AES key based on tenant_id"""
return hashlib.sha256(tenant_id.encode()).digest()
def encrypt_dataset_id(self, dataset_id: str, tenant_id: str) -> str:
"""Encrypt dataset_id using AES-CBC mode"""
key = self._generate_aes_key(tenant_id)
iv = key[:16]
cipher = AES.new(key, AES.MODE_CBC, iv)
ct_bytes = cipher.encrypt(pad(dataset_id.encode(), AES.block_size))
return base64.b64encode(ct_bytes).decode()
def decrypt_dataset_id(self, encrypted_data: str, tenant_id: str) -> str | None:
"""AES decryption"""
try:
key = self._generate_aes_key(tenant_id)
iv = key[:16]
cipher = AES.new(key, AES.MODE_CBC, iv)
pt = unpad(cipher.decrypt(base64.b64decode(encrypted_data)), AES.block_size)
return pt.decode()
except Exception:
return None
def create_rag_pipeline_dataset(
self,
tenant_id: str,
rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity,
):
if rag_pipeline_dataset_create_entity.name:
# check if dataset name already exists
if (
self._session.query(Dataset)
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
.first()
):
raise ValueError(f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists.")
else:
# generate a random name as Untitled 1 2 3 ...
datasets = self._session.query(Dataset).filter_by(tenant_id=tenant_id).all()
names = [dataset.name for dataset in datasets]
rag_pipeline_dataset_create_entity.name = generate_incremental_name(
names,
"Untitled",
)
account = cast(Account, current_user)
rag_pipeline_import_info: RagPipelineImportInfo = self.import_rag_pipeline(
account=account,
import_mode=ImportMode.YAML_CONTENT.value,
yaml_content=rag_pipeline_dataset_create_entity.yaml_content,
dataset=None,
dataset_name=rag_pipeline_dataset_create_entity.name,
icon_info=rag_pipeline_dataset_create_entity.icon_info,
)
return {
"id": rag_pipeline_import_info.id,
"dataset_id": rag_pipeline_import_info.dataset_id,
"pipeline_id": rag_pipeline_import_info.pipeline_id,
"status": rag_pipeline_import_info.status,
"imported_dsl_version": rag_pipeline_import_info.imported_dsl_version,
"current_dsl_version": rag_pipeline_import_info.current_dsl_version,
"error": rag_pipeline_import_info.error,
}

View File

@ -0,0 +1,23 @@
from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity
from core.plugin.impl.datasource import PluginDatasourceManager
from services.datasource_provider_service import DatasourceProviderService
class RagPipelineManageService:
@staticmethod
def list_rag_pipeline_datasources(tenant_id: str) -> list[PluginDatasourceProviderEntity]:
"""
list rag pipeline datasources
"""
# get all builtin providers
manager = PluginDatasourceManager()
datasources = manager.fetch_datasource_providers(tenant_id)
for datasource in datasources:
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_datasource_credentials(
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
)
if credentials:
datasource.is_authorized = True
return datasources

View File

@ -0,0 +1,386 @@
import json
import logging
from datetime import UTC, datetime
from pathlib import Path
from uuid import uuid4
import yaml
from flask_login import current_user
from constants import DOCUMENT_EXTENSIONS
from core.plugin.impl.plugin import PluginInstaller
from extensions.ext_database import db
from factories import variable_factory
from models.dataset import Dataset, Document, DocumentPipelineExecutionLog, Pipeline
from models.model import UploadFile
from models.workflow import Workflow, WorkflowType
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration, RetrievalSetting
from services.plugin.plugin_migration import PluginMigration
from services.plugin.plugin_service import PluginService
logger = logging.getLogger(__name__)
class RagPipelineTransformService:
def transform_dataset(self, dataset_id: str):
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset not found")
if dataset.pipeline_id and dataset.runtime_mode == "rag_pipeline":
return {
"pipeline_id": dataset.pipeline_id,
"dataset_id": dataset_id,
"status": "success",
}
if dataset.provider != "vendor":
raise ValueError("External dataset is not supported")
datasource_type = dataset.data_source_type
indexing_technique = dataset.indexing_technique
if not datasource_type and not indexing_technique:
return self._transform_to_empty_pipeline(dataset)
doc_form = dataset.doc_form
if not doc_form:
return self._transform_to_empty_pipeline(dataset)
retrieval_model = dataset.retrieval_model
pipeline_yaml = self._get_transform_yaml(doc_form, datasource_type, indexing_technique)
# deal dependencies
self._deal_dependencies(pipeline_yaml, dataset.tenant_id)
# Extract app data
workflow_data = pipeline_yaml.get("workflow")
if not workflow_data:
raise ValueError("Missing workflow data for rag pipeline")
graph = workflow_data.get("graph", {})
nodes = graph.get("nodes", [])
new_nodes = []
for node in nodes:
if (
node.get("data", {}).get("type") == "datasource"
and node.get("data", {}).get("provider_type") == "local_file"
):
node = self._deal_file_extensions(node)
if node.get("data", {}).get("type") == "knowledge-index":
node = self._deal_knowledge_index(dataset, doc_form, indexing_technique, retrieval_model, node)
new_nodes.append(node)
if new_nodes:
graph["nodes"] = new_nodes
workflow_data["graph"] = graph
pipeline_yaml["workflow"] = workflow_data
# create pipeline
pipeline = self._create_pipeline(pipeline_yaml)
# save chunk structure to dataset
if doc_form == "hierarchical_model":
dataset.chunk_structure = "hierarchical_model"
elif doc_form == "text_model":
dataset.chunk_structure = "text_model"
else:
raise ValueError("Unsupported doc form")
dataset.runtime_mode = "rag_pipeline"
dataset.pipeline_id = pipeline.id
# deal document data
self._deal_document_data(dataset)
db.session.commit()
return {
"pipeline_id": pipeline.id,
"dataset_id": dataset_id,
"status": "success",
}
def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: str | None):
pipeline_yaml = {}
if doc_form == "text_model":
match datasource_type:
case "upload_file":
if indexing_technique == "high_quality":
# get graph from transform.file-general-high-quality.yml
with open(f"{Path(__file__).parent}/transform/file-general-high-quality.yml") as f:
pipeline_yaml = yaml.safe_load(f)
if indexing_technique == "economy":
# get graph from transform.file-general-economy.yml
with open(f"{Path(__file__).parent}/transform/file-general-economy.yml") as f:
pipeline_yaml = yaml.safe_load(f)
case "notion_import":
if indexing_technique == "high_quality":
# get graph from transform.notion-general-high-quality.yml
with open(f"{Path(__file__).parent}/transform/notion-general-high-quality.yml") as f:
pipeline_yaml = yaml.safe_load(f)
if indexing_technique == "economy":
# get graph from transform.notion-general-economy.yml
with open(f"{Path(__file__).parent}/transform/notion-general-economy.yml") as f:
pipeline_yaml = yaml.safe_load(f)
case "website_crawl":
if indexing_technique == "high_quality":
# get graph from transform.website-crawl-general-high-quality.yml
with open(f"{Path(__file__).parent}/transform/website-crawl-general-high-quality.yml") as f:
pipeline_yaml = yaml.safe_load(f)
if indexing_technique == "economy":
# get graph from transform.website-crawl-general-economy.yml
with open(f"{Path(__file__).parent}/transform/website-crawl-general-economy.yml") as f:
pipeline_yaml = yaml.safe_load(f)
case _:
raise ValueError("Unsupported datasource type")
elif doc_form == "hierarchical_model":
match datasource_type:
case "upload_file":
# get graph from transform.file-parentchild.yml
with open(f"{Path(__file__).parent}/transform/file-parentchild.yml") as f:
pipeline_yaml = yaml.safe_load(f)
case "notion_import":
# get graph from transform.notion-parentchild.yml
with open(f"{Path(__file__).parent}/transform/notion-parentchild.yml") as f:
pipeline_yaml = yaml.safe_load(f)
case "website_crawl":
# get graph from transform.website-crawl-parentchild.yml
with open(f"{Path(__file__).parent}/transform/website-crawl-parentchild.yml") as f:
pipeline_yaml = yaml.safe_load(f)
case _:
raise ValueError("Unsupported datasource type")
else:
raise ValueError("Unsupported doc form")
return pipeline_yaml
def _deal_file_extensions(self, node: dict):
file_extensions = node.get("data", {}).get("fileExtensions", [])
if not file_extensions:
return node
file_extensions = [file_extension.lower() for file_extension in file_extensions]
node["data"]["fileExtensions"] = DOCUMENT_EXTENSIONS
return node
def _deal_knowledge_index(
self, dataset: Dataset, doc_form: str, indexing_technique: str | None, retrieval_model: dict, node: dict
):
knowledge_configuration_dict = node.get("data", {})
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration_dict)
if indexing_technique == "high_quality":
knowledge_configuration.embedding_model = dataset.embedding_model
knowledge_configuration.embedding_model_provider = dataset.embedding_model_provider
if retrieval_model:
retrieval_setting = RetrievalSetting(**retrieval_model)
if indexing_technique == "economy":
retrieval_setting.search_method = "keyword_search"
knowledge_configuration.retrieval_model = retrieval_setting
else:
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
knowledge_configuration_dict.update(knowledge_configuration.model_dump())
node["data"] = knowledge_configuration_dict
return node
def _create_pipeline(
self,
data: dict,
) -> Pipeline:
"""Create a new app or update an existing one."""
pipeline_data = data.get("rag_pipeline", {})
# Initialize pipeline based on mode
workflow_data = data.get("workflow")
if not workflow_data or not isinstance(workflow_data, dict):
raise ValueError("Missing workflow data for rag pipeline")
environment_variables_list = workflow_data.get("environment_variables", [])
environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
conversation_variables_list = workflow_data.get("conversation_variables", [])
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", [])
graph = workflow_data.get("graph", {})
# Create new app
pipeline = Pipeline()
pipeline.id = str(uuid4())
pipeline.tenant_id = current_user.current_tenant_id
pipeline.name = pipeline_data.get("name", "")
pipeline.description = pipeline_data.get("description", "")
pipeline.created_by = current_user.id
pipeline.updated_by = current_user.id
pipeline.is_published = True
pipeline.is_public = True
db.session.add(pipeline)
db.session.flush()
# create draft workflow
draft_workflow = Workflow(
tenant_id=pipeline.tenant_id,
app_id=pipeline.id,
features="{}",
type=WorkflowType.RAG_PIPELINE.value,
version="draft",
graph=json.dumps(graph),
created_by=current_user.id,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
rag_pipeline_variables=rag_pipeline_variables_list,
)
published_workflow = Workflow(
tenant_id=pipeline.tenant_id,
app_id=pipeline.id,
features="{}",
type=WorkflowType.RAG_PIPELINE.value,
version=str(datetime.now(UTC).replace(tzinfo=None)),
graph=json.dumps(graph),
created_by=current_user.id,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
rag_pipeline_variables=rag_pipeline_variables_list,
)
db.session.add(draft_workflow)
db.session.add(published_workflow)
db.session.flush()
pipeline.workflow_id = published_workflow.id
db.session.add(pipeline)
return pipeline
def _deal_dependencies(self, pipeline_yaml: dict, tenant_id: str):
installer_manager = PluginInstaller()
installed_plugins = installer_manager.list_plugins(tenant_id)
plugin_migration = PluginMigration()
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
dependencies = pipeline_yaml.get("dependencies", [])
need_install_plugin_unique_identifiers = []
for dependency in dependencies:
if dependency.get("type") == "marketplace":
plugin_unique_identifier = dependency.get("value", {}).get("plugin_unique_identifier")
plugin_id = plugin_unique_identifier.split(":")[0]
if plugin_id not in installed_plugins_ids:
plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(plugin_id) # type: ignore
if plugin_unique_identifier:
need_install_plugin_unique_identifiers.append(plugin_unique_identifier)
if need_install_plugin_unique_identifiers:
logger.debug("Installing missing pipeline plugins %s", need_install_plugin_unique_identifiers)
PluginService.install_from_marketplace_pkg(tenant_id, need_install_plugin_unique_identifiers)
def _transform_to_empty_pipeline(self, dataset: Dataset):
pipeline = Pipeline(
tenant_id=dataset.tenant_id,
name=dataset.name,
description=dataset.description,
created_by=current_user.id,
)
db.session.add(pipeline)
db.session.flush()
dataset.pipeline_id = pipeline.id
dataset.runtime_mode = "rag_pipeline"
dataset.updated_by = current_user.id
dataset.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.add(dataset)
db.session.commit()
return {
"pipeline_id": pipeline.id,
"dataset_id": dataset.id,
"status": "success",
}
def _deal_document_data(self, dataset: Dataset):
file_node_id = "1752479895761"
notion_node_id = "1752489759475"
jina_node_id = "1752491761974"
firecrawl_node_id = "1752565402678"
documents = db.session.query(Document).where(Document.dataset_id == dataset.id).all()
for document in documents:
data_source_info_dict = document.data_source_info_dict
if not data_source_info_dict:
continue
if document.data_source_type == "upload_file":
document.data_source_type = "local_file"
file_id = data_source_info_dict.get("upload_file_id")
if file_id:
file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if file:
data_source_info = json.dumps(
{
"real_file_id": file_id,
"name": file.name,
"size": file.size,
"extension": file.extension,
"mime_type": file.mime_type,
"url": "",
"transfer_method": "local_file",
}
)
document.data_source_info = data_source_info
document_pipeline_execution_log = DocumentPipelineExecutionLog(
document_id=document.id,
pipeline_id=dataset.pipeline_id,
datasource_type="local_file",
datasource_info=data_source_info,
input_data={},
created_by=document.created_by,
created_at=document.created_at,
datasource_node_id=file_node_id,
)
db.session.add(document)
db.session.add(document_pipeline_execution_log)
elif document.data_source_type == "notion_import":
document.data_source_type = "online_document"
data_source_info = json.dumps(
{
"workspace_id": data_source_info_dict.get("notion_workspace_id"),
"page": {
"page_id": data_source_info_dict.get("notion_page_id"),
"page_name": document.name,
"page_icon": data_source_info_dict.get("notion_page_icon"),
"type": data_source_info_dict.get("type"),
"last_edited_time": data_source_info_dict.get("last_edited_time"),
"parent_id": None,
},
}
)
document.data_source_info = data_source_info
document_pipeline_execution_log = DocumentPipelineExecutionLog(
document_id=document.id,
pipeline_id=dataset.pipeline_id,
datasource_type="online_document",
datasource_info=data_source_info,
input_data={},
created_by=document.created_by,
created_at=document.created_at,
datasource_node_id=notion_node_id,
)
db.session.add(document)
db.session.add(document_pipeline_execution_log)
elif document.data_source_type == "website_crawl":
document.data_source_type = "website_crawl"
data_source_info = json.dumps(
{
"source_url": data_source_info_dict.get("url"),
"content": "",
"title": document.name,
"description": "",
}
)
document.data_source_info = data_source_info
if data_source_info_dict.get("provider") == "firecrawl":
datasource_node_id = firecrawl_node_id
elif data_source_info_dict.get("provider") == "jinareader":
datasource_node_id = jina_node_id
else:
continue
document_pipeline_execution_log = DocumentPipelineExecutionLog(
document_id=document.id,
pipeline_id=dataset.pipeline_id,
datasource_type="website_crawl",
datasource_info=data_source_info,
input_data={},
created_by=document.created_by,
created_at=document.created_at,
datasource_node_id=datasource_node_id,
)
db.session.add(document)
db.session.add(document_pipeline_execution_log)

View File

@ -0,0 +1,709 @@
dependencies:
- current_identifier: null
type: marketplace
value:
plugin_unique_identifier: langgenius/general_chunker:0.0.1@e3da408b7277866404c3f884d599261f9d0b9003ea4ef7eb3b64489bdf39d18b
- current_identifier: null
type: marketplace
value:
plugin_unique_identifier: langgenius/dify_extractor:0.0.1@50103421d4e002f059b662d21ad2d7a1cf34869abdbe320299d7e382516ebb1c
kind: rag_pipeline
rag_pipeline:
description: ''
icon: 📙
icon_background: ''
icon_type: emoji
name: file-general-economy
version: 0.1.0
workflow:
conversation_variables: []
environment_variables: []
features: {}
graph:
edges:
- data:
isInIteration: false
isInLoop: false
sourceType: datasource
targetType: if-else
id: 1752479895761-source-1752481129417-target
source: '1752479895761'
sourceHandle: source
target: '1752481129417'
targetHandle: target
type: custom
zIndex: 0
- data:
isInLoop: false
sourceType: if-else
targetType: tool
id: 1752481129417-24e47cad-f1e2-4f74-9884-3f49d5bb37b7-1752480460682-target
source: '1752481129417'
sourceHandle: 24e47cad-f1e2-4f74-9884-3f49d5bb37b7
target: '1752480460682'
targetHandle: target
type: custom
zIndex: 0
- data:
isInLoop: false
sourceType: if-else
targetType: document-extractor
id: 1752481129417-false-1752481112180-target
source: '1752481129417'
sourceHandle: 'false'
target: '1752481112180'
targetHandle: target
type: custom
zIndex: 0
- data:
isInIteration: false
isInLoop: false
sourceType: tool
targetType: variable-aggregator
id: 1752480460682-source-1752482022496-target
source: '1752480460682'
sourceHandle: source
target: '1752482022496'
targetHandle: target
type: custom
zIndex: 0
- data:
isInLoop: false
sourceType: document-extractor
targetType: variable-aggregator
id: 1752481112180-source-1752482022496-target
source: '1752481112180'
sourceHandle: source
target: '1752482022496'
targetHandle: target
type: custom
zIndex: 0
- data:
isInIteration: false
isInLoop: false
sourceType: variable-aggregator
targetType: tool
id: 1752482022496-source-1752482151668-target
source: '1752482022496'
sourceHandle: source
target: '1752482151668'
targetHandle: target
type: custom
zIndex: 0
- data:
isInIteration: false
isInLoop: false
sourceType: tool
targetType: knowledge-index
id: 1752482151668-source-1752477924228-target
source: '1752482151668'
sourceHandle: source
target: '1752477924228'
targetHandle: target
type: custom
zIndex: 0
nodes:
- data:
chunk_structure: text_model
embedding_model: text-embedding-ada-002
embedding_model_provider: langgenius/openai/openai
index_chunk_variable_selector:
- '1752482151668'
- result
indexing_technique: economy
keyword_number: 10
retrieval_model:
score_threshold: 0.5
score_threshold_enabled: false
search_method: keyword_search
top_k: 3
vector_setting:
embedding_model_name: text-embedding-ada-002
embedding_provider_name: langgenius/openai/openai
selected: true
title: Knowledge Base
type: knowledge-index
height: 114
id: '1752477924228'
position:
x: 1076.4656678451215
y: 281.3910724383104
positionAbsolute:
x: 1076.4656678451215
y: 281.3910724383104
selected: true
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
datasource_configurations: {}
datasource_label: File
datasource_name: upload-file
datasource_parameters: {}
fileExtensions:
- txt
- markdown
- mdx
- pdf
- html
- xlsx
- xls
- vtt
- properties
- doc
- docx
- csv
- eml
- msg
- pptx
- xml
- epub
- ppt
- md
plugin_id: langgenius/file
provider_name: file
provider_type: local_file
selected: false
title: File
type: datasource
height: 52
id: '1752479895761'
position:
x: -839.8603427660498
y: 251.3910724383104
positionAbsolute:
x: -839.8603427660498
y: 251.3910724383104
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
is_team_authorization: true
output_schema:
properties:
documents:
description: the documents extracted from the file
items:
type: object
type: array
images:
description: The images extracted from the file
items:
type: object
type: array
type: object
paramSchemas:
- auto_generate: null
default: null
form: llm
human_description:
en_US: the file to be parsed(support pdf, ppt, pptx, doc, docx, png, jpg,
jpeg)
ja_JP: 解析するファイル(pdf, ppt, pptx, doc, docx, png, jpg, jpegをサポート)
pt_BR: o arquivo a ser analisado (suporta pdf, ppt, pptx, doc, docx, png,
jpg, jpeg)
zh_Hans: 用于解析的文件(支持 pdf, ppt, pptx, doc, docx, png, jpg, jpeg)
label:
en_US: file
ja_JP: ファイル
pt_BR: arquivo
zh_Hans: file
llm_description: the file to be parsed (support pdf, ppt, pptx, doc, docx,
png, jpg, jpeg)
max: null
min: null
name: file
options: []
placeholder: null
precision: null
required: true
scope: null
template: null
type: file
params:
file: ''
provider_id: langgenius/dify_extractor/dify_extractor
provider_name: langgenius/dify_extractor/dify_extractor
provider_type: builtin
selected: false
title: Dify Extractor
tool_configurations: {}
tool_description: Dify Extractor
tool_label: Dify Extractor
tool_name: dify_extractor
tool_parameters:
file:
type: variable
value:
- '1752479895761'
- file
type: tool
height: 52
id: '1752480460682'
position:
x: -108.28652292656551
y: 281.3910724383104
positionAbsolute:
x: -108.28652292656551
y: 281.3910724383104
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
is_array_file: false
selected: false
title: 文档提取器
type: document-extractor
variable_selector:
- '1752479895761'
- file
height: 90
id: '1752481112180'
position:
x: -108.28652292656551
y: 390.6576481692478
positionAbsolute:
x: -108.28652292656551
y: 390.6576481692478
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
cases:
- case_id: 24e47cad-f1e2-4f74-9884-3f49d5bb37b7
conditions:
- comparison_operator: is
id: 9da88d93-3ff6-463f-abfd-6bcafbf2554d
value: .xlsx
varType: file
variable_selector:
- '1752479895761'
- file
- extension
- comparison_operator: is
id: d0e88f5e-dfe3-4bae-af0c-dbec267500de
value: .xls
varType: file
variable_selector:
- '1752479895761'
- file
- extension
- comparison_operator: is
id: a957e91e-1ed7-4c6b-9c80-2f0948858f1d
value: .md
varType: file
variable_selector:
- '1752479895761'
- file
- extension
- comparison_operator: is
id: 870c3c39-8d3f-474a-ab8b-9c0ccf53db73
value: .markdown
varType: file
variable_selector:
- '1752479895761'
- file
- extension
- comparison_operator: is
id: f9541513-1e71-4dc1-9db5-35dc84a39e3c
value: .mdx
varType: file
variable_selector:
- '1752479895761'
- file
- extension
- comparison_operator: is
id: 4c7f455b-ac20-40ca-9495-6cc44ffcb35d
value: .html
varType: file
variable_selector:
- '1752479895761'
- file
- extension
- comparison_operator: is
id: 2e12d9c7-8057-4a09-8851-f9fd1d0718d1
value: .htm
varType: file
variable_selector:
- '1752479895761'
- file
- extension
- comparison_operator: is
id: 73a995a9-d8b9-4aef-89f7-306e2ddcbce2
value: .docx
varType: file
variable_selector:
- '1752479895761'
- file
- extension
- comparison_operator: is
id: 8a2e8772-0426-458b-a1f9-9eaaec0f27c8
value: .csv
varType: file
variable_selector:
- '1752479895761'
- file
- extension
- comparison_operator: is
id: aa2cb6b6-a2fc-462a-a9f5-c9c3f33a1602
value: .txt
varType: file
variable_selector:
- '1752479895761'
- file
- extension
id: 24e47cad-f1e2-4f74-9884-3f49d5bb37b7
logical_operator: or
selected: false
title: IF/ELSE
type: if-else
height: 358
id: '1752481129417'
position:
x: -489.57009543377865
y: 251.3910724383104
positionAbsolute:
x: -489.57009543377865
y: 251.3910724383104
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
advanced_settings:
group_enabled: false
groups:
- groupId: f4cf07b4-914d-4544-8ef8-0c5d9e4f21a7
group_name: Group1
output_type: string
variables:
- - '1752481112180'
- text
- - '1752480460682'
- text
output_type: string
selected: false
title: Variable Aggregator
type: variable-aggregator
variables:
- - '1752481112180'
- text
- - '1752480460682'
- text
height: 129
id: '1752482022496'
position:
x: 319.441649575055
y: 281.3910724383104
positionAbsolute:
x: 319.441649575055
y: 281.3910724383104
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
is_team_authorization: true
output_schema:
properties:
result:
description: The result of the general chunk tool.
properties:
general_chunks:
items:
description: The chunk of the text.
type: string
type: array
type: object
type: object
paramSchemas:
- auto_generate: null
default: null
form: llm
human_description:
en_US: The text you want to chunk.
ja_JP: チャンク化したいテキスト。
pt_BR: O texto que você deseja dividir.
zh_Hans: 你想要分块的文本。
label:
en_US: Input Variable
ja_JP: 入力変数
pt_BR: Variável de entrada
zh_Hans: 输入变量
llm_description: The text you want to chunk.
max: null
min: null
name: input_variable
options: []
placeholder: null
precision: null
required: true
scope: null
template: null
type: string
- auto_generate: null
default: null
form: llm
human_description:
en_US: The delimiter of the chunks.
ja_JP: チャンクの区切り記号。
pt_BR: O delimitador dos blocos.
zh_Hans: 块的分隔符。
label:
en_US: Delimiter
ja_JP: 区切り記号
pt_BR: DDelimitador
zh_Hans: 分隔符
llm_description: The delimiter of the chunks, the format of the delimiter
must be a string.
max: null
min: null
name: delimiter
options: []
placeholder: null
precision: null
required: true
scope: null
template: null
type: string
- auto_generate: null
default: null
form: llm
human_description:
en_US: The maximum chunk length.
ja_JP: 最大長のチャンク。
pt_BR: O comprimento máximo do bloco
zh_Hans: 最大块的长度。
label:
en_US: Maximum Chunk Length
ja_JP: チャンク最大長
pt_BR: O comprimento máximo do bloco
zh_Hans: 最大块的长度
llm_description: The maximum chunk length, the format of the chunk size
must be an integer.
max: null
min: null
name: max_chunk_length
options: []
placeholder: null
precision: null
required: true
scope: null
template: null
type: number
- auto_generate: null
default: null
form: llm
human_description:
en_US: The chunk overlap length.
ja_JP: チャンクの重複長
pt_BR: O comprimento de sobreposição dos fragmentos
zh_Hans: 块的重叠长度。
label:
en_US: Chunk Overlap Length
ja_JP: チャンク重複長
pt_BR: Comprimento de sobreposição do bloco
zh_Hans: 块的重叠长度
llm_description: The chunk overlap length, the format of the chunk overlap
length must be an integer.
max: null
min: null
name: chunk_overlap_length
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: number
- auto_generate: null
default: null
form: llm
human_description:
en_US: Replace consecutive spaces, newlines and tabs
ja_JP: 連続のスペース、改行、まだはタブを置換する
pt_BR: Substituir espaços consecutivos, novas linhas e tabulações
zh_Hans: 替换连续的空格、换行符和制表符
label:
en_US: Replace Consecutive Spaces, Newlines and Tabs
ja_JP: 連続のスペース、改行、まだはタブを置換する
pt_BR: Substituir espaços consecutivos, novas linhas e tabulações
zh_Hans: 替换连续的空格、换行符和制表符
llm_description: Replace consecutive spaces, newlines and tabs, the format
of the replace must be a boolean.
max: null
min: null
name: replace_consecutive_spaces_newlines_tabs
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: boolean
- auto_generate: null
default: null
form: llm
human_description:
en_US: Delete all URLs and email addresses
ja_JP: すべてのURLとメールアドレスを削除する
pt_BR: Excluir todos os URLs e endereços de e-mail
zh_Hans: 删除所有URL和电子邮件地址
label:
en_US: Delete All URLs and Email Addresses
ja_JP: すべてのURLとメールアドレスを削除する
pt_BR: Excluir todos os URLs e endereços de e-mail
zh_Hans: 删除所有URL和电子邮件地址
llm_description: Delete all URLs and email addresses, the format of the
delete must be a boolean.
max: null
min: null
name: delete_all_urls_and_email_addresses
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: boolean
params:
chunk_overlap_length: ''
delete_all_urls_and_email_addresses: ''
delimiter: ''
input_variable: ''
max_chunk_length: ''
replace_consecutive_spaces_newlines_tabs: ''
provider_id: langgenius/general_chunker/general_chunker
provider_name: langgenius/general_chunker/general_chunker
provider_type: builtin
selected: false
title: General Chunker
tool_configurations: {}
tool_description: A tool for general text chunking mode, the chunks retrieved and recalled are the same.
tool_label: General Chunker
tool_name: general_chunker
tool_parameters:
chunk_overlap_length:
type: variable
value:
- rag
- shared
- chunk_overlap
delete_all_urls_and_email_addresses:
type: mixed
value: '{{#rag.shared.delete_urls_email#}}'
delimiter:
type: mixed
value: '{{#rag.shared.delimiter#}}'
input_variable:
type: mixed
value: '{{#1752482022496.output#}}'
max_chunk_length:
type: variable
value:
- rag
- shared
- max_chunk_length
replace_consecutive_spaces_newlines_tabs:
type: mixed
value: '{{#rag.shared.replace_consecutive_spaces#}}'
type: tool
height: 52
id: '1752482151668'
position:
x: 693.5300771507484
y: 281.3910724383104
positionAbsolute:
x: 693.5300771507484
y: 281.3910724383104
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
viewport:
x: 701.4999626224237
y: 128.33739021504016
zoom: 0.48941689643726966
rag_pipeline_variables:
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: \n\n
label: Delimiter
max_length: 100
options: []
placeholder: null
required: true
tooltips: A delimiter is the character used to separate text. \n\n is recommended
for splitting the original document into large parent chunks. You can also use
special delimiters defined by yourself.
type: text-input
unit: null
variable: delimiter
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: null
label: Maximum chunk length
max_length: 48
options: []
placeholder: null
required: true
tooltips: null
type: number
unit: characters
variable: max_chunk_length
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: null
label: Chunk overlap
max_length: 48
options: []
placeholder: null
required: false
tooltips: null
type: number
unit: characters
variable: chunk_overlap
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: null
label: Replace consecutive spaces, newlines and tabs
max_length: 48
options: []
placeholder: null
required: false
tooltips: null
type: checkbox
unit: null
variable: replace_consecutive_spaces
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: null
label: Delete all URLs and email addresses
max_length: 48
options: []
placeholder: null
required: false
tooltips: null
type: checkbox
unit: null
variable: delete_urls_email

View File

@ -0,0 +1,709 @@
dependencies:
- current_identifier: null
type: marketplace
value:
plugin_unique_identifier: langgenius/general_chunker:0.0.1@e3da408b7277866404c3f884d599261f9d0b9003ea4ef7eb3b64489bdf39d18b
- current_identifier: null
type: marketplace
value:
plugin_unique_identifier: langgenius/dify_extractor:0.0.1@50103421d4e002f059b662d21ad2d7a1cf34869abdbe320299d7e382516ebb1c
kind: rag_pipeline
rag_pipeline:
description: ''
icon: 📙
icon_background: '#FFF4ED'
icon_type: emoji
name: file-general-high-quality
version: 0.1.0
workflow:
conversation_variables: []
environment_variables: []
features: {}
graph:
edges:
- data:
isInIteration: false
isInLoop: false
sourceType: datasource
targetType: if-else
id: 1752479895761-source-1752481129417-target
source: '1752479895761'
sourceHandle: source
target: '1752481129417'
targetHandle: target
type: custom
zIndex: 0
- data:
isInLoop: false
sourceType: if-else
targetType: tool
id: 1752481129417-24e47cad-f1e2-4f74-9884-3f49d5bb37b7-1752480460682-target
source: '1752481129417'
sourceHandle: 24e47cad-f1e2-4f74-9884-3f49d5bb37b7
target: '1752480460682'
targetHandle: target
type: custom
zIndex: 0
- data:
isInLoop: false
sourceType: if-else
targetType: document-extractor
id: 1752481129417-false-1752481112180-target
source: '1752481129417'
sourceHandle: 'false'
target: '1752481112180'
targetHandle: target
type: custom
zIndex: 0
- data:
isInIteration: false
isInLoop: false
sourceType: tool
targetType: variable-aggregator
id: 1752480460682-source-1752482022496-target
source: '1752480460682'
sourceHandle: source
target: '1752482022496'
targetHandle: target
type: custom
zIndex: 0
- data:
isInLoop: false
sourceType: document-extractor
targetType: variable-aggregator
id: 1752481112180-source-1752482022496-target
source: '1752481112180'
sourceHandle: source
target: '1752482022496'
targetHandle: target
type: custom
zIndex: 0
- data:
isInIteration: false
isInLoop: false
sourceType: variable-aggregator
targetType: tool
id: 1752482022496-source-1752482151668-target
source: '1752482022496'
sourceHandle: source
target: '1752482151668'
targetHandle: target
type: custom
zIndex: 0
- data:
isInIteration: false
isInLoop: false
sourceType: tool
targetType: knowledge-index
id: 1752482151668-source-1752477924228-target
source: '1752482151668'
sourceHandle: source
target: '1752477924228'
targetHandle: target
type: custom
zIndex: 0
nodes:
- data:
chunk_structure: text_model
embedding_model: text-embedding-ada-002
embedding_model_provider: langgenius/openai/openai
index_chunk_variable_selector:
- '1752482151668'
- result
indexing_technique: high_quality
keyword_number: 10
retrieval_model:
score_threshold: 0.5
score_threshold_enabled: false
search_method: semantic_search
top_k: 3
vector_setting:
embedding_model_name: text-embedding-ada-002
embedding_provider_name: langgenius/openai/openai
selected: false
title: Knowledge Base
type: knowledge-index
height: 114
id: '1752477924228'
position:
x: 1076.4656678451215
y: 281.3910724383104
positionAbsolute:
x: 1076.4656678451215
y: 281.3910724383104
selected: true
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
datasource_configurations: {}
datasource_label: File
datasource_name: upload-file
datasource_parameters: {}
fileExtensions:
- txt
- markdown
- mdx
- pdf
- html
- xlsx
- xls
- vtt
- properties
- doc
- docx
- csv
- eml
- msg
- pptx
- xml
- epub
- ppt
- md
plugin_id: langgenius/file
provider_name: file
provider_type: local_file
selected: false
title: File
type: datasource
height: 52
id: '1752479895761'
position:
x: -839.8603427660498
y: 251.3910724383104
positionAbsolute:
x: -839.8603427660498
y: 251.3910724383104
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
is_team_authorization: true
output_schema:
properties:
documents:
description: the documents extracted from the file
items:
type: object
type: array
images:
description: The images extracted from the file
items:
type: object
type: array
type: object
paramSchemas:
- auto_generate: null
default: null
form: llm
human_description:
en_US: the file to be parsed(support pdf, ppt, pptx, doc, docx, png, jpg,
jpeg)
ja_JP: 解析するファイル(pdf, ppt, pptx, doc, docx, png, jpg, jpegをサポート)
pt_BR: o arquivo a ser analisado (suporta pdf, ppt, pptx, doc, docx, png,
jpg, jpeg)
zh_Hans: 用于解析的文件(支持 pdf, ppt, pptx, doc, docx, png, jpg, jpeg)
label:
en_US: file
ja_JP: ファイル
pt_BR: arquivo
zh_Hans: file
llm_description: the file to be parsed (support pdf, ppt, pptx, doc, docx,
png, jpg, jpeg)
max: null
min: null
name: file
options: []
placeholder: null
precision: null
required: true
scope: null
template: null
type: file
params:
file: ''
provider_id: langgenius/dify_extractor/dify_extractor
provider_name: langgenius/dify_extractor/dify_extractor
provider_type: builtin
selected: false
title: Dify Extractor
tool_configurations: {}
tool_description: Dify Extractor
tool_label: Dify Extractor
tool_name: dify_extractor
tool_parameters:
file:
type: variable
value:
- '1752479895761'
- file
type: tool
height: 52
id: '1752480460682'
position:
x: -108.28652292656551
y: 281.3910724383104
positionAbsolute:
x: -108.28652292656551
y: 281.3910724383104
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
is_array_file: false
selected: false
title: 文档提取器
type: document-extractor
variable_selector:
- '1752479895761'
- file
height: 90
id: '1752481112180'
position:
x: -108.28652292656551
y: 390.6576481692478
positionAbsolute:
x: -108.28652292656551
y: 390.6576481692478
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
cases:
- case_id: 24e47cad-f1e2-4f74-9884-3f49d5bb37b7
conditions:
- comparison_operator: is
id: 9da88d93-3ff6-463f-abfd-6bcafbf2554d
value: .xlsx
varType: file
variable_selector:
- '1752479895761'
- file
- extension
- comparison_operator: is
id: d0e88f5e-dfe3-4bae-af0c-dbec267500de
value: .xls
varType: file
variable_selector:
- '1752479895761'
- file
- extension
- comparison_operator: is
id: a957e91e-1ed7-4c6b-9c80-2f0948858f1d
value: .md
varType: file
variable_selector:
- '1752479895761'
- file
- extension
- comparison_operator: is
id: 870c3c39-8d3f-474a-ab8b-9c0ccf53db73
value: .markdown
varType: file
variable_selector:
- '1752479895761'
- file
- extension
- comparison_operator: is
id: f9541513-1e71-4dc1-9db5-35dc84a39e3c
value: .mdx
varType: file
variable_selector:
- '1752479895761'
- file
- extension
- comparison_operator: is
id: 4c7f455b-ac20-40ca-9495-6cc44ffcb35d
value: .html
varType: file
variable_selector:
- '1752479895761'
- file
- extension
- comparison_operator: is
id: 2e12d9c7-8057-4a09-8851-f9fd1d0718d1
value: .htm
varType: file
variable_selector:
- '1752479895761'
- file
- extension
- comparison_operator: is
id: 73a995a9-d8b9-4aef-89f7-306e2ddcbce2
value: .docx
varType: file
variable_selector:
- '1752479895761'
- file
- extension
- comparison_operator: is
id: 8a2e8772-0426-458b-a1f9-9eaaec0f27c8
value: .csv
varType: file
variable_selector:
- '1752479895761'
- file
- extension
- comparison_operator: is
id: aa2cb6b6-a2fc-462a-a9f5-c9c3f33a1602
value: .txt
varType: file
variable_selector:
- '1752479895761'
- file
- extension
id: 24e47cad-f1e2-4f74-9884-3f49d5bb37b7
logical_operator: or
selected: false
title: IF/ELSE
type: if-else
height: 358
id: '1752481129417'
position:
x: -489.57009543377865
y: 251.3910724383104
positionAbsolute:
x: -489.57009543377865
y: 251.3910724383104
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
advanced_settings:
group_enabled: false
groups:
- groupId: f4cf07b4-914d-4544-8ef8-0c5d9e4f21a7
group_name: Group1
output_type: string
variables:
- - '1752481112180'
- text
- - '1752480460682'
- text
output_type: string
selected: false
title: Variable Aggregator
type: variable-aggregator
variables:
- - '1752481112180'
- text
- - '1752480460682'
- text
height: 129
id: '1752482022496'
position:
x: 319.441649575055
y: 281.3910724383104
positionAbsolute:
x: 319.441649575055
y: 281.3910724383104
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
is_team_authorization: true
output_schema:
properties:
result:
description: The result of the general chunk tool.
properties:
general_chunks:
items:
description: The chunk of the text.
type: string
type: array
type: object
type: object
paramSchemas:
- auto_generate: null
default: null
form: llm
human_description:
en_US: The text you want to chunk.
ja_JP: チャンク化したいテキスト。
pt_BR: O texto que você deseja dividir.
zh_Hans: 你想要分块的文本。
label:
en_US: Input Variable
ja_JP: 入力変数
pt_BR: Variável de entrada
zh_Hans: 输入变量
llm_description: The text you want to chunk.
max: null
min: null
name: input_variable
options: []
placeholder: null
precision: null
required: true
scope: null
template: null
type: string
- auto_generate: null
default: null
form: llm
human_description:
en_US: The delimiter of the chunks.
ja_JP: チャンクの区切り記号。
pt_BR: O delimitador dos pedaços.
zh_Hans: 块的分隔符。
label:
en_US: Delimiter
ja_JP: 区切り記号
pt_BR: Delimitador
zh_Hans: 分隔符
llm_description: The delimiter of the chunks, the format of the delimiter
must be a string.
max: null
min: null
name: delimiter
options: []
placeholder: null
precision: null
required: true
scope: null
template: null
type: string
- auto_generate: null
default: null
form: llm
human_description:
en_US: The maximum chunk length.
ja_JP: 最大長のチャンク。
pt_BR: O comprimento máximo do bloco
zh_Hans: 最大块的长度。
label:
en_US: Maximum Chunk Length
ja_JP: チャンク最大長
pt_BR: O comprimento máximo do bloco
zh_Hans: 最大块的长度
llm_description: The maximum chunk length, the format of the chunk size
must be an integer.
max: null
min: null
name: max_chunk_length
options: []
placeholder: null
precision: null
required: true
scope: null
template: null
type: number
- auto_generate: null
default: null
form: llm
human_description:
en_US: The chunk overlap length.
ja_JP: チャンクの重複長
pt_BR: The chunk overlap length.
zh_Hans: 块的重叠长度。
label:
en_US: Chunk Overlap Length
ja_JP: チャンク重複長
pt_BR: Chunk Overlap Length
zh_Hans: 块的重叠长度
llm_description: The chunk overlap length, the format of the chunk overlap
length must be an integer.
max: null
min: null
name: chunk_overlap_length
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: number
- auto_generate: null
default: null
form: llm
human_description:
en_US: Replace consecutive spaces, newlines and tabs
ja_JP: 連続のスペース、改行、まだはタブを置換する
pt_BR: Replace consecutive spaces, newlines and tabs
zh_Hans: 替换连续的空格、换行符和制表符
label:
en_US: Replace Consecutive Spaces, Newlines and Tabs
ja_JP: 連続のスペース、改行、まだはタブを置換する
pt_BR: Replace Consecutive Spaces, Newlines and Tabs
zh_Hans: 替换连续的空格、换行符和制表符
llm_description: Replace consecutive spaces, newlines and tabs, the format
of the replace must be a boolean.
max: null
min: null
name: replace_consecutive_spaces_newlines_tabs
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: boolean
- auto_generate: null
default: null
form: llm
human_description:
en_US: Delete all URLs and email addresses
ja_JP: すべてのURLとメールアドレスを削除する
pt_BR: Delete all URLs and email addresses
zh_Hans: 删除所有URL和电子邮件地址
label:
en_US: Delete All URLs and Email Addresses
ja_JP: すべてのURLとメールアドレスを削除する
pt_BR: Delete All URLs and Email Addresses
zh_Hans: 删除所有URL和电子邮件地址
llm_description: Delete all URLs and email addresses, the format of the
delete must be a boolean.
max: null
min: null
name: delete_all_urls_and_email_addresses
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: boolean
params:
chunk_overlap_length: ''
delete_all_urls_and_email_addresses: ''
delimiter: ''
input_variable: ''
max_chunk_length: ''
replace_consecutive_spaces_newlines_tabs: ''
provider_id: langgenius/general_chunker/general_chunker
provider_name: langgenius/general_chunker/general_chunker
provider_type: builtin
selected: false
title: General Chunker
tool_configurations: {}
tool_description: A tool for general text chunking mode, the chunks retrieved and recalled are the same.
tool_label: General Chunker
tool_name: general_chunker
tool_parameters:
chunk_overlap_length:
type: variable
value:
- rag
- shared
- chunk_overlap
delete_all_urls_and_email_addresses:
type: mixed
value: '{{#rag.shared.delete_urls_email#}}'
delimiter:
type: mixed
value: '{{#rag.shared.delimiter#}}'
input_variable:
type: mixed
value: '{{#1752482022496.output#}}'
max_chunk_length:
type: variable
value:
- rag
- shared
- max_chunk_length
replace_consecutive_spaces_newlines_tabs:
type: mixed
value: '{{#rag.shared.replace_consecutive_spaces#}}'
type: tool
height: 52
id: '1752482151668'
position:
x: 693.5300771507484
y: 281.3910724383104
positionAbsolute:
x: 693.5300771507484
y: 281.3910724383104
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
viewport:
x: 701.4999626224237
y: 128.33739021504016
zoom: 0.48941689643726966
rag_pipeline_variables:
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: \n\n
label: Delimiter
max_length: 100
options: []
placeholder: null
required: true
tooltips: A delimiter is the character used to separate text. \n\n is recommended
for splitting the original document into large parent chunks. You can also use
special delimiters defined by yourself.
type: text-input
unit: null
variable: delimiter
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: null
label: Maximum chunk length
max_length: 48
options: []
placeholder: null
required: true
tooltips: null
type: number
unit: characters
variable: max_chunk_length
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: null
label: Chunk overlap
max_length: 48
options: []
placeholder: null
required: false
tooltips: null
type: number
unit: characters
variable: chunk_overlap
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: null
label: Replace consecutive spaces, newlines and tabs
max_length: 48
options: []
placeholder: null
required: false
tooltips: null
type: checkbox
unit: null
variable: replace_consecutive_spaces
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: null
label: Delete all URLs and email addresses
max_length: 48
options: []
placeholder: null
required: false
tooltips: null
type: checkbox
unit: null
variable: delete_urls_email

View File

@ -0,0 +1,814 @@
dependencies:
- current_identifier: null
type: marketplace
value:
plugin_unique_identifier: langgenius/parentchild_chunker:0.0.1@b1a28a27e33fec442ce494da2a7814edd7eb9d646c81f38bccfcf1133d486e40
- current_identifier: null
type: marketplace
value:
plugin_unique_identifier: langgenius/dify_extractor:0.0.1@50103421d4e002f059b662d21ad2d7a1cf34869abdbe320299d7e382516ebb1c
kind: rag_pipeline
rag_pipeline:
description: ''
icon: 📙
icon_background: '#FFF4ED'
icon_type: emoji
name: file-parentchild
version: 0.1.0
workflow:
conversation_variables: []
environment_variables: []
features: {}
graph:
edges:
- data:
isInIteration: false
isInLoop: false
sourceType: datasource
targetType: if-else
id: 1752479895761-source-1752481129417-target
source: '1752479895761'
sourceHandle: source
target: '1752481129417'
targetHandle: target
type: custom
zIndex: 0
- data:
isInLoop: false
sourceType: if-else
targetType: tool
id: 1752481129417-24e47cad-f1e2-4f74-9884-3f49d5bb37b7-1752480460682-target
source: '1752481129417'
sourceHandle: 24e47cad-f1e2-4f74-9884-3f49d5bb37b7
target: '1752480460682'
targetHandle: target
type: custom
zIndex: 0
- data:
isInLoop: false
sourceType: if-else
targetType: document-extractor
id: 1752481129417-false-1752481112180-target
source: '1752481129417'
sourceHandle: 'false'
target: '1752481112180'
targetHandle: target
type: custom
zIndex: 0
- data:
isInIteration: false
isInLoop: false
sourceType: tool
targetType: variable-aggregator
id: 1752480460682-source-1752482022496-target
source: '1752480460682'
sourceHandle: source
target: '1752482022496'
targetHandle: target
type: custom
zIndex: 0
- data:
isInLoop: false
sourceType: document-extractor
targetType: variable-aggregator
id: 1752481112180-source-1752482022496-target
source: '1752481112180'
sourceHandle: source
target: '1752482022496'
targetHandle: target
type: custom
zIndex: 0
- data:
isInIteration: false
isInLoop: false
sourceType: variable-aggregator
targetType: tool
id: 1752482022496-source-1752575473519-target
source: '1752482022496'
sourceHandle: source
target: '1752575473519'
targetHandle: target
type: custom
zIndex: 0
- data:
isInLoop: false
sourceType: tool
targetType: knowledge-index
id: 1752575473519-source-1752477924228-target
source: '1752575473519'
sourceHandle: source
target: '1752477924228'
targetHandle: target
type: custom
zIndex: 0
nodes:
- data:
chunk_structure: hierarchical_model
embedding_model: text-embedding-ada-002
embedding_model_provider: langgenius/openai/openai
index_chunk_variable_selector:
- '1752575473519'
- result
indexing_technique: high_quality
keyword_number: 10
retrieval_model:
score_threshold: 0.5
score_threshold_enabled: false
search_method: semantic_search
top_k: 3
vector_setting:
embedding_model_name: text-embedding-ada-002
embedding_provider_name: langgenius/openai/openai
selected: false
title: Knowledge Base
type: knowledge-index
height: 114
id: '1752477924228'
position:
x: 994.3774545394483
y: 281.3910724383104
positionAbsolute:
x: 994.3774545394483
y: 281.3910724383104
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
datasource_configurations: {}
datasource_label: File
datasource_name: upload-file
datasource_parameters: {}
fileExtensions:
- txt
- markdown
- mdx
- pdf
- html
- xlsx
- xls
- vtt
- properties
- doc
- docx
- csv
- eml
- msg
- pptx
- xml
- epub
- ppt
- md
plugin_id: langgenius/file
provider_name: file
provider_type: local_file
selected: false
title: File
type: datasource
height: 52
id: '1752479895761'
position:
x: -839.8603427660498
y: 251.3910724383104
positionAbsolute:
x: -839.8603427660498
y: 251.3910724383104
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
is_team_authorization: true
output_schema:
properties:
documents:
description: the documents extracted from the file
items:
type: object
type: array
images:
description: The images extracted from the file
items:
type: object
type: array
type: object
paramSchemas:
- auto_generate: null
default: null
form: llm
human_description:
en_US: the file to be parsed(support pdf, ppt, pptx, doc, docx, png, jpg,
jpeg)
ja_JP: 解析するファイル(pdf, ppt, pptx, doc, docx, png, jpg, jpegをサポート)
pt_BR: o arquivo a ser analisado (suporta pdf, ppt, pptx, doc, docx, png,
jpg, jpeg)
zh_Hans: 用于解析的文件(支持 pdf, ppt, pptx, doc, docx, png, jpg, jpeg)
label:
en_US: file
ja_JP: ファイル
pt_BR: arquivo
zh_Hans: file
llm_description: the file to be parsed (support pdf, ppt, pptx, doc, docx,
png, jpg, jpeg)
max: null
min: null
name: file
options: []
placeholder: null
precision: null
required: true
scope: null
template: null
type: file
params:
file: ''
provider_id: langgenius/dify_extractor/dify_extractor
provider_name: langgenius/dify_extractor/dify_extractor
provider_type: builtin
selected: false
title: Dify Extractor
tool_configurations: {}
tool_description: Dify Extractor
tool_label: Dify Extractor
tool_name: dify_extractor
tool_parameters:
file:
type: variable
value:
- '1752479895761'
- file
type: tool
height: 52
id: '1752480460682'
position:
x: -108.28652292656551
y: 281.3910724383104
positionAbsolute:
x: -108.28652292656551
y: 281.3910724383104
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
is_array_file: false
selected: false
title: 文档提取器
type: document-extractor
variable_selector:
- '1752479895761'
- file
height: 90
id: '1752481112180'
position:
x: -108.28652292656551
y: 390.6576481692478
positionAbsolute:
x: -108.28652292656551
y: 390.6576481692478
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
cases:
- case_id: 24e47cad-f1e2-4f74-9884-3f49d5bb37b7
conditions:
- comparison_operator: is
id: 9da88d93-3ff6-463f-abfd-6bcafbf2554d
value: .xlsx
varType: file
variable_selector:
- '1752479895761'
- file
- extension
- comparison_operator: is
id: d0e88f5e-dfe3-4bae-af0c-dbec267500de
value: .xls
varType: file
variable_selector:
- '1752479895761'
- file
- extension
- comparison_operator: is
id: a957e91e-1ed7-4c6b-9c80-2f0948858f1d
value: .md
varType: file
variable_selector:
- '1752479895761'
- file
- extension
- comparison_operator: is
id: 870c3c39-8d3f-474a-ab8b-9c0ccf53db73
value: .markdown
varType: file
variable_selector:
- '1752479895761'
- file
- extension
- comparison_operator: is
id: f9541513-1e71-4dc1-9db5-35dc84a39e3c
value: .mdx
varType: file
variable_selector:
- '1752479895761'
- file
- extension
- comparison_operator: is
id: 4c7f455b-ac20-40ca-9495-6cc44ffcb35d
value: .html
varType: file
variable_selector:
- '1752479895761'
- file
- extension
- comparison_operator: is
id: 2e12d9c7-8057-4a09-8851-f9fd1d0718d1
value: .htm
varType: file
variable_selector:
- '1752479895761'
- file
- extension
- comparison_operator: is
id: 73a995a9-d8b9-4aef-89f7-306e2ddcbce2
value: .docx
varType: file
variable_selector:
- '1752479895761'
- file
- extension
- comparison_operator: is
id: 8a2e8772-0426-458b-a1f9-9eaaec0f27c8
value: .csv
varType: file
variable_selector:
- '1752479895761'
- file
- extension
- comparison_operator: is
id: aa2cb6b6-a2fc-462a-a9f5-c9c3f33a1602
value: .txt
varType: file
variable_selector:
- '1752479895761'
- file
- extension
id: 24e47cad-f1e2-4f74-9884-3f49d5bb37b7
logical_operator: or
selected: false
title: IF/ELSE
type: if-else
height: 358
id: '1752481129417'
position:
x: -512.2335487893622
y: 251.3910724383104
positionAbsolute:
x: -512.2335487893622
y: 251.3910724383104
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
advanced_settings:
group_enabled: false
groups:
- groupId: f4cf07b4-914d-4544-8ef8-0c5d9e4f21a7
group_name: Group1
output_type: string
variables:
- - '1752481112180'
- text
- - '1752480460682'
- text
output_type: string
selected: false
title: Variable Aggregator
type: variable-aggregator
variables:
- - '1752481112180'
- text
- - '1752480460682'
- text
height: 129
id: '1752482022496'
position:
x: 319.441649575055
y: 281.3910724383104
positionAbsolute:
x: 319.441649575055
y: 281.3910724383104
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
is_team_authorization: true
output_schema:
properties:
result:
description: Parent child chunks result
items:
type: object
type: array
type: object
paramSchemas:
- auto_generate: null
default: null
form: llm
human_description:
en_US: The text you want to chunk.
ja_JP: チャンク化したいテキスト。
pt_BR: O texto que você deseja dividir.
zh_Hans: 你想要分块的文本。
label:
en_US: Input text
ja_JP: 入力テキスト
pt_BR: Texto de entrada
zh_Hans: 输入文本
llm_description: The text you want to chunk.
max: null
min: null
name: input_text
options: []
placeholder: null
precision: null
required: true
scope: null
template: null
type: string
- auto_generate: null
default: 1024
form: llm
human_description:
en_US: Maximum length for chunking
ja_JP: チャンク分割の最大長
pt_BR: Comprimento máximo para divisão
zh_Hans: 用于分块的最大长度
label:
en_US: Maximum Length
ja_JP: 最大長
pt_BR: Comprimento Máximo
zh_Hans: 最大长度
llm_description: Maximum length allowed per chunk
max: null
min: null
name: max_length
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: number
- auto_generate: null
default: '
'
form: llm
human_description:
en_US: Separator used for chunking
ja_JP: チャンク分割に使用する区切り文字
pt_BR: Separador usado para divisão
zh_Hans: 用于分块的分隔符
label:
en_US: Chunk Separator
ja_JP: チャンク区切り文字
pt_BR: Separador de Divisão
zh_Hans: 分块分隔符
llm_description: The separator used to split chunks
max: null
min: null
name: separator
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: string
- auto_generate: null
default: 512
form: llm
human_description:
en_US: Maximum length for subchunking
ja_JP: サブチャンク分割の最大長
pt_BR: Comprimento máximo para subdivisão
zh_Hans: 用于子分块的最大长度
label:
en_US: Subchunk Maximum Length
ja_JP: サブチャンク最大長
pt_BR: Comprimento Máximo de Subdivisão
zh_Hans: 子分块最大长度
llm_description: Maximum length allowed per subchunk
max: null
min: null
name: subchunk_max_length
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: number
- auto_generate: null
default: '. '
form: llm
human_description:
en_US: Separator used for subchunking
ja_JP: サブチャンク分割に使用する区切り文字
pt_BR: Separador usado para subdivisão
zh_Hans: 用于子分块的分隔符
label:
en_US: Subchunk Separator
ja_JP: サブチャンキング用セパレーター
pt_BR: Separador de Subdivisão
zh_Hans: 子分块分隔符
llm_description: The separator used to split subchunks
max: null
min: null
name: subchunk_separator
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: string
- auto_generate: null
default: paragraph
form: llm
human_description:
en_US: Split text into paragraphs based on separator and maximum chunk
length, using split text as parent block or entire document as parent
block and directly retrieve.
ja_JP: セパレーターと最大チャンク長に基づいてテキストを段落に分割し、分割されたテキスト
を親ブロックとして使用するか、文書全体を親ブロックとして使用して直接取得します。
pt_BR: Dividir texto em parágrafos com base no separador e no comprimento
máximo do bloco, usando o texto dividido como bloco pai ou documento
completo como bloco pai e diretamente recuperá-lo.
zh_Hans: 根据分隔符和最大块长度将文本拆分为段落,使用拆分文本作为检索的父块或整个文档用作父块并直接检索。
label:
en_US: Parent Mode
ja_JP: 親子モード
pt_BR: Modo Pai
zh_Hans: 父块模式
llm_description: Split text into paragraphs based on separator and maximum
chunk length, using split text as parent block or entire document as parent
block and directly retrieve.
max: null
min: null
name: parent_mode
options:
- icon: ''
label:
en_US: Paragraph
ja_JP: 段落
pt_BR: Parágrafo
zh_Hans: 段落
value: paragraph
- icon: ''
label:
en_US: Full Document
ja_JP: 全文
pt_BR: Documento Completo
zh_Hans: 全文
value: full_doc
placeholder: null
precision: null
required: true
scope: null
template: null
type: select
- auto_generate: null
default: 0
form: llm
human_description:
en_US: Whether to remove extra spaces in the text
ja_JP: テキスト内の余分なスペースを削除するかどうか
pt_BR: Se deve remover espaços extras no texto
zh_Hans: 是否移除文本中的多余空格
label:
en_US: Remove Extra Spaces
ja_JP: 余分なスペースを削除
pt_BR: Remover Espaços Extras
zh_Hans: 移除多余空格
llm_description: Whether to remove extra spaces in the text
max: null
min: null
name: remove_extra_spaces
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: boolean
- auto_generate: null
default: 0
form: llm
human_description:
en_US: Whether to remove URLs and emails in the text
ja_JP: テキスト内のURLやメールアドレスを削除するかどうか
pt_BR: Se deve remover URLs e e-mails no texto
zh_Hans: 是否移除文本中的URL和电子邮件地址
label:
en_US: Remove URLs and Emails
ja_JP: URLとメールアドレスを削除
pt_BR: Remover URLs e E-mails
zh_Hans: 移除URL和电子邮件地址
llm_description: Whether to remove URLs and emails in the text
max: null
min: null
name: remove_urls_emails
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: boolean
params:
input_text: ''
max_length: ''
parent_mode: ''
remove_extra_spaces: ''
remove_urls_emails: ''
separator: ''
subchunk_max_length: ''
subchunk_separator: ''
provider_id: langgenius/parentchild_chunker/parentchild_chunker
provider_name: langgenius/parentchild_chunker/parentchild_chunker
provider_type: builtin
selected: false
title: Parent-child Chunker
tool_configurations: {}
tool_description: Parent-child Chunk Structure
tool_label: Parent-child Chunker
tool_name: parentchild_chunker
tool_parameters:
input_text:
type: mixed
value: '{{#1752482022496.output#}}'
max_length:
type: variable
value:
- rag
- shared
- max_chunk_length
parent_mode:
type: variable
value:
- rag
- shared
- parent_mode
remove_extra_spaces:
type: mixed
value: '{{#rag.shared.replace_consecutive_spaces#}}'
remove_urls_emails:
type: mixed
value: '{{#rag.shared.delete_urls_email#}}'
separator:
type: mixed
value: '{{#rag.shared.delimiter#}}'
subchunk_max_length:
type: variable
value:
- rag
- shared
- child_max_chunk_length
subchunk_separator:
type: mixed
value: '{{#rag.shared.child_delimiter#}}'
type: tool
height: 52
id: '1752575473519'
position:
x: 637.9241611063885
y: 281.3910724383104
positionAbsolute:
x: 637.9241611063885
y: 281.3910724383104
selected: true
sourcePosition: right
targetPosition: left
type: custom
width: 242
viewport:
x: 948.6766333808323
y: -102.06757184183238
zoom: 0.8375774577380971
rag_pipeline_variables:
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: \n\n
label: Delimiter
max_length: 256
options: []
placeholder: null
required: true
tooltips: A delimiter is the character used to separate text. \n\n is recommended
for splitting the original document into large parent chunks. You can also use
special delimiters defined by yourself.
type: text-input
unit: null
variable: delimiter
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: 1024
label: Maximum chunk length
max_length: 48
options: []
placeholder: null
required: true
tooltips: null
type: number
unit: characters
variable: max_chunk_length
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: \n
label: Child delimiter
max_length: 256
options: []
placeholder: null
required: true
tooltips: A delimiter is the character used to separate text. \n\n is recommended
for splitting the original document into large parent chunks. You can also use
special delimiters defined by yourself.
type: text-input
unit: null
variable: child_delimiter
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: 512
label: Child max chunk length
max_length: 48
options: []
placeholder: null
required: true
tooltips: null
type: number
unit: characters
variable: child_max_chunk_length
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: paragraph
label: Parent mode
max_length: 48
options:
- full_doc
- paragraph
placeholder: null
required: true
tooltips: null
type: select
unit: null
variable: parent_mode
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: null
label: Replace consecutive spaces, newlines and tabs
max_length: 48
options: []
placeholder: null
required: false
tooltips: null
type: checkbox
unit: null
variable: replace_consecutive_spaces
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: null
label: Delete all URLs and email addresses
max_length: 48
options: []
placeholder: null
required: false
tooltips: null
type: checkbox
unit: null
variable: delete_urls_email

View File

@ -0,0 +1,400 @@
dependencies:
- current_identifier: null
type: marketplace
value:
plugin_unique_identifier: langgenius/general_chunker:0.0.1@e3da408b7277866404c3f884d599261f9d0b9003ea4ef7eb3b64489bdf39d18b
- current_identifier: null
type: marketplace
value:
plugin_unique_identifier: langgenius/notion_datasource:0.0.1@2dd49c2c3ffff976be8d22efb1ac0f63522a8d0f24ef8c44729d0a50a94ec039
kind: rag_pipeline
rag_pipeline:
description: ''
icon: 📙
icon_background: ''
icon_type: emoji
name: notion-general-economy
version: 0.1.0
workflow:
conversation_variables: []
environment_variables: []
features: {}
graph:
edges:
- data:
isInIteration: false
isInLoop: false
sourceType: tool
targetType: knowledge-index
id: 1752482151668-source-1752477924228-target
source: '1752482151668'
sourceHandle: source
target: '1752477924228'
targetHandle: target
type: custom
zIndex: 0
- data:
isInIteration: false
isInLoop: false
sourceType: datasource
targetType: tool
id: 1752489759475-source-1752482151668-target
source: '1752489759475'
sourceHandle: source
target: '1752482151668'
targetHandle: target
type: custom
zIndex: 0
nodes:
- data:
chunk_structure: text_model
embedding_model: text-embedding-ada-002
embedding_model_provider: langgenius/openai/openai
index_chunk_variable_selector:
- '1752482151668'
- result
indexing_technique: economy
keyword_number: 10
retrieval_model:
score_threshold: 0.5
score_threshold_enabled: false
search_method: keyword_search
top_k: 3
vector_setting:
embedding_model_name: text-embedding-ada-002
embedding_provider_name: langgenius/openai/openai
selected: true
title: Knowledge Base
type: knowledge-index
height: 114
id: '1752477924228'
position:
x: 1444.5503479271906
y: 281.3910724383104
positionAbsolute:
x: 1444.5503479271906
y: 281.3910724383104
selected: true
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
is_team_authorization: true
output_schema:
properties:
result:
description: The result of the general chunk tool.
properties:
general_chunks:
items:
description: The chunk of the text.
type: string
type: array
type: object
type: object
paramSchemas:
- auto_generate: null
default: null
form: llm
human_description:
en_US: The text you want to chunk.
ja_JP: チャンク化したいテキスト。
pt_BR: O texto que você deseja dividir.
zh_Hans: 你想要分块的文本。
label:
en_US: Input Variable
ja_JP: 入力変数
pt_BR: Variável de entrada
zh_Hans: 输入变量
llm_description: The text you want to chunk.
max: null
min: null
name: input_variable
options: []
placeholder: null
precision: null
required: true
scope: null
template: null
type: string
- auto_generate: null
default: null
form: llm
human_description:
en_US: The delimiter of the chunks.
ja_JP: チャンクの区切り記号。
pt_BR: O delimitador dos pedaços.
zh_Hans: 块的分隔符。
label:
en_US: Delimiter
ja_JP: 区切り記号
pt_BR: Delimitador
zh_Hans: 分隔符
llm_description: The delimiter of the chunks, the format of the delimiter
must be a string.
max: null
min: null
name: delimiter
options: []
placeholder: null
precision: null
required: true
scope: null
template: null
type: string
- auto_generate: null
default: null
form: llm
human_description:
en_US: The maximum chunk length.
ja_JP: 最大長のチャンク。
pt_BR: O comprimento máximo do bloco
zh_Hans: 最大块的长度。
label:
en_US: Maximum Chunk Length
ja_JP: チャンク最大長
pt_BR: O comprimento máximo do bloco
zh_Hans: 最大块的长度
llm_description: The maximum chunk length, the format of the chunk size
must be an integer.
max: null
min: null
name: max_chunk_length
options: []
placeholder: null
precision: null
required: true
scope: null
template: null
type: number
- auto_generate: null
default: null
form: llm
human_description:
en_US: The chunk overlap length.
ja_JP: チャンクの重複長
pt_BR: The chunk overlap length.
zh_Hans: 块的重叠长度。
label:
en_US: Chunk Overlap Length
ja_JP: チャンク重複長
pt_BR: Chunk Overlap Length
zh_Hans: 块的重叠长度
llm_description: The chunk overlap length, the format of the chunk overlap
length must be an integer.
max: null
min: null
name: chunk_overlap_length
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: number
- auto_generate: null
default: null
form: llm
human_description:
en_US: Replace consecutive spaces, newlines and tabs
ja_JP: 連続のスペース、改行、まだはタブを置換する
pt_BR: Replace consecutive spaces, newlines and tabs
zh_Hans: 替换连续的空格、换行符和制表符
label:
en_US: Replace Consecutive Spaces, Newlines and Tabs
ja_JP: 連続のスペース、改行、まだはタブを置換する
pt_BR: Replace Consecutive Spaces, Newlines and Tabs
zh_Hans: 替换连续的空格、换行符和制表符
llm_description: Replace consecutive spaces, newlines and tabs, the format
of the replace must be a boolean.
max: null
min: null
name: replace_consecutive_spaces_newlines_tabs
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: boolean
- auto_generate: null
default: null
form: llm
human_description:
en_US: Delete all URLs and email addresses
ja_JP: すべてのURLとメールアドレスを削除する
pt_BR: Delete all URLs and email addresses
zh_Hans: 删除所有URL和电子邮件地址
label:
en_US: Delete All URLs and Email Addresses
ja_JP: すべてのURLとメールアドレスを削除する
pt_BR: Delete All URLs and Email Addresses
zh_Hans: 删除所有URL和电子邮件地址
llm_description: Delete all URLs and email addresses, the format of the
delete must be a boolean.
max: null
min: null
name: delete_all_urls_and_email_addresses
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: boolean
params:
chunk_overlap_length: ''
delete_all_urls_and_email_addresses: ''
delimiter: ''
input_variable: ''
max_chunk_length: ''
replace_consecutive_spaces_newlines_tabs: ''
provider_id: langgenius/general_chunker/general_chunker
provider_name: langgenius/general_chunker/general_chunker
provider_type: builtin
selected: false
title: General Chunker
tool_configurations: {}
tool_description: A tool for general text chunking mode, the chunks retrieved and recalled are the same.
tool_label: General Chunker
tool_name: general_chunker
tool_parameters:
chunk_overlap_length:
type: variable
value:
- rag
- shared
- chunk_overlap
delete_all_urls_and_email_addresses:
type: mixed
value: '{{#rag.shared.delete_urls_email#}}'
delimiter:
type: mixed
value: '{{#rag.shared.delimiter#}}'
input_variable:
type: mixed
value: '{{#1752489759475.content#}}'
max_chunk_length:
type: variable
value:
- rag
- shared
- max_chunk_length
replace_consecutive_spaces_newlines_tabs:
type: mixed
value: '{{#rag.shared.replace_consecutive_spaces#}}'
type: tool
height: 52
id: '1752482151668'
position:
x: 1063.6922916384628
y: 281.3910724383104
positionAbsolute:
x: 1063.6922916384628
y: 281.3910724383104
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
datasource_configurations: {}
datasource_label: Notion数据源
datasource_name: notion_datasource
datasource_parameters: {}
plugin_id: langgenius/notion_datasource
provider_name: notion_datasource
provider_type: online_document
selected: false
title: Notion数据源
type: datasource
height: 52
id: '1752489759475'
position:
x: 736.9082104000458
y: 281.3910724383104
positionAbsolute:
x: 736.9082104000458
y: 281.3910724383104
sourcePosition: right
targetPosition: left
type: custom
width: 242
viewport:
x: -838.569649323166
y: -168.94656489167426
zoom: 1.286925643857699
rag_pipeline_variables:
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: \n\n
label: Delimiter
max_length: 100
options: []
placeholder: null
required: true
tooltips: A delimiter is the character used to separate text. \n\n is recommended
for splitting the original document into large parent chunks. You can also use
special delimiters defined by yourself.
type: text-input
unit: null
variable: delimiter
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: null
label: Maximum chunk length
max_length: 48
options: []
placeholder: null
required: true
tooltips: null
type: number
unit: characters
variable: max_chunk_length
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: null
label: Chunk overlap
max_length: 48
options: []
placeholder: null
required: false
tooltips: null
type: number
unit: characters
variable: chunk_overlap
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: null
label: Replace consecutive spaces, newlines and tabs
max_length: 48
options: []
placeholder: null
required: false
tooltips: null
type: checkbox
unit: null
variable: replace_consecutive_spaces
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: null
label: Delete all URLs and email addresses
max_length: 48
options: []
placeholder: null
required: false
tooltips: null
type: checkbox
unit: null
variable: delete_urls_email

View File

@ -0,0 +1,400 @@
dependencies:
- current_identifier: null
type: marketplace
value:
plugin_unique_identifier: langgenius/general_chunker:0.0.1@e3da408b7277866404c3f884d599261f9d0b9003ea4ef7eb3b64489bdf39d18b
- current_identifier: null
type: marketplace
value:
plugin_unique_identifier: langgenius/notion_datasource:0.0.1@2dd49c2c3ffff976be8d22efb1ac0f63522a8d0f24ef8c44729d0a50a94ec039
kind: rag_pipeline
rag_pipeline:
description: ''
icon: 📙
icon_background: '#FFF4ED'
icon_type: emoji
name: notion-general-high-quality
version: 0.1.0
workflow:
conversation_variables: []
environment_variables: []
features: {}
graph:
edges:
- data:
isInIteration: false
isInLoop: false
sourceType: tool
targetType: knowledge-index
id: 1752482151668-source-1752477924228-target
source: '1752482151668'
sourceHandle: source
target: '1752477924228'
targetHandle: target
type: custom
zIndex: 0
- data:
isInIteration: false
isInLoop: false
sourceType: datasource
targetType: tool
id: 1752489759475-source-1752482151668-target
source: '1752489759475'
sourceHandle: source
target: '1752482151668'
targetHandle: target
type: custom
zIndex: 0
nodes:
- data:
chunk_structure: text_model
embedding_model: text-embedding-ada-002
embedding_model_provider: langgenius/openai/openai
index_chunk_variable_selector:
- '1752482151668'
- result
indexing_technique: high_quality
keyword_number: 10
retrieval_model:
score_threshold: 0.5
score_threshold_enabled: false
search_method: semantic_search
top_k: 3
vector_setting:
embedding_model_name: text-embedding-ada-002
embedding_provider_name: langgenius/openai/openai
selected: true
title: Knowledge Base
type: knowledge-index
height: 114
id: '1752477924228'
position:
x: 1444.5503479271906
y: 281.3910724383104
positionAbsolute:
x: 1444.5503479271906
y: 281.3910724383104
selected: true
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
is_team_authorization: true
output_schema:
properties:
result:
description: The result of the general chunk tool.
properties:
general_chunks:
items:
description: The chunk of the text.
type: string
type: array
type: object
type: object
paramSchemas:
- auto_generate: null
default: null
form: llm
human_description:
en_US: The text you want to chunk.
ja_JP: チャンク化したいテキスト。
pt_BR: O texto que você deseja dividir.
zh_Hans: 你想要分块的文本。
label:
en_US: Input Variable
ja_JP: 入力変数
pt_BR: Variável de entrada
zh_Hans: 输入变量
llm_description: The text you want to chunk.
max: null
min: null
name: input_variable
options: []
placeholder: null
precision: null
required: true
scope: null
template: null
type: string
- auto_generate: null
default: null
form: llm
human_description:
en_US: The delimiter of the chunks.
ja_JP: チャンクの区切り記号。
pt_BR: O delimitador dos pedaços.
zh_Hans: 块的分隔符。
label:
en_US: Delimiter
ja_JP: 区切り記号
pt_BR: Delimitador
zh_Hans: 分隔符
llm_description: The delimiter of the chunks, the format of the delimiter
must be a string.
max: null
min: null
name: delimiter
options: []
placeholder: null
precision: null
required: true
scope: null
template: null
type: string
- auto_generate: null
default: null
form: llm
human_description:
en_US: The maximum chunk length.
ja_JP: 最大長のチャンク。
pt_BR: O comprimento máximo do bloco
zh_Hans: 最大块的长度。
label:
en_US: Maximum Chunk Length
ja_JP: チャンク最大長
pt_BR: O comprimento máximo do bloco
zh_Hans: 最大块的长度
llm_description: The maximum chunk length, the format of the chunk size
must be an integer.
max: null
min: null
name: max_chunk_length
options: []
placeholder: null
precision: null
required: true
scope: null
template: null
type: number
- auto_generate: null
default: null
form: llm
human_description:
en_US: The chunk overlap length.
ja_JP: チャンクの重複長
pt_BR: The chunk overlap length.
zh_Hans: 块的重叠长度。
label:
en_US: Chunk Overlap Length
ja_JP: チャンク重複長
pt_BR: Chunk Overlap Length
zh_Hans: 块的重叠长度
llm_description: The chunk overlap length, the format of the chunk overlap
length must be an integer.
max: null
min: null
name: chunk_overlap_length
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: number
- auto_generate: null
default: null
form: llm
human_description:
en_US: Replace consecutive spaces, newlines and tabs
ja_JP: 連続のスペース、改行、まだはタブを置換する
pt_BR: Replace consecutive spaces, newlines and tabs
zh_Hans: 替换连续的空格、换行符和制表符
label:
en_US: Replace Consecutive Spaces, Newlines and Tabs
ja_JP: 連続のスペース、改行、まだはタブを置換する
pt_BR: Replace Consecutive Spaces, Newlines and Tabs
zh_Hans: 替换连续的空格、换行符和制表符
llm_description: Replace consecutive spaces, newlines and tabs, the format
of the replace must be a boolean.
max: null
min: null
name: replace_consecutive_spaces_newlines_tabs
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: boolean
- auto_generate: null
default: null
form: llm
human_description:
en_US: Delete all URLs and email addresses
ja_JP: すべてのURLとメールアドレスを削除する
pt_BR: Delete all URLs and email addresses
zh_Hans: 删除所有URL和电子邮件地址
label:
en_US: Delete All URLs and Email Addresses
ja_JP: すべてのURLとメールアドレスを削除する
pt_BR: Delete All URLs and Email Addresses
zh_Hans: 删除所有URL和电子邮件地址
llm_description: Delete all URLs and email addresses, the format of the
delete must be a boolean.
max: null
min: null
name: delete_all_urls_and_email_addresses
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: boolean
params:
chunk_overlap_length: ''
delete_all_urls_and_email_addresses: ''
delimiter: ''
input_variable: ''
max_chunk_length: ''
replace_consecutive_spaces_newlines_tabs: ''
provider_id: langgenius/general_chunker/general_chunker
provider_name: langgenius/general_chunker/general_chunker
provider_type: builtin
selected: false
title: General Chunker
tool_configurations: {}
tool_description: A tool for general text chunking mode, the chunks retrieved and recalled are the same.
tool_label: General Chunker
tool_name: general_chunker
tool_parameters:
chunk_overlap_length:
type: variable
value:
- rag
- shared
- chunk_overlap
delete_all_urls_and_email_addresses:
type: mixed
value: '{{#rag.shared.delete_urls_email#}}'
delimiter:
type: mixed
value: '{{#rag.shared.delimiter#}}'
input_variable:
type: mixed
value: '{{#1752489759475.content#}}'
max_chunk_length:
type: variable
value:
- rag
- shared
- max_chunk_length
replace_consecutive_spaces_newlines_tabs:
type: mixed
value: '{{#rag.shared.replace_consecutive_spaces#}}'
type: tool
height: 52
id: '1752482151668'
position:
x: 1063.6922916384628
y: 281.3910724383104
positionAbsolute:
x: 1063.6922916384628
y: 281.3910724383104
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
datasource_configurations: {}
datasource_label: Notion数据源
datasource_name: notion_datasource
datasource_parameters: {}
plugin_id: langgenius/notion_datasource
provider_name: notion_datasource
provider_type: online_document
selected: false
title: Notion数据源
type: datasource
height: 52
id: '1752489759475'
position:
x: 736.9082104000458
y: 281.3910724383104
positionAbsolute:
x: 736.9082104000458
y: 281.3910724383104
sourcePosition: right
targetPosition: left
type: custom
width: 242
viewport:
x: -838.569649323166
y: -168.94656489167426
zoom: 1.286925643857699
rag_pipeline_variables:
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: \n\n
label: Delimiter
max_length: 100
options: []
placeholder: null
required: true
tooltips: A delimiter is the character used to separate text. \n\n is recommended
for splitting the original document into large parent chunks. You can also use
special delimiters defined by yourself.
type: text-input
unit: null
variable: delimiter
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: null
label: Maximum chunk length
max_length: 48
options: []
placeholder: null
required: true
tooltips: null
type: number
unit: characters
variable: max_chunk_length
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: null
label: Chunk overlap
max_length: 48
options: []
placeholder: null
required: false
tooltips: null
type: number
unit: characters
variable: chunk_overlap
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: null
label: Replace consecutive spaces, newlines and tabs
max_length: 48
options: []
placeholder: null
required: false
tooltips: null
type: checkbox
unit: null
variable: replace_consecutive_spaces
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: null
label: Delete all URLs and email addresses
max_length: 48
options: []
placeholder: null
required: false
tooltips: null
type: checkbox
unit: null
variable: delete_urls_email

View File

@ -0,0 +1,506 @@
dependencies:
- current_identifier: null
type: marketplace
value:
plugin_unique_identifier: langgenius/parentchild_chunker:0.0.1@b1a28a27e33fec442ce494da2a7814edd7eb9d646c81f38bccfcf1133d486e40
- current_identifier: null
type: marketplace
value:
plugin_unique_identifier: langgenius/notion_datasource:0.0.1@2dd49c2c3ffff976be8d22efb1ac0f63522a8d0f24ef8c44729d0a50a94ec039
kind: rag_pipeline
rag_pipeline:
description: ''
icon: 📙
icon_background: ''
icon_type: emoji
name: notion-parentchild
version: 0.1.0
workflow:
conversation_variables: []
environment_variables: []
features: {}
graph:
edges:
- data:
isInIteration: false
isInLoop: false
sourceType: datasource
targetType: tool
id: 1752489759475-source-1752490343805-target
source: '1752489759475'
sourceHandle: source
target: '1752490343805'
targetHandle: target
type: custom
zIndex: 0
- data:
isInLoop: false
sourceType: tool
targetType: knowledge-index
id: 1752490343805-source-1752477924228-target
source: '1752490343805'
sourceHandle: source
target: '1752477924228'
targetHandle: target
type: custom
zIndex: 0
nodes:
- data:
chunk_structure: hierarchical_model
embedding_model: text-embedding-ada-002
embedding_model_provider: langgenius/openai/openai
index_chunk_variable_selector:
- '1752490343805'
- result
indexing_technique: high_quality
keyword_number: 10
retrieval_model:
score_threshold: 0.5
score_threshold_enabled: false
search_method: semantic_search
top_k: 3
vector_setting:
embedding_model_name: text-embedding-ada-002
embedding_provider_name: langgenius/openai/openai
selected: false
title: Knowledge Base
type: knowledge-index
height: 114
id: '1752477924228'
position:
x: 1486.2052698032674
y: 281.3910724383104
positionAbsolute:
x: 1486.2052698032674
y: 281.3910724383104
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
datasource_configurations: {}
datasource_label: Notion数据源
datasource_name: notion_datasource
datasource_parameters: {}
plugin_id: langgenius/notion_datasource
provider_name: notion_datasource
provider_type: online_document
selected: false
title: Notion数据源
type: datasource
height: 52
id: '1752489759475'
position:
x: 736.9082104000458
y: 281.3910724383104
positionAbsolute:
x: 736.9082104000458
y: 281.3910724383104
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
is_team_authorization: true
output_schema:
properties:
result:
description: Parent child chunks result
items:
type: object
type: array
type: object
paramSchemas:
- auto_generate: null
default: null
form: llm
human_description:
en_US: The text you want to chunk.
ja_JP: チャンク化したいテキスト。
pt_BR: O texto que você deseja dividir.
zh_Hans: 你想要分块的文本。
label:
en_US: Input text
ja_JP: 入力テキスト
pt_BR: Texto de entrada
zh_Hans: 输入文本
llm_description: The text you want to chunk.
max: null
min: null
name: input_text
options: []
placeholder: null
precision: null
required: true
scope: null
template: null
type: string
- auto_generate: null
default: 1024
form: llm
human_description:
en_US: Maximum length for chunking
ja_JP: チャンク分割の最大長
pt_BR: Comprimento máximo para divisão
zh_Hans: 用于分块的最大长度
label:
en_US: Maximum Length
ja_JP: 最大長
pt_BR: Comprimento Máximo
zh_Hans: 最大长度
llm_description: Maximum length allowed per chunk
max: null
min: null
name: max_length
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: number
- auto_generate: null
default: '
'
form: llm
human_description:
en_US: Separator used for chunking
ja_JP: チャンク分割に使用する区切り文字
pt_BR: Separador usado para divisão
zh_Hans: 用于分块的分隔符
label:
en_US: Chunk Separator
ja_JP: チャンク区切り文字
pt_BR: Separador de Divisão
zh_Hans: 分块分隔符
llm_description: The separator used to split chunks
max: null
min: null
name: separator
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: string
- auto_generate: null
default: 512
form: llm
human_description:
en_US: Maximum length for subchunking
ja_JP: サブチャンク分割の最大長
pt_BR: Comprimento máximo para subdivisão
zh_Hans: 用于子分块的最大长度
label:
en_US: Subchunk Maximum Length
ja_JP: サブチャンク最大長
pt_BR: Comprimento Máximo de Subdivisão
zh_Hans: 子分块最大长度
llm_description: Maximum length allowed per subchunk
max: null
min: null
name: subchunk_max_length
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: number
- auto_generate: null
default: '. '
form: llm
human_description:
en_US: Separator used for subchunking
ja_JP: サブチャンク分割に使用する区切り文字
pt_BR: Separador usado para subdivisão
zh_Hans: 用于子分块的分隔符
label:
en_US: Subchunk Separator
ja_JP: サブチャンキング用セパレーター
pt_BR: Separador de Subdivisão
zh_Hans: 子分块分隔符
llm_description: The separator used to split subchunks
max: null
min: null
name: subchunk_separator
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: string
- auto_generate: null
default: paragraph
form: llm
human_description:
en_US: Split text into paragraphs based on separator and maximum chunk
length, using split text as parent block or entire document as parent
block and directly retrieve.
ja_JP: セパレーターと最大チャンク長に基づいてテキストを段落に分割し、分割されたテキスト
を親ブロックとして使用するか、文書全体を親ブロックとして使用して直接取得します。
pt_BR: Dividir texto em parágrafos com base no separador e no comprimento
máximo do bloco, usando o texto dividido como bloco pai ou documento
completo como bloco pai e diretamente recuperá-lo.
zh_Hans: 根据分隔符和最大块长度将文本拆分为段落,使用拆分文本作为检索的父块或整个文档用作父块并直接检索。
label:
en_US: Parent Mode
ja_JP: 親子モード
pt_BR: Modo Pai
zh_Hans: 父块模式
llm_description: Split text into paragraphs based on separator and maximum
chunk length, using split text as parent block or entire document as parent
block and directly retrieve.
max: null
min: null
name: parent_mode
options:
- icon: ''
label:
en_US: Paragraph
ja_JP: 段落
pt_BR: Parágrafo
zh_Hans: 段落
value: paragraph
- icon: ''
label:
en_US: Full Document
ja_JP: 全文
pt_BR: Documento Completo
zh_Hans: 全文
value: full_doc
placeholder: null
precision: null
required: true
scope: null
template: null
type: select
- auto_generate: null
default: 0
form: llm
human_description:
en_US: Whether to remove extra spaces in the text
ja_JP: テキスト内の余分なスペースを削除するかどうか
pt_BR: Se deve remover espaços extras no texto
zh_Hans: 是否移除文本中的多余空格
label:
en_US: Remove Extra Spaces
ja_JP: 余分なスペースを削除
pt_BR: Remover Espaços Extras
zh_Hans: 移除多余空格
llm_description: Whether to remove extra spaces in the text
max: null
min: null
name: remove_extra_spaces
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: boolean
- auto_generate: null
default: 0
form: llm
human_description:
en_US: Whether to remove URLs and emails in the text
ja_JP: テキスト内のURLやメールアドレスを削除するかどうか
pt_BR: Se deve remover URLs e e-mails no texto
zh_Hans: 是否移除文本中的URL和电子邮件地址
label:
en_US: Remove URLs and Emails
ja_JP: URLとメールアドレスを削除
pt_BR: Remover URLs e E-mails
zh_Hans: 移除URL和电子邮件地址
llm_description: Whether to remove URLs and emails in the text
max: null
min: null
name: remove_urls_emails
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: boolean
params:
input_text: ''
max_length: ''
parent_mode: ''
remove_extra_spaces: ''
remove_urls_emails: ''
separator: ''
subchunk_max_length: ''
subchunk_separator: ''
provider_id: langgenius/parentchild_chunker/parentchild_chunker
provider_name: langgenius/parentchild_chunker/parentchild_chunker
provider_type: builtin
selected: true
title: Parent-child Chunker
tool_configurations: {}
tool_description: Parent-child Chunk Structure
tool_label: Parent-child Chunker
tool_name: parentchild_chunker
tool_parameters:
input_text:
type: mixed
value: '{{#1752489759475.content#}}'
max_length:
type: variable
value:
- rag
- shared
- max_chunk_length
parent_mode:
type: variable
value:
- rag
- shared
- parent_mode
remove_extra_spaces:
type: mixed
value: '{{#rag.shared.replace_consecutive_spaces#}}'
remove_urls_emails:
type: mixed
value: '{{#rag.shared.delete_urls_email#}}'
separator:
type: mixed
value: '{{#rag.shared.delimiter#}}'
subchunk_max_length:
type: variable
value:
- rag
- shared
- child_max_chunk_length
subchunk_separator:
type: mixed
value: '{{#rag.shared.child_delimiter#}}'
type: tool
height: 52
id: '1752490343805'
position:
x: 1077.0240183162543
y: 281.3910724383104
positionAbsolute:
x: 1077.0240183162543
y: 281.3910724383104
selected: true
sourcePosition: right
targetPosition: left
type: custom
width: 242
viewport:
x: -487.2912544090391
y: -54.7029301848807
zoom: 0.9994011715768695
rag_pipeline_variables:
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: \n\n
label: Delimiter
max_length: 100
options: []
placeholder: null
required: true
tooltips: A delimiter is the character used to separate text. \n\n is recommended
for splitting the original document into large parent chunks. You can also use
special delimiters defined by yourself.
type: text-input
unit: null
variable: delimiter
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: 1024
label: Maximum chunk length
max_length: 48
options: []
placeholder: null
required: true
tooltips: null
type: number
unit: characters
variable: max_chunk_length
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: \n
label: Child delimiter
max_length: 199
options: []
placeholder: null
required: true
tooltips: A delimiter is the character used to separate text. \n\n is recommended
for splitting the original document into large parent chunks. You can also use
special delimiters defined by yourself.
type: text-input
unit: null
variable: child_delimiter
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: 512
label: Child max chunk length
max_length: 48
options: []
placeholder: null
required: true
tooltips: null
type: number
unit: characters
variable: child_max_chunk_length
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: paragraph
label: Parent mode
max_length: 48
options:
- full_doc
- paragraph
placeholder: null
required: true
tooltips: null
type: select
unit: null
variable: parent_mode
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: null
label: Replace consecutive spaces, newlines and tabs
max_length: 48
options: []
placeholder: null
required: false
tooltips: null
type: checkbox
unit: null
variable: replace_consecutive_spaces
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: null
label: Delete all URLs and email addresses
max_length: 48
options: []
placeholder: null
required: false
tooltips: null
type: checkbox
unit: null
variable: delete_urls_email

View File

@ -0,0 +1,674 @@
dependencies:
- current_identifier: null
type: marketplace
value:
plugin_unique_identifier: langgenius/general_chunker:0.0.1@e3da408b7277866404c3f884d599261f9d0b9003ea4ef7eb3b64489bdf39d18b
- current_identifier: null
type: marketplace
value:
plugin_unique_identifier: langgenius/firecrawl_datasource:0.0.1@f7aed0a26df0e5f4b9555371b5c9fa6db3c7dcf6a46dd1583245697bd90a539a
- current_identifier: null
type: marketplace
value:
plugin_unique_identifier: langgenius/jina_datasource:0.0.1@cf23afb2c3eeccc5a187763a1947f583f0bb10aa56461e512ac4141bf930d608
kind: rag_pipeline
rag_pipeline:
description: ''
icon: 📙
icon_background: ''
icon_type: emoji
name: website-crawl-general-economy
version: 0.1.0
workflow:
conversation_variables: []
environment_variables: []
features: {}
graph:
edges:
- data:
isInIteration: false
isInLoop: false
sourceType: datasource
targetType: variable-aggregator
id: 1752491761974-source-1752565435219-target
source: '1752491761974'
sourceHandle: source
target: '1752565435219'
targetHandle: target
type: custom
zIndex: 0
- data:
isInLoop: false
sourceType: datasource
targetType: variable-aggregator
id: 1752565402678-source-1752565435219-target
source: '1752565402678'
sourceHandle: source
target: '1752565435219'
targetHandle: target
type: custom
zIndex: 0
- data:
isInIteration: false
isInLoop: false
sourceType: variable-aggregator
targetType: tool
id: 1752565435219-source-1752569675978-target
source: '1752565435219'
sourceHandle: source
target: '1752569675978'
targetHandle: target
type: custom
zIndex: 0
- data:
isInLoop: false
sourceType: tool
targetType: knowledge-index
id: 1752569675978-source-1752477924228-target
source: '1752569675978'
sourceHandle: source
target: '1752477924228'
targetHandle: target
type: custom
zIndex: 0
nodes:
- data:
chunk_structure: text_model
embedding_model: text-embedding-ada-002
embedding_model_provider: langgenius/openai/openai
index_chunk_variable_selector:
- '1752569675978'
- result
indexing_technique: economy
keyword_number: 10
retrieval_model:
score_threshold: 0.5
score_threshold_enabled: false
search_method: keyword_search
top_k: 3
vector_setting:
embedding_model_name: text-embedding-ada-002
embedding_provider_name: langgenius/openai/openai
selected: true
title: Knowledge Base
type: knowledge-index
height: 114
id: '1752477924228'
position:
x: 2140.4053851189346
y: 281.3910724383104
positionAbsolute:
x: 2140.4053851189346
y: 281.3910724383104
selected: true
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
datasource_configurations: {}
datasource_label: Jina Reader
datasource_name: jina_reader
datasource_parameters:
crawl_sub_pages:
type: mixed
value: '{{#rag.1752491761974.jina_crawl_sub_pages#}}'
limit:
type: variable
value:
- rag
- '1752491761974'
- jina_limit
url:
type: mixed
value: '{{#rag.1752491761974.jina_url#}}'
use_sitemap:
type: mixed
value: '{{#rag.1752491761974.jina_use_sitemap#}}'
plugin_id: langgenius/jina_datasource
provider_name: jina
provider_type: website_crawl
selected: false
title: Jina Reader
type: datasource
height: 52
id: '1752491761974'
position:
x: 1067.7526055798794
y: 281.3910724383104
positionAbsolute:
x: 1067.7526055798794
y: 281.3910724383104
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
datasource_configurations: {}
datasource_label: Firecrawl
datasource_name: crawl
datasource_parameters:
crawl_subpages:
type: mixed
value: '{{#rag.1752565402678.firecrawl_crawl_sub_pages#}}'
exclude_paths:
type: mixed
value: '{{#rag.1752565402678.firecrawl_exclude_paths#}}'
include_paths:
type: mixed
value: '{{#rag.1752565402678.firecrawl_include_only_paths#}}'
limit:
type: variable
value:
- rag
- '1752565402678'
- firecrawl_limit
max_depth:
type: variable
value:
- rag
- '1752565402678'
- firecrawl_max_depth
only_main_content:
type: mixed
value: '{{#rag.1752565402678.firecrawl_extract_main_content#}}'
url:
type: mixed
value: '{{#rag.1752565402678.firecrawl_url#}}'
plugin_id: langgenius/firecrawl_datasource
provider_name: firecrawl
provider_type: website_crawl
selected: false
title: Firecrawl
type: datasource
height: 52
id: '1752565402678'
position:
x: 1067.7526055798794
y: 417.32608398342404
positionAbsolute:
x: 1067.7526055798794
y: 417.32608398342404
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
output_type: string
selected: false
title: Variable Aggregator
type: variable-aggregator
variables:
- - '1752491761974'
- content
- - '1752565402678'
- content
height: 129
id: '1752565435219'
position:
x: 1505.4306671642219
y: 281.3910724383104
positionAbsolute:
x: 1505.4306671642219
y: 281.3910724383104
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
is_team_authorization: true
output_schema:
properties:
result:
description: The result of the general chunk tool.
properties:
general_chunks:
items:
description: The chunk of the text.
type: string
type: array
type: object
type: object
paramSchemas:
- auto_generate: null
default: null
form: llm
human_description:
en_US: The text you want to chunk.
ja_JP: チャンク化したいテキスト。
pt_BR: O texto que você deseja dividir.
zh_Hans: 你想要分块的文本。
label:
en_US: Input Variable
ja_JP: 入力変数
pt_BR: Variável de entrada
zh_Hans: 输入变量
llm_description: The text you want to chunk.
max: null
min: null
name: input_variable
options: []
placeholder: null
precision: null
required: true
scope: null
template: null
type: string
- auto_generate: null
default: null
form: llm
human_description:
en_US: The delimiter of the chunks.
ja_JP: チャンクの区切り記号。
pt_BR: O delimitador dos pedaços.
zh_Hans: 块的分隔符。
label:
en_US: Delimiter
ja_JP: 区切り記号
pt_BR: Delimitador
zh_Hans: 分隔符
llm_description: The delimiter of the chunks, the format of the delimiter
must be a string.
max: null
min: null
name: delimiter
options: []
placeholder: null
precision: null
required: true
scope: null
template: null
type: string
- auto_generate: null
default: null
form: llm
human_description:
en_US: The maximum chunk length.
ja_JP: 最大長のチャンク。
pt_BR: O comprimento máximo do bloco
zh_Hans: 最大块的长度。
label:
en_US: Maximum Chunk Length
ja_JP: チャンク最大長
pt_BR: O comprimento máximo do bloco
zh_Hans: 最大块的长度
llm_description: The maximum chunk length, the format of the chunk size
must be an integer.
max: null
min: null
name: max_chunk_length
options: []
placeholder: null
precision: null
required: true
scope: null
template: null
type: number
- auto_generate: null
default: null
form: llm
human_description:
en_US: The chunk overlap length.
ja_JP: チャンクの重複長
pt_BR: The chunk overlap length.
zh_Hans: 块的重叠长度。
label:
en_US: Chunk Overlap Length
ja_JP: チャンク重複長
pt_BR: Chunk Overlap Length
zh_Hans: 块的重叠长度
llm_description: The chunk overlap length, the format of the chunk overlap
length must be an integer.
max: null
min: null
name: chunk_overlap_length
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: number
- auto_generate: null
default: null
form: llm
human_description:
en_US: Replace consecutive spaces, newlines and tabs
ja_JP: 連続のスペース、改行、まだはタブを置換する
pt_BR: Replace consecutive spaces, newlines and tabs
zh_Hans: 替换连续的空格、换行符和制表符
label:
en_US: Replace Consecutive Spaces, Newlines and Tabs
ja_JP: 連続のスペース、改行、まだはタブを置換する
pt_BR: Replace Consecutive Spaces, Newlines and Tabs
zh_Hans: 替换连续的空格、换行符和制表符
llm_description: Replace consecutive spaces, newlines and tabs, the format
of the replace must be a boolean.
max: null
min: null
name: replace_consecutive_spaces_newlines_tabs
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: boolean
- auto_generate: null
default: null
form: llm
human_description:
en_US: Delete all URLs and email addresses
ja_JP: すべてのURLとメールアドレスを削除する
pt_BR: Delete all URLs and email addresses
zh_Hans: 删除所有URL和电子邮件地址
label:
en_US: Delete All URLs and Email Addresses
ja_JP: すべてのURLとメールアドレスを削除する
pt_BR: Delete All URLs and Email Addresses
zh_Hans: 删除所有URL和电子邮件地址
llm_description: Delete all URLs and email addresses, the format of the
delete must be a boolean.
max: null
min: null
name: delete_all_urls_and_email_addresses
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: boolean
params:
chunk_overlap_length: ''
delete_all_urls_and_email_addresses: ''
delimiter: ''
input_variable: ''
max_chunk_length: ''
replace_consecutive_spaces_newlines_tabs: ''
provider_id: langgenius/general_chunker/general_chunker
provider_name: langgenius/general_chunker/general_chunker
provider_type: builtin
selected: false
title: General Chunker
tool_configurations: {}
tool_description: A tool for general text chunking mode, the chunks retrieved and recalled are the same.
tool_label: General Chunker
tool_name: general_chunker
tool_parameters:
chunk_overlap_length:
type: variable
value:
- rag
- shared
- chunk_overlap
delete_all_urls_and_email_addresses:
type: mixed
value: '{{#rag.shared.delete_urls_email#}}'
delimiter:
type: mixed
value: '{{#rag.shared.delimiter#}}'
input_variable:
type: mixed
value: '{{#1752565435219.output#}}'
max_chunk_length:
type: variable
value:
- rag
- shared
- max_chunk_length
replace_consecutive_spaces_newlines_tabs:
type: mixed
value: '{{#rag.shared.replace_consecutive_spaces#}}'
type: tool
height: 52
id: '1752569675978'
position:
x: 1807.4306671642219
y: 281.3910724383104
positionAbsolute:
x: 1807.4306671642219
y: 281.3910724383104
sourcePosition: right
targetPosition: left
type: custom
width: 242
viewport:
x: -707.721097109337
y: -93.07807382100896
zoom: 0.9350632198875476
rag_pipeline_variables:
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752491761974'
default_value: null
label: URL
max_length: 256
options: []
placeholder: https://docs.dify.ai/en/
required: true
tooltips: null
type: text-input
unit: null
variable: jina_url
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752491761974'
default_value: 10
label: Limit
max_length: 48
options: []
placeholder: null
required: true
tooltips: null
type: number
unit: null
variable: jina_limit
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752491761974'
default_value: null
label: Crawl sub-pages
max_length: 48
options: []
placeholder: null
required: false
tooltips: null
type: checkbox
unit: null
variable: jina_crawl_sub_pages
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752491761974'
default_value: null
label: Use sitemap
max_length: 48
options: []
placeholder: null
required: false
tooltips: Follow the sitemap to crawl the site. If not, Jina Reader will crawl
iteratively based on page relevance, yielding fewer but higher-quality pages.
type: checkbox
unit: null
variable: jina_use_sitemap
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752565402678'
default_value: null
label: URL
max_length: 256
options: []
placeholder: https://docs.dify.ai/en/
required: true
tooltips: null
type: text-input
unit: null
variable: firecrawl_url
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752565402678'
default_value: true
label: Crawl sub-pages
max_length: 48
options: []
placeholder: null
required: false
tooltips: null
type: checkbox
unit: null
variable: firecrawl_crawl_sub_pages
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752565402678'
default_value: 10
label: Limit
max_length: 48
options: []
placeholder: null
required: true
tooltips: null
type: number
unit: null
variable: firecrawl_limit
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752565402678'
default_value: null
label: Max depth
max_length: 48
options: []
placeholder: ''
required: false
tooltips: Maximum depth to crawl relative to the entered URL. Depth 0 just scrapes
the page of the entered url, depth 1 scrapes the url and everything after enteredURL
+ one /, and so on.
type: number
unit: null
variable: firecrawl_max_depth
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752565402678'
default_value: null
label: Exclude paths
max_length: 256
options: []
placeholder: blog/*, /about/*
required: false
tooltips: null
type: text-input
unit: null
variable: firecrawl_exclude_paths
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752565402678'
default_value: null
label: Include only paths
max_length: 256
options: []
placeholder: articles/*
required: false
tooltips: null
type: text-input
unit: null
variable: firecrawl_include_only_paths
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752565402678'
default_value: null
label: firecrawl_extract_main_content
max_length: 48
options: []
placeholder: null
required: false
tooltips: null
type: checkbox
unit: null
variable: firecrawl_extract_main_content
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: \n\n
label: Delimiter
max_length: 100
options: []
placeholder: null
required: true
tooltips: A delimiter is the character used to separate text. \n\n is recommended
for splitting the original document into large parent chunks. You can also use
special delimiters defined by yourself.
type: text-input
unit: null
variable: delimiter
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: 1024
label: Maximum chunk length
max_length: 48
options: []
placeholder: null
required: true
tooltips: null
type: number
unit: characters
variable: max_chunk_length
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: 50
label: chunk_overlap
max_length: 48
options: []
placeholder: null
required: false
tooltips: Setting the chunk overlap can maintain the semantic relevance between
them, enhancing the retrieve effect. It is recommended to set 10%25% of the
maximum chunk size.
type: number
unit: characters
variable: chunk_overlap
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: null
label: replace_consecutive_spaces
max_length: 48
options: []
placeholder: null
required: false
tooltips: null
type: checkbox
unit: null
variable: replace_consecutive_spaces
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: null
label: Delete all URLs and email addresses
max_length: 48
options: []
placeholder: null
required: false
tooltips: null
type: checkbox
unit: null
variable: delete_urls_email

View File

@ -0,0 +1,674 @@
dependencies:
- current_identifier: null
type: marketplace
value:
plugin_unique_identifier: langgenius/general_chunker:0.0.1@e3da408b7277866404c3f884d599261f9d0b9003ea4ef7eb3b64489bdf39d18b
- current_identifier: null
type: marketplace
value:
plugin_unique_identifier: langgenius/firecrawl_datasource:0.0.1@f7aed0a26df0e5f4b9555371b5c9fa6db3c7dcf6a46dd1583245697bd90a539a
- current_identifier: null
type: marketplace
value:
plugin_unique_identifier: langgenius/jina_datasource:0.0.1@cf23afb2c3eeccc5a187763a1947f583f0bb10aa56461e512ac4141bf930d608
kind: rag_pipeline
rag_pipeline:
description: ''
icon: 📙
icon_background: '#FFF4ED'
icon_type: emoji
name: website-crawl-general-high-quality
version: 0.1.0
workflow:
conversation_variables: []
environment_variables: []
features: {}
graph:
edges:
- data:
isInIteration: false
isInLoop: false
sourceType: datasource
targetType: variable-aggregator
id: 1752491761974-source-1752565435219-target
source: '1752491761974'
sourceHandle: source
target: '1752565435219'
targetHandle: target
type: custom
zIndex: 0
- data:
isInLoop: false
sourceType: datasource
targetType: variable-aggregator
id: 1752565402678-source-1752565435219-target
source: '1752565402678'
sourceHandle: source
target: '1752565435219'
targetHandle: target
type: custom
zIndex: 0
- data:
isInIteration: false
isInLoop: false
sourceType: variable-aggregator
targetType: tool
id: 1752565435219-source-1752569675978-target
source: '1752565435219'
sourceHandle: source
target: '1752569675978'
targetHandle: target
type: custom
zIndex: 0
- data:
isInLoop: false
sourceType: tool
targetType: knowledge-index
id: 1752569675978-source-1752477924228-target
source: '1752569675978'
sourceHandle: source
target: '1752477924228'
targetHandle: target
type: custom
zIndex: 0
nodes:
- data:
chunk_structure: text_model
embedding_model: text-embedding-ada-002
embedding_model_provider: langgenius/openai/openai
index_chunk_variable_selector:
- '1752569675978'
- result
indexing_technique: high_quality
keyword_number: 10
retrieval_model:
score_threshold: 0.5
score_threshold_enabled: false
search_method: semantic_search
top_k: 3
vector_setting:
embedding_model_name: text-embedding-ada-002
embedding_provider_name: langgenius/openai/openai
selected: false
title: Knowledge Base
type: knowledge-index
height: 114
id: '1752477924228'
position:
x: 2140.4053851189346
y: 281.3910724383104
positionAbsolute:
x: 2140.4053851189346
y: 281.3910724383104
selected: true
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
datasource_configurations: {}
datasource_label: Jina Reader
datasource_name: jina_reader
datasource_parameters:
crawl_sub_pages:
type: mixed
value: '{{#rag.1752491761974.jina_crawl_sub_pages#}}'
limit:
type: variable
value:
- rag
- '1752491761974'
- jina_limit
url:
type: mixed
value: '{{#rag.1752491761974.jina_url#}}'
use_sitemap:
type: mixed
value: '{{#rag.1752491761974.jina_use_sitemap#}}'
plugin_id: langgenius/jina_datasource
provider_name: jina
provider_type: website_crawl
selected: false
title: Jina Reader
type: datasource
height: 52
id: '1752491761974'
position:
x: 1067.7526055798794
y: 281.3910724383104
positionAbsolute:
x: 1067.7526055798794
y: 281.3910724383104
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
datasource_configurations: {}
datasource_label: Firecrawl
datasource_name: crawl
datasource_parameters:
crawl_subpages:
type: mixed
value: '{{#rag.1752565402678.firecrawl_crawl_sub_pages#}}'
exclude_paths:
type: mixed
value: '{{#rag.1752565402678.firecrawl_exclude_paths#}}'
include_paths:
type: mixed
value: '{{#rag.1752565402678.firecrawl_include_only_paths#}}'
limit:
type: variable
value:
- rag
- '1752565402678'
- firecrawl_limit
max_depth:
type: variable
value:
- rag
- '1752565402678'
- firecrawl_max_depth
only_main_content:
type: mixed
value: '{{#rag.1752565402678.firecrawl_extract_main_content#}}'
url:
type: mixed
value: '{{#rag.1752565402678.firecrawl_url#}}'
plugin_id: langgenius/firecrawl_datasource
provider_name: firecrawl
provider_type: website_crawl
selected: false
title: Firecrawl
type: datasource
height: 52
id: '1752565402678'
position:
x: 1067.7526055798794
y: 417.32608398342404
positionAbsolute:
x: 1067.7526055798794
y: 417.32608398342404
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
output_type: string
selected: false
title: Variable Aggregator
type: variable-aggregator
variables:
- - '1752491761974'
- content
- - '1752565402678'
- content
height: 129
id: '1752565435219'
position:
x: 1505.4306671642219
y: 281.3910724383104
positionAbsolute:
x: 1505.4306671642219
y: 281.3910724383104
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
is_team_authorization: true
output_schema:
properties:
result:
description: The result of the general chunk tool.
properties:
general_chunks:
items:
description: The chunk of the text.
type: string
type: array
type: object
type: object
paramSchemas:
- auto_generate: null
default: null
form: llm
human_description:
en_US: The text you want to chunk.
ja_JP: チャンク化したいテキスト。
pt_BR: O texto que você deseja dividir.
zh_Hans: 你想要分块的文本。
label:
en_US: Input Variable
ja_JP: 入力変数
pt_BR: Variável de entrada
zh_Hans: 输入变量
llm_description: The text you want to chunk.
max: null
min: null
name: input_variable
options: []
placeholder: null
precision: null
required: true
scope: null
template: null
type: string
- auto_generate: null
default: null
form: llm
human_description:
en_US: The delimiter of the chunks.
ja_JP: チャンクの区切り記号。
pt_BR: O delimitador dos pedaços.
zh_Hans: 块的分隔符。
label:
en_US: Delimiter
ja_JP: 区切り記号
pt_BR: Delimitador
zh_Hans: 分隔符
llm_description: The delimiter of the chunks, the format of the delimiter
must be a string.
max: null
min: null
name: delimiter
options: []
placeholder: null
precision: null
required: true
scope: null
template: null
type: string
- auto_generate: null
default: null
form: llm
human_description:
en_US: The maximum chunk length.
ja_JP: 最大長のチャンク。
pt_BR: O comprimento máximo do bloco
zh_Hans: 最大块的长度。
label:
en_US: Maximum Chunk Length
ja_JP: チャンク最大長
pt_BR: O comprimento máximo do bloco
zh_Hans: 最大块的长度
llm_description: The maximum chunk length, the format of the chunk size
must be an integer.
max: null
min: null
name: max_chunk_length
options: []
placeholder: null
precision: null
required: true
scope: null
template: null
type: number
- auto_generate: null
default: null
form: llm
human_description:
en_US: The chunk overlap length.
ja_JP: チャンクの重複長。
pt_BR: The chunk overlap length.
zh_Hans: 块的重叠长度。
label:
en_US: Chunk Overlap Length
ja_JP: チャンク重複長
pt_BR: Chunk Overlap Length
zh_Hans: 块的重叠长度
llm_description: The chunk overlap length, the format of the chunk overlap
length must be an integer.
max: null
min: null
name: chunk_overlap_length
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: number
- auto_generate: null
default: null
form: llm
human_description:
en_US: Replace consecutive spaces, newlines and tabs
ja_JP: 連続のスペース、改行、まだはタブを置換する
pt_BR: Replace consecutive spaces, newlines and tabs
zh_Hans: 替换连续的空格、换行符和制表符
label:
en_US: Replace Consecutive Spaces, Newlines and Tabs
ja_JP: 連続のスペース、改行、まだはタブを置換する
pt_BR: Replace Consecutive Spaces, Newlines and Tabs
zh_Hans: 替换连续的空格、换行符和制表符
llm_description: Replace consecutive spaces, newlines and tabs, the format
of the replace must be a boolean.
max: null
min: null
name: replace_consecutive_spaces_newlines_tabs
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: boolean
- auto_generate: null
default: null
form: llm
human_description:
en_US: Delete all URLs and email addresses
ja_JP: すべてのURLとメールアドレスを削除する
pt_BR: Delete all URLs and email addresses
zh_Hans: 删除所有URL和电子邮件地址
label:
en_US: Delete All URLs and Email Addresses
ja_JP: すべてのURLとメールアドレスを削除する
pt_BR: Delete All URLs and Email Addresses
zh_Hans: 删除所有URL和电子邮件地址
llm_description: Delete all URLs and email addresses, the format of the
delete must be a boolean.
max: null
min: null
name: delete_all_urls_and_email_addresses
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: boolean
params:
chunk_overlap_length: ''
delete_all_urls_and_email_addresses: ''
delimiter: ''
input_variable: ''
max_chunk_length: ''
replace_consecutive_spaces_newlines_tabs: ''
provider_id: langgenius/general_chunker/general_chunker
provider_name: langgenius/general_chunker/general_chunker
provider_type: builtin
selected: false
title: General Chunker
tool_configurations: {}
tool_description: A tool for general text chunking mode, the chunks retrieved and recalled are the same.
tool_label: General Chunker
tool_name: general_chunker
tool_parameters:
chunk_overlap_length:
type: variable
value:
- rag
- shared
- chunk_overlap
delete_all_urls_and_email_addresses:
type: mixed
value: '{{#rag.shared.delete_urls_email#}}'
delimiter:
type: mixed
value: '{{#rag.shared.delimiter#}}'
input_variable:
type: mixed
value: '{{#1752565435219.output#}}'
max_chunk_length:
type: variable
value:
- rag
- shared
- max_chunk_length
replace_consecutive_spaces_newlines_tabs:
type: mixed
value: '{{#rag.shared.replace_consecutive_spaces#}}'
type: tool
height: 52
id: '1752569675978'
position:
x: 1807.4306671642219
y: 281.3910724383104
positionAbsolute:
x: 1807.4306671642219
y: 281.3910724383104
sourcePosition: right
targetPosition: left
type: custom
width: 242
viewport:
x: -707.721097109337
y: -93.07807382100896
zoom: 0.9350632198875476
rag_pipeline_variables:
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752491761974'
default_value: null
label: URL
max_length: 256
options: []
placeholder: https://docs.dify.ai/en/
required: true
tooltips: null
type: text-input
unit: null
variable: jina_url
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752491761974'
default_value: 10
label: Limit
max_length: 48
options: []
placeholder: null
required: true
tooltips: null
type: number
unit: null
variable: jina_limit
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752491761974'
default_value: null
label: Crawl sub-pages
max_length: 48
options: []
placeholder: null
required: false
tooltips: null
type: checkbox
unit: null
variable: jina_crawl_sub_pages
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752491761974'
default_value: null
label: Use sitemap
max_length: 48
options: []
placeholder: null
required: false
tooltips: Follow the sitemap to crawl the site. If not, Jina Reader will crawl
iteratively based on page relevance, yielding fewer but higher-quality pages.
type: checkbox
unit: null
variable: jina_use_sitemap
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752565402678'
default_value: null
label: URL
max_length: 256
options: []
placeholder: https://docs.dify.ai/en/
required: true
tooltips: null
type: text-input
unit: null
variable: firecrawl_url
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752565402678'
default_value: true
label: Crawl sub-pages
max_length: 48
options: []
placeholder: null
required: false
tooltips: null
type: checkbox
unit: null
variable: firecrawl_crawl_sub_pages
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752565402678'
default_value: 10
label: Limit
max_length: 48
options: []
placeholder: null
required: true
tooltips: null
type: number
unit: null
variable: firecrawl_limit
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752565402678'
default_value: null
label: Max depth
max_length: 48
options: []
placeholder: ''
required: false
tooltips: Maximum depth to crawl relative to the entered URL. Depth 0 just scrapes
the page of the entered url, depth 1 scrapes the url and everything after enteredURL
+ one /, and so on.
type: number
unit: null
variable: firecrawl_max_depth
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752565402678'
default_value: null
label: Exclude paths
max_length: 256
options: []
placeholder: blog/*, /about/*
required: false
tooltips: null
type: text-input
unit: null
variable: firecrawl_exclude_paths
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752565402678'
default_value: null
label: Include only paths
max_length: 256
options: []
placeholder: articles/*
required: false
tooltips: null
type: text-input
unit: null
variable: firecrawl_include_only_paths
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752565402678'
default_value: null
label: firecrawl_extract_main_content
max_length: 48
options: []
placeholder: null
required: false
tooltips: null
type: checkbox
unit: null
variable: firecrawl_extract_main_content
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: \n\n
label: Delimiter
max_length: 100
options: []
placeholder: null
required: true
tooltips: A delimiter is the character used to separate text. \n\n is recommended
for splitting the original document into large parent chunks. You can also use
special delimiters defined by yourself.
type: text-input
unit: null
variable: delimiter
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: 1024
label: Maximum chunk length
max_length: 48
options: []
placeholder: null
required: true
tooltips: null
type: number
unit: characters
variable: max_chunk_length
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: 50
label: chunk_overlap
max_length: 48
options: []
placeholder: null
required: false
tooltips: Setting the chunk overlap can maintain the semantic relevance between
them, enhancing the retrieve effect. It is recommended to set 10%25% of the
maximum chunk size.
type: number
unit: characters
variable: chunk_overlap
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: null
label: replace_consecutive_spaces
max_length: 48
options: []
placeholder: null
required: false
tooltips: null
type: checkbox
unit: null
variable: replace_consecutive_spaces
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: null
label: Delete all URLs and email addresses
max_length: 48
options: []
placeholder: null
required: false
tooltips: null
type: checkbox
unit: null
variable: delete_urls_email

View File

@ -0,0 +1,779 @@
dependencies:
- current_identifier: null
type: marketplace
value:
plugin_unique_identifier: langgenius/parentchild_chunker:0.0.1@b1a28a27e33fec442ce494da2a7814edd7eb9d646c81f38bccfcf1133d486e40
- current_identifier: null
type: marketplace
value:
plugin_unique_identifier: langgenius/firecrawl_datasource:0.0.1@f7aed0a26df0e5f4b9555371b5c9fa6db3c7dcf6a46dd1583245697bd90a539a
- current_identifier: null
type: marketplace
value:
plugin_unique_identifier: langgenius/jina_datasource:0.0.1@cf23afb2c3eeccc5a187763a1947f583f0bb10aa56461e512ac4141bf930d608
kind: rag_pipeline
rag_pipeline:
description: ''
icon: 📙
icon_background: ''
icon_type: emoji
name: website-crawl-parentchild
version: 0.1.0
workflow:
conversation_variables: []
environment_variables: []
features: {}
graph:
edges:
- data:
isInLoop: false
sourceType: tool
targetType: knowledge-index
id: 1752490343805-source-1752477924228-target
source: '1752490343805'
sourceHandle: source
target: '1752477924228'
targetHandle: target
type: custom
zIndex: 0
- data:
isInIteration: false
isInLoop: false
sourceType: datasource
targetType: variable-aggregator
id: 1752491761974-source-1752565435219-target
source: '1752491761974'
sourceHandle: source
target: '1752565435219'
targetHandle: target
type: custom
zIndex: 0
- data:
isInIteration: false
isInLoop: false
sourceType: variable-aggregator
targetType: tool
id: 1752565435219-source-1752490343805-target
source: '1752565435219'
sourceHandle: source
target: '1752490343805'
targetHandle: target
type: custom
zIndex: 0
- data:
isInLoop: false
sourceType: datasource
targetType: variable-aggregator
id: 1752565402678-source-1752565435219-target
source: '1752565402678'
sourceHandle: source
target: '1752565435219'
targetHandle: target
type: custom
zIndex: 0
nodes:
- data:
chunk_structure: hierarchical_model
embedding_model: text-embedding-ada-002
embedding_model_provider: langgenius/openai/openai
index_chunk_variable_selector:
- '1752490343805'
- result
indexing_technique: high_quality
keyword_number: 10
retrieval_model:
score_threshold: 0.5
score_threshold_enabled: false
search_method: semantic_search
top_k: 3
vector_setting:
embedding_model_name: text-embedding-ada-002
embedding_provider_name: langgenius/openai/openai
selected: false
title: Knowledge Base
type: knowledge-index
height: 114
id: '1752477924228'
position:
x: 2215.5544306817387
y: 281.3910724383104
positionAbsolute:
x: 2215.5544306817387
y: 281.3910724383104
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
is_team_authorization: true
output_schema:
properties:
result:
description: Parent child chunks result
items:
type: object
type: array
type: object
paramSchemas:
- auto_generate: null
default: null
form: llm
human_description:
en_US: The text you want to chunk.
ja_JP: チャンク化したいテキスト。
pt_BR: O texto que você deseja dividir.
zh_Hans: 你想要分块的文本。
label:
en_US: Input text
ja_JP: 入力テキスト
pt_BR: Texto de entrada
zh_Hans: 输入文本
llm_description: The text you want to chunk.
max: null
min: null
name: input_text
options: []
placeholder: null
precision: null
required: true
scope: null
template: null
type: string
- auto_generate: null
default: 1024
form: llm
human_description:
en_US: Maximum length for chunking
ja_JP: チャンク分割の最大長
pt_BR: Comprimento máximo para divisão
zh_Hans: 用于分块的最大长度
label:
en_US: Maximum Length
ja_JP: 最大長
pt_BR: Comprimento Máximo
zh_Hans: 最大长度
llm_description: Maximum length allowed per chunk
max: null
min: null
name: max_length
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: number
- auto_generate: null
default: '
'
form: llm
human_description:
en_US: Separator used for chunking
ja_JP: チャンク分割に使用する区切り文字
pt_BR: Separador usado para divisão
zh_Hans: 用于分块的分隔符
label:
en_US: Chunk Separator
ja_JP: チャンク区切り文字
pt_BR: Separador de Divisão
zh_Hans: 分块分隔符
llm_description: The separator used to split chunks
max: null
min: null
name: separator
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: string
- auto_generate: null
default: 512
form: llm
human_description:
en_US: Maximum length for subchunking
ja_JP: サブチャンク分割の最大長
pt_BR: Comprimento máximo para subdivisão
zh_Hans: 用于子分块的最大长度
label:
en_US: Subchunk Maximum Length
ja_JP: サブチャンク最大長
pt_BR: Comprimento Máximo de Subdivisão
zh_Hans: 子分块最大长度
llm_description: Maximum length allowed per subchunk
max: null
min: null
name: subchunk_max_length
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: number
- auto_generate: null
default: '. '
form: llm
human_description:
en_US: Separator used for subchunking
ja_JP: サブチャンク分割に使用する区切り文字
pt_BR: Separador usado para subdivisão
zh_Hans: 用于子分块的分隔符
label:
en_US: Subchunk Separator
ja_JP: サブチャンキング用セパレーター
pt_BR: Separador de Subdivisão
zh_Hans: 子分块分隔符
llm_description: The separator used to split subchunks
max: null
min: null
name: subchunk_separator
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: string
- auto_generate: null
default: paragraph
form: llm
human_description:
en_US: Split text into paragraphs based on separator and maximum chunk
length, using split text as parent block or entire document as parent
block and directly retrieve.
ja_JP: セパレーターと最大チャンク長に基づいてテキストを段落に分割し、分割されたテキスト
を親ブロックとして使用するか、文書全体を親ブロックとして使用して直接取得します。
pt_BR: Dividir texto em parágrafos com base no separador e no comprimento
máximo do bloco, usando o texto dividido como bloco pai ou documento
completo como bloco pai e diretamente recuperá-lo.
zh_Hans: 根据分隔符和最大块长度将文本拆分为段落,使用拆分文本作为检索的父块或整个文档用作父块并直接检索。
label:
en_US: Parent Mode
ja_JP: 親子モード
pt_BR: Modo Pai
zh_Hans: 父块模式
llm_description: Split text into paragraphs based on separator and maximum
chunk length, using split text as parent block or entire document as parent
block and directly retrieve.
max: null
min: null
name: parent_mode
options:
- icon: ''
label:
en_US: Paragraph
ja_JP: 段落
pt_BR: Parágrafo
zh_Hans: 段落
value: paragraph
- icon: ''
label:
en_US: Full Document
ja_JP: 全文
pt_BR: Documento Completo
zh_Hans: 全文
value: full_doc
placeholder: null
precision: null
required: true
scope: null
template: null
type: select
- auto_generate: null
default: 0
form: llm
human_description:
en_US: Whether to remove extra spaces in the text
ja_JP: テキスト内の余分なスペースを削除するかどうか
pt_BR: Se deve remover espaços extras no texto
zh_Hans: 是否移除文本中的多余空格
label:
en_US: Remove Extra Spaces
ja_JP: 余分なスペースを削除
pt_BR: Remover Espaços Extras
zh_Hans: 移除多余空格
llm_description: Whether to remove extra spaces in the text
max: null
min: null
name: remove_extra_spaces
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: boolean
- auto_generate: null
default: 0
form: llm
human_description:
en_US: Whether to remove URLs and emails in the text
ja_JP: テキスト内のURLやメールアドレスを削除するかどうか
pt_BR: Se deve remover URLs e e-mails no texto
zh_Hans: 是否移除文本中的URL和电子邮件地址
label:
en_US: Remove URLs and Emails
ja_JP: URLとメールアドレスを削除
pt_BR: Remover URLs e E-mails
zh_Hans: 移除URL和电子邮件地址
llm_description: Whether to remove URLs and emails in the text
max: null
min: null
name: remove_urls_emails
options: []
placeholder: null
precision: null
required: false
scope: null
template: null
type: boolean
params:
input_text: ''
max_length: ''
parent_mode: ''
remove_extra_spaces: ''
remove_urls_emails: ''
separator: ''
subchunk_max_length: ''
subchunk_separator: ''
provider_id: langgenius/parentchild_chunker/parentchild_chunker
provider_name: langgenius/parentchild_chunker/parentchild_chunker
provider_type: builtin
selected: true
title: Parent-child Chunker
tool_configurations: {}
tool_description: Parent-child Chunk Structure
tool_label: Parent-child Chunker
tool_name: parentchild_chunker
tool_parameters:
input_text:
type: mixed
value: '{{#1752565435219.output#}}'
max_length:
type: variable
value:
- rag
- shared
- max_chunk_length
parent_mode:
type: variable
value:
- rag
- shared
- parent_mode
remove_extra_spaces:
type: mixed
value: '{{#rag.shared.replace_consecutive_spaces#}}'
remove_urls_emails:
type: mixed
value: '{{#rag.shared.delete_urls_email#}}'
separator:
type: mixed
value: '{{#rag.shared.delimiter#}}'
subchunk_max_length:
type: variable
value:
- rag
- shared
- child_max_chunk_length
subchunk_separator:
type: mixed
value: '{{#rag.shared.child_delimiter#}}'
type: tool
height: 52
id: '1752490343805'
position:
x: 1853.5260563244174
y: 281.3910724383104
positionAbsolute:
x: 1853.5260563244174
y: 281.3910724383104
selected: true
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
datasource_configurations: {}
datasource_label: Jina Reader
datasource_name: jina_reader
datasource_parameters:
crawl_sub_pages:
type: mixed
value: '{{#rag.1752491761974.jina_crawl_sub_pages#}}'
limit:
type: variable
value:
- rag
- '1752491761974'
- jina_limit
url:
type: mixed
value: '{{#rag.1752491761974.jina_url#}}'
use_sitemap:
type: mixed
value: '{{#rag.1752491761974.jina_use_sitemap#}}'
plugin_id: langgenius/jina_datasource
provider_name: jina
provider_type: website_crawl
selected: false
title: Jina Reader
type: datasource
height: 52
id: '1752491761974'
position:
x: 1067.7526055798794
y: 281.3910724383104
positionAbsolute:
x: 1067.7526055798794
y: 281.3910724383104
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
datasource_configurations: {}
datasource_label: Firecrawl
datasource_name: crawl
datasource_parameters:
crawl_subpages:
type: mixed
value: '{{#rag.1752565402678.firecrawl_crawl_sub_pages#}}'
exclude_paths:
type: mixed
value: '{{#rag.1752565402678.firecrawl_exclude_paths#}}'
include_paths:
type: mixed
value: '{{#rag.1752565402678.firecrawl_include_only_paths#}}'
limit:
type: variable
value:
- rag
- '1752565402678'
- firecrawl_limit
max_depth:
type: variable
value:
- rag
- '1752565402678'
- firecrawl_max_depth
only_main_content:
type: mixed
value: '{{#rag.1752565402678.firecrawl_extract_main_content#}}'
url:
type: mixed
value: '{{#rag.1752565402678.firecrawl_url#}}'
plugin_id: langgenius/firecrawl_datasource
provider_name: firecrawl
provider_type: website_crawl
selected: false
title: Firecrawl
type: datasource
height: 52
id: '1752565402678'
position:
x: 1067.7526055798794
y: 417.32608398342404
positionAbsolute:
x: 1067.7526055798794
y: 417.32608398342404
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
output_type: string
selected: false
title: Variable Aggregator
type: variable-aggregator
variables:
- - '1752491761974'
- content
- - '1752565402678'
- content
height: 129
id: '1752565435219'
position:
x: 1505.4306671642219
y: 281.3910724383104
positionAbsolute:
x: 1505.4306671642219
y: 281.3910724383104
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
viewport:
x: -826.1791044466438
y: -71.91725474841303
zoom: 0.9980166672552107
rag_pipeline_variables:
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752491761974'
default_value: null
label: URL
max_length: 256
options: []
placeholder: https://docs.dify.ai/en/
required: true
tooltips: null
type: text-input
unit: null
variable: jina_url
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752491761974'
default_value: 10
label: Limit
max_length: 48
options: []
placeholder: null
required: true
tooltips: null
type: number
unit: null
variable: jina_limit
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752491761974'
default_value: null
label: Crawl sub-pages
max_length: 48
options: []
placeholder: null
required: false
tooltips: null
type: checkbox
unit: null
variable: jina_crawl_sub_pages
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752491761974'
default_value: null
label: Use sitemap
max_length: 48
options: []
placeholder: null
required: false
tooltips: Follow the sitemap to crawl the site. If not, Jina Reader will crawl
iteratively based on page relevance, yielding fewer but higher-quality pages.
type: checkbox
unit: null
variable: jina_use_sitemap
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752565402678'
default_value: null
label: URL
max_length: 256
options: []
placeholder: https://docs.dify.ai/en/
required: true
tooltips: null
type: text-input
unit: null
variable: firecrawl_url
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752565402678'
default_value: true
label: Crawl sub-pages
max_length: 48
options: []
placeholder: null
required: false
tooltips: null
type: checkbox
unit: null
variable: firecrawl_crawl_sub_pages
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752565402678'
default_value: 10
label: Limit
max_length: 48
options: []
placeholder: null
required: true
tooltips: null
type: number
unit: null
variable: firecrawl_limit
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752565402678'
default_value: null
label: Max depth
max_length: 48
options: []
placeholder: ''
required: false
tooltips: Maximum depth to crawl relative to the entered URL. Depth 0 just scrapes
the page of the entered url, depth 1 scrapes the url and everything after enteredURL
+ one /, and so on.
type: number
unit: null
variable: firecrawl_max_depth
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752565402678'
default_value: null
label: Exclude paths
max_length: 256
options: []
placeholder: blog/*, /about/*
required: false
tooltips: null
type: text-input
unit: null
variable: firecrawl_exclude_paths
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752565402678'
default_value: null
label: Include only paths
max_length: 256
options: []
placeholder: articles/*
required: false
tooltips: null
type: text-input
unit: null
variable: firecrawl_include_only_paths
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: '1752565402678'
default_value: null
label: firecrawl_extract_main_content
max_length: 48
options: []
placeholder: null
required: false
tooltips: null
type: checkbox
unit: null
variable: firecrawl_extract_main_content
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: \n\n
label: delimiter
max_length: 100
options: []
placeholder: null
required: true
tooltips: A delimiter is the character used to separate text. \n\n is recommended
for splitting the original document into large parent chunks. You can also use
special delimiters defined by yourself.
type: text-input
unit: null
variable: delimiter
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: 1024
label: Maximum chunk length
max_length: 48
options: []
placeholder: null
required: true
tooltips: null
type: number
unit: characters
variable: max_chunk_length
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: \n
label: Child delimiter
max_length: 199
options: []
placeholder: null
required: true
tooltips: A delimiter is the character used to separate text. \n\n is recommended
for splitting the original document into large parent chunks. You can also use
special delimiters defined by yourself.
type: text-input
unit: null
variable: child_delimiter
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: 512
label: Child max chunk length
max_length: 48
options: []
placeholder: null
required: true
tooltips: null
type: number
unit: characters
variable: child_max_chunk_length
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: paragraph
label: Parent mode
max_length: 48
options:
- full_doc
- paragraph
placeholder: null
required: true
tooltips: null
type: select
unit: null
variable: parent_mode
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: null
label: Replace consecutive spaces, newlines and tabs
max_length: 48
options: []
placeholder: null
required: false
tooltips: null
type: checkbox
unit: null
variable: replace_consecutive_spaces
- allow_file_extension: null
allow_file_upload_methods: null
allowed_file_types: null
belong_to_node_id: shared
default_value: null
label: Delete all URLs and email addresses
max_length: 48
options: []
placeholder: null
required: false
tooltips: null
type: checkbox
unit: null
variable: delete_urls_email

View File

@ -1,7 +1,6 @@
import json
from os import path
from pathlib import Path
from typing import Optional
from flask import current_app
@ -14,12 +13,12 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
Retrieval recommended app from buildin, the location is constants/recommended_apps.json
"""
builtin_data: Optional[dict] = None
builtin_data: dict | None = None
def get_type(self) -> str:
return RecommendAppType.BUILDIN
def get_recommended_apps_and_categories(self, language: str) -> dict:
def get_recommended_apps_and_categories(self, language: str):
result = self.fetch_recommended_apps_from_builtin(language)
return result
@ -28,7 +27,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
return result
@classmethod
def _get_builtin_data(cls) -> dict:
def _get_builtin_data(cls):
"""
Get builtin data.
:return:
@ -44,7 +43,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
return cls.builtin_data or {}
@classmethod
def fetch_recommended_apps_from_builtin(cls, language: str) -> dict:
def fetch_recommended_apps_from_builtin(cls, language: str):
"""
Fetch recommended apps from builtin.
:param language: language
@ -54,7 +53,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
return builtin_data.get("recommended_apps", {}).get(language, {})
@classmethod
def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> Optional[dict]:
def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> dict | None:
"""
Fetch recommended app detail from builtin.
:param app_id: App ID

View File

@ -1,4 +1,4 @@
from typing import Optional
from sqlalchemy import select
from constants.languages import languages
from extensions.ext_database import db
@ -13,7 +13,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
Retrieval recommended app from database
"""
def get_recommended_apps_and_categories(self, language: str) -> dict:
def get_recommended_apps_and_categories(self, language: str):
result = self.fetch_recommended_apps_from_db(language)
return result
@ -25,24 +25,20 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
return RecommendAppType.DATABASE
@classmethod
def fetch_recommended_apps_from_db(cls, language: str) -> dict:
def fetch_recommended_apps_from_db(cls, language: str):
"""
Fetch recommended apps from db.
:param language: language
:return:
"""
recommended_apps = (
db.session.query(RecommendedApp)
.where(RecommendedApp.is_listed == True, RecommendedApp.language == language)
.all()
)
recommended_apps = db.session.scalars(
select(RecommendedApp).where(RecommendedApp.is_listed == True, RecommendedApp.language == language)
).all()
if len(recommended_apps) == 0:
recommended_apps = (
db.session.query(RecommendedApp)
.where(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0])
.all()
)
recommended_apps = db.session.scalars(
select(RecommendedApp).where(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0])
).all()
categories = set()
recommended_apps_result = []
@ -74,7 +70,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)}
@classmethod
def fetch_recommended_app_detail_from_db(cls, app_id: str) -> Optional[dict]:
def fetch_recommended_app_detail_from_db(cls, app_id: str) -> dict | None:
"""
Fetch recommended app detail from db.
:param app_id: App ID

View File

@ -5,7 +5,7 @@ class RecommendAppRetrievalBase(ABC):
"""Interface for recommend app retrieval."""
@abstractmethod
def get_recommended_apps_and_categories(self, language: str) -> dict:
def get_recommended_apps_and_categories(self, language: str):
raise NotImplementedError
@abstractmethod

View File

@ -1,5 +1,4 @@
import logging
from typing import Optional
import requests
@ -24,7 +23,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
result = BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin(app_id)
return result
def get_recommended_apps_and_categories(self, language: str) -> dict:
def get_recommended_apps_and_categories(self, language: str):
try:
result = self.fetch_recommended_apps_from_dify_official(language)
except Exception as e:
@ -36,7 +35,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
return RecommendAppType.REMOTE
@classmethod
def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> Optional[dict]:
def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> dict | None:
"""
Fetch recommended app detail from dify official.
:param app_id: App ID
@ -51,7 +50,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
return data
@classmethod
def fetch_recommended_apps_from_dify_official(cls, language: str) -> dict:
def fetch_recommended_apps_from_dify_official(cls, language: str):
"""
Fetch recommended apps from dify official.
:param language: language

View File

@ -1,12 +1,10 @@
from typing import Optional
from configs import dify_config
from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory
class RecommendedAppService:
@classmethod
def get_recommended_apps_and_categories(cls, language: str) -> dict:
def get_recommended_apps_and_categories(cls, language: str):
"""
Get recommended apps and categories.
:param language: language
@ -15,7 +13,7 @@ class RecommendedAppService:
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)()
result = retrieval_instance.get_recommended_apps_and_categories(language)
if not result.get("recommended_apps") and language != "en-US":
if not result.get("recommended_apps"):
result = (
RecommendAppRetrievalFactory.get_buildin_recommend_app_retrieval().fetch_recommended_apps_from_builtin(
"en-US"
@ -25,7 +23,7 @@ class RecommendedAppService:
return result
@classmethod
def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]:
def get_recommend_app_detail(cls, app_id: str) -> dict | None:
"""
Get recommend app detail.
:param app_id: app id

View File

@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Union
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
@ -11,7 +11,7 @@ from services.message_service import MessageService
class SavedMessageService:
@classmethod
def pagination_by_last_id(
cls, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int
cls, app_model: App, user: Union[Account, EndUser] | None, last_id: str | None, limit: int
) -> InfiniteScrollPagination:
if not user:
raise ValueError("User is required")
@ -32,7 +32,7 @@ class SavedMessageService:
)
@classmethod
def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
def save(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str):
if not user:
return
saved_message = (
@ -62,7 +62,7 @@ class SavedMessageService:
db.session.commit()
@classmethod
def delete(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
def delete(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str):
if not user:
return
saved_message = (

View File

@ -1,8 +1,7 @@
import uuid
from typing import Optional
from flask_login import current_user
from sqlalchemy import func
from sqlalchemy import func, select
from werkzeug.exceptions import NotFound
from extensions.ext_database import db
@ -12,7 +11,7 @@ from models.model import App, Tag, TagBinding
class TagService:
@staticmethod
def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = None) -> list:
def get_tags(tag_type: str, current_tenant_id: str, keyword: str | None = None):
query = (
db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
.outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
@ -25,46 +24,41 @@ class TagService:
return results
@staticmethod
def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list:
def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list):
# Check if tag_ids is not empty to avoid WHERE false condition
if not tag_ids or len(tag_ids) == 0:
return []
tags = (
db.session.query(Tag)
.where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
.all()
)
tags = db.session.scalars(
select(Tag).where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
).all()
if not tags:
return []
tag_ids = [tag.id for tag in tags]
# Check if tag_ids is not empty to avoid WHERE false condition
if not tag_ids or len(tag_ids) == 0:
return []
tag_bindings = (
db.session.query(TagBinding.target_id)
.where(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id)
.all()
)
if not tag_bindings:
return []
results = [tag_binding.target_id for tag_binding in tag_bindings]
return results
tag_bindings = db.session.scalars(
select(TagBinding.target_id).where(
TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id
)
).all()
return tag_bindings
@staticmethod
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str) -> list:
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str):
if not tag_type or not tag_name:
return []
tags = (
db.session.query(Tag)
.where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
.all()
tags = list(
db.session.scalars(
select(Tag).where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
).all()
)
if not tags:
return []
return tags
@staticmethod
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list:
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str):
tags = (
db.session.query(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id)
@ -117,7 +111,7 @@ class TagService:
raise NotFound("Tag not found")
db.session.delete(tag)
# delete tag binding
tag_bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag_id).all()
tag_bindings = db.session.scalars(select(TagBinding).where(TagBinding.tag_id == tag_id)).all()
if tag_bindings:
for tag_binding in tag_bindings:
db.session.delete(tag_binding)

View File

@ -4,6 +4,7 @@ from collections.abc import Mapping
from typing import Any, cast
from httpx import get
from sqlalchemy import select
from core.entities.provider_entities import ProviderConfig
from core.model_runtime.utils.encoders import jsonable_encoder
@ -443,9 +444,7 @@ class ApiToolManageService:
list api tools
"""
# get all api providers
db_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all() or []
)
db_providers = db.session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all()
result: list[ToolProviderApiEntity] = []

View File

@ -1,18 +1,19 @@
import json
import logging
import re
from collections.abc import Mapping
from pathlib import Path
from typing import Any, Optional
from typing import Any
from sqlalchemy import exists, select
from sqlalchemy.orm import Session
from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.helper.name_generator import generate_incremental_name
from core.helper.position_helper import is_filtered
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
from core.plugin.entities.plugin import ToolProviderID
# from core.plugin.entities.plugin import ToolProviderID
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
@ -30,6 +31,7 @@ from core.tools.utils.encryption import create_provider_encrypter
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.provider_ids import ToolProviderID
from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient
from services.plugin.plugin_service import PluginService
from services.tools.tools_transform_service import ToolTransformService
@ -222,8 +224,8 @@ class BuiltinToolManageService:
"""
add builtin tool provider
"""
try:
with Session(db.engine) as session:
with Session(db.engine) as session:
try:
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
with redis_client.lock(lock, timeout=20):
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
@ -282,9 +284,9 @@ class BuiltinToolManageService:
session.add(db_provider)
session.commit()
except Exception as e:
session.rollback()
raise ValueError(str(e))
except Exception as e:
session.rollback()
raise ValueError(str(e))
return {"result": "success"}
@staticmethod
@ -308,42 +310,20 @@ class BuiltinToolManageService:
def generate_builtin_tool_provider_name(
session: Session, tenant_id: str, provider: str, credential_type: CredentialType
) -> str:
try:
db_providers = (
session.query(BuiltinToolProvider)
.filter_by(
tenant_id=tenant_id,
provider=provider,
credential_type=credential_type.value,
)
.order_by(BuiltinToolProvider.created_at.desc())
.all()
db_providers = (
session.query(BuiltinToolProvider)
.filter_by(
tenant_id=tenant_id,
provider=provider,
credential_type=credential_type.value,
)
# Get the default name pattern
default_pattern = f"{credential_type.get_name()}"
# Find all names that match the default pattern: "{default_pattern} {number}"
pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$"
numbers = []
for db_provider in db_providers:
if db_provider.name:
match = re.match(pattern, db_provider.name.strip())
if match:
numbers.append(int(match.group(1)))
# If no default pattern names found, start with 1
if not numbers:
return f"{default_pattern} 1"
# Find the next number
max_number = max(numbers)
return f"{default_pattern} {max_number + 1}"
except Exception as e:
logger.warning("Error generating next provider name for %s: %s", provider, str(e))
# fallback
return f"{credential_type.get_name()} 1"
.order_by(BuiltinToolProvider.created_at.desc())
.all()
)
return generate_incremental_name(
[provider.name for provider in db_providers],
f"{credential_type.get_name()}",
)
@staticmethod
def get_builtin_tool_provider_credentials(
@ -570,7 +550,7 @@ class BuiltinToolManageService:
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
data=provider_controller,
name_func=lambda x: x.identity.name,
name_func=lambda x: x.entity.identity.name,
):
continue
@ -601,7 +581,7 @@ class BuiltinToolManageService:
return BuiltinToolProviderSort.sort(result)
@staticmethod
def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
def get_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
"""
This method is used to fetch the builtin provider from the database
1.if the default provider exists, return the default provider
@ -662,8 +642,8 @@ class BuiltinToolManageService:
def save_custom_oauth_client_params(
tenant_id: str,
provider: str,
client_params: Optional[dict] = None,
enable_oauth_custom_client: Optional[bool] = None,
client_params: dict | None = None,
enable_oauth_custom_client: bool | None = None,
):
"""
setup oauth custom client

View File

@ -1,7 +1,7 @@
import hashlib
import json
from datetime import datetime
from typing import Any
from typing import Any, cast
from sqlalchemy import or_
from sqlalchemy.exc import IntegrityError
@ -27,6 +27,36 @@ class MCPToolManageService:
Service class for managing mcp tools.
"""
@staticmethod
def _encrypt_headers(headers: dict[str, str], tenant_id: str) -> dict[str, str]:
"""
Encrypt headers using ProviderConfigEncrypter with all headers as SECRET_INPUT.
Args:
headers: Dictionary of headers to encrypt
tenant_id: Tenant ID for encryption
Returns:
Dictionary with all headers encrypted
"""
if not headers:
return {}
from core.entities.provider_entities import BasicProviderConfig
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.tools.utils.encryption import create_provider_encrypter
# Create dynamic config for all headers as SECRET_INPUT
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers]
encrypter_instance, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=config,
cache=NoOpProviderCredentialCache(),
)
return cast(dict[str, str], encrypter_instance.encrypt(headers))
@staticmethod
def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider:
res = (
@ -61,6 +91,7 @@ class MCPToolManageService:
server_identifier: str,
timeout: float,
sse_read_timeout: float,
headers: dict[str, str] | None = None,
) -> ToolProviderApiEntity:
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
existing_provider = (
@ -83,6 +114,12 @@ class MCPToolManageService:
if existing_provider.server_identifier == server_identifier:
raise ValueError(f"MCP tool {server_identifier} already exists")
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
# Encrypt headers
encrypted_headers = None
if headers:
encrypted_headers_dict = MCPToolManageService._encrypt_headers(headers, tenant_id)
encrypted_headers = json.dumps(encrypted_headers_dict)
mcp_tool = MCPToolProvider(
tenant_id=tenant_id,
name=name,
@ -95,6 +132,7 @@ class MCPToolManageService:
server_identifier=server_identifier,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
encrypted_headers=encrypted_headers,
)
db.session.add(mcp_tool)
db.session.commit()
@ -118,9 +156,21 @@ class MCPToolManageService:
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
server_url = mcp_provider.decrypted_server_url
authed = mcp_provider.authed
headers = mcp_provider.decrypted_headers
timeout = mcp_provider.timeout
sse_read_timeout = mcp_provider.sse_read_timeout
try:
with MCPClient(server_url, provider_id, tenant_id, authed=authed, for_list=True) as mcp_client:
with MCPClient(
server_url,
provider_id,
tenant_id,
authed=authed,
for_list=True,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
) as mcp_client:
tools = mcp_client.list_tools()
except MCPAuthError:
raise ValueError("Please auth the tool first")
@ -172,6 +222,7 @@ class MCPToolManageService:
server_identifier: str,
timeout: float | None = None,
sse_read_timeout: float | None = None,
headers: dict[str, str] | None = None,
):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
@ -207,6 +258,32 @@ class MCPToolManageService:
mcp_provider.timeout = timeout
if sse_read_timeout is not None:
mcp_provider.sse_read_timeout = sse_read_timeout
if headers is not None:
# Merge masked headers from frontend with existing real values
if headers:
# existing decrypted and masked headers
existing_decrypted = mcp_provider.decrypted_headers
existing_masked = mcp_provider.masked_headers
# Build final headers: if value equals masked existing, keep original decrypted value
final_headers: dict[str, str] = {}
for key, incoming_value in headers.items():
if (
key in existing_masked
and key in existing_decrypted
and isinstance(incoming_value, str)
and incoming_value == existing_masked.get(key)
):
# unchanged, use original decrypted value
final_headers[key] = str(existing_decrypted[key])
else:
final_headers[key] = incoming_value
encrypted_headers_dict = MCPToolManageService._encrypt_headers(final_headers, tenant_id)
mcp_provider.encrypted_headers = json.dumps(encrypted_headers_dict)
else:
# Explicitly clear headers if empty dict passed
mcp_provider.encrypted_headers = None
db.session.commit()
except IntegrityError as e:
db.session.rollback()
@ -226,10 +303,10 @@ class MCPToolManageService:
def update_mcp_provider_credentials(
cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False
):
provider_controller = MCPToolProviderController._from_db(mcp_provider)
provider_controller = MCPToolProviderController.from_db(mcp_provider)
tool_configuration = ProviderConfigEncrypter(
tenant_id=mcp_provider.tenant_id,
config=list(provider_controller.get_credentials_schema()),
config=list(provider_controller.get_credentials_schema()), # ty: ignore [invalid-argument-type]
provider_config_cache=NoOpProviderCredentialCache(),
)
credentials = tool_configuration.encrypt(credentials)
@ -242,6 +319,12 @@ class MCPToolManageService:
@classmethod
def _re_connect_mcp_provider(cls, server_url: str, provider_id: str, tenant_id: str):
# Get the existing provider to access headers and timeout settings
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
headers = mcp_provider.decrypted_headers
timeout = mcp_provider.timeout
sse_read_timeout = mcp_provider.sse_read_timeout
try:
with MCPClient(
server_url,
@ -249,6 +332,9 @@ class MCPToolManageService:
tenant_id,
authed=False,
for_list=True,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
) as mcp_client:
tools = mcp_client.list_tools()
return {

View File

@ -1,5 +1,4 @@
import logging
from typing import Optional
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
from core.tools.tool_manager import ToolManager
@ -10,7 +9,7 @@ logger = logging.getLogger(__name__)
class ToolCommonService:
@staticmethod
def list_tool_providers(user_id: str, tenant_id: str, typ: Optional[ToolProviderTypeApiLiteral] = None):
def list_tool_providers(user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral | None = None):
"""
list tool providers

View File

@ -1,13 +1,14 @@
import json
import logging
from typing import Any, Optional, Union, cast
from collections.abc import Mapping
from typing import Any, Union
from yarl import URL
from configs import dify_config
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.mcp.types import Tool as MCPTool
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.entities.plugin_daemon import CredentialType, PluginDatasourceProviderEntity
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.provider import BuiltinToolProviderController
@ -32,7 +33,9 @@ logger = logging.getLogger(__name__)
class ToolTransformService:
@classmethod
def get_tool_provider_icon_url(cls, provider_type: str, provider_name: str, icon: str | dict) -> Union[str, dict]:
def get_tool_provider_icon_url(
cls, provider_type: str, provider_name: str, icon: str | Mapping[str, str]
) -> str | Mapping[str, str]:
"""
get tool provider icon url
"""
@ -45,7 +48,7 @@ class ToolTransformService:
elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}:
try:
if isinstance(icon, str):
return cast(dict, json.loads(icon))
return json.loads(icon)
return icon
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}
@ -54,7 +57,7 @@ class ToolTransformService:
return ""
@staticmethod
def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity]):
def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity, PluginDatasourceProviderEntity]):
"""
repack provider
@ -68,7 +71,9 @@ class ToolTransformService:
elif isinstance(provider, ToolProviderApiEntity):
if provider.plugin_id:
if isinstance(provider.icon, str):
provider.icon = PluginService.get_plugin_icon_url(tenant_id=tenant_id, filename=provider.icon)
provider.icon = PluginService.get_plugin_icon_url(
tenant_id=tenant_id, filename=provider.icon
)
if isinstance(provider.icon_dark, str) and provider.icon_dark:
provider.icon_dark = PluginService.get_plugin_icon_url(
tenant_id=tenant_id, filename=provider.icon_dark
@ -81,12 +86,18 @@ class ToolTransformService:
provider.icon_dark = ToolTransformService.get_tool_provider_icon_url(
provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon_dark
)
elif isinstance(provider, PluginDatasourceProviderEntity):
if provider.plugin_id:
if isinstance(provider.declaration.identity.icon, str):
provider.declaration.identity.icon = ToolTransformService.get_plugin_icon_url(
tenant_id=tenant_id, filename=provider.declaration.identity.icon
)
@classmethod
def builtin_provider_to_user_provider(
cls,
provider_controller: BuiltinToolProviderController | PluginToolProviderController,
db_provider: Optional[BuiltinToolProvider],
db_provider: BuiltinToolProvider | None,
decrypt_credentials: bool = True,
) -> ToolProviderApiEntity:
"""
@ -98,7 +109,7 @@ class ToolTransformService:
name=provider_controller.entity.identity.name,
description=provider_controller.entity.identity.description,
icon=provider_controller.entity.identity.icon,
icon_dark=provider_controller.entity.identity.icon_dark,
icon_dark=provider_controller.entity.identity.icon_dark or "",
label=provider_controller.entity.identity.label,
type=ToolProviderType.BUILT_IN,
masked_credentials={},
@ -120,9 +131,10 @@ class ToolTransformService:
)
}
for name, value in schema.items():
if result.masked_credentials:
result.masked_credentials[name] = ""
masked_creds = {}
for name in schema:
masked_creds[name] = ""
result.masked_credentials = masked_creds
# check if the provider need credentials
if not provider_controller.need_credentials:
@ -200,7 +212,7 @@ class ToolTransformService:
name=provider_controller.entity.identity.name,
description=provider_controller.entity.identity.description,
icon=provider_controller.entity.identity.icon,
icon_dark=provider_controller.entity.identity.icon_dark,
icon_dark=provider_controller.entity.identity.icon_dark or "",
label=provider_controller.entity.identity.label,
type=ToolProviderType.WORKFLOW,
masked_credentials={},
@ -229,6 +241,10 @@ class ToolTransformService:
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
description=I18nObject(en_US="", zh_Hans=""),
server_identifier=db_provider.server_identifier,
timeout=db_provider.timeout,
sse_read_timeout=db_provider.sse_read_timeout,
masked_headers=db_provider.masked_headers,
original_headers=db_provider.decrypted_headers,
)
@staticmethod
@ -239,7 +255,7 @@ class ToolTransformService:
author=user.name if user else "Anonymous",
name=tool.name,
label=I18nObject(en_US=tool.name, zh_Hans=tool.name),
description=I18nObject(en_US=tool.description, zh_Hans=tool.description),
description=I18nObject(en_US=tool.description or "", zh_Hans=tool.description or ""),
parameters=ToolTransformService.convert_mcp_schema_to_parameter(tool.inputSchema),
labels=[],
)
@ -309,7 +325,7 @@ class ToolTransformService:
@staticmethod
def convert_tool_entity_to_api_entity(
tool: Union[ApiToolBundle, WorkflowTool, Tool],
tool: ApiToolBundle | WorkflowTool | Tool,
tenant_id: str,
labels: list[str] | None = None,
) -> ToolApiEntity:
@ -363,7 +379,7 @@ class ToolTransformService:
parameters=merged_parameters,
labels=labels or [],
)
elif isinstance(tool, ApiToolBundle):
else:
return ToolApiEntity(
author=tool.author,
name=tool.operation_id or "",
@ -372,9 +388,6 @@ class ToolTransformService:
parameters=tool.parameters,
labels=labels or [],
)
else:
# Handle WorkflowTool case
raise ValueError(f"Unsupported tool type: {type(tool)}")
@staticmethod
def convert_builtin_provider_to_credential_entity(

View File

@ -3,7 +3,7 @@ from collections.abc import Mapping
from datetime import datetime
from typing import Any
from sqlalchemy import or_
from sqlalchemy import or_, select
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_provider import ToolProviderController
@ -37,7 +37,7 @@ class WorkflowToolManageService:
parameters: list[Mapping[str, Any]],
privacy_policy: str = "",
labels: list[str] | None = None,
) -> dict:
):
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
# check if the name is unique
@ -103,7 +103,7 @@ class WorkflowToolManageService:
parameters: list[Mapping[str, Any]],
privacy_policy: str = "",
labels: list[str] | None = None,
) -> dict:
):
"""
Update a workflow tool.
:param user_id: the user id
@ -186,7 +186,9 @@ class WorkflowToolManageService:
:param tenant_id: the tenant id
:return: the list of tools
"""
db_tools = db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all()
db_tools = db.session.scalars(
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
).all()
tools: list[WorkflowToolProviderController] = []
for provider in db_tools:
@ -217,7 +219,7 @@ class WorkflowToolManageService:
return result
@classmethod
def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict:
def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str):
"""
Delete a workflow tool.
:param user_id: the user id
@ -233,7 +235,7 @@ class WorkflowToolManageService:
return {"result": "success"}
@classmethod
def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict:
def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str):
"""
Get a workflow tool.
:param user_id: the user id
@ -249,7 +251,7 @@ class WorkflowToolManageService:
return cls._get_workflow_tool(tenant_id, db_tool)
@classmethod
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict:
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str):
"""
Get a workflow tool.
:param user_id: the user id
@ -265,7 +267,7 @@ class WorkflowToolManageService:
return cls._get_workflow_tool(tenant_id, db_tool)
@classmethod
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None) -> dict:
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None):
"""
Get a workflow tool.
:db_tool: the database tool

View File

@ -0,0 +1,394 @@
import dataclasses
from collections.abc import Mapping
from typing import Any, Generic, TypeAlias, TypeVar, overload
from configs import dify_config
from core.file.models import File
from core.variables.segments import (
ArrayFileSegment,
ArraySegment,
BooleanSegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
Segment,
StringSegment,
)
from core.variables.utils import dumps_with_segments
_MAX_DEPTH = 100
class _QAKeys:
"""dict keys for _QAStructure"""
QA_CHUNKS = "qa_chunks"
QUESTION = "question"
ANSWER = "answer"
class _PCKeys:
"""dict keys for _ParentChildStructure"""
PARENT_MODE = "parent_mode"
PARENT_CHILD_CHUNKS = "parent_child_chunks"
PARENT_CONTENT = "parent_content"
CHILD_CONTENTS = "child_contents"
_T = TypeVar("_T")
@dataclasses.dataclass(frozen=True)
class _PartResult(Generic[_T]):
value: _T
value_size: int
truncated: bool
class MaxDepthExceededError(Exception):
pass
class UnknownTypeError(Exception):
pass
JSONTypes: TypeAlias = int | float | str | list | dict | None | bool
@dataclasses.dataclass(frozen=True)
class TruncationResult:
result: Segment
truncated: bool
class VariableTruncator:
"""
Handles variable truncation with structure-preserving strategies.
This class implements intelligent truncation that prioritizes maintaining data structure
integrity while ensuring the final size doesn't exceed specified limits.
Uses recursive size calculation to avoid repeated JSON serialization.
"""
def __init__(
self,
string_length_limit=5000,
array_element_limit: int = 20,
max_size_bytes: int = 1024_000, # 100KB
):
if string_length_limit <= 3:
raise ValueError("string_length_limit should be greater than 3.")
self._string_length_limit = string_length_limit
if array_element_limit <= 0:
raise ValueError("array_element_limit should be greater than 0.")
self._array_element_limit = array_element_limit
if max_size_bytes <= 0:
raise ValueError("max_size_bytes should be greater than 0.")
self._max_size_bytes = max_size_bytes
@classmethod
def default(cls) -> "VariableTruncator":
return VariableTruncator(
max_size_bytes=dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE,
array_element_limit=dify_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH,
string_length_limit=dify_config.WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH,
)
def truncate_variable_mapping(self, v: Mapping[str, Any]) -> tuple[Mapping[str, Any], bool]:
"""
`truncate_variable_mapping` is responsible for truncating variable mappings
generated during workflow execution, such as `inputs`, `process_data`, or `outputs`
of a WorkflowNodeExecution record. This ensures the mappings remain within the
specified size limits while preserving their structure.
"""
budget = self._max_size_bytes
is_truncated = False
truncated_mapping: dict[str, Any] = {}
length = len(v.items())
used_size = 0
for key, value in v.items():
used_size += self.calculate_json_size(key)
if used_size > budget:
truncated_mapping[key] = "..."
continue
value_budget = (budget - used_size) // (length - len(truncated_mapping))
if isinstance(value, Segment):
part_result = self._truncate_segment(value, value_budget)
else:
part_result = self._truncate_json_primitives(value, value_budget)
is_truncated = is_truncated or part_result.truncated
truncated_mapping[key] = part_result.value
used_size += part_result.value_size
return truncated_mapping, is_truncated
@staticmethod
def _segment_need_truncation(segment: Segment) -> bool:
if isinstance(
segment,
(NoneSegment, FloatSegment, IntegerSegment, FileSegment, BooleanSegment, ArrayFileSegment),
):
return False
return True
@staticmethod
def _json_value_needs_truncation(value: Any) -> bool:
if value is None:
return False
if isinstance(value, (bool, int, float)):
return False
return True
def truncate(self, segment: Segment) -> TruncationResult:
if isinstance(segment, StringSegment):
result = self._truncate_segment(segment, self._string_length_limit)
else:
result = self._truncate_segment(segment, self._max_size_bytes)
if result.value_size > self._max_size_bytes:
if isinstance(result.value, str):
result = self._truncate_string(result.value, self._max_size_bytes)
return TruncationResult(StringSegment(value=result.value), True)
# Apply final fallback - convert to JSON string and truncate
json_str = dumps_with_segments(result.value, ensure_ascii=False)
if len(json_str) > self._max_size_bytes:
json_str = json_str[: self._max_size_bytes] + "..."
return TruncationResult(result=StringSegment(value=json_str), truncated=True)
return TruncationResult(
result=segment.model_copy(update={"value": result.value.value}), truncated=result.truncated
)
def _truncate_segment(self, segment: Segment, target_size: int) -> _PartResult[Segment]:
"""
Apply smart truncation to a variable value.
Args:
value: The value to truncate (can be Segment or raw value)
Returns:
TruncationResult with truncated data and truncation status
"""
if not VariableTruncator._segment_need_truncation(segment):
return _PartResult(segment, self.calculate_json_size(segment.value), False)
result: _PartResult[Any]
# Apply type-specific truncation with target size
if isinstance(segment, ArraySegment):
result = self._truncate_array(segment.value, target_size)
elif isinstance(segment, StringSegment):
result = self._truncate_string(segment.value, target_size)
elif isinstance(segment, ObjectSegment):
result = self._truncate_object(segment.value, target_size)
else:
raise AssertionError("this should be unreachable.")
return _PartResult(
value=segment.model_copy(update={"value": result.value}),
value_size=result.value_size,
truncated=result.truncated,
)
@staticmethod
def calculate_json_size(value: Any, depth=0) -> int:
"""Recursively calculate JSON size without serialization."""
if isinstance(value, Segment):
return VariableTruncator.calculate_json_size(value.value)
if depth > _MAX_DEPTH:
raise MaxDepthExceededError()
if isinstance(value, str):
# Ideally, the size of strings should be calculated based on their utf-8 encoded length.
# However, this adds complexity as we would need to compute encoded sizes consistently
# throughout the code. Therefore, we approximate the size using the string's length.
# Rough estimate: number of characters, plus 2 for quotes
return len(value) + 2
elif isinstance(value, (int, float)):
return len(str(value))
elif isinstance(value, bool):
return 4 if value else 5 # "true" or "false"
elif value is None:
return 4 # "null"
elif isinstance(value, list):
# Size = sum of elements + separators + brackets
total = 2 # "[]"
for i, item in enumerate(value):
if i > 0:
total += 1 # ","
total += VariableTruncator.calculate_json_size(item, depth=depth + 1)
return total
elif isinstance(value, dict):
# Size = sum of keys + values + separators + brackets
total = 2 # "{}"
for index, key in enumerate(value.keys()):
if index > 0:
total += 1 # ","
total += VariableTruncator.calculate_json_size(str(key), depth=depth + 1) # Key as string
total += 1 # ":"
total += VariableTruncator.calculate_json_size(value[key], depth=depth + 1)
return total
elif isinstance(value, File):
return VariableTruncator.calculate_json_size(value.model_dump(), depth=depth + 1)
else:
raise UnknownTypeError(f"got unknown type {type(value)}")
def _truncate_string(self, value: str, target_size: int) -> _PartResult[str]:
if (size := self.calculate_json_size(value)) < target_size:
return _PartResult(value, size, False)
if target_size < 5:
return _PartResult("...", 5, True)
truncated_size = min(self._string_length_limit, target_size - 5)
truncated_value = value[:truncated_size] + "..."
return _PartResult(truncated_value, self.calculate_json_size(truncated_value), True)
def _truncate_array(self, value: list, target_size: int) -> _PartResult[list]:
"""
Truncate array with correct strategy:
1. First limit to 20 items
2. If still too large, truncate individual items
"""
truncated_value: list[Any] = []
truncated = False
used_size = self.calculate_json_size([])
target_length = self._array_element_limit
for i, item in enumerate(value):
if i >= target_length:
return _PartResult(truncated_value, used_size, True)
if i > 0:
used_size += 1 # Account for comma
if used_size > target_size:
break
part_result = self._truncate_json_primitives(item, target_size - used_size)
truncated_value.append(part_result.value)
used_size += part_result.value_size
truncated = part_result.truncated
return _PartResult(truncated_value, used_size, truncated)
@classmethod
def _maybe_qa_structure(cls, m: Mapping[str, Any]) -> bool:
qa_chunks = m.get(_QAKeys.QA_CHUNKS)
if qa_chunks is None:
return False
if not isinstance(qa_chunks, list):
return False
return True
@classmethod
def _maybe_parent_child_structure(cls, m: Mapping[str, Any]) -> bool:
parent_mode = m.get(_PCKeys.PARENT_MODE)
if parent_mode is None:
return False
if not isinstance(parent_mode, str):
return False
parent_child_chunks = m.get(_PCKeys.PARENT_CHILD_CHUNKS)
if parent_child_chunks is None:
return False
if not isinstance(parent_child_chunks, list):
return False
return True
def _truncate_object(self, mapping: Mapping[str, Any], target_size: int) -> _PartResult[Mapping[str, Any]]:
"""
Truncate object with key preservation priority.
Strategy:
1. Keep all keys, truncate values to fit within budget
2. If still too large, drop keys starting from the end
"""
if not mapping:
return _PartResult(mapping, self.calculate_json_size(mapping), False)
truncated_obj = {}
truncated = False
used_size = self.calculate_json_size({})
# Sort keys to ensure deterministic behavior
sorted_keys = sorted(mapping.keys())
for i, key in enumerate(sorted_keys):
if used_size > target_size:
# No more room for additional key-value pairs
truncated = True
break
pair_size = 0
if i > 0:
pair_size += 1 # Account for comma
# Calculate budget for this key-value pair
# do not try to truncate keys, as we want to keep the structure of
# object.
key_size = self.calculate_json_size(key) + 1 # +1 for ":"
pair_size += key_size
remaining_pairs = len(sorted_keys) - i
value_budget = max(0, (target_size - pair_size - used_size) // remaining_pairs)
if value_budget <= 0:
truncated = True
break
# Truncate the value to fit within budget
value = mapping[key]
if isinstance(value, Segment):
value_result = self._truncate_segment(value, value_budget)
else:
value_result = self._truncate_json_primitives(mapping[key], value_budget)
truncated_obj[key] = value_result.value
pair_size += value_result.value_size
used_size += pair_size
if value_result.truncated:
truncated = True
return _PartResult(truncated_obj, used_size, truncated)
@overload
def _truncate_json_primitives(self, val: str, target_size: int) -> _PartResult[str]: ...
@overload
def _truncate_json_primitives(self, val: list, target_size: int) -> _PartResult[list]: ...
@overload
def _truncate_json_primitives(self, val: dict, target_size: int) -> _PartResult[dict]: ...
@overload
def _truncate_json_primitives(self, val: bool, target_size: int) -> _PartResult[bool]: ... # type: ignore
@overload
def _truncate_json_primitives(self, val: int, target_size: int) -> _PartResult[int]: ...
@overload
def _truncate_json_primitives(self, val: float, target_size: int) -> _PartResult[float]: ...
@overload
def _truncate_json_primitives(self, val: None, target_size: int) -> _PartResult[None]: ...
def _truncate_json_primitives(
self, val: str | list | dict | bool | int | float | None, target_size: int
) -> _PartResult[Any]:
"""Truncate a value within an object to fit within budget."""
if isinstance(val, str):
return self._truncate_string(val, target_size)
elif isinstance(val, list):
return self._truncate_array(val, target_size)
elif isinstance(val, dict):
return self._truncate_object(val, target_size)
elif val is None or isinstance(val, (bool, int, float)):
return _PartResult(val, self.calculate_json_size(val), False)
else:
raise AssertionError("this statement should be unreachable.")

View File

@ -1,5 +1,4 @@
import logging
from typing import Optional
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
@ -19,7 +18,7 @@ logger = logging.getLogger(__name__)
class VectorService:
@classmethod
def create_segments_vector(
cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset, doc_form: str
cls, keywords_list: list[list[str]] | None, segments: list[DocumentSegment], dataset: Dataset, doc_form: str
):
documents: list[Document] = []
@ -79,7 +78,7 @@ class VectorService:
index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list)
@classmethod
def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset):
def update_segment_vector(cls, keywords: list[str] | None, segment: DocumentSegment, dataset: Dataset):
# update segment index task
# format new index

View File

@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Union
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -19,11 +19,11 @@ class WebConversationService:
*,
session: Session,
app_model: App,
user: Optional[Union[Account, EndUser]],
last_id: Optional[str],
user: Union[Account, EndUser] | None,
last_id: str | None,
limit: int,
invoke_from: InvokeFrom,
pinned: Optional[bool] = None,
pinned: bool | None = None,
sort_by="-updated_at",
) -> InfiniteScrollPagination:
if not user:
@ -60,7 +60,7 @@ class WebConversationService:
)
@classmethod
def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
def pin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
if not user:
return
pinned_conversation = (
@ -92,7 +92,7 @@ class WebConversationService:
db.session.commit()
@classmethod
def unpin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
def unpin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
if not user:
return
pinned_conversation = (

View File

@ -1,7 +1,7 @@
import enum
import secrets
from datetime import UTC, datetime, timedelta
from typing import Any, Optional
from typing import Any
from werkzeug.exceptions import NotFound, Unauthorized
@ -63,7 +63,7 @@ class WebAppAuthService:
@classmethod
def send_email_code_login_email(
cls, account: Optional[Account] = None, email: Optional[str] = None, language: str = "en-US"
cls, account: Account | None = None, email: str | None = None, language: str = "en-US"
):
email = account.email if account else email
if email is None:
@ -82,7 +82,7 @@ class WebAppAuthService:
return token
@classmethod
def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]:
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "email_code_login")
@classmethod
@ -130,7 +130,7 @@ class WebAppAuthService:
@classmethod
def is_app_require_permission_check(
cls, app_code: Optional[str] = None, app_id: Optional[str] = None, access_mode: Optional[str] = None
cls, app_code: str | None = None, app_id: str | None = None, access_mode: str | None = None
) -> bool:
"""
Check if the app requires permission check based on its access mode.

View File

@ -1,9 +1,9 @@
import datetime
import json
from dataclasses import dataclass
from typing import Any, Optional
from typing import Any
import requests
import httpx
from flask_login import current_user
from core.helper import encrypter
@ -11,7 +11,7 @@ from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
from core.rag.extractor.watercrawl.provider import WaterCrawlProvider
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from services.auth.api_key_auth_service import ApiKeyAuthService
from services.datasource_provider_service import DatasourceProviderService
@dataclass
@ -21,9 +21,9 @@ class CrawlOptions:
limit: int = 1
crawl_sub_pages: bool = False
only_main_content: bool = False
includes: Optional[str] = None
excludes: Optional[str] = None
max_depth: Optional[int] = None
includes: str | None = None
excludes: str | None = None
max_depth: int | None = None
use_sitemap: bool = True
def get_include_paths(self) -> list[str]:
@ -103,7 +103,6 @@ class WebsiteCrawlStatusApiRequest:
def from_args(cls, args: dict, job_id: str) -> "WebsiteCrawlStatusApiRequest":
"""Create from Flask-RESTful parsed arguments."""
provider = args.get("provider")
if not provider:
raise ValueError("Provider is required")
if not job_id:
@ -116,12 +115,28 @@ class WebsiteService:
"""Service class for website crawling operations using different providers."""
@classmethod
def _get_credentials_and_config(cls, tenant_id: str, provider: str) -> tuple[dict, dict]:
def _get_credentials_and_config(cls, tenant_id: str, provider: str) -> tuple[Any, Any]:
"""Get and validate credentials for a provider."""
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
if not credentials or "config" not in credentials:
raise ValueError("No valid credentials found for the provider")
return credentials, credentials["config"]
if provider == "firecrawl":
plugin_id = "langgenius/firecrawl_datasource"
elif provider == "watercrawl":
plugin_id = "langgenius/watercrawl_datasource"
elif provider == "jinareader":
plugin_id = "langgenius/jina_datasource"
else:
raise ValueError("Invalid provider")
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id,
)
if provider == "firecrawl":
return credential.get("firecrawl_api_key"), credential
elif provider in {"watercrawl", "jinareader"}:
return credential.get("api_key"), credential
else:
raise ValueError("Invalid provider")
@classmethod
def _get_decrypted_api_key(cls, tenant_id: str, config: dict) -> str:
@ -132,7 +147,7 @@ class WebsiteService:
return encrypter.decrypt_token(tenant_id=tenant_id, token=api_key)
@classmethod
def document_create_args_validate(cls, args: dict) -> None:
def document_create_args_validate(cls, args: dict):
"""Validate arguments for document creation."""
try:
WebsiteCrawlApiRequest.from_args(args)
@ -144,8 +159,7 @@ class WebsiteService:
"""Crawl a URL using the specified provider with typed request."""
request = api_request.to_crawl_request()
_, config = cls._get_credentials_and_config(current_user.current_tenant_id, request.provider)
api_key = cls._get_decrypted_api_key(current_user.current_tenant_id, config)
api_key, config = cls._get_credentials_and_config(current_user.current_tenant_id, request.provider)
if request.provider == "firecrawl":
return cls._crawl_with_firecrawl(request=request, api_key=api_key, config=config)
@ -202,15 +216,15 @@ class WebsiteService:
@classmethod
def _crawl_with_jinareader(cls, request: CrawlRequest, api_key: str) -> dict[str, Any]:
if not request.options.crawl_sub_pages:
response = requests.get(
response = httpx.get(
f"https://r.jina.ai/{request.url}",
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
)
if response.json().get("code") != 200:
raise ValueError("Failed to crawl")
raise ValueError("Failed to crawl:")
return {"status": "active", "data": response.json().get("data")}
else:
response = requests.post(
response = httpx.post(
"https://adaptivecrawl-kir3wx7b3a-uc.a.run.app",
json={
"url": request.url,
@ -235,8 +249,7 @@ class WebsiteService:
@classmethod
def get_crawl_status_typed(cls, api_request: WebsiteCrawlStatusApiRequest) -> dict[str, Any]:
"""Get crawl status using typed request."""
_, config = cls._get_credentials_and_config(current_user.current_tenant_id, api_request.provider)
api_key = cls._get_decrypted_api_key(current_user.current_tenant_id, config)
api_key, config = cls._get_credentials_and_config(current_user.current_tenant_id, api_request.provider)
if api_request.provider == "firecrawl":
return cls._get_firecrawl_status(api_request.job_id, api_key, config)
@ -274,7 +287,7 @@ class WebsiteService:
@classmethod
def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]:
response = requests.post(
response = httpx.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id},
@ -290,7 +303,7 @@ class WebsiteService:
}
if crawl_status_data["status"] == "completed":
response = requests.post(
response = httpx.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())},
@ -310,8 +323,7 @@ class WebsiteService:
@classmethod
def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict[str, Any] | None:
_, config = cls._get_credentials_and_config(tenant_id, provider)
api_key = cls._get_decrypted_api_key(tenant_id, config)
api_key, config = cls._get_credentials_and_config(tenant_id, provider)
if provider == "firecrawl":
return cls._get_firecrawl_url_data(job_id, url, api_key, config)
@ -350,7 +362,7 @@ class WebsiteService:
@classmethod
def _get_jinareader_url_data(cls, job_id: str, url: str, api_key: str) -> dict[str, Any] | None:
if not job_id:
response = requests.get(
response = httpx.get(
f"https://r.jina.ai/{url}",
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
)
@ -359,7 +371,7 @@ class WebsiteService:
return dict(response.json().get("data", {}))
else:
# Get crawl status first
status_response = requests.post(
status_response = httpx.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id},
@ -369,7 +381,7 @@ class WebsiteService:
raise ValueError("Crawl job is not completed")
# Get processed data
data_response = requests.post(
data_response = httpx.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())},
@ -384,8 +396,7 @@ class WebsiteService:
def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict[str, Any]:
request = ScrapeRequest(provider=provider, url=url, tenant_id=tenant_id, only_main_content=only_main_content)
_, config = cls._get_credentials_and_config(tenant_id=request.tenant_id, provider=request.provider)
api_key = cls._get_decrypted_api_key(tenant_id=request.tenant_id, config=config)
api_key, config = cls._get_credentials_and_config(tenant_id=request.tenant_id, provider=request.provider)
if request.provider == "firecrawl":
return cls._scrape_with_firecrawl(request=request, api_key=api_key, config=config)

View File

@ -1,5 +1,5 @@
import json
from typing import Any, Optional
from typing import Any
from core.app.app_config.entities import (
DatasetEntity,
@ -18,6 +18,7 @@ from core.helper import encrypter
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.simple_prompt_transform import SimplePromptTransform
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.nodes import NodeType
from events.app_event import app_was_created
from extensions.ext_database import db
@ -64,7 +65,7 @@ class WorkflowConverter:
new_app = App()
new_app.tenant_id = app_model.tenant_id
new_app.name = name or app_model.name + "(workflow)"
new_app.mode = AppMode.ADVANCED_CHAT.value if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value
new_app.mode = AppMode.ADVANCED_CHAT if app_model.mode == AppMode.CHAT else AppMode.WORKFLOW
new_app.icon_type = icon_type or app_model.icon_type
new_app.icon = icon or app_model.icon
new_app.icon_background = icon_background or app_model.icon_background
@ -145,7 +146,7 @@ class WorkflowConverter:
graph=graph,
model_config=app_config.model,
prompt_template=app_config.prompt_template,
file_upload=app_config.additional_features.file_upload,
file_upload=app_config.additional_features.file_upload if app_config.additional_features else None,
external_data_variable_node_mapping=external_data_variable_node_mapping,
)
@ -202,7 +203,7 @@ class WorkflowConverter:
app_mode_enum = AppMode.value_of(app_model.mode)
app_config: EasyUIBasedAppConfig
if app_mode_enum == AppMode.AGENT_CHAT or app_model.is_agent:
app_model.mode = AppMode.AGENT_CHAT.value
app_model.mode = AppMode.AGENT_CHAT
app_config = AgentChatAppConfigManager.get_app_config(
app_model=app_model, app_model_config=app_model_config
)
@ -217,7 +218,7 @@ class WorkflowConverter:
return app_config
def _convert_to_start_node(self, variables: list[VariableEntity]) -> dict:
def _convert_to_start_node(self, variables: list[VariableEntity]):
"""
Convert to Start Node
:param variables: list of variables
@ -278,7 +279,7 @@ class WorkflowConverter:
"app_id": app_model.id,
"tool_variable": tool_variable,
"inputs": inputs,
"query": "{{#sys.query#}}" if app_model.mode == AppMode.CHAT.value else "",
"query": "{{#sys.query#}}" if app_model.mode == AppMode.CHAT else "",
},
}
@ -326,7 +327,7 @@ class WorkflowConverter:
def _convert_to_knowledge_retrieval_node(
self, new_app_mode: AppMode, dataset_config: DatasetEntity, model_config: ModelConfigEntity
) -> Optional[dict]:
) -> dict | None:
"""
Convert datasets to Knowledge Retrieval Node
:param new_app_mode: new app mode
@ -382,9 +383,9 @@ class WorkflowConverter:
graph: dict,
model_config: ModelConfigEntity,
prompt_template: PromptTemplateEntity,
file_upload: Optional[FileUploadConfig] = None,
file_upload: FileUploadConfig | None = None,
external_data_variable_node_mapping: dict[str, str] | None = None,
) -> dict:
):
"""
Convert to LLM Node
:param original_app_mode: original app mode
@ -402,7 +403,7 @@ class WorkflowConverter:
)
role_prefix = None
prompts: Optional[Any] = None
prompts: Any | None = None
# Chat Model
if model_config.mode == LLMMode.CHAT.value:
@ -420,7 +421,11 @@ class WorkflowConverter:
query_in_prompt=False,
)
template = prompt_template_config["prompt_template"].template
prompt_template_obj = prompt_template_config["prompt_template"]
if not isinstance(prompt_template_obj, PromptTemplateParser):
raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template_obj)}")
template = prompt_template_obj.template
if not template:
prompts = []
else:
@ -457,7 +462,11 @@ class WorkflowConverter:
query_in_prompt=False,
)
template = prompt_template_config["prompt_template"].template
prompt_template_obj = prompt_template_config["prompt_template"]
if not isinstance(prompt_template_obj, PromptTemplateParser):
raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template_obj)}")
template = prompt_template_obj.template
template = self._replace_template_variables(
template=template,
variables=start_node["data"]["variables"],
@ -467,6 +476,9 @@ class WorkflowConverter:
prompts = {"text": template}
prompt_rules = prompt_template_config["prompt_rules"]
if not isinstance(prompt_rules, dict):
raise TypeError(f"Expected dict for prompt_rules, got {type(prompt_rules)}")
role_prefix = {
"user": prompt_rules.get("human_prefix", "Human"),
"assistant": prompt_rules.get("assistant_prefix", "Assistant"),
@ -550,7 +562,7 @@ class WorkflowConverter:
return template
def _convert_to_end_node(self) -> dict:
def _convert_to_end_node(self):
"""
Convert to End Node
:return:
@ -566,7 +578,7 @@ class WorkflowConverter:
},
}
def _convert_to_answer_node(self) -> dict:
def _convert_to_answer_node(self):
"""
Convert to Answer Node
:return:
@ -578,7 +590,7 @@ class WorkflowConverter:
"data": {"title": "ANSWER", "type": NodeType.ANSWER.value, "answer": "{{#llm.text#}}"},
}
def _create_edge(self, source: str, target: str) -> dict:
def _create_edge(self, source: str, target: str):
"""
Create Edge
:param source: source node id
@ -587,7 +599,7 @@ class WorkflowConverter:
"""
return {"id": f"{source}-{target}", "source": source, "target": target}
def _append_node(self, graph: dict, node: dict) -> dict:
def _append_node(self, graph: dict, node: dict):
"""
Append Node to Graph
@ -606,7 +618,7 @@ class WorkflowConverter:
:param app_model: App instance
:return: AppMode
"""
if app_model.mode == AppMode.COMPLETION.value:
if app_model.mode == AppMode.COMPLETION:
return AppMode.WORKFLOW
else:
return AppMode.ADVANCED_CHAT

View File

@ -4,7 +4,7 @@ from datetime import datetime
from sqlalchemy import and_, func, or_, select
from sqlalchemy.orm import Session
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
from core.workflow.enums import WorkflowExecutionStatus
from models import Account, App, EndUser, WorkflowAppLog, WorkflowRun
from models.enums import CreatorUserRole
@ -23,7 +23,7 @@ class WorkflowAppService:
limit: int = 20,
created_by_end_user_session_id: str | None = None,
created_by_account: str | None = None,
) -> dict:
):
"""
Get paginate workflow app logs using SQLAlchemy 2.0 style
:param session: SQLAlchemy session

View File

@ -1,32 +1,44 @@
import dataclasses
import json
import logging
from collections.abc import Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from enum import StrEnum
from typing import Any, ClassVar
from sqlalchemy import Engine, orm
from sqlalchemy import Engine, orm, select
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.sql.expression import and_, or_
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.variables import Segment, StringSegment, Variable
from core.variables.consts import SELECTORS_LENGTH
from core.variables.segments import ArrayFileSegment, FileSegment
from core.variables.segments import (
ArrayFileSegment,
FileSegment,
)
from core.variables.types import SegmentType
from core.variables.utils import dumps_with_segments
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType
from core.workflow.nodes.variable_assigner.common.helpers import get_updated_variables
from core.workflow.variable_loader import VariableLoader
from extensions.ext_storage import storage
from factories.file_factory import StorageKeyLoader
from factories.variable_factory import build_segment, segment_to_variable
from libs.datetime_utils import naive_utc_now
from libs.uuid_utils import uuidv7
from models import App, Conversation
from models.account import Account
from models.enums import DraftVariableType
from models.workflow import Workflow, WorkflowDraftVariable, is_system_variable_editable
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowDraftVariableFile, is_system_variable_editable
from repositories.factory import DifyAPIRepositoryFactory
from services.file_service import FileService
from services.variable_truncator import VariableTruncator
logger = logging.getLogger(__name__)
@ -37,6 +49,12 @@ class WorkflowDraftVariableList:
total: int | None = None
@dataclasses.dataclass(frozen=True)
class DraftVarFileDeletion:
draft_var_id: str
draft_var_file_id: str
class WorkflowDraftVariableError(Exception):
pass
@ -67,7 +85,7 @@ class DraftVarLoader(VariableLoader):
app_id: str,
tenant_id: str,
fallback_variables: Sequence[Variable] | None = None,
) -> None:
):
self._engine = engine
self._app_id = app_id
self._tenant_id = tenant_id
@ -87,7 +105,26 @@ class DraftVarLoader(VariableLoader):
srv = WorkflowDraftVariableService(session)
draft_vars = srv.get_draft_variables_by_selectors(self._app_id, selectors)
# Important:
files: list[File] = []
# FileSegment and ArrayFileSegment are not subject to offloading, so their values
# can be safely accessed before any offloading logic is applied.
for draft_var in draft_vars:
value = draft_var.get_value()
if isinstance(value, FileSegment):
files.append(value.value)
elif isinstance(value, ArrayFileSegment):
files.extend(value.value)
with Session(bind=self._engine) as session:
storage_key_loader = StorageKeyLoader(session, tenant_id=self._tenant_id)
storage_key_loader.load_storage_keys(files)
offloaded_draft_vars = []
for draft_var in draft_vars:
if draft_var.is_truncated():
offloaded_draft_vars.append(draft_var)
continue
segment = draft_var.get_value()
variable = segment_to_variable(
segment=segment,
@ -99,25 +136,56 @@ class DraftVarLoader(VariableLoader):
selector_tuple = self._selector_to_tuple(variable.selector)
variable_by_selector[selector_tuple] = variable
# Important:
files: list[File] = []
for draft_var in draft_vars:
value = draft_var.get_value()
if isinstance(value, FileSegment):
files.append(value.value)
elif isinstance(value, ArrayFileSegment):
files.extend(value.value)
with Session(bind=self._engine) as session:
storage_key_loader = StorageKeyLoader(session, tenant_id=self._tenant_id)
storage_key_loader.load_storage_keys(files)
# Load offloaded variables using multithreading.
# This approach reduces loading time by querying external systems concurrently.
with ThreadPoolExecutor(max_workers=10) as executor:
offloaded_variables = executor.map(self._load_offloaded_variable, offloaded_draft_vars)
for selector, variable in offloaded_variables:
variable_by_selector[selector] = variable
return list(variable_by_selector.values())
def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], Variable]:
# This logic is closely tied to `WorkflowDraftVaribleService._try_offload_large_variable`
# and must remain synchronized with it.
# Ideally, these should be co-located for better maintainability.
# However, due to the current code structure, this is not straightforward.
variable_file = draft_var.variable_file
assert variable_file is not None
upload_file = variable_file.upload_file
assert upload_file is not None
content = storage.load(upload_file.key)
if variable_file.value_type == SegmentType.STRING:
# The inferenced type is StringSegment, which is not correct inside this function.
segment: Segment = StringSegment(value=content.decode())
variable = segment_to_variable(
segment=segment,
selector=draft_var.get_selector(),
id=draft_var.id,
name=draft_var.name,
description=draft_var.description,
)
return (draft_var.node_id, draft_var.name), variable
deserialized = json.loads(content)
segment = WorkflowDraftVariable.build_segment_with_type(variable_file.value_type, deserialized)
variable = segment_to_variable(
segment=segment,
selector=draft_var.get_selector(),
id=draft_var.id,
name=draft_var.name,
description=draft_var.description,
)
# No special handling needed for ArrayFileSegment, as we do not offload ArrayFileSegment
return (draft_var.node_id, draft_var.name), variable
class WorkflowDraftVariableService:
_session: Session
def __init__(self, session: Session) -> None:
def __init__(self, session: Session):
"""
Initialize the WorkflowDraftVariableService with a SQLAlchemy session.
@ -138,13 +206,24 @@ class WorkflowDraftVariableService:
)
def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None:
return self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == variable_id).first()
return (
self._session.query(WorkflowDraftVariable)
.options(orm.selectinload(WorkflowDraftVariable.variable_file))
.where(WorkflowDraftVariable.id == variable_id)
.first()
)
def get_draft_variables_by_selectors(
self,
app_id: str,
selectors: Sequence[list[str]],
) -> list[WorkflowDraftVariable]:
"""
Retrieve WorkflowDraftVariable instances based on app_id and selectors.
The returned WorkflowDraftVariable objects are guaranteed to have their
associated variable_file and variable_file.upload_file relationships preloaded.
"""
ors = []
for selector in selectors:
assert len(selector) >= SELECTORS_LENGTH, f"Invalid selector to get: {selector}"
@ -159,7 +238,14 @@ class WorkflowDraftVariableService:
# combined using `UNION` to fetch all rows.
# Benchmarking indicates that both approaches yield comparable performance.
variables = (
self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == app_id, or_(*ors)).all()
self._session.query(WorkflowDraftVariable)
.options(
orm.selectinload(WorkflowDraftVariable.variable_file).selectinload(
WorkflowDraftVariableFile.upload_file
)
)
.where(WorkflowDraftVariable.app_id == app_id, or_(*ors))
.all()
)
return variables
@ -170,8 +256,10 @@ class WorkflowDraftVariableService:
if page == 1:
total = query.count()
variables = (
# Do not load the `value` field.
query.options(orm.defer(WorkflowDraftVariable.value))
# Do not load the `value` field
query.options(
orm.defer(WorkflowDraftVariable.value, raiseload=True),
)
.order_by(WorkflowDraftVariable.created_at.desc())
.limit(limit)
.offset((page - 1) * limit)
@ -186,7 +274,11 @@ class WorkflowDraftVariableService:
WorkflowDraftVariable.node_id == node_id,
)
query = self._session.query(WorkflowDraftVariable).where(*criteria)
variables = query.order_by(WorkflowDraftVariable.created_at.desc()).all()
variables = (
query.options(orm.selectinload(WorkflowDraftVariable.variable_file))
.order_by(WorkflowDraftVariable.created_at.desc())
.all()
)
return WorkflowDraftVariableList(variables=variables)
def list_node_variables(self, app_id: str, node_id: str) -> WorkflowDraftVariableList:
@ -210,6 +302,7 @@ class WorkflowDraftVariableService:
def _get_variable(self, app_id: str, node_id: str, name: str) -> WorkflowDraftVariable | None:
variable = (
self._session.query(WorkflowDraftVariable)
.options(orm.selectinload(WorkflowDraftVariable.variable_file))
.where(
WorkflowDraftVariable.app_id == app_id,
WorkflowDraftVariable.node_id == node_id,
@ -278,7 +371,7 @@ class WorkflowDraftVariableService:
self._session.flush()
return None
outputs_dict = node_exec.outputs_dict or {}
outputs_dict = node_exec.load_full_outputs(self._session, storage) or {}
# a sentinel value used to check the absent of the output variable key.
absent = object()
@ -323,6 +416,49 @@ class WorkflowDraftVariableService:
return self._reset_node_var_or_sys_var(workflow, variable)
def delete_variable(self, variable: WorkflowDraftVariable):
if not variable.is_truncated():
self._session.delete(variable)
return
variable_query = (
select(WorkflowDraftVariable)
.options(
orm.selectinload(WorkflowDraftVariable.variable_file).selectinload(
WorkflowDraftVariableFile.upload_file
),
)
.where(WorkflowDraftVariable.id == variable.id)
)
variable_reloaded = self._session.execute(variable_query).scalars().first()
if variable_reloaded is None:
logger.warning("Associated WorkflowDraftVariable not found, draft_var_id=%s", variable.id)
self._session.delete(variable)
return
variable_file = variable_reloaded.variable_file
if variable_file is None:
logger.warning(
"Associated WorkflowDraftVariableFile not found, draft_var_id=%s, file_id=%s",
variable_reloaded.id,
variable_reloaded.file_id,
)
self._session.delete(variable)
return
upload_file = variable_file.upload_file
if upload_file is None:
logger.warning(
"Associated UploadFile not found, draft_var_id=%s, file_id=%s, upload_file_id=%s",
variable_reloaded.id,
variable_reloaded.file_id,
variable_file.upload_file_id,
)
self._session.delete(variable)
self._session.delete(variable_file)
return
storage.delete(upload_file.key)
self._session.delete(upload_file)
self._session.delete(upload_file)
self._session.delete(variable)
def delete_workflow_variables(self, app_id: str):
@ -332,6 +468,38 @@ class WorkflowDraftVariableService:
.delete(synchronize_session=False)
)
def delete_workflow_draft_variable_file(self, deletions: list[DraftVarFileDeletion]):
variable_files_query = (
select(WorkflowDraftVariableFile)
.options(orm.selectinload(WorkflowDraftVariableFile.upload_file))
.where(WorkflowDraftVariableFile.id.in_([i.draft_var_file_id for i in deletions]))
)
variable_files = self._session.execute(variable_files_query).scalars().all()
variable_files_by_id = {i.id: i for i in variable_files}
for i in deletions:
variable_file = variable_files_by_id.get(i.draft_var_file_id)
if variable_file is None:
logger.warning(
"Associated WorkflowDraftVariableFile not found, draft_var_id=%s, file_id=%s",
i.draft_var_id,
i.draft_var_file_id,
)
continue
upload_file = variable_file.upload_file
if upload_file is None:
logger.warning(
"Associated UploadFile not found, draft_var_id=%s, file_id=%s, upload_file_id=%s",
i.draft_var_id,
i.draft_var_file_id,
variable_file.upload_file_id,
)
self._session.delete(variable_file)
else:
storage.delete(upload_file.key)
self._session.delete(upload_file)
self._session.delete(variable_file)
def delete_node_variables(self, app_id: str, node_id: str):
return self._delete_node_variables(app_id, node_id)
@ -438,7 +606,7 @@ def _batch_upsert_draft_variable(
session: Session,
draft_vars: Sequence[WorkflowDraftVariable],
policy: _UpsertPolicy = _UpsertPolicy.OVERWRITE,
) -> None:
):
if not draft_vars:
return None
# Although we could use SQLAlchemy ORM operations here, we choose not to for several reasons:
@ -476,6 +644,7 @@ def _batch_upsert_draft_variable(
"visible": stmt.excluded.visible,
"editable": stmt.excluded.editable,
"node_execution_id": stmt.excluded.node_execution_id,
"file_id": stmt.excluded.file_id,
},
)
elif policy == _UpsertPolicy.IGNORE:
@ -495,6 +664,7 @@ def _model_to_insertion_dict(model: WorkflowDraftVariable) -> dict[str, Any]:
"value_type": model.value_type,
"value": model.value,
"node_execution_id": model.node_execution_id,
"file_id": model.file_id,
}
if model.visible is not None:
d["visible"] = model.visible
@ -524,6 +694,28 @@ def _build_segment_for_serialized_values(v: Any) -> Segment:
return build_segment(WorkflowDraftVariable.rebuild_file_types(v))
def _make_filename_trans_table() -> dict[int, str]:
linux_chars = ["/", "\x00"]
windows_chars = [
"<",
">",
":",
'"',
"/",
"\\",
"|",
"?",
"*",
]
windows_chars.extend(chr(i) for i in range(32))
trans_table = dict.fromkeys(linux_chars + windows_chars, "_")
return str.maketrans(trans_table)
_FILENAME_TRANS_TABLE = _make_filename_trans_table()
class DraftVariableSaver:
# _DUMMY_OUTPUT_IDENTITY is a placeholder output for workflow nodes.
# Its sole possible value is `None`.
@ -573,6 +765,7 @@ class DraftVariableSaver:
node_id: str,
node_type: NodeType,
node_execution_id: str,
user: Account,
enclosing_node_id: str | None = None,
):
# Important: `node_execution_id` parameter refers to the primary key (`id`) of the
@ -583,6 +776,7 @@ class DraftVariableSaver:
self._node_id = node_id
self._node_type = node_type
self._node_execution_id = node_execution_id
self._user = user
self._enclosing_node_id = enclosing_node_id
def _create_dummy_output_variable(self):
@ -692,17 +886,133 @@ class DraftVariableSaver:
else:
value_seg = _build_segment_for_serialized_values(value)
draft_vars.append(
WorkflowDraftVariable.new_node_variable(
app_id=self._app_id,
node_id=self._node_id,
self._create_draft_variable(
name=name,
node_execution_id=self._node_execution_id,
value=value_seg,
visible=self._should_variable_be_visible(self._node_id, self._node_type, name),
)
visible=True,
editable=True,
),
# WorkflowDraftVariable.new_node_variable(
# app_id=self._app_id,
# node_id=self._node_id,
# name=name,
# node_execution_id=self._node_execution_id,
# value=value_seg,
# visible=self._should_variable_be_visible(self._node_id, self._node_type, name),
# )
)
return draft_vars
def _generate_filename(self, name: str):
node_id_escaped = self._node_id.translate(_FILENAME_TRANS_TABLE)
return f"{node_id_escaped}-{name}"
def _try_offload_large_variable(
self,
name: str,
value_seg: Segment,
) -> tuple[Segment, WorkflowDraftVariableFile] | None:
# This logic is closely tied to `DraftVarLoader._load_offloaded_variable` and must remain
# synchronized with it.
# Ideally, these should be co-located for better maintainability.
# However, due to the current code structure, this is not straightforward.
truncator = VariableTruncator(
max_size_bytes=dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE,
array_element_limit=dify_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH,
string_length_limit=dify_config.WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH,
)
truncation_result = truncator.truncate(value_seg)
if not truncation_result.truncated:
return None
original_length = None
if isinstance(value_seg.value, (list, dict)):
original_length = len(value_seg.value)
# Prepare content for storage
if isinstance(value_seg, StringSegment):
# For string types, store as plain text
original_content_serialized = value_seg.value
content_type = "text/plain"
filename = f"{self._generate_filename(name)}.txt"
else:
# For other types, store as JSON
original_content_serialized = dumps_with_segments(value_seg.value, ensure_ascii=False)
content_type = "application/json"
filename = f"{self._generate_filename(name)}.json"
original_size = len(original_content_serialized.encode("utf-8"))
bind = self._session.get_bind()
assert isinstance(bind, Engine)
file_srv = FileService(bind)
upload_file = file_srv.upload_file(
filename=filename,
content=original_content_serialized.encode(),
mimetype=content_type,
user=self._user,
)
# Create WorkflowDraftVariableFile record
variable_file = WorkflowDraftVariableFile(
id=uuidv7(),
upload_file_id=upload_file.id,
size=original_size,
length=original_length,
value_type=value_seg.value_type,
app_id=self._app_id,
tenant_id=self._user.current_tenant_id,
user_id=self._user.id,
)
engine = bind = self._session.get_bind()
assert isinstance(engine, Engine)
with Session(bind=engine, expire_on_commit=False) as session:
session.add(variable_file)
session.commit()
return truncation_result.result, variable_file
def _create_draft_variable(
self,
*,
name: str,
value: Segment,
visible: bool = True,
editable: bool = True,
) -> WorkflowDraftVariable:
"""Create a draft variable with large variable handling and truncation."""
# Handle Segment values
offload_result = self._try_offload_large_variable(name, value)
if offload_result is None:
# Create the draft variable
draft_var = WorkflowDraftVariable.new_node_variable(
app_id=self._app_id,
node_id=self._node_id,
name=name,
node_execution_id=self._node_execution_id,
value=value,
visible=visible,
editable=editable,
)
return draft_var
else:
truncated, var_file = offload_result
# Create the draft variable
draft_var = WorkflowDraftVariable.new_node_variable(
app_id=self._app_id,
node_id=self._node_id,
name=name,
node_execution_id=self._node_execution_id,
value=truncated,
visible=visible,
editable=False,
file_id=var_file.id,
)
return draft_var
def save(
self,
process_data: Mapping[str, Any] | None = None,

View File

@ -1,6 +1,5 @@
import threading
from collections.abc import Sequence
from typing import Optional
from sqlalchemy.orm import sessionmaker
@ -80,7 +79,7 @@ class WorkflowRunService:
last_id=last_id,
)
def get_workflow_run(self, app_model: App, run_id: str) -> Optional[WorkflowRun]:
def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun | None:
"""
Get workflow run detail

View File

@ -2,8 +2,7 @@ import json
import time
import uuid
from collections.abc import Callable, Generator, Mapping, Sequence
from typing import Any, Optional, cast
from uuid import uuid4
from typing import Any, cast
from sqlalchemy import exists, select
from sqlalchemy.orm import Session, sessionmaker
@ -15,43 +14,33 @@ from core.file import File
from core.repositories import DifyCoreRepositoryFactory
from core.variables import Variable
from core.variables.variables import VariableUnion
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
from core.workflow.entities import VariablePool, WorkflowNodeExecution
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes import NodeType
from core.workflow.nodes.base.node import BaseNode
from core.workflow.nodes.enums import ErrorStrategy
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.event.types import NodeEvent
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
from extensions.ext_database import db
from extensions.ext_storage import storage
from factories.file_factory import build_from_mapping, build_from_mappings
from libs.datetime_utils import naive_utc_now
from models.account import Account
from models.model import App, AppMode
from models.tools import WorkflowToolProvider
from models.workflow import (
Workflow,
WorkflowNodeExecutionModel,
WorkflowNodeExecutionTriggeredFrom,
WorkflowType,
)
from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType
from repositories.factory import DifyAPIRepositoryFactory
from services.enterprise.plugin_manager_service import PluginCredentialType
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError
from services.workflow.workflow_converter import WorkflowConverter
from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
from .workflow_draft_variable_service import (
DraftVariableSaver,
DraftVarLoader,
WorkflowDraftVariableService,
)
from .workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader, WorkflowDraftVariableService
class WorkflowService:
@ -96,10 +85,12 @@ class WorkflowService:
)
return db.session.execute(stmt).scalar_one()
def get_draft_workflow(self, app_model: App) -> Optional[Workflow]:
def get_draft_workflow(self, app_model: App, workflow_id: str | None = None) -> Workflow | None:
"""
Get draft workflow
"""
if workflow_id:
return self.get_published_workflow_by_id(app_model, workflow_id)
# fetch draft workflow by app_model
workflow = (
db.session.query(Workflow)
@ -114,8 +105,10 @@ class WorkflowService:
# return draft workflow
return workflow
def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
# fetch published workflow by workflow_id
def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Workflow | None:
"""
fetch published workflow by workflow_id
"""
workflow = (
db.session.query(Workflow)
.where(
@ -134,7 +127,7 @@ class WorkflowService:
)
return workflow
def get_published_workflow(self, app_model: App) -> Optional[Workflow]:
def get_published_workflow(self, app_model: App) -> Workflow | None:
"""
Get published workflow
"""
@ -199,7 +192,7 @@ class WorkflowService:
app_model: App,
graph: dict,
features: dict,
unique_hash: Optional[str],
unique_hash: str | None,
account: Account,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
@ -267,6 +260,12 @@ class WorkflowService:
if not draft_workflow:
raise ValueError("No valid workflow found.")
# Validate credentials before publishing, for credential policy check
from services.feature_service import FeatureService
if FeatureService.get_system_features().plugin_manager.enabled:
self._validate_workflow_credentials(draft_workflow)
# create new workflow
workflow = Workflow.new(
tenant_id=app_model.tenant_id,
@ -274,12 +273,13 @@ class WorkflowService:
type=draft_workflow.type,
version=Workflow.version_from_datetime(naive_utc_now()),
graph=draft_workflow.graph,
features=draft_workflow.features,
created_by=account.id,
environment_variables=draft_workflow.environment_variables,
conversation_variables=draft_workflow.conversation_variables,
marked_name=marked_name,
marked_comment=marked_comment,
rag_pipeline_variables=draft_workflow.rag_pipeline_variables,
features=draft_workflow.features,
)
# commit db session changes
@ -291,12 +291,285 @@ class WorkflowService:
# return new workflow
return workflow
def get_default_block_configs(self) -> list[dict]:
def _validate_workflow_credentials(self, workflow: Workflow) -> None:
"""
Validate all credentials in workflow nodes before publishing.
:param workflow: The workflow to validate
:raises ValueError: If any credentials violate policy compliance
"""
graph_dict = workflow.graph_dict
nodes = graph_dict.get("nodes", [])
for node in nodes:
node_data = node.get("data", {})
node_type = node_data.get("type")
node_id = node.get("id", "unknown")
try:
# Extract and validate credentials based on node type
if node_type == "tool":
credential_id = node_data.get("credential_id")
provider = node_data.get("provider_id")
if provider:
if credential_id:
# Check specific credential
from core.helper.credential_utils import check_credential_policy_compliance
check_credential_policy_compliance(
credential_id=credential_id,
provider=provider,
credential_type=PluginCredentialType.TOOL,
)
else:
# Check default workspace credential for this provider
self._check_default_tool_credential(workflow.tenant_id, provider)
elif node_type == "agent":
agent_params = node_data.get("agent_parameters", {})
model_config = agent_params.get("model", {}).get("value", {})
if model_config.get("provider") and model_config.get("model"):
self._validate_llm_model_config(
workflow.tenant_id, model_config["provider"], model_config["model"]
)
# Validate load balancing credentials for agent model if load balancing is enabled
agent_model_node_data = {"model": model_config}
self._validate_load_balancing_credentials(workflow, agent_model_node_data, node_id)
# Validate agent tools
tools = agent_params.get("tools", {}).get("value", [])
for tool in tools:
# Agent tools store provider in provider_name field
provider = tool.get("provider_name")
credential_id = tool.get("credential_id")
if provider:
if credential_id:
from core.helper.credential_utils import check_credential_policy_compliance
check_credential_policy_compliance(credential_id, provider, PluginCredentialType.TOOL)
else:
self._check_default_tool_credential(workflow.tenant_id, provider)
elif node_type in ["llm", "knowledge_retrieval", "parameter_extractor", "question_classifier"]:
model_config = node_data.get("model", {})
provider = model_config.get("provider")
model_name = model_config.get("name")
if provider and model_name:
# Validate that the provider+model combination can fetch valid credentials
self._validate_llm_model_config(workflow.tenant_id, provider, model_name)
# Validate load balancing credentials if load balancing is enabled
self._validate_load_balancing_credentials(workflow, node_data, node_id)
else:
raise ValueError(f"Node {node_id} ({node_type}): Missing provider or model configuration")
except Exception as e:
if isinstance(e, ValueError):
raise e
else:
raise ValueError(f"Node {node_id} ({node_type}): {str(e)}")
def _validate_llm_model_config(self, tenant_id: str, provider: str, model_name: str) -> None:
"""
Validate that an LLM model configuration can fetch valid credentials and has active status.
This method attempts to get the model instance and validates that:
1. The provider exists and is configured
2. The model exists in the provider
3. Credentials can be fetched for the model
4. The credentials pass policy compliance checks
5. The model status is ACTIVE (not NO_CONFIGURE, DISABLED, etc.)
:param tenant_id: The tenant ID
:param provider: The provider name
:param model_name: The model name
:raises ValueError: If the model configuration is invalid or credentials fail policy checks
"""
try:
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
# Get model instance to validate provider+model combination
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=tenant_id, provider=provider, model_type=ModelType.LLM, model=model_name
)
# The ModelInstance constructor will automatically check credential policy compliance
# via ProviderConfiguration.get_current_credentials() -> _check_credential_policy_compliance()
# If it fails, an exception will be raised
# Additionally, check the model status to ensure it's ACTIVE
provider_manager = ProviderManager()
provider_configurations = provider_manager.get_configurations(tenant_id)
models = provider_configurations.get_models(provider=provider, model_type=ModelType.LLM)
target_model = None
for model in models:
if model.model == model_name and model.provider.provider == provider:
target_model = model
break
if target_model:
target_model.raise_for_status()
else:
raise ValueError(f"Model {model_name} not found for provider {provider}")
except Exception as e:
raise ValueError(
f"Failed to validate LLM model configuration (provider: {provider}, model: {model_name}): {str(e)}"
)
def _check_default_tool_credential(self, tenant_id: str, provider: str) -> None:
"""
Check credential policy compliance for the default workspace credential of a tool provider.
This method finds the default credential for the given provider and validates it.
Uses the same fallback logic as runtime to handle deauthorized credentials.
:param tenant_id: The tenant ID
:param provider: The tool provider name
:raises ValueError: If no default credential exists or if it fails policy compliance
"""
try:
from models.tools import BuiltinToolProvider
# Use the same fallback logic as runtime: get the first available credential
# ordered by is_default DESC, created_at ASC (same as tool_manager.py)
default_provider = (
db.session.query(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
)
if not default_provider:
# plugin does not require credentials, skip
return
# Check credential policy compliance using the default credential ID
from core.helper.credential_utils import check_credential_policy_compliance
check_credential_policy_compliance(
credential_id=default_provider.id,
provider=provider,
credential_type=PluginCredentialType.TOOL,
check_existence=False,
)
except Exception as e:
raise ValueError(f"Failed to validate default credential for tool provider {provider}: {str(e)}")
def _validate_load_balancing_credentials(self, workflow: Workflow, node_data: dict, node_id: str) -> None:
"""
Validate load balancing credentials for a workflow node.
:param workflow: The workflow being validated
:param node_data: The node data containing model configuration
:param node_id: The node ID for error reporting
:raises ValueError: If load balancing credentials violate policy compliance
"""
# Extract model configuration
model_config = node_data.get("model", {})
provider = model_config.get("provider")
model_name = model_config.get("name")
if not provider or not model_name:
return # No model config to validate
# Check if this model has load balancing enabled
if self._is_load_balancing_enabled(workflow.tenant_id, provider, model_name):
# Get all load balancing configurations for this model
load_balancing_configs = self._get_load_balancing_configs(workflow.tenant_id, provider, model_name)
# Validate each load balancing configuration
try:
for config in load_balancing_configs:
if config.get("credential_id"):
from core.helper.credential_utils import check_credential_policy_compliance
check_credential_policy_compliance(
config["credential_id"], provider, PluginCredentialType.MODEL
)
except Exception as e:
raise ValueError(f"Invalid load balancing credentials for {provider}/{model_name}: {str(e)}")
def _is_load_balancing_enabled(self, tenant_id: str, provider: str, model_name: str) -> bool:
"""
Check if load balancing is enabled for a specific model.
:param tenant_id: The tenant ID
:param provider: The provider name
:param model_name: The model name
:return: True if load balancing is enabled, False otherwise
"""
try:
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
# Get provider configurations
provider_manager = ProviderManager()
provider_configurations = provider_manager.get_configurations(tenant_id)
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
return False
# Get provider model setting
provider_model_setting = provider_configuration.get_provider_model_setting(
model_type=ModelType.LLM,
model=model_name,
)
return provider_model_setting is not None and provider_model_setting.load_balancing_enabled
except Exception:
# If we can't determine the status, assume load balancing is not enabled
return False
def _get_load_balancing_configs(self, tenant_id: str, provider: str, model_name: str) -> list[dict]:
"""
Get all load balancing configurations for a model.
:param tenant_id: The tenant ID
:param provider: The provider name
:param model_name: The model name
:return: List of load balancing configuration dictionaries
"""
try:
from services.model_load_balancing_service import ModelLoadBalancingService
model_load_balancing_service = ModelLoadBalancingService()
_, configs = model_load_balancing_service.get_load_balancing_configs(
tenant_id=tenant_id,
provider=provider,
model=model_name,
model_type="llm", # Load balancing is primarily used for LLM models
config_from="predefined-model", # Check both predefined and custom models
)
_, custom_configs = model_load_balancing_service.get_load_balancing_configs(
tenant_id=tenant_id, provider=provider, model=model_name, model_type="llm", config_from="custom-model"
)
all_configs = configs + custom_configs
return [config for config in all_configs if config.get("credential_id")]
except Exception:
# If we can't get the configurations, return empty list
# This will prevent validation errors from breaking the workflow
return []
def get_default_block_configs(self) -> Sequence[Mapping[str, object]]:
"""
Get default block configs
"""
# return default block config
default_block_configs = []
default_block_configs: list[Mapping[str, object]] = []
for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values():
node_class = node_class_mapping[LATEST_VERSION]
default_config = node_class.get_default_config()
@ -305,7 +578,9 @@ class WorkflowService:
return default_block_configs
def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]:
def get_default_block_config(
self, node_type: str, filters: Mapping[str, object] | None = None
) -> Mapping[str, object]:
"""
Get default config of node.
:param node_type: node type
@ -316,12 +591,12 @@ class WorkflowService:
# return default block config
if node_type_enum not in NODE_TYPE_CLASSES_MAPPING:
return None
return {}
node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION]
default_config = node_class.get_default_config(filters=filters)
if not default_config:
return None
return {}
return default_config
@ -403,7 +678,7 @@ class WorkflowService:
# run draft workflow node
start_at = time.perf_counter()
node_execution = self._handle_node_run_result(
node_execution = self._handle_single_step_result(
invoke_node_fn=lambda: run,
start_at=start_at,
node_id=node_id,
@ -425,6 +700,9 @@ class WorkflowService:
if workflow_node_execution is None:
raise ValueError(f"WorkflowNodeExecution with id {node_execution.id} not found after saving")
with Session(db.engine) as session:
outputs = workflow_node_execution.load_full_outputs(session, storage)
with Session(bind=db.engine) as session, session.begin():
draft_var_saver = DraftVariableSaver(
session=session,
@ -433,8 +711,9 @@ class WorkflowService:
node_type=NodeType(workflow_node_execution.node_type),
enclosing_node_id=enclosing_node_id,
node_execution_id=node_execution.id,
user=account,
)
draft_var_saver.save(process_data=node_execution.process_data, outputs=node_execution.outputs)
draft_var_saver.save(process_data=node_execution.process_data, outputs=outputs)
session.commit()
return workflow_node_execution
@ -448,7 +727,7 @@ class WorkflowService:
# run free workflow node
start_at = time.perf_counter()
node_execution = self._handle_node_run_result(
node_execution = self._handle_single_step_result(
invoke_node_fn=lambda: WorkflowEntry.run_free_node(
node_id=node_id,
node_data=node_data,
@ -462,103 +741,131 @@ class WorkflowService:
return node_execution
def _handle_node_run_result(
def _handle_single_step_result(
self,
invoke_node_fn: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]],
invoke_node_fn: Callable[[], tuple[Node, Generator[GraphNodeEventBase, None, None]]],
start_at: float,
node_id: str,
) -> WorkflowNodeExecution:
try:
node, node_events = invoke_node_fn()
"""
Handle single step execution and return WorkflowNodeExecution.
node_run_result: NodeRunResult | None = None
for event in node_events:
if isinstance(event, RunCompletedEvent):
node_run_result = event.run_result
Args:
invoke_node_fn: Function to invoke node execution
start_at: Execution start time
node_id: ID of the node being executed
# sign output files
# node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
break
Returns:
WorkflowNodeExecution: The execution result
"""
node, node_run_result, run_succeeded, error = self._execute_node_safely(invoke_node_fn)
if not node_run_result:
raise ValueError("Node run failed with no run result")
# single step debug mode error handling return
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node.continue_on_error:
node_error_args: dict[str, Any] = {
"status": WorkflowNodeExecutionStatus.EXCEPTION,
"error": node_run_result.error,
"inputs": node_run_result.inputs,
"metadata": {"error_strategy": node.error_strategy},
}
if node.error_strategy is ErrorStrategy.DEFAULT_VALUE:
node_run_result = NodeRunResult(
**node_error_args,
outputs={
**node.default_value_dict,
"error_message": node_run_result.error,
"error_type": node_run_result.error_type,
},
)
else:
node_run_result = NodeRunResult(
**node_error_args,
outputs={
"error_message": node_run_result.error,
"error_type": node_run_result.error_type,
},
)
run_succeeded = node_run_result.status in (
WorkflowNodeExecutionStatus.SUCCEEDED,
WorkflowNodeExecutionStatus.EXCEPTION,
)
error = node_run_result.error if not run_succeeded else None
except WorkflowNodeRunFailedError as e:
node = e._node
run_succeeded = False
node_run_result = None
error = e._error
# Create a NodeExecution domain model
# Create base node execution
node_execution = WorkflowNodeExecution(
id=str(uuid4()),
workflow_id="", # This is a single-step execution, so no workflow ID
id=str(uuid.uuid4()),
workflow_id="", # Single-step execution has no workflow ID
index=1,
node_id=node_id,
node_type=node.type_,
node_type=node.node_type,
title=node.title,
elapsed_time=time.perf_counter() - start_at,
created_at=naive_utc_now(),
finished_at=naive_utc_now(),
)
# Populate execution result data
self._populate_execution_result(node_execution, node_run_result, run_succeeded, error)
return node_execution
def _execute_node_safely(
self, invoke_node_fn: Callable[[], tuple[Node, Generator[GraphNodeEventBase, None, None]]]
) -> tuple[Node, NodeRunResult | None, bool, str | None]:
"""
Execute node safely and handle errors according to error strategy.
Returns:
Tuple of (node, node_run_result, run_succeeded, error)
"""
try:
node, node_events = invoke_node_fn()
node_run_result = next(
(
event.node_run_result
for event in node_events
if isinstance(event, (NodeRunSucceededEvent, NodeRunFailedEvent))
),
None,
)
if not node_run_result:
raise ValueError("Node execution failed - no result returned")
# Apply error strategy if node failed
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node.error_strategy:
node_run_result = self._apply_error_strategy(node, node_run_result)
run_succeeded = node_run_result.status in (
WorkflowNodeExecutionStatus.SUCCEEDED,
WorkflowNodeExecutionStatus.EXCEPTION,
)
error = node_run_result.error if not run_succeeded else None
return node, node_run_result, run_succeeded, error
except WorkflowNodeRunFailedError as e:
node = e.node
run_succeeded = False
node_run_result = None
error = e.error
return node, node_run_result, run_succeeded, error
def _apply_error_strategy(self, node: Node, node_run_result: NodeRunResult) -> NodeRunResult:
"""Apply error strategy when node execution fails."""
# TODO(Novice): Maybe we should apply error strategy to node level?
error_outputs = {
"error_message": node_run_result.error,
"error_type": node_run_result.error_type,
}
# Add default values if strategy is DEFAULT_VALUE
if node.error_strategy is ErrorStrategy.DEFAULT_VALUE:
error_outputs.update(node.default_value_dict)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.EXCEPTION,
error=node_run_result.error,
inputs=node_run_result.inputs,
metadata={WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node.error_strategy},
outputs=error_outputs,
)
def _populate_execution_result(
self,
node_execution: WorkflowNodeExecution,
node_run_result: NodeRunResult | None,
run_succeeded: bool,
error: str | None,
) -> None:
"""Populate node execution with result data."""
if run_succeeded and node_run_result:
# Set inputs, process_data, and outputs as dictionaries (not JSON strings)
inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None
process_data = (
node_execution.inputs = (
WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None
)
node_execution.process_data = (
WorkflowEntry.handle_special_values(node_run_result.process_data)
if node_run_result.process_data
else None
)
outputs = node_run_result.outputs
node_execution.inputs = inputs
node_execution.process_data = process_data
node_execution.outputs = outputs
node_execution.outputs = node_run_result.outputs
node_execution.metadata = node_run_result.metadata
# Map status from WorkflowNodeExecutionStatus to NodeExecutionStatus
if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION
# Set status and error based on result
node_execution.status = node_run_result.status
if node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
node_execution.error = node_run_result.error
else:
# Set failed status and error
node_execution.status = WorkflowNodeExecutionStatus.FAILED
node_execution.error = error
return node_execution
def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App:
"""
Basic mode of chatbot app(expert mode) to workflow
@ -572,7 +879,7 @@ class WorkflowService:
# chatbot convert to workflow mode
workflow_converter = WorkflowConverter()
if app_model.mode not in {AppMode.CHAT.value, AppMode.COMPLETION.value}:
if app_model.mode not in {AppMode.CHAT, AppMode.COMPLETION}:
raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.")
# convert to workflow
@ -587,12 +894,12 @@ class WorkflowService:
return new_app
def validate_features_structure(self, app_model: App, features: dict) -> dict:
if app_model.mode == AppMode.ADVANCED_CHAT.value:
def validate_features_structure(self, app_model: App, features: dict):
if app_model.mode == AppMode.ADVANCED_CHAT:
return AdvancedChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
)
elif app_model.mode == AppMode.WORKFLOW.value:
elif app_model.mode == AppMode.WORKFLOW:
return WorkflowAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
)
@ -601,7 +908,7 @@ class WorkflowService:
def update_workflow(
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
) -> Optional[Workflow]:
) -> Workflow | None:
"""
Update workflow attributes
@ -702,7 +1009,7 @@ def _setup_variable_pool(
if workflow.type != WorkflowType.WORKFLOW.value:
system_variable.query = query
system_variable.conversation_id = conversation_id
system_variable.dialogue_count = 0
system_variable.dialogue_count = 1
else:
system_variable = SystemVariable.empty()

View File

@ -12,7 +12,7 @@ class WorkspaceService:
def get_tenant_info(cls, tenant: Tenant):
if not tenant:
return None
tenant_info = {
tenant_info: dict[str, object] = {
"id": tenant.id,
"name": tenant.name,
"plan": tenant.plan,