mirror of
https://github.com/langgenius/dify.git
synced 2026-06-30 10:57:47 +08:00
Compare commits
9 Commits
deploy/dev
...
codex/repl
| Author | SHA1 | Date | |
|---|---|---|---|
| b252f4f2d3 | |||
| f148ab9a99 | |||
| 49a92f096f | |||
| fa1ac75922 | |||
| cb35c6fa98 | |||
| 34f62e7df6 | |||
| 07b5dcbb19 | |||
| 23917c7b3e | |||
| 8a6ce28855 |
@ -1,12 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import click
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from models import TenantAccountJoin, TenantAccountRole
|
||||
from services.enterprise.rbac_service import ListOption, RBACService
|
||||
|
||||
_LEGACY_ROLE_TO_BUILTIN_TAG = {
|
||||
TenantAccountRole.OWNER.value: "owner",
|
||||
TenantAccountRole.ADMIN.value: "admin",
|
||||
TenantAccountRole.EDITOR.value: "editor",
|
||||
TenantAccountRole.NORMAL.value: "normal",
|
||||
TenantAccountRole.DATASET_OPERATOR.value: "dataset_operator",
|
||||
}
|
||||
|
||||
|
||||
def _resolve_builtin_role_ids(tenant_id: str, operator_account_id: str) -> dict[str, str]:
|
||||
"""Resolve every legacy workspace role to the current tenant's builtin RBAC role id.
|
||||
|
||||
The migration replays the old `TenantAccountJoin.role` values onto the
|
||||
RBAC member-role binding API. Builtin RBAC roles are tenant-scoped and
|
||||
identified by runtime ids, so the command must look them up per tenant.
|
||||
"""
|
||||
roles = RBACService.Roles.list(
|
||||
tenant_id=tenant_id,
|
||||
account_id=operator_account_id,
|
||||
options=ListOption(page_number=1, results_per_page=100),
|
||||
).data
|
||||
role_id_by_tag = {
|
||||
role.role_tag: role.id
|
||||
for role in roles
|
||||
if role.is_builtin and role.category == "global_system_default" and role.role_tag
|
||||
}
|
||||
resolved: dict[str, str] = {}
|
||||
for legacy_role, expected_builtin_tag in _LEGACY_ROLE_TO_BUILTIN_TAG.items():
|
||||
role_id = role_id_by_tag.get(expected_builtin_tag)
|
||||
if expected_builtin_tag == "dataset_operator" and not dify_config.DATASET_OPERATOR_ENABLED:
|
||||
continue
|
||||
if not role_id:
|
||||
raise ValueError(f"Builtin RBAC role not found for tenant={tenant_id}, legacy_role={legacy_role}")
|
||||
resolved[legacy_role] = role_id
|
||||
return resolved
|
||||
|
||||
|
||||
def _resolve_builtin_role_id(tenant_id: str, operator_account_id: str, legacy_role: str) -> str:
|
||||
"""Resolve a legacy workspace role to the current tenant's builtin RBAC role id.
|
||||
@ -15,26 +55,86 @@ def _resolve_builtin_role_id(tenant_id: str, operator_account_id: str, legacy_ro
|
||||
RBAC member-role binding API. Builtin RBAC roles are tenant-scoped and
|
||||
identified by runtime ids, so the command must look them up per tenant.
|
||||
"""
|
||||
expected_builtin_tag = {
|
||||
TenantAccountRole.OWNER.value: "owner",
|
||||
TenantAccountRole.ADMIN.value: "admin",
|
||||
TenantAccountRole.EDITOR.value: "editor",
|
||||
TenantAccountRole.NORMAL.value: "normal",
|
||||
TenantAccountRole.DATASET_OPERATOR.value: "dataset_operator",
|
||||
}.get(legacy_role)
|
||||
if not expected_builtin_tag:
|
||||
if legacy_role not in _LEGACY_ROLE_TO_BUILTIN_TAG:
|
||||
raise ValueError(f"Unsupported legacy workspace role: {legacy_role}")
|
||||
|
||||
roles = RBACService.Roles.list(
|
||||
return _resolve_builtin_role_ids(tenant_id, operator_account_id)[legacy_role]
|
||||
|
||||
|
||||
def _iter_tenant_member_batches(
|
||||
tenant_id: str | None,
|
||||
*,
|
||||
db_batch_size: int,
|
||||
api_batch_size: int,
|
||||
) -> Iterator[tuple[str, str, list[tuple[str, str]]]]:
|
||||
"""Yield legacy member roles in tenant-scoped API-sized batches.
|
||||
|
||||
Rows are projected to primitive values and streamed from the database, so
|
||||
the command never materializes every TenantAccountJoin ORM object. The
|
||||
iterator only keeps one tenant's API-sized batches in memory while it
|
||||
finds that tenant's owner account.
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
stmt = (
|
||||
select(TenantAccountJoin.tenant_id, TenantAccountJoin.account_id, TenantAccountJoin.role)
|
||||
.order_by(TenantAccountJoin.tenant_id.asc(), TenantAccountJoin.id.asc())
|
||||
.execution_options(yield_per=db_batch_size)
|
||||
)
|
||||
if tenant_id:
|
||||
stmt = stmt.where(TenantAccountJoin.tenant_id == tenant_id)
|
||||
|
||||
current_tenant_id: str | None = None
|
||||
owner_account_id: str | None = None
|
||||
batches: list[list[tuple[str, str]]] = []
|
||||
batch: list[tuple[str, str]] = []
|
||||
|
||||
def flush_current_tenant() -> Iterator[tuple[str, str, list[tuple[str, str]]]]:
|
||||
if current_tenant_id is None:
|
||||
return
|
||||
if batch:
|
||||
batches.append(batch.copy())
|
||||
if not owner_account_id:
|
||||
raise ValueError(f"Workspace owner not found for tenant={current_tenant_id}")
|
||||
for item in batches:
|
||||
yield current_tenant_id, owner_account_id, item
|
||||
|
||||
for row in session.execute(stmt):
|
||||
workspace_id = str(row.tenant_id)
|
||||
if current_tenant_id is not None and workspace_id != current_tenant_id:
|
||||
yield from flush_current_tenant()
|
||||
owner_account_id = None
|
||||
batches = []
|
||||
batch = []
|
||||
current_tenant_id = workspace_id
|
||||
account_id = str(row.account_id)
|
||||
role = str(row.role)
|
||||
if role == TenantAccountRole.OWNER.value:
|
||||
owner_account_id = account_id
|
||||
batch.append((account_id, role))
|
||||
if len(batch) >= api_batch_size:
|
||||
batches.append(batch)
|
||||
batch = []
|
||||
|
||||
yield from flush_current_tenant()
|
||||
|
||||
|
||||
def _member_already_has_role(current_roles_by_account_id: dict[str, set[str]], account_id: str, role_id: str) -> bool:
|
||||
return current_roles_by_account_id.get(account_id) == {role_id}
|
||||
|
||||
|
||||
def _replace_member_role(
|
||||
tenant_id: str,
|
||||
operator_account_id: str,
|
||||
member_account_id: str,
|
||||
role_id: str,
|
||||
) -> str:
|
||||
RBACService.MemberRoles.replace(
|
||||
tenant_id=tenant_id,
|
||||
account_id=operator_account_id,
|
||||
options=ListOption(page_number=1, results_per_page=100),
|
||||
).data
|
||||
for role in roles:
|
||||
if role.is_builtin and role.category == "global_system_default" and role.role_tag == expected_builtin_tag:
|
||||
return str(role.id)
|
||||
|
||||
raise ValueError(f"Builtin RBAC role not found for tenant={tenant_id}, legacy_role={legacy_role}")
|
||||
member_account_id=member_account_id,
|
||||
role_ids=[role_id],
|
||||
)
|
||||
return member_account_id
|
||||
|
||||
|
||||
@click.command(
|
||||
@ -42,7 +142,16 @@ def _resolve_builtin_role_id(tenant_id: str, operator_account_id: str, legacy_ro
|
||||
)
|
||||
@click.option("--tenant-id", help="Only migrate a single workspace.")
|
||||
@click.option("--dry-run", is_flag=True, default=False, help="Preview the migration without writing RBAC bindings.")
|
||||
def migrate_member_roles_to_rbac(tenant_id: str | None, dry_run: bool) -> None:
|
||||
@click.option("--db-batch-size", default=5000, show_default=True, help="Rows fetched per database batch.")
|
||||
@click.option("--api-batch-size", default=200, show_default=True, help="Members checked per RBAC batch_get call.")
|
||||
@click.option("--workers", default=1, show_default=True, help="Concurrent member role replace calls per tenant batch.")
|
||||
def migrate_member_roles_to_rbac(
|
||||
tenant_id: str | None,
|
||||
dry_run: bool,
|
||||
db_batch_size: int,
|
||||
api_batch_size: int,
|
||||
workers: int,
|
||||
) -> None:
|
||||
"""Backfill RBAC member-role bindings from legacy `TenantAccountJoin.role` data.
|
||||
|
||||
This is an offline migration command for workspaces that already have
|
||||
@ -50,63 +159,102 @@ def migrate_member_roles_to_rbac(tenant_id: str | None, dry_run: bool) -> None:
|
||||
member-role binding store.
|
||||
"""
|
||||
click.echo(click.style("Starting RBAC member-role migration.", fg="green"))
|
||||
if workers < 1:
|
||||
raise click.BadParameter("workers must be >= 1", param_hint="--workers")
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
stmt = select(TenantAccountJoin).order_by(TenantAccountJoin.tenant_id.asc(), TenantAccountJoin.id.asc())
|
||||
if tenant_id:
|
||||
stmt = stmt.where(TenantAccountJoin.tenant_id == tenant_id)
|
||||
tenant_count = 0
|
||||
scanned_count = 0
|
||||
skipped_count = 0
|
||||
migrated_count = 0
|
||||
current_tenant_id: str | None = None
|
||||
role_ids_by_legacy_role: dict[str, str] = {}
|
||||
|
||||
joins = list(session.scalars(stmt).all())
|
||||
for workspace_id, owner_account_id, batch in _iter_tenant_member_batches(
|
||||
tenant_id,
|
||||
db_batch_size=db_batch_size,
|
||||
api_batch_size=api_batch_size,
|
||||
):
|
||||
scanned_count += len(batch)
|
||||
if workspace_id != current_tenant_id:
|
||||
tenant_count += 1
|
||||
current_tenant_id = workspace_id
|
||||
role_ids_by_legacy_role = _resolve_builtin_role_ids(workspace_id, owner_account_id)
|
||||
click.echo(f"tenant={workspace_id}")
|
||||
|
||||
if not joins:
|
||||
current_roles_by_account_id: dict[str, set[str]] = {}
|
||||
if not dry_run:
|
||||
current_roles = RBACService.MemberRoles.batch_get(
|
||||
tenant_id=workspace_id,
|
||||
account_id=owner_account_id,
|
||||
member_account_ids=[account_id for account_id, _ in batch],
|
||||
)
|
||||
current_roles_by_account_id = {
|
||||
item.account_id: {str(role.id) for role in item.roles} for item in current_roles
|
||||
}
|
||||
|
||||
replace_jobs: list[tuple[str, str]] = []
|
||||
for member_account_id, legacy_role in batch:
|
||||
resolved_role_id = role_ids_by_legacy_role.get(legacy_role)
|
||||
if not resolved_role_id:
|
||||
raise ValueError(f"Unsupported legacy workspace role: {legacy_role}")
|
||||
|
||||
if dry_run:
|
||||
click.echo(
|
||||
f"tenant={workspace_id} member={member_account_id} "
|
||||
f"legacy_role={legacy_role} -> rbac_role_id={resolved_role_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
if _member_already_has_role(current_roles_by_account_id, member_account_id, resolved_role_id):
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
replace_jobs.append((member_account_id, resolved_role_id))
|
||||
|
||||
if replace_jobs:
|
||||
if workers == 1:
|
||||
for member_account_id, resolved_role_id in replace_jobs:
|
||||
_replace_member_role(workspace_id, owner_account_id, member_account_id, resolved_role_id)
|
||||
migrated_count += 1
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
_replace_member_role,
|
||||
workspace_id,
|
||||
owner_account_id,
|
||||
member_account_id,
|
||||
resolved_role_id,
|
||||
)
|
||||
for member_account_id, resolved_role_id in replace_jobs
|
||||
]
|
||||
for future in as_completed(futures):
|
||||
future.result()
|
||||
migrated_count += 1
|
||||
|
||||
if scanned_count % 10000 == 0:
|
||||
click.echo(
|
||||
f"progress scanned={scanned_count} migrated={migrated_count} skipped={skipped_count}",
|
||||
err=True,
|
||||
)
|
||||
|
||||
if scanned_count == 0:
|
||||
click.echo(click.style("No workspace members found for migration.", fg="yellow"))
|
||||
return
|
||||
|
||||
owner_account_by_tenant: dict[str, str] = {}
|
||||
resolved_role_ids: dict[tuple[str, str], str] = {}
|
||||
migrated_count = 0
|
||||
|
||||
for join in joins:
|
||||
workspace_id = str(join.tenant_id)
|
||||
member_account_id = str(join.account_id)
|
||||
legacy_role = str(join.role)
|
||||
|
||||
if workspace_id not in owner_account_by_tenant:
|
||||
owner_join = next(
|
||||
(
|
||||
item
|
||||
for item in joins
|
||||
if str(item.tenant_id) == workspace_id and str(item.role) == TenantAccountRole.OWNER.value
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not owner_join:
|
||||
raise ValueError(f"Workspace owner not found for tenant={workspace_id}")
|
||||
owner_account_by_tenant[workspace_id] = str(owner_join.account_id)
|
||||
|
||||
operator_account_id = owner_account_by_tenant[workspace_id]
|
||||
cache_key = (workspace_id, legacy_role)
|
||||
if cache_key not in resolved_role_ids:
|
||||
resolved_role_ids[cache_key] = _resolve_builtin_role_id(workspace_id, operator_account_id, legacy_role)
|
||||
|
||||
resolved_role_id = resolved_role_ids[cache_key]
|
||||
click.echo(
|
||||
f"tenant={workspace_id} member={member_account_id} "
|
||||
f"legacy_role={legacy_role} -> rbac_role_id={resolved_role_id}"
|
||||
)
|
||||
|
||||
if dry_run:
|
||||
continue
|
||||
|
||||
RBACService.MemberRoles.replace(
|
||||
tenant_id=workspace_id,
|
||||
account_id=operator_account_id,
|
||||
member_account_id=member_account_id,
|
||||
role_ids=[resolved_role_id],
|
||||
)
|
||||
migrated_count += 1
|
||||
|
||||
if dry_run:
|
||||
click.echo(click.style("Dry run completed. No RBAC bindings were written.", fg="yellow"))
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Dry run completed. Scanned {scanned_count} members across {tenant_count} tenants. "
|
||||
"No RBAC bindings were written.",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
else:
|
||||
click.echo(click.style(f"RBAC member-role migration completed. Migrated {migrated_count} members.", fg="green"))
|
||||
click.echo(
|
||||
click.style(
|
||||
f"RBAC member-role migration completed. Scanned {scanned_count} members across {tenant_count} tenants, "
|
||||
f"migrated {migrated_count}, skipped {skipped_count} already up-to-date.",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
@ -144,7 +144,7 @@ class AnalyticdbVectorBySql:
|
||||
f"id text PRIMARY KEY,"
|
||||
f"vector real[], ref_doc_id text, page_content text, metadata_ jsonb, "
|
||||
f"to_tsvector TSVECTOR"
|
||||
f") WITH (fillfactor=70) DISTRIBUTED BY (id);"
|
||||
f") DISTRIBUTED BY (id);"
|
||||
)
|
||||
if embedding_dimension is not None:
|
||||
index_name = f"{self._collection_name}_embedding_idx"
|
||||
@ -153,7 +153,7 @@ class AnalyticdbVectorBySql:
|
||||
cur.execute(
|
||||
f"CREATE INDEX {index_name} ON {self.table_name} USING ann(vector) "
|
||||
f"WITH(dim='{embedding_dimension}', distancemeasure='{self.config.metrics}', "
|
||||
f"pq_enable=0, external_storage=0)"
|
||||
f"pq_enable=0)"
|
||||
)
|
||||
cur.execute(f"CREATE INDEX ON {self.table_name} USING gin(to_tsvector)")
|
||||
except Exception as e:
|
||||
|
||||
@ -435,6 +435,7 @@ _LEGACY_APP_EDITOR_KEYS: list[str] = [
|
||||
"app.acl.delete",
|
||||
"app.acl.release_and_version",
|
||||
"app.acl.monitor",
|
||||
"app.acl.log_and_annotation",
|
||||
"app.acl.access_config",
|
||||
]
|
||||
|
||||
|
||||
@ -40,6 +40,7 @@ from services.errors.app import QuotaExceededError
|
||||
from services.quota_service import QuotaService
|
||||
from services.trigger.app_trigger_service import AppTriggerService
|
||||
from services.workflow.entities import WebhookTriggerData
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
try:
|
||||
import magic
|
||||
@ -114,6 +115,7 @@ class WebhookService:
|
||||
workflow = session.scalar(
|
||||
select(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == webhook_trigger.tenant_id,
|
||||
Workflow.app_id == webhook_trigger.app_id,
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
)
|
||||
@ -125,6 +127,7 @@ class WebhookService:
|
||||
app_trigger = session.scalar(
|
||||
select(AppTrigger)
|
||||
.where(
|
||||
AppTrigger.tenant_id == webhook_trigger.tenant_id,
|
||||
AppTrigger.app_id == webhook_trigger.app_id,
|
||||
AppTrigger.node_id == webhook_trigger.node_id,
|
||||
AppTrigger.trigger_type == AppTriggerType.TRIGGER_WEBHOOK,
|
||||
@ -145,16 +148,18 @@ class WebhookService:
|
||||
if app_trigger.status != AppTriggerStatus.ENABLED:
|
||||
raise ValueError(f"Webhook trigger is disabled for webhook {webhook_id}")
|
||||
|
||||
# Get workflow
|
||||
workflow = session.scalar(
|
||||
select(Workflow)
|
||||
app = session.scalar(
|
||||
select(App)
|
||||
.where(
|
||||
Workflow.app_id == webhook_trigger.app_id,
|
||||
Workflow.version != Workflow.VERSION_DRAFT,
|
||||
App.tenant_id == webhook_trigger.tenant_id,
|
||||
App.id == webhook_trigger.app_id,
|
||||
)
|
||||
.order_by(Workflow.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
if not app:
|
||||
raise ValueError(f"App not found for webhook {webhook_id}")
|
||||
|
||||
workflow = WorkflowService().get_published_workflow(app, session=session)
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found for app {webhook_trigger.app_id}")
|
||||
|
||||
|
||||
@ -333,7 +333,7 @@ class VectorService:
|
||||
|
||||
# Add documents to vector store if any
|
||||
if documents and dataset.is_multimodal:
|
||||
vector.add_texts(documents, duplicate_check=True)
|
||||
vector.create_multimodal(documents)
|
||||
|
||||
# Single commit for all operations
|
||||
db.session.commit()
|
||||
|
||||
@ -12,7 +12,7 @@ from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from celery import shared_task
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
@ -35,7 +35,7 @@ from models.enums import (
|
||||
WorkflowRunTriggeredFrom,
|
||||
WorkflowTriggerStatus,
|
||||
)
|
||||
from models.model import EndUser
|
||||
from models.model import App, EndUser
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from models.trigger import TriggerSubscription, WorkflowPluginTrigger, WorkflowTriggerLog
|
||||
from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, WorkflowRun
|
||||
@ -99,23 +99,25 @@ def dispatch_trigger_debug_event(
|
||||
return 0
|
||||
|
||||
|
||||
def _get_latest_workflows_by_app_ids(
|
||||
def _get_published_workflows_by_app_ids(
|
||||
session: Session, subscribers: Sequence[WorkflowPluginTrigger]
|
||||
) -> Mapping[str, Workflow]:
|
||||
"""Get the latest workflows by app_ids"""
|
||||
workflow_query = (
|
||||
select(Workflow.app_id, func.max(Workflow.created_at).label("max_created_at"))
|
||||
.where(
|
||||
Workflow.app_id.in_({t.app_id for t in subscribers}),
|
||||
Workflow.version != Workflow.VERSION_DRAFT,
|
||||
)
|
||||
.group_by(Workflow.app_id)
|
||||
.subquery()
|
||||
)
|
||||
"""Get current published workflows through apps.workflow_id."""
|
||||
app_ids = {trigger.app_id for trigger in subscribers}
|
||||
tenant_ids = {trigger.tenant_id for trigger in subscribers}
|
||||
if not app_ids or not tenant_ids:
|
||||
return {}
|
||||
|
||||
workflows = session.scalars(
|
||||
select(Workflow).join(
|
||||
workflow_query,
|
||||
(Workflow.app_id == workflow_query.c.app_id) & (Workflow.created_at == workflow_query.c.max_created_at),
|
||||
select(Workflow)
|
||||
.join(App, App.workflow_id == Workflow.id)
|
||||
.where(
|
||||
App.id.in_(app_ids),
|
||||
App.tenant_id.in_(tenant_ids),
|
||||
App.workflow_id.isnot(None),
|
||||
Workflow.app_id == App.id,
|
||||
Workflow.tenant_id == App.tenant_id,
|
||||
Workflow.version != Workflow.VERSION_DRAFT,
|
||||
)
|
||||
).all()
|
||||
return {w.app_id: w for w in workflows}
|
||||
@ -262,7 +264,7 @@ def dispatch_triggered_workflow(
|
||||
|
||||
# Ensure expire_on_commit is set to False to remain workflows available
|
||||
with session_factory.create_session() as session:
|
||||
workflows: Mapping[str, Workflow] = _get_latest_workflows_by_app_ids(session, subscribers)
|
||||
workflows: Mapping[str, Workflow] = _get_published_workflows_by_app_ids(session, subscribers)
|
||||
|
||||
end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch(
|
||||
type=EndUserType.TRIGGER,
|
||||
|
||||
@ -127,6 +127,9 @@ class TestWebhookService:
|
||||
db_session_with_containers.add(workflow)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
app.workflow_id = workflow.id
|
||||
db_session_with_containers.flush()
|
||||
|
||||
# Create webhook trigger
|
||||
webhook_id = fake.uuid4()[:16]
|
||||
webhook_trigger = WorkflowWebhookTrigger(
|
||||
|
||||
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
@ -240,6 +241,40 @@ class TestWebhookServiceLookupWithContainers:
|
||||
with pytest.raises(ValueError, match="Workflow not found"):
|
||||
WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id)
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_uses_app_workflow_id(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers: Flask
|
||||
):
|
||||
del flask_app_with_containers
|
||||
factory = WebhookServiceRelationshipFactory
|
||||
account, tenant = factory.create_account_and_tenant(db_session_with_containers)
|
||||
app = factory.create_app(db_session_with_containers, tenant, account)
|
||||
current_workflow = factory.create_workflow(
|
||||
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001"
|
||||
)
|
||||
newer_workflow = factory.create_workflow(
|
||||
db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-15.001"
|
||||
)
|
||||
current_workflow.created_at = datetime(2026, 4, 14)
|
||||
newer_workflow.created_at = datetime(2026, 4, 15)
|
||||
app.workflow_id = current_workflow.id
|
||||
db_session_with_containers.commit()
|
||||
|
||||
webhook_trigger = factory.create_webhook_trigger(
|
||||
db_session_with_containers, app=app, account=account, node_id="node-1"
|
||||
)
|
||||
factory.create_app_trigger(
|
||||
db_session_with_containers, app=app, node_id="node-1", status=AppTriggerStatus.ENABLED
|
||||
)
|
||||
|
||||
got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow(
|
||||
webhook_trigger.webhook_id
|
||||
)
|
||||
|
||||
assert got_trigger.id == webhook_trigger.id
|
||||
assert got_workflow.id == current_workflow.id
|
||||
assert got_workflow.id != newer_workflow.id
|
||||
assert got_node_config["id"] == "node-1"
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_returns_debug_draft_workflow(
|
||||
self, db_session_with_containers: Session, flask_app_with_containers: Flask
|
||||
):
|
||||
|
||||
@ -633,6 +633,8 @@ class TestMyPermissions:
|
||||
assert "dataset.acl.preview" in out.workspace.permission_keys
|
||||
assert "app.acl.preview" in out.app.default_permission_keys
|
||||
assert "dataset.acl.preview" in out.dataset.default_permission_keys
|
||||
if role == "editor":
|
||||
assert "app.acl.log_and_annotation" in out.app.default_permission_keys
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("role", "expected_snippet_keys"),
|
||||
|
||||
@ -639,8 +639,8 @@ def test_update_multimodel_vector_adds_bindings_and_vectors_and_skips_missing_up
|
||||
assert len(bindings) == 1
|
||||
assert bindings[0]["attachment_id"] == "file-1"
|
||||
|
||||
vector_instance.add_texts.assert_called_once()
|
||||
documents = vector_instance.add_texts.call_args.args[0]
|
||||
vector_instance.create_multimodal.assert_called_once()
|
||||
documents = vector_instance.create_multimodal.call_args.args[0]
|
||||
assert len(documents) == 1
|
||||
assert documents[0].page_content == "img.png"
|
||||
assert documents[0].metadata["doc_id"] == "file-1"
|
||||
|
||||
@ -98,7 +98,7 @@ class TestDispatchTriggeredWorkflow:
|
||||
),
|
||||
patch.object(
|
||||
trigger_processing_tasks_module,
|
||||
"_get_latest_workflows_by_app_ids",
|
||||
"_get_published_workflows_by_app_ids",
|
||||
) as get_workflows,
|
||||
patch.object(
|
||||
trigger_processing_tasks_module.EndUserService,
|
||||
|
||||
@ -2430,11 +2430,6 @@
|
||||
"count": 2
|
||||
}
|
||||
},
|
||||
"web/app/components/base/voice-input/utils.ts": {
|
||||
"ts/no-explicit-any": {
|
||||
"count": 4
|
||||
}
|
||||
},
|
||||
"web/app/components/billing/plan/assets/index.tsx": {
|
||||
"no-barrel-files/no-barrel-files": {
|
||||
"count": 4
|
||||
@ -7449,11 +7444,6 @@
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"web/types/lamejs.d.ts": {
|
||||
"ts/no-explicit-any": {
|
||||
"count": 3
|
||||
}
|
||||
},
|
||||
"web/types/pipeline.tsx": {
|
||||
"ts/no-explicit-any": {
|
||||
"count": 3
|
||||
|
||||
62
pnpm-lock.yaml
generated
62
pnpm-lock.yaml
generated
@ -84,6 +84,9 @@ catalogs:
|
||||
'@mdx-js/rollup':
|
||||
specifier: 3.1.1
|
||||
version: 3.1.1
|
||||
'@mediabunny/mp3-encoder':
|
||||
specifier: 1.49.0
|
||||
version: 1.49.0
|
||||
'@monaco-editor/react':
|
||||
specifier: 4.7.0
|
||||
version: 4.7.0
|
||||
@ -429,9 +432,6 @@ catalogs:
|
||||
ky:
|
||||
specifier: 2.0.2
|
||||
version: 2.0.2
|
||||
lamejs:
|
||||
specifier: 1.2.1
|
||||
version: 1.2.1
|
||||
lexical:
|
||||
specifier: 0.46.0
|
||||
version: 0.46.0
|
||||
@ -441,6 +441,9 @@ catalogs:
|
||||
loro-crdt:
|
||||
specifier: 1.13.6
|
||||
version: 1.13.6
|
||||
mediabunny:
|
||||
specifier: 1.49.0
|
||||
version: 1.49.0
|
||||
mermaid:
|
||||
specifier: 11.16.0
|
||||
version: 11.16.0
|
||||
@ -1108,6 +1111,9 @@ importers:
|
||||
'@lexical/utils':
|
||||
specifier: 'catalog:'
|
||||
version: 0.46.0(typescript@6.0.3)
|
||||
'@mediabunny/mp3-encoder':
|
||||
specifier: 'catalog:'
|
||||
version: 1.49.0(mediabunny@1.49.0)
|
||||
'@monaco-editor/react':
|
||||
specifier: 'catalog:'
|
||||
version: 4.7.0(react-dom@19.2.7(react@19.2.7))(react@19.2.7)
|
||||
@ -1267,15 +1273,15 @@ importers:
|
||||
ky:
|
||||
specifier: 'catalog:'
|
||||
version: 2.0.2
|
||||
lamejs:
|
||||
specifier: 'catalog:'
|
||||
version: 1.2.1
|
||||
lexical:
|
||||
specifier: 'catalog:'
|
||||
version: 0.46.0(typescript@6.0.3)
|
||||
loro-crdt:
|
||||
specifier: 'catalog:'
|
||||
version: 1.13.6
|
||||
mediabunny:
|
||||
specifier: 'catalog:'
|
||||
version: 1.49.0
|
||||
mermaid:
|
||||
specifier: 'catalog:'
|
||||
version: 11.16.0
|
||||
@ -2983,6 +2989,11 @@ packages:
|
||||
peerDependencies:
|
||||
rollup: 4.62.2
|
||||
|
||||
'@mediabunny/mp3-encoder@1.49.0':
|
||||
resolution: {integrity: sha512-b2US4tYGGou3Yo4nZnCU3NoIu8DPW6UO6WFz5KE7gbVH3iCQRT8hSg6nTAVHtO+isvm9Vtl0fsF0KnSVpAPwRQ==}
|
||||
peerDependencies:
|
||||
mediabunny: ^1.0.0
|
||||
|
||||
'@mermaid-js/parser@1.2.0':
|
||||
resolution: {integrity: sha512-oYPyv8A4As1yH5Bx+04iQEQxXuIQDe0GKCNSRgao6z8AM9jixXIfP0vsppRLvGf+nKIOb9/LdpWA4YuJiVvESA==}
|
||||
|
||||
@ -4886,6 +4897,12 @@ packages:
|
||||
'@types/doctrine@0.0.9':
|
||||
resolution: {integrity: sha512-eOIHzCUSH7SMfonMG1LsC2f8vxBFtho6NGBznK41R84YzPuvSBzrhEps33IsQiOW9+VL6NQ9DbjQJznk/S4uRA==}
|
||||
|
||||
'@types/dom-mediacapture-transform@0.1.11':
|
||||
resolution: {integrity: sha512-Y2p+nGf1bF2XMttBnsVPHUWzRRZzqUoJAKmiP10b5umnO6DDrWI0BrGDJy1pOHoOULVmGSfFNkQrAlC5dcj6nQ==}
|
||||
|
||||
'@types/dom-webcodecs@0.1.13':
|
||||
resolution: {integrity: sha512-O5hkiFIcjjszPIYyUSyvScyvrBoV3NOEEZx/pMlsu44TKzWNkLVBBxnxJz42in5n3QIolYOcBYFCPZZ0h8SkwQ==}
|
||||
|
||||
'@types/esrecurse@4.3.1':
|
||||
resolution: {integrity: sha512-xJBAbDifo5hpffDBuHl0Y8ywswbiAp/Wi7Y/GtAgSlZyIABppyurxVueOPE8LUQOxdlgi6Zqce7uoEpqNTeiUw==}
|
||||
|
||||
@ -7492,9 +7509,6 @@ packages:
|
||||
resolution: {integrity: sha512-/GmXpo9F9W+f8n4Ivr2iH+7h7wL7jLbLKWkMlpflcCRb6kGjBfTlASEXaZ9qUgNTn4VgS0P2pwxxzQ4EM6Ulgg==}
|
||||
engines: {node: '>=22'}
|
||||
|
||||
lamejs@1.2.1:
|
||||
resolution: {integrity: sha512-s7bxvjvYthw6oPLCm5pFxvA84wUROODB8jEO2+CE1adhKgrIvVOlmMgY8zyugxGrvRaDHNJanOiS21/emty6dQ==}
|
||||
|
||||
language-subtag-registry@0.3.23:
|
||||
resolution: {integrity: sha512-0K65Lea881pHotoGEa5gDlMxt3pctLi2RplBb7Ezh4rRdLEOtgi7n4EwK9lamnUCkKBqaeKRVebTq6BAxSkpXQ==}
|
||||
|
||||
@ -7764,6 +7778,9 @@ packages:
|
||||
mdn-data@2.28.1:
|
||||
resolution: {integrity: sha512-U9w+PzSZ00Z5m9rZ5ARVFL5xOfuCHdKYi/1RRwDCJsboFgJDNT3zT6PIPD7mZQYaQLhsZM3GfDRgSMRHhSmVng==}
|
||||
|
||||
mediabunny@1.49.0:
|
||||
resolution: {integrity: sha512-hDMtS/q22GFjyB3Yum+a6NHhDtVnye6hgaYTjx3DPUfnckRjzxFmXmCc3UQXdvo7d6PUCl7KM6Y9MRWzlvnAJQ==}
|
||||
|
||||
merge2@1.4.1:
|
||||
resolution: {integrity: sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg==}
|
||||
engines: {node: '>= 8'}
|
||||
@ -9576,9 +9593,6 @@ packages:
|
||||
'@types/react':
|
||||
optional: true
|
||||
|
||||
use-strict@1.0.1:
|
||||
resolution: {integrity: sha512-IeiWvvEXfW5ltKVMkxq6FvNf2LojMKvB2OCeja6+ct24S1XOmQw2dGr2JyndwACWAGJva9B7yPHwAmeA9QCqAQ==}
|
||||
|
||||
use-sync-external-store@1.6.0:
|
||||
resolution: {integrity: sha512-Pp6GSwGP/NrPIrxVFAIkOQeyw8lFenOHijQWkUTrDvrF4ALqylP2C/KCkeS9dpUM3KvYRQhna5vt7IL95+ZQ9w==}
|
||||
peerDependencies:
|
||||
@ -11709,6 +11723,10 @@ snapshots:
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
'@mediabunny/mp3-encoder@1.49.0(mediabunny@1.49.0)':
|
||||
dependencies:
|
||||
mediabunny: 1.49.0
|
||||
|
||||
'@mermaid-js/parser@1.2.0':
|
||||
dependencies:
|
||||
'@chevrotain/types': 11.1.2
|
||||
@ -13340,6 +13358,12 @@ snapshots:
|
||||
|
||||
'@types/doctrine@0.0.9': {}
|
||||
|
||||
'@types/dom-mediacapture-transform@0.1.11':
|
||||
dependencies:
|
||||
'@types/dom-webcodecs': 0.1.13
|
||||
|
||||
'@types/dom-webcodecs@0.1.13': {}
|
||||
|
||||
'@types/esrecurse@4.3.1': {}
|
||||
|
||||
'@types/estree-jsx@1.0.5':
|
||||
@ -16768,10 +16792,6 @@ snapshots:
|
||||
|
||||
ky@2.0.2: {}
|
||||
|
||||
lamejs@1.2.1:
|
||||
dependencies:
|
||||
use-strict: 1.0.1
|
||||
|
||||
language-subtag-registry@0.3.23: {}
|
||||
|
||||
language-tags@1.0.9:
|
||||
@ -17141,6 +17161,11 @@ snapshots:
|
||||
|
||||
mdn-data@2.28.1: {}
|
||||
|
||||
mediabunny@1.49.0:
|
||||
dependencies:
|
||||
'@types/dom-mediacapture-transform': 0.1.11
|
||||
'@types/dom-webcodecs': 0.1.13
|
||||
|
||||
merge2@1.4.1: {}
|
||||
|
||||
mermaid@11.16.0:
|
||||
@ -19426,8 +19451,6 @@ snapshots:
|
||||
optionalDependencies:
|
||||
'@types/react': 19.2.17
|
||||
|
||||
use-strict@1.0.1: {}
|
||||
|
||||
use-sync-external-store@1.6.0(react@19.2.7):
|
||||
dependencies:
|
||||
react: 19.2.7
|
||||
@ -20066,6 +20089,7 @@ time:
|
||||
'@mdx-js/loader@3.1.1': '2025-08-29T18:03:05.606Z'
|
||||
'@mdx-js/react@3.1.1': '2025-08-29T18:02:56.462Z'
|
||||
'@mdx-js/rollup@3.1.1': '2025-08-29T18:03:10.680Z'
|
||||
'@mediabunny/mp3-encoder@1.49.0': '2026-06-18T17:50:53.396Z'
|
||||
'@monaco-editor/react@4.7.0': '2025-02-13T16:13:41.390Z'
|
||||
'@napi-rs/keyring@1.3.0': '2026-04-30T09:56:44.246Z'
|
||||
'@next/eslint-plugin-next@16.2.9': '2026-06-09T23:01:50.881Z'
|
||||
@ -20182,11 +20206,11 @@ time:
|
||||
katex@0.17.0: '2026-05-22T08:06:26.967Z'
|
||||
knip@6.22.0: '2026-06-27T13:30:51.275Z'
|
||||
ky@2.0.2: '2026-04-21T08:58:46.923Z'
|
||||
lamejs@1.2.1: '2021-12-02T15:44:40.036Z'
|
||||
lexical-code-no-prism@0.41.0: '2026-03-08T16:50:40.266Z'
|
||||
lexical@0.46.0: '2026-06-26T04:53:00.532Z'
|
||||
lockfile@1.0.4: '2018-04-17T00:36:12.565Z'
|
||||
loro-crdt@1.13.6: '2026-06-21T15:40:04.671Z'
|
||||
mediabunny@1.49.0: '2026-06-18T17:50:36.336Z'
|
||||
mermaid@11.16.0: '2026-06-25T11:30:40.280Z'
|
||||
mime@4.1.0: '2025-09-12T17:53:01.376Z'
|
||||
mitt@3.0.1: '2023-07-04T17:31:47.638Z'
|
||||
|
||||
@ -78,6 +78,7 @@ catalog:
|
||||
'@mdx-js/loader': 3.1.1
|
||||
'@mdx-js/react': 3.1.1
|
||||
'@mdx-js/rollup': 3.1.1
|
||||
'@mediabunny/mp3-encoder': 1.49.0
|
||||
'@monaco-editor/react': 4.7.0
|
||||
'@napi-rs/keyring': 1.3.0
|
||||
'@next/eslint-plugin-next': 16.2.9
|
||||
@ -193,10 +194,10 @@ catalog:
|
||||
katex: 0.17.0
|
||||
knip: 6.22.0
|
||||
ky: 2.0.2
|
||||
lamejs: 1.2.1
|
||||
lexical: 0.46.0
|
||||
lockfile: 1.0.4
|
||||
loro-crdt: 1.13.6
|
||||
mediabunny: 1.49.0
|
||||
mermaid: 11.16.0
|
||||
mime: 4.1.0
|
||||
mitt: 3.0.1
|
||||
|
||||
@ -232,7 +232,20 @@ describe('Billing Page + Plan Integration', () => {
|
||||
|
||||
// Verify billing URL button visibility and behavior
|
||||
describe('Billing URL button', () => {
|
||||
it('should show billing button when subscription management permission is granted', () => {
|
||||
it('should show billing button when manager has subscription management permission', () => {
|
||||
setupProviderContext({ type: Plan.sandbox })
|
||||
setupAppContext({
|
||||
isCurrentWorkspaceManager: true,
|
||||
workspacePermissionKeys: ['billing.subscription.manage'],
|
||||
})
|
||||
|
||||
render(<Billing />)
|
||||
|
||||
expect(screen.getByText(/viewBillingTitle/i)).toBeInTheDocument()
|
||||
expect(screen.getByText(/viewBillingAction/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should hide billing button when subscription management permission is granted without manager role', () => {
|
||||
setupProviderContext({ type: Plan.sandbox })
|
||||
setupAppContext({
|
||||
isCurrentWorkspaceManager: false,
|
||||
@ -241,8 +254,7 @@ describe('Billing Page + Plan Integration', () => {
|
||||
|
||||
render(<Billing />)
|
||||
|
||||
expect(screen.getByText(/viewBillingTitle/i)).toBeInTheDocument()
|
||||
expect(screen.getByText(/viewBillingAction/i)).toBeInTheDocument()
|
||||
expect(screen.queryByText(/viewBillingTitle/i)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should hide billing button when subscription management permission is missing', () => {
|
||||
|
||||
@ -21,6 +21,7 @@ let mockChatConversationDetail: Record<string, unknown> | undefined
|
||||
let mockCompletionConversationDetail: Record<string, unknown> | undefined
|
||||
let mockShowMessageLogModal = false
|
||||
let mockShowPromptLogModal = false
|
||||
let mockShowAgentLogModal = false
|
||||
let mockCurrentLogItem: Record<string, unknown> | undefined
|
||||
let mockCurrentLogModalActiveTab = 'messages'
|
||||
|
||||
@ -81,6 +82,7 @@ vi.mock('@/app/components/app/store', () => ({
|
||||
setShowAgentLogModal: mockSetShowAgentLogModal,
|
||||
setShowMessageLogModal: mockSetShowMessageLogModal,
|
||||
showPromptLogModal: mockShowPromptLogModal,
|
||||
showAgentLogModal: mockShowAgentLogModal,
|
||||
currentLogModalActiveTab: mockCurrentLogModalActiveTab,
|
||||
}),
|
||||
}))
|
||||
@ -126,6 +128,7 @@ vi.mock('@/app/components/base/chat/chat', () => ({
|
||||
onAnnotationEdited,
|
||||
onAnnotationRemoved,
|
||||
switchSibling,
|
||||
hideLogModal,
|
||||
}: {
|
||||
chatList: Array<{ id: string }>
|
||||
onFeedback: (mid: string, value: { rating: string, content?: string }) => Promise<boolean>
|
||||
@ -133,8 +136,9 @@ vi.mock('@/app/components/base/chat/chat', () => ({
|
||||
onAnnotationEdited: (query: string, answer: string, index: number) => void
|
||||
onAnnotationRemoved: (index: number) => Promise<boolean>
|
||||
switchSibling: (siblingMessageId: string) => void
|
||||
hideLogModal?: boolean
|
||||
}) => (
|
||||
<div data-testid="chat-panel">
|
||||
<div data-testid="chat-panel" data-hide-log-modal={String(hideLogModal)}>
|
||||
<div>{chatList.length}</div>
|
||||
<button onClick={() => void onFeedback('message-1', { rating: 'like', content: 'nice' })}>chat-feedback</button>
|
||||
<button onClick={() => onAnnotationAdded('annotation-2', 'Admin', 'Edited question', 'Edited answer', 1)}>chat-add-annotation</button>
|
||||
@ -145,6 +149,14 @@ vi.mock('@/app/components/base/chat/chat', () => ({
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/agent-log-modal', () => ({
|
||||
default: ({ floating, onCancel }: { floating?: boolean, onCancel: () => void }) => (
|
||||
<div data-testid="agent-log-modal" data-floating={String(floating)}>
|
||||
<button onClick={onCancel}>close-agent-log-modal</button>
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/message-log-modal', () => ({
|
||||
default: ({ onCancel }: { onCancel: () => void }) => (
|
||||
<div data-testid="message-log-modal">
|
||||
@ -255,6 +267,7 @@ describe('ConversationList', () => {
|
||||
mockCompletionConversationDetail = undefined
|
||||
mockShowMessageLogModal = false
|
||||
mockShowPromptLogModal = false
|
||||
mockShowAgentLogModal = false
|
||||
mockCurrentLogItem = undefined
|
||||
mockCurrentLogModalActiveTab = 'messages'
|
||||
mockDelAnnotation.mockResolvedValue(undefined)
|
||||
@ -383,6 +396,7 @@ describe('ConversationList', () => {
|
||||
|
||||
expect(screen.getByTestId('var-panel')).toHaveTextContent('query:Latest question')
|
||||
expect(screen.getByTestId('model-info')).toHaveTextContent('gpt-4o')
|
||||
expect(screen.getByTestId('chat-panel')).toHaveAttribute('data-hide-log-modal', 'true')
|
||||
expect(screen.getByTestId('message-log-modal')).toBeInTheDocument()
|
||||
|
||||
fireEvent.click(screen.getByText('chat-feedback'))
|
||||
@ -399,6 +413,61 @@ describe('ConversationList', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('should mount agent log modals from the detail panel instead of the nested chat layout', async () => {
|
||||
mockChatConversationDetail = {
|
||||
id: 'conversation-1',
|
||||
created_at: 1710000000,
|
||||
model_config: {
|
||||
model: 'gpt-4o',
|
||||
configs: {
|
||||
introduction: 'Hello there',
|
||||
},
|
||||
user_input_form: [],
|
||||
},
|
||||
message: {
|
||||
inputs: {},
|
||||
},
|
||||
}
|
||||
mockShowAgentLogModal = true
|
||||
mockCurrentLogItem = {
|
||||
id: 'message-1',
|
||||
conversationId: 'conversation-1',
|
||||
}
|
||||
mockFetchChatMessages.mockResolvedValue({
|
||||
data: [
|
||||
{
|
||||
id: 'message-1',
|
||||
answer: 'Assistant reply',
|
||||
query: 'Latest question',
|
||||
created_at: 1710000000,
|
||||
inputs: {},
|
||||
feedbacks: [],
|
||||
message: [],
|
||||
message_files: [],
|
||||
agent_thoughts: [{ id: 'thought-1' }],
|
||||
},
|
||||
],
|
||||
has_more: false,
|
||||
})
|
||||
|
||||
renderConversationList({
|
||||
searchParams: '?page=2&conversation_id=conversation-1',
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('chat-panel')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
expect(screen.getByTestId('chat-panel')).toHaveAttribute('data-hide-log-modal', 'true')
|
||||
expect(screen.getByTestId('agent-log-modal')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('agent-log-modal')).toHaveAttribute('data-floating', 'true')
|
||||
|
||||
fireEvent.click(screen.getByText('close-agent-log-modal'))
|
||||
|
||||
expect(mockSetCurrentLogItem).toHaveBeenCalled()
|
||||
expect(mockSetShowAgentLogModal).toHaveBeenCalledWith(false)
|
||||
})
|
||||
|
||||
it('should render completion details and refetch after feedback updates', async () => {
|
||||
mockCompletionConversationDetail = {
|
||||
id: 'conversation-1',
|
||||
@ -424,7 +493,7 @@ describe('ConversationList', () => {
|
||||
},
|
||||
}
|
||||
mockShowPromptLogModal = true
|
||||
mockCurrentLogItem = { id: 'log-2' }
|
||||
mockCurrentLogItem = { id: 'log-2', log: [{ role: 'user', text: 'Prompt body' }] }
|
||||
|
||||
renderConversationList({
|
||||
appDetail: { id: 'app-1', mode: AppModeEnum.COMPLETION } as any,
|
||||
@ -626,7 +695,7 @@ describe('ConversationList', () => {
|
||||
},
|
||||
}
|
||||
mockShowPromptLogModal = true
|
||||
mockCurrentLogItem = { id: 'log-2' }
|
||||
mockCurrentLogItem = { id: 'log-2', log: [{ role: 'user', text: 'Prompt body' }] }
|
||||
|
||||
renderConversationList({
|
||||
appDetail: { id: 'app-1', mode: AppModeEnum.COMPLETION } as any,
|
||||
|
||||
@ -36,6 +36,7 @@ import ModelInfo from '@/app/components/app/log/model-info'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
import TextGeneration from '@/app/components/app/text-generate/item'
|
||||
import ActionButton from '@/app/components/base/action-button'
|
||||
import AgentLogModal from '@/app/components/base/agent-log-modal'
|
||||
import Chat from '@/app/components/base/chat/chat'
|
||||
import CopyIcon from '@/app/components/base/copy-icon'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
@ -165,13 +166,25 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
||||
})
|
||||
const { formatTime } = useTimestamp()
|
||||
const { onClose, appDetail } = useContext(DrawerContext)
|
||||
const { currentLogItem, setCurrentLogItem, showMessageLogModal, setShowMessageLogModal, showPromptLogModal, setShowPromptLogModal, currentLogModalActiveTab } = useAppStore(useShallow((state: AppStoreState) => ({
|
||||
const {
|
||||
currentLogItem,
|
||||
setCurrentLogItem,
|
||||
showMessageLogModal,
|
||||
setShowMessageLogModal,
|
||||
showPromptLogModal,
|
||||
setShowPromptLogModal,
|
||||
showAgentLogModal,
|
||||
setShowAgentLogModal,
|
||||
currentLogModalActiveTab,
|
||||
} = useAppStore(useShallow((state: AppStoreState) => ({
|
||||
currentLogItem: state.currentLogItem,
|
||||
setCurrentLogItem: state.setCurrentLogItem,
|
||||
showMessageLogModal: state.showMessageLogModal,
|
||||
setShowMessageLogModal: state.setShowMessageLogModal,
|
||||
showPromptLogModal: state.showPromptLogModal,
|
||||
setShowPromptLogModal: state.setShowPromptLogModal,
|
||||
showAgentLogModal: state.showAgentLogModal,
|
||||
setShowAgentLogModal: state.setShowAgentLogModal,
|
||||
currentLogModalActiveTab: state.currentLogModalActiveTab,
|
||||
})))
|
||||
const { t } = useTranslation()
|
||||
@ -395,6 +408,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
||||
|
||||
const isChatMode = appDetail?.mode !== AppModeEnum.COMPLETION
|
||||
const isAdvanced = appDetail?.mode === AppModeEnum.ADVANCED_CHAT
|
||||
const shouldShowPromptLogModal = showPromptLogModal && !!currentLogItem?.log
|
||||
|
||||
const varList = getDetailVarList(detail, varValues)
|
||||
const message_files = getCompletionMessageFiles(detail, isChatMode)
|
||||
@ -507,6 +521,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
||||
noChatInput
|
||||
showPromptLog
|
||||
hideProcessDetail
|
||||
hideLogModal
|
||||
chatContainerInnerClassName="px-3"
|
||||
switchSibling={switchSibling}
|
||||
/>
|
||||
@ -546,6 +561,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
||||
noChatInput
|
||||
showPromptLog
|
||||
hideProcessDetail
|
||||
hideLogModal
|
||||
chatContainerInnerClassName="px-3"
|
||||
switchSibling={switchSibling}
|
||||
/>
|
||||
@ -574,7 +590,18 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
||||
/>
|
||||
</WorkflowContextProvider>
|
||||
)}
|
||||
{!isChatMode && showPromptLogModal && (
|
||||
{showAgentLogModal && (
|
||||
<AgentLogModal
|
||||
floating
|
||||
width={width}
|
||||
currentLogItem={currentLogItem}
|
||||
onCancel={() => {
|
||||
setCurrentLogItem()
|
||||
setShowAgentLogModal(false)
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
{shouldShowPromptLogModal && (
|
||||
<PromptLogModal
|
||||
width={width}
|
||||
currentLogItem={currentLogItem}
|
||||
|
||||
@ -119,6 +119,17 @@ describe('AgentLogModal', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('should render the floating modal through a dialog portal', () => {
|
||||
vi.mocked(fetchAgentLogDetail).mockReturnValue(new Promise(() => {}))
|
||||
|
||||
const { container } = render(<AgentLogModal {...mockProps} floating />)
|
||||
|
||||
const modal = screen.getByRole('dialog')
|
||||
expect(container).not.toContainElement(modal)
|
||||
expect(document.body).toContainElement(modal)
|
||||
expect(modal).toHaveClass('fixed', 'z-50', 'w-[480px]!', 'left-[max(8px,calc(100vw-1136px))]!')
|
||||
})
|
||||
|
||||
it('should call onCancel when close button is clicked', () => {
|
||||
vi.mocked(fetchAgentLogDetail).mockReturnValue(new Promise(() => {}))
|
||||
|
||||
@ -158,4 +169,18 @@ describe('AgentLogModal', () => {
|
||||
|
||||
expect(mockProps.onCancel).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should not use click-away to close the floating dialog', () => {
|
||||
vi.mocked(fetchAgentLogDetail).mockReturnValue(new Promise(() => {}))
|
||||
|
||||
let clickAwayHandler!: (event: Event) => void
|
||||
vi.mocked(useClickAway).mockImplementation((callback) => {
|
||||
clickAwayHandler = callback
|
||||
})
|
||||
|
||||
render(<AgentLogModal {...mockProps} floating />)
|
||||
clickAwayHandler(new Event('click'))
|
||||
|
||||
expect(mockProps.onCancel).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import type { FC } from 'react'
|
||||
import type { IChatItem } from '@/app/components/base/chat/chat/type'
|
||||
import { cn } from '@langgenius/dify-ui/cn'
|
||||
import { Dialog, DialogContent, DialogTitle } from '@langgenius/dify-ui/dialog'
|
||||
import { RiCloseLine } from '@remixicon/react'
|
||||
import { useClickAway } from 'ahooks'
|
||||
import { useEffect, useRef, useState } from 'react'
|
||||
@ -10,11 +11,13 @@ import AgentLogDetail from './detail'
|
||||
type AgentLogModalProps = Readonly<{
|
||||
currentLogItem?: IChatItem
|
||||
width: number
|
||||
floating?: boolean
|
||||
onCancel: () => void
|
||||
}>
|
||||
const AgentLogModal: FC<AgentLogModalProps> = ({
|
||||
currentLogItem,
|
||||
width,
|
||||
floating,
|
||||
onCancel,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
@ -22,7 +25,7 @@ const AgentLogModal: FC<AgentLogModalProps> = ({
|
||||
const [mounted, setMounted] = useState(false)
|
||||
|
||||
useClickAway(() => {
|
||||
if (mounted)
|
||||
if (mounted && !floating)
|
||||
onCancel()
|
||||
}, ref)
|
||||
|
||||
@ -33,6 +36,44 @@ const AgentLogModal: FC<AgentLogModalProps> = ({
|
||||
if (!currentLogItem || !currentLogItem.conversationId)
|
||||
return null
|
||||
|
||||
const detailContent = (
|
||||
<>
|
||||
<AgentLogDetail
|
||||
conversationID={currentLogItem.conversationId}
|
||||
messageID={currentLogItem.id}
|
||||
log={currentLogItem}
|
||||
/>
|
||||
</>
|
||||
)
|
||||
|
||||
if (floating) {
|
||||
return (
|
||||
<Dialog
|
||||
open
|
||||
onOpenChange={(open) => {
|
||||
if (!open)
|
||||
onCancel()
|
||||
}}
|
||||
>
|
||||
<DialogContent
|
||||
backdropClassName="bg-transparent!"
|
||||
className="top-16! bottom-4! left-[max(8px,calc(100vw-1136px))]! flex max-h-none! w-[480px]! max-w-[calc(100vw-16px)]! translate-x-0! translate-y-0! flex-col overflow-hidden! rounded-xl! border-[0.5px]! border-components-panel-border! bg-components-panel-bg! p-0! pt-3! pb-3! shadow-xl!"
|
||||
>
|
||||
<DialogTitle className="text-md shrink-0 px-4 py-1 font-semibold text-text-primary">{t('runDetail.workflowTitle', { ns: 'appLog' })}</DialogTitle>
|
||||
<button
|
||||
type="button"
|
||||
aria-label={t('operation.close', { ns: 'common' })}
|
||||
className="absolute top-4 right-3 z-20 cursor-pointer border-none bg-transparent p-1 focus-visible:ring-1 focus-visible:ring-components-input-border-active focus-visible:outline-hidden"
|
||||
onClick={onCancel}
|
||||
>
|
||||
<RiCloseLine className="size-4 text-text-tertiary" aria-hidden="true" />
|
||||
</button>
|
||||
{detailContent}
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn('relative z-10 flex flex-col rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg py-3 shadow-xl')}
|
||||
@ -54,11 +95,7 @@ const AgentLogModal: FC<AgentLogModalProps> = ({
|
||||
>
|
||||
<RiCloseLine className="size-4 text-text-tertiary" aria-hidden="true" />
|
||||
</button>
|
||||
<AgentLogDetail
|
||||
conversationID={currentLogItem.conversationId}
|
||||
messageID={currentLogItem.id}
|
||||
log={currentLogItem}
|
||||
/>
|
||||
{detailContent}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@ -1,196 +1,285 @@
|
||||
import { convertToMp3 } from '../utils'
|
||||
|
||||
// ── Hoisted mocks ──
|
||||
|
||||
const mocks = vi.hoisted(() => {
|
||||
const readHeader = vi.fn()
|
||||
const encodeBuffer = vi.fn()
|
||||
const flush = vi.fn()
|
||||
const state = {
|
||||
targets: [] as Array<{ buffer: ArrayBuffer | null }>,
|
||||
outputs: [] as Array<{
|
||||
format: unknown
|
||||
target: { buffer: ArrayBuffer | null }
|
||||
addAudioTrack: ReturnType<typeof vi.fn>
|
||||
start: ReturnType<typeof vi.fn>
|
||||
finalize: ReturnType<typeof vi.fn>
|
||||
}>,
|
||||
sources: [] as Array<{
|
||||
encodingConfig: unknown
|
||||
add: ReturnType<typeof vi.fn>
|
||||
close: ReturnType<typeof vi.fn>
|
||||
}>,
|
||||
samples: [] as Array<{
|
||||
init: {
|
||||
data: Int16Array
|
||||
format: string
|
||||
numberOfChannels: number
|
||||
sampleRate: number
|
||||
timestamp: number
|
||||
}
|
||||
close: ReturnType<typeof vi.fn>
|
||||
}>,
|
||||
formats: [] as unknown[],
|
||||
outputBuffer: new Uint8Array([1, 2, 3, 4, 5]).buffer as ArrayBuffer | null,
|
||||
}
|
||||
|
||||
return { readHeader, encodeBuffer, flush }
|
||||
const registerMp3Encoder = vi.fn()
|
||||
const canEncodeAudio = vi.fn()
|
||||
|
||||
class MockBufferTarget {
|
||||
buffer: ArrayBuffer | null = state.outputBuffer
|
||||
|
||||
constructor() {
|
||||
state.targets.push(this)
|
||||
}
|
||||
}
|
||||
|
||||
class MockMp3OutputFormat {
|
||||
constructor() {
|
||||
state.formats.push(this)
|
||||
}
|
||||
}
|
||||
|
||||
class MockOutput {
|
||||
format: unknown
|
||||
target: { buffer: ArrayBuffer | null }
|
||||
addAudioTrack = vi.fn()
|
||||
start = vi.fn(async () => {})
|
||||
finalize = vi.fn(async () => {})
|
||||
|
||||
constructor(options: { format: unknown, target: { buffer: ArrayBuffer | null } }) {
|
||||
this.format = options.format
|
||||
this.target = options.target
|
||||
state.outputs.push(this)
|
||||
}
|
||||
}
|
||||
|
||||
class MockAudioSampleSource {
|
||||
encodingConfig: unknown
|
||||
add = vi.fn(async () => {})
|
||||
close = vi.fn()
|
||||
|
||||
constructor(encodingConfig: unknown) {
|
||||
this.encodingConfig = encodingConfig
|
||||
state.sources.push(this)
|
||||
}
|
||||
}
|
||||
|
||||
class MockAudioSample {
|
||||
init: {
|
||||
data: Int16Array
|
||||
format: string
|
||||
numberOfChannels: number
|
||||
sampleRate: number
|
||||
timestamp: number
|
||||
}
|
||||
|
||||
close = vi.fn()
|
||||
|
||||
constructor(init: MockAudioSample['init']) {
|
||||
this.init = init
|
||||
state.samples.push(this)
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
state,
|
||||
registerMp3Encoder,
|
||||
canEncodeAudio,
|
||||
MockAudioSample,
|
||||
MockAudioSampleSource,
|
||||
MockBufferTarget,
|
||||
MockMp3OutputFormat,
|
||||
MockOutput,
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('lamejs', () => ({
|
||||
default: {
|
||||
WavHeader: {
|
||||
readHeader: mocks.readHeader,
|
||||
},
|
||||
Mp3Encoder: class MockMp3Encoder {
|
||||
encodeBuffer = mocks.encodeBuffer
|
||||
flush = mocks.flush
|
||||
},
|
||||
},
|
||||
vi.mock('@mediabunny/mp3-encoder', () => ({
|
||||
registerMp3Encoder: mocks.registerMp3Encoder,
|
||||
}))
|
||||
|
||||
vi.mock('lamejs/src/js/BitStream', () => ({ default: {} }))
|
||||
vi.mock('lamejs/src/js/Lame', () => ({ default: {} }))
|
||||
vi.mock('lamejs/src/js/MPEGMode', () => ({ default: {} }))
|
||||
vi.mock('mediabunny', () => ({
|
||||
AudioSample: mocks.MockAudioSample,
|
||||
AudioSampleSource: mocks.MockAudioSampleSource,
|
||||
BufferTarget: mocks.MockBufferTarget,
|
||||
Mp3OutputFormat: mocks.MockMp3OutputFormat,
|
||||
Output: mocks.MockOutput,
|
||||
canEncodeAudio: mocks.canEncodeAudio,
|
||||
}))
|
||||
|
||||
// ── helpers ──
|
||||
function createWavHeader(channels: number, sampleRate: number) {
|
||||
const view = new DataView(new ArrayBuffer(44))
|
||||
|
||||
view.setUint16(22, channels, true)
|
||||
view.setUint32(24, sampleRate, true)
|
||||
|
||||
return view
|
||||
}
|
||||
|
||||
function createPcmDataView(samples: number[]) {
|
||||
const view = new DataView(new ArrayBuffer(samples.length * 2))
|
||||
|
||||
samples.forEach((sample, index) => {
|
||||
view.setInt16(index * 2, sample, true)
|
||||
})
|
||||
|
||||
return view
|
||||
}
|
||||
|
||||
/** Build a fake recorder whose getChannelData returns DataView-like objects with .buffer and .byteLength. */
|
||||
function createMockRecorder(opts: {
|
||||
channels: number
|
||||
sampleRate: number
|
||||
leftSamples: number[]
|
||||
rightSamples?: number[]
|
||||
}) {
|
||||
const toDataView = (samples: number[]) => {
|
||||
const buf = new ArrayBuffer(samples.length * 2)
|
||||
const view = new DataView(buf)
|
||||
samples.forEach((v, i) => {
|
||||
view.setInt16(i * 2, v, true)
|
||||
})
|
||||
return view
|
||||
}
|
||||
|
||||
const leftView = toDataView(opts.leftSamples)
|
||||
const rightView = opts.rightSamples ? toDataView(opts.rightSamples) : null
|
||||
|
||||
mocks.readHeader.mockReturnValue({
|
||||
channels: opts.channels,
|
||||
sampleRate: opts.sampleRate,
|
||||
})
|
||||
|
||||
return {
|
||||
getWAV: vi.fn(() => new ArrayBuffer(44)),
|
||||
getWAV: vi.fn(() => createWavHeader(opts.channels, opts.sampleRate)),
|
||||
getChannelData: vi.fn(() => ({
|
||||
left: leftView,
|
||||
right: rightView,
|
||||
left: createPcmDataView(opts.leftSamples),
|
||||
right: opts.rightSamples ? createPcmDataView(opts.rightSamples) : null,
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
async function expectBlobBytes(blob: Blob, bytes: number[]) {
|
||||
expect(new Uint8Array(await blob.arrayBuffer())).toEqual(new Uint8Array(bytes))
|
||||
}
|
||||
|
||||
function getOnly<T>(items: T[]) {
|
||||
expect(items).toHaveLength(1)
|
||||
return items[0]!
|
||||
}
|
||||
|
||||
describe('convertToMp3', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mocks.state.targets = []
|
||||
mocks.state.outputs = []
|
||||
mocks.state.sources = []
|
||||
mocks.state.samples = []
|
||||
mocks.state.formats = []
|
||||
mocks.state.outputBuffer = new Uint8Array([1, 2, 3, 4, 5]).buffer
|
||||
mocks.canEncodeAudio.mockResolvedValue(false)
|
||||
})
|
||||
|
||||
it('should convert mono WAV data to an MP3 blob', () => {
|
||||
it('should encode mono recorder PCM data with Mediabunny', async () => {
|
||||
const recorder = createMockRecorder({
|
||||
channels: 1,
|
||||
sampleRate: 44100,
|
||||
leftSamples: [100, 200, 300, 400],
|
||||
sampleRate: 16000,
|
||||
leftSamples: [-32768, 0, 32767],
|
||||
})
|
||||
|
||||
mocks.encodeBuffer.mockReturnValue(new Int8Array([1, 2, 3]))
|
||||
mocks.flush.mockReturnValue(new Int8Array([4, 5]))
|
||||
|
||||
const result = convertToMp3(recorder)
|
||||
const result = await convertToMp3(recorder)
|
||||
const output = getOnly(mocks.state.outputs)
|
||||
const source = getOnly(mocks.state.sources)
|
||||
const sample = getOnly(mocks.state.samples)
|
||||
|
||||
expect(result).toBeInstanceOf(Blob)
|
||||
expect(result.type).toBe('audio/mp3')
|
||||
expect(mocks.encodeBuffer).toHaveBeenCalled()
|
||||
// Mono: encodeBuffer called with only left data
|
||||
const firstCall = mocks.encodeBuffer.mock.calls[0]
|
||||
expect(firstCall).toHaveLength(1)
|
||||
expect(mocks.flush).toHaveBeenCalled()
|
||||
expect(mocks.canEncodeAudio).toHaveBeenCalledWith('mp3')
|
||||
expect(mocks.registerMp3Encoder).toHaveBeenCalled()
|
||||
expect(source.encodingConfig).toEqual({
|
||||
codec: 'mp3',
|
||||
bitrate: 128000,
|
||||
})
|
||||
expect(sample.init).toMatchObject({
|
||||
format: 's16',
|
||||
numberOfChannels: 1,
|
||||
sampleRate: 16000,
|
||||
timestamp: 0,
|
||||
})
|
||||
expect(Array.from(sample.init.data)).toEqual([-32768, 0, 32767])
|
||||
expect(output.addAudioTrack).toHaveBeenCalledWith(source)
|
||||
expect(output.start).toHaveBeenCalled()
|
||||
expect(source.add).toHaveBeenCalledWith(sample)
|
||||
expect(sample.close).toHaveBeenCalled()
|
||||
expect(source.close).toHaveBeenCalled()
|
||||
expect(output.finalize).toHaveBeenCalled()
|
||||
await expectBlobBytes(result, [1, 2, 3, 4, 5])
|
||||
})
|
||||
|
||||
it('should convert stereo WAV data to an MP3 blob', () => {
|
||||
it('should encode stereo recorder PCM data as interleaved samples', async () => {
|
||||
const recorder = createMockRecorder({
|
||||
channels: 2,
|
||||
sampleRate: 48000,
|
||||
leftSamples: [100, 200],
|
||||
rightSamples: [300, 400],
|
||||
leftSamples: [100, -100],
|
||||
rightSamples: [300, -300],
|
||||
})
|
||||
|
||||
mocks.encodeBuffer.mockReturnValue(new Int8Array([10, 20]))
|
||||
mocks.flush.mockReturnValue(new Int8Array([30]))
|
||||
const result = await convertToMp3(recorder)
|
||||
const sample = getOnly(mocks.state.samples)
|
||||
|
||||
const result = convertToMp3(recorder)
|
||||
|
||||
expect(result).toBeInstanceOf(Blob)
|
||||
expect(result.type).toBe('audio/mp3')
|
||||
// Stereo: encodeBuffer called with left AND right
|
||||
const firstCall = mocks.encodeBuffer.mock.calls[0]
|
||||
expect(firstCall).toHaveLength(2)
|
||||
expect(sample.init).toMatchObject({
|
||||
format: 's16',
|
||||
numberOfChannels: 2,
|
||||
sampleRate: 48000,
|
||||
timestamp: 0,
|
||||
})
|
||||
expect(Array.from(sample.init.data)).toEqual([100, 300, -100, -300])
|
||||
await expectBlobBytes(result, [1, 2, 3, 4, 5])
|
||||
})
|
||||
|
||||
it('should skip empty encoded buffers', () => {
|
||||
it('should skip custom encoder registration when native MP3 encoding is available', async () => {
|
||||
mocks.canEncodeAudio.mockResolvedValue(true)
|
||||
const recorder = createMockRecorder({
|
||||
channels: 1,
|
||||
sampleRate: 44100,
|
||||
leftSamples: [100, 200],
|
||||
leftSamples: [100],
|
||||
})
|
||||
|
||||
mocks.encodeBuffer.mockReturnValue(new Int8Array(0))
|
||||
mocks.flush.mockReturnValue(new Int8Array(0))
|
||||
await convertToMp3(recorder)
|
||||
|
||||
const result = convertToMp3(recorder)
|
||||
|
||||
expect(result).toBeInstanceOf(Blob)
|
||||
expect(result.type).toBe('audio/mp3')
|
||||
expect(result.size).toBe(0)
|
||||
expect(mocks.registerMp3Encoder).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should include flush data when flush returns non-empty buffer', () => {
|
||||
it('should return an empty MP3 blob when the target has no buffer', async () => {
|
||||
mocks.state.outputBuffer = null
|
||||
const recorder = createMockRecorder({
|
||||
channels: 1,
|
||||
sampleRate: 22050,
|
||||
leftSamples: [1],
|
||||
})
|
||||
|
||||
mocks.encodeBuffer.mockReturnValue(new Int8Array(0))
|
||||
mocks.flush.mockReturnValue(new Int8Array([99, 98, 97]))
|
||||
const result = await convertToMp3(recorder)
|
||||
|
||||
const result = convertToMp3(recorder)
|
||||
|
||||
expect(result).toBeInstanceOf(Blob)
|
||||
expect(result.size).toBe(3)
|
||||
expect(result.size).toBe(0)
|
||||
})
|
||||
|
||||
it('should omit flush data when flush returns empty buffer', () => {
|
||||
it('should reject unsupported WAV channel counts', async () => {
|
||||
const recorder = createMockRecorder({
|
||||
channels: 1,
|
||||
channels: 3,
|
||||
sampleRate: 44100,
|
||||
leftSamples: [10, 20],
|
||||
leftSamples: [100],
|
||||
})
|
||||
|
||||
mocks.encodeBuffer.mockReturnValue(new Int8Array([1, 2]))
|
||||
mocks.flush.mockReturnValue(new Int8Array(0))
|
||||
|
||||
const result = convertToMp3(recorder)
|
||||
|
||||
expect(result).toBeInstanceOf(Blob)
|
||||
expect(result.size).toBe(2)
|
||||
await expect(convertToMp3(recorder)).rejects.toThrow('Unsupported WAV channel count: 3')
|
||||
expect(mocks.canEncodeAudio).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should process multiple chunks when sample count exceeds maxSamples (1152)', () => {
|
||||
const samples = Array.from({ length: 2400 }, (_, i) => i % 32767)
|
||||
const recorder = createMockRecorder({
|
||||
channels: 1,
|
||||
sampleRate: 44100,
|
||||
leftSamples: samples,
|
||||
})
|
||||
|
||||
mocks.encodeBuffer.mockReturnValue(new Int8Array([1]))
|
||||
mocks.flush.mockReturnValue(new Int8Array(0))
|
||||
|
||||
const result = convertToMp3(recorder)
|
||||
|
||||
expect(mocks.encodeBuffer.mock.calls.length).toBeGreaterThan(1)
|
||||
expect(result).toBeInstanceOf(Blob)
|
||||
})
|
||||
|
||||
it('should encode stereo with right channel subarray', () => {
|
||||
it('should reject stereo WAV data without a right channel', async () => {
|
||||
const recorder = createMockRecorder({
|
||||
channels: 2,
|
||||
sampleRate: 44100,
|
||||
leftSamples: [100, 200, 300],
|
||||
rightSamples: [400, 500, 600],
|
||||
leftSamples: [100],
|
||||
})
|
||||
|
||||
mocks.encodeBuffer.mockReturnValue(new Int8Array([5, 6, 7]))
|
||||
mocks.flush.mockReturnValue(new Int8Array([8]))
|
||||
await expect(convertToMp3(recorder)).rejects.toThrow('Missing right channel data for stereo WAV')
|
||||
})
|
||||
|
||||
const result = convertToMp3(recorder)
|
||||
it('should reject stereo WAV data with mismatched channel lengths', async () => {
|
||||
const recorder = createMockRecorder({
|
||||
channels: 2,
|
||||
sampleRate: 44100,
|
||||
leftSamples: [100, 200],
|
||||
rightSamples: [300],
|
||||
})
|
||||
|
||||
expect(result).toBeInstanceOf(Blob)
|
||||
for (const call of mocks.encodeBuffer.mock.calls) {
|
||||
expect(call).toHaveLength(2)
|
||||
expect(call[0]).toBeInstanceOf(Int16Array)
|
||||
expect(call[1]).toBeInstanceOf(Int16Array)
|
||||
}
|
||||
await expect(convertToMp3(recorder)).rejects.toThrow('Stereo WAV channel sample counts do not match')
|
||||
})
|
||||
})
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import type { AudioRecorder } from './utils'
|
||||
import { cn } from '@langgenius/dify-ui/cn'
|
||||
import { useRafInterval } from 'ahooks'
|
||||
import Recorder from 'js-audio-recorder'
|
||||
@ -21,12 +22,12 @@ const VoiceInput = ({
|
||||
wordTimestamps,
|
||||
}: VoiceInputTypes) => {
|
||||
const { t } = useTranslation()
|
||||
const recorder = useRef(new Recorder({
|
||||
const recorder = useRef<Recorder & AudioRecorder>(new Recorder({
|
||||
sampleBits: 16,
|
||||
sampleRate: 16000,
|
||||
numChannels: 1,
|
||||
compiling: false,
|
||||
}))
|
||||
}) as Recorder & AudioRecorder)
|
||||
const canvasRef = useRef<HTMLCanvasElement | null>(null)
|
||||
const ctxRef = useRef<CanvasRenderingContext2D | null>(null)
|
||||
const drawRecordId = useRef<number | null>(null)
|
||||
@ -83,7 +84,7 @@ const VoiceInput = ({
|
||||
const canvas = canvasRef.current!
|
||||
const ctx = ctxRef.current!
|
||||
ctx.clearRect(0, 0, canvas.width, canvas.height)
|
||||
const mp3Blob = convertToMp3(recorder.current)
|
||||
const mp3Blob = await convertToMp3(recorder.current)
|
||||
const mp3File = new File([mp3Blob], 'temp.mp3', { type: 'audio/mp3' })
|
||||
const formData = new FormData()
|
||||
formData.append('file', mp3File)
|
||||
|
||||
@ -1,54 +1,111 @@
|
||||
import lamejs from 'lamejs'
|
||||
import BitStream from 'lamejs/src/js/BitStream'
|
||||
import Lame from 'lamejs/src/js/Lame'
|
||||
import MPEGMode from 'lamejs/src/js/MPEGMode'
|
||||
import { registerMp3Encoder } from '@mediabunny/mp3-encoder'
|
||||
import {
|
||||
AudioSample,
|
||||
AudioSampleSource,
|
||||
BufferTarget,
|
||||
canEncodeAudio,
|
||||
Mp3OutputFormat,
|
||||
Output,
|
||||
} from 'mediabunny'
|
||||
|
||||
/* v8 ignore next - @preserve */
|
||||
if (globalThis) {
|
||||
(globalThis as any).MPEGMode = MPEGMode
|
||||
; (globalThis as any).Lame = Lame
|
||||
; (globalThis as any).BitStream = BitStream
|
||||
type SupportedChannelCount = 1 | 2
|
||||
|
||||
type RecorderChannelData = {
|
||||
left: DataView
|
||||
right?: DataView | null
|
||||
}
|
||||
|
||||
export const convertToMp3 = (recorder: any) => {
|
||||
const wav = lamejs.WavHeader.readHeader(recorder.getWAV())
|
||||
const { channels, sampleRate } = wav
|
||||
const mp3enc = new lamejs.Mp3Encoder(channels, sampleRate, 128)
|
||||
export type AudioRecorder = {
|
||||
getWAV: () => ArrayBuffer | DataView
|
||||
getChannelData: () => RecorderChannelData
|
||||
}
|
||||
|
||||
const wavChannelsOffset = 22
|
||||
const wavSampleRateOffset = 24
|
||||
const bytesPerSample = 2
|
||||
const mp3Bitrate = 128_000
|
||||
|
||||
const ensureMp3Encoder = async () => {
|
||||
if (!(await canEncodeAudio('mp3')))
|
||||
registerMp3Encoder()
|
||||
}
|
||||
|
||||
const readWavInfo = (wav: ArrayBuffer | DataView) => {
|
||||
const view = wav instanceof DataView ? wav : new DataView(wav)
|
||||
const channels = view.getUint16(wavChannelsOffset, true)
|
||||
|
||||
if (channels !== 1 && channels !== 2)
|
||||
throw new Error(`Unsupported WAV channel count: ${channels}`)
|
||||
|
||||
return {
|
||||
channels: channels as SupportedChannelCount,
|
||||
sampleRate: view.getUint32(wavSampleRateOffset, true),
|
||||
}
|
||||
}
|
||||
|
||||
const readInt16Samples = (data: DataView) => {
|
||||
const samples = new Int16Array(data.byteLength / bytesPerSample)
|
||||
|
||||
for (let i = 0; i < samples.length; i++)
|
||||
samples[i] = data.getInt16(i * bytesPerSample, true)
|
||||
|
||||
return samples
|
||||
}
|
||||
|
||||
const createPcmData = (data: RecorderChannelData, channels: SupportedChannelCount) => {
|
||||
const leftSamples = readInt16Samples(data.left)
|
||||
|
||||
if (channels === 1)
|
||||
return leftSamples
|
||||
|
||||
if (!data.right)
|
||||
throw new Error('Missing right channel data for stereo WAV')
|
||||
|
||||
const rightSamples = readInt16Samples(data.right)
|
||||
|
||||
if (leftSamples.length !== rightSamples.length)
|
||||
throw new Error('Stereo WAV channel sample counts do not match')
|
||||
|
||||
const samples = new Int16Array(leftSamples.length * channels)
|
||||
|
||||
for (let i = 0; i < leftSamples.length; i++) {
|
||||
samples[i * channels] = leftSamples[i]!
|
||||
samples[i * channels + 1] = rightSamples[i]!
|
||||
}
|
||||
|
||||
return samples
|
||||
}
|
||||
|
||||
export const convertToMp3 = async (recorder: AudioRecorder) => {
|
||||
const { channels, sampleRate } = readWavInfo(recorder.getWAV())
|
||||
const result = recorder.getChannelData()
|
||||
const buffer: BlobPart[] = []
|
||||
const pcmData = createPcmData(result, channels)
|
||||
|
||||
const leftData = result.left && new Int16Array(result.left.buffer, 0, result.left.byteLength / 2)
|
||||
const rightData = result.right && new Int16Array(result.right.buffer, 0, result.right.byteLength / 2)
|
||||
const remaining = leftData.length + (rightData ? rightData.length : 0)
|
||||
await ensureMp3Encoder()
|
||||
|
||||
const maxSamples = 1152
|
||||
const toArrayBuffer = (bytes: Int8Array) => {
|
||||
const arrayBuffer = new ArrayBuffer(bytes.length)
|
||||
new Uint8Array(arrayBuffer).set(bytes)
|
||||
return arrayBuffer
|
||||
}
|
||||
const target = new BufferTarget()
|
||||
const output = new Output({
|
||||
format: new Mp3OutputFormat(),
|
||||
target,
|
||||
})
|
||||
const source = new AudioSampleSource({
|
||||
codec: 'mp3',
|
||||
bitrate: mp3Bitrate,
|
||||
})
|
||||
const sample = new AudioSample({
|
||||
data: pcmData,
|
||||
format: 's16',
|
||||
numberOfChannels: channels,
|
||||
sampleRate,
|
||||
timestamp: 0,
|
||||
})
|
||||
|
||||
for (let i = 0; i < remaining; i += maxSamples) {
|
||||
const left = leftData.subarray(i, i + maxSamples)
|
||||
let right = null
|
||||
let mp3buf = null
|
||||
output.addAudioTrack(source)
|
||||
await output.start()
|
||||
await source.add(sample)
|
||||
sample.close()
|
||||
source.close()
|
||||
await output.finalize()
|
||||
|
||||
if (channels === 2) {
|
||||
right = rightData.subarray(i, i + maxSamples)
|
||||
mp3buf = mp3enc.encodeBuffer(left, right)
|
||||
}
|
||||
else {
|
||||
mp3buf = mp3enc.encodeBuffer(left)
|
||||
}
|
||||
|
||||
if (mp3buf.length > 0)
|
||||
buffer.push(toArrayBuffer(mp3buf))
|
||||
}
|
||||
|
||||
const enc = mp3enc.flush()
|
||||
|
||||
if (enc.length > 0)
|
||||
buffer.push(toArrayBuffer(enc))
|
||||
|
||||
return new Blob(buffer, { type: 'audio/mp3' })
|
||||
return new Blob([target.buffer ?? new ArrayBuffer(0)], { type: 'audio/mp3' })
|
||||
}
|
||||
|
||||
@ -6,6 +6,7 @@ let fetching = false
|
||||
let isManager = true
|
||||
let enableBilling = true
|
||||
let workspacePermissionKeys: string[] = ['billing.subscription.manage']
|
||||
let billingUrlEnabled = false
|
||||
|
||||
const refetchMock = vi.fn()
|
||||
const openAsyncWindowMock = vi.fn()
|
||||
@ -19,11 +20,14 @@ type BillingWindowOptions = {
|
||||
type OpenAsyncWindowCall = [BillingUrlCallback, BillingWindowOptions]
|
||||
|
||||
vi.mock('@/service/use-billing', () => ({
|
||||
useBillingUrl: () => ({
|
||||
data: currentBillingUrl,
|
||||
isFetching: fetching,
|
||||
refetch: refetchMock,
|
||||
}),
|
||||
useBillingUrl: (enabled: boolean) => {
|
||||
billingUrlEnabled = enabled
|
||||
return {
|
||||
data: currentBillingUrl,
|
||||
isFetching: fetching,
|
||||
refetch: refetchMock,
|
||||
}
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/hooks/use-async-window-open', () => ({
|
||||
@ -54,28 +58,32 @@ describe('Billing', () => {
|
||||
fetching = false
|
||||
isManager = true
|
||||
enableBilling = true
|
||||
billingUrlEnabled = false
|
||||
workspacePermissionKeys = ['billing.subscription.manage']
|
||||
refetchMock.mockResolvedValue({ data: 'https://billing' })
|
||||
})
|
||||
|
||||
it('shows the billing action when subscription management permission is granted without manager role', () => {
|
||||
it('hides the billing action when subscription management permission is granted without manager role', () => {
|
||||
isManager = false
|
||||
|
||||
render(<Billing />)
|
||||
|
||||
expect(screen.getByRole('button', { name: /billing\.viewBillingTitle/ })).toBeInTheDocument()
|
||||
expect(screen.queryByRole('button', { name: /billing\.viewBillingTitle/ })).not.toBeInTheDocument()
|
||||
expect(billingUrlEnabled).toBe(false)
|
||||
})
|
||||
|
||||
it('hides the billing action when subscription management permission is missing or billing is disabled', () => {
|
||||
workspacePermissionKeys = []
|
||||
render(<Billing />)
|
||||
expect(screen.queryByRole('button', { name: /billing\.viewBillingTitle/ })).not.toBeInTheDocument()
|
||||
expect(billingUrlEnabled).toBe(false)
|
||||
|
||||
vi.clearAllMocks()
|
||||
workspacePermissionKeys = ['billing.subscription.manage']
|
||||
enableBilling = false
|
||||
render(<Billing />)
|
||||
expect(screen.queryByRole('button', { name: /billing\.viewBillingTitle/ })).not.toBeInTheDocument()
|
||||
expect(billingUrlEnabled).toBe(false)
|
||||
})
|
||||
|
||||
it('opens the billing window with the immediate url when the button is clicked', async () => {
|
||||
|
||||
@ -11,9 +11,9 @@ import PlanComp from '../plan'
|
||||
|
||||
const Billing: FC = () => {
|
||||
const { t } = useTranslation()
|
||||
const { workspacePermissionKeys } = useAppContext()
|
||||
const { isCurrentWorkspaceManager, workspacePermissionKeys } = useAppContext()
|
||||
const { enableBilling } = useProviderContext()
|
||||
const canManageBillingSubscription = hasPermission(workspacePermissionKeys, BillingPermission.SubscriptionManage)
|
||||
const canManageBillingSubscription = isCurrentWorkspaceManager && hasPermission(workspacePermissionKeys, BillingPermission.SubscriptionManage)
|
||||
const { data: billingUrl, isFetching, refetch } = useBillingUrl(enableBilling && canManageBillingSubscription)
|
||||
const openAsyncWindow = useAsyncWindowOpen()
|
||||
|
||||
|
||||
5
web/global.d.ts
vendored
5
web/global.d.ts
vendored
@ -3,11 +3,6 @@ import './types/jsx'
|
||||
import './types/mdx'
|
||||
import './types/assets'
|
||||
|
||||
declare module 'lamejs';
|
||||
declare module 'lamejs/src/js/MPEGMode';
|
||||
declare module 'lamejs/src/js/Lame';
|
||||
declare module 'lamejs/src/js/BitStream';
|
||||
|
||||
declare global {
|
||||
// Google Analytics gtag types
|
||||
type GtagEventParams = {
|
||||
|
||||
@ -65,6 +65,7 @@
|
||||
"@lexical/selection": "catalog:",
|
||||
"@lexical/text": "catalog:",
|
||||
"@lexical/utils": "catalog:",
|
||||
"@mediabunny/mp3-encoder": "catalog:",
|
||||
"@monaco-editor/react": "catalog:",
|
||||
"@orpc/client": "catalog:",
|
||||
"@orpc/contract": "catalog:",
|
||||
@ -118,9 +119,9 @@
|
||||
"jsonschema": "catalog:",
|
||||
"katex": "catalog:",
|
||||
"ky": "catalog:",
|
||||
"lamejs": "catalog:",
|
||||
"lexical": "catalog:",
|
||||
"loro-crdt": "catalog:",
|
||||
"mediabunny": "catalog:",
|
||||
"mermaid": "catalog:",
|
||||
"mime": "catalog:",
|
||||
"mitt": "catalog:",
|
||||
|
||||
36
web/types/lamejs.d.ts
vendored
36
web/types/lamejs.d.ts
vendored
@ -1,36 +0,0 @@
|
||||
declare module 'lamejs' {
|
||||
export class Mp3Encoder {
|
||||
constructor(channels: number, sampleRate: number, bitRate: number)
|
||||
encodeBuffer(left: Int16Array, right?: Int16Array | null): Int8Array
|
||||
flush(): Int8Array
|
||||
}
|
||||
|
||||
export class WavHeader {
|
||||
static readHeader(data: DataView): {
|
||||
channels: number
|
||||
sampleRate: number
|
||||
}
|
||||
}
|
||||
|
||||
const lamejs: {
|
||||
Mp3Encoder: typeof Mp3Encoder
|
||||
WavHeader: typeof WavHeader
|
||||
}
|
||||
|
||||
export default lamejs
|
||||
}
|
||||
|
||||
declare module 'lamejs/src/js/MPEGMode' {
|
||||
const MPEGMode: any
|
||||
export default MPEGMode
|
||||
}
|
||||
|
||||
declare module 'lamejs/src/js/Lame' {
|
||||
const Lame: any
|
||||
export default Lame
|
||||
}
|
||||
|
||||
declare module 'lamejs/src/js/BitStream' {
|
||||
const BitStream: any
|
||||
export default BitStream
|
||||
}
|
||||
Reference in New Issue
Block a user