mirror of
https://github.com/langgenius/dify.git
synced 2026-04-26 21:55:58 +08:00
feat: add export app messages
fix: tests feat: add filename validate
This commit is contained in:
@ -2684,7 +2684,11 @@ def clean_expired_messages(
|
||||
required=True,
|
||||
help="Upper bound (exclusive) for created_at.",
|
||||
)
|
||||
@click.option("--filename", required=True, help="Output filename (local path or cloud storage key).")
|
||||
@click.option(
|
||||
"--filename",
|
||||
required=True,
|
||||
help="Base filename (relative path). Do not include suffix like .jsonl.gz.",
|
||||
)
|
||||
@click.option("--use-cloud-storage", is_flag=True, default=False, help="Upload to cloud storage instead of local file.")
|
||||
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for cursor pagination.")
|
||||
@click.option("--dry-run", is_flag=True, default=False, help="Scan only, print stats without writing any file.")
|
||||
@ -2702,6 +2706,11 @@ def export_app_messages(
|
||||
|
||||
from services.retention.conversation.message_export_service import AppMessageExportService
|
||||
|
||||
try:
|
||||
validated_filename = AppMessageExportService.validate_export_filename(filename)
|
||||
except ValueError as e:
|
||||
raise click.BadParameter(str(e), param_hint="--filename") from e
|
||||
|
||||
click.echo(click.style(f"export_app_messages: starting export for app {app_id}.", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
@ -2709,7 +2718,7 @@ def export_app_messages(
|
||||
service = AppMessageExportService(
|
||||
app_id=app_id,
|
||||
end_before=end_before,
|
||||
filename=filename,
|
||||
filename=validated_filename,
|
||||
start_from=start_from,
|
||||
batch_size=batch_size,
|
||||
use_cloud_storage=use_cloud_storage,
|
||||
|
||||
@ -15,6 +15,7 @@ import logging
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from collections.abc import Generator, Iterable
|
||||
from pathlib import Path, PurePosixPath
|
||||
from typing import Any, BinaryIO, cast
|
||||
|
||||
import orjson
|
||||
@ -29,6 +30,9 @@ from models.model import Message, MessageFeedback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_FILENAME_BASE_LENGTH = 1024
|
||||
FORBIDDEN_FILENAME_SUFFIXES = (".jsonl.gz", ".jsonl", ".gz")
|
||||
|
||||
|
||||
class AppMessageExportFeedback(BaseModel):
|
||||
id: str
|
||||
@ -68,6 +72,49 @@ class AppMessageExportStats(BaseModel):
|
||||
|
||||
|
||||
class AppMessageExportService:
|
||||
@staticmethod
|
||||
def validate_export_filename(filename: str) -> str:
|
||||
normalized = filename.strip()
|
||||
if not normalized:
|
||||
raise ValueError("--filename must not be empty.")
|
||||
|
||||
normalized_lower = normalized.lower()
|
||||
if normalized_lower.endswith(FORBIDDEN_FILENAME_SUFFIXES):
|
||||
raise ValueError("--filename must not include .jsonl.gz/.jsonl/.gz suffix; pass base filename only.")
|
||||
|
||||
if normalized.startswith("/"):
|
||||
raise ValueError("--filename must be a relative path; absolute paths are not allowed.")
|
||||
|
||||
if "\\" in normalized:
|
||||
raise ValueError("--filename must use '/' as path separator; '\\' is not allowed.")
|
||||
|
||||
if "//" in normalized:
|
||||
raise ValueError("--filename must not contain empty path segments ('//').")
|
||||
|
||||
if len(normalized) > MAX_FILENAME_BASE_LENGTH:
|
||||
raise ValueError(f"--filename is too long; max length is {MAX_FILENAME_BASE_LENGTH}.")
|
||||
|
||||
for ch in normalized:
|
||||
if ch == "\x00" or ord(ch) < 32 or ord(ch) == 127:
|
||||
raise ValueError("--filename must not contain control characters or NUL.")
|
||||
|
||||
parts = PurePosixPath(normalized).parts
|
||||
if not parts:
|
||||
raise ValueError("--filename must include a file name.")
|
||||
|
||||
if any(part in (".", "..") for part in parts):
|
||||
raise ValueError("--filename must not contain '.' or '..' path segments.")
|
||||
|
||||
return normalized
|
||||
|
||||
@property
|
||||
def output_gz_name(self) -> str:
|
||||
return f"{self._filename_base}.jsonl.gz"
|
||||
|
||||
@property
|
||||
def output_jsonl_name(self) -> str:
|
||||
return f"{self._filename_base}.jsonl"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_id: str,
|
||||
@ -85,7 +132,7 @@ class AppMessageExportService:
|
||||
self._app_id = app_id
|
||||
self._end_before = end_before
|
||||
self._start_from = start_from
|
||||
self._filename = filename
|
||||
self._filename_base = self.validate_export_filename(filename)
|
||||
self._batch_size = batch_size
|
||||
self._use_cloud_storage = use_cloud_storage
|
||||
self._dry_run = dry_run
|
||||
@ -94,12 +141,13 @@ class AppMessageExportService:
|
||||
stats = AppMessageExportStats()
|
||||
|
||||
logger.info(
|
||||
"export_app_messages: app_id=%s, start_from=%s, end_before=%s, dry_run=%s, cloud=%s",
|
||||
"export_app_messages: app_id=%s, start_from=%s, end_before=%s, dry_run=%s, cloud=%s, output_gz=%s",
|
||||
self._app_id,
|
||||
self._start_from,
|
||||
self._end_before,
|
||||
self._dry_run,
|
||||
self._use_cloud_storage,
|
||||
self.output_gz_name,
|
||||
)
|
||||
|
||||
if self._dry_run:
|
||||
@ -127,7 +175,9 @@ class AppMessageExportService:
|
||||
gz.write(orjson.dumps(record.model_dump(mode="json")) + b"\n")
|
||||
|
||||
def _export_to_local(self, stats: AppMessageExportStats) -> None:
|
||||
with open(self._filename, "wb") as output_file:
|
||||
output_path = Path.cwd() / self.output_gz_name
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with output_path.open("wb") as output_file:
|
||||
self.write_jsonl_gz(self._iter_records_with_stats(stats), output_file)
|
||||
|
||||
def _export_to_cloud(self, stats: AppMessageExportStats) -> None:
|
||||
@ -136,12 +186,10 @@ class AppMessageExportService:
|
||||
tmp.seek(0)
|
||||
data = tmp.read()
|
||||
|
||||
storage.save(self._filename, data)
|
||||
logger.info("export_app_messages: uploaded %d bytes to cloud key=%s", len(data), self._filename)
|
||||
storage.save(self.output_gz_name, data)
|
||||
logger.info("export_app_messages: uploaded %d bytes to cloud key=%s", len(data), self.output_gz_name)
|
||||
|
||||
def _iter_records_with_stats(
|
||||
self, stats: AppMessageExportStats
|
||||
) -> Generator[AppMessageExportRecord, None, None]:
|
||||
def _iter_records_with_stats(self, stats: AppMessageExportStats) -> Generator[AppMessageExportRecord, None, None]:
|
||||
for record in self.iter_records():
|
||||
self._update_stats(stats, record)
|
||||
yield record
|
||||
@ -233,9 +281,7 @@ class AppMessageExportService:
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _build_record(
|
||||
row: Any, feedbacks_map: dict[str, list[AppMessageExportFeedback]]
|
||||
) -> AppMessageExportRecord:
|
||||
def _build_record(row: Any, feedbacks_map: dict[str, list[AppMessageExportFeedback]]) -> AppMessageExportRecord:
|
||||
retriever_resources: list[Any] = []
|
||||
if row.message_metadata:
|
||||
try:
|
||||
|
||||
@ -205,7 +205,7 @@ class TestAppMessageExportServiceIntegration:
|
||||
app_id=app.id,
|
||||
start_from=base_time - datetime.timedelta(minutes=1),
|
||||
end_before=base_time + datetime.timedelta(minutes=10),
|
||||
filename="unused.jsonl.gz",
|
||||
filename="unused",
|
||||
batch_size=1,
|
||||
dry_run=True,
|
||||
)
|
||||
|
||||
43
api/tests/unit_tests/services/test_export_app_messages.py
Normal file
43
api/tests/unit_tests/services/test_export_app_messages.py
Normal file
@ -0,0 +1,43 @@
|
||||
import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from services.retention.conversation.message_export_service import AppMessageExportService
|
||||
|
||||
|
||||
def test_validate_export_filename_accepts_relative_path():
|
||||
assert AppMessageExportService.validate_export_filename("exports/2026/test01") == "exports/2026/test01"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filename",
|
||||
[
|
||||
"test01.jsonl.gz",
|
||||
"test01.jsonl",
|
||||
"test01.gz",
|
||||
"/tmp/test01",
|
||||
"exports/../test01",
|
||||
"bad\x00name",
|
||||
"bad\tname",
|
||||
"a" * 1025,
|
||||
],
|
||||
)
|
||||
def test_validate_export_filename_rejects_invalid_values(filename: str):
|
||||
with pytest.raises(ValueError):
|
||||
AppMessageExportService.validate_export_filename(filename)
|
||||
|
||||
|
||||
def test_service_derives_output_names_from_filename_base():
|
||||
service = AppMessageExportService(
|
||||
app_id="736b9b03-20f2-4697-91da-8d00f6325900",
|
||||
start_from=None,
|
||||
end_before=datetime.datetime(2026, 3, 1),
|
||||
filename="exports/2026/test01",
|
||||
batch_size=1000,
|
||||
use_cloud_storage=True,
|
||||
dry_run=True,
|
||||
)
|
||||
|
||||
assert service._filename_base == "exports/2026/test01"
|
||||
assert service.output_gz_name == "exports/2026/test01.jsonl.gz"
|
||||
assert service.output_jsonl_name == "exports/2026/test01.jsonl"
|
||||
Reference in New Issue
Block a user