Compare commits

..

37 Commits

Author SHA1 Message Date
yyh
92d44c0caa Run pnpm lint:fix 2025-12-17 17:01:21 +08:00
yyh
0478e4cde5 Refactor clear button implementation in type selector 2025-12-17 16:26:23 +08:00
yyh
0f816b52c2 test(web): add AppTypeSelector tests
- Add Jest/RTL coverage for AppTypeSelector, AppTypeLabel, AppTypeIcon\n- Improve clear control accessibility by using a button with aria-label
2025-12-17 16:20:14 +08:00
4fce99379e test(api): add a test for detect_file_encodings (#29778)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-12-17 14:33:30 +08:00
8d1e36540a fix: detect_file_encodings TypeError: tuple indices must be integers or slices, not str (#29595)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-12-17 13:58:05 +08:00
1d1351393a feat: update RAG recommended plugins hook to accept type parameter (#29735) 2025-12-17 13:48:23 +08:00
44f8915e30 feat: Add Aliyun SLS (Simple Log Service) integration for workflow execution logging (#28986)
Co-authored-by: hieheihei <270985384@qq.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2025-12-17 13:43:54 +08:00
94a5fd3617 chore: tests for webapp run batch (#29767) 2025-12-17 13:36:50 +08:00
5bb1346da8 chore: tests form add annotation (#29770)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-17 13:36:40 +08:00
a93eecaeee feat: Add "type" field to PipelineRecommendedPlugin model; Add query param "type" to recommended-plugins api. (#29736)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com>
2025-12-17 11:26:08 +08:00
86131d4bd8 feat: add datasource_parameters handling for API requests (#29757)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-17 10:37:55 +08:00
581b62cf01 feat: add automated tests for pipeline setting (#29478)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
2025-12-17 10:26:58 +08:00
yyh
91714ee413 chore(web): add some jest tests (#29754) 2025-12-17 10:21:32 +08:00
yyh
232149e63f chore: add tests for config string and dataset card item (#29743) 2025-12-17 10:19:10 +08:00
4a1ddea431 ci: show missing lines in coverage report summary (#29717) 2025-12-17 10:18:41 +08:00
5539bf8788 fix: add Slovenian and Tunisian Arabic translations across multiple language files (#29759) 2025-12-17 10:18:10 +08:00
dda7eb03c9 feat: _truncate_json_primitives support file (#29760) 2025-12-17 08:10:43 +09:00
c2f2be6b08 fix: oxlint no unused expressions (#29675)
Co-authored-by: daniel <daniel@example.com>
2025-12-16 18:00:04 +08:00
b7649f61f8 fix: Login secret text transmission (#29659)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2025-12-16 16:55:51 +08:00
ae4a9040df Feat/update notion preview (#29345)
Co-authored-by: twwu <twwu@dify.ai>
2025-12-16 16:43:45 +08:00
d2b63df7a1 chore: tests for components in config (#29739) 2025-12-16 16:39:04 +08:00
0749e6e090 test: Stabilize sharded Redis broadcast multi-subscriber test (#29733)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-12-16 16:35:55 +08:00
yyh
4589157963 test: Add comprehensive Jest test for AppCard component (#29667) 2025-12-16 15:44:51 +08:00
37d4dbeb96 feat: Remove TLS 1.1 from default NGINX protocols (#29728) 2025-12-16 15:39:42 +08:00
yyh
c036a12999 test: add comprehensive unit tests for APIKeyInfoPanel component (#29719) 2025-12-16 15:07:30 +08:00
47cd94ec3e chore: tests for billings (#29720) 2025-12-16 15:06:53 +08:00
e5cf0d0bf6 chore: Disable Swagger UI by default in docker samples (#29723) 2025-12-16 15:01:51 +08:00
yyh
240e1d155a test: add comprehensive tests for CustomizeModal component (#29709) 2025-12-16 14:21:05 +08:00
a915b8a584 revert: "security/fix-swagger-info-leak-m02" (#29721) 2025-12-16 14:19:33 +08:00
yyh
4553e4c12f test: add comprehensive Jest tests for CustomPage and WorkflowOnboardingModal components (#29714) 2025-12-16 14:18:09 +08:00
7695f9151c chore: webhook with bin file should guess mimetype (#29704)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Maries <xh001x@hotmail.com>
2025-12-16 13:34:27 +08:00
bdccbb6e86 feat: add GraphEngine layer node execution hooks (#28583) 2025-12-16 13:26:31 +08:00
c904c58c43 test: add unit tests for DocumentPicker, PreviewDocumentPicker, and R… (#29695)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
2025-12-16 13:06:50 +08:00
yyh
cb5162f37a test: add comprehensive Jest test for CreateAppTemplateDialog component (#29713) 2025-12-16 12:57:51 +08:00
yyh
eeb5129a17 refactor: create shared react-i18next mock to reduce duplication (#29711) 2025-12-16 12:45:17 +08:00
4cc6652424 feat: VECTOR_STORE supports seekdb (#29658) 2025-12-16 12:35:04 +09:00
yyh
a232da564a test: try to use Anthropic Skills to add tests for web/app/components/apps/ (#29607)
Signed-off-by: yyh <yuanyouhuilyz@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-12-16 10:42:34 +08:00
217 changed files with 23057 additions and 4856 deletions

View File

@ -76,7 +76,7 @@ Use this checklist when generating or reviewing tests for Dify frontend componen
- [ ] **DO NOT mock base components** (`@/app/components/base/*`)
- [ ] `jest.clearAllMocks()` in `beforeEach` (not `afterEach`)
- [ ] Shared mock state reset in `beforeEach`
- [ ] i18n mock returns keys (not empty strings)
- [ ] i18n uses shared mock (auto-loaded); only override locally for custom translations
- [ ] Router mocks match actual Next.js API
- [ ] Mocks reflect actual component conditional behavior
- [ ] Only mock: API services, complex context providers, third-party libs

View File

@ -318,3 +318,4 @@ For more detailed information, refer to:
- `web/jest.config.ts` - Jest configuration
- `web/jest.setup.ts` - Test environment setup
- `web/testing/analyze-component.js` - Component analysis tool
- `web/__mocks__/react-i18next.ts` - Shared i18n mock (auto-loaded by Jest, no explicit mock needed; override locally only for custom translations)

View File

@ -46,12 +46,22 @@ Only mock these categories:
## Essential Mocks
### 1. i18n (Always Required)
### 1. i18n (Auto-loaded via Shared Mock)
A shared mock is available at `web/__mocks__/react-i18next.ts` and is auto-loaded by Jest.
**No explicit mock needed** for most tests - it returns translation keys as-is.
For tests requiring custom translations, override the mock:
```typescript
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
t: (key: string) => {
const translations: Record<string, string> = {
'my.custom.key': 'Custom translation',
}
return translations[key] || key
},
}),
}))
```
@ -313,7 +323,7 @@ Need to use a component in test?
│ └─ YES → Mock it (next/navigation, external SDKs)
└─ Is it i18n?
└─ YES → Mock to return keys
└─ YES → Uses shared mock (auto-loaded). Override only for custom translations
```
## Factory Function Pattern

View File

@ -26,13 +26,20 @@ import userEvent from '@testing-library/user-event'
// WHY: Mocks must be hoisted to top of file (Jest requirement).
// They run BEFORE imports, so keep them before component imports.
// i18n (always required in Dify)
// WHY: Returns key instead of translation so tests don't depend on i18n files
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
// i18n (automatically mocked)
// WHY: Shared mock at web/__mocks__/react-i18next.ts is auto-loaded by Jest
// No explicit mock needed - it returns translation keys as-is
// Override only if custom translations are required:
// jest.mock('react-i18next', () => ({
// useTranslation: () => ({
// t: (key: string) => {
// const customTranslations: Record<string, string> = {
// 'my.custom.key': 'Custom Translation',
// }
// return customTranslations[key] || key
// },
// }),
// }))
// Router (if component uses useRouter, usePathname, useSearchParams)
// WHY: Isolates tests from Next.js routing, enables testing navigation behavior

View File

@ -93,4 +93,12 @@ jobs:
# Create a detailed coverage summary
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
{
echo ""
echo "<details><summary>File-level coverage (click to expand)</summary>"
echo ""
echo '```'
uv run --project api coverage report -m
echo '```'
echo "</details>"
} >> $GITHUB_STEP_SUMMARY

View File

@ -543,6 +543,25 @@ APP_MAX_EXECUTION_TIME=1200
APP_DEFAULT_ACTIVE_REQUESTS=0
APP_MAX_ACTIVE_REQUESTS=0
# Aliyun SLS Logstore Configuration
# Aliyun Access Key ID
ALIYUN_SLS_ACCESS_KEY_ID=
# Aliyun Access Key Secret
ALIYUN_SLS_ACCESS_KEY_SECRET=
# Aliyun SLS Endpoint (e.g., cn-hangzhou.log.aliyuncs.com)
ALIYUN_SLS_ENDPOINT=
# Aliyun SLS Region (e.g., cn-hangzhou)
ALIYUN_SLS_REGION=
# Aliyun SLS Project Name
ALIYUN_SLS_PROJECT_NAME=
# Number of days to retain workflow run logs (default: 365 days 3650 for permanent storage)
ALIYUN_SLS_LOGSTORE_TTL=365
# Enable dual-write to both SLS LogStore and SQL database (default: false)
LOGSTORE_DUAL_WRITE_ENABLED=false
# Enable dual-read fallback to SQL database when LogStore returns no results (default: true)
# Useful for migration scenarios where historical data exists only in SQL database
LOGSTORE_DUAL_READ_ENABLED=true
# Celery beat configuration
CELERY_BEAT_SCHEDULER_TIME=1
@ -626,17 +645,7 @@ QUEUE_MONITOR_ALERT_EMAILS=
QUEUE_MONITOR_INTERVAL=30
# Swagger UI configuration
# SECURITY: Swagger UI is automatically disabled in PRODUCTION environment (DEPLOY_ENV=PRODUCTION)
# to prevent API information disclosure.
#
# Behavior:
# - DEPLOY_ENV=PRODUCTION + SWAGGER_UI_ENABLED not set -> Swagger DISABLED (secure default)
# - DEPLOY_ENV=DEVELOPMENT/TESTING + SWAGGER_UI_ENABLED not set -> Swagger ENABLED
# - SWAGGER_UI_ENABLED=true -> Swagger ENABLED (overrides environment check)
# - SWAGGER_UI_ENABLED=false -> Swagger DISABLED (explicit disable)
#
# For development, you can uncomment below or set DEPLOY_ENV=DEVELOPMENT
# SWAGGER_UI_ENABLED=false
SWAGGER_UI_ENABLED=true
SWAGGER_UI_PATH=/swagger-ui.html
# Whether to encrypt dataset IDs when exporting DSL files (default: true)
@ -680,4 +689,4 @@ ANNOTATION_IMPORT_MIN_RECORDS=1
ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE=5
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20
# Maximum number of concurrent annotation import tasks per tenant
ANNOTATION_IMPORT_MAX_CONCURRENT=5
ANNOTATION_IMPORT_MAX_CONCURRENT=5

View File

@ -75,6 +75,7 @@ def initialize_extensions(app: DifyApp):
ext_import_modules,
ext_logging,
ext_login,
ext_logstore,
ext_mail,
ext_migrate,
ext_orjson,
@ -105,6 +106,7 @@ def initialize_extensions(app: DifyApp):
ext_migrate,
ext_redis,
ext_storage,
ext_logstore, # Initialize logstore after storage, before celery
ext_celery,
ext_login,
ext_mail,

View File

@ -1252,19 +1252,9 @@ class WorkflowLogConfig(BaseSettings):
class SwaggerUIConfig(BaseSettings):
"""
Configuration for Swagger UI documentation.
Security Note: Swagger UI is automatically disabled in PRODUCTION environment
to prevent API information disclosure. Set SWAGGER_UI_ENABLED=true explicitly
to enable in production if needed.
"""
SWAGGER_UI_ENABLED: bool | None = Field(
description="Whether to enable Swagger UI in api module. "
"Automatically disabled in PRODUCTION environment for security. "
"Set to true explicitly to enable in production.",
default=None,
SWAGGER_UI_ENABLED: bool = Field(
description="Whether to enable Swagger UI in api module",
default=True,
)
SWAGGER_UI_PATH: str = Field(
@ -1272,23 +1262,6 @@ class SwaggerUIConfig(BaseSettings):
default="/swagger-ui.html",
)
@property
def swagger_ui_enabled(self) -> bool:
"""
Compute whether Swagger UI should be enabled.
If SWAGGER_UI_ENABLED is explicitly set, use that value.
Otherwise, disable in PRODUCTION environment for security.
"""
if self.SWAGGER_UI_ENABLED is not None:
return self.SWAGGER_UI_ENABLED
# Auto-disable in production environment
import os
deploy_env = os.environ.get("DEPLOY_ENV", "PRODUCTION")
return deploy_env.upper() != "PRODUCTION"
class TenantIsolatedTaskQueueConfig(BaseSettings):
TENANT_ISOLATED_TASK_CONCURRENCY: int = Field(

View File

@ -107,7 +107,7 @@ class KeywordStoreConfig(BaseSettings):
class DatabaseConfig(BaseSettings):
# Database type selector
DB_TYPE: Literal["postgresql", "mysql", "oceanbase"] = Field(
DB_TYPE: Literal["postgresql", "mysql", "oceanbase", "seekdb"] = Field(
description="Database type to use. OceanBase is MySQL-compatible.",
default="postgresql",
)

View File

@ -22,7 +22,12 @@ from controllers.console.error import (
NotAllowedCreateWorkspace,
WorkspacesLimitExceeded,
)
from controllers.console.wraps import email_password_login_enabled, setup_required
from controllers.console.wraps import (
decrypt_code_field,
decrypt_password_field,
email_password_login_enabled,
setup_required,
)
from events.tenant_event import tenant_was_created
from libs.helper import EmailStr, extract_remote_ip
from libs.login import current_account_with_tenant
@ -79,6 +84,7 @@ class LoginApi(Resource):
@setup_required
@email_password_login_enabled
@console_ns.expect(console_ns.models[LoginPayload.__name__])
@decrypt_password_field
def post(self):
"""Authenticate user and login."""
args = LoginPayload.model_validate(console_ns.payload)
@ -218,6 +224,7 @@ class EmailCodeLoginSendEmailApi(Resource):
class EmailCodeLoginApi(Resource):
@setup_required
@console_ns.expect(console_ns.models[EmailCodeLoginPayload.__name__])
@decrypt_code_field
def post(self):
args = EmailCodeLoginPayload.model_validate(console_ns.payload)

View File

@ -140,6 +140,18 @@ class DataSourceNotionListApi(Resource):
credential_id = request.args.get("credential_id", default=None, type=str)
if not credential_id:
raise ValueError("Credential id is required.")
# Get datasource_parameters from query string (optional, for GitHub and other datasources)
datasource_parameters_str = request.args.get("datasource_parameters", default=None, type=str)
datasource_parameters = {}
if datasource_parameters_str:
try:
datasource_parameters = json.loads(datasource_parameters_str)
if not isinstance(datasource_parameters, dict):
raise ValueError("datasource_parameters must be a JSON object.")
except json.JSONDecodeError:
raise ValueError("Invalid datasource_parameters JSON format.")
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_tenant_id,
@ -187,7 +199,7 @@ class DataSourceNotionListApi(Resource):
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
datasource_runtime.get_online_document_pages(
user_id=current_user.id,
datasource_parameters={},
datasource_parameters=datasource_parameters,
provider_type=datasource_runtime.datasource_provider_type(),
)
)
@ -218,14 +230,14 @@ class DataSourceNotionListApi(Resource):
@console_ns.route(
"/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview",
"/notion/pages/<uuid:page_id>/<string:page_type>/preview",
"/datasets/notion-indexing-estimate",
)
class DataSourceNotionApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, workspace_id, page_id, page_type):
def get(self, page_id, page_type):
_, current_tenant_id = current_account_with_tenant()
credential_id = request.args.get("credential_id", default=None, type=str)
@ -239,11 +251,10 @@ class DataSourceNotionApi(Resource):
plugin_id="langgenius/notion_datasource",
)
workspace_id = str(workspace_id)
page_id = str(page_id)
extractor = NotionExtractor(
notion_workspace_id=workspace_id,
notion_workspace_id="",
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"),

View File

@ -223,6 +223,7 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
VectorType.COUCHBASE,
VectorType.OPENGAUSS,
VectorType.OCEANBASE,
VectorType.SEEKDB,
VectorType.TABLESTORE,
VectorType.HUAWEI_CLOUD,
VectorType.TENCENT,

View File

@ -4,7 +4,7 @@ from typing import Any, Literal, cast
from uuid import UUID
from flask import abort, request
from flask_restx import Resource, marshal_with # type: ignore
from flask_restx import Resource, marshal_with, reqparse # type: ignore
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@ -975,6 +975,11 @@ class RagPipelineRecommendedPluginApi(Resource):
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("type", type=str, location="args", required=False, default="all")
args = parser.parse_args()
type = args["type"]
rag_pipeline_service = RagPipelineService()
recommended_plugins = rag_pipeline_service.get_recommended_plugins()
recommended_plugins = rag_pipeline_service.get_recommended_plugins(type)
return recommended_plugins

View File

@ -9,10 +9,12 @@ from typing import ParamSpec, TypeVar
from flask import abort, request
from configs import dify_config
from controllers.console.auth.error import AuthenticationFailedError, EmailCodeError
from controllers.console.workspace.error import AccountNotInitializedError
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.encryption import FieldEncryption
from libs.login import current_account_with_tenant
from models.account import AccountStatus
from models.dataset import RateLimitLog
@ -25,6 +27,14 @@ from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogo
P = ParamSpec("P")
R = TypeVar("R")
# Field names for decryption
FIELD_NAME_PASSWORD = "password"
FIELD_NAME_CODE = "code"
# Error messages for decryption failures
ERROR_MSG_INVALID_ENCRYPTED_DATA = "Invalid encrypted data"
ERROR_MSG_INVALID_ENCRYPTED_CODE = "Invalid encrypted code"
def account_initialization_required(view: Callable[P, R]):
@wraps(view)
@ -419,3 +429,75 @@ def annotation_import_concurrency_limit(view: Callable[P, R]):
return view(*args, **kwargs)
return decorated
def _decrypt_field(field_name: str, error_class: type[Exception], error_message: str) -> None:
"""
Helper to decode a Base64 encoded field in the request payload.
Args:
field_name: Name of the field to decode
error_class: Exception class to raise on decoding failure
error_message: Error message to include in the exception
"""
if not request or not request.is_json:
return
# Get the payload dict - it's cached and mutable
payload = request.get_json()
if not payload or field_name not in payload:
return
encoded_value = payload[field_name]
decoded_value = FieldEncryption.decrypt_field(encoded_value)
# If decoding failed, raise error immediately
if decoded_value is None:
raise error_class(error_message)
# Update payload dict in-place with decoded value
# Since payload is a mutable dict and get_json() returns the cached reference,
# modifying it will affect all subsequent accesses including console_ns.payload
payload[field_name] = decoded_value
def decrypt_password_field(view: Callable[P, R]):
"""
Decorator to decrypt password field in request payload.
Automatically decrypts the 'password' field if encryption is enabled.
If decryption fails, raises AuthenticationFailedError.
Usage:
@decrypt_password_field
def post(self):
args = LoginPayload.model_validate(console_ns.payload)
# args.password is now decrypted
"""
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
_decrypt_field(FIELD_NAME_PASSWORD, AuthenticationFailedError, ERROR_MSG_INVALID_ENCRYPTED_DATA)
return view(*args, **kwargs)
return decorated
def decrypt_code_field(view: Callable[P, R]):
"""
Decorator to decrypt verification code field in request payload.
Automatically decrypts the 'code' field if encryption is enabled.
If decryption fails, raises EmailCodeError.
Usage:
@decrypt_code_field
def post(self):
args = EmailCodeLoginPayload.model_validate(console_ns.payload)
# args.code is now decrypted
"""
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
_decrypt_field(FIELD_NAME_CODE, EmailCodeError, ERROR_MSG_INVALID_ENCRYPTED_CODE)
return view(*args, **kwargs)
return decorated

View File

@ -163,7 +163,7 @@ class Vector:
from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStoreFactory
return LindormVectorStoreFactory
case VectorType.OCEANBASE:
case VectorType.OCEANBASE | VectorType.SEEKDB:
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory
return OceanBaseVectorFactory

View File

@ -27,6 +27,7 @@ class VectorType(StrEnum):
UPSTASH = "upstash"
TIDB_ON_QDRANT = "tidb_on_qdrant"
OCEANBASE = "oceanbase"
SEEKDB = "seekdb"
OPENGAUSS = "opengauss"
TABLESTORE = "tablestore"
HUAWEI_CLOUD = "huawei_cloud"

View File

@ -10,7 +10,7 @@ class NotionInfo(BaseModel):
"""
credential_id: str | None = None
notion_workspace_id: str
notion_workspace_id: str | None = ""
notion_obj_id: str
notion_page_type: str
document: Document | None = None

View File

@ -166,7 +166,7 @@ class ExtractProcessor:
elif extract_setting.datasource_type == DatasourceType.NOTION:
assert extract_setting.notion_info is not None, "notion_info is required"
extractor = NotionExtractor(
notion_workspace_id=extract_setting.notion_info.notion_workspace_id,
notion_workspace_id=extract_setting.notion_info.notion_workspace_id or "",
notion_obj_id=extract_setting.notion_info.notion_obj_id,
notion_page_type=extract_setting.notion_info.notion_page_type,
document_model=extract_setting.notion_info.document,

View File

@ -45,6 +45,6 @@ def detect_file_encodings(file_path: str, timeout: int = 5, sample_size: int = 1
except concurrent.futures.TimeoutError:
raise TimeoutError(f"Timeout reached while detecting encoding for {file_path}")
if all(encoding["encoding"] is None for encoding in encodings):
if all(encoding.encoding is None for encoding in encodings):
raise RuntimeError(f"Could not detect encoding for {file_path}")
return [FileEncoding(**enc) for enc in encodings if enc["encoding"] is not None]
return [enc for enc in encodings if enc.encoding is not None]

View File

@ -140,6 +140,10 @@ class GraphEngine:
pause_handler = PauseCommandHandler()
self._command_processor.register_handler(PauseCommand, pause_handler)
# === Extensibility ===
# Layers allow plugins to extend engine functionality
self._layers: list[GraphEngineLayer] = []
# === Worker Pool Setup ===
# Capture Flask app context for worker threads
flask_app: Flask | None = None
@ -158,6 +162,7 @@ class GraphEngine:
ready_queue=self._ready_queue,
event_queue=self._event_queue,
graph=self._graph,
layers=self._layers,
flask_app=flask_app,
context_vars=context_vars,
min_workers=self._min_workers,
@ -196,10 +201,6 @@ class GraphEngine:
event_emitter=self._event_manager,
)
# === Extensibility ===
# Layers allow plugins to extend engine functionality
self._layers: list[GraphEngineLayer] = []
# === Validation ===
# Ensure all nodes share the same GraphRuntimeState instance
self._validate_graph_state_consistency()

View File

@ -8,9 +8,11 @@ with middleware-like components that can observe events and interact with execut
from .base import GraphEngineLayer
from .debug_logging import DebugLoggingLayer
from .execution_limits import ExecutionLimitsLayer
from .observability import ObservabilityLayer
__all__ = [
"DebugLoggingLayer",
"ExecutionLimitsLayer",
"GraphEngineLayer",
"ObservabilityLayer",
]

View File

@ -9,6 +9,7 @@ from abc import ABC, abstractmethod
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
from core.workflow.graph_events import GraphEngineEvent
from core.workflow.nodes.base.node import Node
from core.workflow.runtime import ReadOnlyGraphRuntimeState
@ -83,3 +84,29 @@ class GraphEngineLayer(ABC):
error: The exception that caused execution to fail, or None if successful
"""
pass
def on_node_run_start(self, node: Node) -> None: # noqa: B027
"""
Called immediately before a node begins execution.
Layers can override to inject behavior (e.g., start spans) prior to node execution.
The node's execution ID is available via `node._node_execution_id` and will be
consistent with all events emitted by this node execution.
Args:
node: The node instance about to be executed
"""
pass
def on_node_run_end(self, node: Node, error: Exception | None) -> None: # noqa: B027
"""
Called after a node finishes execution.
The node's execution ID is available via `node._node_execution_id` and matches
the `id` field in all events emitted by this node execution.
Args:
node: The node instance that just finished execution
error: Exception instance if the node failed, otherwise None
"""
pass

View File

@ -0,0 +1,61 @@
"""
Node-level OpenTelemetry parser interfaces and defaults.
"""
import json
from typing import Protocol
from opentelemetry.trace import Span
from opentelemetry.trace.status import Status, StatusCode
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.tool.entities import ToolNodeData
class NodeOTelParser(Protocol):
"""Parser interface for node-specific OpenTelemetry enrichment."""
def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None: ...
class DefaultNodeOTelParser:
"""Fallback parser used when no node-specific parser is registered."""
def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None:
span.set_attribute("node.id", node.id)
if node.execution_id:
span.set_attribute("node.execution_id", node.execution_id)
if hasattr(node, "node_type") and node.node_type:
span.set_attribute("node.type", node.node_type.value)
if error:
span.record_exception(error)
span.set_status(Status(StatusCode.ERROR, str(error)))
else:
span.set_status(Status(StatusCode.OK))
class ToolNodeOTelParser:
"""Parser for tool nodes that captures tool-specific metadata."""
def __init__(self) -> None:
self._delegate = DefaultNodeOTelParser()
def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None:
self._delegate.parse(node=node, span=span, error=error)
tool_data = getattr(node, "_node_data", None)
if not isinstance(tool_data, ToolNodeData):
return
span.set_attribute("tool.provider.id", tool_data.provider_id)
span.set_attribute("tool.provider.type", tool_data.provider_type.value)
span.set_attribute("tool.provider.name", tool_data.provider_name)
span.set_attribute("tool.name", tool_data.tool_name)
span.set_attribute("tool.label", tool_data.tool_label)
if tool_data.plugin_unique_identifier:
span.set_attribute("tool.plugin.id", tool_data.plugin_unique_identifier)
if tool_data.credential_id:
span.set_attribute("tool.credential.id", tool_data.credential_id)
if tool_data.tool_configurations:
span.set_attribute("tool.config", json.dumps(tool_data.tool_configurations, ensure_ascii=False))

View File

@ -0,0 +1,169 @@
"""
Observability layer for GraphEngine.
This layer creates OpenTelemetry spans for node execution, enabling distributed
tracing of workflow execution. It establishes OTel context during node execution
so that automatic instrumentation (HTTP requests, DB queries, etc.) automatically
associates with the node span.
"""
import logging
from dataclasses import dataclass
from typing import cast, final
from opentelemetry import context as context_api
from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_context
from typing_extensions import override
from configs import dify_config
from core.workflow.enums import NodeType
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_engine.layers.node_parsers import (
DefaultNodeOTelParser,
NodeOTelParser,
ToolNodeOTelParser,
)
from core.workflow.nodes.base.node import Node
from extensions.otel.runtime import is_instrument_flag_enabled
logger = logging.getLogger(__name__)
@dataclass(slots=True)
class _NodeSpanContext:
span: "Span"
token: object
@final
class ObservabilityLayer(GraphEngineLayer):
"""
Layer that creates OpenTelemetry spans for node execution.
This layer:
- Creates a span when a node starts execution
- Establishes OTel context so automatic instrumentation associates with the span
- Sets complete attributes and status when node execution ends
"""
def __init__(self) -> None:
super().__init__()
self._node_contexts: dict[str, _NodeSpanContext] = {}
self._parsers: dict[NodeType, NodeOTelParser] = {}
self._default_parser: NodeOTelParser = cast(NodeOTelParser, DefaultNodeOTelParser())
self._is_disabled: bool = False
self._tracer: Tracer | None = None
self._build_parser_registry()
self._init_tracer()
def _init_tracer(self) -> None:
"""Initialize OpenTelemetry tracer in constructor."""
if not (dify_config.ENABLE_OTEL or is_instrument_flag_enabled()):
self._is_disabled = True
return
try:
self._tracer = get_tracer(__name__)
except Exception as e:
logger.warning("Failed to get OpenTelemetry tracer: %s", e)
self._is_disabled = True
def _build_parser_registry(self) -> None:
"""Initialize parser registry for node types."""
self._parsers = {
NodeType.TOOL: ToolNodeOTelParser(),
}
def _get_parser(self, node: Node) -> NodeOTelParser:
node_type = getattr(node, "node_type", None)
if isinstance(node_type, NodeType):
return self._parsers.get(node_type, self._default_parser)
return self._default_parser
@override
def on_graph_start(self) -> None:
"""Called when graph execution starts."""
self._node_contexts.clear()
@override
def on_node_run_start(self, node: Node) -> None:
"""
Called when a node starts execution.
Creates a span and establishes OTel context for automatic instrumentation.
"""
if self._is_disabled:
return
try:
if not self._tracer:
return
execution_id = node.execution_id
if not execution_id:
return
parent_context = context_api.get_current()
span = self._tracer.start_span(
f"{node.title}",
kind=SpanKind.INTERNAL,
context=parent_context,
)
new_context = set_span_in_context(span)
token = context_api.attach(new_context)
self._node_contexts[execution_id] = _NodeSpanContext(span=span, token=token)
except Exception as e:
logger.warning("Failed to create OpenTelemetry span for node %s: %s", node.id, e)
@override
def on_node_run_end(self, node: Node, error: Exception | None) -> None:
"""
Called when a node finishes execution.
Sets complete attributes, records exceptions, and ends the span.
"""
if self._is_disabled:
return
try:
execution_id = node.execution_id
if not execution_id:
return
node_context = self._node_contexts.get(execution_id)
if not node_context:
return
span = node_context.span
parser = self._get_parser(node)
try:
parser.parse(node=node, span=span, error=error)
span.end()
finally:
token = node_context.token
if token is not None:
try:
context_api.detach(token)
except Exception:
logger.warning("Failed to detach OpenTelemetry token: %s", token)
self._node_contexts.pop(execution_id, None)
except Exception as e:
logger.warning("Failed to end OpenTelemetry span for node %s: %s", node.id, e)
@override
def on_event(self, event) -> None:
"""Not used in this layer."""
pass
@override
def on_graph_end(self, error: Exception | None) -> None:
"""Called when graph execution ends."""
if self._node_contexts:
logger.warning(
"ObservabilityLayer: %d node spans were not properly ended",
len(self._node_contexts),
)
self._node_contexts.clear()

View File

@ -9,6 +9,7 @@ import contextvars
import queue
import threading
import time
from collections.abc import Sequence
from datetime import datetime
from typing import final
from uuid import uuid4
@ -17,6 +18,7 @@ from flask import Flask
from typing_extensions import override
from core.workflow.graph import Graph
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
from core.workflow.nodes.base.node import Node
from libs.flask_utils import preserve_flask_contexts
@ -39,6 +41,7 @@ class Worker(threading.Thread):
ready_queue: ReadyQueue,
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
layers: Sequence[GraphEngineLayer],
worker_id: int = 0,
flask_app: Flask | None = None,
context_vars: contextvars.Context | None = None,
@ -50,6 +53,7 @@ class Worker(threading.Thread):
ready_queue: Ready queue containing node IDs ready for execution
event_queue: Queue for pushing execution events
graph: Graph containing nodes to execute
layers: Graph engine layers for node execution hooks
worker_id: Unique identifier for this worker
flask_app: Optional Flask application for context preservation
context_vars: Optional context variables to preserve in worker thread
@ -63,6 +67,7 @@ class Worker(threading.Thread):
self._context_vars = context_vars
self._stop_event = threading.Event()
self._last_task_time = time.time()
self._layers = layers if layers is not None else []
def stop(self) -> None:
"""Signal the worker to stop processing."""
@ -122,20 +127,51 @@ class Worker(threading.Thread):
Args:
node: The node instance to execute
"""
# Execute the node with preserved context if Flask app is provided
node.ensure_execution_id()
error: Exception | None = None
if self._flask_app and self._context_vars:
with preserve_flask_contexts(
flask_app=self._flask_app,
context_vars=self._context_vars,
):
# Execute the node
self._invoke_node_run_start_hooks(node)
try:
node_events = node.run()
for event in node_events:
self._event_queue.put(event)
except Exception as exc:
error = exc
raise
finally:
self._invoke_node_run_end_hooks(node, error)
else:
self._invoke_node_run_start_hooks(node)
try:
node_events = node.run()
for event in node_events:
# Forward event to dispatcher immediately for streaming
self._event_queue.put(event)
else:
# Execute without context preservation
node_events = node.run()
for event in node_events:
# Forward event to dispatcher immediately for streaming
self._event_queue.put(event)
except Exception as exc:
error = exc
raise
finally:
self._invoke_node_run_end_hooks(node, error)
def _invoke_node_run_start_hooks(self, node: Node) -> None:
"""Invoke on_node_run_start hooks for all layers."""
for layer in self._layers:
try:
layer.on_node_run_start(node)
except Exception:
# Silently ignore layer errors to prevent disrupting node execution
continue
def _invoke_node_run_end_hooks(self, node: Node, error: Exception | None) -> None:
"""Invoke on_node_run_end hooks for all layers."""
for layer in self._layers:
try:
layer.on_node_run_end(node, error)
except Exception:
# Silently ignore layer errors to prevent disrupting node execution
continue

View File

@ -14,6 +14,7 @@ from configs import dify_config
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase
from ..layers.base import GraphEngineLayer
from ..ready_queue import ReadyQueue
from ..worker import Worker
@ -39,6 +40,7 @@ class WorkerPool:
ready_queue: ReadyQueue,
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
layers: list[GraphEngineLayer],
flask_app: "Flask | None" = None,
context_vars: "Context | None" = None,
min_workers: int | None = None,
@ -53,6 +55,7 @@ class WorkerPool:
ready_queue: Ready queue for nodes ready for execution
event_queue: Queue for worker events
graph: The workflow graph
layers: Graph engine layers for node execution hooks
flask_app: Optional Flask app for context preservation
context_vars: Optional context variables
min_workers: Minimum number of workers
@ -65,6 +68,7 @@ class WorkerPool:
self._graph = graph
self._flask_app = flask_app
self._context_vars = context_vars
self._layers = layers
# Scaling parameters with defaults
self._min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS
@ -144,6 +148,7 @@ class WorkerPool:
ready_queue=self._ready_queue,
event_queue=self._event_queue,
graph=self._graph,
layers=self._layers,
worker_id=worker_id,
flask_app=self._flask_app,
context_vars=self._context_vars,

View File

@ -244,6 +244,15 @@ class Node(Generic[NodeDataT]):
def graph_init_params(self) -> "GraphInitParams":
return self._graph_init_params
@property
def execution_id(self) -> str:
return self._node_execution_id
def ensure_execution_id(self) -> str:
if not self._node_execution_id:
self._node_execution_id = str(uuid4())
return self._node_execution_id
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
return cast(NodeDataT, self._node_data_type.model_validate(data))
@ -256,14 +265,12 @@ class Node(Generic[NodeDataT]):
raise NotImplementedError
def run(self) -> Generator[GraphNodeEventBase, None, None]:
# Generate a single node execution ID to use for all events
if not self._node_execution_id:
self._node_execution_id = str(uuid4())
execution_id = self.ensure_execution_id()
self._start_at = naive_utc_now()
# Create and push start event with required fields
start_event = NodeRunStartedEvent(
id=self._node_execution_id,
id=execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.title,
@ -321,7 +328,7 @@ class Node(Generic[NodeDataT]):
if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance]
yield self._dispatch(event)
elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance]
event.id = self._node_execution_id
event.id = self.execution_id
yield event
else:
yield event
@ -333,7 +340,7 @@ class Node(Generic[NodeDataT]):
error_type="WorkflowNodeError",
)
yield NodeRunFailedEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
@ -512,7 +519,7 @@ class Node(Generic[NodeDataT]):
match result.status:
case WorkflowNodeExecutionStatus.FAILED:
return NodeRunFailedEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self.id,
node_type=self.node_type,
start_at=self._start_at,
@ -521,7 +528,7 @@ class Node(Generic[NodeDataT]):
)
case WorkflowNodeExecutionStatus.SUCCEEDED:
return NodeRunSucceededEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self.id,
node_type=self.node_type,
start_at=self._start_at,
@ -537,7 +544,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent:
return NodeRunStreamChunkEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
selector=event.selector,
@ -550,7 +557,7 @@ class Node(Generic[NodeDataT]):
match event.node_run_result.status:
case WorkflowNodeExecutionStatus.SUCCEEDED:
return NodeRunSucceededEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
@ -558,7 +565,7 @@ class Node(Generic[NodeDataT]):
)
case WorkflowNodeExecutionStatus.FAILED:
return NodeRunFailedEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
@ -573,7 +580,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent:
return NodeRunPauseRequestedEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.PAUSED),
@ -583,7 +590,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent:
return NodeRunAgentLogEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
message_id=event.message_id,
@ -599,7 +606,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent:
return NodeRunLoopStartedEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@ -612,7 +619,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: LoopNextEvent) -> NodeRunLoopNextEvent:
return NodeRunLoopNextEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@ -623,7 +630,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent:
return NodeRunLoopSucceededEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@ -637,7 +644,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent:
return NodeRunLoopFailedEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@ -652,7 +659,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent:
return NodeRunIterationStartedEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@ -665,7 +672,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: IterationNextEvent) -> NodeRunIterationNextEvent:
return NodeRunIterationNextEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@ -676,7 +683,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent:
return NodeRunIterationSucceededEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@ -690,7 +697,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent:
return NodeRunIterationFailedEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=self.node_data.title,
@ -705,7 +712,7 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent:
return NodeRunRetrieverResourceEvent(
id=self._node_execution_id,
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
retriever_resources=event.retriever_resources,

View File

@ -14,7 +14,7 @@ from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer
from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer, ObservabilityLayer
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
from core.workflow.nodes import NodeType
@ -23,6 +23,7 @@ from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from extensions.otel.runtime import is_instrument_flag_enabled
from factories import file_factory
from models.enums import UserFrom
from models.workflow import Workflow
@ -98,6 +99,10 @@ class WorkflowEntry:
)
self.graph_engine.layer(limits_layer)
# Add observability layer when OTel is enabled
if dify_config.ENABLE_OTEL or is_instrument_flag_enabled():
self.graph_engine.layer(ObservabilityLayer())
def run(self) -> Generator[GraphEngineEvent, None, None]:
graph_engine = self.graph_engine

View File

@ -22,8 +22,8 @@ login_manager = flask_login.LoginManager()
@login_manager.request_loader
def load_user_from_request(request_from_flask_login):
"""Load user based on the request."""
# Skip authentication for documentation endpoints (only when Swagger is enabled)
if dify_config.swagger_ui_enabled and request.path.endswith((dify_config.SWAGGER_UI_PATH, "/swagger.json")):
# Skip authentication for documentation endpoints
if dify_config.SWAGGER_UI_ENABLED and request.path.endswith((dify_config.SWAGGER_UI_PATH, "/swagger.json")):
return None
auth_token = extract_access_token(request)

View File

@ -0,0 +1,74 @@
"""
Logstore extension for Dify application.
This extension initializes the logstore (Aliyun SLS) on application startup,
creating necessary projects, logstores, and indexes if they don't exist.
"""
import logging
import os
from dotenv import load_dotenv
from dify_app import DifyApp
logger = logging.getLogger(__name__)
def is_enabled() -> bool:
"""
Check if logstore extension is enabled.
Returns:
True if all required Aliyun SLS environment variables are set, False otherwise
"""
# Load environment variables from .env file
load_dotenv()
required_vars = [
"ALIYUN_SLS_ACCESS_KEY_ID",
"ALIYUN_SLS_ACCESS_KEY_SECRET",
"ALIYUN_SLS_ENDPOINT",
"ALIYUN_SLS_REGION",
"ALIYUN_SLS_PROJECT_NAME",
]
all_set = all(os.environ.get(var) for var in required_vars)
if not all_set:
logger.info("Logstore extension disabled: required Aliyun SLS environment variables not set")
return all_set
def init_app(app: DifyApp):
"""
Initialize logstore on application startup.
This function:
1. Creates Aliyun SLS project if it doesn't exist
2. Creates logstores (workflow_execution, workflow_node_execution) if they don't exist
3. Creates indexes with field configurations based on PostgreSQL table structures
This operation is idempotent and only executes once during application startup.
Args:
app: The Dify application instance
"""
try:
from extensions.logstore.aliyun_logstore import AliyunLogStore
logger.info("Initializing logstore...")
# Create logstore client and initialize project/logstores/indexes
logstore_client = AliyunLogStore()
logstore_client.init_project_logstore()
# Attach to app for potential later use
app.extensions["logstore"] = logstore_client
logger.info("Logstore initialized successfully")
except Exception:
logger.exception("Failed to initialize logstore")
# Don't raise - allow application to continue even if logstore init fails
# This ensures that the application can still run if logstore is misconfigured

View File

View File

@ -0,0 +1,890 @@
import logging
import os
import threading
import time
from collections.abc import Sequence
from typing import Any
import sqlalchemy as sa
from aliyun.log import ( # type: ignore[import-untyped]
GetLogsRequest,
IndexConfig,
IndexKeyConfig,
IndexLineConfig,
LogClient,
LogItem,
PutLogsRequest,
)
from aliyun.log.auth import AUTH_VERSION_4 # type: ignore[import-untyped]
from aliyun.log.logexception import LogException # type: ignore[import-untyped]
from dotenv import load_dotenv
from sqlalchemy.orm import DeclarativeBase
from configs import dify_config
from extensions.logstore.aliyun_logstore_pg import AliyunLogStorePG
logger = logging.getLogger(__name__)
class AliyunLogStore:
"""
Singleton class for Aliyun SLS LogStore operations.
Ensures only one instance exists to prevent multiple PG connection pools.
"""
_instance: "AliyunLogStore | None" = None
_initialized: bool = False
# Track delayed PG connection for newly created projects
_pg_connection_timer: threading.Timer | None = None
_pg_connection_delay: int = 90 # delay seconds
# Default tokenizer for text/json fields and full-text index
# Common delimiters: comma, space, quotes, punctuation, operators, brackets, special chars
DEFAULT_TOKEN_LIST = [
",",
" ",
'"',
'"',
";",
"=",
"(",
")",
"[",
"]",
"{",
"}",
"?",
"@",
"&",
"<",
">",
"/",
":",
"\n",
"\t",
]
def __new__(cls) -> "AliyunLogStore":
"""Implement singleton pattern."""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
project_des = "dify"
workflow_execution_logstore = "workflow_execution"
workflow_node_execution_logstore = "workflow_node_execution"
@staticmethod
def _sqlalchemy_type_to_logstore_type(column: Any) -> str:
"""
Map SQLAlchemy column type to Aliyun LogStore index type.
Args:
column: SQLAlchemy column object
Returns:
LogStore index type: 'text', 'long', 'double', or 'json'
"""
column_type = column.type
# Integer types -> long
if isinstance(column_type, (sa.Integer, sa.BigInteger, sa.SmallInteger)):
return "long"
# Float types -> double
if isinstance(column_type, (sa.Float, sa.Numeric)):
return "double"
# String and Text types -> text
if isinstance(column_type, (sa.String, sa.Text)):
return "text"
# DateTime -> text (stored as ISO format string in logstore)
if isinstance(column_type, sa.DateTime):
return "text"
# Boolean -> long (stored as 0/1)
if isinstance(column_type, sa.Boolean):
return "long"
# JSON -> json
if isinstance(column_type, sa.JSON):
return "json"
# Default to text for unknown types
return "text"
@staticmethod
def _generate_index_keys_from_model(model_class: type[DeclarativeBase]) -> dict[str, IndexKeyConfig]:
"""
Automatically generate LogStore field index configuration from SQLAlchemy model.
This method introspects the SQLAlchemy model's column definitions and creates
corresponding LogStore index configurations. When the PG schema is updated via
Flask-Migrate, this method will automatically pick up the new fields on next startup.
Args:
model_class: SQLAlchemy model class (e.g., WorkflowRun, WorkflowNodeExecutionModel)
Returns:
Dictionary mapping field names to IndexKeyConfig objects
"""
index_keys = {}
# Iterate over all mapped columns in the model
if hasattr(model_class, "__mapper__"):
for column_name, column_property in model_class.__mapper__.columns.items():
# Skip relationship properties and other non-column attributes
if not hasattr(column_property, "type"):
continue
# Map SQLAlchemy type to LogStore type
logstore_type = AliyunLogStore._sqlalchemy_type_to_logstore_type(column_property)
# Create index configuration
# - text fields: case_insensitive for better search, with tokenizer and Chinese support
# - all fields: doc_value=True for analytics
if logstore_type == "text":
index_keys[column_name] = IndexKeyConfig(
index_type="text",
case_sensitive=False,
doc_value=True,
token_list=AliyunLogStore.DEFAULT_TOKEN_LIST,
chinese=True,
)
else:
index_keys[column_name] = IndexKeyConfig(index_type=logstore_type, doc_value=True)
# Add log_version field (not in PG model, but used in logstore for versioning)
index_keys["log_version"] = IndexKeyConfig(index_type="long", doc_value=True)
return index_keys
def __init__(self) -> None:
# Skip initialization if already initialized (singleton pattern)
if self.__class__._initialized:
return
load_dotenv()
self.access_key_id: str = os.environ.get("ALIYUN_SLS_ACCESS_KEY_ID", "")
self.access_key_secret: str = os.environ.get("ALIYUN_SLS_ACCESS_KEY_SECRET", "")
self.endpoint: str = os.environ.get("ALIYUN_SLS_ENDPOINT", "")
self.region: str = os.environ.get("ALIYUN_SLS_REGION", "")
self.project_name: str = os.environ.get("ALIYUN_SLS_PROJECT_NAME", "")
self.logstore_ttl: int = int(os.environ.get("ALIYUN_SLS_LOGSTORE_TTL", 365))
self.log_enabled: bool = os.environ.get("SQLALCHEMY_ECHO", "false").lower() == "true"
self.pg_mode_enabled: bool = os.environ.get("LOGSTORE_PG_MODE_ENABLED", "true").lower() == "true"
# Initialize SDK client
self.client = LogClient(
self.endpoint, self.access_key_id, self.access_key_secret, auth_version=AUTH_VERSION_4, region=self.region
)
# Append Dify identification to the existing user agent
original_user_agent = self.client._user_agent # pyright: ignore[reportPrivateUsage]
dify_version = dify_config.project.version
enhanced_user_agent = f"Dify,Dify-{dify_version},{original_user_agent}"
self.client.set_user_agent(enhanced_user_agent)
# PG client will be initialized in init_project_logstore
self._pg_client: AliyunLogStorePG | None = None
self._use_pg_protocol: bool = False
self.__class__._initialized = True
@property
def supports_pg_protocol(self) -> bool:
"""Check if PG protocol is supported and enabled."""
return self._use_pg_protocol
def _attempt_pg_connection_init(self) -> bool:
"""
Attempt to initialize PG connection.
This method tries to establish PG connection and performs necessary checks.
It's used both for immediate connection (existing projects) and delayed connection (new projects).
Returns:
True if PG connection was successfully established, False otherwise.
"""
if not self.pg_mode_enabled or not self._pg_client:
return False
try:
self._use_pg_protocol = self._pg_client.init_connection()
if self._use_pg_protocol:
logger.info("Successfully connected to project %s using PG protocol", self.project_name)
# Check if scan_index is enabled for all logstores
self._check_and_disable_pg_if_scan_index_disabled()
return True
else:
logger.info("PG connection failed for project %s. Will use SDK mode.", self.project_name)
return False
except Exception as e:
logger.warning(
"Failed to establish PG connection for project %s: %s. Will use SDK mode.",
self.project_name,
str(e),
)
self._use_pg_protocol = False
return False
def _delayed_pg_connection_init(self) -> None:
"""
Delayed initialization of PG connection for newly created projects.
This method is called by a background timer 3 minutes after project creation.
"""
# Double check conditions in case state changed
if self._use_pg_protocol:
return
logger.info(
"Attempting delayed PG connection for newly created project %s ...",
self.project_name,
)
self._attempt_pg_connection_init()
self.__class__._pg_connection_timer = None
def init_project_logstore(self):
"""
Initialize project, logstore, index, and PG connection.
This method should be called once during application startup to ensure
all required resources exist and connections are established.
"""
# Step 1: Ensure project and logstore exist
project_is_new = False
if not self.is_project_exist():
self.create_project()
project_is_new = True
self.create_logstore_if_not_exist()
# Step 2: Initialize PG client and connection (if enabled)
if not self.pg_mode_enabled:
logger.info("PG mode is disabled. Will use SDK mode.")
return
# Create PG client if not already created
if self._pg_client is None:
logger.info("Initializing PG client for project %s...", self.project_name)
self._pg_client = AliyunLogStorePG(
self.access_key_id, self.access_key_secret, self.endpoint, self.project_name
)
# Step 3: Establish PG connection based on project status
if project_is_new:
# For newly created projects, schedule delayed PG connection
self._use_pg_protocol = False
logger.info(
"Project %s is newly created. Will use SDK mode and schedule PG connection attempt in %d seconds.",
self.project_name,
self.__class__._pg_connection_delay,
)
if self.__class__._pg_connection_timer is not None:
self.__class__._pg_connection_timer.cancel()
self.__class__._pg_connection_timer = threading.Timer(
self.__class__._pg_connection_delay,
self._delayed_pg_connection_init,
)
self.__class__._pg_connection_timer.daemon = True # Don't block app shutdown
self.__class__._pg_connection_timer.start()
else:
# For existing projects, attempt PG connection immediately
logger.info("Project %s already exists. Attempting PG connection...", self.project_name)
self._attempt_pg_connection_init()
def _check_and_disable_pg_if_scan_index_disabled(self) -> None:
"""
Check if scan_index is enabled for all logstores.
If any logstore has scan_index=false, disable PG protocol.
This is necessary because PG protocol requires scan_index to be enabled.
"""
logstore_name_list = [
AliyunLogStore.workflow_execution_logstore,
AliyunLogStore.workflow_node_execution_logstore,
]
for logstore_name in logstore_name_list:
existing_config = self.get_existing_index_config(logstore_name)
if existing_config and not existing_config.scan_index:
logger.info(
"Logstore %s has scan_index=false, USE SDK mode for read/write operations. "
"PG protocol requires scan_index to be enabled.",
logstore_name,
)
self._use_pg_protocol = False
# Close PG connection if it was initialized
if self._pg_client:
self._pg_client.close()
self._pg_client = None
return
def is_project_exist(self) -> bool:
try:
self.client.get_project(self.project_name)
return True
except Exception as e:
if e.args[0] == "ProjectNotExist":
return False
else:
raise e
def create_project(self):
try:
self.client.create_project(self.project_name, AliyunLogStore.project_des)
logger.info("Project %s created successfully", self.project_name)
except LogException as e:
logger.exception(
"Failed to create project %s: errorCode=%s, errorMessage=%s, requestId=%s",
self.project_name,
e.get_error_code(),
e.get_error_message(),
e.get_request_id(),
)
raise
def is_logstore_exist(self, logstore_name: str) -> bool:
try:
_ = self.client.get_logstore(self.project_name, logstore_name)
return True
except Exception as e:
if e.args[0] == "LogStoreNotExist":
return False
else:
raise e
def create_logstore_if_not_exist(self) -> None:
logstore_name_list = [
AliyunLogStore.workflow_execution_logstore,
AliyunLogStore.workflow_node_execution_logstore,
]
for logstore_name in logstore_name_list:
if not self.is_logstore_exist(logstore_name):
try:
self.client.create_logstore(
project_name=self.project_name, logstore_name=logstore_name, ttl=self.logstore_ttl
)
logger.info("logstore %s created successfully", logstore_name)
except LogException as e:
logger.exception(
"Failed to create logstore %s: errorCode=%s, errorMessage=%s, requestId=%s",
logstore_name,
e.get_error_code(),
e.get_error_message(),
e.get_request_id(),
)
raise
# Ensure index contains all Dify-required fields
# This intelligently merges with existing config, preserving custom indexes
self.ensure_index_config(logstore_name)
def is_index_exist(self, logstore_name: str) -> bool:
try:
_ = self.client.get_index_config(self.project_name, logstore_name)
return True
except Exception as e:
if e.args[0] == "IndexConfigNotExist":
return False
else:
raise e
def get_existing_index_config(self, logstore_name: str) -> IndexConfig | None:
"""
Get existing index configuration from logstore.
Args:
logstore_name: Name of the logstore
Returns:
IndexConfig object if index exists, None otherwise
"""
try:
response = self.client.get_index_config(self.project_name, logstore_name)
return response.get_index_config()
except Exception as e:
if e.args[0] == "IndexConfigNotExist":
return None
else:
logger.exception("Failed to get index config for logstore %s", logstore_name)
raise e
def _get_workflow_execution_index_keys(self) -> dict[str, IndexKeyConfig]:
"""
Get field index configuration for workflow_execution logstore.
This method automatically generates index configuration from the WorkflowRun SQLAlchemy model.
When the PG schema is updated via Flask-Migrate, the index configuration will be automatically
updated on next application startup.
"""
from models.workflow import WorkflowRun
index_keys = self._generate_index_keys_from_model(WorkflowRun)
# Add custom fields that are in logstore but not in PG model
# These fields are added by the repository layer
index_keys["error_message"] = IndexKeyConfig(
index_type="text",
case_sensitive=False,
doc_value=True,
token_list=self.DEFAULT_TOKEN_LIST,
chinese=True,
) # Maps to 'error' in PG
index_keys["started_at"] = IndexKeyConfig(
index_type="text",
case_sensitive=False,
doc_value=True,
token_list=self.DEFAULT_TOKEN_LIST,
chinese=True,
) # Maps to 'created_at' in PG
logger.info("Generated %d index keys for workflow_execution from WorkflowRun model", len(index_keys))
return index_keys
def _get_workflow_node_execution_index_keys(self) -> dict[str, IndexKeyConfig]:
"""
Get field index configuration for workflow_node_execution logstore.
This method automatically generates index configuration from the WorkflowNodeExecutionModel.
When the PG schema is updated via Flask-Migrate, the index configuration will be automatically
updated on next application startup.
"""
from models.workflow import WorkflowNodeExecutionModel
index_keys = self._generate_index_keys_from_model(WorkflowNodeExecutionModel)
logger.debug(
"Generated %d index keys for workflow_node_execution from WorkflowNodeExecutionModel", len(index_keys)
)
return index_keys
def _get_index_config(self, logstore_name: str) -> IndexConfig:
"""
Get index configuration for the specified logstore.
Args:
logstore_name: Name of the logstore
Returns:
IndexConfig object with line and field indexes
"""
# Create full-text index (line config) with tokenizer
line_config = IndexLineConfig(token_list=self.DEFAULT_TOKEN_LIST, case_sensitive=False, chinese=True)
# Get field index configuration based on logstore name
field_keys = {}
if logstore_name == AliyunLogStore.workflow_execution_logstore:
field_keys = self._get_workflow_execution_index_keys()
elif logstore_name == AliyunLogStore.workflow_node_execution_logstore:
field_keys = self._get_workflow_node_execution_index_keys()
# key_config_list should be a dict, not a list
# Create index config with both line and field indexes
return IndexConfig(line_config=line_config, key_config_list=field_keys, scan_index=True)
def create_index(self, logstore_name: str) -> None:
"""
Create index for the specified logstore with both full-text and field indexes.
Field indexes are automatically generated from the corresponding SQLAlchemy model.
"""
index_config = self._get_index_config(logstore_name)
try:
self.client.create_index(self.project_name, logstore_name, index_config)
logger.info(
"index for %s created successfully with %d field indexes",
logstore_name,
len(index_config.key_config_list or {}),
)
except LogException as e:
logger.exception(
"Failed to create index for logstore %s: errorCode=%s, errorMessage=%s, requestId=%s",
logstore_name,
e.get_error_code(),
e.get_error_message(),
e.get_request_id(),
)
raise
def _merge_index_configs(
self, existing_config: IndexConfig, required_keys: dict[str, IndexKeyConfig], logstore_name: str
) -> tuple[IndexConfig, bool]:
"""
Intelligently merge existing index config with Dify's required field indexes.
This method:
1. Preserves all existing field indexes in logstore (including custom fields)
2. Adds missing Dify-required fields
3. Updates fields where type doesn't match (with json/text compatibility)
4. Corrects case mismatches (e.g., if Dify needs 'status' but logstore has 'Status')
Type compatibility rules:
- json and text types are considered compatible (users can manually choose either)
- All other type mismatches will be corrected to match Dify requirements
Note: Logstore is case-sensitive and doesn't allow duplicate fields with different cases.
Case mismatch means: existing field name differs from required name only in case.
Args:
existing_config: Current index configuration from logstore
required_keys: Dify's required field index configurations
logstore_name: Name of the logstore (for logging)
Returns:
Tuple of (merged_config, needs_update)
"""
# key_config_list is already a dict in the SDK
# Make a copy to avoid modifying the original
existing_keys = dict(existing_config.key_config_list) if existing_config.key_config_list else {}
# Track changes
needs_update = False
case_corrections = [] # Fields that need case correction (e.g., 'Status' -> 'status')
missing_fields = []
type_mismatches = []
# First pass: Check for and resolve case mismatches with required fields
# Note: Logstore itself doesn't allow duplicate fields with different cases,
# so we only need to check if the existing case matches the required case
for required_name in required_keys:
lower_name = required_name.lower()
# Find key that matches case-insensitively but not exactly
wrong_case_key = None
for existing_key in existing_keys:
if existing_key.lower() == lower_name and existing_key != required_name:
wrong_case_key = existing_key
break
if wrong_case_key:
# Field exists but with wrong case (e.g., 'Status' when we need 'status')
# Remove the wrong-case key, will be added back with correct case later
case_corrections.append((wrong_case_key, required_name))
del existing_keys[wrong_case_key]
needs_update = True
# Second pass: Check each required field
for required_name, required_config in required_keys.items():
# Check for exact match (case-sensitive)
if required_name in existing_keys:
existing_type = existing_keys[required_name].index_type
required_type = required_config.index_type
# Check if type matches
# Special case: json and text are interchangeable for JSON content fields
# Allow users to manually configure text instead of json (or vice versa) without forcing updates
is_compatible = existing_type == required_type or ({existing_type, required_type} == {"json", "text"})
if not is_compatible:
type_mismatches.append((required_name, existing_type, required_type))
# Update with correct type
existing_keys[required_name] = required_config
needs_update = True
# else: field exists with compatible type, no action needed
else:
# Field doesn't exist (may have been removed in first pass due to case conflict)
missing_fields.append(required_name)
existing_keys[required_name] = required_config
needs_update = True
# Log changes
if missing_fields:
logger.info(
"Logstore %s: Adding %d missing Dify-required fields: %s",
logstore_name,
len(missing_fields),
", ".join(missing_fields[:10]) + ("..." if len(missing_fields) > 10 else ""),
)
if type_mismatches:
logger.info(
"Logstore %s: Fixing %d type mismatches: %s",
logstore_name,
len(type_mismatches),
", ".join([f"{name}({old}->{new})" for name, old, new in type_mismatches[:5]])
+ ("..." if len(type_mismatches) > 5 else ""),
)
if case_corrections:
logger.info(
"Logstore %s: Correcting %d field name cases: %s",
logstore_name,
len(case_corrections),
", ".join([f"'{old}' -> '{new}'" for old, new in case_corrections[:5]])
+ ("..." if len(case_corrections) > 5 else ""),
)
# Create merged config
# key_config_list should be a dict, not a list
# Preserve the original scan_index value - don't force it to True
merged_config = IndexConfig(
line_config=existing_config.line_config
or IndexLineConfig(token_list=self.DEFAULT_TOKEN_LIST, case_sensitive=False, chinese=True),
key_config_list=existing_keys,
scan_index=existing_config.scan_index,
)
return merged_config, needs_update
def ensure_index_config(self, logstore_name: str) -> None:
"""
Ensure index configuration includes all Dify-required fields.
This method intelligently manages index configuration:
1. If index doesn't exist, create it with Dify's required fields
2. If index exists:
- Check if all Dify-required fields are present
- Check if field types match requirements
- Only update if fields are missing or types are incorrect
- Preserve any additional custom index configurations
This approach allows users to add their own custom indexes without being overwritten.
"""
# Get Dify's required field indexes
required_keys = {}
if logstore_name == AliyunLogStore.workflow_execution_logstore:
required_keys = self._get_workflow_execution_index_keys()
elif logstore_name == AliyunLogStore.workflow_node_execution_logstore:
required_keys = self._get_workflow_node_execution_index_keys()
# Check if index exists
existing_config = self.get_existing_index_config(logstore_name)
if existing_config is None:
# Index doesn't exist, create it
logger.info(
"Logstore %s: Index doesn't exist, creating with %d required fields",
logstore_name,
len(required_keys),
)
self.create_index(logstore_name)
else:
merged_config, needs_update = self._merge_index_configs(existing_config, required_keys, logstore_name)
if needs_update:
logger.info("Logstore %s: Updating index to include Dify-required fields", logstore_name)
try:
self.client.update_index(self.project_name, logstore_name, merged_config)
logger.info(
"Logstore %s: Index updated successfully, now has %d total field indexes",
logstore_name,
len(merged_config.key_config_list or {}),
)
except LogException as e:
logger.exception(
"Failed to update index for logstore %s: errorCode=%s, errorMessage=%s, requestId=%s",
logstore_name,
e.get_error_code(),
e.get_error_message(),
e.get_request_id(),
)
raise
else:
logger.info(
"Logstore %s: Index already contains all %d Dify-required fields with correct types, "
"no update needed",
logstore_name,
len(required_keys),
)
def put_log(self, logstore: str, contents: Sequence[tuple[str, str]]) -> None:
# Route to PG or SDK based on protocol availability
if self._use_pg_protocol and self._pg_client:
self._pg_client.put_log(logstore, contents, self.log_enabled)
else:
log_item = LogItem(contents=contents)
request = PutLogsRequest(project=self.project_name, logstore=logstore, logitems=[log_item])
if self.log_enabled:
logger.info(
"[LogStore-SDK] PUT_LOG | logstore=%s | project=%s | items_count=%d",
logstore,
self.project_name,
len(contents),
)
try:
self.client.put_logs(request)
except LogException as e:
logger.exception(
"Failed to put logs to logstore %s: errorCode=%s, errorMessage=%s, requestId=%s",
logstore,
e.get_error_code(),
e.get_error_message(),
e.get_request_id(),
)
raise
def get_logs(
self,
logstore: str,
from_time: int,
to_time: int,
topic: str = "",
query: str = "",
line: int = 100,
offset: int = 0,
reverse: bool = True,
) -> list[dict]:
request = GetLogsRequest(
project=self.project_name,
logstore=logstore,
fromTime=from_time,
toTime=to_time,
topic=topic,
query=query,
line=line,
offset=offset,
reverse=reverse,
)
# Log query info if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore] GET_LOGS | logstore=%s | project=%s | query=%s | "
"from_time=%d | to_time=%d | line=%d | offset=%d | reverse=%s",
logstore,
self.project_name,
query,
from_time,
to_time,
line,
offset,
reverse,
)
try:
response = self.client.get_logs(request)
result = []
logs = response.get_logs() if response else []
for log in logs:
result.append(log.get_contents())
# Log result count if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore] GET_LOGS RESULT | logstore=%s | returned_count=%d",
logstore,
len(result),
)
return result
except LogException as e:
logger.exception(
"Failed to get logs from logstore %s with query '%s': errorCode=%s, errorMessage=%s, requestId=%s",
logstore,
query,
e.get_error_code(),
e.get_error_message(),
e.get_request_id(),
)
raise
def execute_sql(
self,
sql: str,
logstore: str | None = None,
query: str = "*",
from_time: int | None = None,
to_time: int | None = None,
power_sql: bool = False,
) -> list[dict]:
"""
Execute SQL query for aggregation and analysis.
Args:
sql: SQL query string (SELECT statement)
logstore: Name of the logstore (required)
query: Search/filter query for SDK mode (default: "*" for all logs).
Only used in SDK mode. PG mode ignores this parameter.
from_time: Start time (Unix timestamp) - only used in SDK mode
to_time: End time (Unix timestamp) - only used in SDK mode
power_sql: Whether to use enhanced SQL mode (default: False)
Returns:
List of result rows as dictionaries
Note:
- PG mode: Only executes the SQL directly
- SDK mode: Combines query and sql as "query | sql"
"""
# Logstore is required
if not logstore:
raise ValueError("logstore parameter is required for execute_sql")
# Route to PG or SDK based on protocol availability
if self._use_pg_protocol and self._pg_client:
# PG mode: execute SQL directly (ignore query parameter)
return self._pg_client.execute_sql(sql, logstore, self.log_enabled)
else:
# SDK mode: combine query and sql as "query | sql"
full_query = f"{query} | {sql}"
# Provide default time range if not specified
if from_time is None:
from_time = 0
if to_time is None:
to_time = int(time.time()) # now
request = GetLogsRequest(
project=self.project_name,
logstore=logstore,
fromTime=from_time,
toTime=to_time,
query=full_query,
)
# Log query info if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore-SDK] EXECUTE_SQL | logstore=%s | project=%s | from_time=%d | to_time=%d | full_query=%s",
logstore,
self.project_name,
from_time,
to_time,
query,
sql,
)
try:
response = self.client.get_logs(request)
result = []
logs = response.get_logs() if response else []
for log in logs:
result.append(log.get_contents())
# Log result count if SQLALCHEMY_ECHO is enabled
if self.log_enabled:
logger.info(
"[LogStore-SDK] EXECUTE_SQL RESULT | logstore=%s | returned_count=%d",
logstore,
len(result),
)
return result
except LogException as e:
logger.exception(
"Failed to execute SQL, logstore %s: errorCode=%s, errorMessage=%s, requestId=%s, full_query=%s",
logstore,
e.get_error_code(),
e.get_error_message(),
e.get_request_id(),
full_query,
)
raise
if __name__ == "__main__":
aliyun_logstore = AliyunLogStore()
# aliyun_logstore.init_project_logstore()
aliyun_logstore.put_log(AliyunLogStore.workflow_execution_logstore, [("key1", "value1")])

View File

@ -0,0 +1,407 @@
import logging
import os
import socket
import time
from collections.abc import Sequence
from contextlib import contextmanager
from typing import Any
import psycopg2
import psycopg2.pool
from psycopg2 import InterfaceError, OperationalError
from configs import dify_config
logger = logging.getLogger(__name__)
class AliyunLogStorePG:
"""
PostgreSQL protocol support for Aliyun SLS LogStore.
Handles PG connection pooling and operations for regions that support PG protocol.
"""
def __init__(self, access_key_id: str, access_key_secret: str, endpoint: str, project_name: str):
"""
Initialize PG connection for SLS.
Args:
access_key_id: Aliyun access key ID
access_key_secret: Aliyun access key secret
endpoint: SLS endpoint
project_name: SLS project name
"""
self._access_key_id = access_key_id
self._access_key_secret = access_key_secret
self._endpoint = endpoint
self.project_name = project_name
self._pg_pool: psycopg2.pool.SimpleConnectionPool | None = None
self._use_pg_protocol = False
def _check_port_connectivity(self, host: str, port: int, timeout: float = 2.0) -> bool:
"""
Check if a TCP port is reachable using socket connection.
This provides a fast check before attempting full database connection,
preventing long waits when connecting to unsupported regions.
Args:
host: Hostname or IP address
port: Port number
timeout: Connection timeout in seconds (default: 2.0)
Returns:
True if port is reachable, False otherwise
"""
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(timeout)
result = sock.connect_ex((host, port))
sock.close()
return result == 0
except Exception as e:
logger.debug("Port connectivity check failed for %s:%d: %s", host, port, str(e))
return False
def init_connection(self) -> bool:
"""
Initialize PostgreSQL connection pool for SLS PG protocol support.
Attempts to connect to SLS using PostgreSQL protocol. If successful, sets
_use_pg_protocol to True and creates a connection pool. If connection fails
(region doesn't support PG protocol or other errors), returns False.
Returns:
True if PG protocol is supported and initialized, False otherwise
"""
try:
# Extract hostname from endpoint (remove protocol if present)
pg_host = self._endpoint.replace("http://", "").replace("https://", "")
# Get pool configuration
pg_max_connections = int(os.environ.get("ALIYUN_SLS_PG_MAX_CONNECTIONS", 10))
logger.debug(
"Check PG protocol connection to SLS: host=%s, project=%s",
pg_host,
self.project_name,
)
# Fast port connectivity check before attempting full connection
# This prevents long waits when connecting to unsupported regions
if not self._check_port_connectivity(pg_host, 5432, timeout=1.0):
logger.info(
"USE SDK mode for read/write operations, host=%s",
pg_host,
)
return False
# Create connection pool
self._pg_pool = psycopg2.pool.SimpleConnectionPool(
minconn=1,
maxconn=pg_max_connections,
host=pg_host,
port=5432,
database=self.project_name,
user=self._access_key_id,
password=self._access_key_secret,
sslmode="require",
connect_timeout=5,
application_name=f"Dify-{dify_config.project.version}",
)
# Note: Skip test query because SLS PG protocol only supports SELECT/INSERT on actual tables
# Connection pool creation success already indicates connectivity
self._use_pg_protocol = True
logger.info(
"PG protocol initialized successfully for SLS project=%s. Will use PG for read/write operations.",
self.project_name,
)
return True
except Exception as e:
# PG connection failed - fallback to SDK mode
self._use_pg_protocol = False
if self._pg_pool:
try:
self._pg_pool.closeall()
except Exception:
logger.debug("Failed to close PG connection pool during cleanup, ignoring")
self._pg_pool = None
logger.info(
"PG protocol connection failed (region may not support PG protocol): %s. "
"Falling back to SDK mode for read/write operations.",
str(e),
)
return False
def _is_connection_valid(self, conn: Any) -> bool:
"""
Check if a connection is still valid.
Args:
conn: psycopg2 connection object
Returns:
True if connection is valid, False otherwise
"""
try:
# Check if connection is closed
if conn.closed:
return False
# Quick ping test - execute a lightweight query
# For SLS PG protocol, we can't use SELECT 1 without FROM,
# so we just check the connection status
with conn.cursor() as cursor:
cursor.execute("SELECT 1")
cursor.fetchone()
return True
except Exception:
return False
@contextmanager
def _get_connection(self):
"""
Context manager to get a PostgreSQL connection from the pool.
Automatically validates and refreshes stale connections.
Note: Aliyun SLS PG protocol does not support transactions, so we always
use autocommit mode.
Yields:
psycopg2 connection object
Raises:
RuntimeError: If PG pool is not initialized
"""
if not self._pg_pool:
raise RuntimeError("PG connection pool is not initialized")
conn = self._pg_pool.getconn()
try:
# Validate connection and get a fresh one if needed
if not self._is_connection_valid(conn):
logger.debug("Connection is stale, marking as bad and getting a new one")
# Mark connection as bad and get a new one
self._pg_pool.putconn(conn, close=True)
conn = self._pg_pool.getconn()
# Aliyun SLS PG protocol does not support transactions, always use autocommit
conn.autocommit = True
yield conn
finally:
# Return connection to pool (or close if it's bad)
if self._is_connection_valid(conn):
self._pg_pool.putconn(conn)
else:
self._pg_pool.putconn(conn, close=True)
def close(self) -> None:
"""Close the PostgreSQL connection pool."""
if self._pg_pool:
try:
self._pg_pool.closeall()
logger.info("PG connection pool closed")
except Exception:
logger.exception("Failed to close PG connection pool")
def _is_retriable_error(self, error: Exception) -> bool:
"""
Check if an error is retriable (connection-related issues).
Args:
error: Exception to check
Returns:
True if the error is retriable, False otherwise
"""
# Retry on connection-related errors
if isinstance(error, (OperationalError, InterfaceError)):
return True
# Check error message for specific connection issues
error_msg = str(error).lower()
retriable_patterns = [
"connection",
"timeout",
"closed",
"broken pipe",
"reset by peer",
"no route to host",
"network",
]
return any(pattern in error_msg for pattern in retriable_patterns)
def put_log(self, logstore: str, contents: Sequence[tuple[str, str]], log_enabled: bool = False) -> None:
"""
Write log to SLS using PostgreSQL protocol with automatic retry.
Note: SLS PG protocol only supports INSERT (not UPDATE). This uses append-only
writes with log_version field for versioning, same as SDK implementation.
Args:
logstore: Name of the logstore table
contents: List of (field_name, value) tuples
log_enabled: Whether to enable logging
Raises:
psycopg2.Error: If database operation fails after all retries
"""
if not contents:
return
# Extract field names and values from contents
fields = [field_name for field_name, _ in contents]
values = [value for _, value in contents]
# Build INSERT statement with literal values
# Note: Aliyun SLS PG protocol doesn't support parameterized queries,
# so we need to use mogrify to safely create literal values
field_list = ", ".join([f'"{field}"' for field in fields])
if log_enabled:
logger.info(
"[LogStore-PG] PUT_LOG | logstore=%s | project=%s | items_count=%d",
logstore,
self.project_name,
len(contents),
)
# Retry configuration
max_retries = 3
retry_delay = 0.1 # Start with 100ms
for attempt in range(max_retries):
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
# Use mogrify to safely convert values to SQL literals
placeholders = ", ".join(["%s"] * len(fields))
values_literal = cursor.mogrify(f"({placeholders})", values).decode("utf-8")
insert_sql = f'INSERT INTO "{logstore}" ({field_list}) VALUES {values_literal}'
cursor.execute(insert_sql)
# Success - exit retry loop
return
except psycopg2.Error as e:
# Check if error is retriable
if not self._is_retriable_error(e):
# Not a retriable error (e.g., data validation error), fail immediately
logger.exception(
"Failed to put logs to logstore %s via PG protocol (non-retriable error)",
logstore,
)
raise
# Retriable error - log and retry if we have attempts left
if attempt < max_retries - 1:
logger.warning(
"Failed to put logs to logstore %s via PG protocol (attempt %d/%d): %s. Retrying...",
logstore,
attempt + 1,
max_retries,
str(e),
)
time.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff
else:
# Last attempt failed
logger.exception(
"Failed to put logs to logstore %s via PG protocol after %d attempts",
logstore,
max_retries,
)
raise
def execute_sql(self, sql: str, logstore: str, log_enabled: bool = False) -> list[dict[str, Any]]:
"""
Execute SQL query using PostgreSQL protocol with automatic retry.
Args:
sql: SQL query string
logstore: Name of the logstore (for logging purposes)
log_enabled: Whether to enable logging
Returns:
List of result rows as dictionaries
Raises:
psycopg2.Error: If database operation fails after all retries
"""
if log_enabled:
logger.info(
"[LogStore-PG] EXECUTE_SQL | logstore=%s | project=%s | sql=%s",
logstore,
self.project_name,
sql,
)
# Retry configuration
max_retries = 3
retry_delay = 0.1 # Start with 100ms
for attempt in range(max_retries):
try:
with self._get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(sql)
# Get column names from cursor description
columns = [desc[0] for desc in cursor.description]
# Fetch all results and convert to list of dicts
result = []
for row in cursor.fetchall():
row_dict = {}
for col, val in zip(columns, row):
row_dict[col] = "" if val is None else str(val)
result.append(row_dict)
if log_enabled:
logger.info(
"[LogStore-PG] EXECUTE_SQL RESULT | logstore=%s | returned_count=%d",
logstore,
len(result),
)
return result
except psycopg2.Error as e:
# Check if error is retriable
if not self._is_retriable_error(e):
# Not a retriable error (e.g., SQL syntax error), fail immediately
logger.exception(
"Failed to execute SQL query on logstore %s via PG protocol (non-retriable error): sql=%s",
logstore,
sql,
)
raise
# Retriable error - log and retry if we have attempts left
if attempt < max_retries - 1:
logger.warning(
"Failed to execute SQL query on logstore %s via PG protocol (attempt %d/%d): %s. Retrying...",
logstore,
attempt + 1,
max_retries,
str(e),
)
time.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff
else:
# Last attempt failed
logger.exception(
"Failed to execute SQL query on logstore %s via PG protocol after %d attempts: sql=%s",
logstore,
max_retries,
sql,
)
raise
# This line should never be reached due to raise above, but makes type checker happy
return []

View File

@ -0,0 +1,365 @@
"""
LogStore implementation of DifyAPIWorkflowNodeExecutionRepository.
This module provides the LogStore-based implementation for service-layer
WorkflowNodeExecutionModel operations using Aliyun SLS LogStore.
"""
import logging
import time
from collections.abc import Sequence
from datetime import datetime
from typing import Any
from sqlalchemy.orm import sessionmaker
from extensions.logstore.aliyun_logstore import AliyunLogStore
from models.workflow import WorkflowNodeExecutionModel
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
logger = logging.getLogger(__name__)
def _dict_to_workflow_node_execution_model(data: dict[str, Any]) -> WorkflowNodeExecutionModel:
"""
Convert LogStore result dictionary to WorkflowNodeExecutionModel instance.
Args:
data: Dictionary from LogStore query result
Returns:
WorkflowNodeExecutionModel instance (detached from session)
Note:
The returned model is not attached to any SQLAlchemy session.
Relationship fields (like offload_data) are not loaded from LogStore.
"""
logger.debug("_dict_to_workflow_node_execution_model: data keys=%s", list(data.keys())[:5])
# Create model instance without session
model = WorkflowNodeExecutionModel()
# Map all required fields with validation
# Critical fields - must not be None
model.id = data.get("id") or ""
model.tenant_id = data.get("tenant_id") or ""
model.app_id = data.get("app_id") or ""
model.workflow_id = data.get("workflow_id") or ""
model.triggered_from = data.get("triggered_from") or ""
model.node_id = data.get("node_id") or ""
model.node_type = data.get("node_type") or ""
model.status = data.get("status") or "running" # Default status if missing
model.title = data.get("title") or ""
model.created_by_role = data.get("created_by_role") or ""
model.created_by = data.get("created_by") or ""
# Numeric fields with defaults
model.index = int(data.get("index", 0))
model.elapsed_time = float(data.get("elapsed_time", 0))
# Optional fields
model.workflow_run_id = data.get("workflow_run_id")
model.predecessor_node_id = data.get("predecessor_node_id")
model.node_execution_id = data.get("node_execution_id")
model.inputs = data.get("inputs")
model.process_data = data.get("process_data")
model.outputs = data.get("outputs")
model.error = data.get("error")
model.execution_metadata = data.get("execution_metadata")
# Handle datetime fields
created_at = data.get("created_at")
if created_at:
if isinstance(created_at, str):
model.created_at = datetime.fromisoformat(created_at)
elif isinstance(created_at, (int, float)):
model.created_at = datetime.fromtimestamp(created_at)
else:
model.created_at = created_at
else:
# Provide default created_at if missing
model.created_at = datetime.now()
finished_at = data.get("finished_at")
if finished_at:
if isinstance(finished_at, str):
model.finished_at = datetime.fromisoformat(finished_at)
elif isinstance(finished_at, (int, float)):
model.finished_at = datetime.fromtimestamp(finished_at)
else:
model.finished_at = finished_at
return model
class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository):
"""
LogStore implementation of DifyAPIWorkflowNodeExecutionRepository.
Provides service-layer database operations for WorkflowNodeExecutionModel
using LogStore SQL queries with optimized deduplication strategies.
"""
def __init__(self, session_maker: sessionmaker | None = None):
"""
Initialize the repository with LogStore client.
Args:
session_maker: SQLAlchemy sessionmaker (unused, for compatibility with factory pattern)
"""
logger.debug("LogstoreAPIWorkflowNodeExecutionRepository.__init__: initializing")
self.logstore_client = AliyunLogStore()
def get_node_last_execution(
self,
tenant_id: str,
app_id: str,
workflow_id: str,
node_id: str,
) -> WorkflowNodeExecutionModel | None:
"""
Get the most recent execution for a specific node.
Uses query syntax to get raw logs and selects the one with max log_version.
Returns the most recent execution ordered by created_at.
"""
logger.debug(
"get_node_last_execution: tenant_id=%s, app_id=%s, workflow_id=%s, node_id=%s",
tenant_id,
app_id,
workflow_id,
node_id,
)
try:
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of each record)
sql_query = f"""
SELECT * FROM (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
WHERE tenant_id = '{tenant_id}'
AND app_id = '{app_id}'
AND workflow_id = '{workflow_id}'
AND node_id = '{node_id}'
AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
"""
results = self.logstore_client.execute_sql(
sql=sql_query,
logstore=AliyunLogStore.workflow_node_execution_logstore,
)
else:
# Use SDK with LogStore query syntax
query = (
f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_id: {workflow_id} and node_id: {node_id}"
)
from_time = 0
to_time = int(time.time()) # now
results = self.logstore_client.get_logs(
logstore=AliyunLogStore.workflow_node_execution_logstore,
from_time=from_time,
to_time=to_time,
query=query,
line=100,
reverse=False,
)
if not results:
return None
# For SDK mode, group by id and select the one with max log_version for each group
# For PG mode, this is already done by the SQL query
if not self.logstore_client.supports_pg_protocol:
id_to_results: dict[str, list[dict[str, Any]]] = {}
for row in results:
row_id = row.get("id")
if row_id:
if row_id not in id_to_results:
id_to_results[row_id] = []
id_to_results[row_id].append(row)
# For each id, select the row with max log_version
deduplicated_results = []
for rows in id_to_results.values():
if len(rows) > 1:
max_row = max(rows, key=lambda x: int(x.get("log_version", 0)))
else:
max_row = rows[0]
deduplicated_results.append(max_row)
else:
# For PG mode, results are already deduplicated by the SQL query
deduplicated_results = results
# Sort by created_at DESC and return the most recent one
deduplicated_results.sort(
key=lambda x: x.get("created_at", 0) if isinstance(x.get("created_at"), (int, float)) else 0,
reverse=True,
)
if deduplicated_results:
return _dict_to_workflow_node_execution_model(deduplicated_results[0])
return None
except Exception:
logger.exception("Failed to get node last execution from LogStore")
raise
def get_executions_by_workflow_run(
self,
tenant_id: str,
app_id: str,
workflow_run_id: str,
) -> Sequence[WorkflowNodeExecutionModel]:
"""
Get all node executions for a specific workflow run.
Uses query syntax to get raw logs and selects the one with max log_version for each node execution.
Ordered by index DESC for trace visualization.
"""
logger.debug(
"[LogStore] get_executions_by_workflow_run: tenant_id=%s, app_id=%s, workflow_run_id=%s",
tenant_id,
app_id,
workflow_run_id,
)
try:
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of each record)
sql_query = f"""
SELECT * FROM (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
WHERE tenant_id = '{tenant_id}'
AND app_id = '{app_id}'
AND workflow_run_id = '{workflow_run_id}'
AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 1000
"""
results = self.logstore_client.execute_sql(
sql=sql_query,
logstore=AliyunLogStore.workflow_node_execution_logstore,
)
else:
# Use SDK with LogStore query syntax
query = f"tenant_id: {tenant_id} and app_id: {app_id} and workflow_run_id: {workflow_run_id}"
from_time = 0
to_time = int(time.time()) # now
results = self.logstore_client.get_logs(
logstore=AliyunLogStore.workflow_node_execution_logstore,
from_time=from_time,
to_time=to_time,
query=query,
line=1000, # Get more results for node executions
reverse=False,
)
if not results:
return []
# For SDK mode, group by id and select the one with max log_version for each group
# For PG mode, this is already done by the SQL query
models = []
if not self.logstore_client.supports_pg_protocol:
id_to_results: dict[str, list[dict[str, Any]]] = {}
for row in results:
row_id = row.get("id")
if row_id:
if row_id not in id_to_results:
id_to_results[row_id] = []
id_to_results[row_id].append(row)
# For each id, select the row with max log_version
for rows in id_to_results.values():
if len(rows) > 1:
max_row = max(rows, key=lambda x: int(x.get("log_version", 0)))
else:
max_row = rows[0]
model = _dict_to_workflow_node_execution_model(max_row)
if model and model.id: # Ensure model is valid
models.append(model)
else:
# For PG mode, results are already deduplicated by the SQL query
for row in results:
model = _dict_to_workflow_node_execution_model(row)
if model and model.id: # Ensure model is valid
models.append(model)
# Sort by index DESC for trace visualization
models.sort(key=lambda x: x.index, reverse=True)
return models
except Exception:
logger.exception("Failed to get executions by workflow run from LogStore")
raise
def get_execution_by_id(
self,
execution_id: str,
tenant_id: str | None = None,
) -> WorkflowNodeExecutionModel | None:
"""
Get a workflow node execution by its ID.
Uses query syntax to get raw logs and selects the one with max log_version.
"""
logger.debug("get_execution_by_id: execution_id=%s, tenant_id=%s", execution_id, tenant_id)
try:
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
tenant_filter = f"AND tenant_id = '{tenant_id}'" if tenant_id else ""
sql_query = f"""
SELECT * FROM (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
WHERE id = '{execution_id}' {tenant_filter} AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 1
"""
results = self.logstore_client.execute_sql(
sql=sql_query,
logstore=AliyunLogStore.workflow_node_execution_logstore,
)
else:
# Use SDK with LogStore query syntax
if tenant_id:
query = f"id: {execution_id} and tenant_id: {tenant_id}"
else:
query = f"id: {execution_id}"
from_time = 0
to_time = int(time.time()) # now
results = self.logstore_client.get_logs(
logstore=AliyunLogStore.workflow_node_execution_logstore,
from_time=from_time,
to_time=to_time,
query=query,
line=100,
reverse=False,
)
if not results:
return None
# For PG mode, result is already the latest version
# For SDK mode, if multiple results, select the one with max log_version
if self.logstore_client.supports_pg_protocol or len(results) == 1:
return _dict_to_workflow_node_execution_model(results[0])
else:
max_result = max(results, key=lambda x: int(x.get("log_version", 0)))
return _dict_to_workflow_node_execution_model(max_result)
except Exception:
logger.exception("Failed to get execution by ID from LogStore: execution_id=%s", execution_id)
raise

View File

@ -0,0 +1,757 @@
"""
LogStore API WorkflowRun Repository Implementation
This module provides the LogStore-based implementation of the APIWorkflowRunRepository
protocol. It handles service-layer WorkflowRun database operations using Aliyun SLS LogStore
with optimized queries for statistics and pagination.
Key Features:
- LogStore SQL queries for aggregation and statistics
- Optimized deduplication using finished_at IS NOT NULL filter
- Window functions only when necessary (running status queries)
- Multi-tenant data isolation and security
"""
import logging
import os
import time
from collections.abc import Sequence
from datetime import datetime
from typing import Any, cast
from sqlalchemy.orm import sessionmaker
from extensions.logstore.aliyun_logstore import AliyunLogStore
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import WorkflowRun
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.types import (
AverageInteractionStats,
DailyRunsStats,
DailyTerminalsStats,
DailyTokenCostStats,
)
logger = logging.getLogger(__name__)
def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun:
"""
Convert LogStore result dictionary to WorkflowRun instance.
Args:
data: Dictionary from LogStore query result
Returns:
WorkflowRun instance
"""
logger.debug("_dict_to_workflow_run: data keys=%s", list(data.keys())[:5])
# Create model instance without session
model = WorkflowRun()
# Map all required fields with validation
# Critical fields - must not be None
model.id = data.get("id") or ""
model.tenant_id = data.get("tenant_id") or ""
model.app_id = data.get("app_id") or ""
model.workflow_id = data.get("workflow_id") or ""
model.type = data.get("type") or ""
model.triggered_from = data.get("triggered_from") or ""
model.version = data.get("version") or ""
model.status = data.get("status") or "running" # Default status if missing
model.created_by_role = data.get("created_by_role") or ""
model.created_by = data.get("created_by") or ""
# Numeric fields with defaults
model.total_tokens = int(data.get("total_tokens", 0))
model.total_steps = int(data.get("total_steps", 0))
model.exceptions_count = int(data.get("exceptions_count", 0))
# Optional fields
model.graph = data.get("graph")
model.inputs = data.get("inputs")
model.outputs = data.get("outputs")
model.error = data.get("error_message") or data.get("error")
# Handle datetime fields
started_at = data.get("started_at") or data.get("created_at")
if started_at:
if isinstance(started_at, str):
model.created_at = datetime.fromisoformat(started_at)
elif isinstance(started_at, (int, float)):
model.created_at = datetime.fromtimestamp(started_at)
else:
model.created_at = started_at
else:
# Provide default created_at if missing
model.created_at = datetime.now()
finished_at = data.get("finished_at")
if finished_at:
if isinstance(finished_at, str):
model.finished_at = datetime.fromisoformat(finished_at)
elif isinstance(finished_at, (int, float)):
model.finished_at = datetime.fromtimestamp(finished_at)
else:
model.finished_at = finished_at
# Compute elapsed_time from started_at and finished_at
# LogStore doesn't store elapsed_time, it's computed in WorkflowExecution domain entity
if model.finished_at and model.created_at:
model.elapsed_time = (model.finished_at - model.created_at).total_seconds()
else:
model.elapsed_time = float(data.get("elapsed_time", 0))
return model
class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
"""
LogStore implementation of APIWorkflowRunRepository.
Provides service-layer WorkflowRun database operations using LogStore SQL
with optimized query strategies:
- Use finished_at IS NOT NULL for deduplication (10-100x faster)
- Use window functions only when running status is required
- Proper time range filtering for LogStore queries
"""
def __init__(self, session_maker: sessionmaker | None = None):
"""
Initialize the repository with LogStore client.
Args:
session_maker: SQLAlchemy sessionmaker (unused, for compatibility with factory pattern)
"""
logger.debug("LogstoreAPIWorkflowRunRepository.__init__: initializing")
self.logstore_client = AliyunLogStore()
# Control flag for dual-read (fallback to PostgreSQL when LogStore returns no results)
# Set to True to enable fallback for safe migration from PostgreSQL to LogStore
# Set to False for new deployments without legacy data in PostgreSQL
self._enable_dual_read = os.environ.get("LOGSTORE_DUAL_READ_ENABLED", "true").lower() == "true"
def get_paginated_workflow_runs(
self,
tenant_id: str,
app_id: str,
triggered_from: WorkflowRunTriggeredFrom | Sequence[WorkflowRunTriggeredFrom],
limit: int = 20,
last_id: str | None = None,
status: str | None = None,
) -> InfiniteScrollPagination:
"""
Get paginated workflow runs with filtering.
Uses window function for deduplication to support both running and finished states.
Args:
tenant_id: Tenant identifier for multi-tenant isolation
app_id: Application identifier
triggered_from: Filter by trigger source(s)
limit: Maximum number of records to return (default: 20)
last_id: Cursor for pagination - ID of the last record from previous page
status: Optional filter by status
Returns:
InfiniteScrollPagination object
"""
logger.debug(
"get_paginated_workflow_runs: tenant_id=%s, app_id=%s, limit=%d, status=%s",
tenant_id,
app_id,
limit,
status,
)
# Convert triggered_from to list if needed
if isinstance(triggered_from, WorkflowRunTriggeredFrom):
triggered_from_list = [triggered_from]
else:
triggered_from_list = list(triggered_from)
# Build triggered_from filter
triggered_from_filter = " OR ".join([f"triggered_from='{tf.value}'" for tf in triggered_from_list])
# Build status filter
status_filter = f"AND status='{status}'" if status else ""
# Build last_id filter for pagination
# Note: This is simplified. In production, you'd need to track created_at from last record
last_id_filter = ""
if last_id:
# TODO: Implement proper cursor-based pagination with created_at
logger.warning("last_id pagination not fully implemented for LogStore")
# Use window function to get latest log_version of each workflow run
sql = f"""
SELECT * FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND ({triggered_from_filter})
{status_filter}
{last_id_filter}
) t
WHERE rn = 1
ORDER BY created_at DESC
LIMIT {limit + 1}
"""
try:
results = self.logstore_client.execute_sql(
sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore, from_time=None, to_time=None
)
# Check if there are more records
has_more = len(results) > limit
if has_more:
results = results[:limit]
# Convert results to WorkflowRun models
workflow_runs = [_dict_to_workflow_run(row) for row in results]
return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
except Exception:
logger.exception("Failed to get paginated workflow runs from LogStore")
raise
def get_workflow_run_by_id(
self,
tenant_id: str,
app_id: str,
run_id: str,
) -> WorkflowRun | None:
"""
Get a specific workflow run by ID with tenant and app isolation.
Uses query syntax to get raw logs and selects the one with max log_version in code.
Falls back to PostgreSQL if not found in LogStore (for data consistency during migration).
"""
logger.debug("get_workflow_run_by_id: tenant_id=%s, app_id=%s, run_id=%s", tenant_id, app_id, run_id)
try:
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
sql_query = f"""
SELECT * FROM (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_execution_logstore}"
WHERE id = '{run_id}' AND tenant_id = '{tenant_id}' AND app_id = '{app_id}' AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
"""
results = self.logstore_client.execute_sql(
sql=sql_query,
logstore=AliyunLogStore.workflow_execution_logstore,
)
else:
# Use SDK with LogStore query syntax
query = f"id: {run_id} and tenant_id: {tenant_id} and app_id: {app_id}"
from_time = 0
to_time = int(time.time()) # now
results = self.logstore_client.get_logs(
logstore=AliyunLogStore.workflow_execution_logstore,
from_time=from_time,
to_time=to_time,
query=query,
line=100,
reverse=False,
)
if not results:
# Fallback to PostgreSQL for records created before LogStore migration
if self._enable_dual_read:
logger.debug(
"WorkflowRun not found in LogStore, falling back to PostgreSQL: "
"run_id=%s, tenant_id=%s, app_id=%s",
run_id,
tenant_id,
app_id,
)
return self._fallback_get_workflow_run_by_id_with_tenant(run_id, tenant_id, app_id)
return None
# For PG mode, results are already deduplicated by the SQL query
# For SDK mode, if multiple results, select the one with max log_version
if self.logstore_client.supports_pg_protocol or len(results) == 1:
return _dict_to_workflow_run(results[0])
else:
max_result = max(results, key=lambda x: int(x.get("log_version", 0)))
return _dict_to_workflow_run(max_result)
except Exception:
logger.exception("Failed to get workflow run by ID from LogStore: run_id=%s", run_id)
# Try PostgreSQL fallback on any error (only if dual-read is enabled)
if self._enable_dual_read:
try:
return self._fallback_get_workflow_run_by_id_with_tenant(run_id, tenant_id, app_id)
except Exception:
logger.exception(
"PostgreSQL fallback also failed: run_id=%s, tenant_id=%s, app_id=%s", run_id, tenant_id, app_id
)
raise
def _fallback_get_workflow_run_by_id_with_tenant(
self, run_id: str, tenant_id: str, app_id: str
) -> WorkflowRun | None:
"""Fallback to PostgreSQL query for records not in LogStore (with tenant isolation)."""
from sqlalchemy import select
from sqlalchemy.orm import Session
from extensions.ext_database import db
with Session(db.engine) as session:
stmt = select(WorkflowRun).where(
WorkflowRun.id == run_id, WorkflowRun.tenant_id == tenant_id, WorkflowRun.app_id == app_id
)
return session.scalar(stmt)
def get_workflow_run_by_id_without_tenant(
self,
run_id: str,
) -> WorkflowRun | None:
"""
Get a specific workflow run by ID without tenant/app context.
Uses query syntax to get raw logs and selects the one with max log_version.
Falls back to PostgreSQL if not found in LogStore (controlled by LOGSTORE_DUAL_READ_ENABLED).
"""
logger.debug("get_workflow_run_by_id_without_tenant: run_id=%s", run_id)
try:
# Check if PG protocol is supported
if self.logstore_client.supports_pg_protocol:
# Use PG protocol with SQL query (get latest version of record)
sql_query = f"""
SELECT * FROM (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_execution_logstore}"
WHERE id = '{run_id}' AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
"""
results = self.logstore_client.execute_sql(
sql=sql_query,
logstore=AliyunLogStore.workflow_execution_logstore,
)
else:
# Use SDK with LogStore query syntax
query = f"id: {run_id}"
from_time = 0
to_time = int(time.time()) # now
results = self.logstore_client.get_logs(
logstore=AliyunLogStore.workflow_execution_logstore,
from_time=from_time,
to_time=to_time,
query=query,
line=100,
reverse=False,
)
if not results:
# Fallback to PostgreSQL for records created before LogStore migration
if self._enable_dual_read:
logger.debug("WorkflowRun not found in LogStore, falling back to PostgreSQL: run_id=%s", run_id)
return self._fallback_get_workflow_run_by_id(run_id)
return None
# For PG mode, results are already deduplicated by the SQL query
# For SDK mode, if multiple results, select the one with max log_version
if self.logstore_client.supports_pg_protocol or len(results) == 1:
return _dict_to_workflow_run(results[0])
else:
max_result = max(results, key=lambda x: int(x.get("log_version", 0)))
return _dict_to_workflow_run(max_result)
except Exception:
logger.exception("Failed to get workflow run without tenant: run_id=%s", run_id)
# Try PostgreSQL fallback on any error (only if dual-read is enabled)
if self._enable_dual_read:
try:
return self._fallback_get_workflow_run_by_id(run_id)
except Exception:
logger.exception("PostgreSQL fallback also failed: run_id=%s", run_id)
raise
def _fallback_get_workflow_run_by_id(self, run_id: str) -> WorkflowRun | None:
"""Fallback to PostgreSQL query for records not in LogStore."""
from sqlalchemy import select
from sqlalchemy.orm import Session
from extensions.ext_database import db
with Session(db.engine) as session:
stmt = select(WorkflowRun).where(WorkflowRun.id == run_id)
return session.scalar(stmt)
def get_workflow_runs_count(
self,
tenant_id: str,
app_id: str,
triggered_from: str,
status: str | None = None,
time_range: str | None = None,
) -> dict[str, int]:
"""
Get workflow runs count statistics grouped by status.
Optimization: Use finished_at IS NOT NULL for completed runs (10-50x faster)
"""
logger.debug(
"get_workflow_runs_count: tenant_id=%s, app_id=%s, triggered_from=%s, status=%s",
tenant_id,
app_id,
triggered_from,
status,
)
# Build time range filter
time_filter = ""
if time_range:
# TODO: Parse time_range and convert to from_time/to_time
logger.warning("time_range filter not implemented")
# If status is provided, simple count
if status:
if status == "running":
# Running status requires window function
sql = f"""
SELECT COUNT(*) as count
FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
AND status='running'
{time_filter}
) t
WHERE rn = 1
"""
else:
# Finished status uses optimized filter
sql = f"""
SELECT COUNT(DISTINCT id) as count
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
AND status='{status}'
AND finished_at IS NOT NULL
{time_filter}
"""
try:
results = self.logstore_client.execute_sql(
sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore
)
count = results[0]["count"] if results and len(results) > 0 else 0
return {
"total": count,
"running": count if status == "running" else 0,
"succeeded": count if status == "succeeded" else 0,
"failed": count if status == "failed" else 0,
"stopped": count if status == "stopped" else 0,
"partial-succeeded": count if status == "partial-succeeded" else 0,
}
except Exception:
logger.exception("Failed to get workflow runs count")
raise
# No status filter - get counts grouped by status
# Use optimized query for finished runs, separate query for running
try:
# Count finished runs grouped by status
finished_sql = f"""
SELECT status, COUNT(DISTINCT id) as count
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY status
"""
# Count running runs
running_sql = f"""
SELECT COUNT(*) as count
FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) AS rn
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
AND status='running'
{time_filter}
) t
WHERE rn = 1
"""
finished_results = self.logstore_client.execute_sql(
sql=finished_sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore
)
running_results = self.logstore_client.execute_sql(
sql=running_sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore
)
# Build response
status_counts = {
"running": 0,
"succeeded": 0,
"failed": 0,
"stopped": 0,
"partial-succeeded": 0,
}
total = 0
for result in finished_results:
status_val = result.get("status")
count = result.get("count", 0)
if status_val in status_counts:
status_counts[status_val] = count
total += count
# Add running count
running_count = running_results[0]["count"] if running_results and len(running_results) > 0 else 0
status_counts["running"] = running_count
total += running_count
return {"total": total} | status_counts
except Exception:
logger.exception("Failed to get workflow runs count")
raise
def get_daily_runs_statistics(
self,
tenant_id: str,
app_id: str,
triggered_from: str,
start_date: datetime | None = None,
end_date: datetime | None = None,
timezone: str = "UTC",
) -> list[DailyRunsStats]:
"""
Get daily runs statistics using optimized query.
Optimization: Use finished_at IS NOT NULL + COUNT(DISTINCT id) (20-100x faster)
"""
logger.debug(
"get_daily_runs_statistics: tenant_id=%s, app_id=%s, triggered_from=%s", tenant_id, app_id, triggered_from
)
# Build time range filter
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
if end_date:
time_filter += f" AND __time__ < to_unixtime(from_iso8601_timestamp('{end_date.isoformat()}'))"
# Optimized query: Use finished_at filter to avoid window function
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT id) as runs
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
ORDER BY date
"""
try:
results = self.logstore_client.execute_sql(
sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore
)
response_data = []
for row in results:
response_data.append({"date": str(row.get("date", "")), "runs": row.get("runs", 0)})
return cast(list[DailyRunsStats], response_data)
except Exception:
logger.exception("Failed to get daily runs statistics")
raise
def get_daily_terminals_statistics(
self,
tenant_id: str,
app_id: str,
triggered_from: str,
start_date: datetime | None = None,
end_date: datetime | None = None,
timezone: str = "UTC",
) -> list[DailyTerminalsStats]:
"""
Get daily terminals statistics using optimized query.
Optimization: Use finished_at IS NOT NULL + COUNT(DISTINCT created_by) (20-100x faster)
"""
logger.debug(
"get_daily_terminals_statistics: tenant_id=%s, app_id=%s, triggered_from=%s",
tenant_id,
app_id,
triggered_from,
)
# Build time range filter
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
if end_date:
time_filter += f" AND __time__ < to_unixtime(from_iso8601_timestamp('{end_date.isoformat()}'))"
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, COUNT(DISTINCT created_by) as terminal_count
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
ORDER BY date
"""
try:
results = self.logstore_client.execute_sql(
sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore
)
response_data = []
for row in results:
response_data.append({"date": str(row.get("date", "")), "terminal_count": row.get("terminal_count", 0)})
return cast(list[DailyTerminalsStats], response_data)
except Exception:
logger.exception("Failed to get daily terminals statistics")
raise
def get_daily_token_cost_statistics(
self,
tenant_id: str,
app_id: str,
triggered_from: str,
start_date: datetime | None = None,
end_date: datetime | None = None,
timezone: str = "UTC",
) -> list[DailyTokenCostStats]:
"""
Get daily token cost statistics using optimized query.
Optimization: Use finished_at IS NOT NULL + SUM(total_tokens) (20-100x faster)
"""
logger.debug(
"get_daily_token_cost_statistics: tenant_id=%s, app_id=%s, triggered_from=%s",
tenant_id,
app_id,
triggered_from,
)
# Build time range filter
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
if end_date:
time_filter += f" AND __time__ < to_unixtime(from_iso8601_timestamp('{end_date.isoformat()}'))"
sql = f"""
SELECT DATE(from_unixtime(__time__)) as date, SUM(total_tokens) as token_count
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date
ORDER BY date
"""
try:
results = self.logstore_client.execute_sql(
sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore
)
response_data = []
for row in results:
response_data.append({"date": str(row.get("date", "")), "token_count": row.get("token_count", 0)})
return cast(list[DailyTokenCostStats], response_data)
except Exception:
logger.exception("Failed to get daily token cost statistics")
raise
def get_average_app_interaction_statistics(
self,
tenant_id: str,
app_id: str,
triggered_from: str,
start_date: datetime | None = None,
end_date: datetime | None = None,
timezone: str = "UTC",
) -> list[AverageInteractionStats]:
"""
Get average app interaction statistics using optimized query.
Optimization: Use finished_at IS NOT NULL + AVG (20-100x faster)
"""
logger.debug(
"get_average_app_interaction_statistics: tenant_id=%s, app_id=%s, triggered_from=%s",
tenant_id,
app_id,
triggered_from,
)
# Build time range filter
time_filter = ""
if start_date:
time_filter += f" AND __time__ >= to_unixtime(from_iso8601_timestamp('{start_date.isoformat()}'))"
if end_date:
time_filter += f" AND __time__ < to_unixtime(from_iso8601_timestamp('{end_date.isoformat()}'))"
sql = f"""
SELECT
AVG(sub.interactions) AS interactions,
sub.date
FROM (
SELECT
DATE(from_unixtime(__time__)) AS date,
created_by,
COUNT(DISTINCT id) AS interactions
FROM {AliyunLogStore.workflow_execution_logstore}
WHERE tenant_id='{tenant_id}'
AND app_id='{app_id}'
AND triggered_from='{triggered_from}'
AND finished_at IS NOT NULL
{time_filter}
GROUP BY date, created_by
) sub
GROUP BY sub.date
"""
try:
results = self.logstore_client.execute_sql(
sql=sql, query="*", logstore=AliyunLogStore.workflow_execution_logstore
)
response_data = []
for row in results:
response_data.append(
{
"date": str(row.get("date", "")),
"interactions": float(row.get("interactions", 0)),
}
)
return cast(list[AverageInteractionStats], response_data)
except Exception:
logger.exception("Failed to get average app interaction statistics")
raise

View File

@ -0,0 +1,164 @@
import json
import logging
import os
import time
from typing import Union
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.entities import WorkflowExecution
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from extensions.logstore.aliyun_logstore import AliyunLogStore
from libs.helper import extract_tenant_id
from models import (
Account,
CreatorUserRole,
EndUser,
)
from models.enums import WorkflowRunTriggeredFrom
logger = logging.getLogger(__name__)
class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
def __init__(
self,
session_factory: sessionmaker | Engine,
user: Union[Account, EndUser],
app_id: str | None,
triggered_from: WorkflowRunTriggeredFrom | None,
):
"""
Initialize the repository with a SQLAlchemy sessionmaker or engine and context information.
Args:
session_factory: SQLAlchemy sessionmaker or engine for creating sessions
user: Account or EndUser object containing tenant_id, user ID, and role information
app_id: App ID for filtering by application (can be None)
triggered_from: Source of the execution trigger (DEBUGGING or APP_RUN)
"""
logger.debug(
"LogstoreWorkflowExecutionRepository.__init__: app_id=%s, triggered_from=%s", app_id, triggered_from
)
# Initialize LogStore client
# Note: Project/logstore/index initialization is done at app startup via ext_logstore
self.logstore_client = AliyunLogStore()
# Extract tenant_id from user
tenant_id = extract_tenant_id(user)
if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id
# Store app context
self._app_id = app_id
# Extract user context
self._triggered_from = triggered_from
self._creator_user_id = user.id
# Determine user role based on user type
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
# Initialize SQL repository for dual-write support
self.sql_repository = SQLAlchemyWorkflowExecutionRepository(session_factory, user, app_id, triggered_from)
# Control flag for dual-write (write to both LogStore and SQL database)
# Set to True to enable dual-write for safe migration, False to use LogStore only
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true"
def _to_logstore_model(self, domain_model: WorkflowExecution) -> list[tuple[str, str]]:
"""
Convert a domain model to a logstore model (List[Tuple[str, str]]).
Args:
domain_model: The domain model to convert
Returns:
The logstore model as a list of key-value tuples
"""
logger.debug(
"_to_logstore_model: id=%s, workflow_id=%s, status=%s",
domain_model.id_,
domain_model.workflow_id,
domain_model.status.value,
)
# Use values from constructor if provided
if not self._triggered_from:
raise ValueError("triggered_from is required in repository constructor")
if not self._creator_user_id:
raise ValueError("created_by is required in repository constructor")
if not self._creator_user_role:
raise ValueError("created_by_role is required in repository constructor")
# Generate log_version as nanosecond timestamp for record versioning
log_version = str(time.time_ns())
logstore_model = [
("id", domain_model.id_),
("log_version", log_version), # Add log_version field for append-only writes
("tenant_id", self._tenant_id),
("app_id", self._app_id or ""),
("workflow_id", domain_model.workflow_id),
(
"triggered_from",
self._triggered_from.value if hasattr(self._triggered_from, "value") else str(self._triggered_from),
),
("type", domain_model.workflow_type.value),
("version", domain_model.workflow_version),
("graph", json.dumps(domain_model.graph, ensure_ascii=False) if domain_model.graph else "{}"),
("inputs", json.dumps(domain_model.inputs, ensure_ascii=False) if domain_model.inputs else "{}"),
("outputs", json.dumps(domain_model.outputs, ensure_ascii=False) if domain_model.outputs else "{}"),
("status", domain_model.status.value),
("error_message", domain_model.error_message or ""),
("total_tokens", str(domain_model.total_tokens)),
("total_steps", str(domain_model.total_steps)),
("exceptions_count", str(domain_model.exceptions_count)),
(
"created_by_role",
self._creator_user_role.value
if hasattr(self._creator_user_role, "value")
else str(self._creator_user_role),
),
("created_by", self._creator_user_id),
("started_at", domain_model.started_at.isoformat() if domain_model.started_at else ""),
("finished_at", domain_model.finished_at.isoformat() if domain_model.finished_at else ""),
]
return logstore_model
def save(self, execution: WorkflowExecution) -> None:
"""
Save or update a WorkflowExecution domain entity to the logstore.
This method serves as a domain-to-logstore adapter that:
1. Converts the domain entity to its logstore representation
2. Persists the logstore model using Aliyun SLS
3. Maintains proper multi-tenancy by including tenant context during conversion
4. Optionally writes to SQL database for dual-write support (controlled by LOGSTORE_DUAL_WRITE_ENABLED)
Args:
execution: The WorkflowExecution domain entity to persist
"""
logger.debug(
"save: id=%s, workflow_id=%s, status=%s", execution.id_, execution.workflow_id, execution.status.value
)
try:
logstore_model = self._to_logstore_model(execution)
self.logstore_client.put_log(AliyunLogStore.workflow_execution_logstore, logstore_model)
logger.debug("Saved workflow execution to logstore: id=%s", execution.id_)
except Exception:
logger.exception("Failed to save workflow execution to logstore: id=%s", execution.id_)
raise
# Dual-write to SQL database if enabled (for safe migration)
if self._enable_dual_write:
try:
self.sql_repository.save(execution)
logger.debug("Dual-write: saved workflow execution to SQL database: id=%s", execution.id_)
except Exception:
logger.exception("Failed to dual-write workflow execution to SQL database: id=%s", execution.id_)
# Don't raise - LogStore write succeeded, SQL is just a backup

View File

@ -0,0 +1,366 @@
"""
LogStore implementation of the WorkflowNodeExecutionRepository.
This module provides a LogStore-based repository for WorkflowNodeExecution entities,
using Aliyun SLS LogStore with append-only writes and version control.
"""
import json
import logging
import os
import time
from collections.abc import Sequence
from datetime import datetime
from typing import Any, Union
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.model_runtime.utils.encoders import jsonable_encoder
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeType
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.logstore.aliyun_logstore import AliyunLogStore
from libs.helper import extract_tenant_id
from models import (
Account,
CreatorUserRole,
EndUser,
WorkflowNodeExecutionTriggeredFrom,
)
logger = logging.getLogger(__name__)
def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecution:
"""
Convert LogStore result dictionary to WorkflowNodeExecution domain model.
Args:
data: Dictionary from LogStore query result
Returns:
WorkflowNodeExecution domain model instance
"""
logger.debug("_dict_to_workflow_node_execution: data keys=%s", list(data.keys())[:5])
# Parse JSON fields
inputs = json.loads(data.get("inputs", "{}"))
process_data = json.loads(data.get("process_data", "{}"))
outputs = json.loads(data.get("outputs", "{}"))
metadata = json.loads(data.get("execution_metadata", "{}"))
# Convert metadata to domain enum keys
domain_metadata = {}
for k, v in metadata.items():
try:
domain_metadata[WorkflowNodeExecutionMetadataKey(k)] = v
except ValueError:
# Skip invalid metadata keys
continue
# Convert status to domain enum
status = WorkflowNodeExecutionStatus(data.get("status", "running"))
# Parse datetime fields
created_at = datetime.fromisoformat(data.get("created_at", "")) if data.get("created_at") else datetime.now()
finished_at = datetime.fromisoformat(data.get("finished_at", "")) if data.get("finished_at") else None
return WorkflowNodeExecution(
id=data.get("id", ""),
node_execution_id=data.get("node_execution_id"),
workflow_id=data.get("workflow_id", ""),
workflow_execution_id=data.get("workflow_run_id"),
index=int(data.get("index", 0)),
predecessor_node_id=data.get("predecessor_node_id"),
node_id=data.get("node_id", ""),
node_type=NodeType(data.get("node_type", "start")),
title=data.get("title", ""),
inputs=inputs,
process_data=process_data,
outputs=outputs,
status=status,
error=data.get("error"),
elapsed_time=float(data.get("elapsed_time", 0.0)),
metadata=domain_metadata,
created_at=created_at,
finished_at=finished_at,
)
class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
"""
LogStore implementation of the WorkflowNodeExecutionRepository interface.
This implementation uses Aliyun SLS LogStore with an append-only write strategy:
- Each save() operation appends a new record with a version timestamp
- Updates are simulated by writing new records with higher version numbers
- Queries retrieve the latest version using finished_at IS NOT NULL filter
- Multi-tenancy is maintained through tenant_id filtering
Version Strategy:
version = time.time_ns() # Nanosecond timestamp for unique ordering
"""
def __init__(
self,
session_factory: sessionmaker | Engine,
user: Union[Account, EndUser],
app_id: str | None,
triggered_from: WorkflowNodeExecutionTriggeredFrom | None,
):
"""
Initialize the repository with a SQLAlchemy sessionmaker or engine and context information.
Args:
session_factory: SQLAlchemy sessionmaker or engine for creating sessions
user: Account or EndUser object containing tenant_id, user ID, and role information
app_id: App ID for filtering by application (can be None)
triggered_from: Source of the execution trigger (SINGLE_STEP or WORKFLOW_RUN)
"""
logger.debug(
"LogstoreWorkflowNodeExecutionRepository.__init__: app_id=%s, triggered_from=%s", app_id, triggered_from
)
# Initialize LogStore client
self.logstore_client = AliyunLogStore()
# Extract tenant_id from user
tenant_id = extract_tenant_id(user)
if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id
# Store app context
self._app_id = app_id
# Extract user context
self._triggered_from = triggered_from
self._creator_user_id = user.id
# Determine user role based on user type
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
# Initialize SQL repository for dual-write support
self.sql_repository = SQLAlchemyWorkflowNodeExecutionRepository(session_factory, user, app_id, triggered_from)
# Control flag for dual-write (write to both LogStore and SQL database)
# Set to True to enable dual-write for safe migration, False to use LogStore only
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true"
def _to_logstore_model(self, domain_model: WorkflowNodeExecution) -> Sequence[tuple[str, str]]:
logger.debug(
"_to_logstore_model: id=%s, node_id=%s, status=%s",
domain_model.id,
domain_model.node_id,
domain_model.status.value,
)
if not self._triggered_from:
raise ValueError("triggered_from is required in repository constructor")
if not self._creator_user_id:
raise ValueError("created_by is required in repository constructor")
if not self._creator_user_role:
raise ValueError("created_by_role is required in repository constructor")
# Generate log_version as nanosecond timestamp for record versioning
log_version = str(time.time_ns())
json_converter = WorkflowRuntimeTypeConverter()
logstore_model = [
("id", domain_model.id),
("log_version", log_version), # Add log_version field for append-only writes
("tenant_id", self._tenant_id),
("app_id", self._app_id or ""),
("workflow_id", domain_model.workflow_id),
(
"triggered_from",
self._triggered_from.value if hasattr(self._triggered_from, "value") else str(self._triggered_from),
),
("workflow_run_id", domain_model.workflow_execution_id or ""),
("index", str(domain_model.index)),
("predecessor_node_id", domain_model.predecessor_node_id or ""),
("node_execution_id", domain_model.node_execution_id or ""),
("node_id", domain_model.node_id),
("node_type", domain_model.node_type.value),
("title", domain_model.title),
(
"inputs",
json.dumps(json_converter.to_json_encodable(domain_model.inputs), ensure_ascii=False)
if domain_model.inputs
else "{}",
),
(
"process_data",
json.dumps(json_converter.to_json_encodable(domain_model.process_data), ensure_ascii=False)
if domain_model.process_data
else "{}",
),
(
"outputs",
json.dumps(json_converter.to_json_encodable(domain_model.outputs), ensure_ascii=False)
if domain_model.outputs
else "{}",
),
("status", domain_model.status.value),
("error", domain_model.error or ""),
("elapsed_time", str(domain_model.elapsed_time)),
(
"execution_metadata",
json.dumps(jsonable_encoder(domain_model.metadata), ensure_ascii=False)
if domain_model.metadata
else "{}",
),
("created_at", domain_model.created_at.isoformat() if domain_model.created_at else ""),
("created_by_role", self._creator_user_role.value),
("created_by", self._creator_user_id),
("finished_at", domain_model.finished_at.isoformat() if domain_model.finished_at else ""),
]
return logstore_model
def save(self, execution: WorkflowNodeExecution) -> None:
"""
Save or update a NodeExecution domain entity to LogStore.
This method serves as a domain-to-logstore adapter that:
1. Converts the domain entity to its logstore representation
2. Appends a new record with a log_version timestamp
3. Maintains proper multi-tenancy by including tenant context during conversion
4. Optionally writes to SQL database for dual-write support (controlled by LOGSTORE_DUAL_WRITE_ENABLED)
Each save operation creates a new record. Updates are simulated by writing
new records with higher log_version numbers.
Args:
execution: The NodeExecution domain entity to persist
"""
logger.debug(
"save: id=%s, node_execution_id=%s, status=%s",
execution.id,
execution.node_execution_id,
execution.status.value,
)
try:
logstore_model = self._to_logstore_model(execution)
self.logstore_client.put_log(AliyunLogStore.workflow_node_execution_logstore, logstore_model)
logger.debug(
"Saved node execution to LogStore: id=%s, node_execution_id=%s, status=%s",
execution.id,
execution.node_execution_id,
execution.status.value,
)
except Exception:
logger.exception(
"Failed to save node execution to LogStore: id=%s, node_execution_id=%s",
execution.id,
execution.node_execution_id,
)
raise
# Dual-write to SQL database if enabled (for safe migration)
if self._enable_dual_write:
try:
self.sql_repository.save(execution)
logger.debug("Dual-write: saved node execution to SQL database: id=%s", execution.id)
except Exception:
logger.exception("Failed to dual-write node execution to SQL database: id=%s", execution.id)
# Don't raise - LogStore write succeeded, SQL is just a backup
def save_execution_data(self, execution: WorkflowNodeExecution) -> None:
"""
Save or update the inputs, process_data, or outputs associated with a specific
node_execution record.
For LogStore implementation, this is similar to save() since we always write
complete records. We append a new record with updated data fields.
Args:
execution: The NodeExecution instance with data to save
"""
logger.debug("save_execution_data: id=%s, node_execution_id=%s", execution.id, execution.node_execution_id)
# In LogStore, we simply write a new complete record with the data
# The log_version timestamp will ensure this is treated as the latest version
self.save(execution)
def get_by_workflow_run(
self,
workflow_run_id: str,
order_config: OrderConfig | None = None,
) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all NodeExecution instances for a specific workflow run.
Uses LogStore SQL query with finished_at IS NOT NULL filter for deduplication.
This ensures we only get the final version of each node execution.
Args:
workflow_run_id: The workflow run ID
order_config: Optional configuration for ordering results
order_config.order_by: List of fields to order by (e.g., ["index", "created_at"])
order_config.order_direction: Direction to order ("asc" or "desc")
Returns:
A list of NodeExecution instances
Note:
This method filters by finished_at IS NOT NULL to avoid duplicates from
version updates. For complete history including intermediate states,
a different query strategy would be needed.
"""
logger.debug("get_by_workflow_run: workflow_run_id=%s, order_config=%s", workflow_run_id, order_config)
# Build SQL query with deduplication using finished_at IS NOT NULL
# This optimization avoids window functions for common case where we only
# want the final state of each node execution
# Build ORDER BY clause
order_clause = ""
if order_config and order_config.order_by:
order_fields = []
for field in order_config.order_by:
# Map domain field names to logstore field names if needed
field_name = field
if order_config.order_direction == "desc":
order_fields.append(f"{field_name} DESC")
else:
order_fields.append(f"{field_name} ASC")
if order_fields:
order_clause = "ORDER BY " + ", ".join(order_fields)
sql = f"""
SELECT *
FROM {AliyunLogStore.workflow_node_execution_logstore}
WHERE workflow_run_id='{workflow_run_id}'
AND tenant_id='{self._tenant_id}'
AND finished_at IS NOT NULL
"""
if self._app_id:
sql += f" AND app_id='{self._app_id}'"
if order_clause:
sql += f" {order_clause}"
try:
# Execute SQL query
results = self.logstore_client.execute_sql(
sql=sql,
query="*",
logstore=AliyunLogStore.workflow_node_execution_logstore,
)
# Convert LogStore results to WorkflowNodeExecution domain models
executions = []
for row in results:
try:
execution = _dict_to_workflow_node_execution(row)
executions.append(execution)
except Exception as e:
logger.warning("Failed to convert row to WorkflowNodeExecution: %s, row=%s", e, row)
continue
return executions
except Exception:
logger.exception("Failed to retrieve node executions from LogStore: workflow_run_id=%s", workflow_run_id)
raise

View File

@ -1,5 +1,4 @@
import functools
import os
from collections.abc import Callable
from typing import Any, TypeVar, cast
@ -7,22 +6,13 @@ from opentelemetry.trace import get_tracer
from configs import dify_config
from extensions.otel.decorators.handler import SpanHandler
from extensions.otel.runtime import is_instrument_flag_enabled
T = TypeVar("T", bound=Callable[..., Any])
_HANDLER_INSTANCES: dict[type[SpanHandler], SpanHandler] = {SpanHandler: SpanHandler()}
def _is_instrument_flag_enabled() -> bool:
"""
Check if external instrumentation is enabled via environment variable.
Third-party non-invasive instrumentation agents set this flag to coordinate
with Dify's manual OpenTelemetry instrumentation.
"""
return os.getenv("ENABLE_OTEL_FOR_INSTRUMENT", "").strip().lower() == "true"
def _get_handler_instance(handler_class: type[SpanHandler]) -> SpanHandler:
"""Get or create a singleton instance of the handler class."""
if handler_class not in _HANDLER_INSTANCES:
@ -43,7 +33,7 @@ def trace_span(handler_class: type[SpanHandler] | None = None) -> Callable[[T],
def decorator(func: T) -> T:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
if not (dify_config.ENABLE_OTEL or _is_instrument_flag_enabled()):
if not (dify_config.ENABLE_OTEL or is_instrument_flag_enabled()):
return func(*args, **kwargs)
handler = _get_handler_instance(handler_class or SpanHandler)

View File

@ -1,4 +1,5 @@
import logging
import os
import sys
from typing import Union
@ -71,3 +72,13 @@ def init_celery_worker(*args, **kwargs):
if dify_config.DEBUG:
logger.info("Initializing OpenTelemetry for Celery worker")
CeleryInstrumentor(tracer_provider=tracer_provider, meter_provider=metric_provider).instrument()
def is_instrument_flag_enabled() -> bool:
"""
Check if external instrumentation is enabled via environment variable.
Third-party non-invasive instrumentation agents set this flag to coordinate
with Dify's manual OpenTelemetry instrumentation.
"""
return os.getenv("ENABLE_OTEL_FOR_INSTRUMENT", "").strip().lower() == "true"

66
api/libs/encryption.py Normal file
View File

@ -0,0 +1,66 @@
"""
Field Encoding/Decoding Utilities
Provides Base64 decoding for sensitive fields (password, verification code)
received from the frontend.
Note: This uses Base64 encoding for obfuscation, not cryptographic encryption.
Real security relies on HTTPS for transport layer encryption.
"""
import base64
import logging
logger = logging.getLogger(__name__)
class FieldEncryption:
"""Handle decoding of sensitive fields during transmission"""
@classmethod
def decrypt_field(cls, encoded_text: str) -> str | None:
"""
Decode Base64 encoded field from frontend.
Args:
encoded_text: Base64 encoded text from frontend
Returns:
Decoded plaintext, or None if decoding fails
"""
try:
# Decode base64
decoded_bytes = base64.b64decode(encoded_text)
decoded_text = decoded_bytes.decode("utf-8")
logger.debug("Field decoding successful")
return decoded_text
except Exception:
# Decoding failed - return None to trigger error in caller
return None
@classmethod
def decrypt_password(cls, encrypted_password: str) -> str | None:
"""
Decrypt password field
Args:
encrypted_password: Encrypted password from frontend
Returns:
Decrypted password or None if decryption fails
"""
return cls.decrypt_field(encrypted_password)
@classmethod
def decrypt_verification_code(cls, encrypted_code: str) -> str | None:
"""
Decrypt verification code field
Args:
encrypted_code: Encrypted code from frontend
Returns:
Decrypted code or None if decryption fails
"""
return cls.decrypt_field(encrypted_code)

View File

@ -131,28 +131,12 @@ class ExternalApi(Api):
}
def __init__(self, app: Blueprint | Flask, *args, **kwargs):
import logging
import os
kwargs.setdefault("authorizations", self._authorizations)
kwargs.setdefault("security", "Bearer")
# Security: Use computed swagger_ui_enabled which respects DEPLOY_ENV
swagger_enabled = dify_config.swagger_ui_enabled
kwargs["add_specs"] = swagger_enabled
kwargs["doc"] = dify_config.SWAGGER_UI_PATH if swagger_enabled else False
kwargs["add_specs"] = dify_config.SWAGGER_UI_ENABLED
kwargs["doc"] = dify_config.SWAGGER_UI_PATH if dify_config.SWAGGER_UI_ENABLED else False
# manual separate call on construction and init_app to ensure configs in kwargs effective
super().__init__(app=None, *args, **kwargs)
self.init_app(app, **kwargs)
register_external_error_handlers(self)
# Security: Log warning when Swagger is enabled in production environment
deploy_env = os.environ.get("DEPLOY_ENV", "PRODUCTION")
if swagger_enabled and deploy_env.upper() == "PRODUCTION":
logger = logging.getLogger(__name__)
logger.warning(
"SECURITY WARNING: Swagger UI is ENABLED in PRODUCTION environment. "
"This may expose sensitive API documentation. "
"Set SWAGGER_UI_ENABLED=false or remove the explicit setting to disable."
)

View File

@ -184,7 +184,7 @@ def timezone(timezone_string):
def convert_datetime_to_date(field, target_timezone: str = ":tz"):
if dify_config.DB_TYPE == "postgresql":
return f"DATE(DATE_TRUNC('day', {field} AT TIME ZONE 'UTC' AT TIME ZONE {target_timezone}))"
elif dify_config.DB_TYPE == "mysql":
elif dify_config.DB_TYPE in ["mysql", "oceanbase", "seekdb"]:
return f"DATE(CONVERT_TZ({field}, 'UTC', {target_timezone}))"
else:
raise NotImplementedError(f"Unsupported database type: {dify_config.DB_TYPE}")

View File

@ -0,0 +1,31 @@
"""add type column not null default tool
Revision ID: 03ea244985ce
Revises: d57accd375ae
Create Date: 2025-12-16 18:17:12.193877
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '03ea244985ce'
down_revision = 'd57accd375ae'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('pipeline_recommended_plugins', schema=None) as batch_op:
batch_op.add_column(sa.Column('type', sa.String(length=50), server_default=sa.text("'tool'"), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('pipeline_recommended_plugins', schema=None) as batch_op:
batch_op.drop_column('type')
# ### end Alembic commands ###

View File

@ -1532,6 +1532,7 @@ class PipelineRecommendedPlugin(TypeBase):
)
plugin_id: Mapped[str] = mapped_column(LongText, nullable=False)
provider_name: Mapped[str] = mapped_column(LongText, nullable=False)
type: Mapped[str] = mapped_column(sa.String(50), nullable=False, server_default=sa.text("'tool'"))
position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True)
created_at: Mapped[datetime] = mapped_column(

View File

@ -4,6 +4,7 @@ version = "1.11.1"
requires-python = ">=3.11,<3.13"
dependencies = [
"aliyun-log-python-sdk~=0.9.37",
"arize-phoenix-otel~=0.9.2",
"azure-identity==1.16.1",
"beautifulsoup4==4.12.2",
@ -11,7 +12,7 @@ dependencies = [
"bs4~=0.0.1",
"cachetools~=5.3.0",
"celery~=5.5.2",
"charset-normalizer>=3.4.4",
"chardet~=5.1.0",
"flask~=3.1.2",
"flask-compress>=1.17,<1.18",
"flask-cors~=6.0.0",
@ -91,7 +92,6 @@ dependencies = [
"weaviate-client==4.17.0",
"apscheduler>=3.11.0",
"weave>=0.52.16",
"jsonschema>=4.25.1",
]
# Before adding new dependency, consider place it in
# alphabet order (a-z) and suitable group.

View File

@ -1248,14 +1248,13 @@ class RagPipelineService:
session.commit()
return workflow_node_execution_db_model
def get_recommended_plugins(self) -> dict:
def get_recommended_plugins(self, type: str) -> dict:
# Query active recommended plugins
pipeline_recommended_plugins = (
db.session.query(PipelineRecommendedPlugin)
.where(PipelineRecommendedPlugin.active == True)
.order_by(PipelineRecommendedPlugin.position.asc())
.all()
)
query = db.session.query(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True)
if type and type != "all":
query = query.where(PipelineRecommendedPlugin.type == type)
pipeline_recommended_plugins = query.order_by(PipelineRecommendedPlugin.position.asc()).all()
if not pipeline_recommended_plugins:
return {

View File

@ -33,6 +33,11 @@ from services.errors.app import QuotaExceededError
from services.trigger.app_trigger_service import AppTriggerService
from services.workflow.entities import WebhookTriggerData
try:
import magic
except ImportError:
magic = None # type: ignore[assignment]
logger = logging.getLogger(__name__)
@ -317,7 +322,8 @@ class WebhookService:
try:
file_content = request.get_data()
if file_content:
file_obj = cls._create_file_from_binary(file_content, "application/octet-stream", webhook_trigger)
mimetype = cls._detect_binary_mimetype(file_content)
file_obj = cls._create_file_from_binary(file_content, mimetype, webhook_trigger)
return {"raw": file_obj.to_dict()}, {}
else:
return {"raw": None}, {}
@ -341,6 +347,18 @@ class WebhookService:
body = {"raw": ""}
return body, {}
@staticmethod
def _detect_binary_mimetype(file_content: bytes) -> str:
"""Guess MIME type for binary payloads using python-magic when available."""
if magic is not None:
try:
detected = magic.from_buffer(file_content[:1024], mime=True)
if detected:
return detected
except Exception:
logger.debug("python-magic detection failed for octet-stream payload")
return "application/octet-stream"
@classmethod
def _process_file_uploads(
cls, files: Mapping[str, FileStorage], webhook_trigger: WorkflowWebhookTrigger

View File

@ -410,9 +410,12 @@ class VariableTruncator(BaseTruncator):
@overload
def _truncate_json_primitives(self, val: None, target_size: int) -> _PartResult[None]: ...
@overload
def _truncate_json_primitives(self, val: File, target_size: int) -> _PartResult[File]: ...
def _truncate_json_primitives(
self,
val: UpdatedVariable | str | list[object] | dict[str, object] | bool | int | float | None,
val: UpdatedVariable | File | str | list[object] | dict[str, object] | bool | int | float | None,
target_size: int,
) -> _PartResult[Any]:
"""Truncate a value within an object to fit within budget."""
@ -425,6 +428,9 @@ class VariableTruncator(BaseTruncator):
return self._truncate_array(val, target_size)
elif isinstance(val, dict):
return self._truncate_object(val, target_size)
elif isinstance(val, File):
# File objects should not be truncated, return as-is
return _PartResult(val, self.calculate_json_size(val), False)
elif val is None or isinstance(val, (bool, int, float)):
return _PartResult(val, self.calculate_json_size(val), False)
else:

View File

@ -113,16 +113,31 @@ class TestShardedRedisBroadcastChannelIntegration:
topic = broadcast_channel.topic(topic_name)
producer = topic.as_producer()
subscriptions = [topic.subscribe() for _ in range(subscriber_count)]
ready_events = [threading.Event() for _ in range(subscriber_count)]
def producer_thread():
time.sleep(0.2) # Allow all subscribers to connect
deadline = time.time() + 5.0
for ev in ready_events:
remaining = deadline - time.time()
if remaining <= 0:
break
if not ev.wait(timeout=max(0.0, remaining)):
pytest.fail("subscriber did not become ready before publish deadline")
producer.publish(message)
time.sleep(0.2)
for sub in subscriptions:
sub.close()
def consumer_thread(subscription: Subscription) -> list[bytes]:
def consumer_thread(subscription: Subscription, ready_event: threading.Event) -> list[bytes]:
received_msgs = []
# Prime subscription so the underlying Pub/Sub listener thread starts before publishing
try:
_ = subscription.receive(0.01)
except SubscriptionClosedError:
return received_msgs
finally:
ready_event.set()
while True:
try:
msg = subscription.receive(0.1)
@ -137,7 +152,10 @@ class TestShardedRedisBroadcastChannelIntegration:
with ThreadPoolExecutor(max_workers=subscriber_count + 1) as executor:
producer_future = executor.submit(producer_thread)
consumer_futures = [executor.submit(consumer_thread, subscription) for subscription in subscriptions]
consumer_futures = [
executor.submit(consumer_thread, subscription, ready_events[idx])
for idx, subscription in enumerate(subscriptions)
]
producer_future.result(timeout=10.0)
msgs_by_consumers = []

View File

@ -1,5 +1,6 @@
"""Test authentication security to prevent user enumeration."""
import base64
from unittest.mock import MagicMock, patch
import pytest
@ -11,6 +12,11 @@ from controllers.console.auth.error import AuthenticationFailedError
from controllers.console.auth.login import LoginApi
def encode_password(password: str) -> str:
"""Helper to encode password as Base64 for testing."""
return base64.b64encode(password.encode("utf-8")).decode()
class TestAuthenticationSecurity:
"""Test authentication endpoints for security against user enumeration."""
@ -42,7 +48,9 @@ class TestAuthenticationSecurity:
# Act
with self.app.test_request_context(
"/login", method="POST", json={"email": "nonexistent@example.com", "password": "WrongPass123!"}
"/login",
method="POST",
json={"email": "nonexistent@example.com", "password": encode_password("WrongPass123!")},
):
login_api = LoginApi()
@ -72,7 +80,9 @@ class TestAuthenticationSecurity:
# Act
with self.app.test_request_context(
"/login", method="POST", json={"email": "existing@example.com", "password": "WrongPass123!"}
"/login",
method="POST",
json={"email": "existing@example.com", "password": encode_password("WrongPass123!")},
):
login_api = LoginApi()
@ -104,7 +114,9 @@ class TestAuthenticationSecurity:
# Act
with self.app.test_request_context(
"/login", method="POST", json={"email": "nonexistent@example.com", "password": "WrongPass123!"}
"/login",
method="POST",
json={"email": "nonexistent@example.com", "password": encode_password("WrongPass123!")},
):
login_api = LoginApi()

View File

@ -8,6 +8,7 @@ This module tests the email code login mechanism including:
- Workspace creation for new users
"""
import base64
from unittest.mock import MagicMock, patch
import pytest
@ -25,6 +26,11 @@ from controllers.console.error import (
from services.errors.account import AccountRegisterError
def encode_code(code: str) -> str:
"""Helper to encode verification code as Base64 for testing."""
return base64.b64encode(code.encode("utf-8")).decode()
class TestEmailCodeLoginSendEmailApi:
"""Test cases for sending email verification codes."""
@ -290,7 +296,7 @@ class TestEmailCodeLoginApi:
with app.test_request_context(
"/email-code-login/validity",
method="POST",
json={"email": "test@example.com", "code": "123456", "token": "valid_token"},
json={"email": "test@example.com", "code": encode_code("123456"), "token": "valid_token"},
):
api = EmailCodeLoginApi()
response = api.post()
@ -339,7 +345,12 @@ class TestEmailCodeLoginApi:
with app.test_request_context(
"/email-code-login/validity",
method="POST",
json={"email": "newuser@example.com", "code": "123456", "token": "valid_token", "language": "en-US"},
json={
"email": "newuser@example.com",
"code": encode_code("123456"),
"token": "valid_token",
"language": "en-US",
},
):
api = EmailCodeLoginApi()
response = api.post()
@ -365,7 +376,7 @@ class TestEmailCodeLoginApi:
with app.test_request_context(
"/email-code-login/validity",
method="POST",
json={"email": "test@example.com", "code": "123456", "token": "invalid_token"},
json={"email": "test@example.com", "code": encode_code("123456"), "token": "invalid_token"},
):
api = EmailCodeLoginApi()
with pytest.raises(InvalidTokenError):
@ -388,7 +399,7 @@ class TestEmailCodeLoginApi:
with app.test_request_context(
"/email-code-login/validity",
method="POST",
json={"email": "different@example.com", "code": "123456", "token": "token"},
json={"email": "different@example.com", "code": encode_code("123456"), "token": "token"},
):
api = EmailCodeLoginApi()
with pytest.raises(InvalidEmailError):
@ -411,7 +422,7 @@ class TestEmailCodeLoginApi:
with app.test_request_context(
"/email-code-login/validity",
method="POST",
json={"email": "test@example.com", "code": "wrong_code", "token": "token"},
json={"email": "test@example.com", "code": encode_code("wrong_code"), "token": "token"},
):
api = EmailCodeLoginApi()
with pytest.raises(EmailCodeError):
@ -497,7 +508,7 @@ class TestEmailCodeLoginApi:
with app.test_request_context(
"/email-code-login/validity",
method="POST",
json={"email": "test@example.com", "code": "123456", "token": "token"},
json={"email": "test@example.com", "code": encode_code("123456"), "token": "token"},
):
api = EmailCodeLoginApi()
with pytest.raises(WorkspacesLimitExceeded):
@ -539,7 +550,7 @@ class TestEmailCodeLoginApi:
with app.test_request_context(
"/email-code-login/validity",
method="POST",
json={"email": "test@example.com", "code": "123456", "token": "token"},
json={"email": "test@example.com", "code": encode_code("123456"), "token": "token"},
):
api = EmailCodeLoginApi()
with pytest.raises(NotAllowedCreateWorkspace):

View File

@ -8,6 +8,7 @@ This module tests the core authentication endpoints including:
- Account status validation
"""
import base64
from unittest.mock import MagicMock, patch
import pytest
@ -28,6 +29,11 @@ from controllers.console.error import (
from services.errors.account import AccountLoginError, AccountPasswordError
def encode_password(password: str) -> str:
"""Helper to encode password as Base64 for testing."""
return base64.b64encode(password.encode("utf-8")).decode()
class TestLoginApi:
"""Test cases for the LoginApi endpoint."""
@ -106,7 +112,9 @@ class TestLoginApi:
# Act
with app.test_request_context(
"/login", method="POST", json={"email": "test@example.com", "password": "ValidPass123!"}
"/login",
method="POST",
json={"email": "test@example.com", "password": encode_password("ValidPass123!")},
):
login_api = LoginApi()
response = login_api.post()
@ -158,7 +166,11 @@ class TestLoginApi:
with app.test_request_context(
"/login",
method="POST",
json={"email": "test@example.com", "password": "ValidPass123!", "invite_token": "valid_token"},
json={
"email": "test@example.com",
"password": encode_password("ValidPass123!"),
"invite_token": "valid_token",
},
):
login_api = LoginApi()
response = login_api.post()
@ -186,7 +198,7 @@ class TestLoginApi:
# Act & Assert
with app.test_request_context(
"/login", method="POST", json={"email": "test@example.com", "password": "password"}
"/login", method="POST", json={"email": "test@example.com", "password": encode_password("password")}
):
login_api = LoginApi()
with pytest.raises(EmailPasswordLoginLimitError):
@ -209,7 +221,7 @@ class TestLoginApi:
# Act & Assert
with app.test_request_context(
"/login", method="POST", json={"email": "frozen@example.com", "password": "password"}
"/login", method="POST", json={"email": "frozen@example.com", "password": encode_password("password")}
):
login_api = LoginApi()
with pytest.raises(AccountInFreezeError):
@ -246,7 +258,7 @@ class TestLoginApi:
# Act & Assert
with app.test_request_context(
"/login", method="POST", json={"email": "test@example.com", "password": "WrongPass123!"}
"/login", method="POST", json={"email": "test@example.com", "password": encode_password("WrongPass123!")}
):
login_api = LoginApi()
with pytest.raises(AuthenticationFailedError):
@ -277,7 +289,7 @@ class TestLoginApi:
# Act & Assert
with app.test_request_context(
"/login", method="POST", json={"email": "banned@example.com", "password": "ValidPass123!"}
"/login", method="POST", json={"email": "banned@example.com", "password": encode_password("ValidPass123!")}
):
login_api = LoginApi()
with pytest.raises(AccountBannedError):
@ -322,7 +334,7 @@ class TestLoginApi:
# Act & Assert
with app.test_request_context(
"/login", method="POST", json={"email": "test@example.com", "password": "ValidPass123!"}
"/login", method="POST", json={"email": "test@example.com", "password": encode_password("ValidPass123!")}
):
login_api = LoginApi()
with pytest.raises(WorkspacesLimitExceeded):
@ -349,7 +361,11 @@ class TestLoginApi:
with app.test_request_context(
"/login",
method="POST",
json={"email": "different@example.com", "password": "ValidPass123!", "invite_token": "token"},
json={
"email": "different@example.com",
"password": encode_password("ValidPass123!"),
"invite_token": "token",
},
):
login_api = LoginApi()
with pytest.raises(InvalidEmailError):

View File

@ -0,0 +1,10 @@
import tempfile
from core.rag.extractor.helpers import FileEncoding, detect_file_encodings
def test_detect_file_encodings() -> None:
with tempfile.NamedTemporaryFile(mode="w+t", suffix=".txt") as temp:
temp.write("Shared data")
temp_path = temp.name
assert detect_file_encodings(temp_path) == [FileEncoding(encoding="utf_8", confidence=0.0, language="Unknown")]

View File

@ -0,0 +1,101 @@
"""
Shared fixtures for ObservabilityLayer tests.
"""
from unittest.mock import MagicMock, patch
import pytest
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
from opentelemetry.trace import set_tracer_provider
from core.workflow.enums import NodeType
@pytest.fixture
def memory_span_exporter():
"""Provide an in-memory span exporter for testing."""
return InMemorySpanExporter()
@pytest.fixture
def tracer_provider_with_memory_exporter(memory_span_exporter):
"""Provide a TracerProvider configured with memory exporter."""
import opentelemetry.trace as trace_api
trace_api._TRACER_PROVIDER = None
trace_api._TRACER_PROVIDER_SET_ONCE._done = False
provider = TracerProvider()
processor = SimpleSpanProcessor(memory_span_exporter)
provider.add_span_processor(processor)
set_tracer_provider(provider)
yield provider
provider.force_flush()
@pytest.fixture
def mock_start_node():
"""Create a mock Start Node."""
node = MagicMock()
node.id = "test-start-node-id"
node.title = "Start Node"
node.execution_id = "test-start-execution-id"
node.node_type = NodeType.START
return node
@pytest.fixture
def mock_llm_node():
"""Create a mock LLM Node."""
node = MagicMock()
node.id = "test-llm-node-id"
node.title = "LLM Node"
node.execution_id = "test-llm-execution-id"
node.node_type = NodeType.LLM
return node
@pytest.fixture
def mock_tool_node():
"""Create a mock Tool Node with tool-specific attributes."""
from core.tools.entities.tool_entities import ToolProviderType
from core.workflow.nodes.tool.entities import ToolNodeData
node = MagicMock()
node.id = "test-tool-node-id"
node.title = "Test Tool Node"
node.execution_id = "test-tool-execution-id"
node.node_type = NodeType.TOOL
tool_data = ToolNodeData(
title="Test Tool Node",
desc=None,
provider_id="test-provider-id",
provider_type=ToolProviderType.BUILT_IN,
provider_name="test-provider",
tool_name="test-tool",
tool_label="Test Tool",
tool_configurations={},
tool_parameters={},
)
node._node_data = tool_data
return node
@pytest.fixture
def mock_is_instrument_flag_enabled_false():
"""Mock is_instrument_flag_enabled to return False."""
with patch("core.workflow.graph_engine.layers.observability.is_instrument_flag_enabled", return_value=False):
yield
@pytest.fixture
def mock_is_instrument_flag_enabled_true():
"""Mock is_instrument_flag_enabled to return True."""
with patch("core.workflow.graph_engine.layers.observability.is_instrument_flag_enabled", return_value=True):
yield

View File

@ -0,0 +1,219 @@
"""
Tests for ObservabilityLayer.
Test coverage:
- Initialization and enable/disable logic
- Node span lifecycle (start, end, error handling)
- Parser integration (default and tool-specific)
- Graph lifecycle management
- Disabled mode behavior
"""
from unittest.mock import patch
import pytest
from opentelemetry.trace import StatusCode
from core.workflow.enums import NodeType
from core.workflow.graph_engine.layers.observability import ObservabilityLayer
class TestObservabilityLayerInitialization:
"""Test ObservabilityLayer initialization logic."""
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_initialization_when_otel_enabled(self, tracer_provider_with_memory_exporter):
"""Test that layer initializes correctly when OTel is enabled."""
layer = ObservabilityLayer()
assert not layer._is_disabled
assert layer._tracer is not None
assert NodeType.TOOL in layer._parsers
assert layer._default_parser is not None
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_true")
def test_initialization_when_instrument_flag_enabled(self, tracer_provider_with_memory_exporter):
"""Test that layer enables when instrument flag is enabled."""
layer = ObservabilityLayer()
assert not layer._is_disabled
assert layer._tracer is not None
assert NodeType.TOOL in layer._parsers
assert layer._default_parser is not None
class TestObservabilityLayerNodeSpanLifecycle:
"""Test node span creation and lifecycle management."""
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_node_span_created_and_ended(
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node
):
"""Test that span is created on node start and ended on node end."""
layer = ObservabilityLayer()
layer.on_graph_start()
layer.on_node_run_start(mock_llm_node)
layer.on_node_run_end(mock_llm_node, None)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1
assert spans[0].name == mock_llm_node.title
assert spans[0].status.status_code == StatusCode.OK
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_node_error_recorded_in_span(
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node
):
"""Test that node execution errors are recorded in span."""
layer = ObservabilityLayer()
layer.on_graph_start()
error = ValueError("Test error")
layer.on_node_run_start(mock_llm_node)
layer.on_node_run_end(mock_llm_node, error)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1
assert spans[0].status.status_code == StatusCode.ERROR
assert len(spans[0].events) > 0
assert any("exception" in event.name.lower() for event in spans[0].events)
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_node_end_without_start_handled_gracefully(
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node
):
"""Test that ending a node without start doesn't crash."""
layer = ObservabilityLayer()
layer.on_graph_start()
layer.on_node_run_end(mock_llm_node, None)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 0
class TestObservabilityLayerParserIntegration:
"""Test parser integration for different node types."""
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_default_parser_used_for_regular_node(
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_start_node
):
"""Test that default parser is used for non-tool nodes."""
layer = ObservabilityLayer()
layer.on_graph_start()
layer.on_node_run_start(mock_start_node)
layer.on_node_run_end(mock_start_node, None)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1
attrs = spans[0].attributes
assert attrs["node.id"] == mock_start_node.id
assert attrs["node.execution_id"] == mock_start_node.execution_id
assert attrs["node.type"] == mock_start_node.node_type.value
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_tool_parser_used_for_tool_node(
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_tool_node
):
"""Test that tool parser is used for tool nodes."""
layer = ObservabilityLayer()
layer.on_graph_start()
layer.on_node_run_start(mock_tool_node)
layer.on_node_run_end(mock_tool_node, None)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1
attrs = spans[0].attributes
assert attrs["node.id"] == mock_tool_node.id
assert attrs["tool.provider.id"] == mock_tool_node._node_data.provider_id
assert attrs["tool.provider.type"] == mock_tool_node._node_data.provider_type.value
assert attrs["tool.name"] == mock_tool_node._node_data.tool_name
class TestObservabilityLayerGraphLifecycle:
"""Test graph lifecycle management."""
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_on_graph_start_clears_contexts(self, tracer_provider_with_memory_exporter, mock_llm_node):
"""Test that on_graph_start clears node contexts."""
layer = ObservabilityLayer()
layer.on_graph_start()
layer.on_node_run_start(mock_llm_node)
assert len(layer._node_contexts) == 1
layer.on_graph_start()
assert len(layer._node_contexts) == 0
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_on_graph_end_with_no_unfinished_spans(
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node
):
"""Test that on_graph_end handles normal completion."""
layer = ObservabilityLayer()
layer.on_graph_start()
layer.on_node_run_start(mock_llm_node)
layer.on_node_run_end(mock_llm_node, None)
layer.on_graph_end(None)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_on_graph_end_with_unfinished_spans_logs_warning(
self, tracer_provider_with_memory_exporter, mock_llm_node, caplog
):
"""Test that on_graph_end logs warning for unfinished spans."""
layer = ObservabilityLayer()
layer.on_graph_start()
layer.on_node_run_start(mock_llm_node)
assert len(layer._node_contexts) == 1
layer.on_graph_end(None)
assert len(layer._node_contexts) == 0
assert "node spans were not properly ended" in caplog.text
class TestObservabilityLayerDisabledMode:
"""Test behavior when layer is disabled."""
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_disabled_mode_skips_node_start(self, memory_span_exporter, mock_start_node):
"""Test that disabled layer doesn't create spans on node start."""
layer = ObservabilityLayer()
assert layer._is_disabled
layer.on_graph_start()
layer.on_node_run_start(mock_start_node)
layer.on_node_run_end(mock_start_node, None)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 0
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", False)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_disabled_mode_skips_node_end(self, memory_span_exporter, mock_llm_node):
"""Test that disabled layer doesn't process node end."""
layer = ObservabilityLayer()
assert layer._is_disabled
layer.on_node_run_end(mock_llm_node, None)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 0

View File

@ -0,0 +1,150 @@
"""
Unit tests for field encoding/decoding utilities.
These tests verify Base64 encoding/decoding functionality and
proper error handling and fallback behavior.
"""
import base64
from libs.encryption import FieldEncryption
class TestDecodeField:
"""Test cases for field decoding functionality."""
def test_decode_valid_base64(self):
"""Test decoding a valid Base64 encoded string."""
plaintext = "password123"
encoded = base64.b64encode(plaintext.encode("utf-8")).decode()
result = FieldEncryption.decrypt_field(encoded)
assert result == plaintext
def test_decode_non_base64_returns_none(self):
"""Test that non-base64 input returns None."""
non_base64 = "plain-password-!@#"
result = FieldEncryption.decrypt_field(non_base64)
# Should return None (decoding failed)
assert result is None
def test_decode_unicode_text(self):
"""Test decoding Base64 encoded Unicode text."""
plaintext = "密码Test123"
encoded = base64.b64encode(plaintext.encode("utf-8")).decode()
result = FieldEncryption.decrypt_field(encoded)
assert result == plaintext
def test_decode_empty_string(self):
"""Test decoding an empty string returns empty string."""
result = FieldEncryption.decrypt_field("")
# Empty string base64 decodes to empty string
assert result == ""
def test_decode_special_characters(self):
"""Test decoding with special characters."""
plaintext = "P@ssw0rd!#$%^&*()"
encoded = base64.b64encode(plaintext.encode("utf-8")).decode()
result = FieldEncryption.decrypt_field(encoded)
assert result == plaintext
class TestDecodePassword:
"""Test cases for password decoding."""
def test_decode_password_base64(self):
"""Test decoding a Base64 encoded password."""
password = "SecureP@ssw0rd!"
encoded = base64.b64encode(password.encode("utf-8")).decode()
result = FieldEncryption.decrypt_password(encoded)
assert result == password
def test_decode_password_invalid_returns_none(self):
"""Test that invalid base64 passwords return None."""
invalid = "PlainPassword!@#"
result = FieldEncryption.decrypt_password(invalid)
# Should return None (decoding failed)
assert result is None
class TestDecodeVerificationCode:
"""Test cases for verification code decoding."""
def test_decode_code_base64(self):
"""Test decoding a Base64 encoded verification code."""
code = "789012"
encoded = base64.b64encode(code.encode("utf-8")).decode()
result = FieldEncryption.decrypt_verification_code(encoded)
assert result == code
def test_decode_code_invalid_returns_none(self):
"""Test that invalid base64 codes return None."""
invalid = "123456" # Plain 6-digit code, not base64
result = FieldEncryption.decrypt_verification_code(invalid)
# Should return None (decoding failed)
assert result is None
class TestRoundTripEncodingDecoding:
"""
Integration tests for complete encoding-decoding cycle.
These tests simulate the full frontend-to-backend flow using Base64.
"""
def test_roundtrip_password(self):
"""Test encoding and decoding a password."""
original_password = "SecureP@ssw0rd!"
# Simulate frontend encoding (Base64)
encoded = base64.b64encode(original_password.encode("utf-8")).decode()
# Backend decoding
decoded = FieldEncryption.decrypt_password(encoded)
assert decoded == original_password
def test_roundtrip_verification_code(self):
"""Test encoding and decoding a verification code."""
original_code = "123456"
# Simulate frontend encoding
encoded = base64.b64encode(original_code.encode("utf-8")).decode()
# Backend decoding
decoded = FieldEncryption.decrypt_verification_code(encoded)
assert decoded == original_code
def test_roundtrip_unicode_password(self):
"""Test encoding and decoding password with Unicode characters."""
original_password = "密码Test123!@#"
# Frontend encoding
encoded = base64.b64encode(original_password.encode("utf-8")).decode()
# Backend decoding
decoded = FieldEncryption.decrypt_password(encoded)
assert decoded == original_password
def test_roundtrip_long_password(self):
"""Test encoding and decoding a long password."""
original_password = "ThisIsAVeryLongPasswordWithLotsOfCharacters123!@#$%^&*()"
encoded = base64.b64encode(original_password.encode("utf-8")).decode()
decoded = FieldEncryption.decrypt_password(encoded)
assert decoded == original_password
def test_roundtrip_with_whitespace(self):
"""Test encoding and decoding with whitespace."""
original_password = "pass word with spaces"
encoded = base64.b64encode(original_password.encode("utf-8")).decode()
decoded = FieldEncryption.decrypt_field(encoded)
assert decoded == original_password

View File

@ -518,6 +518,55 @@ class TestEdgeCases:
assert isinstance(result.result, StringSegment)
class TestTruncateJsonPrimitives:
"""Test _truncate_json_primitives method with different data types."""
@pytest.fixture
def truncator(self):
return VariableTruncator()
def test_truncate_json_primitives_file_type(self, truncator, file):
"""Test that File objects are handled correctly in _truncate_json_primitives."""
# Test File object is returned as-is without truncation
result = truncator._truncate_json_primitives(file, 1000)
assert result.value == file
assert result.truncated is False
# Size should be calculated correctly
expected_size = VariableTruncator.calculate_json_size(file)
assert result.value_size == expected_size
def test_truncate_json_primitives_file_type_small_budget(self, truncator, file):
"""Test that File objects are returned as-is even with small budget."""
# Even with a small size budget, File objects should not be truncated
result = truncator._truncate_json_primitives(file, 10)
assert result.value == file
assert result.truncated is False
def test_truncate_json_primitives_file_type_in_array(self, truncator, file):
"""Test File objects in arrays are handled correctly."""
array_with_files = [file, file]
result = truncator._truncate_json_primitives(array_with_files, 1000)
assert isinstance(result.value, list)
assert len(result.value) == 2
assert result.value[0] == file
assert result.value[1] == file
assert result.truncated is False
def test_truncate_json_primitives_file_type_in_object(self, truncator, file):
"""Test File objects in objects are handled correctly."""
obj_with_files = {"file1": file, "file2": file}
result = truncator._truncate_json_primitives(obj_with_files, 1000)
assert isinstance(result.value, dict)
assert len(result.value) == 2
assert result.value["file1"] == file
assert result.value["file2"] == file
assert result.truncated is False
class TestIntegrationScenarios:
"""Test realistic integration scenarios."""

View File

@ -110,6 +110,70 @@ class TestWebhookServiceUnit:
assert webhook_data["method"] == "POST"
assert webhook_data["body"]["raw"] == "raw text content"
def test_extract_octet_stream_body_uses_detected_mime(self):
"""Octet-stream uploads should rely on detected MIME type."""
app = Flask(__name__)
binary_content = b"plain text data"
with app.test_request_context(
"/webhook", method="POST", headers={"Content-Type": "application/octet-stream"}, data=binary_content
):
webhook_trigger = MagicMock()
mock_file = MagicMock()
mock_file.to_dict.return_value = {"file": "data"}
with (
patch.object(WebhookService, "_detect_binary_mimetype", return_value="text/plain") as mock_detect,
patch.object(WebhookService, "_create_file_from_binary") as mock_create,
):
mock_create.return_value = mock_file
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
assert body["raw"] == {"file": "data"}
assert files == {}
mock_detect.assert_called_once_with(binary_content)
mock_create.assert_called_once()
args = mock_create.call_args[0]
assert args[0] == binary_content
assert args[1] == "text/plain"
assert args[2] is webhook_trigger
def test_detect_binary_mimetype_uses_magic(self, monkeypatch):
"""python-magic output should be used when available."""
fake_magic = MagicMock()
fake_magic.from_buffer.return_value = "image/png"
monkeypatch.setattr("services.trigger.webhook_service.magic", fake_magic)
result = WebhookService._detect_binary_mimetype(b"binary data")
assert result == "image/png"
fake_magic.from_buffer.assert_called_once()
def test_detect_binary_mimetype_fallback_without_magic(self, monkeypatch):
"""Fallback MIME type should be used when python-magic is unavailable."""
monkeypatch.setattr("services.trigger.webhook_service.magic", None)
result = WebhookService._detect_binary_mimetype(b"binary data")
assert result == "application/octet-stream"
def test_detect_binary_mimetype_handles_magic_exception(self, monkeypatch):
"""Fallback MIME type should be used when python-magic raises an exception."""
try:
import magic as real_magic
except ImportError:
pytest.skip("python-magic is not installed")
fake_magic = MagicMock()
fake_magic.from_buffer.side_effect = real_magic.MagicException("magic error")
monkeypatch.setattr("services.trigger.webhook_service.magic", fake_magic)
with patch("services.trigger.webhook_service.logger") as mock_logger:
result = WebhookService._detect_binary_mimetype(b"binary data")
assert result == "application/octet-stream"
mock_logger.debug.assert_called_once()
def test_extract_webhook_data_invalid_json(self):
"""Test webhook data extraction with invalid JSON."""
app = Flask(__name__)

4651
api/uv.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1044,6 +1044,25 @@ WORKFLOW_LOG_RETENTION_DAYS=30
# Batch size for workflow log cleanup operations (default: 100)
WORKFLOW_LOG_CLEANUP_BATCH_SIZE=100
# Aliyun SLS Logstore Configuration
# Aliyun Access Key ID
ALIYUN_SLS_ACCESS_KEY_ID=
# Aliyun Access Key Secret
ALIYUN_SLS_ACCESS_KEY_SECRET=
# Aliyun SLS Endpoint (e.g., cn-hangzhou.log.aliyuncs.com)
ALIYUN_SLS_ENDPOINT=
# Aliyun SLS Region (e.g., cn-hangzhou)
ALIYUN_SLS_REGION=
# Aliyun SLS Project Name
ALIYUN_SLS_PROJECT_NAME=
# Number of days to retain workflow run logs (default: 365 days 3650 for permanent storage)
ALIYUN_SLS_LOGSTORE_TTL=365
# Enable dual-write to both SLS LogStore and SQL database (default: false)
LOGSTORE_DUAL_WRITE_ENABLED=false
# Enable dual-read fallback to SQL database when LogStore returns no results (default: true)
# Useful for migration scenarios where historical data exists only in SQL database
LOGSTORE_DUAL_READ_ENABLED=true
# HTTP request node in workflow configuration
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
@ -1229,7 +1248,7 @@ NGINX_SSL_PORT=443
# and modify the env vars below accordingly.
NGINX_SSL_CERT_FILENAME=dify.crt
NGINX_SSL_CERT_KEY_FILENAME=dify.key
NGINX_SSL_PROTOCOLS=TLSv1.1 TLSv1.2 TLSv1.3
NGINX_SSL_PROTOCOLS=TLSv1.2 TLSv1.3
# Nginx performance tuning
NGINX_WORKER_PROCESSES=auto
@ -1421,7 +1440,7 @@ QUEUE_MONITOR_ALERT_EMAILS=
QUEUE_MONITOR_INTERVAL=30
# Swagger UI configuration
SWAGGER_UI_ENABLED=true
SWAGGER_UI_ENABLED=false
SWAGGER_UI_PATH=/swagger-ui.html
# Whether to encrypt dataset IDs when exporting DSL files (default: true)
@ -1460,4 +1479,4 @@ ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20
ANNOTATION_IMPORT_MAX_CONCURRENT=5
# The API key of amplitude
AMPLITUDE_API_KEY=
AMPLITUDE_API_KEY=

View File

@ -414,7 +414,7 @@ services:
# and modify the env vars below in .env if HTTPS_ENABLED is true.
NGINX_SSL_CERT_FILENAME: ${NGINX_SSL_CERT_FILENAME:-dify.crt}
NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key}
NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.1 TLSv1.2 TLSv1.3}
NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.2 TLSv1.3}
NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto}
NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-100M}
NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65}

View File

@ -455,6 +455,14 @@ x-shared-env: &shared-api-worker-env
WORKFLOW_LOG_CLEANUP_ENABLED: ${WORKFLOW_LOG_CLEANUP_ENABLED:-false}
WORKFLOW_LOG_RETENTION_DAYS: ${WORKFLOW_LOG_RETENTION_DAYS:-30}
WORKFLOW_LOG_CLEANUP_BATCH_SIZE: ${WORKFLOW_LOG_CLEANUP_BATCH_SIZE:-100}
ALIYUN_SLS_ACCESS_KEY_ID: ${ALIYUN_SLS_ACCESS_KEY_ID:-}
ALIYUN_SLS_ACCESS_KEY_SECRET: ${ALIYUN_SLS_ACCESS_KEY_SECRET:-}
ALIYUN_SLS_ENDPOINT: ${ALIYUN_SLS_ENDPOINT:-}
ALIYUN_SLS_REGION: ${ALIYUN_SLS_REGION:-}
ALIYUN_SLS_PROJECT_NAME: ${ALIYUN_SLS_PROJECT_NAME:-}
ALIYUN_SLS_LOGSTORE_TTL: ${ALIYUN_SLS_LOGSTORE_TTL:-365}
LOGSTORE_DUAL_WRITE_ENABLED: ${LOGSTORE_DUAL_WRITE_ENABLED:-false}
LOGSTORE_DUAL_READ_ENABLED: ${LOGSTORE_DUAL_READ_ENABLED:-true}
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760}
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576}
HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True}
@ -528,7 +536,7 @@ x-shared-env: &shared-api-worker-env
NGINX_SSL_PORT: ${NGINX_SSL_PORT:-443}
NGINX_SSL_CERT_FILENAME: ${NGINX_SSL_CERT_FILENAME:-dify.crt}
NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key}
NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.1 TLSv1.2 TLSv1.3}
NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.2 TLSv1.3}
NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto}
NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-100M}
NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65}
@ -631,7 +639,7 @@ x-shared-env: &shared-api-worker-env
QUEUE_MONITOR_THRESHOLD: ${QUEUE_MONITOR_THRESHOLD:-200}
QUEUE_MONITOR_ALERT_EMAILS: ${QUEUE_MONITOR_ALERT_EMAILS:-}
QUEUE_MONITOR_INTERVAL: ${QUEUE_MONITOR_INTERVAL:-30}
SWAGGER_UI_ENABLED: ${SWAGGER_UI_ENABLED:-true}
SWAGGER_UI_ENABLED: ${SWAGGER_UI_ENABLED:-false}
SWAGGER_UI_PATH: ${SWAGGER_UI_PATH:-/swagger-ui.html}
DSL_EXPORT_ENCRYPT_DATASET_ID: ${DSL_EXPORT_ENCRYPT_DATASET_ID:-true}
DATASET_MAX_SEGMENTS_PER_REQUEST: ${DATASET_MAX_SEGMENTS_PER_REQUEST:-0}
@ -1071,7 +1079,7 @@ services:
# and modify the env vars below in .env if HTTPS_ENABLED is true.
NGINX_SSL_CERT_FILENAME: ${NGINX_SSL_CERT_FILENAME:-dify.crt}
NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key}
NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.1 TLSv1.2 TLSv1.3}
NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.2 TLSv1.3}
NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto}
NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-100M}
NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65}

View File

@ -213,3 +213,24 @@ PLUGIN_VOLCENGINE_TOS_ENDPOINT=
PLUGIN_VOLCENGINE_TOS_ACCESS_KEY=
PLUGIN_VOLCENGINE_TOS_SECRET_KEY=
PLUGIN_VOLCENGINE_TOS_REGION=
# ------------------------------
# Environment Variables for Aliyun SLS (Simple Log Service)
# ------------------------------
# Aliyun SLS Access Key ID
ALIYUN_SLS_ACCESS_KEY_ID=
# Aliyun SLS Access Key Secret
ALIYUN_SLS_ACCESS_KEY_SECRET=
# Aliyun SLS Endpoint (e.g., cn-hangzhou.log.aliyuncs.com)
ALIYUN_SLS_ENDPOINT=
# Aliyun SLS Region (e.g., cn-hangzhou)
ALIYUN_SLS_REGION=
# Aliyun SLS Project Name
ALIYUN_SLS_PROJECT_NAME=
# Aliyun SLS Logstore TTL (default: 365 days 3650 for permanent storage)
ALIYUN_SLS_LOGSTORE_TTL=365
# Enable dual-write to both LogStore and SQL database (default: true)
LOGSTORE_DUAL_WRITE_ENABLED=true
# Enable dual-read fallback to SQL database when LogStore returns no results (default: true)
# Useful for migration scenarios where historical data exists only in SQL database
LOGSTORE_DUAL_READ_ENABLED=true

7
web/.gitignore vendored
View File

@ -8,13 +8,6 @@
# testing
/coverage
# playwright e2e
/e2e/.auth/
/e2e/test-results/
/playwright-report/
/blob-report/
/test-results/
# next.js
/.next/
/out/

View File

@ -2,3 +2,4 @@
- Use `web/testing/testing.md` as the canonical instruction set for generating frontend automated tests.
- When proposing or saving tests, re-read that document and follow every requirement.
- All frontend tests MUST also comply with the `frontend-testing` skill. Treat the skill as a mandatory constraint, not optional guidance.

1
web/CLAUDE.md Symbolic link
View File

@ -0,0 +1 @@
AGENTS.md

View File

@ -0,0 +1,40 @@
/**
* Shared mock for react-i18next
*
* Jest automatically uses this mock when react-i18next is imported in tests.
* The default behavior returns the translation key as-is, which is suitable
* for most test scenarios.
*
* For tests that need custom translations, you can override with jest.mock():
*
* @example
* jest.mock('react-i18next', () => ({
* useTranslation: () => ({
* t: (key: string) => {
* if (key === 'some.key') return 'Custom translation'
* return key
* },
* }),
* }))
*/
export const useTranslation = () => ({
t: (key: string, options?: Record<string, unknown>) => {
if (options?.returnObjects)
return [`${key}-feature-1`, `${key}-feature-2`]
if (options)
return `${key}:${JSON.stringify(options)}`
return key
},
i18n: {
language: 'en',
changeLanguage: jest.fn(),
},
})
export const Trans = ({ children }: { children?: React.ReactNode }) => children
export const initReactI18next = {
type: '3rdParty',
init: jest.fn(),
}

View File

@ -4,12 +4,6 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import MailAndPasswordAuth from '@/app/(shareLayout)/webapp-signin/components/mail-and-password-auth'
import CheckCode from '@/app/(shareLayout)/webapp-signin/check-code/page'
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
const replaceMock = jest.fn()
const backMock = jest.fn()

View File

@ -4,12 +4,6 @@ import '@testing-library/jest-dom'
import CommandSelector from '../../app/components/goto-anything/command-selector'
import type { ActionItem } from '../../app/components/goto-anything/actions/types'
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
jest.mock('cmdk', () => ({
Command: {
Group: ({ children, className }: any) => <div className={className}>{children}</div>,

View File

@ -3,13 +3,6 @@ import { render } from '@testing-library/react'
import '@testing-library/jest-dom'
import { OpikIconBig } from '@/app/components/base/icons/src/public/tracing'
// Mock dependencies to isolate the SVG rendering issue
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
describe('SVG Attribute Error Reproduction', () => {
// Capture console errors
const originalError = console.error

View File

@ -0,0 +1,53 @@
import React from 'react'
import { fireEvent, render, screen } from '@testing-library/react'
import EditItem, { EditItemType } from './index'
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
describe('AddAnnotationModal/EditItem', () => {
test('should render query inputs with user avatar and placeholder strings', () => {
render(
<EditItem
type={EditItemType.Query}
content="Why?"
onChange={jest.fn()}
/>,
)
expect(screen.getByText('appAnnotation.addModal.queryName')).toBeInTheDocument()
expect(screen.getByPlaceholderText('appAnnotation.addModal.queryPlaceholder')).toBeInTheDocument()
expect(screen.getByText('Why?')).toBeInTheDocument()
})
test('should render answer name and placeholder text', () => {
render(
<EditItem
type={EditItemType.Answer}
content="Existing answer"
onChange={jest.fn()}
/>,
)
expect(screen.getByText('appAnnotation.addModal.answerName')).toBeInTheDocument()
expect(screen.getByPlaceholderText('appAnnotation.addModal.answerPlaceholder')).toBeInTheDocument()
expect(screen.getByDisplayValue('Existing answer')).toBeInTheDocument()
})
test('should propagate changes when answer content updates', () => {
const handleChange = jest.fn()
render(
<EditItem
type={EditItemType.Answer}
content=""
onChange={handleChange}
/>,
)
fireEvent.change(screen.getByPlaceholderText('appAnnotation.addModal.answerPlaceholder'), { target: { value: 'Because' } })
expect(handleChange).toHaveBeenCalledWith('Because')
})
})

View File

@ -0,0 +1,155 @@
import React from 'react'
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
import AddAnnotationModal from './index'
import { useProviderContext } from '@/context/provider-context'
jest.mock('@/context/provider-context', () => ({
useProviderContext: jest.fn(),
}))
const mockToastNotify = jest.fn()
jest.mock('@/app/components/base/toast', () => ({
__esModule: true,
default: {
notify: jest.fn(args => mockToastNotify(args)),
},
}))
jest.mock('@/app/components/billing/annotation-full', () => () => <div data-testid="annotation-full" />)
const mockUseProviderContext = useProviderContext as jest.Mock
const getProviderContext = ({ usage = 0, total = 10, enableBilling = false } = {}) => ({
plan: {
usage: { annotatedResponse: usage },
total: { annotatedResponse: total },
},
enableBilling,
})
describe('AddAnnotationModal', () => {
const baseProps = {
isShow: true,
onHide: jest.fn(),
onAdd: jest.fn(),
}
beforeEach(() => {
jest.clearAllMocks()
mockUseProviderContext.mockReturnValue(getProviderContext())
})
const typeQuestion = (value: string) => {
fireEvent.change(screen.getByPlaceholderText('appAnnotation.addModal.queryPlaceholder'), {
target: { value },
})
}
const typeAnswer = (value: string) => {
fireEvent.change(screen.getByPlaceholderText('appAnnotation.addModal.answerPlaceholder'), {
target: { value },
})
}
test('should render modal title when drawer is visible', () => {
render(<AddAnnotationModal {...baseProps} />)
expect(screen.getByText('appAnnotation.addModal.title')).toBeInTheDocument()
})
test('should capture query input text when typing', () => {
render(<AddAnnotationModal {...baseProps} />)
typeQuestion('Sample question')
expect(screen.getByPlaceholderText('appAnnotation.addModal.queryPlaceholder')).toHaveValue('Sample question')
})
test('should capture answer input text when typing', () => {
render(<AddAnnotationModal {...baseProps} />)
typeAnswer('Sample answer')
expect(screen.getByPlaceholderText('appAnnotation.addModal.answerPlaceholder')).toHaveValue('Sample answer')
})
test('should show annotation full notice and disable submit when quota exceeded', () => {
mockUseProviderContext.mockReturnValue(getProviderContext({ usage: 10, total: 10, enableBilling: true }))
render(<AddAnnotationModal {...baseProps} />)
expect(screen.getByTestId('annotation-full')).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'common.operation.add' })).toBeDisabled()
})
test('should call onAdd with form values when create next enabled', async () => {
const onAdd = jest.fn().mockResolvedValue(undefined)
render(<AddAnnotationModal {...baseProps} onAdd={onAdd} />)
typeQuestion('Question value')
typeAnswer('Answer value')
fireEvent.click(screen.getByTestId('checkbox-create-next-checkbox'))
await act(async () => {
fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' }))
})
expect(onAdd).toHaveBeenCalledWith({ question: 'Question value', answer: 'Answer value' })
})
test('should reset fields after saving when create next enabled', async () => {
const onAdd = jest.fn().mockResolvedValue(undefined)
render(<AddAnnotationModal {...baseProps} onAdd={onAdd} />)
typeQuestion('Question value')
typeAnswer('Answer value')
const createNextToggle = screen.getByText('appAnnotation.addModal.createNext').previousElementSibling as HTMLElement
fireEvent.click(createNextToggle)
await act(async () => {
fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' }))
})
await waitFor(() => {
expect(screen.getByPlaceholderText('appAnnotation.addModal.queryPlaceholder')).toHaveValue('')
expect(screen.getByPlaceholderText('appAnnotation.addModal.answerPlaceholder')).toHaveValue('')
})
})
test('should show toast when validation fails for missing question', () => {
render(<AddAnnotationModal {...baseProps} />)
fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' }))
expect(mockToastNotify).toHaveBeenCalledWith(expect.objectContaining({
type: 'error',
message: 'appAnnotation.errorMessage.queryRequired',
}))
})
test('should show toast when validation fails for missing answer', () => {
render(<AddAnnotationModal {...baseProps} />)
typeQuestion('Filled question')
fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' }))
expect(mockToastNotify).toHaveBeenCalledWith(expect.objectContaining({
type: 'error',
message: 'appAnnotation.errorMessage.answerRequired',
}))
})
test('should close modal when save completes and create next unchecked', async () => {
const onAdd = jest.fn().mockResolvedValue(undefined)
render(<AddAnnotationModal {...baseProps} onAdd={onAdd} />)
typeQuestion('Q')
typeAnswer('A')
await act(async () => {
fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' }))
})
expect(baseProps.onHide).toHaveBeenCalled()
})
test('should allow cancel button to close the drawer', () => {
render(<AddAnnotationModal {...baseProps} />)
fireEvent.click(screen.getByRole('button', { name: 'common.operation.cancel' }))
expect(baseProps.onHide).toHaveBeenCalled()
})
})

View File

@ -101,7 +101,7 @@ const AddAnnotationModal: FC<Props> = ({
<div
className='flex items-center space-x-2'
>
<Checkbox checked={isCreateNext} onCheck={() => setIsCreateNext(!isCreateNext)} />
<Checkbox id='create-next-checkbox' checked={isCreateNext} onCheck={() => setIsCreateNext(!isCreateNext)} />
<div>{t('appAnnotation.addModal.createNext')}</div>
</div>
<div className='mt-2 flex space-x-2'>

View File

@ -3,12 +3,6 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import CSVUploader, { type Props } from './csv-uploader'
import { ToastContext } from '@/app/components/base/toast'
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
describe('CSVUploader', () => {
const notify = jest.fn()
const updateFile = jest.fn()

View File

@ -0,0 +1,397 @@
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import EditItem, { EditItemType, EditTitle } from './index'
describe('EditTitle', () => {
it('should render title content correctly', () => {
// Arrange
const props = { title: 'Test Title' }
// Act
render(<EditTitle {...props} />)
// Assert
expect(screen.getByText(/test title/i)).toBeInTheDocument()
// Should contain edit icon (svg element)
expect(document.querySelector('svg')).toBeInTheDocument()
})
it('should apply custom className when provided', () => {
// Arrange
const props = {
title: 'Test Title',
className: 'custom-class',
}
// Act
const { container } = render(<EditTitle {...props} />)
// Assert
expect(screen.getByText(/test title/i)).toBeInTheDocument()
expect(container.querySelector('.custom-class')).toBeInTheDocument()
})
})
describe('EditItem', () => {
const defaultProps = {
type: EditItemType.Query,
content: 'Test content',
onSave: jest.fn(),
}
beforeEach(() => {
jest.clearAllMocks()
})
// Rendering tests (REQUIRED)
describe('Rendering', () => {
it('should render content correctly', () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<EditItem {...props} />)
// Assert
expect(screen.getByText(/test content/i)).toBeInTheDocument()
// Should show item name (query or answer)
expect(screen.getByText('appAnnotation.editModal.queryName')).toBeInTheDocument()
})
it('should render different item types correctly', () => {
// Arrange
const props = {
...defaultProps,
type: EditItemType.Answer,
content: 'Answer content',
}
// Act
render(<EditItem {...props} />)
// Assert
expect(screen.getByText(/answer content/i)).toBeInTheDocument()
expect(screen.getByText('appAnnotation.editModal.answerName')).toBeInTheDocument()
})
it('should show edit controls when not readonly', () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<EditItem {...props} />)
// Assert
expect(screen.getByText('common.operation.edit')).toBeInTheDocument()
})
it('should hide edit controls when readonly', () => {
// Arrange
const props = {
...defaultProps,
readonly: true,
}
// Act
render(<EditItem {...props} />)
// Assert
expect(screen.queryByText('common.operation.edit')).not.toBeInTheDocument()
})
})
// Props tests (REQUIRED)
describe('Props', () => {
it('should respect readonly prop for edit functionality', () => {
// Arrange
const props = {
...defaultProps,
readonly: true,
}
// Act
render(<EditItem {...props} />)
// Assert
expect(screen.getByText(/test content/i)).toBeInTheDocument()
expect(screen.queryByText('common.operation.edit')).not.toBeInTheDocument()
})
it('should display provided content', () => {
// Arrange
const props = {
...defaultProps,
content: 'Custom content for testing',
}
// Act
render(<EditItem {...props} />)
// Assert
expect(screen.getByText(/custom content for testing/i)).toBeInTheDocument()
})
it('should render appropriate content based on type', () => {
// Arrange
const props = {
...defaultProps,
type: EditItemType.Query,
content: 'Question content',
}
// Act
render(<EditItem {...props} />)
// Assert
expect(screen.getByText(/question content/i)).toBeInTheDocument()
expect(screen.getByText('appAnnotation.editModal.queryName')).toBeInTheDocument()
})
})
// User Interactions
describe('User Interactions', () => {
it('should activate edit mode when edit button is clicked', async () => {
// Arrange
const props = { ...defaultProps }
const user = userEvent.setup()
// Act
render(<EditItem {...props} />)
await user.click(screen.getByText('common.operation.edit'))
// Assert
expect(screen.getByRole('textbox')).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'common.operation.cancel' })).toBeInTheDocument()
})
it('should save new content when save button is clicked', async () => {
// Arrange
const mockSave = jest.fn().mockResolvedValue(undefined)
const props = {
...defaultProps,
onSave: mockSave,
}
const user = userEvent.setup()
// Act
render(<EditItem {...props} />)
await user.click(screen.getByText('common.operation.edit'))
// Type new content
const textarea = screen.getByRole('textbox')
await user.clear(textarea)
await user.type(textarea, 'Updated content')
// Save
await user.click(screen.getByRole('button', { name: 'common.operation.save' }))
// Assert
expect(mockSave).toHaveBeenCalledWith('Updated content')
})
it('should exit edit mode when cancel button is clicked', async () => {
// Arrange
const props = { ...defaultProps }
const user = userEvent.setup()
// Act
render(<EditItem {...props} />)
await user.click(screen.getByText('common.operation.edit'))
await user.click(screen.getByRole('button', { name: 'common.operation.cancel' }))
// Assert
expect(screen.queryByRole('textbox')).not.toBeInTheDocument()
expect(screen.getByText(/test content/i)).toBeInTheDocument()
})
it('should show content preview while typing', async () => {
// Arrange
const props = { ...defaultProps }
const user = userEvent.setup()
// Act
render(<EditItem {...props} />)
await user.click(screen.getByText('common.operation.edit'))
const textarea = screen.getByRole('textbox')
await user.type(textarea, 'New content')
// Assert
expect(screen.getByText(/new content/i)).toBeInTheDocument()
})
it('should call onSave with correct content when saving', async () => {
// Arrange
const mockSave = jest.fn().mockResolvedValue(undefined)
const props = {
...defaultProps,
onSave: mockSave,
}
const user = userEvent.setup()
// Act
render(<EditItem {...props} />)
await user.click(screen.getByText('common.operation.edit'))
const textarea = screen.getByRole('textbox')
await user.clear(textarea)
await user.type(textarea, 'Test save content')
// Save
await user.click(screen.getByRole('button', { name: 'common.operation.save' }))
// Assert
expect(mockSave).toHaveBeenCalledWith('Test save content')
})
it('should show delete option when content changes', async () => {
// Arrange
const mockSave = jest.fn().mockResolvedValue(undefined)
const props = {
...defaultProps,
onSave: mockSave,
}
const user = userEvent.setup()
// Act
render(<EditItem {...props} />)
// Enter edit mode and change content
await user.click(screen.getByText('common.operation.edit'))
const textarea = screen.getByRole('textbox')
await user.clear(textarea)
await user.type(textarea, 'Modified content')
// Save to trigger content change
await user.click(screen.getByRole('button', { name: 'common.operation.save' }))
// Assert
expect(mockSave).toHaveBeenCalledWith('Modified content')
})
it('should handle keyboard interactions in edit mode', async () => {
// Arrange
const props = { ...defaultProps }
const user = userEvent.setup()
// Act
render(<EditItem {...props} />)
await user.click(screen.getByText('common.operation.edit'))
const textarea = screen.getByRole('textbox')
// Test typing
await user.type(textarea, 'Keyboard test')
// Assert
expect(textarea).toHaveValue('Keyboard test')
expect(screen.getByText(/keyboard test/i)).toBeInTheDocument()
})
})
// State Management
describe('State Management', () => {
it('should reset newContent when content prop changes', async () => {
// Arrange
const { rerender } = render(<EditItem {...defaultProps} />)
// Act - Enter edit mode and type something
const user = userEvent.setup()
await user.click(screen.getByText('common.operation.edit'))
const textarea = screen.getByRole('textbox')
await user.clear(textarea)
await user.type(textarea, 'New content')
// Rerender with new content prop
rerender(<EditItem {...defaultProps} content="Updated content" />)
// Assert - Textarea value should be reset due to useEffect
expect(textarea).toHaveValue('')
})
it('should preserve edit state across content changes', async () => {
// Arrange
const { rerender } = render(<EditItem {...defaultProps} />)
const user = userEvent.setup()
// Act - Enter edit mode
await user.click(screen.getByText('common.operation.edit'))
// Rerender with new content
rerender(<EditItem {...defaultProps} content="Updated content" />)
// Assert - Should still be in edit mode
expect(screen.getByRole('textbox')).toBeInTheDocument()
})
})
// Edge Cases (REQUIRED)
describe('Edge Cases', () => {
it('should handle empty content', () => {
// Arrange
const props = {
...defaultProps,
content: '',
}
// Act
const { container } = render(<EditItem {...props} />)
// Assert - Should render without crashing
// Check that the component renders properly with empty content
expect(container.querySelector('.grow')).toBeInTheDocument()
// Should still show edit button
expect(screen.getByText('common.operation.edit')).toBeInTheDocument()
})
it('should handle very long content', () => {
// Arrange
const longContent = 'A'.repeat(1000)
const props = {
...defaultProps,
content: longContent,
}
// Act
render(<EditItem {...props} />)
// Assert
expect(screen.getByText(longContent)).toBeInTheDocument()
})
it('should handle content with special characters', () => {
// Arrange
const specialContent = 'Content with & < > " \' characters'
const props = {
...defaultProps,
content: specialContent,
}
// Act
render(<EditItem {...props} />)
// Assert
expect(screen.getByText(specialContent)).toBeInTheDocument()
})
it('should handle rapid edit/cancel operations', async () => {
// Arrange
const props = { ...defaultProps }
const user = userEvent.setup()
// Act
render(<EditItem {...props} />)
// Rapid edit/cancel operations
await user.click(screen.getByText('common.operation.edit'))
await user.click(screen.getByText('common.operation.cancel'))
await user.click(screen.getByText('common.operation.edit'))
await user.click(screen.getByText('common.operation.cancel'))
// Assert
expect(screen.queryByRole('textbox')).not.toBeInTheDocument()
expect(screen.getByText('Test content')).toBeInTheDocument()
})
})
})

View File

@ -0,0 +1,408 @@
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import Toast, { type IToastProps, type ToastHandle } from '@/app/components/base/toast'
import EditAnnotationModal from './index'
// Mock only external dependencies
jest.mock('@/service/annotation', () => ({
addAnnotation: jest.fn(),
editAnnotation: jest.fn(),
}))
jest.mock('@/context/provider-context', () => ({
useProviderContext: () => ({
plan: {
usage: { annotatedResponse: 5 },
total: { annotatedResponse: 10 },
},
enableBilling: true,
}),
}))
jest.mock('@/hooks/use-timestamp', () => ({
__esModule: true,
default: () => ({
formatTime: () => '2023-12-01 10:30:00',
}),
}))
// Note: i18n is automatically mocked by Jest via __mocks__/react-i18next.ts
jest.mock('@/app/components/billing/annotation-full', () => ({
__esModule: true,
default: () => <div data-testid="annotation-full" />,
}))
type ToastNotifyProps = Pick<IToastProps, 'type' | 'size' | 'message' | 'duration' | 'className' | 'customComponent' | 'onClose'>
type ToastWithNotify = typeof Toast & { notify: (props: ToastNotifyProps) => ToastHandle }
const toastWithNotify = Toast as unknown as ToastWithNotify
const toastNotifySpy = jest.spyOn(toastWithNotify, 'notify').mockReturnValue({ clear: jest.fn() })
const { addAnnotation: mockAddAnnotation, editAnnotation: mockEditAnnotation } = jest.requireMock('@/service/annotation') as {
addAnnotation: jest.Mock
editAnnotation: jest.Mock
}
describe('EditAnnotationModal', () => {
const defaultProps = {
isShow: true,
onHide: jest.fn(),
appId: 'test-app-id',
query: 'Test query',
answer: 'Test answer',
onEdited: jest.fn(),
onAdded: jest.fn(),
onRemove: jest.fn(),
}
afterAll(() => {
toastNotifySpy.mockRestore()
})
beforeEach(() => {
jest.clearAllMocks()
mockAddAnnotation.mockResolvedValue({
id: 'test-id',
account: { name: 'Test User' },
})
mockEditAnnotation.mockResolvedValue({})
})
// Rendering tests (REQUIRED)
describe('Rendering', () => {
it('should render modal when isShow is true', () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<EditAnnotationModal {...props} />)
// Assert - Check for modal title as it appears in the mock
expect(screen.getByText('appAnnotation.editModal.title')).toBeInTheDocument()
})
it('should not render modal when isShow is false', () => {
// Arrange
const props = { ...defaultProps, isShow: false }
// Act
render(<EditAnnotationModal {...props} />)
// Assert
expect(screen.queryByText('appAnnotation.editModal.title')).not.toBeInTheDocument()
})
it('should display query and answer sections', () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<EditAnnotationModal {...props} />)
// Assert - Look for query and answer content
expect(screen.getByText('Test query')).toBeInTheDocument()
expect(screen.getByText('Test answer')).toBeInTheDocument()
})
})
// Props tests (REQUIRED)
describe('Props', () => {
it('should handle different query and answer content', () => {
// Arrange
const props = {
...defaultProps,
query: 'Custom query content',
answer: 'Custom answer content',
}
// Act
render(<EditAnnotationModal {...props} />)
// Assert - Check content is displayed
expect(screen.getByText('Custom query content')).toBeInTheDocument()
expect(screen.getByText('Custom answer content')).toBeInTheDocument()
})
it('should show remove option when annotationId is provided', () => {
// Arrange
const props = {
...defaultProps,
annotationId: 'test-annotation-id',
}
// Act
render(<EditAnnotationModal {...props} />)
// Assert - Remove option should be present (using pattern)
expect(screen.getByText('appAnnotation.editModal.removeThisCache')).toBeInTheDocument()
})
})
// User Interactions
describe('User Interactions', () => {
it('should enable editing for query and answer sections', () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<EditAnnotationModal {...props} />)
// Assert - Edit links should be visible (using text content)
const editLinks = screen.getAllByText(/common\.operation\.edit/i)
expect(editLinks).toHaveLength(2)
})
it('should show remove option when annotationId is provided', () => {
// Arrange
const props = {
...defaultProps,
annotationId: 'test-annotation-id',
}
// Act
render(<EditAnnotationModal {...props} />)
// Assert
expect(screen.getByText('appAnnotation.editModal.removeThisCache')).toBeInTheDocument()
})
it('should save content when edited', async () => {
// Arrange
const mockOnAdded = jest.fn()
const props = {
...defaultProps,
onAdded: mockOnAdded,
}
const user = userEvent.setup()
// Mock API response
mockAddAnnotation.mockResolvedValueOnce({
id: 'test-annotation-id',
account: { name: 'Test User' },
})
// Act
render(<EditAnnotationModal {...props} />)
// Find and click edit link for query
const editLinks = screen.getAllByText(/common\.operation\.edit/i)
await user.click(editLinks[0])
// Find textarea and enter new content
const textarea = screen.getByRole('textbox')
await user.clear(textarea)
await user.type(textarea, 'New query content')
// Click save button
const saveButton = screen.getByRole('button', { name: 'common.operation.save' })
await user.click(saveButton)
// Assert
expect(mockAddAnnotation).toHaveBeenCalledWith('test-app-id', {
question: 'New query content',
answer: 'Test answer',
message_id: undefined,
})
})
})
// API Calls
describe('API Calls', () => {
it('should call addAnnotation when saving new annotation', async () => {
// Arrange
const mockOnAdded = jest.fn()
const props = {
...defaultProps,
onAdded: mockOnAdded,
}
const user = userEvent.setup()
// Mock the API response
mockAddAnnotation.mockResolvedValueOnce({
id: 'test-annotation-id',
account: { name: 'Test User' },
})
// Act
render(<EditAnnotationModal {...props} />)
// Edit query content
const editLinks = screen.getAllByText(/common\.operation\.edit/i)
await user.click(editLinks[0])
const textarea = screen.getByRole('textbox')
await user.clear(textarea)
await user.type(textarea, 'Updated query')
const saveButton = screen.getByRole('button', { name: 'common.operation.save' })
await user.click(saveButton)
// Assert
expect(mockAddAnnotation).toHaveBeenCalledWith('test-app-id', {
question: 'Updated query',
answer: 'Test answer',
message_id: undefined,
})
})
it('should call editAnnotation when updating existing annotation', async () => {
// Arrange
const mockOnEdited = jest.fn()
const props = {
...defaultProps,
annotationId: 'test-annotation-id',
messageId: 'test-message-id',
onEdited: mockOnEdited,
}
const user = userEvent.setup()
// Act
render(<EditAnnotationModal {...props} />)
// Edit query content
const editLinks = screen.getAllByText(/common\.operation\.edit/i)
await user.click(editLinks[0])
const textarea = screen.getByRole('textbox')
await user.clear(textarea)
await user.type(textarea, 'Modified query')
const saveButton = screen.getByRole('button', { name: 'common.operation.save' })
await user.click(saveButton)
// Assert
expect(mockEditAnnotation).toHaveBeenCalledWith(
'test-app-id',
'test-annotation-id',
{
message_id: 'test-message-id',
question: 'Modified query',
answer: 'Test answer',
},
)
})
})
// State Management
describe('State Management', () => {
it('should initialize with closed confirm modal', () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<EditAnnotationModal {...props} />)
// Assert - Confirm dialog should not be visible initially
expect(screen.queryByText('appDebug.feature.annotation.removeConfirm')).not.toBeInTheDocument()
})
it('should show confirm modal when remove is clicked', async () => {
// Arrange
const props = {
...defaultProps,
annotationId: 'test-annotation-id',
}
const user = userEvent.setup()
// Act
render(<EditAnnotationModal {...props} />)
await user.click(screen.getByText('appAnnotation.editModal.removeThisCache'))
// Assert - Confirmation dialog should appear
expect(screen.getByText('appDebug.feature.annotation.removeConfirm')).toBeInTheDocument()
})
it('should call onRemove when removal is confirmed', async () => {
// Arrange
const mockOnRemove = jest.fn()
const props = {
...defaultProps,
annotationId: 'test-annotation-id',
onRemove: mockOnRemove,
}
const user = userEvent.setup()
// Act
render(<EditAnnotationModal {...props} />)
// Click remove
await user.click(screen.getByText('appAnnotation.editModal.removeThisCache'))
// Click confirm
const confirmButton = screen.getByRole('button', { name: 'common.operation.confirm' })
await user.click(confirmButton)
// Assert
expect(mockOnRemove).toHaveBeenCalled()
})
})
// Edge Cases (REQUIRED)
describe('Edge Cases', () => {
it('should handle empty query and answer', () => {
// Arrange
const props = {
...defaultProps,
query: '',
answer: '',
}
// Act
render(<EditAnnotationModal {...props} />)
// Assert
expect(screen.getByText('appAnnotation.editModal.title')).toBeInTheDocument()
})
it('should handle very long content', () => {
// Arrange
const longQuery = 'Q'.repeat(1000)
const longAnswer = 'A'.repeat(1000)
const props = {
...defaultProps,
query: longQuery,
answer: longAnswer,
}
// Act
render(<EditAnnotationModal {...props} />)
// Assert
expect(screen.getByText(longQuery)).toBeInTheDocument()
expect(screen.getByText(longAnswer)).toBeInTheDocument()
})
it('should handle special characters in content', () => {
// Arrange
const specialQuery = 'Query with & < > " \' characters'
const specialAnswer = 'Answer with & < > " \' characters'
const props = {
...defaultProps,
query: specialQuery,
answer: specialAnswer,
}
// Act
render(<EditAnnotationModal {...props} />)
// Assert
expect(screen.getByText(specialQuery)).toBeInTheDocument()
expect(screen.getByText(specialAnswer)).toBeInTheDocument()
})
it('should handle onlyEditResponse prop', () => {
// Arrange
const props = {
...defaultProps,
onlyEditResponse: true,
}
// Act
render(<EditAnnotationModal {...props} />)
// Assert - Query should be readonly, answer should be editable
const editLinks = screen.queryAllByText(/common\.operation\.edit/i)
expect(editLinks).toHaveLength(1) // Only answer should have edit button
})
})
})

View File

@ -1,12 +1,6 @@
import { fireEvent, render, screen } from '@testing-library/react'
import OperationBtn from './index'
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
jest.mock('@remixicon/react', () => ({
RiAddLine: (props: { className?: string }) => (
<svg data-testid='add-icon' className={props.className} />

View File

@ -0,0 +1,22 @@
import React from 'react'
import { fireEvent, render, screen } from '@testing-library/react'
import CannotQueryDataset from './cannot-query-dataset'
describe('CannotQueryDataset WarningMask', () => {
test('should render dataset warning copy and action button', () => {
const onConfirm = jest.fn()
render(<CannotQueryDataset onConfirm={onConfirm} />)
expect(screen.getByText('appDebug.feature.dataSet.queryVariable.unableToQueryDataSet')).toBeInTheDocument()
expect(screen.getByText('appDebug.feature.dataSet.queryVariable.unableToQueryDataSetTip')).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'appDebug.feature.dataSet.queryVariable.ok' })).toBeInTheDocument()
})
test('should invoke onConfirm when OK button clicked', () => {
const onConfirm = jest.fn()
render(<CannotQueryDataset onConfirm={onConfirm} />)
fireEvent.click(screen.getByRole('button', { name: 'appDebug.feature.dataSet.queryVariable.ok' }))
expect(onConfirm).toHaveBeenCalledTimes(1)
})
})

View File

@ -0,0 +1,39 @@
import React from 'react'
import { fireEvent, render, screen } from '@testing-library/react'
import FormattingChanged from './formatting-changed'
describe('FormattingChanged WarningMask', () => {
test('should display translation text and both actions', () => {
const onConfirm = jest.fn()
const onCancel = jest.fn()
render(
<FormattingChanged
onConfirm={onConfirm}
onCancel={onCancel}
/>,
)
expect(screen.getByText('appDebug.formattingChangedTitle')).toBeInTheDocument()
expect(screen.getByText('appDebug.formattingChangedText')).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'common.operation.cancel' })).toBeInTheDocument()
expect(screen.getByRole('button', { name: /common\.operation\.refresh/ })).toBeInTheDocument()
})
test('should call callbacks when buttons are clicked', () => {
const onConfirm = jest.fn()
const onCancel = jest.fn()
render(
<FormattingChanged
onConfirm={onConfirm}
onCancel={onCancel}
/>,
)
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.refresh/ }))
fireEvent.click(screen.getByRole('button', { name: 'common.operation.cancel' }))
expect(onConfirm).toHaveBeenCalledTimes(1)
expect(onCancel).toHaveBeenCalledTimes(1)
})
})

View File

@ -0,0 +1,26 @@
import React from 'react'
import { fireEvent, render, screen } from '@testing-library/react'
import HasNotSetAPI from './has-not-set-api'
describe('HasNotSetAPI WarningMask', () => {
test('should show default title when trial not finished', () => {
render(<HasNotSetAPI isTrailFinished={false} onSetting={jest.fn()} />)
expect(screen.getByText('appDebug.notSetAPIKey.title')).toBeInTheDocument()
expect(screen.getByText('appDebug.notSetAPIKey.description')).toBeInTheDocument()
})
test('should show trail finished title when flag is true', () => {
render(<HasNotSetAPI isTrailFinished onSetting={jest.fn()} />)
expect(screen.getByText('appDebug.notSetAPIKey.trailFinished')).toBeInTheDocument()
})
test('should call onSetting when primary button clicked', () => {
const onSetting = jest.fn()
render(<HasNotSetAPI isTrailFinished={false} onSetting={onSetting} />)
fireEvent.click(screen.getByRole('button', { name: 'appDebug.notSetAPIKey.settingBtn' }))
expect(onSetting).toHaveBeenCalledTimes(1)
})
})

View File

@ -0,0 +1,25 @@
import React from 'react'
import { render, screen } from '@testing-library/react'
import WarningMask from './index'
describe('WarningMask', () => {
// Rendering of title, description, and footer content
describe('Rendering', () => {
test('should display provided title, description, and footer node', () => {
const footer = <button type="button">Retry</button>
// Arrange
render(
<WarningMask
title="Access Restricted"
description="Only workspace owners may modify this section."
footer={footer}
/>,
)
// Assert
expect(screen.getByText('Access Restricted')).toBeInTheDocument()
expect(screen.getByText('Only workspace owners may modify this section.')).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'Retry' })).toBeInTheDocument()
})
})
})

View File

@ -2,12 +2,6 @@ import React from 'react'
import { fireEvent, render, screen } from '@testing-library/react'
import ConfirmAddVar from './index'
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
jest.mock('../../base/var-highlight', () => ({
__esModule: true,
default: ({ name }: { name: string }) => <span data-testid="var-highlight">{name}</span>,

View File

@ -3,12 +3,6 @@ import { fireEvent, render, screen } from '@testing-library/react'
import EditModal from './edit-modal'
import type { ConversationHistoriesRole } from '@/models/debug'
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
jest.mock('@/app/components/base/modal', () => ({
__esModule: true,
default: ({ children }: { children: React.ReactNode }) => <div>{children}</div>,

View File

@ -2,12 +2,6 @@ import React from 'react'
import { render, screen } from '@testing-library/react'
import HistoryPanel from './history-panel'
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
const mockDocLink = jest.fn(() => 'doc-link')
jest.mock('@/context/i18n', () => ({
useDocLink: () => mockDocLink,

View File

@ -6,12 +6,6 @@ import { MAX_PROMPT_MESSAGE_LENGTH } from '@/config'
import { type PromptItem, PromptRole, type PromptVariable } from '@/models/debug'
import { AppModeEnum, ModelModeType } from '@/types/app'
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
type DebugConfiguration = {
isAdvancedMode: boolean
currentAdvancedPrompt: PromptItem | PromptItem[]

View File

@ -5,12 +5,6 @@ jest.mock('react-sortablejs', () => ({
ReactSortable: ({ children }: { children: React.ReactNode }) => <div>{children}</div>,
}))
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
describe('ConfigSelect Component', () => {
const defaultProps = {
options: ['Option 1', 'Option 2'],

View File

@ -0,0 +1,121 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import ConfigString, { type IConfigStringProps } from './index'
const renderConfigString = (props?: Partial<IConfigStringProps>) => {
const onChange = jest.fn()
const defaultProps: IConfigStringProps = {
value: 5,
maxLength: 10,
modelId: 'model-id',
onChange,
}
render(<ConfigString {...defaultProps} {...props} />)
return { onChange }
}
describe('ConfigString', () => {
beforeEach(() => {
jest.clearAllMocks()
})
describe('Rendering', () => {
it('should render numeric input with bounds', () => {
renderConfigString({ value: 3, maxLength: 8 })
const input = screen.getByRole('spinbutton')
expect(input).toHaveValue(3)
expect(input).toHaveAttribute('min', '1')
expect(input).toHaveAttribute('max', '8')
})
it('should render empty input when value is undefined', () => {
const { onChange } = renderConfigString({ value: undefined })
expect(screen.getByRole('spinbutton')).toHaveValue(null)
expect(onChange).not.toHaveBeenCalled()
})
})
describe('Effect behavior', () => {
it('should clamp initial value to maxLength when it exceeds limit', async () => {
const onChange = jest.fn()
render(
<ConfigString
value={15}
maxLength={10}
modelId="model-id"
onChange={onChange}
/>,
)
await waitFor(() => {
expect(onChange).toHaveBeenCalledWith(10)
})
expect(onChange).toHaveBeenCalledTimes(1)
})
it('should clamp when updated prop value exceeds maxLength', async () => {
const onChange = jest.fn()
const { rerender } = render(
<ConfigString
value={4}
maxLength={6}
modelId="model-id"
onChange={onChange}
/>,
)
rerender(
<ConfigString
value={9}
maxLength={6}
modelId="model-id"
onChange={onChange}
/>,
)
await waitFor(() => {
expect(onChange).toHaveBeenCalledWith(6)
})
expect(onChange).toHaveBeenCalledTimes(1)
})
})
describe('User interactions', () => {
it('should clamp entered value above maxLength', () => {
const { onChange } = renderConfigString({ maxLength: 7 })
fireEvent.change(screen.getByRole('spinbutton'), { target: { value: '12' } })
expect(onChange).toHaveBeenCalledWith(7)
})
it('should raise value below minimum to one', () => {
const { onChange } = renderConfigString()
fireEvent.change(screen.getByRole('spinbutton'), { target: { value: '0' } })
expect(onChange).toHaveBeenCalledWith(1)
})
it('should forward parsed value when within bounds', () => {
const { onChange } = renderConfigString({ maxLength: 9 })
fireEvent.change(screen.getByRole('spinbutton'), { target: { value: '7' } })
expect(onChange).toHaveBeenCalledWith(7)
})
it('should pass through NaN when input is cleared', () => {
const { onChange } = renderConfigString()
fireEvent.change(screen.getByRole('spinbutton'), { target: { value: '' } })
expect(onChange).toHaveBeenCalledTimes(1)
expect(onChange.mock.calls[0][0]).toBeNaN()
})
})
})

View File

@ -0,0 +1,45 @@
import React from 'react'
import { fireEvent, render, screen } from '@testing-library/react'
import SelectTypeItem from './index'
import { InputVarType } from '@/app/components/workflow/types'
describe('SelectTypeItem', () => {
// Rendering pathways based on type and selection state
describe('Rendering', () => {
test('should render ok', () => {
// Arrange
const { container } = render(
<SelectTypeItem
type={InputVarType.textInput}
selected={false}
onClick={jest.fn()}
/>,
)
// Assert
expect(screen.getByText('appDebug.variableConfig.text-input')).toBeInTheDocument()
expect(container.querySelector('svg')).not.toBeNull()
})
})
// User interaction outcomes
describe('Interactions', () => {
test('should trigger onClick when item is pressed', () => {
const handleClick = jest.fn()
// Arrange
render(
<SelectTypeItem
type={InputVarType.paragraph}
selected={false}
onClick={handleClick}
/>,
)
// Act
fireEvent.click(screen.getByText('appDebug.variableConfig.paragraph'))
// Assert
expect(handleClick).toHaveBeenCalledTimes(1)
})
})
})

View File

@ -1,12 +1,6 @@
import { fireEvent, render, screen } from '@testing-library/react'
import ContrlBtnGroup from './index'
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
describe('ContrlBtnGroup', () => {
beforeEach(() => {
jest.clearAllMocks()

View File

@ -0,0 +1,242 @@
import { fireEvent, render, screen, waitFor, within } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import Item from './index'
import type React from 'react'
import type { DataSet } from '@/models/datasets'
import { ChunkingMode, DataSourceType, DatasetPermission } from '@/models/datasets'
import type { IndexingType } from '@/app/components/datasets/create/step-two'
import type { RetrievalConfig } from '@/types/app'
import { RETRIEVE_METHOD } from '@/types/app'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
jest.mock('../settings-modal', () => ({
__esModule: true,
default: ({ onSave, onCancel, currentDataset }: any) => (
<div>
<div>Mock settings modal</div>
<button onClick={() => onSave({ ...currentDataset, name: 'Updated dataset' })}>Save changes</button>
<button onClick={onCancel}>Close</button>
</div>
),
}))
jest.mock('@/hooks/use-breakpoints', () => {
const actual = jest.requireActual('@/hooks/use-breakpoints')
return {
__esModule: true,
...actual,
default: jest.fn(() => actual.MediaType.pc),
}
})
const mockedUseBreakpoints = useBreakpoints as jest.MockedFunction<typeof useBreakpoints>
const baseRetrievalConfig: RetrievalConfig = {
search_method: RETRIEVE_METHOD.semantic,
reranking_enable: false,
reranking_model: {
reranking_provider_name: 'provider',
reranking_model_name: 'rerank-model',
},
top_k: 4,
score_threshold_enabled: false,
score_threshold: 0,
}
const defaultIndexingTechnique: IndexingType = 'high_quality' as IndexingType
const createDataset = (overrides: Partial<DataSet> = {}): DataSet => {
const {
retrieval_model,
retrieval_model_dict,
icon_info,
...restOverrides
} = overrides
const resolvedRetrievalModelDict = {
...baseRetrievalConfig,
...retrieval_model_dict,
}
const resolvedRetrievalModel = {
...baseRetrievalConfig,
...(retrieval_model ?? retrieval_model_dict),
}
const defaultIconInfo = {
icon: '📘',
icon_type: 'emoji',
icon_background: '#FFEAD5',
icon_url: '',
}
const resolvedIconInfo = ('icon_info' in overrides)
? icon_info
: defaultIconInfo
return {
id: 'dataset-id',
name: 'Dataset Name',
indexing_status: 'completed',
icon_info: resolvedIconInfo as DataSet['icon_info'],
description: 'A test dataset',
permission: DatasetPermission.onlyMe,
data_source_type: DataSourceType.FILE,
indexing_technique: defaultIndexingTechnique,
author_name: 'author',
created_by: 'creator',
updated_by: 'updater',
updated_at: 0,
app_count: 0,
doc_form: ChunkingMode.text,
document_count: 0,
total_document_count: 0,
total_available_documents: 0,
word_count: 0,
provider: 'dify',
embedding_model: 'text-embedding',
embedding_model_provider: 'openai',
embedding_available: true,
retrieval_model_dict: resolvedRetrievalModelDict,
retrieval_model: resolvedRetrievalModel,
tags: [],
external_knowledge_info: {
external_knowledge_id: 'external-id',
external_knowledge_api_id: 'api-id',
external_knowledge_api_name: 'api-name',
external_knowledge_api_endpoint: 'https://endpoint',
},
external_retrieval_model: {
top_k: 2,
score_threshold: 0.5,
score_threshold_enabled: true,
},
built_in_field_enabled: true,
doc_metadata: [],
keyword_number: 3,
pipeline_id: 'pipeline-id',
is_published: true,
runtime_mode: 'general',
enable_api: true,
is_multimodal: false,
...restOverrides,
}
}
const renderItem = (config: DataSet, props?: Partial<React.ComponentProps<typeof Item>>) => {
const onSave = jest.fn()
const onRemove = jest.fn()
render(
<Item
config={config}
onSave={onSave}
onRemove={onRemove}
{...props}
/>,
)
return { onSave, onRemove }
}
describe('dataset-config/card-item', () => {
beforeEach(() => {
jest.clearAllMocks()
mockedUseBreakpoints.mockReturnValue(MediaType.pc)
})
it('should render dataset details with indexing and external badges', () => {
const dataset = createDataset({
provider: 'external',
retrieval_model_dict: {
...baseRetrievalConfig,
search_method: RETRIEVE_METHOD.semantic,
},
})
renderItem(dataset)
const card = screen.getByText(dataset.name).closest('.group') as HTMLElement
const actionButtons = within(card).getAllByRole('button', { hidden: true })
expect(screen.getByText(dataset.name)).toBeInTheDocument()
expect(screen.getByText('dataset.indexingTechnique.high_quality · dataset.indexingMethod.semantic_search')).toBeInTheDocument()
expect(screen.getByText('dataset.externalTag')).toBeInTheDocument()
expect(actionButtons).toHaveLength(2)
})
it('should open settings drawer from edit action and close after saving', async () => {
const user = userEvent.setup()
const dataset = createDataset()
const { onSave } = renderItem(dataset)
const card = screen.getByText(dataset.name).closest('.group') as HTMLElement
const [editButton] = within(card).getAllByRole('button', { hidden: true })
await user.click(editButton)
expect(screen.getByText('Mock settings modal')).toBeInTheDocument()
await waitFor(() => {
expect(screen.getByRole('dialog')).toBeVisible()
})
await user.click(screen.getByText('Save changes'))
await waitFor(() => {
expect(onSave).toHaveBeenCalledWith(expect.objectContaining({ name: 'Updated dataset' }))
})
await waitFor(() => {
expect(screen.getByText('Mock settings modal')).not.toBeVisible()
})
})
it('should call onRemove and toggle destructive state on hover', async () => {
const user = userEvent.setup()
const dataset = createDataset()
const { onRemove } = renderItem(dataset)
const card = screen.getByText(dataset.name).closest('.group') as HTMLElement
const buttons = within(card).getAllByRole('button', { hidden: true })
const deleteButton = buttons[buttons.length - 1]
expect(deleteButton.className).not.toContain('action-btn-destructive')
fireEvent.mouseEnter(deleteButton)
expect(deleteButton.className).toContain('action-btn-destructive')
expect(card.className).toContain('border-state-destructive-border')
fireEvent.mouseLeave(deleteButton)
expect(deleteButton.className).not.toContain('action-btn-destructive')
await user.click(deleteButton)
expect(onRemove).toHaveBeenCalledWith(dataset.id)
})
it('should use default icon information when icon details are missing', () => {
const dataset = createDataset({ icon_info: undefined })
renderItem(dataset)
const nameElement = screen.getByText(dataset.name)
const iconElement = nameElement.parentElement?.firstElementChild as HTMLElement
expect(iconElement).toHaveStyle({ background: '#FFF4ED' })
expect(iconElement.querySelector('em-emoji')).toHaveAttribute('id', '📙')
})
it('should apply mask overlay on mobile when drawer is open', async () => {
mockedUseBreakpoints.mockReturnValue(MediaType.mobile)
const user = userEvent.setup()
const dataset = createDataset()
renderItem(dataset)
const card = screen.getByText(dataset.name).closest('.group') as HTMLElement
const [editButton] = within(card).getAllByRole('button', { hidden: true })
await user.click(editButton)
expect(screen.getByText('Mock settings modal')).toBeInTheDocument()
const overlay = Array.from(document.querySelectorAll('[class]'))
.find(element => element.className.toString().includes('bg-black/30'))
expect(overlay).toBeInTheDocument()
})
})

View File

@ -0,0 +1,299 @@
import * as React from 'react'
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import ContextVar from './index'
import type { Props } from './var-picker'
// Mock external dependencies only
jest.mock('next/navigation', () => ({
useRouter: () => ({ push: jest.fn() }),
usePathname: () => '/test',
}))
type PortalToFollowElemProps = {
children: React.ReactNode
open?: boolean
onOpenChange?: (open: boolean) => void
}
type PortalToFollowElemTriggerProps = React.HTMLAttributes<HTMLElement> & { children?: React.ReactNode; asChild?: boolean }
type PortalToFollowElemContentProps = React.HTMLAttributes<HTMLDivElement> & { children?: React.ReactNode }
jest.mock('@/app/components/base/portal-to-follow-elem', () => {
const PortalContext = React.createContext({ open: false })
const PortalToFollowElem = ({ children, open }: PortalToFollowElemProps) => {
return (
<PortalContext.Provider value={{ open: !!open }}>
<div data-testid="portal">{children}</div>
</PortalContext.Provider>
)
}
const PortalToFollowElemContent = ({ children, ...props }: PortalToFollowElemContentProps) => {
const { open } = React.useContext(PortalContext)
if (!open) return null
return (
<div data-testid="portal-content" {...props}>
{children}
</div>
)
}
const PortalToFollowElemTrigger = ({ children, asChild, ...props }: PortalToFollowElemTriggerProps) => {
if (asChild && React.isValidElement(children)) {
return React.cloneElement(children, {
...props,
'data-testid': 'portal-trigger',
} as React.HTMLAttributes<HTMLElement>)
}
return (
<div data-testid="portal-trigger" {...props}>
{children}
</div>
)
}
return {
PortalToFollowElem,
PortalToFollowElemContent,
PortalToFollowElemTrigger,
}
})
describe('ContextVar', () => {
const mockOptions: Props['options'] = [
{ name: 'Variable 1', value: 'var1', type: 'string' },
{ name: 'Variable 2', value: 'var2', type: 'number' },
]
const defaultProps: Props = {
value: 'var1',
options: mockOptions,
onChange: jest.fn(),
}
beforeEach(() => {
jest.clearAllMocks()
})
// Rendering tests (REQUIRED)
describe('Rendering', () => {
it('should display query variable selector when options are provided', () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<ContextVar {...props} />)
// Assert
expect(screen.getByText('appDebug.feature.dataSet.queryVariable.title')).toBeInTheDocument()
})
it('should show selected variable with proper formatting when value is provided', () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<ContextVar {...props} />)
// Assert
expect(screen.getByText('var1')).toBeInTheDocument()
expect(screen.getByText('{{')).toBeInTheDocument()
expect(screen.getByText('}}')).toBeInTheDocument()
})
})
// Props tests (REQUIRED)
describe('Props', () => {
it('should display selected variable when value prop is provided', () => {
// Arrange
const props = { ...defaultProps, value: 'var2' }
// Act
render(<ContextVar {...props} />)
// Assert - Should display the selected value
expect(screen.getByText('var2')).toBeInTheDocument()
})
it('should show placeholder text when no value is selected', () => {
// Arrange
const props = {
...defaultProps,
value: undefined,
}
// Act
render(<ContextVar {...props} />)
// Assert - Should show placeholder instead of variable
expect(screen.queryByText('var1')).not.toBeInTheDocument()
expect(screen.getByText('appDebug.feature.dataSet.queryVariable.choosePlaceholder')).toBeInTheDocument()
})
it('should display custom tip message when notSelectedVarTip is provided', () => {
// Arrange
const props = {
...defaultProps,
value: undefined,
notSelectedVarTip: 'Select a variable',
}
// Act
render(<ContextVar {...props} />)
// Assert
expect(screen.getByText('Select a variable')).toBeInTheDocument()
})
it('should apply custom className to VarPicker when provided', () => {
// Arrange
const props = {
...defaultProps,
className: 'custom-class',
}
// Act
const { container } = render(<ContextVar {...props} />)
// Assert
expect(container.querySelector('.custom-class')).toBeInTheDocument()
})
})
// User Interactions
describe('User Interactions', () => {
it('should call onChange when user selects a different variable', async () => {
// Arrange
const onChange = jest.fn()
const props = { ...defaultProps, onChange }
const user = userEvent.setup()
// Act
render(<ContextVar {...props} />)
const triggers = screen.getAllByTestId('portal-trigger')
const varPickerTrigger = triggers[triggers.length - 1]
await user.click(varPickerTrigger)
expect(screen.getByTestId('portal-content')).toBeInTheDocument()
// Select a different option
const options = screen.getAllByText('var2')
expect(options.length).toBeGreaterThan(0)
await user.click(options[0])
// Assert
expect(onChange).toHaveBeenCalledWith('var2')
expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument()
})
it('should toggle dropdown when clicking the trigger button', async () => {
// Arrange
const props = { ...defaultProps }
const user = userEvent.setup()
// Act
render(<ContextVar {...props} />)
const triggers = screen.getAllByTestId('portal-trigger')
const varPickerTrigger = triggers[triggers.length - 1]
// Open dropdown
await user.click(varPickerTrigger)
expect(screen.getByTestId('portal-content')).toBeInTheDocument()
// Close dropdown
await user.click(varPickerTrigger)
expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument()
})
})
// Edge Cases (REQUIRED)
describe('Edge Cases', () => {
it('should handle undefined value gracefully', () => {
// Arrange
const props = {
...defaultProps,
value: undefined,
}
// Act
render(<ContextVar {...props} />)
// Assert
expect(screen.getByText('appDebug.feature.dataSet.queryVariable.title')).toBeInTheDocument()
expect(screen.getByText('appDebug.feature.dataSet.queryVariable.choosePlaceholder')).toBeInTheDocument()
expect(screen.queryByText('var1')).not.toBeInTheDocument()
})
it('should handle empty options array', () => {
// Arrange
const props = {
...defaultProps,
options: [],
value: undefined,
}
// Act
render(<ContextVar {...props} />)
// Assert
expect(screen.getByText('appDebug.feature.dataSet.queryVariable.title')).toBeInTheDocument()
expect(screen.getByText('appDebug.feature.dataSet.queryVariable.choosePlaceholder')).toBeInTheDocument()
})
it('should handle null value without crashing', () => {
// Arrange
const props = {
...defaultProps,
value: undefined,
}
// Act
render(<ContextVar {...props} />)
// Assert
expect(screen.getByText('appDebug.feature.dataSet.queryVariable.title')).toBeInTheDocument()
expect(screen.getByText('appDebug.feature.dataSet.queryVariable.choosePlaceholder')).toBeInTheDocument()
})
it('should handle options with different data types', () => {
// Arrange
const props = {
...defaultProps,
options: [
{ name: 'String Var', value: 'strVar', type: 'string' },
{ name: 'Number Var', value: '42', type: 'number' },
{ name: 'Boolean Var', value: 'true', type: 'boolean' },
],
value: 'strVar',
}
// Act
render(<ContextVar {...props} />)
// Assert
expect(screen.getByText('strVar')).toBeInTheDocument()
expect(screen.getByText('{{')).toBeInTheDocument()
expect(screen.getByText('}}')).toBeInTheDocument()
})
it('should render variable names with special characters safely', () => {
// Arrange
const props = {
...defaultProps,
options: [
{ name: 'Variable with & < > " \' characters', value: 'specialVar', type: 'string' },
],
value: 'specialVar',
}
// Act
render(<ContextVar {...props} />)
// Assert
expect(screen.getByText('specialVar')).toBeInTheDocument()
})
})
})

View File

@ -0,0 +1,392 @@
import * as React from 'react'
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import VarPicker, { type Props } from './var-picker'
// Mock external dependencies only
jest.mock('next/navigation', () => ({
useRouter: () => ({ push: jest.fn() }),
usePathname: () => '/test',
}))
type PortalToFollowElemProps = {
children: React.ReactNode
open?: boolean
onOpenChange?: (open: boolean) => void
}
type PortalToFollowElemTriggerProps = React.HTMLAttributes<HTMLElement> & { children?: React.ReactNode; asChild?: boolean }
type PortalToFollowElemContentProps = React.HTMLAttributes<HTMLDivElement> & { children?: React.ReactNode }
jest.mock('@/app/components/base/portal-to-follow-elem', () => {
const PortalContext = React.createContext({ open: false })
const PortalToFollowElem = ({ children, open }: PortalToFollowElemProps) => {
return (
<PortalContext.Provider value={{ open: !!open }}>
<div data-testid="portal">{children}</div>
</PortalContext.Provider>
)
}
const PortalToFollowElemContent = ({ children, ...props }: PortalToFollowElemContentProps) => {
const { open } = React.useContext(PortalContext)
if (!open) return null
return (
<div data-testid="portal-content" {...props}>
{children}
</div>
)
}
const PortalToFollowElemTrigger = ({ children, asChild, ...props }: PortalToFollowElemTriggerProps) => {
if (asChild && React.isValidElement(children)) {
return React.cloneElement(children, {
...props,
'data-testid': 'portal-trigger',
} as React.HTMLAttributes<HTMLElement>)
}
return (
<div data-testid="portal-trigger" {...props}>
{children}
</div>
)
}
return {
PortalToFollowElem,
PortalToFollowElemContent,
PortalToFollowElemTrigger,
}
})
describe('VarPicker', () => {
const mockOptions: Props['options'] = [
{ name: 'Variable 1', value: 'var1', type: 'string' },
{ name: 'Variable 2', value: 'var2', type: 'number' },
{ name: 'Variable 3', value: 'var3', type: 'boolean' },
]
const defaultProps: Props = {
value: 'var1',
options: mockOptions,
onChange: jest.fn(),
}
beforeEach(() => {
jest.clearAllMocks()
})
// Rendering tests (REQUIRED)
describe('Rendering', () => {
it('should render variable picker with dropdown trigger', () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<VarPicker {...props} />)
// Assert
expect(screen.getByTestId('portal-trigger')).toBeInTheDocument()
expect(screen.getByText('var1')).toBeInTheDocument()
})
it('should display selected variable with type icon when value is provided', () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<VarPicker {...props} />)
// Assert
expect(screen.getByText('var1')).toBeInTheDocument()
expect(screen.getByText('{{')).toBeInTheDocument()
expect(screen.getByText('}}')).toBeInTheDocument()
// IconTypeIcon should be rendered (check for svg icon)
expect(document.querySelector('svg')).toBeInTheDocument()
})
it('should show placeholder text when no value is selected', () => {
// Arrange
const props = {
...defaultProps,
value: undefined,
}
// Act
render(<VarPicker {...props} />)
// Assert
expect(screen.queryByText('var1')).not.toBeInTheDocument()
expect(screen.getByText('appDebug.feature.dataSet.queryVariable.choosePlaceholder')).toBeInTheDocument()
})
it('should display custom tip message when notSelectedVarTip is provided', () => {
// Arrange
const props = {
...defaultProps,
value: undefined,
notSelectedVarTip: 'Select a variable',
}
// Act
render(<VarPicker {...props} />)
// Assert
expect(screen.getByText('Select a variable')).toBeInTheDocument()
})
it('should render dropdown indicator icon', () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<VarPicker {...props} />)
// Assert - Trigger should be present
expect(screen.getByTestId('portal-trigger')).toBeInTheDocument()
})
})
// Props tests (REQUIRED)
describe('Props', () => {
it('should apply custom className to wrapper', () => {
// Arrange
const props = {
...defaultProps,
className: 'custom-class',
}
// Act
const { container } = render(<VarPicker {...props} />)
// Assert
expect(container.querySelector('.custom-class')).toBeInTheDocument()
})
it('should apply custom triggerClassName to trigger button', () => {
// Arrange
const props = {
...defaultProps,
triggerClassName: 'custom-trigger-class',
}
// Act
render(<VarPicker {...props} />)
// Assert
expect(screen.getByTestId('portal-trigger')).toHaveClass('custom-trigger-class')
})
it('should display selected value with proper formatting', () => {
// Arrange
const props = {
...defaultProps,
value: 'customVar',
options: [
{ name: 'Custom Variable', value: 'customVar', type: 'string' },
],
}
// Act
render(<VarPicker {...props} />)
// Assert
expect(screen.getByText('customVar')).toBeInTheDocument()
expect(screen.getByText('{{')).toBeInTheDocument()
expect(screen.getByText('}}')).toBeInTheDocument()
})
})
// User Interactions
describe('User Interactions', () => {
it('should open dropdown when clicking the trigger button', async () => {
// Arrange
const onChange = jest.fn()
const props = { ...defaultProps, onChange }
const user = userEvent.setup()
// Act
render(<VarPicker {...props} />)
await user.click(screen.getByTestId('portal-trigger'))
// Assert
expect(screen.getByTestId('portal-content')).toBeInTheDocument()
})
it('should call onChange and close dropdown when selecting an option', async () => {
// Arrange
const onChange = jest.fn()
const props = { ...defaultProps, onChange }
const user = userEvent.setup()
// Act
render(<VarPicker {...props} />)
// Open dropdown
await user.click(screen.getByTestId('portal-trigger'))
expect(screen.getByTestId('portal-content')).toBeInTheDocument()
// Select a different option
const options = screen.getAllByText('var2')
expect(options.length).toBeGreaterThan(0)
await user.click(options[0])
// Assert
expect(onChange).toHaveBeenCalledWith('var2')
expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument()
})
it('should toggle dropdown when clicking trigger button multiple times', async () => {
// Arrange
const props = { ...defaultProps }
const user = userEvent.setup()
// Act
render(<VarPicker {...props} />)
const trigger = screen.getByTestId('portal-trigger')
// Open dropdown
await user.click(trigger)
expect(screen.getByTestId('portal-content')).toBeInTheDocument()
// Close dropdown
await user.click(trigger)
expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument()
})
})
// State Management
describe('State Management', () => {
it('should initialize with closed dropdown', () => {
// Arrange
const props = { ...defaultProps }
// Act
render(<VarPicker {...props} />)
// Assert
expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument()
})
it('should toggle dropdown state on trigger click', async () => {
// Arrange
const props = { ...defaultProps }
const user = userEvent.setup()
// Act
render(<VarPicker {...props} />)
const trigger = screen.getByTestId('portal-trigger')
expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument()
// Open dropdown
await user.click(trigger)
expect(screen.getByTestId('portal-content')).toBeInTheDocument()
// Close dropdown
await user.click(trigger)
expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument()
})
it('should preserve selected value when dropdown is closed without selection', async () => {
// Arrange
const props = { ...defaultProps }
const user = userEvent.setup()
// Act
render(<VarPicker {...props} />)
// Open and close dropdown without selecting anything
const trigger = screen.getByTestId('portal-trigger')
await user.click(trigger)
await user.click(trigger)
// Assert
expect(screen.getByText('var1')).toBeInTheDocument() // Original value still displayed
})
})
// Edge Cases (REQUIRED)
describe('Edge Cases', () => {
it('should handle undefined value gracefully', () => {
// Arrange
const props = {
...defaultProps,
value: undefined,
}
// Act
render(<VarPicker {...props} />)
// Assert
expect(screen.getByText('appDebug.feature.dataSet.queryVariable.choosePlaceholder')).toBeInTheDocument()
expect(screen.getByTestId('portal-trigger')).toBeInTheDocument()
})
it('should handle empty options array', () => {
// Arrange
const props = {
...defaultProps,
options: [],
value: undefined,
}
// Act
render(<VarPicker {...props} />)
// Assert
expect(screen.getByTestId('portal-trigger')).toBeInTheDocument()
expect(screen.getByText('appDebug.feature.dataSet.queryVariable.choosePlaceholder')).toBeInTheDocument()
})
it('should handle null value without crashing', () => {
// Arrange
const props = {
...defaultProps,
value: undefined,
}
// Act
render(<VarPicker {...props} />)
// Assert
expect(screen.getByText('appDebug.feature.dataSet.queryVariable.choosePlaceholder')).toBeInTheDocument()
})
it('should handle variable names with special characters safely', () => {
// Arrange
const props = {
...defaultProps,
options: [
{ name: 'Variable with & < > " \' characters', value: 'specialVar', type: 'string' },
],
value: 'specialVar',
}
// Act
render(<VarPicker {...props} />)
// Assert
expect(screen.getByText('specialVar')).toBeInTheDocument()
})
it('should handle long variable names', () => {
// Arrange
const props = {
...defaultProps,
options: [
{ name: 'A very long variable name that should be truncated', value: 'longVar', type: 'string' },
],
value: 'longVar',
}
// Act
render(<VarPicker {...props} />)
// Assert
expect(screen.getByText('longVar')).toBeInTheDocument()
expect(screen.getByTestId('portal-trigger')).toBeInTheDocument()
})
})
})

View File

@ -51,12 +51,6 @@ const mockFiles: FileEntity[] = [
},
]
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
jest.mock('@/context/debug-configuration', () => ({
__esModule: true,
useDebugConfigurationContext: () => mockUseDebugConfigurationContext(),

View File

@ -0,0 +1,347 @@
import { render, screen, within } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import AppCard from './index'
import type { AppIconType } from '@/types/app'
import { AppModeEnum } from '@/types/app'
import type { App } from '@/models/explore'
jest.mock('@heroicons/react/20/solid', () => ({
PlusIcon: ({ className }: any) => <div data-testid="plus-icon" className={className} aria-label="Add icon">+</div>,
}))
const mockApp: App = {
app: {
id: 'test-app-id',
mode: AppModeEnum.CHAT,
icon_type: 'emoji' as AppIconType,
icon: '🤖',
icon_background: '#FFEAD5',
icon_url: '',
name: 'Test Chat App',
description: 'A test chat application for demonstration purposes',
use_icon_as_answer_icon: false,
},
app_id: 'test-app-id',
description: 'A comprehensive chat application template',
copyright: 'Test Corp',
privacy_policy: null,
custom_disclaimer: null,
category: 'Assistant',
position: 1,
is_listed: true,
install_count: 100,
installed: false,
editable: true,
is_agent: false,
}
describe('AppCard', () => {
const defaultProps = {
app: mockApp,
canCreate: true,
onCreate: jest.fn(),
}
beforeEach(() => {
jest.clearAllMocks()
})
describe('Rendering', () => {
it('should render without crashing', () => {
const { container } = render(<AppCard {...defaultProps} />)
expect(container.querySelector('em-emoji')).toBeInTheDocument()
expect(screen.getByText('Test Chat App')).toBeInTheDocument()
expect(screen.getByText(mockApp.description)).toBeInTheDocument()
})
it('should render app type icon and label', () => {
const { container } = render(<AppCard {...defaultProps} />)
expect(container.querySelector('svg')).toBeInTheDocument()
expect(screen.getByText('app.typeSelector.chatbot')).toBeInTheDocument()
})
})
describe('Props', () => {
describe('canCreate behavior', () => {
it('should show create button when canCreate is true', () => {
render(<AppCard {...defaultProps} canCreate={true} />)
const button = screen.getByRole('button', { name: /app\.newApp\.useTemplate/ })
expect(button).toBeInTheDocument()
})
it('should hide create button when canCreate is false', () => {
render(<AppCard {...defaultProps} canCreate={false} />)
const button = screen.queryByRole('button', { name: /app\.newApp\.useTemplate/ })
expect(button).not.toBeInTheDocument()
})
})
it('should display app name from appBasicInfo', () => {
const customApp = {
...mockApp,
app: {
...mockApp.app,
name: 'Custom App Name',
},
}
render(<AppCard {...defaultProps} app={customApp} />)
expect(screen.getByText('Custom App Name')).toBeInTheDocument()
})
it('should display app description from app level', () => {
const customApp = {
...mockApp,
description: 'Custom description for the app',
}
render(<AppCard {...defaultProps} app={customApp} />)
expect(screen.getByText('Custom description for the app')).toBeInTheDocument()
})
it('should truncate long app names', () => {
const longNameApp = {
...mockApp,
app: {
...mockApp.app,
name: 'This is a very long app name that should be truncated with line-clamp-1',
},
}
render(<AppCard {...defaultProps} app={longNameApp} />)
const nameElement = screen.getByTitle('This is a very long app name that should be truncated with line-clamp-1')
expect(nameElement).toBeInTheDocument()
})
})
describe('App Modes - Data Driven Tests', () => {
const testCases = [
{
mode: AppModeEnum.CHAT,
expectedLabel: 'app.typeSelector.chatbot',
description: 'Chat application mode',
},
{
mode: AppModeEnum.AGENT_CHAT,
expectedLabel: 'app.typeSelector.agent',
description: 'Agent chat mode',
},
{
mode: AppModeEnum.COMPLETION,
expectedLabel: 'app.typeSelector.completion',
description: 'Completion mode',
},
{
mode: AppModeEnum.ADVANCED_CHAT,
expectedLabel: 'app.typeSelector.advanced',
description: 'Advanced chat mode',
},
{
mode: AppModeEnum.WORKFLOW,
expectedLabel: 'app.typeSelector.workflow',
description: 'Workflow mode',
},
]
testCases.forEach(({ mode, expectedLabel, description }) => {
it(`should display correct type label for ${description}`, () => {
const appWithMode = {
...mockApp,
app: {
...mockApp.app,
mode,
},
}
render(<AppCard {...defaultProps} app={appWithMode} />)
expect(screen.getByText(expectedLabel)).toBeInTheDocument()
})
})
})
describe('Icon Type Tests', () => {
it('should render emoji icon without image element', () => {
const appWithIcon = {
...mockApp,
app: {
...mockApp.app,
icon_type: 'emoji' as AppIconType,
icon: '🤖',
},
}
const { container } = render(<AppCard {...defaultProps} app={appWithIcon} />)
const card = container.firstElementChild as HTMLElement
expect(within(card).queryByRole('img', { name: 'app icon' })).not.toBeInTheDocument()
expect(card.querySelector('em-emoji')).toBeInTheDocument()
})
it('should prioritize icon_url when both icon and icon_url are provided', () => {
const appWithImageUrl = {
...mockApp,
app: {
...mockApp.app,
icon_type: 'image' as AppIconType,
icon: 'local-icon.png',
icon_url: 'https://example.com/remote-icon.png',
},
}
render(<AppCard {...defaultProps} app={appWithImageUrl} />)
expect(screen.getByRole('img', { name: 'app icon' })).toHaveAttribute('src', 'https://example.com/remote-icon.png')
})
})
describe('User Interactions', () => {
it('should call onCreate when create button is clicked', async () => {
const mockOnCreate = jest.fn()
render(<AppCard {...defaultProps} onCreate={mockOnCreate} />)
const button = screen.getByRole('button', { name: /app\.newApp\.useTemplate/ })
await userEvent.click(button)
expect(mockOnCreate).toHaveBeenCalledTimes(1)
})
it('should handle click on card itself', async () => {
const mockOnCreate = jest.fn()
const { container } = render(<AppCard {...defaultProps} onCreate={mockOnCreate} />)
const card = container.firstElementChild as HTMLElement
await userEvent.click(card)
// Note: Card click doesn't trigger onCreate, only the button does
expect(mockOnCreate).not.toHaveBeenCalled()
})
})
describe('Keyboard Accessibility', () => {
it('should allow the create button to be focused', async () => {
const mockOnCreate = jest.fn()
render(<AppCard {...defaultProps} onCreate={mockOnCreate} />)
await userEvent.tab()
const button = screen.getByRole('button', { name: /app\.newApp\.useTemplate/ }) as HTMLButtonElement
// Test that button can be focused
expect(button).toHaveFocus()
// Test click event works (keyboard events on buttons typically trigger click)
await userEvent.click(button)
expect(mockOnCreate).toHaveBeenCalledTimes(1)
})
})
describe('Edge Cases', () => {
it('should handle app with null icon_type', () => {
const appWithNullIcon = {
...mockApp,
app: {
...mockApp.app,
icon_type: null,
},
}
const { container } = render(<AppCard {...defaultProps} app={appWithNullIcon} />)
const appIcon = container.querySelector('em-emoji')
expect(appIcon).toBeInTheDocument()
// AppIcon component should handle null icon_type gracefully
})
it('should handle app with empty description', () => {
const appWithEmptyDesc = {
...mockApp,
description: '',
}
const { container } = render(<AppCard {...defaultProps} app={appWithEmptyDesc} />)
const descriptionContainer = container.querySelector('.line-clamp-3')
expect(descriptionContainer).toBeInTheDocument()
expect(descriptionContainer).toHaveTextContent('')
})
it('should handle app with very long description', () => {
const longDescription = 'This is a very long description that should be truncated with line-clamp-3. '.repeat(5)
const appWithLongDesc = {
...mockApp,
description: longDescription,
}
render(<AppCard {...defaultProps} app={appWithLongDesc} />)
expect(screen.getByText(/This is a very long description/)).toBeInTheDocument()
})
it('should handle app with special characters in name', () => {
const appWithSpecialChars = {
...mockApp,
app: {
...mockApp.app,
name: 'App <script>alert("test")</script> & Special "Chars"',
},
}
render(<AppCard {...defaultProps} app={appWithSpecialChars} />)
expect(screen.getByText('App <script>alert("test")</script> & Special "Chars"')).toBeInTheDocument()
})
it('should handle onCreate function throwing error', async () => {
const errorOnCreate = jest.fn(() => {
throw new Error('Create failed')
})
// Mock console.error to avoid test output noise
const consoleSpy = jest.spyOn(console, 'error').mockImplementation(jest.fn())
render(<AppCard {...defaultProps} onCreate={errorOnCreate} />)
const button = screen.getByRole('button', { name: /app\.newApp\.useTemplate/ })
let capturedError: unknown
try {
await userEvent.click(button)
}
catch (err) {
capturedError = err
}
expect(errorOnCreate).toHaveBeenCalledTimes(1)
expect(consoleSpy).toHaveBeenCalled()
if (capturedError instanceof Error)
expect(capturedError.message).toContain('Create failed')
consoleSpy.mockRestore()
})
})
describe('Accessibility', () => {
it('should have proper elements for accessibility', () => {
const { container } = render(<AppCard {...defaultProps} />)
expect(container.querySelector('em-emoji')).toBeInTheDocument()
expect(container.querySelector('svg')).toBeInTheDocument()
})
it('should have title attribute for app name when truncated', () => {
render(<AppCard {...defaultProps} />)
const nameElement = screen.getByText('Test Chat App')
expect(nameElement).toHaveAttribute('title', 'Test Chat App')
})
it('should have accessible button with proper label', () => {
render(<AppCard {...defaultProps} />)
const button = screen.getByRole('button', { name: /app\.newApp\.useTemplate/ })
expect(button).toBeEnabled()
expect(button).toHaveTextContent('app.newApp.useTemplate')
})
})
describe('User-Visible Behavior Tests', () => {
it('should show plus icon in create button', () => {
render(<AppCard {...defaultProps} />)
expect(screen.getByTestId('plus-icon')).toBeInTheDocument()
})
})
})

View File

@ -15,6 +15,7 @@ export type AppCardProps = {
const AppCard = ({
app,
canCreate,
onCreate,
}: AppCardProps) => {
const { t } = useTranslation()
@ -45,14 +46,16 @@ const AppCard = ({
{app.description}
</div>
</div>
<div className={cn('absolute bottom-0 left-0 right-0 hidden bg-gradient-to-t from-components-panel-gradient-2 from-[60.27%] to-transparent p-4 pt-8 group-hover:flex')}>
<div className={cn('flex h-8 w-full items-center space-x-2')}>
<Button variant='primary' className='grow' onClick={() => onCreate()}>
<PlusIcon className='mr-1 h-4 w-4' />
<span className='text-xs'>{t('app.newApp.useTemplate')}</span>
</Button>
{canCreate && (
<div className={cn('absolute bottom-0 left-0 right-0 hidden bg-gradient-to-t from-components-panel-gradient-2 from-[60.27%] to-transparent p-4 pt-8 group-hover:flex')}>
<div className={cn('flex h-8 w-full items-center space-x-2')}>
<Button variant='primary' className='grow' onClick={() => onCreate()}>
<PlusIcon className='mr-1 h-4 w-4' />
<span className='text-xs'>{t('app.newApp.useTemplate')}</span>
</Button>
</div>
</div>
</div>
)}
</div>
)
}

View File

@ -0,0 +1,287 @@
import { fireEvent, render, screen } from '@testing-library/react'
import CreateAppTemplateDialog from './index'
// Mock external dependencies (not base components)
jest.mock('./app-list', () => {
return function MockAppList({
onCreateFromBlank,
onSuccess,
}: {
onCreateFromBlank?: () => void
onSuccess: () => void
}) {
return (
<div data-testid="app-list">
<button data-testid="app-list-success" onClick={onSuccess}>
Success
</button>
{onCreateFromBlank && (
<button data-testid="create-from-blank" onClick={onCreateFromBlank}>
Create from Blank
</button>
)}
</div>
)
}
})
jest.mock('ahooks', () => ({
useKeyPress: jest.fn((key: string, callback: () => void) => {
// Mock implementation for testing
return jest.fn()
}),
}))
describe('CreateAppTemplateDialog', () => {
const defaultProps = {
show: false,
onSuccess: jest.fn(),
onClose: jest.fn(),
onCreateFromBlank: jest.fn(),
}
beforeEach(() => {
jest.clearAllMocks()
})
describe('Rendering', () => {
it('should not render when show is false', () => {
render(<CreateAppTemplateDialog {...defaultProps} />)
// FullScreenModal should not render any content when open is false
expect(screen.queryByRole('dialog')).not.toBeInTheDocument()
})
it('should render modal when show is true', () => {
render(<CreateAppTemplateDialog {...defaultProps} show={true} />)
// FullScreenModal renders with role="dialog"
expect(screen.getByRole('dialog')).toBeInTheDocument()
expect(screen.getByTestId('app-list')).toBeInTheDocument()
})
it('should render create from blank button when onCreateFromBlank is provided', () => {
render(<CreateAppTemplateDialog {...defaultProps} show={true} />)
expect(screen.getByTestId('create-from-blank')).toBeInTheDocument()
})
it('should not render create from blank button when onCreateFromBlank is not provided', () => {
const { onCreateFromBlank, ...propsWithoutOnCreate } = defaultProps
render(<CreateAppTemplateDialog {...propsWithoutOnCreate} show={true} />)
expect(screen.queryByTestId('create-from-blank')).not.toBeInTheDocument()
})
})
describe('Props', () => {
it('should pass show prop to FullScreenModal', () => {
const { rerender } = render(<CreateAppTemplateDialog {...defaultProps} />)
expect(screen.queryByRole('dialog')).not.toBeInTheDocument()
rerender(<CreateAppTemplateDialog {...defaultProps} show={true} />)
expect(screen.getByRole('dialog')).toBeInTheDocument()
})
it('should pass closable prop to FullScreenModal', () => {
// Since the FullScreenModal is always rendered with closable=true
// we can verify that the modal renders with the proper structure
render(<CreateAppTemplateDialog {...defaultProps} show={true} />)
// Verify that the modal has the proper dialog structure
const dialog = screen.getByRole('dialog')
expect(dialog).toBeInTheDocument()
expect(dialog).toHaveAttribute('aria-modal', 'true')
})
})
describe('User Interactions', () => {
it('should handle close interactions', () => {
const mockOnClose = jest.fn()
render(<CreateAppTemplateDialog {...defaultProps} show={true} onClose={mockOnClose} />)
// Test that the modal is rendered
const dialog = screen.getByRole('dialog')
expect(dialog).toBeInTheDocument()
// Test that AppList component renders (child component interactions)
expect(screen.getByTestId('app-list')).toBeInTheDocument()
expect(screen.getByTestId('app-list-success')).toBeInTheDocument()
})
it('should call both onSuccess and onClose when app list success is triggered', () => {
const mockOnSuccess = jest.fn()
const mockOnClose = jest.fn()
render(<CreateAppTemplateDialog
{...defaultProps}
show={true}
onSuccess={mockOnSuccess}
onClose={mockOnClose}
/>)
fireEvent.click(screen.getByTestId('app-list-success'))
expect(mockOnSuccess).toHaveBeenCalledTimes(1)
expect(mockOnClose).toHaveBeenCalledTimes(1)
})
it('should call onCreateFromBlank when create from blank is clicked', () => {
const mockOnCreateFromBlank = jest.fn()
render(<CreateAppTemplateDialog
{...defaultProps}
show={true}
onCreateFromBlank={mockOnCreateFromBlank}
/>)
fireEvent.click(screen.getByTestId('create-from-blank'))
expect(mockOnCreateFromBlank).toHaveBeenCalledTimes(1)
})
})
describe('useKeyPress Integration', () => {
it('should set up ESC key listener when modal is shown', () => {
const { useKeyPress } = require('ahooks')
render(<CreateAppTemplateDialog {...defaultProps} show={true} />)
expect(useKeyPress).toHaveBeenCalledWith('esc', expect.any(Function))
})
it('should handle ESC key press to close modal', () => {
const { useKeyPress } = require('ahooks')
let capturedCallback: (() => void) | undefined
useKeyPress.mockImplementation((key: string, callback: () => void) => {
if (key === 'esc')
capturedCallback = callback
return jest.fn()
})
const mockOnClose = jest.fn()
render(<CreateAppTemplateDialog
{...defaultProps}
show={true}
onClose={mockOnClose}
/>)
expect(capturedCallback).toBeDefined()
expect(typeof capturedCallback).toBe('function')
// Simulate ESC key press
capturedCallback?.()
expect(mockOnClose).toHaveBeenCalledTimes(1)
})
it('should not call onClose when ESC key is pressed and modal is not shown', () => {
const { useKeyPress } = require('ahooks')
let capturedCallback: (() => void) | undefined
useKeyPress.mockImplementation((key: string, callback: () => void) => {
if (key === 'esc')
capturedCallback = callback
return jest.fn()
})
const mockOnClose = jest.fn()
render(<CreateAppTemplateDialog
{...defaultProps}
show={false} // Modal not shown
onClose={mockOnClose}
/>)
// The callback should still be created but not execute onClose
expect(capturedCallback).toBeDefined()
// Simulate ESC key press
capturedCallback?.()
// onClose should not be called because modal is not shown
expect(mockOnClose).not.toHaveBeenCalled()
})
})
describe('Callback Dependencies', () => {
it('should create stable callback reference for ESC key handler', () => {
const { useKeyPress } = require('ahooks')
render(<CreateAppTemplateDialog {...defaultProps} show={true} />)
// Verify that useKeyPress was called with a function
const calls = useKeyPress.mock.calls
expect(calls.length).toBeGreaterThan(0)
expect(calls[0][0]).toBe('esc')
expect(typeof calls[0][1]).toBe('function')
})
})
describe('Edge Cases', () => {
it('should handle null props gracefully', () => {
expect(() => {
render(<CreateAppTemplateDialog
show={true}
onSuccess={jest.fn()}
onClose={jest.fn()}
// onCreateFromBlank is undefined
/>)
}).not.toThrow()
})
it('should handle undefined props gracefully', () => {
expect(() => {
render(<CreateAppTemplateDialog
show={true}
onSuccess={jest.fn()}
onClose={jest.fn()}
onCreateFromBlank={undefined}
/>)
}).not.toThrow()
})
it('should handle rapid show/hide toggles', () => {
// Test initial state
const { unmount } = render(<CreateAppTemplateDialog {...defaultProps} show={false} />)
unmount()
// Test show state
render(<CreateAppTemplateDialog {...defaultProps} show={true} />)
expect(screen.getByRole('dialog')).toBeInTheDocument()
// Test hide state
render(<CreateAppTemplateDialog {...defaultProps} show={false} />)
// Due to transition animations, we just verify the component handles the prop change
expect(() => render(<CreateAppTemplateDialog {...defaultProps} show={false} />)).not.toThrow()
})
it('should handle missing optional onCreateFromBlank prop', () => {
const { onCreateFromBlank, ...propsWithoutOnCreate } = defaultProps
expect(() => {
render(<CreateAppTemplateDialog {...propsWithoutOnCreate} show={true} />)
}).not.toThrow()
expect(screen.getByTestId('app-list')).toBeInTheDocument()
expect(screen.queryByTestId('create-from-blank')).not.toBeInTheDocument()
})
it('should work with all required props only', () => {
const requiredProps = {
show: true,
onSuccess: jest.fn(),
onClose: jest.fn(),
}
expect(() => {
render(<CreateAppTemplateDialog {...requiredProps} />)
}).not.toThrow()
expect(screen.getByRole('dialog')).toBeInTheDocument()
expect(screen.getByTestId('app-list')).toBeInTheDocument()
})
})
})

View File

@ -0,0 +1,209 @@
import type { RenderOptions } from '@testing-library/react'
import { fireEvent, render } from '@testing-library/react'
import { defaultPlan } from '@/app/components/billing/config'
import { noop } from 'lodash-es'
import type { ModalContextState } from '@/context/modal-context'
import APIKeyInfoPanel from './index'
// Mock the modules before importing the functions
jest.mock('@/context/provider-context', () => ({
useProviderContext: jest.fn(),
}))
jest.mock('@/context/modal-context', () => ({
useModalContext: jest.fn(),
}))
import { useProviderContext as actualUseProviderContext } from '@/context/provider-context'
import { useModalContext as actualUseModalContext } from '@/context/modal-context'
// Type casting for mocks
const mockUseProviderContext = actualUseProviderContext as jest.MockedFunction<typeof actualUseProviderContext>
const mockUseModalContext = actualUseModalContext as jest.MockedFunction<typeof actualUseModalContext>
// Default mock data
const defaultProviderContext = {
modelProviders: [],
refreshModelProviders: noop,
textGenerationModelList: [],
supportRetrievalMethods: [],
isAPIKeySet: false,
plan: defaultPlan,
isFetchedPlan: false,
enableBilling: false,
onPlanInfoChanged: noop,
enableReplaceWebAppLogo: false,
modelLoadBalancingEnabled: false,
datasetOperatorEnabled: false,
enableEducationPlan: false,
isEducationWorkspace: false,
isEducationAccount: false,
allowRefreshEducationVerify: false,
educationAccountExpireAt: null,
isLoadingEducationAccountInfo: false,
isFetchingEducationAccountInfo: false,
webappCopyrightEnabled: false,
licenseLimit: {
workspace_members: {
size: 0,
limit: 0,
},
},
refreshLicenseLimit: noop,
isAllowTransferWorkspace: false,
isAllowPublishAsCustomKnowledgePipelineTemplate: false,
}
const defaultModalContext: ModalContextState = {
setShowAccountSettingModal: noop,
setShowApiBasedExtensionModal: noop,
setShowModerationSettingModal: noop,
setShowExternalDataToolModal: noop,
setShowPricingModal: noop,
setShowAnnotationFullModal: noop,
setShowModelModal: noop,
setShowExternalKnowledgeAPIModal: noop,
setShowModelLoadBalancingModal: noop,
setShowOpeningModal: noop,
setShowUpdatePluginModal: noop,
setShowEducationExpireNoticeModal: noop,
setShowTriggerEventsLimitModal: noop,
}
export type MockOverrides = {
providerContext?: Partial<typeof defaultProviderContext>
modalContext?: Partial<typeof defaultModalContext>
}
export type APIKeyInfoPanelRenderOptions = {
mockOverrides?: MockOverrides
} & Omit<RenderOptions, 'wrapper'>
// Setup function to configure mocks
export function setupMocks(overrides: MockOverrides = {}) {
mockUseProviderContext.mockReturnValue({
...defaultProviderContext,
...overrides.providerContext,
})
mockUseModalContext.mockReturnValue({
...defaultModalContext,
...overrides.modalContext,
})
}
// Custom render function
export function renderAPIKeyInfoPanel(options: APIKeyInfoPanelRenderOptions = {}) {
const { mockOverrides, ...renderOptions } = options
setupMocks(mockOverrides)
return render(<APIKeyInfoPanel />, renderOptions)
}
// Helper functions for common test scenarios
export const scenarios = {
// Render with API key not set (default)
withAPIKeyNotSet: (overrides: MockOverrides = {}) =>
renderAPIKeyInfoPanel({
mockOverrides: {
providerContext: { isAPIKeySet: false },
...overrides,
},
}),
// Render with API key already set
withAPIKeySet: (overrides: MockOverrides = {}) =>
renderAPIKeyInfoPanel({
mockOverrides: {
providerContext: { isAPIKeySet: true },
...overrides,
},
}),
// Render with mock modal function
withMockModal: (mockSetShowAccountSettingModal: jest.Mock, overrides: MockOverrides = {}) =>
renderAPIKeyInfoPanel({
mockOverrides: {
modalContext: { setShowAccountSettingModal: mockSetShowAccountSettingModal },
...overrides,
},
}),
}
// Common test assertions
export const assertions = {
// Should render main button
shouldRenderMainButton: () => {
const button = document.querySelector('button.btn-primary')
expect(button).toBeInTheDocument()
return button
},
// Should not render at all
shouldNotRender: (container: HTMLElement) => {
expect(container.firstChild).toBeNull()
},
// Should have correct panel styling
shouldHavePanelStyling: (panel: HTMLElement) => {
expect(panel).toHaveClass(
'border-components-panel-border',
'bg-components-panel-bg',
'relative',
'mb-6',
'rounded-2xl',
'border',
'p-8',
'shadow-md',
)
},
// Should have close button
shouldHaveCloseButton: (container: HTMLElement) => {
const closeButton = container.querySelector('.absolute.right-4.top-4')
expect(closeButton).toBeInTheDocument()
expect(closeButton).toHaveClass('cursor-pointer')
return closeButton
},
}
// Common user interactions
export const interactions = {
// Click the main button
clickMainButton: () => {
const button = document.querySelector('button.btn-primary')
if (button) fireEvent.click(button)
return button
},
// Click the close button
clickCloseButton: (container: HTMLElement) => {
const closeButton = container.querySelector('.absolute.right-4.top-4')
if (closeButton) fireEvent.click(closeButton)
return closeButton
},
}
// Text content keys for assertions
export const textKeys = {
selfHost: {
titleRow1: /appOverview\.apiKeyInfo\.selfHost\.title\.row1/,
titleRow2: /appOverview\.apiKeyInfo\.selfHost\.title\.row2/,
setAPIBtn: /appOverview\.apiKeyInfo\.setAPIBtn/,
tryCloud: /appOverview\.apiKeyInfo\.tryCloud/,
},
cloud: {
trialTitle: /appOverview\.apiKeyInfo\.cloud\.trial\.title/,
trialDescription: /appOverview\.apiKeyInfo\.cloud\.trial\.description/,
setAPIBtn: /appOverview\.apiKeyInfo\.setAPIBtn/,
},
}
// Setup and cleanup utilities
export function clearAllMocks() {
jest.clearAllMocks()
}
// Export mock functions for external access
export { mockUseProviderContext, mockUseModalContext, defaultModalContext }

View File

@ -0,0 +1,122 @@
import { cleanup, screen } from '@testing-library/react'
import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants'
import {
assertions,
clearAllMocks,
defaultModalContext,
interactions,
mockUseModalContext,
scenarios,
textKeys,
} from './apikey-info-panel.test-utils'
// Mock config for Cloud edition
jest.mock('@/config', () => ({
IS_CE_EDITION: false, // Test Cloud edition
}))
afterEach(cleanup)
describe('APIKeyInfoPanel - Cloud Edition', () => {
const mockSetShowAccountSettingModal = jest.fn()
beforeEach(() => {
clearAllMocks()
mockUseModalContext.mockReturnValue({
...defaultModalContext,
setShowAccountSettingModal: mockSetShowAccountSettingModal,
})
})
describe('Rendering', () => {
it('should render without crashing when API key is not set', () => {
scenarios.withAPIKeyNotSet()
assertions.shouldRenderMainButton()
})
it('should not render when API key is already set', () => {
const { container } = scenarios.withAPIKeySet()
assertions.shouldNotRender(container)
})
it('should not render when panel is hidden by user', () => {
const { container } = scenarios.withAPIKeyNotSet()
interactions.clickCloseButton(container)
assertions.shouldNotRender(container)
})
})
describe('Cloud Edition Content', () => {
it('should display cloud version title', () => {
scenarios.withAPIKeyNotSet()
expect(screen.getByText(textKeys.cloud.trialTitle)).toBeInTheDocument()
})
it('should display emoji for cloud version', () => {
const { container } = scenarios.withAPIKeyNotSet()
expect(container.querySelector('em-emoji')).toBeInTheDocument()
expect(container.querySelector('em-emoji')).toHaveAttribute('id', '😀')
})
it('should display cloud version description', () => {
scenarios.withAPIKeyNotSet()
expect(screen.getByText(textKeys.cloud.trialDescription)).toBeInTheDocument()
})
it('should not render external link for cloud version', () => {
const { container } = scenarios.withAPIKeyNotSet()
expect(container.querySelector('a[href="https://cloud.dify.ai/apps"]')).not.toBeInTheDocument()
})
it('should display set API button text', () => {
scenarios.withAPIKeyNotSet()
expect(screen.getByText(textKeys.cloud.setAPIBtn)).toBeInTheDocument()
})
})
describe('User Interactions', () => {
it('should call setShowAccountSettingModal when set API button is clicked', () => {
scenarios.withMockModal(mockSetShowAccountSettingModal)
interactions.clickMainButton()
expect(mockSetShowAccountSettingModal).toHaveBeenCalledWith({
payload: ACCOUNT_SETTING_TAB.PROVIDER,
})
})
it('should hide panel when close button is clicked', () => {
const { container } = scenarios.withAPIKeyNotSet()
expect(container.firstChild).toBeInTheDocument()
interactions.clickCloseButton(container)
assertions.shouldNotRender(container)
})
})
describe('Props and Styling', () => {
it('should render button with primary variant', () => {
scenarios.withAPIKeyNotSet()
const button = screen.getByRole('button')
expect(button).toHaveClass('btn-primary')
})
it('should render panel container with correct classes', () => {
const { container } = scenarios.withAPIKeyNotSet()
const panel = container.firstChild as HTMLElement
assertions.shouldHavePanelStyling(panel)
})
})
describe('Accessibility', () => {
it('should have button with proper role', () => {
scenarios.withAPIKeyNotSet()
expect(screen.getByRole('button')).toBeInTheDocument()
})
it('should have clickable close button', () => {
const { container } = scenarios.withAPIKeyNotSet()
assertions.shouldHaveCloseButton(container)
})
})
})

Some files were not shown because too many files have changed in this diff Show More