Compare commits

...

35 Commits

Author SHA1 Message Date
fde8efa4a2 fix: summary index in parent child chunk 2026-01-16 10:49:38 +08:00
f02adc26e5 fix: pipeline run panel summary 2026-01-15 18:02:19 +08:00
1126a2aa95 merge main 2026-01-15 16:08:29 +08:00
4a197b9458 fix: fix log updated_at is refreshed (#31045) 2026-01-15 15:42:46 +08:00
772ff636ec feat: credential sync fix for enterprise edition (#30626) 2026-01-14 23:33:24 -08:00
ab1c5a2027 refactor: remove manual set query logic (#31039) 2026-01-15 15:25:43 +08:00
33e99f069b fix: message clean service ut (#31038) 2026-01-15 15:13:25 +08:00
52af829f1f refactor: enhance clean messages task (#29638)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: 非法操作 <hjlarry@163.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-15 14:03:17 +08:00
0ef8b5a0ca chore: bump version to 1.11.4 (#30961) 2026-01-15 11:36:15 +08:00
2bfc54314e feat: single run add opentelemetry (#31020) 2026-01-15 11:10:55 +08:00
bdd8d5b470 test: add unit tests for PluginPage and related components (#30908)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
2026-01-15 10:56:02 +08:00
4955de5905 fix: validation error when uploading images with None URL values (#31012)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2026-01-15 10:54:10 +08:00
yyh
3bee2ee067 refactor(contract): restructure console contracts with nested billing module (#30999) 2026-01-15 10:41:18 +08:00
328897f81c build: require node 24.13.0 (#30945) 2026-01-15 10:38:55 +08:00
ab078380a3 feat(web): refactor documents component structure and enhance functionality (#30854)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
2026-01-15 10:33:58 +08:00
a33ac77a22 feat: implement document creation pipeline with multi-step wizard and datasource management (#30843)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
2026-01-15 10:33:48 +08:00
d3923e7b56 refactor: port AppAnnotationHitHistory (#30922)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-15 10:14:55 +08:00
2f633de45e refactor: port TenantCreditPool (#30926)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-01-15 10:14:15 +08:00
98c88cec34 refactor: delete_endpoint should be idempotent (#30954) 2026-01-15 10:10:10 +08:00
c6999fb5be fix: fix plugin edit endpoint app disappear (#30951) 2026-01-15 10:09:57 +08:00
f7f9a08fa5 refactor: port TidbAuthBinding( (#31006)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-01-15 10:07:02 +08:00
5008f5e89b fix: Use raw SQL UPDATE to set read status without triggering updated… (#31015) 2026-01-15 09:51:44 +08:00
1dd89a02ea fix: fix missing id and message_id (#31008) 2026-01-14 23:26:17 +09:00
5bf4114d6f fix: increase name length limit in ExternalDatasetCreatePayload (#31000)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-01-14 22:13:53 +09:00
yyh
a56e94ba8e feat: add .agent/skills symlink and orpc-contract-first skill (#30968) 2026-01-14 21:13:14 +08:00
11f1782df0 fix: correct API Extension documentation link (#30962) 2026-01-14 21:21:15 +09:00
8cf5d9a6a1 fix: fix Cannot destructure property 'name' of 'value' as it is undef… (#30991) 2026-01-14 19:30:47 +08:00
0ec2b12e65 feat: allow pass hostname in docker env (#30975) 2026-01-14 19:30:37 +08:00
f33b1a3332 fix: redirect after login (#30985) 2026-01-14 17:20:49 +08:00
08026f7399 fix(deps): security updates for pdfminer.six, authlib, werkzeug, aiohttp and others (#30976)
Signed-off-by: kenwoodjw <blackxin55+@gmail.com>
2026-01-14 17:03:46 +08:00
yyh
18e051bd66 chore(web): remove unused demo service component (#30979) 2026-01-14 17:03:35 +08:00
yyh
42f991dbef chore(web): disable Serwist dev logs (#30980) 2026-01-14 16:23:58 +08:00
yyh
b1b2c9636f fix(web): preserve HTTP method in ORPC fetchCompat mode (#30971)
Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com>
2026-01-14 16:18:12 +08:00
01f17b7ddc refactor(http_request_node): apply DI for http request node (#30509)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-01-14 14:19:48 +08:00
yyh
14b2e5bd0d refactor(web): MCP tool availability to context-based version gating (#30955) 2026-01-14 13:40:16 +08:00
132 changed files with 16045 additions and 4121 deletions

1
.agent/skills Symbolic link
View File

@ -0,0 +1 @@
../.claude/skills

View File

@ -0,0 +1,46 @@
---
name: orpc-contract-first
description: Guide for implementing oRPC contract-first API patterns in Dify frontend. Triggers when creating new API contracts, adding service endpoints, integrating TanStack Query with typed contracts, or migrating legacy service calls to oRPC. Use for all API layer work in web/contract and web/service directories.
---
# oRPC Contract-First Development
## Project Structure
```
web/contract/
├── base.ts # Base contract (inputStructure: 'detailed')
├── router.ts # Router composition & type exports
├── marketplace.ts # Marketplace contracts
└── console/ # Console contracts by domain
├── system.ts
└── billing.ts
```
## Workflow
1. **Create contract** in `web/contract/console/{domain}.ts`
- Import `base` from `../base` and `type` from `@orpc/contract`
- Define route with `path`, `method`, `input`, `output`
2. **Register in router** at `web/contract/router.ts`
- Import directly from domain file (no barrel files)
- Nest by API prefix: `billing: { invoices, bindPartnerStack }`
3. **Create hooks** in `web/service/use-{domain}.ts`
- Use `consoleQuery.{group}.{contract}.queryKey()` for query keys
- Use `consoleClient.{group}.{contract}()` for API calls
## Key Rules
- **Input structure**: Always use `{ params, query?, body? }` format
- **Path params**: Use `{paramName}` in path, match in `params` object
- **Router nesting**: Group by API prefix (e.g., `/billing/*``billing: {}`)
- **No barrel files**: Import directly from specific files
- **Types**: Import from `@/types/`, use `type<T>()` helper
## Type Export
```typescript
export type ConsoleInputs = InferContractRouterInputs<typeof consoleRouterContract>
```

View File

@ -90,7 +90,7 @@ jobs:
uses: actions/setup-node@v6
if: steps.changed-files.outputs.any_changed == 'true'
with:
node-version: 22
node-version: 24
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml

View File

@ -16,10 +16,6 @@ jobs:
name: unit test for Node.js SDK
runs-on: ubuntu-latest
strategy:
matrix:
node-version: [16, 18, 20, 22]
defaults:
run:
working-directory: sdks/nodejs-client
@ -29,10 +25,10 @@ jobs:
with:
persist-credentials: false
- name: Use Node.js ${{ matrix.node-version }}
- name: Use Node.js
uses: actions/setup-node@v6
with:
node-version: ${{ matrix.node-version }}
node-version: 24
cache: ''
cache-dependency-path: 'pnpm-lock.yaml'

View File

@ -57,7 +57,7 @@ jobs:
- name: Set up Node.js
uses: actions/setup-node@v6
with:
node-version: 'lts/*'
node-version: 24
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml

View File

@ -31,7 +31,7 @@ jobs:
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: 22
node-version: 24
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml

View File

@ -417,6 +417,8 @@ SMTP_USERNAME=123
SMTP_PASSWORD=abc
SMTP_USE_TLS=true
SMTP_OPPORTUNISTIC_TLS=false
# Optional: override the local hostname used for SMTP HELO/EHLO
SMTP_LOCAL_HOSTNAME=
# Sendgid configuration
SENDGRID_API_KEY=
# Sentry configuration
@ -713,3 +715,4 @@ ANNOTATION_IMPORT_MAX_CONCURRENT=5
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30

View File

@ -3,6 +3,7 @@ import datetime
import json
import logging
import secrets
import time
from typing import Any
import click
@ -46,6 +47,8 @@ from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpi
from services.plugin.data_migration import PluginDataMigration
from services.plugin.plugin_migration import PluginMigration
from services.plugin.plugin_service import PluginService
from services.retention.conversation.messages_clean_policy import create_message_clean_policy
from services.retention.conversation.messages_clean_service import MessagesCleanService
from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup
from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
@ -2172,3 +2175,79 @@ def migrate_oss(
except Exception as e:
db.session.rollback()
click.echo(click.style(f"Failed to update DB storage_type: {str(e)}", fg="red"))
@click.command("clean-expired-messages", help="Clean expired messages.")
@click.option(
"--start-from",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
required=True,
help="Lower bound (inclusive) for created_at.",
)
@click.option(
"--end-before",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
required=True,
help="Upper bound (exclusive) for created_at.",
)
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for selecting messages.")
@click.option(
"--graceful-period",
default=21,
show_default=True,
help="Graceful period in days after subscription expiration, will be ignored when billing is disabled.",
)
@click.option("--dry-run", is_flag=True, default=False, help="Show messages logs would be cleaned without deleting")
def clean_expired_messages(
batch_size: int,
graceful_period: int,
start_from: datetime.datetime,
end_before: datetime.datetime,
dry_run: bool,
):
"""
Clean expired messages and related data for tenants based on clean policy.
"""
click.echo(click.style("clean_messages: start clean messages.", fg="green"))
start_at = time.perf_counter()
try:
# Create policy based on billing configuration
# NOTE: graceful_period will be ignored when billing is disabled.
policy = create_message_clean_policy(graceful_period_days=graceful_period)
# Create and run the cleanup service
service = MessagesCleanService.from_time_range(
policy=policy,
start_from=start_from,
end_before=end_before,
batch_size=batch_size,
dry_run=dry_run,
)
stats = service.run()
end_at = time.perf_counter()
click.echo(
click.style(
f"clean_messages: completed successfully\n"
f" - Latency: {end_at - start_at:.2f}s\n"
f" - Batches processed: {stats['batches']}\n"
f" - Total messages scanned: {stats['total_messages']}\n"
f" - Messages filtered: {stats['filtered_messages']}\n"
f" - Messages deleted: {stats['total_deleted']}",
fg="green",
)
)
except Exception as e:
end_at = time.perf_counter()
logger.exception("clean_messages failed")
click.echo(
click.style(
f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}",
fg="red",
)
)
raise
click.echo(click.style("messages cleanup completed.", fg="green"))

View File

@ -949,6 +949,12 @@ class MailConfig(BaseSettings):
default=False,
)
SMTP_LOCAL_HOSTNAME: str | None = Field(
description="Override the local hostname used in SMTP HELO/EHLO. "
"Useful behind NAT or when the default hostname causes rejections.",
default=None,
)
EMAIL_SEND_IP_LIMIT_PER_MINUTE: PositiveInt = Field(
description="Maximum number of emails allowed to be sent from the same IP address in a minute",
default=50,

View File

@ -592,9 +592,12 @@ def _get_conversation(app_model, conversation_id):
if not conversation:
raise NotFound("Conversation Not Exists.")
if not conversation.read_at:
conversation.read_at = naive_utc_now()
conversation.read_account_id = current_user.id
db.session.commit()
db.session.execute(
sa.update(Conversation)
.where(Conversation.id == conversation_id, Conversation.read_at.is_(None))
.values(read_at=naive_utc_now(), read_account_id=current_user.id)
)
db.session.commit()
db.session.refresh(conversation)
return conversation

View File

@ -81,7 +81,7 @@ class ExternalKnowledgeApiPayload(BaseModel):
class ExternalDatasetCreatePayload(BaseModel):
external_knowledge_api_id: str
external_knowledge_id: str
name: str = Field(..., min_length=1, max_length=40)
name: str = Field(..., min_length=1, max_length=100)
description: str | None = Field(None, max_length=400)
external_retrieval_model: dict[str, object] | None = None

View File

@ -33,6 +33,10 @@ class MaxRetriesExceededError(ValueError):
pass
request_error = httpx.RequestError
max_retries_exceeded_error = MaxRetriesExceededError
def _create_proxy_mounts() -> dict[str, httpx.HTTPTransport]:
return {
"http://": httpx.HTTPTransport(

View File

@ -1,5 +1,6 @@
from core.plugin.entities.endpoint import EndpointEntityWithInstance
from core.plugin.impl.base import BasePluginClient
from core.plugin.impl.exc import PluginDaemonInternalServerError
class PluginEndpointClient(BasePluginClient):
@ -70,18 +71,27 @@ class PluginEndpointClient(BasePluginClient):
def delete_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str):
"""
Delete the given endpoint.
This operation is idempotent: if the endpoint is already deleted (record not found),
it will return True instead of raising an error.
"""
return self._request_with_plugin_daemon_response(
"POST",
f"plugin/{tenant_id}/endpoint/remove",
bool,
data={
"endpoint_id": endpoint_id,
},
headers={
"Content-Type": "application/json",
},
)
try:
return self._request_with_plugin_daemon_response(
"POST",
f"plugin/{tenant_id}/endpoint/remove",
bool,
data={
"endpoint_id": endpoint_id,
},
headers={
"Content-Type": "application/json",
},
)
except PluginDaemonInternalServerError as e:
# Make delete idempotent: if record is not found, consider it a success
if "record not found" in str(e.description).lower():
return True
raise
def enable_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str):
"""

View File

@ -17,6 +17,7 @@ from core.helper import ssrf_proxy
from core.variables.segments import ArrayFileSegment, FileSegment
from core.workflow.runtime import VariablePool
from ..protocols import FileManagerProtocol, HttpClientProtocol
from .entities import (
HttpRequestNodeAuthorization,
HttpRequestNodeData,
@ -78,6 +79,8 @@ class Executor:
timeout: HttpRequestNodeTimeout,
variable_pool: VariablePool,
max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES,
http_client: HttpClientProtocol = ssrf_proxy,
file_manager: FileManagerProtocol = file_manager,
):
# If authorization API key is present, convert the API key using the variable pool
if node_data.authorization.type == "api-key":
@ -104,6 +107,8 @@ class Executor:
self.data = None
self.json = None
self.max_retries = max_retries
self._http_client = http_client
self._file_manager = file_manager
# init template
self.variable_pool = variable_pool
@ -200,7 +205,7 @@ class Executor:
if file_variable is None:
raise FileFetchError(f"cannot fetch file with selector {file_selector}")
file = file_variable.value
self.content = file_manager.download(file)
self.content = self._file_manager.download(file)
case "x-www-form-urlencoded":
form_data = {
self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template(
@ -239,7 +244,7 @@ class Executor:
):
file_tuple = (
file.filename,
file_manager.download(file),
self._file_manager.download(file),
file.mime_type or "application/octet-stream",
)
if key not in files:
@ -332,19 +337,18 @@ class Executor:
do http request depending on api bundle
"""
_METHOD_MAP = {
"get": ssrf_proxy.get,
"head": ssrf_proxy.head,
"post": ssrf_proxy.post,
"put": ssrf_proxy.put,
"delete": ssrf_proxy.delete,
"patch": ssrf_proxy.patch,
"get": self._http_client.get,
"head": self._http_client.head,
"post": self._http_client.post,
"put": self._http_client.put,
"delete": self._http_client.delete,
"patch": self._http_client.patch,
}
method_lc = self.method.lower()
if method_lc not in _METHOD_MAP:
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
request_args = {
"url": self.url,
"data": self.data,
"files": self.files,
"json": self.json,
@ -357,8 +361,12 @@ class Executor:
}
# request_args = {k: v for k, v in request_args.items() if v is not None}
try:
response: httpx.Response = _METHOD_MAP[method_lc](**request_args, max_retries=self.max_retries)
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
response: httpx.Response = _METHOD_MAP[method_lc](
url=self.url,
**request_args,
max_retries=self.max_retries,
)
except (self._http_client.max_retries_exceeded_error, self._http_client.request_error) as e:
raise HttpRequestNodeError(str(e)) from e
# FIXME: fix type ignore, this maybe httpx type issue
return response

View File

@ -1,10 +1,11 @@
import logging
import mimetypes
from collections.abc import Mapping, Sequence
from typing import Any
from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from configs import dify_config
from core.file import File, FileTransferMethod
from core.file import File, FileTransferMethod, file_manager
from core.helper import ssrf_proxy
from core.tools.tool_file_manager import ToolFileManager
from core.variables.segments import ArrayFileSegment
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@ -13,6 +14,7 @@ from core.workflow.nodes.base import variable_template_parser
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.http_request.executor import Executor
from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
from factories import file_factory
from .entities import (
@ -30,10 +32,35 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
class HttpRequestNode(Node[HttpRequestNodeData]):
node_type = NodeType.HTTP_REQUEST
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
http_client: HttpClientProtocol = ssrf_proxy,
tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
file_manager: FileManagerProtocol = file_manager,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._http_client = http_client
self._tool_file_manager_factory = tool_file_manager_factory
self._file_manager = file_manager
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
@ -71,6 +98,8 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
timeout=self._get_request_timeout(self.node_data),
variable_pool=self.graph_runtime_state.variable_pool,
max_retries=0,
http_client=self._http_client,
file_manager=self._file_manager,
)
process_data["request"] = http_executor.to_log()
@ -199,7 +228,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
mime_type = (
content_disposition_type or content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
)
tool_file_manager = ToolFileManager()
tool_file_manager = self._tool_file_manager_factory()
tool_file = tool_file_manager.create_file_by_raw(
user_id=self.user_id,

View File

@ -1,16 +1,21 @@
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, final
from typing_extensions import override
from configs import dify_config
from core.file import file_manager
from core.helper import ssrf_proxy
from core.helper.code_executor.code_executor import CodeExecutor
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.enums import NodeType
from core.workflow.graph import NodeFactory
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.limits import CodeNodeLimits
from core.workflow.nodes.http_request.node import HttpRequestNode
from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
from core.workflow.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
Jinja2TemplateRenderer,
@ -43,6 +48,9 @@ class DifyNodeFactory(NodeFactory):
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
code_limits: CodeNodeLimits | None = None,
template_renderer: Jinja2TemplateRenderer | None = None,
http_request_http_client: HttpClientProtocol = ssrf_proxy,
http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
http_request_file_manager: FileManagerProtocol = file_manager,
) -> None:
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
@ -61,6 +69,9 @@ class DifyNodeFactory(NodeFactory):
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
)
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
self._http_request_http_client = http_request_http_client
self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
self._http_request_file_manager = http_request_file_manager
@override
def create_node(self, node_config: dict[str, object]) -> Node:
@ -113,6 +124,7 @@ class DifyNodeFactory(NodeFactory):
code_providers=self._code_providers,
code_limits=self._code_limits,
)
if node_type == NodeType.TEMPLATE_TRANSFORM:
return TemplateTransformNode(
id=node_id,
@ -122,6 +134,17 @@ class DifyNodeFactory(NodeFactory):
template_renderer=self._template_renderer,
)
if node_type == NodeType.HTTP_REQUEST:
return HttpRequestNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
http_client=self._http_request_http_client,
tool_file_manager_factory=self._http_request_tool_file_manager_factory,
file_manager=self._http_request_file_manager,
)
return node_class(
id=node_id,
config=node_config,

View File

@ -0,0 +1,29 @@
from typing import Protocol
import httpx
from core.file import File
class HttpClientProtocol(Protocol):
@property
def max_retries_exceeded_error(self) -> type[Exception]: ...
@property
def request_error(self) -> type[Exception]: ...
def get(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def head(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def post(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def put(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def delete(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def patch(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
class FileManagerProtocol(Protocol):
def download(self, f: File, /) -> bytes: ...

View File

@ -189,8 +189,7 @@ class WorkflowEntry:
)
try:
# run node
generator = node.run()
generator = cls._traced_node_run(node)
except Exception as e:
logger.exception(
"error while running node, workflow_id=%s, node_id=%s, node_type=%s, node_version=%s",
@ -323,8 +322,7 @@ class WorkflowEntry:
tenant_id=tenant_id,
)
# run node
generator = node.run()
generator = cls._traced_node_run(node)
return node, generator
except Exception as e:
@ -430,3 +428,26 @@ class WorkflowEntry:
input_value = current_variable.value | input_value
variable_pool.add([variable_node_id] + variable_key_list, input_value)
@staticmethod
def _traced_node_run(node: Node) -> Generator[GraphNodeEventBase, None, None]:
"""
Wraps a node's run method with OpenTelemetry tracing and returns a generator.
"""
# Wrap node.run() with ObservabilityLayer hooks to produce node-level spans
layer = ObservabilityLayer()
layer.on_graph_start()
node.ensure_execution_id()
def _gen():
error: Exception | None = None
layer.on_node_run_start(node)
try:
yield from node.run()
except Exception as exc:
error = exc
raise
finally:
layer.on_node_run_end(node, error)
return _gen()

View File

@ -6,6 +6,7 @@ from .create_site_record_when_app_created import handle as handle_create_site_re
from .delete_tool_parameters_cache_when_sync_draft_workflow import (
handle as handle_delete_tool_parameters_cache_when_sync_draft_workflow,
)
from .queue_credential_sync_when_tenant_created import handle as handle_queue_credential_sync_when_tenant_created
from .sync_plugin_trigger_when_app_created import handle as handle_sync_plugin_trigger_when_app_created
from .sync_webhook_when_app_created import handle as handle_sync_webhook_when_app_created
from .sync_workflow_schedule_when_app_published import handle as handle_sync_workflow_schedule_when_app_published
@ -30,6 +31,7 @@ __all__ = [
"handle_create_installed_app_when_app_created",
"handle_create_site_record_when_app_created",
"handle_delete_tool_parameters_cache_when_sync_draft_workflow",
"handle_queue_credential_sync_when_tenant_created",
"handle_sync_plugin_trigger_when_app_created",
"handle_sync_webhook_when_app_created",
"handle_sync_workflow_schedule_when_app_published",

View File

@ -0,0 +1,19 @@
from configs import dify_config
from events.tenant_event import tenant_was_created
from services.enterprise.workspace_sync import WorkspaceSyncService
@tenant_was_created.connect
def handle(sender, **kwargs):
"""Queue credential sync when a tenant/workspace is created."""
# Only queue sync tasks if plugin manager (enterprise feature) is enabled
if not dify_config.ENTERPRISE_ENABLED:
return
tenant = sender
# Determine source from kwargs if available, otherwise use generic
source = kwargs.get("source", "tenant_created")
# Queue credential sync task to Redis for enterprise backend to process
WorkspaceSyncService.queue_credential_sync(tenant.id, source=source)

View File

@ -4,6 +4,7 @@ from dify_app import DifyApp
def init_app(app: DifyApp):
from commands import (
add_qdrant_index,
clean_expired_messages,
clean_workflow_runs,
cleanup_orphaned_draft_variables,
clear_free_plan_tenant_expired_logs,
@ -58,6 +59,7 @@ def init_app(app: DifyApp):
transform_datasource_credentials,
install_rag_pipeline_plugins,
clean_workflow_runs,
clean_expired_messages,
]
for cmd in cmds_to_register:
app.cli.add_command(cmd)

View File

@ -115,7 +115,18 @@ def build_from_mappings(
# TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query.
# Implement batch processing to reduce database load when handling multiple files.
# Filter out None/empty mappings to avoid errors
valid_mappings = [m for m in mappings if m and m.get("transfer_method")]
def is_valid_mapping(m: Mapping[str, Any]) -> bool:
if not m or not m.get("transfer_method"):
return False
# For REMOTE_URL transfer method, ensure url or remote_url is provided and not None
transfer_method = m.get("transfer_method")
if transfer_method == FileTransferMethod.REMOTE_URL:
url = m.get("url") or m.get("remote_url")
if not url:
return False
return True
valid_mappings = [m for m in mappings if is_valid_mapping(m)]
files = [
build_from_mapping(
mapping=mapping,

View File

@ -2,6 +2,7 @@ from __future__ import annotations
from datetime import datetime
from typing import TypeAlias
from uuid import uuid4
from pydantic import BaseModel, ConfigDict, Field, field_validator
@ -20,8 +21,8 @@ class SimpleFeedback(ResponseModel):
class RetrieverResource(ResponseModel):
id: str
message_id: str
id: str = Field(default_factory=lambda: str(uuid4()))
message_id: str = Field(default_factory=lambda: str(uuid4()))
position: int
dataset_id: str | None = None
dataset_name: str | None = None

View File

@ -3,6 +3,8 @@ import smtplib
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from configs import dify_config
logger = logging.getLogger(__name__)
@ -19,20 +21,21 @@ class SMTPClient:
self.opportunistic_tls = opportunistic_tls
def send(self, mail: dict):
smtp = None
smtp: smtplib.SMTP | None = None
local_host = dify_config.SMTP_LOCAL_HOSTNAME
try:
if self.use_tls:
if self.opportunistic_tls:
smtp = smtplib.SMTP(self.server, self.port, timeout=10)
# Send EHLO command with the HELO domain name as the server address
smtp.ehlo(self.server)
smtp.starttls()
# Resend EHLO command to identify the TLS session
smtp.ehlo(self.server)
else:
smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10)
if self.use_tls and not self.opportunistic_tls:
# SMTP with SSL (implicit TLS)
smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10, local_hostname=local_host)
else:
smtp = smtplib.SMTP(self.server, self.port, timeout=10)
# Plain SMTP or SMTP with STARTTLS (explicit TLS)
smtp = smtplib.SMTP(self.server, self.port, timeout=10, local_hostname=local_host)
assert smtp is not None
if self.use_tls and self.opportunistic_tls:
smtp.ehlo(self.server)
smtp.starttls()
smtp.ehlo(self.server)
# Only authenticate if both username and password are non-empty
if self.username and self.password and self.username.strip() and self.password.strip():

View File

@ -0,0 +1,33 @@
"""feat: add created_at id index to messages
Revision ID: 3334862ee907
Revises: 905527cc8fd3
Create Date: 2026-01-12 17:29:44.846544
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '3334862ee907'
down_revision = '905527cc8fd3'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('messages', schema=None) as batch_op:
batch_op.create_index('message_created_at_id_idx', ['created_at', 'id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('messages', schema=None) as batch_op:
batch_op.drop_index('message_created_at_id_idx')
# ### end Alembic commands ###

View File

@ -1149,7 +1149,7 @@ class DatasetCollectionBinding(TypeBase):
)
class TidbAuthBinding(Base):
class TidbAuthBinding(TypeBase):
__tablename__ = "tidb_auth_bindings"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"),
@ -1158,7 +1158,13 @@ class TidbAuthBinding(Base):
sa.Index("tidb_auth_bindings_created_at_idx", "created_at"),
sa.Index("tidb_auth_bindings_status_idx", "status"),
)
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()))
id: Mapped[str] = mapped_column(
StringUUID,
primary_key=True,
insert_default=lambda: str(uuid4()),
default_factory=lambda: str(uuid4()),
init=False,
)
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
cluster_id: Mapped[str] = mapped_column(String(255), nullable=False)
cluster_name: Mapped[str] = mapped_column(String(255), nullable=False)
@ -1166,7 +1172,9 @@ class TidbAuthBinding(Base):
status: Mapped[str] = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'CREATING'"))
account: Mapped[str] = mapped_column(String(255), nullable=False)
password: Mapped[str] = mapped_column(String(255), nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
class Whitelist(TypeBase):

View File

@ -968,6 +968,7 @@ class Message(Base):
Index("message_workflow_run_id_idx", "conversation_id", "workflow_run_id"),
Index("message_created_at_idx", "created_at"),
Index("message_app_mode_idx", "app_mode"),
Index("message_created_at_id_idx", "created_at", "id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
@ -1447,7 +1448,7 @@ class MessageAnnotation(Base):
return account
class AppAnnotationHitHistory(Base):
class AppAnnotationHitHistory(TypeBase):
__tablename__ = "app_annotation_hit_histories"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"),
@ -1457,17 +1458,19 @@ class AppAnnotationHitHistory(Base):
sa.Index("app_annotation_hit_histories_message_idx", "message_id"),
)
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
app_id = mapped_column(StringUUID, nullable=False)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
annotation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
source = mapped_column(LongText, nullable=False)
question = mapped_column(LongText, nullable=False)
account_id = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
score = mapped_column(Float, nullable=False, server_default=sa.text("0"))
message_id = mapped_column(StringUUID, nullable=False)
annotation_question = mapped_column(LongText, nullable=False)
annotation_content = mapped_column(LongText, nullable=False)
source: Mapped[str] = mapped_column(LongText, nullable=False)
question: Mapped[str] = mapped_column(LongText, nullable=False)
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
score: Mapped[float] = mapped_column(Float, nullable=False, server_default=sa.text("0"))
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
annotation_question: Mapped[str] = mapped_column(LongText, nullable=False)
annotation_content: Mapped[str] = mapped_column(LongText, nullable=False)
@property
def account(self):
@ -2083,7 +2086,7 @@ class TraceAppConfig(TypeBase):
}
class TenantCreditPool(Base):
class TenantCreditPool(TypeBase):
__tablename__ = "tenant_credit_pools"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="tenant_credit_pool_pkey"),
@ -2091,14 +2094,20 @@ class TenantCreditPool(Base):
sa.Index("tenant_credit_pool_pool_type_idx", "pool_type"),
)
id = mapped_column(StringUUID, primary_key=True, server_default=text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
pool_type = mapped_column(String(40), nullable=False, default="trial", server_default="trial")
quota_limit = mapped_column(BigInteger, nullable=False, default=0)
quota_used = mapped_column(BigInteger, nullable=False, default=0)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"))
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=text("uuid_generate_v4()"), init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
pool_type: Mapped[str] = mapped_column(String(40), nullable=False, default="trial", server_default="trial")
quota_limit: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
quota_used: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"), init=False
)
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime,
nullable=False,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
init=False,
)
@property

View File

@ -1,6 +1,6 @@
[project]
name = "dify-api"
version = "1.11.3"
version = "1.11.4"
requires-python = ">=3.11,<3.13"
dependencies = [

View File

@ -1,90 +1,62 @@
import datetime
import logging
import time
import click
from sqlalchemy.exc import SQLAlchemyError
import app
from configs import dify_config
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.model import (
App,
Message,
MessageAgentThought,
MessageAnnotation,
MessageChain,
MessageFeedback,
MessageFile,
)
from models.web import SavedMessage
from services.feature_service import FeatureService
from services.retention.conversation.messages_clean_policy import create_message_clean_policy
from services.retention.conversation.messages_clean_service import MessagesCleanService
logger = logging.getLogger(__name__)
@app.celery.task(queue="dataset")
@app.celery.task(queue="retention")
def clean_messages():
click.echo(click.style("Start clean messages.", fg="green"))
start_at = time.perf_counter()
plan_sandbox_clean_message_day = datetime.datetime.now() - datetime.timedelta(
days=dify_config.PLAN_SANDBOX_CLEAN_MESSAGE_DAY_SETTING
)
while True:
try:
# Main query with join and filter
messages = (
db.session.query(Message)
.where(Message.created_at < plan_sandbox_clean_message_day)
.order_by(Message.created_at.desc())
.limit(100)
.all()
)
"""
Clean expired messages based on clean policy.
except SQLAlchemyError:
raise
if not messages:
break
for message in messages:
app = db.session.query(App).filter_by(id=message.app_id).first()
if not app:
logger.warning(
"Expected App record to exist, but none was found, app_id=%s, message_id=%s",
message.app_id,
message.id,
)
continue
features_cache_key = f"features:{app.tenant_id}"
plan_cache = redis_client.get(features_cache_key)
if plan_cache is None:
features = FeatureService.get_features(app.tenant_id)
redis_client.setex(features_cache_key, 600, features.billing.subscription.plan)
plan = features.billing.subscription.plan
else:
plan = plan_cache.decode()
if plan == CloudPlan.SANDBOX:
# clean related message
db.session.query(MessageFeedback).where(MessageFeedback.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(MessageChain).where(MessageChain.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(MessageFile).where(MessageFile.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(SavedMessage).where(SavedMessage.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(Message).where(Message.id == message.id).delete()
db.session.commit()
end_at = time.perf_counter()
click.echo(click.style(f"Cleaned messages from db success latency: {end_at - start_at}", fg="green"))
This task uses MessagesCleanService to efficiently clean messages in batches.
The behavior depends on BILLING_ENABLED configuration:
- BILLING_ENABLED=True: only delete messages from sandbox tenants (with whitelist/grace period)
- BILLING_ENABLED=False: delete all messages within the time range
"""
click.echo(click.style("clean_messages: start clean messages.", fg="green"))
start_at = time.perf_counter()
try:
# Create policy based on billing configuration
policy = create_message_clean_policy(
graceful_period_days=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD,
)
# Create and run the cleanup service
service = MessagesCleanService.from_days(
policy=policy,
days=dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS,
batch_size=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE,
)
stats = service.run()
end_at = time.perf_counter()
click.echo(
click.style(
f"clean_messages: completed successfully\n"
f" - Latency: {end_at - start_at:.2f}s\n"
f" - Batches processed: {stats['batches']}\n"
f" - Total messages scanned: {stats['total_messages']}\n"
f" - Messages filtered: {stats['filtered_messages']}\n"
f" - Messages deleted: {stats['total_deleted']}",
fg="green",
)
)
except Exception as e:
end_at = time.perf_counter()
logger.exception("clean_messages failed")
click.echo(
click.style(
f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}",
fg="red",
)
)
raise

View File

@ -50,10 +50,13 @@ def create_clusters(batch_size):
)
for new_cluster in new_clusters:
tidb_auth_binding = TidbAuthBinding(
tenant_id=None,
cluster_id=new_cluster["cluster_id"],
cluster_name=new_cluster["cluster_name"],
account=new_cluster["account"],
password=new_cluster["password"],
active=False,
status="CREATING",
)
db.session.add(tidb_auth_binding)
db.session.commit()

View File

@ -0,0 +1,58 @@
import json
import logging
import uuid
from datetime import UTC, datetime
from redis import RedisError
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
WORKSPACE_SYNC_QUEUE = "enterprise:workspace:sync:queue"
WORKSPACE_SYNC_PROCESSING = "enterprise:workspace:sync:processing"
class WorkspaceSyncService:
"""Service to publish workspace sync tasks to Redis queue for enterprise backend consumption"""
@staticmethod
def queue_credential_sync(workspace_id: str, *, source: str) -> bool:
"""
Queue a credential sync task for a newly created workspace.
This publishes a task to Redis that will be consumed by the enterprise backend
worker to sync credentials with the plugin-manager.
Args:
workspace_id: The workspace/tenant ID to sync credentials for
source: Source of the sync request (for debugging/tracking)
Returns:
bool: True if task was queued successfully, False otherwise
"""
try:
task = {
"task_id": str(uuid.uuid4()),
"workspace_id": workspace_id,
"retry_count": 0,
"created_at": datetime.now(UTC).isoformat(),
"source": source,
}
# Push to Redis list (queue) - LPUSH adds to the head, worker consumes from tail with RPOP
redis_client.lpush(WORKSPACE_SYNC_QUEUE, json.dumps(task))
logger.info(
"Queued credential sync task for workspace %s, task_id: %s, source: %s",
workspace_id,
task["task_id"],
source,
)
return True
except (RedisError, TypeError) as e:
logger.error("Failed to queue credential sync for workspace %s: %s", workspace_id, str(e), exc_info=True)
# Don't raise - we don't want to fail workspace creation if queueing fails
# The scheduled task will catch it later
return False

View File

@ -0,0 +1,216 @@
import datetime
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from configs import dify_config
from enums.cloud_plan import CloudPlan
from services.billing_service import BillingService, SubscriptionPlan
logger = logging.getLogger(__name__)
@dataclass
class SimpleMessage:
id: str
app_id: str
created_at: datetime.datetime
class MessagesCleanPolicy(ABC):
"""
Abstract base class for message cleanup policies.
A policy determines which messages from a batch should be deleted.
"""
@abstractmethod
def filter_message_ids(
self,
messages: Sequence[SimpleMessage],
app_to_tenant: dict[str, str],
) -> Sequence[str]:
"""
Filter messages and return IDs of messages that should be deleted.
Args:
messages: Batch of messages to evaluate
app_to_tenant: Mapping from app_id to tenant_id
Returns:
List of message IDs that should be deleted
"""
...
class BillingDisabledPolicy(MessagesCleanPolicy):
"""
Policy for community or enterpriseedition (billing disabled).
No special filter logic, just return all message ids.
"""
def filter_message_ids(
self,
messages: Sequence[SimpleMessage],
app_to_tenant: dict[str, str],
) -> Sequence[str]:
return [msg.id for msg in messages]
class BillingSandboxPolicy(MessagesCleanPolicy):
"""
Policy for sandbox plan tenants in cloud edition (billing enabled).
Filters messages based on sandbox plan expiration rules:
- Skip tenants in the whitelist
- Only delete messages from sandbox plan tenants
- Respect grace period after subscription expiration
- Safe default: if tenant mapping or plan is missing, do NOT delete
"""
def __init__(
self,
plan_provider: Callable[[Sequence[str]], dict[str, SubscriptionPlan]],
graceful_period_days: int = 21,
tenant_whitelist: Sequence[str] | None = None,
current_timestamp: int | None = None,
) -> None:
self._graceful_period_days = graceful_period_days
self._tenant_whitelist: Sequence[str] = tenant_whitelist or []
self._plan_provider = plan_provider
self._current_timestamp = current_timestamp
def filter_message_ids(
self,
messages: Sequence[SimpleMessage],
app_to_tenant: dict[str, str],
) -> Sequence[str]:
"""
Filter messages based on sandbox plan expiration rules.
Args:
messages: Batch of messages to evaluate
app_to_tenant: Mapping from app_id to tenant_id
Returns:
List of message IDs that should be deleted
"""
if not messages or not app_to_tenant:
return []
# Get unique tenant_ids and fetch subscription plans
tenant_ids = list(set(app_to_tenant.values()))
tenant_plans = self._plan_provider(tenant_ids)
if not tenant_plans:
return []
# Apply sandbox deletion rules
return self._filter_expired_sandbox_messages(
messages=messages,
app_to_tenant=app_to_tenant,
tenant_plans=tenant_plans,
)
def _filter_expired_sandbox_messages(
self,
messages: Sequence[SimpleMessage],
app_to_tenant: dict[str, str],
tenant_plans: dict[str, SubscriptionPlan],
) -> list[str]:
"""
Filter messages that should be deleted based on sandbox plan expiration.
A message should be deleted if:
1. It belongs to a sandbox tenant AND
2. Either:
a) The tenant has no previous subscription (expiration_date == -1), OR
b) The subscription expired more than graceful_period_days ago
Args:
messages: List of message objects with id and app_id attributes
app_to_tenant: Mapping from app_id to tenant_id
tenant_plans: Mapping from tenant_id to subscription plan info
Returns:
List of message IDs that should be deleted
"""
current_timestamp = self._current_timestamp
if current_timestamp is None:
current_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp())
sandbox_message_ids: list[str] = []
graceful_period_seconds = self._graceful_period_days * 24 * 60 * 60
for msg in messages:
# Get tenant_id for this message's app
tenant_id = app_to_tenant.get(msg.app_id)
if not tenant_id:
continue
# Skip tenant messages in whitelist
if tenant_id in self._tenant_whitelist:
continue
# Get subscription plan for this tenant
tenant_plan = tenant_plans.get(tenant_id)
if not tenant_plan:
continue
plan = str(tenant_plan["plan"])
expiration_date = int(tenant_plan["expiration_date"])
# Only process sandbox plans
if plan != CloudPlan.SANDBOX:
continue
# Case 1: No previous subscription (-1 means never had a paid subscription)
if expiration_date == -1:
sandbox_message_ids.append(msg.id)
continue
# Case 2: Subscription expired beyond grace period
if current_timestamp - expiration_date > graceful_period_seconds:
sandbox_message_ids.append(msg.id)
return sandbox_message_ids
def create_message_clean_policy(
graceful_period_days: int = 21,
current_timestamp: int | None = None,
) -> MessagesCleanPolicy:
"""
Factory function to create the appropriate message clean policy.
Determines which policy to use based on BILLING_ENABLED configuration:
- If BILLING_ENABLED is True: returns BillingSandboxPolicy
- If BILLING_ENABLED is False: returns BillingDisabledPolicy
Args:
graceful_period_days: Grace period in days after subscription expiration (default: 21)
current_timestamp: Current Unix timestamp for testing (default: None, uses current time)
"""
if not dify_config.BILLING_ENABLED:
logger.info("create_message_clean_policy: billing disabled, using BillingDisabledPolicy")
return BillingDisabledPolicy()
# Billing enabled - fetch whitelist from BillingService
tenant_whitelist = BillingService.get_expired_subscription_cleanup_whitelist()
plan_provider = BillingService.get_plan_bulk_with_cache
logger.info(
"create_message_clean_policy: billing enabled, using BillingSandboxPolicy "
"(graceful_period_days=%s, whitelist=%s)",
graceful_period_days,
tenant_whitelist,
)
return BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=graceful_period_days,
tenant_whitelist=tenant_whitelist,
current_timestamp=current_timestamp,
)

View File

@ -0,0 +1,334 @@
import datetime
import logging
import random
from collections.abc import Sequence
from typing import cast
from sqlalchemy import delete, select
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session
from extensions.ext_database import db
from models.model import (
App,
AppAnnotationHitHistory,
DatasetRetrieverResource,
Message,
MessageAgentThought,
MessageAnnotation,
MessageChain,
MessageFeedback,
MessageFile,
)
from models.web import SavedMessage
from services.retention.conversation.messages_clean_policy import (
MessagesCleanPolicy,
SimpleMessage,
)
logger = logging.getLogger(__name__)
class MessagesCleanService:
"""
Service for cleaning expired messages based on retention policies.
Compatible with non cloud edition (billing disabled): all messages in the time range will be deleted.
If billing is enabled: only sandbox plan tenant messages are deleted (with whitelist and grace period support).
"""
def __init__(
self,
policy: MessagesCleanPolicy,
end_before: datetime.datetime,
start_from: datetime.datetime | None = None,
batch_size: int = 1000,
dry_run: bool = False,
) -> None:
"""
Initialize the service with cleanup parameters.
Args:
policy: The policy that determines which messages to delete
end_before: End time (exclusive) of the range
start_from: Optional start time (inclusive) of the range
batch_size: Number of messages to process per batch
dry_run: Whether to perform a dry run (no actual deletion)
"""
self._policy = policy
self._end_before = end_before
self._start_from = start_from
self._batch_size = batch_size
self._dry_run = dry_run
@classmethod
def from_time_range(
cls,
policy: MessagesCleanPolicy,
start_from: datetime.datetime,
end_before: datetime.datetime,
batch_size: int = 1000,
dry_run: bool = False,
) -> "MessagesCleanService":
"""
Create a service instance for cleaning messages within a specific time range.
Time range is [start_from, end_before).
Args:
policy: The policy that determines which messages to delete
start_from: Start time (inclusive) of the range
end_before: End time (exclusive) of the range
batch_size: Number of messages to process per batch
dry_run: Whether to perform a dry run (no actual deletion)
Returns:
MessagesCleanService instance
Raises:
ValueError: If start_from >= end_before or invalid parameters
"""
if start_from >= end_before:
raise ValueError(f"start_from ({start_from}) must be less than end_before ({end_before})")
if batch_size <= 0:
raise ValueError(f"batch_size ({batch_size}) must be greater than 0")
logger.info(
"clean_messages: start_from=%s, end_before=%s, batch_size=%s, policy=%s",
start_from,
end_before,
batch_size,
policy.__class__.__name__,
)
return cls(
policy=policy,
end_before=end_before,
start_from=start_from,
batch_size=batch_size,
dry_run=dry_run,
)
@classmethod
def from_days(
cls,
policy: MessagesCleanPolicy,
days: int = 30,
batch_size: int = 1000,
dry_run: bool = False,
) -> "MessagesCleanService":
"""
Create a service instance for cleaning messages older than specified days.
Args:
policy: The policy that determines which messages to delete
days: Number of days to look back from now
batch_size: Number of messages to process per batch
dry_run: Whether to perform a dry run (no actual deletion)
Returns:
MessagesCleanService instance
Raises:
ValueError: If invalid parameters
"""
if days < 0:
raise ValueError(f"days ({days}) must be greater than or equal to 0")
if batch_size <= 0:
raise ValueError(f"batch_size ({batch_size}) must be greater than 0")
end_before = datetime.datetime.now() - datetime.timedelta(days=days)
logger.info(
"clean_messages: days=%s, end_before=%s, batch_size=%s, policy=%s",
days,
end_before,
batch_size,
policy.__class__.__name__,
)
return cls(policy=policy, end_before=end_before, start_from=None, batch_size=batch_size, dry_run=dry_run)
def run(self) -> dict[str, int]:
"""
Execute the message cleanup operation.
Returns:
Dict with statistics: batches, filtered_messages, total_deleted
"""
return self._clean_messages_by_time_range()
def _clean_messages_by_time_range(self) -> dict[str, int]:
"""
Clean messages within a time range using cursor-based pagination.
Time range is [start_from, end_before)
Steps:
1. Iterate messages using cursor pagination (by created_at, id)
2. Query app_id -> tenant_id mapping
3. Delegate to policy to determine which messages to delete
4. Batch delete messages and their relations
Returns:
Dict with statistics: batches, filtered_messages, total_deleted
"""
stats = {
"batches": 0,
"total_messages": 0,
"filtered_messages": 0,
"total_deleted": 0,
}
# Cursor-based pagination using (created_at, id) to avoid infinite loops
# and ensure proper ordering with time-based filtering
_cursor: tuple[datetime.datetime, str] | None = None
logger.info(
"clean_messages: start cleaning messages (dry_run=%s), start_from=%s, end_before=%s",
self._dry_run,
self._start_from,
self._end_before,
)
while True:
stats["batches"] += 1
# Step 1: Fetch a batch of messages using cursor
with Session(db.engine, expire_on_commit=False) as session:
msg_stmt = (
select(Message.id, Message.app_id, Message.created_at)
.where(Message.created_at < self._end_before)
.order_by(Message.created_at, Message.id)
.limit(self._batch_size)
)
if self._start_from:
msg_stmt = msg_stmt.where(Message.created_at >= self._start_from)
# Apply cursor condition: (created_at, id) > (last_created_at, last_message_id)
# This translates to:
# created_at > last_created_at OR (created_at = last_created_at AND id > last_message_id)
if _cursor:
# Continuing from previous batch
msg_stmt = msg_stmt.where(
(Message.created_at > _cursor[0])
| ((Message.created_at == _cursor[0]) & (Message.id > _cursor[1]))
)
raw_messages = list(session.execute(msg_stmt).all())
messages = [
SimpleMessage(id=msg_id, app_id=app_id, created_at=msg_created_at)
for msg_id, app_id, msg_created_at in raw_messages
]
# Track total messages fetched across all batches
stats["total_messages"] += len(messages)
if not messages:
logger.info("clean_messages (batch %s): no more messages to process", stats["batches"])
break
# Update cursor to the last message's (created_at, id)
_cursor = (messages[-1].created_at, messages[-1].id)
# Step 2: Extract app_ids and query tenant_ids
app_ids = list({msg.app_id for msg in messages})
if not app_ids:
logger.info("clean_messages (batch %s): no app_ids found, skip", stats["batches"])
continue
app_stmt = select(App.id, App.tenant_id).where(App.id.in_(app_ids))
apps = list(session.execute(app_stmt).all())
if not apps:
logger.info("clean_messages (batch %s): no apps found, skip", stats["batches"])
continue
# Build app_id -> tenant_id mapping
app_to_tenant: dict[str, str] = {app.id: app.tenant_id for app in apps}
# Step 3: Delegate to policy to determine which messages to delete
message_ids_to_delete = self._policy.filter_message_ids(messages, app_to_tenant)
if not message_ids_to_delete:
logger.info("clean_messages (batch %s): no messages to delete, skip", stats["batches"])
continue
stats["filtered_messages"] += len(message_ids_to_delete)
# Step 4: Batch delete messages and their relations
if not self._dry_run:
with Session(db.engine, expire_on_commit=False) as session:
# Delete related records first
self._batch_delete_message_relations(session, message_ids_to_delete)
# Delete messages
delete_stmt = delete(Message).where(Message.id.in_(message_ids_to_delete))
delete_result = cast(CursorResult, session.execute(delete_stmt))
messages_deleted = delete_result.rowcount
session.commit()
stats["total_deleted"] += messages_deleted
logger.info(
"clean_messages (batch %s): processed %s messages, deleted %s messages",
stats["batches"],
len(messages),
messages_deleted,
)
else:
# Log random sample of message IDs that would be deleted (up to 10)
sample_size = min(10, len(message_ids_to_delete))
sampled_ids = random.sample(list(message_ids_to_delete), sample_size)
logger.info(
"clean_messages (batch %s, dry_run): would delete %s messages, sampling %s ids:",
stats["batches"],
len(message_ids_to_delete),
sample_size,
)
for msg_id in sampled_ids:
logger.info("clean_messages (batch %s, dry_run) sample: message_id=%s", stats["batches"], msg_id)
logger.info(
"clean_messages completed: total batches: %s, total messages: %s, filtered messages: %s, total deleted: %s",
stats["batches"],
stats["total_messages"],
stats["filtered_messages"],
stats["total_deleted"],
)
return stats
@staticmethod
def _batch_delete_message_relations(session: Session, message_ids: Sequence[str]) -> None:
"""
Batch delete all related records for given message IDs.
Args:
session: Database session
message_ids: List of message IDs to delete relations for
"""
if not message_ids:
return
# Delete all related records in batch
session.execute(delete(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)))
session.execute(delete(MessageAnnotation).where(MessageAnnotation.message_id.in_(message_ids)))
session.execute(delete(MessageChain).where(MessageChain.message_id.in_(message_ids)))
session.execute(delete(MessageAgentThought).where(MessageAgentThought.message_id.in_(message_ids)))
session.execute(delete(MessageFile).where(MessageFile.message_id.in_(message_ids)))
session.execute(delete(SavedMessage).where(SavedMessage.message_id.in_(message_ids)))
session.execute(delete(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id.in_(message_ids)))
session.execute(delete(DatasetRetrieverResource).where(DatasetRetrieverResource.message_id.in_(message_ids)))

View File

@ -103,6 +103,8 @@ SMTP_USERNAME=123
SMTP_PASSWORD=abc
SMTP_USE_TLS=true
SMTP_OPPORTUNISTIC_TLS=false
# Optional: override the local hostname used for SMTP HELO/EHLO
SMTP_LOCAL_HOSTNAME=
# Sentry configuration
SENTRY_DSN=

View File

@ -0,0 +1 @@
"""Unit tests for `controllers.console.datasets` controllers."""

View File

@ -0,0 +1,49 @@
from __future__ import annotations
"""
Unit tests for the external dataset controller payload schemas.
These tests focus on Pydantic validation rules so we can catch regressions
in request constraints (e.g. max length changes) without exercising the
full Flask/RESTX request stack.
"""
import pytest
from pydantic import ValidationError
from controllers.console.datasets.external import ExternalDatasetCreatePayload
def test_external_dataset_create_payload_allows_name_length_100() -> None:
"""Ensure the `name` field accepts up to 100 characters (inclusive)."""
# Build a request payload with a boundary-length name value.
name_100: str = "a" * 100
payload = {
"external_knowledge_api_id": "ek-api-1",
"external_knowledge_id": "ek-1",
"name": name_100,
}
model = ExternalDatasetCreatePayload.model_validate(payload)
assert model.name == name_100
def test_external_dataset_create_payload_rejects_name_length_101() -> None:
"""Ensure the `name` field rejects values longer than 100 characters."""
# Build a request payload that exceeds the max length by 1.
name_101: str = "a" * 101
payload: dict[str, object] = {
"external_knowledge_api_id": "ek-api-1",
"external_knowledge_id": "ek-1",
"name": name_101,
}
with pytest.raises(ValidationError) as exc_info:
ExternalDatasetCreatePayload.model_validate(payload)
errors = exc_info.value.errors()
assert errors[0]["loc"] == ("name",)
assert errors[0]["type"] == "string_too_long"
assert errors[0]["ctx"]["max_length"] == 100

View File

@ -0,0 +1,279 @@
"""Unit tests for PluginEndpointClient functionality.
This test module covers the endpoint client operations including:
- Successful endpoint deletion
- Idempotent delete behavior (record not found)
- Non-idempotent delete behavior (other errors)
Tests follow the Arrange-Act-Assert pattern for clarity.
"""
from unittest.mock import MagicMock, patch
import pytest
from core.plugin.impl.endpoint import PluginEndpointClient
from core.plugin.impl.exc import PluginDaemonInternalServerError
class TestPluginEndpointClientDelete:
"""Unit tests for PluginEndpointClient delete_endpoint operation.
Tests cover:
- Successful endpoint deletion
- Idempotent behavior when endpoint is already deleted (record not found)
- Non-idempotent behavior for other errors
"""
@pytest.fixture
def endpoint_client(self):
"""Create a PluginEndpointClient instance for testing."""
return PluginEndpointClient()
@pytest.fixture
def mock_config(self):
"""Mock plugin daemon configuration."""
with (
patch("core.plugin.impl.base.dify_config.PLUGIN_DAEMON_URL", "http://127.0.0.1:5002"),
patch("core.plugin.impl.base.dify_config.PLUGIN_DAEMON_KEY", "test-api-key"),
):
yield
def test_delete_endpoint_success(self, endpoint_client, mock_config):
"""Test successful endpoint deletion.
Given:
- A valid tenant_id, user_id, and endpoint_id
- The plugin daemon returns success response
When:
- delete_endpoint is called
Then:
- The method should return True
- The request should be made with correct parameters
"""
# Arrange
tenant_id = "tenant-123"
user_id = "user-456"
endpoint_id = "endpoint-789"
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"code": 0,
"message": "success",
"data": True,
}
with patch("httpx.request", return_value=mock_response):
# Act
result = endpoint_client.delete_endpoint(
tenant_id=tenant_id,
user_id=user_id,
endpoint_id=endpoint_id,
)
# Assert
assert result is True
def test_delete_endpoint_idempotent_record_not_found(self, endpoint_client, mock_config):
"""Test idempotent delete behavior when endpoint is already deleted.
Given:
- A valid tenant_id, user_id, and endpoint_id
- The plugin daemon returns "record not found" error
When:
- delete_endpoint is called
Then:
- The method should return True (idempotent behavior)
- No exception should be raised
"""
# Arrange
tenant_id = "tenant-123"
user_id = "user-456"
endpoint_id = "endpoint-789"
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"code": -1,
"message": (
'{"error_type": "PluginDaemonInternalServerError", '
'"message": "failed to remove endpoint: record not found"}'
),
}
with patch("httpx.request", return_value=mock_response):
# Act
result = endpoint_client.delete_endpoint(
tenant_id=tenant_id,
user_id=user_id,
endpoint_id=endpoint_id,
)
# Assert - should return True instead of raising an error
assert result is True
def test_delete_endpoint_non_idempotent_other_errors(self, endpoint_client, mock_config):
"""Test non-idempotent delete behavior for other errors.
Given:
- A valid tenant_id, user_id, and endpoint_id
- The plugin daemon returns a different error (not "record not found")
When:
- delete_endpoint is called
Then:
- The method should raise PluginDaemonInternalServerError
"""
# Arrange
tenant_id = "tenant-123"
user_id = "user-456"
endpoint_id = "endpoint-789"
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"code": -1,
"message": (
'{"error_type": "PluginDaemonInternalServerError", '
'"message": "failed to remove endpoint: internal server error"}'
),
}
with patch("httpx.request", return_value=mock_response):
# Act & Assert
with pytest.raises(PluginDaemonInternalServerError) as exc_info:
endpoint_client.delete_endpoint(
tenant_id=tenant_id,
user_id=user_id,
endpoint_id=endpoint_id,
)
# Assert - the error message should not be "record not found"
assert "record not found" not in str(exc_info.value.description)
def test_delete_endpoint_idempotent_case_insensitive(self, endpoint_client, mock_config):
"""Test idempotent delete behavior with case-insensitive error message.
Given:
- A valid tenant_id, user_id, and endpoint_id
- The plugin daemon returns "Record Not Found" error (different case)
When:
- delete_endpoint is called
Then:
- The method should return True (idempotent behavior)
"""
# Arrange
tenant_id = "tenant-123"
user_id = "user-456"
endpoint_id = "endpoint-789"
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"code": -1,
"message": '{"error_type": "PluginDaemonInternalServerError", "message": "Record Not Found"}',
}
with patch("httpx.request", return_value=mock_response):
# Act
result = endpoint_client.delete_endpoint(
tenant_id=tenant_id,
user_id=user_id,
endpoint_id=endpoint_id,
)
# Assert - should still return True
assert result is True
def test_delete_endpoint_multiple_calls_idempotent(self, endpoint_client, mock_config):
"""Test that multiple delete calls are idempotent.
Given:
- A valid tenant_id, user_id, and endpoint_id
- The first call succeeds
- Subsequent calls return "record not found"
When:
- delete_endpoint is called multiple times
Then:
- All calls should return True
"""
# Arrange
tenant_id = "tenant-123"
user_id = "user-456"
endpoint_id = "endpoint-789"
# First call - success
mock_response_success = MagicMock()
mock_response_success.status_code = 200
mock_response_success.json.return_value = {
"code": 0,
"message": "success",
"data": True,
}
# Second call - record not found
mock_response_not_found = MagicMock()
mock_response_not_found.status_code = 200
mock_response_not_found.json.return_value = {
"code": -1,
"message": (
'{"error_type": "PluginDaemonInternalServerError", '
'"message": "failed to remove endpoint: record not found"}'
),
}
with patch("httpx.request") as mock_request:
# Act - first call
mock_request.return_value = mock_response_success
result1 = endpoint_client.delete_endpoint(
tenant_id=tenant_id,
user_id=user_id,
endpoint_id=endpoint_id,
)
# Act - second call (already deleted)
mock_request.return_value = mock_response_not_found
result2 = endpoint_client.delete_endpoint(
tenant_id=tenant_id,
user_id=user_id,
endpoint_id=endpoint_id,
)
# Assert - both should return True
assert result1 is True
assert result2 is True
def test_delete_endpoint_non_idempotent_unauthorized_error(self, endpoint_client, mock_config):
"""Test that authorization errors are not treated as idempotent.
Given:
- A valid tenant_id, user_id, and endpoint_id
- The plugin daemon returns an unauthorized error
When:
- delete_endpoint is called
Then:
- The method should raise the appropriate error (not return True)
"""
# Arrange
tenant_id = "tenant-123"
user_id = "user-456"
endpoint_id = "endpoint-789"
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"code": -1,
"message": '{"error_type": "PluginDaemonUnauthorizedError", "message": "unauthorized access"}',
}
with patch("httpx.request", return_value=mock_response):
# Act & Assert
with pytest.raises(Exception) as exc_info:
endpoint_client.delete_endpoint(
tenant_id=tenant_id,
user_id=user_id,
endpoint_id=endpoint_id,
)
# Assert - should not return True for unauthorized errors
assert exc_info.value.__class__.__name__ == "PluginDaemonUnauthorizedError"

View File

@ -1,4 +1,4 @@
from unittest.mock import MagicMock, patch
from unittest.mock import ANY, MagicMock, patch
import pytest
@ -17,7 +17,7 @@ def test_smtp_plain_success(mock_smtp_cls: MagicMock):
client = SMTPClient(server="smtp.example.com", port=25, username="", password="", _from="noreply@example.com")
client.send(_mail())
mock_smtp_cls.assert_called_once_with("smtp.example.com", 25, timeout=10)
mock_smtp_cls.assert_called_once_with("smtp.example.com", 25, timeout=10, local_hostname=ANY)
mock_smtp.sendmail.assert_called_once()
mock_smtp.quit.assert_called_once()
@ -38,7 +38,7 @@ def test_smtp_tls_opportunistic_success(mock_smtp_cls: MagicMock):
)
client.send(_mail())
mock_smtp_cls.assert_called_once_with("smtp.example.com", 587, timeout=10)
mock_smtp_cls.assert_called_once_with("smtp.example.com", 587, timeout=10, local_hostname=ANY)
assert mock_smtp.ehlo.call_count == 2
mock_smtp.starttls.assert_called_once()
mock_smtp.login.assert_called_once_with("user", "pass")

View File

@ -0,0 +1,627 @@
import datetime
from unittest.mock import MagicMock, patch
import pytest
from enums.cloud_plan import CloudPlan
from services.retention.conversation.messages_clean_policy import (
BillingDisabledPolicy,
BillingSandboxPolicy,
SimpleMessage,
create_message_clean_policy,
)
from services.retention.conversation.messages_clean_service import MessagesCleanService
def make_simple_message(msg_id: str, app_id: str) -> SimpleMessage:
"""Helper to create a SimpleMessage with a fixed created_at timestamp."""
return SimpleMessage(id=msg_id, app_id=app_id, created_at=datetime.datetime(2024, 1, 1))
def make_plan_provider(tenant_plans: dict) -> MagicMock:
"""Helper to create a mock plan_provider that returns the given tenant_plans."""
provider = MagicMock()
provider.return_value = tenant_plans
return provider
class TestBillingSandboxPolicyFilterMessageIds:
"""Unit tests for BillingSandboxPolicy.filter_message_ids method."""
# Fixed timestamp for deterministic tests
CURRENT_TIMESTAMP = 1000000
GRACEFUL_PERIOD_DAYS = 8
GRACEFUL_PERIOD_SECONDS = GRACEFUL_PERIOD_DAYS * 24 * 60 * 60
def test_missing_tenant_mapping_excluded(self):
"""Test that messages with missing app-to-tenant mapping are excluded."""
# Arrange
messages = [
make_simple_message("msg1", "app1"),
make_simple_message("msg2", "app2"),
]
app_to_tenant = {} # No mapping
tenant_plans = {"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}}
plan_provider = make_plan_provider(tenant_plans)
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS,
current_timestamp=self.CURRENT_TIMESTAMP,
)
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert
assert list(result) == []
def test_missing_tenant_plan_excluded(self):
"""Test that messages with missing tenant plan are excluded (safe default)."""
# Arrange
messages = [
make_simple_message("msg1", "app1"),
make_simple_message("msg2", "app2"),
]
app_to_tenant = {"app1": "tenant1", "app2": "tenant2"}
tenant_plans = {} # No plans
plan_provider = make_plan_provider(tenant_plans)
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS,
current_timestamp=self.CURRENT_TIMESTAMP,
)
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert
assert list(result) == []
def test_non_sandbox_plan_excluded(self):
"""Test that messages from non-sandbox plans (PROFESSIONAL/TEAM) are excluded."""
# Arrange
messages = [
make_simple_message("msg1", "app1"),
make_simple_message("msg2", "app2"),
make_simple_message("msg3", "app3"),
]
app_to_tenant = {"app1": "tenant1", "app2": "tenant2", "app3": "tenant3"}
tenant_plans = {
"tenant1": {"plan": CloudPlan.PROFESSIONAL, "expiration_date": -1},
"tenant2": {"plan": CloudPlan.TEAM, "expiration_date": -1},
"tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, # Only this one
}
plan_provider = make_plan_provider(tenant_plans)
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS,
current_timestamp=self.CURRENT_TIMESTAMP,
)
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert - only msg3 (sandbox tenant) should be included
assert set(result) == {"msg3"}
def test_whitelist_skip(self):
"""Test that whitelisted tenants are excluded even if sandbox + expired."""
# Arrange
messages = [
make_simple_message("msg1", "app1"), # Whitelisted - excluded
make_simple_message("msg2", "app2"), # Not whitelisted - included
make_simple_message("msg3", "app3"), # Whitelisted - excluded
]
app_to_tenant = {"app1": "tenant1", "app2": "tenant2", "app3": "tenant3"}
tenant_plans = {
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
"tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
}
plan_provider = make_plan_provider(tenant_plans)
tenant_whitelist = ["tenant1", "tenant3"]
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS,
tenant_whitelist=tenant_whitelist,
current_timestamp=self.CURRENT_TIMESTAMP,
)
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert - only msg2 should be included
assert set(result) == {"msg2"}
def test_no_previous_subscription_included(self):
"""Test that messages with expiration_date=-1 (no previous subscription) are included."""
# Arrange
messages = [
make_simple_message("msg1", "app1"),
make_simple_message("msg2", "app2"),
]
app_to_tenant = {"app1": "tenant1", "app2": "tenant2"}
tenant_plans = {
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
}
plan_provider = make_plan_provider(tenant_plans)
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS,
current_timestamp=self.CURRENT_TIMESTAMP,
)
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert - all messages should be included
assert set(result) == {"msg1", "msg2"}
def test_within_grace_period_excluded(self):
"""Test that messages within grace period are excluded."""
# Arrange
now = self.CURRENT_TIMESTAMP
expired_1_day_ago = now - (1 * 24 * 60 * 60)
expired_5_days_ago = now - (5 * 24 * 60 * 60)
expired_7_days_ago = now - (7 * 24 * 60 * 60)
messages = [
make_simple_message("msg1", "app1"),
make_simple_message("msg2", "app2"),
make_simple_message("msg3", "app3"),
]
app_to_tenant = {"app1": "tenant1", "app2": "tenant2", "app3": "tenant3"}
tenant_plans = {
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_1_day_ago},
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_5_days_ago},
"tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_7_days_ago},
}
plan_provider = make_plan_provider(tenant_plans)
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS, # 8 days
current_timestamp=now,
)
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert - all within 8-day grace period, none should be included
assert list(result) == []
def test_exactly_at_boundary_excluded(self):
"""Test that messages exactly at grace period boundary are excluded (code uses >)."""
# Arrange
now = self.CURRENT_TIMESTAMP
expired_exactly_8_days_ago = now - self.GRACEFUL_PERIOD_SECONDS # Exactly at boundary
messages = [make_simple_message("msg1", "app1")]
app_to_tenant = {"app1": "tenant1"}
tenant_plans = {
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_exactly_8_days_ago},
}
plan_provider = make_plan_provider(tenant_plans)
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS,
current_timestamp=now,
)
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert - exactly at boundary (==) should be excluded (code uses >)
assert list(result) == []
def test_beyond_grace_period_included(self):
"""Test that messages beyond grace period are included."""
# Arrange
now = self.CURRENT_TIMESTAMP
expired_9_days_ago = now - (9 * 24 * 60 * 60) # Just beyond 8-day grace
expired_30_days_ago = now - (30 * 24 * 60 * 60) # Well beyond
messages = [
make_simple_message("msg1", "app1"),
make_simple_message("msg2", "app2"),
]
app_to_tenant = {"app1": "tenant1", "app2": "tenant2"}
tenant_plans = {
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_9_days_ago},
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_30_days_ago},
}
plan_provider = make_plan_provider(tenant_plans)
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS,
current_timestamp=now,
)
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert - both beyond grace period, should be included
assert set(result) == {"msg1", "msg2"}
def test_empty_messages_returns_empty(self):
"""Test that empty messages returns empty list."""
# Arrange
messages: list[SimpleMessage] = []
app_to_tenant = {"app1": "tenant1"}
plan_provider = make_plan_provider({"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}})
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS,
current_timestamp=self.CURRENT_TIMESTAMP,
)
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert
assert list(result) == []
def test_plan_provider_called_with_correct_tenant_ids(self):
"""Test that plan_provider is called with correct tenant_ids."""
# Arrange
messages = [
make_simple_message("msg1", "app1"),
make_simple_message("msg2", "app2"),
make_simple_message("msg3", "app3"),
]
app_to_tenant = {"app1": "tenant1", "app2": "tenant2", "app3": "tenant1"} # tenant1 appears twice
plan_provider = make_plan_provider({})
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS,
current_timestamp=self.CURRENT_TIMESTAMP,
)
# Act
policy.filter_message_ids(messages, app_to_tenant)
# Assert - plan_provider should be called once with unique tenant_ids
plan_provider.assert_called_once()
called_tenant_ids = set(plan_provider.call_args[0][0])
assert called_tenant_ids == {"tenant1", "tenant2"}
def test_complex_mixed_scenario(self):
"""Test complex scenario with mixed plans, expirations, whitelist, and missing mappings."""
# Arrange
now = self.CURRENT_TIMESTAMP
sandbox_expired_old = now - (15 * 24 * 60 * 60) # Beyond grace
sandbox_expired_recent = now - (3 * 24 * 60 * 60) # Within grace
future_expiration = now + (30 * 24 * 60 * 60)
messages = [
make_simple_message("msg1", "app1"), # Sandbox, no subscription - included
make_simple_message("msg2", "app2"), # Sandbox, expired old - included
make_simple_message("msg3", "app3"), # Sandbox, within grace - excluded
make_simple_message("msg4", "app4"), # Team plan, active - excluded
make_simple_message("msg5", "app5"), # No tenant mapping - excluded
make_simple_message("msg6", "app6"), # No plan info - excluded
make_simple_message("msg7", "app7"), # Sandbox, expired old, whitelisted - excluded
]
app_to_tenant = {
"app1": "tenant1",
"app2": "tenant2",
"app3": "tenant3",
"app4": "tenant4",
"app6": "tenant6", # Has mapping but no plan
"app7": "tenant7",
# app5 has no mapping
}
tenant_plans = {
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_old},
"tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_recent},
"tenant4": {"plan": CloudPlan.TEAM, "expiration_date": future_expiration},
"tenant7": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_old},
# tenant6 has no plan
}
plan_provider = make_plan_provider(tenant_plans)
tenant_whitelist = ["tenant7"]
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS,
tenant_whitelist=tenant_whitelist,
current_timestamp=now,
)
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert - only msg1 and msg2 should be included
assert set(result) == {"msg1", "msg2"}
class TestBillingDisabledPolicyFilterMessageIds:
"""Unit tests for BillingDisabledPolicy.filter_message_ids method."""
def test_returns_all_message_ids(self):
"""Test that all message IDs are returned (order-preserving)."""
# Arrange
messages = [
make_simple_message("msg1", "app1"),
make_simple_message("msg2", "app2"),
make_simple_message("msg3", "app3"),
]
app_to_tenant = {"app1": "tenant1", "app2": "tenant2"}
policy = BillingDisabledPolicy()
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert - all message IDs returned in order
assert list(result) == ["msg1", "msg2", "msg3"]
def test_ignores_app_to_tenant(self):
"""Test that app_to_tenant mapping is ignored."""
# Arrange
messages = [
make_simple_message("msg1", "app1"),
make_simple_message("msg2", "app2"),
]
app_to_tenant: dict[str, str] = {} # Empty - should be ignored
policy = BillingDisabledPolicy()
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert - all message IDs still returned
assert list(result) == ["msg1", "msg2"]
def test_empty_messages_returns_empty(self):
"""Test that empty messages returns empty list."""
# Arrange
messages: list[SimpleMessage] = []
app_to_tenant = {"app1": "tenant1"}
policy = BillingDisabledPolicy()
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert
assert list(result) == []
class TestCreateMessageCleanPolicy:
"""Unit tests for create_message_clean_policy factory function."""
@patch("services.retention.conversation.messages_clean_policy.dify_config")
def test_billing_disabled_returns_billing_disabled_policy(self, mock_config):
"""Test that BILLING_ENABLED=False returns BillingDisabledPolicy."""
# Arrange
mock_config.BILLING_ENABLED = False
# Act
policy = create_message_clean_policy(graceful_period_days=21)
# Assert
assert isinstance(policy, BillingDisabledPolicy)
@patch("services.retention.conversation.messages_clean_policy.BillingService")
@patch("services.retention.conversation.messages_clean_policy.dify_config")
def test_billing_enabled_policy_has_correct_internals(self, mock_config, mock_billing_service):
"""Test that BillingSandboxPolicy is created with correct internal values."""
# Arrange
mock_config.BILLING_ENABLED = True
whitelist = ["tenant1", "tenant2"]
mock_billing_service.get_expired_subscription_cleanup_whitelist.return_value = whitelist
mock_plan_provider = MagicMock()
mock_billing_service.get_plan_bulk_with_cache = mock_plan_provider
# Act
policy = create_message_clean_policy(graceful_period_days=14, current_timestamp=1234567)
# Assert
mock_billing_service.get_expired_subscription_cleanup_whitelist.assert_called_once()
assert isinstance(policy, BillingSandboxPolicy)
assert policy._graceful_period_days == 14
assert list(policy._tenant_whitelist) == whitelist
assert policy._plan_provider == mock_plan_provider
assert policy._current_timestamp == 1234567
class TestMessagesCleanServiceFromTimeRange:
"""Unit tests for MessagesCleanService.from_time_range factory method."""
def test_start_from_end_before_raises_value_error(self):
"""Test that start_from == end_before raises ValueError."""
policy = BillingDisabledPolicy()
# Arrange
same_time = datetime.datetime(2024, 1, 1, 12, 0, 0)
# Act & Assert
with pytest.raises(ValueError, match="start_from .* must be less than end_before"):
MessagesCleanService.from_time_range(
policy=policy,
start_from=same_time,
end_before=same_time,
)
# Arrange
start_from = datetime.datetime(2024, 12, 31)
end_before = datetime.datetime(2024, 1, 1)
# Act & Assert
with pytest.raises(ValueError, match="start_from .* must be less than end_before"):
MessagesCleanService.from_time_range(
policy=policy,
start_from=start_from,
end_before=end_before,
)
def test_batch_size_raises_value_error(self):
"""Test that batch_size=0 raises ValueError."""
# Arrange
start_from = datetime.datetime(2024, 1, 1)
end_before = datetime.datetime(2024, 2, 1)
policy = BillingDisabledPolicy()
# Act & Assert
with pytest.raises(ValueError, match="batch_size .* must be greater than 0"):
MessagesCleanService.from_time_range(
policy=policy,
start_from=start_from,
end_before=end_before,
batch_size=0,
)
start_from = datetime.datetime(2024, 1, 1)
end_before = datetime.datetime(2024, 2, 1)
policy = BillingDisabledPolicy()
# Act & Assert
with pytest.raises(ValueError, match="batch_size .* must be greater than 0"):
MessagesCleanService.from_time_range(
policy=policy,
start_from=start_from,
end_before=end_before,
batch_size=-100,
)
def test_valid_params_creates_instance(self):
"""Test that valid parameters create a correctly configured instance."""
# Arrange
start_from = datetime.datetime(2024, 1, 1, 0, 0, 0)
end_before = datetime.datetime(2024, 12, 31, 23, 59, 59)
policy = BillingDisabledPolicy()
batch_size = 500
dry_run = True
# Act
service = MessagesCleanService.from_time_range(
policy=policy,
start_from=start_from,
end_before=end_before,
batch_size=batch_size,
dry_run=dry_run,
)
# Assert
assert isinstance(service, MessagesCleanService)
assert service._policy is policy
assert service._start_from == start_from
assert service._end_before == end_before
assert service._batch_size == batch_size
assert service._dry_run == dry_run
def test_default_params(self):
"""Test that default parameters are applied correctly."""
# Arrange
start_from = datetime.datetime(2024, 1, 1)
end_before = datetime.datetime(2024, 2, 1)
policy = BillingDisabledPolicy()
# Act
service = MessagesCleanService.from_time_range(
policy=policy,
start_from=start_from,
end_before=end_before,
)
# Assert
assert service._batch_size == 1000 # default
assert service._dry_run is False # default
class TestMessagesCleanServiceFromDays:
"""Unit tests for MessagesCleanService.from_days factory method."""
def test_days_raises_value_error(self):
"""Test that days < 0 raises ValueError."""
# Arrange
policy = BillingDisabledPolicy()
# Act & Assert
with pytest.raises(ValueError, match="days .* must be greater than or equal to 0"):
MessagesCleanService.from_days(policy=policy, days=-1)
# Act
with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime:
fixed_now = datetime.datetime(2024, 6, 15, 14, 0, 0)
mock_datetime.datetime.now.return_value = fixed_now
mock_datetime.timedelta = datetime.timedelta
service = MessagesCleanService.from_days(policy=policy, days=0)
# Assert
assert service._end_before == fixed_now
def test_batch_size_raises_value_error(self):
"""Test that batch_size=0 raises ValueError."""
# Arrange
policy = BillingDisabledPolicy()
# Act & Assert
with pytest.raises(ValueError, match="batch_size .* must be greater than 0"):
MessagesCleanService.from_days(policy=policy, days=30, batch_size=0)
# Act & Assert
with pytest.raises(ValueError, match="batch_size .* must be greater than 0"):
MessagesCleanService.from_days(policy=policy, days=30, batch_size=-500)
def test_valid_params_creates_instance(self):
"""Test that valid parameters create a correctly configured instance."""
# Arrange
policy = BillingDisabledPolicy()
days = 90
batch_size = 500
dry_run = True
# Act
with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime:
fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0)
mock_datetime.datetime.now.return_value = fixed_now
mock_datetime.timedelta = datetime.timedelta
service = MessagesCleanService.from_days(
policy=policy,
days=days,
batch_size=batch_size,
dry_run=dry_run,
)
# Assert
expected_end_before = fixed_now - datetime.timedelta(days=days)
assert isinstance(service, MessagesCleanService)
assert service._policy is policy
assert service._start_from is None
assert service._end_before == expected_end_before
assert service._batch_size == batch_size
assert service._dry_run == dry_run
def test_default_params(self):
"""Test that default parameters are applied correctly."""
# Arrange
policy = BillingDisabledPolicy()
# Act
with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime:
fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0)
mock_datetime.datetime.now.return_value = fixed_now
mock_datetime.timedelta = datetime.timedelta
service = MessagesCleanService.from_days(policy=policy)
# Assert
expected_end_before = fixed_now - datetime.timedelta(days=30) # default days=30
assert service._end_before == expected_end_before
assert service._batch_size == 1000 # default
assert service._dry_run is False # default

View File

@ -9,7 +9,7 @@ This module tests the mail sending functionality including:
"""
import smtplib
from unittest.mock import MagicMock, patch
from unittest.mock import ANY, MagicMock, patch
import pytest
@ -151,7 +151,7 @@ class TestSMTPIntegration:
client.send(mail_data)
# Assert
mock_smtp_ssl.assert_called_once_with("smtp.example.com", 465, timeout=10)
mock_smtp_ssl.assert_called_once_with("smtp.example.com", 465, timeout=10, local_hostname=ANY)
mock_server.login.assert_called_once_with("user@example.com", "password123")
mock_server.sendmail.assert_called_once()
mock_server.quit.assert_called_once()
@ -181,7 +181,7 @@ class TestSMTPIntegration:
client.send(mail_data)
# Assert
mock_smtp.assert_called_once_with("smtp.example.com", 587, timeout=10)
mock_smtp.assert_called_once_with("smtp.example.com", 587, timeout=10, local_hostname=ANY)
mock_server.ehlo.assert_called()
mock_server.starttls.assert_called_once()
assert mock_server.ehlo.call_count == 2 # Before and after STARTTLS
@ -213,7 +213,7 @@ class TestSMTPIntegration:
client.send(mail_data)
# Assert
mock_smtp.assert_called_once_with("smtp.example.com", 25, timeout=10)
mock_smtp.assert_called_once_with("smtp.example.com", 25, timeout=10, local_hostname=ANY)
mock_server.login.assert_called_once()
mock_server.sendmail.assert_called_once()
mock_server.quit.assert_called_once()

4666
api/uv.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -968,6 +968,8 @@ SMTP_USERNAME=
SMTP_PASSWORD=
SMTP_USE_TLS=true
SMTP_OPPORTUNISTIC_TLS=false
# Optional: override the local hostname used for SMTP HELO/EHLO
SMTP_LOCAL_HOSTNAME=
# Sendgid configuration
SENDGRID_API_KEY=

View File

@ -21,7 +21,7 @@ services:
# API service
api:
image: langgenius/dify-api:1.11.3
image: langgenius/dify-api:1.11.4
restart: always
environment:
# Use the shared environment variables.
@ -63,7 +63,7 @@ services:
# worker service
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
worker:
image: langgenius/dify-api:1.11.3
image: langgenius/dify-api:1.11.4
restart: always
environment:
# Use the shared environment variables.
@ -102,7 +102,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: langgenius/dify-api:1.11.3
image: langgenius/dify-api:1.11.4
restart: always
environment:
# Use the shared environment variables.
@ -132,7 +132,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:1.11.3
image: langgenius/dify-web:1.11.4
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}

View File

@ -425,6 +425,7 @@ x-shared-env: &shared-api-worker-env
SMTP_PASSWORD: ${SMTP_PASSWORD:-}
SMTP_USE_TLS: ${SMTP_USE_TLS:-true}
SMTP_OPPORTUNISTIC_TLS: ${SMTP_OPPORTUNISTIC_TLS:-false}
SMTP_LOCAL_HOSTNAME: ${SMTP_LOCAL_HOSTNAME:-}
SENDGRID_API_KEY: ${SENDGRID_API_KEY:-}
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-4000}
INVITE_EXPIRY_HOURS: ${INVITE_EXPIRY_HOURS:-72}
@ -704,7 +705,7 @@ services:
# API service
api:
image: langgenius/dify-api:1.11.3
image: langgenius/dify-api:1.11.4
restart: always
environment:
# Use the shared environment variables.
@ -746,7 +747,7 @@ services:
# worker service
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
worker:
image: langgenius/dify-api:1.11.3
image: langgenius/dify-api:1.11.4
restart: always
environment:
# Use the shared environment variables.
@ -785,7 +786,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: langgenius/dify-api:1.11.3
image: langgenius/dify-api:1.11.4
restart: always
environment:
# Use the shared environment variables.
@ -815,7 +816,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:1.11.3
image: langgenius/dify-web:1.11.4
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}

View File

@ -1 +1 @@
22.21.1
24

View File

@ -1,5 +1,5 @@
# base image
FROM node:22.21.1-alpine3.23 AS base
FROM node:24-alpine AS base
LABEL maintainer="takatost@gmail.com"
# if you located in China, you can use aliyun mirror to speed up

View File

@ -8,8 +8,8 @@ This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next
Before starting the web frontend service, please make sure the following environment is ready.
- [Node.js](https://nodejs.org) >= v22.11.x
- [pnpm](https://pnpm.io) v10.x
- [Node.js](https://nodejs.org)
- [pnpm](https://pnpm.io)
> [!TIP]
> It is recommended to install and enable Corepack to manage package manager versions automatically:

View File

@ -66,7 +66,9 @@ export default function CheckCode() {
setIsLoading(true)
const ret = await webAppEmailLoginWithCode({ email, code: encryptVerificationCode(code), token })
if (ret.result === 'success') {
setWebAppAccessToken(ret.data.access_token)
if (ret?.data?.access_token) {
setWebAppAccessToken(ret.data.access_token)
}
const { access_token } = await fetchAccessToken({
appCode: appCode!,
userId: embeddedUserId || undefined,

View File

@ -82,7 +82,9 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut
body: loginData,
})
if (res.result === 'success') {
setWebAppAccessToken(res.data.access_token)
if (res?.data?.access_token) {
setWebAppAccessToken(res.data.access_token)
}
const { access_token } = await fetchAccessToken({
appCode: appCode!,

View File

@ -183,7 +183,6 @@ const AgentTools: FC = () => {
onSelect={handleSelectTool}
onSelectMultiple={handleSelectMultipleTool}
selectedTools={tools as unknown as ToolValue[]}
canChooseMCPTool
/>
</>
)}

View File

@ -12,7 +12,6 @@ import { useDebounceFn } from 'ahooks'
import dynamic from 'next/dynamic'
import {
useRouter,
useSearchParams,
} from 'next/navigation'
import { parseAsString, useQueryState } from 'nuqs'
import { useCallback, useEffect, useRef, useState } from 'react'
@ -29,7 +28,6 @@ import { CheckModal } from '@/hooks/use-pay'
import { useInfiniteAppList } from '@/service/use-apps'
import { AppModeEnum } from '@/types/app'
import { cn } from '@/utils/classnames'
import { isServer } from '@/utils/client'
import AppCard from './app-card'
import { AppCardSkeleton } from './app-card-skeleton'
import Empty from './empty'
@ -59,7 +57,6 @@ const List = () => {
const { t } = useTranslation()
const { systemFeatures } = useGlobalPublicStore()
const router = useRouter()
const searchParams = useSearchParams()
const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator, isLoadingCurrentWorkspace } = useAppContext()
const showTagManagementModal = useTagStore(s => s.showTagManagementModal)
const [activeTab, setActiveTab] = useQueryState(
@ -67,33 +64,6 @@ const List = () => {
parseAsString.withDefault('all').withOptions({ history: 'push' }),
)
// valid tabs for apps list; anything else should fallback to 'all'
// 1) Normalize legacy/incorrect query params like ?mode=discover -> ?category=all
useEffect(() => {
// avoid running on server
if (isServer)
return
const mode = searchParams.get('mode')
if (!mode)
return
const url = new URL(window.location.href)
url.searchParams.delete('mode')
if (validTabs.has(mode)) {
// migrate to category key
url.searchParams.set('category', mode)
}
else {
url.searchParams.set('category', 'all')
}
router.replace(url.pathname + url.search)
}, [router, searchParams])
// 2) If category has an invalid value (e.g., 'discover'), reset to 'all'
useEffect(() => {
if (!validTabs.has(activeTab))
setActiveTab('all')
}, [activeTab, setActiveTab])
const { query: { tagIDs = [], keywords = '', isCreatedByMe: queryIsCreatedByMe = false }, setQuery } = useAppsQueryState()
const [isCreatedByMe, setIsCreatedByMe] = useState(queryIsCreatedByMe)
const [tagFilterValue, setTagFilterValue] = useState<string[]>(tagIDs)

View File

@ -27,7 +27,9 @@ vi.mock('@/service/billing', () => ({
vi.mock('@/service/client', () => ({
consoleClient: {
billingUrl: vi.fn(),
billing: {
invoices: vi.fn(),
},
},
}))
@ -43,7 +45,7 @@ vi.mock('../../assets', () => ({
const mockUseAppContext = useAppContext as Mock
const mockUseAsyncWindowOpen = useAsyncWindowOpen as Mock
const mockBillingUrl = consoleClient.billingUrl as Mock
const mockBillingInvoices = consoleClient.billing.invoices as Mock
const mockFetchSubscriptionUrls = fetchSubscriptionUrls as Mock
const mockToastNotify = Toast.notify as Mock
@ -75,7 +77,7 @@ beforeEach(() => {
vi.clearAllMocks()
mockUseAppContext.mockReturnValue({ isCurrentWorkspaceManager: true })
mockUseAsyncWindowOpen.mockReturnValue(vi.fn(async open => await open()))
mockBillingUrl.mockResolvedValue({ url: 'https://billing.example' })
mockBillingInvoices.mockResolvedValue({ url: 'https://billing.example' })
mockFetchSubscriptionUrls.mockResolvedValue({ url: 'https://subscription.example' })
assignedHref = ''
})
@ -149,7 +151,7 @@ describe('CloudPlanItem', () => {
type: 'error',
message: 'billing.buyPermissionDeniedTip',
}))
expect(mockBillingUrl).not.toHaveBeenCalled()
expect(mockBillingInvoices).not.toHaveBeenCalled()
})
it('should open billing portal when upgrading current paid plan', async () => {
@ -168,7 +170,7 @@ describe('CloudPlanItem', () => {
fireEvent.click(screen.getByRole('button', { name: 'billing.plansCommon.currentPlan' }))
await waitFor(() => {
expect(mockBillingUrl).toHaveBeenCalledTimes(1)
expect(mockBillingInvoices).toHaveBeenCalledTimes(1)
})
expect(openWindow).toHaveBeenCalledTimes(1)
})

View File

@ -77,7 +77,7 @@ const CloudPlanItem: FC<CloudPlanItemProps> = ({
try {
if (isCurrentPaidPlan) {
await openAsyncWindow(async () => {
const res = await consoleClient.billingUrl()
const res = await consoleClient.billing.invoices()
if (res.url)
return res.url
throw new Error('Failed to open billing page')

View File

@ -362,6 +362,18 @@ describe('PreviewDocumentPicker', () => {
expect(screen.getByText('--')).toBeInTheDocument()
})
it('should render when value prop is omitted (optional)', () => {
const files = createMockDocumentList(2)
const onChange = vi.fn()
// Do not pass `value` at all to verify optional behavior
render(<PreviewDocumentPicker files={files} onChange={onChange} />)
// Renders placeholder for missing name
expect(screen.getByText('--')).toBeInTheDocument()
// Portal wrapper renders
expect(screen.getByTestId('portal-elem')).toBeInTheDocument()
})
it('should handle empty files array', () => {
renderComponent({ files: [] })

View File

@ -18,7 +18,7 @@ import DocumentList from './document-list'
type Props = {
className?: string
value: DocumentItem
value?: DocumentItem
files: DocumentItem[]
onChange: (value: DocumentItem) => void
}
@ -30,7 +30,8 @@ const PreviewDocumentPicker: FC<Props> = ({
onChange,
}) => {
const { t } = useTranslation()
const { name, extension } = value
const name = value?.name || ''
const extension = value?.extension
const [open, {
set: setOpen,

View File

@ -0,0 +1,201 @@
'use client'
import type { FC } from 'react'
import type { Item } from '@/app/components/base/select'
import type { BuiltInMetadataItem, MetadataItemWithValueLength } from '@/app/components/datasets/metadata/types'
import type { SortType } from '@/service/datasets'
import { PlusIcon } from '@heroicons/react/24/solid'
import { RiDraftLine, RiExternalLinkLine } from '@remixicon/react'
import { useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import Chip from '@/app/components/base/chip'
import Input from '@/app/components/base/input'
import Sort from '@/app/components/base/sort'
import AutoDisabledDocument from '@/app/components/datasets/common/document-status-with-action/auto-disabled-document'
import IndexFailed from '@/app/components/datasets/common/document-status-with-action/index-failed'
import StatusWithAction from '@/app/components/datasets/common/document-status-with-action/status-with-action'
import DatasetMetadataDrawer from '@/app/components/datasets/metadata/metadata-dataset/dataset-metadata-drawer'
import { useDocLink } from '@/context/i18n'
import { DataSourceType } from '@/models/datasets'
import { useIndexStatus } from '../status-item/hooks'
type DocumentsHeaderProps = {
// Dataset info
datasetId: string
dataSourceType?: DataSourceType
embeddingAvailable: boolean
isFreePlan: boolean
// Filter & sort
statusFilterValue: string
sortValue: SortType
inputValue: string
onStatusFilterChange: (value: string) => void
onStatusFilterClear: () => void
onSortChange: (value: string) => void
onInputChange: (value: string) => void
// Metadata modal
isShowEditMetadataModal: boolean
showEditMetadataModal: () => void
hideEditMetadataModal: () => void
datasetMetaData?: MetadataItemWithValueLength[]
builtInMetaData?: BuiltInMetadataItem[]
builtInEnabled: boolean
onAddMetaData: (payload: BuiltInMetadataItem) => Promise<void>
onRenameMetaData: (payload: MetadataItemWithValueLength) => Promise<void>
onDeleteMetaData: (metaDataId: string) => Promise<void>
onBuiltInEnabledChange: (enabled: boolean) => void
// Actions
onAddDocument: () => void
}
const DocumentsHeader: FC<DocumentsHeaderProps> = ({
datasetId,
dataSourceType,
embeddingAvailable,
isFreePlan,
statusFilterValue,
sortValue,
inputValue,
onStatusFilterChange,
onStatusFilterClear,
onSortChange,
onInputChange,
isShowEditMetadataModal,
showEditMetadataModal,
hideEditMetadataModal,
datasetMetaData,
builtInMetaData,
builtInEnabled,
onAddMetaData,
onRenameMetaData,
onDeleteMetaData,
onBuiltInEnabledChange,
onAddDocument,
}) => {
const { t } = useTranslation()
const docLink = useDocLink()
const DOC_INDEX_STATUS_MAP = useIndexStatus()
const isDataSourceNotion = dataSourceType === DataSourceType.NOTION
const isDataSourceWeb = dataSourceType === DataSourceType.WEB
const statusFilterItems: Item[] = useMemo(() => [
{ value: 'all', name: t('list.index.all', { ns: 'datasetDocuments' }) as string },
{ value: 'queuing', name: DOC_INDEX_STATUS_MAP.queuing.text },
{ value: 'indexing', name: DOC_INDEX_STATUS_MAP.indexing.text },
{ value: 'paused', name: DOC_INDEX_STATUS_MAP.paused.text },
{ value: 'error', name: DOC_INDEX_STATUS_MAP.error.text },
{ value: 'available', name: DOC_INDEX_STATUS_MAP.available.text },
{ value: 'enabled', name: DOC_INDEX_STATUS_MAP.enabled.text },
{ value: 'disabled', name: DOC_INDEX_STATUS_MAP.disabled.text },
{ value: 'archived', name: DOC_INDEX_STATUS_MAP.archived.text },
], [DOC_INDEX_STATUS_MAP, t])
const sortItems: Item[] = useMemo(() => [
{ value: 'created_at', name: t('list.sort.uploadTime', { ns: 'datasetDocuments' }) as string },
{ value: 'hit_count', name: t('list.sort.hitCount', { ns: 'datasetDocuments' }) as string },
], [t])
// Determine add button text based on data source type
const addButtonText = useMemo(() => {
if (isDataSourceNotion)
return t('list.addPages', { ns: 'datasetDocuments' })
if (isDataSourceWeb)
return t('list.addUrl', { ns: 'datasetDocuments' })
return t('list.addFile', { ns: 'datasetDocuments' })
}, [isDataSourceNotion, isDataSourceWeb, t])
return (
<>
{/* Title section */}
<div className="flex flex-col justify-center gap-1 px-6 pt-4">
<h1 className="text-base font-semibold text-text-primary">
{t('list.title', { ns: 'datasetDocuments' })}
</h1>
<div className="flex items-center space-x-0.5 text-sm font-normal text-text-tertiary">
<span>{t('list.desc', { ns: 'datasetDocuments' })}</span>
<a
className="flex items-center text-text-accent"
target="_blank"
rel="noopener noreferrer"
href={docLink('/guides/knowledge-base/integrate-knowledge-within-application')}
>
<span>{t('list.learnMore', { ns: 'datasetDocuments' })}</span>
<RiExternalLinkLine className="h-3 w-3" />
</a>
</div>
</div>
{/* Toolbar section */}
<div className="flex flex-wrap items-center justify-between px-6 pt-4">
{/* Left: Filters */}
<div className="flex items-center gap-2">
<Chip
className="w-[160px]"
showLeftIcon={false}
value={statusFilterValue}
items={statusFilterItems}
onSelect={item => onStatusFilterChange(item?.value ? String(item.value) : '')}
onClear={onStatusFilterClear}
/>
<Input
showLeftIcon
showClearIcon
wrapperClassName="!w-[200px]"
value={inputValue}
onChange={e => onInputChange(e.target.value)}
onClear={() => onInputChange('')}
/>
<div className="h-3.5 w-px bg-divider-regular"></div>
<Sort
order={sortValue.startsWith('-') ? '-' : ''}
value={sortValue.replace('-', '')}
items={sortItems}
onSelect={value => onSortChange(String(value))}
/>
</div>
{/* Right: Actions */}
<div className="flex !h-8 items-center justify-center gap-2">
{!isFreePlan && <AutoDisabledDocument datasetId={datasetId} />}
<IndexFailed datasetId={datasetId} />
{!embeddingAvailable && (
<StatusWithAction
type="warning"
description={t('embeddingModelNotAvailable', { ns: 'dataset' })}
/>
)}
{embeddingAvailable && (
<Button variant="secondary" className="shrink-0" onClick={showEditMetadataModal}>
<RiDraftLine className="mr-1 size-4" />
{t('metadata.metadata', { ns: 'dataset' })}
</Button>
)}
{isShowEditMetadataModal && (
<DatasetMetadataDrawer
userMetadata={datasetMetaData ?? []}
onClose={hideEditMetadataModal}
onAdd={onAddMetaData}
onRename={onRenameMetaData}
onRemove={onDeleteMetaData}
builtInMetadata={builtInMetaData ?? []}
isBuiltInEnabled={builtInEnabled}
onIsBuiltInEnabledChange={onBuiltInEnabledChange}
/>
)}
{embeddingAvailable && (
<Button variant="primary" onClick={onAddDocument} className="shrink-0">
<PlusIcon className="mr-2 h-4 w-4 stroke-current" />
{addButtonText}
</Button>
)}
</div>
</div>
</>
)
}
export default DocumentsHeader

View File

@ -0,0 +1,41 @@
'use client'
import type { FC } from 'react'
import { PlusIcon } from '@heroicons/react/24/solid'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import s from '../style.module.css'
import { FolderPlusIcon, NotionIcon, ThreeDotsIcon } from './icons'
type EmptyElementProps = {
canAdd: boolean
onClick: () => void
type?: 'upload' | 'sync'
}
const EmptyElement: FC<EmptyElementProps> = ({ canAdd = true, onClick, type = 'upload' }) => {
const { t } = useTranslation()
return (
<div className={s.emptyWrapper}>
<div className={s.emptyElement}>
<div className={s.emptySymbolIconWrapper}>
{type === 'upload' ? <FolderPlusIcon /> : <NotionIcon />}
</div>
<span className={s.emptyTitle}>
{t('list.empty.title', { ns: 'datasetDocuments' })}
<ThreeDotsIcon className="relative -left-1.5 -top-3 inline" />
</span>
<div className={s.emptyTip}>
{t(`list.empty.${type}.tip`, { ns: 'datasetDocuments' })}
</div>
{type === 'upload' && canAdd && (
<Button onClick={onClick} className={s.addFileBtn} variant="secondary-accent">
<PlusIcon className={s.plusIcon} />
{t('list.addFile', { ns: 'datasetDocuments' })}
</Button>
)}
</div>
</div>
)
}
export default EmptyElement

View File

@ -0,0 +1,34 @@
import type * as React from 'react'
export const FolderPlusIcon = ({ className }: React.SVGProps<SVGElement>) => {
return (
<svg width="20" height="20" viewBox="0 0 20 20" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
<path d="M10.8332 5.83333L9.90355 3.9741C9.63601 3.439 9.50222 3.17144 9.30265 2.97597C9.12615 2.80311 8.91344 2.67164 8.6799 2.59109C8.41581 2.5 8.11668 2.5 7.51841 2.5H4.33317C3.39975 2.5 2.93304 2.5 2.57652 2.68166C2.26292 2.84144 2.00795 3.09641 1.84816 3.41002C1.6665 3.76654 1.6665 4.23325 1.6665 5.16667V5.83333M1.6665 5.83333H14.3332C15.7333 5.83333 16.4334 5.83333 16.9681 6.10582C17.4386 6.3455 17.821 6.72795 18.0607 7.19836C18.3332 7.73314 18.3332 8.4332 18.3332 9.83333V13.5C18.3332 14.9001 18.3332 15.6002 18.0607 16.135C17.821 16.6054 17.4386 16.9878 16.9681 17.2275C16.4334 17.5 15.7333 17.5 14.3332 17.5H5.6665C4.26637 17.5 3.56631 17.5 3.03153 17.2275C2.56112 16.9878 2.17867 16.6054 1.93899 16.135C1.6665 15.6002 1.6665 14.9001 1.6665 13.5V5.83333ZM9.99984 14.1667V9.16667M7.49984 11.6667H12.4998" stroke="#667085" strokeWidth="1.5" strokeLinecap="round" strokeLinejoin="round" />
</svg>
)
}
export const ThreeDotsIcon = ({ className }: React.SVGProps<SVGElement>) => {
return (
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
<path d="M5 6.5V5M8.93934 7.56066L10 6.5M10.0103 11.5H11.5103" stroke="#374151" strokeWidth="2" strokeLinecap="round" strokeLinejoin="round" />
</svg>
)
}
export const NotionIcon = ({ className }: React.SVGProps<SVGElement>) => {
return (
<svg width="20" height="20" viewBox="0 0 20 20" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
<g clipPath="url(#clip0_2164_11263)">
<path fillRule="evenodd" clipRule="evenodd" d="M3.5725 18.2611L1.4229 15.5832C0.905706 14.9389 0.625 14.1466 0.625 13.3312V3.63437C0.625 2.4129 1.60224 1.39936 2.86295 1.31328L12.8326 0.632614C13.5569 0.583164 14.2768 0.775682 14.8717 1.17794L18.3745 3.5462C19.0015 3.97012 19.375 4.66312 19.375 5.40266V16.427C19.375 17.6223 18.4141 18.6121 17.1798 18.688L6.11458 19.3692C5.12958 19.4298 4.17749 19.0148 3.5725 18.2611Z" fill="white" />
<path d="M7.03006 8.48669V8.35974C7.03006 8.03794 7.28779 7.77104 7.61997 7.74886L10.0396 7.58733L13.3857 12.5147V8.19009L12.5244 8.07528V8.01498C12.5244 7.68939 12.788 7.42074 13.1244 7.4035L15.326 7.29073V7.60755C15.326 7.75628 15.2154 7.88349 15.0638 7.90913L14.534 7.99874V15.0023L13.8691 15.231C13.3136 15.422 12.6952 15.2175 12.3772 14.7377L9.12879 9.83574V14.5144L10.1287 14.7057L10.1147 14.7985C10.0711 15.089 9.82028 15.3087 9.51687 15.3222L7.03006 15.4329C6.99718 15.1205 7.23132 14.841 7.55431 14.807L7.88143 14.7727V8.53453L7.03006 8.48669Z" fill="black" />
<path fillRule="evenodd" clipRule="evenodd" d="M12.9218 1.85424L2.95217 2.53491C2.35499 2.57568 1.89209 3.05578 1.89209 3.63437V13.3312C1.89209 13.8748 2.07923 14.403 2.42402 14.8325L4.57362 17.5104C4.92117 17.9434 5.46812 18.1818 6.03397 18.147L17.0991 17.4658C17.6663 17.4309 18.1078 16.9762 18.1078 16.427V5.40266C18.1078 5.06287 17.9362 4.74447 17.6481 4.54969L14.1453 2.18143C13.7883 1.94008 13.3564 1.82457 12.9218 1.85424ZM3.44654 3.78562C3.30788 3.68296 3.37387 3.46909 3.54806 3.4566L12.9889 2.77944C13.2897 2.75787 13.5886 2.8407 13.8318 3.01305L15.7261 4.35508C15.798 4.40603 15.7642 4.51602 15.6752 4.52086L5.67742 5.0646C5.37485 5.08106 5.0762 4.99217 4.83563 4.81406L3.44654 3.78562ZM5.20848 6.76919C5.20848 6.4444 5.47088 6.1761 5.80642 6.15783L16.3769 5.58216C16.7039 5.56435 16.9792 5.81583 16.9792 6.13239V15.6783C16.9792 16.0025 16.7177 16.2705 16.3829 16.2896L5.8793 16.8872C5.51537 16.9079 5.20848 16.6283 5.20848 16.2759V6.76919Z" fill="black" />
</g>
<defs>
<clipPath id="clip0_2164_11263">
<rect width="20" height="20" fill="white" />
</clipPath>
</defs>
</svg>
)
}

View File

@ -16,13 +16,16 @@ import * as React from 'react'
import { useCallback, useEffect, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Checkbox from '@/app/components/base/checkbox'
import FileTypeIcon from '@/app/components/base/file-uploader/file-type-icon'
import NotionIcon from '@/app/components/base/notion-icon'
import Pagination from '@/app/components/base/pagination'
import Toast from '@/app/components/base/toast'
import Tooltip from '@/app/components/base/tooltip'
import ChunkingModeLabel from '@/app/components/datasets/common/chunking-mode-label'
import { normalizeStatusForQuery } from '@/app/components/datasets/documents/status-filter'
import { extensionToFileType } from '@/app/components/datasets/hit-testing/utils/extension-to-file-type'
import EditMetadataBatchModal from '@/app/components/datasets/metadata/edit-metadata-batch/modal'
import useBatchEditDocumentMetadata from '@/app/components/datasets/metadata/hooks/use-batch-edit-document-metadata'
import { useDatasetDetailContextWithSelector as useDatasetDetailContext } from '@/context/dataset-detail'
import useTimestamp from '@/hooks/use-timestamp'
import { ChunkingMode, DataSourceType, DocumentActionType } from '@/models/datasets'
@ -31,15 +34,12 @@ import { useDocumentArchive, useDocumentBatchRetryIndex, useDocumentDelete, useD
import { asyncRunSafe } from '@/utils'
import { cn } from '@/utils/classnames'
import { formatNumber } from '@/utils/format'
import FileTypeIcon from '../../base/file-uploader/file-type-icon'
import ChunkingModeLabel from '../common/chunking-mode-label'
import useBatchEditDocumentMetadata from '../metadata/hooks/use-batch-edit-document-metadata'
import BatchAction from './detail/completed/common/batch-action'
import SummaryStatus from './detail/completed/common/summary-status'
import BatchAction from '../detail/completed/common/batch-action'
import SummaryStatus from '../detail/completed/common/summary-status'
import StatusItem from '../status-item'
import s from '../style.module.css'
import Operations from './operations'
import RenameModal from './rename-modal'
import StatusItem from './status-item'
import s from './style.module.css'
export const renderTdValue = (value: string | number | null, isEmptyStyle = false) => {
return (

View File

@ -1,4 +1,4 @@
import type { OperationName } from './types'
import type { OperationName } from '../types'
import type { CommonResponse } from '@/models/common'
import {
RiArchive2Line,
@ -17,7 +17,13 @@ import * as React from 'react'
import { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { useContext } from 'use-context-selector'
import Confirm from '@/app/components/base/confirm'
import Divider from '@/app/components/base/divider'
import { SearchLinesSparkle } from '@/app/components/base/icons/src/vender/knowledge'
import CustomPopover from '@/app/components/base/popover'
import Switch from '@/app/components/base/switch'
import { ToastContext } from '@/app/components/base/toast'
import Tooltip from '@/app/components/base/tooltip'
import { DataSourceType, DocumentActionType } from '@/models/datasets'
import {
useDocumentArchive,
@ -33,14 +39,8 @@ import {
} from '@/service/knowledge/use-document'
import { asyncRunSafe } from '@/utils'
import { cn } from '@/utils/classnames'
import Confirm from '../../base/confirm'
import Divider from '../../base/divider'
import CustomPopover from '../../base/popover'
import Switch from '../../base/switch'
import { ToastContext } from '../../base/toast'
import Tooltip from '../../base/tooltip'
import s from '../style.module.css'
import RenameModal from './rename-modal'
import s from './style.module.css'
type OperationsProps = {
embeddingAvailable: boolean

View File

@ -7,8 +7,8 @@ import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import Input from '@/app/components/base/input'
import Modal from '@/app/components/base/modal'
import Toast from '@/app/components/base/toast'
import { renameDocumentName } from '@/service/datasets'
import Toast from '../../base/toast'
type Props = {
datasetId: string

View File

@ -0,0 +1,5 @@
export { useAddDocumentsSteps } from './use-add-documents-steps'
export { useDatasourceActions } from './use-datasource-actions'
export { useDatasourceOptions } from './use-datasource-options'
export { useLocalFile, useOnlineDocument, useOnlineDrive, useWebsiteCrawl } from './use-datasource-store'
export { useDatasourceUIState } from './use-datasource-ui-state'

View File

@ -0,0 +1,41 @@
import { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { AddDocumentsStep } from '../types'
/**
* Hook for managing add documents wizard steps
*/
export const useAddDocumentsSteps = () => {
const { t } = useTranslation()
const [currentStep, setCurrentStep] = useState(1)
const handleNextStep = useCallback(() => {
setCurrentStep(preStep => preStep + 1)
}, [])
const handleBackStep = useCallback(() => {
setCurrentStep(preStep => preStep - 1)
}, [])
const steps = [
{
label: t('addDocuments.steps.chooseDatasource', { ns: 'datasetPipeline' }),
value: AddDocumentsStep.dataSource,
},
{
label: t('addDocuments.steps.processDocuments', { ns: 'datasetPipeline' }),
value: AddDocumentsStep.processDocuments,
},
{
label: t('addDocuments.steps.processingDocuments', { ns: 'datasetPipeline' }),
value: AddDocumentsStep.processingDocuments,
},
]
return {
steps,
currentStep,
handleNextStep,
handleBackStep,
}
}

View File

@ -0,0 +1,321 @@
import type { StoreApi } from 'zustand'
import type { DataSourceShape } from '@/app/components/datasets/documents/create-from-pipeline/data-source/store'
import type { Datasource } from '@/app/components/rag-pipeline/components/panel/test-run/types'
import type { DataSourceNotionPageMap, NotionPage } from '@/models/common'
import type { CrawlResultItem, DocumentItem, CustomFile as File, FileIndexingEstimateResponse } from '@/models/datasets'
import type {
OnlineDriveFile,
PublishedPipelineRunPreviewResponse,
PublishedPipelineRunResponse,
} from '@/models/pipeline'
import { useCallback, useRef } from 'react'
import { trackEvent } from '@/app/components/base/amplitude'
import { DatasourceType } from '@/models/pipeline'
import { useRunPublishedPipeline } from '@/service/use-pipeline'
import {
buildLocalFileDatasourceInfo,
buildOnlineDocumentDatasourceInfo,
buildOnlineDriveDatasourceInfo,
buildWebsiteCrawlDatasourceInfo,
} from '../utils/datasource-info-builder'
type DatasourceActionsParams = {
datasource: Datasource | undefined
datasourceType: string | undefined
pipelineId: string | undefined
dataSourceStore: StoreApi<DataSourceShape>
setEstimateData: (data: FileIndexingEstimateResponse | undefined) => void
setBatchId: (id: string) => void
setDocuments: (docs: PublishedPipelineRunResponse['documents']) => void
handleNextStep: () => void
PagesMapAndSelectedPagesId: DataSourceNotionPageMap
currentWorkspacePages: { page_id: string }[] | undefined
clearOnlineDocumentData: () => void
clearWebsiteCrawlData: () => void
clearOnlineDriveData: () => void
setDatasource: (ds: Datasource) => void
}
/**
* Hook for datasource-related actions (preview, process, etc.)
*/
export const useDatasourceActions = ({
datasource,
datasourceType,
pipelineId,
dataSourceStore,
setEstimateData,
setBatchId,
setDocuments,
handleNextStep,
PagesMapAndSelectedPagesId,
currentWorkspacePages,
clearOnlineDocumentData,
clearWebsiteCrawlData,
clearOnlineDriveData,
setDatasource,
}: DatasourceActionsParams) => {
const isPreview = useRef(false)
const formRef = useRef<{ submit: () => void } | null>(null)
const { mutateAsync: runPublishedPipeline, isIdle, isPending } = useRunPublishedPipeline()
// Build datasource info for preview (single item)
const buildPreviewDatasourceInfo = useCallback(() => {
const {
previewLocalFileRef,
previewOnlineDocumentRef,
previewWebsitePageRef,
previewOnlineDriveFileRef,
currentCredentialId,
bucket,
} = dataSourceStore.getState()
const datasourceInfoList: Record<string, unknown>[] = []
if (datasourceType === DatasourceType.localFile && previewLocalFileRef.current) {
datasourceInfoList.push(buildLocalFileDatasourceInfo(
previewLocalFileRef.current as File,
currentCredentialId,
))
}
if (datasourceType === DatasourceType.onlineDocument && previewOnlineDocumentRef.current) {
datasourceInfoList.push(buildOnlineDocumentDatasourceInfo(
previewOnlineDocumentRef.current,
currentCredentialId,
))
}
if (datasourceType === DatasourceType.websiteCrawl && previewWebsitePageRef.current) {
datasourceInfoList.push(buildWebsiteCrawlDatasourceInfo(
previewWebsitePageRef.current,
currentCredentialId,
))
}
if (datasourceType === DatasourceType.onlineDrive && previewOnlineDriveFileRef.current) {
datasourceInfoList.push(buildOnlineDriveDatasourceInfo(
previewOnlineDriveFileRef.current,
bucket,
currentCredentialId,
))
}
return datasourceInfoList
}, [dataSourceStore, datasourceType])
// Build datasource info for processing (all items)
const buildProcessDatasourceInfo = useCallback(() => {
const {
currentCredentialId,
localFileList,
onlineDocuments,
websitePages,
bucket,
selectedFileIds,
onlineDriveFileList,
} = dataSourceStore.getState()
const datasourceInfoList: Record<string, unknown>[] = []
if (datasourceType === DatasourceType.localFile) {
localFileList.forEach((file) => {
datasourceInfoList.push(buildLocalFileDatasourceInfo(file.file, currentCredentialId))
})
}
if (datasourceType === DatasourceType.onlineDocument) {
onlineDocuments.forEach((page) => {
datasourceInfoList.push(buildOnlineDocumentDatasourceInfo(page, currentCredentialId))
})
}
if (datasourceType === DatasourceType.websiteCrawl) {
websitePages.forEach((page) => {
datasourceInfoList.push(buildWebsiteCrawlDatasourceInfo(page, currentCredentialId))
})
}
if (datasourceType === DatasourceType.onlineDrive) {
selectedFileIds.forEach((id) => {
const file = onlineDriveFileList.find(f => f.id === id)
if (file)
datasourceInfoList.push(buildOnlineDriveDatasourceInfo(file, bucket, currentCredentialId))
})
}
return datasourceInfoList
}, [dataSourceStore, datasourceType])
// Handle chunk preview
const handlePreviewChunks = useCallback(async (data: Record<string, unknown>) => {
if (!datasource || !pipelineId)
return
const datasourceInfoList = buildPreviewDatasourceInfo()
await runPublishedPipeline({
pipeline_id: pipelineId,
inputs: data,
start_node_id: datasource.nodeId,
datasource_type: datasourceType as DatasourceType,
datasource_info_list: datasourceInfoList,
is_preview: true,
}, {
onSuccess: (res) => {
setEstimateData((res as PublishedPipelineRunPreviewResponse).data.outputs)
},
})
}, [datasource, pipelineId, datasourceType, buildPreviewDatasourceInfo, runPublishedPipeline, setEstimateData])
// Handle document processing
const handleProcess = useCallback(async (data: Record<string, unknown>) => {
if (!datasource || !pipelineId)
return
const datasourceInfoList = buildProcessDatasourceInfo()
await runPublishedPipeline({
pipeline_id: pipelineId,
inputs: data,
start_node_id: datasource.nodeId,
datasource_type: datasourceType as DatasourceType,
datasource_info_list: datasourceInfoList,
is_preview: false,
}, {
onSuccess: (res) => {
setBatchId((res as PublishedPipelineRunResponse).batch || '')
setDocuments((res as PublishedPipelineRunResponse).documents || [])
handleNextStep()
trackEvent('dataset_document_added', {
data_source_type: datasourceType,
indexing_technique: 'pipeline',
})
},
})
}, [datasource, pipelineId, datasourceType, buildProcessDatasourceInfo, runPublishedPipeline, setBatchId, setDocuments, handleNextStep])
// Form submission handlers
const onClickProcess = useCallback(() => {
isPreview.current = false
formRef.current?.submit()
}, [])
const onClickPreview = useCallback(() => {
isPreview.current = true
formRef.current?.submit()
}, [])
const handleSubmit = useCallback((data: Record<string, unknown>) => {
if (isPreview.current)
handlePreviewChunks(data)
else
handleProcess(data)
}, [handlePreviewChunks, handleProcess])
// Preview change handlers
const handlePreviewFileChange = useCallback((file: DocumentItem) => {
const { previewLocalFileRef } = dataSourceStore.getState()
previewLocalFileRef.current = file
onClickPreview()
}, [dataSourceStore, onClickPreview])
const handlePreviewOnlineDocumentChange = useCallback((page: NotionPage) => {
const { previewOnlineDocumentRef } = dataSourceStore.getState()
previewOnlineDocumentRef.current = page
onClickPreview()
}, [dataSourceStore, onClickPreview])
const handlePreviewWebsiteChange = useCallback((website: CrawlResultItem) => {
const { previewWebsitePageRef } = dataSourceStore.getState()
previewWebsitePageRef.current = website
onClickPreview()
}, [dataSourceStore, onClickPreview])
const handlePreviewOnlineDriveFileChange = useCallback((file: OnlineDriveFile) => {
const { previewOnlineDriveFileRef } = dataSourceStore.getState()
previewOnlineDriveFileRef.current = file
onClickPreview()
}, [dataSourceStore, onClickPreview])
// Select all handler
const handleSelectAll = useCallback(() => {
const {
onlineDocuments,
onlineDriveFileList,
selectedFileIds,
setOnlineDocuments,
setSelectedFileIds,
setSelectedPagesId,
} = dataSourceStore.getState()
if (datasourceType === DatasourceType.onlineDocument) {
const allIds = currentWorkspacePages?.map(page => page.page_id) || []
if (onlineDocuments.length < allIds.length) {
const selectedPages = Array.from(allIds).map(pageId => PagesMapAndSelectedPagesId[pageId])
setOnlineDocuments(selectedPages)
setSelectedPagesId(new Set(allIds))
}
else {
setOnlineDocuments([])
setSelectedPagesId(new Set())
}
}
if (datasourceType === DatasourceType.onlineDrive) {
const allKeys = onlineDriveFileList.filter(item => item.type !== 'bucket').map(file => file.id)
if (selectedFileIds.length < allKeys.length)
setSelectedFileIds(allKeys)
else
setSelectedFileIds([])
}
}, [PagesMapAndSelectedPagesId, currentWorkspacePages, dataSourceStore, datasourceType])
// Clear datasource data based on type
const clearDataSourceData = useCallback((dataSource: Datasource) => {
const providerType = dataSource.nodeData.provider_type
const clearFunctions: Record<string, () => void> = {
[DatasourceType.onlineDocument]: clearOnlineDocumentData,
[DatasourceType.websiteCrawl]: clearWebsiteCrawlData,
[DatasourceType.onlineDrive]: clearOnlineDriveData,
[DatasourceType.localFile]: () => {},
}
clearFunctions[providerType]?.()
}, [clearOnlineDocumentData, clearOnlineDriveData, clearWebsiteCrawlData])
// Switch datasource handler
const handleSwitchDataSource = useCallback((dataSource: Datasource) => {
const {
setCurrentCredentialId,
currentNodeIdRef,
} = dataSourceStore.getState()
clearDataSourceData(dataSource)
setCurrentCredentialId('')
currentNodeIdRef.current = dataSource.nodeId
setDatasource(dataSource)
}, [clearDataSourceData, dataSourceStore, setDatasource])
// Credential change handler
const handleCredentialChange = useCallback((credentialId: string) => {
const { setCurrentCredentialId } = dataSourceStore.getState()
if (datasource)
clearDataSourceData(datasource)
setCurrentCredentialId(credentialId)
}, [clearDataSourceData, dataSourceStore, datasource])
return {
isPreview,
formRef,
isIdle,
isPending,
onClickProcess,
onClickPreview,
handleSubmit,
handlePreviewFileChange,
handlePreviewOnlineDocumentChange,
handlePreviewWebsiteChange,
handlePreviewOnlineDriveFileChange,
handleSelectAll,
handleSwitchDataSource,
handleCredentialChange,
}
}

View File

@ -0,0 +1,27 @@
import type { DataSourceOption } from '@/app/components/rag-pipeline/components/panel/test-run/types'
import type { DataSourceNodeType } from '@/app/components/workflow/nodes/data-source/types'
import type { Node } from '@/app/components/workflow/types'
import { useMemo } from 'react'
import { BlockEnum } from '@/app/components/workflow/types'
/**
* Hook for getting datasource options from pipeline nodes
*/
export const useDatasourceOptions = (pipelineNodes: Node<DataSourceNodeType>[]) => {
const datasourceNodes = pipelineNodes.filter(node => node.data.type === BlockEnum.DataSource)
const options = useMemo(() => {
const options: DataSourceOption[] = []
datasourceNodes.forEach((node) => {
const label = node.data.title
options.push({
label,
value: node.id,
data: node.data,
})
})
return options
}, [datasourceNodes])
return options
}

View File

@ -1,69 +1,12 @@
import type { DataSourceOption } from '@/app/components/rag-pipeline/components/panel/test-run/types'
import type { DataSourceNodeType } from '@/app/components/workflow/nodes/data-source/types'
import type { Node } from '@/app/components/workflow/types'
import type { DataSourceNotionPageMap, DataSourceNotionWorkspace } from '@/models/common'
import { useCallback, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { useCallback, useMemo } from 'react'
import { useShallow } from 'zustand/react/shallow'
import { BlockEnum } from '@/app/components/workflow/types'
import { CrawlStep } from '@/models/datasets'
import { useDataSourceStore, useDataSourceStoreWithSelector } from './data-source/store'
import { AddDocumentsStep } from './types'
export const useAddDocumentsSteps = () => {
const { t } = useTranslation()
const [currentStep, setCurrentStep] = useState(1)
const handleNextStep = useCallback(() => {
setCurrentStep(preStep => preStep + 1)
}, [])
const handleBackStep = useCallback(() => {
setCurrentStep(preStep => preStep - 1)
}, [])
const steps = [
{
label: t('addDocuments.steps.chooseDatasource', { ns: 'datasetPipeline' }),
value: AddDocumentsStep.dataSource,
},
{
label: t('addDocuments.steps.processDocuments', { ns: 'datasetPipeline' }),
value: AddDocumentsStep.processDocuments,
},
{
label: t('addDocuments.steps.processingDocuments', { ns: 'datasetPipeline' }),
value: AddDocumentsStep.processingDocuments,
},
]
return {
steps,
currentStep,
handleNextStep,
handleBackStep,
}
}
export const useDatasourceOptions = (pipelineNodes: Node<DataSourceNodeType>[]) => {
const datasourceNodes = pipelineNodes.filter(node => node.data.type === BlockEnum.DataSource)
const options = useMemo(() => {
const options: DataSourceOption[] = []
datasourceNodes.forEach((node) => {
const label = node.data.title
options.push({
label,
value: node.id,
data: node.data,
})
})
return options
}, [datasourceNodes])
return options
}
import { useDataSourceStore, useDataSourceStoreWithSelector } from '../data-source/store'
/**
* Hook for local file datasource store operations
*/
export const useLocalFile = () => {
const {
localFileList,
@ -89,6 +32,9 @@ export const useLocalFile = () => {
}
}
/**
* Hook for online document datasource store operations
*/
export const useOnlineDocument = () => {
const {
documentsData,
@ -147,6 +93,9 @@ export const useOnlineDocument = () => {
}
}
/**
* Hook for website crawl datasource store operations
*/
export const useWebsiteCrawl = () => {
const {
websitePages,
@ -186,6 +135,9 @@ export const useWebsiteCrawl = () => {
}
}
/**
* Hook for online drive datasource store operations
*/
export const useOnlineDrive = () => {
const {
onlineDriveFileList,

View File

@ -0,0 +1,132 @@
import type { Datasource } from '@/app/components/rag-pipeline/components/panel/test-run/types'
import type { OnlineDriveFile } from '@/models/pipeline'
import { useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import { DatasourceType } from '@/models/pipeline'
type DatasourceUIStateParams = {
datasource: Datasource | undefined
allFileLoaded: boolean
localFileListLength: number
onlineDocumentsLength: number
websitePagesLength: number
selectedFileIdsLength: number
onlineDriveFileList: OnlineDriveFile[]
isVectorSpaceFull: boolean
enableBilling: boolean
currentWorkspacePagesLength: number
fileUploadConfig: { file_size_limit: number, batch_count_limit: number }
}
/**
* Hook for computing datasource UI state based on datasource type
*/
export const useDatasourceUIState = ({
datasource,
allFileLoaded,
localFileListLength,
onlineDocumentsLength,
websitePagesLength,
selectedFileIdsLength,
onlineDriveFileList,
isVectorSpaceFull,
enableBilling,
currentWorkspacePagesLength,
fileUploadConfig,
}: DatasourceUIStateParams) => {
const { t } = useTranslation()
const datasourceType = datasource?.nodeData.provider_type
const isShowVectorSpaceFull = useMemo(() => {
if (!datasource || !datasourceType)
return false
// Lookup table for vector space full condition check
const vectorSpaceFullConditions: Record<string, boolean> = {
[DatasourceType.localFile]: allFileLoaded,
[DatasourceType.onlineDocument]: onlineDocumentsLength > 0,
[DatasourceType.websiteCrawl]: websitePagesLength > 0,
[DatasourceType.onlineDrive]: onlineDriveFileList.length > 0,
}
const condition = vectorSpaceFullConditions[datasourceType]
return condition && isVectorSpaceFull && enableBilling
}, [datasource, datasourceType, allFileLoaded, onlineDocumentsLength, websitePagesLength, onlineDriveFileList.length, isVectorSpaceFull, enableBilling])
// Lookup table for next button disabled conditions
const nextBtnDisabled = useMemo(() => {
if (!datasource || !datasourceType)
return true
const disabledConditions: Record<string, boolean> = {
[DatasourceType.localFile]: isShowVectorSpaceFull || localFileListLength === 0 || !allFileLoaded,
[DatasourceType.onlineDocument]: isShowVectorSpaceFull || onlineDocumentsLength === 0,
[DatasourceType.websiteCrawl]: isShowVectorSpaceFull || websitePagesLength === 0,
[DatasourceType.onlineDrive]: isShowVectorSpaceFull || selectedFileIdsLength === 0,
}
return disabledConditions[datasourceType] ?? true
}, [datasource, datasourceType, isShowVectorSpaceFull, localFileListLength, allFileLoaded, onlineDocumentsLength, websitePagesLength, selectedFileIdsLength])
// Check if select all should be shown
const showSelect = useMemo(() => {
if (datasourceType === DatasourceType.onlineDocument)
return currentWorkspacePagesLength > 0
if (datasourceType === DatasourceType.onlineDrive) {
const nonBucketItems = onlineDriveFileList.filter(item => item.type !== 'bucket')
const isBucketList = onlineDriveFileList.some(file => file.type === 'bucket')
return !isBucketList && nonBucketItems.length > 0
}
return false
}, [currentWorkspacePagesLength, datasourceType, onlineDriveFileList])
// Total selectable options count
const totalOptions = useMemo(() => {
if (datasourceType === DatasourceType.onlineDocument)
return currentWorkspacePagesLength
if (datasourceType === DatasourceType.onlineDrive)
return onlineDriveFileList.filter(item => item.type !== 'bucket').length
return undefined
}, [currentWorkspacePagesLength, datasourceType, onlineDriveFileList])
// Selected options count
const selectedOptions = useMemo(() => {
if (datasourceType === DatasourceType.onlineDocument)
return onlineDocumentsLength
if (datasourceType === DatasourceType.onlineDrive)
return selectedFileIdsLength
return undefined
}, [datasourceType, onlineDocumentsLength, selectedFileIdsLength])
// Tip message for selection
const tip = useMemo(() => {
if (datasourceType === DatasourceType.onlineDocument)
return t('addDocuments.selectOnlineDocumentTip', { ns: 'datasetPipeline', count: 50 })
if (datasourceType === DatasourceType.onlineDrive) {
return t('addDocuments.selectOnlineDriveTip', {
ns: 'datasetPipeline',
count: fileUploadConfig.batch_count_limit,
fileSize: fileUploadConfig.file_size_limit,
})
}
return ''
}, [datasourceType, fileUploadConfig.batch_count_limit, fileUploadConfig.file_size_limit, t])
return {
datasourceType,
isShowVectorSpaceFull,
nextBtnDisabled,
showSelect,
totalOptions,
selectedOptions,
tip,
}
}

File diff suppressed because it is too large Load Diff

View File

@ -2,75 +2,71 @@
import type { Datasource } from '@/app/components/rag-pipeline/components/panel/test-run/types'
import type { DataSourceNodeType } from '@/app/components/workflow/nodes/data-source/types'
import type { Node } from '@/app/components/workflow/types'
import type { NotionPage } from '@/models/common'
import type { CrawlResultItem, DocumentItem, CustomFile as File, FileIndexingEstimateResponse } from '@/models/datasets'
import type {
InitialDocumentDetail,
OnlineDriveFile,
PublishedPipelineRunPreviewResponse,
PublishedPipelineRunResponse,
} from '@/models/pipeline'
import type { FileIndexingEstimateResponse } from '@/models/datasets'
import type { InitialDocumentDetail } from '@/models/pipeline'
import { useBoolean } from 'ahooks'
import { useCallback, useMemo, useRef, useState } from 'react'
import { useCallback, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { trackEvent } from '@/app/components/base/amplitude'
import Divider from '@/app/components/base/divider'
import Loading from '@/app/components/base/loading'
import PlanUpgradeModal from '@/app/components/billing/plan-upgrade-modal'
import VectorSpaceFull from '@/app/components/billing/vector-space-full'
import LocalFile from '@/app/components/datasets/documents/create-from-pipeline/data-source/local-file'
import OnlineDocuments from '@/app/components/datasets/documents/create-from-pipeline/data-source/online-documents'
import OnlineDrive from '@/app/components/datasets/documents/create-from-pipeline/data-source/online-drive'
import WebsiteCrawl from '@/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl'
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
import { useProviderContextSelector } from '@/context/provider-context'
import { DatasourceType } from '@/models/pipeline'
import { useFileUploadConfig } from '@/service/use-common'
import { usePublishedPipelineInfo, useRunPublishedPipeline } from '@/service/use-pipeline'
import { TransferMethod } from '@/types/app'
import UpgradeCard from '../../create/step-one/upgrade-card'
import Actions from './actions'
import DataSourceOptions from './data-source-options'
import { usePublishedPipelineInfo } from '@/service/use-pipeline'
import { useDataSourceStore } from './data-source/store'
import DataSourceProvider from './data-source/store/provider'
import { useAddDocumentsSteps, useLocalFile, useOnlineDocument, useOnlineDrive, useWebsiteCrawl } from './hooks'
import {
useAddDocumentsSteps,
useDatasourceActions,
useDatasourceUIState,
useLocalFile,
useOnlineDocument,
useOnlineDrive,
useWebsiteCrawl,
} from './hooks'
import LeftHeader from './left-header'
import ChunkPreview from './preview/chunk-preview'
import FilePreview from './preview/file-preview'
import OnlineDocumentPreview from './preview/online-document-preview'
import WebsitePreview from './preview/web-preview'
import ProcessDocuments from './process-documents'
import Processing from './processing'
import { StepOneContent, StepThreeContent, StepTwoContent } from './steps'
import { StepOnePreview, StepTwoPreview } from './steps/preview-panel'
const CreateFormPipeline = () => {
const { t } = useTranslation()
const plan = useProviderContextSelector(state => state.plan)
const enableBilling = useProviderContextSelector(state => state.enableBilling)
const pipelineId = useDatasetDetailContextWithSelector(s => s.dataset?.pipeline_id)
const dataSourceStore = useDataSourceStore()
// Core state
const [datasource, setDatasource] = useState<Datasource>()
const [estimateData, setEstimateData] = useState<FileIndexingEstimateResponse | undefined>(undefined)
const [batchId, setBatchId] = useState('')
const [documents, setDocuments] = useState<InitialDocumentDetail[]>([])
const dataSourceStore = useDataSourceStore()
const isPreview = useRef(false)
const formRef = useRef<any>(null)
// Data fetching
const { data: pipelineInfo, isFetching: isFetchingPipelineInfo } = usePublishedPipelineInfo(pipelineId || '')
const { data: fileUploadConfigResponse } = useFileUploadConfig()
const fileUploadConfig = useMemo(() => fileUploadConfigResponse ?? {
file_size_limit: 15,
batch_count_limit: 5,
}, [fileUploadConfigResponse])
// Steps management
const {
steps,
currentStep,
handleNextStep: doHandleNextStep,
handleBackStep,
} = useAddDocumentsSteps()
// Datasource-specific hooks
const {
localFileList,
allFileLoaded,
currentLocalFile,
hidePreviewLocalFile,
} = useLocalFile()
const {
currentWorkspace,
onlineDocuments,
@ -79,12 +75,14 @@ const CreateFormPipeline = () => {
hidePreviewOnlineDocument,
clearOnlineDocumentData,
} = useOnlineDocument()
const {
websitePages,
currentWebsite,
hideWebsitePreview,
clearWebsiteCrawlData,
} = useWebsiteCrawl()
const {
onlineDriveFileList,
selectedFileIds,
@ -92,43 +90,50 @@ const CreateFormPipeline = () => {
clearOnlineDriveData,
} = useOnlineDrive()
const datasourceType = useMemo(() => datasource?.nodeData.provider_type, [datasource])
// Computed values
const isVectorSpaceFull = plan.usage.vectorSpace >= plan.total.vectorSpace
const isShowVectorSpaceFull = useMemo(() => {
if (!datasource)
return false
if (datasourceType === DatasourceType.localFile)
return allFileLoaded && isVectorSpaceFull && enableBilling
if (datasourceType === DatasourceType.onlineDocument)
return onlineDocuments.length > 0 && isVectorSpaceFull && enableBilling
if (datasourceType === DatasourceType.websiteCrawl)
return websitePages.length > 0 && isVectorSpaceFull && enableBilling
if (datasourceType === DatasourceType.onlineDrive)
return onlineDriveFileList.length > 0 && isVectorSpaceFull && enableBilling
return false
}, [allFileLoaded, datasource, datasourceType, enableBilling, isVectorSpaceFull, onlineDocuments.length, onlineDriveFileList.length, websitePages.length])
const supportBatchUpload = !enableBilling || plan.type !== 'sandbox'
// UI state
const {
datasourceType,
isShowVectorSpaceFull,
nextBtnDisabled,
showSelect,
totalOptions,
selectedOptions,
tip,
} = useDatasourceUIState({
datasource,
allFileLoaded,
localFileListLength: localFileList.length,
onlineDocumentsLength: onlineDocuments.length,
websitePagesLength: websitePages.length,
selectedFileIdsLength: selectedFileIds.length,
onlineDriveFileList,
isVectorSpaceFull,
enableBilling,
currentWorkspacePagesLength: currentWorkspace?.pages.length ?? 0,
fileUploadConfig,
})
// Plan upgrade modal
const [isShowPlanUpgradeModal, {
setTrue: showPlanUpgradeModal,
setFalse: hidePlanUpgradeModal,
}] = useBoolean(false)
// Next step with batch upload check
const handleNextStep = useCallback(() => {
if (!supportBatchUpload) {
let isMultiple = false
if (datasourceType === DatasourceType.localFile && localFileList.length > 1)
isMultiple = true
if (datasourceType === DatasourceType.onlineDocument && onlineDocuments.length > 1)
isMultiple = true
if (datasourceType === DatasourceType.websiteCrawl && websitePages.length > 1)
isMultiple = true
if (datasourceType === DatasourceType.onlineDrive && selectedFileIds.length > 1)
isMultiple = true
if (isMultiple) {
const multipleCheckMap: Record<string, number> = {
[DatasourceType.localFile]: localFileList.length,
[DatasourceType.onlineDocument]: onlineDocuments.length,
[DatasourceType.websiteCrawl]: websitePages.length,
[DatasourceType.onlineDrive]: selectedFileIds.length,
}
const count = datasourceType ? multipleCheckMap[datasourceType] : 0
if (count > 1) {
showPlanUpgradeModal()
return
}
@ -136,334 +141,44 @@ const CreateFormPipeline = () => {
doHandleNextStep()
}, [datasourceType, doHandleNextStep, localFileList.length, onlineDocuments.length, selectedFileIds.length, showPlanUpgradeModal, supportBatchUpload, websitePages.length])
const nextBtnDisabled = useMemo(() => {
if (!datasource)
return true
if (datasourceType === DatasourceType.localFile)
return isShowVectorSpaceFull || !localFileList.length || !allFileLoaded
if (datasourceType === DatasourceType.onlineDocument)
return isShowVectorSpaceFull || !onlineDocuments.length
if (datasourceType === DatasourceType.websiteCrawl)
return isShowVectorSpaceFull || !websitePages.length
if (datasourceType === DatasourceType.onlineDrive)
return isShowVectorSpaceFull || !selectedFileIds.length
return false
}, [datasource, datasourceType, isShowVectorSpaceFull, localFileList.length, allFileLoaded, onlineDocuments.length, websitePages.length, selectedFileIds.length])
// Datasource actions
const {
isPreview,
formRef,
isIdle,
isPending,
onClickProcess,
onClickPreview,
handleSubmit,
handlePreviewFileChange,
handlePreviewOnlineDocumentChange,
handlePreviewWebsiteChange,
handlePreviewOnlineDriveFileChange,
handleSelectAll,
handleSwitchDataSource,
handleCredentialChange,
} = useDatasourceActions({
datasource,
datasourceType,
pipelineId,
dataSourceStore,
setEstimateData,
setBatchId,
setDocuments,
handleNextStep,
PagesMapAndSelectedPagesId,
currentWorkspacePages: currentWorkspace?.pages,
clearOnlineDocumentData,
clearWebsiteCrawlData,
clearOnlineDriveData,
setDatasource,
})
const fileUploadConfig = useMemo(() => fileUploadConfigResponse ?? {
file_size_limit: 15,
batch_count_limit: 5,
}, [fileUploadConfigResponse])
const showSelect = useMemo(() => {
if (datasourceType === DatasourceType.onlineDocument) {
const pagesCount = currentWorkspace?.pages.length ?? 0
return pagesCount > 0
}
if (datasourceType === DatasourceType.onlineDrive) {
const isBucketList = onlineDriveFileList.some(file => file.type === 'bucket')
return !isBucketList && onlineDriveFileList.filter((item) => {
return item.type !== 'bucket'
}).length > 0
}
return false
}, [currentWorkspace?.pages.length, datasourceType, onlineDriveFileList])
const totalOptions = useMemo(() => {
if (datasourceType === DatasourceType.onlineDocument)
return currentWorkspace?.pages.length
if (datasourceType === DatasourceType.onlineDrive) {
return onlineDriveFileList.filter((item) => {
return item.type !== 'bucket'
}).length
}
}, [currentWorkspace?.pages.length, datasourceType, onlineDriveFileList])
const selectedOptions = useMemo(() => {
if (datasourceType === DatasourceType.onlineDocument)
return onlineDocuments.length
if (datasourceType === DatasourceType.onlineDrive)
return selectedFileIds.length
}, [datasourceType, onlineDocuments.length, selectedFileIds.length])
const tip = useMemo(() => {
if (datasourceType === DatasourceType.onlineDocument)
return t('addDocuments.selectOnlineDocumentTip', { ns: 'datasetPipeline', count: 50 })
if (datasourceType === DatasourceType.onlineDrive) {
return t('addDocuments.selectOnlineDriveTip', {
ns: 'datasetPipeline',
count: fileUploadConfig.batch_count_limit,
fileSize: fileUploadConfig.file_size_limit,
})
}
return ''
}, [datasourceType, fileUploadConfig.batch_count_limit, fileUploadConfig.file_size_limit, t])
const { mutateAsync: runPublishedPipeline, isIdle, isPending } = useRunPublishedPipeline()
const handlePreviewChunks = useCallback(async (data: Record<string, any>) => {
if (!datasource)
return
const {
previewLocalFileRef,
previewOnlineDocumentRef,
previewWebsitePageRef,
previewOnlineDriveFileRef,
currentCredentialId,
} = dataSourceStore.getState()
const datasourceInfoList: Record<string, any>[] = []
if (datasourceType === DatasourceType.localFile) {
const { id, name, type, size, extension, mime_type } = previewLocalFileRef.current as File
const documentInfo = {
related_id: id,
name,
type,
size,
extension,
mime_type,
url: '',
transfer_method: TransferMethod.local_file,
credential_id: currentCredentialId,
}
datasourceInfoList.push(documentInfo)
}
if (datasourceType === DatasourceType.onlineDocument) {
const { workspace_id, ...rest } = previewOnlineDocumentRef.current!
const documentInfo = {
workspace_id,
page: rest,
credential_id: currentCredentialId,
}
datasourceInfoList.push(documentInfo)
}
if (datasourceType === DatasourceType.websiteCrawl) {
datasourceInfoList.push({
...previewWebsitePageRef.current!,
credential_id: currentCredentialId,
})
}
if (datasourceType === DatasourceType.onlineDrive) {
const { bucket } = dataSourceStore.getState()
const { id, type, name } = previewOnlineDriveFileRef.current!
datasourceInfoList.push({
bucket,
id,
name,
type,
credential_id: currentCredentialId,
})
}
await runPublishedPipeline({
pipeline_id: pipelineId!,
inputs: data,
start_node_id: datasource.nodeId,
datasource_type: datasourceType as DatasourceType,
datasource_info_list: datasourceInfoList,
is_preview: true,
}, {
onSuccess: (res) => {
setEstimateData((res as PublishedPipelineRunPreviewResponse).data.outputs)
},
})
}, [datasource, datasourceType, runPublishedPipeline, pipelineId, dataSourceStore])
const handleProcess = useCallback(async (data: Record<string, any>) => {
if (!datasource)
return
const { currentCredentialId } = dataSourceStore.getState()
const datasourceInfoList: Record<string, any>[] = []
if (datasourceType === DatasourceType.localFile) {
const {
localFileList,
} = dataSourceStore.getState()
localFileList.forEach((file) => {
const { id, name, type, size, extension, mime_type } = file.file
const documentInfo = {
related_id: id,
name,
type,
size,
extension,
mime_type,
url: '',
transfer_method: TransferMethod.local_file,
credential_id: currentCredentialId,
}
datasourceInfoList.push(documentInfo)
})
}
if (datasourceType === DatasourceType.onlineDocument) {
const {
onlineDocuments,
} = dataSourceStore.getState()
onlineDocuments.forEach((page) => {
const { workspace_id, ...rest } = page
const documentInfo = {
workspace_id,
page: rest,
credential_id: currentCredentialId,
}
datasourceInfoList.push(documentInfo)
})
}
if (datasourceType === DatasourceType.websiteCrawl) {
const {
websitePages,
} = dataSourceStore.getState()
websitePages.forEach((websitePage) => {
datasourceInfoList.push({
...websitePage,
credential_id: currentCredentialId,
})
})
}
if (datasourceType === DatasourceType.onlineDrive) {
const {
bucket,
selectedFileIds,
onlineDriveFileList,
} = dataSourceStore.getState()
selectedFileIds.forEach((id) => {
const file = onlineDriveFileList.find(file => file.id === id)
datasourceInfoList.push({
bucket,
id: file?.id,
name: file?.name,
type: file?.type,
credential_id: currentCredentialId,
})
})
}
await runPublishedPipeline({
pipeline_id: pipelineId!,
inputs: data,
start_node_id: datasource.nodeId,
datasource_type: datasourceType as DatasourceType,
datasource_info_list: datasourceInfoList,
is_preview: false,
}, {
onSuccess: (res) => {
setBatchId((res as PublishedPipelineRunResponse).batch || '')
setDocuments((res as PublishedPipelineRunResponse).documents || [])
handleNextStep()
trackEvent('dataset_document_added', {
data_source_type: datasourceType,
indexing_technique: 'pipeline',
})
},
})
}, [dataSourceStore, datasource, datasourceType, handleNextStep, pipelineId, runPublishedPipeline])
const onClickProcess = useCallback(() => {
isPreview.current = false
formRef.current?.submit()
}, [])
const onClickPreview = useCallback(() => {
isPreview.current = true
formRef.current?.submit()
}, [])
const handleSubmit = useCallback((data: Record<string, any>) => {
if (isPreview.current)
handlePreviewChunks(data)
else
handleProcess(data)
}, [handlePreviewChunks, handleProcess])
const handlePreviewFileChange = useCallback((file: DocumentItem) => {
const { previewLocalFileRef } = dataSourceStore.getState()
previewLocalFileRef.current = file
onClickPreview()
}, [dataSourceStore, onClickPreview])
const handlePreviewOnlineDocumentChange = useCallback((page: NotionPage) => {
const { previewOnlineDocumentRef } = dataSourceStore.getState()
previewOnlineDocumentRef.current = page
onClickPreview()
}, [dataSourceStore, onClickPreview])
const handlePreviewWebsiteChange = useCallback((website: CrawlResultItem) => {
const { previewWebsitePageRef } = dataSourceStore.getState()
previewWebsitePageRef.current = website
onClickPreview()
}, [dataSourceStore, onClickPreview])
const handlePreviewOnlineDriveFileChange = useCallback((file: OnlineDriveFile) => {
const { previewOnlineDriveFileRef } = dataSourceStore.getState()
previewOnlineDriveFileRef.current = file
onClickPreview()
}, [dataSourceStore, onClickPreview])
const handleSelectAll = useCallback(() => {
const {
onlineDocuments,
onlineDriveFileList,
selectedFileIds,
setOnlineDocuments,
setSelectedFileIds,
setSelectedPagesId,
} = dataSourceStore.getState()
if (datasourceType === DatasourceType.onlineDocument) {
const allIds = currentWorkspace?.pages.map(page => page.page_id) || []
if (onlineDocuments.length < allIds.length) {
const selectedPages = Array.from(allIds).map(pageId => PagesMapAndSelectedPagesId[pageId])
setOnlineDocuments(selectedPages)
setSelectedPagesId(new Set(allIds))
}
else {
setOnlineDocuments([])
setSelectedPagesId(new Set())
}
}
if (datasourceType === DatasourceType.onlineDrive) {
const allKeys = onlineDriveFileList.filter((item) => {
return item.type !== 'bucket'
}).map(file => file.id)
if (selectedFileIds.length < allKeys.length)
setSelectedFileIds(allKeys)
else
setSelectedFileIds([])
}
}, [PagesMapAndSelectedPagesId, currentWorkspace?.pages, dataSourceStore, datasourceType])
const clearDataSourceData = useCallback((dataSource: Datasource) => {
const providerType = dataSource.nodeData.provider_type
if (providerType === DatasourceType.onlineDocument)
clearOnlineDocumentData()
else if (providerType === DatasourceType.websiteCrawl)
clearWebsiteCrawlData()
else if (providerType === DatasourceType.onlineDrive)
clearOnlineDriveData()
}, [clearOnlineDocumentData, clearOnlineDriveData, clearWebsiteCrawlData])
const handleSwitchDataSource = useCallback((dataSource: Datasource) => {
const {
setCurrentCredentialId,
currentNodeIdRef,
} = dataSourceStore.getState()
clearDataSourceData(dataSource)
setCurrentCredentialId('')
currentNodeIdRef.current = dataSource.nodeId
setDatasource(dataSource)
}, [clearDataSourceData, dataSourceStore])
const handleCredentialChange = useCallback((credentialId: string) => {
const { setCurrentCredentialId } = dataSourceStore.getState()
clearDataSourceData(datasource!)
setCurrentCredentialId(credentialId)
}, [clearDataSourceData, dataSourceStore, datasource])
if (isFetchingPipelineInfo) {
return (
<Loading type="app" />
)
}
if (isFetchingPipelineInfo)
return <Loading type="app" />
return (
<div
className="relative flex h-[calc(100vh-56px)] w-full min-w-[1024px] overflow-x-auto rounded-t-2xl border-t border-effects-highlight bg-background-default-subtle"
>
<div className="relative flex h-[calc(100vh-56px)] w-full min-w-[1024px] overflow-x-auto rounded-t-2xl border-t border-effects-highlight bg-background-default-subtle">
<div className="h-full min-w-0 flex-1">
<div className="flex h-full flex-col px-14">
<LeftHeader
@ -472,139 +187,77 @@ const CreateFormPipeline = () => {
currentStep={currentStep}
/>
<div className="grow overflow-y-auto">
{
currentStep === 1 && (
<div className="flex flex-col gap-y-5 pt-4">
<DataSourceOptions
datasourceNodeId={datasource?.nodeId || ''}
onSelect={handleSwitchDataSource}
pipelineNodes={(pipelineInfo?.graph.nodes || []) as Node<DataSourceNodeType>[]}
/>
{datasourceType === DatasourceType.localFile && (
<LocalFile
allowedExtensions={datasource!.nodeData.fileExtensions || []}
supportBatchUpload={supportBatchUpload}
/>
)}
{datasourceType === DatasourceType.onlineDocument && (
<OnlineDocuments
nodeId={datasource!.nodeId}
nodeData={datasource!.nodeData}
onCredentialChange={handleCredentialChange}
/>
)}
{datasourceType === DatasourceType.websiteCrawl && (
<WebsiteCrawl
nodeId={datasource!.nodeId}
nodeData={datasource!.nodeData}
onCredentialChange={handleCredentialChange}
/>
)}
{datasourceType === DatasourceType.onlineDrive && (
<OnlineDrive
nodeId={datasource!.nodeId}
nodeData={datasource!.nodeData}
onCredentialChange={handleCredentialChange}
/>
)}
{isShowVectorSpaceFull && (
<VectorSpaceFull />
)}
<Actions
showSelect={showSelect}
totalOptions={totalOptions}
selectedOptions={selectedOptions}
onSelectAll={handleSelectAll}
disabled={nextBtnDisabled}
handleNextStep={handleNextStep}
tip={tip}
/>
{
!supportBatchUpload && datasourceType === DatasourceType.localFile && localFileList.length > 0 && (
<>
<Divider type="horizontal" className="my-4 h-px bg-divider-subtle" />
<UpgradeCard />
</>
)
}
</div>
)
}
{
currentStep === 2 && (
<ProcessDocuments
ref={formRef}
dataSourceNodeId={datasource!.nodeId}
isRunning={isPending}
onProcess={onClickProcess}
onPreview={onClickPreview}
onSubmit={handleSubmit}
onBack={handleBackStep}
/>
)
}
{
currentStep === 3 && (
<Processing
batchId={batchId}
documents={documents}
/>
)
}
{currentStep === 1 && (
<StepOneContent
datasource={datasource}
datasourceType={datasourceType}
pipelineNodes={(pipelineInfo?.graph.nodes || []) as Node<DataSourceNodeType>[]}
supportBatchUpload={supportBatchUpload}
localFileListLength={localFileList.length}
isShowVectorSpaceFull={isShowVectorSpaceFull}
showSelect={showSelect}
totalOptions={totalOptions}
selectedOptions={selectedOptions}
tip={tip}
nextBtnDisabled={nextBtnDisabled}
onSelectDataSource={handleSwitchDataSource}
onCredentialChange={handleCredentialChange}
onSelectAll={handleSelectAll}
onNextStep={handleNextStep}
/>
)}
{currentStep === 2 && (
<StepTwoContent
formRef={formRef}
dataSourceNodeId={datasource!.nodeId}
isRunning={isPending}
onProcess={onClickProcess}
onPreview={onClickPreview}
onSubmit={handleSubmit}
onBack={handleBackStep}
/>
)}
{currentStep === 3 && (
<StepThreeContent
batchId={batchId}
documents={documents}
/>
)}
</div>
</div>
</div>
{/* Preview */}
{
currentStep === 1 && (
<div className="h-full min-w-0 flex-1">
<div className="flex h-full flex-col pl-2 pt-2">
{currentLocalFile && (
<FilePreview
file={currentLocalFile}
hidePreview={hidePreviewLocalFile}
/>
)}
{currentDocument && (
<OnlineDocumentPreview
datasourceNodeId={datasource!.nodeId}
currentPage={currentDocument}
hidePreview={hidePreviewOnlineDocument}
/>
)}
{currentWebsite && (
<WebsitePreview
currentWebsite={currentWebsite}
hidePreview={hideWebsitePreview}
/>
)}
</div>
</div>
)
}
{
currentStep === 2 && (
<div className="h-full min-w-0 flex-1">
<div className="flex h-full flex-col pl-2 pt-2">
<ChunkPreview
dataSourceType={datasourceType as DatasourceType}
localFiles={localFileList.map(file => file.file)}
onlineDocuments={onlineDocuments}
websitePages={websitePages}
onlineDriveFiles={selectedOnlineDriveFileList}
isIdle={isIdle}
isPending={isPending && isPreview.current}
estimateData={estimateData}
onPreview={onClickPreview}
handlePreviewFileChange={handlePreviewFileChange}
handlePreviewOnlineDocumentChange={handlePreviewOnlineDocumentChange}
handlePreviewWebsitePageChange={handlePreviewWebsiteChange}
handlePreviewOnlineDriveFileChange={handlePreviewOnlineDriveFileChange}
/>
</div>
</div>
)
}
{/* Preview Panel */}
{currentStep === 1 && (
<StepOnePreview
datasource={datasource}
currentLocalFile={currentLocalFile}
currentDocument={currentDocument}
currentWebsite={currentWebsite}
hidePreviewLocalFile={hidePreviewLocalFile}
hidePreviewOnlineDocument={hidePreviewOnlineDocument}
hideWebsitePreview={hideWebsitePreview}
/>
)}
{currentStep === 2 && (
<StepTwoPreview
datasourceType={datasourceType}
localFileList={localFileList}
onlineDocuments={onlineDocuments}
websitePages={websitePages}
selectedOnlineDriveFileList={selectedOnlineDriveFileList}
isIdle={isIdle}
isPendingPreview={isPending && isPreview.current}
estimateData={estimateData}
onPreview={onClickPreview}
handlePreviewFileChange={handlePreviewFileChange}
handlePreviewOnlineDocumentChange={handlePreviewOnlineDocumentChange}
handlePreviewWebsitePageChange={handlePreviewWebsiteChange}
handlePreviewOnlineDriveFileChange={handlePreviewOnlineDriveFileChange}
/>
)}
{/* Plan Upgrade Modal */}
{isShowPlanUpgradeModal && (
<PlanUpgradeModal
show

View File

@ -0,0 +1,3 @@
export { default as StepOneContent } from './step-one-content'
export { default as StepThreeContent } from './step-three-content'
export { default as StepTwoContent } from './step-two-content'

View File

@ -0,0 +1,112 @@
'use client'
import type { Datasource } from '@/app/components/rag-pipeline/components/panel/test-run/types'
import type { NotionPage } from '@/models/common'
import type { CrawlResultItem, CustomFile, DocumentItem, FileIndexingEstimateResponse, FileItem } from '@/models/datasets'
import type { DatasourceType, OnlineDriveFile } from '@/models/pipeline'
import { memo } from 'react'
import ChunkPreview from '../preview/chunk-preview'
import FilePreview from '../preview/file-preview'
import OnlineDocumentPreview from '../preview/online-document-preview'
import WebsitePreview from '../preview/web-preview'
type StepOnePreviewProps = {
datasource: Datasource | undefined
currentLocalFile: CustomFile | undefined
currentDocument: (NotionPage & { workspace_id: string }) | undefined
currentWebsite: CrawlResultItem | undefined
hidePreviewLocalFile: () => void
hidePreviewOnlineDocument: () => void
hideWebsitePreview: () => void
}
export const StepOnePreview = memo(({
datasource,
currentLocalFile,
currentDocument,
currentWebsite,
hidePreviewLocalFile,
hidePreviewOnlineDocument,
hideWebsitePreview,
}: StepOnePreviewProps) => {
return (
<div className="h-full min-w-0 flex-1">
<div className="flex h-full flex-col pl-2 pt-2">
{currentLocalFile && (
<FilePreview
file={currentLocalFile}
hidePreview={hidePreviewLocalFile}
/>
)}
{currentDocument && (
<OnlineDocumentPreview
datasourceNodeId={datasource!.nodeId}
currentPage={currentDocument}
hidePreview={hidePreviewOnlineDocument}
/>
)}
{currentWebsite && (
<WebsitePreview
currentWebsite={currentWebsite}
hidePreview={hideWebsitePreview}
/>
)}
</div>
</div>
)
})
StepOnePreview.displayName = 'StepOnePreview'
type StepTwoPreviewProps = {
datasourceType: string | undefined
localFileList: FileItem[]
onlineDocuments: (NotionPage & { workspace_id: string })[]
websitePages: CrawlResultItem[]
selectedOnlineDriveFileList: OnlineDriveFile[]
isIdle: boolean
isPendingPreview: boolean
estimateData: FileIndexingEstimateResponse | undefined
onPreview: () => void
handlePreviewFileChange: (file: DocumentItem) => void
handlePreviewOnlineDocumentChange: (page: NotionPage) => void
handlePreviewWebsitePageChange: (website: CrawlResultItem) => void
handlePreviewOnlineDriveFileChange: (file: OnlineDriveFile) => void
}
export const StepTwoPreview = memo(({
datasourceType,
localFileList,
onlineDocuments,
websitePages,
selectedOnlineDriveFileList,
isIdle,
isPendingPreview,
estimateData,
onPreview,
handlePreviewFileChange,
handlePreviewOnlineDocumentChange,
handlePreviewWebsitePageChange,
handlePreviewOnlineDriveFileChange,
}: StepTwoPreviewProps) => {
return (
<div className="h-full min-w-0 flex-1">
<div className="flex h-full flex-col pl-2 pt-2">
<ChunkPreview
dataSourceType={datasourceType as DatasourceType}
localFiles={localFileList.map(file => file.file)}
onlineDocuments={onlineDocuments}
websitePages={websitePages}
onlineDriveFiles={selectedOnlineDriveFileList}
isIdle={isIdle}
isPending={isPendingPreview}
estimateData={estimateData}
onPreview={onPreview}
handlePreviewFileChange={handlePreviewFileChange}
handlePreviewOnlineDocumentChange={handlePreviewOnlineDocumentChange}
handlePreviewWebsitePageChange={handlePreviewWebsitePageChange}
handlePreviewOnlineDriveFileChange={handlePreviewOnlineDriveFileChange}
/>
</div>
</div>
)
})
StepTwoPreview.displayName = 'StepTwoPreview'

View File

@ -0,0 +1,110 @@
'use client'
import type { Datasource } from '@/app/components/rag-pipeline/components/panel/test-run/types'
import type { DataSourceNodeType } from '@/app/components/workflow/nodes/data-source/types'
import type { Node } from '@/app/components/workflow/types'
import { memo } from 'react'
import Divider from '@/app/components/base/divider'
import VectorSpaceFull from '@/app/components/billing/vector-space-full'
import LocalFile from '@/app/components/datasets/documents/create-from-pipeline/data-source/local-file'
import OnlineDocuments from '@/app/components/datasets/documents/create-from-pipeline/data-source/online-documents'
import OnlineDrive from '@/app/components/datasets/documents/create-from-pipeline/data-source/online-drive'
import WebsiteCrawl from '@/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl'
import { DatasourceType } from '@/models/pipeline'
import UpgradeCard from '../../../create/step-one/upgrade-card'
import Actions from '../actions'
import DataSourceOptions from '../data-source-options'
type StepOneContentProps = {
datasource: Datasource | undefined
datasourceType: string | undefined
pipelineNodes: Node<DataSourceNodeType>[]
supportBatchUpload: boolean
localFileListLength: number
isShowVectorSpaceFull: boolean
showSelect: boolean
totalOptions: number | undefined
selectedOptions: number | undefined
tip: string
nextBtnDisabled: boolean
onSelectDataSource: (dataSource: Datasource) => void
onCredentialChange: (credentialId: string) => void
onSelectAll: () => void
onNextStep: () => void
}
const StepOneContent = ({
datasource,
datasourceType,
pipelineNodes,
supportBatchUpload,
localFileListLength,
isShowVectorSpaceFull,
showSelect,
totalOptions,
selectedOptions,
tip,
nextBtnDisabled,
onSelectDataSource,
onCredentialChange,
onSelectAll,
onNextStep,
}: StepOneContentProps) => {
const showUpgradeCard = !supportBatchUpload
&& datasourceType === DatasourceType.localFile
&& localFileListLength > 0
return (
<div className="flex flex-col gap-y-5 pt-4">
<DataSourceOptions
datasourceNodeId={datasource?.nodeId || ''}
onSelect={onSelectDataSource}
pipelineNodes={pipelineNodes}
/>
{datasourceType === DatasourceType.localFile && (
<LocalFile
allowedExtensions={datasource!.nodeData.fileExtensions || []}
supportBatchUpload={supportBatchUpload}
/>
)}
{datasourceType === DatasourceType.onlineDocument && (
<OnlineDocuments
nodeId={datasource!.nodeId}
nodeData={datasource!.nodeData}
onCredentialChange={onCredentialChange}
/>
)}
{datasourceType === DatasourceType.websiteCrawl && (
<WebsiteCrawl
nodeId={datasource!.nodeId}
nodeData={datasource!.nodeData}
onCredentialChange={onCredentialChange}
/>
)}
{datasourceType === DatasourceType.onlineDrive && (
<OnlineDrive
nodeId={datasource!.nodeId}
nodeData={datasource!.nodeData}
onCredentialChange={onCredentialChange}
/>
)}
{isShowVectorSpaceFull && <VectorSpaceFull />}
<Actions
showSelect={showSelect}
totalOptions={totalOptions}
selectedOptions={selectedOptions}
onSelectAll={onSelectAll}
disabled={nextBtnDisabled}
handleNextStep={onNextStep}
tip={tip}
/>
{showUpgradeCard && (
<>
<Divider type="horizontal" className="my-4 h-px bg-divider-subtle" />
<UpgradeCard />
</>
)}
</div>
)
}
export default memo(StepOneContent)

View File

@ -0,0 +1,23 @@
'use client'
import type { InitialDocumentDetail } from '@/models/pipeline'
import { memo } from 'react'
import Processing from '../processing'
type StepThreeContentProps = {
batchId: string
documents: InitialDocumentDetail[]
}
const StepThreeContent = ({
batchId,
documents,
}: StepThreeContentProps) => {
return (
<Processing
batchId={batchId}
documents={documents}
/>
)
}
export default memo(StepThreeContent)

View File

@ -0,0 +1,38 @@
'use client'
import type { RefObject } from 'react'
import { memo } from 'react'
import ProcessDocuments from '../process-documents'
type StepTwoContentProps = {
formRef: RefObject<{ submit: () => void } | null>
dataSourceNodeId: string
isRunning: boolean
onProcess: () => void
onPreview: () => void
onSubmit: (data: Record<string, unknown>) => void
onBack: () => void
}
const StepTwoContent = ({
formRef,
dataSourceNodeId,
isRunning,
onProcess,
onPreview,
onSubmit,
onBack,
}: StepTwoContentProps) => {
return (
<ProcessDocuments
ref={formRef}
dataSourceNodeId={dataSourceNodeId}
isRunning={isRunning}
onProcess={onProcess}
onPreview={onPreview}
onSubmit={onSubmit}
onBack={onBack}
/>
)
}
export default memo(StepTwoContent)

View File

@ -0,0 +1,63 @@
import type { NotionPage } from '@/models/common'
import type { CrawlResultItem, CustomFile as File } from '@/models/datasets'
import type { OnlineDriveFile } from '@/models/pipeline'
import { TransferMethod } from '@/types/app'
/**
* Build datasource info for local files
*/
export const buildLocalFileDatasourceInfo = (
file: File,
credentialId: string,
): Record<string, unknown> => ({
related_id: file.id,
name: file.name,
type: file.type,
size: file.size,
extension: file.extension,
mime_type: file.mime_type,
url: '',
transfer_method: TransferMethod.local_file,
credential_id: credentialId,
})
/**
* Build datasource info for online documents
*/
export const buildOnlineDocumentDatasourceInfo = (
page: NotionPage & { workspace_id: string },
credentialId: string,
): Record<string, unknown> => {
const { workspace_id, ...rest } = page
return {
workspace_id,
page: rest,
credential_id: credentialId,
}
}
/**
* Build datasource info for website crawl
*/
export const buildWebsiteCrawlDatasourceInfo = (
page: CrawlResultItem,
credentialId: string,
): Record<string, unknown> => ({
...page,
credential_id: credentialId,
})
/**
* Build datasource info for online drive
*/
export const buildOnlineDriveDatasourceInfo = (
file: OnlineDriveFile,
bucket: string,
credentialId: string,
): Record<string, unknown> => ({
bucket,
id: file.id,
name: file.name,
type: file.type,
credential_id: credentialId,
})

View File

@ -18,7 +18,7 @@ import { useDocumentDetail, useDocumentMetadata, useInvalidDocumentList } from '
import { useCheckSegmentBatchImportProgress, useChildSegmentListKey, useSegmentBatchImport, useSegmentListKey } from '@/service/knowledge/use-segment'
import { useInvalid } from '@/service/use-base'
import { cn } from '@/utils/classnames'
import Operations from '../operations'
import Operations from '../components/operations'
import StatusItem from '../status-item'
import BatchModal from './batch-modal'
import Completed from './completed'

View File

@ -0,0 +1,197 @@
import type { DocumentListResponse } from '@/models/datasets'
import type { SortType } from '@/service/datasets'
import { useDebounce, useDebounceFn } from 'ahooks'
import { useCallback, useEffect, useMemo, useState } from 'react'
import { normalizeStatusForQuery, sanitizeStatusValue } from '../status-filter'
import useDocumentListQueryState from './use-document-list-query-state'
/**
* Custom hook to manage documents page state including:
* - Search state (input value, debounced search value)
* - Filter state (status filter, sort value)
* - Pagination state (current page, limit)
* - Selection state (selected document ids)
* - Polling state (timer control for auto-refresh)
*/
export function useDocumentsPageState() {
const { query, updateQuery } = useDocumentListQueryState()
// Search state
const [inputValue, setInputValue] = useState<string>('')
const [searchValue, setSearchValue] = useState<string>('')
const debouncedSearchValue = useDebounce(searchValue, { wait: 500 })
// Filter & sort state
const [statusFilterValue, setStatusFilterValue] = useState<string>(() => sanitizeStatusValue(query.status))
const [sortValue, setSortValue] = useState<SortType>(query.sort)
const normalizedStatusFilterValue = useMemo(
() => normalizeStatusForQuery(statusFilterValue),
[statusFilterValue],
)
// Pagination state
const [currPage, setCurrPage] = useState<number>(query.page - 1)
const [limit, setLimit] = useState<number>(query.limit)
// Selection state
const [selectedIds, setSelectedIds] = useState<string[]>([])
// Polling state
const [timerCanRun, setTimerCanRun] = useState(true)
// Initialize search value from URL on mount
useEffect(() => {
if (query.keyword) {
setInputValue(query.keyword)
setSearchValue(query.keyword)
}
}, []) // Only run on mount
// Sync local state with URL query changes
useEffect(() => {
setCurrPage(query.page - 1)
setLimit(query.limit)
if (query.keyword !== searchValue) {
setInputValue(query.keyword)
setSearchValue(query.keyword)
}
setStatusFilterValue((prev) => {
const nextValue = sanitizeStatusValue(query.status)
return prev === nextValue ? prev : nextValue
})
setSortValue(query.sort)
}, [query])
// Update URL when search changes
useEffect(() => {
if (debouncedSearchValue !== query.keyword) {
setCurrPage(0)
updateQuery({ keyword: debouncedSearchValue, page: 1 })
}
}, [debouncedSearchValue, query.keyword, updateQuery])
// Clear selection when search changes
useEffect(() => {
if (searchValue !== query.keyword)
setSelectedIds([])
}, [searchValue, query.keyword])
// Clear selection when status filter changes
useEffect(() => {
setSelectedIds([])
}, [normalizedStatusFilterValue])
// Page change handler
const handlePageChange = useCallback((newPage: number) => {
setCurrPage(newPage)
updateQuery({ page: newPage + 1 })
}, [updateQuery])
// Limit change handler
const handleLimitChange = useCallback((newLimit: number) => {
setLimit(newLimit)
setCurrPage(0)
updateQuery({ limit: newLimit, page: 1 })
}, [updateQuery])
// Debounced search handler
const { run: handleSearch } = useDebounceFn(() => {
setSearchValue(inputValue)
}, { wait: 500 })
// Input change handler
const handleInputChange = useCallback((value: string) => {
setInputValue(value)
handleSearch()
}, [handleSearch])
// Status filter change handler
const handleStatusFilterChange = useCallback((value: string) => {
const selectedValue = sanitizeStatusValue(value)
setStatusFilterValue(selectedValue)
setCurrPage(0)
updateQuery({ status: selectedValue, page: 1 })
}, [updateQuery])
// Status filter clear handler
const handleStatusFilterClear = useCallback(() => {
if (statusFilterValue === 'all')
return
setStatusFilterValue('all')
setCurrPage(0)
updateQuery({ status: 'all', page: 1 })
}, [statusFilterValue, updateQuery])
// Sort change handler
const handleSortChange = useCallback((value: string) => {
const next = value as SortType
if (next === sortValue)
return
setSortValue(next)
setCurrPage(0)
updateQuery({ sort: next, page: 1 })
}, [sortValue, updateQuery])
// Update polling state based on documents response
const updatePollingState = useCallback((documentsRes: DocumentListResponse | undefined) => {
if (!documentsRes?.data)
return
let completedNum = 0
documentsRes.data.forEach((documentItem) => {
const { indexing_status } = documentItem
const isEmbedded = indexing_status === 'completed' || indexing_status === 'paused' || indexing_status === 'error'
if (isEmbedded)
completedNum++
})
const hasIncompleteDocuments = completedNum !== documentsRes.data.length
const transientStatuses = ['queuing', 'indexing', 'paused']
const shouldForcePolling = normalizedStatusFilterValue === 'all'
? false
: transientStatuses.includes(normalizedStatusFilterValue)
setTimerCanRun(shouldForcePolling || hasIncompleteDocuments)
}, [normalizedStatusFilterValue])
// Adjust page when total pages change
const adjustPageForTotal = useCallback((documentsRes: DocumentListResponse | undefined) => {
if (!documentsRes)
return
const totalPages = Math.ceil(documentsRes.total / limit)
if (currPage > 0 && currPage + 1 > totalPages)
handlePageChange(totalPages > 0 ? totalPages - 1 : 0)
}, [limit, currPage, handlePageChange])
return {
// Search state
inputValue,
searchValue,
debouncedSearchValue,
handleInputChange,
// Filter & sort state
statusFilterValue,
sortValue,
normalizedStatusFilterValue,
handleStatusFilterChange,
handleStatusFilterClear,
handleSortChange,
// Pagination state
currPage,
limit,
handlePageChange,
handleLimitChange,
// Selection state
selectedIds,
setSelectedIds,
// Polling state
timerCanRun,
updatePollingState,
adjustPageForTotal,
}
}
export default useDocumentsPageState

View File

@ -1,185 +1,55 @@
'use client'
import type { FC } from 'react'
import type { Item } from '@/app/components/base/select'
import type { SortType } from '@/service/datasets'
import { PlusIcon } from '@heroicons/react/24/solid'
import { RiDraftLine, RiExternalLinkLine } from '@remixicon/react'
import { useDebounce, useDebounceFn } from 'ahooks'
import { useRouter } from 'next/navigation'
import * as React from 'react'
import { useCallback, useEffect, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import Input from '@/app/components/base/input'
import { useCallback, useEffect } from 'react'
import Loading from '@/app/components/base/loading'
import IndexFailed from '@/app/components/datasets/common/document-status-with-action/index-failed'
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
import { useDocLink } from '@/context/i18n'
import { useProviderContext } from '@/context/provider-context'
import { DataSourceType } from '@/models/datasets'
import { useDocumentList, useInvalidDocumentDetail, useInvalidDocumentList } from '@/service/knowledge/use-document'
import { useChildSegmentListKey, useSegmentListKey } from '@/service/knowledge/use-segment'
import { useInvalid } from '@/service/use-base'
import { cn } from '@/utils/classnames'
import Chip from '../../base/chip'
import Sort from '../../base/sort'
import AutoDisabledDocument from '../common/document-status-with-action/auto-disabled-document'
import StatusWithAction from '../common/document-status-with-action/status-with-action'
import useEditDocumentMetadata from '../metadata/hooks/use-edit-dataset-metadata'
import DatasetMetadataDrawer from '../metadata/metadata-dataset/dataset-metadata-drawer'
import useDocumentListQueryState from './hooks/use-document-list-query-state'
import List from './list'
import { normalizeStatusForQuery, sanitizeStatusValue } from './status-filter'
import { useIndexStatus } from './status-item/hooks'
import s from './style.module.css'
const FolderPlusIcon = ({ className }: React.SVGProps<SVGElement>) => {
return (
<svg width="20" height="20" viewBox="0 0 20 20" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
<path d="M10.8332 5.83333L9.90355 3.9741C9.63601 3.439 9.50222 3.17144 9.30265 2.97597C9.12615 2.80311 8.91344 2.67164 8.6799 2.59109C8.41581 2.5 8.11668 2.5 7.51841 2.5H4.33317C3.39975 2.5 2.93304 2.5 2.57652 2.68166C2.26292 2.84144 2.00795 3.09641 1.84816 3.41002C1.6665 3.76654 1.6665 4.23325 1.6665 5.16667V5.83333M1.6665 5.83333H14.3332C15.7333 5.83333 16.4334 5.83333 16.9681 6.10582C17.4386 6.3455 17.821 6.72795 18.0607 7.19836C18.3332 7.73314 18.3332 8.4332 18.3332 9.83333V13.5C18.3332 14.9001 18.3332 15.6002 18.0607 16.135C17.821 16.6054 17.4386 16.9878 16.9681 17.2275C16.4334 17.5 15.7333 17.5 14.3332 17.5H5.6665C4.26637 17.5 3.56631 17.5 3.03153 17.2275C2.56112 16.9878 2.17867 16.6054 1.93899 16.135C1.6665 15.6002 1.6665 14.9001 1.6665 13.5V5.83333ZM9.99984 14.1667V9.16667M7.49984 11.6667H12.4998" stroke="#667085" strokeWidth="1.5" strokeLinecap="round" strokeLinejoin="round" />
</svg>
)
}
const ThreeDotsIcon = ({ className }: React.SVGProps<SVGElement>) => {
return (
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
<path d="M5 6.5V5M8.93934 7.56066L10 6.5M10.0103 11.5H11.5103" stroke="#374151" strokeWidth="2" strokeLinecap="round" strokeLinejoin="round" />
</svg>
)
}
const NotionIcon = ({ className }: React.SVGProps<SVGElement>) => {
return (
<svg width="20" height="20" viewBox="0 0 20 20" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
<g clipPath="url(#clip0_2164_11263)">
<path fillRule="evenodd" clipRule="evenodd" d="M3.5725 18.2611L1.4229 15.5832C0.905706 14.9389 0.625 14.1466 0.625 13.3312V3.63437C0.625 2.4129 1.60224 1.39936 2.86295 1.31328L12.8326 0.632614C13.5569 0.583164 14.2768 0.775682 14.8717 1.17794L18.3745 3.5462C19.0015 3.97012 19.375 4.66312 19.375 5.40266V16.427C19.375 17.6223 18.4141 18.6121 17.1798 18.688L6.11458 19.3692C5.12958 19.4298 4.17749 19.0148 3.5725 18.2611Z" fill="white" />
<path d="M7.03006 8.48669V8.35974C7.03006 8.03794 7.28779 7.77104 7.61997 7.74886L10.0396 7.58733L13.3857 12.5147V8.19009L12.5244 8.07528V8.01498C12.5244 7.68939 12.788 7.42074 13.1244 7.4035L15.326 7.29073V7.60755C15.326 7.75628 15.2154 7.88349 15.0638 7.90913L14.534 7.99874V15.0023L13.8691 15.231C13.3136 15.422 12.6952 15.2175 12.3772 14.7377L9.12879 9.83574V14.5144L10.1287 14.7057L10.1147 14.7985C10.0711 15.089 9.82028 15.3087 9.51687 15.3222L7.03006 15.4329C6.99718 15.1205 7.23132 14.841 7.55431 14.807L7.88143 14.7727V8.53453L7.03006 8.48669Z" fill="black" />
<path fillRule="evenodd" clipRule="evenodd" d="M12.9218 1.85424L2.95217 2.53491C2.35499 2.57568 1.89209 3.05578 1.89209 3.63437V13.3312C1.89209 13.8748 2.07923 14.403 2.42402 14.8325L4.57362 17.5104C4.92117 17.9434 5.46812 18.1818 6.03397 18.147L17.0991 17.4658C17.6663 17.4309 18.1078 16.9762 18.1078 16.427V5.40266C18.1078 5.06287 17.9362 4.74447 17.6481 4.54969L14.1453 2.18143C13.7883 1.94008 13.3564 1.82457 12.9218 1.85424ZM3.44654 3.78562C3.30788 3.68296 3.37387 3.46909 3.54806 3.4566L12.9889 2.77944C13.2897 2.75787 13.5886 2.8407 13.8318 3.01305L15.7261 4.35508C15.798 4.40603 15.7642 4.51602 15.6752 4.52086L5.67742 5.0646C5.37485 5.08106 5.0762 4.99217 4.83563 4.81406L3.44654 3.78562ZM5.20848 6.76919C5.20848 6.4444 5.47088 6.1761 5.80642 6.15783L16.3769 5.58216C16.7039 5.56435 16.9792 5.81583 16.9792 6.13239V15.6783C16.9792 16.0025 16.7177 16.2705 16.3829 16.2896L5.8793 16.8872C5.51537 16.9079 5.20848 16.6283 5.20848 16.2759V6.76919Z" fill="black" />
</g>
<defs>
<clipPath id="clip0_2164_11263">
<rect width="20" height="20" fill="white" />
</clipPath>
</defs>
</svg>
)
}
const EmptyElement: FC<{ canAdd: boolean, onClick: () => void, type?: 'upload' | 'sync' }> = ({ canAdd = true, onClick, type = 'upload' }) => {
const { t } = useTranslation()
return (
<div className={s.emptyWrapper}>
<div className={s.emptyElement}>
<div className={s.emptySymbolIconWrapper}>
{type === 'upload' ? <FolderPlusIcon /> : <NotionIcon />}
</div>
<span className={s.emptyTitle}>
{t('list.empty.title', { ns: 'datasetDocuments' })}
<ThreeDotsIcon className="relative -left-1.5 -top-3 inline" />
</span>
<div className={s.emptyTip}>
{t(`list.empty.${type}.tip`, { ns: 'datasetDocuments' })}
</div>
{type === 'upload' && canAdd && (
<Button onClick={onClick} className={s.addFileBtn} variant="secondary-accent">
<PlusIcon className={s.plusIcon} />
{t('list.addFile', { ns: 'datasetDocuments' })}
</Button>
)}
</div>
</div>
)
}
import DocumentsHeader from './components/documents-header'
import EmptyElement from './components/empty-element'
import List from './components/list'
import useDocumentsPageState from './hooks/use-documents-page-state'
type IDocumentsProps = {
datasetId: string
}
const Documents: FC<IDocumentsProps> = ({ datasetId }) => {
const { t } = useTranslation()
const docLink = useDocLink()
const router = useRouter()
const { plan } = useProviderContext()
const isFreePlan = plan.type === 'sandbox'
const { query, updateQuery } = useDocumentListQueryState()
const [inputValue, setInputValue] = useState<string>('') // the input value
const [searchValue, setSearchValue] = useState<string>('')
const [statusFilterValue, setStatusFilterValue] = useState<string>(() => sanitizeStatusValue(query.status))
const [sortValue, setSortValue] = useState<SortType>(query.sort)
const DOC_INDEX_STATUS_MAP = useIndexStatus()
const [currPage, setCurrPage] = React.useState<number>(query.page - 1) // Convert to 0-based index
const [limit, setLimit] = useState<number>(query.limit)
const router = useRouter()
const dataset = useDatasetDetailContextWithSelector(s => s.dataset)
const [timerCanRun, setTimerCanRun] = useState(true)
const isDataSourceNotion = dataset?.data_source_type === DataSourceType.NOTION
const isDataSourceWeb = dataset?.data_source_type === DataSourceType.WEB
const isDataSourceFile = dataset?.data_source_type === DataSourceType.FILE
const embeddingAvailable = !!dataset?.embedding_available
const debouncedSearchValue = useDebounce(searchValue, { wait: 500 })
const statusFilterItems: Item[] = useMemo(() => [
{ value: 'all', name: t('list.index.all', { ns: 'datasetDocuments' }) as string },
{ value: 'queuing', name: DOC_INDEX_STATUS_MAP.queuing.text },
{ value: 'indexing', name: DOC_INDEX_STATUS_MAP.indexing.text },
{ value: 'paused', name: DOC_INDEX_STATUS_MAP.paused.text },
{ value: 'error', name: DOC_INDEX_STATUS_MAP.error.text },
{ value: 'available', name: DOC_INDEX_STATUS_MAP.available.text },
{ value: 'enabled', name: DOC_INDEX_STATUS_MAP.enabled.text },
{ value: 'disabled', name: DOC_INDEX_STATUS_MAP.disabled.text },
{ value: 'archived', name: DOC_INDEX_STATUS_MAP.archived.text },
], [DOC_INDEX_STATUS_MAP, t])
const normalizedStatusFilterValue = useMemo(() => normalizeStatusForQuery(statusFilterValue), [statusFilterValue])
const sortItems: Item[] = useMemo(() => [
{ value: 'created_at', name: t('list.sort.uploadTime', { ns: 'datasetDocuments' }) as string },
{ value: 'hit_count', name: t('list.sort.hitCount', { ns: 'datasetDocuments' }) as string },
], [t])
// Initialize search value from URL on mount
useEffect(() => {
if (query.keyword) {
setInputValue(query.keyword)
setSearchValue(query.keyword)
}
}, []) // Only run on mount
// Sync local state with URL query changes
useEffect(() => {
setCurrPage(query.page - 1)
setLimit(query.limit)
if (query.keyword !== searchValue) {
setInputValue(query.keyword)
setSearchValue(query.keyword)
}
setStatusFilterValue((prev) => {
const nextValue = sanitizeStatusValue(query.status)
return prev === nextValue ? prev : nextValue
})
setSortValue(query.sort)
}, [query])
// Update URL when pagination changes
const handlePageChange = (newPage: number) => {
setCurrPage(newPage)
updateQuery({ page: newPage + 1 }) // Pagination emits 0-based page, convert to 1-based for URL
}
// Update URL when limit changes
const handleLimitChange = (newLimit: number) => {
setLimit(newLimit)
setCurrPage(0) // Reset to first page when limit changes
updateQuery({ limit: newLimit, page: 1 })
}
// Update URL when search changes
useEffect(() => {
if (debouncedSearchValue !== query.keyword) {
setCurrPage(0) // Reset to first page when search changes
updateQuery({ keyword: debouncedSearchValue, page: 1 })
}
}, [debouncedSearchValue, query.keyword, updateQuery])
// Use custom hook for page state management
const {
inputValue,
debouncedSearchValue,
handleInputChange,
statusFilterValue,
sortValue,
normalizedStatusFilterValue,
handleStatusFilterChange,
handleStatusFilterClear,
handleSortChange,
currPage,
limit,
handlePageChange,
handleLimitChange,
selectedIds,
setSelectedIds,
timerCanRun,
updatePollingState,
adjustPageForTotal,
} = useDocumentsPageState()
// Fetch document list
const { data: documentsRes, isLoading: isListLoading } = useDocumentList({
datasetId,
query: {
@ -192,16 +62,18 @@ const Documents: FC<IDocumentsProps> = ({ datasetId }) => {
refetchInterval: timerCanRun ? 2500 : 0,
})
const invalidDocumentList = useInvalidDocumentList(datasetId)
// Update polling state when documents change
useEffect(() => {
if (documentsRes) {
const totalPages = Math.ceil(documentsRes.total / limit)
if (totalPages < currPage + 1)
setCurrPage(totalPages === 0 ? 0 : totalPages - 1)
}
}, [documentsRes])
updatePollingState(documentsRes)
}, [documentsRes, updatePollingState])
// Adjust page when total changes
useEffect(() => {
adjustPageForTotal(documentsRes)
}, [documentsRes, adjustPageForTotal])
// Invalidation hooks
const invalidDocumentList = useInvalidDocumentList(datasetId)
const invalidDocumentDetail = useInvalidDocumentDetail()
const invalidChunkList = useInvalid(useSegmentListKey)
const invalidChildChunkList = useInvalid(useChildSegmentListKey)
@ -213,73 +85,9 @@ const Documents: FC<IDocumentsProps> = ({ datasetId }) => {
invalidChunkList()
invalidChildChunkList()
}, 5000)
}, [])
useEffect(() => {
let completedNum = 0
let percent = 0
documentsRes?.data?.forEach((documentItem) => {
const { indexing_status, completed_segments, total_segments } = documentItem
const isEmbedded = indexing_status === 'completed' || indexing_status === 'paused' || indexing_status === 'error'
if (isEmbedded)
completedNum++
const completedCount = completed_segments || 0
const totalCount = total_segments || 0
if (totalCount === 0 && completedCount === 0) {
percent = isEmbedded ? 100 : 0
}
else {
const per = Math.round(completedCount * 100 / totalCount)
percent = per > 100 ? 100 : per
}
return {
...documentItem,
percent,
}
})
const hasIncompleteDocuments = completedNum !== documentsRes?.data?.length
const transientStatuses = ['queuing', 'indexing', 'paused']
const shouldForcePolling = normalizedStatusFilterValue === 'all'
? false
: transientStatuses.includes(normalizedStatusFilterValue)
setTimerCanRun(shouldForcePolling || hasIncompleteDocuments)
}, [documentsRes, normalizedStatusFilterValue])
const total = documentsRes?.total || 0
const routeToDocCreate = () => {
// if dataset is created from pipeline, go to create from pipeline page
if (dataset?.runtime_mode === 'rag_pipeline') {
router.push(`/datasets/${datasetId}/documents/create-from-pipeline`)
return
}
router.push(`/datasets/${datasetId}/documents/create`)
}
const documentsList = documentsRes?.data
const [selectedIds, setSelectedIds] = useState<string[]>([])
// Clear selection when search changes to avoid confusion
useEffect(() => {
if (searchValue !== query.keyword)
setSelectedIds([])
}, [searchValue, query.keyword])
useEffect(() => {
setSelectedIds([])
}, [normalizedStatusFilterValue])
const { run: handleSearch } = useDebounceFn(() => {
setSearchValue(inputValue)
}, { wait: 500 })
const handleInputChange = (value: string) => {
setInputValue(value)
handleSearch()
}
}, [invalidDocumentList, invalidDocumentDetail, invalidChunkList, invalidChildChunkList])
// Metadata editing hook
const {
isShowEditModal: isShowEditMetadataModal,
showEditModal: showEditMetadataModal,
@ -297,130 +105,84 @@ const Documents: FC<IDocumentsProps> = ({ datasetId }) => {
onUpdateDocList: invalidDocumentList,
})
// Route to document creation page
const routeToDocCreate = useCallback(() => {
if (dataset?.runtime_mode === 'rag_pipeline') {
router.push(`/datasets/${datasetId}/documents/create-from-pipeline`)
return
}
router.push(`/datasets/${datasetId}/documents/create`)
}, [dataset?.runtime_mode, datasetId, router])
const total = documentsRes?.total || 0
const documentsList = documentsRes?.data
// Render content based on loading and data state
const renderContent = () => {
if (isListLoading)
return <Loading type="app" />
if (total > 0) {
return (
<List
embeddingAvailable={embeddingAvailable}
documents={documentsList || []}
datasetId={datasetId}
onUpdate={handleUpdate}
selectedIds={selectedIds}
onSelectedIdChange={setSelectedIds}
statusFilterValue={normalizedStatusFilterValue}
remoteSortValue={sortValue}
pagination={{
total,
limit,
onLimitChange: handleLimitChange,
current: currPage,
onChange: handlePageChange,
}}
onManageMetadata={showEditMetadataModal}
/>
)
}
const isDataSourceNotion = dataset?.data_source_type === DataSourceType.NOTION
return (
<EmptyElement
canAdd={embeddingAvailable}
onClick={routeToDocCreate}
type={isDataSourceNotion ? 'sync' : 'upload'}
/>
)
}
return (
<div className="flex h-full flex-col">
<div className="flex flex-col justify-center gap-1 px-6 pt-4">
<h1 className="text-base font-semibold text-text-primary">{t('list.title', { ns: 'datasetDocuments' })}</h1>
<div className="flex items-center space-x-0.5 text-sm font-normal text-text-tertiary">
<span>{t('list.desc', { ns: 'datasetDocuments' })}</span>
<a
className="flex items-center text-text-accent"
target="_blank"
href={docLink('/guides/knowledge-base/integrate-knowledge-within-application')}
>
<span>{t('list.learnMore', { ns: 'datasetDocuments' })}</span>
<RiExternalLinkLine className="h-3 w-3" />
</a>
</div>
</div>
<DocumentsHeader
datasetId={datasetId}
dataSourceType={dataset?.data_source_type}
embeddingAvailable={embeddingAvailable}
isFreePlan={isFreePlan}
statusFilterValue={statusFilterValue}
sortValue={sortValue}
inputValue={inputValue}
onStatusFilterChange={handleStatusFilterChange}
onStatusFilterClear={handleStatusFilterClear}
onSortChange={handleSortChange}
onInputChange={handleInputChange}
isShowEditMetadataModal={isShowEditMetadataModal}
showEditMetadataModal={showEditMetadataModal}
hideEditMetadataModal={hideEditMetadataModal}
datasetMetaData={datasetMetaData}
builtInMetaData={builtInMetaData}
builtInEnabled={!!builtInEnabled}
onAddMetaData={handleAddMetaData}
onRenameMetaData={handleRename}
onDeleteMetaData={handleDeleteMetaData}
onBuiltInEnabledChange={setBuiltInEnabled}
onAddDocument={routeToDocCreate}
/>
<div className="flex h-0 grow flex-col px-6 pt-4">
<div className="flex flex-wrap items-center justify-between">
<div className="flex items-center gap-2">
<Chip
className="w-[160px]"
showLeftIcon={false}
value={statusFilterValue}
items={statusFilterItems}
onSelect={(item) => {
const selectedValue = sanitizeStatusValue(item?.value ? String(item.value) : '')
setStatusFilterValue(selectedValue)
setCurrPage(0)
updateQuery({ status: selectedValue, page: 1 })
}}
onClear={() => {
if (statusFilterValue === 'all')
return
setStatusFilterValue('all')
setCurrPage(0)
updateQuery({ status: 'all', page: 1 })
}}
/>
<Input
showLeftIcon
showClearIcon
wrapperClassName="!w-[200px]"
value={inputValue}
onChange={e => handleInputChange(e.target.value)}
onClear={() => handleInputChange('')}
/>
<div className="h-3.5 w-px bg-divider-regular"></div>
<Sort
order={sortValue.startsWith('-') ? '-' : ''}
value={sortValue.replace('-', '')}
items={sortItems}
onSelect={(value) => {
const next = String(value) as SortType
if (next === sortValue)
return
setSortValue(next)
setCurrPage(0)
updateQuery({ sort: next, page: 1 })
}}
/>
</div>
<div className="flex !h-8 items-center justify-center gap-2">
{!isFreePlan && <AutoDisabledDocument datasetId={datasetId} />}
<IndexFailed datasetId={datasetId} />
{!embeddingAvailable && <StatusWithAction type="warning" description={t('embeddingModelNotAvailable', { ns: 'dataset' })} />}
{embeddingAvailable && (
<Button variant="secondary" className="shrink-0" onClick={showEditMetadataModal}>
<RiDraftLine className="mr-1 size-4" />
{t('metadata.metadata', { ns: 'dataset' })}
</Button>
)}
{isShowEditMetadataModal && (
<DatasetMetadataDrawer
userMetadata={datasetMetaData || []}
onClose={hideEditMetadataModal}
onAdd={handleAddMetaData}
onRename={handleRename}
onRemove={handleDeleteMetaData}
builtInMetadata={builtInMetaData || []}
isBuiltInEnabled={!!builtInEnabled}
onIsBuiltInEnabledChange={setBuiltInEnabled}
/>
)}
{embeddingAvailable && (
<Button variant="primary" onClick={routeToDocCreate} className="shrink-0">
<PlusIcon className={cn('mr-2 h-4 w-4 stroke-current')} />
{isDataSourceNotion && t('list.addPages', { ns: 'datasetDocuments' })}
{isDataSourceWeb && t('list.addUrl', { ns: 'datasetDocuments' })}
{(!dataset?.data_source_type || isDataSourceFile) && t('list.addFile', { ns: 'datasetDocuments' })}
</Button>
)}
</div>
</div>
{isListLoading
? <Loading type="app" />
// eslint-disable-next-line sonarjs/no-nested-conditional
: total > 0
? (
<List
embeddingAvailable={embeddingAvailable}
documents={documentsList || []}
datasetId={datasetId}
onUpdate={handleUpdate}
selectedIds={selectedIds}
onSelectedIdChange={setSelectedIds}
statusFilterValue={normalizedStatusFilterValue}
remoteSortValue={sortValue}
pagination={{
total,
limit,
onLimitChange: handleLimitChange,
current: currPage,
onChange: handlePageChange,
}}
onManageMetadata={showEditMetadataModal}
/>
)
: (
<EmptyElement
canAdd={embeddingAvailable}
onClick={routeToDocCreate}
type={isDataSourceNotion ? 'sync' : 'upload'}
/>
)}
{renderContent()}
</div>
</div>
)

View File

@ -357,7 +357,9 @@ const Form = () => {
</div>
)}
{
indexMethod === IndexingType.QUALIFIED && (
indexMethod === IndexingType.QUALIFIED
&& [ChunkingMode.text, ChunkingMode.parentChild].includes(currentDataset?.doc_form as ChunkingMode)
&& (
<>
<Divider
type="horizontal"

View File

@ -30,7 +30,7 @@ const ApiBasedExtensionModal: FC<ApiBasedExtensionModalProps> = ({
onSave,
}) => {
const { t } = useTranslation()
const docLink = useDocLink()
const docLink = useDocLink('https://docs.dify.ai/versions/3-0-x')
const [localeData, setLocaleData] = useState(data)
const [loading, setLoading] = useState(false)
const { notify } = useToastContext()
@ -102,7 +102,7 @@ const ApiBasedExtensionModal: FC<ApiBasedExtensionModalProps> = ({
<div className="flex h-9 items-center justify-between text-sm font-medium text-text-primary">
{t('apiBasedExtension.modal.apiEndpoint.title', { ns: 'common' })}
<a
href={docLink('/guides/extension/api-based-extension/README')}
href={docLink('/user-guide/extension/api-based-extension/README#api-based-extension')}
target="_blank"
rel="noopener noreferrer"
className="group flex items-center text-xs font-normal text-text-accent"

View File

@ -55,7 +55,6 @@ type FormProps<
nodeId?: string
nodeOutputVars?: NodeOutPutVar[]
availableNodes?: Node[]
canChooseMCPTool?: boolean
}
function Form<
@ -81,7 +80,6 @@ function Form<
nodeId,
nodeOutputVars,
availableNodes,
canChooseMCPTool,
}: FormProps<CustomFormSchema>) {
const language = useLanguage()
const [changeKey, setChangeKey] = useState('')
@ -407,7 +405,6 @@ function Form<
value={value[variable] || []}
onChange={item => handleFormChange(variable, item as any)}
supportCollapse
canChooseMCPTool={canChooseMCPTool}
/>
{fieldMoreInfo?.(formSchema)}
{validating && changeKey === variable && <ValidatingTip />}

View File

@ -2,9 +2,9 @@
import type { INavSelectorProps } from './nav-selector'
import Link from 'next/link'
import { usePathname, useSearchParams, useSelectedLayoutSegment } from 'next/navigation'
import { useSelectedLayoutSegment } from 'next/navigation'
import * as React from 'react'
import { useEffect, useState } from 'react'
import { useState } from 'react'
import { useStore as useAppStore } from '@/app/components/app/store'
import { ArrowNarrowLeft } from '@/app/components/base/icons/src/vender/line/arrows'
import { cn } from '@/utils/classnames'
@ -36,14 +36,6 @@ const Nav = ({
const [hovered, setHovered] = useState(false)
const segment = useSelectedLayoutSegment()
const isActivated = Array.isArray(activeSegment) ? activeSegment.includes(segment!) : segment === activeSegment
const pathname = usePathname()
const searchParams = useSearchParams()
const [linkLastSearchParams, setLinkLastSearchParams] = useState('')
useEffect(() => {
if (pathname === link)
setLinkLastSearchParams(searchParams.toString())
}, [pathname, searchParams])
return (
<div className={`
@ -52,7 +44,7 @@ const Nav = ({
${!curNav && !isActivated && 'hover:bg-components-main-nav-nav-button-bg-hover'}
`}
>
<Link href={link + (linkLastSearchParams && `?${linkLastSearchParams}`)}>
<Link href={link}>
<div
onClick={(e) => {
// Don't clear state if opening in new tab/window

View File

@ -897,6 +897,58 @@ describe('Icon', () => {
const iconDiv = container.firstChild as HTMLElement
expect(iconDiv).toHaveStyle({ backgroundImage: 'url(/icon?name=test&size=large)' })
})
it('should not render status indicators when src is object with installed=true', () => {
render(<Icon src={{ content: '🎉', background: '#fff' }} installed={true} />)
// Status indicators should not render for object src
expect(screen.queryByTestId('ri-check-line')).not.toBeInTheDocument()
})
it('should not render status indicators when src is object with installFailed=true', () => {
render(<Icon src={{ content: '🎉', background: '#fff' }} installFailed={true} />)
// Status indicators should not render for object src
expect(screen.queryByTestId('ri-close-line')).not.toBeInTheDocument()
})
it('should render object src with all size variants', () => {
const sizes: Array<'xs' | 'tiny' | 'small' | 'medium' | 'large'> = ['xs', 'tiny', 'small', 'medium', 'large']
sizes.forEach((size) => {
const { unmount } = render(<Icon src={{ content: '🔗', background: '#fff' }} size={size} />)
expect(screen.getByTestId('app-icon')).toHaveAttribute('data-size', size)
unmount()
})
})
it('should render object src with custom className', () => {
const { container } = render(
<Icon src={{ content: '🎉', background: '#fff' }} className="custom-object-icon" />,
)
expect(container.querySelector('.custom-object-icon')).toBeInTheDocument()
})
it('should pass correct props to AppIcon for object src', () => {
render(<Icon src={{ content: '😀', background: '#123456' }} />)
const appIcon = screen.getByTestId('app-icon')
expect(appIcon).toHaveAttribute('data-icon', '😀')
expect(appIcon).toHaveAttribute('data-background', '#123456')
expect(appIcon).toHaveAttribute('data-icon-type', 'emoji')
})
it('should render inner icon only when shouldUseMcpIcon returns true', () => {
// Test with MCP icon content
const { unmount } = render(<Icon src={{ content: '🔗', background: '#fff' }} />)
expect(screen.getByTestId('inner-icon')).toBeInTheDocument()
unmount()
// Test without MCP icon content
render(<Icon src={{ content: '🎉', background: '#fff' }} />)
expect(screen.queryByTestId('inner-icon')).not.toBeInTheDocument()
})
})
})

View File

@ -180,7 +180,14 @@ const AppPicker: FC<Props> = ({
background={app.icon_background}
imageUrl={app.icon_url}
/>
<div title={app.name} className="system-sm-medium grow text-components-input-text-filled">{app.name}</div>
<div title={`${app.name} (${app.id})`} className="system-sm-medium grow text-components-input-text-filled">
<span className="mr-1">{app.name}</span>
<span className="text-text-tertiary">
(
{app.id.slice(0, 8)}
)
</span>
</div>
<div className="system-2xs-medium-uppercase shrink-0 text-text-tertiary">{getAppType(app)}</div>
</div>
))}

View File

@ -16,7 +16,7 @@ import {
import AppInputsPanel from '@/app/components/plugins/plugin-detail-panel/app-selector/app-inputs-panel'
import AppPicker from '@/app/components/plugins/plugin-detail-panel/app-selector/app-picker'
import AppTrigger from '@/app/components/plugins/plugin-detail-panel/app-selector/app-trigger'
import { useInfiniteAppList } from '@/service/use-apps'
import { useAppDetail, useInfiniteAppList } from '@/service/use-apps'
const PAGE_SIZE = 20
@ -70,6 +70,30 @@ const AppSelector: FC<Props> = ({
return pages.flatMap(({ data: apps }) => apps)
}, [pages])
// fetch selected app by id to avoid pagination gaps
const { data: selectedAppDetail } = useAppDetail(value?.app_id || '')
// Ensure the currently selected app is available for display and in the picker options
const currentAppInfo = useMemo(() => {
if (!value?.app_id)
return undefined
return selectedAppDetail || displayedApps.find(app => app.id === value.app_id)
}, [value?.app_id, selectedAppDetail, displayedApps])
const appsForPicker = useMemo(() => {
if (!currentAppInfo)
return displayedApps
const appIndex = displayedApps.findIndex(a => a.id === currentAppInfo.id)
if (appIndex === -1)
return [currentAppInfo, ...displayedApps]
const updatedApps = [...displayedApps]
updatedApps[appIndex] = currentAppInfo
return updatedApps
}, [currentAppInfo, displayedApps])
const hasMore = hasNextPage ?? true
const handleLoadMore = useCallback(async () => {
@ -127,12 +151,6 @@ const AppSelector: FC<Props> = ({
}
}, [value])
const currentAppInfo = useMemo(() => {
if (!displayedApps || !value)
return undefined
return displayedApps.find(app => app.id === value.app_id)
}, [displayedApps, value])
return (
<>
<PortalToFollowElem
@ -168,7 +186,7 @@ const AppSelector: FC<Props> = ({
disabled={false}
onSelect={handleSelectApp}
scope={scope || 'all'}
apps={displayedApps}
apps={appsForPicker}
isLoading={isLoading || isLoadingMore || isFetchingNextPage}
hasMore={hasMore}
onLoadMore={handleLoadMore}

View File

@ -7,6 +7,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'
// ==================== Imports (after mocks) ====================
import { MCPToolAvailabilityProvider } from '@/app/components/workflow/nodes/_base/components/mcp-tool-availability'
import MultipleToolSelector from './index'
// ==================== Mock Setup ====================
@ -190,10 +191,11 @@ type RenderOptions = {
nodeOutputVars?: NodeOutPutVar[]
availableNodes?: Node[]
nodeId?: string
canChooseMCPTool?: boolean
versionSupported?: boolean
}
const renderComponent = (options: RenderOptions = {}) => {
const { versionSupported, ...overrides } = options
const defaultProps = {
disabled: false,
value: [],
@ -206,16 +208,17 @@ const renderComponent = (options: RenderOptions = {}) => {
nodeOutputVars: [createNodeOutputVar()],
availableNodes: [createNode()],
nodeId: 'test-node-id',
canChooseMCPTool: false,
}
const props = { ...defaultProps, ...options }
const props = { ...defaultProps, ...overrides }
const queryClient = createQueryClient()
return {
...render(
<QueryClientProvider client={queryClient}>
<MultipleToolSelector {...props} />
<MCPToolAvailabilityProvider versionSupported={versionSupported}>
<MultipleToolSelector {...props} />
</MCPToolAvailabilityProvider>
</QueryClientProvider>,
),
props,
@ -410,7 +413,7 @@ describe('MultipleToolSelector', () => {
expect(screen.getByText('2/3')).toBeInTheDocument()
})
it('should track enabled count with MCP tools when canChooseMCPTool is true', () => {
it('should track enabled count with MCP tools when version is supported', () => {
// Arrange
const mcpTools = [createMCPTool({ id: 'mcp-provider' })]
mockMCPToolsData.mockReturnValue(mcpTools)
@ -421,13 +424,13 @@ describe('MultipleToolSelector', () => {
]
// Act
renderComponent({ value: tools, canChooseMCPTool: true })
renderComponent({ value: tools, versionSupported: true })
// Assert
expect(screen.getByText('2/2')).toBeInTheDocument()
})
it('should not count MCP tools when canChooseMCPTool is false', () => {
it('should not count MCP tools when version is unsupported', () => {
// Arrange
const mcpTools = [createMCPTool({ id: 'mcp-provider' })]
mockMCPToolsData.mockReturnValue(mcpTools)
@ -438,7 +441,7 @@ describe('MultipleToolSelector', () => {
]
// Act
renderComponent({ value: tools, canChooseMCPTool: false })
renderComponent({ value: tools, versionSupported: false })
// Assert
expect(screen.getByText('1/2')).toBeInTheDocument()
@ -721,14 +724,6 @@ describe('MultipleToolSelector', () => {
expect(screen.getByTestId('tool-selector-add')).toBeInTheDocument()
})
it('should pass canChooseMCPTool prop correctly', () => {
// Arrange & Act
renderComponent({ canChooseMCPTool: true })
// Assert
expect(screen.getByTestId('tool-selector-add')).toBeInTheDocument()
})
it('should render with supportEnableSwitch for edit selectors', () => {
// Arrange
const tools = [createToolValue()]
@ -771,13 +766,13 @@ describe('MultipleToolSelector', () => {
]
// Act
renderComponent({ value: tools, canChooseMCPTool: true })
renderComponent({ value: tools, versionSupported: true })
// Assert
expect(screen.getByText('2/2')).toBeInTheDocument()
})
it('should exclude MCP tools from enabled count when canChooseMCPTool is false', () => {
it('should exclude MCP tools from enabled count when strategy version is unsupported', () => {
// Arrange
const mcpTools = [createMCPTool({ id: 'mcp-provider' })]
mockMCPToolsData.mockReturnValue(mcpTools)
@ -788,7 +783,7 @@ describe('MultipleToolSelector', () => {
]
// Act
renderComponent({ value: tools, canChooseMCPTool: false })
renderComponent({ value: tools, versionSupported: false })
// Assert - Only regular tool should be counted
expect(screen.getByText('1/2')).toBeInTheDocument()

View File

@ -12,6 +12,7 @@ import Divider from '@/app/components/base/divider'
import { ArrowDownRoundFill } from '@/app/components/base/icons/src/vender/solid/general'
import Tooltip from '@/app/components/base/tooltip'
import ToolSelector from '@/app/components/plugins/plugin-detail-panel/tool-selector'
import { useMCPToolAvailability } from '@/app/components/workflow/nodes/_base/components/mcp-tool-availability'
import { useAllMCPTools } from '@/service/use-tools'
import { cn } from '@/utils/classnames'
@ -27,7 +28,6 @@ type Props = {
nodeOutputVars: NodeOutPutVar[]
availableNodes: Node[]
nodeId?: string
canChooseMCPTool?: boolean
}
const MultipleToolSelector = ({
@ -42,14 +42,14 @@ const MultipleToolSelector = ({
nodeOutputVars,
availableNodes,
nodeId,
canChooseMCPTool,
}: Props) => {
const { t } = useTranslation()
const { allowed: isMCPToolAllowed } = useMCPToolAvailability()
const { data: mcpTools } = useAllMCPTools()
const enabledCount = value.filter((item) => {
const isMCPTool = mcpTools?.find(tool => tool.id === item.provider_name)
if (isMCPTool)
return item.enabled && canChooseMCPTool
return item.enabled && isMCPToolAllowed
return item.enabled
}).length
// collapse control
@ -167,7 +167,6 @@ const MultipleToolSelector = ({
onSelectMultiple={handleAddMultiple}
onDelete={() => handleDelete(index)}
supportEnableSwitch
canChooseMCPTool={canChooseMCPTool}
isEdit
/>
</div>
@ -190,7 +189,6 @@ const MultipleToolSelector = ({
panelShowState={panelShowState}
onPanelShowStateChange={setPanelShowState}
isEdit={false}
canChooseMCPTool={canChooseMCPTool}
onSelectMultiple={handleAddMultiple}
/>
</>

View File

@ -64,7 +64,6 @@ type Props = {
nodeOutputVars: NodeOutPutVar[]
availableNodes: Node[]
nodeId?: string
canChooseMCPTool?: boolean
}
const ToolSelector: FC<Props> = ({
value,
@ -86,7 +85,6 @@ const ToolSelector: FC<Props> = ({
nodeOutputVars,
availableNodes,
nodeId = '',
canChooseMCPTool,
}) => {
const { t } = useTranslation()
const [isShow, onShowChange] = useState(false)
@ -267,7 +265,6 @@ const ToolSelector: FC<Props> = ({
</p>
</div>
)}
canChooseMCPTool={canChooseMCPTool}
/>
)}
</PortalToFollowElemTrigger>
@ -300,7 +297,6 @@ const ToolSelector: FC<Props> = ({
onSelectMultiple={handleSelectMultipleTool}
scope={scope}
selectedTools={selectedTools}
canChooseMCPTool={canChooseMCPTool}
/>
</div>
<div className="flex flex-col gap-1">

View File

@ -16,6 +16,7 @@ import Tooltip from '@/app/components/base/tooltip'
import { ToolTipContent } from '@/app/components/base/tooltip/content'
import Indicator from '@/app/components/header/indicator'
import { InstallPluginButton } from '@/app/components/workflow/nodes/_base/components/install-plugin-button'
import { useMCPToolAvailability } from '@/app/components/workflow/nodes/_base/components/mcp-tool-availability'
import McpToolNotSupportTooltip from '@/app/components/workflow/nodes/_base/components/mcp-tool-not-support-tooltip'
import { SwitchPluginVersion } from '@/app/components/workflow/nodes/_base/components/switch-plugin-version'
import { cn } from '@/utils/classnames'
@ -39,7 +40,6 @@ type Props = {
versionMismatch?: boolean
open: boolean
authRemoved?: boolean
canChooseMCPTool?: boolean
}
const ToolItem = ({
@ -61,13 +61,13 @@ const ToolItem = ({
errorTip,
versionMismatch,
authRemoved,
canChooseMCPTool,
}: Props) => {
const { t } = useTranslation()
const { allowed: isMCPToolAllowed } = useMCPToolAvailability()
const providerNameText = isMCPTool ? providerShowName : providerName?.split('/').pop()
const isTransparent = uninstalled || versionMismatch || isError
const [isDeleting, setIsDeleting] = useState(false)
const isShowCanNotChooseMCPTip = isMCPTool && !canChooseMCPTool
const isShowCanNotChooseMCPTip = isMCPTool && !isMCPToolAllowed
return (
<div className={cn(
@ -125,9 +125,7 @@ const ToolItem = ({
/>
</div>
)}
{isShowCanNotChooseMCPTip && (
<McpToolNotSupportTooltip />
)}
{isShowCanNotChooseMCPTip && <McpToolNotSupportTooltip />}
{!isError && !uninstalled && !versionMismatch && noAuth && (
<Button variant="secondary" size="small">
{t('notAuthorized', { ns: 'tools' })}

View File

@ -0,0 +1,123 @@
import { render, screen } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
// Import mocks
import { useGlobalPublicStore } from '@/context/global-public-context'
import { PluginPageContext, PluginPageContextProvider, usePluginPageContext } from './context'
// Mock dependencies
vi.mock('nuqs', () => ({
useQueryState: vi.fn(() => ['plugins', vi.fn()]),
}))
vi.mock('@/context/global-public-context', () => ({
useGlobalPublicStore: vi.fn(),
}))
vi.mock('../hooks', () => ({
PLUGIN_PAGE_TABS_MAP: {
plugins: 'plugins',
marketplace: 'discover',
},
usePluginPageTabs: () => [
{ value: 'plugins', text: 'Plugins' },
{ value: 'discover', text: 'Explore Marketplace' },
],
}))
// Helper function to mock useGlobalPublicStore with marketplace setting
const mockGlobalPublicStore = (enableMarketplace: boolean) => {
vi.mocked(useGlobalPublicStore).mockImplementation((selector) => {
const state = { systemFeatures: { enable_marketplace: enableMarketplace } }
return selector(state as Parameters<typeof selector>[0])
})
}
// Test component that uses the context
const TestConsumer = () => {
const containerRef = usePluginPageContext(v => v.containerRef)
const options = usePluginPageContext(v => v.options)
const activeTab = usePluginPageContext(v => v.activeTab)
return (
<div>
<span data-testid="has-container-ref">{containerRef ? 'true' : 'false'}</span>
<span data-testid="options-count">{options.length}</span>
<span data-testid="active-tab">{activeTab}</span>
{options.map((opt: { value: string, text: string }) => (
<span key={opt.value} data-testid={`option-${opt.value}`}>{opt.text}</span>
))}
</div>
)
}
describe('PluginPageContext', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('PluginPageContextProvider', () => {
it('should provide context values to children', () => {
mockGlobalPublicStore(true)
render(
<PluginPageContextProvider>
<TestConsumer />
</PluginPageContextProvider>,
)
expect(screen.getByTestId('has-container-ref')).toHaveTextContent('true')
expect(screen.getByTestId('options-count')).toHaveTextContent('2')
})
it('should include marketplace tab when enable_marketplace is true', () => {
mockGlobalPublicStore(true)
render(
<PluginPageContextProvider>
<TestConsumer />
</PluginPageContextProvider>,
)
expect(screen.getByTestId('option-plugins')).toBeInTheDocument()
expect(screen.getByTestId('option-discover')).toBeInTheDocument()
})
it('should filter out marketplace tab when enable_marketplace is false', () => {
mockGlobalPublicStore(false)
render(
<PluginPageContextProvider>
<TestConsumer />
</PluginPageContextProvider>,
)
expect(screen.getByTestId('option-plugins')).toBeInTheDocument()
expect(screen.queryByTestId('option-discover')).not.toBeInTheDocument()
expect(screen.getByTestId('options-count')).toHaveTextContent('1')
})
})
describe('usePluginPageContext', () => {
it('should select specific context values', () => {
mockGlobalPublicStore(true)
render(
<PluginPageContextProvider>
<TestConsumer />
</PluginPageContextProvider>,
)
// activeTab should be 'plugins' from the mock
expect(screen.getByTestId('active-tab')).toHaveTextContent('plugins')
})
})
describe('Default Context Values', () => {
it('should have empty options by default from context', () => {
// Test that the context has proper default values by checking the exported constant
// The PluginPageContext is created with default values including empty options array
expect(PluginPageContext).toBeDefined()
})
})
})

File diff suppressed because it is too large Load Diff

View File

@ -207,6 +207,7 @@ const PluginPage = ({
popupContent={t('privilege.title', { ns: 'plugin' })}
>
<Button
data-testid="plugin-settings-button"
className="group h-full w-full p-2 text-components-button-secondary-text"
onClick={setShowPluginSettingModal}
>

View File

@ -0,0 +1,219 @@
import type { FC, ReactNode } from 'react'
import type { PluginStatus } from '@/app/components/plugins/types'
import type { Locale } from '@/i18n-config'
import {
RiCheckboxCircleFill,
RiErrorWarningFill,
RiLoaderLine,
} from '@remixicon/react'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import CardIcon from '@/app/components/plugins/card/base/card-icon'
import { useGetLanguage } from '@/context/i18n'
// Types
type PluginItemProps = {
plugin: PluginStatus
getIconUrl: (icon: string) => string
language: Locale
statusIcon: ReactNode
statusText: string
statusClassName?: string
action?: ReactNode
}
type PluginSectionProps = {
title: string
count: number
plugins: PluginStatus[]
getIconUrl: (icon: string) => string
language: Locale
statusIcon: ReactNode
defaultStatusText: string
statusClassName?: string
headerAction?: ReactNode
renderItemAction?: (plugin: PluginStatus) => ReactNode
}
type PluginTaskListProps = {
runningPlugins: PluginStatus[]
successPlugins: PluginStatus[]
errorPlugins: PluginStatus[]
getIconUrl: (icon: string) => string
onClearAll: () => void
onClearErrors: () => void
onClearSingle: (taskId: string, pluginId: string) => void
}
// Plugin Item Component
const PluginItem: FC<PluginItemProps> = ({
plugin,
getIconUrl,
language,
statusIcon,
statusText,
statusClassName,
action,
}) => {
return (
<div className="flex items-center rounded-lg p-2 hover:bg-state-base-hover">
<div className="relative mr-2 flex h-6 w-6 items-center justify-center rounded-md border-[0.5px] border-components-panel-border-subtle bg-background-default-dodge">
{statusIcon}
<CardIcon
size="tiny"
src={getIconUrl(plugin.icon)}
/>
</div>
<div className="grow">
<div className="system-md-regular truncate text-text-secondary">
{plugin.labels[language]}
</div>
<div className={`system-xs-regular ${statusClassName || 'text-text-tertiary'}`}>
{statusText}
</div>
</div>
{action}
</div>
)
}
// Plugin Section Component
const PluginSection: FC<PluginSectionProps> = ({
title,
count,
plugins,
getIconUrl,
language,
statusIcon,
defaultStatusText,
statusClassName,
headerAction,
renderItemAction,
}) => {
if (plugins.length === 0)
return null
return (
<>
<div className="system-sm-semibold-uppercase sticky top-0 flex h-7 items-center justify-between px-2 pt-1 text-text-secondary">
{title}
{' '}
(
{count}
)
{headerAction}
</div>
<div className="max-h-[200px] overflow-y-auto">
{plugins.map(plugin => (
<PluginItem
key={plugin.plugin_unique_identifier}
plugin={plugin}
getIconUrl={getIconUrl}
language={language}
statusIcon={statusIcon}
statusText={plugin.message || defaultStatusText}
statusClassName={statusClassName}
action={renderItemAction?.(plugin)}
/>
))}
</div>
</>
)
}
// Main Plugin Task List Component
const PluginTaskList: FC<PluginTaskListProps> = ({
runningPlugins,
successPlugins,
errorPlugins,
getIconUrl,
onClearAll,
onClearErrors,
onClearSingle,
}) => {
const { t } = useTranslation()
const language = useGetLanguage()
return (
<div className="w-[360px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur p-1 shadow-lg">
{/* Running Plugins Section */}
{runningPlugins.length > 0 && (
<PluginSection
title={t('task.installing', { ns: 'plugin' })}
count={runningPlugins.length}
plugins={runningPlugins}
getIconUrl={getIconUrl}
language={language}
statusIcon={
<RiLoaderLine className="absolute -bottom-0.5 -right-0.5 z-10 h-3 w-3 animate-spin text-text-accent" />
}
defaultStatusText={t('task.installing', { ns: 'plugin' })}
/>
)}
{/* Success Plugins Section */}
{successPlugins.length > 0 && (
<PluginSection
title={t('task.installed', { ns: 'plugin' })}
count={successPlugins.length}
plugins={successPlugins}
getIconUrl={getIconUrl}
language={language}
statusIcon={
<RiCheckboxCircleFill className="absolute -bottom-0.5 -right-0.5 z-10 h-3 w-3 text-text-success" />
}
defaultStatusText={t('task.installed', { ns: 'plugin' })}
statusClassName="text-text-success"
headerAction={(
<Button
className="shrink-0"
size="small"
variant="ghost"
onClick={onClearAll}
>
{t('task.clearAll', { ns: 'plugin' })}
</Button>
)}
/>
)}
{/* Error Plugins Section */}
{errorPlugins.length > 0 && (
<PluginSection
title={t('task.installError', { ns: 'plugin', errorLength: errorPlugins.length })}
count={errorPlugins.length}
plugins={errorPlugins}
getIconUrl={getIconUrl}
language={language}
statusIcon={
<RiErrorWarningFill className="absolute -bottom-0.5 -right-0.5 z-10 h-3 w-3 text-text-destructive" />
}
defaultStatusText={t('task.installError', { ns: 'plugin', errorLength: errorPlugins.length })}
statusClassName="text-text-destructive break-all"
headerAction={(
<Button
className="shrink-0"
size="small"
variant="ghost"
onClick={onClearErrors}
>
{t('task.clearAll', { ns: 'plugin' })}
</Button>
)}
renderItemAction={plugin => (
<Button
className="shrink-0"
size="small"
variant="ghost"
onClick={() => onClearSingle(plugin.taskId, plugin.plugin_unique_identifier)}
>
{t('operation.clear', { ns: 'common' })}
</Button>
)}
/>
)}
</div>
)
}
export default PluginTaskList

View File

@ -0,0 +1,96 @@
import type { FC } from 'react'
import {
RiCheckboxCircleFill,
RiErrorWarningFill,
RiInstallLine,
} from '@remixicon/react'
import ProgressCircle from '@/app/components/base/progress-bar/progress-circle'
import Tooltip from '@/app/components/base/tooltip'
import DownloadingIcon from '@/app/components/header/plugins-nav/downloading-icon'
import { cn } from '@/utils/classnames'
export type TaskStatusIndicatorProps = {
tip: string
isInstalling: boolean
isInstallingWithSuccess: boolean
isInstallingWithError: boolean
isSuccess: boolean
isFailed: boolean
successPluginsLength: number
runningPluginsLength: number
totalPluginsLength: number
onClick: () => void
}
const TaskStatusIndicator: FC<TaskStatusIndicatorProps> = ({
tip,
isInstalling,
isInstallingWithSuccess,
isInstallingWithError,
isSuccess,
isFailed,
successPluginsLength,
runningPluginsLength,
totalPluginsLength,
onClick,
}) => {
const showDownloadingIcon = isInstalling || isInstallingWithError
const showErrorStyle = isInstallingWithError || isFailed
const showSuccessIcon = isSuccess || (successPluginsLength > 0 && runningPluginsLength === 0)
return (
<Tooltip
popupContent={tip}
asChild
offset={8}
>
<div
className={cn(
'relative flex h-8 w-8 items-center justify-center rounded-lg border-[0.5px] border-components-button-secondary-border bg-components-button-secondary-bg shadow-xs hover:bg-components-button-secondary-bg-hover',
showErrorStyle && 'cursor-pointer border-components-button-destructive-secondary-border-hover bg-state-destructive-hover hover:bg-state-destructive-hover-alt',
(isInstalling || isInstallingWithSuccess || isSuccess) && 'cursor-pointer hover:bg-components-button-secondary-bg-hover',
)}
id="plugin-task-trigger"
onClick={onClick}
>
{/* Main Icon */}
{showDownloadingIcon
? <DownloadingIcon />
: (
<RiInstallLine
className={cn(
'h-4 w-4 text-components-button-secondary-text',
showErrorStyle && 'text-components-button-destructive-secondary-text',
)}
/>
)}
{/* Status Indicator Badge */}
<div className="absolute -right-1 -top-1">
{(isInstalling || isInstallingWithSuccess) && (
<ProgressCircle
percentage={(totalPluginsLength > 0 ? successPluginsLength / totalPluginsLength : 0) * 100}
circleFillColor="fill-components-progress-brand-bg"
/>
)}
{isInstallingWithError && (
<ProgressCircle
percentage={(totalPluginsLength > 0 ? runningPluginsLength / totalPluginsLength : 0) * 100}
circleFillColor="fill-components-progress-brand-bg"
sectorFillColor="fill-components-progress-error-border"
circleStrokeColor="stroke-components-progress-error-border"
/>
)}
{showSuccessIcon && !isInstalling && !isInstallingWithSuccess && !isInstallingWithError && (
<RiCheckboxCircleFill className="h-3.5 w-3.5 text-text-success" />
)}
{isFailed && (
<RiErrorWarningFill className="h-3.5 w-3.5 text-text-destructive" />
)}
</div>
</div>
</Tooltip>
)
}
export default TaskStatusIndicator

View File

@ -0,0 +1,856 @@
import type { PluginStatus } from '@/app/components/plugins/types'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { TaskStatus } from '@/app/components/plugins/types'
// Import mocked modules
import { useMutationClearTaskPlugin, usePluginTaskList } from '@/service/use-plugins'
import PluginTaskList from './components/plugin-task-list'
import TaskStatusIndicator from './components/task-status-indicator'
import { usePluginTaskStatus } from './hooks'
import PluginTasks from './index'
// Mock external dependencies
vi.mock('@/service/use-plugins', () => ({
usePluginTaskList: vi.fn(),
useMutationClearTaskPlugin: vi.fn(),
}))
vi.mock('@/app/components/plugins/install-plugin/base/use-get-icon', () => ({
default: () => ({
getIconUrl: (icon: string) => `https://example.com/${icon}`,
}),
}))
vi.mock('@/context/i18n', () => ({
useGetLanguage: () => 'en_US',
}))
// Helper to create mock plugin
const createMockPlugin = (overrides: Partial<PluginStatus> = {}): PluginStatus => ({
plugin_unique_identifier: `plugin-${Math.random().toString(36).substr(2, 9)}`,
plugin_id: 'test-plugin',
status: TaskStatus.running,
message: '',
icon: 'test-icon.png',
labels: {
en_US: 'Test Plugin',
zh_Hans: '测试插件',
} as Record<string, string>,
taskId: 'task-1',
...overrides,
})
// Helper to setup mock hook returns
const setupMocks = (plugins: PluginStatus[] = []) => {
const mockMutateAsync = vi.fn().mockResolvedValue({})
const mockHandleRefetch = vi.fn()
vi.mocked(usePluginTaskList).mockReturnValue({
pluginTasks: plugins.length > 0
? [{ id: 'task-1', plugins, created_at: '', updated_at: '', status: 'running', total_plugins: plugins.length, completed_plugins: 0 }]
: [],
handleRefetch: mockHandleRefetch,
} as any)
vi.mocked(useMutationClearTaskPlugin).mockReturnValue({
mutateAsync: mockMutateAsync,
} as any)
return { mockMutateAsync, mockHandleRefetch }
}
// ============================================================================
// usePluginTaskStatus Hook Tests
// ============================================================================
describe('usePluginTaskStatus Hook', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('Plugin categorization', () => {
it('should categorize running plugins correctly', () => {
const runningPlugin = createMockPlugin({ status: TaskStatus.running })
setupMocks([runningPlugin])
const TestComponent = () => {
const { runningPlugins, runningPluginsLength } = usePluginTaskStatus()
return (
<div>
<span data-testid="running-count">{runningPluginsLength}</span>
<span data-testid="running-id">{runningPlugins[0]?.plugin_unique_identifier}</span>
</div>
)
}
render(<TestComponent />)
expect(screen.getByTestId('running-count')).toHaveTextContent('1')
expect(screen.getByTestId('running-id')).toHaveTextContent(runningPlugin.plugin_unique_identifier)
})
it('should categorize success plugins correctly', () => {
const successPlugin = createMockPlugin({ status: TaskStatus.success })
setupMocks([successPlugin])
const TestComponent = () => {
const { successPlugins, successPluginsLength } = usePluginTaskStatus()
return (
<div>
<span data-testid="success-count">{successPluginsLength}</span>
<span data-testid="success-id">{successPlugins[0]?.plugin_unique_identifier}</span>
</div>
)
}
render(<TestComponent />)
expect(screen.getByTestId('success-count')).toHaveTextContent('1')
expect(screen.getByTestId('success-id')).toHaveTextContent(successPlugin.plugin_unique_identifier)
})
it('should categorize error plugins correctly', () => {
const errorPlugin = createMockPlugin({ status: TaskStatus.failed, message: 'Install failed' })
setupMocks([errorPlugin])
const TestComponent = () => {
const { errorPlugins, errorPluginsLength } = usePluginTaskStatus()
return (
<div>
<span data-testid="error-count">{errorPluginsLength}</span>
<span data-testid="error-id">{errorPlugins[0]?.plugin_unique_identifier}</span>
</div>
)
}
render(<TestComponent />)
expect(screen.getByTestId('error-count')).toHaveTextContent('1')
expect(screen.getByTestId('error-id')).toHaveTextContent(errorPlugin.plugin_unique_identifier)
})
it('should categorize mixed plugins correctly', () => {
const plugins = [
createMockPlugin({ status: TaskStatus.running, plugin_unique_identifier: 'running-1' }),
createMockPlugin({ status: TaskStatus.success, plugin_unique_identifier: 'success-1' }),
createMockPlugin({ status: TaskStatus.failed, plugin_unique_identifier: 'error-1' }),
]
setupMocks(plugins)
const TestComponent = () => {
const { runningPluginsLength, successPluginsLength, errorPluginsLength, totalPluginsLength } = usePluginTaskStatus()
return (
<div>
<span data-testid="running">{runningPluginsLength}</span>
<span data-testid="success">{successPluginsLength}</span>
<span data-testid="error">{errorPluginsLength}</span>
<span data-testid="total">{totalPluginsLength}</span>
</div>
)
}
render(<TestComponent />)
expect(screen.getByTestId('running')).toHaveTextContent('1')
expect(screen.getByTestId('success')).toHaveTextContent('1')
expect(screen.getByTestId('error')).toHaveTextContent('1')
expect(screen.getByTestId('total')).toHaveTextContent('3')
})
})
describe('Status flags', () => {
it('should set isInstalling when only running plugins exist', () => {
setupMocks([createMockPlugin({ status: TaskStatus.running })])
const TestComponent = () => {
const { isInstalling, isInstallingWithSuccess, isInstallingWithError, isSuccess, isFailed } = usePluginTaskStatus()
return (
<div>
<span data-testid="isInstalling">{String(isInstalling)}</span>
<span data-testid="isInstallingWithSuccess">{String(isInstallingWithSuccess)}</span>
<span data-testid="isInstallingWithError">{String(isInstallingWithError)}</span>
<span data-testid="isSuccess">{String(isSuccess)}</span>
<span data-testid="isFailed">{String(isFailed)}</span>
</div>
)
}
render(<TestComponent />)
expect(screen.getByTestId('isInstalling')).toHaveTextContent('true')
expect(screen.getByTestId('isInstallingWithSuccess')).toHaveTextContent('false')
expect(screen.getByTestId('isInstallingWithError')).toHaveTextContent('false')
expect(screen.getByTestId('isSuccess')).toHaveTextContent('false')
expect(screen.getByTestId('isFailed')).toHaveTextContent('false')
})
it('should set isInstallingWithSuccess when running and success plugins exist', () => {
setupMocks([
createMockPlugin({ status: TaskStatus.running }),
createMockPlugin({ status: TaskStatus.success }),
])
const TestComponent = () => {
const { isInstallingWithSuccess } = usePluginTaskStatus()
return <span data-testid="flag">{String(isInstallingWithSuccess)}</span>
}
render(<TestComponent />)
expect(screen.getByTestId('flag')).toHaveTextContent('true')
})
it('should set isInstallingWithError when running and error plugins exist', () => {
setupMocks([
createMockPlugin({ status: TaskStatus.running }),
createMockPlugin({ status: TaskStatus.failed }),
])
const TestComponent = () => {
const { isInstallingWithError } = usePluginTaskStatus()
return <span data-testid="flag">{String(isInstallingWithError)}</span>
}
render(<TestComponent />)
expect(screen.getByTestId('flag')).toHaveTextContent('true')
})
it('should set isSuccess when all plugins succeeded', () => {
setupMocks([
createMockPlugin({ status: TaskStatus.success }),
createMockPlugin({ status: TaskStatus.success }),
])
const TestComponent = () => {
const { isSuccess } = usePluginTaskStatus()
return <span data-testid="flag">{String(isSuccess)}</span>
}
render(<TestComponent />)
expect(screen.getByTestId('flag')).toHaveTextContent('true')
})
it('should set isFailed when no running plugins and some failed', () => {
setupMocks([
createMockPlugin({ status: TaskStatus.success }),
createMockPlugin({ status: TaskStatus.failed }),
])
const TestComponent = () => {
const { isFailed } = usePluginTaskStatus()
return <span data-testid="flag">{String(isFailed)}</span>
}
render(<TestComponent />)
expect(screen.getByTestId('flag')).toHaveTextContent('true')
})
})
describe('handleClearErrorPlugin', () => {
it('should call mutateAsync and handleRefetch', async () => {
const { mockMutateAsync, mockHandleRefetch } = setupMocks([
createMockPlugin({ status: TaskStatus.failed }),
])
const TestComponent = () => {
const { handleClearErrorPlugin } = usePluginTaskStatus()
return (
<button onClick={() => handleClearErrorPlugin('task-1', 'plugin-1')}>
Clear
</button>
)
}
render(<TestComponent />)
fireEvent.click(screen.getByRole('button'))
await waitFor(() => {
expect(mockMutateAsync).toHaveBeenCalledWith({
taskId: 'task-1',
pluginId: 'plugin-1',
})
expect(mockHandleRefetch).toHaveBeenCalled()
})
})
})
})
// ============================================================================
// TaskStatusIndicator Component Tests
// ============================================================================
describe('TaskStatusIndicator Component', () => {
const defaultProps = {
tip: 'Test tooltip',
isInstalling: false,
isInstallingWithSuccess: false,
isInstallingWithError: false,
isSuccess: false,
isFailed: false,
successPluginsLength: 0,
runningPluginsLength: 0,
totalPluginsLength: 1,
onClick: vi.fn(),
}
beforeEach(() => {
vi.clearAllMocks()
})
describe('Rendering', () => {
it('should render without crashing', () => {
render(<TaskStatusIndicator {...defaultProps} />)
expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument()
})
it('should render with correct id', () => {
render(<TaskStatusIndicator {...defaultProps} />)
expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument()
})
})
describe('Icon display', () => {
it('should show downloading icon when installing', () => {
render(<TaskStatusIndicator {...defaultProps} isInstalling />)
// DownloadingIcon is rendered when isInstalling is true
expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument()
})
it('should show downloading icon when installing with error', () => {
render(<TaskStatusIndicator {...defaultProps} isInstallingWithError />)
expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument()
})
it('should show install icon when not installing', () => {
render(<TaskStatusIndicator {...defaultProps} isSuccess />)
expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument()
})
})
describe('Status badge', () => {
it('should show progress circle when installing', () => {
render(
<TaskStatusIndicator
{...defaultProps}
isInstalling
successPluginsLength={1}
totalPluginsLength={3}
/>,
)
expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument()
})
it('should show progress circle when installing with success', () => {
render(
<TaskStatusIndicator
{...defaultProps}
isInstallingWithSuccess
successPluginsLength={2}
totalPluginsLength={3}
/>,
)
expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument()
})
it('should show error progress circle when installing with error', () => {
render(
<TaskStatusIndicator
{...defaultProps}
isInstallingWithError
runningPluginsLength={1}
totalPluginsLength={3}
/>,
)
expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument()
})
it('should show success icon when all completed successfully', () => {
render(
<TaskStatusIndicator
{...defaultProps}
isSuccess
successPluginsLength={3}
runningPluginsLength={0}
totalPluginsLength={3}
/>,
)
expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument()
})
it('should show error icon when failed', () => {
render(<TaskStatusIndicator {...defaultProps} isFailed />)
expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument()
})
})
describe('Styling', () => {
it('should apply error styles when installing with error', () => {
render(<TaskStatusIndicator {...defaultProps} isInstallingWithError />)
const trigger = document.getElementById('plugin-task-trigger')
expect(trigger).toHaveClass('bg-state-destructive-hover')
})
it('should apply error styles when failed', () => {
render(<TaskStatusIndicator {...defaultProps} isFailed />)
const trigger = document.getElementById('plugin-task-trigger')
expect(trigger).toHaveClass('bg-state-destructive-hover')
})
it('should apply cursor-pointer when clickable', () => {
render(<TaskStatusIndicator {...defaultProps} isInstalling />)
const trigger = document.getElementById('plugin-task-trigger')
expect(trigger).toHaveClass('cursor-pointer')
})
})
describe('User interactions', () => {
it('should call onClick when clicked', () => {
const handleClick = vi.fn()
render(<TaskStatusIndicator {...defaultProps} onClick={handleClick} />)
fireEvent.click(document.getElementById('plugin-task-trigger')!)
expect(handleClick).toHaveBeenCalledTimes(1)
})
})
})
// ============================================================================
// PluginTaskList Component Tests
// ============================================================================
describe('PluginTaskList Component', () => {
const defaultProps = {
runningPlugins: [] as PluginStatus[],
successPlugins: [] as PluginStatus[],
errorPlugins: [] as PluginStatus[],
getIconUrl: (icon: string) => `https://example.com/${icon}`,
onClearAll: vi.fn(),
onClearErrors: vi.fn(),
onClearSingle: vi.fn(),
}
beforeEach(() => {
vi.clearAllMocks()
})
describe('Rendering', () => {
it('should render without crashing with empty lists', () => {
render(<PluginTaskList {...defaultProps} />)
expect(document.querySelector('.w-\\[360px\\]')).toBeInTheDocument()
})
it('should render running plugins section when plugins exist', () => {
const runningPlugins = [createMockPlugin({ status: TaskStatus.running })]
render(<PluginTaskList {...defaultProps} runningPlugins={runningPlugins} />)
// Translation key is returned as text in tests, multiple matches expected (title + status)
expect(screen.getAllByText(/task\.installing/i).length).toBeGreaterThan(0)
// Verify section container is rendered
expect(document.querySelector('.max-h-\\[200px\\]')).toBeInTheDocument()
})
it('should render success plugins section when plugins exist', () => {
const successPlugins = [createMockPlugin({ status: TaskStatus.success })]
render(<PluginTaskList {...defaultProps} successPlugins={successPlugins} />)
// Translation key is returned as text in tests, multiple matches expected
expect(screen.getAllByText(/task\.installed/i).length).toBeGreaterThan(0)
})
it('should render error plugins section when plugins exist', () => {
const errorPlugins = [createMockPlugin({ status: TaskStatus.failed, message: 'Error occurred' })]
render(<PluginTaskList {...defaultProps} errorPlugins={errorPlugins} />)
expect(screen.getByText('Error occurred')).toBeInTheDocument()
})
it('should render all sections when all types exist', () => {
render(
<PluginTaskList
{...defaultProps}
runningPlugins={[createMockPlugin({ status: TaskStatus.running })]}
successPlugins={[createMockPlugin({ status: TaskStatus.success })]}
errorPlugins={[createMockPlugin({ status: TaskStatus.failed })]}
/>,
)
// All sections should be present
expect(document.querySelectorAll('.max-h-\\[200px\\]').length).toBe(3)
})
})
describe('User interactions', () => {
it('should call onClearAll when clear all button is clicked in success section', () => {
const handleClearAll = vi.fn()
const successPlugins = [createMockPlugin({ status: TaskStatus.success })]
render(
<PluginTaskList
{...defaultProps}
successPlugins={successPlugins}
onClearAll={handleClearAll}
/>,
)
fireEvent.click(screen.getByRole('button', { name: /task\.clearAll/i }))
expect(handleClearAll).toHaveBeenCalledTimes(1)
})
it('should call onClearErrors when clear all button is clicked in error section', () => {
const handleClearErrors = vi.fn()
const errorPlugins = [createMockPlugin({ status: TaskStatus.failed })]
render(
<PluginTaskList
{...defaultProps}
errorPlugins={errorPlugins}
onClearErrors={handleClearErrors}
/>,
)
const clearButtons = screen.getAllByRole('button')
fireEvent.click(clearButtons.find(btn => btn.textContent?.includes('task.clearAll'))!)
expect(handleClearErrors).toHaveBeenCalledTimes(1)
})
it('should call onClearSingle with correct args when individual clear is clicked', () => {
const handleClearSingle = vi.fn()
const errorPlugin = createMockPlugin({
status: TaskStatus.failed,
plugin_unique_identifier: 'error-plugin-1',
taskId: 'task-123',
})
render(
<PluginTaskList
{...defaultProps}
errorPlugins={[errorPlugin]}
onClearSingle={handleClearSingle}
/>,
)
// The individual clear button has the text 'operation.clear'
fireEvent.click(screen.getByRole('button', { name: /operation\.clear/i }))
expect(handleClearSingle).toHaveBeenCalledWith('task-123', 'error-plugin-1')
})
})
describe('Plugin display', () => {
it('should display plugin name from labels', () => {
const plugin = createMockPlugin({
status: TaskStatus.running,
labels: { en_US: 'My Test Plugin' } as Record<string, string>,
})
render(<PluginTaskList {...defaultProps} runningPlugins={[plugin]} />)
expect(screen.getByText('My Test Plugin')).toBeInTheDocument()
})
it('should display plugin message when available', () => {
const plugin = createMockPlugin({
status: TaskStatus.success,
message: 'Successfully installed!',
})
render(<PluginTaskList {...defaultProps} successPlugins={[plugin]} />)
expect(screen.getByText('Successfully installed!')).toBeInTheDocument()
})
it('should display multiple plugins in each section', () => {
const runningPlugins = [
createMockPlugin({ status: TaskStatus.running, labels: { en_US: 'Plugin A' } as Record<string, string> }),
createMockPlugin({ status: TaskStatus.running, labels: { en_US: 'Plugin B' } as Record<string, string> }),
]
render(<PluginTaskList {...defaultProps} runningPlugins={runningPlugins} />)
expect(screen.getByText('Plugin A')).toBeInTheDocument()
expect(screen.getByText('Plugin B')).toBeInTheDocument()
// Count is rendered, verify multiple items are in list
expect(document.querySelectorAll('.hover\\:bg-state-base-hover').length).toBe(2)
})
})
})
// ============================================================================
// PluginTasks Main Component Tests
// ============================================================================
describe('PluginTasks Component', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('Rendering', () => {
it('should return null when no plugins exist', () => {
setupMocks([])
const { container } = render(<PluginTasks />)
expect(container.firstChild).toBeNull()
})
it('should render when plugins exist', () => {
setupMocks([createMockPlugin({ status: TaskStatus.running })])
render(<PluginTasks />)
expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument()
})
})
describe('Tooltip text (tip memoization)', () => {
it('should show installing tip when isInstalling', () => {
setupMocks([createMockPlugin({ status: TaskStatus.running })])
render(<PluginTasks />)
// The component renders with a tooltip, we verify it exists
expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument()
})
it('should show success tip when all succeeded', () => {
setupMocks([createMockPlugin({ status: TaskStatus.success })])
render(<PluginTasks />)
expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument()
})
it('should show error tip when some failed', () => {
setupMocks([
createMockPlugin({ status: TaskStatus.success }),
createMockPlugin({ status: TaskStatus.failed }),
])
render(<PluginTasks />)
expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument()
})
})
describe('Popover interaction', () => {
it('should toggle popover when trigger is clicked and status allows', () => {
setupMocks([createMockPlugin({ status: TaskStatus.running })])
render(<PluginTasks />)
// Click to open
fireEvent.click(document.getElementById('plugin-task-trigger')!)
// The popover content should be visible (PluginTaskList)
expect(document.querySelector('.w-\\[360px\\]')).toBeInTheDocument()
})
it('should not toggle when status does not allow', () => {
// Setup with no actionable status (edge case - should not happen in practice)
setupMocks([createMockPlugin({ status: TaskStatus.running })])
render(<PluginTasks />)
// Component should still render
expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument()
})
})
describe('Clear handlers', () => {
it('should clear all completed plugins when onClearAll is called', async () => {
const { mockMutateAsync } = setupMocks([
createMockPlugin({ status: TaskStatus.success, plugin_unique_identifier: 'success-1' }),
createMockPlugin({ status: TaskStatus.failed, plugin_unique_identifier: 'error-1' }),
])
render(<PluginTasks />)
// Open popover
fireEvent.click(document.getElementById('plugin-task-trigger')!)
// Wait for popover content to render
await waitFor(() => {
expect(document.querySelector('.w-\\[360px\\]')).toBeInTheDocument()
})
// Find and click clear all button
const clearButtons = screen.getAllByRole('button')
const clearAllButton = clearButtons.find(btn => btn.textContent?.includes('clearAll'))
if (clearAllButton)
fireEvent.click(clearAllButton)
// Verify mutateAsync was called for each completed plugin
await waitFor(() => {
expect(mockMutateAsync).toHaveBeenCalled()
})
})
it('should clear only error plugins when onClearErrors is called', async () => {
const { mockMutateAsync } = setupMocks([
createMockPlugin({ status: TaskStatus.failed, plugin_unique_identifier: 'error-1' }),
])
render(<PluginTasks />)
// Open popover
fireEvent.click(document.getElementById('plugin-task-trigger')!)
await waitFor(() => {
expect(document.querySelector('.w-\\[360px\\]')).toBeInTheDocument()
})
// Find and click the clear all button in error section
const clearButtons = screen.getAllByRole('button')
if (clearButtons.length > 0)
fireEvent.click(clearButtons[0])
await waitFor(() => {
expect(mockMutateAsync).toHaveBeenCalled()
})
})
it('should clear single plugin when onClearSingle is called', async () => {
const { mockMutateAsync } = setupMocks([
createMockPlugin({
status: TaskStatus.failed,
plugin_unique_identifier: 'error-plugin',
taskId: 'task-1',
}),
])
render(<PluginTasks />)
// Open popover
fireEvent.click(document.getElementById('plugin-task-trigger')!)
await waitFor(() => {
expect(document.querySelector('.w-\\[360px\\]')).toBeInTheDocument()
})
// Find and click individual clear button (usually the last one)
const clearButtons = screen.getAllByRole('button')
const individualClearButton = clearButtons[clearButtons.length - 1]
fireEvent.click(individualClearButton)
await waitFor(() => {
expect(mockMutateAsync).toHaveBeenCalledWith({
taskId: 'task-1',
pluginId: 'error-plugin',
})
})
})
})
describe('Edge cases', () => {
it('should handle empty plugin tasks array', () => {
setupMocks([])
const { container } = render(<PluginTasks />)
expect(container.firstChild).toBeNull()
})
it('should handle single running plugin', () => {
setupMocks([createMockPlugin({ status: TaskStatus.running })])
render(<PluginTasks />)
expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument()
})
it('should handle many plugins', () => {
const manyPlugins = Array.from({ length: 10 }, (_, i) =>
createMockPlugin({
status: i % 3 === 0 ? TaskStatus.running : i % 3 === 1 ? TaskStatus.success : TaskStatus.failed,
plugin_unique_identifier: `plugin-${i}`,
}))
setupMocks(manyPlugins)
render(<PluginTasks />)
expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument()
})
it('should handle plugins with empty labels', () => {
const plugin = createMockPlugin({
status: TaskStatus.running,
labels: {} as Record<string, string>,
})
setupMocks([plugin])
render(<PluginTasks />)
expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument()
})
it('should handle plugins with long messages', () => {
const plugin = createMockPlugin({
status: TaskStatus.failed,
message: 'A'.repeat(500),
})
setupMocks([plugin])
render(<PluginTasks />)
// Open popover
fireEvent.click(document.getElementById('plugin-task-trigger')!)
expect(document.querySelector('.w-\\[360px\\]')).toBeInTheDocument()
})
})
})
// ============================================================================
// Integration Tests
// ============================================================================
describe('PluginTasks Integration', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('should show correct UI flow from installing to success', async () => {
// Start with installing state
setupMocks([createMockPlugin({ status: TaskStatus.running })])
const { rerender } = render(<PluginTasks />)
expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument()
// Simulate completion by re-rendering with success
setupMocks([createMockPlugin({ status: TaskStatus.success })])
rerender(<PluginTasks />)
expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument()
})
it('should show correct UI flow from installing to failure', async () => {
// Start with installing state
setupMocks([createMockPlugin({ status: TaskStatus.running })])
const { rerender } = render(<PluginTasks />)
expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument()
// Simulate failure by re-rendering with failed
setupMocks([createMockPlugin({ status: TaskStatus.failed, message: 'Network error' })])
rerender(<PluginTasks />)
expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument()
})
it('should handle mixed status during installation', () => {
setupMocks([
createMockPlugin({ status: TaskStatus.running, plugin_unique_identifier: 'p1' }),
createMockPlugin({ status: TaskStatus.success, plugin_unique_identifier: 'p2' }),
createMockPlugin({ status: TaskStatus.failed, plugin_unique_identifier: 'p3' }),
])
render(<PluginTasks />)
// Open popover
fireEvent.click(document.getElementById('plugin-task-trigger')!)
// All sections should be visible
const sections = document.querySelectorAll('.max-h-\\[200px\\]')
expect(sections.length).toBe(3)
})
})

View File

@ -1,33 +1,21 @@
import {
RiCheckboxCircleFill,
RiErrorWarningFill,
RiInstallLine,
RiLoaderLine,
} from '@remixicon/react'
import {
useCallback,
useMemo,
useState,
} from 'react'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import {
PortalToFollowElem,
PortalToFollowElemContent,
PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem'
import ProgressCircle from '@/app/components/base/progress-bar/progress-circle'
import Tooltip from '@/app/components/base/tooltip'
import DownloadingIcon from '@/app/components/header/plugins-nav/downloading-icon'
import CardIcon from '@/app/components/plugins/card/base/card-icon'
import useGetIcon from '@/app/components/plugins/install-plugin/base/use-get-icon'
import { useGetLanguage } from '@/context/i18n'
import { cn } from '@/utils/classnames'
import PluginTaskList from './components/plugin-task-list'
import TaskStatusIndicator from './components/task-status-indicator'
import { usePluginTaskStatus } from './hooks'
const PluginTasks = () => {
const { t } = useTranslation()
const language = useGetLanguage()
const [open, setOpen] = useState(false)
const {
errorPlugins,
@ -46,35 +34,7 @@ const PluginTasks = () => {
} = usePluginTaskStatus()
const { getIconUrl } = useGetIcon()
const handleClearAllWithModal = useCallback(async () => {
// Clear all completed plugins (success and error) but keep running ones
const completedPlugins = [...successPlugins, ...errorPlugins]
// Clear all completed plugins individually
for (const plugin of completedPlugins)
await handleClearErrorPlugin(plugin.taskId, plugin.plugin_unique_identifier)
// Only close modal if no plugins are still installing
if (runningPluginsLength === 0)
setOpen(false)
}, [successPlugins, errorPlugins, handleClearErrorPlugin, runningPluginsLength])
const handleClearErrorsWithModal = useCallback(async () => {
// Clear only error plugins, not all plugins
for (const plugin of errorPlugins)
await handleClearErrorPlugin(plugin.taskId, plugin.plugin_unique_identifier)
// Only close modal if no plugins are still installing
if (runningPluginsLength === 0)
setOpen(false)
}, [errorPlugins, handleClearErrorPlugin, runningPluginsLength])
const handleClearSingleWithModal = useCallback(async (taskId: string, pluginId: string) => {
await handleClearErrorPlugin(taskId, pluginId)
// Only close modal if no plugins are still installing
if (runningPluginsLength === 0)
setOpen(false)
}, [handleClearErrorPlugin, runningPluginsLength])
// Generate tooltip text based on status
const tip = useMemo(() => {
if (isInstallingWithError)
return t('task.installingWithError', { ns: 'plugin', installingLength: runningPluginsLength, successLength: successPluginsLength, errorLength: errorPluginsLength })
@ -99,8 +59,38 @@ const PluginTasks = () => {
t,
])
// Show icon if there are any plugin tasks (completed, running, or failed)
// Only hide when there are absolutely no plugin tasks
// Generic clear function that handles clearing and modal closing
const clearPluginsAndClose = useCallback(async (
plugins: Array<{ taskId: string, plugin_unique_identifier: string }>,
) => {
for (const plugin of plugins)
await handleClearErrorPlugin(plugin.taskId, plugin.plugin_unique_identifier)
if (runningPluginsLength === 0)
setOpen(false)
}, [handleClearErrorPlugin, runningPluginsLength])
// Clear handlers using the generic function
const handleClearAll = useCallback(
() => clearPluginsAndClose([...successPlugins, ...errorPlugins]),
[clearPluginsAndClose, successPlugins, errorPlugins],
)
const handleClearErrors = useCallback(
() => clearPluginsAndClose(errorPlugins),
[clearPluginsAndClose, errorPlugins],
)
const handleClearSingle = useCallback(
(taskId: string, pluginId: string) => clearPluginsAndClose([{ taskId, plugin_unique_identifier: pluginId }]),
[clearPluginsAndClose],
)
const handleTriggerClick = useCallback(() => {
if (isFailed || isInstalling || isInstallingWithSuccess || isInstallingWithError || isSuccess)
setOpen(v => !v)
}, [isFailed, isInstalling, isInstallingWithSuccess, isInstallingWithError, isSuccess])
// Hide when no plugin tasks
if (totalPluginsLength === 0)
return null
@ -115,206 +105,30 @@ const PluginTasks = () => {
crossAxis: 79,
}}
>
<PortalToFollowElemTrigger
onClick={() => {
if (isFailed || isInstalling || isInstallingWithSuccess || isInstallingWithError || isSuccess)
setOpen(v => !v)
}}
>
<Tooltip
popupContent={tip}
asChild
offset={8}
>
<div
className={cn(
'relative flex h-8 w-8 items-center justify-center rounded-lg border-[0.5px] border-components-button-secondary-border bg-components-button-secondary-bg shadow-xs hover:bg-components-button-secondary-bg-hover',
(isInstallingWithError || isFailed) && 'cursor-pointer border-components-button-destructive-secondary-border-hover bg-state-destructive-hover hover:bg-state-destructive-hover-alt',
(isInstalling || isInstallingWithSuccess || isSuccess) && 'cursor-pointer hover:bg-components-button-secondary-bg-hover',
)}
id="plugin-task-trigger"
>
{
(isInstalling || isInstallingWithError) && (
<DownloadingIcon />
)
}
{
!(isInstalling || isInstallingWithError) && (
<RiInstallLine
className={cn(
'h-4 w-4 text-components-button-secondary-text',
(isInstallingWithError || isFailed) && 'text-components-button-destructive-secondary-text',
)}
/>
)
}
<div className="absolute -right-1 -top-1">
{
(isInstalling || isInstallingWithSuccess) && (
<ProgressCircle
percentage={successPluginsLength / totalPluginsLength * 100}
circleFillColor="fill-components-progress-brand-bg"
/>
)
}
{
isInstallingWithError && (
<ProgressCircle
percentage={runningPluginsLength / totalPluginsLength * 100}
circleFillColor="fill-components-progress-brand-bg"
sectorFillColor="fill-components-progress-error-border"
circleStrokeColor="stroke-components-progress-error-border"
/>
)
}
{
(isSuccess || (successPluginsLength > 0 && runningPluginsLength === 0 && errorPluginsLength === 0)) && (
<RiCheckboxCircleFill className="h-3.5 w-3.5 text-text-success" />
)
}
{
isFailed && (
<RiErrorWarningFill className="h-3.5 w-3.5 text-text-destructive" />
)
}
</div>
</div>
</Tooltip>
<PortalToFollowElemTrigger onClick={handleTriggerClick}>
<TaskStatusIndicator
tip={tip}
isInstalling={isInstalling}
isInstallingWithSuccess={isInstallingWithSuccess}
isInstallingWithError={isInstallingWithError}
isSuccess={isSuccess}
isFailed={isFailed}
successPluginsLength={successPluginsLength}
runningPluginsLength={runningPluginsLength}
totalPluginsLength={totalPluginsLength}
onClick={() => {}}
/>
</PortalToFollowElemTrigger>
<PortalToFollowElemContent className="z-[11]">
<div className="w-[360px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur p-1 shadow-lg">
{/* Running Plugins */}
{runningPlugins.length > 0 && (
<>
<div className="system-sm-semibold-uppercase sticky top-0 flex h-7 items-center justify-between px-2 pt-1 text-text-secondary">
{t('task.installing', { ns: 'plugin' })}
{' '}
(
{runningPlugins.length}
)
</div>
<div className="max-h-[200px] overflow-y-auto">
{runningPlugins.map(runningPlugin => (
<div
key={runningPlugin.plugin_unique_identifier}
className="flex items-center rounded-lg p-2 hover:bg-state-base-hover"
>
<div className="relative mr-2 flex h-6 w-6 items-center justify-center rounded-md border-[0.5px] border-components-panel-border-subtle bg-background-default-dodge">
<RiLoaderLine className="absolute -bottom-0.5 -right-0.5 z-10 h-3 w-3 animate-spin text-text-accent" />
<CardIcon
size="tiny"
src={getIconUrl(runningPlugin.icon)}
/>
</div>
<div className="grow">
<div className="system-md-regular truncate text-text-secondary">
{runningPlugin.labels[language]}
</div>
<div className="system-xs-regular text-text-tertiary">
{t('task.installing', { ns: 'plugin' })}
</div>
</div>
</div>
))}
</div>
</>
)}
{/* Success Plugins */}
{successPlugins.length > 0 && (
<>
<div className="system-sm-semibold-uppercase sticky top-0 flex h-7 items-center justify-between px-2 pt-1 text-text-secondary">
{t('task.installed', { ns: 'plugin' })}
{' '}
(
{successPlugins.length}
)
<Button
className="shrink-0"
size="small"
variant="ghost"
onClick={() => handleClearAllWithModal()}
>
{t('task.clearAll', { ns: 'plugin' })}
</Button>
</div>
<div className="max-h-[200px] overflow-y-auto">
{successPlugins.map(successPlugin => (
<div
key={successPlugin.plugin_unique_identifier}
className="flex items-center rounded-lg p-2 hover:bg-state-base-hover"
>
<div className="relative mr-2 flex h-6 w-6 items-center justify-center rounded-md border-[0.5px] border-components-panel-border-subtle bg-background-default-dodge">
<RiCheckboxCircleFill className="absolute -bottom-0.5 -right-0.5 z-10 h-3 w-3 text-text-success" />
<CardIcon
size="tiny"
src={getIconUrl(successPlugin.icon)}
/>
</div>
<div className="grow">
<div className="system-md-regular truncate text-text-secondary">
{successPlugin.labels[language]}
</div>
<div className="system-xs-regular text-text-success">
{successPlugin.message || t('task.installed', { ns: 'plugin' })}
</div>
</div>
</div>
))}
</div>
</>
)}
{/* Error Plugins */}
{errorPlugins.length > 0 && (
<>
<div className="system-sm-semibold-uppercase sticky top-0 flex h-7 items-center justify-between px-2 pt-1 text-text-secondary">
{t('task.installError', { ns: 'plugin', errorLength: errorPlugins.length })}
<Button
className="shrink-0"
size="small"
variant="ghost"
onClick={() => handleClearErrorsWithModal()}
>
{t('task.clearAll', { ns: 'plugin' })}
</Button>
</div>
<div className="max-h-[200px] overflow-y-auto">
{errorPlugins.map(errorPlugin => (
<div
key={errorPlugin.plugin_unique_identifier}
className="flex items-center rounded-lg p-2 hover:bg-state-base-hover"
>
<div className="relative mr-2 flex h-6 w-6 items-center justify-center rounded-md border-[0.5px] border-components-panel-border-subtle bg-background-default-dodge">
<RiErrorWarningFill className="absolute -bottom-0.5 -right-0.5 z-10 h-3 w-3 text-text-destructive" />
<CardIcon
size="tiny"
src={getIconUrl(errorPlugin.icon)}
/>
</div>
<div className="grow">
<div className="system-md-regular truncate text-text-secondary">
{errorPlugin.labels[language]}
</div>
<div className="system-xs-regular break-all text-text-destructive">
{errorPlugin.message}
</div>
</div>
<Button
className="shrink-0"
size="small"
variant="ghost"
onClick={() => handleClearSingleWithModal(errorPlugin.taskId, errorPlugin.plugin_unique_identifier)}
>
{t('operation.clear', { ns: 'common' })}
</Button>
</div>
))}
</div>
</>
)}
</div>
<PluginTaskList
runningPlugins={runningPlugins}
successPlugins={successPlugins}
errorPlugins={errorPlugins}
getIconUrl={getIconUrl}
onClearAll={handleClearAll}
onClearErrors={handleClearErrors}
onClearSingle={handleClearSingle}
/>
</PortalToFollowElemContent>
</PortalToFollowElem>
</div>

View File

@ -0,0 +1,388 @@
import { renderHook, waitFor } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
// Import mocks for assertions
import { useAppContext } from '@/context/app-context'
import { useGlobalPublicStore } from '@/context/global-public-context'
import { useInvalidateReferenceSettings, useMutationReferenceSettings, useReferenceSettings } from '@/service/use-plugins'
import Toast from '../../base/toast'
import { PermissionType } from '../types'
import useReferenceSetting, { useCanInstallPluginFromMarketplace } from './use-reference-setting'
// Mock dependencies
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
vi.mock('@/context/app-context', () => ({
useAppContext: vi.fn(),
}))
vi.mock('@/context/global-public-context', () => ({
useGlobalPublicStore: vi.fn(),
}))
vi.mock('@/service/use-plugins', () => ({
useReferenceSettings: vi.fn(),
useMutationReferenceSettings: vi.fn(),
useInvalidateReferenceSettings: vi.fn(),
}))
vi.mock('../../base/toast', () => ({
default: {
notify: vi.fn(),
},
}))
describe('useReferenceSetting Hook', () => {
beforeEach(() => {
vi.clearAllMocks()
// Default mocks
vi.mocked(useAppContext).mockReturnValue({
isCurrentWorkspaceManager: false,
isCurrentWorkspaceOwner: false,
} as ReturnType<typeof useAppContext>)
vi.mocked(useReferenceSettings).mockReturnValue({
data: {
permission: {
install_permission: PermissionType.everyone,
debug_permission: PermissionType.everyone,
},
},
} as ReturnType<typeof useReferenceSettings>)
vi.mocked(useMutationReferenceSettings).mockReturnValue({
mutate: vi.fn(),
isPending: false,
} as unknown as ReturnType<typeof useMutationReferenceSettings>)
vi.mocked(useInvalidateReferenceSettings).mockReturnValue(vi.fn())
})
describe('hasPermission logic', () => {
it('should return false when permission is undefined', () => {
vi.mocked(useReferenceSettings).mockReturnValue({
data: {
permission: {
install_permission: undefined,
debug_permission: undefined,
},
},
} as unknown as ReturnType<typeof useReferenceSettings>)
const { result } = renderHook(() => useReferenceSetting())
expect(result.current.canManagement).toBe(false)
expect(result.current.canDebugger).toBe(false)
})
it('should return false when permission is noOne', () => {
vi.mocked(useReferenceSettings).mockReturnValue({
data: {
permission: {
install_permission: PermissionType.noOne,
debug_permission: PermissionType.noOne,
},
},
} as ReturnType<typeof useReferenceSettings>)
const { result } = renderHook(() => useReferenceSetting())
expect(result.current.canManagement).toBe(false)
expect(result.current.canDebugger).toBe(false)
})
it('should return true when permission is everyone', () => {
vi.mocked(useReferenceSettings).mockReturnValue({
data: {
permission: {
install_permission: PermissionType.everyone,
debug_permission: PermissionType.everyone,
},
},
} as ReturnType<typeof useReferenceSettings>)
const { result } = renderHook(() => useReferenceSetting())
expect(result.current.canManagement).toBe(true)
expect(result.current.canDebugger).toBe(true)
})
it('should return isAdmin when permission is admin and user is manager', () => {
vi.mocked(useAppContext).mockReturnValue({
isCurrentWorkspaceManager: true,
isCurrentWorkspaceOwner: false,
} as ReturnType<typeof useAppContext>)
vi.mocked(useReferenceSettings).mockReturnValue({
data: {
permission: {
install_permission: PermissionType.admin,
debug_permission: PermissionType.admin,
},
},
} as ReturnType<typeof useReferenceSettings>)
const { result } = renderHook(() => useReferenceSetting())
expect(result.current.canManagement).toBe(true)
expect(result.current.canDebugger).toBe(true)
})
it('should return isAdmin when permission is admin and user is owner', () => {
vi.mocked(useAppContext).mockReturnValue({
isCurrentWorkspaceManager: false,
isCurrentWorkspaceOwner: true,
} as ReturnType<typeof useAppContext>)
vi.mocked(useReferenceSettings).mockReturnValue({
data: {
permission: {
install_permission: PermissionType.admin,
debug_permission: PermissionType.admin,
},
},
} as ReturnType<typeof useReferenceSettings>)
const { result } = renderHook(() => useReferenceSetting())
expect(result.current.canManagement).toBe(true)
expect(result.current.canDebugger).toBe(true)
})
it('should return false when permission is admin and user is not admin', () => {
vi.mocked(useAppContext).mockReturnValue({
isCurrentWorkspaceManager: false,
isCurrentWorkspaceOwner: false,
} as ReturnType<typeof useAppContext>)
vi.mocked(useReferenceSettings).mockReturnValue({
data: {
permission: {
install_permission: PermissionType.admin,
debug_permission: PermissionType.admin,
},
},
} as ReturnType<typeof useReferenceSettings>)
const { result } = renderHook(() => useReferenceSetting())
expect(result.current.canManagement).toBe(false)
expect(result.current.canDebugger).toBe(false)
})
})
describe('canSetPermissions', () => {
it('should be true when user is workspace manager', () => {
vi.mocked(useAppContext).mockReturnValue({
isCurrentWorkspaceManager: true,
isCurrentWorkspaceOwner: false,
} as ReturnType<typeof useAppContext>)
const { result } = renderHook(() => useReferenceSetting())
expect(result.current.canSetPermissions).toBe(true)
})
it('should be true when user is workspace owner', () => {
vi.mocked(useAppContext).mockReturnValue({
isCurrentWorkspaceManager: false,
isCurrentWorkspaceOwner: true,
} as ReturnType<typeof useAppContext>)
const { result } = renderHook(() => useReferenceSetting())
expect(result.current.canSetPermissions).toBe(true)
})
it('should be false when user is neither manager nor owner', () => {
vi.mocked(useAppContext).mockReturnValue({
isCurrentWorkspaceManager: false,
isCurrentWorkspaceOwner: false,
} as ReturnType<typeof useAppContext>)
const { result } = renderHook(() => useReferenceSetting())
expect(result.current.canSetPermissions).toBe(false)
})
})
describe('setReferenceSettings callback', () => {
it('should call invalidateReferenceSettings and show toast on success', async () => {
const mockInvalidate = vi.fn()
vi.mocked(useInvalidateReferenceSettings).mockReturnValue(mockInvalidate)
let onSuccessCallback: (() => void) | undefined
vi.mocked(useMutationReferenceSettings).mockImplementation((options) => {
onSuccessCallback = options?.onSuccess as () => void
return {
mutate: vi.fn(),
isPending: false,
} as unknown as ReturnType<typeof useMutationReferenceSettings>
})
renderHook(() => useReferenceSetting())
// Trigger the onSuccess callback
if (onSuccessCallback)
onSuccessCallback()
await waitFor(() => {
expect(mockInvalidate).toHaveBeenCalled()
expect(Toast.notify).toHaveBeenCalledWith({
type: 'success',
message: 'api.actionSuccess',
})
})
})
})
describe('returned values', () => {
it('should return referenceSetting data', () => {
const mockData = {
permission: {
install_permission: PermissionType.everyone,
debug_permission: PermissionType.everyone,
},
}
vi.mocked(useReferenceSettings).mockReturnValue({
data: mockData,
} as ReturnType<typeof useReferenceSettings>)
const { result } = renderHook(() => useReferenceSetting())
expect(result.current.referenceSetting).toEqual(mockData)
})
it('should return isUpdatePending from mutation', () => {
vi.mocked(useMutationReferenceSettings).mockReturnValue({
mutate: vi.fn(),
isPending: true,
} as unknown as ReturnType<typeof useMutationReferenceSettings>)
const { result } = renderHook(() => useReferenceSetting())
expect(result.current.isUpdatePending).toBe(true)
})
it('should handle null data', () => {
vi.mocked(useReferenceSettings).mockReturnValue({
data: null,
} as unknown as ReturnType<typeof useReferenceSettings>)
const { result } = renderHook(() => useReferenceSetting())
expect(result.current.canManagement).toBe(false)
expect(result.current.canDebugger).toBe(false)
})
})
})
describe('useCanInstallPluginFromMarketplace Hook', () => {
beforeEach(() => {
vi.clearAllMocks()
vi.mocked(useAppContext).mockReturnValue({
isCurrentWorkspaceManager: true,
isCurrentWorkspaceOwner: false,
} as ReturnType<typeof useAppContext>)
vi.mocked(useReferenceSettings).mockReturnValue({
data: {
permission: {
install_permission: PermissionType.everyone,
debug_permission: PermissionType.everyone,
},
},
} as ReturnType<typeof useReferenceSettings>)
vi.mocked(useMutationReferenceSettings).mockReturnValue({
mutate: vi.fn(),
isPending: false,
} as unknown as ReturnType<typeof useMutationReferenceSettings>)
vi.mocked(useInvalidateReferenceSettings).mockReturnValue(vi.fn())
})
it('should return true when marketplace is enabled and canManagement is true', () => {
vi.mocked(useGlobalPublicStore).mockImplementation((selector) => {
const state = {
systemFeatures: {
enable_marketplace: true,
},
}
return selector(state as Parameters<typeof selector>[0])
})
const { result } = renderHook(() => useCanInstallPluginFromMarketplace())
expect(result.current.canInstallPluginFromMarketplace).toBe(true)
})
it('should return false when marketplace is disabled', () => {
vi.mocked(useGlobalPublicStore).mockImplementation((selector) => {
const state = {
systemFeatures: {
enable_marketplace: false,
},
}
return selector(state as Parameters<typeof selector>[0])
})
const { result } = renderHook(() => useCanInstallPluginFromMarketplace())
expect(result.current.canInstallPluginFromMarketplace).toBe(false)
})
it('should return false when canManagement is false', () => {
vi.mocked(useGlobalPublicStore).mockImplementation((selector) => {
const state = {
systemFeatures: {
enable_marketplace: true,
},
}
return selector(state as Parameters<typeof selector>[0])
})
vi.mocked(useReferenceSettings).mockReturnValue({
data: {
permission: {
install_permission: PermissionType.noOne,
debug_permission: PermissionType.noOne,
},
},
} as ReturnType<typeof useReferenceSettings>)
const { result } = renderHook(() => useCanInstallPluginFromMarketplace())
expect(result.current.canInstallPluginFromMarketplace).toBe(false)
})
it('should return false when both marketplace is disabled and canManagement is false', () => {
vi.mocked(useGlobalPublicStore).mockImplementation((selector) => {
const state = {
systemFeatures: {
enable_marketplace: false,
},
}
return selector(state as Parameters<typeof selector>[0])
})
vi.mocked(useReferenceSettings).mockReturnValue({
data: {
permission: {
install_permission: PermissionType.noOne,
debug_permission: PermissionType.noOne,
},
},
} as ReturnType<typeof useReferenceSettings>)
const { result } = renderHook(() => useCanInstallPluginFromMarketplace())
expect(result.current.canInstallPluginFromMarketplace).toBe(false)
})
})

View File

@ -0,0 +1,487 @@
import type { RefObject } from 'react'
import { act, renderHook } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { useUploader } from './use-uploader'
describe('useUploader Hook', () => {
let mockContainerRef: RefObject<HTMLDivElement | null>
let mockOnFileChange: (file: File | null) => void
let mockContainer: HTMLDivElement
beforeEach(() => {
vi.clearAllMocks()
mockContainer = document.createElement('div')
document.body.appendChild(mockContainer)
mockContainerRef = { current: mockContainer }
mockOnFileChange = vi.fn()
})
afterEach(() => {
if (mockContainer.parentNode)
document.body.removeChild(mockContainer)
})
describe('Initial State', () => {
it('should return initial state with dragging false', () => {
const { result } = renderHook(() =>
useUploader({
onFileChange: mockOnFileChange,
containerRef: mockContainerRef,
}),
)
expect(result.current.dragging).toBe(false)
expect(result.current.fileUploader.current).toBeNull()
expect(result.current.fileChangeHandle).not.toBeNull()
expect(result.current.removeFile).not.toBeNull()
})
it('should return null handlers when disabled', () => {
const { result } = renderHook(() =>
useUploader({
onFileChange: mockOnFileChange,
containerRef: mockContainerRef,
enabled: false,
}),
)
expect(result.current.dragging).toBe(false)
expect(result.current.fileChangeHandle).toBeNull()
expect(result.current.removeFile).toBeNull()
})
})
describe('Drag Events', () => {
it('should handle dragenter and set dragging to true', () => {
const { result } = renderHook(() =>
useUploader({
onFileChange: mockOnFileChange,
containerRef: mockContainerRef,
}),
)
const dragEnterEvent = new Event('dragenter', { bubbles: true, cancelable: true })
Object.defineProperty(dragEnterEvent, 'dataTransfer', {
value: { types: ['Files'] },
})
act(() => {
mockContainer.dispatchEvent(dragEnterEvent)
})
expect(result.current.dragging).toBe(true)
})
it('should not set dragging when dragenter without Files type', () => {
const { result } = renderHook(() =>
useUploader({
onFileChange: mockOnFileChange,
containerRef: mockContainerRef,
}),
)
const dragEnterEvent = new Event('dragenter', { bubbles: true, cancelable: true })
Object.defineProperty(dragEnterEvent, 'dataTransfer', {
value: { types: ['text/plain'] },
})
act(() => {
mockContainer.dispatchEvent(dragEnterEvent)
})
expect(result.current.dragging).toBe(false)
})
it('should handle dragover event', () => {
renderHook(() =>
useUploader({
onFileChange: mockOnFileChange,
containerRef: mockContainerRef,
}),
)
const dragOverEvent = new Event('dragover', { bubbles: true, cancelable: true })
act(() => {
mockContainer.dispatchEvent(dragOverEvent)
})
// dragover should prevent default and stop propagation
expect(mockContainer).toBeInTheDocument()
})
it('should handle dragleave when relatedTarget is null', () => {
const { result } = renderHook(() =>
useUploader({
onFileChange: mockOnFileChange,
containerRef: mockContainerRef,
}),
)
// First set dragging to true
const dragEnterEvent = new Event('dragenter', { bubbles: true, cancelable: true })
Object.defineProperty(dragEnterEvent, 'dataTransfer', {
value: { types: ['Files'] },
})
act(() => {
mockContainer.dispatchEvent(dragEnterEvent)
})
expect(result.current.dragging).toBe(true)
// Then trigger dragleave with null relatedTarget
const dragLeaveEvent = new Event('dragleave', { bubbles: true, cancelable: true })
Object.defineProperty(dragLeaveEvent, 'relatedTarget', {
value: null,
})
act(() => {
mockContainer.dispatchEvent(dragLeaveEvent)
})
expect(result.current.dragging).toBe(false)
})
it('should handle dragleave when relatedTarget is outside container', () => {
const { result } = renderHook(() =>
useUploader({
onFileChange: mockOnFileChange,
containerRef: mockContainerRef,
}),
)
// First set dragging to true
const dragEnterEvent = new Event('dragenter', { bubbles: true, cancelable: true })
Object.defineProperty(dragEnterEvent, 'dataTransfer', {
value: { types: ['Files'] },
})
act(() => {
mockContainer.dispatchEvent(dragEnterEvent)
})
expect(result.current.dragging).toBe(true)
// Create element outside container
const outsideElement = document.createElement('div')
document.body.appendChild(outsideElement)
// Trigger dragleave with relatedTarget outside container
const dragLeaveEvent = new Event('dragleave', { bubbles: true, cancelable: true })
Object.defineProperty(dragLeaveEvent, 'relatedTarget', {
value: outsideElement,
})
act(() => {
mockContainer.dispatchEvent(dragLeaveEvent)
})
expect(result.current.dragging).toBe(false)
document.body.removeChild(outsideElement)
})
it('should not set dragging to false when relatedTarget is inside container', () => {
const { result } = renderHook(() =>
useUploader({
onFileChange: mockOnFileChange,
containerRef: mockContainerRef,
}),
)
// First set dragging to true
const dragEnterEvent = new Event('dragenter', { bubbles: true, cancelable: true })
Object.defineProperty(dragEnterEvent, 'dataTransfer', {
value: { types: ['Files'] },
})
act(() => {
mockContainer.dispatchEvent(dragEnterEvent)
})
expect(result.current.dragging).toBe(true)
// Create element inside container
const insideElement = document.createElement('div')
mockContainer.appendChild(insideElement)
// Trigger dragleave with relatedTarget inside container
const dragLeaveEvent = new Event('dragleave', { bubbles: true, cancelable: true })
Object.defineProperty(dragLeaveEvent, 'relatedTarget', {
value: insideElement,
})
act(() => {
mockContainer.dispatchEvent(dragLeaveEvent)
})
// Should still be dragging since relatedTarget is inside container
expect(result.current.dragging).toBe(true)
})
})
describe('Drop Events', () => {
it('should handle drop event with files', () => {
const { result } = renderHook(() =>
useUploader({
onFileChange: mockOnFileChange,
containerRef: mockContainerRef,
}),
)
// First set dragging to true
const dragEnterEvent = new Event('dragenter', { bubbles: true, cancelable: true })
Object.defineProperty(dragEnterEvent, 'dataTransfer', {
value: { types: ['Files'] },
})
act(() => {
mockContainer.dispatchEvent(dragEnterEvent)
})
// Create mock file
const file = new File(['content'], 'test.difypkg', { type: 'application/octet-stream' })
// Trigger drop event
const dropEvent = new Event('drop', { bubbles: true, cancelable: true })
Object.defineProperty(dropEvent, 'dataTransfer', {
value: { files: [file] },
})
act(() => {
mockContainer.dispatchEvent(dropEvent)
})
expect(result.current.dragging).toBe(false)
expect(mockOnFileChange).toHaveBeenCalledWith(file)
})
it('should not call onFileChange when drop has no dataTransfer', () => {
const { result } = renderHook(() =>
useUploader({
onFileChange: mockOnFileChange,
containerRef: mockContainerRef,
}),
)
// Set dragging first
const dragEnterEvent = new Event('dragenter', { bubbles: true, cancelable: true })
Object.defineProperty(dragEnterEvent, 'dataTransfer', {
value: { types: ['Files'] },
})
act(() => {
mockContainer.dispatchEvent(dragEnterEvent)
})
// Drop without dataTransfer
const dropEvent = new Event('drop', { bubbles: true, cancelable: true })
// No dataTransfer property
act(() => {
mockContainer.dispatchEvent(dropEvent)
})
expect(result.current.dragging).toBe(false)
expect(mockOnFileChange).not.toHaveBeenCalled()
})
it('should not call onFileChange when drop has empty files array', () => {
renderHook(() =>
useUploader({
onFileChange: mockOnFileChange,
containerRef: mockContainerRef,
}),
)
const dropEvent = new Event('drop', { bubbles: true, cancelable: true })
Object.defineProperty(dropEvent, 'dataTransfer', {
value: { files: [] },
})
act(() => {
mockContainer.dispatchEvent(dropEvent)
})
expect(mockOnFileChange).not.toHaveBeenCalled()
})
})
describe('File Change Handler', () => {
it('should call onFileChange with file from input', () => {
const { result } = renderHook(() =>
useUploader({
onFileChange: mockOnFileChange,
containerRef: mockContainerRef,
}),
)
const file = new File(['content'], 'test.difypkg', { type: 'application/octet-stream' })
const mockEvent = {
target: {
files: [file],
},
} as unknown as React.ChangeEvent<HTMLInputElement>
act(() => {
result.current.fileChangeHandle?.(mockEvent)
})
expect(mockOnFileChange).toHaveBeenCalledWith(file)
})
it('should call onFileChange with null when no files', () => {
const { result } = renderHook(() =>
useUploader({
onFileChange: mockOnFileChange,
containerRef: mockContainerRef,
}),
)
const mockEvent = {
target: {
files: null,
},
} as unknown as React.ChangeEvent<HTMLInputElement>
act(() => {
result.current.fileChangeHandle?.(mockEvent)
})
expect(mockOnFileChange).toHaveBeenCalledWith(null)
})
})
describe('Remove File', () => {
it('should call onFileChange with null', () => {
const { result } = renderHook(() =>
useUploader({
onFileChange: mockOnFileChange,
containerRef: mockContainerRef,
}),
)
act(() => {
result.current.removeFile?.()
})
expect(mockOnFileChange).toHaveBeenCalledWith(null)
})
it('should handle removeFile when fileUploader has a value', () => {
const { result } = renderHook(() =>
useUploader({
onFileChange: mockOnFileChange,
containerRef: mockContainerRef,
}),
)
// Create a mock input element with value property
const mockInput = {
value: 'test.difypkg',
}
// Override the fileUploader ref
Object.defineProperty(result.current.fileUploader, 'current', {
value: mockInput,
writable: true,
})
act(() => {
result.current.removeFile?.()
})
expect(mockOnFileChange).toHaveBeenCalledWith(null)
expect(mockInput.value).toBe('')
})
it('should handle removeFile when fileUploader is null', () => {
const { result } = renderHook(() =>
useUploader({
onFileChange: mockOnFileChange,
containerRef: mockContainerRef,
}),
)
// fileUploader.current is null by default
act(() => {
result.current.removeFile?.()
})
expect(mockOnFileChange).toHaveBeenCalledWith(null)
})
})
describe('Enabled/Disabled State', () => {
it('should not add event listeners when disabled', () => {
const addEventListenerSpy = vi.spyOn(mockContainer, 'addEventListener')
renderHook(() =>
useUploader({
onFileChange: mockOnFileChange,
containerRef: mockContainerRef,
enabled: false,
}),
)
expect(addEventListenerSpy).not.toHaveBeenCalled()
})
it('should add event listeners when enabled', () => {
const addEventListenerSpy = vi.spyOn(mockContainer, 'addEventListener')
renderHook(() =>
useUploader({
onFileChange: mockOnFileChange,
containerRef: mockContainerRef,
enabled: true,
}),
)
expect(addEventListenerSpy).toHaveBeenCalledWith('dragenter', expect.any(Function))
expect(addEventListenerSpy).toHaveBeenCalledWith('dragover', expect.any(Function))
expect(addEventListenerSpy).toHaveBeenCalledWith('dragleave', expect.any(Function))
expect(addEventListenerSpy).toHaveBeenCalledWith('drop', expect.any(Function))
})
it('should remove event listeners on cleanup', () => {
const removeEventListenerSpy = vi.spyOn(mockContainer, 'removeEventListener')
const { unmount } = renderHook(() =>
useUploader({
onFileChange: mockOnFileChange,
containerRef: mockContainerRef,
enabled: true,
}),
)
unmount()
expect(removeEventListenerSpy).toHaveBeenCalledWith('dragenter', expect.any(Function))
expect(removeEventListenerSpy).toHaveBeenCalledWith('dragover', expect.any(Function))
expect(removeEventListenerSpy).toHaveBeenCalledWith('dragleave', expect.any(Function))
expect(removeEventListenerSpy).toHaveBeenCalledWith('drop', expect.any(Function))
})
it('should return false for dragging when disabled', () => {
const { result } = renderHook(() =>
useUploader({
onFileChange: mockOnFileChange,
containerRef: mockContainerRef,
enabled: false,
}),
)
expect(result.current.dragging).toBe(false)
})
})
describe('Container Ref Edge Cases', () => {
it('should handle null containerRef.current', () => {
const nullRef: RefObject<HTMLDivElement | null> = { current: null }
const { result } = renderHook(() =>
useUploader({
onFileChange: mockOnFileChange,
containerRef: nullRef,
}),
)
expect(result.current.dragging).toBe(false)
})
})
})

Some files were not shown because too many files have changed in this diff Show More