Merge branch 'main' into feat/pull-a-variable

This commit is contained in:
zhsama
2026-01-15 16:14:15 +08:00
141 changed files with 17063 additions and 4450 deletions

1
.agent/skills Symbolic link
View File

@ -0,0 +1 @@
../.claude/skills

View File

@ -0,0 +1,46 @@
---
name: orpc-contract-first
description: Guide for implementing oRPC contract-first API patterns in Dify frontend. Triggers when creating new API contracts, adding service endpoints, integrating TanStack Query with typed contracts, or migrating legacy service calls to oRPC. Use for all API layer work in web/contract and web/service directories.
---
# oRPC Contract-First Development
## Project Structure
```
web/contract/
├── base.ts # Base contract (inputStructure: 'detailed')
├── router.ts # Router composition & type exports
├── marketplace.ts # Marketplace contracts
└── console/ # Console contracts by domain
├── system.ts
└── billing.ts
```
## Workflow
1. **Create contract** in `web/contract/console/{domain}.ts`
- Import `base` from `../base` and `type` from `@orpc/contract`
- Define route with `path`, `method`, `input`, `output`
2. **Register in router** at `web/contract/router.ts`
- Import directly from domain file (no barrel files)
- Nest by API prefix: `billing: { invoices, bindPartnerStack }`
3. **Create hooks** in `web/service/use-{domain}.ts`
- Use `consoleQuery.{group}.{contract}.queryKey()` for query keys
- Use `consoleClient.{group}.{contract}()` for API calls
## Key Rules
- **Input structure**: Always use `{ params, query?, body? }` format
- **Path params**: Use `{paramName}` in path, match in `params` object
- **Router nesting**: Group by API prefix (e.g., `/billing/*``billing: {}`)
- **No barrel files**: Import directly from specific files
- **Types**: Import from `@/types/`, use `type<T>()` helper
## Type Export
```typescript
export type ConsoleInputs = InferContractRouterInputs<typeof consoleRouterContract>
```

View File

@ -90,7 +90,7 @@ jobs:
uses: actions/setup-node@v6
if: steps.changed-files.outputs.any_changed == 'true'
with:
node-version: 22
node-version: 24
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml

View File

@ -16,10 +16,6 @@ jobs:
name: unit test for Node.js SDK
runs-on: ubuntu-latest
strategy:
matrix:
node-version: [16, 18, 20, 22]
defaults:
run:
working-directory: sdks/nodejs-client
@ -29,10 +25,10 @@ jobs:
with:
persist-credentials: false
- name: Use Node.js ${{ matrix.node-version }}
- name: Use Node.js
uses: actions/setup-node@v6
with:
node-version: ${{ matrix.node-version }}
node-version: 24
cache: ''
cache-dependency-path: 'pnpm-lock.yaml'

View File

@ -57,7 +57,7 @@ jobs:
- name: Set up Node.js
uses: actions/setup-node@v6
with:
node-version: 'lts/*'
node-version: 24
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml

View File

@ -31,7 +31,7 @@ jobs:
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: 22
node-version: 24
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml

View File

@ -417,6 +417,8 @@ SMTP_USERNAME=123
SMTP_PASSWORD=abc
SMTP_USE_TLS=true
SMTP_OPPORTUNISTIC_TLS=false
# Optional: override the local hostname used for SMTP HELO/EHLO
SMTP_LOCAL_HOSTNAME=
# Sendgid configuration
SENDGRID_API_KEY=
# Sentry configuration
@ -713,3 +715,4 @@ ANNOTATION_IMPORT_MAX_CONCURRENT=5
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30

View File

@ -3,6 +3,7 @@ import datetime
import json
import logging
import secrets
import time
from typing import Any
import click
@ -46,6 +47,8 @@ from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpi
from services.plugin.data_migration import PluginDataMigration
from services.plugin.plugin_migration import PluginMigration
from services.plugin.plugin_service import PluginService
from services.retention.conversation.messages_clean_policy import create_message_clean_policy
from services.retention.conversation.messages_clean_service import MessagesCleanService
from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup
from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
@ -2172,3 +2175,79 @@ def migrate_oss(
except Exception as e:
db.session.rollback()
click.echo(click.style(f"Failed to update DB storage_type: {str(e)}", fg="red"))
@click.command("clean-expired-messages", help="Clean expired messages.")
@click.option(
"--start-from",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
required=True,
help="Lower bound (inclusive) for created_at.",
)
@click.option(
"--end-before",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
required=True,
help="Upper bound (exclusive) for created_at.",
)
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for selecting messages.")
@click.option(
"--graceful-period",
default=21,
show_default=True,
help="Graceful period in days after subscription expiration, will be ignored when billing is disabled.",
)
@click.option("--dry-run", is_flag=True, default=False, help="Show messages logs would be cleaned without deleting")
def clean_expired_messages(
batch_size: int,
graceful_period: int,
start_from: datetime.datetime,
end_before: datetime.datetime,
dry_run: bool,
):
"""
Clean expired messages and related data for tenants based on clean policy.
"""
click.echo(click.style("clean_messages: start clean messages.", fg="green"))
start_at = time.perf_counter()
try:
# Create policy based on billing configuration
# NOTE: graceful_period will be ignored when billing is disabled.
policy = create_message_clean_policy(graceful_period_days=graceful_period)
# Create and run the cleanup service
service = MessagesCleanService.from_time_range(
policy=policy,
start_from=start_from,
end_before=end_before,
batch_size=batch_size,
dry_run=dry_run,
)
stats = service.run()
end_at = time.perf_counter()
click.echo(
click.style(
f"clean_messages: completed successfully\n"
f" - Latency: {end_at - start_at:.2f}s\n"
f" - Batches processed: {stats['batches']}\n"
f" - Total messages scanned: {stats['total_messages']}\n"
f" - Messages filtered: {stats['filtered_messages']}\n"
f" - Messages deleted: {stats['total_deleted']}",
fg="green",
)
)
except Exception as e:
end_at = time.perf_counter()
logger.exception("clean_messages failed")
click.echo(
click.style(
f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}",
fg="red",
)
)
raise
click.echo(click.style("messages cleanup completed.", fg="green"))

View File

@ -949,6 +949,12 @@ class MailConfig(BaseSettings):
default=False,
)
SMTP_LOCAL_HOSTNAME: str | None = Field(
description="Override the local hostname used in SMTP HELO/EHLO. "
"Useful behind NAT or when the default hostname causes rejections.",
default=None,
)
EMAIL_SEND_IP_LIMIT_PER_MINUTE: PositiveInt = Field(
description="Maximum number of emails allowed to be sent from the same IP address in a minute",
default=50,

View File

@ -4,7 +4,7 @@ from pydantic_settings import BaseSettings
class VolcengineTOSStorageConfig(BaseSettings):
"""
Configuration settings for Volcengine Tinder Object Storage (TOS)
Configuration settings for Volcengine Torch Object Storage (TOS)
"""
VOLCENGINE_TOS_BUCKET_NAME: str | None = Field(

View File

@ -592,9 +592,12 @@ def _get_conversation(app_model, conversation_id):
if not conversation:
raise NotFound("Conversation Not Exists.")
if not conversation.read_at:
conversation.read_at = naive_utc_now()
conversation.read_account_id = current_user.id
db.session.commit()
db.session.execute(
sa.update(Conversation)
.where(Conversation.id == conversation_id, Conversation.read_at.is_(None))
.values(read_at=naive_utc_now(), read_account_id=current_user.id)
)
db.session.commit()
db.session.refresh(conversation)
return conversation

View File

@ -7,7 +7,7 @@ from typing import Literal, cast
import sqlalchemy as sa
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel
from pydantic import BaseModel, Field
from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound
@ -104,6 +104,15 @@ class DocumentRenamePayload(BaseModel):
name: str
class DocumentDatasetListParam(BaseModel):
page: int = Field(1, title="Page", description="Page number.")
limit: int = Field(20, title="Limit", description="Page size.")
search: str | None = Field(None, alias="keyword", title="Search", description="Search keyword.")
sort_by: str = Field("-created_at", alias="sort", title="SortBy", description="Sort by field.")
status: str | None = Field(None, title="Status", description="Document status.")
fetch_val: str = Field("false", alias="fetch")
register_schema_models(
console_ns,
KnowledgeConfig,
@ -225,14 +234,16 @@ class DatasetDocumentListApi(Resource):
def get(self, dataset_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str)
sort = request.args.get("sort", default="-created_at", type=str)
status = request.args.get("status", default=None, type=str)
raw_args = request.args.to_dict()
param = DocumentDatasetListParam.model_validate(raw_args)
page = param.page
limit = param.limit
search = param.search
sort = param.sort_by
status = param.status
# "yes", "true", "t", "y", "1" convert to True, while others convert to False.
try:
fetch_val = request.args.get("fetch", default="false")
fetch_val = param.fetch_val
if isinstance(fetch_val, bool):
fetch = fetch_val
else:

View File

@ -81,7 +81,7 @@ class ExternalKnowledgeApiPayload(BaseModel):
class ExternalDatasetCreatePayload(BaseModel):
external_knowledge_api_id: str
external_knowledge_id: str
name: str = Field(..., min_length=1, max_length=40)
name: str = Field(..., min_length=1, max_length=100)
description: str | None = Field(None, max_length=400)
external_retrieval_model: dict[str, object] | None = None

View File

@ -33,6 +33,10 @@ class MaxRetriesExceededError(ValueError):
pass
request_error = httpx.RequestError
max_retries_exceeded_error = MaxRetriesExceededError
def _create_proxy_mounts() -> dict[str, httpx.HTTPTransport]:
return {
"http://": httpx.HTTPTransport(

View File

@ -55,7 +55,7 @@ from core.ops.entities.trace_entity import (
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
@ -275,7 +275,7 @@ class AliyunDataTrace(BaseTraceInstance):
service_account = self.get_service_account_with_tenant(app_id)
session_factory = sessionmaker(bind=db.engine)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
app_id=app_id,

View File

@ -1,5 +1,6 @@
from core.plugin.entities.endpoint import EndpointEntityWithInstance
from core.plugin.impl.base import BasePluginClient
from core.plugin.impl.exc import PluginDaemonInternalServerError
class PluginEndpointClient(BasePluginClient):
@ -70,18 +71,27 @@ class PluginEndpointClient(BasePluginClient):
def delete_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str):
"""
Delete the given endpoint.
This operation is idempotent: if the endpoint is already deleted (record not found),
it will return True instead of raising an error.
"""
return self._request_with_plugin_daemon_response(
"POST",
f"plugin/{tenant_id}/endpoint/remove",
bool,
data={
"endpoint_id": endpoint_id,
},
headers={
"Content-Type": "application/json",
},
)
try:
return self._request_with_plugin_daemon_response(
"POST",
f"plugin/{tenant_id}/endpoint/remove",
bool,
data={
"endpoint_id": endpoint_id,
},
headers={
"Content-Type": "application/json",
},
)
except PluginDaemonInternalServerError as e:
# Make delete idempotent: if record is not found, consider it a success
if "record not found" in str(e.description).lower():
return True
raise
def enable_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str):
"""

View File

@ -17,6 +17,7 @@ from core.helper import ssrf_proxy
from core.variables.segments import ArrayFileSegment, FileSegment
from core.workflow.runtime import VariablePool
from ..protocols import FileManagerProtocol, HttpClientProtocol
from .entities import (
HttpRequestNodeAuthorization,
HttpRequestNodeData,
@ -78,6 +79,8 @@ class Executor:
timeout: HttpRequestNodeTimeout,
variable_pool: VariablePool,
max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES,
http_client: HttpClientProtocol = ssrf_proxy,
file_manager: FileManagerProtocol = file_manager,
):
# If authorization API key is present, convert the API key using the variable pool
if node_data.authorization.type == "api-key":
@ -104,6 +107,8 @@ class Executor:
self.data = None
self.json = None
self.max_retries = max_retries
self._http_client = http_client
self._file_manager = file_manager
# init template
self.variable_pool = variable_pool
@ -200,7 +205,7 @@ class Executor:
if file_variable is None:
raise FileFetchError(f"cannot fetch file with selector {file_selector}")
file = file_variable.value
self.content = file_manager.download(file)
self.content = self._file_manager.download(file)
case "x-www-form-urlencoded":
form_data = {
self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template(
@ -239,7 +244,7 @@ class Executor:
):
file_tuple = (
file.filename,
file_manager.download(file),
self._file_manager.download(file),
file.mime_type or "application/octet-stream",
)
if key not in files:
@ -332,19 +337,18 @@ class Executor:
do http request depending on api bundle
"""
_METHOD_MAP = {
"get": ssrf_proxy.get,
"head": ssrf_proxy.head,
"post": ssrf_proxy.post,
"put": ssrf_proxy.put,
"delete": ssrf_proxy.delete,
"patch": ssrf_proxy.patch,
"get": self._http_client.get,
"head": self._http_client.head,
"post": self._http_client.post,
"put": self._http_client.put,
"delete": self._http_client.delete,
"patch": self._http_client.patch,
}
method_lc = self.method.lower()
if method_lc not in _METHOD_MAP:
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
request_args = {
"url": self.url,
"data": self.data,
"files": self.files,
"json": self.json,
@ -357,8 +361,12 @@ class Executor:
}
# request_args = {k: v for k, v in request_args.items() if v is not None}
try:
response: httpx.Response = _METHOD_MAP[method_lc](**request_args, max_retries=self.max_retries)
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
response: httpx.Response = _METHOD_MAP[method_lc](
url=self.url,
**request_args,
max_retries=self.max_retries,
)
except (self._http_client.max_retries_exceeded_error, self._http_client.request_error) as e:
raise HttpRequestNodeError(str(e)) from e
# FIXME: fix type ignore, this maybe httpx type issue
return response

View File

@ -1,10 +1,11 @@
import logging
import mimetypes
from collections.abc import Mapping, Sequence
from typing import Any
from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from configs import dify_config
from core.file import File, FileTransferMethod
from core.file import File, FileTransferMethod, file_manager
from core.helper import ssrf_proxy
from core.tools.tool_file_manager import ToolFileManager
from core.variables.segments import ArrayFileSegment
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@ -13,6 +14,7 @@ from core.workflow.nodes.base import variable_template_parser
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.http_request.executor import Executor
from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
from factories import file_factory
from .entities import (
@ -30,10 +32,35 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
class HttpRequestNode(Node[HttpRequestNodeData]):
node_type = NodeType.HTTP_REQUEST
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
http_client: HttpClientProtocol = ssrf_proxy,
tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
file_manager: FileManagerProtocol = file_manager,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._http_client = http_client
self._tool_file_manager_factory = tool_file_manager_factory
self._file_manager = file_manager
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
@ -71,6 +98,8 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
timeout=self._get_request_timeout(self.node_data),
variable_pool=self.graph_runtime_state.variable_pool,
max_retries=0,
http_client=self._http_client,
file_manager=self._file_manager,
)
process_data["request"] = http_executor.to_log()
@ -199,7 +228,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
mime_type = (
content_disposition_type or content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
)
tool_file_manager = ToolFileManager()
tool_file_manager = self._tool_file_manager_factory()
tool_file = tool_file_manager.create_file_by_raw(
user_id=self.user_id,

View File

@ -1,16 +1,21 @@
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, final
from typing_extensions import override
from configs import dify_config
from core.file import file_manager
from core.helper import ssrf_proxy
from core.helper.code_executor.code_executor import CodeExecutor
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.enums import NodeType
from core.workflow.graph import NodeFactory
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.limits import CodeNodeLimits
from core.workflow.nodes.http_request.node import HttpRequestNode
from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
from core.workflow.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
Jinja2TemplateRenderer,
@ -43,6 +48,9 @@ class DifyNodeFactory(NodeFactory):
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
code_limits: CodeNodeLimits | None = None,
template_renderer: Jinja2TemplateRenderer | None = None,
http_request_http_client: HttpClientProtocol = ssrf_proxy,
http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
http_request_file_manager: FileManagerProtocol = file_manager,
) -> None:
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
@ -61,6 +69,9 @@ class DifyNodeFactory(NodeFactory):
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
)
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
self._http_request_http_client = http_request_http_client
self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
self._http_request_file_manager = http_request_file_manager
@override
def create_node(self, node_config: dict[str, object]) -> Node:
@ -113,6 +124,7 @@ class DifyNodeFactory(NodeFactory):
code_providers=self._code_providers,
code_limits=self._code_limits,
)
if node_type == NodeType.TEMPLATE_TRANSFORM:
return TemplateTransformNode(
id=node_id,
@ -122,6 +134,17 @@ class DifyNodeFactory(NodeFactory):
template_renderer=self._template_renderer,
)
if node_type == NodeType.HTTP_REQUEST:
return HttpRequestNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
http_client=self._http_request_http_client,
tool_file_manager_factory=self._http_request_tool_file_manager_factory,
file_manager=self._http_request_file_manager,
)
return node_class(
id=node_id,
config=node_config,

View File

@ -0,0 +1,29 @@
from typing import Protocol
import httpx
from core.file import File
class HttpClientProtocol(Protocol):
@property
def max_retries_exceeded_error(self) -> type[Exception]: ...
@property
def request_error(self) -> type[Exception]: ...
def get(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def head(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def post(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def put(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def delete(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def patch(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
class FileManagerProtocol(Protocol):
def download(self, f: File, /) -> bytes: ...

View File

@ -189,8 +189,7 @@ class WorkflowEntry:
)
try:
# run node
generator = node.run()
generator = cls._traced_node_run(node)
except Exception as e:
logger.exception(
"error while running node, workflow_id=%s, node_id=%s, node_type=%s, node_version=%s",
@ -323,8 +322,7 @@ class WorkflowEntry:
tenant_id=tenant_id,
)
# run node
generator = node.run()
generator = cls._traced_node_run(node)
return node, generator
except Exception as e:
@ -430,3 +428,26 @@ class WorkflowEntry:
input_value = current_variable.value | input_value
variable_pool.add([variable_node_id] + variable_key_list, input_value)
@staticmethod
def _traced_node_run(node: Node) -> Generator[GraphNodeEventBase, None, None]:
"""
Wraps a node's run method with OpenTelemetry tracing and returns a generator.
"""
# Wrap node.run() with ObservabilityLayer hooks to produce node-level spans
layer = ObservabilityLayer()
layer.on_graph_start()
node.ensure_execution_id()
def _gen():
error: Exception | None = None
layer.on_node_run_start(node)
try:
yield from node.run()
except Exception as exc:
error = exc
raise
finally:
layer.on_node_run_end(node, error)
return _gen()

View File

@ -6,6 +6,7 @@ from .create_site_record_when_app_created import handle as handle_create_site_re
from .delete_tool_parameters_cache_when_sync_draft_workflow import (
handle as handle_delete_tool_parameters_cache_when_sync_draft_workflow,
)
from .queue_credential_sync_when_tenant_created import handle as handle_queue_credential_sync_when_tenant_created
from .sync_plugin_trigger_when_app_created import handle as handle_sync_plugin_trigger_when_app_created
from .sync_webhook_when_app_created import handle as handle_sync_webhook_when_app_created
from .sync_workflow_schedule_when_app_published import handle as handle_sync_workflow_schedule_when_app_published
@ -30,6 +31,7 @@ __all__ = [
"handle_create_installed_app_when_app_created",
"handle_create_site_record_when_app_created",
"handle_delete_tool_parameters_cache_when_sync_draft_workflow",
"handle_queue_credential_sync_when_tenant_created",
"handle_sync_plugin_trigger_when_app_created",
"handle_sync_webhook_when_app_created",
"handle_sync_workflow_schedule_when_app_published",

View File

@ -0,0 +1,19 @@
from configs import dify_config
from events.tenant_event import tenant_was_created
from services.enterprise.workspace_sync import WorkspaceSyncService
@tenant_was_created.connect
def handle(sender, **kwargs):
"""Queue credential sync when a tenant/workspace is created."""
# Only queue sync tasks if plugin manager (enterprise feature) is enabled
if not dify_config.ENTERPRISE_ENABLED:
return
tenant = sender
# Determine source from kwargs if available, otherwise use generic
source = kwargs.get("source", "tenant_created")
# Queue credential sync task to Redis for enterprise backend to process
WorkspaceSyncService.queue_credential_sync(tenant.id, source=source)

View File

@ -4,6 +4,7 @@ from dify_app import DifyApp
def init_app(app: DifyApp):
from commands import (
add_qdrant_index,
clean_expired_messages,
clean_workflow_runs,
cleanup_orphaned_draft_variables,
clear_free_plan_tenant_expired_logs,
@ -58,6 +59,7 @@ def init_app(app: DifyApp):
transform_datasource_credentials,
install_rag_pipeline_plugins,
clean_workflow_runs,
clean_expired_messages,
]
for cmd in cmds_to_register:
app.cli.add_command(cmd)

View File

@ -10,6 +10,7 @@ import os
from dotenv import load_dotenv
from configs import dify_config
from dify_app import DifyApp
logger = logging.getLogger(__name__)
@ -19,12 +20,17 @@ def is_enabled() -> bool:
"""
Check if logstore extension is enabled.
Logstore is considered enabled when:
1. All required Aliyun SLS environment variables are set
2. At least one repository configuration points to a logstore implementation
Returns:
True if all required Aliyun SLS environment variables are set, False otherwise
True if logstore should be initialized, False otherwise
"""
# Load environment variables from .env file
load_dotenv()
# Check if Aliyun SLS connection parameters are configured
required_vars = [
"ALIYUN_SLS_ACCESS_KEY_ID",
"ALIYUN_SLS_ACCESS_KEY_SECRET",
@ -33,24 +39,32 @@ def is_enabled() -> bool:
"ALIYUN_SLS_PROJECT_NAME",
]
all_set = all(os.environ.get(var) for var in required_vars)
sls_vars_set = all(os.environ.get(var) for var in required_vars)
if not all_set:
logger.info("Logstore extension disabled: required Aliyun SLS environment variables not set")
if not sls_vars_set:
return False
return all_set
# Check if any repository configuration points to logstore implementation
repository_configs = [
dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY,
dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY,
dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY,
dify_config.API_WORKFLOW_RUN_REPOSITORY,
]
uses_logstore = any("logstore" in config.lower() for config in repository_configs)
if not uses_logstore:
return False
logger.info("Logstore extension enabled: SLS variables set and repository configured to use logstore")
return True
def init_app(app: DifyApp):
"""
Initialize logstore on application startup.
This function:
1. Creates Aliyun SLS project if it doesn't exist
2. Creates logstores (workflow_execution, workflow_node_execution) if they don't exist
3. Creates indexes with field configurations based on PostgreSQL table structures
This operation is idempotent and only executes once during application startup.
If initialization fails, the application continues running without logstore features.
Args:
app: The Dify application instance
@ -58,17 +72,23 @@ def init_app(app: DifyApp):
try:
from extensions.logstore.aliyun_logstore import AliyunLogStore
logger.info("Initializing logstore...")
logger.info("Initializing Aliyun SLS Logstore...")
# Create logstore client and initialize project/logstores/indexes
# Create logstore client and initialize resources
logstore_client = AliyunLogStore()
logstore_client.init_project_logstore()
# Attach to app for potential later use
app.extensions["logstore"] = logstore_client
logger.info("Logstore initialized successfully")
except Exception:
logger.exception("Failed to initialize logstore")
# Don't raise - allow application to continue even if logstore init fails
# This ensures that the application can still run if logstore is misconfigured
logger.exception(
"Logstore initialization failed. Configuration: endpoint=%s, region=%s, project=%s, timeout=%ss. "
"Application will continue but logstore features will NOT work.",
os.environ.get("ALIYUN_SLS_ENDPOINT"),
os.environ.get("ALIYUN_SLS_REGION"),
os.environ.get("ALIYUN_SLS_PROJECT_NAME"),
os.environ.get("ALIYUN_SLS_CHECK_CONNECTIVITY_TIMEOUT", "30"),
)
# Don't raise - allow application to continue even if logstore setup fails

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import logging
import os
import socket
import threading
import time
from collections.abc import Sequence
@ -179,9 +180,18 @@ class AliyunLogStore:
self.region: str = os.environ.get("ALIYUN_SLS_REGION", "")
self.project_name: str = os.environ.get("ALIYUN_SLS_PROJECT_NAME", "")
self.logstore_ttl: int = int(os.environ.get("ALIYUN_SLS_LOGSTORE_TTL", 365))
self.log_enabled: bool = os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true"
self.log_enabled: bool = (
os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true"
or os.environ.get("LOGSTORE_SQL_ECHO", "false").lower() == "true"
)
self.pg_mode_enabled: bool = os.environ.get("LOGSTORE_PG_MODE_ENABLED", "true").lower() == "true"
# Get timeout configuration
check_timeout = int(os.environ.get("ALIYUN_SLS_CHECK_CONNECTIVITY_TIMEOUT", 30))
# Pre-check endpoint connectivity to prevent indefinite hangs
self._check_endpoint_connectivity(self.endpoint, check_timeout)
# Initialize SDK client
self.client = LogClient(
self.endpoint, self.access_key_id, self.access_key_secret, auth_version=AUTH_VERSION_4, region=self.region
@ -199,6 +209,49 @@ class AliyunLogStore:
self.__class__._initialized = True
@staticmethod
def _check_endpoint_connectivity(endpoint: str, timeout: int) -> None:
"""
Check if the SLS endpoint is reachable before creating LogClient.
Prevents indefinite hangs when the endpoint is unreachable.
Args:
endpoint: SLS endpoint URL
timeout: Connection timeout in seconds
Raises:
ConnectionError: If endpoint is not reachable
"""
# Parse endpoint URL to extract hostname and port
from urllib.parse import urlparse
parsed_url = urlparse(endpoint if "://" in endpoint else f"http://{endpoint}")
hostname = parsed_url.hostname
port = parsed_url.port or (443 if parsed_url.scheme == "https" else 80)
if not hostname:
raise ConnectionError(f"Invalid endpoint URL: {endpoint}")
sock = None
try:
# Create socket and set timeout
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(timeout)
sock.connect((hostname, port))
except Exception as e:
# Catch all exceptions and provide clear error message
error_type = type(e).__name__
raise ConnectionError(
f"Cannot connect to {hostname}:{port} (timeout={timeout}s): [{error_type}] {e}"
) from e
finally:
# Ensure socket is properly closed
if sock:
try:
sock.close()
except Exception: # noqa: S110
pass # Ignore errors during cleanup
@property
def supports_pg_protocol(self) -> bool:
"""Check if PG protocol is supported and enabled."""
@ -220,19 +273,16 @@ class AliyunLogStore:
try:
self._use_pg_protocol = self._pg_client.init_connection()
if self._use_pg_protocol:
logger.info("Successfully connected to project %s using PG protocol", self.project_name)
logger.info("Using PG protocol for project %s", self.project_name)
# Check if scan_index is enabled for all logstores
self._check_and_disable_pg_if_scan_index_disabled()
return True
else:
logger.info("PG connection failed for project %s. Will use SDK mode.", self.project_name)
logger.info("Using SDK mode for project %s", self.project_name)
return False
except Exception as e:
logger.warning(
"Failed to establish PG connection for project %s: %s. Will use SDK mode.",
self.project_name,
str(e),
)
logger.info("Using SDK mode for project %s", self.project_name)
logger.debug("PG connection details: %s", str(e))
self._use_pg_protocol = False
return False
@ -246,10 +296,6 @@ class AliyunLogStore:
if self._use_pg_protocol:
return
logger.info(
"Attempting delayed PG connection for newly created project %s ...",
self.project_name,
)
self._attempt_pg_connection_init()
self.__class__._pg_connection_timer = None
@ -284,11 +330,7 @@ class AliyunLogStore:
if project_is_new:
# For newly created projects, schedule delayed PG connection
self._use_pg_protocol = False
logger.info(
"Project %s is newly created. Will use SDK mode and schedule PG connection attempt in %d seconds.",
self.project_name,
self.__class__._pg_connection_delay,
)
logger.info("Using SDK mode for project %s (newly created)", self.project_name)
if self.__class__._pg_connection_timer is not None:
self.__class__._pg_connection_timer.cancel()
self.__class__._pg_connection_timer = threading.Timer(
@ -299,7 +341,6 @@ class AliyunLogStore:
self.__class__._pg_connection_timer.start()
else:
# For existing projects, attempt PG connection immediately
logger.info("Project %s already exists. Attempting PG connection...", self.project_name)
self._attempt_pg_connection_init()
def _check_and_disable_pg_if_scan_index_disabled(self) -> None:
@ -318,9 +359,9 @@ class AliyunLogStore:
existing_config = self.get_existing_index_config(logstore_name)
if existing_config and not existing_config.scan_index:
logger.info(
"Logstore %s has scan_index=false, USE SDK mode for read/write operations. "
"PG protocol requires scan_index to be enabled.",
"Logstore %s requires scan_index enabled, using SDK mode for project %s",
logstore_name,
self.project_name,
)
self._use_pg_protocol = False
# Close PG connection if it was initialized
@ -748,7 +789,6 @@ class AliyunLogStore:
reverse=reverse,
)
# Log query info if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore] GET_LOGS | logstore=%s | project=%s | query=%s | "
@ -770,7 +810,6 @@ class AliyunLogStore:
for log in logs:
result.append(log.get_contents())
# Log result count if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore] GET_LOGS RESULT | logstore=%s | returned_count=%d",
@ -845,7 +884,6 @@ class AliyunLogStore:
query=full_query,
)
# Log query info if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore-SDK] EXECUTE_SQL | logstore=%s | project=%s | from_time=%d | to_time=%d | full_query=%s",
@ -853,8 +891,7 @@ class AliyunLogStore:
self.project_name,
from_time,
to_time,
query,
sql,
full_query,
)
try:
@ -865,7 +902,6 @@ class AliyunLogStore:
for log in logs:
result.append(log.get_contents())
# Log result count if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore-SDK] EXECUTE_SQL RESULT | logstore=%s | returned_count=%d",

View File

@ -7,8 +7,7 @@ from contextlib import contextmanager
from typing import Any
import psycopg2
import psycopg2.pool
from psycopg2 import InterfaceError, OperationalError
from sqlalchemy import create_engine
from configs import dify_config
@ -16,11 +15,7 @@ logger = logging.getLogger(__name__)
class AliyunLogStorePG:
"""
PostgreSQL protocol support for Aliyun SLS LogStore.
Handles PG connection pooling and operations for regions that support PG protocol.
"""
"""PostgreSQL protocol support for Aliyun SLS LogStore using SQLAlchemy connection pool."""
def __init__(self, access_key_id: str, access_key_secret: str, endpoint: str, project_name: str):
"""
@ -36,24 +31,11 @@ class AliyunLogStorePG:
self._access_key_secret = access_key_secret
self._endpoint = endpoint
self.project_name = project_name
self._pg_pool: psycopg2.pool.SimpleConnectionPool | None = None
self._engine: Any = None # SQLAlchemy Engine
self._use_pg_protocol = False
def _check_port_connectivity(self, host: str, port: int, timeout: float = 2.0) -> bool:
"""
Check if a TCP port is reachable using socket connection.
This provides a fast check before attempting full database connection,
preventing long waits when connecting to unsupported regions.
Args:
host: Hostname or IP address
port: Port number
timeout: Connection timeout in seconds (default: 2.0)
Returns:
True if port is reachable, False otherwise
"""
"""Fast TCP port check to avoid long waits on unsupported regions."""
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(timeout)
@ -65,166 +47,101 @@ class AliyunLogStorePG:
return False
def init_connection(self) -> bool:
"""
Initialize PostgreSQL connection pool for SLS PG protocol support.
Attempts to connect to SLS using PostgreSQL protocol. If successful, sets
_use_pg_protocol to True and creates a connection pool. If connection fails
(region doesn't support PG protocol or other errors), returns False.
Returns:
True if PG protocol is supported and initialized, False otherwise
"""
"""Initialize SQLAlchemy connection pool with pool_recycle and TCP keepalive support."""
try:
# Extract hostname from endpoint (remove protocol if present)
pg_host = self._endpoint.replace("http://", "").replace("https://", "")
# Get pool configuration
pg_max_connections = int(os.environ.get("ALIYUN_SLS_PG_MAX_CONNECTIONS", 10))
# Pool configuration
pool_size = int(os.environ.get("ALIYUN_SLS_PG_POOL_SIZE", 5))
max_overflow = int(os.environ.get("ALIYUN_SLS_PG_MAX_OVERFLOW", 5))
pool_recycle = int(os.environ.get("ALIYUN_SLS_PG_POOL_RECYCLE", 3600))
pool_pre_ping = os.environ.get("ALIYUN_SLS_PG_POOL_PRE_PING", "false").lower() == "true"
logger.debug(
"Check PG protocol connection to SLS: host=%s, project=%s",
pg_host,
self.project_name,
)
logger.debug("Check PG protocol connection to SLS: host=%s, project=%s", pg_host, self.project_name)
# Fast port connectivity check before attempting full connection
# This prevents long waits when connecting to unsupported regions
# Fast port check to avoid long waits
if not self._check_port_connectivity(pg_host, 5432, timeout=1.0):
logger.info(
"USE SDK mode for read/write operations, host=%s",
pg_host,
)
logger.debug("Using SDK mode for host=%s", pg_host)
return False
# Create connection pool
self._pg_pool = psycopg2.pool.SimpleConnectionPool(
minconn=1,
maxconn=pg_max_connections,
host=pg_host,
port=5432,
database=self.project_name,
user=self._access_key_id,
password=self._access_key_secret,
sslmode="require",
connect_timeout=5,
application_name=f"Dify-{dify_config.project.version}",
# Build connection URL
from urllib.parse import quote_plus
username = quote_plus(self._access_key_id)
password = quote_plus(self._access_key_secret)
database_url = (
f"postgresql+psycopg2://{username}:{password}@{pg_host}:5432/{self.project_name}?sslmode=require"
)
# Note: Skip test query because SLS PG protocol only supports SELECT/INSERT on actual tables
# Connection pool creation success already indicates connectivity
# Create SQLAlchemy engine with connection pool
self._engine = create_engine(
database_url,
pool_size=pool_size,
max_overflow=max_overflow,
pool_recycle=pool_recycle,
pool_pre_ping=pool_pre_ping,
pool_timeout=30,
connect_args={
"connect_timeout": 5,
"application_name": f"Dify-{dify_config.project.version}-fixautocommit",
"keepalives": 1,
"keepalives_idle": 60,
"keepalives_interval": 10,
"keepalives_count": 5,
},
)
self._use_pg_protocol = True
logger.info(
"PG protocol initialized successfully for SLS project=%s. Will use PG for read/write operations.",
"PG protocol initialized for SLS project=%s (pool_size=%d, pool_recycle=%ds)",
self.project_name,
pool_size,
pool_recycle,
)
return True
except Exception as e:
# PG connection failed - fallback to SDK mode
self._use_pg_protocol = False
if self._pg_pool:
if self._engine:
try:
self._pg_pool.closeall()
self._engine.dispose()
except Exception:
logger.debug("Failed to close PG connection pool during cleanup, ignoring")
self._pg_pool = None
logger.debug("Failed to dispose engine during cleanup, ignoring")
self._engine = None
logger.info(
"PG protocol connection failed (region may not support PG protocol): %s. "
"Falling back to SDK mode for read/write operations.",
str(e),
)
return False
def _is_connection_valid(self, conn: Any) -> bool:
"""
Check if a connection is still valid.
Args:
conn: psycopg2 connection object
Returns:
True if connection is valid, False otherwise
"""
try:
# Check if connection is closed
if conn.closed:
return False
# Quick ping test - execute a lightweight query
# For SLS PG protocol, we can't use SELECT 1 without FROM,
# so we just check the connection status
with conn.cursor() as cursor:
cursor.execute("SELECT 1")
cursor.fetchone()
return True
except Exception:
logger.debug("Using SDK mode for region: %s", str(e))
return False
@contextmanager
def _get_connection(self):
"""
Context manager to get a PostgreSQL connection from the pool.
"""Get connection from SQLAlchemy pool. Pool handles recycle, invalidation, and keepalive automatically."""
if not self._engine:
raise RuntimeError("SQLAlchemy engine is not initialized")
Automatically validates and refreshes stale connections.
Note: Aliyun SLS PG protocol does not support transactions, so we always
use autocommit mode.
Yields:
psycopg2 connection object
Raises:
RuntimeError: If PG pool is not initialized
"""
if not self._pg_pool:
raise RuntimeError("PG connection pool is not initialized")
conn = self._pg_pool.getconn()
connection = self._engine.raw_connection()
try:
# Validate connection and get a fresh one if needed
if not self._is_connection_valid(conn):
logger.debug("Connection is stale, marking as bad and getting a new one")
# Mark connection as bad and get a new one
self._pg_pool.putconn(conn, close=True)
conn = self._pg_pool.getconn()
# Aliyun SLS PG protocol does not support transactions, always use autocommit
conn.autocommit = True
yield conn
connection.autocommit = True # SLS PG protocol does not support transactions
yield connection
except Exception:
raise
finally:
# Return connection to pool (or close if it's bad)
if self._is_connection_valid(conn):
self._pg_pool.putconn(conn)
else:
self._pg_pool.putconn(conn, close=True)
connection.close()
def close(self) -> None:
"""Close the PostgreSQL connection pool."""
if self._pg_pool:
"""Dispose SQLAlchemy engine and close all connections."""
if self._engine:
try:
self._pg_pool.closeall()
logger.info("PG connection pool closed")
self._engine.dispose()
logger.info("SQLAlchemy engine disposed")
except Exception:
logger.exception("Failed to close PG connection pool")
logger.exception("Failed to dispose engine")
def _is_retriable_error(self, error: Exception) -> bool:
"""
Check if an error is retriable (connection-related issues).
Args:
error: Exception to check
Returns:
True if the error is retriable, False otherwise
"""
# Retry on connection-related errors
if isinstance(error, (OperationalError, InterfaceError)):
"""Check if error is retriable (connection-related issues)."""
# Check for psycopg2 connection errors directly
if isinstance(error, (psycopg2.OperationalError, psycopg2.InterfaceError)):
return True
# Check error message for specific connection issues
error_msg = str(error).lower()
retriable_patterns = [
"connection",
@ -234,34 +151,18 @@ class AliyunLogStorePG:
"reset by peer",
"no route to host",
"network",
"operational error",
"interface error",
]
return any(pattern in error_msg for pattern in retriable_patterns)
def put_log(self, logstore: str, contents: Sequence[tuple[str, str]], log_enabled: bool = False) -> None:
"""
Write log to SLS using PostgreSQL protocol with automatic retry.
Note: SLS PG protocol only supports INSERT (not UPDATE). This uses append-only
writes with log_version field for versioning, same as SDK implementation.
Args:
logstore: Name of the logstore table
contents: List of (field_name, value) tuples
log_enabled: Whether to enable logging
Raises:
psycopg2.Error: If database operation fails after all retries
"""
"""Write log to SLS using INSERT with automatic retry (3 attempts with exponential backoff)."""
if not contents:
return
# Extract field names and values from contents
fields = [field_name for field_name, _ in contents]
values = [value for _, value in contents]
# Build INSERT statement with literal values
# Note: Aliyun SLS PG protocol doesn't support parameterized queries,
# so we need to use mogrify to safely create literal values
field_list = ", ".join([f'"{field}"' for field in fields])
if log_enabled:
@ -272,67 +173,40 @@ class AliyunLogStorePG:
len(contents),
)
# Retry configuration
max_retries = 3
retry_delay = 0.1 # Start with 100ms
retry_delay = 0.1
for attempt in range(max_retries):
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
# Use mogrify to safely convert values to SQL literals
placeholders = ", ".join(["%s"] * len(fields))
values_literal = cursor.mogrify(f"({placeholders})", values).decode("utf-8")
insert_sql = f'INSERT INTO "{logstore}" ({field_list}) VALUES {values_literal}'
cursor.execute(insert_sql)
# Success - exit retry loop
return
except psycopg2.Error as e:
# Check if error is retriable
if not self._is_retriable_error(e):
# Not a retriable error (e.g., data validation error), fail immediately
logger.exception(
"Failed to put logs to logstore %s via PG protocol (non-retriable error)",
logstore,
)
logger.exception("Failed to put logs to logstore %s (non-retriable error)", logstore)
raise
# Retriable error - log and retry if we have attempts left
if attempt < max_retries - 1:
logger.warning(
"Failed to put logs to logstore %s via PG protocol (attempt %d/%d): %s. Retrying...",
"Failed to put logs to logstore %s (attempt %d/%d): %s. Retrying...",
logstore,
attempt + 1,
max_retries,
str(e),
)
time.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff
retry_delay *= 2
else:
# Last attempt failed
logger.exception(
"Failed to put logs to logstore %s via PG protocol after %d attempts",
logstore,
max_retries,
)
logger.exception("Failed to put logs to logstore %s after %d attempts", logstore, max_retries)
raise
def execute_sql(self, sql: str, logstore: str, log_enabled: bool = False) -> list[dict[str, Any]]:
"""
Execute SQL query using PostgreSQL protocol with automatic retry.
Args:
sql: SQL query string
logstore: Name of the logstore (for logging purposes)
log_enabled: Whether to enable logging
Returns:
List of result rows as dictionaries
Raises:
psycopg2.Error: If database operation fails after all retries
"""
"""Execute SQL query with automatic retry (3 attempts with exponential backoff)."""
if log_enabled:
logger.info(
"[LogStore-PG] EXECUTE_SQL | logstore=%s | project=%s | sql=%s",
@ -341,20 +215,16 @@ class AliyunLogStorePG:
sql,
)
# Retry configuration
max_retries = 3
retry_delay = 0.1 # Start with 100ms
retry_delay = 0.1
for attempt in range(max_retries):
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(sql)
# Get column names from cursor description
columns = [desc[0] for desc in cursor.description]
# Fetch all results and convert to list of dicts
result = []
for row in cursor.fetchall():
row_dict = {}
@ -372,36 +242,31 @@ class AliyunLogStorePG:
return result
except psycopg2.Error as e:
# Check if error is retriable
if not self._is_retriable_error(e):
# Not a retriable error (e.g., SQL syntax error), fail immediately
logger.exception(
"Failed to execute SQL query on logstore %s via PG protocol (non-retriable error): sql=%s",
"Failed to execute SQL on logstore %s (non-retriable error): sql=%s",
logstore,
sql,
)
raise
# Retriable error - log and retry if we have attempts left
if attempt < max_retries - 1:
logger.warning(
"Failed to execute SQL query on logstore %s via PG protocol (attempt %d/%d): %s. Retrying...",
"Failed to execute SQL on logstore %s (attempt %d/%d): %s. Retrying...",
logstore,
attempt + 1,
max_retries,
str(e),
)
time.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff
retry_delay *= 2
else:
# Last attempt failed
logger.exception(
"Failed to execute SQL query on logstore %s via PG protocol after %d attempts: sql=%s",
"Failed to execute SQL on logstore %s after %d attempts: sql=%s",
logstore,
max_retries,
sql,
)
raise
# This line should never be reached due to raise above, but makes type checker happy
return []

View File

@ -0,0 +1,29 @@
"""
LogStore repository utilities.
"""
from typing import Any
def safe_float(value: Any, default: float = 0.0) -> float:
"""
Safely convert a value to float, handling 'null' strings and None.
"""
if value is None or value in {"null", ""}:
return default
try:
return float(value)
except (ValueError, TypeError):
return default
def safe_int(value: Any, default: int = 0) -> int:
"""
Safely convert a value to int, handling 'null' strings and None.
"""
if value is None or value in {"null", ""}:
return default
try:
return int(float(value))
except (ValueError, TypeError):
return default

View File

@ -14,6 +14,8 @@ from typing import Any
from sqlalchemy.orm import sessionmaker
from extensions.logstore.aliyun_logstore import AliyunLogStore
from extensions.logstore.repositories import safe_float, safe_int
from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value
from models.workflow import WorkflowNodeExecutionModel
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
@ -52,9 +54,8 @@ def _dict_to_workflow_node_execution_model(data: dict[str, Any]) -> WorkflowNode
model.created_by_role = data.get("created_by_role") or ""
model.created_by = data.get("created_by") or ""
# Numeric fields with defaults
model.index = int(data.get("index", 0))
model.elapsed_time = float(data.get("elapsed_time", 0))
model.index = safe_int(data.get("index", 0))
model.elapsed_time = safe_float(data.get("elapsed_time", 0))
# Optional fields
model.workflow_run_id = data.get("workflow_run_id")
@ -130,6 +131,12 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
node_id,
)
try:
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_workflow_id = escape_identifier(workflow_id)
escaped_node_id = escape_identifier(node_id)
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of each record)
@ -138,10 +145,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
WHERE tenant_id = '{tenant_id}'
AND app_id = '{app_id}'
AND workflow_id = '{workflow_id}'
AND node_id = '{node_id}'
WHERE tenant_id = '{escaped_tenant_id}'
AND app_id = '{escaped_app_id}'
AND workflow_id = '{escaped_workflow_id}'
AND node_id = '{escaped_node_id}'
AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
@ -153,7 +160,8 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
else:
# Use SDK with LogStore query syntax
query = (
f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_id: {workflow_id} and node_id: {node_id}"
f"tenant_id: {escaped_tenant_id} and app_id: {escaped_app_id} "
f"and workflow_id: {escaped_workflow_id} and node_id: {escaped_node_id}"
)
from_time = 0
to_time = int(time.time()) # now
@ -227,6 +235,11 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
workflow_run_id,
)
try:
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_workflow_run_id = escape_identifier(workflow_run_id)
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of each record)
@ -235,9 +248,9 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
WHERE tenant_id = '{tenant_id}'
AND app_id = '{app_id}'
AND workflow_run_id = '{workflow_run_id}'
WHERE tenant_id = '{escaped_tenant_id}'
AND app_id = '{escaped_app_id}'
AND workflow_run_id = '{escaped_workflow_run_id}'
AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 1000
@ -248,7 +261,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
)
else:
# Use SDK with LogStore query syntax
query = f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_run_id: {workflow_run_id}"
query = (
f"tenant_id: {escaped_tenant_id} and app_id: {escaped_app_id} "
f"and workflow_run_id: {escaped_workflow_run_id}"
)
from_time = 0
to_time = int(time.time()) # now
@ -313,16 +329,24 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
"""
logger.debug("get_execution_by_id: execution_id=%s, tenant_id=%s", execution_id, tenant_id)
try:
# Escape parameters to prevent SQL injection
escaped_execution_id = escape_identifier(execution_id)
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
tenant_filter = f"AND tenant_id = '{tenant_id}'" if tenant_id else ""
if tenant_id:
escaped_tenant_id = escape_identifier(tenant_id)
tenant_filter = f"AND tenant_id = '{escaped_tenant_id}'"
else:
tenant_filter = ""
sql_query = f"""
SELECT * FROM (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
WHERE id = '{execution_id}' {tenant_filter} AND __time__ > 0
WHERE id = '{escaped_execution_id}' {tenant_filter} AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 1
"""
@ -332,10 +356,14 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
)
else:
# Use SDK with LogStore query syntax
# Note: Values must be quoted in LogStore query syntax to prevent injection
if tenant_id:
query = f"id: {execution_id} and tenant_id: {tenant_id}"
query = (
f"id:{escape_logstore_query_value(execution_id)} "
f"and tenant_id:{escape_logstore_query_value(tenant_id)}"
)
else:
query = f"id: {execution_id}"
query = f"id:{escape_logstore_query_value(execution_id)}"
from_time = 0
to_time = int(time.time()) # now

View File

@ -10,6 +10,7 @@ Key Features:
- Optimized deduplication using finished_at IS NOT NULL filter
- Window functions only when necessary (running status queries)
- Multi-tenant data isolation and security
- SQL injection prevention via parameter escaping
"""
import logging
@ -22,6 +23,8 @@ from typing import Any, cast
from sqlalchemy.orm import sessionmaker
from extensions.logstore.aliyun_logstore import AliyunLogStore
from extensions.logstore.repositories import safe_float, safe_int
from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import WorkflowRun
@ -63,10 +66,9 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun:
model.created_by_role = data.get("created_by_role") or ""
model.created_by = data.get("created_by") or ""
# Numeric fields with defaults
model.total_tokens = int(data.get("total_tokens", 0))
model.total_steps = int(data.get("total_steps", 0))
model.exceptions_count = int(data.get("exceptions_count", 0))
model.total_tokens = safe_int(data.get("total_tokens", 0))
model.total_steps = safe_int(data.get("total_steps", 0))
model.exceptions_count = safe_int(data.get("exceptions_count", 0))
# Optional fields
model.graph = data.get("graph")
@ -101,7 +103,8 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun:
if model.finished_at and model.created_at:
model.elapsed_time = (model.finished_at - model.created_at).total_seconds()
else:
model.elapsed_time = float(data.get("elapsed_time", 0))
# Use safe conversion to handle 'null' strings and None values
model.elapsed_time = safe_float(data.get("elapsed_time", 0))
return model
@ -165,16 +168,26 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
status,
)
# Convert triggered_from to list if needed
if isinstance(triggered_from, WorkflowRunTriggeredFrom):
if isinstance(triggered_from, (WorkflowRunTriggeredFrom, str)):
triggered_from_list = [triggered_from]
else:
triggered_from_list = list(triggered_from)
# Build triggered_from filter
triggered_from_filter = " OR ".join([f"triggered_from='{tf.value}'" for tf in triggered_from_list])
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
# Build status filter
status_filter = f"AND status='{status}'" if status else ""
# Build triggered_from filter with escaped values
# Support both enum and string values for triggered_from
triggered_from_filter = " OR ".join(
[
f"triggered_from='{escape_sql_string(tf.value if isinstance(tf, WorkflowRunTriggeredFrom) else tf)}'"
for tf in triggered_from_list
]
)
# Build status filter with escaped value
status_filter = f"AND status='{escape_sql_string(status)}'" if status else ""
# Build last_id filter for pagination
# Note: This is simplified. In production, you'd need to track created_at from last record
@ -188,8 +201,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
SELECT * FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND ({triggered_from_filter})
{status_filter}
{last_id_filter}
@ -232,6 +245,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
logger.debug("get_workflow_run_by_id: tenant_id=%s, app_id=%s, run_id=%s", tenant_id, app_id, run_id)
try:
# Escape parameters to prevent SQL injection
escaped_run_id = escape_identifier(run_id)
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
@ -240,7 +258,10 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_execution_logstore}"
WHERE id = '{run_id}' AND tenant_id = '{tenant_id}' AND app_id = '{app_id}' AND __time__ > 0
WHERE id = '{escaped_run_id}'
AND tenant_id = '{escaped_tenant_id}'
AND app_id = '{escaped_app_id}'
AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
"""
@ -250,7 +271,12 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
)
else:
# Use SDK with LogStore query syntax
query = f"id: {run_id} and tenant_id: {tenant_id} and app_id: {app_id}"
# Note: Values must be quoted in LogStore query syntax to prevent injection
query = (
f"id:{escape_logstore_query_value(run_id)} "
f"and tenant_id:{escape_logstore_query_value(tenant_id)} "
f"and app_id:{escape_logstore_query_value(app_id)}"
)
from_time = 0
to_time = int(time.time()) # now
@ -323,6 +349,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
logger.debug("get_workflow_run_by_id_without_tenant: run_id=%s", run_id)
try:
# Escape parameter to prevent SQL injection
escaped_run_id = escape_identifier(run_id)
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
@ -331,7 +360,7 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_execution_logstore}"
WHERE id = '{run_id}' AND __time__ > 0
WHERE id = '{escaped_run_id}' AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
"""
@ -341,7 +370,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
)
else:
# Use SDK with LogStore query syntax
query = f"id: {run_id}"
# Note: Values must be quoted in LogStore query syntax
query = f"id:{escape_logstore_query_value(run_id)}"
from_time = 0
to_time = int(time.time()) # now
@ -410,6 +440,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
triggered_from,
status,
)
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_triggered_from = escape_sql_string(triggered_from)
# Build time range filter
time_filter = ""
if time_range:
@ -418,6 +453,8 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
# If status is provided, simple count
if status:
escaped_status = escape_sql_string(status)
if status == "running":
# Running status requires window function
sql = f"""
@ -425,9 +462,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND status='running'
{time_filter}
) t
@ -438,10 +475,10 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f"""
SELECT COUNT(DISTINCT id) as count
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
AND status='{status}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND status='{escaped_status}'
AND finished_at IS NOT NULL
{time_filter}
"""
@ -467,13 +504,14 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
# No status filter - get counts grouped by status
# Use optimized query for finished runs, separate query for running
try:
# Escape parameters (already escaped above, reuse variables)
# Count finished runs grouped by status
finished_sql = f"""
SELECT status, COUNT(DISTINCT id) as count
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY status
@ -485,9 +523,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND status='running'
{time_filter}
) t
@ -546,7 +584,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
logger.debug(
"get_daily_runs_statistics: tenant_id=%s, app_id=%s, triggered_from=%s", tenant_id, app_id, triggered_from
)
# Build time range filter
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_triggered_from = escape_sql_string(triggered_from)
# Build time range filter (datetime.isoformat() is safe)
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@ -557,9 +601,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT id) as runs
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
@ -601,7 +645,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
app_id,
triggered_from,
)
# Build time range filter
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_triggered_from = escape_sql_string(triggered_from)
# Build time range filter (datetime.isoformat() is safe)
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@ -611,9 +661,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT created_by) as terminal_count
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
@ -655,7 +705,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
app_id,
triggered_from,
)
# Build time range filter
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_triggered_from = escape_sql_string(triggered_from)
# Build time range filter (datetime.isoformat() is safe)
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@ -665,9 +721,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, SUM(total_tokens) as token_count
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
@ -709,7 +765,13 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
app_id,
triggered_from,
)
# Build time range filter
# Escape parameters to prevent SQL injection
escaped_tenant_id = escape_identifier(tenant_id)
escaped_app_id = escape_identifier(app_id)
escaped_triggered_from = escape_sql_string(triggered_from)
# Build time range filter (datetime.isoformat() is safe)
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
@ -726,9 +788,9 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
created_by,
COUNT(DISTINCT id) AS interactions
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
WHERE tenant_id='{escaped_tenant_id}'
AND app_id='{escaped_app_id}'
AND triggered_from='{escaped_triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date, created_by

View File

@ -10,6 +10,7 @@ from sqlalchemy.orm import sessionmaker
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.entities import WorkflowExecution
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.logstore.aliyun_logstore import AliyunLogStore
from libs.helper import extract_tenant_id
from models import (
@ -22,18 +23,6 @@ from models.enums import WorkflowRunTriggeredFrom
logger = logging.getLogger(__name__)
def to_serializable(obj):
"""
Convert non-JSON-serializable objects into JSON-compatible formats.
- Uses `to_dict()` if it's a callable method.
- Falls back to string representation.
"""
if hasattr(obj, "to_dict") and callable(obj.to_dict):
return obj.to_dict()
return str(obj)
class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
def __init__(
self,
@ -79,7 +68,7 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
# Control flag for dual-write (write to both LogStore and SQL database)
# Set to True to enable dual-write for safe migration, False to use LogStore only
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true"
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "false").lower() == "true"
# Control flag for whether to write the `graph` field to LogStore.
# If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field;
@ -113,6 +102,9 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
# Generate log_version as nanosecond timestamp for record versioning
log_version = str(time.time_ns())
# Use WorkflowRuntimeTypeConverter to handle complex types (Segment, File, etc.)
json_converter = WorkflowRuntimeTypeConverter()
logstore_model = [
("id", domain_model.id_),
("log_version", log_version), # Add log_version field for append-only writes
@ -127,19 +119,19 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
("version", domain_model.workflow_version),
(
"graph",
json.dumps(domain_model.graph, ensure_ascii=False, default=to_serializable)
json.dumps(json_converter.to_json_encodable(domain_model.graph), ensure_ascii=False)
if domain_model.graph and self._enable_put_graph_field
else "{}",
),
(
"inputs",
json.dumps(domain_model.inputs, ensure_ascii=False, default=to_serializable)
json.dumps(json_converter.to_json_encodable(domain_model.inputs), ensure_ascii=False)
if domain_model.inputs
else "{}",
),
(
"outputs",
json.dumps(domain_model.outputs, ensure_ascii=False, default=to_serializable)
json.dumps(json_converter.to_json_encodable(domain_model.outputs), ensure_ascii=False)
if domain_model.outputs
else "{}",
),

View File

@ -24,6 +24,8 @@ from core.workflow.enums import NodeType
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.logstore.aliyun_logstore import AliyunLogStore
from extensions.logstore.repositories import safe_float, safe_int
from extensions.logstore.sql_escape import escape_identifier
from libs.helper import extract_tenant_id
from models import (
Account,
@ -73,7 +75,7 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut
node_execution_id=data.get("node_execution_id"),
workflow_id=data.get("workflow_id", ""),
workflow_execution_id=data.get("workflow_run_id"),
index=int(data.get("index", 0)),
index=safe_int(data.get("index", 0)),
predecessor_node_id=data.get("predecessor_node_id"),
node_id=data.get("node_id", ""),
node_type=NodeType(data.get("node_type", "start")),
@ -83,7 +85,7 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut
outputs=outputs,
status=status,
error=data.get("error"),
elapsed_time=float(data.get("elapsed_time", 0.0)),
elapsed_time=safe_float(data.get("elapsed_time", 0.0)),
metadata=domain_metadata,
created_at=created_at,
finished_at=finished_at,
@ -147,7 +149,7 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
# Control flag for dual-write (write to both LogStore and SQL database)
# Set to True to enable dual-write for safe migration, False to use LogStore only
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true"
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "false").lower() == "true"
def _to_logstore_model(self, domain_model: WorkflowNodeExecution) -> Sequence[tuple[str, str]]:
logger.debug(
@ -274,16 +276,34 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
Save or update the inputs, process_data, or outputs associated with a specific
node_execution record.
For LogStore implementation, this is similar to save() since we always write
complete records. We append a new record with updated data fields.
For LogStore implementation, this is a no-op for the LogStore write because save()
already writes all fields including inputs, process_data, and outputs. The caller
typically calls save() first to persist status/metadata, then calls save_execution_data()
to persist data fields. Since LogStore writes complete records atomically, we don't
need a separate write here to avoid duplicate records.
However, if dual-write is enabled, we still need to call the SQL repository's
save_execution_data() method to properly update the SQL database.
Args:
execution: The NodeExecution instance with data to save
"""
logger.debug("save_execution_data: id=%s, node_execution_id=%s", execution.id, execution.node_execution_id)
# In LogStore, we simply write a new complete record with the data
# The log_version timestamp will ensure this is treated as the latest version
self.save(execution)
logger.debug(
"save_execution_data: no-op for LogStore (data already saved by save()): id=%s, node_execution_id=%s",
execution.id,
execution.node_execution_id,
)
# No-op for LogStore: save() already writes all fields including inputs, process_data, and outputs
# Calling save() again would create a duplicate record in the append-only LogStore
# Dual-write to SQL database if enabled (for safe migration)
if self._enable_dual_write:
try:
self.sql_repository.save_execution_data(execution)
logger.debug("Dual-write: saved node execution data to SQL database: id=%s", execution.id)
except Exception:
logger.exception("Failed to dual-write node execution data to SQL database: id=%s", execution.id)
# Don't raise - LogStore write succeeded, SQL is just a backup
def get_by_workflow_run(
self,
@ -292,8 +312,8 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all NodeExecution instances for a specific workflow run.
Uses LogStore SQL query with finished_at IS NOT NULL filter for deduplication.
This ensures we only get the final version of each node execution.
Uses LogStore SQL query with window function to get the latest version of each node execution.
This ensures we only get the most recent version of each node execution record.
Args:
workflow_run_id: The workflow run ID
order_config: Optional configuration for ordering results
@ -304,16 +324,19 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
A list of NodeExecution instances
Note:
This method filters by finished_at IS NOT NULL to avoid duplicates from
version updates. For complete history including intermediate states,
a different query strategy would be needed.
This method uses ROW_NUMBER() window function partitioned by node_execution_id
to get the latest version (highest log_version) of each node execution.
"""
logger.debug("get_by_workflow_run: workflow_run_id=%s, order_config=%s", workflow_run_id, order_config)
# Build SQL query with deduplication using finished_at IS NOT NULL
# This optimization avoids window functions for common case where we only
# want the final state of each node execution
# Build SQL query with deduplication using window function
# ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC)
# ensures we get the latest version of each node execution
# Build ORDER BY clause
# Escape parameters to prevent SQL injection
escaped_workflow_run_id = escape_identifier(workflow_run_id)
escaped_tenant_id = escape_identifier(self._tenant_id)
# Build ORDER BY clause for outer query
order_clause = ""
if order_config and order_config.order_by:
order_fields = []
@ -327,16 +350,23 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
if order_fields:
order_clause = "ORDER BY " + ", ".join(order_fields)
sql = f"""
SELECT *
FROM {AliyunLogStore.workflow_node_execution_logstore}
WHERE workflow_run_id='{workflow_run_id}'
AND tenant_id='{self._tenant_id}'
AND finished_at IS NOT NULL
"""
# Build app_id filter for subquery
app_id_filter = ""
if self._app_id:
sql += f" AND app_id='{self._app_id}'"
escaped_app_id = escape_identifier(self._app_id)
app_id_filter = f" AND app_id='{escaped_app_id}'"
# Use window function to get latest version of each node execution
sql = f"""
SELECT * FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_node_execution_logstore}
WHERE workflow_run_id='{escaped_workflow_run_id}'
AND tenant_id='{escaped_tenant_id}'
{app_id_filter}
) t
WHERE rn = 1
"""
if order_clause:
sql += f" {order_clause}"

View File

@ -0,0 +1,134 @@
"""
SQL Escape Utility for LogStore Queries
This module provides escaping utilities to prevent injection attacks in LogStore queries.
LogStore supports two query modes:
1. PG Protocol Mode: Uses SQL syntax with single quotes for strings
2. SDK Mode: Uses LogStore query syntax (key: value) with double quotes
Key Security Concerns:
- Prevent tenant A from accessing tenant B's data via injection
- SLS queries are read-only, so we focus on data access control
- Different escaping strategies for SQL vs LogStore query syntax
"""
def escape_sql_string(value: str) -> str:
"""
Escape a string value for safe use in SQL queries.
This function escapes single quotes by doubling them, which is the standard
SQL escaping method. This prevents SQL injection by ensuring that user input
cannot break out of string literals.
Args:
value: The string value to escape
Returns:
Escaped string safe for use in SQL queries
Examples:
>>> escape_sql_string("normal_value")
"normal_value"
>>> escape_sql_string("value' OR '1'='1")
"value'' OR ''1''=''1"
>>> escape_sql_string("tenant's_id")
"tenant''s_id"
Security:
- Prevents breaking out of string literals
- Stops injection attacks like: ' OR '1'='1
- Protects against cross-tenant data access
"""
if not value:
return value
# Escape single quotes by doubling them (standard SQL escaping)
# This prevents breaking out of string literals in SQL queries
return value.replace("'", "''")
def escape_identifier(value: str) -> str:
"""
Escape an identifier (tenant_id, app_id, run_id, etc.) for safe SQL use.
This function is for PG protocol mode (SQL syntax).
For SDK mode, use escape_logstore_query_value() instead.
Args:
value: The identifier value to escape
Returns:
Escaped identifier safe for use in SQL queries
Examples:
>>> escape_identifier("550e8400-e29b-41d4-a716-446655440000")
"550e8400-e29b-41d4-a716-446655440000"
>>> escape_identifier("tenant_id' OR '1'='1")
"tenant_id'' OR ''1''=''1"
Security:
- Prevents SQL injection via identifiers
- Stops cross-tenant access attempts
- Works for UUIDs, alphanumeric IDs, and similar identifiers
"""
# For identifiers, use the same escaping as strings
# This is simple and effective for preventing injection
return escape_sql_string(value)
def escape_logstore_query_value(value: str) -> str:
"""
Escape value for LogStore query syntax (SDK mode).
LogStore query syntax rules:
1. Keywords (and/or/not) are case-insensitive
2. Single quotes are ordinary characters (no special meaning)
3. Double quotes wrap values: key:"value"
4. Backslash is the escape character:
- \" for double quote inside value
- \\ for backslash itself
5. Parentheses can change query structure
To prevent injection:
- Wrap value in double quotes to treat special chars as literals
- Escape backslashes and double quotes using backslash
Args:
value: The value to escape for LogStore query syntax
Returns:
Quoted and escaped value safe for LogStore query syntax (includes the quotes)
Examples:
>>> escape_logstore_query_value("normal_value")
'"normal_value"'
>>> escape_logstore_query_value("value or field:evil")
'"value or field:evil"' # 'or' and ':' are now literals
>>> escape_logstore_query_value('value"test')
'"value\\"test"' # Internal double quote escaped
>>> escape_logstore_query_value('value\\test')
'"value\\\\test"' # Backslash escaped
Security:
- Prevents injection via and/or/not keywords
- Prevents injection via colons (:)
- Prevents injection via parentheses
- Protects against cross-tenant data access
Note:
Escape order is critical: backslash first, then double quotes.
Otherwise, we'd double-escape the escape character itself.
"""
if not value:
return '""'
# IMPORTANT: Escape backslashes FIRST, then double quotes
# This prevents double-escaping (e.g., " -> \" -> \\" incorrectly)
escaped = value.replace("\\", "\\\\") # \ -> \\
escaped = escaped.replace('"', '\\"') # " -> \"
# Wrap in double quotes to treat as literal string
# This prevents and/or/not/:/() from being interpreted as operators
return f'"{escaped}"'

View File

@ -115,7 +115,18 @@ def build_from_mappings(
# TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query.
# Implement batch processing to reduce database load when handling multiple files.
# Filter out None/empty mappings to avoid errors
valid_mappings = [m for m in mappings if m and m.get("transfer_method")]
def is_valid_mapping(m: Mapping[str, Any]) -> bool:
if not m or not m.get("transfer_method"):
return False
# For REMOTE_URL transfer method, ensure url or remote_url is provided and not None
transfer_method = m.get("transfer_method")
if transfer_method == FileTransferMethod.REMOTE_URL:
url = m.get("url") or m.get("remote_url")
if not url:
return False
return True
valid_mappings = [m for m in mappings if is_valid_mapping(m)]
files = [
build_from_mapping(
mapping=mapping,

View File

@ -2,6 +2,7 @@ from __future__ import annotations
from datetime import datetime
from typing import TypeAlias
from uuid import uuid4
from pydantic import BaseModel, ConfigDict, Field, field_validator
@ -20,8 +21,8 @@ class SimpleFeedback(ResponseModel):
class RetrieverResource(ResponseModel):
id: str
message_id: str
id: str = Field(default_factory=lambda: str(uuid4()))
message_id: str = Field(default_factory=lambda: str(uuid4()))
position: int
dataset_id: str | None = None
dataset_name: str | None = None

View File

@ -3,6 +3,8 @@ import smtplib
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from configs import dify_config
logger = logging.getLogger(__name__)
@ -19,20 +21,21 @@ class SMTPClient:
self.opportunistic_tls = opportunistic_tls
def send(self, mail: dict):
smtp = None
smtp: smtplib.SMTP | None = None
local_host = dify_config.SMTP_LOCAL_HOSTNAME
try:
if self.use_tls:
if self.opportunistic_tls:
smtp = smtplib.SMTP(self.server, self.port, timeout=10)
# Send EHLO command with the HELO domain name as the server address
smtp.ehlo(self.server)
smtp.starttls()
# Resend EHLO command to identify the TLS session
smtp.ehlo(self.server)
else:
smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10)
if self.use_tls and not self.opportunistic_tls:
# SMTP with SSL (implicit TLS)
smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10, local_hostname=local_host)
else:
smtp = smtplib.SMTP(self.server, self.port, timeout=10)
# Plain SMTP or SMTP with STARTTLS (explicit TLS)
smtp = smtplib.SMTP(self.server, self.port, timeout=10, local_hostname=local_host)
assert smtp is not None
if self.use_tls and self.opportunistic_tls:
smtp.ehlo(self.server)
smtp.starttls()
smtp.ehlo(self.server)
# Only authenticate if both username and password are non-empty
if self.username and self.password and self.username.strip() and self.password.strip():

View File

@ -0,0 +1,33 @@
"""feat: add created_at id index to messages
Revision ID: 3334862ee907
Revises: 905527cc8fd3
Create Date: 2026-01-12 17:29:44.846544
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '3334862ee907'
down_revision = '905527cc8fd3'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('messages', schema=None) as batch_op:
batch_op.create_index('message_created_at_id_idx', ['created_at', 'id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('messages', schema=None) as batch_op:
batch_op.drop_index('message_created_at_id_idx')
# ### end Alembic commands ###

View File

@ -1149,7 +1149,7 @@ class DatasetCollectionBinding(TypeBase):
)
class TidbAuthBinding(Base):
class TidbAuthBinding(TypeBase):
__tablename__ = "tidb_auth_bindings"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"),
@ -1158,7 +1158,13 @@ class TidbAuthBinding(Base):
sa.Index("tidb_auth_bindings_created_at_idx", "created_at"),
sa.Index("tidb_auth_bindings_status_idx", "status"),
)
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()))
id: Mapped[str] = mapped_column(
StringUUID,
primary_key=True,
insert_default=lambda: str(uuid4()),
default_factory=lambda: str(uuid4()),
init=False,
)
tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
cluster_id: Mapped[str] = mapped_column(String(255), nullable=False)
cluster_name: Mapped[str] = mapped_column(String(255), nullable=False)
@ -1166,7 +1172,9 @@ class TidbAuthBinding(Base):
status: Mapped[str] = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'CREATING'"))
account: Mapped[str] = mapped_column(String(255), nullable=False)
password: Mapped[str] = mapped_column(String(255), nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
class Whitelist(TypeBase):

View File

@ -968,6 +968,7 @@ class Message(Base):
Index("message_workflow_run_id_idx", "conversation_id", "workflow_run_id"),
Index("message_created_at_idx", "created_at"),
Index("message_app_mode_idx", "app_mode"),
Index("message_created_at_id_idx", "created_at", "id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
@ -1447,7 +1448,7 @@ class MessageAnnotation(Base):
return account
class AppAnnotationHitHistory(Base):
class AppAnnotationHitHistory(TypeBase):
__tablename__ = "app_annotation_hit_histories"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"),
@ -1457,17 +1458,19 @@ class AppAnnotationHitHistory(Base):
sa.Index("app_annotation_hit_histories_message_idx", "message_id"),
)
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
app_id = mapped_column(StringUUID, nullable=False)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
annotation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
source = mapped_column(LongText, nullable=False)
question = mapped_column(LongText, nullable=False)
account_id = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
score = mapped_column(Float, nullable=False, server_default=sa.text("0"))
message_id = mapped_column(StringUUID, nullable=False)
annotation_question = mapped_column(LongText, nullable=False)
annotation_content = mapped_column(LongText, nullable=False)
source: Mapped[str] = mapped_column(LongText, nullable=False)
question: Mapped[str] = mapped_column(LongText, nullable=False)
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
score: Mapped[float] = mapped_column(Float, nullable=False, server_default=sa.text("0"))
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
annotation_question: Mapped[str] = mapped_column(LongText, nullable=False)
annotation_content: Mapped[str] = mapped_column(LongText, nullable=False)
@property
def account(self):
@ -2083,7 +2086,7 @@ class TraceAppConfig(TypeBase):
}
class TenantCreditPool(Base):
class TenantCreditPool(TypeBase):
__tablename__ = "tenant_credit_pools"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="tenant_credit_pool_pkey"),
@ -2091,14 +2094,20 @@ class TenantCreditPool(Base):
sa.Index("tenant_credit_pool_pool_type_idx", "pool_type"),
)
id = mapped_column(StringUUID, primary_key=True, server_default=text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
pool_type = mapped_column(String(40), nullable=False, default="trial", server_default="trial")
quota_limit = mapped_column(BigInteger, nullable=False, default=0)
quota_used = mapped_column(BigInteger, nullable=False, default=0)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"))
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=text("uuid_generate_v4()"), init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
pool_type: Mapped[str] = mapped_column(String(40), nullable=False, default="trial", server_default="trial")
quota_limit: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
quota_used: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"), init=False
)
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime,
nullable=False,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
init=False,
)
@property

View File

@ -1,6 +1,6 @@
[project]
name = "dify-api"
version = "1.11.3"
version = "1.11.4"
requires-python = ">=3.11,<3.13"
dependencies = [

View File

@ -1,90 +1,62 @@
import datetime
import logging
import time
import click
from sqlalchemy.exc import SQLAlchemyError
import app
from configs import dify_config
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.model import (
App,
Message,
MessageAgentThought,
MessageAnnotation,
MessageChain,
MessageFeedback,
MessageFile,
)
from models.web import SavedMessage
from services.feature_service import FeatureService
from services.retention.conversation.messages_clean_policy import create_message_clean_policy
from services.retention.conversation.messages_clean_service import MessagesCleanService
logger = logging.getLogger(__name__)
@app.celery.task(queue="dataset")
@app.celery.task(queue="retention")
def clean_messages():
click.echo(click.style("Start clean messages.", fg="green"))
start_at = time.perf_counter()
plan_sandbox_clean_message_day = datetime.datetime.now() - datetime.timedelta(
days=dify_config.PLAN_SANDBOX_CLEAN_MESSAGE_DAY_SETTING
)
while True:
try:
# Main query with join and filter
messages = (
db.session.query(Message)
.where(Message.created_at < plan_sandbox_clean_message_day)
.order_by(Message.created_at.desc())
.limit(100)
.all()
)
"""
Clean expired messages based on clean policy.
except SQLAlchemyError:
raise
if not messages:
break
for message in messages:
app = db.session.query(App).filter_by(id=message.app_id).first()
if not app:
logger.warning(
"Expected App record to exist, but none was found, app_id=%s, message_id=%s",
message.app_id,
message.id,
)
continue
features_cache_key = f"features:{app.tenant_id}"
plan_cache = redis_client.get(features_cache_key)
if plan_cache is None:
features = FeatureService.get_features(app.tenant_id)
redis_client.setex(features_cache_key, 600, features.billing.subscription.plan)
plan = features.billing.subscription.plan
else:
plan = plan_cache.decode()
if plan == CloudPlan.SANDBOX:
# clean related message
db.session.query(MessageFeedback).where(MessageFeedback.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(MessageChain).where(MessageChain.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(MessageFile).where(MessageFile.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(SavedMessage).where(SavedMessage.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(Message).where(Message.id == message.id).delete()
db.session.commit()
end_at = time.perf_counter()
click.echo(click.style(f"Cleaned messages from db success latency: {end_at - start_at}", fg="green"))
This task uses MessagesCleanService to efficiently clean messages in batches.
The behavior depends on BILLING_ENABLED configuration:
- BILLING_ENABLED=True: only delete messages from sandbox tenants (with whitelist/grace period)
- BILLING_ENABLED=False: delete all messages within the time range
"""
click.echo(click.style("clean_messages: start clean messages.", fg="green"))
start_at = time.perf_counter()
try:
# Create policy based on billing configuration
policy = create_message_clean_policy(
graceful_period_days=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD,
)
# Create and run the cleanup service
service = MessagesCleanService.from_days(
policy=policy,
days=dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS,
batch_size=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE,
)
stats = service.run()
end_at = time.perf_counter()
click.echo(
click.style(
f"clean_messages: completed successfully\n"
f" - Latency: {end_at - start_at:.2f}s\n"
f" - Batches processed: {stats['batches']}\n"
f" - Total messages scanned: {stats['total_messages']}\n"
f" - Messages filtered: {stats['filtered_messages']}\n"
f" - Messages deleted: {stats['total_deleted']}",
fg="green",
)
)
except Exception as e:
end_at = time.perf_counter()
logger.exception("clean_messages failed")
click.echo(
click.style(
f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}",
fg="red",
)
)
raise

View File

@ -50,10 +50,13 @@ def create_clusters(batch_size):
)
for new_cluster in new_clusters:
tidb_auth_binding = TidbAuthBinding(
tenant_id=None,
cluster_id=new_cluster["cluster_id"],
cluster_name=new_cluster["cluster_name"],
account=new_cluster["account"],
password=new_cluster["password"],
active=False,
status="CREATING",
)
db.session.add(tidb_auth_binding)
db.session.commit()

View File

@ -0,0 +1,58 @@
import json
import logging
import uuid
from datetime import UTC, datetime
from redis import RedisError
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
WORKSPACE_SYNC_QUEUE = "enterprise:workspace:sync:queue"
WORKSPACE_SYNC_PROCESSING = "enterprise:workspace:sync:processing"
class WorkspaceSyncService:
"""Service to publish workspace sync tasks to Redis queue for enterprise backend consumption"""
@staticmethod
def queue_credential_sync(workspace_id: str, *, source: str) -> bool:
"""
Queue a credential sync task for a newly created workspace.
This publishes a task to Redis that will be consumed by the enterprise backend
worker to sync credentials with the plugin-manager.
Args:
workspace_id: The workspace/tenant ID to sync credentials for
source: Source of the sync request (for debugging/tracking)
Returns:
bool: True if task was queued successfully, False otherwise
"""
try:
task = {
"task_id": str(uuid.uuid4()),
"workspace_id": workspace_id,
"retry_count": 0,
"created_at": datetime.now(UTC).isoformat(),
"source": source,
}
# Push to Redis list (queue) - LPUSH adds to the head, worker consumes from tail with RPOP
redis_client.lpush(WORKSPACE_SYNC_QUEUE, json.dumps(task))
logger.info(
"Queued credential sync task for workspace %s, task_id: %s, source: %s",
workspace_id,
task["task_id"],
source,
)
return True
except (RedisError, TypeError) as e:
logger.error("Failed to queue credential sync for workspace %s: %s", workspace_id, str(e), exc_info=True)
# Don't raise - we don't want to fail workspace creation if queueing fails
# The scheduled task will catch it later
return False

View File

@ -3,6 +3,7 @@ from collections.abc import Mapping, Sequence
from mimetypes import guess_type
from pydantic import BaseModel
from sqlalchemy import select
from yarl import URL
from configs import dify_config
@ -25,7 +26,9 @@ from core.plugin.entities.plugin_daemon import (
from core.plugin.impl.asset import PluginAssetManager
from core.plugin.impl.debugging import PluginDebuggingClient
from core.plugin.impl.plugin import PluginInstaller
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.provider import ProviderCredential
from models.provider_ids import GenericProviderID
from services.errors.plugin import PluginInstallationForbiddenError
from services.feature_service import FeatureService, PluginInstallationScope
@ -506,6 +509,33 @@ class PluginService:
@staticmethod
def uninstall(tenant_id: str, plugin_installation_id: str) -> bool:
manager = PluginInstaller()
# Get plugin info before uninstalling to delete associated credentials
try:
plugins = manager.list_plugins(tenant_id)
plugin = next((p for p in plugins if p.installation_id == plugin_installation_id), None)
if plugin:
plugin_id = plugin.plugin_id
logger.info("Deleting credentials for plugin: %s", plugin_id)
# Delete provider credentials that match this plugin
credentials = db.session.scalars(
select(ProviderCredential).where(
ProviderCredential.tenant_id == tenant_id,
ProviderCredential.provider_name.like(f"{plugin_id}/%"),
)
).all()
for cred in credentials:
db.session.delete(cred)
db.session.commit()
logger.info("Deleted %d credentials for plugin: %s", len(credentials), plugin_id)
except Exception as e:
logger.warning("Failed to delete credentials: %s", e)
# Continue with uninstall even if credential deletion fails
return manager.uninstall(tenant_id, plugin_installation_id)
@staticmethod

View File

@ -0,0 +1,216 @@
import datetime
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from configs import dify_config
from enums.cloud_plan import CloudPlan
from services.billing_service import BillingService, SubscriptionPlan
logger = logging.getLogger(__name__)
@dataclass
class SimpleMessage:
id: str
app_id: str
created_at: datetime.datetime
class MessagesCleanPolicy(ABC):
"""
Abstract base class for message cleanup policies.
A policy determines which messages from a batch should be deleted.
"""
@abstractmethod
def filter_message_ids(
self,
messages: Sequence[SimpleMessage],
app_to_tenant: dict[str, str],
) -> Sequence[str]:
"""
Filter messages and return IDs of messages that should be deleted.
Args:
messages: Batch of messages to evaluate
app_to_tenant: Mapping from app_id to tenant_id
Returns:
List of message IDs that should be deleted
"""
...
class BillingDisabledPolicy(MessagesCleanPolicy):
"""
Policy for community or enterpriseedition (billing disabled).
No special filter logic, just return all message ids.
"""
def filter_message_ids(
self,
messages: Sequence[SimpleMessage],
app_to_tenant: dict[str, str],
) -> Sequence[str]:
return [msg.id for msg in messages]
class BillingSandboxPolicy(MessagesCleanPolicy):
"""
Policy for sandbox plan tenants in cloud edition (billing enabled).
Filters messages based on sandbox plan expiration rules:
- Skip tenants in the whitelist
- Only delete messages from sandbox plan tenants
- Respect grace period after subscription expiration
- Safe default: if tenant mapping or plan is missing, do NOT delete
"""
def __init__(
self,
plan_provider: Callable[[Sequence[str]], dict[str, SubscriptionPlan]],
graceful_period_days: int = 21,
tenant_whitelist: Sequence[str] | None = None,
current_timestamp: int | None = None,
) -> None:
self._graceful_period_days = graceful_period_days
self._tenant_whitelist: Sequence[str] = tenant_whitelist or []
self._plan_provider = plan_provider
self._current_timestamp = current_timestamp
def filter_message_ids(
self,
messages: Sequence[SimpleMessage],
app_to_tenant: dict[str, str],
) -> Sequence[str]:
"""
Filter messages based on sandbox plan expiration rules.
Args:
messages: Batch of messages to evaluate
app_to_tenant: Mapping from app_id to tenant_id
Returns:
List of message IDs that should be deleted
"""
if not messages or not app_to_tenant:
return []
# Get unique tenant_ids and fetch subscription plans
tenant_ids = list(set(app_to_tenant.values()))
tenant_plans = self._plan_provider(tenant_ids)
if not tenant_plans:
return []
# Apply sandbox deletion rules
return self._filter_expired_sandbox_messages(
messages=messages,
app_to_tenant=app_to_tenant,
tenant_plans=tenant_plans,
)
def _filter_expired_sandbox_messages(
self,
messages: Sequence[SimpleMessage],
app_to_tenant: dict[str, str],
tenant_plans: dict[str, SubscriptionPlan],
) -> list[str]:
"""
Filter messages that should be deleted based on sandbox plan expiration.
A message should be deleted if:
1. It belongs to a sandbox tenant AND
2. Either:
a) The tenant has no previous subscription (expiration_date == -1), OR
b) The subscription expired more than graceful_period_days ago
Args:
messages: List of message objects with id and app_id attributes
app_to_tenant: Mapping from app_id to tenant_id
tenant_plans: Mapping from tenant_id to subscription plan info
Returns:
List of message IDs that should be deleted
"""
current_timestamp = self._current_timestamp
if current_timestamp is None:
current_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp())
sandbox_message_ids: list[str] = []
graceful_period_seconds = self._graceful_period_days * 24 * 60 * 60
for msg in messages:
# Get tenant_id for this message's app
tenant_id = app_to_tenant.get(msg.app_id)
if not tenant_id:
continue
# Skip tenant messages in whitelist
if tenant_id in self._tenant_whitelist:
continue
# Get subscription plan for this tenant
tenant_plan = tenant_plans.get(tenant_id)
if not tenant_plan:
continue
plan = str(tenant_plan["plan"])
expiration_date = int(tenant_plan["expiration_date"])
# Only process sandbox plans
if plan != CloudPlan.SANDBOX:
continue
# Case 1: No previous subscription (-1 means never had a paid subscription)
if expiration_date == -1:
sandbox_message_ids.append(msg.id)
continue
# Case 2: Subscription expired beyond grace period
if current_timestamp - expiration_date > graceful_period_seconds:
sandbox_message_ids.append(msg.id)
return sandbox_message_ids
def create_message_clean_policy(
graceful_period_days: int = 21,
current_timestamp: int | None = None,
) -> MessagesCleanPolicy:
"""
Factory function to create the appropriate message clean policy.
Determines which policy to use based on BILLING_ENABLED configuration:
- If BILLING_ENABLED is True: returns BillingSandboxPolicy
- If BILLING_ENABLED is False: returns BillingDisabledPolicy
Args:
graceful_period_days: Grace period in days after subscription expiration (default: 21)
current_timestamp: Current Unix timestamp for testing (default: None, uses current time)
"""
if not dify_config.BILLING_ENABLED:
logger.info("create_message_clean_policy: billing disabled, using BillingDisabledPolicy")
return BillingDisabledPolicy()
# Billing enabled - fetch whitelist from BillingService
tenant_whitelist = BillingService.get_expired_subscription_cleanup_whitelist()
plan_provider = BillingService.get_plan_bulk_with_cache
logger.info(
"create_message_clean_policy: billing enabled, using BillingSandboxPolicy "
"(graceful_period_days=%s, whitelist=%s)",
graceful_period_days,
tenant_whitelist,
)
return BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=graceful_period_days,
tenant_whitelist=tenant_whitelist,
current_timestamp=current_timestamp,
)

View File

@ -0,0 +1,334 @@
import datetime
import logging
import random
from collections.abc import Sequence
from typing import cast
from sqlalchemy import delete, select
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session
from extensions.ext_database import db
from models.model import (
App,
AppAnnotationHitHistory,
DatasetRetrieverResource,
Message,
MessageAgentThought,
MessageAnnotation,
MessageChain,
MessageFeedback,
MessageFile,
)
from models.web import SavedMessage
from services.retention.conversation.messages_clean_policy import (
MessagesCleanPolicy,
SimpleMessage,
)
logger = logging.getLogger(__name__)
class MessagesCleanService:
"""
Service for cleaning expired messages based on retention policies.
Compatible with non cloud edition (billing disabled): all messages in the time range will be deleted.
If billing is enabled: only sandbox plan tenant messages are deleted (with whitelist and grace period support).
"""
def __init__(
self,
policy: MessagesCleanPolicy,
end_before: datetime.datetime,
start_from: datetime.datetime | None = None,
batch_size: int = 1000,
dry_run: bool = False,
) -> None:
"""
Initialize the service with cleanup parameters.
Args:
policy: The policy that determines which messages to delete
end_before: End time (exclusive) of the range
start_from: Optional start time (inclusive) of the range
batch_size: Number of messages to process per batch
dry_run: Whether to perform a dry run (no actual deletion)
"""
self._policy = policy
self._end_before = end_before
self._start_from = start_from
self._batch_size = batch_size
self._dry_run = dry_run
@classmethod
def from_time_range(
cls,
policy: MessagesCleanPolicy,
start_from: datetime.datetime,
end_before: datetime.datetime,
batch_size: int = 1000,
dry_run: bool = False,
) -> "MessagesCleanService":
"""
Create a service instance for cleaning messages within a specific time range.
Time range is [start_from, end_before).
Args:
policy: The policy that determines which messages to delete
start_from: Start time (inclusive) of the range
end_before: End time (exclusive) of the range
batch_size: Number of messages to process per batch
dry_run: Whether to perform a dry run (no actual deletion)
Returns:
MessagesCleanService instance
Raises:
ValueError: If start_from >= end_before or invalid parameters
"""
if start_from >= end_before:
raise ValueError(f"start_from ({start_from}) must be less than end_before ({end_before})")
if batch_size <= 0:
raise ValueError(f"batch_size ({batch_size}) must be greater than 0")
logger.info(
"clean_messages: start_from=%s, end_before=%s, batch_size=%s, policy=%s",
start_from,
end_before,
batch_size,
policy.__class__.__name__,
)
return cls(
policy=policy,
end_before=end_before,
start_from=start_from,
batch_size=batch_size,
dry_run=dry_run,
)
@classmethod
def from_days(
cls,
policy: MessagesCleanPolicy,
days: int = 30,
batch_size: int = 1000,
dry_run: bool = False,
) -> "MessagesCleanService":
"""
Create a service instance for cleaning messages older than specified days.
Args:
policy: The policy that determines which messages to delete
days: Number of days to look back from now
batch_size: Number of messages to process per batch
dry_run: Whether to perform a dry run (no actual deletion)
Returns:
MessagesCleanService instance
Raises:
ValueError: If invalid parameters
"""
if days < 0:
raise ValueError(f"days ({days}) must be greater than or equal to 0")
if batch_size <= 0:
raise ValueError(f"batch_size ({batch_size}) must be greater than 0")
end_before = datetime.datetime.now() - datetime.timedelta(days=days)
logger.info(
"clean_messages: days=%s, end_before=%s, batch_size=%s, policy=%s",
days,
end_before,
batch_size,
policy.__class__.__name__,
)
return cls(policy=policy, end_before=end_before, start_from=None, batch_size=batch_size, dry_run=dry_run)
def run(self) -> dict[str, int]:
"""
Execute the message cleanup operation.
Returns:
Dict with statistics: batches, filtered_messages, total_deleted
"""
return self._clean_messages_by_time_range()
def _clean_messages_by_time_range(self) -> dict[str, int]:
"""
Clean messages within a time range using cursor-based pagination.
Time range is [start_from, end_before)
Steps:
1. Iterate messages using cursor pagination (by created_at, id)
2. Query app_id -> tenant_id mapping
3. Delegate to policy to determine which messages to delete
4. Batch delete messages and their relations
Returns:
Dict with statistics: batches, filtered_messages, total_deleted
"""
stats = {
"batches": 0,
"total_messages": 0,
"filtered_messages": 0,
"total_deleted": 0,
}
# Cursor-based pagination using (created_at, id) to avoid infinite loops
# and ensure proper ordering with time-based filtering
_cursor: tuple[datetime.datetime, str] | None = None
logger.info(
"clean_messages: start cleaning messages (dry_run=%s), start_from=%s, end_before=%s",
self._dry_run,
self._start_from,
self._end_before,
)
while True:
stats["batches"] += 1
# Step 1: Fetch a batch of messages using cursor
with Session(db.engine, expire_on_commit=False) as session:
msg_stmt = (
select(Message.id, Message.app_id, Message.created_at)
.where(Message.created_at < self._end_before)
.order_by(Message.created_at, Message.id)
.limit(self._batch_size)
)
if self._start_from:
msg_stmt = msg_stmt.where(Message.created_at >= self._start_from)
# Apply cursor condition: (created_at, id) > (last_created_at, last_message_id)
# This translates to:
# created_at > last_created_at OR (created_at = last_created_at AND id > last_message_id)
if _cursor:
# Continuing from previous batch
msg_stmt = msg_stmt.where(
(Message.created_at > _cursor[0])
| ((Message.created_at == _cursor[0]) & (Message.id > _cursor[1]))
)
raw_messages = list(session.execute(msg_stmt).all())
messages = [
SimpleMessage(id=msg_id, app_id=app_id, created_at=msg_created_at)
for msg_id, app_id, msg_created_at in raw_messages
]
# Track total messages fetched across all batches
stats["total_messages"] += len(messages)
if not messages:
logger.info("clean_messages (batch %s): no more messages to process", stats["batches"])
break
# Update cursor to the last message's (created_at, id)
_cursor = (messages[-1].created_at, messages[-1].id)
# Step 2: Extract app_ids and query tenant_ids
app_ids = list({msg.app_id for msg in messages})
if not app_ids:
logger.info("clean_messages (batch %s): no app_ids found, skip", stats["batches"])
continue
app_stmt = select(App.id, App.tenant_id).where(App.id.in_(app_ids))
apps = list(session.execute(app_stmt).all())
if not apps:
logger.info("clean_messages (batch %s): no apps found, skip", stats["batches"])
continue
# Build app_id -> tenant_id mapping
app_to_tenant: dict[str, str] = {app.id: app.tenant_id for app in apps}
# Step 3: Delegate to policy to determine which messages to delete
message_ids_to_delete = self._policy.filter_message_ids(messages, app_to_tenant)
if not message_ids_to_delete:
logger.info("clean_messages (batch %s): no messages to delete, skip", stats["batches"])
continue
stats["filtered_messages"] += len(message_ids_to_delete)
# Step 4: Batch delete messages and their relations
if not self._dry_run:
with Session(db.engine, expire_on_commit=False) as session:
# Delete related records first
self._batch_delete_message_relations(session, message_ids_to_delete)
# Delete messages
delete_stmt = delete(Message).where(Message.id.in_(message_ids_to_delete))
delete_result = cast(CursorResult, session.execute(delete_stmt))
messages_deleted = delete_result.rowcount
session.commit()
stats["total_deleted"] += messages_deleted
logger.info(
"clean_messages (batch %s): processed %s messages, deleted %s messages",
stats["batches"],
len(messages),
messages_deleted,
)
else:
# Log random sample of message IDs that would be deleted (up to 10)
sample_size = min(10, len(message_ids_to_delete))
sampled_ids = random.sample(list(message_ids_to_delete), sample_size)
logger.info(
"clean_messages (batch %s, dry_run): would delete %s messages, sampling %s ids:",
stats["batches"],
len(message_ids_to_delete),
sample_size,
)
for msg_id in sampled_ids:
logger.info("clean_messages (batch %s, dry_run) sample: message_id=%s", stats["batches"], msg_id)
logger.info(
"clean_messages completed: total batches: %s, total messages: %s, filtered messages: %s, total deleted: %s",
stats["batches"],
stats["total_messages"],
stats["filtered_messages"],
stats["total_deleted"],
)
return stats
@staticmethod
def _batch_delete_message_relations(session: Session, message_ids: Sequence[str]) -> None:
"""
Batch delete all related records for given message IDs.
Args:
session: Database session
message_ids: List of message IDs to delete relations for
"""
if not message_ids:
return
# Delete all related records in batch
session.execute(delete(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)))
session.execute(delete(MessageAnnotation).where(MessageAnnotation.message_id.in_(message_ids)))
session.execute(delete(MessageChain).where(MessageChain.message_id.in_(message_ids)))
session.execute(delete(MessageAgentThought).where(MessageAgentThought.message_id.in_(message_ids)))
session.execute(delete(MessageFile).where(MessageFile.message_id.in_(message_ids)))
session.execute(delete(SavedMessage).where(SavedMessage.message_id.in_(message_ids)))
session.execute(delete(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id.in_(message_ids)))
session.execute(delete(DatasetRetrieverResource).where(DatasetRetrieverResource.message_id.in_(message_ids)))

View File

@ -103,6 +103,8 @@ SMTP_USERNAME=123
SMTP_PASSWORD=abc
SMTP_USE_TLS=true
SMTP_OPPORTUNISTIC_TLS=false
# Optional: override the local hostname used for SMTP HELO/EHLO
SMTP_LOCAL_HOSTNAME=
# Sentry configuration
SENTRY_DSN=

View File

@ -0,0 +1 @@
"""Unit tests for `controllers.console.datasets` controllers."""

View File

@ -0,0 +1,49 @@
from __future__ import annotations
"""
Unit tests for the external dataset controller payload schemas.
These tests focus on Pydantic validation rules so we can catch regressions
in request constraints (e.g. max length changes) without exercising the
full Flask/RESTX request stack.
"""
import pytest
from pydantic import ValidationError
from controllers.console.datasets.external import ExternalDatasetCreatePayload
def test_external_dataset_create_payload_allows_name_length_100() -> None:
"""Ensure the `name` field accepts up to 100 characters (inclusive)."""
# Build a request payload with a boundary-length name value.
name_100: str = "a" * 100
payload = {
"external_knowledge_api_id": "ek-api-1",
"external_knowledge_id": "ek-1",
"name": name_100,
}
model = ExternalDatasetCreatePayload.model_validate(payload)
assert model.name == name_100
def test_external_dataset_create_payload_rejects_name_length_101() -> None:
"""Ensure the `name` field rejects values longer than 100 characters."""
# Build a request payload that exceeds the max length by 1.
name_101: str = "a" * 101
payload: dict[str, object] = {
"external_knowledge_api_id": "ek-api-1",
"external_knowledge_id": "ek-1",
"name": name_101,
}
with pytest.raises(ValidationError) as exc_info:
ExternalDatasetCreatePayload.model_validate(payload)
errors = exc_info.value.errors()
assert errors[0]["loc"] == ("name",)
assert errors[0]["type"] == "string_too_long"
assert errors[0]["ctx"]["max_length"] == 100

View File

@ -0,0 +1,279 @@
"""Unit tests for PluginEndpointClient functionality.
This test module covers the endpoint client operations including:
- Successful endpoint deletion
- Idempotent delete behavior (record not found)
- Non-idempotent delete behavior (other errors)
Tests follow the Arrange-Act-Assert pattern for clarity.
"""
from unittest.mock import MagicMock, patch
import pytest
from core.plugin.impl.endpoint import PluginEndpointClient
from core.plugin.impl.exc import PluginDaemonInternalServerError
class TestPluginEndpointClientDelete:
"""Unit tests for PluginEndpointClient delete_endpoint operation.
Tests cover:
- Successful endpoint deletion
- Idempotent behavior when endpoint is already deleted (record not found)
- Non-idempotent behavior for other errors
"""
@pytest.fixture
def endpoint_client(self):
"""Create a PluginEndpointClient instance for testing."""
return PluginEndpointClient()
@pytest.fixture
def mock_config(self):
"""Mock plugin daemon configuration."""
with (
patch("core.plugin.impl.base.dify_config.PLUGIN_DAEMON_URL", "http://127.0.0.1:5002"),
patch("core.plugin.impl.base.dify_config.PLUGIN_DAEMON_KEY", "test-api-key"),
):
yield
def test_delete_endpoint_success(self, endpoint_client, mock_config):
"""Test successful endpoint deletion.
Given:
- A valid tenant_id, user_id, and endpoint_id
- The plugin daemon returns success response
When:
- delete_endpoint is called
Then:
- The method should return True
- The request should be made with correct parameters
"""
# Arrange
tenant_id = "tenant-123"
user_id = "user-456"
endpoint_id = "endpoint-789"
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"code": 0,
"message": "success",
"data": True,
}
with patch("httpx.request", return_value=mock_response):
# Act
result = endpoint_client.delete_endpoint(
tenant_id=tenant_id,
user_id=user_id,
endpoint_id=endpoint_id,
)
# Assert
assert result is True
def test_delete_endpoint_idempotent_record_not_found(self, endpoint_client, mock_config):
"""Test idempotent delete behavior when endpoint is already deleted.
Given:
- A valid tenant_id, user_id, and endpoint_id
- The plugin daemon returns "record not found" error
When:
- delete_endpoint is called
Then:
- The method should return True (idempotent behavior)
- No exception should be raised
"""
# Arrange
tenant_id = "tenant-123"
user_id = "user-456"
endpoint_id = "endpoint-789"
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"code": -1,
"message": (
'{"error_type": "PluginDaemonInternalServerError", '
'"message": "failed to remove endpoint: record not found"}'
),
}
with patch("httpx.request", return_value=mock_response):
# Act
result = endpoint_client.delete_endpoint(
tenant_id=tenant_id,
user_id=user_id,
endpoint_id=endpoint_id,
)
# Assert - should return True instead of raising an error
assert result is True
def test_delete_endpoint_non_idempotent_other_errors(self, endpoint_client, mock_config):
"""Test non-idempotent delete behavior for other errors.
Given:
- A valid tenant_id, user_id, and endpoint_id
- The plugin daemon returns a different error (not "record not found")
When:
- delete_endpoint is called
Then:
- The method should raise PluginDaemonInternalServerError
"""
# Arrange
tenant_id = "tenant-123"
user_id = "user-456"
endpoint_id = "endpoint-789"
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"code": -1,
"message": (
'{"error_type": "PluginDaemonInternalServerError", '
'"message": "failed to remove endpoint: internal server error"}'
),
}
with patch("httpx.request", return_value=mock_response):
# Act & Assert
with pytest.raises(PluginDaemonInternalServerError) as exc_info:
endpoint_client.delete_endpoint(
tenant_id=tenant_id,
user_id=user_id,
endpoint_id=endpoint_id,
)
# Assert - the error message should not be "record not found"
assert "record not found" not in str(exc_info.value.description)
def test_delete_endpoint_idempotent_case_insensitive(self, endpoint_client, mock_config):
"""Test idempotent delete behavior with case-insensitive error message.
Given:
- A valid tenant_id, user_id, and endpoint_id
- The plugin daemon returns "Record Not Found" error (different case)
When:
- delete_endpoint is called
Then:
- The method should return True (idempotent behavior)
"""
# Arrange
tenant_id = "tenant-123"
user_id = "user-456"
endpoint_id = "endpoint-789"
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"code": -1,
"message": '{"error_type": "PluginDaemonInternalServerError", "message": "Record Not Found"}',
}
with patch("httpx.request", return_value=mock_response):
# Act
result = endpoint_client.delete_endpoint(
tenant_id=tenant_id,
user_id=user_id,
endpoint_id=endpoint_id,
)
# Assert - should still return True
assert result is True
def test_delete_endpoint_multiple_calls_idempotent(self, endpoint_client, mock_config):
"""Test that multiple delete calls are idempotent.
Given:
- A valid tenant_id, user_id, and endpoint_id
- The first call succeeds
- Subsequent calls return "record not found"
When:
- delete_endpoint is called multiple times
Then:
- All calls should return True
"""
# Arrange
tenant_id = "tenant-123"
user_id = "user-456"
endpoint_id = "endpoint-789"
# First call - success
mock_response_success = MagicMock()
mock_response_success.status_code = 200
mock_response_success.json.return_value = {
"code": 0,
"message": "success",
"data": True,
}
# Second call - record not found
mock_response_not_found = MagicMock()
mock_response_not_found.status_code = 200
mock_response_not_found.json.return_value = {
"code": -1,
"message": (
'{"error_type": "PluginDaemonInternalServerError", '
'"message": "failed to remove endpoint: record not found"}'
),
}
with patch("httpx.request") as mock_request:
# Act - first call
mock_request.return_value = mock_response_success
result1 = endpoint_client.delete_endpoint(
tenant_id=tenant_id,
user_id=user_id,
endpoint_id=endpoint_id,
)
# Act - second call (already deleted)
mock_request.return_value = mock_response_not_found
result2 = endpoint_client.delete_endpoint(
tenant_id=tenant_id,
user_id=user_id,
endpoint_id=endpoint_id,
)
# Assert - both should return True
assert result1 is True
assert result2 is True
def test_delete_endpoint_non_idempotent_unauthorized_error(self, endpoint_client, mock_config):
"""Test that authorization errors are not treated as idempotent.
Given:
- A valid tenant_id, user_id, and endpoint_id
- The plugin daemon returns an unauthorized error
When:
- delete_endpoint is called
Then:
- The method should raise the appropriate error (not return True)
"""
# Arrange
tenant_id = "tenant-123"
user_id = "user-456"
endpoint_id = "endpoint-789"
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"code": -1,
"message": '{"error_type": "PluginDaemonUnauthorizedError", "message": "unauthorized access"}',
}
with patch("httpx.request", return_value=mock_response):
# Act & Assert
with pytest.raises(Exception) as exc_info:
endpoint_client.delete_endpoint(
tenant_id=tenant_id,
user_id=user_id,
endpoint_id=endpoint_id,
)
# Assert - should not return True for unauthorized errors
assert exc_info.value.__class__.__name__ == "PluginDaemonUnauthorizedError"

View File

@ -0,0 +1 @@
"""LogStore extension unit tests."""

View File

@ -0,0 +1,469 @@
"""
Unit tests for SQL escape utility functions.
These tests ensure that SQL injection attacks are properly prevented
in LogStore queries, particularly for cross-tenant access scenarios.
"""
import pytest
from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string
class TestEscapeSQLString:
"""Test escape_sql_string function."""
def test_escape_empty_string(self):
"""Test escaping empty string."""
assert escape_sql_string("") == ""
def test_escape_normal_string(self):
"""Test escaping string without special characters."""
assert escape_sql_string("tenant_abc123") == "tenant_abc123"
assert escape_sql_string("app-uuid-1234") == "app-uuid-1234"
def test_escape_single_quote(self):
"""Test escaping single quote."""
# Single quote should be doubled
assert escape_sql_string("tenant'id") == "tenant''id"
assert escape_sql_string("O'Reilly") == "O''Reilly"
def test_escape_multiple_quotes(self):
"""Test escaping multiple single quotes."""
assert escape_sql_string("a'b'c") == "a''b''c"
assert escape_sql_string("'''") == "''''''"
# === SQL Injection Attack Scenarios ===
def test_prevent_boolean_injection(self):
"""Test prevention of boolean injection attacks."""
# Classic OR 1=1 attack
malicious_input = "tenant' OR '1'='1"
escaped = escape_sql_string(malicious_input)
assert escaped == "tenant'' OR ''1''=''1"
# When used in SQL, this becomes a safe string literal
sql = f"WHERE tenant_id='{escaped}'"
assert sql == "WHERE tenant_id='tenant'' OR ''1''=''1'"
# The entire input is now a string literal that won't match any tenant
def test_prevent_or_injection(self):
"""Test prevention of OR-based injection."""
malicious_input = "tenant_a' OR tenant_id='tenant_b"
escaped = escape_sql_string(malicious_input)
assert escaped == "tenant_a'' OR tenant_id=''tenant_b"
sql = f"WHERE tenant_id='{escaped}'"
# The OR is now part of the string literal, not SQL logic
assert "OR tenant_id=" in sql
# The SQL has: opening ', doubled internal quotes '', and closing '
assert sql == "WHERE tenant_id='tenant_a'' OR tenant_id=''tenant_b'"
def test_prevent_union_injection(self):
"""Test prevention of UNION-based injection."""
malicious_input = "xxx' UNION SELECT password FROM users WHERE '1'='1"
escaped = escape_sql_string(malicious_input)
assert escaped == "xxx'' UNION SELECT password FROM users WHERE ''1''=''1"
# UNION becomes part of the string literal
assert "UNION" in escaped
assert escaped.count("''") == 4 # All internal quotes are doubled
def test_prevent_comment_injection(self):
"""Test prevention of comment-based injection."""
# SQL comment to bypass remaining conditions
malicious_input = "tenant' --"
escaped = escape_sql_string(malicious_input)
assert escaped == "tenant'' --"
sql = f"WHERE tenant_id='{escaped}' AND deleted=false"
# The -- is now inside the string, not a SQL comment
assert "--" in sql
assert "AND deleted=false" in sql # This part is NOT commented out
def test_prevent_semicolon_injection(self):
"""Test prevention of semicolon-based multi-statement injection."""
malicious_input = "tenant'; DROP TABLE users; --"
escaped = escape_sql_string(malicious_input)
assert escaped == "tenant''; DROP TABLE users; --"
# Semicolons and DROP are now part of the string
assert "DROP TABLE" in escaped
def test_prevent_time_based_blind_injection(self):
"""Test prevention of time-based blind SQL injection."""
malicious_input = "tenant' AND SLEEP(5) --"
escaped = escape_sql_string(malicious_input)
assert escaped == "tenant'' AND SLEEP(5) --"
# SLEEP becomes part of the string
assert "SLEEP" in escaped
def test_prevent_wildcard_injection(self):
"""Test prevention of wildcard-based injection."""
malicious_input = "tenant' OR tenant_id LIKE '%"
escaped = escape_sql_string(malicious_input)
assert escaped == "tenant'' OR tenant_id LIKE ''%"
# The LIKE and wildcard are now part of the string
assert "LIKE" in escaped
def test_prevent_null_byte_injection(self):
"""Test handling of null bytes."""
# Null bytes can sometimes bypass filters
malicious_input = "tenant\x00' OR '1'='1"
escaped = escape_sql_string(malicious_input)
# Null byte is preserved, but quote is escaped
assert "''1''=''1" in escaped
# === Real-world SAAS Scenarios ===
def test_cross_tenant_access_attempt(self):
"""Test prevention of cross-tenant data access."""
# Attacker tries to access another tenant's data
attacker_input = "tenant_b' OR tenant_id='tenant_a"
escaped = escape_sql_string(attacker_input)
sql = f"SELECT * FROM workflow_runs WHERE tenant_id='{escaped}'"
# The query will look for a tenant literally named "tenant_b' OR tenant_id='tenant_a"
# which doesn't exist - preventing access to either tenant's data
assert "tenant_b'' OR tenant_id=''tenant_a" in sql
def test_cross_app_access_attempt(self):
"""Test prevention of cross-application data access."""
attacker_input = "app1' OR app_id='app2"
escaped = escape_sql_string(attacker_input)
sql = f"WHERE app_id='{escaped}'"
# Cannot access app2's data
assert "app1'' OR app_id=''app2" in sql
def test_bypass_status_filter(self):
"""Test prevention of bypassing status filters."""
# Try to see all statuses instead of just 'running'
attacker_input = "running' OR status LIKE '%"
escaped = escape_sql_string(attacker_input)
sql = f"WHERE status='{escaped}'"
# Status condition is not bypassed
assert "running'' OR status LIKE ''%" in sql
# === Edge Cases ===
def test_escape_only_quotes(self):
"""Test string with only quotes."""
assert escape_sql_string("'") == "''"
assert escape_sql_string("''") == "''''"
def test_escape_mixed_content(self):
"""Test string with mixed quotes and other chars."""
input_str = "It's a 'test' of O'Reilly's code"
escaped = escape_sql_string(input_str)
assert escaped == "It''s a ''test'' of O''Reilly''s code"
def test_escape_unicode_with_quotes(self):
"""Test Unicode strings with quotes."""
input_str = "租户' OR '1'='1"
escaped = escape_sql_string(input_str)
assert escaped == "租户'' OR ''1''=''1"
class TestEscapeIdentifier:
"""Test escape_identifier function."""
def test_escape_uuid(self):
"""Test escaping UUID identifiers."""
uuid = "550e8400-e29b-41d4-a716-446655440000"
assert escape_identifier(uuid) == uuid
def test_escape_alphanumeric_id(self):
"""Test escaping alphanumeric identifiers."""
assert escape_identifier("tenant_123") == "tenant_123"
assert escape_identifier("app-abc-123") == "app-abc-123"
def test_escape_identifier_with_quote(self):
"""Test escaping identifier with single quote."""
malicious = "tenant' OR '1'='1"
escaped = escape_identifier(malicious)
assert escaped == "tenant'' OR ''1''=''1"
def test_identifier_injection_attempt(self):
"""Test prevention of injection through identifiers."""
# Common identifier injection patterns
test_cases = [
("id' OR '1'='1", "id'' OR ''1''=''1"),
("id'; DROP TABLE", "id''; DROP TABLE"),
("id' UNION SELECT", "id'' UNION SELECT"),
]
for malicious, expected in test_cases:
assert escape_identifier(malicious) == expected
class TestSQLInjectionIntegration:
"""Integration tests simulating real SQL construction scenarios."""
def test_complete_where_clause_safety(self):
"""Test that a complete WHERE clause is safe from injection."""
# Simulating typical query construction
tenant_id = "tenant' OR '1'='1"
app_id = "app' UNION SELECT"
run_id = "run' --"
escaped_tenant = escape_identifier(tenant_id)
escaped_app = escape_identifier(app_id)
escaped_run = escape_identifier(run_id)
sql = f"""
SELECT * FROM workflow_runs
WHERE tenant_id='{escaped_tenant}'
AND app_id='{escaped_app}'
AND id='{escaped_run}'
"""
# Verify all special characters are escaped
assert "tenant'' OR ''1''=''1" in sql
assert "app'' UNION SELECT" in sql
assert "run'' --" in sql
# Verify SQL structure is preserved (3 conditions with AND)
assert sql.count("AND") == 2
def test_multiple_conditions_with_injection_attempts(self):
"""Test multiple conditions all attempting injection."""
conditions = {
"tenant_id": "t1' OR tenant_id='t2",
"app_id": "a1' OR app_id='a2",
"status": "running' OR '1'='1",
}
where_parts = []
for field, value in conditions.items():
escaped = escape_sql_string(value)
where_parts.append(f"{field}='{escaped}'")
where_clause = " AND ".join(where_parts)
# All injection attempts are neutralized
assert "t1'' OR tenant_id=''t2" in where_clause
assert "a1'' OR app_id=''a2" in where_clause
assert "running'' OR ''1''=''1" in where_clause
# AND structure is preserved
assert where_clause.count(" AND ") == 2
@pytest.mark.parametrize(
("attack_vector", "description"),
[
("' OR '1'='1", "Boolean injection"),
("' OR '1'='1' --", "Boolean with comment"),
("' UNION SELECT * FROM users --", "Union injection"),
("'; DROP TABLE workflow_runs; --", "Destructive command"),
("' AND SLEEP(10) --", "Time-based blind"),
("' OR tenant_id LIKE '%", "Wildcard injection"),
("admin' --", "Comment bypass"),
("' OR 1=1 LIMIT 1 --", "Limit bypass"),
],
)
def test_common_injection_vectors(self, attack_vector, description):
"""Test protection against common injection attack vectors."""
escaped = escape_sql_string(attack_vector)
# Build SQL
sql = f"WHERE tenant_id='{escaped}'"
# Verify the attack string is now a safe literal
# The key indicator: all internal single quotes are doubled
internal_quotes = escaped.count("''")
original_quotes = attack_vector.count("'")
# Each original quote should be doubled
assert internal_quotes == original_quotes
# Verify SQL has exactly 2 quotes (opening and closing)
assert sql.count("'") >= 2 # At least opening and closing
def test_logstore_specific_scenario(self):
"""Test SQL injection prevention in LogStore-specific scenarios."""
# Simulate LogStore query with window function
tenant_id = "tenant' OR '1'='1"
app_id = "app' UNION SELECT"
escaped_tenant = escape_identifier(tenant_id)
escaped_app = escape_identifier(app_id)
sql = f"""
SELECT * FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM workflow_execution_logstore
WHERE tenant_id='{escaped_tenant}'
AND app_id='{escaped_app}'
AND __time__ > 0
) AS subquery WHERE rn = 1
"""
# Complex query structure is maintained
assert "ROW_NUMBER()" in sql
assert "PARTITION BY id" in sql
# Injection attempts are escaped
assert "tenant'' OR ''1''=''1" in sql
assert "app'' UNION SELECT" in sql
# ====================================================================================
# Tests for LogStore Query Syntax (SDK Mode)
# ====================================================================================
class TestLogStoreQueryEscape:
"""Test escape_logstore_query_value for SDK mode query syntax."""
def test_normal_value(self):
"""Test escaping normal alphanumeric value."""
value = "550e8400-e29b-41d4-a716-446655440000"
escaped = escape_logstore_query_value(value)
# Should be wrapped in double quotes
assert escaped == '"550e8400-e29b-41d4-a716-446655440000"'
def test_empty_value(self):
"""Test escaping empty string."""
assert escape_logstore_query_value("") == '""'
def test_value_with_and_keyword(self):
"""Test that 'and' keyword is neutralized when quoted."""
malicious = "value and field:evil"
escaped = escape_logstore_query_value(malicious)
# Should be wrapped in quotes, making 'and' a literal
assert escaped == '"value and field:evil"'
# Simulate using in query
query = f"tenant_id:{escaped}"
assert query == 'tenant_id:"value and field:evil"'
def test_value_with_or_keyword(self):
"""Test that 'or' keyword is neutralized when quoted."""
malicious = "tenant_a or tenant_id:tenant_b"
escaped = escape_logstore_query_value(malicious)
assert escaped == '"tenant_a or tenant_id:tenant_b"'
query = f"tenant_id:{escaped}"
assert "or" in query # Present but as literal string
def test_value_with_not_keyword(self):
"""Test that 'not' keyword is neutralized when quoted."""
malicious = "not field:value"
escaped = escape_logstore_query_value(malicious)
assert escaped == '"not field:value"'
def test_value_with_parentheses(self):
"""Test that parentheses are neutralized when quoted."""
malicious = "(tenant_a or tenant_b)"
escaped = escape_logstore_query_value(malicious)
assert escaped == '"(tenant_a or tenant_b)"'
assert "(" in escaped # Present as literal
assert ")" in escaped # Present as literal
def test_value_with_colon(self):
"""Test that colons are neutralized when quoted."""
malicious = "field:value"
escaped = escape_logstore_query_value(malicious)
assert escaped == '"field:value"'
assert ":" in escaped # Present as literal
def test_value_with_double_quotes(self):
"""Test that internal double quotes are escaped."""
value_with_quotes = 'tenant"test"value'
escaped = escape_logstore_query_value(value_with_quotes)
# Double quotes should be escaped with backslash
assert escaped == '"tenant\\"test\\"value"'
# Should have outer quotes plus escaped inner quotes
assert '\\"' in escaped
def test_value_with_backslash(self):
"""Test that backslashes are escaped."""
value_with_backslash = "tenant\\test"
escaped = escape_logstore_query_value(value_with_backslash)
# Backslash should be escaped
assert escaped == '"tenant\\\\test"'
assert "\\\\" in escaped
def test_value_with_backslash_and_quote(self):
"""Test escaping both backslash and double quote."""
value = 'path\\to\\"file"'
escaped = escape_logstore_query_value(value)
# Both should be escaped
assert escaped == '"path\\\\to\\\\\\"file\\""'
# Verify escape order is correct
assert "\\\\" in escaped # Escaped backslash
assert '\\"' in escaped # Escaped double quote
def test_complex_injection_attempt(self):
"""Test complex injection combining multiple operators."""
malicious = 'tenant_a" or (tenant_id:"tenant_b" and app_id:"evil")'
escaped = escape_logstore_query_value(malicious)
# All special chars should be literals or escaped
assert escaped.startswith('"')
assert escaped.endswith('"')
# Inner double quotes escaped, operators become literals
assert "or" in escaped
assert "and" in escaped
assert '\\"' in escaped # Escaped quotes
def test_only_backslash(self):
"""Test escaping a single backslash."""
assert escape_logstore_query_value("\\") == '"\\\\"'
def test_only_double_quote(self):
"""Test escaping a single double quote."""
assert escape_logstore_query_value('"') == '"\\""'
def test_multiple_backslashes(self):
"""Test escaping multiple consecutive backslashes."""
assert escape_logstore_query_value("\\\\\\") == '"\\\\\\\\\\\\"' # 3 backslashes -> 6
def test_escape_sequence_like_input(self):
"""Test that existing escape sequences are properly escaped."""
# Input looks like already escaped, but we still escape it
value = 'value\\"test'
escaped = escape_logstore_query_value(value)
# \\ -> \\\\, " -> \"
assert escaped == '"value\\\\\\"test"'
@pytest.mark.parametrize(
("attack_scenario", "field", "malicious_value"),
[
("Cross-tenant via OR", "tenant_id", "tenant_a or tenant_id:tenant_b"),
("Cross-app via AND", "app_id", "app_a and (app_id:app_b or app_id:app_c)"),
("Boolean logic", "status", "succeeded or status:failed"),
("Negation", "tenant_id", "not tenant_a"),
("Field injection", "run_id", "run123 and tenant_id:evil_tenant"),
("Parentheses grouping", "app_id", "app1 or (app_id:app2 and tenant_id:tenant2)"),
("Quote breaking attempt", "tenant_id", 'tenant" or "1"="1'),
("Backslash escape bypass", "app_id", "app\\ and app_id:evil"),
],
)
def test_logstore_query_injection_scenarios(attack_scenario: str, field: str, malicious_value: str):
"""Test that various LogStore query injection attempts are neutralized."""
escaped = escape_logstore_query_value(malicious_value)
# Build query
query = f"{field}:{escaped}"
# All operators should be within quoted string (literals)
assert escaped.startswith('"')
assert escaped.endswith('"')
# Verify the full query structure is safe
assert query.count(":") >= 1 # At least the main field:value separator

View File

@ -1,4 +1,4 @@
from unittest.mock import MagicMock, patch
from unittest.mock import ANY, MagicMock, patch
import pytest
@ -17,7 +17,7 @@ def test_smtp_plain_success(mock_smtp_cls: MagicMock):
client = SMTPClient(server="smtp.example.com", port=25, username="", password="", _from="noreply@example.com")
client.send(_mail())
mock_smtp_cls.assert_called_once_with("smtp.example.com", 25, timeout=10)
mock_smtp_cls.assert_called_once_with("smtp.example.com", 25, timeout=10, local_hostname=ANY)
mock_smtp.sendmail.assert_called_once()
mock_smtp.quit.assert_called_once()
@ -38,7 +38,7 @@ def test_smtp_tls_opportunistic_success(mock_smtp_cls: MagicMock):
)
client.send(_mail())
mock_smtp_cls.assert_called_once_with("smtp.example.com", 587, timeout=10)
mock_smtp_cls.assert_called_once_with("smtp.example.com", 587, timeout=10, local_hostname=ANY)
assert mock_smtp.ehlo.call_count == 2
mock_smtp.starttls.assert_called_once()
mock_smtp.login.assert_called_once_with("user", "pass")

View File

@ -0,0 +1,627 @@
import datetime
from unittest.mock import MagicMock, patch
import pytest
from enums.cloud_plan import CloudPlan
from services.retention.conversation.messages_clean_policy import (
BillingDisabledPolicy,
BillingSandboxPolicy,
SimpleMessage,
create_message_clean_policy,
)
from services.retention.conversation.messages_clean_service import MessagesCleanService
def make_simple_message(msg_id: str, app_id: str) -> SimpleMessage:
"""Helper to create a SimpleMessage with a fixed created_at timestamp."""
return SimpleMessage(id=msg_id, app_id=app_id, created_at=datetime.datetime(2024, 1, 1))
def make_plan_provider(tenant_plans: dict) -> MagicMock:
"""Helper to create a mock plan_provider that returns the given tenant_plans."""
provider = MagicMock()
provider.return_value = tenant_plans
return provider
class TestBillingSandboxPolicyFilterMessageIds:
"""Unit tests for BillingSandboxPolicy.filter_message_ids method."""
# Fixed timestamp for deterministic tests
CURRENT_TIMESTAMP = 1000000
GRACEFUL_PERIOD_DAYS = 8
GRACEFUL_PERIOD_SECONDS = GRACEFUL_PERIOD_DAYS * 24 * 60 * 60
def test_missing_tenant_mapping_excluded(self):
"""Test that messages with missing app-to-tenant mapping are excluded."""
# Arrange
messages = [
make_simple_message("msg1", "app1"),
make_simple_message("msg2", "app2"),
]
app_to_tenant = {} # No mapping
tenant_plans = {"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}}
plan_provider = make_plan_provider(tenant_plans)
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS,
current_timestamp=self.CURRENT_TIMESTAMP,
)
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert
assert list(result) == []
def test_missing_tenant_plan_excluded(self):
"""Test that messages with missing tenant plan are excluded (safe default)."""
# Arrange
messages = [
make_simple_message("msg1", "app1"),
make_simple_message("msg2", "app2"),
]
app_to_tenant = {"app1": "tenant1", "app2": "tenant2"}
tenant_plans = {} # No plans
plan_provider = make_plan_provider(tenant_plans)
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS,
current_timestamp=self.CURRENT_TIMESTAMP,
)
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert
assert list(result) == []
def test_non_sandbox_plan_excluded(self):
"""Test that messages from non-sandbox plans (PROFESSIONAL/TEAM) are excluded."""
# Arrange
messages = [
make_simple_message("msg1", "app1"),
make_simple_message("msg2", "app2"),
make_simple_message("msg3", "app3"),
]
app_to_tenant = {"app1": "tenant1", "app2": "tenant2", "app3": "tenant3"}
tenant_plans = {
"tenant1": {"plan": CloudPlan.PROFESSIONAL, "expiration_date": -1},
"tenant2": {"plan": CloudPlan.TEAM, "expiration_date": -1},
"tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, # Only this one
}
plan_provider = make_plan_provider(tenant_plans)
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS,
current_timestamp=self.CURRENT_TIMESTAMP,
)
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert - only msg3 (sandbox tenant) should be included
assert set(result) == {"msg3"}
def test_whitelist_skip(self):
"""Test that whitelisted tenants are excluded even if sandbox + expired."""
# Arrange
messages = [
make_simple_message("msg1", "app1"), # Whitelisted - excluded
make_simple_message("msg2", "app2"), # Not whitelisted - included
make_simple_message("msg3", "app3"), # Whitelisted - excluded
]
app_to_tenant = {"app1": "tenant1", "app2": "tenant2", "app3": "tenant3"}
tenant_plans = {
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
"tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
}
plan_provider = make_plan_provider(tenant_plans)
tenant_whitelist = ["tenant1", "tenant3"]
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS,
tenant_whitelist=tenant_whitelist,
current_timestamp=self.CURRENT_TIMESTAMP,
)
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert - only msg2 should be included
assert set(result) == {"msg2"}
def test_no_previous_subscription_included(self):
"""Test that messages with expiration_date=-1 (no previous subscription) are included."""
# Arrange
messages = [
make_simple_message("msg1", "app1"),
make_simple_message("msg2", "app2"),
]
app_to_tenant = {"app1": "tenant1", "app2": "tenant2"}
tenant_plans = {
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
}
plan_provider = make_plan_provider(tenant_plans)
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS,
current_timestamp=self.CURRENT_TIMESTAMP,
)
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert - all messages should be included
assert set(result) == {"msg1", "msg2"}
def test_within_grace_period_excluded(self):
"""Test that messages within grace period are excluded."""
# Arrange
now = self.CURRENT_TIMESTAMP
expired_1_day_ago = now - (1 * 24 * 60 * 60)
expired_5_days_ago = now - (5 * 24 * 60 * 60)
expired_7_days_ago = now - (7 * 24 * 60 * 60)
messages = [
make_simple_message("msg1", "app1"),
make_simple_message("msg2", "app2"),
make_simple_message("msg3", "app3"),
]
app_to_tenant = {"app1": "tenant1", "app2": "tenant2", "app3": "tenant3"}
tenant_plans = {
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_1_day_ago},
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_5_days_ago},
"tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_7_days_ago},
}
plan_provider = make_plan_provider(tenant_plans)
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS, # 8 days
current_timestamp=now,
)
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert - all within 8-day grace period, none should be included
assert list(result) == []
def test_exactly_at_boundary_excluded(self):
"""Test that messages exactly at grace period boundary are excluded (code uses >)."""
# Arrange
now = self.CURRENT_TIMESTAMP
expired_exactly_8_days_ago = now - self.GRACEFUL_PERIOD_SECONDS # Exactly at boundary
messages = [make_simple_message("msg1", "app1")]
app_to_tenant = {"app1": "tenant1"}
tenant_plans = {
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_exactly_8_days_ago},
}
plan_provider = make_plan_provider(tenant_plans)
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS,
current_timestamp=now,
)
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert - exactly at boundary (==) should be excluded (code uses >)
assert list(result) == []
def test_beyond_grace_period_included(self):
"""Test that messages beyond grace period are included."""
# Arrange
now = self.CURRENT_TIMESTAMP
expired_9_days_ago = now - (9 * 24 * 60 * 60) # Just beyond 8-day grace
expired_30_days_ago = now - (30 * 24 * 60 * 60) # Well beyond
messages = [
make_simple_message("msg1", "app1"),
make_simple_message("msg2", "app2"),
]
app_to_tenant = {"app1": "tenant1", "app2": "tenant2"}
tenant_plans = {
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_9_days_ago},
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_30_days_ago},
}
plan_provider = make_plan_provider(tenant_plans)
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS,
current_timestamp=now,
)
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert - both beyond grace period, should be included
assert set(result) == {"msg1", "msg2"}
def test_empty_messages_returns_empty(self):
"""Test that empty messages returns empty list."""
# Arrange
messages: list[SimpleMessage] = []
app_to_tenant = {"app1": "tenant1"}
plan_provider = make_plan_provider({"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}})
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS,
current_timestamp=self.CURRENT_TIMESTAMP,
)
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert
assert list(result) == []
def test_plan_provider_called_with_correct_tenant_ids(self):
"""Test that plan_provider is called with correct tenant_ids."""
# Arrange
messages = [
make_simple_message("msg1", "app1"),
make_simple_message("msg2", "app2"),
make_simple_message("msg3", "app3"),
]
app_to_tenant = {"app1": "tenant1", "app2": "tenant2", "app3": "tenant1"} # tenant1 appears twice
plan_provider = make_plan_provider({})
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS,
current_timestamp=self.CURRENT_TIMESTAMP,
)
# Act
policy.filter_message_ids(messages, app_to_tenant)
# Assert - plan_provider should be called once with unique tenant_ids
plan_provider.assert_called_once()
called_tenant_ids = set(plan_provider.call_args[0][0])
assert called_tenant_ids == {"tenant1", "tenant2"}
def test_complex_mixed_scenario(self):
"""Test complex scenario with mixed plans, expirations, whitelist, and missing mappings."""
# Arrange
now = self.CURRENT_TIMESTAMP
sandbox_expired_old = now - (15 * 24 * 60 * 60) # Beyond grace
sandbox_expired_recent = now - (3 * 24 * 60 * 60) # Within grace
future_expiration = now + (30 * 24 * 60 * 60)
messages = [
make_simple_message("msg1", "app1"), # Sandbox, no subscription - included
make_simple_message("msg2", "app2"), # Sandbox, expired old - included
make_simple_message("msg3", "app3"), # Sandbox, within grace - excluded
make_simple_message("msg4", "app4"), # Team plan, active - excluded
make_simple_message("msg5", "app5"), # No tenant mapping - excluded
make_simple_message("msg6", "app6"), # No plan info - excluded
make_simple_message("msg7", "app7"), # Sandbox, expired old, whitelisted - excluded
]
app_to_tenant = {
"app1": "tenant1",
"app2": "tenant2",
"app3": "tenant3",
"app4": "tenant4",
"app6": "tenant6", # Has mapping but no plan
"app7": "tenant7",
# app5 has no mapping
}
tenant_plans = {
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_old},
"tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_recent},
"tenant4": {"plan": CloudPlan.TEAM, "expiration_date": future_expiration},
"tenant7": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_old},
# tenant6 has no plan
}
plan_provider = make_plan_provider(tenant_plans)
tenant_whitelist = ["tenant7"]
policy = BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=self.GRACEFUL_PERIOD_DAYS,
tenant_whitelist=tenant_whitelist,
current_timestamp=now,
)
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert - only msg1 and msg2 should be included
assert set(result) == {"msg1", "msg2"}
class TestBillingDisabledPolicyFilterMessageIds:
"""Unit tests for BillingDisabledPolicy.filter_message_ids method."""
def test_returns_all_message_ids(self):
"""Test that all message IDs are returned (order-preserving)."""
# Arrange
messages = [
make_simple_message("msg1", "app1"),
make_simple_message("msg2", "app2"),
make_simple_message("msg3", "app3"),
]
app_to_tenant = {"app1": "tenant1", "app2": "tenant2"}
policy = BillingDisabledPolicy()
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert - all message IDs returned in order
assert list(result) == ["msg1", "msg2", "msg3"]
def test_ignores_app_to_tenant(self):
"""Test that app_to_tenant mapping is ignored."""
# Arrange
messages = [
make_simple_message("msg1", "app1"),
make_simple_message("msg2", "app2"),
]
app_to_tenant: dict[str, str] = {} # Empty - should be ignored
policy = BillingDisabledPolicy()
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert - all message IDs still returned
assert list(result) == ["msg1", "msg2"]
def test_empty_messages_returns_empty(self):
"""Test that empty messages returns empty list."""
# Arrange
messages: list[SimpleMessage] = []
app_to_tenant = {"app1": "tenant1"}
policy = BillingDisabledPolicy()
# Act
result = policy.filter_message_ids(messages, app_to_tenant)
# Assert
assert list(result) == []
class TestCreateMessageCleanPolicy:
"""Unit tests for create_message_clean_policy factory function."""
@patch("services.retention.conversation.messages_clean_policy.dify_config")
def test_billing_disabled_returns_billing_disabled_policy(self, mock_config):
"""Test that BILLING_ENABLED=False returns BillingDisabledPolicy."""
# Arrange
mock_config.BILLING_ENABLED = False
# Act
policy = create_message_clean_policy(graceful_period_days=21)
# Assert
assert isinstance(policy, BillingDisabledPolicy)
@patch("services.retention.conversation.messages_clean_policy.BillingService")
@patch("services.retention.conversation.messages_clean_policy.dify_config")
def test_billing_enabled_policy_has_correct_internals(self, mock_config, mock_billing_service):
"""Test that BillingSandboxPolicy is created with correct internal values."""
# Arrange
mock_config.BILLING_ENABLED = True
whitelist = ["tenant1", "tenant2"]
mock_billing_service.get_expired_subscription_cleanup_whitelist.return_value = whitelist
mock_plan_provider = MagicMock()
mock_billing_service.get_plan_bulk_with_cache = mock_plan_provider
# Act
policy = create_message_clean_policy(graceful_period_days=14, current_timestamp=1234567)
# Assert
mock_billing_service.get_expired_subscription_cleanup_whitelist.assert_called_once()
assert isinstance(policy, BillingSandboxPolicy)
assert policy._graceful_period_days == 14
assert list(policy._tenant_whitelist) == whitelist
assert policy._plan_provider == mock_plan_provider
assert policy._current_timestamp == 1234567
class TestMessagesCleanServiceFromTimeRange:
"""Unit tests for MessagesCleanService.from_time_range factory method."""
def test_start_from_end_before_raises_value_error(self):
"""Test that start_from == end_before raises ValueError."""
policy = BillingDisabledPolicy()
# Arrange
same_time = datetime.datetime(2024, 1, 1, 12, 0, 0)
# Act & Assert
with pytest.raises(ValueError, match="start_from .* must be less than end_before"):
MessagesCleanService.from_time_range(
policy=policy,
start_from=same_time,
end_before=same_time,
)
# Arrange
start_from = datetime.datetime(2024, 12, 31)
end_before = datetime.datetime(2024, 1, 1)
# Act & Assert
with pytest.raises(ValueError, match="start_from .* must be less than end_before"):
MessagesCleanService.from_time_range(
policy=policy,
start_from=start_from,
end_before=end_before,
)
def test_batch_size_raises_value_error(self):
"""Test that batch_size=0 raises ValueError."""
# Arrange
start_from = datetime.datetime(2024, 1, 1)
end_before = datetime.datetime(2024, 2, 1)
policy = BillingDisabledPolicy()
# Act & Assert
with pytest.raises(ValueError, match="batch_size .* must be greater than 0"):
MessagesCleanService.from_time_range(
policy=policy,
start_from=start_from,
end_before=end_before,
batch_size=0,
)
start_from = datetime.datetime(2024, 1, 1)
end_before = datetime.datetime(2024, 2, 1)
policy = BillingDisabledPolicy()
# Act & Assert
with pytest.raises(ValueError, match="batch_size .* must be greater than 0"):
MessagesCleanService.from_time_range(
policy=policy,
start_from=start_from,
end_before=end_before,
batch_size=-100,
)
def test_valid_params_creates_instance(self):
"""Test that valid parameters create a correctly configured instance."""
# Arrange
start_from = datetime.datetime(2024, 1, 1, 0, 0, 0)
end_before = datetime.datetime(2024, 12, 31, 23, 59, 59)
policy = BillingDisabledPolicy()
batch_size = 500
dry_run = True
# Act
service = MessagesCleanService.from_time_range(
policy=policy,
start_from=start_from,
end_before=end_before,
batch_size=batch_size,
dry_run=dry_run,
)
# Assert
assert isinstance(service, MessagesCleanService)
assert service._policy is policy
assert service._start_from == start_from
assert service._end_before == end_before
assert service._batch_size == batch_size
assert service._dry_run == dry_run
def test_default_params(self):
"""Test that default parameters are applied correctly."""
# Arrange
start_from = datetime.datetime(2024, 1, 1)
end_before = datetime.datetime(2024, 2, 1)
policy = BillingDisabledPolicy()
# Act
service = MessagesCleanService.from_time_range(
policy=policy,
start_from=start_from,
end_before=end_before,
)
# Assert
assert service._batch_size == 1000 # default
assert service._dry_run is False # default
class TestMessagesCleanServiceFromDays:
"""Unit tests for MessagesCleanService.from_days factory method."""
def test_days_raises_value_error(self):
"""Test that days < 0 raises ValueError."""
# Arrange
policy = BillingDisabledPolicy()
# Act & Assert
with pytest.raises(ValueError, match="days .* must be greater than or equal to 0"):
MessagesCleanService.from_days(policy=policy, days=-1)
# Act
with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime:
fixed_now = datetime.datetime(2024, 6, 15, 14, 0, 0)
mock_datetime.datetime.now.return_value = fixed_now
mock_datetime.timedelta = datetime.timedelta
service = MessagesCleanService.from_days(policy=policy, days=0)
# Assert
assert service._end_before == fixed_now
def test_batch_size_raises_value_error(self):
"""Test that batch_size=0 raises ValueError."""
# Arrange
policy = BillingDisabledPolicy()
# Act & Assert
with pytest.raises(ValueError, match="batch_size .* must be greater than 0"):
MessagesCleanService.from_days(policy=policy, days=30, batch_size=0)
# Act & Assert
with pytest.raises(ValueError, match="batch_size .* must be greater than 0"):
MessagesCleanService.from_days(policy=policy, days=30, batch_size=-500)
def test_valid_params_creates_instance(self):
"""Test that valid parameters create a correctly configured instance."""
# Arrange
policy = BillingDisabledPolicy()
days = 90
batch_size = 500
dry_run = True
# Act
with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime:
fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0)
mock_datetime.datetime.now.return_value = fixed_now
mock_datetime.timedelta = datetime.timedelta
service = MessagesCleanService.from_days(
policy=policy,
days=days,
batch_size=batch_size,
dry_run=dry_run,
)
# Assert
expected_end_before = fixed_now - datetime.timedelta(days=days)
assert isinstance(service, MessagesCleanService)
assert service._policy is policy
assert service._start_from is None
assert service._end_before == expected_end_before
assert service._batch_size == batch_size
assert service._dry_run == dry_run
def test_default_params(self):
"""Test that default parameters are applied correctly."""
# Arrange
policy = BillingDisabledPolicy()
# Act
with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime:
fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0)
mock_datetime.datetime.now.return_value = fixed_now
mock_datetime.timedelta = datetime.timedelta
service = MessagesCleanService.from_days(policy=policy)
# Assert
expected_end_before = fixed_now - datetime.timedelta(days=30) # default days=30
assert service._end_before == expected_end_before
assert service._batch_size == 1000 # default
assert service._dry_run is False # default

View File

@ -9,7 +9,7 @@ This module tests the mail sending functionality including:
"""
import smtplib
from unittest.mock import MagicMock, patch
from unittest.mock import ANY, MagicMock, patch
import pytest
@ -151,7 +151,7 @@ class TestSMTPIntegration:
client.send(mail_data)
# Assert
mock_smtp_ssl.assert_called_once_with("smtp.example.com", 465, timeout=10)
mock_smtp_ssl.assert_called_once_with("smtp.example.com", 465, timeout=10, local_hostname=ANY)
mock_server.login.assert_called_once_with("user@example.com", "password123")
mock_server.sendmail.assert_called_once()
mock_server.quit.assert_called_once()
@ -181,7 +181,7 @@ class TestSMTPIntegration:
client.send(mail_data)
# Assert
mock_smtp.assert_called_once_with("smtp.example.com", 587, timeout=10)
mock_smtp.assert_called_once_with("smtp.example.com", 587, timeout=10, local_hostname=ANY)
mock_server.ehlo.assert_called()
mock_server.starttls.assert_called_once()
assert mock_server.ehlo.call_count == 2 # Before and after STARTTLS
@ -213,7 +213,7 @@ class TestSMTPIntegration:
client.send(mail_data)
# Assert
mock_smtp.assert_called_once_with("smtp.example.com", 25, timeout=10)
mock_smtp.assert_called_once_with("smtp.example.com", 25, timeout=10, local_hostname=ANY)
mock_server.login.assert_called_once()
mock_server.sendmail.assert_called_once()
mock_server.quit.assert_called_once()

4670
api/uv.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -968,6 +968,8 @@ SMTP_USERNAME=
SMTP_PASSWORD=
SMTP_USE_TLS=true
SMTP_OPPORTUNISTIC_TLS=false
# Optional: override the local hostname used for SMTP HELO/EHLO
SMTP_LOCAL_HOSTNAME=
# Sendgid configuration
SENDGRID_API_KEY=
@ -1037,18 +1039,26 @@ WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
# Options:
# - core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository (default)
# - core.repositories.celery_workflow_execution_repository.CeleryWorkflowExecutionRepository
# - extensions.logstore.repositories.logstore_workflow_execution_repository.LogstoreWorkflowExecutionRepository
CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository
# Core workflow node execution repository implementation
# Options:
# - core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository (default)
# - core.repositories.celery_workflow_node_execution_repository.CeleryWorkflowNodeExecutionRepository
# - extensions.logstore.repositories.logstore_workflow_node_execution_repository.LogstoreWorkflowNodeExecutionRepository
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository
# API workflow run repository implementation
# Options:
# - repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository (default)
# - extensions.logstore.repositories.logstore_api_workflow_run_repository.LogstoreAPIWorkflowRunRepository
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
# API workflow node execution repository implementation
# Options:
# - repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository (default)
# - extensions.logstore.repositories.logstore_api_workflow_node_execution_repository.LogstoreAPIWorkflowNodeExecutionRepository
API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository
# Workflow log cleanup configuration

View File

@ -21,7 +21,7 @@ services:
# API service
api:
image: langgenius/dify-api:1.11.3
image: langgenius/dify-api:1.11.4
restart: always
environment:
# Use the shared environment variables.
@ -63,7 +63,7 @@ services:
# worker service
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
worker:
image: langgenius/dify-api:1.11.3
image: langgenius/dify-api:1.11.4
restart: always
environment:
# Use the shared environment variables.
@ -102,7 +102,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: langgenius/dify-api:1.11.3
image: langgenius/dify-api:1.11.4
restart: always
environment:
# Use the shared environment variables.
@ -132,7 +132,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:1.11.3
image: langgenius/dify-web:1.11.4
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}

View File

@ -425,6 +425,7 @@ x-shared-env: &shared-api-worker-env
SMTP_PASSWORD: ${SMTP_PASSWORD:-}
SMTP_USE_TLS: ${SMTP_USE_TLS:-true}
SMTP_OPPORTUNISTIC_TLS: ${SMTP_OPPORTUNISTIC_TLS:-false}
SMTP_LOCAL_HOSTNAME: ${SMTP_LOCAL_HOSTNAME:-}
SENDGRID_API_KEY: ${SENDGRID_API_KEY:-}
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-4000}
INVITE_EXPIRY_HOURS: ${INVITE_EXPIRY_HOURS:-72}
@ -704,7 +705,7 @@ services:
# API service
api:
image: langgenius/dify-api:1.11.3
image: langgenius/dify-api:1.11.4
restart: always
environment:
# Use the shared environment variables.
@ -746,7 +747,7 @@ services:
# worker service
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
worker:
image: langgenius/dify-api:1.11.3
image: langgenius/dify-api:1.11.4
restart: always
environment:
# Use the shared environment variables.
@ -785,7 +786,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: langgenius/dify-api:1.11.3
image: langgenius/dify-api:1.11.4
restart: always
environment:
# Use the shared environment variables.
@ -815,7 +816,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:1.11.3
image: langgenius/dify-web:1.11.4
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}

View File

@ -1 +1 @@
22.21.1
24

View File

@ -1,5 +1,5 @@
# base image
FROM node:22.21.1-alpine3.23 AS base
FROM node:24-alpine AS base
LABEL maintainer="takatost@gmail.com"
# if you located in China, you can use aliyun mirror to speed up

View File

@ -8,8 +8,8 @@ This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next
Before starting the web frontend service, please make sure the following environment is ready.
- [Node.js](https://nodejs.org) >= v22.11.x
- [pnpm](https://pnpm.io) v10.x
- [Node.js](https://nodejs.org)
- [pnpm](https://pnpm.io)
> [!TIP]
> It is recommended to install and enable Corepack to manage package manager versions automatically:

View File

@ -66,7 +66,9 @@ export default function CheckCode() {
setIsLoading(true)
const ret = await webAppEmailLoginWithCode({ email, code: encryptVerificationCode(code), token })
if (ret.result === 'success') {
setWebAppAccessToken(ret.data.access_token)
if (ret?.data?.access_token) {
setWebAppAccessToken(ret.data.access_token)
}
const { access_token } = await fetchAccessToken({
appCode: appCode!,
userId: embeddedUserId || undefined,

View File

@ -82,7 +82,9 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut
body: loginData,
})
if (res.result === 'success') {
setWebAppAccessToken(res.data.access_token)
if (res?.data?.access_token) {
setWebAppAccessToken(res.data.access_token)
}
const { access_token } = await fetchAccessToken({
appCode: appCode!,

View File

@ -183,7 +183,6 @@ const AgentTools: FC = () => {
onSelect={handleSelectTool}
onSelectMultiple={handleSelectMultipleTool}
selectedTools={tools as unknown as ToolValue[]}
canChooseMCPTool
/>
</>
)}

View File

@ -12,7 +12,6 @@ import { useDebounceFn } from 'ahooks'
import dynamic from 'next/dynamic'
import {
useRouter,
useSearchParams,
} from 'next/navigation'
import { parseAsString, useQueryState } from 'nuqs'
import { useCallback, useEffect, useRef, useState } from 'react'
@ -29,7 +28,6 @@ import { CheckModal } from '@/hooks/use-pay'
import { useInfiniteAppList } from '@/service/use-apps'
import { AppModeEnum } from '@/types/app'
import { cn } from '@/utils/classnames'
import { isServer } from '@/utils/client'
import AppCard from './app-card'
import { AppCardSkeleton } from './app-card-skeleton'
import Empty from './empty'
@ -59,7 +57,6 @@ const List = () => {
const { t } = useTranslation()
const { systemFeatures } = useGlobalPublicStore()
const router = useRouter()
const searchParams = useSearchParams()
const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator, isLoadingCurrentWorkspace } = useAppContext()
const showTagManagementModal = useTagStore(s => s.showTagManagementModal)
const [activeTab, setActiveTab] = useQueryState(
@ -67,33 +64,6 @@ const List = () => {
parseAsString.withDefault('all').withOptions({ history: 'push' }),
)
// valid tabs for apps list; anything else should fallback to 'all'
// 1) Normalize legacy/incorrect query params like ?mode=discover -> ?category=all
useEffect(() => {
// avoid running on server
if (isServer)
return
const mode = searchParams.get('mode')
if (!mode)
return
const url = new URL(window.location.href)
url.searchParams.delete('mode')
if (validTabs.has(mode)) {
// migrate to category key
url.searchParams.set('category', mode)
}
else {
url.searchParams.set('category', 'all')
}
router.replace(url.pathname + url.search)
}, [router, searchParams])
// 2) If category has an invalid value (e.g., 'discover'), reset to 'all'
useEffect(() => {
if (!validTabs.has(activeTab))
setActiveTab('all')
}, [activeTab, setActiveTab])
const { query: { tagIDs = [], keywords = '', isCreatedByMe: queryIsCreatedByMe = false }, setQuery } = useAppsQueryState()
const [isCreatedByMe, setIsCreatedByMe] = useState(queryIsCreatedByMe)
const [tagFilterValue, setTagFilterValue] = useState<string[]>(tagIDs)

View File

@ -27,7 +27,9 @@ vi.mock('@/service/billing', () => ({
vi.mock('@/service/client', () => ({
consoleClient: {
billingUrl: vi.fn(),
billing: {
invoices: vi.fn(),
},
},
}))
@ -43,7 +45,7 @@ vi.mock('../../assets', () => ({
const mockUseAppContext = useAppContext as Mock
const mockUseAsyncWindowOpen = useAsyncWindowOpen as Mock
const mockBillingUrl = consoleClient.billingUrl as Mock
const mockBillingInvoices = consoleClient.billing.invoices as Mock
const mockFetchSubscriptionUrls = fetchSubscriptionUrls as Mock
const mockToastNotify = Toast.notify as Mock
@ -75,7 +77,7 @@ beforeEach(() => {
vi.clearAllMocks()
mockUseAppContext.mockReturnValue({ isCurrentWorkspaceManager: true })
mockUseAsyncWindowOpen.mockReturnValue(vi.fn(async open => await open()))
mockBillingUrl.mockResolvedValue({ url: 'https://billing.example' })
mockBillingInvoices.mockResolvedValue({ url: 'https://billing.example' })
mockFetchSubscriptionUrls.mockResolvedValue({ url: 'https://subscription.example' })
assignedHref = ''
})
@ -149,7 +151,7 @@ describe('CloudPlanItem', () => {
type: 'error',
message: 'billing.buyPermissionDeniedTip',
}))
expect(mockBillingUrl).not.toHaveBeenCalled()
expect(mockBillingInvoices).not.toHaveBeenCalled()
})
it('should open billing portal when upgrading current paid plan', async () => {
@ -168,7 +170,7 @@ describe('CloudPlanItem', () => {
fireEvent.click(screen.getByRole('button', { name: 'billing.plansCommon.currentPlan' }))
await waitFor(() => {
expect(mockBillingUrl).toHaveBeenCalledTimes(1)
expect(mockBillingInvoices).toHaveBeenCalledTimes(1)
})
expect(openWindow).toHaveBeenCalledTimes(1)
})

View File

@ -77,7 +77,7 @@ const CloudPlanItem: FC<CloudPlanItemProps> = ({
try {
if (isCurrentPaidPlan) {
await openAsyncWindow(async () => {
const res = await consoleClient.billingUrl()
const res = await consoleClient.billing.invoices()
if (res.url)
return res.url
throw new Error('Failed to open billing page')

View File

@ -362,6 +362,18 @@ describe('PreviewDocumentPicker', () => {
expect(screen.getByText('--')).toBeInTheDocument()
})
it('should render when value prop is omitted (optional)', () => {
const files = createMockDocumentList(2)
const onChange = vi.fn()
// Do not pass `value` at all to verify optional behavior
render(<PreviewDocumentPicker files={files} onChange={onChange} />)
// Renders placeholder for missing name
expect(screen.getByText('--')).toBeInTheDocument()
// Portal wrapper renders
expect(screen.getByTestId('portal-elem')).toBeInTheDocument()
})
it('should handle empty files array', () => {
renderComponent({ files: [] })

View File

@ -18,7 +18,7 @@ import DocumentList from './document-list'
type Props = {
className?: string
value: DocumentItem
value?: DocumentItem
files: DocumentItem[]
onChange: (value: DocumentItem) => void
}
@ -30,7 +30,8 @@ const PreviewDocumentPicker: FC<Props> = ({
onChange,
}) => {
const { t } = useTranslation()
const { name, extension } = value
const name = value?.name || ''
const extension = value?.extension
const [open, {
set: setOpen,

View File

@ -0,0 +1,201 @@
'use client'
import type { FC } from 'react'
import type { Item } from '@/app/components/base/select'
import type { BuiltInMetadataItem, MetadataItemWithValueLength } from '@/app/components/datasets/metadata/types'
import type { SortType } from '@/service/datasets'
import { PlusIcon } from '@heroicons/react/24/solid'
import { RiDraftLine, RiExternalLinkLine } from '@remixicon/react'
import { useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import Chip from '@/app/components/base/chip'
import Input from '@/app/components/base/input'
import Sort from '@/app/components/base/sort'
import AutoDisabledDocument from '@/app/components/datasets/common/document-status-with-action/auto-disabled-document'
import IndexFailed from '@/app/components/datasets/common/document-status-with-action/index-failed'
import StatusWithAction from '@/app/components/datasets/common/document-status-with-action/status-with-action'
import DatasetMetadataDrawer from '@/app/components/datasets/metadata/metadata-dataset/dataset-metadata-drawer'
import { useDocLink } from '@/context/i18n'
import { DataSourceType } from '@/models/datasets'
import { useIndexStatus } from '../status-item/hooks'
type DocumentsHeaderProps = {
// Dataset info
datasetId: string
dataSourceType?: DataSourceType
embeddingAvailable: boolean
isFreePlan: boolean
// Filter & sort
statusFilterValue: string
sortValue: SortType
inputValue: string
onStatusFilterChange: (value: string) => void
onStatusFilterClear: () => void
onSortChange: (value: string) => void
onInputChange: (value: string) => void
// Metadata modal
isShowEditMetadataModal: boolean
showEditMetadataModal: () => void
hideEditMetadataModal: () => void
datasetMetaData?: MetadataItemWithValueLength[]
builtInMetaData?: BuiltInMetadataItem[]
builtInEnabled: boolean
onAddMetaData: (payload: BuiltInMetadataItem) => Promise<void>
onRenameMetaData: (payload: MetadataItemWithValueLength) => Promise<void>
onDeleteMetaData: (metaDataId: string) => Promise<void>
onBuiltInEnabledChange: (enabled: boolean) => void
// Actions
onAddDocument: () => void
}
const DocumentsHeader: FC<DocumentsHeaderProps> = ({
datasetId,
dataSourceType,
embeddingAvailable,
isFreePlan,
statusFilterValue,
sortValue,
inputValue,
onStatusFilterChange,
onStatusFilterClear,
onSortChange,
onInputChange,
isShowEditMetadataModal,
showEditMetadataModal,
hideEditMetadataModal,
datasetMetaData,
builtInMetaData,
builtInEnabled,
onAddMetaData,
onRenameMetaData,
onDeleteMetaData,
onBuiltInEnabledChange,
onAddDocument,
}) => {
const { t } = useTranslation()
const docLink = useDocLink()
const DOC_INDEX_STATUS_MAP = useIndexStatus()
const isDataSourceNotion = dataSourceType === DataSourceType.NOTION
const isDataSourceWeb = dataSourceType === DataSourceType.WEB
const statusFilterItems: Item[] = useMemo(() => [
{ value: 'all', name: t('list.index.all', { ns: 'datasetDocuments' }) as string },
{ value: 'queuing', name: DOC_INDEX_STATUS_MAP.queuing.text },
{ value: 'indexing', name: DOC_INDEX_STATUS_MAP.indexing.text },
{ value: 'paused', name: DOC_INDEX_STATUS_MAP.paused.text },
{ value: 'error', name: DOC_INDEX_STATUS_MAP.error.text },
{ value: 'available', name: DOC_INDEX_STATUS_MAP.available.text },
{ value: 'enabled', name: DOC_INDEX_STATUS_MAP.enabled.text },
{ value: 'disabled', name: DOC_INDEX_STATUS_MAP.disabled.text },
{ value: 'archived', name: DOC_INDEX_STATUS_MAP.archived.text },
], [DOC_INDEX_STATUS_MAP, t])
const sortItems: Item[] = useMemo(() => [
{ value: 'created_at', name: t('list.sort.uploadTime', { ns: 'datasetDocuments' }) as string },
{ value: 'hit_count', name: t('list.sort.hitCount', { ns: 'datasetDocuments' }) as string },
], [t])
// Determine add button text based on data source type
const addButtonText = useMemo(() => {
if (isDataSourceNotion)
return t('list.addPages', { ns: 'datasetDocuments' })
if (isDataSourceWeb)
return t('list.addUrl', { ns: 'datasetDocuments' })
return t('list.addFile', { ns: 'datasetDocuments' })
}, [isDataSourceNotion, isDataSourceWeb, t])
return (
<>
{/* Title section */}
<div className="flex flex-col justify-center gap-1 px-6 pt-4">
<h1 className="text-base font-semibold text-text-primary">
{t('list.title', { ns: 'datasetDocuments' })}
</h1>
<div className="flex items-center space-x-0.5 text-sm font-normal text-text-tertiary">
<span>{t('list.desc', { ns: 'datasetDocuments' })}</span>
<a
className="flex items-center text-text-accent"
target="_blank"
rel="noopener noreferrer"
href={docLink('/guides/knowledge-base/integrate-knowledge-within-application')}
>
<span>{t('list.learnMore', { ns: 'datasetDocuments' })}</span>
<RiExternalLinkLine className="h-3 w-3" />
</a>
</div>
</div>
{/* Toolbar section */}
<div className="flex flex-wrap items-center justify-between px-6 pt-4">
{/* Left: Filters */}
<div className="flex items-center gap-2">
<Chip
className="w-[160px]"
showLeftIcon={false}
value={statusFilterValue}
items={statusFilterItems}
onSelect={item => onStatusFilterChange(item?.value ? String(item.value) : '')}
onClear={onStatusFilterClear}
/>
<Input
showLeftIcon
showClearIcon
wrapperClassName="!w-[200px]"
value={inputValue}
onChange={e => onInputChange(e.target.value)}
onClear={() => onInputChange('')}
/>
<div className="h-3.5 w-px bg-divider-regular"></div>
<Sort
order={sortValue.startsWith('-') ? '-' : ''}
value={sortValue.replace('-', '')}
items={sortItems}
onSelect={value => onSortChange(String(value))}
/>
</div>
{/* Right: Actions */}
<div className="flex !h-8 items-center justify-center gap-2">
{!isFreePlan && <AutoDisabledDocument datasetId={datasetId} />}
<IndexFailed datasetId={datasetId} />
{!embeddingAvailable && (
<StatusWithAction
type="warning"
description={t('embeddingModelNotAvailable', { ns: 'dataset' })}
/>
)}
{embeddingAvailable && (
<Button variant="secondary" className="shrink-0" onClick={showEditMetadataModal}>
<RiDraftLine className="mr-1 size-4" />
{t('metadata.metadata', { ns: 'dataset' })}
</Button>
)}
{isShowEditMetadataModal && (
<DatasetMetadataDrawer
userMetadata={datasetMetaData ?? []}
onClose={hideEditMetadataModal}
onAdd={onAddMetaData}
onRename={onRenameMetaData}
onRemove={onDeleteMetaData}
builtInMetadata={builtInMetaData ?? []}
isBuiltInEnabled={builtInEnabled}
onIsBuiltInEnabledChange={onBuiltInEnabledChange}
/>
)}
{embeddingAvailable && (
<Button variant="primary" onClick={onAddDocument} className="shrink-0">
<PlusIcon className="mr-2 h-4 w-4 stroke-current" />
{addButtonText}
</Button>
)}
</div>
</div>
</>
)
}
export default DocumentsHeader

View File

@ -0,0 +1,41 @@
'use client'
import type { FC } from 'react'
import { PlusIcon } from '@heroicons/react/24/solid'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import s from '../style.module.css'
import { FolderPlusIcon, NotionIcon, ThreeDotsIcon } from './icons'
type EmptyElementProps = {
canAdd: boolean
onClick: () => void
type?: 'upload' | 'sync'
}
const EmptyElement: FC<EmptyElementProps> = ({ canAdd = true, onClick, type = 'upload' }) => {
const { t } = useTranslation()
return (
<div className={s.emptyWrapper}>
<div className={s.emptyElement}>
<div className={s.emptySymbolIconWrapper}>
{type === 'upload' ? <FolderPlusIcon /> : <NotionIcon />}
</div>
<span className={s.emptyTitle}>
{t('list.empty.title', { ns: 'datasetDocuments' })}
<ThreeDotsIcon className="relative -left-1.5 -top-3 inline" />
</span>
<div className={s.emptyTip}>
{t(`list.empty.${type}.tip`, { ns: 'datasetDocuments' })}
</div>
{type === 'upload' && canAdd && (
<Button onClick={onClick} className={s.addFileBtn} variant="secondary-accent">
<PlusIcon className={s.plusIcon} />
{t('list.addFile', { ns: 'datasetDocuments' })}
</Button>
)}
</div>
</div>
)
}
export default EmptyElement

View File

@ -0,0 +1,34 @@
import type * as React from 'react'
export const FolderPlusIcon = ({ className }: React.SVGProps<SVGElement>) => {
return (
<svg width="20" height="20" viewBox="0 0 20 20" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
<path d="M10.8332 5.83333L9.90355 3.9741C9.63601 3.439 9.50222 3.17144 9.30265 2.97597C9.12615 2.80311 8.91344 2.67164 8.6799 2.59109C8.41581 2.5 8.11668 2.5 7.51841 2.5H4.33317C3.39975 2.5 2.93304 2.5 2.57652 2.68166C2.26292 2.84144 2.00795 3.09641 1.84816 3.41002C1.6665 3.76654 1.6665 4.23325 1.6665 5.16667V5.83333M1.6665 5.83333H14.3332C15.7333 5.83333 16.4334 5.83333 16.9681 6.10582C17.4386 6.3455 17.821 6.72795 18.0607 7.19836C18.3332 7.73314 18.3332 8.4332 18.3332 9.83333V13.5C18.3332 14.9001 18.3332 15.6002 18.0607 16.135C17.821 16.6054 17.4386 16.9878 16.9681 17.2275C16.4334 17.5 15.7333 17.5 14.3332 17.5H5.6665C4.26637 17.5 3.56631 17.5 3.03153 17.2275C2.56112 16.9878 2.17867 16.6054 1.93899 16.135C1.6665 15.6002 1.6665 14.9001 1.6665 13.5V5.83333ZM9.99984 14.1667V9.16667M7.49984 11.6667H12.4998" stroke="#667085" strokeWidth="1.5" strokeLinecap="round" strokeLinejoin="round" />
</svg>
)
}
export const ThreeDotsIcon = ({ className }: React.SVGProps<SVGElement>) => {
return (
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
<path d="M5 6.5V5M8.93934 7.56066L10 6.5M10.0103 11.5H11.5103" stroke="#374151" strokeWidth="2" strokeLinecap="round" strokeLinejoin="round" />
</svg>
)
}
export const NotionIcon = ({ className }: React.SVGProps<SVGElement>) => {
return (
<svg width="20" height="20" viewBox="0 0 20 20" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
<g clipPath="url(#clip0_2164_11263)">
<path fillRule="evenodd" clipRule="evenodd" d="M3.5725 18.2611L1.4229 15.5832C0.905706 14.9389 0.625 14.1466 0.625 13.3312V3.63437C0.625 2.4129 1.60224 1.39936 2.86295 1.31328L12.8326 0.632614C13.5569 0.583164 14.2768 0.775682 14.8717 1.17794L18.3745 3.5462C19.0015 3.97012 19.375 4.66312 19.375 5.40266V16.427C19.375 17.6223 18.4141 18.6121 17.1798 18.688L6.11458 19.3692C5.12958 19.4298 4.17749 19.0148 3.5725 18.2611Z" fill="white" />
<path d="M7.03006 8.48669V8.35974C7.03006 8.03794 7.28779 7.77104 7.61997 7.74886L10.0396 7.58733L13.3857 12.5147V8.19009L12.5244 8.07528V8.01498C12.5244 7.68939 12.788 7.42074 13.1244 7.4035L15.326 7.29073V7.60755C15.326 7.75628 15.2154 7.88349 15.0638 7.90913L14.534 7.99874V15.0023L13.8691 15.231C13.3136 15.422 12.6952 15.2175 12.3772 14.7377L9.12879 9.83574V14.5144L10.1287 14.7057L10.1147 14.7985C10.0711 15.089 9.82028 15.3087 9.51687 15.3222L7.03006 15.4329C6.99718 15.1205 7.23132 14.841 7.55431 14.807L7.88143 14.7727V8.53453L7.03006 8.48669Z" fill="black" />
<path fillRule="evenodd" clipRule="evenodd" d="M12.9218 1.85424L2.95217 2.53491C2.35499 2.57568 1.89209 3.05578 1.89209 3.63437V13.3312C1.89209 13.8748 2.07923 14.403 2.42402 14.8325L4.57362 17.5104C4.92117 17.9434 5.46812 18.1818 6.03397 18.147L17.0991 17.4658C17.6663 17.4309 18.1078 16.9762 18.1078 16.427V5.40266C18.1078 5.06287 17.9362 4.74447 17.6481 4.54969L14.1453 2.18143C13.7883 1.94008 13.3564 1.82457 12.9218 1.85424ZM3.44654 3.78562C3.30788 3.68296 3.37387 3.46909 3.54806 3.4566L12.9889 2.77944C13.2897 2.75787 13.5886 2.8407 13.8318 3.01305L15.7261 4.35508C15.798 4.40603 15.7642 4.51602 15.6752 4.52086L5.67742 5.0646C5.37485 5.08106 5.0762 4.99217 4.83563 4.81406L3.44654 3.78562ZM5.20848 6.76919C5.20848 6.4444 5.47088 6.1761 5.80642 6.15783L16.3769 5.58216C16.7039 5.56435 16.9792 5.81583 16.9792 6.13239V15.6783C16.9792 16.0025 16.7177 16.2705 16.3829 16.2896L5.8793 16.8872C5.51537 16.9079 5.20848 16.6283 5.20848 16.2759V6.76919Z" fill="black" />
</g>
<defs>
<clipPath id="clip0_2164_11263">
<rect width="20" height="20" fill="white" />
</clipPath>
</defs>
</svg>
)
}

View File

@ -16,13 +16,16 @@ import * as React from 'react'
import { useCallback, useEffect, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Checkbox from '@/app/components/base/checkbox'
import FileTypeIcon from '@/app/components/base/file-uploader/file-type-icon'
import NotionIcon from '@/app/components/base/notion-icon'
import Pagination from '@/app/components/base/pagination'
import Toast from '@/app/components/base/toast'
import Tooltip from '@/app/components/base/tooltip'
import ChunkingModeLabel from '@/app/components/datasets/common/chunking-mode-label'
import { normalizeStatusForQuery } from '@/app/components/datasets/documents/status-filter'
import { extensionToFileType } from '@/app/components/datasets/hit-testing/utils/extension-to-file-type'
import EditMetadataBatchModal from '@/app/components/datasets/metadata/edit-metadata-batch/modal'
import useBatchEditDocumentMetadata from '@/app/components/datasets/metadata/hooks/use-batch-edit-document-metadata'
import { useDatasetDetailContextWithSelector as useDatasetDetailContext } from '@/context/dataset-detail'
import useTimestamp from '@/hooks/use-timestamp'
import { ChunkingMode, DataSourceType, DocumentActionType } from '@/models/datasets'
@ -31,14 +34,11 @@ import { useDocumentArchive, useDocumentBatchRetryIndex, useDocumentDelete, useD
import { asyncRunSafe } from '@/utils'
import { cn } from '@/utils/classnames'
import { formatNumber } from '@/utils/format'
import FileTypeIcon from '../../base/file-uploader/file-type-icon'
import ChunkingModeLabel from '../common/chunking-mode-label'
import useBatchEditDocumentMetadata from '../metadata/hooks/use-batch-edit-document-metadata'
import BatchAction from './detail/completed/common/batch-action'
import BatchAction from '../detail/completed/common/batch-action'
import StatusItem from '../status-item'
import s from '../style.module.css'
import Operations from './operations'
import RenameModal from './rename-modal'
import StatusItem from './status-item'
import s from './style.module.css'
export const renderTdValue = (value: string | number | null, isEmptyStyle = false) => {
return (

View File

@ -1,4 +1,4 @@
import type { OperationName } from './types'
import type { OperationName } from '../types'
import type { CommonResponse } from '@/models/common'
import {
RiArchive2Line,
@ -17,6 +17,12 @@ import * as React from 'react'
import { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { useContext } from 'use-context-selector'
import Confirm from '@/app/components/base/confirm'
import Divider from '@/app/components/base/divider'
import CustomPopover from '@/app/components/base/popover'
import Switch from '@/app/components/base/switch'
import { ToastContext } from '@/app/components/base/toast'
import Tooltip from '@/app/components/base/tooltip'
import { DataSourceType, DocumentActionType } from '@/models/datasets'
import {
useDocumentArchive,
@ -31,14 +37,8 @@ import {
} from '@/service/knowledge/use-document'
import { asyncRunSafe } from '@/utils'
import { cn } from '@/utils/classnames'
import Confirm from '../../base/confirm'
import Divider from '../../base/divider'
import CustomPopover from '../../base/popover'
import Switch from '../../base/switch'
import { ToastContext } from '../../base/toast'
import Tooltip from '../../base/tooltip'
import s from '../style.module.css'
import RenameModal from './rename-modal'
import s from './style.module.css'
type OperationsProps = {
embeddingAvailable: boolean

View File

@ -7,8 +7,8 @@ import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import Input from '@/app/components/base/input'
import Modal from '@/app/components/base/modal'
import Toast from '@/app/components/base/toast'
import { renameDocumentName } from '@/service/datasets'
import Toast from '../../base/toast'
type Props = {
datasetId: string

View File

@ -0,0 +1,5 @@
export { useAddDocumentsSteps } from './use-add-documents-steps'
export { useDatasourceActions } from './use-datasource-actions'
export { useDatasourceOptions } from './use-datasource-options'
export { useLocalFile, useOnlineDocument, useOnlineDrive, useWebsiteCrawl } from './use-datasource-store'
export { useDatasourceUIState } from './use-datasource-ui-state'

View File

@ -0,0 +1,41 @@
import { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { AddDocumentsStep } from '../types'
/**
* Hook for managing add documents wizard steps
*/
export const useAddDocumentsSteps = () => {
const { t } = useTranslation()
const [currentStep, setCurrentStep] = useState(1)
const handleNextStep = useCallback(() => {
setCurrentStep(preStep => preStep + 1)
}, [])
const handleBackStep = useCallback(() => {
setCurrentStep(preStep => preStep - 1)
}, [])
const steps = [
{
label: t('addDocuments.steps.chooseDatasource', { ns: 'datasetPipeline' }),
value: AddDocumentsStep.dataSource,
},
{
label: t('addDocuments.steps.processDocuments', { ns: 'datasetPipeline' }),
value: AddDocumentsStep.processDocuments,
},
{
label: t('addDocuments.steps.processingDocuments', { ns: 'datasetPipeline' }),
value: AddDocumentsStep.processingDocuments,
},
]
return {
steps,
currentStep,
handleNextStep,
handleBackStep,
}
}

View File

@ -0,0 +1,321 @@
import type { StoreApi } from 'zustand'
import type { DataSourceShape } from '@/app/components/datasets/documents/create-from-pipeline/data-source/store'
import type { Datasource } from '@/app/components/rag-pipeline/components/panel/test-run/types'
import type { DataSourceNotionPageMap, NotionPage } from '@/models/common'
import type { CrawlResultItem, DocumentItem, CustomFile as File, FileIndexingEstimateResponse } from '@/models/datasets'
import type {
OnlineDriveFile,
PublishedPipelineRunPreviewResponse,
PublishedPipelineRunResponse,
} from '@/models/pipeline'
import { useCallback, useRef } from 'react'
import { trackEvent } from '@/app/components/base/amplitude'
import { DatasourceType } from '@/models/pipeline'
import { useRunPublishedPipeline } from '@/service/use-pipeline'
import {
buildLocalFileDatasourceInfo,
buildOnlineDocumentDatasourceInfo,
buildOnlineDriveDatasourceInfo,
buildWebsiteCrawlDatasourceInfo,
} from '../utils/datasource-info-builder'
type DatasourceActionsParams = {
datasource: Datasource | undefined
datasourceType: string | undefined
pipelineId: string | undefined
dataSourceStore: StoreApi<DataSourceShape>
setEstimateData: (data: FileIndexingEstimateResponse | undefined) => void
setBatchId: (id: string) => void
setDocuments: (docs: PublishedPipelineRunResponse['documents']) => void
handleNextStep: () => void
PagesMapAndSelectedPagesId: DataSourceNotionPageMap
currentWorkspacePages: { page_id: string }[] | undefined
clearOnlineDocumentData: () => void
clearWebsiteCrawlData: () => void
clearOnlineDriveData: () => void
setDatasource: (ds: Datasource) => void
}
/**
* Hook for datasource-related actions (preview, process, etc.)
*/
export const useDatasourceActions = ({
datasource,
datasourceType,
pipelineId,
dataSourceStore,
setEstimateData,
setBatchId,
setDocuments,
handleNextStep,
PagesMapAndSelectedPagesId,
currentWorkspacePages,
clearOnlineDocumentData,
clearWebsiteCrawlData,
clearOnlineDriveData,
setDatasource,
}: DatasourceActionsParams) => {
const isPreview = useRef(false)
const formRef = useRef<{ submit: () => void } | null>(null)
const { mutateAsync: runPublishedPipeline, isIdle, isPending } = useRunPublishedPipeline()
// Build datasource info for preview (single item)
const buildPreviewDatasourceInfo = useCallback(() => {
const {
previewLocalFileRef,
previewOnlineDocumentRef,
previewWebsitePageRef,
previewOnlineDriveFileRef,
currentCredentialId,
bucket,
} = dataSourceStore.getState()
const datasourceInfoList: Record<string, unknown>[] = []
if (datasourceType === DatasourceType.localFile && previewLocalFileRef.current) {
datasourceInfoList.push(buildLocalFileDatasourceInfo(
previewLocalFileRef.current as File,
currentCredentialId,
))
}
if (datasourceType === DatasourceType.onlineDocument && previewOnlineDocumentRef.current) {
datasourceInfoList.push(buildOnlineDocumentDatasourceInfo(
previewOnlineDocumentRef.current,
currentCredentialId,
))
}
if (datasourceType === DatasourceType.websiteCrawl && previewWebsitePageRef.current) {
datasourceInfoList.push(buildWebsiteCrawlDatasourceInfo(
previewWebsitePageRef.current,
currentCredentialId,
))
}
if (datasourceType === DatasourceType.onlineDrive && previewOnlineDriveFileRef.current) {
datasourceInfoList.push(buildOnlineDriveDatasourceInfo(
previewOnlineDriveFileRef.current,
bucket,
currentCredentialId,
))
}
return datasourceInfoList
}, [dataSourceStore, datasourceType])
// Build datasource info for processing (all items)
const buildProcessDatasourceInfo = useCallback(() => {
const {
currentCredentialId,
localFileList,
onlineDocuments,
websitePages,
bucket,
selectedFileIds,
onlineDriveFileList,
} = dataSourceStore.getState()
const datasourceInfoList: Record<string, unknown>[] = []
if (datasourceType === DatasourceType.localFile) {
localFileList.forEach((file) => {
datasourceInfoList.push(buildLocalFileDatasourceInfo(file.file, currentCredentialId))
})
}
if (datasourceType === DatasourceType.onlineDocument) {
onlineDocuments.forEach((page) => {
datasourceInfoList.push(buildOnlineDocumentDatasourceInfo(page, currentCredentialId))
})
}
if (datasourceType === DatasourceType.websiteCrawl) {
websitePages.forEach((page) => {
datasourceInfoList.push(buildWebsiteCrawlDatasourceInfo(page, currentCredentialId))
})
}
if (datasourceType === DatasourceType.onlineDrive) {
selectedFileIds.forEach((id) => {
const file = onlineDriveFileList.find(f => f.id === id)
if (file)
datasourceInfoList.push(buildOnlineDriveDatasourceInfo(file, bucket, currentCredentialId))
})
}
return datasourceInfoList
}, [dataSourceStore, datasourceType])
// Handle chunk preview
const handlePreviewChunks = useCallback(async (data: Record<string, unknown>) => {
if (!datasource || !pipelineId)
return
const datasourceInfoList = buildPreviewDatasourceInfo()
await runPublishedPipeline({
pipeline_id: pipelineId,
inputs: data,
start_node_id: datasource.nodeId,
datasource_type: datasourceType as DatasourceType,
datasource_info_list: datasourceInfoList,
is_preview: true,
}, {
onSuccess: (res) => {
setEstimateData((res as PublishedPipelineRunPreviewResponse).data.outputs)
},
})
}, [datasource, pipelineId, datasourceType, buildPreviewDatasourceInfo, runPublishedPipeline, setEstimateData])
// Handle document processing
const handleProcess = useCallback(async (data: Record<string, unknown>) => {
if (!datasource || !pipelineId)
return
const datasourceInfoList = buildProcessDatasourceInfo()
await runPublishedPipeline({
pipeline_id: pipelineId,
inputs: data,
start_node_id: datasource.nodeId,
datasource_type: datasourceType as DatasourceType,
datasource_info_list: datasourceInfoList,
is_preview: false,
}, {
onSuccess: (res) => {
setBatchId((res as PublishedPipelineRunResponse).batch || '')
setDocuments((res as PublishedPipelineRunResponse).documents || [])
handleNextStep()
trackEvent('dataset_document_added', {
data_source_type: datasourceType,
indexing_technique: 'pipeline',
})
},
})
}, [datasource, pipelineId, datasourceType, buildProcessDatasourceInfo, runPublishedPipeline, setBatchId, setDocuments, handleNextStep])
// Form submission handlers
const onClickProcess = useCallback(() => {
isPreview.current = false
formRef.current?.submit()
}, [])
const onClickPreview = useCallback(() => {
isPreview.current = true
formRef.current?.submit()
}, [])
const handleSubmit = useCallback((data: Record<string, unknown>) => {
if (isPreview.current)
handlePreviewChunks(data)
else
handleProcess(data)
}, [handlePreviewChunks, handleProcess])
// Preview change handlers
const handlePreviewFileChange = useCallback((file: DocumentItem) => {
const { previewLocalFileRef } = dataSourceStore.getState()
previewLocalFileRef.current = file
onClickPreview()
}, [dataSourceStore, onClickPreview])
const handlePreviewOnlineDocumentChange = useCallback((page: NotionPage) => {
const { previewOnlineDocumentRef } = dataSourceStore.getState()
previewOnlineDocumentRef.current = page
onClickPreview()
}, [dataSourceStore, onClickPreview])
const handlePreviewWebsiteChange = useCallback((website: CrawlResultItem) => {
const { previewWebsitePageRef } = dataSourceStore.getState()
previewWebsitePageRef.current = website
onClickPreview()
}, [dataSourceStore, onClickPreview])
const handlePreviewOnlineDriveFileChange = useCallback((file: OnlineDriveFile) => {
const { previewOnlineDriveFileRef } = dataSourceStore.getState()
previewOnlineDriveFileRef.current = file
onClickPreview()
}, [dataSourceStore, onClickPreview])
// Select all handler
const handleSelectAll = useCallback(() => {
const {
onlineDocuments,
onlineDriveFileList,
selectedFileIds,
setOnlineDocuments,
setSelectedFileIds,
setSelectedPagesId,
} = dataSourceStore.getState()
if (datasourceType === DatasourceType.onlineDocument) {
const allIds = currentWorkspacePages?.map(page => page.page_id) || []
if (onlineDocuments.length < allIds.length) {
const selectedPages = Array.from(allIds).map(pageId => PagesMapAndSelectedPagesId[pageId])
setOnlineDocuments(selectedPages)
setSelectedPagesId(new Set(allIds))
}
else {
setOnlineDocuments([])
setSelectedPagesId(new Set())
}
}
if (datasourceType === DatasourceType.onlineDrive) {
const allKeys = onlineDriveFileList.filter(item => item.type !== 'bucket').map(file => file.id)
if (selectedFileIds.length < allKeys.length)
setSelectedFileIds(allKeys)
else
setSelectedFileIds([])
}
}, [PagesMapAndSelectedPagesId, currentWorkspacePages, dataSourceStore, datasourceType])
// Clear datasource data based on type
const clearDataSourceData = useCallback((dataSource: Datasource) => {
const providerType = dataSource.nodeData.provider_type
const clearFunctions: Record<string, () => void> = {
[DatasourceType.onlineDocument]: clearOnlineDocumentData,
[DatasourceType.websiteCrawl]: clearWebsiteCrawlData,
[DatasourceType.onlineDrive]: clearOnlineDriveData,
[DatasourceType.localFile]: () => {},
}
clearFunctions[providerType]?.()
}, [clearOnlineDocumentData, clearOnlineDriveData, clearWebsiteCrawlData])
// Switch datasource handler
const handleSwitchDataSource = useCallback((dataSource: Datasource) => {
const {
setCurrentCredentialId,
currentNodeIdRef,
} = dataSourceStore.getState()
clearDataSourceData(dataSource)
setCurrentCredentialId('')
currentNodeIdRef.current = dataSource.nodeId
setDatasource(dataSource)
}, [clearDataSourceData, dataSourceStore, setDatasource])
// Credential change handler
const handleCredentialChange = useCallback((credentialId: string) => {
const { setCurrentCredentialId } = dataSourceStore.getState()
if (datasource)
clearDataSourceData(datasource)
setCurrentCredentialId(credentialId)
}, [clearDataSourceData, dataSourceStore, datasource])
return {
isPreview,
formRef,
isIdle,
isPending,
onClickProcess,
onClickPreview,
handleSubmit,
handlePreviewFileChange,
handlePreviewOnlineDocumentChange,
handlePreviewWebsiteChange,
handlePreviewOnlineDriveFileChange,
handleSelectAll,
handleSwitchDataSource,
handleCredentialChange,
}
}

View File

@ -0,0 +1,27 @@
import type { DataSourceOption } from '@/app/components/rag-pipeline/components/panel/test-run/types'
import type { DataSourceNodeType } from '@/app/components/workflow/nodes/data-source/types'
import type { Node } from '@/app/components/workflow/types'
import { useMemo } from 'react'
import { BlockEnum } from '@/app/components/workflow/types'
/**
* Hook for getting datasource options from pipeline nodes
*/
export const useDatasourceOptions = (pipelineNodes: Node<DataSourceNodeType>[]) => {
const datasourceNodes = pipelineNodes.filter(node => node.data.type === BlockEnum.DataSource)
const options = useMemo(() => {
const options: DataSourceOption[] = []
datasourceNodes.forEach((node) => {
const label = node.data.title
options.push({
label,
value: node.id,
data: node.data,
})
})
return options
}, [datasourceNodes])
return options
}

View File

@ -1,69 +1,12 @@
import type { DataSourceOption } from '@/app/components/rag-pipeline/components/panel/test-run/types'
import type { DataSourceNodeType } from '@/app/components/workflow/nodes/data-source/types'
import type { Node } from '@/app/components/workflow/types'
import type { DataSourceNotionPageMap, DataSourceNotionWorkspace } from '@/models/common'
import { useCallback, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { useCallback, useMemo } from 'react'
import { useShallow } from 'zustand/react/shallow'
import { BlockEnum } from '@/app/components/workflow/types'
import { CrawlStep } from '@/models/datasets'
import { useDataSourceStore, useDataSourceStoreWithSelector } from './data-source/store'
import { AddDocumentsStep } from './types'
export const useAddDocumentsSteps = () => {
const { t } = useTranslation()
const [currentStep, setCurrentStep] = useState(1)
const handleNextStep = useCallback(() => {
setCurrentStep(preStep => preStep + 1)
}, [])
const handleBackStep = useCallback(() => {
setCurrentStep(preStep => preStep - 1)
}, [])
const steps = [
{
label: t('addDocuments.steps.chooseDatasource', { ns: 'datasetPipeline' }),
value: AddDocumentsStep.dataSource,
},
{
label: t('addDocuments.steps.processDocuments', { ns: 'datasetPipeline' }),
value: AddDocumentsStep.processDocuments,
},
{
label: t('addDocuments.steps.processingDocuments', { ns: 'datasetPipeline' }),
value: AddDocumentsStep.processingDocuments,
},
]
return {
steps,
currentStep,
handleNextStep,
handleBackStep,
}
}
export const useDatasourceOptions = (pipelineNodes: Node<DataSourceNodeType>[]) => {
const datasourceNodes = pipelineNodes.filter(node => node.data.type === BlockEnum.DataSource)
const options = useMemo(() => {
const options: DataSourceOption[] = []
datasourceNodes.forEach((node) => {
const label = node.data.title
options.push({
label,
value: node.id,
data: node.data,
})
})
return options
}, [datasourceNodes])
return options
}
import { useDataSourceStore, useDataSourceStoreWithSelector } from '../data-source/store'
/**
* Hook for local file datasource store operations
*/
export const useLocalFile = () => {
const {
localFileList,
@ -89,6 +32,9 @@ export const useLocalFile = () => {
}
}
/**
* Hook for online document datasource store operations
*/
export const useOnlineDocument = () => {
const {
documentsData,
@ -147,6 +93,9 @@ export const useOnlineDocument = () => {
}
}
/**
* Hook for website crawl datasource store operations
*/
export const useWebsiteCrawl = () => {
const {
websitePages,
@ -186,6 +135,9 @@ export const useWebsiteCrawl = () => {
}
}
/**
* Hook for online drive datasource store operations
*/
export const useOnlineDrive = () => {
const {
onlineDriveFileList,

View File

@ -0,0 +1,132 @@
import type { Datasource } from '@/app/components/rag-pipeline/components/panel/test-run/types'
import type { OnlineDriveFile } from '@/models/pipeline'
import { useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import { DatasourceType } from '@/models/pipeline'
type DatasourceUIStateParams = {
datasource: Datasource | undefined
allFileLoaded: boolean
localFileListLength: number
onlineDocumentsLength: number
websitePagesLength: number
selectedFileIdsLength: number
onlineDriveFileList: OnlineDriveFile[]
isVectorSpaceFull: boolean
enableBilling: boolean
currentWorkspacePagesLength: number
fileUploadConfig: { file_size_limit: number, batch_count_limit: number }
}
/**
* Hook for computing datasource UI state based on datasource type
*/
export const useDatasourceUIState = ({
datasource,
allFileLoaded,
localFileListLength,
onlineDocumentsLength,
websitePagesLength,
selectedFileIdsLength,
onlineDriveFileList,
isVectorSpaceFull,
enableBilling,
currentWorkspacePagesLength,
fileUploadConfig,
}: DatasourceUIStateParams) => {
const { t } = useTranslation()
const datasourceType = datasource?.nodeData.provider_type
const isShowVectorSpaceFull = useMemo(() => {
if (!datasource || !datasourceType)
return false
// Lookup table for vector space full condition check
const vectorSpaceFullConditions: Record<string, boolean> = {
[DatasourceType.localFile]: allFileLoaded,
[DatasourceType.onlineDocument]: onlineDocumentsLength > 0,
[DatasourceType.websiteCrawl]: websitePagesLength > 0,
[DatasourceType.onlineDrive]: onlineDriveFileList.length > 0,
}
const condition = vectorSpaceFullConditions[datasourceType]
return condition && isVectorSpaceFull && enableBilling
}, [datasource, datasourceType, allFileLoaded, onlineDocumentsLength, websitePagesLength, onlineDriveFileList.length, isVectorSpaceFull, enableBilling])
// Lookup table for next button disabled conditions
const nextBtnDisabled = useMemo(() => {
if (!datasource || !datasourceType)
return true
const disabledConditions: Record<string, boolean> = {
[DatasourceType.localFile]: isShowVectorSpaceFull || localFileListLength === 0 || !allFileLoaded,
[DatasourceType.onlineDocument]: isShowVectorSpaceFull || onlineDocumentsLength === 0,
[DatasourceType.websiteCrawl]: isShowVectorSpaceFull || websitePagesLength === 0,
[DatasourceType.onlineDrive]: isShowVectorSpaceFull || selectedFileIdsLength === 0,
}
return disabledConditions[datasourceType] ?? true
}, [datasource, datasourceType, isShowVectorSpaceFull, localFileListLength, allFileLoaded, onlineDocumentsLength, websitePagesLength, selectedFileIdsLength])
// Check if select all should be shown
const showSelect = useMemo(() => {
if (datasourceType === DatasourceType.onlineDocument)
return currentWorkspacePagesLength > 0
if (datasourceType === DatasourceType.onlineDrive) {
const nonBucketItems = onlineDriveFileList.filter(item => item.type !== 'bucket')
const isBucketList = onlineDriveFileList.some(file => file.type === 'bucket')
return !isBucketList && nonBucketItems.length > 0
}
return false
}, [currentWorkspacePagesLength, datasourceType, onlineDriveFileList])
// Total selectable options count
const totalOptions = useMemo(() => {
if (datasourceType === DatasourceType.onlineDocument)
return currentWorkspacePagesLength
if (datasourceType === DatasourceType.onlineDrive)
return onlineDriveFileList.filter(item => item.type !== 'bucket').length
return undefined
}, [currentWorkspacePagesLength, datasourceType, onlineDriveFileList])
// Selected options count
const selectedOptions = useMemo(() => {
if (datasourceType === DatasourceType.onlineDocument)
return onlineDocumentsLength
if (datasourceType === DatasourceType.onlineDrive)
return selectedFileIdsLength
return undefined
}, [datasourceType, onlineDocumentsLength, selectedFileIdsLength])
// Tip message for selection
const tip = useMemo(() => {
if (datasourceType === DatasourceType.onlineDocument)
return t('addDocuments.selectOnlineDocumentTip', { ns: 'datasetPipeline', count: 50 })
if (datasourceType === DatasourceType.onlineDrive) {
return t('addDocuments.selectOnlineDriveTip', {
ns: 'datasetPipeline',
count: fileUploadConfig.batch_count_limit,
fileSize: fileUploadConfig.file_size_limit,
})
}
return ''
}, [datasourceType, fileUploadConfig.batch_count_limit, fileUploadConfig.file_size_limit, t])
return {
datasourceType,
isShowVectorSpaceFull,
nextBtnDisabled,
showSelect,
totalOptions,
selectedOptions,
tip,
}
}

File diff suppressed because it is too large Load Diff

View File

@ -2,75 +2,71 @@
import type { Datasource } from '@/app/components/rag-pipeline/components/panel/test-run/types'
import type { DataSourceNodeType } from '@/app/components/workflow/nodes/data-source/types'
import type { Node } from '@/app/components/workflow/types'
import type { NotionPage } from '@/models/common'
import type { CrawlResultItem, DocumentItem, CustomFile as File, FileIndexingEstimateResponse } from '@/models/datasets'
import type {
InitialDocumentDetail,
OnlineDriveFile,
PublishedPipelineRunPreviewResponse,
PublishedPipelineRunResponse,
} from '@/models/pipeline'
import type { FileIndexingEstimateResponse } from '@/models/datasets'
import type { InitialDocumentDetail } from '@/models/pipeline'
import { useBoolean } from 'ahooks'
import { useCallback, useMemo, useRef, useState } from 'react'
import { useCallback, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { trackEvent } from '@/app/components/base/amplitude'
import Divider from '@/app/components/base/divider'
import Loading from '@/app/components/base/loading'
import PlanUpgradeModal from '@/app/components/billing/plan-upgrade-modal'
import VectorSpaceFull from '@/app/components/billing/vector-space-full'
import LocalFile from '@/app/components/datasets/documents/create-from-pipeline/data-source/local-file'
import OnlineDocuments from '@/app/components/datasets/documents/create-from-pipeline/data-source/online-documents'
import OnlineDrive from '@/app/components/datasets/documents/create-from-pipeline/data-source/online-drive'
import WebsiteCrawl from '@/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl'
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
import { useProviderContextSelector } from '@/context/provider-context'
import { DatasourceType } from '@/models/pipeline'
import { useFileUploadConfig } from '@/service/use-common'
import { usePublishedPipelineInfo, useRunPublishedPipeline } from '@/service/use-pipeline'
import { TransferMethod } from '@/types/app'
import UpgradeCard from '../../create/step-one/upgrade-card'
import Actions from './actions'
import DataSourceOptions from './data-source-options'
import { usePublishedPipelineInfo } from '@/service/use-pipeline'
import { useDataSourceStore } from './data-source/store'
import DataSourceProvider from './data-source/store/provider'
import { useAddDocumentsSteps, useLocalFile, useOnlineDocument, useOnlineDrive, useWebsiteCrawl } from './hooks'
import {
useAddDocumentsSteps,
useDatasourceActions,
useDatasourceUIState,
useLocalFile,
useOnlineDocument,
useOnlineDrive,
useWebsiteCrawl,
} from './hooks'
import LeftHeader from './left-header'
import ChunkPreview from './preview/chunk-preview'
import FilePreview from './preview/file-preview'
import OnlineDocumentPreview from './preview/online-document-preview'
import WebsitePreview from './preview/web-preview'
import ProcessDocuments from './process-documents'
import Processing from './processing'
import { StepOneContent, StepThreeContent, StepTwoContent } from './steps'
import { StepOnePreview, StepTwoPreview } from './steps/preview-panel'
const CreateFormPipeline = () => {
const { t } = useTranslation()
const plan = useProviderContextSelector(state => state.plan)
const enableBilling = useProviderContextSelector(state => state.enableBilling)
const pipelineId = useDatasetDetailContextWithSelector(s => s.dataset?.pipeline_id)
const dataSourceStore = useDataSourceStore()
// Core state
const [datasource, setDatasource] = useState<Datasource>()
const [estimateData, setEstimateData] = useState<FileIndexingEstimateResponse | undefined>(undefined)
const [batchId, setBatchId] = useState('')
const [documents, setDocuments] = useState<InitialDocumentDetail[]>([])
const dataSourceStore = useDataSourceStore()
const isPreview = useRef(false)
const formRef = useRef<any>(null)
// Data fetching
const { data: pipelineInfo, isFetching: isFetchingPipelineInfo } = usePublishedPipelineInfo(pipelineId || '')
const { data: fileUploadConfigResponse } = useFileUploadConfig()
const fileUploadConfig = useMemo(() => fileUploadConfigResponse ?? {
file_size_limit: 15,
batch_count_limit: 5,
}, [fileUploadConfigResponse])
// Steps management
const {
steps,
currentStep,
handleNextStep: doHandleNextStep,
handleBackStep,
} = useAddDocumentsSteps()
// Datasource-specific hooks
const {
localFileList,
allFileLoaded,
currentLocalFile,
hidePreviewLocalFile,
} = useLocalFile()
const {
currentWorkspace,
onlineDocuments,
@ -79,12 +75,14 @@ const CreateFormPipeline = () => {
hidePreviewOnlineDocument,
clearOnlineDocumentData,
} = useOnlineDocument()
const {
websitePages,
currentWebsite,
hideWebsitePreview,
clearWebsiteCrawlData,
} = useWebsiteCrawl()
const {
onlineDriveFileList,
selectedFileIds,
@ -92,43 +90,50 @@ const CreateFormPipeline = () => {
clearOnlineDriveData,
} = useOnlineDrive()
const datasourceType = useMemo(() => datasource?.nodeData.provider_type, [datasource])
// Computed values
const isVectorSpaceFull = plan.usage.vectorSpace >= plan.total.vectorSpace
const isShowVectorSpaceFull = useMemo(() => {
if (!datasource)
return false
if (datasourceType === DatasourceType.localFile)
return allFileLoaded && isVectorSpaceFull && enableBilling
if (datasourceType === DatasourceType.onlineDocument)
return onlineDocuments.length > 0 && isVectorSpaceFull && enableBilling
if (datasourceType === DatasourceType.websiteCrawl)
return websitePages.length > 0 && isVectorSpaceFull && enableBilling
if (datasourceType === DatasourceType.onlineDrive)
return onlineDriveFileList.length > 0 && isVectorSpaceFull && enableBilling
return false
}, [allFileLoaded, datasource, datasourceType, enableBilling, isVectorSpaceFull, onlineDocuments.length, onlineDriveFileList.length, websitePages.length])
const supportBatchUpload = !enableBilling || plan.type !== 'sandbox'
// UI state
const {
datasourceType,
isShowVectorSpaceFull,
nextBtnDisabled,
showSelect,
totalOptions,
selectedOptions,
tip,
} = useDatasourceUIState({
datasource,
allFileLoaded,
localFileListLength: localFileList.length,
onlineDocumentsLength: onlineDocuments.length,
websitePagesLength: websitePages.length,
selectedFileIdsLength: selectedFileIds.length,
onlineDriveFileList,
isVectorSpaceFull,
enableBilling,
currentWorkspacePagesLength: currentWorkspace?.pages.length ?? 0,
fileUploadConfig,
})
// Plan upgrade modal
const [isShowPlanUpgradeModal, {
setTrue: showPlanUpgradeModal,
setFalse: hidePlanUpgradeModal,
}] = useBoolean(false)
// Next step with batch upload check
const handleNextStep = useCallback(() => {
if (!supportBatchUpload) {
let isMultiple = false
if (datasourceType === DatasourceType.localFile && localFileList.length > 1)
isMultiple = true
if (datasourceType === DatasourceType.onlineDocument && onlineDocuments.length > 1)
isMultiple = true
if (datasourceType === DatasourceType.websiteCrawl && websitePages.length > 1)
isMultiple = true
if (datasourceType === DatasourceType.onlineDrive && selectedFileIds.length > 1)
isMultiple = true
if (isMultiple) {
const multipleCheckMap: Record<string, number> = {
[DatasourceType.localFile]: localFileList.length,
[DatasourceType.onlineDocument]: onlineDocuments.length,
[DatasourceType.websiteCrawl]: websitePages.length,
[DatasourceType.onlineDrive]: selectedFileIds.length,
}
const count = datasourceType ? multipleCheckMap[datasourceType] : 0
if (count > 1) {
showPlanUpgradeModal()
return
}
@ -136,334 +141,44 @@ const CreateFormPipeline = () => {
doHandleNextStep()
}, [datasourceType, doHandleNextStep, localFileList.length, onlineDocuments.length, selectedFileIds.length, showPlanUpgradeModal, supportBatchUpload, websitePages.length])
const nextBtnDisabled = useMemo(() => {
if (!datasource)
return true
if (datasourceType === DatasourceType.localFile)
return isShowVectorSpaceFull || !localFileList.length || !allFileLoaded
if (datasourceType === DatasourceType.onlineDocument)
return isShowVectorSpaceFull || !onlineDocuments.length
if (datasourceType === DatasourceType.websiteCrawl)
return isShowVectorSpaceFull || !websitePages.length
if (datasourceType === DatasourceType.onlineDrive)
return isShowVectorSpaceFull || !selectedFileIds.length
return false
}, [datasource, datasourceType, isShowVectorSpaceFull, localFileList.length, allFileLoaded, onlineDocuments.length, websitePages.length, selectedFileIds.length])
// Datasource actions
const {
isPreview,
formRef,
isIdle,
isPending,
onClickProcess,
onClickPreview,
handleSubmit,
handlePreviewFileChange,
handlePreviewOnlineDocumentChange,
handlePreviewWebsiteChange,
handlePreviewOnlineDriveFileChange,
handleSelectAll,
handleSwitchDataSource,
handleCredentialChange,
} = useDatasourceActions({
datasource,
datasourceType,
pipelineId,
dataSourceStore,
setEstimateData,
setBatchId,
setDocuments,
handleNextStep,
PagesMapAndSelectedPagesId,
currentWorkspacePages: currentWorkspace?.pages,
clearOnlineDocumentData,
clearWebsiteCrawlData,
clearOnlineDriveData,
setDatasource,
})
const fileUploadConfig = useMemo(() => fileUploadConfigResponse ?? {
file_size_limit: 15,
batch_count_limit: 5,
}, [fileUploadConfigResponse])
const showSelect = useMemo(() => {
if (datasourceType === DatasourceType.onlineDocument) {
const pagesCount = currentWorkspace?.pages.length ?? 0
return pagesCount > 0
}
if (datasourceType === DatasourceType.onlineDrive) {
const isBucketList = onlineDriveFileList.some(file => file.type === 'bucket')
return !isBucketList && onlineDriveFileList.filter((item) => {
return item.type !== 'bucket'
}).length > 0
}
return false
}, [currentWorkspace?.pages.length, datasourceType, onlineDriveFileList])
const totalOptions = useMemo(() => {
if (datasourceType === DatasourceType.onlineDocument)
return currentWorkspace?.pages.length
if (datasourceType === DatasourceType.onlineDrive) {
return onlineDriveFileList.filter((item) => {
return item.type !== 'bucket'
}).length
}
}, [currentWorkspace?.pages.length, datasourceType, onlineDriveFileList])
const selectedOptions = useMemo(() => {
if (datasourceType === DatasourceType.onlineDocument)
return onlineDocuments.length
if (datasourceType === DatasourceType.onlineDrive)
return selectedFileIds.length
}, [datasourceType, onlineDocuments.length, selectedFileIds.length])
const tip = useMemo(() => {
if (datasourceType === DatasourceType.onlineDocument)
return t('addDocuments.selectOnlineDocumentTip', { ns: 'datasetPipeline', count: 50 })
if (datasourceType === DatasourceType.onlineDrive) {
return t('addDocuments.selectOnlineDriveTip', {
ns: 'datasetPipeline',
count: fileUploadConfig.batch_count_limit,
fileSize: fileUploadConfig.file_size_limit,
})
}
return ''
}, [datasourceType, fileUploadConfig.batch_count_limit, fileUploadConfig.file_size_limit, t])
const { mutateAsync: runPublishedPipeline, isIdle, isPending } = useRunPublishedPipeline()
const handlePreviewChunks = useCallback(async (data: Record<string, any>) => {
if (!datasource)
return
const {
previewLocalFileRef,
previewOnlineDocumentRef,
previewWebsitePageRef,
previewOnlineDriveFileRef,
currentCredentialId,
} = dataSourceStore.getState()
const datasourceInfoList: Record<string, any>[] = []
if (datasourceType === DatasourceType.localFile) {
const { id, name, type, size, extension, mime_type } = previewLocalFileRef.current as File
const documentInfo = {
related_id: id,
name,
type,
size,
extension,
mime_type,
url: '',
transfer_method: TransferMethod.local_file,
credential_id: currentCredentialId,
}
datasourceInfoList.push(documentInfo)
}
if (datasourceType === DatasourceType.onlineDocument) {
const { workspace_id, ...rest } = previewOnlineDocumentRef.current!
const documentInfo = {
workspace_id,
page: rest,
credential_id: currentCredentialId,
}
datasourceInfoList.push(documentInfo)
}
if (datasourceType === DatasourceType.websiteCrawl) {
datasourceInfoList.push({
...previewWebsitePageRef.current!,
credential_id: currentCredentialId,
})
}
if (datasourceType === DatasourceType.onlineDrive) {
const { bucket } = dataSourceStore.getState()
const { id, type, name } = previewOnlineDriveFileRef.current!
datasourceInfoList.push({
bucket,
id,
name,
type,
credential_id: currentCredentialId,
})
}
await runPublishedPipeline({
pipeline_id: pipelineId!,
inputs: data,
start_node_id: datasource.nodeId,
datasource_type: datasourceType as DatasourceType,
datasource_info_list: datasourceInfoList,
is_preview: true,
}, {
onSuccess: (res) => {
setEstimateData((res as PublishedPipelineRunPreviewResponse).data.outputs)
},
})
}, [datasource, datasourceType, runPublishedPipeline, pipelineId, dataSourceStore])
const handleProcess = useCallback(async (data: Record<string, any>) => {
if (!datasource)
return
const { currentCredentialId } = dataSourceStore.getState()
const datasourceInfoList: Record<string, any>[] = []
if (datasourceType === DatasourceType.localFile) {
const {
localFileList,
} = dataSourceStore.getState()
localFileList.forEach((file) => {
const { id, name, type, size, extension, mime_type } = file.file
const documentInfo = {
related_id: id,
name,
type,
size,
extension,
mime_type,
url: '',
transfer_method: TransferMethod.local_file,
credential_id: currentCredentialId,
}
datasourceInfoList.push(documentInfo)
})
}
if (datasourceType === DatasourceType.onlineDocument) {
const {
onlineDocuments,
} = dataSourceStore.getState()
onlineDocuments.forEach((page) => {
const { workspace_id, ...rest } = page
const documentInfo = {
workspace_id,
page: rest,
credential_id: currentCredentialId,
}
datasourceInfoList.push(documentInfo)
})
}
if (datasourceType === DatasourceType.websiteCrawl) {
const {
websitePages,
} = dataSourceStore.getState()
websitePages.forEach((websitePage) => {
datasourceInfoList.push({
...websitePage,
credential_id: currentCredentialId,
})
})
}
if (datasourceType === DatasourceType.onlineDrive) {
const {
bucket,
selectedFileIds,
onlineDriveFileList,
} = dataSourceStore.getState()
selectedFileIds.forEach((id) => {
const file = onlineDriveFileList.find(file => file.id === id)
datasourceInfoList.push({
bucket,
id: file?.id,
name: file?.name,
type: file?.type,
credential_id: currentCredentialId,
})
})
}
await runPublishedPipeline({
pipeline_id: pipelineId!,
inputs: data,
start_node_id: datasource.nodeId,
datasource_type: datasourceType as DatasourceType,
datasource_info_list: datasourceInfoList,
is_preview: false,
}, {
onSuccess: (res) => {
setBatchId((res as PublishedPipelineRunResponse).batch || '')
setDocuments((res as PublishedPipelineRunResponse).documents || [])
handleNextStep()
trackEvent('dataset_document_added', {
data_source_type: datasourceType,
indexing_technique: 'pipeline',
})
},
})
}, [dataSourceStore, datasource, datasourceType, handleNextStep, pipelineId, runPublishedPipeline])
const onClickProcess = useCallback(() => {
isPreview.current = false
formRef.current?.submit()
}, [])
const onClickPreview = useCallback(() => {
isPreview.current = true
formRef.current?.submit()
}, [])
const handleSubmit = useCallback((data: Record<string, any>) => {
if (isPreview.current)
handlePreviewChunks(data)
else
handleProcess(data)
}, [handlePreviewChunks, handleProcess])
const handlePreviewFileChange = useCallback((file: DocumentItem) => {
const { previewLocalFileRef } = dataSourceStore.getState()
previewLocalFileRef.current = file
onClickPreview()
}, [dataSourceStore, onClickPreview])
const handlePreviewOnlineDocumentChange = useCallback((page: NotionPage) => {
const { previewOnlineDocumentRef } = dataSourceStore.getState()
previewOnlineDocumentRef.current = page
onClickPreview()
}, [dataSourceStore, onClickPreview])
const handlePreviewWebsiteChange = useCallback((website: CrawlResultItem) => {
const { previewWebsitePageRef } = dataSourceStore.getState()
previewWebsitePageRef.current = website
onClickPreview()
}, [dataSourceStore, onClickPreview])
const handlePreviewOnlineDriveFileChange = useCallback((file: OnlineDriveFile) => {
const { previewOnlineDriveFileRef } = dataSourceStore.getState()
previewOnlineDriveFileRef.current = file
onClickPreview()
}, [dataSourceStore, onClickPreview])
const handleSelectAll = useCallback(() => {
const {
onlineDocuments,
onlineDriveFileList,
selectedFileIds,
setOnlineDocuments,
setSelectedFileIds,
setSelectedPagesId,
} = dataSourceStore.getState()
if (datasourceType === DatasourceType.onlineDocument) {
const allIds = currentWorkspace?.pages.map(page => page.page_id) || []
if (onlineDocuments.length < allIds.length) {
const selectedPages = Array.from(allIds).map(pageId => PagesMapAndSelectedPagesId[pageId])
setOnlineDocuments(selectedPages)
setSelectedPagesId(new Set(allIds))
}
else {
setOnlineDocuments([])
setSelectedPagesId(new Set())
}
}
if (datasourceType === DatasourceType.onlineDrive) {
const allKeys = onlineDriveFileList.filter((item) => {
return item.type !== 'bucket'
}).map(file => file.id)
if (selectedFileIds.length < allKeys.length)
setSelectedFileIds(allKeys)
else
setSelectedFileIds([])
}
}, [PagesMapAndSelectedPagesId, currentWorkspace?.pages, dataSourceStore, datasourceType])
const clearDataSourceData = useCallback((dataSource: Datasource) => {
const providerType = dataSource.nodeData.provider_type
if (providerType === DatasourceType.onlineDocument)
clearOnlineDocumentData()
else if (providerType === DatasourceType.websiteCrawl)
clearWebsiteCrawlData()
else if (providerType === DatasourceType.onlineDrive)
clearOnlineDriveData()
}, [clearOnlineDocumentData, clearOnlineDriveData, clearWebsiteCrawlData])
const handleSwitchDataSource = useCallback((dataSource: Datasource) => {
const {
setCurrentCredentialId,
currentNodeIdRef,
} = dataSourceStore.getState()
clearDataSourceData(dataSource)
setCurrentCredentialId('')
currentNodeIdRef.current = dataSource.nodeId
setDatasource(dataSource)
}, [clearDataSourceData, dataSourceStore])
const handleCredentialChange = useCallback((credentialId: string) => {
const { setCurrentCredentialId } = dataSourceStore.getState()
clearDataSourceData(datasource!)
setCurrentCredentialId(credentialId)
}, [clearDataSourceData, dataSourceStore, datasource])
if (isFetchingPipelineInfo) {
return (
<Loading type="app" />
)
}
if (isFetchingPipelineInfo)
return <Loading type="app" />
return (
<div
className="relative flex h-[calc(100vh-56px)] w-full min-w-[1024px] overflow-x-auto rounded-t-2xl border-t border-effects-highlight bg-background-default-subtle"
>
<div className="relative flex h-[calc(100vh-56px)] w-full min-w-[1024px] overflow-x-auto rounded-t-2xl border-t border-effects-highlight bg-background-default-subtle">
<div className="h-full min-w-0 flex-1">
<div className="flex h-full flex-col px-14">
<LeftHeader
@ -472,139 +187,77 @@ const CreateFormPipeline = () => {
currentStep={currentStep}
/>
<div className="grow overflow-y-auto">
{
currentStep === 1 && (
<div className="flex flex-col gap-y-5 pt-4">
<DataSourceOptions
datasourceNodeId={datasource?.nodeId || ''}
onSelect={handleSwitchDataSource}
pipelineNodes={(pipelineInfo?.graph.nodes || []) as Node<DataSourceNodeType>[]}
/>
{datasourceType === DatasourceType.localFile && (
<LocalFile
allowedExtensions={datasource!.nodeData.fileExtensions || []}
supportBatchUpload={supportBatchUpload}
/>
)}
{datasourceType === DatasourceType.onlineDocument && (
<OnlineDocuments
nodeId={datasource!.nodeId}
nodeData={datasource!.nodeData}
onCredentialChange={handleCredentialChange}
/>
)}
{datasourceType === DatasourceType.websiteCrawl && (
<WebsiteCrawl
nodeId={datasource!.nodeId}
nodeData={datasource!.nodeData}
onCredentialChange={handleCredentialChange}
/>
)}
{datasourceType === DatasourceType.onlineDrive && (
<OnlineDrive
nodeId={datasource!.nodeId}
nodeData={datasource!.nodeData}
onCredentialChange={handleCredentialChange}
/>
)}
{isShowVectorSpaceFull && (
<VectorSpaceFull />
)}
<Actions
showSelect={showSelect}
totalOptions={totalOptions}
selectedOptions={selectedOptions}
onSelectAll={handleSelectAll}
disabled={nextBtnDisabled}
handleNextStep={handleNextStep}
tip={tip}
/>
{
!supportBatchUpload && datasourceType === DatasourceType.localFile && localFileList.length > 0 && (
<>
<Divider type="horizontal" className="my-4 h-px bg-divider-subtle" />
<UpgradeCard />
</>
)
}
</div>
)
}
{
currentStep === 2 && (
<ProcessDocuments
ref={formRef}
dataSourceNodeId={datasource!.nodeId}
isRunning={isPending}
onProcess={onClickProcess}
onPreview={onClickPreview}
onSubmit={handleSubmit}
onBack={handleBackStep}
/>
)
}
{
currentStep === 3 && (
<Processing
batchId={batchId}
documents={documents}
/>
)
}
{currentStep === 1 && (
<StepOneContent
datasource={datasource}
datasourceType={datasourceType}
pipelineNodes={(pipelineInfo?.graph.nodes || []) as Node<DataSourceNodeType>[]}
supportBatchUpload={supportBatchUpload}
localFileListLength={localFileList.length}
isShowVectorSpaceFull={isShowVectorSpaceFull}
showSelect={showSelect}
totalOptions={totalOptions}
selectedOptions={selectedOptions}
tip={tip}
nextBtnDisabled={nextBtnDisabled}
onSelectDataSource={handleSwitchDataSource}
onCredentialChange={handleCredentialChange}
onSelectAll={handleSelectAll}
onNextStep={handleNextStep}
/>
)}
{currentStep === 2 && (
<StepTwoContent
formRef={formRef}
dataSourceNodeId={datasource!.nodeId}
isRunning={isPending}
onProcess={onClickProcess}
onPreview={onClickPreview}
onSubmit={handleSubmit}
onBack={handleBackStep}
/>
)}
{currentStep === 3 && (
<StepThreeContent
batchId={batchId}
documents={documents}
/>
)}
</div>
</div>
</div>
{/* Preview */}
{
currentStep === 1 && (
<div className="h-full min-w-0 flex-1">
<div className="flex h-full flex-col pl-2 pt-2">
{currentLocalFile && (
<FilePreview
file={currentLocalFile}
hidePreview={hidePreviewLocalFile}
/>
)}
{currentDocument && (
<OnlineDocumentPreview
datasourceNodeId={datasource!.nodeId}
currentPage={currentDocument}
hidePreview={hidePreviewOnlineDocument}
/>
)}
{currentWebsite && (
<WebsitePreview
currentWebsite={currentWebsite}
hidePreview={hideWebsitePreview}
/>
)}
</div>
</div>
)
}
{
currentStep === 2 && (
<div className="h-full min-w-0 flex-1">
<div className="flex h-full flex-col pl-2 pt-2">
<ChunkPreview
dataSourceType={datasourceType as DatasourceType}
localFiles={localFileList.map(file => file.file)}
onlineDocuments={onlineDocuments}
websitePages={websitePages}
onlineDriveFiles={selectedOnlineDriveFileList}
isIdle={isIdle}
isPending={isPending && isPreview.current}
estimateData={estimateData}
onPreview={onClickPreview}
handlePreviewFileChange={handlePreviewFileChange}
handlePreviewOnlineDocumentChange={handlePreviewOnlineDocumentChange}
handlePreviewWebsitePageChange={handlePreviewWebsiteChange}
handlePreviewOnlineDriveFileChange={handlePreviewOnlineDriveFileChange}
/>
</div>
</div>
)
}
{/* Preview Panel */}
{currentStep === 1 && (
<StepOnePreview
datasource={datasource}
currentLocalFile={currentLocalFile}
currentDocument={currentDocument}
currentWebsite={currentWebsite}
hidePreviewLocalFile={hidePreviewLocalFile}
hidePreviewOnlineDocument={hidePreviewOnlineDocument}
hideWebsitePreview={hideWebsitePreview}
/>
)}
{currentStep === 2 && (
<StepTwoPreview
datasourceType={datasourceType}
localFileList={localFileList}
onlineDocuments={onlineDocuments}
websitePages={websitePages}
selectedOnlineDriveFileList={selectedOnlineDriveFileList}
isIdle={isIdle}
isPendingPreview={isPending && isPreview.current}
estimateData={estimateData}
onPreview={onClickPreview}
handlePreviewFileChange={handlePreviewFileChange}
handlePreviewOnlineDocumentChange={handlePreviewOnlineDocumentChange}
handlePreviewWebsitePageChange={handlePreviewWebsiteChange}
handlePreviewOnlineDriveFileChange={handlePreviewOnlineDriveFileChange}
/>
)}
{/* Plan Upgrade Modal */}
{isShowPlanUpgradeModal && (
<PlanUpgradeModal
show

View File

@ -0,0 +1,3 @@
export { default as StepOneContent } from './step-one-content'
export { default as StepThreeContent } from './step-three-content'
export { default as StepTwoContent } from './step-two-content'

View File

@ -0,0 +1,112 @@
'use client'
import type { Datasource } from '@/app/components/rag-pipeline/components/panel/test-run/types'
import type { NotionPage } from '@/models/common'
import type { CrawlResultItem, CustomFile, DocumentItem, FileIndexingEstimateResponse, FileItem } from '@/models/datasets'
import type { DatasourceType, OnlineDriveFile } from '@/models/pipeline'
import { memo } from 'react'
import ChunkPreview from '../preview/chunk-preview'
import FilePreview from '../preview/file-preview'
import OnlineDocumentPreview from '../preview/online-document-preview'
import WebsitePreview from '../preview/web-preview'
type StepOnePreviewProps = {
datasource: Datasource | undefined
currentLocalFile: CustomFile | undefined
currentDocument: (NotionPage & { workspace_id: string }) | undefined
currentWebsite: CrawlResultItem | undefined
hidePreviewLocalFile: () => void
hidePreviewOnlineDocument: () => void
hideWebsitePreview: () => void
}
export const StepOnePreview = memo(({
datasource,
currentLocalFile,
currentDocument,
currentWebsite,
hidePreviewLocalFile,
hidePreviewOnlineDocument,
hideWebsitePreview,
}: StepOnePreviewProps) => {
return (
<div className="h-full min-w-0 flex-1">
<div className="flex h-full flex-col pl-2 pt-2">
{currentLocalFile && (
<FilePreview
file={currentLocalFile}
hidePreview={hidePreviewLocalFile}
/>
)}
{currentDocument && (
<OnlineDocumentPreview
datasourceNodeId={datasource!.nodeId}
currentPage={currentDocument}
hidePreview={hidePreviewOnlineDocument}
/>
)}
{currentWebsite && (
<WebsitePreview
currentWebsite={currentWebsite}
hidePreview={hideWebsitePreview}
/>
)}
</div>
</div>
)
})
StepOnePreview.displayName = 'StepOnePreview'
type StepTwoPreviewProps = {
datasourceType: string | undefined
localFileList: FileItem[]
onlineDocuments: (NotionPage & { workspace_id: string })[]
websitePages: CrawlResultItem[]
selectedOnlineDriveFileList: OnlineDriveFile[]
isIdle: boolean
isPendingPreview: boolean
estimateData: FileIndexingEstimateResponse | undefined
onPreview: () => void
handlePreviewFileChange: (file: DocumentItem) => void
handlePreviewOnlineDocumentChange: (page: NotionPage) => void
handlePreviewWebsitePageChange: (website: CrawlResultItem) => void
handlePreviewOnlineDriveFileChange: (file: OnlineDriveFile) => void
}
export const StepTwoPreview = memo(({
datasourceType,
localFileList,
onlineDocuments,
websitePages,
selectedOnlineDriveFileList,
isIdle,
isPendingPreview,
estimateData,
onPreview,
handlePreviewFileChange,
handlePreviewOnlineDocumentChange,
handlePreviewWebsitePageChange,
handlePreviewOnlineDriveFileChange,
}: StepTwoPreviewProps) => {
return (
<div className="h-full min-w-0 flex-1">
<div className="flex h-full flex-col pl-2 pt-2">
<ChunkPreview
dataSourceType={datasourceType as DatasourceType}
localFiles={localFileList.map(file => file.file)}
onlineDocuments={onlineDocuments}
websitePages={websitePages}
onlineDriveFiles={selectedOnlineDriveFileList}
isIdle={isIdle}
isPending={isPendingPreview}
estimateData={estimateData}
onPreview={onPreview}
handlePreviewFileChange={handlePreviewFileChange}
handlePreviewOnlineDocumentChange={handlePreviewOnlineDocumentChange}
handlePreviewWebsitePageChange={handlePreviewWebsitePageChange}
handlePreviewOnlineDriveFileChange={handlePreviewOnlineDriveFileChange}
/>
</div>
</div>
)
})
StepTwoPreview.displayName = 'StepTwoPreview'

View File

@ -0,0 +1,110 @@
'use client'
import type { Datasource } from '@/app/components/rag-pipeline/components/panel/test-run/types'
import type { DataSourceNodeType } from '@/app/components/workflow/nodes/data-source/types'
import type { Node } from '@/app/components/workflow/types'
import { memo } from 'react'
import Divider from '@/app/components/base/divider'
import VectorSpaceFull from '@/app/components/billing/vector-space-full'
import LocalFile from '@/app/components/datasets/documents/create-from-pipeline/data-source/local-file'
import OnlineDocuments from '@/app/components/datasets/documents/create-from-pipeline/data-source/online-documents'
import OnlineDrive from '@/app/components/datasets/documents/create-from-pipeline/data-source/online-drive'
import WebsiteCrawl from '@/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl'
import { DatasourceType } from '@/models/pipeline'
import UpgradeCard from '../../../create/step-one/upgrade-card'
import Actions from '../actions'
import DataSourceOptions from '../data-source-options'
type StepOneContentProps = {
datasource: Datasource | undefined
datasourceType: string | undefined
pipelineNodes: Node<DataSourceNodeType>[]
supportBatchUpload: boolean
localFileListLength: number
isShowVectorSpaceFull: boolean
showSelect: boolean
totalOptions: number | undefined
selectedOptions: number | undefined
tip: string
nextBtnDisabled: boolean
onSelectDataSource: (dataSource: Datasource) => void
onCredentialChange: (credentialId: string) => void
onSelectAll: () => void
onNextStep: () => void
}
const StepOneContent = ({
datasource,
datasourceType,
pipelineNodes,
supportBatchUpload,
localFileListLength,
isShowVectorSpaceFull,
showSelect,
totalOptions,
selectedOptions,
tip,
nextBtnDisabled,
onSelectDataSource,
onCredentialChange,
onSelectAll,
onNextStep,
}: StepOneContentProps) => {
const showUpgradeCard = !supportBatchUpload
&& datasourceType === DatasourceType.localFile
&& localFileListLength > 0
return (
<div className="flex flex-col gap-y-5 pt-4">
<DataSourceOptions
datasourceNodeId={datasource?.nodeId || ''}
onSelect={onSelectDataSource}
pipelineNodes={pipelineNodes}
/>
{datasourceType === DatasourceType.localFile && (
<LocalFile
allowedExtensions={datasource!.nodeData.fileExtensions || []}
supportBatchUpload={supportBatchUpload}
/>
)}
{datasourceType === DatasourceType.onlineDocument && (
<OnlineDocuments
nodeId={datasource!.nodeId}
nodeData={datasource!.nodeData}
onCredentialChange={onCredentialChange}
/>
)}
{datasourceType === DatasourceType.websiteCrawl && (
<WebsiteCrawl
nodeId={datasource!.nodeId}
nodeData={datasource!.nodeData}
onCredentialChange={onCredentialChange}
/>
)}
{datasourceType === DatasourceType.onlineDrive && (
<OnlineDrive
nodeId={datasource!.nodeId}
nodeData={datasource!.nodeData}
onCredentialChange={onCredentialChange}
/>
)}
{isShowVectorSpaceFull && <VectorSpaceFull />}
<Actions
showSelect={showSelect}
totalOptions={totalOptions}
selectedOptions={selectedOptions}
onSelectAll={onSelectAll}
disabled={nextBtnDisabled}
handleNextStep={onNextStep}
tip={tip}
/>
{showUpgradeCard && (
<>
<Divider type="horizontal" className="my-4 h-px bg-divider-subtle" />
<UpgradeCard />
</>
)}
</div>
)
}
export default memo(StepOneContent)

View File

@ -0,0 +1,23 @@
'use client'
import type { InitialDocumentDetail } from '@/models/pipeline'
import { memo } from 'react'
import Processing from '../processing'
type StepThreeContentProps = {
batchId: string
documents: InitialDocumentDetail[]
}
const StepThreeContent = ({
batchId,
documents,
}: StepThreeContentProps) => {
return (
<Processing
batchId={batchId}
documents={documents}
/>
)
}
export default memo(StepThreeContent)

View File

@ -0,0 +1,38 @@
'use client'
import type { RefObject } from 'react'
import { memo } from 'react'
import ProcessDocuments from '../process-documents'
type StepTwoContentProps = {
formRef: RefObject<{ submit: () => void } | null>
dataSourceNodeId: string
isRunning: boolean
onProcess: () => void
onPreview: () => void
onSubmit: (data: Record<string, unknown>) => void
onBack: () => void
}
const StepTwoContent = ({
formRef,
dataSourceNodeId,
isRunning,
onProcess,
onPreview,
onSubmit,
onBack,
}: StepTwoContentProps) => {
return (
<ProcessDocuments
ref={formRef}
dataSourceNodeId={dataSourceNodeId}
isRunning={isRunning}
onProcess={onProcess}
onPreview={onPreview}
onSubmit={onSubmit}
onBack={onBack}
/>
)
}
export default memo(StepTwoContent)

View File

@ -0,0 +1,63 @@
import type { NotionPage } from '@/models/common'
import type { CrawlResultItem, CustomFile as File } from '@/models/datasets'
import type { OnlineDriveFile } from '@/models/pipeline'
import { TransferMethod } from '@/types/app'
/**
* Build datasource info for local files
*/
export const buildLocalFileDatasourceInfo = (
file: File,
credentialId: string,
): Record<string, unknown> => ({
related_id: file.id,
name: file.name,
type: file.type,
size: file.size,
extension: file.extension,
mime_type: file.mime_type,
url: '',
transfer_method: TransferMethod.local_file,
credential_id: credentialId,
})
/**
* Build datasource info for online documents
*/
export const buildOnlineDocumentDatasourceInfo = (
page: NotionPage & { workspace_id: string },
credentialId: string,
): Record<string, unknown> => {
const { workspace_id, ...rest } = page
return {
workspace_id,
page: rest,
credential_id: credentialId,
}
}
/**
* Build datasource info for website crawl
*/
export const buildWebsiteCrawlDatasourceInfo = (
page: CrawlResultItem,
credentialId: string,
): Record<string, unknown> => ({
...page,
credential_id: credentialId,
})
/**
* Build datasource info for online drive
*/
export const buildOnlineDriveDatasourceInfo = (
file: OnlineDriveFile,
bucket: string,
credentialId: string,
): Record<string, unknown> => ({
bucket,
id: file.id,
name: file.name,
type: file.type,
credential_id: credentialId,
})

View File

@ -18,7 +18,7 @@ import { useDocumentDetail, useDocumentMetadata, useInvalidDocumentList } from '
import { useCheckSegmentBatchImportProgress, useChildSegmentListKey, useSegmentBatchImport, useSegmentListKey } from '@/service/knowledge/use-segment'
import { useInvalid } from '@/service/use-base'
import { cn } from '@/utils/classnames'
import Operations from '../operations'
import Operations from '../components/operations'
import StatusItem from '../status-item'
import BatchModal from './batch-modal'
import Completed from './completed'

View File

@ -0,0 +1,197 @@
import type { DocumentListResponse } from '@/models/datasets'
import type { SortType } from '@/service/datasets'
import { useDebounce, useDebounceFn } from 'ahooks'
import { useCallback, useEffect, useMemo, useState } from 'react'
import { normalizeStatusForQuery, sanitizeStatusValue } from '../status-filter'
import useDocumentListQueryState from './use-document-list-query-state'
/**
* Custom hook to manage documents page state including:
* - Search state (input value, debounced search value)
* - Filter state (status filter, sort value)
* - Pagination state (current page, limit)
* - Selection state (selected document ids)
* - Polling state (timer control for auto-refresh)
*/
export function useDocumentsPageState() {
const { query, updateQuery } = useDocumentListQueryState()
// Search state
const [inputValue, setInputValue] = useState<string>('')
const [searchValue, setSearchValue] = useState<string>('')
const debouncedSearchValue = useDebounce(searchValue, { wait: 500 })
// Filter & sort state
const [statusFilterValue, setStatusFilterValue] = useState<string>(() => sanitizeStatusValue(query.status))
const [sortValue, setSortValue] = useState<SortType>(query.sort)
const normalizedStatusFilterValue = useMemo(
() => normalizeStatusForQuery(statusFilterValue),
[statusFilterValue],
)
// Pagination state
const [currPage, setCurrPage] = useState<number>(query.page - 1)
const [limit, setLimit] = useState<number>(query.limit)
// Selection state
const [selectedIds, setSelectedIds] = useState<string[]>([])
// Polling state
const [timerCanRun, setTimerCanRun] = useState(true)
// Initialize search value from URL on mount
useEffect(() => {
if (query.keyword) {
setInputValue(query.keyword)
setSearchValue(query.keyword)
}
}, []) // Only run on mount
// Sync local state with URL query changes
useEffect(() => {
setCurrPage(query.page - 1)
setLimit(query.limit)
if (query.keyword !== searchValue) {
setInputValue(query.keyword)
setSearchValue(query.keyword)
}
setStatusFilterValue((prev) => {
const nextValue = sanitizeStatusValue(query.status)
return prev === nextValue ? prev : nextValue
})
setSortValue(query.sort)
}, [query])
// Update URL when search changes
useEffect(() => {
if (debouncedSearchValue !== query.keyword) {
setCurrPage(0)
updateQuery({ keyword: debouncedSearchValue, page: 1 })
}
}, [debouncedSearchValue, query.keyword, updateQuery])
// Clear selection when search changes
useEffect(() => {
if (searchValue !== query.keyword)
setSelectedIds([])
}, [searchValue, query.keyword])
// Clear selection when status filter changes
useEffect(() => {
setSelectedIds([])
}, [normalizedStatusFilterValue])
// Page change handler
const handlePageChange = useCallback((newPage: number) => {
setCurrPage(newPage)
updateQuery({ page: newPage + 1 })
}, [updateQuery])
// Limit change handler
const handleLimitChange = useCallback((newLimit: number) => {
setLimit(newLimit)
setCurrPage(0)
updateQuery({ limit: newLimit, page: 1 })
}, [updateQuery])
// Debounced search handler
const { run: handleSearch } = useDebounceFn(() => {
setSearchValue(inputValue)
}, { wait: 500 })
// Input change handler
const handleInputChange = useCallback((value: string) => {
setInputValue(value)
handleSearch()
}, [handleSearch])
// Status filter change handler
const handleStatusFilterChange = useCallback((value: string) => {
const selectedValue = sanitizeStatusValue(value)
setStatusFilterValue(selectedValue)
setCurrPage(0)
updateQuery({ status: selectedValue, page: 1 })
}, [updateQuery])
// Status filter clear handler
const handleStatusFilterClear = useCallback(() => {
if (statusFilterValue === 'all')
return
setStatusFilterValue('all')
setCurrPage(0)
updateQuery({ status: 'all', page: 1 })
}, [statusFilterValue, updateQuery])
// Sort change handler
const handleSortChange = useCallback((value: string) => {
const next = value as SortType
if (next === sortValue)
return
setSortValue(next)
setCurrPage(0)
updateQuery({ sort: next, page: 1 })
}, [sortValue, updateQuery])
// Update polling state based on documents response
const updatePollingState = useCallback((documentsRes: DocumentListResponse | undefined) => {
if (!documentsRes?.data)
return
let completedNum = 0
documentsRes.data.forEach((documentItem) => {
const { indexing_status } = documentItem
const isEmbedded = indexing_status === 'completed' || indexing_status === 'paused' || indexing_status === 'error'
if (isEmbedded)
completedNum++
})
const hasIncompleteDocuments = completedNum !== documentsRes.data.length
const transientStatuses = ['queuing', 'indexing', 'paused']
const shouldForcePolling = normalizedStatusFilterValue === 'all'
? false
: transientStatuses.includes(normalizedStatusFilterValue)
setTimerCanRun(shouldForcePolling || hasIncompleteDocuments)
}, [normalizedStatusFilterValue])
// Adjust page when total pages change
const adjustPageForTotal = useCallback((documentsRes: DocumentListResponse | undefined) => {
if (!documentsRes)
return
const totalPages = Math.ceil(documentsRes.total / limit)
if (currPage > 0 && currPage + 1 > totalPages)
handlePageChange(totalPages > 0 ? totalPages - 1 : 0)
}, [limit, currPage, handlePageChange])
return {
// Search state
inputValue,
searchValue,
debouncedSearchValue,
handleInputChange,
// Filter & sort state
statusFilterValue,
sortValue,
normalizedStatusFilterValue,
handleStatusFilterChange,
handleStatusFilterClear,
handleSortChange,
// Pagination state
currPage,
limit,
handlePageChange,
handleLimitChange,
// Selection state
selectedIds,
setSelectedIds,
// Polling state
timerCanRun,
updatePollingState,
adjustPageForTotal,
}
}
export default useDocumentsPageState

View File

@ -1,185 +1,55 @@
'use client'
import type { FC } from 'react'
import type { Item } from '@/app/components/base/select'
import type { SortType } from '@/service/datasets'
import { PlusIcon } from '@heroicons/react/24/solid'
import { RiDraftLine, RiExternalLinkLine } from '@remixicon/react'
import { useDebounce, useDebounceFn } from 'ahooks'
import { useRouter } from 'next/navigation'
import * as React from 'react'
import { useCallback, useEffect, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import Input from '@/app/components/base/input'
import { useCallback, useEffect } from 'react'
import Loading from '@/app/components/base/loading'
import IndexFailed from '@/app/components/datasets/common/document-status-with-action/index-failed'
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
import { useDocLink } from '@/context/i18n'
import { useProviderContext } from '@/context/provider-context'
import { DataSourceType } from '@/models/datasets'
import { useDocumentList, useInvalidDocumentDetail, useInvalidDocumentList } from '@/service/knowledge/use-document'
import { useChildSegmentListKey, useSegmentListKey } from '@/service/knowledge/use-segment'
import { useInvalid } from '@/service/use-base'
import { cn } from '@/utils/classnames'
import Chip from '../../base/chip'
import Sort from '../../base/sort'
import AutoDisabledDocument from '../common/document-status-with-action/auto-disabled-document'
import StatusWithAction from '../common/document-status-with-action/status-with-action'
import useEditDocumentMetadata from '../metadata/hooks/use-edit-dataset-metadata'
import DatasetMetadataDrawer from '../metadata/metadata-dataset/dataset-metadata-drawer'
import useDocumentListQueryState from './hooks/use-document-list-query-state'
import List from './list'
import { normalizeStatusForQuery, sanitizeStatusValue } from './status-filter'
import { useIndexStatus } from './status-item/hooks'
import s from './style.module.css'
const FolderPlusIcon = ({ className }: React.SVGProps<SVGElement>) => {
return (
<svg width="20" height="20" viewBox="0 0 20 20" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
<path d="M10.8332 5.83333L9.90355 3.9741C9.63601 3.439 9.50222 3.17144 9.30265 2.97597C9.12615 2.80311 8.91344 2.67164 8.6799 2.59109C8.41581 2.5 8.11668 2.5 7.51841 2.5H4.33317C3.39975 2.5 2.93304 2.5 2.57652 2.68166C2.26292 2.84144 2.00795 3.09641 1.84816 3.41002C1.6665 3.76654 1.6665 4.23325 1.6665 5.16667V5.83333M1.6665 5.83333H14.3332C15.7333 5.83333 16.4334 5.83333 16.9681 6.10582C17.4386 6.3455 17.821 6.72795 18.0607 7.19836C18.3332 7.73314 18.3332 8.4332 18.3332 9.83333V13.5C18.3332 14.9001 18.3332 15.6002 18.0607 16.135C17.821 16.6054 17.4386 16.9878 16.9681 17.2275C16.4334 17.5 15.7333 17.5 14.3332 17.5H5.6665C4.26637 17.5 3.56631 17.5 3.03153 17.2275C2.56112 16.9878 2.17867 16.6054 1.93899 16.135C1.6665 15.6002 1.6665 14.9001 1.6665 13.5V5.83333ZM9.99984 14.1667V9.16667M7.49984 11.6667H12.4998" stroke="#667085" strokeWidth="1.5" strokeLinecap="round" strokeLinejoin="round" />
</svg>
)
}
const ThreeDotsIcon = ({ className }: React.SVGProps<SVGElement>) => {
return (
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
<path d="M5 6.5V5M8.93934 7.56066L10 6.5M10.0103 11.5H11.5103" stroke="#374151" strokeWidth="2" strokeLinecap="round" strokeLinejoin="round" />
</svg>
)
}
const NotionIcon = ({ className }: React.SVGProps<SVGElement>) => {
return (
<svg width="20" height="20" viewBox="0 0 20 20" fill="none" xmlns="http://www.w3.org/2000/svg" className={className ?? ''}>
<g clipPath="url(#clip0_2164_11263)">
<path fillRule="evenodd" clipRule="evenodd" d="M3.5725 18.2611L1.4229 15.5832C0.905706 14.9389 0.625 14.1466 0.625 13.3312V3.63437C0.625 2.4129 1.60224 1.39936 2.86295 1.31328L12.8326 0.632614C13.5569 0.583164 14.2768 0.775682 14.8717 1.17794L18.3745 3.5462C19.0015 3.97012 19.375 4.66312 19.375 5.40266V16.427C19.375 17.6223 18.4141 18.6121 17.1798 18.688L6.11458 19.3692C5.12958 19.4298 4.17749 19.0148 3.5725 18.2611Z" fill="white" />
<path d="M7.03006 8.48669V8.35974C7.03006 8.03794 7.28779 7.77104 7.61997 7.74886L10.0396 7.58733L13.3857 12.5147V8.19009L12.5244 8.07528V8.01498C12.5244 7.68939 12.788 7.42074 13.1244 7.4035L15.326 7.29073V7.60755C15.326 7.75628 15.2154 7.88349 15.0638 7.90913L14.534 7.99874V15.0023L13.8691 15.231C13.3136 15.422 12.6952 15.2175 12.3772 14.7377L9.12879 9.83574V14.5144L10.1287 14.7057L10.1147 14.7985C10.0711 15.089 9.82028 15.3087 9.51687 15.3222L7.03006 15.4329C6.99718 15.1205 7.23132 14.841 7.55431 14.807L7.88143 14.7727V8.53453L7.03006 8.48669Z" fill="black" />
<path fillRule="evenodd" clipRule="evenodd" d="M12.9218 1.85424L2.95217 2.53491C2.35499 2.57568 1.89209 3.05578 1.89209 3.63437V13.3312C1.89209 13.8748 2.07923 14.403 2.42402 14.8325L4.57362 17.5104C4.92117 17.9434 5.46812 18.1818 6.03397 18.147L17.0991 17.4658C17.6663 17.4309 18.1078 16.9762 18.1078 16.427V5.40266C18.1078 5.06287 17.9362 4.74447 17.6481 4.54969L14.1453 2.18143C13.7883 1.94008 13.3564 1.82457 12.9218 1.85424ZM3.44654 3.78562C3.30788 3.68296 3.37387 3.46909 3.54806 3.4566L12.9889 2.77944C13.2897 2.75787 13.5886 2.8407 13.8318 3.01305L15.7261 4.35508C15.798 4.40603 15.7642 4.51602 15.6752 4.52086L5.67742 5.0646C5.37485 5.08106 5.0762 4.99217 4.83563 4.81406L3.44654 3.78562ZM5.20848 6.76919C5.20848 6.4444 5.47088 6.1761 5.80642 6.15783L16.3769 5.58216C16.7039 5.56435 16.9792 5.81583 16.9792 6.13239V15.6783C16.9792 16.0025 16.7177 16.2705 16.3829 16.2896L5.8793 16.8872C5.51537 16.9079 5.20848 16.6283 5.20848 16.2759V6.76919Z" fill="black" />
</g>
<defs>
<clipPath id="clip0_2164_11263">
<rect width="20" height="20" fill="white" />
</clipPath>
</defs>
</svg>
)
}
const EmptyElement: FC<{ canAdd: boolean, onClick: () => void, type?: 'upload' | 'sync' }> = ({ canAdd = true, onClick, type = 'upload' }) => {
const { t } = useTranslation()
return (
<div className={s.emptyWrapper}>
<div className={s.emptyElement}>
<div className={s.emptySymbolIconWrapper}>
{type === 'upload' ? <FolderPlusIcon /> : <NotionIcon />}
</div>
<span className={s.emptyTitle}>
{t('list.empty.title', { ns: 'datasetDocuments' })}
<ThreeDotsIcon className="relative -left-1.5 -top-3 inline" />
</span>
<div className={s.emptyTip}>
{t(`list.empty.${type}.tip`, { ns: 'datasetDocuments' })}
</div>
{type === 'upload' && canAdd && (
<Button onClick={onClick} className={s.addFileBtn} variant="secondary-accent">
<PlusIcon className={s.plusIcon} />
{t('list.addFile', { ns: 'datasetDocuments' })}
</Button>
)}
</div>
</div>
)
}
import DocumentsHeader from './components/documents-header'
import EmptyElement from './components/empty-element'
import List from './components/list'
import useDocumentsPageState from './hooks/use-documents-page-state'
type IDocumentsProps = {
datasetId: string
}
const Documents: FC<IDocumentsProps> = ({ datasetId }) => {
const { t } = useTranslation()
const docLink = useDocLink()
const router = useRouter()
const { plan } = useProviderContext()
const isFreePlan = plan.type === 'sandbox'
const { query, updateQuery } = useDocumentListQueryState()
const [inputValue, setInputValue] = useState<string>('') // the input value
const [searchValue, setSearchValue] = useState<string>('')
const [statusFilterValue, setStatusFilterValue] = useState<string>(() => sanitizeStatusValue(query.status))
const [sortValue, setSortValue] = useState<SortType>(query.sort)
const DOC_INDEX_STATUS_MAP = useIndexStatus()
const [currPage, setCurrPage] = React.useState<number>(query.page - 1) // Convert to 0-based index
const [limit, setLimit] = useState<number>(query.limit)
const router = useRouter()
const dataset = useDatasetDetailContextWithSelector(s => s.dataset)
const [timerCanRun, setTimerCanRun] = useState(true)
const isDataSourceNotion = dataset?.data_source_type === DataSourceType.NOTION
const isDataSourceWeb = dataset?.data_source_type === DataSourceType.WEB
const isDataSourceFile = dataset?.data_source_type === DataSourceType.FILE
const embeddingAvailable = !!dataset?.embedding_available
const debouncedSearchValue = useDebounce(searchValue, { wait: 500 })
const statusFilterItems: Item[] = useMemo(() => [
{ value: 'all', name: t('list.index.all', { ns: 'datasetDocuments' }) as string },
{ value: 'queuing', name: DOC_INDEX_STATUS_MAP.queuing.text },
{ value: 'indexing', name: DOC_INDEX_STATUS_MAP.indexing.text },
{ value: 'paused', name: DOC_INDEX_STATUS_MAP.paused.text },
{ value: 'error', name: DOC_INDEX_STATUS_MAP.error.text },
{ value: 'available', name: DOC_INDEX_STATUS_MAP.available.text },
{ value: 'enabled', name: DOC_INDEX_STATUS_MAP.enabled.text },
{ value: 'disabled', name: DOC_INDEX_STATUS_MAP.disabled.text },
{ value: 'archived', name: DOC_INDEX_STATUS_MAP.archived.text },
], [DOC_INDEX_STATUS_MAP, t])
const normalizedStatusFilterValue = useMemo(() => normalizeStatusForQuery(statusFilterValue), [statusFilterValue])
const sortItems: Item[] = useMemo(() => [
{ value: 'created_at', name: t('list.sort.uploadTime', { ns: 'datasetDocuments' }) as string },
{ value: 'hit_count', name: t('list.sort.hitCount', { ns: 'datasetDocuments' }) as string },
], [t])
// Initialize search value from URL on mount
useEffect(() => {
if (query.keyword) {
setInputValue(query.keyword)
setSearchValue(query.keyword)
}
}, []) // Only run on mount
// Sync local state with URL query changes
useEffect(() => {
setCurrPage(query.page - 1)
setLimit(query.limit)
if (query.keyword !== searchValue) {
setInputValue(query.keyword)
setSearchValue(query.keyword)
}
setStatusFilterValue((prev) => {
const nextValue = sanitizeStatusValue(query.status)
return prev === nextValue ? prev : nextValue
})
setSortValue(query.sort)
}, [query])
// Update URL when pagination changes
const handlePageChange = (newPage: number) => {
setCurrPage(newPage)
updateQuery({ page: newPage + 1 }) // Pagination emits 0-based page, convert to 1-based for URL
}
// Update URL when limit changes
const handleLimitChange = (newLimit: number) => {
setLimit(newLimit)
setCurrPage(0) // Reset to first page when limit changes
updateQuery({ limit: newLimit, page: 1 })
}
// Update URL when search changes
useEffect(() => {
if (debouncedSearchValue !== query.keyword) {
setCurrPage(0) // Reset to first page when search changes
updateQuery({ keyword: debouncedSearchValue, page: 1 })
}
}, [debouncedSearchValue, query.keyword, updateQuery])
// Use custom hook for page state management
const {
inputValue,
debouncedSearchValue,
handleInputChange,
statusFilterValue,
sortValue,
normalizedStatusFilterValue,
handleStatusFilterChange,
handleStatusFilterClear,
handleSortChange,
currPage,
limit,
handlePageChange,
handleLimitChange,
selectedIds,
setSelectedIds,
timerCanRun,
updatePollingState,
adjustPageForTotal,
} = useDocumentsPageState()
// Fetch document list
const { data: documentsRes, isLoading: isListLoading } = useDocumentList({
datasetId,
query: {
@ -192,16 +62,18 @@ const Documents: FC<IDocumentsProps> = ({ datasetId }) => {
refetchInterval: timerCanRun ? 2500 : 0,
})
const invalidDocumentList = useInvalidDocumentList(datasetId)
// Update polling state when documents change
useEffect(() => {
if (documentsRes) {
const totalPages = Math.ceil(documentsRes.total / limit)
if (totalPages < currPage + 1)
setCurrPage(totalPages === 0 ? 0 : totalPages - 1)
}
}, [documentsRes])
updatePollingState(documentsRes)
}, [documentsRes, updatePollingState])
// Adjust page when total changes
useEffect(() => {
adjustPageForTotal(documentsRes)
}, [documentsRes, adjustPageForTotal])
// Invalidation hooks
const invalidDocumentList = useInvalidDocumentList(datasetId)
const invalidDocumentDetail = useInvalidDocumentDetail()
const invalidChunkList = useInvalid(useSegmentListKey)
const invalidChildChunkList = useInvalid(useChildSegmentListKey)
@ -213,73 +85,9 @@ const Documents: FC<IDocumentsProps> = ({ datasetId }) => {
invalidChunkList()
invalidChildChunkList()
}, 5000)
}, [])
useEffect(() => {
let completedNum = 0
let percent = 0
documentsRes?.data?.forEach((documentItem) => {
const { indexing_status, completed_segments, total_segments } = documentItem
const isEmbedded = indexing_status === 'completed' || indexing_status === 'paused' || indexing_status === 'error'
if (isEmbedded)
completedNum++
const completedCount = completed_segments || 0
const totalCount = total_segments || 0
if (totalCount === 0 && completedCount === 0) {
percent = isEmbedded ? 100 : 0
}
else {
const per = Math.round(completedCount * 100 / totalCount)
percent = per > 100 ? 100 : per
}
return {
...documentItem,
percent,
}
})
const hasIncompleteDocuments = completedNum !== documentsRes?.data?.length
const transientStatuses = ['queuing', 'indexing', 'paused']
const shouldForcePolling = normalizedStatusFilterValue === 'all'
? false
: transientStatuses.includes(normalizedStatusFilterValue)
setTimerCanRun(shouldForcePolling || hasIncompleteDocuments)
}, [documentsRes, normalizedStatusFilterValue])
const total = documentsRes?.total || 0
const routeToDocCreate = () => {
// if dataset is created from pipeline, go to create from pipeline page
if (dataset?.runtime_mode === 'rag_pipeline') {
router.push(`/datasets/${datasetId}/documents/create-from-pipeline`)
return
}
router.push(`/datasets/${datasetId}/documents/create`)
}
const documentsList = documentsRes?.data
const [selectedIds, setSelectedIds] = useState<string[]>([])
// Clear selection when search changes to avoid confusion
useEffect(() => {
if (searchValue !== query.keyword)
setSelectedIds([])
}, [searchValue, query.keyword])
useEffect(() => {
setSelectedIds([])
}, [normalizedStatusFilterValue])
const { run: handleSearch } = useDebounceFn(() => {
setSearchValue(inputValue)
}, { wait: 500 })
const handleInputChange = (value: string) => {
setInputValue(value)
handleSearch()
}
}, [invalidDocumentList, invalidDocumentDetail, invalidChunkList, invalidChildChunkList])
// Metadata editing hook
const {
isShowEditModal: isShowEditMetadataModal,
showEditModal: showEditMetadataModal,
@ -297,130 +105,84 @@ const Documents: FC<IDocumentsProps> = ({ datasetId }) => {
onUpdateDocList: invalidDocumentList,
})
// Route to document creation page
const routeToDocCreate = useCallback(() => {
if (dataset?.runtime_mode === 'rag_pipeline') {
router.push(`/datasets/${datasetId}/documents/create-from-pipeline`)
return
}
router.push(`/datasets/${datasetId}/documents/create`)
}, [dataset?.runtime_mode, datasetId, router])
const total = documentsRes?.total || 0
const documentsList = documentsRes?.data
// Render content based on loading and data state
const renderContent = () => {
if (isListLoading)
return <Loading type="app" />
if (total > 0) {
return (
<List
embeddingAvailable={embeddingAvailable}
documents={documentsList || []}
datasetId={datasetId}
onUpdate={handleUpdate}
selectedIds={selectedIds}
onSelectedIdChange={setSelectedIds}
statusFilterValue={normalizedStatusFilterValue}
remoteSortValue={sortValue}
pagination={{
total,
limit,
onLimitChange: handleLimitChange,
current: currPage,
onChange: handlePageChange,
}}
onManageMetadata={showEditMetadataModal}
/>
)
}
const isDataSourceNotion = dataset?.data_source_type === DataSourceType.NOTION
return (
<EmptyElement
canAdd={embeddingAvailable}
onClick={routeToDocCreate}
type={isDataSourceNotion ? 'sync' : 'upload'}
/>
)
}
return (
<div className="flex h-full flex-col">
<div className="flex flex-col justify-center gap-1 px-6 pt-4">
<h1 className="text-base font-semibold text-text-primary">{t('list.title', { ns: 'datasetDocuments' })}</h1>
<div className="flex items-center space-x-0.5 text-sm font-normal text-text-tertiary">
<span>{t('list.desc', { ns: 'datasetDocuments' })}</span>
<a
className="flex items-center text-text-accent"
target="_blank"
href={docLink('/guides/knowledge-base/integrate-knowledge-within-application')}
>
<span>{t('list.learnMore', { ns: 'datasetDocuments' })}</span>
<RiExternalLinkLine className="h-3 w-3" />
</a>
</div>
</div>
<DocumentsHeader
datasetId={datasetId}
dataSourceType={dataset?.data_source_type}
embeddingAvailable={embeddingAvailable}
isFreePlan={isFreePlan}
statusFilterValue={statusFilterValue}
sortValue={sortValue}
inputValue={inputValue}
onStatusFilterChange={handleStatusFilterChange}
onStatusFilterClear={handleStatusFilterClear}
onSortChange={handleSortChange}
onInputChange={handleInputChange}
isShowEditMetadataModal={isShowEditMetadataModal}
showEditMetadataModal={showEditMetadataModal}
hideEditMetadataModal={hideEditMetadataModal}
datasetMetaData={datasetMetaData}
builtInMetaData={builtInMetaData}
builtInEnabled={!!builtInEnabled}
onAddMetaData={handleAddMetaData}
onRenameMetaData={handleRename}
onDeleteMetaData={handleDeleteMetaData}
onBuiltInEnabledChange={setBuiltInEnabled}
onAddDocument={routeToDocCreate}
/>
<div className="flex h-0 grow flex-col px-6 pt-4">
<div className="flex flex-wrap items-center justify-between">
<div className="flex items-center gap-2">
<Chip
className="w-[160px]"
showLeftIcon={false}
value={statusFilterValue}
items={statusFilterItems}
onSelect={(item) => {
const selectedValue = sanitizeStatusValue(item?.value ? String(item.value) : '')
setStatusFilterValue(selectedValue)
setCurrPage(0)
updateQuery({ status: selectedValue, page: 1 })
}}
onClear={() => {
if (statusFilterValue === 'all')
return
setStatusFilterValue('all')
setCurrPage(0)
updateQuery({ status: 'all', page: 1 })
}}
/>
<Input
showLeftIcon
showClearIcon
wrapperClassName="!w-[200px]"
value={inputValue}
onChange={e => handleInputChange(e.target.value)}
onClear={() => handleInputChange('')}
/>
<div className="h-3.5 w-px bg-divider-regular"></div>
<Sort
order={sortValue.startsWith('-') ? '-' : ''}
value={sortValue.replace('-', '')}
items={sortItems}
onSelect={(value) => {
const next = String(value) as SortType
if (next === sortValue)
return
setSortValue(next)
setCurrPage(0)
updateQuery({ sort: next, page: 1 })
}}
/>
</div>
<div className="flex !h-8 items-center justify-center gap-2">
{!isFreePlan && <AutoDisabledDocument datasetId={datasetId} />}
<IndexFailed datasetId={datasetId} />
{!embeddingAvailable && <StatusWithAction type="warning" description={t('embeddingModelNotAvailable', { ns: 'dataset' })} />}
{embeddingAvailable && (
<Button variant="secondary" className="shrink-0" onClick={showEditMetadataModal}>
<RiDraftLine className="mr-1 size-4" />
{t('metadata.metadata', { ns: 'dataset' })}
</Button>
)}
{isShowEditMetadataModal && (
<DatasetMetadataDrawer
userMetadata={datasetMetaData || []}
onClose={hideEditMetadataModal}
onAdd={handleAddMetaData}
onRename={handleRename}
onRemove={handleDeleteMetaData}
builtInMetadata={builtInMetaData || []}
isBuiltInEnabled={!!builtInEnabled}
onIsBuiltInEnabledChange={setBuiltInEnabled}
/>
)}
{embeddingAvailable && (
<Button variant="primary" onClick={routeToDocCreate} className="shrink-0">
<PlusIcon className={cn('mr-2 h-4 w-4 stroke-current')} />
{isDataSourceNotion && t('list.addPages', { ns: 'datasetDocuments' })}
{isDataSourceWeb && t('list.addUrl', { ns: 'datasetDocuments' })}
{(!dataset?.data_source_type || isDataSourceFile) && t('list.addFile', { ns: 'datasetDocuments' })}
</Button>
)}
</div>
</div>
{isListLoading
? <Loading type="app" />
// eslint-disable-next-line sonarjs/no-nested-conditional
: total > 0
? (
<List
embeddingAvailable={embeddingAvailable}
documents={documentsList || []}
datasetId={datasetId}
onUpdate={handleUpdate}
selectedIds={selectedIds}
onSelectedIdChange={setSelectedIds}
statusFilterValue={normalizedStatusFilterValue}
remoteSortValue={sortValue}
pagination={{
total,
limit,
onLimitChange: handleLimitChange,
current: currPage,
onChange: handlePageChange,
}}
onManageMetadata={showEditMetadataModal}
/>
)
: (
<EmptyElement
canAdd={embeddingAvailable}
onClick={routeToDocCreate}
type={isDataSourceNotion ? 'sync' : 'upload'}
/>
)}
{renderContent()}
</div>
</div>
)

View File

@ -30,7 +30,7 @@ const ApiBasedExtensionModal: FC<ApiBasedExtensionModalProps> = ({
onSave,
}) => {
const { t } = useTranslation()
const docLink = useDocLink()
const docLink = useDocLink('https://docs.dify.ai/versions/3-0-x')
const [localeData, setLocaleData] = useState(data)
const [loading, setLoading] = useState(false)
const { notify } = useToastContext()
@ -102,7 +102,7 @@ const ApiBasedExtensionModal: FC<ApiBasedExtensionModalProps> = ({
<div className="flex h-9 items-center justify-between text-sm font-medium text-text-primary">
{t('apiBasedExtension.modal.apiEndpoint.title', { ns: 'common' })}
<a
href={docLink('/guides/extension/api-based-extension/README')}
href={docLink('/user-guide/extension/api-based-extension/README#api-based-extension')}
target="_blank"
rel="noopener noreferrer"
className="group flex items-center text-xs font-normal text-text-accent"

View File

@ -55,7 +55,6 @@ type FormProps<
nodeId?: string
nodeOutputVars?: NodeOutPutVar[]
availableNodes?: Node[]
canChooseMCPTool?: boolean
}
function Form<
@ -81,7 +80,6 @@ function Form<
nodeId,
nodeOutputVars,
availableNodes,
canChooseMCPTool,
}: FormProps<CustomFormSchema>) {
const language = useLanguage()
const [changeKey, setChangeKey] = useState('')
@ -407,7 +405,6 @@ function Form<
value={value[variable] || []}
onChange={item => handleFormChange(variable, item as any)}
supportCollapse
canChooseMCPTool={canChooseMCPTool}
/>
{fieldMoreInfo?.(formSchema)}
{validating && changeKey === variable && <ValidatingTip />}

View File

@ -2,9 +2,9 @@
import type { INavSelectorProps } from './nav-selector'
import Link from 'next/link'
import { usePathname, useSearchParams, useSelectedLayoutSegment } from 'next/navigation'
import { useSelectedLayoutSegment } from 'next/navigation'
import * as React from 'react'
import { useEffect, useState } from 'react'
import { useState } from 'react'
import { useStore as useAppStore } from '@/app/components/app/store'
import { ArrowNarrowLeft } from '@/app/components/base/icons/src/vender/line/arrows'
import { cn } from '@/utils/classnames'
@ -36,14 +36,6 @@ const Nav = ({
const [hovered, setHovered] = useState(false)
const segment = useSelectedLayoutSegment()
const isActivated = Array.isArray(activeSegment) ? activeSegment.includes(segment!) : segment === activeSegment
const pathname = usePathname()
const searchParams = useSearchParams()
const [linkLastSearchParams, setLinkLastSearchParams] = useState('')
useEffect(() => {
if (pathname === link)
setLinkLastSearchParams(searchParams.toString())
}, [pathname, searchParams])
return (
<div className={`
@ -52,7 +44,7 @@ const Nav = ({
${!curNav && !isActivated && 'hover:bg-components-main-nav-nav-button-bg-hover'}
`}
>
<Link href={link + (linkLastSearchParams && `?${linkLastSearchParams}`)}>
<Link href={link}>
<div
onClick={(e) => {
// Don't clear state if opening in new tab/window

View File

@ -897,6 +897,58 @@ describe('Icon', () => {
const iconDiv = container.firstChild as HTMLElement
expect(iconDiv).toHaveStyle({ backgroundImage: 'url(/icon?name=test&size=large)' })
})
it('should not render status indicators when src is object with installed=true', () => {
render(<Icon src={{ content: '🎉', background: '#fff' }} installed={true} />)
// Status indicators should not render for object src
expect(screen.queryByTestId('ri-check-line')).not.toBeInTheDocument()
})
it('should not render status indicators when src is object with installFailed=true', () => {
render(<Icon src={{ content: '🎉', background: '#fff' }} installFailed={true} />)
// Status indicators should not render for object src
expect(screen.queryByTestId('ri-close-line')).not.toBeInTheDocument()
})
it('should render object src with all size variants', () => {
const sizes: Array<'xs' | 'tiny' | 'small' | 'medium' | 'large'> = ['xs', 'tiny', 'small', 'medium', 'large']
sizes.forEach((size) => {
const { unmount } = render(<Icon src={{ content: '🔗', background: '#fff' }} size={size} />)
expect(screen.getByTestId('app-icon')).toHaveAttribute('data-size', size)
unmount()
})
})
it('should render object src with custom className', () => {
const { container } = render(
<Icon src={{ content: '🎉', background: '#fff' }} className="custom-object-icon" />,
)
expect(container.querySelector('.custom-object-icon')).toBeInTheDocument()
})
it('should pass correct props to AppIcon for object src', () => {
render(<Icon src={{ content: '😀', background: '#123456' }} />)
const appIcon = screen.getByTestId('app-icon')
expect(appIcon).toHaveAttribute('data-icon', '😀')
expect(appIcon).toHaveAttribute('data-background', '#123456')
expect(appIcon).toHaveAttribute('data-icon-type', 'emoji')
})
it('should render inner icon only when shouldUseMcpIcon returns true', () => {
// Test with MCP icon content
const { unmount } = render(<Icon src={{ content: '🔗', background: '#fff' }} />)
expect(screen.getByTestId('inner-icon')).toBeInTheDocument()
unmount()
// Test without MCP icon content
render(<Icon src={{ content: '🎉', background: '#fff' }} />)
expect(screen.queryByTestId('inner-icon')).not.toBeInTheDocument()
})
})
})

View File

@ -180,7 +180,14 @@ const AppPicker: FC<Props> = ({
background={app.icon_background}
imageUrl={app.icon_url}
/>
<div title={app.name} className="system-sm-medium grow text-components-input-text-filled">{app.name}</div>
<div title={`${app.name} (${app.id})`} className="system-sm-medium grow text-components-input-text-filled">
<span className="mr-1">{app.name}</span>
<span className="text-text-tertiary">
(
{app.id.slice(0, 8)}
)
</span>
</div>
<div className="system-2xs-medium-uppercase shrink-0 text-text-tertiary">{getAppType(app)}</div>
</div>
))}

Some files were not shown because too many files have changed in this diff Show More