Compare commits

..

23 Commits

Author SHA1 Message Date
yyh
142f94e27a Merge remote-tracking branch 'origin/main' into codex/dify-ui-package-migration 2026-04-03 12:14:22 +08:00
608958de1c refactor: select in external_knowledge_service (#34493)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-03 03:42:16 +00:00
7eb632eb34 refactor: select in rag_pipeline (#34495)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-03 03:42:01 +00:00
33d4fd357c refactor: select in account_service (AccountService class) (#34496) 2026-04-03 03:41:46 +00:00
e55bd61c17 refactor: replace useContext with use in selected batch (#34450) 2026-04-03 03:37:35 +00:00
f2fc213d52 chore: update deps (#34487) 2026-04-03 03:26:49 +00:00
f814579ed2 test: migrate service_api dataset controller tests to testcontainers (#34423)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-03 02:28:47 +00:00
71d299d0d3 refactor(api): type hit testing retrieve responses with TypedDict (#34484) 2026-04-03 02:25:30 +00:00
e178451d04 refactor(api): type log identity dict with IdentityDict TypedDict (#34485) 2026-04-03 02:25:02 +00:00
9a6222f245 refactor(api): type webhook data extraction with RawWebhookDataDict TypedDict (#34486) 2026-04-03 02:24:17 +00:00
affe5ed30b refactor(api): type get_knowledge_rate_limit with KnowledgeRateLimitD… (#34483) 2026-04-03 02:23:32 +00:00
4cc5401d7e fix: fix import dsl failed (#34492)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-03 02:08:21 +00:00
36e840cd87 chore: knip fix (#34481)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-02 15:03:42 +00:00
985b41c40b fix(security): add tenant_id validation to prevent IDOR in data source binding (#34456)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-02 13:17:02 +00:00
lif
2e29ac2829 fix: remove redundant cast in MCP base session (#34461)
Signed-off-by: majiayu000 <1835304752@qq.com>
2026-04-02 12:36:21 +00:00
dbfb474eab refactor: select in workflow_tools_manage_service (#34477) 2026-04-02 12:35:04 +00:00
d243de26ec refactor: select in metadata_service (#34479) 2026-04-02 12:34:38 +00:00
894826771a chore: clean up useless tailwind reference (#34478) 2026-04-02 11:45:19 +00:00
yyh
a1bd929b3c remove 2026-04-02 18:35:02 +08:00
yyh
ffb9ee3e36 fix(web): support lint tooling package exports 2026-04-02 18:29:44 +08:00
yyh
485586f49a feat(web): extract dify ui package 2026-04-02 18:25:16 +08:00
a3386da5d6 ci: Update pyrefly version to 0.59.1 (#34452)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-02 09:48:46 +00:00
99
318a3d0308 refactor(api): tighten login and wrapper typing (#34447) 2026-04-02 09:36:58 +00:00
563 changed files with 4179 additions and 8443 deletions

View File

@ -89,6 +89,12 @@ if $web_modified; then
echo "No staged TypeScript changes detected, skipping type-check:tsgo"
fi
echo "Running knip"
if ! pnpm run knip; then
echo "Knip check failed. Please run 'pnpm run knip' to fix the errors."
exit 1
fi
echo "Running unit tests check"
modified_files=$(git diff --cached --name-only -- utils | grep -v '\.spec\.ts$' || true)

View File

@ -1,6 +1,6 @@
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session, sessionmaker
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
@ -71,7 +71,7 @@ class AppImportApi(Resource):
args = AppImportPayload.model_validate(console_ns.payload)
# Create service with session
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
import_service = AppDslService(session)
# Import app
account = current_user

View File

@ -193,7 +193,7 @@ workflow_draft_variable_list_model = console_ns.model(
)
def _api_prerequisite(f: Callable[..., Any]) -> Callable[..., Any]:
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
"""Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied:
@ -210,7 +210,7 @@ def _api_prerequisite(f: Callable[..., Any]) -> Callable[..., Any]:
@edit_permission_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@wraps(f)
def wrapper(*args: Any, **kwargs: Any):
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
return f(*args, **kwargs)
return wrapper

View File

@ -1,6 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import Any
from typing import overload
from sqlalchemy import select
@ -23,14 +23,30 @@ def _load_app_model_with_trial(app_id: str) -> App | None:
return app_model
def get_app_model(
view: Callable[..., Any] | None = None,
@overload
def get_app_model[**P, R](
view: Callable[P, R],
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
) -> Callable[P, R]: ...
@overload
def get_app_model[**P, R](
view: None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
def get_app_model[**P, R](
view: Callable[P, R] | None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
@wraps(view_func)
def decorated_view(*args: Any, **kwargs: Any):
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters")
@ -68,14 +84,30 @@ def get_app_model(
return decorator(view)
def get_app_model_with_trial(
view: Callable[..., Any] | None = None,
@overload
def get_app_model_with_trial[**P, R](
view: Callable[P, R],
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
) -> Callable[P, R]: ...
@overload
def get_app_model_with_trial[**P, R](
view: None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
def get_app_model_with_trial[**P, R](
view: Callable[P, R] | None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
@wraps(view_func)
def decorated_view(*args: Any, **kwargs: Any):
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters")

View File

@ -158,10 +158,11 @@ class DataSourceApi(Resource):
@login_required
@account_initialization_required
def patch(self, binding_id, action: Literal["enable", "disable"]):
_, current_tenant_id = current_account_with_tenant()
binding_id = str(binding_id)
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).filter_by(id=binding_id)
select(DataSourceOauthBinding).filter_by(id=binding_id, tenant_id=current_tenant_id)
).scalar_one_or_none()
if data_source_binding is None:
raise NotFound("Data source binding not found.")

View File

@ -1,4 +1,5 @@
import logging
from collections.abc import Callable
from typing import Any, NoReturn
from flask import Response, request
@ -55,7 +56,7 @@ class WorkflowDraftVariablePatchPayload(BaseModel):
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
def _api_prerequisite(f):
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
"""Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied:
@ -70,7 +71,7 @@ def _api_prerequisite(f):
@login_required
@account_initialization_required
@get_rag_pipeline
def wrapper(*args, **kwargs):
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
return f(*args, **kwargs)

View File

@ -1,9 +1,10 @@
import inspect
import logging
import time
from collections.abc import Callable
from enum import StrEnum, auto
from functools import wraps
from typing import Any, cast, overload
from typing import cast, overload
from flask import current_app, request
from flask_login import user_logged_in
@ -230,94 +231,73 @@ def cloud_edition_billing_rate_limit_check[**P, R](
return interceptor
def validate_dataset_token(
view: Callable[..., Any] | None = None,
) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(view_func)
def decorated(*args: Any, **kwargs: Any) -> Any:
api_token = validate_and_get_api_token("dataset")
def validate_dataset_token[R](view: Callable[..., R]) -> Callable[..., R]:
positional_parameters = [
parameter
for parameter in inspect.signature(view).parameters.values()
if parameter.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
]
expects_bound_instance = bool(positional_parameters and positional_parameters[0].name in {"self", "cls"})
# get url path dataset_id from positional args or kwargs
# Flask passes URL path parameters as positional arguments
dataset_id = None
@wraps(view)
def decorated(*args: object, **kwargs: object) -> R:
api_token = validate_and_get_api_token("dataset")
# First try to get from kwargs (explicit parameter)
dataset_id = kwargs.get("dataset_id")
# Flask may pass URL path parameters positionally, so inspect both kwargs and args.
dataset_id = kwargs.get("dataset_id")
# If not in kwargs, try to extract from positional args
if not dataset_id and args:
# For class methods: args[0] is self, args[1] is dataset_id (if exists)
# Check if first arg is likely a class instance (has __dict__ or __class__)
if len(args) > 1 and hasattr(args[0], "__dict__"):
# This is a class method, dataset_id should be in args[1]
potential_id = args[1]
# Validate it's a string-like UUID, not another object
try:
# Try to convert to string and check if it's a valid UUID format
str_id = str(potential_id)
# Basic check: UUIDs are 36 chars with hyphens
if len(str_id) == 36 and str_id.count("-") == 4:
dataset_id = str_id
except Exception:
logger.exception("Failed to parse dataset_id from class method args")
elif len(args) > 0:
# Not a class method, check if args[0] looks like a UUID
potential_id = args[0]
try:
str_id = str(potential_id)
if len(str_id) == 36 and str_id.count("-") == 4:
dataset_id = str_id
except Exception:
logger.exception("Failed to parse dataset_id from positional args")
if not dataset_id and args:
potential_id = args[0]
try:
str_id = str(potential_id)
if len(str_id) == 36 and str_id.count("-") == 4:
dataset_id = str_id
except Exception:
logger.exception("Failed to parse dataset_id from positional args")
# Validate dataset if dataset_id is provided
if dataset_id:
dataset_id = str(dataset_id)
dataset = db.session.scalar(
select(Dataset)
.where(
Dataset.id == dataset_id,
Dataset.tenant_id == api_token.tenant_id,
)
.limit(1)
if dataset_id:
dataset_id = str(dataset_id)
dataset = db.session.scalar(
select(Dataset)
.where(
Dataset.id == dataset_id,
Dataset.tenant_id == api_token.tenant_id,
)
if not dataset:
raise NotFound("Dataset not found.")
if not dataset.enable_api:
raise Forbidden("Dataset api access is not enabled.")
tenant_account_join = db.session.execute(
select(Tenant, TenantAccountJoin)
.where(Tenant.id == api_token.tenant_id)
.where(TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.role.in_(["owner"]))
.where(Tenant.status == TenantStatus.NORMAL)
).one_or_none() # TODO: only owner information is required, so only one is returned.
if tenant_account_join:
tenant, ta = tenant_account_join
account = db.session.get(Account, ta.account_id)
# Login admin
if account:
account.current_tenant = tenant
current_app.login_manager._update_request_context_with_user(account) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
else:
raise Unauthorized("Tenant owner account does not exist.")
.limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
if not dataset.enable_api:
raise Forbidden("Dataset api access is not enabled.")
tenant_account_join = db.session.execute(
select(Tenant, TenantAccountJoin)
.where(Tenant.id == api_token.tenant_id)
.where(TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.role.in_(["owner"]))
.where(Tenant.status == TenantStatus.NORMAL)
).one_or_none() # TODO: only owner information is required, so only one is returned.
if tenant_account_join:
tenant, ta = tenant_account_join
account = db.session.get(Account, ta.account_id)
# Login admin
if account:
account.current_tenant = tenant
current_app.login_manager._update_request_context_with_user(account) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
else:
raise Unauthorized("Tenant does not exist.")
if args and isinstance(args[0], Resource):
return view_func(args[0], api_token.tenant_id, *args[1:], **kwargs)
raise Unauthorized("Tenant owner account does not exist.")
else:
raise Unauthorized("Tenant does not exist.")
return view_func(api_token.tenant_id, *args, **kwargs)
if expects_bound_instance:
if not args:
raise TypeError("validate_dataset_token expected a bound resource instance.")
return view(args[0], api_token.tenant_id, *args[1:], **kwargs)
return decorated
return view(api_token.tenant_id, *args, **kwargs)
if view:
return decorator(view)
# if view is None, it means that the decorator is used without parentheses
# use the decorator as a function for method_decorators
return decorator
return decorated
def validate_and_get_api_token(scope: str | None = None):

View File

@ -7,7 +7,7 @@ from werkzeug.exceptions import NotFound, RequestEntityTooLarge
from controllers.trigger import bp
from core.trigger.debug.event_bus import TriggerDebugEventBus
from core.trigger.debug.events import WebhookDebugEvent, build_webhook_pool_key
from services.trigger.webhook_service import WebhookService
from services.trigger.webhook_service import RawWebhookDataDict, WebhookService
logger = logging.getLogger(__name__)
@ -23,6 +23,7 @@ def _prepare_webhook_execution(webhook_id: str, is_debug: bool = False):
webhook_id, is_debug=is_debug
)
webhook_data: RawWebhookDataDict
try:
# Use new unified extraction and validation
webhook_data = WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)

View File

@ -3,13 +3,19 @@
import logging
import traceback
from datetime import UTC, datetime
from typing import Any
from typing import Any, TypedDict
import orjson
from configs import dify_config
class IdentityDict(TypedDict, total=False):
tenant_id: str
user_id: str
user_type: str
class StructuredJSONFormatter(logging.Formatter):
"""
JSON log formatter following the specified schema:
@ -84,7 +90,7 @@ class StructuredJSONFormatter(logging.Formatter):
return log_dict
def _extract_identity(self, record: logging.LogRecord) -> dict[str, str] | None:
def _extract_identity(self, record: logging.LogRecord) -> IdentityDict | None:
tenant_id = getattr(record, "tenant_id", None)
user_id = getattr(record, "user_id", None)
user_type = getattr(record, "user_type", None)
@ -92,7 +98,7 @@ class StructuredJSONFormatter(logging.Formatter):
if not any([tenant_id, user_id, user_type]):
return None
identity: dict[str, str] = {}
identity: IdentityDict = {}
if tenant_id:
identity["tenant_id"] = tenant_id
if user_id:

View File

@ -4,7 +4,7 @@ from collections.abc import Callable
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
from datetime import timedelta
from types import TracebackType
from typing import Any, Self, cast
from typing import Any, Self
from httpx import HTTPStatusError
from pydantic import BaseModel
@ -338,12 +338,11 @@ class BaseSession[
validated_request = self._receive_request_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
validated_request = cast(ReceiveRequestT, validated_request)
responder = RequestResponder[ReceiveRequestT, SendResultT](
request_id=message.message.root.id,
request_meta=validated_request.root.params.meta if validated_request.root.params else None,
request=validated_request,
request=validated_request, # type: ignore[arg-type] # mypy can't narrow constrained TypeVar from model_validate
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
)
@ -359,15 +358,14 @@ class BaseSession[
notification = self._receive_notification_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
notification = cast(ReceiveNotificationT, notification)
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId
if cancelled_id in self._in_flight:
self._in_flight[cancelled_id].cancel()
else:
self._received_notification(notification)
self._handle_incoming(notification)
self._received_notification(notification) # type: ignore[arg-type]
self._handle_incoming(notification) # type: ignore[arg-type]
except Exception as e:
# For other validation errors, log and continue
logger.warning("Failed to validate notification: %s. Message was: %s", e, message.message.root)

View File

@ -1,5 +1,14 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from flask import Flask
if TYPE_CHECKING:
from extensions.ext_login import DifyLoginManager
class DifyApp(Flask):
pass
"""Flask application type with Dify-specific extension attributes."""
login_manager: DifyLoginManager

View File

@ -1,7 +1,8 @@
import json
from typing import cast
import flask_login
from flask import Response, request
from flask import Request, Response, request
from flask_login import user_loaded_from_request, user_logged_in
from sqlalchemy import select
from werkzeug.exceptions import NotFound, Unauthorized
@ -16,13 +17,35 @@ from models import Account, Tenant, TenantAccountJoin
from models.model import AppMCPServer, EndUser
from services.account_service import AccountService
login_manager = flask_login.LoginManager()
type LoginUser = Account | EndUser
class DifyLoginManager(flask_login.LoginManager):
"""Project-specific Flask-Login manager with a stable unauthorized contract.
Dify registers `unauthorized_handler` below to always return a JSON `Response`.
Overriding this method lets callers rely on that narrower return type instead of
Flask-Login's broader callback contract.
"""
def unauthorized(self) -> Response:
"""Return the registered unauthorized handler result as a Flask `Response`."""
return cast(Response, super().unauthorized())
def load_user_from_request_context(self) -> None:
"""Populate Flask-Login's request-local user cache for the current request."""
self._load_user()
login_manager = DifyLoginManager()
# Flask-Login configuration
@login_manager.request_loader
def load_user_from_request(request_from_flask_login):
def load_user_from_request(request_from_flask_login: Request) -> LoginUser | None:
"""Load user based on the request."""
del request_from_flask_login
# Skip authentication for documentation endpoints
if dify_config.SWAGGER_UI_ENABLED and request.path.endswith((dify_config.SWAGGER_UI_PATH, "/swagger.json")):
return None
@ -100,10 +123,12 @@ def load_user_from_request(request_from_flask_login):
raise NotFound("End user not found.")
return end_user
return None
@user_logged_in.connect
@user_loaded_from_request.connect
def on_user_logged_in(_sender, user):
def on_user_logged_in(_sender: object, user: LoginUser) -> None:
"""Called when a user logged in.
Note: AccountService.load_logged_in_account will populate user.current_tenant_id
@ -114,8 +139,10 @@ def on_user_logged_in(_sender, user):
@login_manager.unauthorized_handler
def unauthorized_handler():
def unauthorized_handler() -> Response:
"""Handle unauthorized requests."""
# Keep this as a concrete `Response`; `DifyLoginManager.unauthorized()` narrows
# Flask-Login's callback contract based on this override.
return Response(
json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
status=401,
@ -123,5 +150,5 @@ def unauthorized_handler():
)
def init_app(app: DifyApp):
def init_app(app: DifyApp) -> None:
login_manager.init_app(app)

View File

@ -2,19 +2,19 @@ from __future__ import annotations
from collections.abc import Callable
from functools import wraps
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast
from flask import current_app, g, has_request_context, request
from flask import Response, current_app, g, has_request_context, request
from flask_login.config import EXEMPT_METHODS
from werkzeug.local import LocalProxy
from configs import dify_config
from dify_app import DifyApp
from extensions.ext_login import DifyLoginManager
from libs.token import check_csrf_token
from models import Account
if TYPE_CHECKING:
from flask.typing import ResponseReturnValue
from models.model import EndUser
@ -29,7 +29,13 @@ def _resolve_current_user() -> EndUser | Account | None:
return get_current_object() if callable(get_current_object) else user_proxy # type: ignore
def current_account_with_tenant():
def _get_login_manager() -> DifyLoginManager:
"""Return the project login manager with Dify's narrowed unauthorized contract."""
app = cast(DifyApp, current_app)
return app.login_manager
def current_account_with_tenant() -> tuple[Account, str]:
"""
Resolve the underlying account for the current user proxy and ensure tenant context exists.
Allows tests to supply plain Account mocks without the LocalProxy helper.
@ -42,7 +48,7 @@ def current_account_with_tenant():
return user, user.current_tenant_id
def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]:
def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | Response]:
"""
If you decorate a view with this, it will ensure that the current user is
logged in and authenticated before calling the actual view. (If they are
@ -77,13 +83,16 @@ def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | ResponseRetu
"""
@wraps(func)
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | ResponseReturnValue:
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | Response:
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
return current_app.ensure_sync(func)(*args, **kwargs)
user = _resolve_current_user()
if user is None or not user.is_authenticated:
return current_app.login_manager.unauthorized() # type: ignore
# `DifyLoginManager` guarantees that the registered unauthorized handler
# is surfaced here as a concrete Flask `Response`.
unauthorized_response: Response = _get_login_manager().unauthorized()
return unauthorized_response
g._login_user = user
# we put csrf validation here for less conflicts
# TODO: maybe find a better place for it.
@ -96,7 +105,7 @@ def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | ResponseRetu
def _get_user() -> EndUser | Account | None:
if has_request_context():
if "_login_user" not in g:
current_app.login_manager._load_user() # type: ignore
_get_login_manager().load_user_from_request_context()
return g._login_user

View File

@ -171,7 +171,7 @@ dev = [
"sseclient-py>=1.8.0",
"pytest-timeout>=2.4.0",
"pytest-xdist>=3.8.0",
"pyrefly>=0.57.1",
"pyrefly>=0.59.1",
]
############################################################

View File

@ -144,22 +144,26 @@ class AccountService:
@staticmethod
def load_user(user_id: str) -> None | Account:
account = db.session.query(Account).filter_by(id=user_id).first()
account = db.session.get(Account, user_id)
if not account:
return None
if account.status == AccountStatus.BANNED:
raise Unauthorized("Account is banned.")
current_tenant = db.session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first()
current_tenant = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.account_id == account.id, TenantAccountJoin.current == True)
.limit(1)
)
if current_tenant:
account.set_tenant_id(current_tenant.tenant_id)
else:
available_ta = (
db.session.query(TenantAccountJoin)
.filter_by(account_id=account.id)
available_ta = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.account_id == account.id)
.order_by(TenantAccountJoin.id.asc())
.first()
.limit(1)
)
if not available_ta:
return None
@ -195,7 +199,7 @@ class AccountService:
def authenticate(email: str, password: str, invite_token: str | None = None) -> Account:
"""authenticate account with email and password"""
account = db.session.query(Account).filter_by(email=email).first()
account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
if not account:
raise AccountPasswordError("Invalid email or password.")
@ -371,8 +375,10 @@ class AccountService:
"""Link account integrate"""
try:
# Query whether there is an existing binding record for the same provider
account_integrate: AccountIntegrate | None = (
db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first()
account_integrate: AccountIntegrate | None = db.session.scalar(
select(AccountIntegrate)
.where(AccountIntegrate.account_id == account.id, AccountIntegrate.provider == provider)
.limit(1)
)
if account_integrate:
@ -416,7 +422,9 @@ class AccountService:
def update_account_email(account: Account, email: str) -> Account:
"""Update account email"""
account.email = email
account_integrate = db.session.query(AccountIntegrate).filter_by(account_id=account.id).first()
account_integrate = db.session.scalar(
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id).limit(1)
)
if account_integrate:
db.session.delete(account_integrate)
db.session.add(account)
@ -818,7 +826,7 @@ class AccountService:
)
)
account = db.session.query(Account).where(Account.email == email).first()
account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
if not account:
return None
@ -1018,7 +1026,7 @@ class AccountService:
@staticmethod
def check_email_unique(email: str) -> bool:
return db.session.query(Account).filter_by(email=email).first() is None
return db.session.scalar(select(Account).where(Account.email == email).limit(1)) is None
class TenantService:

View File

@ -32,6 +32,11 @@ class SubscriptionPlan(TypedDict):
expiration_date: int
class KnowledgeRateLimitDict(TypedDict):
limit: int
subscription_plan: str
class BillingService:
base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY")
@ -58,7 +63,7 @@ class BillingService:
return usage_info
@classmethod
def get_knowledge_rate_limit(cls, tenant_id: str):
def get_knowledge_rate_limit(cls, tenant_id: str) -> KnowledgeRateLimitDict:
params = {"tenant_id": tenant_id}
knowledge_rate_limit = cls._send_request("GET", "/subscription/knowledge-rate-limit", params=params)

View File

@ -5,7 +5,7 @@ from urllib.parse import urlparse
import httpx
from graphon.nodes.http_request.exc import InvalidHttpMethodError
from sqlalchemy import select
from sqlalchemy import func, select
from constants import HIDDEN_VALUE
from core.helper import ssrf_proxy
@ -103,8 +103,10 @@ class ExternalDatasetService:
@staticmethod
def get_external_knowledge_api(external_knowledge_api_id: str, tenant_id: str) -> ExternalKnowledgeApis:
external_knowledge_api: ExternalKnowledgeApis | None = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
external_knowledge_api: ExternalKnowledgeApis | None = db.session.scalar(
select(ExternalKnowledgeApis)
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
.limit(1)
)
if external_knowledge_api is None:
raise ValueError("api template not found")
@ -112,8 +114,10 @@ class ExternalDatasetService:
@staticmethod
def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis:
external_knowledge_api: ExternalKnowledgeApis | None = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
external_knowledge_api: ExternalKnowledgeApis | None = db.session.scalar(
select(ExternalKnowledgeApis)
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
.limit(1)
)
if external_knowledge_api is None:
raise ValueError("api template not found")
@ -132,8 +136,10 @@ class ExternalDatasetService:
@staticmethod
def delete_external_knowledge_api(tenant_id: str, external_knowledge_api_id: str):
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
external_knowledge_api = db.session.scalar(
select(ExternalKnowledgeApis)
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
.limit(1)
)
if external_knowledge_api is None:
raise ValueError("api template not found")
@ -144,9 +150,12 @@ class ExternalDatasetService:
@staticmethod
def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]:
count = (
db.session.query(ExternalKnowledgeBindings)
.filter_by(external_knowledge_api_id=external_knowledge_api_id)
.count()
db.session.scalar(
select(func.count(ExternalKnowledgeBindings.id)).where(
ExternalKnowledgeBindings.external_knowledge_api_id == external_knowledge_api_id
)
)
or 0
)
if count > 0:
return True, count
@ -154,8 +163,10 @@ class ExternalDatasetService:
@staticmethod
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
external_knowledge_binding: ExternalKnowledgeBindings | None = (
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
external_knowledge_binding: ExternalKnowledgeBindings | None = db.session.scalar(
select(ExternalKnowledgeBindings)
.where(ExternalKnowledgeBindings.dataset_id == dataset_id, ExternalKnowledgeBindings.tenant_id == tenant_id)
.limit(1)
)
if not external_knowledge_binding:
raise ValueError("external knowledge binding not found")
@ -163,8 +174,10 @@ class ExternalDatasetService:
@staticmethod
def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict):
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
external_knowledge_api = db.session.scalar(
select(ExternalKnowledgeApis)
.where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id)
.limit(1)
)
if external_knowledge_api is None or external_knowledge_api.settings is None:
raise ValueError("api template not found")
@ -238,12 +251,17 @@ class ExternalDatasetService:
@staticmethod
def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset:
# check if dataset name already exists
if db.session.query(Dataset).filter_by(name=args.get("name"), tenant_id=tenant_id).first():
if db.session.scalar(
select(Dataset).where(Dataset.name == args.get("name"), Dataset.tenant_id == tenant_id).limit(1)
):
raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.")
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis)
.filter_by(id=args.get("external_knowledge_api_id"), tenant_id=tenant_id)
.first()
external_knowledge_api = db.session.scalar(
select(ExternalKnowledgeApis)
.where(
ExternalKnowledgeApis.id == args.get("external_knowledge_api_id"),
ExternalKnowledgeApis.tenant_id == tenant_id,
)
.limit(1)
)
if external_knowledge_api is None:
@ -286,16 +304,18 @@ class ExternalDatasetService:
external_retrieval_parameters: dict,
metadata_condition: MetadataCondition | None = None,
):
external_knowledge_binding = (
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
external_knowledge_binding = db.session.scalar(
select(ExternalKnowledgeBindings)
.where(ExternalKnowledgeBindings.dataset_id == dataset_id, ExternalKnowledgeBindings.tenant_id == tenant_id)
.limit(1)
)
if not external_knowledge_binding:
raise ValueError("external knowledge binding not found")
external_knowledge_api = (
db.session.query(ExternalKnowledgeApis)
.filter_by(id=external_knowledge_binding.external_knowledge_api_id)
.first()
external_knowledge_api = db.session.scalar(
select(ExternalKnowledgeApis)
.where(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id)
.limit(1)
)
if external_knowledge_api is None or external_knowledge_api.settings is None:
raise ValueError("external api template not found")

View File

@ -1,7 +1,7 @@
import json
import logging
import time
from typing import Any
from typing import Any, TypedDict
from graphon.model_runtime.entities import LLMMode
@ -18,6 +18,16 @@ from models.enums import CreatorUserRole, DatasetQuerySource
logger = logging.getLogger(__name__)
class QueryDict(TypedDict):
content: str
class RetrieveResponseDict(TypedDict):
query: QueryDict
records: list[dict[str, Any]]
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
@ -150,7 +160,7 @@ class HitTestingService:
return dict(cls.compact_external_retrieve_response(dataset, query, all_documents))
@classmethod
def compact_retrieve_response(cls, query: str, documents: list[Document]) -> dict[Any, Any]:
def compact_retrieve_response(cls, query: str, documents: list[Document]) -> RetrieveResponseDict:
records = RetrievalService.format_retrieval_documents(documents)
return {
@ -161,7 +171,7 @@ class HitTestingService:
}
@classmethod
def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list) -> dict[Any, Any]:
def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list) -> RetrieveResponseDict:
records = []
if dataset.provider == "external":
for document in documents:

View File

@ -1,6 +1,8 @@
import copy
import logging
from sqlalchemy import delete, func, select
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@ -25,10 +27,14 @@ class MetadataService:
raise ValueError("Metadata name cannot exceed 255 characters.")
current_user, current_tenant_id = current_account_with_tenant()
# check if metadata name already exists
if (
db.session.query(DatasetMetadata)
.filter_by(tenant_id=current_tenant_id, dataset_id=dataset_id, name=metadata_args.name)
.first()
if db.session.scalar(
select(DatasetMetadata)
.where(
DatasetMetadata.tenant_id == current_tenant_id,
DatasetMetadata.dataset_id == dataset_id,
DatasetMetadata.name == metadata_args.name,
)
.limit(1)
):
raise ValueError("Metadata name already exists.")
for field in BuiltInField:
@ -54,10 +60,14 @@ class MetadataService:
lock_key = f"dataset_metadata_lock_{dataset_id}"
# check if metadata name already exists
current_user, current_tenant_id = current_account_with_tenant()
if (
db.session.query(DatasetMetadata)
.filter_by(tenant_id=current_tenant_id, dataset_id=dataset_id, name=name)
.first()
if db.session.scalar(
select(DatasetMetadata)
.where(
DatasetMetadata.tenant_id == current_tenant_id,
DatasetMetadata.dataset_id == dataset_id,
DatasetMetadata.name == name,
)
.limit(1)
):
raise ValueError("Metadata name already exists.")
for field in BuiltInField:
@ -65,7 +75,11 @@ class MetadataService:
raise ValueError("Metadata name already exists in Built-in fields.")
try:
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id, dataset_id=dataset_id).first()
metadata = db.session.scalar(
select(DatasetMetadata)
.where(DatasetMetadata.id == metadata_id, DatasetMetadata.dataset_id == dataset_id)
.limit(1)
)
if metadata is None:
raise ValueError("Metadata not found.")
old_name = metadata.name
@ -74,9 +88,9 @@ class MetadataService:
metadata.updated_at = naive_utc_now()
# update related documents
dataset_metadata_bindings = (
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all()
)
dataset_metadata_bindings = db.session.scalars(
select(DatasetMetadataBinding).where(DatasetMetadataBinding.metadata_id == metadata_id)
).all()
if dataset_metadata_bindings:
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
documents = DocumentService.get_document_by_ids(document_ids)
@ -101,15 +115,19 @@ class MetadataService:
lock_key = f"dataset_metadata_lock_{dataset_id}"
try:
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id, dataset_id=dataset_id).first()
metadata = db.session.scalar(
select(DatasetMetadata)
.where(DatasetMetadata.id == metadata_id, DatasetMetadata.dataset_id == dataset_id)
.limit(1)
)
if metadata is None:
raise ValueError("Metadata not found.")
db.session.delete(metadata)
# deal related documents
dataset_metadata_bindings = (
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all()
)
dataset_metadata_bindings = db.session.scalars(
select(DatasetMetadataBinding).where(DatasetMetadataBinding.metadata_id == metadata_id)
).all()
if dataset_metadata_bindings:
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
documents = DocumentService.get_document_by_ids(document_ids)
@ -224,16 +242,23 @@ class MetadataService:
# deal metadata binding (in the same transaction as the doc_metadata update)
if not operation.partial_update:
db.session.query(DatasetMetadataBinding).filter_by(document_id=operation.document_id).delete()
db.session.execute(
delete(DatasetMetadataBinding).where(
DatasetMetadataBinding.document_id == operation.document_id
)
)
current_user, current_tenant_id = current_account_with_tenant()
for metadata_value in operation.metadata_list:
# check if binding already exists
if operation.partial_update:
existing_binding = (
db.session.query(DatasetMetadataBinding)
.filter_by(document_id=operation.document_id, metadata_id=metadata_value.id)
.first()
existing_binding = db.session.scalar(
select(DatasetMetadataBinding)
.where(
DatasetMetadataBinding.document_id == operation.document_id,
DatasetMetadataBinding.metadata_id == metadata_value.id,
)
.limit(1)
)
if existing_binding:
continue
@ -275,9 +300,13 @@ class MetadataService:
"id": item.get("id"),
"name": item.get("name"),
"type": item.get("type"),
"count": db.session.query(DatasetMetadataBinding)
.filter_by(metadata_id=item.get("id"), dataset_id=dataset.id)
.count(),
"count": db.session.scalar(
select(func.count(DatasetMetadataBinding.id)).where(
DatasetMetadataBinding.metadata_id == item.get("id"),
DatasetMetadataBinding.dataset_id == dataset.id,
)
)
or 0,
}
for item in dataset.doc_metadata or []
if item.get("id") != "built-in"

View File

@ -156,27 +156,27 @@ class RagPipelineService:
:param template_id: template id
:param template_info: template info
"""
customized_template: PipelineCustomizedTemplate | None = (
db.session.query(PipelineCustomizedTemplate)
customized_template: PipelineCustomizedTemplate | None = db.session.scalar(
select(PipelineCustomizedTemplate)
.where(
PipelineCustomizedTemplate.id == template_id,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
)
.first()
.limit(1)
)
if not customized_template:
raise ValueError("Customized pipeline template not found.")
# check template name is exist
template_name = template_info.name
if template_name:
template = (
db.session.query(PipelineCustomizedTemplate)
template = db.session.scalar(
select(PipelineCustomizedTemplate)
.where(
PipelineCustomizedTemplate.name == template_name,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
PipelineCustomizedTemplate.id != template_id,
)
.first()
.limit(1)
)
if template:
raise ValueError("Template name is already exists")
@ -192,13 +192,13 @@ class RagPipelineService:
"""
Delete customized pipeline template.
"""
customized_template: PipelineCustomizedTemplate | None = (
db.session.query(PipelineCustomizedTemplate)
customized_template: PipelineCustomizedTemplate | None = db.session.scalar(
select(PipelineCustomizedTemplate)
.where(
PipelineCustomizedTemplate.id == template_id,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
)
.first()
.limit(1)
)
if not customized_template:
raise ValueError("Customized pipeline template not found.")
@ -210,14 +210,14 @@ class RagPipelineService:
Get draft workflow
"""
# fetch draft workflow by rag pipeline
workflow = (
db.session.query(Workflow)
workflow = db.session.scalar(
select(Workflow)
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == "draft",
)
.first()
.limit(1)
)
# return draft workflow
@ -232,28 +232,28 @@ class RagPipelineService:
return None
# fetch published workflow by workflow_id
workflow = (
db.session.query(Workflow)
workflow = db.session.scalar(
select(Workflow)
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.id == pipeline.workflow_id,
)
.first()
.limit(1)
)
return workflow
def get_published_workflow_by_id(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None:
"""Fetch a published workflow snapshot by ID for restore operations."""
workflow = (
db.session.query(Workflow)
workflow = db.session.scalar(
select(Workflow)
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.id == workflow_id,
)
.first()
.limit(1)
)
if workflow and workflow.version == Workflow.VERSION_DRAFT:
raise IsDraftWorkflowError("source workflow must be published")
@ -974,7 +974,7 @@ class RagPipelineService:
if invoke_from.value == InvokeFrom.PUBLISHED_PIPELINE:
document_id = get_system_segment(variable_pool, SystemVariableKey.DOCUMENT_ID)
if document_id:
document = db.session.query(Document).where(Document.id == document_id.value).first()
document = db.session.get(Document, document_id.value)
if document:
document.indexing_status = IndexingStatus.ERROR
document.error = error
@ -1178,12 +1178,12 @@ class RagPipelineService:
"""
Publish customized pipeline template
"""
pipeline = db.session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
pipeline = db.session.get(Pipeline, pipeline_id)
if not pipeline:
raise ValueError("Pipeline not found")
if not pipeline.workflow_id:
raise ValueError("Pipeline workflow not found")
workflow = db.session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
workflow = db.session.get(Workflow, pipeline.workflow_id)
if not workflow:
raise ValueError("Workflow not found")
with Session(db.engine) as session:
@ -1194,21 +1194,21 @@ class RagPipelineService:
# check template name is exist
template_name = args.get("name")
if template_name:
template = (
db.session.query(PipelineCustomizedTemplate)
template = db.session.scalar(
select(PipelineCustomizedTemplate)
.where(
PipelineCustomizedTemplate.name == template_name,
PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id,
)
.first()
.limit(1)
)
if template:
raise ValueError("Template name is already exists")
max_position = (
db.session.query(func.max(PipelineCustomizedTemplate.position))
.where(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id)
.scalar()
max_position = db.session.scalar(
select(func.max(PipelineCustomizedTemplate.position)).where(
PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id
)
)
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
@ -1239,13 +1239,14 @@ class RagPipelineService:
def is_workflow_exist(self, pipeline: Pipeline) -> bool:
return (
db.session.query(Workflow)
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == Workflow.VERSION_DRAFT,
db.session.scalar(
select(func.count(Workflow.id)).where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == Workflow.VERSION_DRAFT,
)
)
.count()
or 0
) > 0
def get_node_last_run(
@ -1353,11 +1354,11 @@ class RagPipelineService:
def get_recommended_plugins(self, type: str) -> dict:
# Query active recommended plugins
query = db.session.query(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True)
stmt = select(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True)
if type and type != "all":
query = query.where(PipelineRecommendedPlugin.type == type)
stmt = stmt.where(PipelineRecommendedPlugin.type == type)
pipeline_recommended_plugins = query.order_by(PipelineRecommendedPlugin.position.asc()).all()
pipeline_recommended_plugins = db.session.scalars(stmt.order_by(PipelineRecommendedPlugin.position.asc())).all()
if not pipeline_recommended_plugins:
return {
@ -1396,14 +1397,12 @@ class RagPipelineService:
"""
Retry error document
"""
document_pipeline_execution_log = (
db.session.query(DocumentPipelineExecutionLog)
.where(DocumentPipelineExecutionLog.document_id == document.id)
.first()
document_pipeline_execution_log = db.session.scalar(
select(DocumentPipelineExecutionLog).where(DocumentPipelineExecutionLog.document_id == document.id).limit(1)
)
if not document_pipeline_execution_log:
raise ValueError("Document pipeline execution log not found")
pipeline = db.session.query(Pipeline).where(Pipeline.id == document_pipeline_execution_log.pipeline_id).first()
pipeline = db.session.get(Pipeline, document_pipeline_execution_log.pipeline_id)
if not pipeline:
raise ValueError("Pipeline not found")
# convert to app config
@ -1432,23 +1431,23 @@ class RagPipelineService:
"""
Get datasource plugins
"""
dataset: Dataset | None = (
db.session.query(Dataset)
dataset: Dataset | None = db.session.scalar(
select(Dataset)
.where(
Dataset.id == dataset_id,
Dataset.tenant_id == tenant_id,
)
.first()
.limit(1)
)
if not dataset:
raise ValueError("Dataset not found")
pipeline: Pipeline | None = (
db.session.query(Pipeline)
pipeline: Pipeline | None = db.session.scalar(
select(Pipeline)
.where(
Pipeline.id == dataset.pipeline_id,
Pipeline.tenant_id == tenant_id,
)
.first()
.limit(1)
)
if not pipeline:
raise ValueError("Pipeline not found")
@ -1530,23 +1529,23 @@ class RagPipelineService:
"""
Get pipeline
"""
dataset: Dataset | None = (
db.session.query(Dataset)
dataset: Dataset | None = db.session.scalar(
select(Dataset)
.where(
Dataset.id == dataset_id,
Dataset.tenant_id == tenant_id,
)
.first()
.limit(1)
)
if not dataset:
raise ValueError("Dataset not found")
pipeline: Pipeline | None = (
db.session.query(Pipeline)
pipeline: Pipeline | None = db.session.scalar(
select(Pipeline)
.where(
Pipeline.id == dataset.pipeline_id,
Pipeline.tenant_id == tenant_id,
)
.first()
.limit(1)
)
if not pipeline:
raise ValueError("Pipeline not found")

View File

@ -3,7 +3,7 @@ import logging
from datetime import datetime
from graphon.model_runtime.utils.encoders import jsonable_encoder
from sqlalchemy import or_, select
from sqlalchemy import delete, or_, select
from sqlalchemy.orm import Session
from core.tools.__base.tool_provider import ToolProviderController
@ -42,20 +42,22 @@ class WorkflowToolManageService:
labels: list[str] | None = None,
):
# check if the name is unique
existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider)
existing_workflow_tool_provider = db.session.scalar(
select(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == tenant_id,
# name or app_id
or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id),
)
.first()
.limit(1)
)
if existing_workflow_tool_provider is not None:
raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists")
app: App | None = db.session.query(App).where(App.id == workflow_app_id, App.tenant_id == tenant_id).first()
app: App | None = db.session.scalar(
select(App).where(App.id == workflow_app_id, App.tenant_id == tenant_id).limit(1)
)
if app is None:
raise ValueError(f"App {workflow_app_id} not found")
@ -122,30 +124,30 @@ class WorkflowToolManageService:
:return: the updated tool
"""
# check if the name is unique
existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider)
existing_workflow_tool_provider = db.session.scalar(
select(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.name == name,
WorkflowToolProvider.id != workflow_tool_id,
)
.first()
.limit(1)
)
if existing_workflow_tool_provider is not None:
raise ValueError(f"Tool with name {name} already exists")
workflow_tool_provider: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
workflow_tool_provider: WorkflowToolProvider | None = db.session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
.limit(1)
)
if workflow_tool_provider is None:
raise ValueError(f"Tool {workflow_tool_id} not found")
app: App | None = (
db.session.query(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first()
app: App | None = db.session.scalar(
select(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).limit(1)
)
if app is None:
@ -234,9 +236,11 @@ class WorkflowToolManageService:
:param tenant_id: the tenant id
:param workflow_tool_id: the workflow tool id
"""
db.session.query(WorkflowToolProvider).where(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id
).delete()
db.session.execute(
delete(WorkflowToolProvider).where(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id
)
)
db.session.commit()
@ -251,10 +255,10 @@ class WorkflowToolManageService:
:param workflow_tool_id: the workflow tool id
:return: the tool
"""
db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
db_tool: WorkflowToolProvider | None = db.session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
.limit(1)
)
return cls._get_workflow_tool(tenant_id, db_tool)
@ -267,10 +271,10 @@ class WorkflowToolManageService:
:param workflow_app_id: the workflow app id
:return: the tool
"""
db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
db_tool: WorkflowToolProvider | None = db.session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
.first()
.limit(1)
)
return cls._get_workflow_tool(tenant_id, db_tool)
@ -284,8 +288,8 @@ class WorkflowToolManageService:
if db_tool is None:
raise ValueError("Tool not found")
workflow_app: App | None = (
db.session.query(App).where(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).first()
workflow_app: App | None = db.session.scalar(
select(App).where(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).limit(1)
)
if workflow_app is None:
@ -331,10 +335,10 @@ class WorkflowToolManageService:
:param workflow_tool_id: the workflow tool id
:return: the list of tools
"""
db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
db_tool: WorkflowToolProvider | None = db.session.scalar(
select(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
.limit(1)
)
if db_tool is None:

View File

@ -3,7 +3,7 @@ import logging
import mimetypes
import secrets
from collections.abc import Callable, Mapping, Sequence
from typing import Any
from typing import Any, TypedDict
import orjson
from flask import request
@ -50,6 +50,14 @@ logger = logging.getLogger(__name__)
_file_access_controller = DatabaseFileAccessController()
class RawWebhookDataDict(TypedDict):
method: str
headers: dict[str, str]
query_params: dict[str, str]
body: dict[str, Any]
files: dict[str, Any]
class WebhookService:
"""Service for handling webhook operations."""
@ -145,7 +153,7 @@ class WebhookService:
@classmethod
def extract_and_validate_webhook_data(
cls, webhook_trigger: WorkflowWebhookTrigger, node_config: NodeConfigDict
) -> dict[str, Any]:
) -> RawWebhookDataDict:
"""Extract and validate webhook data in a single unified process.
Args:
@ -173,7 +181,7 @@ class WebhookService:
return processed_data
@classmethod
def extract_webhook_data(cls, webhook_trigger: WorkflowWebhookTrigger) -> dict[str, Any]:
def extract_webhook_data(cls, webhook_trigger: WorkflowWebhookTrigger) -> RawWebhookDataDict:
"""Extract raw data from incoming webhook request without type conversion.
Args:
@ -189,7 +197,7 @@ class WebhookService:
"""
cls._validate_content_length()
data = {
data: RawWebhookDataDict = {
"method": request.method,
"headers": dict(request.headers),
"query_params": dict(request.args),
@ -223,7 +231,7 @@ class WebhookService:
return data
@classmethod
def _process_and_validate_data(cls, raw_data: dict[str, Any], node_data: WebhookData) -> dict[str, Any]:
def _process_and_validate_data(cls, raw_data: RawWebhookDataDict, node_data: WebhookData) -> RawWebhookDataDict:
"""Process and validate webhook data according to node configuration.
Args:
@ -664,7 +672,7 @@ class WebhookService:
raise ValueError(f"Required header missing: {header_name}")
@classmethod
def _validate_http_metadata(cls, webhook_data: dict[str, Any], node_data: WebhookData) -> dict[str, Any]:
def _validate_http_metadata(cls, webhook_data: RawWebhookDataDict, node_data: WebhookData) -> dict[str, Any]:
"""Validate HTTP method and content-type.
Args:
@ -729,7 +737,7 @@ class WebhookService:
return False
@classmethod
def build_workflow_inputs(cls, webhook_data: dict[str, Any]) -> dict[str, Any]:
def build_workflow_inputs(cls, webhook_data: RawWebhookDataDict) -> dict[str, Any]:
"""Construct workflow inputs payload from webhook data.
Args:
@ -747,7 +755,7 @@ class WebhookService:
@classmethod
def trigger_workflow_execution(
cls, webhook_trigger: WorkflowWebhookTrigger, webhook_data: dict[str, Any], workflow: Workflow
cls, webhook_trigger: WorkflowWebhookTrigger, webhook_data: RawWebhookDataDict, workflow: Workflow
) -> None:
"""Trigger workflow execution via AsyncWorkflowService.

View File

@ -20,7 +20,7 @@ def app():
app = Flask(__name__)
app.config["TESTING"] = True
app.config["RESTX_MASK_HEADER"] = "X-Fields"
app.login_manager = SimpleNamespace(_load_user=lambda: None)
app.login_manager = SimpleNamespace(load_user_from_request_context=lambda: None)
return app

View File

@ -12,7 +12,7 @@ from models.account import Account, TenantAccountRole
def app():
flask_app = Flask(__name__)
flask_app.config["TESTING"] = True
flask_app.login_manager = SimpleNamespace(_load_user=lambda: None)
flask_app.login_manager = SimpleNamespace(load_user_from_request_context=lambda: None)
return flask_app

View File

@ -0,0 +1,17 @@
import json
from flask import Response
from extensions.ext_login import unauthorized_handler
def test_unauthorized_handler_returns_json_response() -> None:
response = unauthorized_handler()
assert isinstance(response, Response)
assert response.status_code == 401
assert response.content_type == "application/json"
assert json.loads(response.get_data(as_text=True)) == {
"code": "unauthorized",
"message": "Unauthorized.",
}

View File

@ -2,11 +2,12 @@ from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from flask import Flask, g
from flask_login import LoginManager, UserMixin
from flask import Flask, Response, g
from flask_login import UserMixin
from pytest_mock import MockerFixture
import libs.login as login_module
from extensions.ext_login import DifyLoginManager
from libs.login import current_user
from models.account import Account
@ -39,9 +40,12 @@ def login_app(mocker: MockerFixture) -> Flask:
app = Flask(__name__)
app.config["TESTING"] = True
login_manager = LoginManager()
login_manager = DifyLoginManager()
login_manager.init_app(app)
login_manager.unauthorized = mocker.Mock(name="unauthorized", return_value="Unauthorized")
login_manager.unauthorized = mocker.Mock(
name="unauthorized",
return_value=Response("Unauthorized", status=401, content_type="application/json"),
)
@login_manager.user_loader
def load_user(_user_id: str):
@ -109,18 +113,43 @@ class TestLoginRequired:
resolved_user: MockUser | None,
description: str,
):
"""Test that missing or unauthenticated users are redirected."""
"""Test that missing or unauthenticated users return the manager response."""
resolve_user = resolve_current_user(resolved_user)
with login_app.test_request_context():
result = protected_view()
assert result == "Unauthorized", description
assert result is login_app.login_manager.unauthorized.return_value, description
assert isinstance(result, Response)
assert result.status_code == 401
resolve_user.assert_called_once_with()
login_app.login_manager.unauthorized.assert_called_once_with()
csrf_check.assert_not_called()
def test_unauthorized_access_propagates_response_object(
self,
login_app: Flask,
protected_view,
csrf_check: MagicMock,
resolve_current_user,
mocker: MockerFixture,
) -> None:
"""Test that unauthorized responses are propagated as Flask Response objects."""
resolve_user = resolve_current_user(None)
response = Response("Unauthorized", status=401, content_type="application/json")
mocker.patch.object(
login_module, "_get_login_manager", return_value=SimpleNamespace(unauthorized=lambda: response)
)
with login_app.test_request_context():
result = protected_view()
assert result is response
assert isinstance(result, Response)
resolve_user.assert_called_once_with()
csrf_check.assert_not_called()
@pytest.mark.parametrize(
("method", "login_disabled"),
[
@ -168,10 +197,14 @@ class TestGetUser:
"""Test that _get_user loads user if not already in g."""
mock_user = MockUser("test_user")
def _load_user() -> None:
def load_user_from_request_context() -> None:
g._login_user = mock_user
load_user = mocker.patch.object(login_app.login_manager, "_load_user", side_effect=_load_user)
load_user = mocker.patch.object(
login_app.login_manager,
"load_user_from_request_context",
side_effect=load_user_from_request_context,
)
with login_app.test_request_context():
user = login_module._get_user()

View File

@ -401,10 +401,7 @@ class TestMetadataServiceCreateMetadata:
metadata_args = MetadataTestDataFactory.create_metadata_args_mock(name="category", metadata_type="string")
# Mock query to return None (no existing metadata with same name)
mock_query = Mock()
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db_session.query.return_value = mock_query
mock_db_session.scalar.return_value = None
# Mock BuiltInField enum iteration
with patch("services.metadata_service.BuiltInField") as mock_builtin:
@ -417,10 +414,6 @@ class TestMetadataServiceCreateMetadata:
assert result is not None
assert isinstance(result, DatasetMetadata)
# Verify query was made to check for duplicates
mock_db_session.query.assert_called()
mock_query.filter_by.assert_called()
# Verify metadata was added and committed
mock_db_session.add.assert_called_once()
mock_db_session.commit.assert_called_once()
@ -468,10 +461,7 @@ class TestMetadataServiceCreateMetadata:
# Mock existing metadata with same name
existing_metadata = MetadataTestDataFactory.create_metadata_mock(name="category")
mock_query = Mock()
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = existing_metadata
mock_db_session.query.return_value = mock_query
mock_db_session.scalar.return_value = existing_metadata
# Act & Assert
with pytest.raises(ValueError, match="Metadata name already exists"):
@ -500,10 +490,7 @@ class TestMetadataServiceCreateMetadata:
)
# Mock query to return None (no duplicate in database)
mock_query = Mock()
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db_session.query.return_value = mock_query
mock_db_session.scalar.return_value = None
# Mock BuiltInField to include the conflicting name
with patch("services.metadata_service.BuiltInField") as mock_builtin:
@ -597,27 +584,11 @@ class TestMetadataServiceUpdateMetadataName:
existing_metadata = MetadataTestDataFactory.create_metadata_mock(metadata_id=metadata_id, name="category")
# Mock query for duplicate check (no duplicate)
mock_query = Mock()
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db_session.query.return_value = mock_query
# Mock metadata retrieval
def query_side_effect(model):
if model == DatasetMetadata:
mock_meta_query = Mock()
mock_meta_query.filter_by.return_value = mock_meta_query
mock_meta_query.first.return_value = existing_metadata
return mock_meta_query
return mock_query
mock_db_session.query.side_effect = query_side_effect
# Mock scalar calls: first for duplicate check (None), second for metadata retrieval
mock_db_session.scalar.side_effect = [None, existing_metadata]
# Mock no metadata bindings (no documents to update)
mock_binding_query = Mock()
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.all.return_value = []
mock_db_session.scalars.return_value.all.return_value = []
# Mock BuiltInField enum
with patch("services.metadata_service.BuiltInField") as mock_builtin:
@ -655,22 +626,8 @@ class TestMetadataServiceUpdateMetadataName:
metadata_id = "non-existent-metadata"
new_name = "updated_category"
# Mock query for duplicate check (no duplicate)
mock_query = Mock()
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db_session.query.return_value = mock_query
# Mock metadata retrieval to return None
def query_side_effect(model):
if model == DatasetMetadata:
mock_meta_query = Mock()
mock_meta_query.filter_by.return_value = mock_meta_query
mock_meta_query.first.return_value = None # Not found
return mock_meta_query
return mock_query
mock_db_session.query.side_effect = query_side_effect
# Mock scalar calls: first for duplicate check (None), second for metadata retrieval (None = not found)
mock_db_session.scalar.side_effect = [None, None]
# Mock BuiltInField enum
with patch("services.metadata_service.BuiltInField") as mock_builtin:
@ -746,15 +703,10 @@ class TestMetadataServiceDeleteMetadata:
existing_metadata = MetadataTestDataFactory.create_metadata_mock(metadata_id=metadata_id, name="category")
# Mock metadata retrieval
mock_query = Mock()
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = existing_metadata
mock_db_session.query.return_value = mock_query
mock_db_session.scalar.return_value = existing_metadata
# Mock no metadata bindings (no documents to update)
mock_binding_query = Mock()
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.all.return_value = []
mock_db_session.scalars.return_value.all.return_value = []
# Act
result = MetadataService.delete_metadata(dataset_id, metadata_id)
@ -788,10 +740,7 @@ class TestMetadataServiceDeleteMetadata:
metadata_id = "non-existent-metadata"
# Mock metadata retrieval to return None
mock_query = Mock()
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db_session.query.return_value = mock_query
mock_db_session.scalar.return_value = None
# Act & Assert
with pytest.raises(ValueError, match="Metadata not found"):
@ -1013,10 +962,7 @@ class TestMetadataServiceGetDatasetMetadatas:
)
# Mock usage count queries
mock_query = Mock()
mock_query.filter_by.return_value = mock_query
mock_query.count.return_value = 5 # 5 documents use this metadata
mock_db_session.query.return_value = mock_query
mock_db_session.scalar.return_value = 5 # 5 documents use this metadata
# Act
result = MetadataService.get_dataset_metadatas(dataset)

View File

@ -292,7 +292,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
"""
api = Mock(spec=ExternalKnowledgeApis)
mock_db_session.query.return_value.filter_by.return_value.first.return_value = api
mock_db_session.scalar.return_value = api
result = ExternalDatasetService.get_external_knowledge_api("api-id", "tenant-id")
assert result is api
@ -302,7 +302,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
When the record is absent, a ``ValueError`` is raised.
"""
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="api template not found"):
ExternalDatasetService.get_external_knowledge_api("missing-id", "tenant-id")
@ -320,7 +320,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
existing_api = Mock(spec=ExternalKnowledgeApis)
existing_api.settings_dict = {"api_key": "stored-key"}
existing_api.settings = '{"api_key":"stored-key"}'
mock_db_session.query.return_value.filter_by.return_value.first.return_value = existing_api
mock_db_session.scalar.return_value = existing_api
args = {
"name": "New Name",
@ -340,7 +340,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
Updating a nonexistent API template should raise ``ValueError``.
"""
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="api template not found"):
ExternalDatasetService.update_external_knowledge_api(
@ -356,7 +356,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
"""
api = Mock(spec=ExternalKnowledgeApis)
mock_db_session.query.return_value.filter_by.return_value.first.return_value = api
mock_db_session.scalar.return_value = api
ExternalDatasetService.delete_external_knowledge_api("tenant-1", "api-1")
@ -368,7 +368,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
Deletion of a missing template should raise ``ValueError``.
"""
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="api template not found"):
ExternalDatasetService.delete_external_knowledge_api("tenant-1", "missing")
@ -394,7 +394,7 @@ class TestExternalDatasetServiceUsageAndBindings:
When there are bindings, ``external_knowledge_api_use_check`` returns True and count.
"""
mock_db_session.query.return_value.filter_by.return_value.count.return_value = 3
mock_db_session.scalar.return_value = 3
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1")
@ -406,7 +406,7 @@ class TestExternalDatasetServiceUsageAndBindings:
Zero bindings should return ``(False, 0)``.
"""
mock_db_session.query.return_value.filter_by.return_value.count.return_value = 0
mock_db_session.scalar.return_value = 0
in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1")
@ -419,7 +419,7 @@ class TestExternalDatasetServiceUsageAndBindings:
"""
binding = Mock(spec=ExternalKnowledgeBindings)
mock_db_session.query.return_value.filter_by.return_value.first.return_value = binding
mock_db_session.scalar.return_value = binding
result = ExternalDatasetService.get_external_knowledge_binding_with_dataset_id("tenant-1", "ds-1")
assert result is binding
@ -429,7 +429,7 @@ class TestExternalDatasetServiceUsageAndBindings:
Missing binding should result in a ``ValueError``.
"""
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="external knowledge binding not found"):
ExternalDatasetService.get_external_knowledge_binding_with_dataset_id("tenant-1", "ds-1")
@ -460,7 +460,7 @@ class TestExternalDatasetServiceDocumentCreateArgsValidate:
'[{"document_process_setting":[{"name":"foo","required":true},{"name":"bar","required":false}]}]'
)
# Raw string; the service itself calls json.loads on it
mock_db_session.query.return_value.filter_by.return_value.first.return_value = external_api
mock_db_session.scalar.return_value = external_api
process_parameter = {"foo": "value", "bar": "optional"}
@ -474,7 +474,7 @@ class TestExternalDatasetServiceDocumentCreateArgsValidate:
When the referenced API template is missing, a ``ValueError`` is raised.
"""
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="api template not found"):
ExternalDatasetService.document_create_args_validate("tenant-1", "missing", {})
@ -488,7 +488,7 @@ class TestExternalDatasetServiceDocumentCreateArgsValidate:
external_api.settings = (
'[{"document_process_setting":[{"name":"foo","required":true},{"name":"bar","required":false}]}]'
)
mock_db_session.query.return_value.filter_by.return_value.first.return_value = external_api
mock_db_session.scalar.return_value = external_api
process_parameter = {"bar": "present"} # missing "foo"
@ -702,7 +702,7 @@ class TestExternalDatasetServiceCreateExternalDataset:
}
# No existing dataset with same name.
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
mock_db_session.scalar.side_effect = [
None, # duplicatename check
Mock(spec=ExternalKnowledgeApis), # external knowledge api
]
@ -724,7 +724,7 @@ class TestExternalDatasetServiceCreateExternalDataset:
"""
existing_dataset = Mock(spec=Dataset)
mock_db_session.query.return_value.filter_by.return_value.first.return_value = existing_dataset
mock_db_session.scalar.return_value = existing_dataset
args = {
"name": "Existing",
@ -744,7 +744,7 @@ class TestExternalDatasetServiceCreateExternalDataset:
"""
# First call: duplicate name check not found.
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
mock_db_session.scalar.side_effect = [
None,
None, # external knowledge api lookup
]
@ -763,8 +763,10 @@ class TestExternalDatasetServiceCreateExternalDataset:
``external_knowledge_id`` and ``external_knowledge_api_id`` are mandatory.
"""
# duplicate name check
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
# duplicate name check — two calls to create_external_dataset, each does 2 scalar calls
mock_db_session.scalar.side_effect = [
None,
Mock(spec=ExternalKnowledgeApis),
None,
Mock(spec=ExternalKnowledgeApis),
]
@ -826,7 +828,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
api.settings = '{"endpoint":"https://example.com","api_key":"secret"}'
# First query: binding; second query: api.
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
mock_db_session.scalar.side_effect = [
binding,
api,
]
@ -861,7 +863,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
Missing binding should raise ``ValueError``.
"""
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(ValueError, match="external knowledge binding not found"):
ExternalDatasetService.fetch_external_knowledge_retrieval(
@ -878,7 +880,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
"""
binding = ExternalDatasetTestDataFactory.create_external_binding()
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
mock_db_session.scalar.side_effect = [
binding,
None,
]
@ -901,7 +903,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
api = Mock(spec=ExternalKnowledgeApis)
api.settings = '{"endpoint":"https://example.com","api_key":"secret"}'
mock_db_session.query.return_value.filter_by.return_value.first.side_effect = [
mock_db_session.scalar.side_effect = [
binding,
api,
]

View File

@ -117,9 +117,7 @@ def test_get_all_published_workflow_applies_limit_and_has_more(rag_pipeline_serv
def test_get_pipeline_raises_when_dataset_not_found(mocker, rag_pipeline_service) -> None:
first_query = mocker.Mock()
first_query.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=first_query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
with pytest.raises(ValueError, match="Dataset not found"):
rag_pipeline_service.get_pipeline("tenant-1", "dataset-1")
@ -131,12 +129,8 @@ def test_get_pipeline_raises_when_dataset_not_found(mocker, rag_pipeline_service
def test_update_customized_pipeline_template_success(mocker) -> None:
template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None)
# First query finds the template, second query (duplicate check) returns None
query_mock_1 = mocker.Mock()
query_mock_1.where.return_value.first.return_value = template
query_mock_2 = mocker.Mock()
query_mock_2.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", side_effect=[query_mock_1, query_mock_2])
# First scalar finds the template, second scalar (duplicate check) returns None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[template, None])
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
@ -152,9 +146,7 @@ def test_update_customized_pipeline_template_success(mocker) -> None:
def test_update_customized_pipeline_template_not_found(mocker) -> None:
query_mock = mocker.Mock()
query_mock.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
info = PipelineTemplateInfoEntity(name="x", description="d", icon_info=IconInfo(icon="i"))
@ -166,9 +158,7 @@ def test_update_customized_pipeline_template_duplicate_name(mocker) -> None:
template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None)
duplicate = SimpleNamespace(name="dup")
query_mock = mocker.Mock()
query_mock.where.return_value.first.side_effect = [template, duplicate]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[template, duplicate])
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
info = PipelineTemplateInfoEntity(name="dup", description="d", icon_info=IconInfo(icon="i"))
@ -181,9 +171,7 @@ def test_update_customized_pipeline_template_duplicate_name(mocker) -> None:
def test_delete_customized_pipeline_template_success(mocker) -> None:
template = SimpleNamespace(id="tpl-1")
query_mock = mocker.Mock()
query_mock.where.return_value.first.return_value = template
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=template)
delete_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.delete")
commit_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
@ -196,9 +184,7 @@ def test_delete_customized_pipeline_template_success(mocker) -> None:
def test_delete_customized_pipeline_template_not_found(mocker) -> None:
query_mock = mocker.Mock()
query_mock.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
with pytest.raises(ValueError, match="Customized pipeline template not found"):
@ -397,18 +383,14 @@ def test_get_rag_pipeline_workflow_run_delegates(mocker, rag_pipeline_service) -
def test_is_workflow_exist_returns_true_when_draft_exists(mocker, rag_pipeline_service) -> None:
query_mock = mocker.Mock()
query_mock.where.return_value.count.return_value = 1
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=1)
pipeline = SimpleNamespace(tenant_id="t1", id="p1")
assert rag_pipeline_service.is_workflow_exist(pipeline) is True
def test_is_workflow_exist_returns_false_when_no_draft(mocker, rag_pipeline_service) -> None:
query_mock = mocker.Mock()
query_mock.where.return_value.count.return_value = 0
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query_mock)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=0)
pipeline = SimpleNamespace(tenant_id="t1", id="p1")
assert rag_pipeline_service.is_workflow_exist(pipeline) is False
@ -738,8 +720,7 @@ def test_get_second_step_parameters_success(mocker, rag_pipeline_service) -> Non
def test_publish_customized_pipeline_template_success(mocker, rag_pipeline_service) -> None:
from models.dataset import Dataset, Pipeline, PipelineCustomizedTemplate
from models.workflow import Workflow
from models.dataset import Pipeline
# 1. Setup mocks
pipeline = mocker.Mock(spec=Pipeline)
@ -754,36 +735,15 @@ def test_publish_customized_pipeline_template_success(mocker, rag_pipeline_servi
# Mock db itself to avoid app context errors
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
# Improved mocking for session.query
def mock_query_side_effect(model):
m = mocker.Mock()
if model == Pipeline:
m.where.return_value.first.return_value = pipeline
elif model == Workflow:
m.where.return_value.first.return_value = workflow
elif model == PipelineCustomizedTemplate:
m.where.return_value.first.return_value = None
elif model == Dataset:
m.where.return_value.first.return_value = mocker.Mock()
else:
# For func.max cases
m.where.return_value.scalar.return_value = 5
m.where.return_value.first.return_value = mocker.Mock()
return m
mock_db.session.query.side_effect = mock_query_side_effect
# Mock get() for Pipeline and Workflow PK lookups
mock_db.session.get.side_effect = [pipeline, workflow]
# Mock scalar() for template name check (None) and max position (5)
mock_db.session.scalar.side_effect = [None, 5]
# Mock retrieve_dataset
dataset = mocker.Mock()
pipeline.retrieve_dataset.return_value = dataset
# Mock max position
mocker.patch("services.rag_pipeline.rag_pipeline.func.max", return_value=1)
mocker.patch(
"services.rag_pipeline.rag_pipeline.db.session.query.return_value.where.return_value.scalar",
return_value=5,
)
# Mock RagPipelineDslService
mock_dsl_service = mocker.Mock()
mock_dsl_service.export_rag_pipeline_dsl.return_value = {"dsl": "content"}
@ -839,9 +799,7 @@ def test_get_datasource_plugins_success(mocker, rag_pipeline_service) -> None:
workflow.rag_pipeline_variables = []
# Mock queries
mock_query = mocker.Mock()
mock_query.where.return_value.first.side_effect = [dataset, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=mock_query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
@ -881,11 +839,9 @@ def test_retry_error_document_success(mocker, rag_pipeline_service) -> None:
workflow = mocker.Mock()
# Mock queries
mock_query = mocker.Mock()
# Log lookup, then Pipeline lookup
mock_query.where.return_value.first.side_effect = [log, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=mock_query)
# Mock queries: Log lookup via scalar, Pipeline lookup via get
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=log)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline)
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
@ -913,7 +869,7 @@ def test_set_datasource_variables_success(mocker, rag_pipeline_service) -> None:
# Mock db aggressively
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.engine = mocker.Mock()
mock_db.session.query.return_value.where.return_value.first.return_value = mocker.Mock()
mock_db.session.scalar.return_value = mocker.Mock()
pipeline = mocker.Mock(spec=Pipeline)
pipeline.id = "p-1"
@ -976,7 +932,7 @@ def test_get_draft_workflow_success(mocker, rag_pipeline_service) -> None:
workflow = mocker.Mock(spec=Workflow)
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.session.query.return_value.where.return_value.first.return_value = workflow
mock_db.session.scalar.return_value = workflow
# 2. Run test
result = rag_pipeline_service.get_draft_workflow(pipeline)
@ -998,7 +954,7 @@ def test_get_published_workflow_success(mocker, rag_pipeline_service) -> None:
workflow = mocker.Mock(spec=Workflow)
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.session.query.return_value.where.return_value.first.return_value = workflow
mock_db.session.scalar.return_value = workflow
# 2. Run test
result = rag_pipeline_service.get_published_workflow(pipeline)
@ -1319,11 +1275,8 @@ def test_get_rag_pipeline_workflow_run_node_executions_returns_sorted_executions
def test_get_recommended_plugins_returns_empty_when_no_active_plugins(mocker, rag_pipeline_service) -> None:
query = mocker.Mock()
query.where.return_value = query
query.order_by.return_value.all.return_value = []
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.session.query.return_value = query
mock_db.session.scalars.return_value.all.return_value = []
result = rag_pipeline_service.get_recommended_plugins("all")
@ -1336,11 +1289,8 @@ def test_get_recommended_plugins_returns_empty_when_no_active_plugins(mocker, ra
def test_get_recommended_plugins_returns_installed_and_uninstalled(mocker, rag_pipeline_service) -> None:
plugin_a = SimpleNamespace(plugin_id="plugin-a")
plugin_b = SimpleNamespace(plugin_id="plugin-b")
query = mocker.Mock()
query.where.return_value = query
query.order_by.return_value.all.return_value = [plugin_a, plugin_b]
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.session.query.return_value = query
mock_db.session.scalars.return_value.all.return_value = [plugin_a, plugin_b]
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
mocker.patch(
"services.rag_pipeline.rag_pipeline.BuiltinToolManageService.list_builtin_tools",
@ -1568,9 +1518,7 @@ def test_get_second_step_parameters_filters_first_step_variables(mocker, rag_pip
def test_retry_error_document_raises_when_execution_log_not_found(mocker, rag_pipeline_service) -> None:
query = mocker.Mock()
query.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
with pytest.raises(ValueError, match="Document pipeline execution log not found"):
rag_pipeline_service.retry_error_document(
@ -1581,9 +1529,7 @@ def test_retry_error_document_raises_when_execution_log_not_found(mocker, rag_pi
def test_get_datasource_plugins_raises_when_workflow_not_found(mocker, rag_pipeline_service) -> None:
dataset = SimpleNamespace(pipeline_id="p1")
pipeline = SimpleNamespace(id="p1", tenant_id="t1")
query = mocker.Mock()
query.where.return_value.first.side_effect = [dataset, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=None)
with pytest.raises(ValueError, match="Pipeline or workflow not found"):
@ -1656,8 +1602,7 @@ def test_handle_node_run_result_marks_document_error_for_published_invoke(mocker
document = SimpleNamespace(indexing_status="waiting", error=None)
query = mocker.Mock()
query.where.return_value.first.return_value = document
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=document)
add_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.add")
commit_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
@ -1712,9 +1657,7 @@ def test_run_datasource_node_preview_raises_for_unsupported_provider(mocker, rag
def test_publish_customized_pipeline_template_raises_for_missing_pipeline(mocker, rag_pipeline_service) -> None:
query = mocker.Mock()
query.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=None)
with pytest.raises(ValueError, match="Pipeline not found"):
rag_pipeline_service.publish_customized_pipeline_template("p1", {})
@ -1722,9 +1665,7 @@ def test_publish_customized_pipeline_template_raises_for_missing_pipeline(mocker
def test_publish_customized_pipeline_template_raises_for_missing_workflow_id(mocker, rag_pipeline_service) -> None:
pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id=None)
query = mocker.Mock()
query.where.return_value.first.return_value = pipeline
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline)
with pytest.raises(ValueError, match="Pipeline workflow not found"):
rag_pipeline_service.publish_customized_pipeline_template("p1", {"name": "template-name"})
@ -1732,8 +1673,7 @@ def test_publish_customized_pipeline_template_raises_for_missing_workflow_id(moc
def test_get_pipeline_raises_when_dataset_missing(mocker, rag_pipeline_service) -> None:
query = mocker.Mock()
query.where.return_value.first.return_value = None
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None)
with pytest.raises(ValueError, match="Dataset not found"):
rag_pipeline_service.get_pipeline("t1", "d1")
@ -1742,8 +1682,7 @@ def test_get_pipeline_raises_when_dataset_missing(mocker, rag_pipeline_service)
def test_get_pipeline_raises_when_pipeline_missing(mocker, rag_pipeline_service) -> None:
dataset = SimpleNamespace(pipeline_id="p1")
query = mocker.Mock()
query.where.return_value.first.side_effect = [dataset, None]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, None])
with pytest.raises(ValueError, match="Pipeline not found"):
rag_pipeline_service.get_pipeline("t1", "d1")
@ -1783,8 +1722,7 @@ def test_get_pipeline_templates_builtin_en_us_no_fallback(mocker) -> None:
def test_update_customized_pipeline_template_commits_when_name_empty(mocker) -> None:
template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None)
query = mocker.Mock()
query.where.return_value.first.return_value = template
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=template)
commit = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit")
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
@ -2011,8 +1949,7 @@ def test_run_free_workflow_node_delegates_to_handle_result(mocker, rag_pipeline_
def test_publish_customized_pipeline_template_raises_when_workflow_missing(mocker, rag_pipeline_service) -> None:
pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id="wf-1")
query = mocker.Mock()
query.where.return_value.first.side_effect = [pipeline, None]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", side_effect=[pipeline, None])
with pytest.raises(ValueError, match="Workflow not found"):
rag_pipeline_service.publish_customized_pipeline_template("p1", {})
@ -2021,11 +1958,9 @@ def test_publish_customized_pipeline_template_raises_when_workflow_missing(mocke
def test_publish_customized_pipeline_template_raises_when_dataset_missing(mocker, rag_pipeline_service) -> None:
pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id="wf-1")
workflow = SimpleNamespace(id="wf-1")
query = mocker.Mock()
query.where.return_value.first.side_effect = [pipeline, workflow]
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.engine = mocker.Mock()
mock_db.session.query.return_value = query
mock_db.session.get.side_effect = [pipeline, workflow]
session_ctx = mocker.MagicMock()
session_ctx.__enter__.return_value = SimpleNamespace()
session_ctx.__exit__.return_value = False
@ -2038,11 +1973,8 @@ def test_publish_customized_pipeline_template_raises_when_dataset_missing(mocker
def test_get_recommended_plugins_skips_manifest_when_missing(mocker, rag_pipeline_service) -> None:
plugin = SimpleNamespace(plugin_id="plugin-a")
query = mocker.Mock()
query.where.return_value = query
query.order_by.return_value.all.return_value = [plugin]
mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db")
mock_db.session.query.return_value = query
mock_db.session.scalars.return_value.all.return_value = [plugin]
mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1"))
mocker.patch("services.rag_pipeline.rag_pipeline.BuiltinToolManageService.list_builtin_tools", return_value=[])
mocker.patch("services.rag_pipeline.rag_pipeline.marketplace.batch_fetch_plugin_by_ids", return_value=[])
@ -2056,8 +1988,8 @@ def test_get_recommended_plugins_skips_manifest_when_missing(mocker, rag_pipelin
def test_retry_error_document_raises_when_pipeline_missing(mocker, rag_pipeline_service) -> None:
exec_log = SimpleNamespace(pipeline_id="p1")
query = mocker.Mock()
query.where.return_value.first.side_effect = [exec_log, None]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=exec_log)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=None)
with pytest.raises(ValueError, match="Pipeline not found"):
rag_pipeline_service.retry_error_document(
@ -2069,8 +2001,8 @@ def test_retry_error_document_raises_when_workflow_missing(mocker, rag_pipeline_
exec_log = SimpleNamespace(pipeline_id="p1")
pipeline = SimpleNamespace(id="p1")
query = mocker.Mock()
query.where.return_value.first.side_effect = [exec_log, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=exec_log)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline)
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=None)
with pytest.raises(ValueError, match="Workflow not found"):
@ -2086,8 +2018,7 @@ def test_get_datasource_plugins_returns_empty_for_non_datasource_nodes(mocker, r
graph_dict={"nodes": [{"id": "n1", "data": {"type": "start"}}]}, rag_pipeline_variables=[]
)
query = mocker.Mock()
query.where.return_value.first.side_effect = [dataset, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
assert rag_pipeline_service.get_datasource_plugins("t1", "d1", True) == []
@ -2250,8 +2181,7 @@ def test_get_datasource_plugins_handles_empty_datasource_data_and_non_published(
rag_pipeline_variables=[{"variable": "v1", "belong_to_node_id": "shared"}],
)
query = mocker.Mock()
query.where.return_value.first.side_effect = [dataset, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
mocker.patch.object(rag_pipeline_service, "get_draft_workflow", return_value=workflow)
mocker.patch(
"services.rag_pipeline.rag_pipeline.DatasourceProviderService.list_datasource_credentials", return_value=[]
@ -2291,8 +2221,7 @@ def test_get_datasource_plugins_extracts_user_inputs_and_credentials(mocker, rag
],
)
query = mocker.Mock()
query.where.return_value.first.side_effect = [dataset, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
mocker.patch.object(rag_pipeline_service, "get_published_workflow", return_value=workflow)
mocker.patch(
"services.rag_pipeline.rag_pipeline.DatasourceProviderService.list_datasource_credentials",
@ -2310,8 +2239,7 @@ def test_get_pipeline_returns_pipeline_when_found(mocker, rag_pipeline_service)
dataset = SimpleNamespace(pipeline_id="p1")
pipeline = SimpleNamespace(id="p1")
query = mocker.Mock()
query.where.return_value.first.side_effect = [dataset, pipeline]
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.query", return_value=query)
mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline])
result = rag_pipeline_service.get_pipeline("t1", "d1")

View File

@ -173,9 +173,7 @@ class TestAccountService:
# Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
# Setup smart database query mock
query_results = {("Account", "email", "test@example.com"): mock_account}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.scalar.return_value = mock_account
mock_password_dependencies["compare_password"].return_value = True
@ -188,9 +186,7 @@ class TestAccountService:
def test_authenticate_account_not_found(self, mock_db_dependencies):
"""Test authentication when account does not exist."""
# Setup smart database query mock - no matching results
query_results = {("Account", "email", "notfound@example.com"): None}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.scalar.return_value = None
# Execute test and verify exception
self._assert_exception_raised(
@ -202,9 +198,7 @@ class TestAccountService:
# Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="banned")
# Setup smart database query mock
query_results = {("Account", "email", "banned@example.com"): mock_account}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.scalar.return_value = mock_account
# Execute test and verify exception
self._assert_exception_raised(AccountLoginError, AccountService.authenticate, "banned@example.com", "password")
@ -214,9 +208,7 @@ class TestAccountService:
# Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
# Setup smart database query mock
query_results = {("Account", "email", "test@example.com"): mock_account}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.scalar.return_value = mock_account
mock_password_dependencies["compare_password"].return_value = False
@ -230,9 +222,7 @@ class TestAccountService:
# Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="pending")
# Setup smart database query mock
query_results = {("Account", "email", "pending@example.com"): mock_account}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.scalar.return_value = mock_account
mock_password_dependencies["compare_password"].return_value = True
@ -422,12 +412,8 @@ class TestAccountService:
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
mock_tenant_join = TestAccountAssociatedDataFactory.create_tenant_join_mock()
# Setup smart database query mock
query_results = {
("Account", "id", "user-123"): mock_account,
("TenantAccountJoin", "account_id", "user-123"): mock_tenant_join,
}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.get.return_value = mock_account
mock_db_dependencies["db"].session.scalar.return_value = mock_tenant_join
# Mock datetime
with patch("services.account_service.datetime") as mock_datetime:
@ -444,9 +430,7 @@ class TestAccountService:
def test_load_user_not_found(self, mock_db_dependencies):
"""Test user loading when user does not exist."""
# Setup smart database query mock - no matching results
query_results = {("Account", "id", "non-existent-user"): None}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.get.return_value = None
# Execute test
result = AccountService.load_user("non-existent-user")
@ -459,9 +443,7 @@ class TestAccountService:
# Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="banned")
# Setup smart database query mock
query_results = {("Account", "id", "user-123"): mock_account}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.get.return_value = mock_account
# Execute test and verify exception
self._assert_exception_raised(
@ -476,13 +458,9 @@ class TestAccountService:
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
mock_available_tenant = TestAccountAssociatedDataFactory.create_tenant_join_mock(current=False)
# Setup smart database query mock for complex scenario
query_results = {
("Account", "id", "user-123"): mock_account,
("TenantAccountJoin", "account_id", "user-123"): None, # No current tenant
("TenantAccountJoin", "order_by", "first_available"): mock_available_tenant, # First available tenant
}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.get.return_value = mock_account
# First scalar: current tenant (None), second scalar: available tenant
mock_db_dependencies["db"].session.scalar.side_effect = [None, mock_available_tenant]
# Mock datetime
with patch("services.account_service.datetime") as mock_datetime:
@ -503,13 +481,9 @@ class TestAccountService:
# Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
# Setup smart database query mock for no tenants scenario
query_results = {
("Account", "id", "user-123"): mock_account,
("TenantAccountJoin", "account_id", "user-123"): None, # No current tenant
("TenantAccountJoin", "order_by", "first_available"): None, # No available tenants
}
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
mock_db_dependencies["db"].session.get.return_value = mock_account
# First scalar: current tenant (None), second scalar: available tenant (None)
mock_db_dependencies["db"].session.scalar.side_effect = [None, None]
# Mock datetime
with patch("services.account_service.datetime") as mock_datetime:

View File

@ -799,10 +799,7 @@ class TestExternalDatasetServiceGetAPI:
api_id = "api-123"
expected_api = factory.create_external_knowledge_api_mock(api_id=api_id)
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = expected_api
mock_db.session.scalar.return_value = expected_api
# Act
tenant_id = "tenant-123"
@ -810,16 +807,12 @@ class TestExternalDatasetServiceGetAPI:
# Assert
assert result.id == api_id
mock_query.filter_by.assert_called_once_with(id=api_id, tenant_id=tenant_id)
@patch("services.external_knowledge_service.db")
def test_get_external_knowledge_api_not_found(self, mock_db, factory):
"""Test error when API is not found."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(ValueError, match="api template not found"):
@ -848,10 +841,7 @@ class TestExternalDatasetServiceUpdateAPI:
"settings": {"endpoint": "https://new.example.com", "api_key": "new-key"},
}
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = existing_api
mock_db.session.scalar.return_value = existing_api
# Act
result = ExternalDatasetService.update_external_knowledge_api(tenant_id, user_id, api_id, args)
@ -881,10 +871,7 @@ class TestExternalDatasetServiceUpdateAPI:
"settings": {"endpoint": "https://api.example.com", "api_key": HIDDEN_VALUE},
}
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = existing_api
mock_db.session.scalar.return_value = existing_api
# Act
result = ExternalDatasetService.update_external_knowledge_api(tenant_id, "user-123", api_id, args)
@ -897,10 +884,7 @@ class TestExternalDatasetServiceUpdateAPI:
def test_update_external_knowledge_api_not_found(self, mock_db, factory):
"""Test error when API is not found."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db.session.scalar.return_value = None
args = {"name": "Updated API"}
@ -912,10 +896,7 @@ class TestExternalDatasetServiceUpdateAPI:
def test_update_external_knowledge_api_tenant_mismatch(self, mock_db, factory):
"""Test error when tenant ID doesn't match."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db.session.scalar.return_value = None
args = {"name": "Updated API"}
@ -934,10 +915,7 @@ class TestExternalDatasetServiceUpdateAPI:
args = {"name": "New Name Only"}
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = existing_api
mock_db.session.scalar.return_value = existing_api
# Act
result = ExternalDatasetService.update_external_knowledge_api("tenant-123", "user-123", "api-123", args)
@ -958,10 +936,7 @@ class TestExternalDatasetServiceDeleteAPI:
existing_api = factory.create_external_knowledge_api_mock(api_id=api_id, tenant_id=tenant_id)
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = existing_api
mock_db.session.scalar.return_value = existing_api
# Act
ExternalDatasetService.delete_external_knowledge_api(tenant_id, api_id)
@ -974,10 +949,7 @@ class TestExternalDatasetServiceDeleteAPI:
def test_delete_external_knowledge_api_not_found(self, mock_db, factory):
"""Test error when API is not found."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(ValueError, match="api template not found"):
@ -987,10 +959,7 @@ class TestExternalDatasetServiceDeleteAPI:
def test_delete_external_knowledge_api_tenant_mismatch(self, mock_db, factory):
"""Test error when tenant ID doesn't match."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(ValueError, match="api template not found"):
@ -1006,10 +975,7 @@ class TestExternalDatasetServiceAPIUseCheck:
# Arrange
api_id = "api-123"
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.count.return_value = 1
mock_db.session.scalar.return_value = 1
# Act
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
@ -1024,10 +990,7 @@ class TestExternalDatasetServiceAPIUseCheck:
# Arrange
api_id = "api-123"
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.count.return_value = 10
mock_db.session.scalar.return_value = 10
# Act
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
@ -1042,10 +1005,7 @@ class TestExternalDatasetServiceAPIUseCheck:
# Arrange
api_id = "api-123"
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.count.return_value = 0
mock_db.session.scalar.return_value = 0
# Act
in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id)
@ -1067,10 +1027,7 @@ class TestExternalDatasetServiceGetBinding:
expected_binding = factory.create_external_knowledge_binding_mock(tenant_id=tenant_id, dataset_id=dataset_id)
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = expected_binding
mock_db.session.scalar.return_value = expected_binding
# Act
result = ExternalDatasetService.get_external_knowledge_binding_with_dataset_id(tenant_id, dataset_id)
@ -1083,10 +1040,7 @@ class TestExternalDatasetServiceGetBinding:
def test_get_external_knowledge_binding_not_found(self, mock_db, factory):
"""Test error when binding is not found."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(ValueError, match="external knowledge binding not found"):
@ -1113,10 +1067,7 @@ class TestExternalDatasetServiceDocumentValidate:
api = factory.create_external_knowledge_api_mock(api_id=api_id, settings=[settings])
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = api
mock_db.session.scalar.return_value = api
process_parameter = {"param1": "value1", "param2": "value2"}
@ -1134,10 +1085,7 @@ class TestExternalDatasetServiceDocumentValidate:
api = factory.create_external_knowledge_api_mock(api_id=api_id, settings=[settings])
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = api
mock_db.session.scalar.return_value = api
process_parameter = {}
@ -1149,10 +1097,7 @@ class TestExternalDatasetServiceDocumentValidate:
def test_document_create_args_validate_api_not_found(self, mock_db, factory):
"""Test validation fails when API is not found."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(ValueError, match="api template not found"):
@ -1165,10 +1110,7 @@ class TestExternalDatasetServiceDocumentValidate:
settings = {}
api = factory.create_external_knowledge_api_mock(settings=[settings])
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = api
mock_db.session.scalar.return_value = api
# Act & Assert - should not raise
ExternalDatasetService.document_create_args_validate("tenant-123", "api-123", {})
@ -1186,10 +1128,7 @@ class TestExternalDatasetServiceDocumentValidate:
api = factory.create_external_knowledge_api_mock(settings=[settings])
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = api
mock_db.session.scalar.return_value = api
process_parameter = {"required_param": "value"}
@ -1498,24 +1437,7 @@ class TestExternalDatasetServiceCreateDataset:
api = factory.create_external_knowledge_api_mock(api_id="api-123")
# Mock database queries
mock_dataset_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == Dataset:
return mock_dataset_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_dataset_query.filter_by.return_value = mock_dataset_query
mock_dataset_query.first.return_value = None
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [None, api]
# Act
result = ExternalDatasetService.create_external_dataset(tenant_id, user_id, args)
@ -1534,10 +1456,7 @@ class TestExternalDatasetServiceCreateDataset:
# Arrange
existing_dataset = factory.create_dataset_mock(name="Duplicate Dataset")
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = existing_dataset
mock_db.session.scalar.return_value = existing_dataset
args = {"name": "Duplicate Dataset"}
@ -1549,23 +1468,7 @@ class TestExternalDatasetServiceCreateDataset:
def test_create_external_dataset_api_not_found_error(self, mock_db, factory):
"""Test error when external knowledge API is not found."""
# Arrange
mock_dataset_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == Dataset:
return mock_dataset_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_dataset_query.filter_by.return_value = mock_dataset_query
mock_dataset_query.first.return_value = None
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = None
mock_db.session.scalar.side_effect = [None, None]
args = {"name": "Test Dataset", "external_knowledge_api_id": "nonexistent-api"}
@ -1579,23 +1482,7 @@ class TestExternalDatasetServiceCreateDataset:
# Arrange
api = factory.create_external_knowledge_api_mock()
mock_dataset_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == Dataset:
return mock_dataset_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_dataset_query.filter_by.return_value = mock_dataset_query
mock_dataset_query.first.return_value = None
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [None, api]
args = {"name": "Test Dataset", "external_knowledge_api_id": "api-123"}
@ -1609,23 +1496,7 @@ class TestExternalDatasetServiceCreateDataset:
# Arrange
api = factory.create_external_knowledge_api_mock()
mock_dataset_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == Dataset:
return mock_dataset_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_dataset_query.filter_by.return_value = mock_dataset_query
mock_dataset_query.first.return_value = None
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [None, api]
args = {"name": "Test Dataset", "external_knowledge_id": "knowledge-123"}
@ -1651,23 +1522,7 @@ class TestExternalDatasetServiceFetchRetrieval:
)
api = factory.create_external_knowledge_api_mock(api_id="api-123")
mock_binding_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == ExternalKnowledgeBindings:
return mock_binding_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.first.return_value = binding
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [binding, api]
mock_response = MagicMock()
mock_response.status_code = 200
@ -1695,10 +1550,7 @@ class TestExternalDatasetServiceFetchRetrieval:
def test_fetch_external_knowledge_retrieval_binding_not_found_error(self, mock_db, factory):
"""Test error when external knowledge binding is not found."""
# Arrange
mock_query = MagicMock()
mock_db.session.query.return_value = mock_query
mock_query.filter_by.return_value = mock_query
mock_query.first.return_value = None
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(ValueError, match="external knowledge binding not found"):
@ -1712,23 +1564,7 @@ class TestExternalDatasetServiceFetchRetrieval:
binding = factory.create_external_knowledge_binding_mock()
api = factory.create_external_knowledge_api_mock()
mock_binding_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == ExternalKnowledgeBindings:
return mock_binding_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.first.return_value = binding
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [binding, api]
mock_response = MagicMock()
mock_response.status_code = 200
@ -1751,23 +1587,7 @@ class TestExternalDatasetServiceFetchRetrieval:
binding = factory.create_external_knowledge_binding_mock()
api = factory.create_external_knowledge_api_mock()
mock_binding_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == ExternalKnowledgeBindings:
return mock_binding_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.first.return_value = binding
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [binding, api]
mock_response = MagicMock()
mock_response.status_code = 200
@ -1799,23 +1619,7 @@ class TestExternalDatasetServiceFetchRetrieval:
binding = factory.create_external_knowledge_binding_mock()
api = factory.create_external_knowledge_api_mock()
mock_binding_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == ExternalKnowledgeBindings:
return mock_binding_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.first.return_value = binding
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [binding, api]
mock_response = MagicMock()
mock_response.status_code = 500
@ -1856,23 +1660,7 @@ class TestExternalDatasetServiceFetchRetrieval:
)
api = factory.create_external_knowledge_api_mock(api_id="api-123")
mock_binding_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == ExternalKnowledgeBindings:
return mock_binding_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.first.return_value = binding
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [binding, api]
mock_response = MagicMock()
mock_response.status_code = status_code
@ -1891,23 +1679,7 @@ class TestExternalDatasetServiceFetchRetrieval:
binding = factory.create_external_knowledge_binding_mock()
api = factory.create_external_knowledge_api_mock()
mock_binding_query = MagicMock()
mock_api_query = MagicMock()
def query_side_effect(model):
if model == ExternalKnowledgeBindings:
return mock_binding_query
elif model == ExternalKnowledgeApis:
return mock_api_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
mock_binding_query.filter_by.return_value = mock_binding_query
mock_binding_query.first.return_value = binding
mock_api_query.filter_by.return_value = mock_api_query
mock_api_query.first.return_value = api
mock_db.session.scalar.side_effect = [binding, api]
mock_response = MagicMock()
mock_response.status_code = 503

40
api/uv.lock generated
View File

@ -53,23 +53,6 @@ dependencies = [
]
sdist = { url = "https://files.pythonhosted.org/packages/45/4a/064321452809dae953c1ed6e017504e72551a26b6f5708a5a80e4bf556ff/aiohttp-3.13.4.tar.gz", hash = "sha256:d97a6d09c66087890c2ab5d49069e1e570583f7ac0314ecf98294c1b6aaebd38", size = 7859748, upload-time = "2026-03-28T17:19:40.6Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d4/7e/cb94129302d78c46662b47f9897d642fd0b33bdfef4b73b20c6ced35aa4c/aiohttp-3.13.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8ea0c64d1bcbf201b285c2246c51a0c035ba3bbd306640007bc5844a3b4658c1", size = 760027, upload-time = "2026-03-28T17:15:33.022Z" },
{ url = "https://files.pythonhosted.org/packages/5e/cd/2db3c9397c3bd24216b203dd739945b04f8b87bb036c640da7ddb63c75ef/aiohttp-3.13.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6f742e1fa45c0ed522b00ede565e18f97e4cf8d1883a712ac42d0339dfb0cce7", size = 508325, upload-time = "2026-03-28T17:15:34.714Z" },
{ url = "https://files.pythonhosted.org/packages/36/a3/d28b2722ec13107f2e37a86b8a169897308bab6a3b9e071ecead9d67bd9b/aiohttp-3.13.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6dcfb50ee25b3b7a1222a9123be1f9f89e56e67636b561441f0b304e25aaef8f", size = 502402, upload-time = "2026-03-28T17:15:36.409Z" },
{ url = "https://files.pythonhosted.org/packages/fa/d6/acd47b5f17c4430e555590990a4746efbcb2079909bb865516892bf85f37/aiohttp-3.13.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3262386c4ff370849863ea93b9ea60fd59c6cf56bf8f93beac625cf4d677c04d", size = 1771224, upload-time = "2026-03-28T17:15:38.223Z" },
{ url = "https://files.pythonhosted.org/packages/98/af/af6e20113ba6a48fd1cd9e5832c4851e7613ef50c7619acdaee6ec5f1aff/aiohttp-3.13.4-cp311-cp311-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:473bb5aa4218dd254e9ae4834f20e31f5a0083064ac0136a01a62ddbae2eaa42", size = 1731530, upload-time = "2026-03-28T17:15:39.988Z" },
{ url = "https://files.pythonhosted.org/packages/81/16/78a2f5d9c124ad05d5ce59a9af94214b6466c3491a25fb70760e98e9f762/aiohttp-3.13.4-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:e56423766399b4c77b965f6aaab6c9546617b8994a956821cc507d00b91d978c", size = 1827925, upload-time = "2026-03-28T17:15:41.944Z" },
{ url = "https://files.pythonhosted.org/packages/2a/1f/79acf0974ced805e0e70027389fccbb7d728e6f30fcac725fb1071e63075/aiohttp-3.13.4-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:8af249343fafd5ad90366a16d230fc265cf1149f26075dc9fe93cfd7c7173942", size = 1923579, upload-time = "2026-03-28T17:15:44.071Z" },
{ url = "https://files.pythonhosted.org/packages/af/53/29f9e2054ea6900413f3b4c3eb9d8331f60678ec855f13ba8714c47fd48d/aiohttp-3.13.4-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0bc0a5cf4f10ef5a2c94fdde488734b582a3a7a000b131263e27c9295bd682d9", size = 1767655, upload-time = "2026-03-28T17:15:45.911Z" },
{ url = "https://files.pythonhosted.org/packages/f3/57/462fe1d3da08109ba4aa8590e7aed57c059af2a7e80ec21f4bac5cfe1094/aiohttp-3.13.4-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:5c7ff1028e3c9fc5123a865ce17df1cb6424d180c503b8517afbe89aa566e6be", size = 1630439, upload-time = "2026-03-28T17:15:48.11Z" },
{ url = "https://files.pythonhosted.org/packages/d7/4b/4813344aacdb8127263e3eec343d24e973421143826364fa9fc847f6283f/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ba5cf98b5dcb9bddd857da6713a503fa6d341043258ca823f0f5ab7ab4a94ee8", size = 1745557, upload-time = "2026-03-28T17:15:50.13Z" },
{ url = "https://files.pythonhosted.org/packages/d4/01/1ef1adae1454341ec50a789f03cfafe4c4ac9c003f6a64515ecd32fe4210/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:d85965d3ba21ee4999e83e992fecb86c4614d6920e40705501c0a1f80a583c12", size = 1741796, upload-time = "2026-03-28T17:15:52.351Z" },
{ url = "https://files.pythonhosted.org/packages/22/04/8cdd99af988d2aa6922714d957d21383c559835cbd43fbf5a47ddf2e0f05/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:49f0b18a9b05d79f6f37ddd567695943fcefb834ef480f17a4211987302b2dc7", size = 1805312, upload-time = "2026-03-28T17:15:54.407Z" },
{ url = "https://files.pythonhosted.org/packages/fb/7f/b48d5577338d4b25bbdbae35c75dbfd0493cb8886dc586fbfb2e90862239/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:7f78cb080c86fbf765920e5f1ef35af3f24ec4314d6675d0a21eaf41f6f2679c", size = 1621751, upload-time = "2026-03-28T17:15:56.564Z" },
{ url = "https://files.pythonhosted.org/packages/bc/89/4eecad8c1858e6d0893c05929e22343e0ebe3aec29a8a399c65c3cc38311/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:67a3ec705534a614b68bbf1c70efa777a21c3da3895d1c44510a41f5a7ae0453", size = 1826073, upload-time = "2026-03-28T17:15:58.489Z" },
{ url = "https://files.pythonhosted.org/packages/f5/5c/9dc8293ed31b46c39c9c513ac7ca152b3c3d38e0ea111a530ad12001b827/aiohttp-3.13.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d6630ec917e85c5356b2295744c8a97d40f007f96a1c76bf1928dc2e27465393", size = 1760083, upload-time = "2026-03-28T17:16:00.677Z" },
{ url = "https://files.pythonhosted.org/packages/1e/19/8bbf6a4994205d96831f97b7d21a0feed120136e6267b5b22d229c6dc4dc/aiohttp-3.13.4-cp311-cp311-win32.whl", hash = "sha256:54049021bc626f53a5394c29e8c444f726ee5a14b6e89e0ad118315b1f90f5e3", size = 439690, upload-time = "2026-03-28T17:16:02.902Z" },
{ url = "https://files.pythonhosted.org/packages/0c/f5/ac409ecd1007528d15c3e8c3a57d34f334c70d76cfb7128a28cffdebd4c1/aiohttp-3.13.4-cp311-cp311-win_amd64.whl", hash = "sha256:c033f2bc964156030772d31cbf7e5defea181238ce1f87b9455b786de7d30145", size = 463824, upload-time = "2026-03-28T17:16:05.058Z" },
{ url = "https://files.pythonhosted.org/packages/1e/bd/ede278648914cabbabfdf95e436679b5d4156e417896a9b9f4587169e376/aiohttp-3.13.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:ee62d4471ce86b108b19c3364db4b91180d13fe3510144872d6bad5401957360", size = 752158, upload-time = "2026-03-28T17:16:06.901Z" },
{ url = "https://files.pythonhosted.org/packages/90/de/581c053253c07b480b03785196ca5335e3c606a37dc73e95f6527f1591fe/aiohttp-3.13.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c0fd8f41b54b58636402eb493afd512c23580456f022c1ba2db0f810c959ed0d", size = 501037, upload-time = "2026-03-28T17:16:08.82Z" },
{ url = "https://files.pythonhosted.org/packages/fa/f9/a5ede193c08f13cc42c0a5b50d1e246ecee9115e4cf6e900d8dbd8fd6acb/aiohttp-3.13.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4baa48ce49efd82d6b1a0be12d6a36b35e5594d1dd42f8bfba96ea9f8678b88c", size = 501556, upload-time = "2026-03-28T17:16:10.63Z" },
@ -1586,7 +1569,7 @@ dev = [
{ name = "lxml-stubs", specifier = "~=0.5.1" },
{ name = "mypy", specifier = "~=1.19.1" },
{ name = "pandas-stubs", specifier = "~=3.0.0" },
{ name = "pyrefly", specifier = ">=0.57.1" },
{ name = "pyrefly", specifier = ">=0.59.1" },
{ name = "pytest", specifier = "~=9.0.2" },
{ name = "pytest-benchmark", specifier = "~=5.2.3" },
{ name = "pytest-cov", specifier = "~=7.1.0" },
@ -4839,18 +4822,19 @@ wheels = [
[[package]]
name = "pyrefly"
version = "0.57.1"
version = "0.59.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/c9/c1/c17211e5bbd2b90a24447484713da7cc2cee4e9455e57b87016ffc69d426/pyrefly-0.57.1.tar.gz", hash = "sha256:b05f6f5ee3a6a5d502ca19d84cb9ab62d67f05083819964a48c1510f2993efc6", size = 5310800, upload-time = "2026-03-18T18:42:35.614Z" }
sdist = { url = "https://files.pythonhosted.org/packages/d5/ce/7882c2af92b2ff6505fcd3430eff8048ece6c6254cc90bdc76ecee12dfab/pyrefly-0.59.1.tar.gz", hash = "sha256:bf1675b0c38d45df2c8f8618cbdfa261a1b92430d9d31eba16e0282b551e210f", size = 5475432, upload-time = "2026-04-01T22:04:04.11Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b7/58/8af37856c8d45b365ece635a6728a14b0356b08d1ff1ac601d7120def1e0/pyrefly-0.57.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:91974bfbe951eebf5a7bc959c1f3921f0371c789cad84761511d695e9ab2265f", size = 12681847, upload-time = "2026-03-18T18:42:10.963Z" },
{ url = "https://files.pythonhosted.org/packages/5f/d7/fae6dd9d0355fc5b8df7793f1423b7433ca8e10b698ea934c35f0e4e6522/pyrefly-0.57.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:808087298537c70f5e7cdccb5bbaad482e7e056e947c0adf00fb612cbace9fdc", size = 12219634, upload-time = "2026-03-18T18:42:13.469Z" },
{ url = "https://files.pythonhosted.org/packages/29/8f/9511ae460f0690e837b9ba0f7e5e192079e16ff9a9ba8a272450e81f11f8/pyrefly-0.57.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b01f454fa5539e070c0cba17ddec46b3d2107d571d519bd8eca8f3142ba02a6", size = 34947757, upload-time = "2026-03-18T18:42:17.152Z" },
{ url = "https://files.pythonhosted.org/packages/07/43/f053bf9c65218f70e6a49561e9942c7233f8c3e4da8d42e5fe2aae50b3d2/pyrefly-0.57.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:02ad59ea722191f51635f23e37574662116b82ca9d814529f7cb5528f041f381", size = 37621018, upload-time = "2026-03-18T18:42:20.79Z" },
{ url = "https://files.pythonhosted.org/packages/0e/76/9cea46de01665bbc125e4f215340c9365c8d56cda6198ff238a563ea8e75/pyrefly-0.57.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54bc0afe56776145e37733ff763e7e9679ee8a76c467b617dc3f227d4124a9e2", size = 40203649, upload-time = "2026-03-18T18:42:24.519Z" },
{ url = "https://files.pythonhosted.org/packages/fd/8b/2fb4a96d75e2a57df698a43e2970e441ba2704e3906cdc0386a055daa05a/pyrefly-0.57.1-py3-none-win32.whl", hash = "sha256:468e5839144b25bb0dce839bfc5fd879c9f38e68ebf5de561f30bed9ae19d8ca", size = 11732953, upload-time = "2026-03-18T18:42:27.379Z" },
{ url = "https://files.pythonhosted.org/packages/13/5a/4a197910fe2e9b102b15ae5e7687c45b7b5981275a11a564b41e185dd907/pyrefly-0.57.1-py3-none-win_amd64.whl", hash = "sha256:46db9c97093673c4fb7fab96d610e74d140661d54688a92d8e75ad885a56c141", size = 12537319, upload-time = "2026-03-18T18:42:30.196Z" },
{ url = "https://files.pythonhosted.org/packages/b5/c6/bc442874be1d9b63da1f9debb4f04b7d0c590a8dc4091921f3c288207242/pyrefly-0.57.1-py3-none-win_arm64.whl", hash = "sha256:feb1bbe3b0d8d5a70121dcdf1476e6a99cc056a26a49379a156f040729244dcb", size = 12013455, upload-time = "2026-03-18T18:42:32.928Z" },
{ url = "https://files.pythonhosted.org/packages/d0/10/04a0e05b08fc855b6fe38c3df549925fc3c2c6e750506870de7335d3e1f7/pyrefly-0.59.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:390db3cd14aa7e0268e847b60cd9ee18b04273eddfa38cf341ed3bb43f3fef2a", size = 12868133, upload-time = "2026-04-01T22:03:39.436Z" },
{ url = "https://files.pythonhosted.org/packages/c7/78/fa7be227c3e3fcacee501c1562278dd026186ffd1b5b5beb51d3941a3aed/pyrefly-0.59.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d246d417b6187c1650d7f855f61c68fbfd6d6155dc846d4e4d273a3e6b5175cb", size = 12379325, upload-time = "2026-04-01T22:03:42.046Z" },
{ url = "https://files.pythonhosted.org/packages/bb/13/6828ce1c98171b5f8388f33c4b0b9ea2ab8c49abe0ef8d793c31e30a05cb/pyrefly-0.59.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:575ac67b04412dc651a7143d27e38a40fbdd3c831c714d5520d0e9d4c8631ab4", size = 35826408, upload-time = "2026-04-01T22:03:45.067Z" },
{ url = "https://files.pythonhosted.org/packages/23/56/79ed8ece9a7ecad0113c394a06a084107db3ad8f1fefe19e7ded43c51245/pyrefly-0.59.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:062e6262ce1064d59dcad81ac0499bb7a3ad501e9bc8a677a50dc630ff0bf862", size = 38532699, upload-time = "2026-04-01T22:03:48.376Z" },
{ url = "https://files.pythonhosted.org/packages/18/7d/ecc025e0f0e3f295b497f523cc19cefaa39e57abede8fc353d29445d174b/pyrefly-0.59.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:43ef4247f9e6f734feb93e1f2b75335b943629956e509f545cc9cdcccd76dd20", size = 36743570, upload-time = "2026-04-01T22:03:51.362Z" },
{ url = "https://files.pythonhosted.org/packages/2f/03/b1ce882ebcb87c673165c00451fbe4df17bf96ccfde18c75880dc87c5f5e/pyrefly-0.59.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59a2d01723b84d042f4fa6ec871ffd52d0a7e83b0ea791c2e0bb0ff750abce56", size = 41236246, upload-time = "2026-04-01T22:03:54.361Z" },
{ url = "https://files.pythonhosted.org/packages/17/af/5e9c7afd510e7dd64a2204be0ed39e804089cbc4338675a28615c7176acb/pyrefly-0.59.1-py3-none-win32.whl", hash = "sha256:4ea70c780848f8376411e787643ae5d2d09da8a829362332b7b26d15ebcbaf56", size = 11884747, upload-time = "2026-04-01T22:03:56.776Z" },
{ url = "https://files.pythonhosted.org/packages/aa/c1/7db1077627453fd1068f0761f059a9512645c00c4c20acfb9f0c24ac02ec/pyrefly-0.59.1-py3-none-win_amd64.whl", hash = "sha256:67e6a08cfd129a0d2788d5e40a627f9860e0fe91a876238d93d5c63ff4af68ae", size = 12720608, upload-time = "2026-04-01T22:03:59.252Z" },
{ url = "https://files.pythonhosted.org/packages/07/16/4bb6e5fce5a9cf0992932d9435d964c33e507aaaf96fdfbb1be493078a4a/pyrefly-0.59.1-py3-none-win_arm64.whl", hash = "sha256:01179cb215cf079e8223a064f61a074f7079aa97ea705cbbc68af3d6713afd15", size = 12223158, upload-time = "2026-04-01T22:04:01.869Z" },
]
[[package]]

2
packages/dify-ui/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
dist/
node_modules/

View File

@ -0,0 +1,82 @@
{
"name": "@langgenius/dify-ui",
"private": true,
"version": "0.0.0-private",
"type": "module",
"files": [
"dist"
],
"sideEffects": [
"**/*.css"
],
"exports": {
"./context-menu": {
"types": "./dist/context-menu/index.d.ts",
"import": "./dist/context-menu/index.js",
"default": "./dist/context-menu/index.js"
},
"./dropdown-menu": {
"types": "./dist/dropdown-menu/index.d.ts",
"import": "./dist/dropdown-menu/index.js",
"default": "./dist/dropdown-menu/index.js"
},
"./tailwind-preset": {
"types": "./dist/tailwind-preset.d.ts",
"import": "./dist/tailwind-preset.js",
"default": "./dist/tailwind-preset.js"
},
"./styles.css": "./dist/styles.css",
"./markdown.css": "./dist/markdown.css",
"./themes/light.css": "./dist/themes/light.css",
"./themes/dark.css": "./dist/themes/dark.css",
"./themes/manual-light.css": "./dist/themes/manual-light.css",
"./themes/manual-dark.css": "./dist/themes/manual-dark.css",
"./themes/markdown-light.css": "./dist/themes/markdown-light.css",
"./themes/markdown-dark.css": "./dist/themes/markdown-dark.css",
"./tokens/tailwind-theme-var-define": {
"types": "./dist/tokens/tailwind-theme-var-define.d.ts",
"import": "./dist/tokens/tailwind-theme-var-define.js",
"default": "./dist/tokens/tailwind-theme-var-define.js"
},
"./package.json": "./package.json"
},
"scripts": {
"build": "node ./scripts/build.mjs",
"prepack": "pnpm build",
"test": "vp test",
"test:watch": "vp test --watch",
"type-check": "tsc -p tsconfig.json --noEmit"
},
"peerDependencies": {
"react": "catalog:",
"react-dom": "catalog:"
},
"dependencies": {
"@base-ui/react": "catalog:",
"@dify/iconify-collections": "workspace:*",
"@egoist/tailwindcss-icons": "catalog:",
"@iconify-json/heroicons": "catalog:",
"@iconify-json/ri": "catalog:",
"@remixicon/react": "catalog:",
"@tailwindcss/typography": "catalog:",
"clsx": "catalog:",
"tailwind-merge": "catalog:"
},
"devDependencies": {
"@storybook/react": "catalog:",
"@testing-library/jest-dom": "catalog:",
"@testing-library/react": "catalog:",
"@types/node": "catalog:",
"@types/react": "catalog:",
"@types/react-dom": "catalog:",
"@vitejs/plugin-react": "catalog:",
"happy-dom": "catalog:",
"react": "catalog:",
"react-dom": "catalog:",
"tailwindcss": "catalog:",
"typescript": "catalog:",
"vite": "catalog:",
"vite-plus": "catalog:",
"vitest": "catalog:"
}
}

View File

@ -0,0 +1,31 @@
import { cp, mkdir, rm } from 'node:fs/promises'
import { spawnSync } from 'node:child_process'
import { dirname, resolve } from 'node:path'
import { fileURLToPath } from 'node:url'
const packageRoot = resolve(dirname(fileURLToPath(import.meta.url)), '..')
const distDir = resolve(packageRoot, 'dist')
await rm(distDir, { recursive: true, force: true })
const tsc = spawnSync('pnpm', ['exec', 'tsc', '-p', 'tsconfig.build.json'], {
cwd: packageRoot,
stdio: 'inherit',
})
if (tsc.status !== 0)
process.exit(tsc.status ?? 1)
await mkdir(distDir, { recursive: true })
await cp(resolve(packageRoot, 'src/styles.css'), resolve(packageRoot, 'dist/styles.css'))
await cp(resolve(packageRoot, 'src/markdown.css'), resolve(packageRoot, 'dist/markdown.css'))
await cp(resolve(packageRoot, 'src/styles'), resolve(packageRoot, 'dist/styles'), {
force: true,
recursive: true,
})
await cp(resolve(packageRoot, 'src/themes'), resolve(packageRoot, 'dist/themes'), {
force: true,
recursive: true,
})

View File

@ -1,4 +1,10 @@
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
import type { Meta, StoryObj } from '@storybook/react'
import {
RiDeleteBinLine,
RiFileCopyLine,
RiPencilLine,
RiShareLine,
} from '@remixicon/react'
import { useState } from 'react'
import {
ContextMenu,
@ -17,7 +23,7 @@ import {
ContextMenuSubContent,
ContextMenuSubTrigger,
ContextMenuTrigger,
} from '.'
} from './index'
const TriggerArea = ({ label = 'Right-click inside this area' }: { label?: string }) => (
<ContextMenuTrigger
@ -185,17 +191,17 @@ export const Complex: Story = {
<TriggerArea label="Right-click to inspect all menu capabilities" />
<ContextMenuContent>
<ContextMenuItem>
<span aria-hidden className="i-ri-pencil-line size-4 shrink-0 text-text-tertiary" />
<RiPencilLine aria-hidden className="size-4 shrink-0 text-text-tertiary" />
Rename
</ContextMenuItem>
<ContextMenuItem>
<span aria-hidden className="i-ri-file-copy-line size-4 shrink-0 text-text-tertiary" />
<RiFileCopyLine aria-hidden className="size-4 shrink-0 text-text-tertiary" />
Duplicate
</ContextMenuItem>
<ContextMenuSeparator />
<ContextMenuSub>
<ContextMenuSubTrigger>
<span aria-hidden className="i-ri-share-line size-4 shrink-0 text-text-tertiary" />
<RiShareLine aria-hidden className="size-4 shrink-0 text-text-tertiary" />
Share
</ContextMenuSubTrigger>
<ContextMenuSubContent>
@ -206,7 +212,7 @@ export const Complex: Story = {
</ContextMenuSub>
<ContextMenuSeparator />
<ContextMenuItem destructive>
<span aria-hidden className="i-ri-delete-bin-line size-4 shrink-0" />
<RiDeleteBinLine aria-hidden className="size-4 shrink-0" />
Delete
</ContextMenuItem>
</ContextMenuContent>

View File

@ -1,8 +1,10 @@
'use client'
import type { Placement } from '@/app/components/base/ui/placement'
import type { Placement } from '../internal/placement.js'
import { ContextMenu as BaseContextMenu } from '@base-ui/react/context-menu'
import { RiArrowRightSLine, RiCheckLine } from '@remixicon/react'
import * as React from 'react'
import { cn } from '../internal/cn.js'
import {
menuBackdropClassName,
menuGroupLabelClassName,
@ -11,14 +13,11 @@ import {
menuPopupBaseClassName,
menuRowClassName,
menuSeparatorClassName,
} from '@/app/components/base/ui/menu-shared'
import { parsePlacement } from '@/app/components/base/ui/placement'
import { cn } from '@/utils/classnames'
} from '../internal/menu-shared.js'
import { parsePlacement } from '../internal/placement.js'
export const ContextMenu = BaseContextMenu.Root
export const ContextMenuTrigger = BaseContextMenu.Trigger
export const ContextMenuPortal = BaseContextMenu.Portal
export const ContextMenuBackdrop = BaseContextMenu.Backdrop
export const ContextMenuSub = BaseContextMenu.SubmenuRoot
export const ContextMenuGroup = BaseContextMenu.Group
export const ContextMenuRadioGroup = BaseContextMenu.RadioGroup
@ -44,11 +43,11 @@ type ContextMenuPopupRenderProps = Required<Pick<ContextMenuContentProps, 'child
placement: Placement
sideOffset: number
alignOffset: number
className?: string
popupClassName?: string
positionerProps?: ContextMenuContentProps['positionerProps']
popupProps?: ContextMenuContentProps['popupProps']
withBackdrop?: boolean
className?: string | undefined
popupClassName?: string | undefined
positionerProps?: ContextMenuContentProps['positionerProps'] | undefined
popupProps?: ContextMenuContentProps['popupProps'] | undefined
withBackdrop?: boolean | undefined
}
function renderContextMenuPopup({
@ -190,11 +189,10 @@ export function ContextMenuItemIndicator({
className={cn(menuIndicatorClassName, className)}
{...props}
>
{children ?? <span aria-hidden className="i-ri-check-line h-4 w-4" />}
{children ?? <RiCheckLine aria-hidden className="h-4 w-4" />}
</span>
)
}
export function ContextMenuCheckboxItemIndicator({
className,
...props
@ -204,7 +202,7 @@ export function ContextMenuCheckboxItemIndicator({
className={cn(menuIndicatorClassName, className)}
{...props}
>
<span aria-hidden className="i-ri-check-line h-4 w-4" />
<RiCheckLine aria-hidden className="h-4 w-4" />
</BaseContextMenu.CheckboxItemIndicator>
)
}
@ -218,7 +216,7 @@ export function ContextMenuRadioItemIndicator({
className={cn(menuIndicatorClassName, className)}
{...props}
>
<span aria-hidden className="i-ri-check-line h-4 w-4" />
<RiCheckLine aria-hidden className="h-4 w-4" />
</BaseContextMenu.RadioItemIndicator>
)
}
@ -239,20 +237,20 @@ export function ContextMenuSubTrigger({
{...props}
>
{children}
<span aria-hidden className="i-ri-arrow-right-s-line ml-auto size-4 shrink-0 text-text-tertiary" />
<RiArrowRightSLine aria-hidden className="ml-auto size-4 shrink-0 text-text-tertiary" />
</BaseContextMenu.SubmenuTrigger>
)
}
type ContextMenuSubContentProps = {
children: React.ReactNode
placement?: Placement
sideOffset?: number
alignOffset?: number
className?: string
popupClassName?: string
positionerProps?: ContextMenuContentProps['positionerProps']
popupProps?: ContextMenuContentProps['popupProps']
placement?: Placement | undefined
sideOffset?: number | undefined
alignOffset?: number | undefined
className?: string | undefined
popupClassName?: string | undefined
positionerProps?: ContextMenuContentProps['positionerProps'] | undefined
popupProps?: ContextMenuContentProps['popupProps'] | undefined
}
export function ContextMenuSubContent({
@ -300,3 +298,5 @@ export function ContextMenuSeparator({
/>
)
}
export type { Placement }

View File

@ -1,7 +1,6 @@
import type { ComponentPropsWithoutRef, ReactNode } from 'react'
import { fireEvent, render, screen, within } from '@testing-library/react'
import { describe, expect, it, vi } from 'vitest'
import Link from '@/next/link'
import {
DropdownMenu,
DropdownMenuContent,
@ -14,20 +13,20 @@ import {
DropdownMenuTrigger,
} from '../index'
vi.mock('@/next/link', () => ({
default: ({
href,
children,
...props
}: {
href: string
children?: ReactNode
} & Omit<ComponentPropsWithoutRef<'a'>, 'href'>) => (
function MockLink({
href,
children,
...props
}: {
href: string
children?: ReactNode
} & Omit<ComponentPropsWithoutRef<'a'>, 'href'>) {
return (
<a href={href} {...props}>
{children}
</a>
),
}))
)
}
describe('dropdown-menu wrapper', () => {
describe('DropdownMenuContent', () => {
@ -301,7 +300,7 @@ describe('dropdown-menu wrapper', () => {
<DropdownMenuTrigger aria-label="menu trigger">Open</DropdownMenuTrigger>
<DropdownMenuContent>
<DropdownMenuLinkItem
render={<Link href="/account" />}
render={<MockLink href="/account" />}
aria-label="account link"
>
Account settings

View File

@ -1,4 +1,15 @@
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
import type { Meta, StoryObj } from '@storybook/react'
import {
RiArchiveLine,
RiChat1Line,
RiDeleteBinLine,
RiFileCopyLine,
RiLink,
RiLockLine,
RiMailLine,
RiPencilLine,
RiShareLine,
} from '@remixicon/react'
import { useState } from 'react'
import {
DropdownMenu,
@ -17,7 +28,7 @@ import {
DropdownMenuSubContent,
DropdownMenuSubTrigger,
DropdownMenuTrigger,
} from '.'
} from './index'
const TriggerButton = ({ label = 'Open Menu' }: { label?: string }) => (
<DropdownMenuTrigger
@ -214,20 +225,20 @@ export const WithIcons: Story = {
<TriggerButton />
<DropdownMenuContent>
<DropdownMenuItem>
<span aria-hidden className="i-ri-pencil-line size-4 shrink-0 text-text-tertiary" />
<RiPencilLine aria-hidden className="size-4 shrink-0 text-text-tertiary" />
Edit
</DropdownMenuItem>
<DropdownMenuItem>
<span aria-hidden className="i-ri-file-copy-line size-4 shrink-0 text-text-tertiary" />
<RiFileCopyLine aria-hidden className="size-4 shrink-0 text-text-tertiary" />
Duplicate
</DropdownMenuItem>
<DropdownMenuItem>
<span aria-hidden className="i-ri-archive-line size-4 shrink-0 text-text-tertiary" />
<RiArchiveLine aria-hidden className="size-4 shrink-0 text-text-tertiary" />
Archive
</DropdownMenuItem>
<DropdownMenuSeparator />
<DropdownMenuItem destructive>
<span aria-hidden className="i-ri-delete-bin-line size-4 shrink-0" />
<RiDeleteBinLine aria-hidden className="size-4 shrink-0" />
Delete
</DropdownMenuItem>
</DropdownMenuContent>
@ -262,35 +273,35 @@ const ComplexDemo = () => {
<DropdownMenuGroup>
<DropdownMenuGroupLabel>Edit</DropdownMenuGroupLabel>
<DropdownMenuItem>
<span aria-hidden className="i-ri-pencil-line size-4 shrink-0 text-text-tertiary" />
<RiPencilLine aria-hidden className="size-4 shrink-0 text-text-tertiary" />
Rename
</DropdownMenuItem>
<DropdownMenuItem>
<span aria-hidden className="i-ri-file-copy-line size-4 shrink-0 text-text-tertiary" />
<RiFileCopyLine aria-hidden className="size-4 shrink-0 text-text-tertiary" />
Duplicate
</DropdownMenuItem>
<DropdownMenuItem disabled>
<span aria-hidden className="i-ri-lock-line size-4 shrink-0 text-text-tertiary" />
<RiLockLine aria-hidden className="size-4 shrink-0 text-text-tertiary" />
Move to Workspace
</DropdownMenuItem>
</DropdownMenuGroup>
<DropdownMenuSeparator />
<DropdownMenuSub>
<DropdownMenuSubTrigger>
<span aria-hidden className="i-ri-share-line size-4 shrink-0 text-text-tertiary" />
<RiShareLine aria-hidden className="size-4 shrink-0 text-text-tertiary" />
Share
</DropdownMenuSubTrigger>
<DropdownMenuSubContent>
<DropdownMenuItem>
<span aria-hidden className="i-ri-mail-line size-4 shrink-0 text-text-tertiary" />
<RiMailLine aria-hidden className="size-4 shrink-0 text-text-tertiary" />
Email
</DropdownMenuItem>
<DropdownMenuItem>
<span aria-hidden className="i-ri-chat-1-line size-4 shrink-0 text-text-tertiary" />
<RiChat1Line aria-hidden className="size-4 shrink-0 text-text-tertiary" />
Slack
</DropdownMenuItem>
<DropdownMenuItem>
<span aria-hidden className="i-ri-link size-4 shrink-0 text-text-tertiary" />
<RiLink aria-hidden className="size-4 shrink-0 text-text-tertiary" />
Copy Link
</DropdownMenuItem>
</DropdownMenuSubContent>
@ -315,13 +326,13 @@ const ComplexDemo = () => {
</DropdownMenuGroup>
<DropdownMenuSeparator />
<DropdownMenuCheckboxItem checked={showArchived} onCheckedChange={setShowArchived}>
<span aria-hidden className="i-ri-archive-line size-4 shrink-0 text-text-tertiary" />
<RiArchiveLine aria-hidden className="size-4 shrink-0 text-text-tertiary" />
Show Archived
<DropdownMenuCheckboxItemIndicator />
</DropdownMenuCheckboxItem>
<DropdownMenuSeparator />
<DropdownMenuItem destructive>
<span aria-hidden className="i-ri-delete-bin-line size-4 shrink-0" />
<RiDeleteBinLine aria-hidden className="size-4 shrink-0" />
Delete
</DropdownMenuItem>
</DropdownMenuContent>

View File

@ -1,8 +1,10 @@
'use client'
import type { Placement } from '@/app/components/base/ui/placement'
import type { Placement } from '../internal/placement.js'
import { Menu } from '@base-ui/react/menu'
import { RiArrowRightSLine, RiCheckLine } from '@remixicon/react'
import * as React from 'react'
import { cn } from '../internal/cn.js'
import {
menuGroupLabelClassName,
menuIndicatorClassName,
@ -10,12 +12,10 @@ import {
menuPopupBaseClassName,
menuRowClassName,
menuSeparatorClassName,
} from '@/app/components/base/ui/menu-shared'
import { parsePlacement } from '@/app/components/base/ui/placement'
import { cn } from '@/utils/classnames'
} from '../internal/menu-shared.js'
import { parsePlacement } from '../internal/placement.js'
export const DropdownMenu = Menu.Root
export const DropdownMenuPortal = Menu.Portal
export const DropdownMenuTrigger = Menu.Trigger
export const DropdownMenuSub = Menu.SubmenuRoot
export const DropdownMenuGroup = Menu.Group
@ -42,7 +42,7 @@ export function DropdownMenuRadioItemIndicator({
className={cn(menuIndicatorClassName, className)}
{...props}
>
<span aria-hidden className="i-ri-check-line h-4 w-4" />
<RiCheckLine aria-hidden className="h-4 w-4" />
</Menu.RadioItemIndicator>
)
}
@ -68,7 +68,7 @@ export function DropdownMenuCheckboxItemIndicator({
className={cn(menuIndicatorClassName, className)}
{...props}
>
<span aria-hidden className="i-ri-check-line h-4 w-4" />
<RiCheckLine aria-hidden className="h-4 w-4" />
</Menu.CheckboxItemIndicator>
)
}
@ -106,10 +106,10 @@ type DropdownMenuPopupRenderProps = Required<Pick<DropdownMenuContentProps, 'chi
placement: Placement
sideOffset: number
alignOffset: number
className?: string
popupClassName?: string
positionerProps?: DropdownMenuContentProps['positionerProps']
popupProps?: DropdownMenuContentProps['popupProps']
className?: string | undefined
popupClassName?: string | undefined
positionerProps?: DropdownMenuContentProps['positionerProps'] | undefined
popupProps?: DropdownMenuContentProps['popupProps'] | undefined
}
function renderDropdownMenuPopup({
@ -187,20 +187,20 @@ export function DropdownMenuSubTrigger({
{...props}
>
{children}
<span aria-hidden className="i-ri-arrow-right-s-line ml-auto size-4 shrink-0 text-text-tertiary" />
<RiArrowRightSLine aria-hidden className="ml-auto size-4 shrink-0 text-text-tertiary" />
</Menu.SubmenuTrigger>
)
}
type DropdownMenuSubContentProps = {
children: React.ReactNode
placement?: Placement
sideOffset?: number
alignOffset?: number
className?: string
popupClassName?: string
positionerProps?: DropdownMenuContentProps['positionerProps']
popupProps?: DropdownMenuContentProps['popupProps']
placement?: Placement | undefined
sideOffset?: number | undefined
alignOffset?: number | undefined
className?: string | undefined
popupClassName?: string | undefined
positionerProps?: DropdownMenuContentProps['positionerProps'] | undefined
popupProps?: DropdownMenuContentProps['popupProps'] | undefined
}
export function DropdownMenuSubContent({
@ -272,3 +272,5 @@ export function DropdownMenuSeparator({
/>
)
}
export type { Placement }

View File

@ -0,0 +1,7 @@
import type { ClassValue } from 'clsx'
import { clsx } from 'clsx'
import { twMerge } from 'tailwind-merge'
export function cn(...inputs: ClassValue[]) {
return twMerge(clsx(inputs))
}

View File

@ -0,0 +1,25 @@
type Side = 'top' | 'bottom' | 'left' | 'right'
type Align = 'start' | 'center' | 'end'
export type Placement
= 'top'
| 'top-start'
| 'top-end'
| 'right'
| 'right-start'
| 'right-end'
| 'bottom'
| 'bottom-start'
| 'bottom-end'
| 'left'
| 'left-start'
| 'left-end'
export function parsePlacement(placement: Placement): { side: Side, align: Align } {
const [side, align] = placement.split('-') as [Side, Align | undefined]
return {
side,
align: align ?? 'center',
}
}

View File

@ -0,0 +1,2 @@
@import './themes/markdown-light.css';
@import './themes/markdown-dark.css';

View File

@ -0,0 +1,7 @@
@import './themes/light.css' layer(base);
@import './themes/dark.css' layer(base);
@import './themes/manual-light.css' layer(base);
@import './themes/manual-dark.css' layer(base);
@import './styles/tokens.css';
@source './**/*.{js,mjs}';

View File

@ -0,0 +1,713 @@
@layer base {
*,
::after,
::before,
::backdrop,
::file-selector-button {
border-color: var(--color-gray-200, currentcolor);
}
}
@utility system-kbd {
/* font define start */
font-size: 12px;
font-weight: 500;
line-height: 16px;
/* border radius end */
}
@utility system-2xs-regular-uppercase {
font-size: 10px;
font-weight: 400;
text-transform: uppercase;
line-height: 12px;
/* border radius end */
}
@utility system-2xs-regular {
font-size: 10px;
font-weight: 400;
line-height: 12px;
/* border radius end */
}
@utility system-2xs-medium {
font-size: 10px;
font-weight: 500;
line-height: 12px;
/* border radius end */
}
@utility system-2xs-medium-uppercase {
font-size: 10px;
font-weight: 500;
text-transform: uppercase;
line-height: 12px;
/* border radius end */
}
@utility system-2xs-semibold-uppercase {
font-size: 10px;
font-weight: 600;
text-transform: uppercase;
line-height: 12px;
/* border radius end */
}
@utility system-xs-regular {
font-size: 12px;
font-weight: 400;
line-height: 16px;
/* border radius end */
}
@utility system-xs-regular-uppercase {
font-size: 12px;
font-weight: 400;
text-transform: uppercase;
line-height: 16px;
/* border radius end */
}
@utility system-xs-medium {
font-size: 12px;
font-weight: 500;
line-height: 16px;
/* border radius end */
}
@utility system-xs-medium-uppercase {
font-size: 12px;
font-weight: 500;
text-transform: uppercase;
line-height: 16px;
/* border radius end */
}
@utility system-xs-semibold {
font-size: 12px;
font-weight: 600;
line-height: 16px;
/* border radius end */
}
@utility system-xs-semibold-uppercase {
font-size: 12px;
font-weight: 600;
text-transform: uppercase;
line-height: 16px;
/* border radius end */
}
@utility system-sm-regular {
font-size: 13px;
font-weight: 400;
line-height: 16px;
/* border radius end */
}
@utility system-sm-medium {
font-size: 13px;
font-weight: 500;
line-height: 16px;
/* border radius end */
}
@utility system-sm-medium-uppercase {
font-size: 13px;
font-weight: 500;
text-transform: uppercase;
line-height: 16px;
/* border radius end */
}
@utility system-sm-semibold {
font-size: 13px;
font-weight: 600;
line-height: 16px;
/* border radius end */
}
@utility system-sm-semibold-uppercase {
font-size: 13px;
font-weight: 600;
text-transform: uppercase;
line-height: 16px;
/* border radius end */
}
@utility system-md-regular {
font-size: 14px;
font-weight: 400;
line-height: 20px;
/* border radius end */
}
@utility system-md-medium {
font-size: 14px;
font-weight: 500;
line-height: 20px;
/* border radius end */
}
@utility system-md-semibold {
font-size: 14px;
font-weight: 600;
line-height: 20px;
/* border radius end */
}
@utility system-md-semibold-uppercase {
font-size: 14px;
font-weight: 600;
text-transform: uppercase;
line-height: 20px;
/* border radius end */
}
@utility system-xl-regular {
font-size: 16px;
font-weight: 400;
line-height: 24px;
/* border radius end */
}
@utility system-xl-medium {
font-size: 16px;
font-weight: 500;
line-height: 24px;
/* border radius end */
}
@utility system-xl-semibold {
font-size: 16px;
font-weight: 600;
line-height: 24px;
/* border radius end */
}
@utility code-xs-regular {
font-size: 12px;
font-weight: 400;
line-height: 1.5;
/* border radius end */
}
@utility code-xs-semibold {
font-size: 12px;
font-weight: 600;
line-height: 1.5;
/* border radius end */
}
@utility code-sm-regular {
font-size: 13px;
font-weight: 400;
line-height: 1.5;
/* border radius end */
}
@utility code-sm-semibold {
font-size: 13px;
font-weight: 600;
line-height: 1.5;
/* border radius end */
}
@utility code-md-regular {
font-size: 14px;
font-weight: 400;
line-height: 1.5;
/* border radius end */
}
@utility code-md-semibold {
font-size: 14px;
font-weight: 600;
line-height: 1.5;
/* border radius end */
}
@utility body-xs-light {
font-size: 12px;
font-weight: 300;
line-height: 16px;
/* border radius end */
}
@utility body-xs-regular {
font-size: 12px;
font-weight: 400;
line-height: 16px;
/* border radius end */
}
@utility body-xs-medium {
font-size: 12px;
font-weight: 500;
line-height: 16px;
/* border radius end */
}
@utility body-sm-light {
font-size: 13px;
font-weight: 300;
line-height: 16px;
/* border radius end */
}
@utility body-sm-regular {
font-size: 13px;
font-weight: 400;
line-height: 16px;
/* border radius end */
}
@utility body-sm-medium {
font-size: 13px;
font-weight: 500;
line-height: 16px;
/* border radius end */
}
@utility body-md-light {
font-size: 14px;
font-weight: 300;
line-height: 20px;
/* border radius end */
}
@utility body-md-regular {
font-size: 14px;
font-weight: 400;
line-height: 20px;
/* border radius end */
}
@utility body-md-medium {
font-size: 14px;
font-weight: 500;
line-height: 20px;
/* border radius end */
}
@utility body-lg-light {
font-size: 15px;
font-weight: 300;
line-height: 20px;
/* border radius end */
}
@utility body-lg-regular {
font-size: 15px;
font-weight: 400;
line-height: 20px;
/* border radius end */
}
@utility body-lg-medium {
font-size: 15px;
font-weight: 500;
line-height: 20px;
/* border radius end */
}
@utility body-xl-regular {
font-size: 16px;
font-weight: 400;
line-height: 24px;
/* border radius end */
}
@utility body-xl-medium {
font-size: 16px;
font-weight: 500;
line-height: 24px;
/* border radius end */
}
@utility body-xl-light {
font-size: 16px;
font-weight: 300;
line-height: 24px;
/* border radius end */
}
@utility body-2xl-light {
font-size: 18px;
font-weight: 300;
line-height: 1.5;
/* border radius end */
}
@utility body-2xl-regular {
font-size: 18px;
font-weight: 400;
line-height: 1.5;
/* border radius end */
}
@utility body-2xl-medium {
font-size: 18px;
font-weight: 500;
line-height: 1.5;
/* border radius end */
}
@utility title-xs-semi-bold {
font-size: 12px;
font-weight: 600;
line-height: 16px;
/* border radius end */
}
@utility title-xs-bold {
font-size: 12px;
font-weight: 700;
line-height: 16px;
/* border radius end */
}
@utility title-sm-semi-bold {
font-size: 13px;
font-weight: 600;
line-height: 16px;
/* border radius end */
}
@utility title-sm-bold {
font-size: 13px;
font-weight: 700;
line-height: 16px;
/* border radius end */
}
@utility title-md-semi-bold {
font-size: 14px;
font-weight: 600;
line-height: 20px;
/* border radius end */
}
@utility title-md-bold {
font-size: 14px;
font-weight: 700;
line-height: 20px;
/* border radius end */
}
@utility title-lg-semi-bold {
font-size: 15px;
font-weight: 600;
line-height: 1.2;
/* border radius end */
}
@utility title-lg-bold {
font-size: 15px;
font-weight: 700;
line-height: 1.2;
/* border radius end */
}
@utility title-xl-semi-bold {
font-size: 16px;
font-weight: 600;
line-height: 1.2;
/* border radius end */
}
@utility title-xl-bold {
font-size: 16px;
font-weight: 700;
line-height: 1.2;
/* border radius end */
}
@utility title-2xl-semi-bold {
font-size: 18px;
font-weight: 600;
line-height: 1.2;
/* border radius end */
}
@utility title-2xl-bold {
font-size: 18px;
font-weight: 700;
line-height: 1.2;
/* border radius end */
}
@utility title-3xl-semi-bold {
font-size: 20px;
font-weight: 600;
line-height: 1.2;
/* border radius end */
}
@utility title-3xl-bold {
font-size: 20px;
font-weight: 700;
line-height: 1.2;
/* border radius end */
}
@utility title-4xl-semi-bold {
font-size: 24px;
font-weight: 600;
line-height: 1.2;
/* border radius end */
}
@utility title-4xl-bold {
font-size: 24px;
font-weight: 700;
line-height: 1.2;
/* border radius end */
}
@utility title-5xl-semi-bold {
font-size: 30px;
font-weight: 600;
line-height: 1.2;
/* border radius end */
}
@utility title-5xl-bold {
font-size: 30px;
font-weight: 700;
line-height: 1.2;
/* border radius end */
}
@utility title-6xl-semi-bold {
font-size: 36px;
font-weight: 600;
line-height: 1.2;
/* border radius end */
}
@utility title-6xl-bold {
font-size: 36px;
font-weight: 700;
line-height: 1.2;
/* border radius end */
}
@utility title-7xl-semi-bold {
font-size: 48px;
font-weight: 600;
line-height: 1.2;
/* border radius end */
}
@utility title-7xl-bold {
font-size: 48px;
font-weight: 700;
line-height: 1.2;
/* border radius end */
}
@utility title-8xl-semi-bold {
font-size: 60px;
font-weight: 600;
line-height: 1.2;
/* border radius end */
}
@utility title-8xl-bold {
font-size: 60px;
font-weight: 700;
line-height: 1.2;
/* border radius end */
}
@utility radius-2xs {
/* font define end */
/* border radius start */
border-radius: 2px;
/* border radius end */
}
@utility radius-xs {
border-radius: 4px;
/* border radius end */
}
@utility radius-sm {
border-radius: 6px;
/* border radius end */
}
@utility radius-md {
border-radius: 8px;
/* border radius end */
}
@utility radius-lg {
border-radius: 10px;
/* border radius end */
}
@utility radius-xl {
border-radius: 12px;
/* border radius end */
}
@utility radius-2xl {
border-radius: 16px;
/* border radius end */
}
@utility radius-3xl {
border-radius: 20px;
/* border radius end */
}
@utility radius-4xl {
border-radius: 24px;
/* border radius end */
}
@utility radius-5xl {
border-radius: 24px;
/* border radius end */
}
@utility radius-6xl {
border-radius: 28px;
/* border radius end */
}
@utility radius-7xl {
border-radius: 32px;
/* border radius end */
}
@utility radius-8xl {
border-radius: 40px;
/* border radius end */
}
@utility radius-9xl {
border-radius: 48px;
/* border radius end */
}
@utility radius-full {
border-radius: 64px;
/* border radius end */
}
@utility no-scrollbar {
/* Hide scrollbar for Chrome, Safari and Opera */
&::-webkit-scrollbar {
display: none;
}
/* Hide scrollbar for IE, Edge and Firefox */
-ms-overflow-style: none;
scrollbar-width: none;
}
@utility no-spinner {
/* Hide arrows from number input */
&::-webkit-outer-spin-button {
-webkit-appearance: none;
margin: 0;
}
&::-webkit-inner-spin-button {
-webkit-appearance: none;
margin: 0;
}
-moz-appearance: textfield;
}

View File

@ -1,8 +1,10 @@
import { icons as customPublicIcons } from '@dify/iconify-collections/custom-public'
import { icons as customVenderIcons } from '@dify/iconify-collections/custom-vender'
import { getIconCollections, iconsPlugin } from '@egoist/tailwindcss-icons'
import { icons as heroicons } from '@iconify-json/heroicons'
import { icons as remixIcons } from '@iconify-json/ri'
import { iconsPlugin } from '@egoist/tailwindcss-icons'
import tailwindTypography from '@tailwindcss/typography'
import tailwindThemeVarDefine from './themes/tailwind-theme-var-define'
import tailwindThemeVarDefine from './tokens/tailwind-theme-var-define.js'
import typography from './typography.js'
const config = {
@ -151,7 +153,8 @@ const config = {
tailwindTypography,
iconsPlugin({
collections: {
...getIconCollections(['heroicons', 'ri']),
heroicons,
ri: remixIcons,
'custom-public': customPublicIcons,
'custom-vender': customVenderIcons,
},

3
packages/dify-ui/src/typography.d.ts vendored Normal file
View File

@ -0,0 +1,3 @@
declare const typography: (helpers: { theme: (path: string) => unknown }) => Record<string, unknown>
export default typography

View File

@ -0,0 +1,8 @@
import difyUiTailwindPreset from './src/tailwind-preset'
const config = {
content: [],
...difyUiTailwindPreset,
}
export default config

View File

@ -0,0 +1,21 @@
{
"extends": "./tsconfig.json",
"compilerOptions": {
"allowJs": true,
"noEmit": false,
"declaration": true,
"declarationMap": true,
"sourceMap": true,
"outDir": "./dist",
"rootDir": "./src"
},
"include": [
"src/**/*.ts",
"src/**/*.tsx",
"src/**/*.js"
],
"exclude": [
"src/**/*.stories.tsx",
"src/**/__tests__/**"
]
}

View File

@ -0,0 +1,38 @@
{
"compilerOptions": {
"target": "es2022",
"jsx": "react-jsx",
"lib": [
"dom",
"dom.iterable",
"es2022"
],
"module": "esnext",
"moduleResolution": "bundler",
"moduleDetection": "force",
"resolveJsonModule": true,
"allowJs": true,
"strict": true,
"noUncheckedIndexedAccess": true,
"exactOptionalPropertyTypes": true,
"noEmit": true,
"esModuleInterop": true,
"forceConsistentCasingInFileNames": true,
"isolatedModules": true,
"verbatimModuleSyntax": true,
"skipLibCheck": true,
"types": [
"node",
"vitest/globals",
"@testing-library/jest-dom"
]
},
"include": [
"src/**/*.ts",
"src/**/*.tsx",
"src/**/*.js",
"scripts/**/*.mjs",
"vite.config.ts",
"vitest.setup.ts"
]
}

View File

@ -0,0 +1,11 @@
import react from '@vitejs/plugin-react'
import { defineConfig } from 'vite-plus'
export default defineConfig({
plugins: [react()],
test: {
environment: 'happy-dom',
globals: true,
setupFiles: ['./vitest.setup.ts'],
},
})

View File

@ -0,0 +1,39 @@
import { cleanup } from '@testing-library/react'
import '@testing-library/jest-dom/vitest'
import { afterEach } from 'vitest'
if (typeof globalThis.ResizeObserver === 'undefined') {
globalThis.ResizeObserver = class {
observe() {
return undefined
}
unobserve() {
return undefined
}
disconnect() {
return undefined
}
}
}
if (typeof globalThis.IntersectionObserver === 'undefined') {
globalThis.IntersectionObserver = class {
readonly root: Element | Document | null = null
readonly rootMargin = ''
readonly scrollMargin = ''
readonly thresholds: ReadonlyArray<number> = []
constructor(_callback: IntersectionObserverCallback, _options?: IntersectionObserverInit) {}
observe(_target: Element) {}
unobserve(_target: Element) {}
disconnect() {}
takeRecords(): IntersectionObserverEntry[] {
return []
}
}
}
afterEach(() => {
cleanup()
})

2692
pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,9 @@
catalogMode: prefer
trustPolicy: no-downgrade
minimumReleaseAge: 2880
trustPolicyExclude:
- chokidar@4.0.3
- reselect@5.1.1
- semver@6.3.1
blockExoticSubdeps: true
strictDepBuilds: true
allowBuilds:
@ -23,7 +27,7 @@ overrides:
array.prototype.flatmap: npm:@nolyfill/array.prototype.flatmap@^1.0.44
array.prototype.tosorted: npm:@nolyfill/array.prototype.tosorted@^1.0.44
assert: npm:@nolyfill/assert@^1.0.26
brace-expansion@<2.0.2: 2.0.2
brace-expansion@>=2.0.0 <2.0.3: 2.0.3
canvas: ^3.2.2
devalue@<5.3.2: 5.3.2
dompurify@>=3.1.3 <=3.3.1: 3.3.2
@ -37,6 +41,7 @@ overrides:
is-generator-function: npm:@nolyfill/is-generator-function@^1.0.44
is-typed-array: npm:@nolyfill/is-typed-array@^1.0.44
isarray: npm:@nolyfill/isarray@^1.0.44
lodash-es@>=4.0.0 <= 4.17.23: 4.18.0
object.assign: npm:@nolyfill/object.assign@^1.0.44
object.entries: npm:@nolyfill/object.entries@^1.0.44
object.fromentries: npm:@nolyfill/object.fromentries@^1.0.44
@ -64,15 +69,15 @@ overrides:
tar@<=7.5.10: 7.5.11
typed-array-buffer: npm:@nolyfill/typed-array-buffer@^1.0.44
undici@>=7.0.0 <7.24.0: 7.24.0
vite: npm:@voidzero-dev/vite-plus-core@0.1.14
vitest: npm:@voidzero-dev/vite-plus-test@0.1.14
vite: npm:@voidzero-dev/vite-plus-core@0.1.15
vitest: npm:@voidzero-dev/vite-plus-test@0.1.15
which-typed-array: npm:@nolyfill/which-typed-array@^1.0.44
yaml@>=2.0.0 <2.8.3: 2.8.3
yauzl@<3.2.1: 3.2.1
catalog:
"@amplitude/analytics-browser": 2.38.0
"@amplitude/plugin-session-replay-browser": 1.27.5
"@antfu/eslint-config": 7.7.3
"@amplitude/analytics-browser": 2.38.1
"@amplitude/plugin-session-replay-browser": 1.27.6
"@antfu/eslint-config": 8.0.0
"@base-ui/react": 1.3.0
"@chromatic-com/storybook": 5.1.1
"@cucumber/cucumber": 12.7.0
@ -84,7 +89,7 @@ catalog:
"@formatjs/intl-localematcher": 0.8.2
"@headlessui/react": 2.2.9
"@heroicons/react": 2.2.0
"@hono/node-server": 1.19.11
"@hono/node-server": 1.19.12
"@iconify-json/heroicons": 1.2.3
"@iconify-json/ri": 1.2.10
"@lexical/code": 0.42.0
@ -98,34 +103,34 @@ catalog:
"@mdx-js/react": 3.1.1
"@mdx-js/rollup": 3.1.1
"@monaco-editor/react": 4.7.0
"@next/eslint-plugin-next": 16.2.1
"@next/mdx": 16.2.1
"@next/eslint-plugin-next": 16.2.2
"@next/mdx": 16.2.2
"@orpc/client": 1.13.13
"@orpc/contract": 1.13.13
"@orpc/openapi-client": 1.13.13
"@orpc/tanstack-query": 1.13.13
"@playwright/test": 1.58.2
"@playwright/test": 1.59.1
"@remixicon/react": 4.9.0
"@rgrove/parse-xml": 4.2.0
"@sentry/react": 10.46.0
"@storybook/addon-docs": 10.3.3
"@storybook/addon-links": 10.3.3
"@storybook/addon-onboarding": 10.3.3
"@storybook/addon-themes": 10.3.3
"@storybook/nextjs-vite": 10.3.3
"@storybook/react": 10.3.3
"@sentry/react": 10.47.0
"@storybook/addon-docs": 10.3.4
"@storybook/addon-links": 10.3.4
"@storybook/addon-onboarding": 10.3.4
"@storybook/addon-themes": 10.3.4
"@storybook/nextjs-vite": 10.3.4
"@storybook/react": 10.3.4
"@streamdown/math": 1.0.2
"@svgdotjs/svg.js": 3.2.5
"@t3-oss/env-nextjs": 0.13.11
"@tailwindcss/postcss": 4.2.2
"@tailwindcss/typography": 0.5.19
"@tailwindcss/vite": 4.2.2
"@tanstack/eslint-plugin-query": 5.95.2
"@tanstack/react-devtools": 0.10.0
"@tanstack/react-form": 1.28.5
"@tanstack/react-form-devtools": 0.2.19
"@tanstack/react-query": 5.95.2
"@tanstack/react-query-devtools": 5.95.2
"@tanstack/eslint-plugin-query": 5.96.1
"@tanstack/react-devtools": 0.10.1
"@tanstack/react-form": 1.28.6
"@tanstack/react-form-devtools": 0.2.20
"@tanstack/react-query": 5.96.1
"@tanstack/react-query-devtools": 5.96.1
"@testing-library/dom": 10.4.1
"@testing-library/jest-dom": 6.9.1
"@testing-library/react": 16.3.2
@ -144,12 +149,12 @@ catalog:
"@types/react-syntax-highlighter": 15.5.13
"@types/react-window": 1.8.8
"@types/sortablejs": 1.15.9
"@typescript-eslint/eslint-plugin": 8.57.2
"@typescript-eslint/parser": 8.57.2
"@typescript/native-preview": 7.0.0-dev.20260329.1
"@typescript-eslint/eslint-plugin": 8.58.0
"@typescript-eslint/parser": 8.58.0
"@typescript/native-preview": 7.0.0-dev.20260401.1
"@vitejs/plugin-react": 6.0.1
"@vitejs/plugin-rsc": 0.5.21
"@vitest/coverage-v8": 4.1.1
"@vitest/coverage-v8": 4.1.2
abcjs: 6.6.2
agentation: 3.0.2
ahooks: 3.9.7
@ -157,7 +162,7 @@ catalog:
class-variance-authority: 0.7.1
clsx: 2.1.1
cmdk: 1.1.1
code-inspector-plugin: 1.4.5
code-inspector-plugin: 1.5.0
copy-to-clipboard: 3.3.3
cron-parser: 5.5.0
dayjs: 1.11.20
@ -174,19 +179,18 @@ catalog:
eslint-markdown: 0.6.0
eslint-plugin-better-tailwindcss: 4.3.2
eslint-plugin-hyoban: 0.14.1
eslint-plugin-markdown-preferences: 0.40.3
eslint-plugin-markdown-preferences: 0.41.0
eslint-plugin-no-barrel-files: 1.2.2
eslint-plugin-react-hooks: 7.0.1
eslint-plugin-react-refresh: 0.5.2
eslint-plugin-sonarjs: 4.0.2
eslint-plugin-storybook: 10.3.3
eslint-plugin-storybook: 10.3.4
fast-deep-equal: 3.1.3
foxact: 0.3.0
happy-dom: 20.8.9
hono: 4.12.9
hono: 4.12.10
html-entities: 2.6.0
html-to-image: 1.11.13
i18next: 25.10.10
i18next: 26.0.3
i18next-resources-to-backend: 1.2.1
iconify-import-svg: 0.1.2
immer: 11.1.4
@ -196,15 +200,15 @@ catalog:
js-yaml: 4.1.1
jsonschema: 1.5.0
katex: 0.16.44
knip: 6.1.0
knip: 6.2.0
ky: 1.14.3
lamejs: 1.2.1
lexical: 0.42.0
mermaid: 11.13.0
mermaid: 11.14.0
mime: 4.1.0
mitt: 3.0.1
negotiator: 1.0.0
next: 16.2.1
next: 16.2.2
next-themes: 0.4.6
nuqs: 2.8.9
pinyin-pro: 3.28.0
@ -217,7 +221,7 @@ catalog:
react-dom: 19.2.4
react-easy-crop: 5.5.7
react-hotkeys-hook: 5.2.4
react-i18next: 16.6.6
react-i18next: 17.0.2
react-multi-email: 1.0.25
react-papaparse: 4.4.0
react-pdf-highlighter: 8.0.0-rc.0
@ -229,30 +233,29 @@ catalog:
reactflow: 11.11.4
remark-breaks: 4.0.0
remark-directive: 4.0.0
sass: 1.98.0
scheduler: 0.27.0
sharp: 0.34.5
sortablejs: 1.15.7
std-semver: 1.0.8
storybook: 10.3.3
storybook: 10.3.4
streamdown: 2.5.0
string-ts: 2.3.1
tailwind-merge: 3.5.0
tailwindcss: 4.2.2
taze: 19.10.0
taze: 19.11.0
tldts: 7.0.27
tsup: ^8.5.1
tsdown: 0.21.7
tsx: 4.21.0
typescript: 5.9.3
typescript: 6.0.2
uglify-js: 3.19.3
unist-util-visit: 5.1.0
use-context-selector: 2.0.0
uuid: 13.0.0
vinext: 0.0.38
vite: npm:@voidzero-dev/vite-plus-core@0.1.14
vinext: 0.0.39
vite: npm:@voidzero-dev/vite-plus-core@0.1.15
vite-plugin-inspect: 12.0.0-beta.1
vite-plus: 0.1.14
vitest: npm:@voidzero-dev/vite-plus-test@0.1.14
vite-plus: 0.1.15
vitest: npm:@voidzero-dev/vite-plus-test@0.1.15
vitest-canvas-mock: 1.1.4
zod: 4.3.6
zundo: 2.3.0

View File

@ -45,12 +45,12 @@
"homepage": "https://dify.ai",
"license": "MIT",
"scripts": {
"build": "tsup",
"build": "vp pack",
"lint": "eslint",
"lint:fix": "eslint --fix",
"type-check": "tsc -p tsconfig.json --noEmit",
"test": "vitest run",
"test:coverage": "vitest run --coverage",
"test": "vp test",
"test:coverage": "vp test --coverage",
"publish:check": "./scripts/publish.sh --dry-run",
"publish:npm": "./scripts/publish.sh"
},
@ -61,8 +61,8 @@
"@typescript-eslint/parser": "catalog:",
"@vitest/coverage-v8": "catalog:",
"eslint": "catalog:",
"tsup": "catalog:",
"typescript": "catalog:",
"vite-plus": "catalog:",
"vitest": "catalog:"
}
}

View File

@ -11,7 +11,8 @@
"strict": true,
"esModuleInterop": true,
"forceConsistentCasingInFileNames": true,
"skipLibCheck": true
"skipLibCheck": true,
"types": ["node"]
},
"include": ["src/**/*.ts", "tests/**/*.ts"]
}

View File

@ -1,12 +0,0 @@
import { defineConfig } from "tsup";
export default defineConfig({
entry: ["src/index.ts"],
format: ["esm"],
dts: true,
clean: true,
sourcemap: true,
splitting: false,
treeshake: true,
outDir: "dist",
});

View File

@ -1,6 +1,17 @@
import { defineConfig } from "vitest/config";
import { defineConfig } from "vite-plus";
export default defineConfig({
pack: {
entry: ["src/index.ts"],
format: ["esm"],
dts: true,
clean: true,
sourcemap: true,
// splitting: false,
treeshake: true,
outDir: "dist",
target: false,
},
test: {
environment: "node",
include: ["**/*.test.ts"],

View File

@ -6,10 +6,5 @@ export default defineConfig({
'react-syntax-highlighter',
'react-window',
'@types/react-window',
// We can not upgrade these yet
'typescript',
],
maturityPeriod: 2,
})

View File

@ -1,7 +1,10 @@
import type { StorybookConfig } from '@storybook/nextjs-vite'
const config: StorybookConfig = {
stories: ['../app/components/**/*.stories.@(js|jsx|mjs|ts|tsx)'],
stories: [
'../app/components/**/*.stories.@(js|jsx|mjs|ts|tsx)',
'../../packages/dify-ui/src/**/*.stories.@(js|jsx|mjs|ts|tsx)',
],
addons: [
// Not working with Storybook Vite framework
// '@storybook/addon-onboarding',

View File

@ -35,7 +35,7 @@ const TagManagementModal = dynamic(() => import('@/app/components/base/tag-manag
ssr: false,
})
export type IAppDetailLayoutProps = {
type IAppDetailLayoutProps = {
children: React.ReactNode
appId: string
}

View File

@ -25,7 +25,7 @@ import { useAppWorkflow } from '@/service/use-workflow'
import { AppModeEnum } from '@/types/app'
import { asyncRunSafe } from '@/utils'
export type ICardViewProps = {
type ICardViewProps = {
appId: string
isInPanel?: boolean
className?: string

View File

@ -27,7 +27,7 @@ const TIME_PERIOD_MAPPING: { value: number, name: TimePeriodName }[] = [
const queryDateFormat = 'YYYY-MM-DD HH:mm'
export type IChartViewProps = {
type IChartViewProps = {
appId: string
headerRight: React.ReactNode
}

View File

@ -1,5 +1,3 @@
@reference "../../../../styles/globals.css";
.app {
flex-grow: 1;
height: 0;

View File

@ -1,11 +0,0 @@
@reference "../../../../../styles/globals.css";
.logTable td {
padding: 7px 8px;
box-sizing: border-box;
max-width: 200px;
}
.pagination li {
list-style: none;
}

View File

@ -26,7 +26,7 @@ import { usePathname } from '@/next/navigation'
import { useDatasetDetail, useDatasetRelatedApps } from '@/service/knowledge/use-dataset'
import { cn } from '@/utils/classnames'
export type IAppDetailLayoutProps = {
type IAppDetailLayoutProps = {
children: React.ReactNode
datasetId: string
}

View File

@ -55,7 +55,7 @@ describe('DatasetsLayout', () => {
setAppContext()
})
it('should keep rendering children when workspace is still loading', () => {
it('should render loading when workspace is still loading', () => {
setAppContext({
isLoadingCurrentWorkspace: true,
currentWorkspace: { id: '' },
@ -67,7 +67,8 @@ describe('DatasetsLayout', () => {
</DatasetsLayout>
))
expect(screen.getByTestId('datasets-content')).toBeInTheDocument()
expect(screen.getByRole('status')).toBeInTheDocument()
expect(screen.queryByTestId('datasets-content')).not.toBeInTheDocument()
expect(mockReplace).not.toHaveBeenCalled()
})

View File

@ -1,18 +1,29 @@
'use client'
import { useEffect } from 'react'
import Loading from '@/app/components/base/loading'
import { useAppContext } from '@/context/app-context'
import { ExternalApiPanelProvider } from '@/context/external-api-panel-context'
import { ExternalKnowledgeApiProvider } from '@/context/external-knowledge-api-context'
import { redirect } from '@/next/navigation'
import { useRouter } from '@/next/navigation'
export default function DatasetsLayout({ children }: { children: React.ReactNode }) {
const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator, currentWorkspace, isLoadingCurrentWorkspace } = useAppContext()
const router = useRouter()
const shouldRedirect = !isLoadingCurrentWorkspace
&& currentWorkspace.id
&& !(isCurrentWorkspaceEditor || isCurrentWorkspaceDatasetOperator)
useEffect(() => {
if (shouldRedirect)
router.replace('/apps')
}, [shouldRedirect, router])
if (isLoadingCurrentWorkspace || !currentWorkspace.id)
return <Loading type="app" />
if (shouldRedirect) {
return redirect('/apps')
return null
}
return (

View File

@ -1,4 +1,5 @@
import type { ReactNode } from 'react'
import * as React from 'react'
import { AppInitializer } from '@/app/components/app-initializer'
import InSiteMessageNotification from '@/app/components/app/in-site-message/notification'
import AmplitudeProvider from '@/app/components/base/amplitude'
@ -13,6 +14,7 @@ import { EventEmitterContextProvider } from '@/context/event-emitter-provider'
import { ModalContextProvider } from '@/context/modal-context-provider'
import { ProviderContextProvider } from '@/context/provider-context-provider'
import PartnerStack from '../components/billing/partner-stack'
import Splash from '../components/splash'
import RoleRouteGuard from './role-route-guard'
const Layout = ({ children }: { children: ReactNode }) => {
@ -35,6 +37,7 @@ const Layout = ({ children }: { children: ReactNode }) => {
<PartnerStack />
<ReadmePanel />
<GotoAnything />
<Splash />
</ModalContextProvider>
</ProviderContextProvider>
</EventEmitterContextProvider>

View File

@ -41,7 +41,7 @@ describe('RoleRouteGuard', () => {
setAppContext()
})
it('should hide guarded content while workspace is loading', () => {
it('should render loading while workspace is loading', () => {
setAppContext({
isLoadingCurrentWorkspace: true,
})
@ -52,6 +52,7 @@ describe('RoleRouteGuard', () => {
</RoleRouteGuard>
))
expect(screen.getByRole('status')).toBeInTheDocument()
expect(screen.queryByTestId('guarded-content')).not.toBeInTheDocument()
expect(mockReplace).not.toHaveBeenCalled()
})

View File

@ -1,8 +1,10 @@
'use client'
import type { ReactNode } from 'react'
import { useEffect } from 'react'
import Loading from '@/app/components/base/loading'
import { useAppContext } from '@/context/app-context'
import { redirect, usePathname } from '@/next/navigation'
import { usePathname, useRouter } from '@/next/navigation'
const datasetOperatorRedirectRoutes = ['/apps', '/app', '/explore', '/tools'] as const
@ -11,11 +13,21 @@ const isPathUnderRoute = (pathname: string, route: string) => pathname === route
export default function RoleRouteGuard({ children }: { children: ReactNode }) {
const { isCurrentWorkspaceDatasetOperator, isLoadingCurrentWorkspace } = useAppContext()
const pathname = usePathname()
const router = useRouter()
const shouldGuardRoute = datasetOperatorRedirectRoutes.some(route => isPathUnderRoute(pathname, route))
const shouldRedirect = shouldGuardRoute && !isLoadingCurrentWorkspace && isCurrentWorkspaceDatasetOperator
useEffect(() => {
if (shouldRedirect)
router.replace('/datasets')
}, [shouldRedirect, router])
// Block rendering only for guarded routes to avoid permission flicker.
if (shouldGuardRoute && isLoadingCurrentWorkspace)
return <Loading type="app" />
if (shouldRedirect)
return redirect('/datasets')
return null
return <>{children}</>
}

View File

@ -1,158 +0,0 @@
import type { Mock } from 'vitest'
import { render, screen } from '@testing-library/react'
import { useWebAppStore } from '@/context/web-app-context'
import { useGetUserCanAccessApp } from '@/service/access-control'
import { useGetWebAppInfo, useGetWebAppMeta, useGetWebAppParams } from '@/service/use-share'
import AuthenticatedLayout from '../authenticated-layout'
const mockReplace = vi.fn()
const mockShareCode = 'share-code'
const mockUpdateAppInfo = vi.fn()
const mockUpdateAppParams = vi.fn()
const mockUpdateWebAppMeta = vi.fn()
const mockUpdateUserCanAccessApp = vi.fn()
const mockAppInfo = {
app_id: 'app-123',
}
vi.mock('@/next/navigation', () => ({
useRouter: () => ({
replace: mockReplace,
}),
usePathname: () => '/chat/test-share-code',
useSearchParams: () => new URLSearchParams(),
}))
vi.mock('@/context/web-app-context', () => ({
useWebAppStore: vi.fn(),
}))
vi.mock('@/service/access-control', () => ({
useGetUserCanAccessApp: vi.fn(),
}))
vi.mock('@/service/use-share', () => ({
useGetWebAppInfo: vi.fn(),
useGetWebAppParams: vi.fn(),
useGetWebAppMeta: vi.fn(),
}))
vi.mock('@/service/webapp-auth', () => ({
webAppLogout: vi.fn(),
}))
describe('AuthenticatedLayout', () => {
beforeEach(() => {
vi.clearAllMocks()
;(useWebAppStore as unknown as Mock).mockImplementation((selector: (state: Record<string, unknown>) => unknown) => {
const state = {
shareCode: mockShareCode,
updateAppInfo: mockUpdateAppInfo,
updateAppParams: mockUpdateAppParams,
updateWebAppMeta: mockUpdateWebAppMeta,
updateUserCanAccessApp: mockUpdateUserCanAccessApp,
}
return selector(state)
})
;(useGetWebAppInfo as Mock).mockReturnValue({
data: mockAppInfo,
error: null,
isPending: false,
})
;(useGetWebAppParams as Mock).mockReturnValue({
data: { user_input_form: [] },
error: null,
isPending: false,
})
;(useGetWebAppMeta as Mock).mockReturnValue({
data: { tool_icons: {} },
error: null,
isPending: false,
})
;(useGetUserCanAccessApp as Mock).mockReturnValue({
data: { result: true },
error: null,
isPending: false,
})
})
describe('Permission Gating', () => {
it('should not render children while the app info needed for permission is still pending', () => {
;(useGetWebAppInfo as Mock).mockReturnValue({
data: undefined,
error: null,
isPending: true,
})
render(
<AuthenticatedLayout>
<div>protected child</div>
</AuthenticatedLayout>,
)
expect(screen.queryByText('protected child')).not.toBeInTheDocument()
})
it('should not render children while the access check is still pending', () => {
;(useGetUserCanAccessApp as Mock).mockReturnValue({
data: undefined,
error: null,
isPending: true,
})
render(
<AuthenticatedLayout>
<div>protected child</div>
</AuthenticatedLayout>,
)
expect(screen.queryByText('protected child')).not.toBeInTheDocument()
})
it('should render children once access is allowed even if metadata queries are still pending', () => {
;(useGetWebAppParams as Mock).mockReturnValue({
data: undefined,
error: null,
isPending: true,
})
;(useGetWebAppMeta as Mock).mockReturnValue({
data: undefined,
error: null,
isPending: true,
})
render(
<AuthenticatedLayout>
<div>protected child</div>
</AuthenticatedLayout>,
)
expect(screen.getByText('protected child')).toBeInTheDocument()
})
it('should render the no permission state when access is denied', () => {
;(useGetUserCanAccessApp as Mock).mockReturnValue({
data: { result: false },
error: null,
isPending: false,
})
render(
<AuthenticatedLayout>
<div>protected child</div>
</AuthenticatedLayout>,
)
expect(screen.queryByText('protected child')).not.toBeInTheDocument()
expect(screen.getByText(/403/)).toBeInTheDocument()
expect(screen.getByText(/no permission/i)).toBeInTheDocument()
})
})
})

View File

@ -1,123 +0,0 @@
import { render, screen, waitFor } from '@testing-library/react'
import Splash from '../splash'
const mockReplace = vi.fn()
const mockWebAppLoginStatus = vi.fn()
const mockFetchAccessToken = vi.fn()
const mockSetWebAppAccessToken = vi.fn()
const mockSetWebAppPassport = vi.fn()
const mockWebAppLogout = vi.fn()
let mockShareCode: string | null = null
let mockEmbeddedUserId: string | null = null
let mockMessage: string | null = null
let mockRedirectUrl: string | null = '/chat/test-share-code'
let mockCode: string | null = null
let mockTokenFromUrl: string | null = null
vi.mock('@/context/web-app-context', () => ({
useWebAppStore: (selector: (state: { shareCode: string | null, embeddedUserId: string | null }) => unknown) =>
selector({
shareCode: mockShareCode,
embeddedUserId: mockEmbeddedUserId,
}),
}))
vi.mock('@/next/navigation', () => ({
useRouter: () => ({
replace: mockReplace,
}),
useSearchParams: () => ({
get: (key: string) => {
if (key === 'redirect_url')
return mockRedirectUrl
if (key === 'message')
return mockMessage
if (key === 'code')
return mockCode
if (key === 'web_sso_token')
return mockTokenFromUrl
return null
},
toString: () => {
const params = new URLSearchParams()
if (mockRedirectUrl)
params.set('redirect_url', mockRedirectUrl)
if (mockMessage)
params.set('message', mockMessage)
if (mockCode)
params.set('code', mockCode)
if (mockTokenFromUrl)
params.set('web_sso_token', mockTokenFromUrl)
return params.toString()
},
* [Symbol.iterator]() {
const params = new URLSearchParams(this.toString())
yield* params.entries()
},
}),
}))
vi.mock('@/service/share', () => ({
fetchAccessToken: (...args: unknown[]) => mockFetchAccessToken(...args),
}))
vi.mock('@/service/webapp-auth', () => ({
setWebAppAccessToken: (...args: unknown[]) => mockSetWebAppAccessToken(...args),
setWebAppPassport: (...args: unknown[]) => mockSetWebAppPassport(...args),
webAppLoginStatus: (...args: unknown[]) => mockWebAppLoginStatus(...args),
webAppLogout: (...args: unknown[]) => mockWebAppLogout(...args),
}))
describe('Share Splash', () => {
beforeEach(() => {
vi.clearAllMocks()
mockShareCode = null
mockEmbeddedUserId = null
mockMessage = null
mockRedirectUrl = '/chat/test-share-code'
mockCode = null
mockTokenFromUrl = null
mockWebAppLoginStatus.mockResolvedValue({
userLoggedIn: true,
appLoggedIn: true,
})
mockFetchAccessToken.mockResolvedValue({ access_token: 'token' })
})
describe('Share Code Guard', () => {
it('should skip login-status checks until the share code is available', async () => {
render(
<Splash>
<div>share child</div>
</Splash>,
)
expect(screen.getByText('share child')).toBeInTheDocument()
await waitFor(() => {
expect(mockWebAppLoginStatus).not.toHaveBeenCalled()
})
expect(mockFetchAccessToken).not.toHaveBeenCalled()
})
it('should resume the auth flow after the share code becomes available', async () => {
const { rerender } = render(
<Splash>
<div>share child</div>
</Splash>,
)
mockShareCode = 'share-code'
rerender(
<Splash>
<div>share child</div>
</Splash>,
)
await waitFor(() => {
expect(mockWebAppLoginStatus).toHaveBeenCalledWith('share-code', undefined)
})
expect(mockReplace).toHaveBeenCalledWith('/chat/test-share-code')
})
})
})

View File

@ -4,6 +4,7 @@ import * as React from 'react'
import { useCallback, useEffect } from 'react'
import { useTranslation } from 'react-i18next'
import AppUnavailable from '@/app/components/base/app-unavailable'
import Loading from '@/app/components/base/loading'
import { useWebAppStore } from '@/context/web-app-context'
import { usePathname, useRouter, useSearchParams } from '@/next/navigation'
import { useGetUserCanAccessApp } from '@/service/access-control'
@ -17,9 +18,9 @@ const AuthenticatedLayout = ({ children }: { children: React.ReactNode }) => {
const updateAppParams = useWebAppStore(s => s.updateAppParams)
const updateWebAppMeta = useWebAppStore(s => s.updateWebAppMeta)
const updateUserCanAccessApp = useWebAppStore(s => s.updateUserCanAccessApp)
const { data: appParams, error: appParamsError } = useGetWebAppParams()
const { data: appInfo, error: appInfoError } = useGetWebAppInfo()
const { data: appMeta, error: appMetaError } = useGetWebAppMeta()
const { isFetching: isFetchingAppParams, data: appParams, error: appParamsError } = useGetWebAppParams()
const { isFetching: isFetchingAppInfo, data: appInfo, error: appInfoError } = useGetWebAppInfo()
const { isFetching: isFetchingAppMeta, data: appMeta, error: appMetaError } = useGetWebAppMeta()
const { data: userCanAccessApp, error: useCanAccessAppError } = useGetUserCanAccessApp({ appId: appInfo?.app_id, isInstalledApp: false })
useEffect(() => {
@ -80,11 +81,17 @@ const AuthenticatedLayout = ({ children }: { children: React.ReactNode }) => {
return (
<div className="flex h-full flex-col items-center justify-center gap-y-2">
<AppUnavailable className="h-auto w-auto" code={403} unknownReason="no permission." />
<span className="cursor-pointer system-sm-regular text-text-tertiary" onClick={backToHome}>{t('userProfile.logout', { ns: 'common' })}</span>
<span className="system-sm-regular cursor-pointer text-text-tertiary" onClick={backToHome}>{t('userProfile.logout', { ns: 'common' })}</span>
</div>
)
}
if (isFetchingAppInfo || isFetchingAppParams || isFetchingAppMeta) {
return (
<div className="flex h-full items-center justify-center">
<Loading />
</div>
)
}
return <>{children}</>
}

View File

@ -1,8 +1,9 @@
'use client'
import type { FC, PropsWithChildren } from 'react'
import { useCallback, useEffect } from 'react'
import { useCallback, useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next'
import AppUnavailable from '@/app/components/base/app-unavailable'
import Loading from '@/app/components/base/loading'
import { useWebAppStore } from '@/context/web-app-context'
import { useRouter, useSearchParams } from '@/next/navigation'
import { fetchAccessToken } from '@/service/share'
@ -31,8 +32,11 @@ const Splash: FC<PropsWithChildren> = ({ children }) => {
const url = getSigninUrl()
router.replace(url)
}, [getSigninUrl, router, shareCode])
const [isLoading, setIsLoading] = useState(true)
useEffect(() => {
if (message) {
setIsLoading(false)
return
}
@ -42,6 +46,12 @@ const Splash: FC<PropsWithChildren> = ({ children }) => {
const redirectOrFinish = () => {
if (redirectUrl)
router.replace(decodeURIComponent(redirectUrl))
else
setIsLoading(false)
}
const proceedToAuth = () => {
setIsLoading(false)
}
(async () => {
@ -50,6 +60,9 @@ const Splash: FC<PropsWithChildren> = ({ children }) => {
if (userLoggedIn && appLoggedIn) {
redirectOrFinish()
}
else if (!userLoggedIn && !appLoggedIn) {
proceedToAuth()
}
else if (!userLoggedIn && appLoggedIn) {
redirectOrFinish()
}
@ -64,6 +77,7 @@ const Splash: FC<PropsWithChildren> = ({ children }) => {
}
catch {
await webAppLogout(shareCode!)
proceedToAuth()
}
}
})()
@ -81,11 +95,18 @@ const Splash: FC<PropsWithChildren> = ({ children }) => {
return (
<div className="flex h-full flex-col items-center justify-center gap-y-4">
<AppUnavailable className="h-auto w-auto" code={code || t('common.appUnavailable', { ns: 'share' })} unknownReason={message} />
<span className="cursor-pointer system-sm-regular text-text-tertiary" onClick={backToHome}>{code === '403' ? t('userProfile.logout', { ns: 'common' }) : t('login.backToHome', { ns: 'share' })}</span>
<span className="system-sm-regular cursor-pointer text-text-tertiary" onClick={backToHome}>{code === '403' ? t('userProfile.logout', { ns: 'common' }) : t('login.backToHome', { ns: 'share' })}</span>
</div>
)
}
if (isLoading) {
return (
<div className="flex h-full items-center justify-center">
<Loading />
</div>
)
}
return <>{children}</>
}

View File

@ -1,75 +0,0 @@
import { render, screen } from '@testing-library/react'
import { AccessMode } from '@/models/access-control'
import WebSSOForm from '../page'
const mockReplace = vi.fn()
let mockRedirectUrl = '/share/test-share-code'
let mockWebAppAccessMode: AccessMode | null = null
let mockSystemFeaturesEnabled = true
vi.mock('@/next/navigation', () => ({
useRouter: () => ({
replace: mockReplace,
}),
useSearchParams: () => ({
get: (key: string) => key === 'redirect_url' ? mockRedirectUrl : null,
}),
}))
vi.mock('@/context/global-public-context', () => ({
useGlobalPublicStore: (selector: (state: { systemFeatures: { webapp_auth: { enabled: boolean } } }) => unknown) =>
selector({
systemFeatures: {
webapp_auth: {
enabled: mockSystemFeaturesEnabled,
},
},
}),
}))
vi.mock('@/context/web-app-context', () => ({
useWebAppStore: (selector: (state: { webAppAccessMode: AccessMode | null, shareCode: string | null }) => unknown) =>
selector({
webAppAccessMode: mockWebAppAccessMode,
shareCode: 'test-share-code',
}),
}))
vi.mock('@/service/webapp-auth', () => ({
webAppLogout: vi.fn(),
}))
vi.mock('../normalForm', () => ({
default: () => <div data-testid="normal-form" />,
}))
vi.mock('../components/external-member-sso-auth', () => ({
default: () => <div data-testid="external-member-sso-auth" />,
}))
describe('WebSSOForm', () => {
beforeEach(() => {
vi.clearAllMocks()
mockRedirectUrl = '/share/test-share-code'
mockWebAppAccessMode = null
mockSystemFeaturesEnabled = true
})
describe('Access Mode Resolution', () => {
it('should avoid rendering auth variants before the access mode query resolves', () => {
render(<WebSSOForm />)
expect(screen.queryByTestId('normal-form')).not.toBeInTheDocument()
expect(screen.queryByTestId('external-member-sso-auth')).not.toBeInTheDocument()
expect(screen.queryByText('share.login.backToHome')).not.toBeInTheDocument()
})
it('should render the normal form for organization-backed access modes', () => {
mockWebAppAccessMode = AccessMode.ORGANIZATION
render(<WebSSOForm />)
expect(screen.getByTestId('normal-form')).toBeInTheDocument()
})
})
})

View File

@ -63,7 +63,7 @@ const WebSSOForm: FC = () => {
return (
<div className="flex h-full flex-col items-center justify-center gap-y-4">
<AppUnavailable className="h-auto w-auto" isUnknownReason={true} />
<span className="cursor-pointer system-sm-regular text-text-tertiary" onClick={backToHome}>{t('login.backToHome', { ns: 'share' })}</span>
<span className="system-sm-regular cursor-pointer text-text-tertiary" onClick={backToHome}>{t('login.backToHome', { ns: 'share' })}</span>
</div>
)
}

View File

@ -13,10 +13,6 @@ import { useProviderContext } from '@/context/provider-context'
import { useRouter } from '@/next/navigation'
import { useLogout, useUserProfile } from '@/service/use-common'
export type IAppSelector = {
isMobile: boolean
}
export default function AppSelector() {
const router = useRouter()
const { t } = useTranslation()

View File

@ -1,4 +1,6 @@
'use client'
import Loading from '@/app/components/base/loading'
import Header from '@/app/signin/_header'
import { AppContextProvider } from '@/context/app-context-provider'
import { useGlobalPublicStore } from '@/context/global-public-context'
@ -9,8 +11,16 @@ import { cn } from '@/utils/classnames'
export default function SignInLayout({ children }: any) {
const { systemFeatures } = useGlobalPublicStore()
useDocumentTitle('')
const { data: loginData } = useIsLogin()
const { isLoading, data: loginData } = useIsLogin()
const isLoggedIn = loginData?.logged_in
if (isLoading) {
return (
<div className="flex min-h-screen w-full justify-center bg-background-default-burn">
<Loading />
</div>
)
}
return (
<>
<div className={cn('flex min-h-screen w-full justify-center bg-background-default-burn p-6')}>

View File

@ -12,6 +12,7 @@ import { useEffect, useRef } from 'react'
import { useTranslation } from 'react-i18next'
import { Avatar } from '@/app/components/base/avatar'
import Button from '@/app/components/base/button'
import Loading from '@/app/components/base/loading'
import { toast } from '@/app/components/base/ui/toast'
import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { setPostLoginRedirect } from '@/app/signin/utils/post-login-redirect'
@ -70,8 +71,6 @@ export default function OAuthAuthorize() {
const isLoggedIn = loginData?.logged_in
const isLoading = isOAuthLoading || isIsLoginLoading
const onLoginSwitchClick = () => {
if (isLoading)
return
try {
const returnUrl = buildReturnUrl('/account/oauth/authorize', `?client_id=${encodeURIComponent(client_id)}&redirect_uri=${encodeURIComponent(redirect_uri)}`)
setPostLoginRedirect(returnUrl)
@ -107,6 +106,14 @@ export default function OAuthAuthorize() {
}
}, [client_id, redirect_uri, isError])
if (isLoading) {
return (
<div className="bg-background-default-subtle">
<Loading type="app" />
</div>
)
}
return (
<div className="bg-background-default-subtle">
{authAppInfo?.app_icon && (
@ -115,13 +122,13 @@ export default function OAuthAuthorize() {
</div>
)}
<div className={`mt-5 mb-4 flex flex-col gap-2 ${isLoggedIn ? 'pb-2' : ''}`}>
<div className={`mb-4 mt-5 flex flex-col gap-2 ${isLoggedIn ? 'pb-2' : ''}`}>
<div className="title-4xl-semi-bold">
{isLoggedIn && <div className="text-text-primary">{t('connect', { ns: 'oauth' })}</div>}
<div className="text-(--color-saas-dify-blue-inverted)">{authAppInfo?.app_label[language] || authAppInfo?.app_label?.en_US || t('unknownApp', { ns: 'oauth' })}</div>
{!isLoggedIn && <div className="text-text-primary">{t('tips.notLoggedIn', { ns: 'oauth' })}</div>}
</div>
<div className="body-md-regular text-text-secondary">{isLoggedIn ? `${authAppInfo?.app_label[language] || authAppInfo?.app_label?.en_US || t('unknownApp', { ns: 'oauth' })} ${t('tips.loggedIn', { ns: 'oauth' })}` : t('tips.needLogin', { ns: 'oauth' })}</div>
<div className="text-text-secondary body-md-regular">{isLoggedIn ? `${authAppInfo?.app_label[language] || authAppInfo?.app_label?.en_US || t('unknownApp', { ns: 'oauth' })} ${t('tips.loggedIn', { ns: 'oauth' })}` : t('tips.needLogin', { ns: 'oauth' })}</div>
</div>
{isLoggedIn && userProfile && (
@ -130,7 +137,7 @@ export default function OAuthAuthorize() {
<Avatar avatar={userProfile.avatar_url} name={userProfile.name} size="lg" />
<div>
<div className="system-md-semi-bold text-text-secondary">{userProfile.name}</div>
<div className="system-xs-regular text-text-tertiary">{userProfile.email}</div>
<div className="text-text-tertiary system-xs-regular">{userProfile.email}</div>
</div>
</div>
<Button variant="tertiary" size="small" onClick={onLoginSwitchClick}>{t('switchAccount', { ns: 'oauth' })}</Button>
@ -142,7 +149,7 @@ export default function OAuthAuthorize() {
{authAppInfo!.scope.split(/\s+/).filter(Boolean).map((scope: string) => {
const Icon = SCOPE_INFO_MAP[scope]
return (
<div key={scope} className="flex items-center gap-2 body-sm-medium text-text-secondary">
<div key={scope} className="flex items-center gap-2 text-text-secondary body-sm-medium">
{Icon ? <Icon.icon className="h-4 w-4" /> : <RiAccountCircleLine className="h-4 w-4" />}
{Icon.label}
</div>
@ -175,7 +182,7 @@ export default function OAuthAuthorize() {
</defs>
</svg>
</div>
<div className="mt-3 system-xs-regular text-text-tertiary">{t('tips.common', { ns: 'oauth' })}</div>
<div className="mt-3 text-text-tertiary system-xs-regular">{t('tips.common', { ns: 'oauth' })}</div>
</div>
)
}

View File

@ -3,7 +3,7 @@
import type { ReactNode } from 'react'
import Cookies from 'js-cookie'
import { parseAsBoolean, useQueryState } from 'nuqs'
import { useCallback, useEffect } from 'react'
import { useCallback, useEffect, useState } from 'react'
import {
EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION,
EDUCATION_VERIFYING_LOCALSTORAGE_ITEM,
@ -25,6 +25,7 @@ export const AppInitializer = ({
const searchParams = useSearchParams()
// Tokens are now stored in cookies, no need to check localStorage
const pathname = usePathname()
const [init, setInit] = useState(false)
const [oauthNewUser] = useQueryState(
'oauth_new_user',
parseAsBoolean.withOptions({ history: 'replace' }),
@ -86,7 +87,10 @@ export const AppInitializer = ({
const redirectUrl = resolvePostLoginRedirect()
if (redirectUrl) {
location.replace(redirectUrl)
return
}
setInit(true)
}
catch {
router.replace('/signin')
@ -94,5 +98,5 @@ export const AppInitializer = ({
})()
}, [isSetupFinished, router, pathname, searchParams, oauthNewUser])
return children
return init ? children : null
}

View File

@ -5,7 +5,7 @@ import AppInfoModals from './app-info-modals'
import AppInfoTrigger from './app-info-trigger'
import { useAppInfoActions } from './use-app-info-actions'
export type IAppInfoProps = {
type IAppInfoProps = {
expand: boolean
onlyShowDetail?: boolean
openState?: boolean

View File

@ -7,7 +7,7 @@ import {
import Tooltip from '@/app/components/base/tooltip'
import AppIcon from '../base/app-icon'
export type IAppBasicProps = {
type IAppBasicProps = {
iconType?: 'app' | 'api' | 'dataset' | 'webapp' | 'notion'
icon?: string
icon_background?: string | null

View File

@ -17,7 +17,7 @@ import DatasetSidebarDropdown from './dataset-sidebar-dropdown'
import NavLink from './nav-link'
import ToggleButton from './toggle-button'
export type IAppDetailNavProps = {
type IAppDetailNavProps = {
iconType?: 'app' | 'dataset'
navigation: Array<{
name: string

View File

@ -39,7 +39,5 @@ export enum AnnotationEnableStatus {
}
export enum JobStatus {
waiting = 'waiting',
processing = 'processing',
completed = 'completed',
}

View File

@ -2,7 +2,7 @@ import type { HTMLProps, PropsWithChildren } from 'react'
import { RiArrowRightUpLine } from '@remixicon/react'
import { cn } from '@/utils/classnames'
export type SuggestedActionProps = PropsWithChildren<HTMLProps<HTMLAnchorElement> & {
type SuggestedActionProps = PropsWithChildren<HTMLProps<HTMLAnchorElement> & {
icon?: React.ReactNode
link?: string
disabled?: boolean

View File

@ -3,7 +3,7 @@ import type { FC, ReactNode } from 'react'
import * as React from 'react'
import { cn } from '@/utils/classnames'
export type IFeaturePanelProps = {
type IFeaturePanelProps = {
className?: string
headerIcon?: ReactNode
title: ReactNode

Some files were not shown because too many files have changed in this diff Show More