Merge remote main and resolve conflicts for summaryindex feature

- Resolved conflicts in 9 task files by adopting session_factory pattern from main
- Preserved all summaryindex functionality including enable/disable logic
- Updated all task files to use session_factory.create_session() instead of db.session
- Merged new features from main (FileService, DocumentBatchDownloadZipPayload, etc.)
This commit is contained in:
FFXN
2026-01-21 16:03:54 +08:00
822 changed files with 60654 additions and 10354 deletions

View File

@ -1381,6 +1381,11 @@ class RegisterService:
normalized_email = email.lower()
"""Invite new member"""
# Check workspace permission for member invitations
from libs.workspace_permission import check_workspace_member_invite_permission
check_workspace_member_invite_permission(tenant.id)
with Session(db.engine) as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)

View File

@ -13,10 +13,11 @@ import sqlalchemy as sa
from redis.exceptions import LockNotOwnedError
from sqlalchemy import exists, func, select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.file import helpers as file_helpers
from core.helper.name_generator import generate_incremental_name
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
@ -73,6 +74,7 @@ from services.errors.document import DocumentIndexingError
from services.errors.file import FileNotExistsError
from services.external_knowledge_service import ExternalDatasetService
from services.feature_service import FeatureModel, FeatureService
from services.file_service import FileService
from services.rag_pipeline.rag_pipeline import RagPipelineService
from services.tag_service import TagService
from services.vector_service import VectorService
@ -1246,6 +1248,7 @@ class DocumentService:
Document.archived.is_(True),
),
}
DOCUMENT_BATCH_DOWNLOAD_ZIP_FILENAME_EXTENSION = ".zip"
@classmethod
def normalize_display_status(cls, status: str | None) -> str | None:
@ -1372,6 +1375,143 @@ class DocumentService:
else:
return None
@staticmethod
def get_documents_by_ids(dataset_id: str, document_ids: Sequence[str]) -> Sequence[Document]:
"""Fetch documents for a dataset in a single batch query."""
if not document_ids:
return []
document_id_list: list[str] = [str(document_id) for document_id in document_ids]
# Fetch all requested documents in one query to avoid N+1 lookups.
documents: Sequence[Document] = db.session.scalars(
select(Document).where(
Document.dataset_id == dataset_id,
Document.id.in_(document_id_list),
)
).all()
return documents
@staticmethod
def get_document_download_url(document: Document) -> str:
"""
Return a signed download URL for an upload-file document.
"""
upload_file = DocumentService._get_upload_file_for_upload_file_document(document)
return file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True)
@staticmethod
def prepare_document_batch_download_zip(
*,
dataset_id: str,
document_ids: Sequence[str],
tenant_id: str,
current_user: Account,
) -> tuple[list[UploadFile], str]:
"""
Resolve upload files for batch ZIP downloads and generate a client-visible filename.
"""
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except NoPermissionError as e:
raise Forbidden(str(e))
upload_files_by_document_id = DocumentService._get_upload_files_by_document_id_for_zip_download(
dataset_id=dataset_id,
document_ids=document_ids,
tenant_id=tenant_id,
)
upload_files = [upload_files_by_document_id[document_id] for document_id in document_ids]
download_name = DocumentService._generate_document_batch_download_zip_filename()
return upload_files, download_name
@staticmethod
def _generate_document_batch_download_zip_filename() -> str:
"""
Generate a random attachment filename for the batch download ZIP.
"""
return f"{uuid.uuid4().hex}{DocumentService.DOCUMENT_BATCH_DOWNLOAD_ZIP_FILENAME_EXTENSION}"
@staticmethod
def _get_upload_file_id_for_upload_file_document(
document: Document,
*,
invalid_source_message: str,
missing_file_message: str,
) -> str:
"""
Normalize and validate `Document -> UploadFile` linkage for download flows.
"""
if document.data_source_type != "upload_file":
raise NotFound(invalid_source_message)
data_source_info: dict[str, Any] = document.data_source_info_dict or {}
upload_file_id: str | None = data_source_info.get("upload_file_id")
if not upload_file_id:
raise NotFound(missing_file_message)
return str(upload_file_id)
@staticmethod
def _get_upload_file_for_upload_file_document(document: Document) -> UploadFile:
"""
Load the `UploadFile` row for an upload-file document.
"""
upload_file_id = DocumentService._get_upload_file_id_for_upload_file_document(
document,
invalid_source_message="Document does not have an uploaded file to download.",
missing_file_message="Uploaded file not found.",
)
upload_files_by_id = FileService.get_upload_files_by_ids(document.tenant_id, [upload_file_id])
upload_file = upload_files_by_id.get(upload_file_id)
if not upload_file:
raise NotFound("Uploaded file not found.")
return upload_file
@staticmethod
def _get_upload_files_by_document_id_for_zip_download(
*,
dataset_id: str,
document_ids: Sequence[str],
tenant_id: str,
) -> dict[str, UploadFile]:
"""
Batch load upload files keyed by document id for ZIP downloads.
"""
document_id_list: list[str] = [str(document_id) for document_id in document_ids]
documents = DocumentService.get_documents_by_ids(dataset_id, document_id_list)
documents_by_id: dict[str, Document] = {str(document.id): document for document in documents}
missing_document_ids: set[str] = set(document_id_list) - set(documents_by_id.keys())
if missing_document_ids:
raise NotFound("Document not found.")
upload_file_ids: list[str] = []
upload_file_ids_by_document_id: dict[str, str] = {}
for document_id, document in documents_by_id.items():
if document.tenant_id != tenant_id:
raise Forbidden("No permission.")
upload_file_id = DocumentService._get_upload_file_id_for_upload_file_document(
document,
invalid_source_message="Only uploaded-file documents can be downloaded as ZIP.",
missing_file_message="Only uploaded-file documents can be downloaded as ZIP.",
)
upload_file_ids.append(upload_file_id)
upload_file_ids_by_document_id[document_id] = upload_file_id
upload_files_by_id = FileService.get_upload_files_by_ids(tenant_id, upload_file_ids)
missing_upload_file_ids: set[str] = set(upload_file_ids) - set(upload_files_by_id.keys())
if missing_upload_file_ids:
raise NotFound("Only uploaded-file documents can be downloaded as ZIP.")
return {
document_id: upload_files_by_id[upload_file_id]
for document_id, upload_file_id in upload_file_ids_by_document_id.items()
}
@staticmethod
def get_document_by_id(document_id: str) -> Document | None:
document = db.session.query(Document).where(Document.id == document_id).first()

View File

@ -13,6 +13,23 @@ class WebAppSettings(BaseModel):
)
class WorkspacePermission(BaseModel):
workspace_id: str = Field(
description="The ID of the workspace.",
alias="workspaceId",
)
allow_member_invite: bool = Field(
description="Whether to allow members to invite new members to the workspace.",
default=False,
alias="allowMemberInvite",
)
allow_owner_transfer: bool = Field(
description="Whether to allow owners to transfer ownership of the workspace.",
default=False,
alias="allowOwnerTransfer",
)
class EnterpriseService:
@classmethod
def get_info(cls):
@ -44,6 +61,16 @@ class EnterpriseService:
except ValueError as e:
raise ValueError(f"Invalid date format: {data}") from e
class WorkspacePermissionService:
@classmethod
def get_permission(cls, workspace_id: str):
if not workspace_id:
raise ValueError("workspace_id must be provided.")
data = EnterpriseRequest.send_request("GET", f"/workspaces/{workspace_id}/permission")
if not data or "permission" not in data:
raise ValueError("No data found.")
return WorkspacePermission.model_validate(data["permission"])
class WebAppAuth:
@classmethod
def is_user_allowed_to_access_webapp(cls, user_id: str, app_id: str):

View File

@ -0,0 +1,58 @@
import json
import logging
import uuid
from datetime import UTC, datetime
from redis import RedisError
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
WORKSPACE_SYNC_QUEUE = "enterprise:workspace:sync:queue"
WORKSPACE_SYNC_PROCESSING = "enterprise:workspace:sync:processing"
class WorkspaceSyncService:
"""Service to publish workspace sync tasks to Redis queue for enterprise backend consumption"""
@staticmethod
def queue_credential_sync(workspace_id: str, *, source: str) -> bool:
"""
Queue a credential sync task for a newly created workspace.
This publishes a task to Redis that will be consumed by the enterprise backend
worker to sync credentials with the plugin-manager.
Args:
workspace_id: The workspace/tenant ID to sync credentials for
source: Source of the sync request (for debugging/tracking)
Returns:
bool: True if task was queued successfully, False otherwise
"""
try:
task = {
"task_id": str(uuid.uuid4()),
"workspace_id": workspace_id,
"retry_count": 0,
"created_at": datetime.now(UTC).isoformat(),
"source": source,
}
# Push to Redis list (queue) - LPUSH adds to the head, worker consumes from tail with RPOP
redis_client.lpush(WORKSPACE_SYNC_QUEUE, json.dumps(task))
logger.info(
"Queued credential sync task for workspace %s, task_id: %s, source: %s",
workspace_id,
task["task_id"],
source,
)
return True
except (RedisError, TypeError) as e:
logger.error("Failed to queue credential sync for workspace %s: %s", workspace_id, str(e), exc_info=True)
# Don't raise - we don't want to fail workspace creation if queueing fails
# The scheduled task will catch it later
return False

View File

@ -2,7 +2,11 @@ import base64
import hashlib
import os
import uuid
from collections.abc import Iterator, Sequence
from contextlib import contextmanager, suppress
from tempfile import NamedTemporaryFile
from typing import Literal, Union
from zipfile import ZIP_DEFLATED, ZipFile
from sqlalchemy import Engine, select
from sqlalchemy.orm import Session, sessionmaker
@ -17,6 +21,7 @@ from constants import (
)
from core.file import helpers as file_helpers
from core.rag.extractor.extract_processor import ExtractProcessor
from extensions.ext_database import db
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from libs.helper import extract_tenant_id
@ -167,6 +172,9 @@ class FileService:
return upload_file
def get_file_preview(self, file_id: str):
"""
Return a short text preview extracted from a document file.
"""
with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
@ -253,3 +261,101 @@ class FileService:
return
storage.delete(upload_file.key)
session.delete(upload_file)
@staticmethod
def get_upload_files_by_ids(tenant_id: str, upload_file_ids: Sequence[str]) -> dict[str, UploadFile]:
"""
Fetch `UploadFile` rows for a tenant in a single batch query.
This is a generic `UploadFile` lookup helper (not dataset/document specific), so it lives in `FileService`.
"""
if not upload_file_ids:
return {}
# Normalize and deduplicate ids before using them in the IN clause.
upload_file_id_list: list[str] = [str(upload_file_id) for upload_file_id in upload_file_ids]
unique_upload_file_ids: list[str] = list(set(upload_file_id_list))
# Fetch upload files in one query for efficient batch access.
upload_files: Sequence[UploadFile] = db.session.scalars(
select(UploadFile).where(
UploadFile.tenant_id == tenant_id,
UploadFile.id.in_(unique_upload_file_ids),
)
).all()
return {str(upload_file.id): upload_file for upload_file in upload_files}
@staticmethod
def _sanitize_zip_entry_name(name: str) -> str:
"""
Sanitize a ZIP entry name to avoid path traversal and weird separators.
We keep this conservative: the upload flow already rejects `/` and `\\`, but older rows (or imported data)
could still contain unsafe names.
"""
# Drop any directory components and prevent empty names.
base = os.path.basename(name).strip() or "file"
# ZIP uses forward slashes as separators; remove any residual separator characters.
return base.replace("/", "_").replace("\\", "_")
@staticmethod
def _dedupe_zip_entry_name(original_name: str, used_names: set[str]) -> str:
"""
Return a unique ZIP entry name, inserting suffixes before the extension.
"""
# Keep the original name when it's not already used.
if original_name not in used_names:
return original_name
# Insert suffixes before the extension (e.g., "doc.txt" -> "doc (1).txt").
stem, extension = os.path.splitext(original_name)
suffix = 1
while True:
candidate = f"{stem} ({suffix}){extension}"
if candidate not in used_names:
return candidate
suffix += 1
@staticmethod
@contextmanager
def build_upload_files_zip_tempfile(
*,
upload_files: Sequence[UploadFile],
) -> Iterator[str]:
"""
Build a ZIP from `UploadFile`s and yield a tempfile path.
We yield a path (rather than an open file handle) to avoid "read of closed file" issues when Flask/Werkzeug
streams responses. The caller is expected to keep this context open until the response is fully sent, then
close it (e.g., via `response.call_on_close(...)`) to delete the tempfile.
"""
used_names: set[str] = set()
# Build a ZIP in a temp file and keep it on disk until the caller finishes streaming it.
tmp_path: str | None = None
try:
with NamedTemporaryFile(mode="w+b", suffix=".zip", delete=False) as tmp:
tmp_path = tmp.name
with ZipFile(tmp, mode="w", compression=ZIP_DEFLATED) as zf:
for upload_file in upload_files:
# Ensure the entry name is safe and unique.
safe_name = FileService._sanitize_zip_entry_name(upload_file.name)
arcname = FileService._dedupe_zip_entry_name(safe_name, used_names)
used_names.add(arcname)
# Stream file bytes from storage into the ZIP entry.
with zf.open(arcname, "w") as entry:
for chunk in storage.load(upload_file.key, stream=True):
entry.write(chunk)
# Flush so `send_file(path, ...)` can re-open it safely on all platforms.
tmp.flush()
assert tmp_path is not None
yield tmp_path
finally:
# Remove the temp file when the context is closed (typically after the response finishes streaming).
if tmp_path is not None:
with suppress(FileNotFoundError):
os.remove(tmp_path)

View File

@ -0,0 +1,216 @@
import datetime
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from configs import dify_config
from enums.cloud_plan import CloudPlan
from services.billing_service import BillingService, SubscriptionPlan
logger = logging.getLogger(__name__)
@dataclass
class SimpleMessage:
id: str
app_id: str
created_at: datetime.datetime
class MessagesCleanPolicy(ABC):
"""
Abstract base class for message cleanup policies.
A policy determines which messages from a batch should be deleted.
"""
@abstractmethod
def filter_message_ids(
self,
messages: Sequence[SimpleMessage],
app_to_tenant: dict[str, str],
) -> Sequence[str]:
"""
Filter messages and return IDs of messages that should be deleted.
Args:
messages: Batch of messages to evaluate
app_to_tenant: Mapping from app_id to tenant_id
Returns:
List of message IDs that should be deleted
"""
...
class BillingDisabledPolicy(MessagesCleanPolicy):
"""
Policy for community or enterpriseedition (billing disabled).
No special filter logic, just return all message ids.
"""
def filter_message_ids(
self,
messages: Sequence[SimpleMessage],
app_to_tenant: dict[str, str],
) -> Sequence[str]:
return [msg.id for msg in messages]
class BillingSandboxPolicy(MessagesCleanPolicy):
"""
Policy for sandbox plan tenants in cloud edition (billing enabled).
Filters messages based on sandbox plan expiration rules:
- Skip tenants in the whitelist
- Only delete messages from sandbox plan tenants
- Respect grace period after subscription expiration
- Safe default: if tenant mapping or plan is missing, do NOT delete
"""
def __init__(
self,
plan_provider: Callable[[Sequence[str]], dict[str, SubscriptionPlan]],
graceful_period_days: int = 21,
tenant_whitelist: Sequence[str] | None = None,
current_timestamp: int | None = None,
) -> None:
self._graceful_period_days = graceful_period_days
self._tenant_whitelist: Sequence[str] = tenant_whitelist or []
self._plan_provider = plan_provider
self._current_timestamp = current_timestamp
def filter_message_ids(
self,
messages: Sequence[SimpleMessage],
app_to_tenant: dict[str, str],
) -> Sequence[str]:
"""
Filter messages based on sandbox plan expiration rules.
Args:
messages: Batch of messages to evaluate
app_to_tenant: Mapping from app_id to tenant_id
Returns:
List of message IDs that should be deleted
"""
if not messages or not app_to_tenant:
return []
# Get unique tenant_ids and fetch subscription plans
tenant_ids = list(set(app_to_tenant.values()))
tenant_plans = self._plan_provider(tenant_ids)
if not tenant_plans:
return []
# Apply sandbox deletion rules
return self._filter_expired_sandbox_messages(
messages=messages,
app_to_tenant=app_to_tenant,
tenant_plans=tenant_plans,
)
def _filter_expired_sandbox_messages(
self,
messages: Sequence[SimpleMessage],
app_to_tenant: dict[str, str],
tenant_plans: dict[str, SubscriptionPlan],
) -> list[str]:
"""
Filter messages that should be deleted based on sandbox plan expiration.
A message should be deleted if:
1. It belongs to a sandbox tenant AND
2. Either:
a) The tenant has no previous subscription (expiration_date == -1), OR
b) The subscription expired more than graceful_period_days ago
Args:
messages: List of message objects with id and app_id attributes
app_to_tenant: Mapping from app_id to tenant_id
tenant_plans: Mapping from tenant_id to subscription plan info
Returns:
List of message IDs that should be deleted
"""
current_timestamp = self._current_timestamp
if current_timestamp is None:
current_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp())
sandbox_message_ids: list[str] = []
graceful_period_seconds = self._graceful_period_days * 24 * 60 * 60
for msg in messages:
# Get tenant_id for this message's app
tenant_id = app_to_tenant.get(msg.app_id)
if not tenant_id:
continue
# Skip tenant messages in whitelist
if tenant_id in self._tenant_whitelist:
continue
# Get subscription plan for this tenant
tenant_plan = tenant_plans.get(tenant_id)
if not tenant_plan:
continue
plan = str(tenant_plan["plan"])
expiration_date = int(tenant_plan["expiration_date"])
# Only process sandbox plans
if plan != CloudPlan.SANDBOX:
continue
# Case 1: No previous subscription (-1 means never had a paid subscription)
if expiration_date == -1:
sandbox_message_ids.append(msg.id)
continue
# Case 2: Subscription expired beyond grace period
if current_timestamp - expiration_date > graceful_period_seconds:
sandbox_message_ids.append(msg.id)
return sandbox_message_ids
def create_message_clean_policy(
graceful_period_days: int = 21,
current_timestamp: int | None = None,
) -> MessagesCleanPolicy:
"""
Factory function to create the appropriate message clean policy.
Determines which policy to use based on BILLING_ENABLED configuration:
- If BILLING_ENABLED is True: returns BillingSandboxPolicy
- If BILLING_ENABLED is False: returns BillingDisabledPolicy
Args:
graceful_period_days: Grace period in days after subscription expiration (default: 21)
current_timestamp: Current Unix timestamp for testing (default: None, uses current time)
"""
if not dify_config.BILLING_ENABLED:
logger.info("create_message_clean_policy: billing disabled, using BillingDisabledPolicy")
return BillingDisabledPolicy()
# Billing enabled - fetch whitelist from BillingService
tenant_whitelist = BillingService.get_expired_subscription_cleanup_whitelist()
plan_provider = BillingService.get_plan_bulk_with_cache
logger.info(
"create_message_clean_policy: billing enabled, using BillingSandboxPolicy "
"(graceful_period_days=%s, whitelist=%s)",
graceful_period_days,
tenant_whitelist,
)
return BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=graceful_period_days,
tenant_whitelist=tenant_whitelist,
current_timestamp=current_timestamp,
)

View File

@ -0,0 +1,334 @@
import datetime
import logging
import random
from collections.abc import Sequence
from typing import cast
from sqlalchemy import delete, select
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session
from extensions.ext_database import db
from models.model import (
App,
AppAnnotationHitHistory,
DatasetRetrieverResource,
Message,
MessageAgentThought,
MessageAnnotation,
MessageChain,
MessageFeedback,
MessageFile,
)
from models.web import SavedMessage
from services.retention.conversation.messages_clean_policy import (
MessagesCleanPolicy,
SimpleMessage,
)
logger = logging.getLogger(__name__)
class MessagesCleanService:
"""
Service for cleaning expired messages based on retention policies.
Compatible with non cloud edition (billing disabled): all messages in the time range will be deleted.
If billing is enabled: only sandbox plan tenant messages are deleted (with whitelist and grace period support).
"""
def __init__(
self,
policy: MessagesCleanPolicy,
end_before: datetime.datetime,
start_from: datetime.datetime | None = None,
batch_size: int = 1000,
dry_run: bool = False,
) -> None:
"""
Initialize the service with cleanup parameters.
Args:
policy: The policy that determines which messages to delete
end_before: End time (exclusive) of the range
start_from: Optional start time (inclusive) of the range
batch_size: Number of messages to process per batch
dry_run: Whether to perform a dry run (no actual deletion)
"""
self._policy = policy
self._end_before = end_before
self._start_from = start_from
self._batch_size = batch_size
self._dry_run = dry_run
@classmethod
def from_time_range(
cls,
policy: MessagesCleanPolicy,
start_from: datetime.datetime,
end_before: datetime.datetime,
batch_size: int = 1000,
dry_run: bool = False,
) -> "MessagesCleanService":
"""
Create a service instance for cleaning messages within a specific time range.
Time range is [start_from, end_before).
Args:
policy: The policy that determines which messages to delete
start_from: Start time (inclusive) of the range
end_before: End time (exclusive) of the range
batch_size: Number of messages to process per batch
dry_run: Whether to perform a dry run (no actual deletion)
Returns:
MessagesCleanService instance
Raises:
ValueError: If start_from >= end_before or invalid parameters
"""
if start_from >= end_before:
raise ValueError(f"start_from ({start_from}) must be less than end_before ({end_before})")
if batch_size <= 0:
raise ValueError(f"batch_size ({batch_size}) must be greater than 0")
logger.info(
"clean_messages: start_from=%s, end_before=%s, batch_size=%s, policy=%s",
start_from,
end_before,
batch_size,
policy.__class__.__name__,
)
return cls(
policy=policy,
end_before=end_before,
start_from=start_from,
batch_size=batch_size,
dry_run=dry_run,
)
@classmethod
def from_days(
cls,
policy: MessagesCleanPolicy,
days: int = 30,
batch_size: int = 1000,
dry_run: bool = False,
) -> "MessagesCleanService":
"""
Create a service instance for cleaning messages older than specified days.
Args:
policy: The policy that determines which messages to delete
days: Number of days to look back from now
batch_size: Number of messages to process per batch
dry_run: Whether to perform a dry run (no actual deletion)
Returns:
MessagesCleanService instance
Raises:
ValueError: If invalid parameters
"""
if days < 0:
raise ValueError(f"days ({days}) must be greater than or equal to 0")
if batch_size <= 0:
raise ValueError(f"batch_size ({batch_size}) must be greater than 0")
end_before = datetime.datetime.now() - datetime.timedelta(days=days)
logger.info(
"clean_messages: days=%s, end_before=%s, batch_size=%s, policy=%s",
days,
end_before,
batch_size,
policy.__class__.__name__,
)
return cls(policy=policy, end_before=end_before, start_from=None, batch_size=batch_size, dry_run=dry_run)
def run(self) -> dict[str, int]:
"""
Execute the message cleanup operation.
Returns:
Dict with statistics: batches, filtered_messages, total_deleted
"""
return self._clean_messages_by_time_range()
def _clean_messages_by_time_range(self) -> dict[str, int]:
"""
Clean messages within a time range using cursor-based pagination.
Time range is [start_from, end_before)
Steps:
1. Iterate messages using cursor pagination (by created_at, id)
2. Query app_id -> tenant_id mapping
3. Delegate to policy to determine which messages to delete
4. Batch delete messages and their relations
Returns:
Dict with statistics: batches, filtered_messages, total_deleted
"""
stats = {
"batches": 0,
"total_messages": 0,
"filtered_messages": 0,
"total_deleted": 0,
}
# Cursor-based pagination using (created_at, id) to avoid infinite loops
# and ensure proper ordering with time-based filtering
_cursor: tuple[datetime.datetime, str] | None = None
logger.info(
"clean_messages: start cleaning messages (dry_run=%s), start_from=%s, end_before=%s",
self._dry_run,
self._start_from,
self._end_before,
)
while True:
stats["batches"] += 1
# Step 1: Fetch a batch of messages using cursor
with Session(db.engine, expire_on_commit=False) as session:
msg_stmt = (
select(Message.id, Message.app_id, Message.created_at)
.where(Message.created_at < self._end_before)
.order_by(Message.created_at, Message.id)
.limit(self._batch_size)
)
if self._start_from:
msg_stmt = msg_stmt.where(Message.created_at >= self._start_from)
# Apply cursor condition: (created_at, id) > (last_created_at, last_message_id)
# This translates to:
# created_at > last_created_at OR (created_at = last_created_at AND id > last_message_id)
if _cursor:
# Continuing from previous batch
msg_stmt = msg_stmt.where(
(Message.created_at > _cursor[0])
| ((Message.created_at == _cursor[0]) & (Message.id > _cursor[1]))
)
raw_messages = list(session.execute(msg_stmt).all())
messages = [
SimpleMessage(id=msg_id, app_id=app_id, created_at=msg_created_at)
for msg_id, app_id, msg_created_at in raw_messages
]
# Track total messages fetched across all batches
stats["total_messages"] += len(messages)
if not messages:
logger.info("clean_messages (batch %s): no more messages to process", stats["batches"])
break
# Update cursor to the last message's (created_at, id)
_cursor = (messages[-1].created_at, messages[-1].id)
# Step 2: Extract app_ids and query tenant_ids
app_ids = list({msg.app_id for msg in messages})
if not app_ids:
logger.info("clean_messages (batch %s): no app_ids found, skip", stats["batches"])
continue
app_stmt = select(App.id, App.tenant_id).where(App.id.in_(app_ids))
apps = list(session.execute(app_stmt).all())
if not apps:
logger.info("clean_messages (batch %s): no apps found, skip", stats["batches"])
continue
# Build app_id -> tenant_id mapping
app_to_tenant: dict[str, str] = {app.id: app.tenant_id for app in apps}
# Step 3: Delegate to policy to determine which messages to delete
message_ids_to_delete = self._policy.filter_message_ids(messages, app_to_tenant)
if not message_ids_to_delete:
logger.info("clean_messages (batch %s): no messages to delete, skip", stats["batches"])
continue
stats["filtered_messages"] += len(message_ids_to_delete)
# Step 4: Batch delete messages and their relations
if not self._dry_run:
with Session(db.engine, expire_on_commit=False) as session:
# Delete related records first
self._batch_delete_message_relations(session, message_ids_to_delete)
# Delete messages
delete_stmt = delete(Message).where(Message.id.in_(message_ids_to_delete))
delete_result = cast(CursorResult, session.execute(delete_stmt))
messages_deleted = delete_result.rowcount
session.commit()
stats["total_deleted"] += messages_deleted
logger.info(
"clean_messages (batch %s): processed %s messages, deleted %s messages",
stats["batches"],
len(messages),
messages_deleted,
)
else:
# Log random sample of message IDs that would be deleted (up to 10)
sample_size = min(10, len(message_ids_to_delete))
sampled_ids = random.sample(list(message_ids_to_delete), sample_size)
logger.info(
"clean_messages (batch %s, dry_run): would delete %s messages, sampling %s ids:",
stats["batches"],
len(message_ids_to_delete),
sample_size,
)
for msg_id in sampled_ids:
logger.info("clean_messages (batch %s, dry_run) sample: message_id=%s", stats["batches"], msg_id)
logger.info(
"clean_messages completed: total batches: %s, total messages: %s, filtered messages: %s, total deleted: %s",
stats["batches"],
stats["total_messages"],
stats["filtered_messages"],
stats["total_deleted"],
)
return stats
@staticmethod
def _batch_delete_message_relations(session: Session, message_ids: Sequence[str]) -> None:
"""
Batch delete all related records for given message IDs.
Args:
session: Database session
message_ids: List of message IDs to delete relations for
"""
if not message_ids:
return
# Delete all related records in batch
session.execute(delete(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)))
session.execute(delete(MessageAnnotation).where(MessageAnnotation.message_id.in_(message_ids)))
session.execute(delete(MessageChain).where(MessageChain.message_id.in_(message_ids)))
session.execute(delete(MessageAgentThought).where(MessageAgentThought.message_id.in_(message_ids)))
session.execute(delete(MessageFile).where(MessageFile.message_id.in_(message_ids)))
session.execute(delete(SavedMessage).where(SavedMessage.message_id.in_(message_ids)))
session.execute(delete(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id.in_(message_ids)))
session.execute(delete(DatasetRetrieverResource).where(DatasetRetrieverResource.message_id.in_(message_ids)))

View File

@ -10,9 +10,7 @@ from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from models.workflow import WorkflowRun
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.sqlalchemy_api_workflow_node_execution_repository import (
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
)
from repositories.factory import DifyAPIRepositoryFactory
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from services.billing_service import BillingService, SubscriptionPlan
@ -92,9 +90,12 @@ class WorkflowRunCleanup:
paid_or_skipped = len(run_rows) - len(free_runs)
if not free_runs:
skipped_message = (
f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)"
)
click.echo(
click.style(
f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)",
skipped_message,
fg="yellow",
)
)
@ -255,21 +256,6 @@ class WorkflowRunCleanup:
trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
return trigger_repo.count_by_run_ids(run_ids)
@staticmethod
def _build_run_contexts(
runs: Sequence[WorkflowRun],
) -> list[DifyAPISQLAlchemyWorkflowNodeExecutionRepository.RunContext]:
return [
{
"run_id": run.id,
"tenant_id": run.tenant_id,
"app_id": run.app_id,
"workflow_id": run.workflow_id,
"triggered_from": run.triggered_from,
}
for run in runs
]
@staticmethod
def _empty_related_counts() -> dict[str, int]:
return {
@ -293,9 +279,15 @@ class WorkflowRunCleanup:
)
def _count_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]:
run_contexts = self._build_run_contexts(runs)
return DifyAPISQLAlchemyWorkflowNodeExecutionRepository.count_by_runs(session, run_contexts)
run_ids = [run.id for run in runs]
repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
session_maker=sessionmaker(bind=session.get_bind(), expire_on_commit=False)
)
return repo.count_by_runs(session, run_ids)
def _delete_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]:
run_contexts = self._build_run_contexts(runs)
return DifyAPISQLAlchemyWorkflowNodeExecutionRepository.delete_by_runs(session, run_contexts)
run_ids = [run.id for run in runs]
repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
session_maker=sessionmaker(bind=session.get_bind(), expire_on_commit=False)
)
return repo.delete_by_runs(session, run_ids)