Merge branch 'feat/plugin-auto-upgrade' into deploy/dev

This commit is contained in:
Junyan Qin
2025-07-10 14:39:03 +08:00
80 changed files with 1507 additions and 716 deletions

View File

@ -451,6 +451,16 @@ APP_MAX_ACTIVE_REQUESTS=0
# Celery beat configuration
CELERY_BEAT_SCHEDULER_TIME=1
# Celery schedule tasks configuration
ENABLE_CLEAN_EMBEDDING_CACHE_TASK=false
ENABLE_CLEAN_UNUSED_DATASETS_TASK=false
ENABLE_CREATE_TIDB_SERVERLESS_TASK=false
ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK=false
ENABLE_CLEAN_MESSAGES=false
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
ENABLE_DATASETS_QUEUE_MONITOR=false
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
# Position configuration
POSITION_TOOL_PINS=
POSITION_TOOL_INCLUDES=

View File

@ -779,6 +779,41 @@ class CeleryBeatConfig(BaseSettings):
)
class CeleryScheduleTasksConfig(BaseSettings):
ENABLE_CLEAN_EMBEDDING_CACHE_TASK: bool = Field(
description="Enable clean embedding cache task",
default=False,
)
ENABLE_CLEAN_UNUSED_DATASETS_TASK: bool = Field(
description="Enable clean unused datasets task",
default=False,
)
ENABLE_CREATE_TIDB_SERVERLESS_TASK: bool = Field(
description="Enable create tidb service job task",
default=False,
)
ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK: bool = Field(
description="Enable update tidb service job status task",
default=False,
)
ENABLE_CLEAN_MESSAGES: bool = Field(
description="Enable clean messages task",
default=False,
)
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field(
description="Enable mail clean document notify task",
default=False,
)
ENABLE_DATASETS_QUEUE_MONITOR: bool = Field(
description="Enable queue monitor task",
default=False,
)
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: bool = Field(
description="Enable check upgradable plugin task",
default=True,
)
class PositionConfig(BaseSettings):
POSITION_PROVIDER_PINS: str = Field(
description="Comma-separated list of pinned model providers",
@ -907,5 +942,6 @@ class FeatureConfig(
# hosted services config
HostedServiceConfig,
CeleryBeatConfig,
CeleryScheduleTasksConfig,
):
pass

View File

@ -533,7 +533,8 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
raise ValueError(e)
return jsonable_encoder({"options": options})
class PluginChangePreferencesApi(Resource):
@setup_required
@login_required
@ -597,29 +598,26 @@ class PluginFetchPreferencesApi(Resource):
tenant_id = current_user.current_tenant_id
permission = PluginPermissionService.get_permission(tenant_id)
permission_dict = {
"install_permission": TenantPluginPermission.InstallPermission.EVERYONE,
"debug_permission": TenantPluginPermission.DebugPermission.EVERYONE,
}
if not permission:
permission = {
"install_permission": TenantPluginPermission.InstallPermission.EVERYONE,
"debug_permission": TenantPluginPermission.DebugPermission.EVERYONE,
}
else:
permission = {
"install_permission": permission.install_permission,
"debug_permission": permission.debug_permission,
}
if permission:
permission_dict["install_permission"] = permission.install_permission
permission_dict["debug_permission"] = permission.debug_permission
auto_upgrade = PluginAutoUpgradeService.get_strategy(tenant_id)
if not auto_upgrade:
auto_upgrade = {
"strategy_setting": TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
"upgrade_time_of_day": 0,
"upgrade_mode": TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
"exclude_plugins": [],
"include_plugins": [],
}
else:
auto_upgrade = {
auto_upgrade_dict = {
"strategy_setting": TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
"upgrade_time_of_day": 0,
"upgrade_mode": TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
"exclude_plugins": [],
"include_plugins": [],
}
if auto_upgrade:
auto_upgrade_dict = {
"strategy_setting": auto_upgrade.strategy_setting,
"upgrade_time_of_day": auto_upgrade.upgrade_time_of_day,
"upgrade_mode": auto_upgrade.upgrade_mode,
@ -627,7 +625,7 @@ class PluginFetchPreferencesApi(Resource):
"include_plugins": auto_upgrade.include_plugins,
}
return jsonable_encoder({"permission": permission, "auto_upgrade": auto_upgrade})
return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict})
class PluginAutoUpgradeExcludePluginApi(Resource):

View File

@ -25,9 +25,29 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP
url = str(marketplace_api_url / "api/v1/plugins/batch")
response = requests.post(url, json={"plugin_ids": plugin_ids})
response.raise_for_status()
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
def batch_fetch_plugin_manifests_ignore_deserialization_error(
plugin_ids: list[str],
) -> Sequence[MarketplacePluginDeclaration]:
if len(plugin_ids) == 0:
return []
url = str(marketplace_api_url / "api/v1/plugins/batch")
response = requests.post(url, json={"plugin_ids": plugin_ids})
response.raise_for_status()
result: list[MarketplacePluginDeclaration] = []
for plugin in response.json()["data"]["plugins"]:
try:
result.append(MarketplacePluginDeclaration(**plugin))
except Exception as e:
pass
return result
def record_install_plugin_event(plugin_unique_identifier: str):
url = str(marketplace_api_url / "api/v1/stats/plugins/install_count")
response = requests.post(url, json={"unique_identifier": plugin_unique_identifier})

View File

@ -47,6 +47,7 @@ class QdrantConfig(BaseModel):
grpc_port: int = 6334
prefer_grpc: bool = False
replication_factor: int = 1
write_consistency_factor: int = 1
def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith("path:"):
@ -127,6 +128,7 @@ class QdrantVector(BaseVector):
hnsw_config=hnsw_config,
timeout=int(self._client_config.timeout),
replication_factor=self._client_config.replication_factor,
write_consistency_factor=self._client_config.write_consistency_factor,
)
# create group_id payload index

View File

@ -1,3 +1,5 @@
import logging
import time
from abc import ABC, abstractmethod
from typing import Any, Optional
@ -13,6 +15,8 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, Whitelist
logger = logging.getLogger(__name__)
class AbstractVectorFactory(ABC):
@abstractmethod
@ -173,8 +177,20 @@ class Vector:
def create(self, texts: Optional[list] = None, **kwargs):
if texts:
embeddings = self._embeddings.embed_documents([document.page_content for document in texts])
self._vector_processor.create(texts=texts, embeddings=embeddings, **kwargs)
start = time.time()
logger.info(f"start embedding {len(texts)} texts {start}")
batch_size = 1000
total_batches = len(texts) + batch_size - 1
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
batch_start = time.time()
logger.info(f"Processing batch {i // batch_size + 1}/{total_batches} ({len(batch)} texts)")
batch_embeddings = self._embeddings.embed_documents([document.page_content for document in batch])
logger.info(
f"Embedding batch {i // batch_size + 1}/{total_batches} took {time.time() - batch_start:.3f}s"
)
self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs)
logger.info(f"Embedding {len(texts)} texts took {time.time() - start:.3f}s")
def add_texts(self, documents: list[Document], **kwargs):
if kwargs.get("duplicate_check", False):

View File

@ -232,14 +232,14 @@ class WorkflowLoggingCallback(WorkflowCallback):
Publish loop started
"""
self.print_text("\n[LoopRunStartedEvent]", color="blue")
self.print_text(f"Loop Node ID: {event.loop_id}", color="blue")
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
def on_workflow_loop_next(self, event: LoopRunNextEvent) -> None:
"""
Publish loop next
"""
self.print_text("\n[LoopRunNextEvent]", color="blue")
self.print_text(f"Loop Node ID: {event.loop_id}", color="blue")
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
self.print_text(f"Loop Index: {event.index}", color="blue")
def on_workflow_loop_completed(self, event: LoopRunSucceededEvent | LoopRunFailedEvent) -> None:
@ -250,7 +250,7 @@ class WorkflowLoggingCallback(WorkflowCallback):
"\n[LoopRunSucceededEvent]" if isinstance(event, LoopRunSucceededEvent) else "\n[LoopRunFailedEvent]",
color="blue",
)
self.print_text(f"Node ID: {event.loop_id}", color="blue")
self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
def print_text(self, text: str, color: Optional[str] = None, end: str = "\n") -> None:
"""Print text with highlighting and no end characters."""

View File

@ -334,7 +334,7 @@ class Graph(BaseModel):
parallel = GraphParallel(
start_from_node_id=start_node_id,
parent_parallel_id=parent_parallel.id if parent_parallel else None,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None,
)
parallel_mapping[parallel.id] = parallel

View File

@ -285,7 +285,8 @@ class ToolNode(BaseNode[ToolNodeData]):
for key, value in msg_metadata.items()
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
}
json.append(message.message.json_object)
if message.message.json_object is not None:
json.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
@ -369,31 +370,31 @@ class ToolNode(BaseNode[ToolNodeData]):
agent_logs.append(agent_log)
yield agent_log
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
json_output: dict[str, Any] = {}
if json:
if isinstance(json, list) and len(json) == 1:
# If json is a list with only one element, convert it to a dictionary
json_output = json[0] if isinstance(json[0], dict) else {"data": json[0]}
elif isinstance(json, list):
# If json is a list with multiple elements, create a dictionary containing all data
json_output = {"data": json}
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
json_output: list[dict[str, Any]] = []
# Step 1: append each agent log as its own dict.
if agent_logs:
# Add agent_logs to json output
json_output["agent_logs"] = [
{
"id": log.id,
"parent_id": log.parent_id,
"error": log.error,
"status": log.status,
"data": log.data,
"label": log.label,
"metadata": log.metadata,
"node_id": log.node_id,
}
for log in agent_logs
]
for log in agent_logs:
json_output.append(
{
"id": log.id,
"parent_id": log.parent_id,
"error": log.error,
"status": log.status,
"data": log.data,
"label": log.label,
"metadata": log.metadata,
"node_id": log.node_id,
}
)
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
if json:
json_output.extend(json)
else:
json_output.append({"data": []})
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,

View File

@ -64,55 +64,62 @@ def init_app(app: DifyApp) -> Celery:
celery_app.set_default()
app.extensions["celery"] = celery_app
imports = [
"schedule.clean_embedding_cache_task",
"schedule.clean_unused_datasets_task",
"schedule.create_tidb_serverless_task",
"schedule.update_tidb_serverless_status_task",
"schedule.clean_messages",
"schedule.mail_clean_document_notify_task",
"schedule.queue_monitor_task",
"schedule.check_upgradable_plugin_task",
]
imports = []
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
beat_schedule = {
"clean_embedding_cache_task": {
# if you add a new task, please add the switch to CeleryScheduleTasksConfig
beat_schedule = {}
if dify_config.ENABLE_CLEAN_EMBEDDING_CACHE_TASK:
imports.append("schedule.clean_embedding_cache_task")
beat_schedule["clean_embedding_cache_task"] = {
"task": "schedule.clean_embedding_cache_task.clean_embedding_cache_task",
"schedule": timedelta(days=day),
},
"clean_unused_datasets_task": {
}
if dify_config.ENABLE_CLEAN_UNUSED_DATASETS_TASK:
imports.append("schedule.clean_unused_datasets_task")
beat_schedule["clean_unused_datasets_task"] = {
"task": "schedule.clean_unused_datasets_task.clean_unused_datasets_task",
"schedule": timedelta(days=day),
},
"create_tidb_serverless_task": {
}
if dify_config.ENABLE_CREATE_TIDB_SERVERLESS_TASK:
imports.append("schedule.create_tidb_serverless_task")
beat_schedule["create_tidb_serverless_task"] = {
"task": "schedule.create_tidb_serverless_task.create_tidb_serverless_task",
"schedule": crontab(minute="0", hour="*"),
},
"update_tidb_serverless_status_task": {
}
if dify_config.ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK:
imports.append("schedule.update_tidb_serverless_status_task")
beat_schedule["update_tidb_serverless_status_task"] = {
"task": "schedule.update_tidb_serverless_status_task.update_tidb_serverless_status_task",
"schedule": timedelta(minutes=10),
},
"clean_messages": {
}
if dify_config.ENABLE_CLEAN_MESSAGES:
imports.append("schedule.clean_messages")
beat_schedule["clean_messages"] = {
"task": "schedule.clean_messages.clean_messages",
"schedule": timedelta(days=day),
},
# every Monday
"mail_clean_document_notify_task": {
}
if dify_config.ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK:
imports.append("schedule.mail_clean_document_notify_task")
beat_schedule["mail_clean_document_notify_task"] = {
"task": "schedule.mail_clean_document_notify_task.mail_clean_document_notify_task",
"schedule": crontab(minute="0", hour="10", day_of_week="1"),
},
"datasets-queue-monitor": {
}
if dify_config.ENABLE_DATASETS_QUEUE_MONITOR:
imports.append("schedule.queue_monitor_task")
beat_schedule["datasets-queue-monitor"] = {
"task": "schedule.queue_monitor_task.queue_monitor_task",
"schedule": timedelta(
minutes=dify_config.QUEUE_MONITOR_INTERVAL if dify_config.QUEUE_MONITOR_INTERVAL else 30
),
},
# every 15 minutes
"check_upgradable_plugin_task": {
}
if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK:
imports.append("schedule.check_upgradable_plugin_task")
beat_schedule["check_upgradable_plugin_task"] = {
"task": "schedule.check_upgradable_plugin_task.check_upgradable_plugin_task",
"schedule": crontab(minute="*/15"),
},
}
}
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
return celery_app

View File

@ -1,6 +1,10 @@
import functools
import logging
from collections.abc import Callable
from typing import Any, Union
import redis
from redis import RedisError
from redis.cache import CacheConfig
from redis.cluster import ClusterNode, RedisCluster
from redis.connection import Connection, SSLConnection
@ -9,6 +13,8 @@ from redis.sentinel import Sentinel
from configs import dify_config
from dify_app import DifyApp
logger = logging.getLogger(__name__)
class RedisClientWrapper:
"""
@ -115,3 +121,25 @@ def init_app(app: DifyApp):
redis_client.initialize(redis.Redis(connection_pool=pool))
app.extensions["redis"] = redis_client
def redis_fallback(default_return: Any = None):
"""
decorator to handle Redis operation exceptions and return a default value when Redis is unavailable.
Args:
default_return: The value to return when a Redis operation fails. Defaults to None.
"""
def decorator(func: Callable):
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except RedisError as e:
logger.warning(f"Redis operation failed in {func.__name__}: {str(e)}", exc_info=True)
return default_return
return wrapper
return decorator

View File

@ -0,0 +1,41 @@
"""empty message
Revision ID: 16081485540c
Revises: d28f2004b072
Create Date: 2025-05-15 16:35:39.113777
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '16081485540c'
down_revision = '58eb7bdb93fe'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tenant_plugin_auto_upgrade_strategies',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('strategy_setting', sa.String(length=16), server_default='fix_only', nullable=False),
sa.Column('upgrade_time_of_day', sa.Integer(), nullable=False),
sa.Column('upgrade_mode', sa.String(length=16), server_default='exclude', nullable=False),
sa.Column('exclude_plugins', sa.ARRAY(sa.String(length=255)), nullable=False),
sa.Column('include_plugins', sa.ARRAY(sa.String(length=255)), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='tenant_plugin_auto_upgrade_strategy_pkey'),
sa.UniqueConstraint('tenant_id', name='unique_tenant_plugin_auto_upgrade_strategy')
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('tenant_plugin_auto_upgrade_strategies')
# ### end Alembic commands ###

View File

@ -1,29 +1,23 @@
import time
import traceback
import click
import app
from core.helper import marketplace
from core.plugin.entities.plugin import PluginInstallationSource
from core.plugin.impl.plugin import PluginInstaller
from extensions.ext_database import db
from models.account import TenantPluginAutoUpgradeStrategy
from tasks.process_tenant_plugin_autoupgrade_check_task import process_tenant_plugin_autoupgrade_check_task
AUTO_UPGRADE_MINIMAL_CHECKING_INTERVAL = 15 * 60 # 15 minutes
RETRY_TIMES_OF_ONE_PLUGIN_IN_ONE_TENANT = 3
@app.celery.task(queue="plugin")
def check_upgradable_plugin_task():
click.echo(click.style("Start check upgradable plugin.", fg="green"))
start_at = time.perf_counter()
now_seconds_of_day = time.time() % 86400 # we assume the tz is UTC
now_seconds_of_day = time.time() % 86400 - 30 # we assume the tz is UTC
click.echo(click.style("Now seconds of day: {}".format(now_seconds_of_day), fg="green"))
# get strategies that set to be performed in the next AUTO_UPGRADE_MINIMAL_CHECKING_INTERVAL
strategies = (
db.session.query(TenantPluginAutoUpgradeStrategy)
.filter(
@ -36,113 +30,15 @@ def check_upgradable_plugin_task():
.all()
)
manager = PluginInstaller()
for strategy in strategies:
try:
tenant_id = strategy.tenant_id
strategy_setting = strategy.strategy_setting
upgrade_mode = strategy.upgrade_mode
exclude_plugins = strategy.exclude_plugins
include_plugins = strategy.include_plugins
if strategy_setting == TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED:
continue
# get plugins that need to be checked
plugin_ids: list[tuple[str, str, str]] = [] # plugin_id, version, unique_identifier
if upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL and include_plugins:
all_plugins = manager.list_plugins(tenant_id)
for plugin in all_plugins:
if plugin.source == PluginInstallationSource.Marketplace and plugin.plugin_id in include_plugins:
plugin_ids.append((plugin.plugin_id, plugin.version, plugin.plugin_unique_identifier))
elif upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE:
# get all plugins and remove the exclude plugins
all_plugins = manager.list_plugins(tenant_id)
plugin_ids = [
(plugin.plugin_id, plugin.version, plugin.plugin_unique_identifier)
for plugin in all_plugins
if plugin.source == PluginInstallationSource.Marketplace and plugin.plugin_id not in exclude_plugins
]
elif upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL:
all_plugins = manager.list_plugins(tenant_id)
plugin_ids = [
(plugin.plugin_id, plugin.version, plugin.plugin_unique_identifier)
for plugin in all_plugins
if plugin.source == PluginInstallationSource.Marketplace
]
if not plugin_ids:
continue
plugin_ids_plain_list = [plugin_id for plugin_id, _, _ in plugin_ids]
click.echo(click.style("Fetching manifests for plugins: {}".format(plugin_ids_plain_list), fg="green"))
# fetch latest versions from marketplace
manifests = marketplace.batch_fetch_plugin_manifests(plugin_ids_plain_list)
for manifest in manifests:
for plugin_id, version, original_unique_identifier in plugin_ids:
if manifest.plugin_id != plugin_id:
continue
try:
current_version = version
latest_version = manifest.latest_version
# @yeuoly review here
def fix_only_checker(latest_version, current_version):
latest_version_tuple = tuple(int(val) for val in latest_version.split("."))
current_version_tuple = tuple(int(val) for val in current_version.split("."))
if (
latest_version_tuple[0] == current_version_tuple[0]
and latest_version_tuple[1] == current_version_tuple[1]
):
return latest_version_tuple[2] != current_version_tuple[2]
return False
version_checker = {
TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST: lambda latest_version,
current_version: latest_version != current_version,
TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY: fix_only_checker,
}
if version_checker[strategy_setting](latest_version, current_version):
# execute upgrade
new_unique_identifier = manifest.latest_package_identifier
marketplace.record_install_plugin_event(new_unique_identifier)
click.echo(
click.style(
"Upgrade plugin: {} -> {}".format(
original_unique_identifier, new_unique_identifier
),
fg="green",
)
)
task_start_resp = manager.upgrade_plugin(
tenant_id,
original_unique_identifier,
new_unique_identifier,
PluginInstallationSource.Marketplace,
{
"plugin_unique_identifier": new_unique_identifier,
},
)
except Exception as e:
click.echo(click.style("Error when upgrading plugin: {}".format(e), fg="red"))
traceback.print_exc()
break
except Exception as e:
click.echo(click.style("Error when checking upgradable plugin: {}".format(e), fg="red"))
traceback.print_exc()
continue
process_tenant_plugin_autoupgrade_check_task.delay(
strategy.tenant_id,
strategy.strategy_setting,
strategy.upgrade_time_of_day,
strategy.upgrade_mode,
strategy.exclude_plugins,
strategy.include_plugins,
)
end_at = time.perf_counter()
click.echo(

View File

@ -16,7 +16,7 @@ from configs import dify_config
from constants.languages import language_timezone_mapping, languages
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_redis import redis_client, redis_fallback
from libs.helper import RateLimiter, TokenManager
from libs.passport import PassportService
from libs.password import compare_password, hash_password, valid_password
@ -28,6 +28,7 @@ from models.account import (
Tenant,
TenantAccountJoin,
TenantAccountRole,
TenantPluginAutoUpgradeStrategy,
TenantStatus,
)
from models.model import DifySetup
@ -495,6 +496,7 @@ class AccountService:
return account
@staticmethod
@redis_fallback(default_return=None)
def add_login_error_rate_limit(email: str) -> None:
key = f"login_error_rate_limit:{email}"
count = redis_client.get(key)
@ -504,6 +506,7 @@ class AccountService:
redis_client.setex(key, dify_config.LOGIN_LOCKOUT_DURATION, count)
@staticmethod
@redis_fallback(default_return=False)
def is_login_error_rate_limit(email: str) -> bool:
key = f"login_error_rate_limit:{email}"
count = redis_client.get(key)
@ -516,11 +519,13 @@ class AccountService:
return False
@staticmethod
@redis_fallback(default_return=None)
def reset_login_error_rate_limit(email: str):
key = f"login_error_rate_limit:{email}"
redis_client.delete(key)
@staticmethod
@redis_fallback(default_return=None)
def add_forgot_password_error_rate_limit(email: str) -> None:
key = f"forgot_password_error_rate_limit:{email}"
count = redis_client.get(key)
@ -530,6 +535,7 @@ class AccountService:
redis_client.setex(key, dify_config.FORGOT_PASSWORD_LOCKOUT_DURATION, count)
@staticmethod
@redis_fallback(default_return=False)
def is_forgot_password_error_rate_limit(email: str) -> bool:
key = f"forgot_password_error_rate_limit:{email}"
count = redis_client.get(key)
@ -542,11 +548,13 @@ class AccountService:
return False
@staticmethod
@redis_fallback(default_return=None)
def reset_forgot_password_error_rate_limit(email: str):
key = f"forgot_password_error_rate_limit:{email}"
redis_client.delete(key)
@staticmethod
@redis_fallback(default_return=False)
def is_email_send_ip_limit(ip_address: str):
minute_key = f"email_send_ip_limit_minute:{ip_address}"
freeze_key = f"email_send_ip_limit_freeze:{ip_address}"
@ -604,6 +612,17 @@ class TenantService:
db.session.add(tenant)
db.session.commit()
plugin_upgrade_strategy = TenantPluginAutoUpgradeStrategy(
tenant_id=tenant.id,
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
upgrade_time_of_day=0,
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
exclude_plugins=[],
include_plugins=[],
)
db.session.add(plugin_upgrade_strategy)
db.session.commit()
tenant.encrypt_public_key = generate_key_pair(tenant.id)
db.session.commit()
return tenant

View File

@ -22,7 +22,7 @@ class PluginAutoUpgradeService:
upgrade_mode: TenantPluginAutoUpgradeStrategy.UpgradeMode,
exclude_plugins: list[str],
include_plugins: list[str],
) -> None:
) -> bool:
with Session(db.engine) as session:
exist_strategy = (
session.query(TenantPluginAutoUpgradeStrategy)

View File

@ -72,6 +72,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id == document_id,
).delete()
db.session.commit()
end_at = time.perf_counter()
logging.info(

View File

@ -0,0 +1,166 @@
import traceback
import typing
import click
from celery import shared_task # type: ignore
from core.helper import marketplace
from core.helper.marketplace import MarketplacePluginDeclaration
from core.plugin.entities.plugin import PluginInstallationSource
from core.plugin.impl.plugin import PluginInstaller
from models.account import TenantPluginAutoUpgradeStrategy
RETRY_TIMES_OF_ONE_PLUGIN_IN_ONE_TENANT = 3
cached_plugin_manifests: dict[str, typing.Union[MarketplacePluginDeclaration, None]] = {}
def marketplace_batch_fetch_plugin_manifests(
plugin_ids_plain_list: list[str],
) -> list[MarketplacePluginDeclaration]:
global cached_plugin_manifests
# return marketplace.batch_fetch_plugin_manifests(plugin_ids_plain_list)
not_included_plugin_ids = [
plugin_id for plugin_id in plugin_ids_plain_list if plugin_id not in cached_plugin_manifests
]
if not_included_plugin_ids:
manifests = marketplace.batch_fetch_plugin_manifests_ignore_deserialization_error(not_included_plugin_ids)
for manifest in manifests:
cached_plugin_manifests[manifest.plugin_id] = manifest
if (
len(manifests) == 0
): # this indicates that the plugin not found in marketplace, should set None in cache to prevent future check
for plugin_id in not_included_plugin_ids:
cached_plugin_manifests[plugin_id] = None
result: list[MarketplacePluginDeclaration] = []
for plugin_id in plugin_ids_plain_list:
final_manifest = cached_plugin_manifests.get(plugin_id)
if final_manifest is not None:
result.append(final_manifest)
return result
@shared_task(queue="plugin")
def process_tenant_plugin_autoupgrade_check_task(
tenant_id: str,
strategy_setting: TenantPluginAutoUpgradeStrategy.StrategySetting,
upgrade_time_of_day: int,
upgrade_mode: TenantPluginAutoUpgradeStrategy.UpgradeMode,
exclude_plugins: list[str],
include_plugins: list[str],
):
try:
manager = PluginInstaller()
click.echo(
click.style(
"Checking upgradable plugin for tenant: {}".format(tenant_id),
fg="green",
)
)
if strategy_setting == TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED:
return
# 获取需要检查的插件
plugin_ids: list[tuple[str, str, str]] = [] # plugin_id, version, unique_identifier
click.echo(click.style("Upgrade mode: {}".format(upgrade_mode), fg="green"))
if upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL and include_plugins:
all_plugins = manager.list_plugins(tenant_id)
for plugin in all_plugins:
if plugin.source == PluginInstallationSource.Marketplace and plugin.plugin_id in include_plugins:
plugin_ids.append(
(
plugin.plugin_id,
plugin.version,
plugin.plugin_unique_identifier,
)
)
elif upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE:
# 获取所有插件并移除exclude的
all_plugins = manager.list_plugins(tenant_id)
plugin_ids = [
(plugin.plugin_id, plugin.version, plugin.plugin_unique_identifier)
for plugin in all_plugins
if plugin.source == PluginInstallationSource.Marketplace and plugin.plugin_id not in exclude_plugins
]
elif upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL:
all_plugins = manager.list_plugins(tenant_id)
plugin_ids = [
(plugin.plugin_id, plugin.version, plugin.plugin_unique_identifier)
for plugin in all_plugins
if plugin.source == PluginInstallationSource.Marketplace
]
if not plugin_ids:
return
plugin_ids_plain_list = [plugin_id for plugin_id, _, _ in plugin_ids]
manifests = marketplace_batch_fetch_plugin_manifests(plugin_ids_plain_list)
if not manifests:
return
for manifest in manifests:
for plugin_id, version, original_unique_identifier in plugin_ids:
if manifest.plugin_id != plugin_id:
continue
try:
current_version = version
latest_version = manifest.latest_version
def fix_only_checker(latest_version, current_version):
latest_version_tuple = tuple(int(val) for val in latest_version.split("."))
current_version_tuple = tuple(int(val) for val in current_version.split("."))
if (
latest_version_tuple[0] == current_version_tuple[0]
and latest_version_tuple[1] == current_version_tuple[1]
):
return latest_version_tuple[2] != current_version_tuple[2]
return False
version_checker = {
TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST: lambda latest_version,
current_version: latest_version != current_version,
TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY: fix_only_checker,
}
if version_checker[strategy_setting](latest_version, current_version):
# 执行升级
new_unique_identifier = manifest.latest_package_identifier
marketplace.record_install_plugin_event(new_unique_identifier)
click.echo(
click.style(
"Upgrade plugin: {} -> {}".format(original_unique_identifier, new_unique_identifier),
fg="green",
)
)
task_start_resp = manager.upgrade_plugin(
tenant_id,
original_unique_identifier,
new_unique_identifier,
PluginInstallationSource.Marketplace,
{
"plugin_unique_identifier": new_unique_identifier,
},
)
except Exception as e:
click.echo(click.style("Error when upgrading plugin: {}".format(e), fg="red"))
traceback.print_exc()
break
except Exception as e:
click.echo(click.style("Error when checking upgradable plugin: {}".format(e), fg="red"))
traceback.print_exc()
return

View File

@ -4,7 +4,6 @@ import time
from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseConfig, CouchbaseVector
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
get_example_text,
setup_mock_redis,
)

View File

@ -1,7 +1,6 @@
from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneConfig, MatrixoneVector
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
get_example_text,
setup_mock_redis,
)

View File

@ -5,7 +5,6 @@ import psycopg2 # type: ignore
from core.rag.datasource.vdb.opengauss.opengauss import OpenGauss, OpenGaussConfig
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
get_example_text,
setup_mock_redis,
)

View File

@ -1,7 +1,6 @@
from core.rag.datasource.vdb.pyvastbase.vastbase_vector import VastbaseVector, VastbaseVectorConfig
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
get_example_text,
setup_mock_redis,
)

View File

@ -1,5 +1,4 @@
import json
import os
import time
import uuid
from collections.abc import Generator
@ -113,8 +112,6 @@ def test_execute_llm(flask_req_ctx):
},
)
credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
# Create a proper LLM result with real entities
mock_usage = LLMUsage(
prompt_tokens=30,

View File

@ -0,0 +1,280 @@
import base64
import binascii
from unittest.mock import MagicMock, patch
import pytest
from core.helper.encrypter import (
batch_decrypt_token,
decrypt_token,
encrypt_token,
get_decrypt_decoding,
obfuscated_token,
)
from libs.rsa import PrivkeyNotFoundError
class TestObfuscatedToken:
@pytest.mark.parametrize(
("token", "expected"),
[
("", ""), # Empty token
("1234567", "*" * 20), # Short token (<8 chars)
("12345678", "*" * 20), # Boundary case (8 chars)
("123456789abcdef", "123456" + "*" * 12 + "ef"), # Long token
("abc!@#$%^&*()def", "abc!@#" + "*" * 12 + "ef"), # Special chars
],
)
def test_obfuscation_logic(self, token, expected):
"""Test core obfuscation logic for various token lengths"""
assert obfuscated_token(token) == expected
def test_sensitive_data_protection(self):
"""Ensure obfuscation never reveals full sensitive data"""
token = "api_key_secret_12345"
obfuscated = obfuscated_token(token)
assert token not in obfuscated
assert "*" * 12 in obfuscated
class TestEncryptToken:
@patch("models.engine.db.session.query")
@patch("libs.rsa.encrypt")
def test_successful_encryption(self, mock_encrypt, mock_query):
"""Test successful token encryption"""
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
mock_encrypt.return_value = b"encrypted_data"
result = encrypt_token("tenant-123", "test_token")
assert result == base64.b64encode(b"encrypted_data").decode()
mock_encrypt.assert_called_with("test_token", "mock_public_key")
@patch("models.engine.db.session.query")
def test_tenant_not_found(self, mock_query):
"""Test error when tenant doesn't exist"""
mock_query.return_value.filter.return_value.first.return_value = None
with pytest.raises(ValueError) as exc_info:
encrypt_token("invalid-tenant", "test_token")
assert "Tenant with id invalid-tenant not found" in str(exc_info.value)
class TestDecryptToken:
@patch("libs.rsa.decrypt")
def test_successful_decryption(self, mock_decrypt):
"""Test successful token decryption"""
mock_decrypt.return_value = "decrypted_token"
encrypted_data = base64.b64encode(b"encrypted_data").decode()
result = decrypt_token("tenant-123", encrypted_data)
assert result == "decrypted_token"
mock_decrypt.assert_called_once_with(b"encrypted_data", "tenant-123")
def test_invalid_base64(self):
"""Test handling of invalid base64 input"""
with pytest.raises(binascii.Error):
decrypt_token("tenant-123", "invalid_base64!!!")
class TestBatchDecryptToken:
@patch("libs.rsa.get_decrypt_decoding")
@patch("libs.rsa.decrypt_token_with_decoding")
def test_batch_decryption(self, mock_decrypt_with_decoding, mock_get_decoding):
"""Test batch decryption functionality"""
mock_rsa_key = MagicMock()
mock_cipher_rsa = MagicMock()
mock_get_decoding.return_value = (mock_rsa_key, mock_cipher_rsa)
# Test multiple tokens
mock_decrypt_with_decoding.side_effect = ["token1", "token2", "token3"]
tokens = [
base64.b64encode(b"encrypted1").decode(),
base64.b64encode(b"encrypted2").decode(),
base64.b64encode(b"encrypted3").decode(),
]
result = batch_decrypt_token("tenant-123", tokens)
assert result == ["token1", "token2", "token3"]
# Key should only be loaded once
mock_get_decoding.assert_called_once_with("tenant-123")
class TestGetDecryptDecoding:
@patch("extensions.ext_redis.redis_client.get")
@patch("extensions.ext_storage.storage.load")
def test_private_key_not_found(self, mock_storage_load, mock_redis_get):
"""Test error when private key file doesn't exist"""
mock_redis_get.return_value = None
mock_storage_load.side_effect = FileNotFoundError()
with pytest.raises(PrivkeyNotFoundError) as exc_info:
get_decrypt_decoding("tenant-123")
assert "Private key not found, tenant_id: tenant-123" in str(exc_info.value)
class TestEncryptDecryptIntegration:
@patch("models.engine.db.session.query")
@patch("libs.rsa.encrypt")
@patch("libs.rsa.decrypt")
def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_query):
"""Test that encryption and decryption are consistent"""
# Setup mock tenant
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
# Setup mock encryption/decryption
original_token = "test_token_123"
mock_encrypt.return_value = b"encrypted_data"
mock_decrypt.return_value = original_token
# Test encryption
encrypted = encrypt_token("tenant-123", original_token)
# Test decryption
decrypted = decrypt_token("tenant-123", encrypted)
assert decrypted == original_token
class TestSecurity:
"""Critical security tests for encryption system"""
@patch("models.engine.db.session.query")
@patch("libs.rsa.encrypt")
def test_cross_tenant_isolation(self, mock_encrypt, mock_query):
"""Ensure tokens encrypted for one tenant cannot be used by another"""
# Setup mock tenant
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "tenant1_public_key"
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
mock_encrypt.return_value = b"encrypted_for_tenant1"
# Encrypt token for tenant1
encrypted = encrypt_token("tenant-123", "sensitive_data")
# Attempt to decrypt with different tenant should fail
with patch("libs.rsa.decrypt") as mock_decrypt:
mock_decrypt.side_effect = Exception("Invalid tenant key")
with pytest.raises(Exception, match="Invalid tenant key"):
decrypt_token("different-tenant", encrypted)
@patch("libs.rsa.decrypt")
def test_tampered_ciphertext_rejection(self, mock_decrypt):
"""Detect and reject tampered ciphertext"""
valid_encrypted = base64.b64encode(b"valid_data").decode()
# Tamper with ciphertext
tampered_bytes = bytearray(base64.b64decode(valid_encrypted))
tampered_bytes[0] ^= 0xFF
tampered = base64.b64encode(bytes(tampered_bytes)).decode()
mock_decrypt.side_effect = Exception("Decryption error")
with pytest.raises(Exception, match="Decryption error"):
decrypt_token("tenant-123", tampered)
@patch("models.engine.db.session.query")
@patch("libs.rsa.encrypt")
def test_encryption_randomness(self, mock_encrypt, mock_query):
"""Ensure same plaintext produces different ciphertext"""
mock_tenant = MagicMock(encrypt_public_key="key")
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
# Different outputs for same input
mock_encrypt.side_effect = [b"enc1", b"enc2", b"enc3"]
results = [encrypt_token("tenant-123", "token") for _ in range(3)]
# All results should be different
assert len(set(results)) == 3
class TestEdgeCases:
"""Additional security-focused edge case tests"""
def test_should_handle_empty_string_in_obfuscation(self):
"""Test handling of empty string in obfuscation"""
# Test empty string (which is a valid str type)
assert obfuscated_token("") == ""
@patch("models.engine.db.session.query")
@patch("libs.rsa.encrypt")
def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_query):
"""Test encryption of empty token"""
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
mock_encrypt.return_value = b"encrypted_empty"
result = encrypt_token("tenant-123", "")
assert result == base64.b64encode(b"encrypted_empty").decode()
mock_encrypt.assert_called_with("", "mock_public_key")
@patch("models.engine.db.session.query")
@patch("libs.rsa.encrypt")
def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_query):
"""Test tokens containing special/unicode characters"""
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
mock_encrypt.return_value = b"encrypted_special"
# Test various special characters
special_tokens = [
"token\x00with\x00null", # Null bytes
"token_with_emoji_😀🎉", # Unicode emoji
"token\nwith\nnewlines", # Newlines
"token\twith\ttabs", # Tabs
"token_with_中文字符", # Chinese characters
]
for token in special_tokens:
result = encrypt_token("tenant-123", token)
assert result == base64.b64encode(b"encrypted_special").decode()
mock_encrypt.assert_called_with(token, "mock_public_key")
@patch("models.engine.db.session.query")
@patch("libs.rsa.encrypt")
def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_query):
"""Test behavior when token exceeds RSA encryption limits"""
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
# RSA 2048-bit can only encrypt ~245 bytes
# The actual limit depends on padding scheme
mock_encrypt.side_effect = ValueError("Message too long for RSA key size")
# Create a token that would exceed RSA limits
long_token = "x" * 300
with pytest.raises(ValueError, match="Message too long for RSA key size"):
encrypt_token("tenant-123", long_token)
@patch("libs.rsa.get_decrypt_decoding")
@patch("libs.rsa.decrypt_token_with_decoding")
def test_batch_decrypt_loads_key_only_once(self, mock_decrypt_with_decoding, mock_get_decoding):
"""Verify batch decryption optimization - loads key only once"""
mock_rsa_key = MagicMock()
mock_cipher_rsa = MagicMock()
mock_get_decoding.return_value = (mock_rsa_key, mock_cipher_rsa)
# Test with multiple tokens
mock_decrypt_with_decoding.side_effect = ["token1", "token2", "token3", "token4", "token5"]
tokens = [base64.b64encode(f"encrypted{i}".encode()).decode() for i in range(5)]
result = batch_decrypt_token("tenant-123", tokens)
assert result == ["token1", "token2", "token3", "token4", "token5"]
# Key should only be loaded once regardless of token count
mock_get_decoding.assert_called_once_with("tenant-123")
assert mock_decrypt_with_decoding.call_count == 5

View File

@ -0,0 +1,53 @@
from redis import RedisError
from extensions.ext_redis import redis_fallback
def test_redis_fallback_success():
@redis_fallback(default_return=None)
def test_func():
return "success"
assert test_func() == "success"
def test_redis_fallback_error():
@redis_fallback(default_return="fallback")
def test_func():
raise RedisError("Redis error")
assert test_func() == "fallback"
def test_redis_fallback_none_default():
@redis_fallback()
def test_func():
raise RedisError("Redis error")
assert test_func() is None
def test_redis_fallback_with_args():
@redis_fallback(default_return=0)
def test_func(x, y):
raise RedisError("Redis error")
assert test_func(1, 2) == 0
def test_redis_fallback_with_kwargs():
@redis_fallback(default_return={})
def test_func(x=None, y=None):
raise RedisError("Redis error")
assert test_func(x=1, y=2) == {}
def test_redis_fallback_preserves_function_metadata():
@redis_fallback(default_return=None)
def test_func():
"""Test function docstring"""
pass
assert test_func.__name__ == "test_func"
assert test_func.__doc__ == "Test function docstring"

View File

@ -14,9 +14,7 @@ from core.variables import (
ArrayStringVariable,
FloatVariable,
IntegerVariable,
ObjectSegment,
SecretVariable,
SegmentType,
StringVariable,
)
from core.variables.exc import VariableError
@ -418,8 +416,6 @@ def test_build_segment_file_array_with_different_file_types():
@st.composite
def _generate_file(draw) -> File:
file_id = draw(st.text(min_size=1, max_size=10))
tenant_id = draw(st.text(min_size=1, max_size=10))
file_type, mime_type, extension = draw(
st.sampled_from(
[

View File

@ -10,7 +10,6 @@ from core.model_runtime.entities.model_entities import ModelType
from models.dataset import Dataset, ExternalKnowledgeBindings
from services.dataset_service import DatasetService
from services.errors.account import NoPermissionError
from tests.unit_tests.conftest import redis_mock
class DatasetUpdateTestDataFactory: