mirror of
https://github.com/langgenius/dify.git
synced 2026-04-01 11:27:50 +08:00
Compare commits
4 Commits
deploy/cle
...
copilot/an
| Author | SHA1 | Date | |
|---|---|---|---|
| 85396a2af2 | |||
| 6c19e75969 | |||
| 9970f4449a | |||
| cbb19cce39 |
@ -133,7 +133,7 @@ Star Dify on GitHub and be instantly notified of new releases.
|
||||
|
||||
### Custom configurations
|
||||
|
||||
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
||||
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
||||
|
||||
#### Customizing Suggested Questions
|
||||
|
||||
|
||||
@ -937,12 +937,6 @@ def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[
|
||||
is_flag=True,
|
||||
help="Preview cleanup results without deleting any workflow run data.",
|
||||
)
|
||||
@click.option(
|
||||
"--task-label",
|
||||
default="daily",
|
||||
show_default=True,
|
||||
help="Stable label value used to distinguish multiple cleanup CronJobs in metrics.",
|
||||
)
|
||||
def clean_workflow_runs(
|
||||
before_days: int,
|
||||
batch_size: int,
|
||||
@ -951,13 +945,10 @@ def clean_workflow_runs(
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime | None,
|
||||
dry_run: bool,
|
||||
task_label: str,
|
||||
):
|
||||
"""
|
||||
Clean workflow runs and related workflow data for free tenants.
|
||||
"""
|
||||
from extensions.otel.runtime import flush_telemetry
|
||||
|
||||
if (start_from is None) ^ (end_before is None):
|
||||
raise click.UsageError("--start-from and --end-before must be provided together.")
|
||||
|
||||
@ -977,17 +968,13 @@ def clean_workflow_runs(
|
||||
start_time = datetime.datetime.now(datetime.UTC)
|
||||
click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white"))
|
||||
|
||||
try:
|
||||
WorkflowRunCleanup(
|
||||
days=before_days,
|
||||
batch_size=batch_size,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
).run()
|
||||
finally:
|
||||
flush_telemetry()
|
||||
WorkflowRunCleanup(
|
||||
days=before_days,
|
||||
batch_size=batch_size,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
dry_run=dry_run,
|
||||
).run()
|
||||
|
||||
end_time = datetime.datetime.now(datetime.UTC)
|
||||
elapsed = end_time - start_time
|
||||
@ -2643,12 +2630,6 @@ def migrate_oss(
|
||||
help="Graceful period in days after subscription expiration, will be ignored when billing is disabled.",
|
||||
)
|
||||
@click.option("--dry-run", is_flag=True, default=False, help="Show messages logs would be cleaned without deleting")
|
||||
@click.option(
|
||||
"--task-label",
|
||||
default="daily",
|
||||
show_default=True,
|
||||
help="Stable label value used to distinguish multiple cleanup CronJobs in metrics.",
|
||||
)
|
||||
def clean_expired_messages(
|
||||
batch_size: int,
|
||||
graceful_period: int,
|
||||
@ -2657,13 +2638,10 @@ def clean_expired_messages(
|
||||
from_days_ago: int | None,
|
||||
before_days: int | None,
|
||||
dry_run: bool,
|
||||
task_label: str,
|
||||
):
|
||||
"""
|
||||
Clean expired messages and related data for tenants based on clean policy.
|
||||
"""
|
||||
from extensions.otel.runtime import flush_telemetry
|
||||
|
||||
click.echo(click.style("clean_messages: start clean messages.", fg="green"))
|
||||
|
||||
start_at = time.perf_counter()
|
||||
@ -2713,7 +2691,6 @@ def clean_expired_messages(
|
||||
end_before=end_before,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
)
|
||||
elif from_days_ago is None:
|
||||
assert before_days is not None
|
||||
@ -2722,7 +2699,6 @@ def clean_expired_messages(
|
||||
days=before_days,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
)
|
||||
else:
|
||||
assert before_days is not None
|
||||
@ -2734,7 +2710,6 @@ def clean_expired_messages(
|
||||
end_before=now - datetime.timedelta(days=before_days),
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
)
|
||||
stats = service.run()
|
||||
|
||||
@ -2760,8 +2735,6 @@ def clean_expired_messages(
|
||||
)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
flush_telemetry()
|
||||
|
||||
click.echo(click.style("messages cleanup completed.", fg="green"))
|
||||
|
||||
|
||||
@ -239,7 +239,7 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
||||
def get(self, app_model, end_user, message_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotCompletionAppError()
|
||||
raise NotChatAppError()
|
||||
|
||||
message_id = str(message_id)
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ from typing import Union
|
||||
|
||||
from celery.signals import worker_init
|
||||
from flask_login import user_loaded_from_request, user_logged_in
|
||||
from opentelemetry import metrics, trace
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.propagate import set_global_textmap
|
||||
from opentelemetry.propagators.b3 import B3Format
|
||||
from opentelemetry.propagators.composite import CompositePropagator
|
||||
@ -31,29 +31,9 @@ def setup_context_propagation() -> None:
|
||||
|
||||
|
||||
def shutdown_tracer() -> None:
|
||||
flush_telemetry()
|
||||
|
||||
|
||||
def flush_telemetry() -> None:
|
||||
"""
|
||||
Best-effort flush for telemetry providers.
|
||||
|
||||
This is mainly used by short-lived command processes (e.g. Kubernetes CronJob)
|
||||
so counters/histograms are exported before the process exits.
|
||||
"""
|
||||
provider = trace.get_tracer_provider()
|
||||
if hasattr(provider, "force_flush"):
|
||||
try:
|
||||
provider.force_flush()
|
||||
except Exception:
|
||||
logger.exception("otel: failed to flush trace provider")
|
||||
|
||||
metric_provider = metrics.get_meter_provider()
|
||||
if hasattr(metric_provider, "force_flush"):
|
||||
try:
|
||||
metric_provider.force_flush()
|
||||
except Exception:
|
||||
logger.exception("otel: failed to flush metric provider")
|
||||
provider.force_flush()
|
||||
|
||||
|
||||
def is_celery_worker():
|
||||
|
||||
@ -21,6 +21,10 @@ celery_redis = Redis(
|
||||
ssl_cert_reqs=getattr(dify_config, "REDIS_SSL_CERT_REQS", None) if dify_config.BROKER_USE_SSL else None,
|
||||
ssl_certfile=getattr(dify_config, "REDIS_SSL_CERTFILE", None) if dify_config.BROKER_USE_SSL else None,
|
||||
ssl_keyfile=getattr(dify_config, "REDIS_SSL_KEYFILE", None) if dify_config.BROKER_USE_SSL else None,
|
||||
# Add conservative socket timeouts and health checks to avoid long-lived half-open sockets
|
||||
socket_timeout=5,
|
||||
socket_connect_timeout=5,
|
||||
health_check_interval=30,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -3,6 +3,7 @@ import math
|
||||
import time
|
||||
from collections.abc import Iterable, Sequence
|
||||
|
||||
from celery import group
|
||||
from sqlalchemy import ColumnElement, and_, func, or_, select
|
||||
from sqlalchemy.engine.row import Row
|
||||
from sqlalchemy.orm import Session
|
||||
@ -85,20 +86,25 @@ def trigger_provider_refresh() -> None:
|
||||
lock_keys: list[str] = build_trigger_refresh_lock_keys(subscriptions)
|
||||
acquired: list[bool] = _acquire_locks(keys=lock_keys, ttl_seconds=lock_ttl)
|
||||
|
||||
enqueued: int = 0
|
||||
for (tenant_id, subscription_id), is_locked in zip(subscriptions, acquired):
|
||||
if not is_locked:
|
||||
continue
|
||||
trigger_subscription_refresh.delay(tenant_id=tenant_id, subscription_id=subscription_id)
|
||||
enqueued += 1
|
||||
if not any(acquired):
|
||||
continue
|
||||
|
||||
jobs = [
|
||||
trigger_subscription_refresh.s(tenant_id=tenant_id, subscription_id=subscription_id)
|
||||
for (tenant_id, subscription_id), is_locked in zip(subscriptions, acquired)
|
||||
if is_locked
|
||||
]
|
||||
result = group(jobs).apply_async()
|
||||
enqueued = len(jobs)
|
||||
|
||||
logger.info(
|
||||
"Trigger refresh page %d/%d: scanned=%d locks_acquired=%d enqueued=%d",
|
||||
"Trigger refresh page %d/%d: scanned=%d locks_acquired=%d enqueued=%d result=%s",
|
||||
page + 1,
|
||||
pages,
|
||||
len(subscriptions),
|
||||
sum(1 for x in acquired if x),
|
||||
enqueued,
|
||||
result,
|
||||
)
|
||||
|
||||
logger.info("Trigger refresh scan done: due=%d", total_due)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
from celery import group, shared_task
|
||||
from celery import current_app, group, shared_task
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
@ -29,31 +29,27 @@ def poll_workflow_schedules() -> None:
|
||||
with session_factory() as session:
|
||||
total_dispatched = 0
|
||||
|
||||
# Process in batches until we've handled all due schedules or hit the limit
|
||||
while True:
|
||||
due_schedules = _fetch_due_schedules(session)
|
||||
|
||||
if not due_schedules:
|
||||
break
|
||||
|
||||
dispatched_count = _process_schedules(session, due_schedules)
|
||||
total_dispatched += dispatched_count
|
||||
with current_app.producer_or_acquire() as producer: # type: ignore
|
||||
dispatched_count = _process_schedules(session, due_schedules, producer)
|
||||
total_dispatched += dispatched_count
|
||||
|
||||
logger.debug("Batch processed: %d dispatched", dispatched_count)
|
||||
|
||||
# Circuit breaker: check if we've hit the per-tick limit (if enabled)
|
||||
if (
|
||||
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK > 0
|
||||
and total_dispatched >= dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK
|
||||
):
|
||||
logger.warning(
|
||||
"Circuit breaker activated: reached dispatch limit (%d), will continue next tick",
|
||||
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK,
|
||||
)
|
||||
break
|
||||
logger.debug("Batch processed: %d dispatched", dispatched_count)
|
||||
|
||||
# Circuit breaker: check if we've hit the per-tick limit (if enabled)
|
||||
if 0 < dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK <= total_dispatched:
|
||||
logger.warning(
|
||||
"Circuit breaker activated: reached dispatch limit (%d), will continue next tick",
|
||||
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK,
|
||||
)
|
||||
break
|
||||
if total_dispatched > 0:
|
||||
logger.info("Total processed: %d dispatched", total_dispatched)
|
||||
logger.info("Total processed: %d workflow schedule(s) dispatched", total_dispatched)
|
||||
|
||||
|
||||
def _fetch_due_schedules(session: Session) -> list[WorkflowSchedulePlan]:
|
||||
@ -90,7 +86,7 @@ def _fetch_due_schedules(session: Session) -> list[WorkflowSchedulePlan]:
|
||||
return list(due_schedules)
|
||||
|
||||
|
||||
def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan]) -> int:
|
||||
def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan], producer=None) -> int:
|
||||
"""Process schedules: check quota, update next run time and dispatch to Celery in parallel."""
|
||||
if not schedules:
|
||||
return 0
|
||||
@ -107,7 +103,7 @@ def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan])
|
||||
|
||||
if tasks_to_dispatch:
|
||||
job = group(run_schedule_trigger.s(schedule_id) for schedule_id in tasks_to_dispatch)
|
||||
job.apply_async()
|
||||
job.apply_async(producer=producer)
|
||||
|
||||
logger.debug("Dispatched %d tasks in parallel", len(tasks_to_dispatch))
|
||||
|
||||
|
||||
@ -1,16 +1,16 @@
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, cast
|
||||
from typing import cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import delete, select, tuple_
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import (
|
||||
@ -33,128 +33,6 @@ from services.retention.conversation.messages_clean_policy import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.metrics import Counter, Histogram
|
||||
|
||||
|
||||
class MessagesCleanupMetrics:
|
||||
"""
|
||||
Records low-cardinality OpenTelemetry metrics for expired message cleanup jobs.
|
||||
|
||||
We keep labels stable (dry_run/window_mode/task_label/status) so these metrics remain
|
||||
dashboard-friendly for long-running CronJob executions.
|
||||
"""
|
||||
|
||||
_job_runs_total: "Counter | None"
|
||||
_batches_total: "Counter | None"
|
||||
_messages_scanned_total: "Counter | None"
|
||||
_messages_filtered_total: "Counter | None"
|
||||
_messages_deleted_total: "Counter | None"
|
||||
_job_duration_seconds: "Histogram | None"
|
||||
_batch_duration_seconds: "Histogram | None"
|
||||
_base_attributes: dict[str, str]
|
||||
|
||||
def __init__(self, *, dry_run: bool, has_window: bool, task_label: str) -> None:
|
||||
self._job_runs_total = None
|
||||
self._batches_total = None
|
||||
self._messages_scanned_total = None
|
||||
self._messages_filtered_total = None
|
||||
self._messages_deleted_total = None
|
||||
self._job_duration_seconds = None
|
||||
self._batch_duration_seconds = None
|
||||
self._base_attributes = {
|
||||
"job_name": "messages_cleanup",
|
||||
"dry_run": str(dry_run).lower(),
|
||||
"window_mode": "between" if has_window else "before_cutoff",
|
||||
"task_label": task_label,
|
||||
}
|
||||
self._init_instruments()
|
||||
|
||||
def _init_instruments(self) -> None:
|
||||
try:
|
||||
from opentelemetry.metrics import get_meter
|
||||
|
||||
meter = get_meter("messages_cleanup", version=dify_config.project.version)
|
||||
self._job_runs_total = meter.create_counter(
|
||||
"messages_cleanup_jobs_total",
|
||||
description="Total number of expired message cleanup jobs by status.",
|
||||
unit="{job}",
|
||||
)
|
||||
self._batches_total = meter.create_counter(
|
||||
"messages_cleanup_batches_total",
|
||||
description="Total number of message cleanup batches processed.",
|
||||
unit="{batch}",
|
||||
)
|
||||
self._messages_scanned_total = meter.create_counter(
|
||||
"messages_cleanup_scanned_messages_total",
|
||||
description="Total messages scanned by cleanup jobs.",
|
||||
unit="{message}",
|
||||
)
|
||||
self._messages_filtered_total = meter.create_counter(
|
||||
"messages_cleanup_filtered_messages_total",
|
||||
description="Total messages selected by cleanup policy.",
|
||||
unit="{message}",
|
||||
)
|
||||
self._messages_deleted_total = meter.create_counter(
|
||||
"messages_cleanup_deleted_messages_total",
|
||||
description="Total messages deleted by cleanup jobs.",
|
||||
unit="{message}",
|
||||
)
|
||||
self._job_duration_seconds = meter.create_histogram(
|
||||
"messages_cleanup_job_duration_seconds",
|
||||
description="Duration of expired message cleanup jobs in seconds.",
|
||||
unit="s",
|
||||
)
|
||||
self._batch_duration_seconds = meter.create_histogram(
|
||||
"messages_cleanup_batch_duration_seconds",
|
||||
description="Duration of expired message cleanup batch processing in seconds.",
|
||||
unit="s",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("messages_cleanup_metrics: failed to initialize instruments")
|
||||
|
||||
def _attrs(self, **extra: str) -> dict[str, str]:
|
||||
return {**self._base_attributes, **extra}
|
||||
|
||||
@staticmethod
|
||||
def _add(counter: "Counter | None", value: int, attributes: dict[str, str]) -> None:
|
||||
if not counter or value <= 0:
|
||||
return
|
||||
try:
|
||||
counter.add(value, attributes)
|
||||
except Exception:
|
||||
logger.exception("messages_cleanup_metrics: failed to add counter value")
|
||||
|
||||
@staticmethod
|
||||
def _record(histogram: "Histogram | None", value: float, attributes: dict[str, str]) -> None:
|
||||
if not histogram:
|
||||
return
|
||||
try:
|
||||
histogram.record(value, attributes)
|
||||
except Exception:
|
||||
logger.exception("messages_cleanup_metrics: failed to record histogram value")
|
||||
|
||||
def record_batch(
|
||||
self,
|
||||
*,
|
||||
scanned_messages: int,
|
||||
filtered_messages: int,
|
||||
deleted_messages: int,
|
||||
batch_duration_seconds: float,
|
||||
) -> None:
|
||||
attributes = self._attrs()
|
||||
self._add(self._batches_total, 1, attributes)
|
||||
self._add(self._messages_scanned_total, scanned_messages, attributes)
|
||||
self._add(self._messages_filtered_total, filtered_messages, attributes)
|
||||
self._add(self._messages_deleted_total, deleted_messages, attributes)
|
||||
self._record(self._batch_duration_seconds, batch_duration_seconds, attributes)
|
||||
|
||||
def record_completion(self, *, status: str, job_duration_seconds: float) -> None:
|
||||
attributes = self._attrs(status=status)
|
||||
self._add(self._job_runs_total, 1, attributes)
|
||||
self._record(self._job_duration_seconds, job_duration_seconds, attributes)
|
||||
|
||||
|
||||
class MessagesCleanService:
|
||||
"""
|
||||
Service for cleaning expired messages based on retention policies.
|
||||
@ -170,7 +48,6 @@ class MessagesCleanService:
|
||||
start_from: datetime.datetime | None = None,
|
||||
batch_size: int = 1000,
|
||||
dry_run: bool = False,
|
||||
task_label: str = "daily",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the service with cleanup parameters.
|
||||
@ -181,20 +58,12 @@ class MessagesCleanService:
|
||||
start_from: Optional start time (inclusive) of the range
|
||||
batch_size: Number of messages to process per batch
|
||||
dry_run: Whether to perform a dry run (no actual deletion)
|
||||
task_label: Stable task label to distinguish multiple cleanup CronJobs
|
||||
"""
|
||||
self._policy = policy
|
||||
self._end_before = end_before
|
||||
self._start_from = start_from
|
||||
self._batch_size = batch_size
|
||||
self._dry_run = dry_run
|
||||
normalized_task_label = task_label.strip()
|
||||
self._task_label = normalized_task_label or "daily"
|
||||
self._metrics = MessagesCleanupMetrics(
|
||||
dry_run=dry_run,
|
||||
has_window=bool(start_from),
|
||||
task_label=self._task_label,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_time_range(
|
||||
@ -204,7 +73,6 @@ class MessagesCleanService:
|
||||
end_before: datetime.datetime,
|
||||
batch_size: int = 1000,
|
||||
dry_run: bool = False,
|
||||
task_label: str = "daily",
|
||||
) -> "MessagesCleanService":
|
||||
"""
|
||||
Create a service instance for cleaning messages within a specific time range.
|
||||
@ -217,7 +85,6 @@ class MessagesCleanService:
|
||||
end_before: End time (exclusive) of the range
|
||||
batch_size: Number of messages to process per batch
|
||||
dry_run: Whether to perform a dry run (no actual deletion)
|
||||
task_label: Stable task label to distinguish multiple cleanup CronJobs
|
||||
|
||||
Returns:
|
||||
MessagesCleanService instance
|
||||
@ -245,7 +112,6 @@ class MessagesCleanService:
|
||||
start_from=start_from,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -255,7 +121,6 @@ class MessagesCleanService:
|
||||
days: int = 30,
|
||||
batch_size: int = 1000,
|
||||
dry_run: bool = False,
|
||||
task_label: str = "daily",
|
||||
) -> "MessagesCleanService":
|
||||
"""
|
||||
Create a service instance for cleaning messages older than specified days.
|
||||
@ -265,7 +130,6 @@ class MessagesCleanService:
|
||||
days: Number of days to look back from now
|
||||
batch_size: Number of messages to process per batch
|
||||
dry_run: Whether to perform a dry run (no actual deletion)
|
||||
task_label: Stable task label to distinguish multiple cleanup CronJobs
|
||||
|
||||
Returns:
|
||||
MessagesCleanService instance
|
||||
@ -289,14 +153,7 @@ class MessagesCleanService:
|
||||
policy.__class__.__name__,
|
||||
)
|
||||
|
||||
return cls(
|
||||
policy=policy,
|
||||
end_before=end_before,
|
||||
start_from=None,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
)
|
||||
return cls(policy=policy, end_before=end_before, start_from=None, batch_size=batch_size, dry_run=dry_run)
|
||||
|
||||
def run(self) -> dict[str, int]:
|
||||
"""
|
||||
@ -305,18 +162,7 @@ class MessagesCleanService:
|
||||
Returns:
|
||||
Dict with statistics: batches, filtered_messages, total_deleted
|
||||
"""
|
||||
status = "success"
|
||||
run_start = time.monotonic()
|
||||
try:
|
||||
return self._clean_messages_by_time_range()
|
||||
except Exception:
|
||||
status = "failed"
|
||||
raise
|
||||
finally:
|
||||
self._metrics.record_completion(
|
||||
status=status,
|
||||
job_duration_seconds=time.monotonic() - run_start,
|
||||
)
|
||||
return self._clean_messages_by_time_range()
|
||||
|
||||
def _clean_messages_by_time_range(self) -> dict[str, int]:
|
||||
"""
|
||||
@ -351,14 +197,11 @@ class MessagesCleanService:
|
||||
self._end_before,
|
||||
)
|
||||
|
||||
max_batch_interval_ms = dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL
|
||||
max_batch_interval_ms = int(os.environ.get("SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL", 200))
|
||||
|
||||
while True:
|
||||
stats["batches"] += 1
|
||||
batch_start = time.monotonic()
|
||||
batch_scanned_messages = 0
|
||||
batch_filtered_messages = 0
|
||||
batch_deleted_messages = 0
|
||||
|
||||
# Step 1: Fetch a batch of messages using cursor
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
@ -397,16 +240,9 @@ class MessagesCleanService:
|
||||
|
||||
# Track total messages fetched across all batches
|
||||
stats["total_messages"] += len(messages)
|
||||
batch_scanned_messages = len(messages)
|
||||
|
||||
if not messages:
|
||||
logger.info("clean_messages (batch %s): no more messages to process", stats["batches"])
|
||||
self._metrics.record_batch(
|
||||
scanned_messages=batch_scanned_messages,
|
||||
filtered_messages=batch_filtered_messages,
|
||||
deleted_messages=batch_deleted_messages,
|
||||
batch_duration_seconds=time.monotonic() - batch_start,
|
||||
)
|
||||
break
|
||||
|
||||
# Update cursor to the last message's (created_at, id)
|
||||
@ -432,12 +268,6 @@ class MessagesCleanService:
|
||||
|
||||
if not apps:
|
||||
logger.info("clean_messages (batch %s): no apps found, skip", stats["batches"])
|
||||
self._metrics.record_batch(
|
||||
scanned_messages=batch_scanned_messages,
|
||||
filtered_messages=batch_filtered_messages,
|
||||
deleted_messages=batch_deleted_messages,
|
||||
batch_duration_seconds=time.monotonic() - batch_start,
|
||||
)
|
||||
continue
|
||||
|
||||
# Build app_id -> tenant_id mapping
|
||||
@ -456,16 +286,9 @@ class MessagesCleanService:
|
||||
|
||||
if not message_ids_to_delete:
|
||||
logger.info("clean_messages (batch %s): no messages to delete, skip", stats["batches"])
|
||||
self._metrics.record_batch(
|
||||
scanned_messages=batch_scanned_messages,
|
||||
filtered_messages=batch_filtered_messages,
|
||||
deleted_messages=batch_deleted_messages,
|
||||
batch_duration_seconds=time.monotonic() - batch_start,
|
||||
)
|
||||
continue
|
||||
|
||||
stats["filtered_messages"] += len(message_ids_to_delete)
|
||||
batch_filtered_messages = len(message_ids_to_delete)
|
||||
|
||||
# Step 4: Batch delete messages and their relations
|
||||
if not self._dry_run:
|
||||
@ -486,7 +309,6 @@ class MessagesCleanService:
|
||||
commit_ms = int((time.monotonic() - commit_start) * 1000)
|
||||
|
||||
stats["total_deleted"] += messages_deleted
|
||||
batch_deleted_messages = messages_deleted
|
||||
|
||||
logger.info(
|
||||
"clean_messages (batch %s): processed %s messages, deleted %s messages",
|
||||
@ -521,13 +343,6 @@ class MessagesCleanService:
|
||||
for msg_id in sampled_ids:
|
||||
logger.info("clean_messages (batch %s, dry_run) sample: message_id=%s", stats["batches"], msg_id)
|
||||
|
||||
self._metrics.record_batch(
|
||||
scanned_messages=batch_scanned_messages,
|
||||
filtered_messages=batch_filtered_messages,
|
||||
deleted_messages=batch_deleted_messages,
|
||||
batch_duration_seconds=time.monotonic() - batch_start,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"clean_messages completed: total batches: %s, total messages: %s, filtered messages: %s, total deleted: %s",
|
||||
stats["batches"],
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import click
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
@ -20,156 +20,6 @@ from services.billing_service import BillingService, SubscriptionPlan
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.metrics import Counter, Histogram
|
||||
|
||||
|
||||
class WorkflowRunCleanupMetrics:
|
||||
"""
|
||||
Records low-cardinality OpenTelemetry metrics for workflow run cleanup jobs.
|
||||
|
||||
Metrics are emitted with stable labels only (dry_run/window_mode/task_label/status)
|
||||
to keep dashboard and alert cardinality predictable in production clusters.
|
||||
"""
|
||||
|
||||
_job_runs_total: "Counter | None"
|
||||
_batches_total: "Counter | None"
|
||||
_runs_scanned_total: "Counter | None"
|
||||
_runs_targeted_total: "Counter | None"
|
||||
_runs_deleted_total: "Counter | None"
|
||||
_runs_skipped_total: "Counter | None"
|
||||
_related_records_total: "Counter | None"
|
||||
_job_duration_seconds: "Histogram | None"
|
||||
_batch_duration_seconds: "Histogram | None"
|
||||
_base_attributes: dict[str, str]
|
||||
|
||||
def __init__(self, *, dry_run: bool, has_window: bool, task_label: str) -> None:
|
||||
self._job_runs_total = None
|
||||
self._batches_total = None
|
||||
self._runs_scanned_total = None
|
||||
self._runs_targeted_total = None
|
||||
self._runs_deleted_total = None
|
||||
self._runs_skipped_total = None
|
||||
self._related_records_total = None
|
||||
self._job_duration_seconds = None
|
||||
self._batch_duration_seconds = None
|
||||
self._base_attributes = {
|
||||
"job_name": "workflow_run_cleanup",
|
||||
"dry_run": str(dry_run).lower(),
|
||||
"window_mode": "between" if has_window else "before_cutoff",
|
||||
"task_label": task_label,
|
||||
}
|
||||
self._init_instruments()
|
||||
|
||||
def _init_instruments(self) -> None:
|
||||
try:
|
||||
from opentelemetry.metrics import get_meter
|
||||
|
||||
meter = get_meter("workflow_run_cleanup", version=dify_config.project.version)
|
||||
self._job_runs_total = meter.create_counter(
|
||||
"workflow_run_cleanup_jobs_total",
|
||||
description="Total number of workflow run cleanup jobs by status.",
|
||||
unit="{job}",
|
||||
)
|
||||
self._batches_total = meter.create_counter(
|
||||
"workflow_run_cleanup_batches_total",
|
||||
description="Total number of processed cleanup batches.",
|
||||
unit="{batch}",
|
||||
)
|
||||
self._runs_scanned_total = meter.create_counter(
|
||||
"workflow_run_cleanup_scanned_runs_total",
|
||||
description="Total workflow runs scanned by cleanup jobs.",
|
||||
unit="{run}",
|
||||
)
|
||||
self._runs_targeted_total = meter.create_counter(
|
||||
"workflow_run_cleanup_targeted_runs_total",
|
||||
description="Total workflow runs targeted by cleanup policy.",
|
||||
unit="{run}",
|
||||
)
|
||||
self._runs_deleted_total = meter.create_counter(
|
||||
"workflow_run_cleanup_deleted_runs_total",
|
||||
description="Total workflow runs deleted by cleanup jobs.",
|
||||
unit="{run}",
|
||||
)
|
||||
self._runs_skipped_total = meter.create_counter(
|
||||
"workflow_run_cleanup_skipped_runs_total",
|
||||
description="Total workflow runs skipped because tenant is paid/unknown.",
|
||||
unit="{run}",
|
||||
)
|
||||
self._related_records_total = meter.create_counter(
|
||||
"workflow_run_cleanup_related_records_total",
|
||||
description="Total related records processed by cleanup jobs.",
|
||||
unit="{record}",
|
||||
)
|
||||
self._job_duration_seconds = meter.create_histogram(
|
||||
"workflow_run_cleanup_job_duration_seconds",
|
||||
description="Duration of workflow run cleanup jobs in seconds.",
|
||||
unit="s",
|
||||
)
|
||||
self._batch_duration_seconds = meter.create_histogram(
|
||||
"workflow_run_cleanup_batch_duration_seconds",
|
||||
description="Duration of workflow run cleanup batch processing in seconds.",
|
||||
unit="s",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("workflow_run_cleanup_metrics: failed to initialize instruments")
|
||||
|
||||
def _attrs(self, **extra: str) -> dict[str, str]:
|
||||
return {**self._base_attributes, **extra}
|
||||
|
||||
@staticmethod
|
||||
def _add(counter: "Counter | None", value: int, attributes: dict[str, str]) -> None:
|
||||
if not counter or value <= 0:
|
||||
return
|
||||
try:
|
||||
counter.add(value, attributes)
|
||||
except Exception:
|
||||
logger.exception("workflow_run_cleanup_metrics: failed to add counter value")
|
||||
|
||||
@staticmethod
|
||||
def _record(histogram: "Histogram | None", value: float, attributes: dict[str, str]) -> None:
|
||||
if not histogram:
|
||||
return
|
||||
try:
|
||||
histogram.record(value, attributes)
|
||||
except Exception:
|
||||
logger.exception("workflow_run_cleanup_metrics: failed to record histogram value")
|
||||
|
||||
def record_batch(
|
||||
self,
|
||||
*,
|
||||
batch_rows: int,
|
||||
targeted_runs: int,
|
||||
skipped_runs: int,
|
||||
deleted_runs: int,
|
||||
related_counts: dict[str, int] | None,
|
||||
related_action: str | None,
|
||||
batch_duration_seconds: float,
|
||||
) -> None:
|
||||
attributes = self._attrs()
|
||||
self._add(self._batches_total, 1, attributes)
|
||||
self._add(self._runs_scanned_total, batch_rows, attributes)
|
||||
self._add(self._runs_targeted_total, targeted_runs, attributes)
|
||||
self._add(self._runs_skipped_total, skipped_runs, attributes)
|
||||
self._add(self._runs_deleted_total, deleted_runs, attributes)
|
||||
self._record(self._batch_duration_seconds, batch_duration_seconds, attributes)
|
||||
|
||||
if not related_counts or not related_action:
|
||||
return
|
||||
|
||||
for record_type, count in related_counts.items():
|
||||
self._add(
|
||||
self._related_records_total,
|
||||
count,
|
||||
self._attrs(action=related_action, record_type=record_type),
|
||||
)
|
||||
|
||||
def record_completion(self, *, status: str, job_duration_seconds: float) -> None:
|
||||
attributes = self._attrs(status=status)
|
||||
self._add(self._job_runs_total, 1, attributes)
|
||||
self._record(self._job_duration_seconds, job_duration_seconds, attributes)
|
||||
|
||||
|
||||
class WorkflowRunCleanup:
|
||||
def __init__(
|
||||
self,
|
||||
@ -179,7 +29,6 @@ class WorkflowRunCleanup:
|
||||
end_before: datetime.datetime | None = None,
|
||||
workflow_run_repo: APIWorkflowRunRepository | None = None,
|
||||
dry_run: bool = False,
|
||||
task_label: str = "daily",
|
||||
):
|
||||
if (start_from is None) ^ (end_before is None):
|
||||
raise ValueError("start_from and end_before must be both set or both omitted.")
|
||||
@ -197,13 +46,6 @@ class WorkflowRunCleanup:
|
||||
self.batch_size = batch_size
|
||||
self._cleanup_whitelist: set[str] | None = None
|
||||
self.dry_run = dry_run
|
||||
normalized_task_label = task_label.strip()
|
||||
self.task_label = normalized_task_label or "daily"
|
||||
self._metrics = WorkflowRunCleanupMetrics(
|
||||
dry_run=dry_run,
|
||||
has_window=bool(start_from),
|
||||
task_label=self.task_label,
|
||||
)
|
||||
self.free_plan_grace_period_days = dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD
|
||||
self.workflow_run_repo: APIWorkflowRunRepository
|
||||
if workflow_run_repo:
|
||||
@ -232,193 +74,153 @@ class WorkflowRunCleanup:
|
||||
related_totals = self._empty_related_counts() if self.dry_run else None
|
||||
batch_index = 0
|
||||
last_seen: tuple[datetime.datetime, str] | None = None
|
||||
status = "success"
|
||||
run_start = time.monotonic()
|
||||
max_batch_interval_ms = dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL
|
||||
|
||||
try:
|
||||
while True:
|
||||
batch_start = time.monotonic()
|
||||
max_batch_interval_ms = int(os.environ.get("SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL", 200))
|
||||
|
||||
fetch_start = time.monotonic()
|
||||
run_rows = self.workflow_run_repo.get_runs_batch_by_time_range(
|
||||
start_from=self.window_start,
|
||||
end_before=self.window_end,
|
||||
last_seen=last_seen,
|
||||
batch_size=self.batch_size,
|
||||
while True:
|
||||
batch_start = time.monotonic()
|
||||
|
||||
fetch_start = time.monotonic()
|
||||
run_rows = self.workflow_run_repo.get_runs_batch_by_time_range(
|
||||
start_from=self.window_start,
|
||||
end_before=self.window_end,
|
||||
last_seen=last_seen,
|
||||
batch_size=self.batch_size,
|
||||
)
|
||||
if not run_rows:
|
||||
logger.info("workflow_run_cleanup (batch #%s): no more rows to process", batch_index + 1)
|
||||
break
|
||||
|
||||
batch_index += 1
|
||||
last_seen = (run_rows[-1].created_at, run_rows[-1].id)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s): fetched %s rows in %sms",
|
||||
batch_index,
|
||||
len(run_rows),
|
||||
int((time.monotonic() - fetch_start) * 1000),
|
||||
)
|
||||
|
||||
tenant_ids = {row.tenant_id for row in run_rows}
|
||||
|
||||
filter_start = time.monotonic()
|
||||
free_tenants = self._filter_free_tenants(tenant_ids)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s): filtered %s free tenants from %s tenants in %sms",
|
||||
batch_index,
|
||||
len(free_tenants),
|
||||
len(tenant_ids),
|
||||
int((time.monotonic() - filter_start) * 1000),
|
||||
)
|
||||
|
||||
free_runs = [row for row in run_rows if row.tenant_id in free_tenants]
|
||||
paid_or_skipped = len(run_rows) - len(free_runs)
|
||||
|
||||
if not free_runs:
|
||||
skipped_message = (
|
||||
f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)"
|
||||
)
|
||||
if not run_rows:
|
||||
logger.info("workflow_run_cleanup (batch #%s): no more rows to process", batch_index + 1)
|
||||
break
|
||||
|
||||
batch_index += 1
|
||||
last_seen = (run_rows[-1].created_at, run_rows[-1].id)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s): fetched %s rows in %sms",
|
||||
batch_index,
|
||||
len(run_rows),
|
||||
int((time.monotonic() - fetch_start) * 1000),
|
||||
)
|
||||
|
||||
tenant_ids = {row.tenant_id for row in run_rows}
|
||||
|
||||
filter_start = time.monotonic()
|
||||
free_tenants = self._filter_free_tenants(tenant_ids)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s): filtered %s free tenants from %s tenants in %sms",
|
||||
batch_index,
|
||||
len(free_tenants),
|
||||
len(tenant_ids),
|
||||
int((time.monotonic() - filter_start) * 1000),
|
||||
)
|
||||
|
||||
free_runs = [row for row in run_rows if row.tenant_id in free_tenants]
|
||||
paid_or_skipped = len(run_rows) - len(free_runs)
|
||||
|
||||
if not free_runs:
|
||||
skipped_message = (
|
||||
f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)"
|
||||
)
|
||||
click.echo(
|
||||
click.style(
|
||||
skipped_message,
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
self._metrics.record_batch(
|
||||
batch_rows=len(run_rows),
|
||||
targeted_runs=0,
|
||||
skipped_runs=paid_or_skipped,
|
||||
deleted_runs=0,
|
||||
related_counts=None,
|
||||
related_action=None,
|
||||
batch_duration_seconds=time.monotonic() - batch_start,
|
||||
)
|
||||
continue
|
||||
|
||||
total_runs_targeted += len(free_runs)
|
||||
|
||||
if self.dry_run:
|
||||
count_start = time.monotonic()
|
||||
batch_counts = self.workflow_run_repo.count_runs_with_related(
|
||||
free_runs,
|
||||
count_node_executions=self._count_node_executions,
|
||||
count_trigger_logs=self._count_trigger_logs,
|
||||
)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s, dry_run): counted related records in %sms",
|
||||
batch_index,
|
||||
int((time.monotonic() - count_start) * 1000),
|
||||
)
|
||||
if related_totals is not None:
|
||||
for key in related_totals:
|
||||
related_totals[key] += batch_counts.get(key, 0)
|
||||
sample_ids = ", ".join(run.id for run in free_runs[:5])
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[batch #{batch_index}] would delete {len(free_runs)} runs "
|
||||
f"(sample ids: {sample_ids}) and skip {paid_or_skipped} paid/unknown",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s, dry_run): batch total %sms",
|
||||
batch_index,
|
||||
int((time.monotonic() - batch_start) * 1000),
|
||||
)
|
||||
self._metrics.record_batch(
|
||||
batch_rows=len(run_rows),
|
||||
targeted_runs=len(free_runs),
|
||||
skipped_runs=paid_or_skipped,
|
||||
deleted_runs=0,
|
||||
related_counts={key: batch_counts.get(key, 0) for key in self._empty_related_counts()},
|
||||
related_action="would_delete",
|
||||
batch_duration_seconds=time.monotonic() - batch_start,
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
delete_start = time.monotonic()
|
||||
counts = self.workflow_run_repo.delete_runs_with_related(
|
||||
free_runs,
|
||||
delete_node_executions=self._delete_node_executions,
|
||||
delete_trigger_logs=self._delete_trigger_logs,
|
||||
)
|
||||
delete_ms = int((time.monotonic() - delete_start) * 1000)
|
||||
except Exception:
|
||||
logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0])
|
||||
raise
|
||||
|
||||
total_runs_deleted += counts["runs"]
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[batch #{batch_index}] deleted runs: {counts['runs']} "
|
||||
f"(nodes {counts['node_executions']}, offloads {counts['offloads']}, "
|
||||
f"app_logs {counts['app_logs']}, trigger_logs {counts['trigger_logs']}, "
|
||||
f"pauses {counts['pauses']}, pause_reasons {counts['pause_reasons']}); "
|
||||
f"skipped {paid_or_skipped} paid/unknown",
|
||||
fg="green",
|
||||
skipped_message,
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
total_runs_targeted += len(free_runs)
|
||||
|
||||
if self.dry_run:
|
||||
count_start = time.monotonic()
|
||||
batch_counts = self.workflow_run_repo.count_runs_with_related(
|
||||
free_runs,
|
||||
count_node_executions=self._count_node_executions,
|
||||
count_trigger_logs=self._count_trigger_logs,
|
||||
)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s, dry_run): counted related records in %sms",
|
||||
batch_index,
|
||||
int((time.monotonic() - count_start) * 1000),
|
||||
)
|
||||
if related_totals is not None:
|
||||
for key in related_totals:
|
||||
related_totals[key] += batch_counts.get(key, 0)
|
||||
sample_ids = ", ".join(run.id for run in free_runs[:5])
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[batch #{batch_index}] would delete {len(free_runs)} runs "
|
||||
f"(sample ids: {sample_ids}) and skip {paid_or_skipped} paid/unknown",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s): delete %sms, batch total %sms",
|
||||
"workflow_run_cleanup (batch #%s, dry_run): batch total %sms",
|
||||
batch_index,
|
||||
delete_ms,
|
||||
int((time.monotonic() - batch_start) * 1000),
|
||||
)
|
||||
self._metrics.record_batch(
|
||||
batch_rows=len(run_rows),
|
||||
targeted_runs=len(free_runs),
|
||||
skipped_runs=paid_or_skipped,
|
||||
deleted_runs=counts["runs"],
|
||||
related_counts={key: counts.get(key, 0) for key in self._empty_related_counts()},
|
||||
related_action="deleted",
|
||||
batch_duration_seconds=time.monotonic() - batch_start,
|
||||
continue
|
||||
|
||||
try:
|
||||
delete_start = time.monotonic()
|
||||
counts = self.workflow_run_repo.delete_runs_with_related(
|
||||
free_runs,
|
||||
delete_node_executions=self._delete_node_executions,
|
||||
delete_trigger_logs=self._delete_trigger_logs,
|
||||
)
|
||||
delete_ms = int((time.monotonic() - delete_start) * 1000)
|
||||
except Exception:
|
||||
logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0])
|
||||
raise
|
||||
|
||||
# Random sleep between batches to avoid overwhelming the database
|
||||
sleep_ms = random.uniform(0, max_batch_interval_ms) # noqa: S311
|
||||
logger.info("workflow_run_cleanup (batch #%s): sleeping for %.2fms", batch_index, sleep_ms)
|
||||
time.sleep(sleep_ms / 1000)
|
||||
|
||||
if self.dry_run:
|
||||
if self.window_start:
|
||||
summary_message = (
|
||||
f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
|
||||
f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
|
||||
)
|
||||
else:
|
||||
summary_message = (
|
||||
f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
|
||||
f"before {self.window_end.isoformat()}"
|
||||
)
|
||||
if related_totals is not None:
|
||||
summary_message = (
|
||||
f"{summary_message}; related records: {self._format_related_counts(related_totals)}"
|
||||
)
|
||||
summary_color = "yellow"
|
||||
else:
|
||||
if self.window_start:
|
||||
summary_message = (
|
||||
f"Cleanup complete. Deleted {total_runs_deleted} workflow runs "
|
||||
f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
|
||||
)
|
||||
else:
|
||||
summary_message = (
|
||||
f"Cleanup complete. Deleted {total_runs_deleted} workflow runs "
|
||||
f"before {self.window_end.isoformat()}"
|
||||
)
|
||||
summary_color = "white"
|
||||
|
||||
click.echo(click.style(summary_message, fg=summary_color))
|
||||
except Exception:
|
||||
status = "failed"
|
||||
raise
|
||||
finally:
|
||||
self._metrics.record_completion(
|
||||
status=status,
|
||||
job_duration_seconds=time.monotonic() - run_start,
|
||||
total_runs_deleted += counts["runs"]
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[batch #{batch_index}] deleted runs: {counts['runs']} "
|
||||
f"(nodes {counts['node_executions']}, offloads {counts['offloads']}, "
|
||||
f"app_logs {counts['app_logs']}, trigger_logs {counts['trigger_logs']}, "
|
||||
f"pauses {counts['pauses']}, pause_reasons {counts['pause_reasons']}); "
|
||||
f"skipped {paid_or_skipped} paid/unknown",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s): delete %sms, batch total %sms",
|
||||
batch_index,
|
||||
delete_ms,
|
||||
int((time.monotonic() - batch_start) * 1000),
|
||||
)
|
||||
|
||||
# Random sleep between batches to avoid overwhelming the database
|
||||
sleep_ms = random.uniform(0, max_batch_interval_ms) # noqa: S311
|
||||
logger.info("workflow_run_cleanup (batch #%s): sleeping for %.2fms", batch_index, sleep_ms)
|
||||
time.sleep(sleep_ms / 1000)
|
||||
|
||||
if self.dry_run:
|
||||
if self.window_start:
|
||||
summary_message = (
|
||||
f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
|
||||
f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
|
||||
)
|
||||
else:
|
||||
summary_message = (
|
||||
f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
|
||||
f"before {self.window_end.isoformat()}"
|
||||
)
|
||||
if related_totals is not None:
|
||||
summary_message = f"{summary_message}; related records: {self._format_related_counts(related_totals)}"
|
||||
summary_color = "yellow"
|
||||
else:
|
||||
if self.window_start:
|
||||
summary_message = (
|
||||
f"Cleanup complete. Deleted {total_runs_deleted} workflow runs "
|
||||
f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
|
||||
)
|
||||
else:
|
||||
summary_message = (
|
||||
f"Cleanup complete. Deleted {total_runs_deleted} workflow runs before {self.window_end.isoformat()}"
|
||||
)
|
||||
summary_color = "white"
|
||||
|
||||
click.echo(click.style(summary_message, fg=summary_color))
|
||||
|
||||
def _filter_free_tenants(self, tenant_ids: Iterable[str]) -> set[str]:
|
||||
tenant_id_list = list(tenant_ids)
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable, Sequence
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Protocol
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from celery import current_app, shared_task
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
@ -19,6 +20,12 @@ from tasks.generate_summary_index_task import generate_summary_index_task
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CeleryTaskLike(Protocol):
|
||||
def delay(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
def apply_async(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def document_indexing_task(dataset_id: str, document_ids: list):
|
||||
"""
|
||||
@ -179,8 +186,8 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
|
||||
|
||||
def _document_indexing_with_tenant_queue(
|
||||
tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: Callable[[str, str, Sequence[str]], None]
|
||||
):
|
||||
tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: CeleryTaskLike
|
||||
) -> None:
|
||||
try:
|
||||
_document_indexing(dataset_id, document_ids)
|
||||
except Exception:
|
||||
@ -201,16 +208,20 @@ def _document_indexing_with_tenant_queue(
|
||||
logger.info("document indexing tenant isolation queue %s next tasks: %s", tenant_id, next_tasks)
|
||||
|
||||
if next_tasks:
|
||||
for next_task in next_tasks:
|
||||
document_task = DocumentTask(**next_task)
|
||||
# Process the next waiting task
|
||||
# Keep the flag set to indicate a task is running
|
||||
tenant_isolated_task_queue.set_task_waiting_time()
|
||||
task_func.delay( # type: ignore
|
||||
tenant_id=document_task.tenant_id,
|
||||
dataset_id=document_task.dataset_id,
|
||||
document_ids=document_task.document_ids,
|
||||
)
|
||||
with current_app.producer_or_acquire() as producer: # type: ignore
|
||||
for next_task in next_tasks:
|
||||
document_task = DocumentTask(**next_task)
|
||||
# Keep the flag set to indicate a task is running
|
||||
tenant_isolated_task_queue.set_task_waiting_time()
|
||||
task_func.apply_async(
|
||||
kwargs={
|
||||
"tenant_id": document_task.tenant_id,
|
||||
"dataset_id": document_task.dataset_id,
|
||||
"document_ids": document_task.document_ids,
|
||||
},
|
||||
producer=producer,
|
||||
)
|
||||
|
||||
else:
|
||||
# No more waiting tasks, clear the flag
|
||||
tenant_isolated_task_queue.delete_task_key()
|
||||
|
||||
@ -3,12 +3,13 @@ import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from itertools import islice
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
from celery import group, shared_task
|
||||
from flask import current_app, g
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
@ -27,6 +28,11 @@ from services.file_service import FileService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def chunked(iterable: Sequence, size: int):
|
||||
it = iter(iterable)
|
||||
return iter(lambda: list(islice(it, size)), [])
|
||||
|
||||
|
||||
@shared_task(queue="pipeline")
|
||||
def rag_pipeline_run_task(
|
||||
rag_pipeline_invoke_entities_file_id: str,
|
||||
@ -83,16 +89,24 @@ def rag_pipeline_run_task(
|
||||
logger.info("rag pipeline tenant isolation queue %s next files: %s", tenant_id, next_file_ids)
|
||||
|
||||
if next_file_ids:
|
||||
for next_file_id in next_file_ids:
|
||||
# Process the next waiting task
|
||||
# Keep the flag set to indicate a task is running
|
||||
tenant_isolated_task_queue.set_task_waiting_time()
|
||||
rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
|
||||
if isinstance(next_file_id, bytes)
|
||||
else next_file_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
for batch in chunked(next_file_ids, 100):
|
||||
jobs = []
|
||||
for next_file_id in batch:
|
||||
tenant_isolated_task_queue.set_task_waiting_time()
|
||||
|
||||
file_id = (
|
||||
next_file_id.decode("utf-8") if isinstance(next_file_id, (bytes, bytearray)) else next_file_id
|
||||
)
|
||||
|
||||
jobs.append(
|
||||
rag_pipeline_run_task.s(
|
||||
rag_pipeline_invoke_entities_file_id=file_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
if jobs:
|
||||
group(jobs).apply_async()
|
||||
else:
|
||||
# No more waiting tasks, clear the flag
|
||||
tenant_isolated_task_queue.delete_task_key()
|
||||
|
||||
@ -322,11 +322,14 @@ class TestDatasetIndexingTaskIntegration:
|
||||
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
||||
|
||||
# Assert
|
||||
task_dispatch_spy.delay.assert_called_once_with(
|
||||
tenant_id=next_task["tenant_id"],
|
||||
dataset_id=next_task["dataset_id"],
|
||||
document_ids=next_task["document_ids"],
|
||||
)
|
||||
# apply_async is used by implementation; assert it was called once with expected kwargs
|
||||
assert task_dispatch_spy.apply_async.call_count == 1
|
||||
call_kwargs = task_dispatch_spy.apply_async.call_args.kwargs.get("kwargs", {})
|
||||
assert call_kwargs == {
|
||||
"tenant_id": next_task["tenant_id"],
|
||||
"dataset_id": next_task["dataset_id"],
|
||||
"document_ids": next_task["document_ids"],
|
||||
}
|
||||
set_waiting_spy.assert_called_once()
|
||||
delete_key_spy.assert_not_called()
|
||||
|
||||
@ -352,7 +355,7 @@ class TestDatasetIndexingTaskIntegration:
|
||||
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
||||
|
||||
# Assert
|
||||
task_dispatch_spy.delay.assert_not_called()
|
||||
task_dispatch_spy.apply_async.assert_not_called()
|
||||
delete_key_spy.assert_called_once()
|
||||
|
||||
def test_validation_failure_sets_error_status_when_vector_space_at_limit(
|
||||
@ -447,7 +450,7 @@ class TestDatasetIndexingTaskIntegration:
|
||||
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
||||
|
||||
# Assert
|
||||
task_dispatch_spy.delay.assert_called_once()
|
||||
task_dispatch_spy.apply_async.assert_called_once()
|
||||
|
||||
def test_sessions_close_on_successful_indexing(
|
||||
self,
|
||||
@ -534,7 +537,7 @@ class TestDatasetIndexingTaskIntegration:
|
||||
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
||||
|
||||
# Assert
|
||||
assert task_dispatch_spy.delay.call_count == concurrency_limit
|
||||
assert task_dispatch_spy.apply_async.call_count == concurrency_limit
|
||||
assert set_waiting_spy.call_count == concurrency_limit
|
||||
|
||||
def test_task_queue_fifo_ordering(self, db_session_with_containers, patched_external_dependencies):
|
||||
@ -565,9 +568,10 @@ class TestDatasetIndexingTaskIntegration:
|
||||
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
||||
|
||||
# Assert
|
||||
assert task_dispatch_spy.delay.call_count == 3
|
||||
assert task_dispatch_spy.apply_async.call_count == 3
|
||||
for index, expected_task in enumerate(ordered_tasks):
|
||||
assert task_dispatch_spy.delay.call_args_list[index].kwargs["document_ids"] == expected_task["document_ids"]
|
||||
call_kwargs = task_dispatch_spy.apply_async.call_args_list[index].kwargs.get("kwargs", {})
|
||||
assert call_kwargs.get("document_ids") == expected_task["document_ids"]
|
||||
|
||||
def test_billing_disabled_skips_limit_checks(self, db_session_with_containers, patched_external_dependencies):
|
||||
"""Skip limit checks when billing feature is disabled."""
|
||||
|
||||
@ -762,11 +762,12 @@ class TestDocumentIndexingTasks:
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify task function was called for each waiting task
|
||||
assert mock_task_func.delay.call_count == 1
|
||||
assert mock_task_func.apply_async.call_count == 1
|
||||
|
||||
# Verify correct parameters for each call
|
||||
calls = mock_task_func.delay.call_args_list
|
||||
assert calls[0][1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
|
||||
calls = mock_task_func.apply_async.call_args_list
|
||||
sent_kwargs = calls[0][1]["kwargs"]
|
||||
assert sent_kwargs == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
|
||||
|
||||
# Verify queue is empty after processing (tasks were pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10) # Pull more than we added
|
||||
@ -830,11 +831,15 @@ class TestDocumentIndexingTasks:
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
# Verify waiting task was still processed despite core processing error
|
||||
mock_task_func.delay.assert_called_once()
|
||||
mock_task_func.apply_async.assert_called_once()
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call = mock_task_func.delay.call_args
|
||||
assert call[1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
|
||||
call = mock_task_func.apply_async.call_args
|
||||
assert call[1]["kwargs"] == {
|
||||
"tenant_id": tenant_id,
|
||||
"dataset_id": dataset_id,
|
||||
"document_ids": ["waiting-doc-1"],
|
||||
}
|
||||
|
||||
# Verify queue is empty after processing (task was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
@ -896,9 +901,13 @@ class TestDocumentIndexingTasks:
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify only tenant1's waiting task was processed
|
||||
mock_task_func.delay.assert_called_once()
|
||||
call = mock_task_func.delay.call_args
|
||||
assert call[1] == {"tenant_id": tenant1_id, "dataset_id": dataset1_id, "document_ids": ["tenant1-doc-1"]}
|
||||
mock_task_func.apply_async.assert_called_once()
|
||||
call = mock_task_func.apply_async.call_args
|
||||
assert call[1]["kwargs"] == {
|
||||
"tenant_id": tenant1_id,
|
||||
"dataset_id": dataset1_id,
|
||||
"document_ids": ["tenant1-doc-1"],
|
||||
}
|
||||
|
||||
# Verify tenant1's queue is empty
|
||||
remaining_tasks1 = queue1.pull_tasks(count=10)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
@ -388,8 +388,10 @@ class TestRagPipelineRunTasks:
|
||||
# Set the task key to indicate there are waiting tasks (legacy behavior)
|
||||
redis_client.set(legacy_task_key, 1, ex=60 * 60)
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Mock the Celery group scheduling used by the implementation
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
|
||||
mock_group.return_value.apply_async = MagicMock()
|
||||
|
||||
# Act: Execute the priority task with new code but legacy queue data
|
||||
rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
@ -398,13 +400,14 @@ class TestRagPipelineRunTasks:
|
||||
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify waiting tasks were processed, pull 1 task a time by default
|
||||
assert mock_delay.call_count == 1
|
||||
# Verify waiting tasks were processed via group, pull 1 task a time by default
|
||||
assert mock_group.return_value.apply_async.called
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == legacy_file_ids[0]
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
# Verify correct parameters for the first scheduled job signature
|
||||
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
|
||||
first_kwargs = jobs[0].kwargs if jobs else {}
|
||||
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == legacy_file_ids[0]
|
||||
assert first_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify that new code can process legacy queue entries
|
||||
# The new TenantIsolatedTaskQueue should be able to read from the legacy format
|
||||
@ -446,8 +449,10 @@ class TestRagPipelineRunTasks:
|
||||
waiting_file_ids = [str(uuid.uuid4()) for _ in range(3)]
|
||||
queue.push_tasks(waiting_file_ids)
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Mock the Celery group scheduling used by the implementation
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
|
||||
mock_group.return_value.apply_async = MagicMock()
|
||||
|
||||
# Act: Execute the regular task
|
||||
rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
@ -456,13 +461,14 @@ class TestRagPipelineRunTasks:
|
||||
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify waiting tasks were processed, pull 1 task a time by default
|
||||
assert mock_delay.call_count == 1
|
||||
# Verify waiting tasks were processed via group.apply_async
|
||||
assert mock_group.return_value.apply_async.called
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0]
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
# Verify correct parameters for the first scheduled job signature
|
||||
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
|
||||
first_kwargs = jobs[0].kwargs if jobs else {}
|
||||
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0]
|
||||
assert first_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify queue still has remaining tasks (only 1 was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
@ -557,8 +563,10 @@ class TestRagPipelineRunTasks:
|
||||
waiting_file_id = str(uuid.uuid4())
|
||||
queue.push_tasks([waiting_file_id])
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Mock the Celery group scheduling used by the implementation
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
|
||||
mock_group.return_value.apply_async = MagicMock()
|
||||
|
||||
# Act: Execute the regular task (should not raise exception)
|
||||
rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
@ -569,12 +577,13 @@ class TestRagPipelineRunTasks:
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify waiting task was still processed despite core processing error
|
||||
mock_delay.assert_called_once()
|
||||
assert mock_group.return_value.apply_async.called
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
# Verify correct parameters for the first scheduled job signature
|
||||
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
|
||||
first_kwargs = jobs[0].kwargs if jobs else {}
|
||||
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
|
||||
assert first_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify queue is empty after processing (task was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
@ -684,8 +693,10 @@ class TestRagPipelineRunTasks:
|
||||
queue1.push_tasks([waiting_file_id1])
|
||||
queue2.push_tasks([waiting_file_id2])
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Mock the Celery group scheduling used by the implementation
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
|
||||
mock_group.return_value.apply_async = MagicMock()
|
||||
|
||||
# Act: Execute the regular task for tenant1 only
|
||||
rag_pipeline_run_task(file_id1, tenant1.id)
|
||||
|
||||
@ -694,11 +705,12 @@ class TestRagPipelineRunTasks:
|
||||
assert mock_file_service["delete_file"].call_count == 1
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify only tenant1's waiting task was processed
|
||||
mock_delay.assert_called_once()
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1
|
||||
assert call_kwargs.get("tenant_id") == tenant1.id
|
||||
# Verify only tenant1's waiting task was processed (via group)
|
||||
assert mock_group.return_value.apply_async.called
|
||||
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
|
||||
first_kwargs = jobs[0].kwargs if jobs else {}
|
||||
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1
|
||||
assert first_kwargs.get("tenant_id") == tenant1.id
|
||||
|
||||
# Verify tenant1's queue is empty
|
||||
remaining_tasks1 = queue1.pull_tasks(count=10)
|
||||
@ -913,8 +925,10 @@ class TestRagPipelineRunTasks:
|
||||
waiting_file_id = str(uuid.uuid4())
|
||||
queue.push_tasks([waiting_file_id])
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Mock the Celery group scheduling used by the implementation
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
|
||||
mock_group.return_value.apply_async = MagicMock()
|
||||
|
||||
# Act & Assert: Execute the regular task (should raise Exception)
|
||||
with pytest.raises(Exception, match="File not found"):
|
||||
rag_pipeline_run_task(file_id, tenant.id)
|
||||
@ -924,12 +938,13 @@ class TestRagPipelineRunTasks:
|
||||
mock_pipeline_generator.assert_not_called()
|
||||
|
||||
# Verify waiting task was still processed despite file error
|
||||
mock_delay.assert_called_once()
|
||||
assert mock_group.return_value.apply_async.called
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
# Verify correct parameters for the first scheduled job signature
|
||||
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
|
||||
first_kwargs = jobs[0].kwargs if jobs else {}
|
||||
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
|
||||
assert first_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify queue is empty after processing (task was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
|
||||
@ -105,18 +105,26 @@ def app_model(
|
||||
|
||||
|
||||
class MockCeleryGroup:
|
||||
"""Mock for celery group() function that collects dispatched tasks."""
|
||||
"""Mock for celery group() function that collects dispatched tasks.
|
||||
|
||||
Matches the Celery group API loosely, accepting arbitrary kwargs on apply_async
|
||||
(e.g. producer) so production code can pass broker-related options without
|
||||
breaking tests.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.collected: list[dict[str, Any]] = []
|
||||
self._applied = False
|
||||
self.last_apply_async_kwargs: dict[str, Any] | None = None
|
||||
|
||||
def __call__(self, items: Any) -> MockCeleryGroup:
|
||||
self.collected = list(items)
|
||||
return self
|
||||
|
||||
def apply_async(self) -> None:
|
||||
def apply_async(self, **kwargs: Any) -> None:
|
||||
# Accept arbitrary kwargs like producer to be compatible with Celery
|
||||
self._applied = True
|
||||
self.last_apply_async_kwargs = kwargs
|
||||
|
||||
@property
|
||||
def applied(self) -> bool:
|
||||
|
||||
@ -38,7 +38,6 @@ def test_absolute_mode_calls_from_time_range():
|
||||
from_days_ago=None,
|
||||
before_days=None,
|
||||
dry_run=True,
|
||||
task_label="daily",
|
||||
)
|
||||
|
||||
mock_from_time_range.assert_called_once_with(
|
||||
@ -47,7 +46,6 @@ def test_absolute_mode_calls_from_time_range():
|
||||
end_before=end_before,
|
||||
batch_size=200,
|
||||
dry_run=True,
|
||||
task_label="daily",
|
||||
)
|
||||
mock_from_days.assert_not_called()
|
||||
|
||||
@ -69,7 +67,6 @@ def test_relative_mode_before_days_only_calls_from_days():
|
||||
from_days_ago=None,
|
||||
before_days=30,
|
||||
dry_run=False,
|
||||
task_label="daily",
|
||||
)
|
||||
|
||||
mock_from_days.assert_called_once_with(
|
||||
@ -77,7 +74,6 @@ def test_relative_mode_before_days_only_calls_from_days():
|
||||
days=30,
|
||||
batch_size=500,
|
||||
dry_run=False,
|
||||
task_label="daily",
|
||||
)
|
||||
mock_from_time_range.assert_not_called()
|
||||
|
||||
@ -101,7 +97,6 @@ def test_relative_mode_with_from_days_ago_calls_from_time_range():
|
||||
from_days_ago=60,
|
||||
before_days=30,
|
||||
dry_run=False,
|
||||
task_label="daily",
|
||||
)
|
||||
|
||||
mock_from_time_range.assert_called_once_with(
|
||||
@ -110,7 +105,6 @@ def test_relative_mode_with_from_days_ago_calls_from_time_range():
|
||||
end_before=fixed_now - datetime.timedelta(days=30),
|
||||
batch_size=1000,
|
||||
dry_run=False,
|
||||
task_label="daily",
|
||||
)
|
||||
mock_from_days.assert_not_called()
|
||||
|
||||
@ -184,5 +178,4 @@ def test_invalid_inputs_raise_usage_error(kwargs: dict, message: str):
|
||||
from_days_ago=kwargs["from_days_ago"],
|
||||
before_days=kwargs["before_days"],
|
||||
dry_run=False,
|
||||
task_label="daily",
|
||||
)
|
||||
|
||||
0
api/tests/unit_tests/controllers/web/__init__.py
Normal file
0
api/tests/unit_tests/controllers/web/__init__.py
Normal file
85
api/tests/unit_tests/controllers/web/conftest.py
Normal file
85
api/tests/unit_tests/controllers/web/conftest.py
Normal file
@ -0,0 +1,85 @@
|
||||
"""Shared fixtures for controllers.web unit tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
"""Minimal Flask app for request contexts."""
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
|
||||
|
||||
class FakeSession:
|
||||
"""Stand-in for db.session that returns pre-seeded objects by model class name."""
|
||||
|
||||
def __init__(self, mapping: dict[str, Any] | None = None):
|
||||
self._mapping: dict[str, Any] = mapping or {}
|
||||
self._model_name: str | None = None
|
||||
|
||||
def query(self, model: type) -> FakeSession:
|
||||
self._model_name = model.__name__
|
||||
return self
|
||||
|
||||
def where(self, *_args: object, **_kwargs: object) -> FakeSession:
|
||||
return self
|
||||
|
||||
def first(self) -> Any:
|
||||
assert self._model_name is not None
|
||||
return self._mapping.get(self._model_name)
|
||||
|
||||
|
||||
class FakeDB:
|
||||
"""Minimal db stub exposing engine and session."""
|
||||
|
||||
def __init__(self, session: FakeSession | None = None):
|
||||
self.session = session or FakeSession()
|
||||
self.engine = object()
|
||||
|
||||
|
||||
def make_app_model(
|
||||
*,
|
||||
app_id: str = "app-1",
|
||||
tenant_id: str = "tenant-1",
|
||||
mode: str = "chat",
|
||||
enable_site: bool = True,
|
||||
status: str = "normal",
|
||||
) -> SimpleNamespace:
|
||||
"""Build a fake App model with common defaults."""
|
||||
tenant = SimpleNamespace(
|
||||
id=tenant_id,
|
||||
status="normal",
|
||||
plan="basic",
|
||||
custom_config_dict={},
|
||||
)
|
||||
return SimpleNamespace(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
tenant=tenant,
|
||||
mode=mode,
|
||||
enable_site=enable_site,
|
||||
status=status,
|
||||
workflow=None,
|
||||
app_model_config=None,
|
||||
)
|
||||
|
||||
|
||||
def make_end_user(
|
||||
*,
|
||||
user_id: str = "end-user-1",
|
||||
session_id: str = "session-1",
|
||||
external_user_id: str = "ext-user-1",
|
||||
) -> SimpleNamespace:
|
||||
"""Build a fake EndUser model with common defaults."""
|
||||
return SimpleNamespace(
|
||||
id=user_id,
|
||||
session_id=session_id,
|
||||
external_user_id=external_user_id,
|
||||
)
|
||||
165
api/tests/unit_tests/controllers/web/test_app.py
Normal file
165
api/tests/unit_tests/controllers/web/test_app.py
Normal file
@ -0,0 +1,165 @@
|
||||
"""Unit tests for controllers.web.app endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.app import AppAccessMode, AppMeta, AppParameterApi, AppWebAuthPermission
|
||||
from controllers.web.error import AppUnavailableError
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AppParameterApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestAppParameterApi:
|
||||
def test_advanced_chat_mode_uses_workflow(self, app: Flask) -> None:
|
||||
features_dict = {"opening_statement": "Hello"}
|
||||
workflow = SimpleNamespace(
|
||||
features_dict=features_dict,
|
||||
user_input_form=lambda to_old_structure=False: [],
|
||||
)
|
||||
app_model = SimpleNamespace(mode="advanced-chat", workflow=workflow)
|
||||
|
||||
with (
|
||||
app.test_request_context("/parameters"),
|
||||
patch("controllers.web.app.get_parameters_from_feature_dict", return_value={}) as mock_params,
|
||||
patch("controllers.web.app.fields.Parameters") as mock_fields,
|
||||
):
|
||||
mock_fields.model_validate.return_value.model_dump.return_value = {"result": "ok"}
|
||||
result = AppParameterApi().get(app_model, SimpleNamespace())
|
||||
|
||||
mock_params.assert_called_once_with(features_dict=features_dict, user_input_form=[])
|
||||
assert result == {"result": "ok"}
|
||||
|
||||
def test_workflow_mode_uses_workflow(self, app: Flask) -> None:
|
||||
features_dict = {}
|
||||
workflow = SimpleNamespace(
|
||||
features_dict=features_dict,
|
||||
user_input_form=lambda to_old_structure=False: [{"var": "x"}],
|
||||
)
|
||||
app_model = SimpleNamespace(mode="workflow", workflow=workflow)
|
||||
|
||||
with (
|
||||
app.test_request_context("/parameters"),
|
||||
patch("controllers.web.app.get_parameters_from_feature_dict", return_value={}) as mock_params,
|
||||
patch("controllers.web.app.fields.Parameters") as mock_fields,
|
||||
):
|
||||
mock_fields.model_validate.return_value.model_dump.return_value = {}
|
||||
AppParameterApi().get(app_model, SimpleNamespace())
|
||||
|
||||
mock_params.assert_called_once_with(features_dict=features_dict, user_input_form=[{"var": "x"}])
|
||||
|
||||
def test_advanced_chat_mode_no_workflow_raises(self, app: Flask) -> None:
|
||||
app_model = SimpleNamespace(mode="advanced-chat", workflow=None)
|
||||
with app.test_request_context("/parameters"):
|
||||
with pytest.raises(AppUnavailableError):
|
||||
AppParameterApi().get(app_model, SimpleNamespace())
|
||||
|
||||
def test_standard_mode_uses_app_model_config(self, app: Flask) -> None:
|
||||
config = SimpleNamespace(to_dict=lambda: {"user_input_form": [{"var": "y"}], "key": "val"})
|
||||
app_model = SimpleNamespace(mode="chat", app_model_config=config)
|
||||
|
||||
with (
|
||||
app.test_request_context("/parameters"),
|
||||
patch("controllers.web.app.get_parameters_from_feature_dict", return_value={}) as mock_params,
|
||||
patch("controllers.web.app.fields.Parameters") as mock_fields,
|
||||
):
|
||||
mock_fields.model_validate.return_value.model_dump.return_value = {}
|
||||
AppParameterApi().get(app_model, SimpleNamespace())
|
||||
|
||||
call_kwargs = mock_params.call_args
|
||||
assert call_kwargs.kwargs["user_input_form"] == [{"var": "y"}]
|
||||
|
||||
def test_standard_mode_no_config_raises(self, app: Flask) -> None:
|
||||
app_model = SimpleNamespace(mode="chat", app_model_config=None)
|
||||
with app.test_request_context("/parameters"):
|
||||
with pytest.raises(AppUnavailableError):
|
||||
AppParameterApi().get(app_model, SimpleNamespace())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AppMeta
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestAppMeta:
|
||||
@patch("controllers.web.app.AppService")
|
||||
def test_get_returns_meta(self, mock_service_cls: MagicMock, app: Flask) -> None:
|
||||
mock_service_cls.return_value.get_app_meta.return_value = {"tool_icons": {}}
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
with app.test_request_context("/meta"):
|
||||
result = AppMeta().get(app_model, SimpleNamespace())
|
||||
|
||||
assert result == {"tool_icons": {}}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AppAccessMode
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestAppAccessMode:
|
||||
@patch("controllers.web.app.FeatureService.get_system_features")
|
||||
def test_returns_public_when_webapp_auth_disabled(self, mock_features: MagicMock, app: Flask) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
with app.test_request_context("/webapp/access-mode?appId=app-1"):
|
||||
result = AppAccessMode().get()
|
||||
|
||||
assert result == {"accessMode": "public"}
|
||||
|
||||
@patch("controllers.web.app.EnterpriseService.WebAppAuth.get_app_access_mode_by_id")
|
||||
@patch("controllers.web.app.FeatureService.get_system_features")
|
||||
def test_returns_access_mode_with_app_id(
|
||||
self, mock_features: MagicMock, mock_access: MagicMock, app: Flask
|
||||
) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True))
|
||||
mock_access.return_value = SimpleNamespace(access_mode="internal")
|
||||
|
||||
with app.test_request_context("/webapp/access-mode?appId=app-1"):
|
||||
result = AppAccessMode().get()
|
||||
|
||||
assert result == {"accessMode": "internal"}
|
||||
mock_access.assert_called_once_with("app-1")
|
||||
|
||||
@patch("controllers.web.app.AppService.get_app_id_by_code", return_value="resolved-id")
|
||||
@patch("controllers.web.app.EnterpriseService.WebAppAuth.get_app_access_mode_by_id")
|
||||
@patch("controllers.web.app.FeatureService.get_system_features")
|
||||
def test_resolves_app_code_to_id(
|
||||
self, mock_features: MagicMock, mock_access: MagicMock, mock_resolve: MagicMock, app: Flask
|
||||
) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True))
|
||||
mock_access.return_value = SimpleNamespace(access_mode="external")
|
||||
|
||||
with app.test_request_context("/webapp/access-mode?appCode=code1"):
|
||||
result = AppAccessMode().get()
|
||||
|
||||
mock_resolve.assert_called_once_with("code1")
|
||||
mock_access.assert_called_once_with("resolved-id")
|
||||
assert result == {"accessMode": "external"}
|
||||
|
||||
@patch("controllers.web.app.FeatureService.get_system_features")
|
||||
def test_raises_when_no_app_id_or_code(self, mock_features: MagicMock, app: Flask) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True))
|
||||
|
||||
with app.test_request_context("/webapp/access-mode"):
|
||||
with pytest.raises(ValueError, match="appId or appCode"):
|
||||
AppAccessMode().get()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AppWebAuthPermission
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestAppWebAuthPermission:
|
||||
@patch("controllers.web.app.WebAppAuthService.is_app_require_permission_check", return_value=False)
|
||||
def test_returns_true_when_no_permission_check_required(self, mock_check: MagicMock, app: Flask) -> None:
|
||||
with app.test_request_context("/webapp/permission?appId=app-1", headers={"X-App-Code": "code1"}):
|
||||
result = AppWebAuthPermission().get()
|
||||
|
||||
assert result == {"result": True}
|
||||
|
||||
def test_raises_when_missing_app_id(self, app: Flask) -> None:
|
||||
with app.test_request_context("/webapp/permission", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(ValueError, match="appId"):
|
||||
AppWebAuthPermission().get()
|
||||
135
api/tests/unit_tests/controllers/web/test_audio.py
Normal file
135
api/tests/unit_tests/controllers/web/test_audio.py
Normal file
@ -0,0 +1,135 @@
|
||||
"""Unit tests for controllers.web.audio endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from io import BytesIO
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.audio import AudioApi, TextApi
|
||||
from controllers.web.error import (
|
||||
AudioTooLargeError,
|
||||
CompletionRequestError,
|
||||
NoAudioUploadedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeError
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
NoAudioUploadedServiceError,
|
||||
ProviderNotSupportSpeechToTextServiceError,
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
|
||||
|
||||
def _app_model() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="chat")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1", external_user_id="ext-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AudioApi (audio-to-text)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestAudioApi:
|
||||
@patch("controllers.web.audio.AudioService.transcript_asr", return_value={"text": "hello"})
|
||||
def test_happy_path(self, mock_asr: MagicMock, app: Flask) -> None:
|
||||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
data = {"file": (BytesIO(b"fake-audio"), "test.mp3")}
|
||||
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
|
||||
result = AudioApi().post(_app_model(), _end_user())
|
||||
|
||||
assert result == {"text": "hello"}
|
||||
|
||||
@patch("controllers.web.audio.AudioService.transcript_asr", side_effect=NoAudioUploadedServiceError())
|
||||
def test_no_audio_uploaded(self, mock_asr: MagicMock, app: Flask) -> None:
|
||||
data = {"file": (BytesIO(b""), "empty.mp3")}
|
||||
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(NoAudioUploadedError):
|
||||
AudioApi().post(_app_model(), _end_user())
|
||||
|
||||
@patch("controllers.web.audio.AudioService.transcript_asr", side_effect=AudioTooLargeServiceError("too big"))
|
||||
def test_audio_too_large(self, mock_asr: MagicMock, app: Flask) -> None:
|
||||
data = {"file": (BytesIO(b"big"), "big.mp3")}
|
||||
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(AudioTooLargeError):
|
||||
AudioApi().post(_app_model(), _end_user())
|
||||
|
||||
@patch("controllers.web.audio.AudioService.transcript_asr", side_effect=UnsupportedAudioTypeServiceError())
|
||||
def test_unsupported_type(self, mock_asr: MagicMock, app: Flask) -> None:
|
||||
data = {"file": (BytesIO(b"bad"), "bad.xyz")}
|
||||
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(UnsupportedAudioTypeError):
|
||||
AudioApi().post(_app_model(), _end_user())
|
||||
|
||||
@patch(
|
||||
"controllers.web.audio.AudioService.transcript_asr",
|
||||
side_effect=ProviderNotSupportSpeechToTextServiceError(),
|
||||
)
|
||||
def test_provider_not_support(self, mock_asr: MagicMock, app: Flask) -> None:
|
||||
data = {"file": (BytesIO(b"x"), "x.mp3")}
|
||||
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(ProviderNotSupportSpeechToTextError):
|
||||
AudioApi().post(_app_model(), _end_user())
|
||||
|
||||
@patch(
|
||||
"controllers.web.audio.AudioService.transcript_asr",
|
||||
side_effect=ProviderTokenNotInitError(description="no token"),
|
||||
)
|
||||
def test_provider_not_init(self, mock_asr: MagicMock, app: Flask) -> None:
|
||||
data = {"file": (BytesIO(b"x"), "x.mp3")}
|
||||
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
AudioApi().post(_app_model(), _end_user())
|
||||
|
||||
@patch("controllers.web.audio.AudioService.transcript_asr", side_effect=QuotaExceededError())
|
||||
def test_quota_exceeded(self, mock_asr: MagicMock, app: Flask) -> None:
|
||||
data = {"file": (BytesIO(b"x"), "x.mp3")}
|
||||
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
AudioApi().post(_app_model(), _end_user())
|
||||
|
||||
@patch("controllers.web.audio.AudioService.transcript_asr", side_effect=ModelCurrentlyNotSupportError())
|
||||
def test_model_not_support(self, mock_asr: MagicMock, app: Flask) -> None:
|
||||
data = {"file": (BytesIO(b"x"), "x.mp3")}
|
||||
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(ProviderModelCurrentlyNotSupportError):
|
||||
AudioApi().post(_app_model(), _end_user())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TextApi (text-to-audio)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestTextApi:
|
||||
@patch("controllers.web.audio.AudioService.transcript_tts", return_value="audio-bytes")
|
||||
@patch("controllers.web.audio.web_ns")
|
||||
def test_happy_path(self, mock_ns: MagicMock, mock_tts: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"text": "hello", "voice": "alloy"}
|
||||
|
||||
with app.test_request_context("/text-to-audio", method="POST"):
|
||||
result = TextApi().post(_app_model(), _end_user())
|
||||
|
||||
assert result == "audio-bytes"
|
||||
mock_tts.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"controllers.web.audio.AudioService.transcript_tts",
|
||||
side_effect=InvokeError(description="invoke failed"),
|
||||
)
|
||||
@patch("controllers.web.audio.web_ns")
|
||||
def test_invoke_error_mapped(self, mock_ns: MagicMock, mock_tts: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"text": "hello"}
|
||||
|
||||
with app.test_request_context("/text-to-audio", method="POST"):
|
||||
with pytest.raises(CompletionRequestError):
|
||||
TextApi().post(_app_model(), _end_user())
|
||||
161
api/tests/unit_tests/controllers/web/test_completion.py
Normal file
161
api/tests/unit_tests/controllers/web/test_completion.py
Normal file
@ -0,0 +1,161 @@
|
||||
"""Unit tests for controllers.web.completion endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi
|
||||
from controllers.web.error import (
|
||||
CompletionRequestError,
|
||||
NotChatAppError,
|
||||
NotCompletionAppError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeError
|
||||
|
||||
|
||||
def _completion_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="completion")
|
||||
|
||||
|
||||
def _chat_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="chat")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CompletionApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestCompletionApi:
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/completion-messages", method="POST"):
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
CompletionApi().post(_chat_app(), _end_user())
|
||||
|
||||
@patch("controllers.web.completion.helper.compact_generate_response", return_value={"answer": "hi"})
|
||||
@patch("controllers.web.completion.AppGenerateService.generate")
|
||||
@patch("controllers.web.completion.web_ns")
|
||||
def test_happy_path(self, mock_ns: MagicMock, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {}, "query": "test"}
|
||||
mock_gen.return_value = "response-obj"
|
||||
|
||||
with app.test_request_context("/completion-messages", method="POST"):
|
||||
result = CompletionApi().post(_completion_app(), _end_user())
|
||||
|
||||
assert result == {"answer": "hi"}
|
||||
|
||||
@patch(
|
||||
"controllers.web.completion.AppGenerateService.generate",
|
||||
side_effect=ProviderTokenNotInitError(description="not init"),
|
||||
)
|
||||
@patch("controllers.web.completion.web_ns")
|
||||
def test_provider_not_init_error(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {}}
|
||||
|
||||
with app.test_request_context("/completion-messages", method="POST"):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
CompletionApi().post(_completion_app(), _end_user())
|
||||
|
||||
@patch(
|
||||
"controllers.web.completion.AppGenerateService.generate",
|
||||
side_effect=QuotaExceededError(),
|
||||
)
|
||||
@patch("controllers.web.completion.web_ns")
|
||||
def test_quota_exceeded_error(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {}}
|
||||
|
||||
with app.test_request_context("/completion-messages", method="POST"):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
CompletionApi().post(_completion_app(), _end_user())
|
||||
|
||||
@patch(
|
||||
"controllers.web.completion.AppGenerateService.generate",
|
||||
side_effect=ModelCurrentlyNotSupportError(),
|
||||
)
|
||||
@patch("controllers.web.completion.web_ns")
|
||||
def test_model_not_support_error(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {}}
|
||||
|
||||
with app.test_request_context("/completion-messages", method="POST"):
|
||||
with pytest.raises(ProviderModelCurrentlyNotSupportError):
|
||||
CompletionApi().post(_completion_app(), _end_user())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CompletionStopApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestCompletionStopApi:
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/completion-messages/task-1/stop", method="POST"):
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
CompletionStopApi().post(_chat_app(), _end_user(), "task-1")
|
||||
|
||||
@patch("controllers.web.completion.AppTaskService.stop_task")
|
||||
def test_stop_success(self, mock_stop: MagicMock, app: Flask) -> None:
|
||||
with app.test_request_context("/completion-messages/task-1/stop", method="POST"):
|
||||
result, status = CompletionStopApi().post(_completion_app(), _end_user(), "task-1")
|
||||
|
||||
assert status == 200
|
||||
assert result == {"result": "success"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ChatApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestChatApi:
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/chat-messages", method="POST"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ChatApi().post(_completion_app(), _end_user())
|
||||
|
||||
@patch("controllers.web.completion.helper.compact_generate_response", return_value={"answer": "reply"})
|
||||
@patch("controllers.web.completion.AppGenerateService.generate")
|
||||
@patch("controllers.web.completion.web_ns")
|
||||
def test_happy_path(self, mock_ns: MagicMock, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {}, "query": "hi"}
|
||||
mock_gen.return_value = "response"
|
||||
|
||||
with app.test_request_context("/chat-messages", method="POST"):
|
||||
result = ChatApi().post(_chat_app(), _end_user())
|
||||
|
||||
assert result == {"answer": "reply"}
|
||||
|
||||
@patch(
|
||||
"controllers.web.completion.AppGenerateService.generate",
|
||||
side_effect=InvokeError(description="rate limit"),
|
||||
)
|
||||
@patch("controllers.web.completion.web_ns")
|
||||
def test_invoke_error_mapped(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {}, "query": "x"}
|
||||
|
||||
with app.test_request_context("/chat-messages", method="POST"):
|
||||
with pytest.raises(CompletionRequestError):
|
||||
ChatApi().post(_chat_app(), _end_user())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ChatStopApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestChatStopApi:
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/chat-messages/task-1/stop", method="POST"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ChatStopApi().post(_completion_app(), _end_user(), "task-1")
|
||||
|
||||
@patch("controllers.web.completion.AppTaskService.stop_task")
|
||||
def test_stop_success(self, mock_stop: MagicMock, app: Flask) -> None:
|
||||
with app.test_request_context("/chat-messages/task-1/stop", method="POST"):
|
||||
result, status = ChatStopApi().post(_chat_app(), _end_user(), "task-1")
|
||||
|
||||
assert status == 200
|
||||
assert result == {"result": "success"}
|
||||
183
api/tests/unit_tests/controllers/web/test_conversation.py
Normal file
183
api/tests/unit_tests/controllers/web/test_conversation.py
Normal file
@ -0,0 +1,183 @@
|
||||
"""Unit tests for controllers.web.conversation endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.web.conversation import (
|
||||
ConversationApi,
|
||||
ConversationListApi,
|
||||
ConversationPinApi,
|
||||
ConversationRenameApi,
|
||||
ConversationUnPinApi,
|
||||
)
|
||||
from controllers.web.error import NotChatAppError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
|
||||
|
||||
def _chat_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="chat")
|
||||
|
||||
|
||||
def _completion_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="completion")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ConversationListApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestConversationListApi:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/conversations"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationListApi().get(_completion_app(), _end_user())
|
||||
|
||||
@patch("controllers.web.conversation.WebConversationService.pagination_by_last_id")
|
||||
@patch("controllers.web.conversation.db")
|
||||
def test_happy_path(self, mock_db: MagicMock, mock_paginate: MagicMock, app: Flask) -> None:
|
||||
conv_id = str(uuid4())
|
||||
conv = SimpleNamespace(
|
||||
id=conv_id,
|
||||
name="Test",
|
||||
inputs={},
|
||||
status="normal",
|
||||
introduction="",
|
||||
created_at=1700000000,
|
||||
updated_at=1700000000,
|
||||
)
|
||||
mock_paginate.return_value = SimpleNamespace(limit=20, has_more=False, data=[conv])
|
||||
mock_db.engine = "engine"
|
||||
|
||||
session_mock = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/conversations?limit=20"),
|
||||
patch("controllers.web.conversation.Session", return_value=session_ctx),
|
||||
):
|
||||
result = ConversationListApi().get(_chat_app(), _end_user())
|
||||
|
||||
assert result["limit"] == 20
|
||||
assert result["has_more"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ConversationApi (delete)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestConversationApi:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context(f"/conversations/{uuid4()}"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationApi().delete(_completion_app(), _end_user(), uuid4())
|
||||
|
||||
@patch("controllers.web.conversation.ConversationService.delete")
|
||||
def test_delete_success(self, mock_delete: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}"):
|
||||
result, status = ConversationApi().delete(_chat_app(), _end_user(), c_id)
|
||||
|
||||
assert status == 204
|
||||
assert result["result"] == "success"
|
||||
|
||||
@patch("controllers.web.conversation.ConversationService.delete", side_effect=ConversationNotExistsError())
|
||||
def test_delete_not_found(self, mock_delete: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}"):
|
||||
with pytest.raises(NotFound, match="Conversation Not Exists"):
|
||||
ConversationApi().delete(_chat_app(), _end_user(), c_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ConversationRenameApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestConversationRenameApi:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context(f"/conversations/{uuid4()}/name", method="POST", json={"name": "x"}):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationRenameApi().post(_completion_app(), _end_user(), uuid4())
|
||||
|
||||
@patch("controllers.web.conversation.ConversationService.rename")
|
||||
@patch("controllers.web.conversation.web_ns")
|
||||
def test_rename_success(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
mock_ns.payload = {"name": "New Name", "auto_generate": False}
|
||||
conv = SimpleNamespace(
|
||||
id=str(c_id),
|
||||
name="New Name",
|
||||
inputs={},
|
||||
status="normal",
|
||||
introduction="",
|
||||
created_at=1700000000,
|
||||
updated_at=1700000000,
|
||||
)
|
||||
mock_rename.return_value = conv
|
||||
|
||||
with app.test_request_context(f"/conversations/{c_id}/name", method="POST", json={"name": "New Name"}):
|
||||
result = ConversationRenameApi().post(_chat_app(), _end_user(), c_id)
|
||||
|
||||
assert result["name"] == "New Name"
|
||||
|
||||
@patch(
|
||||
"controllers.web.conversation.ConversationService.rename",
|
||||
side_effect=ConversationNotExistsError(),
|
||||
)
|
||||
@patch("controllers.web.conversation.web_ns")
|
||||
def test_rename_not_found(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
mock_ns.payload = {"name": "X", "auto_generate": False}
|
||||
|
||||
with app.test_request_context(f"/conversations/{c_id}/name", method="POST", json={"name": "X"}):
|
||||
with pytest.raises(NotFound, match="Conversation Not Exists"):
|
||||
ConversationRenameApi().post(_chat_app(), _end_user(), c_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ConversationPinApi / ConversationUnPinApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestConversationPinApi:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context(f"/conversations/{uuid4()}/pin", method="PATCH"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationPinApi().patch(_completion_app(), _end_user(), uuid4())
|
||||
|
||||
@patch("controllers.web.conversation.WebConversationService.pin")
|
||||
def test_pin_success(self, mock_pin: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"):
|
||||
result = ConversationPinApi().patch(_chat_app(), _end_user(), c_id)
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
@patch("controllers.web.conversation.WebConversationService.pin", side_effect=ConversationNotExistsError())
|
||||
def test_pin_not_found(self, mock_pin: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"):
|
||||
with pytest.raises(NotFound):
|
||||
ConversationPinApi().patch(_chat_app(), _end_user(), c_id)
|
||||
|
||||
|
||||
class TestConversationUnPinApi:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context(f"/conversations/{uuid4()}/unpin", method="PATCH"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationUnPinApi().patch(_completion_app(), _end_user(), uuid4())
|
||||
|
||||
@patch("controllers.web.conversation.WebConversationService.unpin")
|
||||
def test_unpin_success(self, mock_unpin: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}/unpin", method="PATCH"):
|
||||
result = ConversationUnPinApi().patch(_chat_app(), _end_user(), c_id)
|
||||
|
||||
assert result["result"] == "success"
|
||||
75
api/tests/unit_tests/controllers/web/test_error.py
Normal file
75
api/tests/unit_tests/controllers/web/test_error.py
Normal file
@ -0,0 +1,75 @@
|
||||
"""Unit tests for controllers.web.error HTTP exception classes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.web.error import (
|
||||
AppMoreLikeThisDisabledError,
|
||||
AppSuggestedQuestionsAfterAnswerDisabledError,
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
CompletionRequestError,
|
||||
ConversationCompletedError,
|
||||
InvalidArgumentError,
|
||||
InvokeRateLimitError,
|
||||
NoAudioUploadedError,
|
||||
NotChatAppError,
|
||||
NotCompletionAppError,
|
||||
NotFoundError,
|
||||
NotWorkflowAppError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
WebAppAuthAccessDeniedError,
|
||||
WebAppAuthRequiredError,
|
||||
WebFormRateLimitExceededError,
|
||||
)
|
||||
|
||||
_ERROR_SPECS: list[tuple[type, str, int]] = [
|
||||
(AppUnavailableError, "app_unavailable", 400),
|
||||
(NotCompletionAppError, "not_completion_app", 400),
|
||||
(NotChatAppError, "not_chat_app", 400),
|
||||
(NotWorkflowAppError, "not_workflow_app", 400),
|
||||
(ConversationCompletedError, "conversation_completed", 400),
|
||||
(ProviderNotInitializeError, "provider_not_initialize", 400),
|
||||
(ProviderQuotaExceededError, "provider_quota_exceeded", 400),
|
||||
(ProviderModelCurrentlyNotSupportError, "model_currently_not_support", 400),
|
||||
(CompletionRequestError, "completion_request_error", 400),
|
||||
(AppMoreLikeThisDisabledError, "app_more_like_this_disabled", 403),
|
||||
(AppSuggestedQuestionsAfterAnswerDisabledError, "app_suggested_questions_after_answer_disabled", 403),
|
||||
(NoAudioUploadedError, "no_audio_uploaded", 400),
|
||||
(AudioTooLargeError, "audio_too_large", 413),
|
||||
(UnsupportedAudioTypeError, "unsupported_audio_type", 415),
|
||||
(ProviderNotSupportSpeechToTextError, "provider_not_support_speech_to_text", 400),
|
||||
(WebAppAuthRequiredError, "web_sso_auth_required", 401),
|
||||
(WebAppAuthAccessDeniedError, "web_app_access_denied", 401),
|
||||
(InvokeRateLimitError, "rate_limit_error", 429),
|
||||
(WebFormRateLimitExceededError, "web_form_rate_limit_exceeded", 429),
|
||||
(NotFoundError, "not_found", 404),
|
||||
(InvalidArgumentError, "invalid_param", 400),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("cls", "expected_code", "expected_status"),
|
||||
_ERROR_SPECS,
|
||||
ids=[cls.__name__ for cls, _, _ in _ERROR_SPECS],
|
||||
)
|
||||
def test_error_class_attributes(cls: type, expected_code: str, expected_status: int) -> None:
|
||||
"""Each error class exposes the correct error_code and HTTP status code."""
|
||||
assert cls.error_code == expected_code
|
||||
assert cls.code == expected_status
|
||||
|
||||
|
||||
def test_error_classes_have_description() -> None:
|
||||
"""Every error class has a description (string or None for generic errors)."""
|
||||
# NotFoundError and InvalidArgumentError use None description by design
|
||||
_NO_DESCRIPTION = {NotFoundError, InvalidArgumentError}
|
||||
for cls, _, _ in _ERROR_SPECS:
|
||||
if cls in _NO_DESCRIPTION:
|
||||
continue
|
||||
assert isinstance(cls.description, str), f"{cls.__name__} missing description"
|
||||
assert len(cls.description) > 0, f"{cls.__name__} has empty description"
|
||||
38
api/tests/unit_tests/controllers/web/test_feature.py
Normal file
38
api/tests/unit_tests/controllers/web/test_feature.py
Normal file
@ -0,0 +1,38 @@
|
||||
"""Unit tests for controllers.web.feature endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.feature import SystemFeatureApi
|
||||
|
||||
|
||||
class TestSystemFeatureApi:
|
||||
@patch("controllers.web.feature.FeatureService.get_system_features")
|
||||
def test_returns_system_features(self, mock_features: MagicMock, app: Flask) -> None:
|
||||
mock_model = MagicMock()
|
||||
mock_model.model_dump.return_value = {"sso_enforced_for_signin": False, "webapp_auth": {"enabled": False}}
|
||||
mock_features.return_value = mock_model
|
||||
|
||||
with app.test_request_context("/system-features"):
|
||||
result = SystemFeatureApi().get()
|
||||
|
||||
assert result == {"sso_enforced_for_signin": False, "webapp_auth": {"enabled": False}}
|
||||
mock_features.assert_called_once()
|
||||
|
||||
@patch("controllers.web.feature.FeatureService.get_system_features")
|
||||
def test_unauthenticated_access(self, mock_features: MagicMock, app: Flask) -> None:
|
||||
"""SystemFeatureApi is unauthenticated by design — no WebApiResource decorator."""
|
||||
mock_model = MagicMock()
|
||||
mock_model.model_dump.return_value = {}
|
||||
mock_features.return_value = mock_model
|
||||
|
||||
# Verify it's a bare Resource, not WebApiResource
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.web.wraps import WebApiResource
|
||||
|
||||
assert issubclass(SystemFeatureApi, Resource)
|
||||
assert not issubclass(SystemFeatureApi, WebApiResource)
|
||||
89
api/tests/unit_tests/controllers/web/test_files.py
Normal file
89
api/tests/unit_tests/controllers/web/test_files.py
Normal file
@ -0,0 +1,89 @@
|
||||
"""Unit tests for controllers.web.files endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from io import BytesIO
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.common.errors import (
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
)
|
||||
from controllers.web.files import FileApi
|
||||
|
||||
|
||||
def _app_model() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
class TestFileApi:
|
||||
def test_no_file_uploaded(self, app: Flask) -> None:
|
||||
with app.test_request_context("/files/upload", method="POST", content_type="multipart/form-data"):
|
||||
with pytest.raises(NoFileUploadedError):
|
||||
FileApi().post(_app_model(), _end_user())
|
||||
|
||||
def test_too_many_files(self, app: Flask) -> None:
|
||||
data = {
|
||||
"file": (BytesIO(b"a"), "a.txt"),
|
||||
"file2": (BytesIO(b"b"), "b.txt"),
|
||||
}
|
||||
with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"):
|
||||
# Now has "file" key but len(request.files) > 1
|
||||
with pytest.raises(TooManyFilesError):
|
||||
FileApi().post(_app_model(), _end_user())
|
||||
|
||||
def test_filename_missing(self, app: Flask) -> None:
|
||||
data = {"file": (BytesIO(b"content"), "")}
|
||||
with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(FilenameNotExistsError):
|
||||
FileApi().post(_app_model(), _end_user())
|
||||
|
||||
@patch("controllers.web.files.FileService")
|
||||
@patch("controllers.web.files.db")
|
||||
def test_upload_success(self, mock_db: MagicMock, mock_file_svc_cls: MagicMock, app: Flask) -> None:
|
||||
mock_db.engine = "engine"
|
||||
from datetime import datetime
|
||||
|
||||
upload_file = SimpleNamespace(
|
||||
id="file-1",
|
||||
name="test.txt",
|
||||
size=100,
|
||||
extension="txt",
|
||||
mime_type="text/plain",
|
||||
created_by="eu-1",
|
||||
created_at=datetime(2024, 1, 1),
|
||||
)
|
||||
mock_file_svc_cls.return_value.upload_file.return_value = upload_file
|
||||
|
||||
data = {"file": (BytesIO(b"content"), "test.txt")}
|
||||
with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"):
|
||||
result, status = FileApi().post(_app_model(), _end_user())
|
||||
|
||||
assert status == 201
|
||||
assert result["id"] == "file-1"
|
||||
assert result["name"] == "test.txt"
|
||||
|
||||
@patch("controllers.web.files.FileService")
|
||||
@patch("controllers.web.files.db")
|
||||
def test_file_too_large_from_service(self, mock_db: MagicMock, mock_file_svc_cls: MagicMock, app: Flask) -> None:
|
||||
import services.errors.file
|
||||
|
||||
mock_db.engine = "engine"
|
||||
mock_file_svc_cls.return_value.upload_file.side_effect = services.errors.file.FileTooLargeError(
|
||||
description="max 10MB"
|
||||
)
|
||||
|
||||
data = {"file": (BytesIO(b"big"), "big.txt")}
|
||||
with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(FileTooLargeError):
|
||||
FileApi().post(_app_model(), _end_user())
|
||||
156
api/tests/unit_tests/controllers/web/test_message_endpoints.py
Normal file
156
api/tests/unit_tests/controllers/web/test_message_endpoints.py
Normal file
@ -0,0 +1,156 @@
|
||||
"""Unit tests for controllers.web.message — feedback, more-like-this, suggested questions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.web.error import (
|
||||
AppMoreLikeThisDisabledError,
|
||||
NotChatAppError,
|
||||
NotCompletionAppError,
|
||||
)
|
||||
from controllers.web.message import (
|
||||
MessageFeedbackApi,
|
||||
MessageMoreLikeThisApi,
|
||||
MessageSuggestedQuestionApi,
|
||||
)
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
|
||||
def _chat_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="chat")
|
||||
|
||||
|
||||
def _completion_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="completion")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MessageFeedbackApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestMessageFeedbackApi:
|
||||
@patch("controllers.web.message.MessageService.create_feedback")
|
||||
@patch("controllers.web.message.web_ns")
|
||||
def test_feedback_success(self, mock_ns: MagicMock, mock_create: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"rating": "like", "content": "great"}
|
||||
msg_id = uuid4()
|
||||
|
||||
with app.test_request_context(f"/messages/{msg_id}/feedbacks", method="POST"):
|
||||
result = MessageFeedbackApi().post(_chat_app(), _end_user(), msg_id)
|
||||
|
||||
assert result == {"result": "success"}
|
||||
mock_create.assert_called_once()
|
||||
|
||||
@patch("controllers.web.message.MessageService.create_feedback")
|
||||
@patch("controllers.web.message.web_ns")
|
||||
def test_feedback_null_rating(self, mock_ns: MagicMock, mock_create: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"rating": None}
|
||||
msg_id = uuid4()
|
||||
|
||||
with app.test_request_context(f"/messages/{msg_id}/feedbacks", method="POST"):
|
||||
result = MessageFeedbackApi().post(_chat_app(), _end_user(), msg_id)
|
||||
|
||||
assert result == {"result": "success"}
|
||||
|
||||
@patch(
|
||||
"controllers.web.message.MessageService.create_feedback",
|
||||
side_effect=MessageNotExistsError(),
|
||||
)
|
||||
@patch("controllers.web.message.web_ns")
|
||||
def test_feedback_message_not_found(self, mock_ns: MagicMock, mock_create: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"rating": "dislike"}
|
||||
msg_id = uuid4()
|
||||
|
||||
with app.test_request_context(f"/messages/{msg_id}/feedbacks", method="POST"):
|
||||
with pytest.raises(NotFound, match="Message Not Exists"):
|
||||
MessageFeedbackApi().post(_chat_app(), _end_user(), msg_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MessageMoreLikeThisApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestMessageMoreLikeThisApi:
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"):
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
MessageMoreLikeThisApi().get(_chat_app(), _end_user(), msg_id)
|
||||
|
||||
@patch("controllers.web.message.helper.compact_generate_response", return_value={"answer": "similar"})
|
||||
@patch("controllers.web.message.AppGenerateService.generate_more_like_this")
|
||||
def test_happy_path(self, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
mock_gen.return_value = "response"
|
||||
|
||||
with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"):
|
||||
result = MessageMoreLikeThisApi().get(_completion_app(), _end_user(), msg_id)
|
||||
|
||||
assert result == {"answer": "similar"}
|
||||
|
||||
@patch(
|
||||
"controllers.web.message.AppGenerateService.generate_more_like_this",
|
||||
side_effect=MessageNotExistsError(),
|
||||
)
|
||||
def test_message_not_found(self, mock_gen: MagicMock, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"):
|
||||
with pytest.raises(NotFound, match="Message Not Exists"):
|
||||
MessageMoreLikeThisApi().get(_completion_app(), _end_user(), msg_id)
|
||||
|
||||
@patch(
|
||||
"controllers.web.message.AppGenerateService.generate_more_like_this",
|
||||
side_effect=MoreLikeThisDisabledError(),
|
||||
)
|
||||
def test_feature_disabled(self, mock_gen: MagicMock, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"):
|
||||
with pytest.raises(AppMoreLikeThisDisabledError):
|
||||
MessageMoreLikeThisApi().get(_completion_app(), _end_user(), msg_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MessageSuggestedQuestionApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestMessageSuggestedQuestionApi:
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/messages/{msg_id}/suggested-questions"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
MessageSuggestedQuestionApi().get(_completion_app(), _end_user(), msg_id)
|
||||
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/messages/{msg_id}/suggested-questions"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
MessageSuggestedQuestionApi().get(_completion_app(), _end_user(), msg_id)
|
||||
|
||||
@patch("controllers.web.message.MessageService.get_suggested_questions_after_answer")
|
||||
def test_happy_path(self, mock_suggest: MagicMock, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
mock_suggest.return_value = ["What about X?", "Tell me more about Y."]
|
||||
|
||||
with app.test_request_context(f"/messages/{msg_id}/suggested-questions"):
|
||||
result = MessageSuggestedQuestionApi().get(_chat_app(), _end_user(), msg_id)
|
||||
|
||||
assert result["data"] == ["What about X?", "Tell me more about Y."]
|
||||
|
||||
@patch(
|
||||
"controllers.web.message.MessageService.get_suggested_questions_after_answer",
|
||||
side_effect=MessageNotExistsError(),
|
||||
)
|
||||
def test_message_not_found(self, mock_suggest: MagicMock, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/messages/{msg_id}/suggested-questions"):
|
||||
with pytest.raises(NotFound, match="Message not found"):
|
||||
MessageSuggestedQuestionApi().get(_chat_app(), _end_user(), msg_id)
|
||||
103
api/tests/unit_tests/controllers/web/test_passport.py
Normal file
103
api/tests/unit_tests/controllers/web/test_passport.py
Normal file
@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from controllers.web.error import WebAppAuthRequiredError
|
||||
from controllers.web.passport import (
|
||||
PassportService,
|
||||
decode_enterprise_webapp_user_id,
|
||||
exchange_token_for_existing_web_user,
|
||||
generate_session_id,
|
||||
)
|
||||
from services.webapp_auth_service import WebAppAuthType
|
||||
|
||||
|
||||
def test_decode_enterprise_webapp_user_id_none() -> None:
|
||||
assert decode_enterprise_webapp_user_id(None) is None
|
||||
|
||||
|
||||
def test_decode_enterprise_webapp_user_id_invalid_source(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(PassportService, "verify", lambda *_args, **_kwargs: {"token_source": "bad"})
|
||||
with pytest.raises(Unauthorized):
|
||||
decode_enterprise_webapp_user_id("token")
|
||||
|
||||
|
||||
def test_decode_enterprise_webapp_user_id_valid(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
decoded = {"token_source": "webapp_login_token", "user_id": "u1"}
|
||||
monkeypatch.setattr(PassportService, "verify", lambda *_args, **_kwargs: decoded)
|
||||
assert decode_enterprise_webapp_user_id("token") == decoded
|
||||
|
||||
|
||||
def test_exchange_token_public_flow(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal")
|
||||
app_model = SimpleNamespace(id="a1", status="normal", enable_site=True)
|
||||
|
||||
def _scalar_side_effect(*_args, **_kwargs):
|
||||
if not hasattr(_scalar_side_effect, "calls"):
|
||||
_scalar_side_effect.calls = 0
|
||||
_scalar_side_effect.calls += 1
|
||||
return site if _scalar_side_effect.calls == 1 else app_model
|
||||
|
||||
db_session = SimpleNamespace(scalar=_scalar_side_effect)
|
||||
monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session))
|
||||
monkeypatch.setattr("controllers.web.passport._exchange_for_public_app_token", lambda *_args, **_kwargs: "resp")
|
||||
|
||||
decoded = {"auth_type": "public"}
|
||||
result = exchange_token_for_existing_web_user("code", decoded, WebAppAuthType.PUBLIC)
|
||||
assert result == "resp"
|
||||
|
||||
|
||||
def test_exchange_token_requires_external(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal")
|
||||
app_model = SimpleNamespace(id="a1", status="normal", enable_site=True)
|
||||
|
||||
def _scalar_side_effect(*_args, **_kwargs):
|
||||
if not hasattr(_scalar_side_effect, "calls"):
|
||||
_scalar_side_effect.calls = 0
|
||||
_scalar_side_effect.calls += 1
|
||||
return site if _scalar_side_effect.calls == 1 else app_model
|
||||
|
||||
db_session = SimpleNamespace(scalar=_scalar_side_effect)
|
||||
monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session))
|
||||
|
||||
decoded = {"auth_type": "internal"}
|
||||
with pytest.raises(WebAppAuthRequiredError):
|
||||
exchange_token_for_existing_web_user("code", decoded, WebAppAuthType.EXTERNAL)
|
||||
|
||||
|
||||
def test_exchange_token_missing_session_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal")
|
||||
app_model = SimpleNamespace(id="a1", status="normal", enable_site=True, tenant_id="t1")
|
||||
|
||||
def _scalar_side_effect(*_args, **_kwargs):
|
||||
if not hasattr(_scalar_side_effect, "calls"):
|
||||
_scalar_side_effect.calls = 0
|
||||
_scalar_side_effect.calls += 1
|
||||
if _scalar_side_effect.calls == 1:
|
||||
return site
|
||||
if _scalar_side_effect.calls == 2:
|
||||
return app_model
|
||||
return None
|
||||
|
||||
db_session = SimpleNamespace(scalar=_scalar_side_effect, add=lambda *_a, **_k: None, commit=lambda: None)
|
||||
monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session))
|
||||
|
||||
decoded = {"auth_type": "internal"}
|
||||
with pytest.raises(NotFound):
|
||||
exchange_token_for_existing_web_user("code", decoded, WebAppAuthType.INTERNAL)
|
||||
|
||||
|
||||
def test_generate_session_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
counts = [1, 0]
|
||||
|
||||
def _scalar(*_args, **_kwargs):
|
||||
return counts.pop(0)
|
||||
|
||||
db_session = SimpleNamespace(scalar=_scalar)
|
||||
monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session))
|
||||
|
||||
session_id = generate_session_id()
|
||||
assert session_id
|
||||
423
api/tests/unit_tests/controllers/web/test_pydantic_models.py
Normal file
423
api/tests/unit_tests/controllers/web/test_pydantic_models.py
Normal file
@ -0,0 +1,423 @@
|
||||
"""Unit tests for Pydantic models defined in controllers.web modules.
|
||||
|
||||
Covers validation logic, field defaults, constraints, and custom validators
|
||||
for all ~15 Pydantic models across the web controller layer.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# app.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.app import AppAccessModeQuery
|
||||
|
||||
|
||||
class TestAppAccessModeQuery:
|
||||
def test_alias_resolution(self) -> None:
|
||||
q = AppAccessModeQuery.model_validate({"appId": "abc", "appCode": "xyz"})
|
||||
assert q.app_id == "abc"
|
||||
assert q.app_code == "xyz"
|
||||
|
||||
def test_defaults_to_none(self) -> None:
|
||||
q = AppAccessModeQuery.model_validate({})
|
||||
assert q.app_id is None
|
||||
assert q.app_code is None
|
||||
|
||||
def test_accepts_snake_case(self) -> None:
|
||||
q = AppAccessModeQuery(app_id="id1", app_code="code1")
|
||||
assert q.app_id == "id1"
|
||||
assert q.app_code == "code1"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# audio.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.audio import TextToAudioPayload
|
||||
|
||||
|
||||
class TestTextToAudioPayload:
|
||||
def test_defaults(self) -> None:
|
||||
p = TextToAudioPayload.model_validate({})
|
||||
assert p.message_id is None
|
||||
assert p.voice is None
|
||||
assert p.text is None
|
||||
assert p.streaming is None
|
||||
|
||||
def test_valid_uuid_message_id(self) -> None:
|
||||
uid = str(uuid4())
|
||||
p = TextToAudioPayload(message_id=uid)
|
||||
assert p.message_id == uid
|
||||
|
||||
def test_none_message_id_passthrough(self) -> None:
|
||||
p = TextToAudioPayload(message_id=None)
|
||||
assert p.message_id is None
|
||||
|
||||
def test_invalid_uuid_message_id(self) -> None:
|
||||
with pytest.raises(ValidationError, match="not a valid uuid"):
|
||||
TextToAudioPayload(message_id="not-a-uuid")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# completion.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.completion import ChatMessagePayload, CompletionMessagePayload
|
||||
|
||||
|
||||
class TestCompletionMessagePayload:
|
||||
def test_defaults(self) -> None:
|
||||
p = CompletionMessagePayload(inputs={})
|
||||
assert p.query == ""
|
||||
assert p.files is None
|
||||
assert p.response_mode is None
|
||||
assert p.retriever_from == "web_app"
|
||||
|
||||
def test_accepts_full_payload(self) -> None:
|
||||
p = CompletionMessagePayload(
|
||||
inputs={"key": "val"},
|
||||
query="test",
|
||||
files=[{"id": "f1"}],
|
||||
response_mode="streaming",
|
||||
)
|
||||
assert p.response_mode == "streaming"
|
||||
assert p.files == [{"id": "f1"}]
|
||||
|
||||
def test_invalid_response_mode(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
CompletionMessagePayload(inputs={}, response_mode="invalid")
|
||||
|
||||
|
||||
class TestChatMessagePayload:
|
||||
def test_valid_uuid_fields(self) -> None:
|
||||
cid = str(uuid4())
|
||||
pid = str(uuid4())
|
||||
p = ChatMessagePayload(inputs={}, query="hi", conversation_id=cid, parent_message_id=pid)
|
||||
assert p.conversation_id == cid
|
||||
assert p.parent_message_id == pid
|
||||
|
||||
def test_none_uuid_fields(self) -> None:
|
||||
p = ChatMessagePayload(inputs={}, query="hi")
|
||||
assert p.conversation_id is None
|
||||
assert p.parent_message_id is None
|
||||
|
||||
def test_invalid_conversation_id(self) -> None:
|
||||
with pytest.raises(ValidationError, match="not a valid uuid"):
|
||||
ChatMessagePayload(inputs={}, query="hi", conversation_id="bad")
|
||||
|
||||
def test_invalid_parent_message_id(self) -> None:
|
||||
with pytest.raises(ValidationError, match="not a valid uuid"):
|
||||
ChatMessagePayload(inputs={}, query="hi", parent_message_id="bad")
|
||||
|
||||
def test_query_required(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ChatMessagePayload(inputs={})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# conversation.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.conversation import ConversationListQuery, ConversationRenamePayload
|
||||
|
||||
|
||||
class TestConversationListQuery:
|
||||
def test_defaults(self) -> None:
|
||||
q = ConversationListQuery()
|
||||
assert q.last_id is None
|
||||
assert q.limit == 20
|
||||
assert q.pinned is None
|
||||
assert q.sort_by == "-updated_at"
|
||||
|
||||
def test_limit_lower_bound(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ConversationListQuery(limit=0)
|
||||
|
||||
def test_limit_upper_bound(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ConversationListQuery(limit=101)
|
||||
|
||||
def test_limit_boundaries_valid(self) -> None:
|
||||
assert ConversationListQuery(limit=1).limit == 1
|
||||
assert ConversationListQuery(limit=100).limit == 100
|
||||
|
||||
def test_valid_sort_by_options(self) -> None:
|
||||
for opt in ("created_at", "-created_at", "updated_at", "-updated_at"):
|
||||
assert ConversationListQuery(sort_by=opt).sort_by == opt
|
||||
|
||||
def test_invalid_sort_by(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ConversationListQuery(sort_by="invalid")
|
||||
|
||||
def test_valid_last_id(self) -> None:
|
||||
uid = str(uuid4())
|
||||
assert ConversationListQuery(last_id=uid).last_id == uid
|
||||
|
||||
def test_invalid_last_id(self) -> None:
|
||||
with pytest.raises(ValidationError, match="not a valid uuid"):
|
||||
ConversationListQuery(last_id="not-uuid")
|
||||
|
||||
|
||||
class TestConversationRenamePayload:
|
||||
def test_auto_generate_true_no_name_required(self) -> None:
|
||||
p = ConversationRenamePayload(auto_generate=True)
|
||||
assert p.name is None
|
||||
|
||||
def test_auto_generate_false_requires_name(self) -> None:
|
||||
with pytest.raises(ValidationError, match="name is required"):
|
||||
ConversationRenamePayload(auto_generate=False)
|
||||
|
||||
def test_auto_generate_false_blank_name_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError, match="name is required"):
|
||||
ConversationRenamePayload(auto_generate=False, name=" ")
|
||||
|
||||
def test_auto_generate_false_with_valid_name(self) -> None:
|
||||
p = ConversationRenamePayload(auto_generate=False, name="My Chat")
|
||||
assert p.name == "My Chat"
|
||||
|
||||
def test_defaults(self) -> None:
|
||||
p = ConversationRenamePayload(name="test")
|
||||
assert p.auto_generate is False
|
||||
assert p.name == "test"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# message.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.message import MessageFeedbackPayload, MessageListQuery, MessageMoreLikeThisQuery
|
||||
|
||||
|
||||
class TestMessageListQuery:
|
||||
def test_valid_query(self) -> None:
|
||||
cid = str(uuid4())
|
||||
q = MessageListQuery(conversation_id=cid)
|
||||
assert q.conversation_id == cid
|
||||
assert q.first_id is None
|
||||
assert q.limit == 20
|
||||
|
||||
def test_invalid_conversation_id(self) -> None:
|
||||
with pytest.raises(ValidationError, match="not a valid uuid"):
|
||||
MessageListQuery(conversation_id="bad")
|
||||
|
||||
def test_limit_bounds(self) -> None:
|
||||
cid = str(uuid4())
|
||||
with pytest.raises(ValidationError):
|
||||
MessageListQuery(conversation_id=cid, limit=0)
|
||||
with pytest.raises(ValidationError):
|
||||
MessageListQuery(conversation_id=cid, limit=101)
|
||||
|
||||
def test_valid_first_id(self) -> None:
|
||||
cid = str(uuid4())
|
||||
fid = str(uuid4())
|
||||
q = MessageListQuery(conversation_id=cid, first_id=fid)
|
||||
assert q.first_id == fid
|
||||
|
||||
def test_invalid_first_id(self) -> None:
|
||||
cid = str(uuid4())
|
||||
with pytest.raises(ValidationError, match="not a valid uuid"):
|
||||
MessageListQuery(conversation_id=cid, first_id="invalid")
|
||||
|
||||
|
||||
class TestMessageFeedbackPayload:
|
||||
def test_defaults(self) -> None:
|
||||
p = MessageFeedbackPayload()
|
||||
assert p.rating is None
|
||||
assert p.content is None
|
||||
|
||||
def test_valid_ratings(self) -> None:
|
||||
assert MessageFeedbackPayload(rating="like").rating == "like"
|
||||
assert MessageFeedbackPayload(rating="dislike").rating == "dislike"
|
||||
|
||||
def test_invalid_rating(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
MessageFeedbackPayload(rating="neutral")
|
||||
|
||||
|
||||
class TestMessageMoreLikeThisQuery:
|
||||
def test_valid_modes(self) -> None:
|
||||
assert MessageMoreLikeThisQuery(response_mode="blocking").response_mode == "blocking"
|
||||
assert MessageMoreLikeThisQuery(response_mode="streaming").response_mode == "streaming"
|
||||
|
||||
def test_invalid_mode(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
MessageMoreLikeThisQuery(response_mode="invalid")
|
||||
|
||||
def test_required(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
MessageMoreLikeThisQuery()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# remote_files.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.remote_files import RemoteFileUploadPayload
|
||||
|
||||
|
||||
class TestRemoteFileUploadPayload:
|
||||
def test_valid_url(self) -> None:
|
||||
p = RemoteFileUploadPayload(url="https://example.com/file.pdf")
|
||||
assert str(p.url) == "https://example.com/file.pdf"
|
||||
|
||||
def test_invalid_url(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
RemoteFileUploadPayload(url="not-a-url")
|
||||
|
||||
def test_url_required(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
RemoteFileUploadPayload()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# saved_message.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.saved_message import SavedMessageCreatePayload, SavedMessageListQuery
|
||||
|
||||
|
||||
class TestSavedMessageListQuery:
|
||||
def test_defaults(self) -> None:
|
||||
q = SavedMessageListQuery()
|
||||
assert q.last_id is None
|
||||
assert q.limit == 20
|
||||
|
||||
def test_limit_bounds(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
SavedMessageListQuery(limit=0)
|
||||
with pytest.raises(ValidationError):
|
||||
SavedMessageListQuery(limit=101)
|
||||
|
||||
def test_valid_last_id(self) -> None:
|
||||
uid = str(uuid4())
|
||||
q = SavedMessageListQuery(last_id=uid)
|
||||
assert q.last_id == uid
|
||||
|
||||
def test_empty_last_id(self) -> None:
|
||||
q = SavedMessageListQuery(last_id="")
|
||||
assert q.last_id == ""
|
||||
|
||||
|
||||
class TestSavedMessageCreatePayload:
|
||||
def test_valid_message_id(self) -> None:
|
||||
uid = str(uuid4())
|
||||
p = SavedMessageCreatePayload(message_id=uid)
|
||||
assert p.message_id == uid
|
||||
|
||||
def test_required(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
SavedMessageCreatePayload()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# workflow.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.workflow import WorkflowRunPayload
|
||||
|
||||
|
||||
class TestWorkflowRunPayload:
|
||||
def test_defaults(self) -> None:
|
||||
p = WorkflowRunPayload(inputs={})
|
||||
assert p.inputs == {}
|
||||
assert p.files is None
|
||||
|
||||
def test_with_files(self) -> None:
|
||||
p = WorkflowRunPayload(inputs={"k": "v"}, files=[{"id": "f1"}])
|
||||
assert p.files == [{"id": "f1"}]
|
||||
|
||||
def test_inputs_required(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
WorkflowRunPayload()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# forgot_password.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.forgot_password import (
|
||||
ForgotPasswordCheckPayload,
|
||||
ForgotPasswordResetPayload,
|
||||
ForgotPasswordSendPayload,
|
||||
)
|
||||
|
||||
|
||||
class TestForgotPasswordSendPayload:
|
||||
def test_valid_email(self) -> None:
|
||||
p = ForgotPasswordSendPayload(email="user@example.com")
|
||||
assert p.email == "user@example.com"
|
||||
|
||||
def test_invalid_email(self) -> None:
|
||||
with pytest.raises(ValidationError, match="not a valid email"):
|
||||
ForgotPasswordSendPayload(email="not-an-email")
|
||||
|
||||
def test_language_optional(self) -> None:
|
||||
p = ForgotPasswordSendPayload(email="a@b.com")
|
||||
assert p.language is None
|
||||
|
||||
|
||||
class TestForgotPasswordCheckPayload:
|
||||
def test_valid(self) -> None:
|
||||
p = ForgotPasswordCheckPayload(email="a@b.com", code="1234", token="tok")
|
||||
assert p.email == "a@b.com"
|
||||
assert p.code == "1234"
|
||||
assert p.token == "tok"
|
||||
|
||||
def test_empty_token_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ForgotPasswordCheckPayload(email="a@b.com", code="1234", token="")
|
||||
|
||||
|
||||
class TestForgotPasswordResetPayload:
|
||||
def test_valid_passwords(self) -> None:
|
||||
p = ForgotPasswordResetPayload(token="tok", new_password="Valid1234", password_confirm="Valid1234")
|
||||
assert p.new_password == "Valid1234"
|
||||
|
||||
def test_weak_password_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError, match="Password must contain"):
|
||||
ForgotPasswordResetPayload(token="tok", new_password="short", password_confirm="short")
|
||||
|
||||
def test_letters_only_password_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError, match="Password must contain"):
|
||||
ForgotPasswordResetPayload(token="tok", new_password="abcdefghi", password_confirm="abcdefghi")
|
||||
|
||||
def test_digits_only_password_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError, match="Password must contain"):
|
||||
ForgotPasswordResetPayload(token="tok", new_password="123456789", password_confirm="123456789")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# login.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.login import EmailCodeLoginSendPayload, EmailCodeLoginVerifyPayload, LoginPayload
|
||||
|
||||
|
||||
class TestLoginPayload:
|
||||
def test_valid(self) -> None:
|
||||
p = LoginPayload(email="a@b.com", password="Valid1234")
|
||||
assert p.email == "a@b.com"
|
||||
|
||||
def test_invalid_email(self) -> None:
|
||||
with pytest.raises(ValidationError, match="not a valid email"):
|
||||
LoginPayload(email="bad", password="Valid1234")
|
||||
|
||||
def test_weak_password(self) -> None:
|
||||
with pytest.raises(ValidationError, match="Password must contain"):
|
||||
LoginPayload(email="a@b.com", password="weak")
|
||||
|
||||
|
||||
class TestEmailCodeLoginSendPayload:
|
||||
def test_valid(self) -> None:
|
||||
p = EmailCodeLoginSendPayload(email="a@b.com")
|
||||
assert p.language is None
|
||||
|
||||
def test_with_language(self) -> None:
|
||||
p = EmailCodeLoginSendPayload(email="a@b.com", language="zh-Hans")
|
||||
assert p.language == "zh-Hans"
|
||||
|
||||
|
||||
class TestEmailCodeLoginVerifyPayload:
|
||||
def test_valid(self) -> None:
|
||||
p = EmailCodeLoginVerifyPayload(email="a@b.com", code="1234", token="tok")
|
||||
assert p.code == "1234"
|
||||
|
||||
def test_empty_token_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
EmailCodeLoginVerifyPayload(email="a@b.com", code="1234", token="")
|
||||
147
api/tests/unit_tests/controllers/web/test_remote_files.py
Normal file
147
api/tests/unit_tests/controllers/web/test_remote_files.py
Normal file
@ -0,0 +1,147 @@
|
||||
"""Unit tests for controllers.web.remote_files endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.common.errors import FileTooLargeError, RemoteFileUploadError
|
||||
from controllers.web.remote_files import RemoteFileInfoApi, RemoteFileUploadApi
|
||||
|
||||
|
||||
def _app_model() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RemoteFileInfoApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestRemoteFileInfoApi:
|
||||
@patch("controllers.web.remote_files.ssrf_proxy")
|
||||
def test_head_success(self, mock_proxy: MagicMock, app: Flask) -> None:
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/pdf", "Content-Length": "1024"}
|
||||
mock_proxy.head.return_value = mock_resp
|
||||
|
||||
with app.test_request_context("/remote-files/https%3A%2F%2Fexample.com%2Ffile.pdf"):
|
||||
result = RemoteFileInfoApi().get(_app_model(), _end_user(), "https%3A%2F%2Fexample.com%2Ffile.pdf")
|
||||
|
||||
assert result["file_type"] == "application/pdf"
|
||||
assert result["file_length"] == 1024
|
||||
|
||||
@patch("controllers.web.remote_files.ssrf_proxy")
|
||||
def test_fallback_to_get(self, mock_proxy: MagicMock, app: Flask) -> None:
|
||||
head_resp = MagicMock()
|
||||
head_resp.status_code = 405 # Method not allowed
|
||||
get_resp = MagicMock()
|
||||
get_resp.status_code = 200
|
||||
get_resp.headers = {"Content-Type": "text/plain", "Content-Length": "42"}
|
||||
get_resp.raise_for_status = MagicMock()
|
||||
mock_proxy.head.return_value = head_resp
|
||||
mock_proxy.get.return_value = get_resp
|
||||
|
||||
with app.test_request_context("/remote-files/https%3A%2F%2Fexample.com%2Ffile.txt"):
|
||||
result = RemoteFileInfoApi().get(_app_model(), _end_user(), "https%3A%2F%2Fexample.com%2Ffile.txt")
|
||||
|
||||
assert result["file_type"] == "text/plain"
|
||||
mock_proxy.get.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RemoteFileUploadApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestRemoteFileUploadApi:
|
||||
@patch("controllers.web.remote_files.file_helpers.get_signed_file_url", return_value="https://signed-url")
|
||||
@patch("controllers.web.remote_files.FileService")
|
||||
@patch("controllers.web.remote_files.helpers.guess_file_info_from_response")
|
||||
@patch("controllers.web.remote_files.ssrf_proxy")
|
||||
@patch("controllers.web.remote_files.web_ns")
|
||||
@patch("controllers.web.remote_files.db")
|
||||
def test_upload_success(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_ns: MagicMock,
|
||||
mock_proxy: MagicMock,
|
||||
mock_guess: MagicMock,
|
||||
mock_file_svc_cls: MagicMock,
|
||||
mock_signed: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_db.engine = "engine"
|
||||
mock_ns.payload = {"url": "https://example.com/file.pdf"}
|
||||
head_resp = MagicMock()
|
||||
head_resp.status_code = 200
|
||||
head_resp.content = b"pdf-content"
|
||||
head_resp.request.method = "HEAD"
|
||||
mock_proxy.head.return_value = head_resp
|
||||
get_resp = MagicMock()
|
||||
get_resp.content = b"pdf-content"
|
||||
mock_proxy.get.return_value = get_resp
|
||||
|
||||
mock_guess.return_value = SimpleNamespace(
|
||||
filename="file.pdf", extension="pdf", mimetype="application/pdf", size=100
|
||||
)
|
||||
mock_file_svc_cls.is_file_size_within_limit.return_value = True
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
upload_file = SimpleNamespace(
|
||||
id="f-1",
|
||||
name="file.pdf",
|
||||
size=100,
|
||||
extension="pdf",
|
||||
mime_type="application/pdf",
|
||||
created_by="eu-1",
|
||||
created_at=datetime(2024, 1, 1),
|
||||
)
|
||||
mock_file_svc_cls.return_value.upload_file.return_value = upload_file
|
||||
|
||||
with app.test_request_context("/remote-files/upload", method="POST"):
|
||||
result, status = RemoteFileUploadApi().post(_app_model(), _end_user())
|
||||
|
||||
assert status == 201
|
||||
assert result["id"] == "f-1"
|
||||
|
||||
@patch("controllers.web.remote_files.FileService.is_file_size_within_limit", return_value=False)
|
||||
@patch("controllers.web.remote_files.helpers.guess_file_info_from_response")
|
||||
@patch("controllers.web.remote_files.ssrf_proxy")
|
||||
@patch("controllers.web.remote_files.web_ns")
|
||||
def test_file_too_large(
|
||||
self,
|
||||
mock_ns: MagicMock,
|
||||
mock_proxy: MagicMock,
|
||||
mock_guess: MagicMock,
|
||||
mock_size_check: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_ns.payload = {"url": "https://example.com/big.zip"}
|
||||
head_resp = MagicMock()
|
||||
head_resp.status_code = 200
|
||||
mock_proxy.head.return_value = head_resp
|
||||
mock_guess.return_value = SimpleNamespace(
|
||||
filename="big.zip", extension="zip", mimetype="application/zip", size=999999999
|
||||
)
|
||||
|
||||
with app.test_request_context("/remote-files/upload", method="POST"):
|
||||
with pytest.raises(FileTooLargeError):
|
||||
RemoteFileUploadApi().post(_app_model(), _end_user())
|
||||
|
||||
@patch("controllers.web.remote_files.ssrf_proxy")
|
||||
@patch("controllers.web.remote_files.web_ns")
|
||||
def test_fetch_failure_raises(self, mock_ns: MagicMock, mock_proxy: MagicMock, app: Flask) -> None:
|
||||
import httpx
|
||||
|
||||
mock_ns.payload = {"url": "https://example.com/bad"}
|
||||
mock_proxy.head.side_effect = httpx.RequestError("connection failed")
|
||||
|
||||
with app.test_request_context("/remote-files/upload", method="POST"):
|
||||
with pytest.raises(RemoteFileUploadError):
|
||||
RemoteFileUploadApi().post(_app_model(), _end_user())
|
||||
97
api/tests/unit_tests/controllers/web/test_saved_message.py
Normal file
97
api/tests/unit_tests/controllers/web/test_saved_message.py
Normal file
@ -0,0 +1,97 @@
|
||||
"""Unit tests for controllers.web.saved_message endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.web.error import NotCompletionAppError
|
||||
from controllers.web.saved_message import SavedMessageApi, SavedMessageListApi
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
|
||||
def _completion_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="completion")
|
||||
|
||||
|
||||
def _chat_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="chat")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SavedMessageListApi (GET)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestSavedMessageListApiGet:
|
||||
def test_non_completion_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/saved-messages"):
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
SavedMessageListApi().get(_chat_app(), _end_user())
|
||||
|
||||
@patch("controllers.web.saved_message.SavedMessageService.pagination_by_last_id")
|
||||
def test_happy_path(self, mock_paginate: MagicMock, app: Flask) -> None:
|
||||
mock_paginate.return_value = SimpleNamespace(limit=20, has_more=False, data=[])
|
||||
|
||||
with app.test_request_context("/saved-messages?limit=20"):
|
||||
result = SavedMessageListApi().get(_completion_app(), _end_user())
|
||||
|
||||
assert result["limit"] == 20
|
||||
assert result["has_more"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SavedMessageListApi (POST)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestSavedMessageListApiPost:
|
||||
def test_non_completion_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/saved-messages", method="POST"):
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
SavedMessageListApi().post(_chat_app(), _end_user())
|
||||
|
||||
@patch("controllers.web.saved_message.SavedMessageService.save")
|
||||
@patch("controllers.web.saved_message.web_ns")
|
||||
def test_save_success(self, mock_ns: MagicMock, mock_save: MagicMock, app: Flask) -> None:
|
||||
msg_id = str(uuid4())
|
||||
mock_ns.payload = {"message_id": msg_id}
|
||||
|
||||
with app.test_request_context("/saved-messages", method="POST"):
|
||||
result = SavedMessageListApi().post(_completion_app(), _end_user())
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
@patch("controllers.web.saved_message.SavedMessageService.save", side_effect=MessageNotExistsError())
|
||||
@patch("controllers.web.saved_message.web_ns")
|
||||
def test_save_not_found(self, mock_ns: MagicMock, mock_save: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"message_id": str(uuid4())}
|
||||
|
||||
with app.test_request_context("/saved-messages", method="POST"):
|
||||
with pytest.raises(NotFound, match="Message Not Exists"):
|
||||
SavedMessageListApi().post(_completion_app(), _end_user())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SavedMessageApi (DELETE)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestSavedMessageApi:
|
||||
def test_non_completion_mode_raises(self, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/saved-messages/{msg_id}", method="DELETE"):
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
SavedMessageApi().delete(_chat_app(), _end_user(), msg_id)
|
||||
|
||||
@patch("controllers.web.saved_message.SavedMessageService.delete")
|
||||
def test_delete_success(self, mock_delete: MagicMock, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/saved-messages/{msg_id}", method="DELETE"):
|
||||
result, status = SavedMessageApi().delete(_completion_app(), _end_user(), msg_id)
|
||||
|
||||
assert status == 204
|
||||
assert result["result"] == "success"
|
||||
126
api/tests/unit_tests/controllers/web/test_site.py
Normal file
126
api/tests/unit_tests/controllers/web/test_site.py
Normal file
@ -0,0 +1,126 @@
|
||||
"""Unit tests for controllers.web.site endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.web.site import AppSiteApi, AppSiteInfo
|
||||
|
||||
|
||||
def _tenant(*, status: str = "normal") -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id="tenant-1",
|
||||
status=status,
|
||||
plan="basic",
|
||||
custom_config_dict={"remove_webapp_brand": False, "replace_webapp_logo": False},
|
||||
)
|
||||
|
||||
|
||||
def _site() -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
title="Site",
|
||||
icon_type="emoji",
|
||||
icon="robot",
|
||||
icon_background="#fff",
|
||||
description="desc",
|
||||
default_language="en",
|
||||
chat_color_theme="light",
|
||||
chat_color_theme_inverted=False,
|
||||
copyright=None,
|
||||
privacy_policy=None,
|
||||
custom_disclaimer=None,
|
||||
prompt_public=False,
|
||||
show_workflow_steps=True,
|
||||
use_icon_as_answer_icon=False,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AppSiteApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestAppSiteApi:
|
||||
@patch("controllers.web.site.FeatureService.get_features")
|
||||
@patch("controllers.web.site.db")
|
||||
def test_happy_path(self, mock_db: MagicMock, mock_features: MagicMock, app: Flask) -> None:
|
||||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
mock_features.return_value = SimpleNamespace(can_replace_logo=False)
|
||||
site_obj = _site()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = site_obj
|
||||
tenant = _tenant()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True)
|
||||
end_user = SimpleNamespace(id="eu-1")
|
||||
|
||||
with app.test_request_context("/site"):
|
||||
result = AppSiteApi().get(app_model, end_user)
|
||||
|
||||
# marshal_with serializes AppSiteInfo to a dict
|
||||
assert result["app_id"] == "app-1"
|
||||
assert result["plan"] == "basic"
|
||||
assert result["enable_site"] is True
|
||||
|
||||
@patch("controllers.web.site.db")
|
||||
def test_missing_site_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None:
|
||||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
tenant = _tenant()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant)
|
||||
end_user = SimpleNamespace(id="eu-1")
|
||||
|
||||
with app.test_request_context("/site"):
|
||||
with pytest.raises(Forbidden):
|
||||
AppSiteApi().get(app_model, end_user)
|
||||
|
||||
@patch("controllers.web.site.db")
|
||||
def test_archived_tenant_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None:
|
||||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
from models.account import TenantStatus
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = _site()
|
||||
tenant = SimpleNamespace(
|
||||
id="tenant-1",
|
||||
status=TenantStatus.ARCHIVE,
|
||||
plan="basic",
|
||||
custom_config_dict={},
|
||||
)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant)
|
||||
end_user = SimpleNamespace(id="eu-1")
|
||||
|
||||
with app.test_request_context("/site"):
|
||||
with pytest.raises(Forbidden):
|
||||
AppSiteApi().get(app_model, end_user)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AppSiteInfo
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestAppSiteInfo:
|
||||
def test_basic_fields(self) -> None:
|
||||
tenant = _tenant()
|
||||
site_obj = _site()
|
||||
info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", False)
|
||||
|
||||
assert info.app_id == "app-1"
|
||||
assert info.end_user_id == "eu-1"
|
||||
assert info.enable_site is True
|
||||
assert info.plan == "basic"
|
||||
assert info.can_replace_logo is False
|
||||
assert info.model_config is None
|
||||
|
||||
@patch("controllers.web.site.dify_config", SimpleNamespace(FILES_URL="https://files.example.com"))
|
||||
def test_can_replace_logo_sets_custom_config(self) -> None:
|
||||
tenant = SimpleNamespace(
|
||||
id="tenant-1",
|
||||
plan="pro",
|
||||
custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": True},
|
||||
)
|
||||
site_obj = _site()
|
||||
info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", True)
|
||||
|
||||
assert info.can_replace_logo is True
|
||||
assert info.custom_config["remove_webapp_brand"] is True
|
||||
assert "webapp-logo" in info.custom_config["replace_webapp_logo"]
|
||||
@ -5,7 +5,8 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi
|
||||
import services.errors.account
|
||||
from controllers.web.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi, LoginApi, LoginStatusApi, LogoutApi
|
||||
|
||||
|
||||
def encode_code(code: str) -> str:
|
||||
@ -89,3 +90,114 @@ class TestEmailCodeLoginApi:
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_login.assert_called_once()
|
||||
mock_reset_login_rate.assert_called_once_with("user@example.com")
|
||||
|
||||
|
||||
class TestLoginApi:
|
||||
@patch("controllers.web.login.WebAppAuthService.login", return_value="access-tok")
|
||||
@patch("controllers.web.login.WebAppAuthService.authenticate")
|
||||
def test_login_success(self, mock_auth: MagicMock, mock_login: MagicMock, app: Flask) -> None:
|
||||
mock_auth.return_value = MagicMock()
|
||||
|
||||
with app.test_request_context(
|
||||
"/web/login",
|
||||
method="POST",
|
||||
json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()},
|
||||
):
|
||||
response = LoginApi().post()
|
||||
|
||||
assert response.get_json()["data"]["access_token"] == "access-tok"
|
||||
mock_auth.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"controllers.web.login.WebAppAuthService.authenticate",
|
||||
side_effect=services.errors.account.AccountLoginError(),
|
||||
)
|
||||
def test_login_banned_account(self, mock_auth: MagicMock, app: Flask) -> None:
|
||||
from controllers.console.error import AccountBannedError
|
||||
|
||||
with app.test_request_context(
|
||||
"/web/login",
|
||||
method="POST",
|
||||
json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()},
|
||||
):
|
||||
with pytest.raises(AccountBannedError):
|
||||
LoginApi().post()
|
||||
|
||||
@patch(
|
||||
"controllers.web.login.WebAppAuthService.authenticate",
|
||||
side_effect=services.errors.account.AccountPasswordError(),
|
||||
)
|
||||
def test_login_wrong_password(self, mock_auth: MagicMock, app: Flask) -> None:
|
||||
from controllers.console.auth.error import AuthenticationFailedError
|
||||
|
||||
with app.test_request_context(
|
||||
"/web/login",
|
||||
method="POST",
|
||||
json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()},
|
||||
):
|
||||
with pytest.raises(AuthenticationFailedError):
|
||||
LoginApi().post()
|
||||
|
||||
|
||||
class TestLoginStatusApi:
|
||||
@patch("controllers.web.login.extract_webapp_access_token", return_value=None)
|
||||
def test_no_app_code_returns_logged_in_false(self, mock_extract: MagicMock, app: Flask) -> None:
|
||||
with app.test_request_context("/web/login/status"):
|
||||
result = LoginStatusApi().get()
|
||||
|
||||
assert result["logged_in"] is False
|
||||
assert result["app_logged_in"] is False
|
||||
|
||||
@patch("controllers.web.login.decode_jwt_token")
|
||||
@patch("controllers.web.login.PassportService")
|
||||
@patch("controllers.web.login.WebAppAuthService.is_app_require_permission_check", return_value=False)
|
||||
@patch("controllers.web.login.AppService.get_app_id_by_code", return_value="app-1")
|
||||
@patch("controllers.web.login.extract_webapp_access_token", return_value="tok")
|
||||
def test_public_app_user_logged_in(
|
||||
self,
|
||||
mock_extract: MagicMock,
|
||||
mock_app_id: MagicMock,
|
||||
mock_perm: MagicMock,
|
||||
mock_passport: MagicMock,
|
||||
mock_decode: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_decode.return_value = (MagicMock(), MagicMock())
|
||||
|
||||
with app.test_request_context("/web/login/status?app_code=code1"):
|
||||
result = LoginStatusApi().get()
|
||||
|
||||
assert result["logged_in"] is True
|
||||
assert result["app_logged_in"] is True
|
||||
|
||||
@patch("controllers.web.login.decode_jwt_token", side_effect=Exception("bad"))
|
||||
@patch("controllers.web.login.PassportService")
|
||||
@patch("controllers.web.login.WebAppAuthService.is_app_require_permission_check", return_value=True)
|
||||
@patch("controllers.web.login.AppService.get_app_id_by_code", return_value="app-1")
|
||||
@patch("controllers.web.login.extract_webapp_access_token", return_value="tok")
|
||||
def test_private_app_passport_fails(
|
||||
self,
|
||||
mock_extract: MagicMock,
|
||||
mock_app_id: MagicMock,
|
||||
mock_perm: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_decode: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_passport_cls.return_value.verify.side_effect = Exception("bad")
|
||||
|
||||
with app.test_request_context("/web/login/status?app_code=code1"):
|
||||
result = LoginStatusApi().get()
|
||||
|
||||
assert result["logged_in"] is False
|
||||
assert result["app_logged_in"] is False
|
||||
|
||||
|
||||
class TestLogoutApi:
|
||||
@patch("controllers.web.login.clear_webapp_access_token_from_cookie")
|
||||
def test_logout_success(self, mock_clear: MagicMock, app: Flask) -> None:
|
||||
with app.test_request_context("/web/logout", method="POST"):
|
||||
response = LogoutApi().post()
|
||||
|
||||
assert response.get_json() == {"result": "success"}
|
||||
mock_clear.assert_called_once()
|
||||
|
||||
192
api/tests/unit_tests/controllers/web/test_web_passport.py
Normal file
192
api/tests/unit_tests/controllers/web/test_web_passport.py
Normal file
@ -0,0 +1,192 @@
|
||||
"""Unit tests for controllers.web.passport — token issuance and enterprise auth exchange."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from controllers.web.error import WebAppAuthRequiredError
|
||||
from controllers.web.passport import (
|
||||
PassportResource,
|
||||
decode_enterprise_webapp_user_id,
|
||||
exchange_token_for_existing_web_user,
|
||||
generate_session_id,
|
||||
)
|
||||
from services.webapp_auth_service import WebAppAuthType
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# decode_enterprise_webapp_user_id
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestDecodeEnterpriseWebappUserId:
|
||||
def test_none_token_returns_none(self) -> None:
|
||||
assert decode_enterprise_webapp_user_id(None) is None
|
||||
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
def test_valid_token_returns_decoded(self, mock_passport_cls: MagicMock) -> None:
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"token_source": "webapp_login_token",
|
||||
"user_id": "u1",
|
||||
}
|
||||
result = decode_enterprise_webapp_user_id("valid-jwt")
|
||||
assert result["user_id"] == "u1"
|
||||
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
def test_wrong_source_raises_unauthorized(self, mock_passport_cls: MagicMock) -> None:
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"token_source": "other_source",
|
||||
}
|
||||
with pytest.raises(Unauthorized, match="Expected 'webapp_login_token'"):
|
||||
decode_enterprise_webapp_user_id("bad-jwt")
|
||||
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
def test_missing_source_raises_unauthorized(self, mock_passport_cls: MagicMock) -> None:
|
||||
mock_passport_cls.return_value.verify.return_value = {}
|
||||
with pytest.raises(Unauthorized, match="Expected 'webapp_login_token'"):
|
||||
decode_enterprise_webapp_user_id("no-source-jwt")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# generate_session_id
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestGenerateSessionId:
|
||||
@patch("controllers.web.passport.db")
|
||||
def test_returns_unique_session_id(self, mock_db: MagicMock) -> None:
|
||||
mock_db.session.scalar.return_value = 0
|
||||
sid = generate_session_id()
|
||||
assert isinstance(sid, str)
|
||||
assert len(sid) == 36 # UUID format
|
||||
|
||||
@patch("controllers.web.passport.db")
|
||||
def test_retries_on_collision(self, mock_db: MagicMock) -> None:
|
||||
# First call returns count=1 (collision), second returns 0
|
||||
mock_db.session.scalar.side_effect = [1, 0]
|
||||
sid = generate_session_id()
|
||||
assert isinstance(sid, str)
|
||||
assert mock_db.session.scalar.call_count == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# exchange_token_for_existing_web_user
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestExchangeTokenForExistingWebUser:
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
@patch("controllers.web.passport.db")
|
||||
def test_external_auth_type_mismatch_raises(self, mock_db: MagicMock, mock_passport_cls: MagicMock) -> None:
|
||||
site = SimpleNamespace(code="code1", app_id="app-1")
|
||||
app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1")
|
||||
mock_db.session.scalar.side_effect = [site, app_model]
|
||||
|
||||
decoded = {"user_id": "u1", "auth_type": "internal"} # mismatch: expected "external"
|
||||
with pytest.raises(WebAppAuthRequiredError, match="external"):
|
||||
exchange_token_for_existing_web_user(
|
||||
app_code="code1", enterprise_user_decoded=decoded, auth_type=WebAppAuthType.EXTERNAL
|
||||
)
|
||||
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
@patch("controllers.web.passport.db")
|
||||
def test_internal_auth_type_mismatch_raises(self, mock_db: MagicMock, mock_passport_cls: MagicMock) -> None:
|
||||
site = SimpleNamespace(code="code1", app_id="app-1")
|
||||
app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1")
|
||||
mock_db.session.scalar.side_effect = [site, app_model]
|
||||
|
||||
decoded = {"user_id": "u1", "auth_type": "external"} # mismatch: expected "internal"
|
||||
with pytest.raises(WebAppAuthRequiredError, match="internal"):
|
||||
exchange_token_for_existing_web_user(
|
||||
app_code="code1", enterprise_user_decoded=decoded, auth_type=WebAppAuthType.INTERNAL
|
||||
)
|
||||
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
@patch("controllers.web.passport.db")
|
||||
def test_site_not_found_raises(self, mock_db: MagicMock, mock_passport_cls: MagicMock) -> None:
|
||||
mock_db.session.scalar.return_value = None
|
||||
decoded = {"user_id": "u1", "auth_type": "external"}
|
||||
with pytest.raises(NotFound):
|
||||
exchange_token_for_existing_web_user(
|
||||
app_code="code1", enterprise_user_decoded=decoded, auth_type=WebAppAuthType.EXTERNAL
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PassportResource.get
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestPassportResource:
|
||||
@patch("controllers.web.passport.FeatureService.get_system_features")
|
||||
def test_missing_app_code_raises_unauthorized(self, mock_features: MagicMock, app: Flask) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
with app.test_request_context("/passport"):
|
||||
with pytest.raises(Unauthorized, match="X-App-Code"):
|
||||
PassportResource().get()
|
||||
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
@patch("controllers.web.passport.generate_session_id", return_value="new-sess-id")
|
||||
@patch("controllers.web.passport.db")
|
||||
@patch("controllers.web.passport.FeatureService.get_system_features")
|
||||
def test_creates_new_end_user_when_no_user_id(
|
||||
self,
|
||||
mock_features: MagicMock,
|
||||
mock_db: MagicMock,
|
||||
mock_gen_session: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
site = SimpleNamespace(app_id="app-1", code="code1")
|
||||
app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1")
|
||||
mock_db.session.scalar.side_effect = [site, app_model]
|
||||
mock_passport_cls.return_value.issue.return_value = "issued-token"
|
||||
|
||||
with app.test_request_context("/passport", headers={"X-App-Code": "code1"}):
|
||||
response = PassportResource().get()
|
||||
|
||||
assert response.get_json()["access_token"] == "issued-token"
|
||||
mock_db.session.add.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
@patch("controllers.web.passport.db")
|
||||
@patch("controllers.web.passport.FeatureService.get_system_features")
|
||||
def test_reuses_existing_end_user_when_user_id_provided(
|
||||
self,
|
||||
mock_features: MagicMock,
|
||||
mock_db: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
site = SimpleNamespace(app_id="app-1", code="code1")
|
||||
app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1")
|
||||
existing_user = SimpleNamespace(id="eu-1", session_id="sess-existing")
|
||||
mock_db.session.scalar.side_effect = [site, app_model, existing_user]
|
||||
mock_passport_cls.return_value.issue.return_value = "reused-token"
|
||||
|
||||
with app.test_request_context("/passport?user_id=sess-existing", headers={"X-App-Code": "code1"}):
|
||||
response = PassportResource().get()
|
||||
|
||||
assert response.get_json()["access_token"] == "reused-token"
|
||||
# Should not create a new end user
|
||||
mock_db.session.add.assert_not_called()
|
||||
|
||||
@patch("controllers.web.passport.db")
|
||||
@patch("controllers.web.passport.FeatureService.get_system_features")
|
||||
def test_site_not_found_raises(self, mock_features: MagicMock, mock_db: MagicMock, app: Flask) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
mock_db.session.scalar.return_value = None
|
||||
with app.test_request_context("/passport", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(NotFound):
|
||||
PassportResource().get()
|
||||
|
||||
@patch("controllers.web.passport.db")
|
||||
@patch("controllers.web.passport.FeatureService.get_system_features")
|
||||
def test_disabled_app_raises_not_found(self, mock_features: MagicMock, mock_db: MagicMock, app: Flask) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
site = SimpleNamespace(app_id="app-1", code="code1")
|
||||
disabled_app = SimpleNamespace(id="app-1", status="normal", enable_site=False)
|
||||
mock_db.session.scalar.side_effect = [site, disabled_app]
|
||||
with app.test_request_context("/passport", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(NotFound):
|
||||
PassportResource().get()
|
||||
95
api/tests/unit_tests/controllers/web/test_workflow.py
Normal file
95
api/tests/unit_tests/controllers/web/test_workflow.py
Normal file
@ -0,0 +1,95 @@
|
||||
"""Unit tests for controllers.web.workflow endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.error import (
|
||||
NotWorkflowAppError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.web.workflow import WorkflowRunApi, WorkflowTaskStopApi
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError
|
||||
|
||||
|
||||
def _workflow_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="workflow")
|
||||
|
||||
|
||||
def _chat_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="chat")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WorkflowRunApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestWorkflowRunApi:
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/workflows/run", method="POST"):
|
||||
with pytest.raises(NotWorkflowAppError):
|
||||
WorkflowRunApi().post(_chat_app(), _end_user())
|
||||
|
||||
@patch("controllers.web.workflow.helper.compact_generate_response", return_value={"result": "ok"})
|
||||
@patch("controllers.web.workflow.AppGenerateService.generate")
|
||||
@patch("controllers.web.workflow.web_ns")
|
||||
def test_happy_path(self, mock_ns: MagicMock, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {"key": "val"}}
|
||||
mock_gen.return_value = "response"
|
||||
|
||||
with app.test_request_context("/workflows/run", method="POST"):
|
||||
result = WorkflowRunApi().post(_workflow_app(), _end_user())
|
||||
|
||||
assert result == {"result": "ok"}
|
||||
|
||||
@patch(
|
||||
"controllers.web.workflow.AppGenerateService.generate",
|
||||
side_effect=ProviderTokenNotInitError(description="not init"),
|
||||
)
|
||||
@patch("controllers.web.workflow.web_ns")
|
||||
def test_provider_not_init(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {}}
|
||||
|
||||
with app.test_request_context("/workflows/run", method="POST"):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
WorkflowRunApi().post(_workflow_app(), _end_user())
|
||||
|
||||
@patch(
|
||||
"controllers.web.workflow.AppGenerateService.generate",
|
||||
side_effect=QuotaExceededError(),
|
||||
)
|
||||
@patch("controllers.web.workflow.web_ns")
|
||||
def test_quota_exceeded(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {}}
|
||||
|
||||
with app.test_request_context("/workflows/run", method="POST"):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
WorkflowRunApi().post(_workflow_app(), _end_user())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WorkflowTaskStopApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestWorkflowTaskStopApi:
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/workflows/tasks/task-1/stop", method="POST"):
|
||||
with pytest.raises(NotWorkflowAppError):
|
||||
WorkflowTaskStopApi().post(_chat_app(), _end_user(), "task-1")
|
||||
|
||||
@patch("controllers.web.workflow.GraphEngineManager.send_stop_command")
|
||||
@patch("controllers.web.workflow.AppQueueManager.set_stop_flag_no_user_check")
|
||||
def test_stop_calls_both_mechanisms(self, mock_legacy: MagicMock, mock_graph: MagicMock, app: Flask) -> None:
|
||||
with app.test_request_context("/workflows/tasks/task-1/stop", method="POST"):
|
||||
result = WorkflowTaskStopApi().post(_workflow_app(), _end_user(), "task-1")
|
||||
|
||||
assert result == {"result": "success"}
|
||||
mock_legacy.assert_called_once_with("task-1")
|
||||
mock_graph.assert_called_once_with("task-1")
|
||||
127
api/tests/unit_tests/controllers/web/test_workflow_events.py
Normal file
127
api/tests/unit_tests/controllers/web/test_workflow_events.py
Normal file
@ -0,0 +1,127 @@
|
||||
"""Unit tests for controllers.web.workflow_events endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.error import NotFoundError
|
||||
from controllers.web.workflow_events import WorkflowEventsApi
|
||||
from models.enums import CreatorUserRole
|
||||
|
||||
|
||||
def _workflow_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", tenant_id="tenant-1", mode="workflow")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WorkflowEventsApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestWorkflowEventsApi:
|
||||
@patch("controllers.web.workflow_events.DifyAPIRepositoryFactory")
|
||||
@patch("controllers.web.workflow_events.db")
|
||||
def test_workflow_run_not_found(self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask) -> None:
|
||||
mock_db.engine = "engine"
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = None
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_repo
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events"):
|
||||
with pytest.raises(NotFoundError):
|
||||
WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1")
|
||||
|
||||
@patch("controllers.web.workflow_events.DifyAPIRepositoryFactory")
|
||||
@patch("controllers.web.workflow_events.db")
|
||||
def test_workflow_run_wrong_app(self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask) -> None:
|
||||
mock_db.engine = "engine"
|
||||
run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="other-app",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="eu-1",
|
||||
finished_at=None,
|
||||
)
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_repo
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events"):
|
||||
with pytest.raises(NotFoundError):
|
||||
WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1")
|
||||
|
||||
@patch("controllers.web.workflow_events.DifyAPIRepositoryFactory")
|
||||
@patch("controllers.web.workflow_events.db")
|
||||
def test_workflow_run_not_created_by_end_user(
|
||||
self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask
|
||||
) -> None:
|
||||
mock_db.engine = "engine"
|
||||
run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="app-1",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by="eu-1",
|
||||
finished_at=None,
|
||||
)
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_repo
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events"):
|
||||
with pytest.raises(NotFoundError):
|
||||
WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1")
|
||||
|
||||
@patch("controllers.web.workflow_events.DifyAPIRepositoryFactory")
|
||||
@patch("controllers.web.workflow_events.db")
|
||||
def test_workflow_run_wrong_end_user(self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask) -> None:
|
||||
mock_db.engine = "engine"
|
||||
run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="app-1",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="other-user",
|
||||
finished_at=None,
|
||||
)
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_repo
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events"):
|
||||
with pytest.raises(NotFoundError):
|
||||
WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1")
|
||||
|
||||
@patch("controllers.web.workflow_events.WorkflowResponseConverter")
|
||||
@patch("controllers.web.workflow_events.DifyAPIRepositoryFactory")
|
||||
@patch("controllers.web.workflow_events.db")
|
||||
def test_finished_run_returns_sse_response(
|
||||
self, mock_db: MagicMock, mock_factory: MagicMock, mock_converter: MagicMock, app: Flask
|
||||
) -> None:
|
||||
from datetime import datetime
|
||||
|
||||
mock_db.engine = "engine"
|
||||
run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="app-1",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="eu-1",
|
||||
finished_at=datetime(2024, 1, 1),
|
||||
)
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_repo
|
||||
|
||||
finish_response = MagicMock()
|
||||
finish_response.model_dump.return_value = {"task_id": "run-1"}
|
||||
finish_response.event.value = "workflow_finished"
|
||||
mock_converter.workflow_run_result_to_finish_response.return_value = finish_response
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events"):
|
||||
response = WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1")
|
||||
|
||||
assert response.mimetype == "text/event-stream"
|
||||
393
api/tests/unit_tests/controllers/web/test_wraps.py
Normal file
393
api/tests/unit_tests/controllers/web/test_wraps.py
Normal file
@ -0,0 +1,393 @@
|
||||
"""Unit tests for controllers.web.wraps — JWT auth decorator and validation helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
||||
|
||||
from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError
|
||||
from controllers.web.wraps import (
|
||||
_validate_user_accessibility,
|
||||
_validate_webapp_token,
|
||||
decode_jwt_token,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_webapp_token
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestValidateWebappToken:
|
||||
def test_enterprise_enabled_and_app_auth_requires_webapp_source(self) -> None:
|
||||
"""When both flags are true, a non-webapp source must raise."""
|
||||
decoded = {"token_source": "other"}
|
||||
with pytest.raises(WebAppAuthRequiredError):
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True)
|
||||
|
||||
def test_enterprise_enabled_and_app_auth_accepts_webapp_source(self) -> None:
|
||||
decoded = {"token_source": "webapp"}
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True)
|
||||
|
||||
def test_enterprise_enabled_and_app_auth_missing_source_raises(self) -> None:
|
||||
decoded = {}
|
||||
with pytest.raises(WebAppAuthRequiredError):
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True)
|
||||
|
||||
def test_public_app_rejects_webapp_source(self) -> None:
|
||||
"""When auth is not required, a webapp-sourced token must be rejected."""
|
||||
decoded = {"token_source": "webapp"}
|
||||
with pytest.raises(Unauthorized):
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False)
|
||||
|
||||
def test_public_app_accepts_non_webapp_source(self) -> None:
|
||||
decoded = {"token_source": "other"}
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False)
|
||||
|
||||
def test_public_app_accepts_no_source(self) -> None:
|
||||
decoded = {}
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False)
|
||||
|
||||
def test_system_enabled_but_app_public(self) -> None:
|
||||
"""system_webapp_auth_enabled=True but app is public — webapp source rejected."""
|
||||
decoded = {"token_source": "webapp"}
|
||||
with pytest.raises(Unauthorized):
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=True)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_user_accessibility
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestValidateUserAccessibility:
|
||||
def test_skips_when_auth_disabled(self) -> None:
|
||||
"""No checks when system or app auth is disabled."""
|
||||
_validate_user_accessibility(
|
||||
decoded={},
|
||||
app_code="code",
|
||||
app_web_auth_enabled=False,
|
||||
system_webapp_auth_enabled=False,
|
||||
webapp_settings=None,
|
||||
)
|
||||
|
||||
def test_missing_user_id_raises(self) -> None:
|
||||
decoded = {}
|
||||
with pytest.raises(WebAppAuthRequiredError):
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
app_web_auth_enabled=True,
|
||||
system_webapp_auth_enabled=True,
|
||||
webapp_settings=SimpleNamespace(access_mode="internal"),
|
||||
)
|
||||
|
||||
def test_missing_webapp_settings_raises(self) -> None:
|
||||
decoded = {"user_id": "u1"}
|
||||
with pytest.raises(WebAppAuthRequiredError, match="settings not found"):
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
app_web_auth_enabled=True,
|
||||
system_webapp_auth_enabled=True,
|
||||
webapp_settings=None,
|
||||
)
|
||||
|
||||
def test_missing_auth_type_raises(self) -> None:
|
||||
decoded = {"user_id": "u1", "granted_at": 1}
|
||||
settings = SimpleNamespace(access_mode="public")
|
||||
with pytest.raises(WebAppAuthAccessDeniedError, match="auth_type"):
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
app_web_auth_enabled=True,
|
||||
system_webapp_auth_enabled=True,
|
||||
webapp_settings=settings,
|
||||
)
|
||||
|
||||
def test_missing_granted_at_raises(self) -> None:
|
||||
decoded = {"user_id": "u1", "auth_type": "external"}
|
||||
settings = SimpleNamespace(access_mode="public")
|
||||
with pytest.raises(WebAppAuthAccessDeniedError, match="granted_at"):
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
app_web_auth_enabled=True,
|
||||
system_webapp_auth_enabled=True,
|
||||
webapp_settings=settings,
|
||||
)
|
||||
|
||||
@patch("controllers.web.wraps.EnterpriseService.get_app_sso_settings_last_update_time")
|
||||
@patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=False)
|
||||
def test_external_auth_type_checks_sso_update_time(
|
||||
self, mock_perm_check: MagicMock, mock_sso_time: MagicMock
|
||||
) -> None:
|
||||
# granted_at is before SSO update time → denied
|
||||
mock_sso_time.return_value = datetime.now(UTC)
|
||||
old_granted = int((datetime.now(UTC) - timedelta(hours=1)).timestamp())
|
||||
decoded = {"user_id": "u1", "auth_type": "external", "granted_at": old_granted}
|
||||
settings = SimpleNamespace(access_mode="public")
|
||||
with pytest.raises(WebAppAuthAccessDeniedError, match="SSO settings"):
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
app_web_auth_enabled=True,
|
||||
system_webapp_auth_enabled=True,
|
||||
webapp_settings=settings,
|
||||
)
|
||||
|
||||
@patch("controllers.web.wraps.EnterpriseService.get_workspace_sso_settings_last_update_time")
|
||||
@patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=False)
|
||||
def test_internal_auth_type_checks_workspace_sso_update_time(
|
||||
self, mock_perm_check: MagicMock, mock_workspace_sso: MagicMock
|
||||
) -> None:
|
||||
mock_workspace_sso.return_value = datetime.now(UTC)
|
||||
old_granted = int((datetime.now(UTC) - timedelta(hours=1)).timestamp())
|
||||
decoded = {"user_id": "u1", "auth_type": "internal", "granted_at": old_granted}
|
||||
settings = SimpleNamespace(access_mode="public")
|
||||
with pytest.raises(WebAppAuthAccessDeniedError, match="SSO settings"):
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
app_web_auth_enabled=True,
|
||||
system_webapp_auth_enabled=True,
|
||||
webapp_settings=settings,
|
||||
)
|
||||
|
||||
@patch("controllers.web.wraps.EnterpriseService.get_app_sso_settings_last_update_time")
|
||||
@patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=False)
|
||||
def test_external_auth_passes_when_granted_after_sso_update(
|
||||
self, mock_perm_check: MagicMock, mock_sso_time: MagicMock
|
||||
) -> None:
|
||||
mock_sso_time.return_value = datetime.now(UTC) - timedelta(hours=2)
|
||||
recent_granted = int(datetime.now(UTC).timestamp())
|
||||
decoded = {"user_id": "u1", "auth_type": "external", "granted_at": recent_granted}
|
||||
settings = SimpleNamespace(access_mode="public")
|
||||
# Should not raise
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
app_web_auth_enabled=True,
|
||||
system_webapp_auth_enabled=True,
|
||||
webapp_settings=settings,
|
||||
)
|
||||
|
||||
@patch("controllers.web.wraps.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp", return_value=False)
|
||||
@patch("controllers.web.wraps.AppService.get_app_id_by_code", return_value="app-id-1")
|
||||
@patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=True)
|
||||
def test_permission_check_denies_unauthorized_user(
|
||||
self, mock_perm: MagicMock, mock_app_id: MagicMock, mock_allowed: MagicMock
|
||||
) -> None:
|
||||
decoded = {"user_id": "u1", "auth_type": "external", "granted_at": int(datetime.now(UTC).timestamp())}
|
||||
settings = SimpleNamespace(access_mode="internal")
|
||||
with pytest.raises(WebAppAuthAccessDeniedError):
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
app_web_auth_enabled=True,
|
||||
system_webapp_auth_enabled=True,
|
||||
webapp_settings=settings,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# decode_jwt_token
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestDecodeJwtToken:
|
||||
@patch("controllers.web.wraps._validate_user_accessibility")
|
||||
@patch("controllers.web.wraps._validate_webapp_token")
|
||||
@patch("controllers.web.wraps.EnterpriseService.WebAppAuth.get_app_access_mode_by_id")
|
||||
@patch("controllers.web.wraps.AppService.get_app_id_by_code")
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.PassportService")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
@patch("controllers.web.wraps.db")
|
||||
def test_happy_path(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
mock_app_id: MagicMock,
|
||||
mock_access_mode: MagicMock,
|
||||
mock_validate_token: MagicMock,
|
||||
mock_validate_user: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_extract.return_value = "jwt-token"
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"app_code": "code1",
|
||||
"app_id": "app-1",
|
||||
"end_user_id": "eu-1",
|
||||
}
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", enable_site=True)
|
||||
site = SimpleNamespace(code="code1")
|
||||
end_user = SimpleNamespace(id="eu-1", session_id="sess-1")
|
||||
|
||||
# Configure session mock to return correct objects via scalar()
|
||||
session_mock = MagicMock()
|
||||
session_mock.scalar.side_effect = [app_model, site, end_user]
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_db.engine = "engine"
|
||||
|
||||
with patch("controllers.web.wraps.Session", return_value=session_ctx):
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
result_app, result_user = decode_jwt_token()
|
||||
|
||||
assert result_app.id == "app-1"
|
||||
assert result_user.id == "eu-1"
|
||||
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
def test_missing_token_raises_unauthorized(
|
||||
self, mock_extract: MagicMock, mock_features: MagicMock, app: Flask
|
||||
) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
mock_extract.return_value = None
|
||||
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(Unauthorized):
|
||||
decode_jwt_token()
|
||||
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.PassportService")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
@patch("controllers.web.wraps.db")
|
||||
def test_missing_app_raises_not_found(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_extract.return_value = "jwt-token"
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"app_code": "code1",
|
||||
"app_id": "app-1",
|
||||
"end_user_id": "eu-1",
|
||||
}
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
session_mock = MagicMock()
|
||||
session_mock.scalar.return_value = None # No app found
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_db.engine = "engine"
|
||||
|
||||
with patch("controllers.web.wraps.Session", return_value=session_ctx):
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(NotFound):
|
||||
decode_jwt_token()
|
||||
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.PassportService")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
@patch("controllers.web.wraps.db")
|
||||
def test_disabled_site_raises_bad_request(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_extract.return_value = "jwt-token"
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"app_code": "code1",
|
||||
"app_id": "app-1",
|
||||
"end_user_id": "eu-1",
|
||||
}
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", enable_site=False)
|
||||
|
||||
session_mock = MagicMock()
|
||||
# scalar calls: app_model, site (code found), then end_user
|
||||
session_mock.scalar.side_effect = [app_model, SimpleNamespace(code="code1"), None]
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_db.engine = "engine"
|
||||
|
||||
with patch("controllers.web.wraps.Session", return_value=session_ctx):
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(BadRequest, match="Site is disabled"):
|
||||
decode_jwt_token()
|
||||
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.PassportService")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
@patch("controllers.web.wraps.db")
|
||||
def test_missing_end_user_raises_not_found(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_extract.return_value = "jwt-token"
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"app_code": "code1",
|
||||
"app_id": "app-1",
|
||||
"end_user_id": "eu-1",
|
||||
}
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", enable_site=True)
|
||||
site = SimpleNamespace(code="code1")
|
||||
|
||||
session_mock = MagicMock()
|
||||
session_mock.scalar.side_effect = [app_model, site, None] # end_user is None
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_db.engine = "engine"
|
||||
|
||||
with patch("controllers.web.wraps.Session", return_value=session_ctx):
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(NotFound):
|
||||
decode_jwt_token()
|
||||
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.PassportService")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
@patch("controllers.web.wraps.db")
|
||||
def test_user_id_mismatch_raises_unauthorized(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_extract.return_value = "jwt-token"
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"app_code": "code1",
|
||||
"app_id": "app-1",
|
||||
"end_user_id": "eu-1",
|
||||
}
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", enable_site=True)
|
||||
site = SimpleNamespace(code="code1")
|
||||
end_user = SimpleNamespace(id="eu-1", session_id="sess-1")
|
||||
|
||||
session_mock = MagicMock()
|
||||
session_mock.scalar.side_effect = [app_model, site, end_user]
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_db.engine = "engine"
|
||||
|
||||
with patch("controllers.web.wraps.Session", return_value=session_ctx):
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(Unauthorized, match="expired"):
|
||||
decode_jwt_token(user_id="different-user")
|
||||
@ -265,61 +265,6 @@ def test_run_exits_on_empty_batch(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
cleanup.run()
|
||||
|
||||
|
||||
def test_run_records_metrics_on_success(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
cutoff = datetime.datetime.now()
|
||||
repo = FakeRepo(
|
||||
batches=[[FakeRun("run-free", "t_free", cutoff)]],
|
||||
delete_result={
|
||||
"runs": 0,
|
||||
"node_executions": 2,
|
||||
"offloads": 1,
|
||||
"app_logs": 3,
|
||||
"trigger_logs": 4,
|
||||
"pauses": 5,
|
||||
"pause_reasons": 6,
|
||||
},
|
||||
)
|
||||
cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10)
|
||||
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", False)
|
||||
|
||||
batch_calls: list[dict[str, object]] = []
|
||||
completion_calls: list[dict[str, object]] = []
|
||||
monkeypatch.setattr(cleanup._metrics, "record_batch", lambda **kwargs: batch_calls.append(kwargs))
|
||||
monkeypatch.setattr(cleanup._metrics, "record_completion", lambda **kwargs: completion_calls.append(kwargs))
|
||||
|
||||
cleanup.run()
|
||||
|
||||
assert len(batch_calls) == 1
|
||||
assert batch_calls[0]["batch_rows"] == 1
|
||||
assert batch_calls[0]["targeted_runs"] == 1
|
||||
assert batch_calls[0]["deleted_runs"] == 1
|
||||
assert batch_calls[0]["related_action"] == "deleted"
|
||||
assert len(completion_calls) == 1
|
||||
assert completion_calls[0]["status"] == "success"
|
||||
|
||||
|
||||
def test_run_records_failed_metrics(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class FailingRepo(FakeRepo):
|
||||
def delete_runs_with_related(
|
||||
self, runs: list[FakeRun], delete_node_executions=None, delete_trigger_logs=None
|
||||
) -> dict[str, int]:
|
||||
raise RuntimeError("delete failed")
|
||||
|
||||
cutoff = datetime.datetime.now()
|
||||
repo = FailingRepo(batches=[[FakeRun("run-free", "t_free", cutoff)]])
|
||||
cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10)
|
||||
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", False)
|
||||
|
||||
completion_calls: list[dict[str, object]] = []
|
||||
monkeypatch.setattr(cleanup._metrics, "record_completion", lambda **kwargs: completion_calls.append(kwargs))
|
||||
|
||||
with pytest.raises(RuntimeError, match="delete failed"):
|
||||
cleanup.run()
|
||||
|
||||
assert len(completion_calls) == 1
|
||||
assert completion_calls[0]["status"] == "failed"
|
||||
|
||||
|
||||
def test_run_dry_run_skips_deletions(monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
cutoff = datetime.datetime.now()
|
||||
repo = FakeRepo(
|
||||
|
||||
@ -619,53 +619,3 @@ class TestMessagesCleanServiceFromDays:
|
||||
assert service._end_before == expected_end_before
|
||||
assert service._batch_size == 1000 # default
|
||||
assert service._dry_run is False # default
|
||||
|
||||
|
||||
class TestMessagesCleanServiceRun:
|
||||
"""Unit tests for MessagesCleanService.run instrumentation behavior."""
|
||||
|
||||
def test_run_records_completion_metrics_on_success(self):
|
||||
# Arrange
|
||||
service = MessagesCleanService(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=datetime.datetime(2024, 1, 1),
|
||||
end_before=datetime.datetime(2024, 1, 2),
|
||||
batch_size=100,
|
||||
dry_run=False,
|
||||
)
|
||||
expected_stats = {
|
||||
"batches": 1,
|
||||
"total_messages": 10,
|
||||
"filtered_messages": 5,
|
||||
"total_deleted": 5,
|
||||
}
|
||||
service._clean_messages_by_time_range = MagicMock(return_value=expected_stats) # type: ignore[method-assign]
|
||||
completion_calls: list[dict[str, object]] = []
|
||||
service._metrics.record_completion = lambda **kwargs: completion_calls.append(kwargs) # type: ignore[method-assign]
|
||||
|
||||
# Act
|
||||
result = service.run()
|
||||
|
||||
# Assert
|
||||
assert result == expected_stats
|
||||
assert len(completion_calls) == 1
|
||||
assert completion_calls[0]["status"] == "success"
|
||||
|
||||
def test_run_records_completion_metrics_on_failure(self):
|
||||
# Arrange
|
||||
service = MessagesCleanService(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=datetime.datetime(2024, 1, 1),
|
||||
end_before=datetime.datetime(2024, 1, 2),
|
||||
batch_size=100,
|
||||
dry_run=False,
|
||||
)
|
||||
service._clean_messages_by_time_range = MagicMock(side_effect=RuntimeError("clean failed")) # type: ignore[method-assign]
|
||||
completion_calls: list[dict[str, object]] = []
|
||||
service._metrics.record_completion = lambda **kwargs: completion_calls.append(kwargs) # type: ignore[method-assign]
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="clean failed"):
|
||||
service.run()
|
||||
assert len(completion_calls) == 1
|
||||
assert completion_calls[0]["status"] == "failed"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user