diff --git a/api/.env.example b/api/.env.example index 40fed7403c..34be400e87 100644 --- a/api/.env.example +++ b/api/.env.example @@ -557,7 +557,7 @@ MAX_VARIABLE_SIZE=204800 # GraphEngine Worker Pool Configuration # Minimum number of workers per GraphEngine instance (default: 1) -GRAPH_ENGINE_MIN_WORKERS=1 +GRAPH_ENGINE_MIN_WORKERS=3 # Maximum number of workers per GraphEngine instance (default: 10) GRAPH_ENGINE_MAX_WORKERS=10 # Queue depth threshold that triggers worker scale up (default: 3) diff --git a/api/Dockerfile b/api/Dockerfile index 6098652573..8425578953 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -24,7 +24,8 @@ RUN apt-get update \ # Install Python dependencies (workspace members under providers/vdb/) COPY pyproject.toml uv.lock ./ COPY providers ./providers -RUN uv sync --locked --no-dev +# Trust the checked-in lock during image builds; dev-only path sources live outside the api/ context. +RUN uv sync --frozen --no-dev # production stage FROM base AS production diff --git a/api/commands/system.py b/api/commands/system.py index 39b2e991ed..7755d3b5bc 100644 --- a/api/commands/system.py +++ b/api/commands/system.py @@ -14,6 +14,7 @@ from libs.rsa import generate_key_pair from models import Tenant from models.model import App, AppMode, Conversation from models.provider import Provider, ProviderModel +from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider logger = logging.getLogger(__name__) @@ -23,13 +24,16 @@ DB_UPGRADE_LOCK_TTL_SECONDS = 60 @click.command( "reset-encrypt-key-pair", help="Reset the asymmetric key pair of workspace for encrypt LLM credentials. " - "After the reset, all LLM credentials will become invalid, " - "requiring re-entry." + "After the reset, all LLM credentials and tool provider credentials " + "(builtin / API / MCP) will be purged, requiring re-entry. " "Only support SELF_HOSTED mode.", ) @click.confirmation_option( prompt=click.style( - "Are you sure you want to reset encrypt key pair? This operation cannot be rolled back!", fg="red" + "Are you sure you want to reset encrypt key pair? " + "This will also purge builtin / API / MCP tool provider records for every tenant. " + "This operation cannot be rolled back!", + fg="red", ) ) def reset_encrypt_key_pair(): @@ -53,6 +57,13 @@ def reset_encrypt_key_pair(): session.execute(delete(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id)) session.execute(delete(ProviderModel).where(ProviderModel.tenant_id == tenant.id)) + # Purge tool provider records that hold credentials encrypted under the + # tenant key. Leaving them in place causes /console/api/workspaces/current/ + # tool-providers to 500 because decryption fails on stale ciphertext (#35396). + session.execute(delete(BuiltinToolProvider).where(BuiltinToolProvider.tenant_id == tenant.id)) + session.execute(delete(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant.id)) + session.execute(delete(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant.id)) + click.echo( click.style( f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.", diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index ccb97d96ef..a752d9d103 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -761,7 +761,7 @@ class WorkflowConfig(BaseSettings): # GraphEngine Worker Pool Configuration GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field( description="Minimum number of workers per GraphEngine instance", - default=1, + default=3, ) GRAPH_ENGINE_MAX_WORKERS: PositiveInt = Field( diff --git a/api/controllers/common/human_input.py b/api/controllers/common/human_input.py index 5d6f4efb95..98fe2ce67b 100644 --- a/api/controllers/common/human_input.py +++ b/api/controllers/common/human_input.py @@ -1,6 +1,21 @@ +import json + from pydantic import BaseModel, JsonValue class HumanInputFormSubmitPayload(BaseModel): inputs: dict[str, JsonValue] action: str + + +def stringify_form_default_values(values: dict[str, object]) -> dict[str, str]: + """Serialize default values into strings expected by human-input form clients.""" + result: dict[str, str] = {} + for key, value in values.items(): + if value is None: + result[key] = "" + elif isinstance(value, (dict, list)): + result[key] = json.dumps(value, ensure_ascii=False) + else: + result[key] = str(value) + return result diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index b03d9b4a4c..6463b022b5 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -11,6 +11,7 @@ from werkzeug.exceptions import Forbidden from controllers.common.schema import register_schema_models from extensions.ext_database import db from fields.base import ResponseModel +from libs.helper import to_timestamp from libs.login import current_account_with_tenant, login_required from models.dataset import Dataset from models.enums import ApiTokenType @@ -21,12 +22,6 @@ from . import console_ns from .wraps import account_initialization_required, edit_permission_required, setup_required -def _to_timestamp(value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return int(value.timestamp()) - return value - - class ApiKeyItem(ResponseModel): id: str type: str @@ -37,7 +32,7 @@ class ApiKeyItem(ResponseModel): @field_validator("last_used_at", "created_at", mode="before") @classmethod def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) + return to_timestamp(value) class ApiKeyList(ResponseModel): diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 4429039d79..045325f283 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -3,7 +3,6 @@ import re import uuid from datetime import datetime from typing import Any, Literal -from uuid import UUID from flask import request from flask_restx import Resource @@ -34,7 +33,7 @@ from core.trigger.constants import TRIGGER_NODE_TYPES from extensions.ext_database import db from fields.base import ResponseModel from graphon.enums import WorkflowExecutionStatus -from libs.helper import build_icon_url +from libs.helper import build_icon_url, to_timestamp from libs.login import current_account_with_tenant, login_required from models import App, DatasetPermissionEnum, Workflow from models.model import IconType @@ -178,12 +177,6 @@ class AppTracePayload(BaseModel): type JSONValue = Any -def _to_timestamp(value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return int(value.timestamp()) - return value - - class Tag(ResponseModel): id: str name: str @@ -200,7 +193,7 @@ class WorkflowPartial(ResponseModel): @field_validator("created_at", "updated_at", mode="before") @classmethod def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) + return to_timestamp(value) class ModelConfigPartial(ResponseModel): @@ -214,7 +207,7 @@ class ModelConfigPartial(ResponseModel): @field_validator("created_at", "updated_at", mode="before") @classmethod def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) + return to_timestamp(value) class ModelConfig(ResponseModel): @@ -275,7 +268,7 @@ class ModelConfig(ResponseModel): @field_validator("created_at", "updated_at", mode="before") @classmethod def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) + return to_timestamp(value) class Site(ResponseModel): @@ -318,7 +311,7 @@ class Site(ResponseModel): @field_validator("created_at", "updated_at", mode="before") @classmethod def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) + return to_timestamp(value) class DeletedTool(ResponseModel): @@ -361,7 +354,7 @@ class AppPartial(ResponseModel): @field_validator("created_at", "updated_at", mode="before") @classmethod def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) + return to_timestamp(value) class AppDetail(ResponseModel): @@ -391,7 +384,7 @@ class AppDetail(ResponseModel): @field_validator("created_at", "updated_at", mode="before") @classmethod def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) + return to_timestamp(value) class AppDetailWithSite(AppDetail): @@ -856,10 +849,11 @@ class AppTraceApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id: UUID): + @get_app_model + def get(self, app_model): """Get app trace""" with session_factory.create_session() as session: - app_trace_config = OpsTraceManager.get_app_tracing_config(str(app_id), session) + app_trace_config = OpsTraceManager.get_app_tracing_config(app_model.id, session) return app_trace_config @@ -873,12 +867,13 @@ class AppTraceApi(Resource): @login_required @account_initialization_required @edit_permission_required - def post(self, app_id: UUID): + @get_app_model + def post(self, app_model): # add app trace args = AppTracePayload.model_validate(console_ns.payload) OpsTraceManager.update_app_tracing_config( - app_id=str(app_id), + app_id=app_model.id, enabled=args.enabled, tracing_provider=args.tracing_provider, ) diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index 60a2bfc799..5951f7405a 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -16,6 +16,7 @@ from controllers.console.wraps import account_initialization_required, setup_req from extensions.ext_database import db from fields._value_type_serializer import serialize_value_type from fields.base import ResponseModel +from libs.helper import to_timestamp from libs.login import login_required from models import ConversationVariable from models.model import AppMode @@ -25,12 +26,6 @@ class ConversationVariablesQuery(BaseModel): conversation_id: str = Field(..., description="Conversation ID to filter variables") -def _to_timestamp(value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return int(value.timestamp()) - return value - - class ConversationVariableResponse(ResponseModel): id: str name: str @@ -65,7 +60,7 @@ class ConversationVariableResponse(ResponseModel): @field_validator("created_at", "updated_at", mode="before") @classmethod def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) + return to_timestamp(value) class PaginatedConversationVariableResponse(ResponseModel): diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index d517f695b8..13f6e098ba 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -13,6 +13,7 @@ from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from extensions.ext_database import db from fields.base import ResponseModel +from libs.helper import to_timestamp from libs.login import current_account_with_tenant, login_required from models.enums import AppMCPServerStatus from models.model import AppMCPServer @@ -30,12 +31,6 @@ class MCPServerUpdatePayload(BaseModel): status: str | None = Field(default=None, description="Server status") -def _to_timestamp(value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return int(value.timestamp()) - return value - - class AppMCPServerResponse(ResponseModel): id: str name: str @@ -59,7 +54,7 @@ class AppMCPServerResponse(ResponseModel): @field_validator("created_at", "updated_at", mode="before") @classmethod def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) + return to_timestamp(value) register_schema_models(console_ns, MCPServerCreatePayload, MCPServerUpdatePayload, AppMCPServerResponse) diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 44e19b57db..4b596b992f 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -37,10 +37,9 @@ from fields.conversation_fields import ( JSONValue, MessageFile, format_files_contained, - to_timestamp, ) from graphon.model_runtime.errors.invoke import InvokeError -from libs.helper import uuid_value +from libs.helper import to_timestamp, uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.login import current_account_with_tenant, login_required from models.enums import FeedbackFromSource, FeedbackRating @@ -144,9 +143,7 @@ class MessageDetailResponse(ResponseModel): @field_validator("created_at", mode="before") @classmethod def _normalize_created_at(cls, value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return to_timestamp(value) - return value + return to_timestamp(value) class MessageInfiniteScrollPaginationResponse(ResponseModel): diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py index 9227d00a21..41acf39541 100644 --- a/api/controllers/console/app/ops_trace.py +++ b/api/controllers/console/app/ops_trace.py @@ -1,5 +1,4 @@ from typing import Any -from uuid import UUID from flask import request from flask_restx import Resource, fields @@ -9,8 +8,10 @@ from werkzeug.exceptions import BadRequest from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist +from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from libs.login import login_required +from models import App from services.ops_service import OpsService @@ -43,11 +44,14 @@ class TraceAppConfigApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id: UUID): - args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) + @get_app_model + def get(self, app_model: App): + args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore try: - trace_config = OpsService.get_tracing_app_config(app_id=str(app_id), tracing_provider=args.tracing_provider) + trace_config = OpsService.get_tracing_app_config( + app_id=app_model.id, tracing_provider=args.tracing_provider + ) if not trace_config: return {"has_not_configured": True} return trace_config @@ -65,13 +69,14 @@ class TraceAppConfigApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id: UUID): + @get_app_model + def post(self, app_model: App): """Create a new trace app configuration""" args = TraceConfigPayload.model_validate(console_ns.payload) try: result = OpsService.create_tracing_app_config( - app_id=str(app_id), tracing_provider=args.tracing_provider, tracing_config=args.tracing_config + app_id=app_model.id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config ) if not result: raise TracingConfigIsExist() @@ -90,13 +95,14 @@ class TraceAppConfigApi(Resource): @setup_required @login_required @account_initialization_required - def patch(self, app_id: UUID): + @get_app_model + def patch(self, app_model: App): """Update an existing trace app configuration""" args = TraceConfigPayload.model_validate(console_ns.payload) try: result = OpsService.update_tracing_app_config( - app_id=str(app_id), tracing_provider=args.tracing_provider, tracing_config=args.tracing_config + app_id=app_model.id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config ) if not result: raise TracingConfigNotExist() @@ -113,12 +119,13 @@ class TraceAppConfigApi(Resource): @setup_required @login_required @account_initialization_required - def delete(self, app_id: UUID): + @get_app_model + def delete(self, app_model: App): """Delete an existing trace app configuration""" args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) try: - result = OpsService.delete_tracing_app_config(app_id=str(app_id), tracing_provider=args.tracing_provider) + result = OpsService.delete_tracing_app_config(app_id=app_model.id, tracing_provider=args.tracing_provider) if not result: raise TracingConfigNotExist() return {"result": "success"}, 204 diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index ddc900eb2d..dec183a300 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -16,6 +16,7 @@ from fields.base import ResponseModel from fields.end_user_fields import SimpleEndUser from fields.member_fields import SimpleAccount from graphon.enums import WorkflowExecutionStatus +from libs.helper import to_timestamp from libs.login import login_required from models import App from models.model import AppMode @@ -82,9 +83,7 @@ class WorkflowRunForLogResponse(ResponseModel): @field_validator("created_at", "finished_at", mode="before") @classmethod def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return int(value.timestamp()) - return value + return to_timestamp(value) class WorkflowRunForArchivedLogResponse(ResponseModel): @@ -117,9 +116,7 @@ class WorkflowAppLogPartialResponse(ResponseModel): @field_validator("created_at", mode="before") @classmethod def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return int(value.timestamp()) - return value + return to_timestamp(value) class WorkflowArchivedLogPartialResponse(ResponseModel): @@ -133,9 +130,7 @@ class WorkflowArchivedLogPartialResponse(ResponseModel): @field_validator("created_at", mode="before") @classmethod def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return int(value.timestamp()) - return value + return to_timestamp(value) class WorkflowAppLogPaginationResponse(ResponseModel): diff --git a/api/controllers/console/app/workflow_comment.py b/api/controllers/console/app/workflow_comment.py index c003be1303..f011f576fd 100644 --- a/api/controllers/console/app/workflow_comment.py +++ b/api/controllers/console/app/workflow_comment.py @@ -1,22 +1,16 @@ import logging +from datetime import datetime -from flask_restx import Resource, marshal_with -from pydantic import BaseModel, Field, TypeAdapter +from flask_restx import Resource +from pydantic import BaseModel, Field, TypeAdapter, computed_field, field_validator -from controllers.common.schema import register_schema_models +from controllers.common.schema import register_response_schema_models, register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required +from fields.base import ResponseModel from fields.member_fields import AccountWithRole -from fields.workflow_comment_fields import ( - workflow_comment_basic_fields, - workflow_comment_create_fields, - workflow_comment_detail_fields, - workflow_comment_reply_create_fields, - workflow_comment_reply_update_fields, - workflow_comment_resolve_fields, - workflow_comment_update_fields, -) +from libs.helper import build_avatar_url, dump_response, to_timestamp from libs.login import current_user, login_required from models import App from services.account_service import TenantService @@ -51,6 +45,138 @@ class WorkflowCommentMentionUsersPayload(BaseModel): users: list[AccountWithRole] +class WorkflowCommentAccount(ResponseModel): + id: str + name: str + email: str + avatar: str | None = Field(default=None, exclude=True) + + @computed_field(return_type=str | None) # type: ignore[prop-decorator] + @property + def avatar_url(self) -> str | None: + return build_avatar_url(self.avatar) + + +class WorkflowCommentReply(ResponseModel): + id: str + content: str + created_by: str + created_by_account: WorkflowCommentAccount | None = None + created_at: int | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return to_timestamp(value) + + +class WorkflowCommentMention(ResponseModel): + mentioned_user_id: str + mentioned_user_account: WorkflowCommentAccount | None = None + reply_id: str | None = None + + +class WorkflowCommentBasic(ResponseModel): + id: str + position_x: float + position_y: float + content: str + created_by: str + created_by_account: WorkflowCommentAccount | None = None + created_at: int | None = None + updated_at: int | None = None + resolved: bool + resolved_at: int | None = None + resolved_by: str | None = None + resolved_by_account: WorkflowCommentAccount | None = None + reply_count: int + mention_count: int + participants: list[WorkflowCommentAccount] + + @field_validator("created_at", "updated_at", "resolved_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return to_timestamp(value) + + +class WorkflowCommentBasicList(ResponseModel): + data: list[WorkflowCommentBasic] + + +class WorkflowCommentDetail(ResponseModel): + id: str + position_x: float + position_y: float + content: str + created_by: str + created_by_account: WorkflowCommentAccount | None = None + created_at: int | None = None + updated_at: int | None = None + resolved: bool + resolved_at: int | None = None + resolved_by: str | None = None + resolved_by_account: WorkflowCommentAccount | None = None + replies: list[WorkflowCommentReply] + mentions: list[WorkflowCommentMention] + + @field_validator("created_at", "updated_at", "resolved_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return to_timestamp(value) + + +class WorkflowCommentCreate(ResponseModel): + id: str + created_at: int | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return to_timestamp(value) + + +class WorkflowCommentUpdate(ResponseModel): + id: str + updated_at: int | None = None + + @field_validator("updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return to_timestamp(value) + + +class WorkflowCommentResolve(ResponseModel): + id: str + resolved: bool + resolved_at: int | None = None + resolved_by: str | None = None + + @field_validator("resolved_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return to_timestamp(value) + + +class WorkflowCommentReplyCreate(ResponseModel): + id: str + created_at: int | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return to_timestamp(value) + + +class WorkflowCommentReplyUpdate(ResponseModel): + id: str + updated_at: int | None = None + + @field_validator("updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return to_timestamp(value) + + register_schema_models( console_ns, AccountWithRole, @@ -59,17 +185,19 @@ register_schema_models( WorkflowCommentUpdatePayload, WorkflowCommentReplyPayload, ) - -workflow_comment_basic_model = console_ns.model("WorkflowCommentBasic", workflow_comment_basic_fields) -workflow_comment_detail_model = console_ns.model("WorkflowCommentDetail", workflow_comment_detail_fields) -workflow_comment_create_model = console_ns.model("WorkflowCommentCreate", workflow_comment_create_fields) -workflow_comment_update_model = console_ns.model("WorkflowCommentUpdate", workflow_comment_update_fields) -workflow_comment_resolve_model = console_ns.model("WorkflowCommentResolve", workflow_comment_resolve_fields) -workflow_comment_reply_create_model = console_ns.model( - "WorkflowCommentReplyCreate", workflow_comment_reply_create_fields -) -workflow_comment_reply_update_model = console_ns.model( - "WorkflowCommentReplyUpdate", workflow_comment_reply_update_fields +register_response_schema_models( + console_ns, + WorkflowCommentAccount, + WorkflowCommentReply, + WorkflowCommentMention, + WorkflowCommentBasic, + WorkflowCommentBasicList, + WorkflowCommentDetail, + WorkflowCommentCreate, + WorkflowCommentUpdate, + WorkflowCommentResolve, + WorkflowCommentReplyCreate, + WorkflowCommentReplyUpdate, ) @@ -80,28 +208,26 @@ class WorkflowCommentListApi(Resource): @console_ns.doc("list_workflow_comments") @console_ns.doc(description="Get all comments for a workflow") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.response(200, "Comments retrieved successfully", workflow_comment_basic_model) + @console_ns.response(200, "Comments retrieved successfully", console_ns.models[WorkflowCommentBasicList.__name__]) @login_required @setup_required @account_initialization_required @get_app_model() - @marshal_with(workflow_comment_basic_model, envelope="data") def get(self, app_model: App): """Get all comments for a workflow.""" comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id) - return comments + return WorkflowCommentBasicList.model_validate({"data": comments}).model_dump(mode="json") @console_ns.doc("create_workflow_comment") @console_ns.doc(description="Create a new workflow comment") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[WorkflowCommentCreatePayload.__name__]) - @console_ns.response(201, "Comment created successfully", workflow_comment_create_model) + @console_ns.response(201, "Comment created successfully", console_ns.models[WorkflowCommentCreate.__name__]) @login_required @setup_required @account_initialization_required @get_app_model() - @marshal_with(workflow_comment_create_model) @edit_permission_required def post(self, app_model: App): """Create a new workflow comment.""" @@ -117,7 +243,7 @@ class WorkflowCommentListApi(Resource): mentioned_user_ids=payload.mentioned_user_ids, ) - return result, 201 + return dump_response(WorkflowCommentCreate, result), 201 @console_ns.route("/apps//workflow/comments/") @@ -127,30 +253,28 @@ class WorkflowCommentDetailApi(Resource): @console_ns.doc("get_workflow_comment") @console_ns.doc(description="Get a specific workflow comment") @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"}) - @console_ns.response(200, "Comment retrieved successfully", workflow_comment_detail_model) + @console_ns.response(200, "Comment retrieved successfully", console_ns.models[WorkflowCommentDetail.__name__]) @login_required @setup_required @account_initialization_required @get_app_model() - @marshal_with(workflow_comment_detail_model) def get(self, app_model: App, comment_id: str): """Get a specific workflow comment.""" comment = WorkflowCommentService.get_comment( tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id ) - return comment + return dump_response(WorkflowCommentDetail, comment) @console_ns.doc("update_workflow_comment") @console_ns.doc(description="Update a workflow comment") @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"}) @console_ns.expect(console_ns.models[WorkflowCommentUpdatePayload.__name__]) - @console_ns.response(200, "Comment updated successfully", workflow_comment_update_model) + @console_ns.response(200, "Comment updated successfully", console_ns.models[WorkflowCommentUpdate.__name__]) @login_required @setup_required @account_initialization_required @get_app_model() - @marshal_with(workflow_comment_update_model) @edit_permission_required def put(self, app_model: App, comment_id: str): """Update a workflow comment.""" @@ -167,7 +291,7 @@ class WorkflowCommentDetailApi(Resource): mentioned_user_ids=payload.mentioned_user_ids, ) - return result + return dump_response(WorkflowCommentUpdate, result) @console_ns.doc("delete_workflow_comment") @console_ns.doc(description="Delete a workflow comment") @@ -197,12 +321,11 @@ class WorkflowCommentResolveApi(Resource): @console_ns.doc("resolve_workflow_comment") @console_ns.doc(description="Resolve a workflow comment") @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"}) - @console_ns.response(200, "Comment resolved successfully", workflow_comment_resolve_model) + @console_ns.response(200, "Comment resolved successfully", console_ns.models[WorkflowCommentResolve.__name__]) @login_required @setup_required @account_initialization_required @get_app_model() - @marshal_with(workflow_comment_resolve_model) @edit_permission_required def post(self, app_model: App, comment_id: str): """Resolve a workflow comment.""" @@ -213,7 +336,7 @@ class WorkflowCommentResolveApi(Resource): user_id=current_user.id, ) - return comment + return dump_response(WorkflowCommentResolve, comment) @console_ns.route("/apps//workflow/comments//replies") @@ -224,12 +347,11 @@ class WorkflowCommentReplyApi(Resource): @console_ns.doc(description="Add a reply to a workflow comment") @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"}) @console_ns.expect(console_ns.models[WorkflowCommentReplyPayload.__name__]) - @console_ns.response(201, "Reply created successfully", workflow_comment_reply_create_model) + @console_ns.response(201, "Reply created successfully", console_ns.models[WorkflowCommentReplyCreate.__name__]) @login_required @setup_required @account_initialization_required @get_app_model() - @marshal_with(workflow_comment_reply_create_model) @edit_permission_required def post(self, app_model: App, comment_id: str): """Add a reply to a workflow comment.""" @@ -247,7 +369,7 @@ class WorkflowCommentReplyApi(Resource): mentioned_user_ids=payload.mentioned_user_ids, ) - return result, 201 + return dump_response(WorkflowCommentReplyCreate, result), 201 @console_ns.route("/apps//workflow/comments//replies/") @@ -258,12 +380,11 @@ class WorkflowCommentReplyDetailApi(Resource): @console_ns.doc(description="Update a comment reply") @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"}) @console_ns.expect(console_ns.models[WorkflowCommentReplyPayload.__name__]) - @console_ns.response(200, "Reply updated successfully", workflow_comment_reply_update_model) + @console_ns.response(200, "Reply updated successfully", console_ns.models[WorkflowCommentReplyUpdate.__name__]) @login_required @setup_required @account_initialization_required @get_app_model() - @marshal_with(workflow_comment_reply_update_model) @edit_permission_required def put(self, app_model: App, comment_id: str, reply_id: str): """Update a comment reply.""" @@ -284,7 +405,7 @@ class WorkflowCommentReplyDetailApi(Resource): mentioned_user_ids=payload.mentioned_user_ids, ) - return reply + return dump_response(WorkflowCommentReplyUpdate, reply) @console_ns.doc("delete_workflow_comment_reply") @console_ns.doc(description="Delete a comment reply") diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py index f6b8aedf22..e1f3f0eaeb 100644 --- a/api/controllers/console/auth/email_register.py +++ b/api/controllers/console/auth/email_register.py @@ -3,7 +3,7 @@ from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from configs import dify_config -from constants.languages import languages +from constants.languages import get_valid_language, languages from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.auth.error import ( @@ -15,11 +15,12 @@ from controllers.console.auth.error import ( PasswordMismatchError, ) from libs.helper import EmailStr, extract_remote_ip +from libs.helper import timezone as validate_timezone_string from libs.password import valid_password from models import Account from services.account_service import AccountService from services.billing_service import BillingService -from services.errors.account import AccountNotFoundError, AccountRegisterError +from services.errors.account import AccountRegisterError from ..error import AccountInFreezeError, EmailSendIpLimitError from ..wraps import email_password_login_enabled, email_register_enabled, setup_required @@ -40,12 +41,21 @@ class EmailRegisterResetPayload(BaseModel): token: str = Field(...) new_password: str = Field(...) password_confirm: str = Field(...) + language: str | None = Field(default=None) + timezone: str | None = Field(default=None) @field_validator("new_password", "password_confirm") @classmethod def validate_password(cls, value: str) -> str: return valid_password(value) + @field_validator("timezone") + @classmethod + def validate_timezone(cls, value: str | None) -> str | None: + if value is None: + return None + return validate_timezone_string(value) + register_schema_models(console_ns, EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload) @@ -144,26 +154,32 @@ class EmailRegisterResetApi(Resource): if account: raise EmailAlreadyInUseError() - else: - account = self._create_new_account(normalized_email, args.password_confirm) - if not account: - raise AccountNotFoundError() - token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) - AccountService.reset_login_error_rate_limit(normalized_email) + + account = self._create_new_account( + email=normalized_email, + password=args.password_confirm, + timezone=args.timezone, + language=args.language, + ) + token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) + AccountService.reset_login_error_rate_limit(normalized_email) return {"result": "success", "data": token_pair.model_dump()} - def _create_new_account(self, email: str, password: str) -> Account | None: - # Create new account if allowed - account = None + def _create_new_account( + self, + email: str, + password: str, + timezone: str | None = None, + language: str | None = None, + ) -> Account: try: - account = AccountService.create_account_and_tenant( + return AccountService.create_account_and_tenant( email=email, name=email, password=password, - interface_language=languages[0], + interface_language=get_valid_language(language), + timezone=timezone, ) except AccountRegisterError: raise AccountInFreezeError() - - return account diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 19c98f3a1a..3121470b84 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -3,7 +3,7 @@ import logging import flask_login from flask import make_response, request from flask_restx import Resource -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import Unauthorized import services @@ -34,6 +34,7 @@ from controllers.console.wraps import ( ) from events.tenant_event import tenant_was_created from libs.helper import EmailStr, extract_remote_ip +from libs.helper import timezone as validate_timezone_string from libs.login import current_account_with_tenant from libs.token import ( clear_access_token_from_cookie, @@ -69,6 +70,14 @@ class EmailCodeLoginPayload(BaseModel): code: str = Field(...) token: str = Field(...) language: str | None = Field(default=None) + timezone: str | None = Field(default=None) + + @field_validator("timezone") + @classmethod + def validate_timezone(cls, value: str | None) -> str | None: + if value is None: + return None + return validate_timezone_string(value) register_schema_models(console_ns, LoginPayload, EmailPayload, EmailCodeLoginPayload) @@ -288,6 +297,7 @@ class EmailCodeLoginApi(Resource): email=user_email, name=user_email, interface_language=get_valid_language(language), + timezone=args.timezone, ) except WorkSpaceNotAllowedCreateError: raise NotAllowedCreateWorkspace() diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index d31fb4a46c..2254fa4981 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -12,7 +12,8 @@ from events.tenant_event import tenant_was_created from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from libs.helper import extract_remote_ip -from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo +from libs.helper import timezone as validate_timezone_string +from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo, decode_oauth_state from libs.token import ( set_access_token_to_cookie, set_csrf_token_to_cookie, @@ -53,6 +54,31 @@ def get_oauth_providers(): return OAUTH_PROVIDERS +def _validated_timezone(value: str | None) -> str | None: + if not value: + return None + try: + return validate_timezone_string(value) + except ValueError: + return None + + +def _validated_language(value: str | None) -> str | None: + if value and value in languages: + return value + return None + + +def _preferred_interface_language(language: str | None = None) -> str: + if language: + return language + + preferred_lang = request.accept_languages.best_match(languages) + if preferred_lang and preferred_lang in languages: + return preferred_lang + return languages[0] + + @console_ns.route("/oauth/login/") class OAuthLogin(Resource): @console_ns.doc("oauth_login") @@ -64,13 +90,19 @@ class OAuthLogin(Resource): @console_ns.response(400, "Invalid provider") def get(self, provider: str): invite_token = request.args.get("invite_token") or None + timezone = _validated_timezone(request.args.get("timezone") or None) + language = _validated_language(request.args.get("language") or None) OAUTH_PROVIDERS = get_oauth_providers() with current_app.app_context(): oauth_provider = OAUTH_PROVIDERS.get(provider) if not oauth_provider: return {"error": "Invalid provider"}, 400 - auth_url = oauth_provider.get_authorization_url(invite_token=invite_token) + auth_url = oauth_provider.get_authorization_url( + invite_token=invite_token, + timezone=timezone, + language=language, + ) return redirect(auth_url) @@ -96,9 +128,10 @@ class OAuthCallback(Resource): code = request.args.get("code") state = request.args.get("state") - invite_token = None - if state: - invite_token = state + oauth_state = decode_oauth_state(state) + invite_token = oauth_state.get("invite_token") + timezone = _validated_timezone(oauth_state.get("timezone")) + language = _validated_language(oauth_state.get("language")) if not code: return {"error": "Authorization code is required"}, 400 @@ -129,7 +162,7 @@ class OAuthCallback(Resource): return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}") try: - account, oauth_new_user = _generate_account(provider, user_info) + account, oauth_new_user = _generate_account(provider, user_info, timezone=timezone, language=language) except AccountNotFoundError: return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.") except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError): @@ -184,7 +217,12 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> return account -def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account, bool]: +def _generate_account( + provider: str, + user_info: OAuthUserInfo, + timezone: str | None = None, + language: str | None = None, +) -> tuple[Account, bool]: # Get account by openid or email. account = _get_account_by_openid_or_email(provider, user_info) oauth_new_user = False @@ -211,26 +249,19 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account, "30 days and is temporarily unavailable for new account registration" ) ) - else: - raise AccountRegisterError(description=("Invalid email or password")) + raise AccountRegisterError(description=("Invalid email or password")) account_name = user_info.name or "Dify" + interface_language = _preferred_interface_language(language) account = RegisterService.register( email=normalized_email, name=account_name, password=None, open_id=user_info.id, provider=provider, + language=interface_language, + timezone=timezone, ) - # Set interface language - preferred_lang = request.accept_languages.best_match(languages) - if preferred_lang and preferred_lang in languages: - interface_language = preferred_lang - else: - interface_language = languages[0] - account.interface_language = interface_language - db.session.commit() - # Link account AccountService.link_account_integrate(provider, user_info.id, account) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index c4e13c41a5..dfe8192b89 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -39,6 +39,7 @@ from fields.document_fields import ( from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from libs.datetime_utils import naive_utc_now +from libs.helper import to_timestamp from libs.login import current_account_with_tenant, login_required from models import DatasetProcessRule, Document, DocumentSegment, UploadFile from models.dataset import DocumentPipelineExecutionLog @@ -71,12 +72,6 @@ from ..wraps import ( logger = logging.getLogger(__name__) -def _to_timestamp(value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return int(value.timestamp()) - return value - - def _normalize_enum(value: Any) -> Any: if isinstance(value, str) or value is None: return value @@ -101,7 +96,7 @@ class DatasetResponse(ResponseModel): @field_validator("created_at", mode="before") @classmethod def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) + return to_timestamp(value) class DocumentMetadataResponse(ResponseModel): @@ -152,7 +147,7 @@ class DocumentResponse(ResponseModel): @field_validator("created_at", "disabled_at", mode="before") @classmethod def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) + return to_timestamp(value) class DocumentWithSegmentsResponse(DocumentResponse): diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 36a7a4bb0e..8758f983ee 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -8,6 +8,7 @@ from pydantic import Field, field_validator from controllers.common.schema import register_schema_models from fields.base import ResponseModel +from libs.helper import to_timestamp from libs.login import login_required from .. import console_ns @@ -19,12 +20,6 @@ from ..wraps import ( ) -def _to_timestamp(value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return int(value.timestamp()) - return value - - class HitTestingDocument(ResponseModel): id: str | None = None data_source_type: str | None = None @@ -61,7 +56,7 @@ class HitTestingSegment(ResponseModel): @field_validator("disabled_at", "created_at", "indexing_at", "completed_at", "stopped_at", mode="before") @classmethod def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) + return to_timestamp(value) class HitTestingChildChunk(ResponseModel): diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 2d9a997fbf..08c72e45d5 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -16,6 +16,7 @@ from extensions.ext_database import db from fields.base import ResponseModel from graphon.file import helpers as file_helpers from libs.datetime_utils import naive_utc_now +from libs.helper import to_timestamp from libs.login import current_account_with_tenant, login_required from models import App, InstalledApp, RecommendedApp from models.model import IconType @@ -105,9 +106,7 @@ class InstalledAppResponse(ResponseModel): @field_validator("last_used_at", mode="before") @classmethod def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return int(value.timestamp()) - return value + return to_timestamp(value) class InstalledAppListResponse(ResponseModel): diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index 9ffc18e4c2..0c9a93c1cd 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, TypeAdapter, field_validator from constants import HIDDEN_VALUE from fields.base import ResponseModel +from libs.helper import to_timestamp from libs.login import current_account_with_tenant, login_required from models.api_based_extension import APIBasedExtension from services.api_based_extension_service import APIBasedExtensionService @@ -40,12 +41,6 @@ def _mask_api_key(api_key: str) -> str: return api_key[:3] + "******" + api_key[-3:] -def _to_timestamp(value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return int(value.timestamp()) - return value - - class APIBasedExtensionResponse(ResponseModel): id: str name: str @@ -61,7 +56,7 @@ class APIBasedExtensionResponse(ResponseModel): @field_validator("created_at", mode="before") @classmethod def _normalize_created_at(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) + return to_timestamp(value) register_schema_models(console_ns, APIBasedExtensionPayload, CodeBasedExtensionResponse, APIBasedExtensionResponse) diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index 9fa5b0f5c1..5751026040 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -105,7 +105,8 @@ class FilePreviewApi(Resource): @account_initialization_required def get(self, file_id): file_id = str(file_id) - text = FileService(db.engine).get_file_preview(file_id) + _, tenant_id = current_account_with_tenant() + text = FileService(db.engine).get_file_preview(file_id, tenant_id) return {"content": text} diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index b9e876c906..346f572ccc 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -25,6 +25,10 @@ class TagBasePayload(BaseModel): type: TagType = Field(description="Tag type") +class TagUpdateRequestPayload(BaseModel): + name: str = Field(description="Tag name", min_length=1, max_length=50) + + class TagBindingPayload(BaseModel): tag_ids: list[str] = Field(description="Tag IDs to bind") target_id: str = Field(description="Target ID to bind tags to") @@ -68,6 +72,7 @@ class TagResponse(ResponseModel): register_schema_models( console_ns, TagBasePayload, + TagUpdateRequestPayload, TagBindingPayload, TagBindingRemovePayload, TagListQueryParam, @@ -118,7 +123,7 @@ class TagListApi(Resource): @console_ns.route("/tags/") class TagUpdateDeleteApi(Resource): - @console_ns.expect(console_ns.models[TagBasePayload.__name__]) + @console_ns.expect(console_ns.models[TagUpdateRequestPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -129,8 +134,8 @@ class TagUpdateDeleteApi(Resource): if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() - payload = TagBasePayload.model_validate(console_ns.payload or {}) - tag = TagService.update_tags(UpdateTagPayload(name=payload.name, type=payload.type), tag_id) + payload = TagUpdateRequestPayload.model_validate(console_ns.payload or {}) + tag = TagService.update_tags(UpdateTagPayload(name=payload.name), tag_id) binding_count = TagService.get_tag_binding_count(tag_id) diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 68520e540b..b1c363433a 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -42,7 +42,7 @@ from fields.base import ResponseModel from fields.member_fields import Account as AccountResponse from graphon.file import helpers as file_helpers from libs.datetime_utils import naive_utc_now -from libs.helper import EmailStr, extract_remote_ip, timezone +from libs.helper import EmailStr, extract_remote_ip, timezone, to_timestamp from libs.login import current_account_with_tenant, login_required from models import AccountIntegrate, InvitationCode from models.account import AccountStatus, InvitationCodeStatus @@ -185,12 +185,6 @@ def _serialize_account(account) -> dict[str, Any]: return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json") -def _to_timestamp(value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return int(value.timestamp()) - return value - - class AccountIntegrateResponse(ResponseModel): provider: str created_at: int | None = None @@ -200,7 +194,7 @@ class AccountIntegrateResponse(ResponseModel): @field_validator("created_at", mode="before") @classmethod def _normalize_created_at(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) + return to_timestamp(value) class AccountIntegrateListResponse(ResponseModel): @@ -220,7 +214,7 @@ class EducationStatusResponse(ResponseModel): @field_validator("expire_at", mode="before") @classmethod def _normalize_expire_at(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) + return to_timestamp(value) class EducationAutocompleteResponse(ResponseModel): diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 84890f0443..1eb91c472e 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -29,7 +29,7 @@ from controllers.console.wraps import ( from enums.cloud_plan import CloudPlan from extensions.ext_database import db from fields.base import ResponseModel -from libs.helper import TimestampField +from libs.helper import TimestampField, to_timestamp from libs.login import current_account_with_tenant, login_required from models.account import Tenant, TenantCustomConfigDict, TenantStatus from services.account_service import TenantService @@ -86,9 +86,7 @@ class TenantInfoResponse(ResponseModel): @field_validator("created_at", mode="before") @classmethod def _normalize_created_at(cls, value: datetime | int | None): - if isinstance(value, datetime): - return int(value.timestamp()) - return value + return to_timestamp(value) register_schema_models( diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index ca4b18cb5e..64b2038f9c 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -22,7 +22,7 @@ from fields.conversation_fields import ( SimpleConversation, ) from graphon.variables.types import SegmentType -from libs.helper import UUIDStrOrEmpty +from libs.helper import UUIDStrOrEmpty, to_timestamp from models.model import App, AppMode, EndUser from services.conversation_service import ConversationService @@ -115,9 +115,7 @@ class ConversationVariableResponse(ResponseModel): @field_validator("created_at", "updated_at", mode="before") @classmethod def normalize_timestamp(cls, value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return int(value.timestamp()) - return value + return to_timestamp(value) class ConversationVariableInfiniteScrollPaginationResponse(ResponseModel): diff --git a/api/controllers/service_api/app/human_input_form.py b/api/controllers/service_api/app/human_input_form.py index 8e5003dbbf..2b38a84b0e 100644 --- a/api/controllers/service_api/app/human_input_form.py +++ b/api/controllers/service_api/app/human_input_form.py @@ -7,18 +7,18 @@ paused human input forms in workflow/chatflow runs. import json import logging -from datetime import datetime from flask import Response from flask_restx import Resource from werkzeug.exceptions import BadRequest, NotFound -from controllers.common.human_input import HumanInputFormSubmitPayload +from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values from controllers.common.schema import register_schema_models from controllers.service_api import service_api_ns from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface from extensions.ext_database import db +from libs.helper import to_timestamp from models.model import App, EndUser from services.human_input_service import Form, FormNotFoundError, HumanInputService @@ -28,30 +28,14 @@ logger = logging.getLogger(__name__) register_schema_models(service_api_ns, HumanInputFormSubmitPayload) -def _stringify_default_values(values: dict[str, object]) -> dict[str, str]: - result: dict[str, str] = {} - for key, value in values.items(): - if value is None: - result[key] = "" - elif isinstance(value, (dict, list)): - result[key] = json.dumps(value, ensure_ascii=False) - else: - result[key] = str(value) - return result - - -def _to_timestamp(value: datetime) -> int: - return int(value.timestamp()) - - def _jsonify_form_definition(form: Form) -> Response: definition_payload = form.get_definition().model_dump() payload = { "form_content": definition_payload["rendered_content"], "inputs": definition_payload["inputs"], - "resolved_default_values": _stringify_default_values(definition_payload["default_values"]), + "resolved_default_values": stringify_form_default_values(definition_payload["default_values"]), "user_actions": definition_payload["user_actions"], - "expiration_time": _to_timestamp(form.expiration_time), + "expiration_time": to_timestamp(form.expiration_time), } return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json") diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index cc763fa89c..45d2dda858 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -39,6 +39,7 @@ from graphon.enums import WorkflowExecutionStatus from graphon.graph_engine.manager import GraphEngineManager from graphon.model_runtime.errors.invoke import InvokeError from libs import helper +from libs.helper import to_timestamp from models.model import App, AppMode, EndUser from models.workflow import WorkflowRun from repositories.factory import DifyAPIRepositoryFactory @@ -68,12 +69,6 @@ class WorkflowLogQuery(BaseModel): register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery) -def _to_timestamp(value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return int(value.timestamp()) - return value - - def _enum_value(value): return getattr(value, "value", value) @@ -109,7 +104,7 @@ class WorkflowRunResponse(ResponseModel): @field_validator("created_at", "finished_at", mode="before") @classmethod def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) + return to_timestamp(value) class WorkflowRunForLogResponse(ResponseModel): @@ -133,7 +128,7 @@ class WorkflowRunForLogResponse(ResponseModel): @field_validator("created_at", "finished_at", mode="before") @classmethod def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) + return to_timestamp(value) class WorkflowAppLogPartialResponse(ResponseModel): @@ -154,7 +149,7 @@ class WorkflowAppLogPartialResponse(ResponseModel): @field_validator("created_at", mode="before") @classmethod def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) + return to_timestamp(value) class WorkflowAppLogPaginationResponse(ResponseModel): diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 9af66f1960..d85e46498d 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -31,7 +31,9 @@ from services.tag_service import ( TagBindingCreatePayload, TagBindingDeletePayload, TagService, - UpdateTagPayload, +) +from services.tag_service import ( + UpdateTagPayload as UpdateTagServicePayload, ) register_enum_models(service_api_ns, DatasetPermissionEnum) @@ -556,7 +558,7 @@ class DatasetTagsApi(DatasetApiResource): payload = TagUpdatePayload.model_validate(service_api_ns.payload or {}) tag_id = payload.tag_id - tag = TagService.update_tags(UpdateTagPayload(name=payload.name, type=TagType.KNOWLEDGE), tag_id) + tag = TagService.update_tags(UpdateTagServicePayload(name=payload.name), tag_id) binding_count = TagService.get_tag_binding_count(tag_id) diff --git a/api/controllers/web/human_input_form.py b/api/controllers/web/human_input_form.py index 1ddf2e0717..69297450c9 100644 --- a/api/controllers/web/human_input_form.py +++ b/api/controllers/web/human_input_form.py @@ -4,7 +4,6 @@ Web App Human Input Form APIs. import json import logging -from datetime import datetime from typing import Any, NotRequired, TypedDict from flask import Response, request @@ -13,12 +12,12 @@ from sqlalchemy import select from werkzeug.exceptions import Forbidden from configs import dify_config -from controllers.common.human_input import HumanInputFormSubmitPayload +from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values from controllers.web import web_ns from controllers.web.error import NotFoundError, WebFormRateLimitExceededError from controllers.web.site import serialize_app_site_payload from extensions.ext_database import db -from libs.helper import RateLimiter, extract_remote_ip +from libs.helper import RateLimiter, extract_remote_ip, to_timestamp from models.account import TenantStatus from models.model import App, Site from services.human_input_service import Form, FormNotFoundError, HumanInputService @@ -38,22 +37,6 @@ _FORM_ACCESS_RATE_LIMITER = RateLimiter( ) -def _stringify_default_values(values: dict[str, object]) -> dict[str, str]: - result: dict[str, str] = {} - for key, value in values.items(): - if value is None: - result[key] = "" - elif isinstance(value, (dict, list)): - result[key] = json.dumps(value, ensure_ascii=False) - else: - result[key] = str(value) - return result - - -def _to_timestamp(value: datetime) -> int: - return int(value.timestamp()) - - class FormDefinitionPayload(TypedDict): form_content: Any inputs: Any @@ -69,9 +52,9 @@ def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Re payload: FormDefinitionPayload = { "form_content": definition_payload["rendered_content"], "inputs": definition_payload["inputs"], - "resolved_default_values": _stringify_default_values(definition_payload["default_values"]), + "resolved_default_values": stringify_form_default_values(definition_payload["default_values"]), "user_actions": definition_payload["user_actions"], - "expiration_time": _to_timestamp(form.expiration_time), + "expiration_time": to_timestamp(form.expiration_time), } if site_payload is not None: payload["site"] = site_payload diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 82dbf5381d..3c46f91e51 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -9,7 +9,7 @@ from datetime import datetime from threading import Thread from typing import Any, Union -from sqlalchemy import select +from sqlalchemy import select, update from sqlalchemy.orm import Session, sessionmaker from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME @@ -425,11 +425,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): self._workflow_run_id = run_id with self._database_session() as session: - message = self._get_message(session=session) - if not message: - raise ValueError(f"Message not found: {self._message_id}") - - message.workflow_run_id = run_id + session.execute(update(Message).where(Message.id == self._message_id).values(workflow_run_id=run_id)) workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response( task_id=self._application_generate_entity.task_id, diff --git a/api/core/app/file_access/__init__.py b/api/core/app/file_access/__init__.py index a75ab9781b..e02aba102b 100644 --- a/api/core/app/file_access/__init__.py +++ b/api/core/app/file_access/__init__.py @@ -1,6 +1,13 @@ from .controller import DatabaseFileAccessController from .protocols import FileAccessControllerProtocol -from .scope import FileAccessScope, bind_file_access_scope, get_current_file_access_scope +from .scope import ( + FileAccessScope, + bind_file_access_scope, + get_current_file_access_scope, + grant_retriever_segment_access, + grant_upload_file_access, + is_retriever_segment_access_granted, +) __all__ = [ "DatabaseFileAccessController", @@ -8,4 +15,7 @@ __all__ = [ "FileAccessScope", "bind_file_access_scope", "get_current_file_access_scope", + "grant_retriever_segment_access", + "grant_upload_file_access", + "is_retriever_segment_access_granted", ] diff --git a/api/core/app/file_access/controller.py b/api/core/app/file_access/controller.py index 300c187083..a6c6e74f06 100644 --- a/api/core/app/file_access/controller.py +++ b/api/core/app/file_access/controller.py @@ -2,7 +2,7 @@ from __future__ import annotations from collections.abc import Callable -from sqlalchemy import select +from sqlalchemy import and_, or_, select from sqlalchemy.orm import Session from sqlalchemy.sql import Select @@ -18,7 +18,8 @@ class DatabaseFileAccessController(FileAccessControllerProtocol): Tenant scoping remains mandatory. When the current execution belongs to an end user, the lookup is additionally constrained to that end user's file - ownership markers. + ownership markers, plus upload files explicitly granted by the current + execution context. """ _scope_getter: Callable[[], FileAccessScope | None] @@ -47,10 +48,19 @@ class DatabaseFileAccessController(FileAccessControllerProtocol): if not resolved_scope.requires_user_ownership: return scoped_stmt - return scoped_stmt.where( + user_owned_filter = and_( UploadFile.created_by_role == CreatorUserRole.END_USER, UploadFile.created_by == resolved_scope.user_id, ) + if not resolved_scope.granted_upload_file_ids: + return scoped_stmt.where(user_owned_filter) + + return scoped_stmt.where( + or_( + user_owned_filter, + UploadFile.id.in_(resolved_scope.granted_upload_file_ids), + ) + ) def apply_tool_file_filters( self, diff --git a/api/core/app/file_access/scope.py b/api/core/app/file_access/scope.py index a583301f9b..12fe7b3840 100644 --- a/api/core/app/file_access/scope.py +++ b/api/core/app/file_access/scope.py @@ -1,9 +1,9 @@ from __future__ import annotations -from collections.abc import Generator # Changed from Iterator +from collections.abc import Generator, Iterable from contextlib import contextmanager from contextvars import ContextVar -from dataclasses import dataclass +from dataclasses import dataclass, field, replace from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom @@ -15,12 +15,23 @@ _current_file_access_scope: ContextVar[FileAccessScope | None] = ContextVar( @dataclass(frozen=True, slots=True) class FileAccessScope: - """Request-scoped ownership context used by workflow-layer file lookups.""" + """Request-scoped ownership context used by workflow-layer file lookups. + + ``granted_upload_file_ids`` is execution-local: callers may add upload files + that were returned by trusted retrieval paths without changing persistent + ownership markers. + + ``granted_retriever_segment_ids`` gates lazy attachment loading by segment + ID, so user-provided context cannot make a later LLM node load arbitrary + same-tenant knowledge attachments. + """ tenant_id: str user_id: str user_from: UserFrom invoke_from: InvokeFrom + granted_upload_file_ids: frozenset[str] = field(default_factory=frozenset) + granted_retriever_segment_ids: frozenset[str] = field(default_factory=frozenset) @property def requires_user_ownership(self) -> bool: @@ -31,8 +42,49 @@ def get_current_file_access_scope() -> FileAccessScope | None: return _current_file_access_scope.get() +def grant_upload_file_access(upload_file_ids: Iterable[str]) -> None: + scope = _current_file_access_scope.get() + if scope is None: + return + + granted_upload_file_ids = frozenset(str(file_id) for file_id in upload_file_ids if file_id) + if not granted_upload_file_ids: + return + + _current_file_access_scope.set( + replace( + scope, + granted_upload_file_ids=scope.granted_upload_file_ids | granted_upload_file_ids, + ) + ) + + +def grant_retriever_segment_access(segment_ids: Iterable[str]) -> None: + scope = _current_file_access_scope.get() + if scope is None: + return + + granted_segment_ids = frozenset(str(segment_id) for segment_id in segment_ids if segment_id) + if not granted_segment_ids: + return + + _current_file_access_scope.set( + replace( + scope, + granted_retriever_segment_ids=scope.granted_retriever_segment_ids | granted_segment_ids, + ) + ) + + +def is_retriever_segment_access_granted(segment_id: str) -> bool: + scope = _current_file_access_scope.get() + if scope is None or not scope.requires_user_ownership: + return True + return str(segment_id) in scope.granted_retriever_segment_ids + + @contextmanager -def bind_file_access_scope(scope: FileAccessScope) -> Generator[None, None, None]: # Changed from Iterator[None] +def bind_file_access_scope(scope: FileAccessScope) -> Generator[None, None, None]: token = _current_file_access_scope.set(scope) try: yield diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index e2e07ebaff..171d5ab342 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -140,42 +140,43 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): :return: """ for stream_response in generator: - if isinstance(stream_response, ErrorStreamResponse): - raise stream_response.err - elif isinstance(stream_response, MessageEndStreamResponse): - extras = {"usage": self._task_state.llm_result.usage.model_dump()} - if self._task_state.metadata: - extras["metadata"] = self._task_state.metadata.model_dump() - response: ChatbotAppBlockingResponse | CompletionAppBlockingResponse - if self._conversation_mode == AppMode.COMPLETION: - response = CompletionAppBlockingResponse( - task_id=self._application_generate_entity.task_id, - data=CompletionAppBlockingResponse.Data( - id=self._message_id, - mode=self._conversation_mode, - message_id=self._message_id, - answer=self._task_state.llm_result.message.get_text_content(), - created_at=self._message_created_at, - **extras, - ), - ) - else: - response = ChatbotAppBlockingResponse( - task_id=self._application_generate_entity.task_id, - data=ChatbotAppBlockingResponse.Data( - id=self._message_id, - mode=self._conversation_mode, - conversation_id=self._conversation_id, - message_id=self._message_id, - answer=self._task_state.llm_result.message.get_text_content(), - created_at=self._message_created_at, - **extras, - ), - ) + match stream_response: + case ErrorStreamResponse(): + raise stream_response.err + case MessageEndStreamResponse(): + extras = {"usage": self._task_state.llm_result.usage.model_dump()} + if self._task_state.metadata: + extras["metadata"] = self._task_state.metadata.model_dump() + response: ChatbotAppBlockingResponse | CompletionAppBlockingResponse + if self._conversation_mode == AppMode.COMPLETION: + response = CompletionAppBlockingResponse( + task_id=self._application_generate_entity.task_id, + data=CompletionAppBlockingResponse.Data( + id=self._message_id, + mode=self._conversation_mode, + message_id=self._message_id, + answer=self._task_state.llm_result.message.get_text_content(), + created_at=self._message_created_at, + **extras, + ), + ) + else: + response = ChatbotAppBlockingResponse( + task_id=self._application_generate_entity.task_id, + data=ChatbotAppBlockingResponse.Data( + id=self._message_id, + mode=self._conversation_mode, + conversation_id=self._conversation_id, + message_id=self._message_id, + answer=self._task_state.llm_result.message.get_text_content(), + created_at=self._message_created_at, + **extras, + ), + ) - return response - else: - continue + return response + case _: + continue raise RuntimeError("queue listening stopped unexpectedly.") @@ -265,104 +266,107 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): publisher.publish(message) event = message.event - if isinstance(event, QueueErrorEvent): - with sessionmaker(bind=db.engine).begin() as session: - err = self.handle_error(event=event, session=session, message_id=self._message_id) - yield self.error_to_stream_response(err) - break - elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): - if isinstance(event, QueueMessageEndEvent): - if event.llm_result: - self._task_state.llm_result = event.llm_result - else: - self._handle_stop(event) + match event: + case QueueErrorEvent(): + with sessionmaker(bind=db.engine).begin() as session: + err = self.handle_error(event=event, session=session, message_id=self._message_id) + yield self.error_to_stream_response(err) + break + case QueueStopEvent() | QueueMessageEndEvent(): + if isinstance(event, QueueMessageEndEvent): + if event.llm_result: + self._task_state.llm_result = event.llm_result + else: + self._handle_stop(event) - # handle output moderation - output_moderation_answer = self.handle_output_moderation_when_task_finished( - self._task_state.llm_result.message.get_text_content() - ) - if output_moderation_answer: - self._task_state.llm_result.message.content = output_moderation_answer - yield self._message_cycle_manager.message_replace_to_stream_response( - answer=output_moderation_answer + # handle output moderation + output_moderation_answer = self.handle_output_moderation_when_task_finished( + self._task_state.llm_result.message.get_text_content() ) - - with sessionmaker(bind=db.engine).begin() as session: - # Save message - self._save_message(session=session, trace_manager=trace_manager) - message_end_resp = self._message_end_to_stream_response() - yield message_end_resp - elif isinstance(event, QueueRetrieverResourcesEvent): - self._message_cycle_manager.handle_retriever_resources(event) - elif isinstance(event, QueueAnnotationReplyEvent): - annotation = self._message_cycle_manager.handle_annotation_reply(event) - if annotation: - self._task_state.llm_result.message.content = annotation.content - elif isinstance(event, QueueAgentThoughtEvent): - agent_thought_response = self._agent_thought_to_stream_response(event) - if agent_thought_response is not None: - yield agent_thought_response - elif isinstance(event, QueueMessageFileEvent): - response = self._message_cycle_manager.message_file_to_stream_response(event) - if response: - yield response - elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent): - chunk = event.chunk - delta_text = chunk.delta.message.content - if delta_text is None: - continue - if isinstance(chunk.delta.message.content, list): - delta_text = "" - for content in chunk.delta.message.content: - logger.debug( - "The content type %s in LLM chunk delta message content.: %r", type(content), content + if output_moderation_answer: + self._task_state.llm_result.message.content = output_moderation_answer + yield self._message_cycle_manager.message_replace_to_stream_response( + answer=output_moderation_answer ) - if isinstance(content, TextPromptMessageContent): - delta_text += content.data - elif isinstance(content, str): - delta_text += content # failback to str - else: - logger.warning( - "Unsupported content type %s in LLM chunk delta message content.: %r", - type(content), - content, + + with sessionmaker(bind=db.engine).begin() as session: + # Save message + self._save_message(session=session, trace_manager=trace_manager) + message_end_resp = self._message_end_to_stream_response() + yield message_end_resp + case QueueRetrieverResourcesEvent(): + self._message_cycle_manager.handle_retriever_resources(event) + case QueueAnnotationReplyEvent(): + annotation = self._message_cycle_manager.handle_annotation_reply(event) + if annotation: + self._task_state.llm_result.message.content = annotation.content + case QueueAgentThoughtEvent(): + agent_thought_response = self._agent_thought_to_stream_response(event) + if agent_thought_response is not None: + yield agent_thought_response + case QueueMessageFileEvent(): + response = self._message_cycle_manager.message_file_to_stream_response(event) + if response: + yield response + case QueueLLMChunkEvent() | QueueAgentMessageEvent(): + chunk = event.chunk + delta_text = chunk.delta.message.content + if delta_text is None: + continue + if isinstance(chunk.delta.message.content, list): + delta_text = "" + for content in chunk.delta.message.content: + logger.debug( + "The content type %s in LLM chunk delta message content.: %r", type(content), content ) - continue + match content: + case TextPromptMessageContent(): + delta_text += content.data + case str(): + delta_text += content # failback to str + case _: + logger.warning( + "Unsupported content type %s in LLM chunk delta message content.: %r", + type(content), + content, + ) + continue - if not self._task_state.llm_result.prompt_messages: - self._task_state.llm_result.prompt_messages = chunk.prompt_messages + if not self._task_state.llm_result.prompt_messages: + self._task_state.llm_result.prompt_messages = chunk.prompt_messages - # handle output moderation chunk - should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text)) - if should_direct_answer: + # handle output moderation chunk + should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text)) + if should_direct_answer: + continue + + current_content = cast(str, self._task_state.llm_result.message.content) + current_content += cast(str, delta_text) + self._task_state.llm_result.message.content = current_content + + match event: + case QueueLLMChunkEvent(): + # Determine the event type once, on first LLM chunk, and reuse for subsequent chunks + if not hasattr(self, "_precomputed_event_type") or self._precomputed_event_type is None: + self._precomputed_event_type = self._message_cycle_manager.get_message_event_type( + message_id=self._message_id + ) + yield self._message_cycle_manager.message_to_stream_response( + answer=cast(str, delta_text), + message_id=self._message_id, + event_type=self._precomputed_event_type, + ) + case _: + yield self._agent_message_to_stream_response( + answer=cast(str, delta_text), + message_id=self._message_id, + ) + case QueueMessageReplaceEvent(): + yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text) + case QueuePingEvent(): + yield self.ping_stream_response() + case _: continue - - current_content = cast(str, self._task_state.llm_result.message.content) - current_content += cast(str, delta_text) - self._task_state.llm_result.message.content = current_content - - if isinstance(event, QueueLLMChunkEvent): - # Determine the event type once, on first LLM chunk, and reuse for subsequent chunks - if not hasattr(self, "_precomputed_event_type") or self._precomputed_event_type is None: - self._precomputed_event_type = self._message_cycle_manager.get_message_event_type( - message_id=self._message_id - ) - yield self._message_cycle_manager.message_to_stream_response( - answer=cast(str, delta_text), - message_id=self._message_id, - event_type=self._precomputed_event_type, - ) - else: - yield self._agent_message_to_stream_response( - answer=cast(str, delta_text), - message_id=self._message_id, - ) - elif isinstance(event, QueueMessageReplaceEvent): - yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text) - elif isinstance(event, QueuePingEvent): - yield self.ping_stream_response() - else: - continue if publisher: publisher.publish(None) if self._conversation_name_generate_thread: diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 8cc2be8feb..904d5c843f 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -9,6 +9,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session, load_only from configs import dify_config +from core.app.file_access import grant_upload_file_access from core.db.session_factory import session_factory from core.model_manager import ModelManager from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict @@ -890,6 +891,7 @@ class RetrievalService: .limit(1) ) if attachment_binding: + grant_upload_file_access([str(upload_file.id)]) attachment_info: AttachmentInfoDict = { "id": upload_file.id, "name": upload_file.name, @@ -906,6 +908,7 @@ class RetrievalService: cls, attachment_ids: list[str], session: Session ) -> list[SegmentAttachmentInfoResult]: attachment_infos: list[SegmentAttachmentInfoResult] = [] + granted_upload_file_ids: list[str] = [] upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(attachment_ids))).all() if upload_files: upload_file_ids = [upload_file.id for upload_file in upload_files] @@ -926,6 +929,7 @@ class RetrievalService: "size": upload_file.size, } if attachment_binding: + granted_upload_file_ids.append(str(upload_file.id)) attachment_infos.append( { "attachment_id": attachment_binding.attachment_id, @@ -933,4 +937,5 @@ class RetrievalService: "segment_id": attachment_binding.segment_id, } ) + grant_upload_file_access(granted_upload_file_ids) return attachment_infos diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 010566d203..039a266f44 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -19,6 +19,7 @@ from core.app.app_config.entities import ( ModelConfig, ) from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity +from core.app.file_access import grant_retriever_segment_access, grant_upload_file_access from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.db.session_factory import session_factory from core.entities.agent_entities import PlanningStrategy @@ -326,6 +327,7 @@ class DatasetRetrieval: if record.summary: source.summary = record.summary + grant_retriever_segment_access([str(segment.id)]) retrieval_resource_list.append(source) if retrieval_resource_list: @@ -515,6 +517,9 @@ class DatasetRetrieval: ) ).all() if attachments_with_bindings: + grant_upload_file_access( + str(upload_file.id) for _, upload_file in attachments_with_bindings + ) for _, upload_file in attachments_with_bindings: attachment_info = File( file_id=upload_file.id, diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index c3fbc836d6..d73ff43386 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -1,6 +1,6 @@ import importlib import pkgutil -from collections.abc import Callable, Iterator, Mapping, MutableMapping +from collections.abc import Callable, Iterator, Mapping, MutableMapping, Sequence from dataclasses import dataclass from functools import lru_cache from typing import TYPE_CHECKING, Any, cast, final, override @@ -56,6 +56,7 @@ from graphon.nodes.http_request import build_http_request_config from graphon.nodes.llm.entities import LLMNodeData from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData +from graphon.variables.segments import ArrayObjectSegment from models.model import Conversation if TYPE_CHECKING: @@ -497,13 +498,47 @@ class DifyNodeFactory(NodeFactory): if include_prompt_message_serializer: node_init_kwargs["prompt_message_serializer"] = self._prompt_message_serializer if include_retriever_attachment_loader: - node_init_kwargs["retriever_attachment_loader"] = self._retriever_attachment_loader + node_init_kwargs["retriever_attachment_loader"] = self._build_retriever_attachment_loader( + cast(LLMNodeData, validated_node_data) + ) if include_jinja2_template_renderer: node_init_kwargs["jinja2_template_renderer"] = self._jinja2_template_renderer if validated_node_data.type == BuiltinNodeTypes.LLM: node_init_kwargs["default_query_selector"] = system_variable_selector(SystemVariableKey.QUERY) return node_init_kwargs + def _build_retriever_attachment_loader(self, node_data: LLMNodeData) -> DifyRetrieverAttachmentLoader: + return DifyRetrieverAttachmentLoader( + file_reference_factory=self._file_reference_factory, + segment_access_checker=self._build_retriever_segment_access_checker( + node_data.context.variable_selector if node_data.context.enabled else None + ), + ) + + def _build_retriever_segment_access_checker( + self, + context_variable_selector: Sequence[str] | None, + ) -> Callable[[str], bool]: + def checker(segment_id: str) -> bool: + if not context_variable_selector: + return False + + context_value = self.graph_runtime_state.variable_pool.get(context_variable_selector) + if not isinstance(context_value, ArrayObjectSegment): + return False + + for item in context_value.value: + if not isinstance(item, Mapping): + continue + metadata = item.get("metadata") + if not isinstance(metadata, Mapping): + continue + if metadata.get("_source") == "knowledge" and str(metadata.get("segment_id")) == str(segment_id): + return True + return False + + return checker + def _build_model_instance_for_llm_node(self, node_data: LLMCompatibleNodeData) -> ModelInstance: node_data_model = node_data.model model_instance, _ = fetch_model_config( diff --git a/api/core/workflow/node_runtime.py b/api/core/workflow/node_runtime.py index 85927199b0..7d6ac74791 100644 --- a/api/core/workflow/node_runtime.py +++ b/api/core/workflow/node_runtime.py @@ -8,7 +8,11 @@ from sqlalchemy import select from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext -from core.app.file_access import DatabaseFileAccessController +from core.app.file_access import ( + DatabaseFileAccessController, + grant_upload_file_access, + is_retriever_segment_access_granted, +) from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.helper.trace_id_helper import ParentTraceContext from core.llm_generator.output_parser.errors import OutputParserError @@ -275,10 +279,23 @@ class DifyPromptMessageSerializer(PromptMessageSerializerProtocol): class DifyRetrieverAttachmentLoader(RetrieverAttachmentLoaderProtocol): """Resolve retriever attachments through Dify persistence and return graph file references.""" - def __init__(self, *, file_reference_factory: FileReferenceFactoryProtocol) -> None: + _segment_access_checker: Callable[[str], bool] | None + + def __init__( + self, + *, + file_reference_factory: FileReferenceFactoryProtocol, + segment_access_checker: Callable[[str], bool] | None = None, + ) -> None: self._file_reference_factory = file_reference_factory + self._segment_access_checker = segment_access_checker def load(self, *, segment_id: str) -> Sequence[File]: + if not is_retriever_segment_access_granted(segment_id): + return [] + if self._segment_access_checker is not None and not self._segment_access_checker(segment_id): + return [] + with Session(db.engine, expire_on_commit=False) as session: attachments_with_bindings = session.execute( select(SegmentAttachmentBinding, UploadFile) @@ -286,6 +303,7 @@ class DifyRetrieverAttachmentLoader(RetrieverAttachmentLoaderProtocol): .where(SegmentAttachmentBinding.segment_id == segment_id) ).all() + grant_upload_file_access(str(upload_file.id) for _, upload_file in attachments_with_bindings) return [ self._file_reference_factory.build_from_mapping( mapping={ diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py index b2a0e92c47..4546a051cc 100644 --- a/api/fields/annotation_fields.py +++ b/api/fields/annotation_fields.py @@ -5,12 +5,7 @@ from datetime import datetime from pydantic import Field, field_validator from fields.base import ResponseModel - - -def _to_timestamp(value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return int(value.timestamp()) - return value +from libs.helper import to_timestamp class Annotation(ResponseModel): @@ -23,7 +18,7 @@ class Annotation(ResponseModel): @field_validator("created_at", mode="before") @classmethod def _normalize_created_at(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) + return to_timestamp(value) class AnnotationList(ResponseModel): @@ -50,7 +45,7 @@ class AnnotationHitHistory(ResponseModel): @field_validator("created_at", mode="before") @classmethod def _normalize_created_at(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) + return to_timestamp(value) class AnnotationHitHistoryList(ResponseModel): diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index bf5c9ffcb1..eb49577d59 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -7,6 +7,7 @@ from pydantic import Field, field_validator, model_validator from fields.base import ResponseModel from graphon.file import File +from libs.helper import to_timestamp type JSONValue = Any @@ -47,9 +48,7 @@ class SimpleConversation(ResponseModel): @field_validator("created_at", "updated_at", mode="before") @classmethod def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return to_timestamp(value) - return value + return to_timestamp(value) class ConversationInfiniteScrollPagination(ResponseModel): @@ -90,9 +89,7 @@ class ConversationAnnotation(ResponseModel): @field_validator("created_at", mode="before") @classmethod def _normalize_created_at(cls, value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return to_timestamp(value) - return value + return to_timestamp(value) class ConversationAnnotationHitHistory(ResponseModel): @@ -103,9 +100,7 @@ class ConversationAnnotationHitHistory(ResponseModel): @field_validator("created_at", mode="before") @classmethod def _normalize_created_at(cls, value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return to_timestamp(value) - return value + return to_timestamp(value) class AgentThought(ResponseModel): @@ -125,9 +120,7 @@ class AgentThought(ResponseModel): @field_validator("created_at", mode="before") @classmethod def _normalize_created_at(cls, value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return to_timestamp(value) - return value + return to_timestamp(value) @model_validator(mode="after") def _fallback_chain_id(self): @@ -169,9 +162,7 @@ class MessageDetail(ResponseModel): @field_validator("created_at", mode="before") @classmethod def _normalize_created_at(cls, value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return to_timestamp(value) - return value + return to_timestamp(value) class FeedbackStat(ResponseModel): @@ -237,9 +228,7 @@ class Conversation(ResponseModel): @field_validator("read_at", "created_at", "updated_at", mode="before") @classmethod def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return to_timestamp(value) - return value + return to_timestamp(value) class ConversationPagination(ResponseModel): @@ -263,9 +252,7 @@ class ConversationMessageDetail(ResponseModel): @field_validator("created_at", mode="before") @classmethod def _normalize_created_at(cls, value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return to_timestamp(value) - return value + return to_timestamp(value) class ConversationWithSummary(ResponseModel): @@ -291,9 +278,7 @@ class ConversationWithSummary(ResponseModel): @field_validator("read_at", "created_at", "updated_at", mode="before") @classmethod def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return to_timestamp(value) - return value + return to_timestamp(value) class ConversationWithSummaryPagination(ResponseModel): @@ -322,15 +307,7 @@ class ConversationDetail(ResponseModel): @field_validator("created_at", "updated_at", mode="before") @classmethod def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return to_timestamp(value) - return value - - -def to_timestamp(value: datetime | None) -> int | None: - if value is None: - return None - return int(value.timestamp()) + return to_timestamp(value) def format_files_contained(value: JSONValue) -> JSONValue: diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py index e4219ba1ee..05a519f3b1 100644 --- a/api/fields/conversation_variable_fields.py +++ b/api/fields/conversation_variable_fields.py @@ -8,7 +8,7 @@ from pydantic import field_validator from fields.base import ResponseModel from graphon.variables.types import SegmentType -from libs.helper import TimestampField +from libs.helper import TimestampField, to_timestamp from ._value_type_serializer import serialize_value_type @@ -37,12 +37,6 @@ conversation_variable_infinite_scroll_pagination_fields = { } -def _to_timestamp(value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return int(value.timestamp()) - return value - - class ConversationVariableResponse(ResponseModel): id: str name: str @@ -88,7 +82,7 @@ class ConversationVariableResponse(ResponseModel): @field_validator("created_at", "updated_at", mode="before") @classmethod def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) + return to_timestamp(value) class PaginatedConversationVariableResponse(ResponseModel): diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index ad8b95e4dc..a3987a7e40 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -5,12 +5,7 @@ from datetime import datetime from pydantic import field_validator from fields.base import ResponseModel - - -def _to_timestamp(value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return int(value.timestamp()) - return value +from libs.helper import to_timestamp class UploadConfig(ResponseModel): @@ -45,7 +40,7 @@ class FileResponse(ResponseModel): @field_validator("created_at", mode="before") @classmethod def _normalize_created_at(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) + return to_timestamp(value) class RemoteFileInfo(ResponseModel): @@ -66,7 +61,7 @@ class FileWithSignedUrl(ResponseModel): @field_validator("created_at", mode="before") @classmethod def _normalize_created_at(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) + return to_timestamp(value) __all__ = [ diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index 67b320beaa..7ae5e3b652 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -6,7 +6,7 @@ from flask_restx import fields from pydantic import computed_field, field_validator from fields.base import ResponseModel -from graphon.file import helpers as file_helpers +from libs.helper import build_avatar_url, to_timestamp simple_account_fields = { "id": fields.String, @@ -15,20 +15,6 @@ simple_account_fields = { } -def _to_timestamp(value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return int(value.timestamp()) - return value - - -def _build_avatar_url(avatar: str | None) -> str | None: - if avatar is None: - return None - if avatar.startswith(("http://", "https://")): - return avatar - return file_helpers.get_signed_file_url(avatar) - - class SimpleAccount(ResponseModel): id: str name: str @@ -41,7 +27,7 @@ class _AccountAvatar(ResponseModel): @computed_field(return_type=str | None) # type: ignore[prop-decorator] @property def avatar_url(self) -> str | None: - return _build_avatar_url(self.avatar) + return build_avatar_url(self.avatar) class Account(_AccountAvatar): @@ -59,7 +45,7 @@ class Account(_AccountAvatar): @field_validator("last_login_at", "created_at", mode="before") @classmethod def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) + return to_timestamp(value) class AccountWithRole(_AccountAvatar): @@ -75,7 +61,7 @@ class AccountWithRole(_AccountAvatar): @field_validator("last_login_at", "last_active_at", "created_at", mode="before") @classmethod def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) + return to_timestamp(value) class AccountWithRoleList(ResponseModel): diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index ca18f1c203..e0d37dd701 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -9,6 +9,7 @@ from core.entities.execution_extra_content import ExecutionExtraContentDomainMod from fields.base import ResponseModel from fields.conversation_fields import AgentThought, JSONValue, MessageFile from graphon.file import File +from libs.helper import to_timestamp type JSONValueType = JSONValue @@ -39,9 +40,7 @@ class RetrieverResource(ResponseModel): @field_validator("created_at", mode="before") @classmethod def _normalize_created_at(cls, value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return to_timestamp(value) - return value + return to_timestamp(value) class MessageListItem(ResponseModel): @@ -68,9 +67,7 @@ class MessageListItem(ResponseModel): @field_validator("created_at", mode="before") @classmethod def _normalize_created_at(cls, value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return to_timestamp(value) - return value + return to_timestamp(value) class WebMessageListItem(MessageListItem): @@ -106,9 +103,7 @@ class SavedMessageItem(ResponseModel): @field_validator("created_at", mode="before") @classmethod def _normalize_created_at(cls, value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return to_timestamp(value) - return value + return to_timestamp(value) class SavedMessageInfiniteScrollPagination(ResponseModel): @@ -121,12 +116,6 @@ class SuggestedQuestionsResponse(ResponseModel): data: list[str] -def to_timestamp(value: datetime | None) -> int | None: - if value is None: - return None - return int(value.timestamp()) - - def format_files_contained(value: JSONValueType) -> JSONValueType: if isinstance(value, File): # Response payloads must preserve legacy file keys like `related_id`/`url` diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py index 1b2c71255d..a70f051807 100644 --- a/api/fields/workflow_app_log_fields.py +++ b/api/fields/workflow_app_log_fields.py @@ -17,7 +17,7 @@ from fields.workflow_run_fields import ( workflow_run_for_archived_log_fields, workflow_run_for_log_fields, ) -from libs.helper import TimestampField +from libs.helper import TimestampField, to_timestamp workflow_app_log_partial_fields = { "id": fields.String, @@ -96,12 +96,6 @@ def build_workflow_archived_log_pagination_model(api_or_ns: Namespace): return api_or_ns.model("WorkflowArchivedLogPagination", copied_fields) -def _to_timestamp(value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return int(value.timestamp()) - return value - - class WorkflowAppLogPartialResponse(ResponseModel): id: str workflow_run: WorkflowRunForLogResponse | None = None @@ -115,7 +109,7 @@ class WorkflowAppLogPartialResponse(ResponseModel): @field_validator("created_at", mode="before") @classmethod def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) + return to_timestamp(value) class WorkflowArchivedLogPartialResponse(ResponseModel): @@ -129,7 +123,7 @@ class WorkflowArchivedLogPartialResponse(ResponseModel): @field_validator("created_at", mode="before") @classmethod def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) + return to_timestamp(value) class WorkflowAppLogPaginationResponse(ResponseModel): diff --git a/api/fields/workflow_comment_fields.py b/api/fields/workflow_comment_fields.py deleted file mode 100644 index c708dd3460..0000000000 --- a/api/fields/workflow_comment_fields.py +++ /dev/null @@ -1,96 +0,0 @@ -from flask_restx import fields - -from libs.helper import AvatarUrlField, TimestampField - -# basic account fields for comments -account_fields = { - "id": fields.String, - "name": fields.String, - "email": fields.String, - "avatar_url": AvatarUrlField, -} - -# Comment mention fields -workflow_comment_mention_fields = { - "mentioned_user_id": fields.String, - "mentioned_user_account": fields.Nested(account_fields, allow_null=True), - "reply_id": fields.String, -} - -# Comment reply fields -workflow_comment_reply_fields = { - "id": fields.String, - "content": fields.String, - "created_by": fields.String, - "created_by_account": fields.Nested(account_fields, allow_null=True), - "created_at": TimestampField, -} - -# Basic comment fields (for list views) -workflow_comment_basic_fields = { - "id": fields.String, - "position_x": fields.Float, - "position_y": fields.Float, - "content": fields.String, - "created_by": fields.String, - "created_by_account": fields.Nested(account_fields, allow_null=True), - "created_at": TimestampField, - "updated_at": TimestampField, - "resolved": fields.Boolean, - "resolved_at": TimestampField, - "resolved_by": fields.String, - "resolved_by_account": fields.Nested(account_fields, allow_null=True), - "reply_count": fields.Integer, - "mention_count": fields.Integer, - "participants": fields.List(fields.Nested(account_fields)), -} - -# Detailed comment fields (for single comment view) -workflow_comment_detail_fields = { - "id": fields.String, - "position_x": fields.Float, - "position_y": fields.Float, - "content": fields.String, - "created_by": fields.String, - "created_by_account": fields.Nested(account_fields, allow_null=True), - "created_at": TimestampField, - "updated_at": TimestampField, - "resolved": fields.Boolean, - "resolved_at": TimestampField, - "resolved_by": fields.String, - "resolved_by_account": fields.Nested(account_fields, allow_null=True), - "replies": fields.List(fields.Nested(workflow_comment_reply_fields)), - "mentions": fields.List(fields.Nested(workflow_comment_mention_fields)), -} - -# Comment creation response fields (simplified) -workflow_comment_create_fields = { - "id": fields.String, - "created_at": TimestampField, -} - -# Comment update response fields (simplified) -workflow_comment_update_fields = { - "id": fields.String, - "updated_at": TimestampField, -} - -# Comment resolve response fields -workflow_comment_resolve_fields = { - "id": fields.String, - "resolved": fields.Boolean, - "resolved_at": TimestampField, - "resolved_by": fields.String, -} - -# Reply creation response fields (simplified) -workflow_comment_reply_create_fields = { - "id": fields.String, - "created_at": TimestampField, -} - -# Reply update response fields -workflow_comment_reply_update_fields = { - "id": fields.String, - "updated_at": TimestampField, -} diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index a852f21bb2..53cdfa234f 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -16,7 +16,7 @@ from pydantic import AliasChoices, Field, field_validator from fields.base import ResponseModel from fields.end_user_fields import SimpleEndUser from fields.member_fields import SimpleAccount -from libs.helper import TimestampField +from libs.helper import TimestampField, to_timestamp workflow_run_for_log_fields = { "id": fields.String, @@ -50,12 +50,6 @@ def build_workflow_run_for_archived_log_model(api_or_ns: Namespace): return api_or_ns.model("WorkflowRunForArchivedLog", workflow_run_for_archived_log_fields) -def _to_timestamp(value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return int(value.timestamp()) - return value - - class WorkflowRunForLogResponse(ResponseModel): id: str version: str | None = None @@ -79,7 +73,7 @@ class WorkflowRunForLogResponse(ResponseModel): @field_validator("created_at", "finished_at", mode="before") @classmethod def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) + return to_timestamp(value) class WorkflowRunForArchivedLogResponse(ResponseModel): @@ -120,7 +114,7 @@ class WorkflowRunForListResponse(ResponseModel): @field_validator("created_at", "finished_at", mode="before") @classmethod def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) + return to_timestamp(value) class AdvancedChatWorkflowRunForListResponse(WorkflowRunForListResponse): @@ -180,7 +174,7 @@ class WorkflowRunDetailResponse(ResponseModel): @field_validator("created_at", "finished_at", mode="before") @classmethod def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) + return to_timestamp(value) class WorkflowRunNodeExecutionResponse(ResponseModel): @@ -217,7 +211,7 @@ class WorkflowRunNodeExecutionResponse(ResponseModel): @field_validator("created_at", "finished_at", mode="before") @classmethod def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - return _to_timestamp(value) + return to_timestamp(value) class WorkflowRunNodeExecutionListResponse(ResponseModel): diff --git a/api/libs/helper.py b/api/libs/helper.py index ac69a11084..b66324a5d7 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -10,7 +10,7 @@ import uuid from collections.abc import Callable, Generator, Mapping from datetime import datetime from hashlib import sha256 -from typing import TYPE_CHECKING, Annotated, Any, Protocol, cast +from typing import TYPE_CHECKING, Annotated, Any, Protocol, cast, overload from uuid import UUID from zoneinfo import available_timezones @@ -136,6 +136,14 @@ def build_icon_url(icon_type: Any, icon: str | None) -> str | None: return file_helpers.get_signed_file_url(icon) +def build_avatar_url(avatar: str | None) -> str | None: + if avatar is None: + return None + if avatar.startswith(("http://", "https://")): + return avatar + return file_helpers.get_signed_file_url(avatar) + + class AvatarUrlField(fields.Raw): def output(self, key, obj, **kwargs): if obj is None: @@ -144,9 +152,7 @@ class AvatarUrlField(fields.Raw): from models import Account if isinstance(obj, Account) and obj.avatar is not None: - if obj.avatar.startswith(("http://", "https://")): - return obj.avatar - return file_helpers.get_signed_file_url(obj.avatar) + return build_avatar_url(obj.avatar) return None @@ -162,6 +168,35 @@ class OptionalTimestampField(fields.Raw): return int(value.timestamp()) +@overload +def to_timestamp(value: datetime) -> int: ... + + +@overload +def to_timestamp(value: int) -> int: ... + + +@overload +def to_timestamp(value: None) -> None: ... + + +def to_timestamp(value: datetime | int | None) -> int | None: + """Normalize API response timestamp values to epoch seconds.""" + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +def dump_response(model: type[BaseModel], data: Any) -> dict[str, Any]: + """Serialize a Pydantic response model to JSON-compatible dict output.""" + return model.model_validate(data, from_attributes=True).model_dump(mode="json") + + +def current_timestamp() -> int: + """Return the current Unix timestamp in seconds.""" + return int(time.time()) + + def email(email): # Define a regex pattern for email addresses pattern = r"^[\w\.!#$%&'*+\-/=?^_`{|}~]+@([\w-]+\.)+[\w-]{2,}$" diff --git a/api/libs/oauth.py b/api/libs/oauth.py index 3daaa038e0..309f2aa812 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -1,3 +1,6 @@ +import base64 +import binascii +import json import logging import urllib.parse from dataclasses import dataclass @@ -27,6 +30,12 @@ class AccessTokenResponse(TypedDict, total=False): access_token: str +class OAuthState(TypedDict, total=False): + invite_token: str + timezone: str + language: str + + class GitHubEmailRecord(TypedDict, total=False): email: str primary: bool @@ -46,6 +55,7 @@ class GoogleRawUserInfo(TypedDict): ACCESS_TOKEN_RESPONSE_ADAPTER = TypeAdapter(AccessTokenResponse) +OAUTH_STATE_ADAPTER = TypeAdapter(OAuthState) GITHUB_RAW_USER_INFO_ADAPTER = TypeAdapter(GitHubRawUserInfo) GITHUB_EMAIL_RECORDS_ADAPTER = TypeAdapter(list[GitHubEmailRecord]) GOOGLE_RAW_USER_INFO_ADAPTER = TypeAdapter(GoogleRawUserInfo) @@ -58,6 +68,37 @@ class OAuthUserInfo: email: str +def encode_oauth_state( + invite_token: str | None = None, + timezone: str | None = None, + language: str | None = None, +) -> str | None: + state: OAuthState = {} + if invite_token: + state["invite_token"] = invite_token + if timezone: + state["timezone"] = timezone + if language: + state["language"] = language + if not state: + return None + + raw_state = json.dumps(state, separators=(",", ":")).encode("utf-8") + return base64.urlsafe_b64encode(raw_state).decode("ascii").rstrip("=") + + +def decode_oauth_state(state: str | None) -> OAuthState: + if not state: + return {} + + try: + padded_state = state + "=" * (-len(state) % 4) + raw_state = base64.urlsafe_b64decode(padded_state.encode("ascii")).decode("utf-8") + return OAUTH_STATE_ADAPTER.validate_python(json.loads(raw_state)) + except (binascii.Error, ValueError, UnicodeDecodeError, json.JSONDecodeError, ValidationError): + return {} + + def _json_object(response: httpx.Response) -> JsonObject: return JSON_OBJECT_ADAPTER.validate_python(response.json()) @@ -76,7 +117,12 @@ class OAuth: self.client_secret = client_secret self.redirect_uri = redirect_uri - def get_authorization_url(self, invite_token: str | None = None) -> str: + def get_authorization_url( + self, + invite_token: str | None = None, + timezone: str | None = None, + language: str | None = None, + ) -> str: raise NotImplementedError() def get_access_token(self, code: str) -> str: @@ -99,14 +145,20 @@ class GitHubOAuth(OAuth): _USER_INFO_URL = "https://api.github.com/user" _EMAIL_INFO_URL = "https://api.github.com/user/emails" - def get_authorization_url(self, invite_token: str | None = None) -> str: + def get_authorization_url( + self, + invite_token: str | None = None, + timezone: str | None = None, + language: str | None = None, + ) -> str: params = { "client_id": self.client_id, "redirect_uri": self.redirect_uri, "scope": "user:email", # Request only basic user information } - if invite_token: - params["state"] = invite_token + state = encode_oauth_state(invite_token=invite_token, timezone=timezone, language=language) + if state: + params["state"] = state return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" def get_access_token(self, code: str) -> str: @@ -186,15 +238,21 @@ class GoogleOAuth(OAuth): _TOKEN_URL = "https://oauth2.googleapis.com/token" _USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo" - def get_authorization_url(self, invite_token: str | None = None) -> str: + def get_authorization_url( + self, + invite_token: str | None = None, + timezone: str | None = None, + language: str | None = None, + ) -> str: params = { "client_id": self.client_id, "response_type": "code", "redirect_uri": self.redirect_uri, "scope": "openid email", } - if invite_token: - params["state"] = invite_token + state = encode_oauth_state(invite_token=invite_token, timezone=timezone, language=language) + if state: + params["state"] = state return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" def get_access_token(self, code: str) -> str: diff --git a/api/openapi/markdown/console-swagger.md b/api/openapi/markdown/console-swagger.md index d5827fce3f..08642c15d1 100644 --- a/api/openapi/markdown/console-swagger.md +++ b/api/openapi/markdown/console-swagger.md @@ -2404,7 +2404,7 @@ Get all comments for a workflow | Code | Description | Schema | | ---- | ----------- | ------ | -| 200 | Comments retrieved successfully | [WorkflowCommentBasic](#workflowcommentbasic) | +| 200 | Comments retrieved successfully | [WorkflowCommentBasicList](#workflowcommentbasiclist) | #### POST ##### Summary @@ -7515,7 +7515,7 @@ Remove one or more tag bindings from a target. | Name | Located in | Description | Required | Schema | | ---- | ---------- | ----------- | -------- | ------ | | tag_id | path | | Yes | string | -| payload | body | | Yes | [TagBasePayload](#tagbasepayload) | +| payload | body | | Yes | [TagUpdateRequestPayload](#tagupdaterequestpayload) | ##### Responses @@ -11596,6 +11596,7 @@ Request payload for bulk downloading documents as a zip archive. | code | string | | Yes | | email | string | | Yes | | language | | | No | +| timezone | | | No | | token | string | | Yes | #### EmailPayload @@ -11609,8 +11610,10 @@ Request payload for bulk downloading documents as a zip archive. | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | +| language | | | No | | new_password | string | | Yes | | password_confirm | string | | Yes | +| timezone | | | No | | token | string | | Yes | #### EmailRegisterSendPayload @@ -13495,6 +13498,12 @@ Tag type | ---- | ---- | ----------- | -------- | | TagType | string | Tag type | | +#### TagUpdateRequestPayload + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| name | string | Tag name | Yes | + #### TenantAccountRole | Name | Type | Description | Required | @@ -13975,32 +13984,47 @@ in form definiton, or a variable while the workflow is running. | trigger_metadata | | | No | | workflow_run | | | No | +#### WorkflowCommentAccount + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| avatar_url | | | Yes | +| email | string | | Yes | +| id | string | | Yes | +| name | string | | Yes | + #### WorkflowCommentBasic | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| content | string | | No | -| created_at | object | | No | -| created_by | string | | No | -| created_by_account | [_AnonymousInlineModel_6fec07cd0d85](#_anonymousinlinemodel_6fec07cd0d85) | | No | -| id | string | | No | -| mention_count | integer | | No | -| participants | [ [_AnonymousInlineModel_6fec07cd0d85](#_anonymousinlinemodel_6fec07cd0d85) ] | | No | -| position_x | number | | No | -| position_y | number | | No | -| reply_count | integer | | No | -| resolved | boolean | | No | -| resolved_at | object | | No | -| resolved_by | string | | No | -| resolved_by_account | [_AnonymousInlineModel_6fec07cd0d85](#_anonymousinlinemodel_6fec07cd0d85) | | No | -| updated_at | object | | No | +| content | string | | Yes | +| created_at | | | No | +| created_by | string | | Yes | +| created_by_account | | | No | +| id | string | | Yes | +| mention_count | integer | | Yes | +| participants | [ [WorkflowCommentAccount](#workflowcommentaccount) ] | | Yes | +| position_x | number | | Yes | +| position_y | number | | Yes | +| reply_count | integer | | Yes | +| resolved | boolean | | Yes | +| resolved_at | | | No | +| resolved_by | | | No | +| resolved_by_account | | | No | +| updated_at | | | No | + +#### WorkflowCommentBasicList + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| data | [ [WorkflowCommentBasic](#workflowcommentbasic) ] | | Yes | #### WorkflowCommentCreate | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| created_at | object | | No | -| id | string | | No | +| created_at | | | No | +| id | string | | Yes | #### WorkflowCommentCreatePayload @@ -14015,20 +14039,28 @@ in form definiton, or a variable while the workflow is running. | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| content | string | | No | -| created_at | object | | No | -| created_by | string | | No | -| created_by_account | [_AnonymousInlineModel_6fec07cd0d85](#_anonymousinlinemodel_6fec07cd0d85) | | No | -| id | string | | No | -| mentions | [ [_AnonymousInlineModel_f7ff64cce858](#_anonymousinlinemodel_f7ff64cce858) ] | | No | -| position_x | number | | No | -| position_y | number | | No | -| replies | [ [_AnonymousInlineModel_55c39c6a4b9e](#_anonymousinlinemodel_55c39c6a4b9e) ] | | No | -| resolved | boolean | | No | -| resolved_at | object | | No | -| resolved_by | string | | No | -| resolved_by_account | [_AnonymousInlineModel_6fec07cd0d85](#_anonymousinlinemodel_6fec07cd0d85) | | No | -| updated_at | object | | No | +| content | string | | Yes | +| created_at | | | No | +| created_by | string | | Yes | +| created_by_account | | | No | +| id | string | | Yes | +| mentions | [ [WorkflowCommentMention](#workflowcommentmention) ] | | Yes | +| position_x | number | | Yes | +| position_y | number | | Yes | +| replies | [ [WorkflowCommentReply](#workflowcommentreply) ] | | Yes | +| resolved | boolean | | Yes | +| resolved_at | | | No | +| resolved_by | | | No | +| resolved_by_account | | | No | +| updated_at | | | No | + +#### WorkflowCommentMention + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| mentioned_user_account | | | No | +| mentioned_user_id | string | | Yes | +| reply_id | | | No | #### WorkflowCommentMentionUsersPayload @@ -14036,12 +14068,22 @@ in form definiton, or a variable while the workflow is running. | ---- | ---- | ----------- | -------- | | users | [ [AccountWithRole](#accountwithrole) ] | | Yes | +#### WorkflowCommentReply + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| content | string | | Yes | +| created_at | | | No | +| created_by | string | | Yes | +| created_by_account | | | No | +| id | string | | Yes | + #### WorkflowCommentReplyCreate | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| created_at | object | | No | -| id | string | | No | +| created_at | | | No | +| id | string | | Yes | #### WorkflowCommentReplyPayload @@ -14054,24 +14096,24 @@ in form definiton, or a variable while the workflow is running. | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| id | string | | No | -| updated_at | object | | No | +| id | string | | Yes | +| updated_at | | | No | #### WorkflowCommentResolve | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| id | string | | No | -| resolved | boolean | | No | -| resolved_at | object | | No | -| resolved_by | string | | No | +| id | string | | Yes | +| resolved | boolean | | Yes | +| resolved_at | | | No | +| resolved_by | | | No | #### WorkflowCommentUpdate | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| id | string | | No | -| updated_at | object | | No | +| id | string | | Yes | +| updated_at | | | No | #### WorkflowCommentUpdatePayload @@ -14475,25 +14517,6 @@ Workflow tool configuration | limit | integer | | No | | page | integer | | No | -#### _AnonymousInlineModel_55c39c6a4b9e - -| Name | Type | Description | Required | -| ---- | ---- | ----------- | -------- | -| content | string | | No | -| created_at | object | | No | -| created_by | string | | No | -| created_by_account | [_AnonymousInlineModel_6fec07cd0d85](#_anonymousinlinemodel_6fec07cd0d85) | | No | -| id | string | | No | - -#### _AnonymousInlineModel_6fec07cd0d85 - -| Name | Type | Description | Required | -| ---- | ---- | ----------- | -------- | -| avatar_url | object | | No | -| email | string | | No | -| id | string | | No | -| name | string | | No | - #### _AnonymousInlineModel_b1954337d565 | Name | Type | Description | Required | @@ -14503,14 +14526,6 @@ Workflow tool configuration | model_provider_name | string | | No | | summary_prompt | string | | No | -#### _AnonymousInlineModel_f7ff64cce858 - -| Name | Type | Description | Required | -| ---- | ---- | ----------- | -------- | -| mentioned_user_account | [_AnonymousInlineModel_6fec07cd0d85](#_anonymousinlinemodel_6fec07cd0d85) | | No | -| mentioned_user_id | string | | No | -| reply_id | string | | No | - ## FastOpenAPI Preview (OpenAPI 3.0) ### Dify API (FastOpenAPI PoC) diff --git a/api/pyproject.toml b/api/pyproject.toml index 2ab9f4ca71..3ddc1e30a9 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -42,7 +42,6 @@ dependencies = [ "readabilipy>=0.3.0,<1.0.0", "resend>=2.27.0,<3.0.0", # Emerging: newer and fast-moving, use compatible pins - "dify-agent", "fastopenapi[flask]~=0.7.0", "graphon~=0.4.0", "httpx-sse~=0.4.0", @@ -115,6 +114,7 @@ override-dependencies = [ ############################################################ dev = [ "coverage>=7.13.4", + "dify-agent", "dotenv-linter>=0.7.0", "faker>=40.15.0", "lxml-stubs>=0.5.1", diff --git a/api/schedule/trigger_provider_refresh_task.py b/api/schedule/trigger_provider_refresh_task.py index df5058d70a..1fefdb9e5a 100644 --- a/api/schedule/trigger_provider_refresh_task.py +++ b/api/schedule/trigger_provider_refresh_task.py @@ -1,6 +1,5 @@ import logging import math -import time from collections.abc import Iterable, Sequence from celery import group @@ -13,16 +12,13 @@ from configs import dify_config from core.trigger.utils.locks import build_trigger_refresh_lock_keys from extensions.ext_database import db from extensions.ext_redis import redis_client +from libs.helper import current_timestamp from models.trigger import TriggerSubscription from tasks.trigger_subscription_refresh_tasks import trigger_subscription_refresh logger = logging.getLogger(__name__) -def _now_ts() -> int: - return int(time.time()) - - def _build_due_filter(now_ts: int): """Build SQLAlchemy filter for due credential or subscription refresh.""" credential_due: ColumnElement[bool] = and_( @@ -54,7 +50,7 @@ def trigger_provider_refresh() -> None: """ Scan due trigger subscriptions and enqueue refresh tasks with in-flight locks. """ - now: int = _now_ts() + now: int = current_timestamp() batch_size: int = int(dify_config.TRIGGER_PROVIDER_REFRESH_BATCH_SIZE) lock_ttl: int = max(300, int(dify_config.TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS)) diff --git a/api/services/account_service.py b/api/services/account_service.py index 744c17d126..6533526b60 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -29,6 +29,7 @@ from extensions.ext_database import db from extensions.ext_redis import redis_client, redis_fallback from libs.datetime_utils import naive_utc_now from libs.helper import RateLimiter, TokenManager +from libs.helper import timezone as validate_timezone from libs.passport import PassportService from libs.password import compare_password, hash_password, valid_password from libs.rsa import generate_key_pair @@ -271,8 +272,9 @@ class AccountService: password: str | None = None, interface_theme: str = "light", is_setup: bool | None = False, + timezone: str | None = None, ) -> Account: - """create account""" + """Create an account, preferring explicit user timezone over language-derived defaults.""" if not FeatureService.get_system_features().is_allow_register and not is_setup: from controllers.console.error import AccountNotFound @@ -302,6 +304,10 @@ class AccountService: password_to_set = base64_password_hashed salt_to_set = base64_salt + resolved_timezone = language_timezone_mapping.get(interface_language, "UTC") + if timezone is not None: + resolved_timezone = validate_timezone(timezone) + account = Account( name=name, email=email, @@ -309,7 +315,7 @@ class AccountService: password_salt=salt_to_set, interface_language=interface_language, interface_theme=interface_theme, - timezone=language_timezone_mapping.get(interface_language, "UTC"), + timezone=resolved_timezone, ) db.session.add(account) @@ -318,11 +324,15 @@ class AccountService: @staticmethod def create_account_and_tenant( - email: str, name: str, interface_language: str, password: str | None = None + email: str, name: str, interface_language: str, password: str | None = None, timezone: str | None = None ) -> Account: - """create account""" + """Create an account and owner workspace.""" account = AccountService.create_account( - email=email, name=name, interface_language=interface_language, password=password + email=email, + name=name, + interface_language=interface_language, + password=password, + timezone=timezone, ) try: @@ -1474,8 +1484,8 @@ class RegisterService: @classmethod def register( cls, - email, - name, + email: str, + name: str, password: str | None = None, open_id: str | None = None, provider: str | None = None, @@ -1483,16 +1493,19 @@ class RegisterService: status: AccountStatus | None = None, is_setup: bool | None = False, create_workspace_required: bool | None = True, + timezone: str | None = None, ) -> Account: - db.session.begin_nested() """Register account""" + db.session.begin_nested() try: + interface_language = get_valid_language(language) account = AccountService.create_account( email=email, name=name, - interface_language=get_valid_language(language), + interface_language=interface_language, password=password, is_setup=is_setup, + timezone=timezone, ) account.status = status or AccountStatus.ACTIVE account.initialized_at = naive_utc_now() diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 97aaea3395..7ba2b64c74 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -10,7 +10,6 @@ from uuid import uuid4 import yaml from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad -from packaging import version from packaging.version import parse as parse_version from pydantic import BaseModel from sqlalchemy import select @@ -40,6 +39,7 @@ from libs.datetime_utils import naive_utc_now from models import Account, App, AppMode from models.model import AppModelConfig, AppModelConfigDict, IconType from models.workflow import Workflow +from services.dsl_version import check_version_compatibility from services.entities.dsl_entities import CheckDependenciesResult, ImportMode, ImportStatus from services.plugin.dependencies_analysis import DependenciesAnalysisService from services.workflow_draft_variable_service import WorkflowDraftVariableService @@ -64,30 +64,6 @@ class Import(BaseModel): error: str = "" -def _check_version_compatibility(imported_version: str) -> ImportStatus: - """Determine import status based on version comparison""" - try: - current_ver = version.parse(CURRENT_DSL_VERSION) - imported_ver = version.parse(imported_version) - except version.InvalidVersion: - return ImportStatus.FAILED - - # If imported version is newer than current, always return PENDING - if imported_ver > current_ver: - return ImportStatus.PENDING - - # If imported version is older than current's major, return PENDING - if imported_ver.major < current_ver.major: - return ImportStatus.PENDING - - # If imported version is older than current's minor, return COMPLETED_WITH_WARNINGS - if imported_ver.minor < current_ver.minor: - return ImportStatus.COMPLETED_WITH_WARNINGS - - # If imported version equals or is older than current's micro, return COMPLETED - return ImportStatus.COMPLETED - - class PendingData(BaseModel): import_mode: str yaml_content: str @@ -203,7 +179,7 @@ class AppDslService: # check if imported_version is a float-like string if not isinstance(imported_version, str): raise ValueError(f"Invalid version type, expected str, got {type(imported_version)}") - status = _check_version_compatibility(imported_version) + status = check_version_compatibility(imported_version, CURRENT_DSL_VERSION) # Extract app data app_data = data.get("app") diff --git a/api/services/dsl_version.py b/api/services/dsl_version.py new file mode 100644 index 0000000000..cb7384df70 --- /dev/null +++ b/api/services/dsl_version.py @@ -0,0 +1,20 @@ +from packaging import version + +from services.entities.dsl_entities import ImportStatus + + +def check_version_compatibility(imported_version: str, current_version: str) -> ImportStatus: + """Determine DSL import status based on imported and current versions.""" + try: + current_ver = version.parse(current_version) + imported_ver = version.parse(imported_version) + except version.InvalidVersion: + return ImportStatus.FAILED + + if imported_ver > current_ver: + return ImportStatus.PENDING + if imported_ver.major < current_ver.major: + return ImportStatus.PENDING + if imported_ver.minor < current_ver.minor: + return ImportStatus.COMPLETED_WITH_WARNINGS + return ImportStatus.COMPLETED diff --git a/api/services/file_service.py b/api/services/file_service.py index b683a2f3d4..cd412638e0 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -172,12 +172,14 @@ class FileService: return upload_file - def get_file_preview(self, file_id: str): + def get_file_preview(self, file_id: str, tenant_id: str): """ Return a short text preview extracted from a document file. """ with self._session_maker(expire_on_commit=False) as session: - upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) + upload_file = session.scalar( + select(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).limit(1) + ) if not upload_file: raise NotFound("File not found") diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index f315d053cb..37ebffbeb4 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -13,7 +13,6 @@ import yaml # type: ignore from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad from flask_login import current_user -from packaging import version from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import Session @@ -37,6 +36,7 @@ from models import Account from models.dataset import Dataset, DatasetCollectionBinding, Pipeline from models.enums import CollectionBindingType, DatasetRuntimeMode from models.workflow import Workflow, WorkflowType +from services.dsl_version import check_version_compatibility from services.entities.dsl_entities import CheckDependenciesResult, ImportMode, ImportStatus from services.entities.knowledge_entities.rag_pipeline_entities import ( IconInfo, @@ -64,30 +64,6 @@ class RagPipelineImportInfo(BaseModel): dataset_id: str | None = None -def _check_version_compatibility(imported_version: str) -> ImportStatus: - """Determine import status based on version comparison""" - try: - current_ver = version.parse(CURRENT_DSL_VERSION) - imported_ver = version.parse(imported_version) - except version.InvalidVersion: - return ImportStatus.FAILED - - # If imported version is newer than current, always return PENDING - if imported_ver > current_ver: - return ImportStatus.PENDING - - # If imported version is older than current's major, return PENDING - if imported_ver.major < current_ver.major: - return ImportStatus.PENDING - - # If imported version is older than current's minor, return COMPLETED_WITH_WARNINGS - if imported_ver.minor < current_ver.minor: - return ImportStatus.COMPLETED_WITH_WARNINGS - - # If imported version equals or is older than current's micro, return COMPLETED - return ImportStatus.COMPLETED - - class RagPipelinePendingData(BaseModel): import_mode: str yaml_content: str @@ -100,6 +76,13 @@ class CheckDependenciesPendingData(BaseModel): class RagPipelineDslService: + """Import, export, and inspect RAG pipeline DSL using the caller-owned session. + + Controllers wrap this service in a SQLAlchemy transaction context, so methods must only flush interim changes when + generated IDs are needed. Committing inside the service would close the caller's transaction and break later work in + the same context manager. + """ + def __init__(self, session: Session): self._session = session @@ -195,7 +178,7 @@ class RagPipelineDslService: # check if imported_version is a float-like string if not isinstance(imported_version, str): raise ValueError(f"Invalid version type, expected str, got {type(imported_version)}") - status = _check_version_compatibility(imported_version) + status = check_version_compatibility(imported_version, CURRENT_DSL_VERSION) # Extract app data pipeline_data = data.get("rag_pipeline") @@ -325,7 +308,7 @@ class RagPipelineDslService: type=CollectionBindingType.DATASET, ) self._session.add(dataset_collection_binding) - self._session.commit() + self._session.flush() dataset_collection_binding_id = dataset_collection_binding.id dataset.collection_binding_id = dataset_collection_binding_id dataset.embedding_model = knowledge_configuration.embedding_model @@ -337,7 +320,7 @@ class RagPipelineDslService: dataset.summary_index_setting = knowledge_configuration.summary_index_setting dataset.pipeline_id = pipeline.id self._session.add(dataset) - self._session.commit() + self._session.flush() dataset_id = dataset.id if not dataset_id: raise ValueError("DSL is not valid, please check the Knowledge Index node.") @@ -462,7 +445,7 @@ class RagPipelineDslService: type=CollectionBindingType.DATASET, ) self._session.add(dataset_collection_binding) - self._session.commit() + self._session.flush() dataset_collection_binding_id = dataset_collection_binding.id dataset.collection_binding_id = dataset_collection_binding_id dataset.embedding_model = knowledge_configuration.embedding_model @@ -474,7 +457,7 @@ class RagPipelineDslService: dataset.summary_index_setting = knowledge_configuration.summary_index_setting dataset.pipeline_id = pipeline.id self._session.add(dataset) - self._session.commit() + self._session.flush() dataset_id = dataset.id if not dataset_id: raise ValueError("DSL is not valid, please check the Knowledge Index node.") @@ -585,7 +568,7 @@ class RagPipelineDslService: pipeline.id = str(uuid4()) self._session.add(pipeline) - self._session.commit() + self._session.flush() # save dependencies if dependencies: redis_client.setex( @@ -627,8 +610,8 @@ class RagPipelineDslService: workflow.environment_variables = environment_variables workflow.conversation_variables = conversation_variables workflow.rag_pipeline_variables = rag_pipeline_variables_list - # commit db session changes - self._session.commit() + # Keep transaction ownership with the caller while materializing IDs and constraint checks before returning. + self._session.flush() return pipeline diff --git a/api/services/tag_service.py b/api/services/tag_service.py index 8043a99be1..09d49d8b3e 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -21,7 +21,6 @@ class SaveTagPayload(BaseModel): class UpdateTagPayload(BaseModel): name: str = Field(min_length=1, max_length=50) - type: TagType class TagBindingCreatePayload(BaseModel): diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index 39564bbede..ba26c20331 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -7,10 +7,12 @@ from sqlalchemy import delete, select from core.db.session_factory import session_factory from core.indexing_runner import DocumentIsPausedError, IndexingRunner +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment from models.enums import IndexingStatus +from tasks.generate_summary_index_task import generate_summary_index_task logger = logging.getLogger(__name__) @@ -70,6 +72,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id) session.execute(segment_delete_stmt) + has_error = False try: indexing_runner = IndexingRunner() indexing_runner.run([document]) @@ -77,5 +80,45 @@ def document_indexing_update_task(dataset_id: str, document_id: str): logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green")) except DocumentIsPausedError as ex: logger.info(click.style(str(ex), fg="yellow")) + has_error = True except Exception: logger.exception("document_indexing_update_task failed, document_id: %s", document_id) + has_error = True + + if has_error: + return + + # Trigger summary index generation for the updated document if enabled. + # Only generate for high_quality indexing technique and when summary_index_setting is enabled. + with session_factory.create_session() as session: + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) + if not dataset: + logger.warning("Dataset %s not found after update indexing", dataset_id) + return + + if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: + summary_index_setting = dataset.summary_index_setting + if summary_index_setting and summary_index_setting.get("enable"): + session.expire_all() + document = session.scalar( + select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1) + ) + if ( + document + and document.indexing_status == IndexingStatus.COMPLETED + and document.doc_form != IndexStructureType.QA_INDEX + and document.need_summary is True + ): + try: + generate_summary_index_task.delay(dataset.id, document.id, None) + logger.info( + "Queued summary index generation task for document %s in dataset %s " + "after update indexing completed", + document.id, + dataset.id, + ) + except Exception: + logger.exception( + "Failed to queue summary index generation task for document %s after update", + document.id, + ) diff --git a/api/tasks/trigger_subscription_refresh_tasks.py b/api/tasks/trigger_subscription_refresh_tasks.py index 1daf8f302c..f6552fb294 100644 --- a/api/tasks/trigger_subscription_refresh_tasks.py +++ b/api/tasks/trigger_subscription_refresh_tasks.py @@ -1,5 +1,4 @@ import logging -import time from collections.abc import Mapping from typing import Any @@ -12,16 +11,13 @@ from core.db.session_factory import session_factory from core.plugin.entities.plugin_daemon import CredentialType from core.trigger.utils.locks import build_trigger_refresh_lock_key from extensions.ext_redis import redis_client +from libs.helper import current_timestamp from models.trigger import TriggerSubscription from services.trigger.trigger_provider_service import TriggerProviderService logger = logging.getLogger(__name__) -def _now_ts() -> int: - return int(time.time()) - - def _load_subscription(session: Session, tenant_id: str, subscription_id: str) -> TriggerSubscription | None: return session.scalar( select(TriggerSubscription) @@ -96,7 +92,7 @@ def trigger_subscription_refresh(tenant_id: str, subscription_id: str) -> None: logger.info("Begin subscription refresh: tenant=%s id=%s", tenant_id, subscription_id) try: - now: int = _now_ts() + now: int = current_timestamp() with session_factory.create_session() as session: subscription: TriggerSubscription | None = _load_subscription(session, tenant_id, subscription_id) diff --git a/api/tests/test_containers_integration_tests/conftest.py b/api/tests/test_containers_integration_tests/conftest.py index b4482674da..901da9cbe2 100644 --- a/api/tests/test_containers_integration_tests/conftest.py +++ b/api/tests/test_containers_integration_tests/conftest.py @@ -22,7 +22,7 @@ from sqlalchemy import Engine, text from sqlalchemy.orm import Session from testcontainers.core.container import DockerContainer from testcontainers.core.network import Network -from testcontainers.core.waiting_utils import wait_for_logs +from testcontainers.core.wait_strategies import LogMessageWaitStrategy from testcontainers.postgres import PostgresContainer from testcontainers.redis import RedisContainer @@ -54,6 +54,10 @@ def _auto_close[T: _CloserProtocol](closer: T) -> Generator[T, None, None]: closer.close() +def _wait_for_log_message(message: str, timeout: int) -> LogMessageWaitStrategy: + return LogMessageWaitStrategy(message).with_startup_timeout(timeout) + + class DifyTestContainers: """ Manages all test containers required for Dify integration tests. @@ -99,6 +103,7 @@ class DifyTestContainers: self.postgres = PostgresContainer( image="postgres:14-alpine", ).with_network(self.network) + self.postgres.waiting_for(_wait_for_log_message("is ready to accept connections", 30)) self.postgres.start() db_host = self.postgres.get_container_host_ip() db_port = self.postgres.get_exposed_port(5432) @@ -115,9 +120,6 @@ class DifyTestContainers: self.postgres.dbname, ) - # Wait for PostgreSQL to be ready - logger.info("Waiting for PostgreSQL to be ready to accept connections...") - wait_for_logs(self.postgres, "is ready to accept connections", timeout=30) logger.info("PostgreSQL container is ready and accepting connections") conn = psycopg2.connect( @@ -152,6 +154,7 @@ class DifyTestContainers: # Redis is used for storing session data, cache entries, and temporary data logger.info("Initializing Redis container...") self.redis = RedisContainer(image="redis:6-alpine", port=6379).with_network(self.network) + self.redis.waiting_for(_wait_for_log_message("Ready to accept connections", 30)) self.redis.start() redis_host = self.redis.get_container_host_ip() redis_port = self.redis.get_exposed_port(6379) @@ -159,9 +162,6 @@ class DifyTestContainers: os.environ["REDIS_PORT"] = str(redis_port) logger.info("Redis container started successfully - Host: %s, Port: %s", redis_host, redis_port) - # Wait for Redis to be ready - logger.info("Waiting for Redis to be ready to accept connections...") - wait_for_logs(self.redis, "Ready to accept connections", timeout=30) logger.info("Redis container is ready and accepting connections") # Start Dify Sandbox container for code execution environment. @@ -170,6 +170,7 @@ class DifyTestContainers: sandbox_image = os.getenv(SANDBOX_TEST_IMAGE_ENV, DEFAULT_SANDBOX_TEST_IMAGE) self.dify_sandbox = DockerContainer(image=sandbox_image).with_network(self.network) self.dify_sandbox.with_exposed_ports(8194) + self.dify_sandbox.waiting_for(_wait_for_log_message("config init success", 60)) self.dify_sandbox.env = { "API_KEY": "test_api_key", } @@ -185,9 +186,6 @@ class DifyTestContainers: sandbox_port, ) - # Wait for Dify Sandbox to be ready - logger.info("Waiting for Dify Sandbox to be ready to accept connections...") - wait_for_logs(self.dify_sandbox, "config init success", timeout=60) logger.info("Dify Sandbox container is ready and accepting connections") # Start Dify Plugin Daemon container for plugin management @@ -197,6 +195,7 @@ class DifyTestContainers: self.network ) self.dify_plugin_daemon.with_exposed_ports(5002) + self.dify_plugin_daemon.waiting_for(_wait_for_log_message("start plugin manager daemon", 60)) # Get container internal network addresses postgres_container_name = self.postgres.get_wrapped_container().name redis_container_name = self.redis.get_wrapped_container().name @@ -243,9 +242,6 @@ class DifyTestContainers: plugin_daemon_port, ) - # Wait for Dify Plugin Daemon to be ready - logger.info("Waiting for Dify Plugin Daemon to be ready to accept connections...") - wait_for_logs(self.dify_plugin_daemon, "start plugin manager daemon", timeout=60) logger.info("Dify Plugin Daemon container is ready and accepting connections") except Exception as e: logger.warning("Failed to start Dify Plugin Daemon container: %s", e) diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py index bb737754a1..b13bdba2bc 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py @@ -8,7 +8,9 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask +from flask.testing import FlaskClient from pydantic import ValidationError +from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, NotFound from controllers.console import console_ns @@ -57,6 +59,12 @@ from controllers.console.app.workflow_app_log import WorkflowAppLogQuery from controllers.console.app.workflow_draft_variable import WorkflowDraftVariableUpdatePayload from controllers.console.app.workflow_statistic import WorkflowStatisticQuery from controllers.console.app.workflow_trigger import Parser, ParserEnable +from models.model import AppMode +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, + create_console_app, +) def _unwrap(func): @@ -270,6 +278,35 @@ class TestOpsTraceEndpoints: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers + @pytest.mark.parametrize( + "path_template", + [ + "/console/api/apps/{app_id}/trace-config?tracing_provider=langfuse", + "/console/api/apps/{app_id}/trace", + ], + ) + def test_trace_endpoints_hide_apps_from_other_tenants( + self, + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, + path_template: str, + ): + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + foreign_account, foreign_tenant = create_console_account_and_tenant(db_session_with_containers) + foreign_app = create_console_app( + db_session_with_containers, + tenant_id=foreign_tenant.id, + account_id=foreign_account.id, + mode=AppMode.CHAT, + ) + + response = test_client_with_containers.get( + path_template.format(app_id=foreign_app.id), + headers=authenticate_console_client(test_client_with_containers, account), + ) + + assert response.status_code == 404 + def test_ops_trace_query_basic(self): query = TraceProviderQuery(tracing_provider="langfuse") assert query.tracing_provider == "langfuse" @@ -289,7 +326,7 @@ class TestOpsTraceEndpoints: ) with app.test_request_context("/?tracing_provider=langfuse"): - result = method(app_id="app-1") + result = method(app_model=MagicMock(id="app-1")) assert result == {"has_not_configured": True} @@ -308,7 +345,7 @@ class TestOpsTraceEndpoints: json={"tracing_provider": "langfuse", "tracing_config": {"api_key": "k"}}, ): with pytest.raises(BadRequest): - method(app_id="app-1") + method(app_model=MagicMock(id="app-1")) def test_trace_app_config_delete_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = ops_trace_module.TraceAppConfigApi() @@ -322,7 +359,7 @@ class TestOpsTraceEndpoints: with app.test_request_context("/?tracing_provider=langfuse"): with pytest.raises(BadRequest): - method(app_id="app-1") + method(app_model=MagicMock(id="app-1")) class TestSiteEndpoints: diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py index 5a22f81a69..0efd77934e 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py @@ -6,17 +6,17 @@ import uuid from flask.testing import FlaskClient from sqlalchemy.orm import Session -from configs import dify_config from constants import HEADER_NAME_CSRF_TOKEN from graphon.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from libs.token import _real_cookie_name, generate_csrf_token -from models import Account, DifySetup, Tenant, TenantAccountJoin +from models import Account, Tenant, TenantAccountJoin from models.account import AccountStatus, TenantAccountRole, TenantStatus from models.enums import ConversationFromSource, CreatorUserRole from models.model import App, AppMode, Conversation, Message from models.workflow import WorkflowRun from services.account_service import AccountService +from tests.test_containers_integration_tests.controllers.console.helpers import ensure_dify_setup def _create_account_and_tenant(db_session: Session) -> tuple[Account, Tenant]: @@ -47,9 +47,7 @@ def _create_account_and_tenant(db_session: Session) -> tuple[Account, Tenant]: account.timezone = "UTC" db_session.commit() - dify_setup = DifySetup(version=dify_config.project.version) - db_session.add(dify_setup) - db_session.commit() + ensure_dify_setup(db_session) return account, tenant diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py index 1fcce9ca44..bb7921a5f4 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py @@ -143,7 +143,118 @@ class TestEmailRegisterResetApi: response = EmailRegisterResetApi().post() assert response == {"result": "success", "data": {"access_token": "a", "refresh_token": "r"}} - mock_create_account.assert_called_once_with("invitee@example.com", "ValidPass123!") + mock_create_account.assert_called_once_with( + email="invitee@example.com", + password="ValidPass123!", + timezone=None, + language=None, + ) + mock_reset_login_rate.assert_called_once_with("invitee@example.com") + mock_revoke_token.assert_called_once_with("token-123") + mock_extract_ip.assert_called_once() + + @patch("controllers.console.auth.email_register.AccountService.reset_login_error_rate_limit") + @patch("controllers.console.auth.email_register.AccountService.login") + @patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account") + @patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token") + @patch("controllers.console.auth.email_register.AccountService.get_email_register_data") + @patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1") + def test_reset_passes_timezone_to_new_account( + self, + mock_extract_ip, + mock_get_data, + mock_revoke_token, + mock_get_account, + mock_create_account, + mock_login, + mock_reset_login_rate, + app: Flask, + ): + mock_get_data.return_value = {"phase": "register", "email": "Invitee@Example.com"} + mock_create_account.return_value = MagicMock() + token_pair = MagicMock() + token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"} + mock_login.return_value = token_pair + mock_get_account.return_value = None + + feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) + with ( + patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), + patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags), + ): + with app.test_request_context( + "/email-register", + method="POST", + json={ + "token": "token-123", + "new_password": "ValidPass123!", + "password_confirm": "ValidPass123!", + "timezone": "Asia/Shanghai", + }, + ): + response = EmailRegisterResetApi().post() + + assert response == {"result": "success", "data": {"access_token": "a", "refresh_token": "r"}} + mock_create_account.assert_called_once_with( + email="invitee@example.com", + password="ValidPass123!", + timezone="Asia/Shanghai", + language=None, + ) + mock_reset_login_rate.assert_called_once_with("invitee@example.com") + mock_revoke_token.assert_called_once_with("token-123") + mock_extract_ip.assert_called_once() + + @patch("controllers.console.auth.email_register.AccountService.reset_login_error_rate_limit") + @patch("controllers.console.auth.email_register.AccountService.login") + @patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account") + @patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token") + @patch("controllers.console.auth.email_register.AccountService.get_email_register_data") + @patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1") + def test_reset_passes_language_to_new_account( + self, + mock_extract_ip, + mock_get_data, + mock_revoke_token, + mock_get_account, + mock_create_account, + mock_login, + mock_reset_login_rate, + app: Flask, + ): + mock_get_data.return_value = {"phase": "register", "email": "Invitee@Example.com"} + mock_create_account.return_value = MagicMock() + token_pair = MagicMock() + token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"} + mock_login.return_value = token_pair + mock_get_account.return_value = None + + feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) + with ( + patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), + patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags), + ): + with app.test_request_context( + "/email-register", + method="POST", + json={ + "token": "token-123", + "new_password": "ValidPass123!", + "password_confirm": "ValidPass123!", + "language": "zh-Hans", + }, + ): + response = EmailRegisterResetApi().post() + + assert response == {"result": "success", "data": {"access_token": "a", "refresh_token": "r"}} + mock_create_account.assert_called_once_with( + email="invitee@example.com", + password="ValidPass123!", + timezone=None, + language="zh-Hans", + ) mock_reset_login_rate.assert_called_once_with("invitee@example.com") mock_revoke_token.assert_called_once_with("token-123") mock_extract_ip.assert_called_once() diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py index 55b6a919d8..a5ae83739c 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py @@ -14,7 +14,7 @@ from controllers.console.auth.oauth import ( _get_account_by_openid_or_email, get_oauth_providers, ) -from libs.oauth import OAuthUserInfo +from libs.oauth import OAuthUserInfo, encode_oauth_state from models.account import AccountStatus from services.account_service import AccountService from services.errors.account import AccountRegisterError @@ -101,7 +101,55 @@ class TestOAuthLogin: with app.test_request_context(f"/auth/oauth/github?{query_string}"): resource.get("github") - mock_oauth_provider.get_authorization_url.assert_called_once_with(invite_token=expected_token) + mock_oauth_provider.get_authorization_url.assert_called_once_with( + invite_token=expected_token, + timezone=None, + language=None, + ) + mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?...") + + @patch("controllers.console.auth.oauth.get_oauth_providers") + @patch("controllers.console.auth.oauth.redirect") + def test_should_pass_timezone_to_oauth_state( + self, + mock_redirect, + mock_get_providers, + resource, + app: Flask, + mock_oauth_provider, + ): + mock_get_providers.return_value = {"github": mock_oauth_provider, "google": None} + + with app.test_request_context("/auth/oauth/github?timezone=Asia/Shanghai"): + resource.get("github") + + mock_oauth_provider.get_authorization_url.assert_called_once_with( + invite_token=None, + timezone="Asia/Shanghai", + language=None, + ) + mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?...") + + @patch("controllers.console.auth.oauth.get_oauth_providers") + @patch("controllers.console.auth.oauth.redirect") + def test_should_pass_language_to_oauth_state( + self, + mock_redirect, + mock_get_providers, + resource, + app: Flask, + mock_oauth_provider, + ): + mock_get_providers.return_value = {"github": mock_oauth_provider, "google": None} + + with app.test_request_context("/auth/oauth/github?language=zh-Hans"): + resource.get("github") + + mock_oauth_provider.get_authorization_url.assert_called_once_with( + invite_token=None, + timezone=None, + language="zh-Hans", + ) mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?...") @pytest.mark.parametrize( @@ -229,7 +277,8 @@ class TestOAuthCallback: mock_register_service.is_valid_invite_token.return_value = True mock_register_service.get_invitation_by_token.return_value = {"email": "user@example.com"} - with app.test_request_context("/auth/oauth/github/callback?code=test_code&state=invite123"): + state = encode_oauth_state(invite_token="invite123", timezone="Asia/Shanghai") + with app.test_request_context(f"/auth/oauth/github/callback?code=test_code&state={state}"): resource.get("github") mock_register_service.get_invitation_by_token.assert_called_once_with(token="invite123") @@ -488,7 +537,13 @@ class TestAccountGeneration: if should_create: mock_register_service.register.assert_called_once_with( - email="test@example.com", name="Test User", password=None, open_id="123", provider="github" + email="test@example.com", + name="Test User", + password=None, + open_id="123", + provider="github", + language="en-US", + timezone=None, ) else: mock_register_service.register.assert_not_called() @@ -515,7 +570,75 @@ class TestAccountGeneration: _generate_account("github", user_info) mock_register_service.register.assert_called_once_with( - email="upper@example.com", name="Test User", password=None, open_id="123", provider="github" + email="upper@example.com", + name="Test User", + password=None, + open_id="123", + provider="github", + language="en-US", + timezone=None, + ) + + @patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None) + @patch("controllers.console.auth.oauth.FeatureService") + @patch("controllers.console.auth.oauth.RegisterService") + @patch("controllers.console.auth.oauth.AccountService") + @patch("controllers.console.auth.oauth.TenantService") + def test_should_register_with_browser_timezone( + self, + mock_tenant_service, + mock_account_service, + mock_register_service, + mock_feature_service, + mock_get_account, + app: Flask, + user_info, + ): + mock_feature_service.get_system_features.return_value.is_allow_register = True + mock_register_service.register.return_value = MagicMock() + + with app.test_request_context(headers={"Accept-Language": "zh-Hans,zh;q=0.9"}): + _generate_account("github", user_info, timezone="Asia/Shanghai") + + mock_register_service.register.assert_called_once_with( + email="test@example.com", + name="Test User", + password=None, + open_id="123", + provider="github", + language="zh-Hans", + timezone="Asia/Shanghai", + ) + + @patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None) + @patch("controllers.console.auth.oauth.FeatureService") + @patch("controllers.console.auth.oauth.RegisterService") + @patch("controllers.console.auth.oauth.AccountService") + @patch("controllers.console.auth.oauth.TenantService") + def test_should_register_with_state_language( + self, + mock_tenant_service, + mock_account_service, + mock_register_service, + mock_feature_service, + mock_get_account, + app: Flask, + user_info, + ): + mock_feature_service.get_system_features.return_value.is_allow_register = True + mock_register_service.register.return_value = MagicMock() + + with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}): + _generate_account("github", user_info, language="zh-Hans") + + mock_register_service.register.assert_called_once_with( + email="test@example.com", + name="Test User", + password=None, + open_id="123", + provider="github", + language="zh-Hans", + timezone=None, ) @patch("controllers.console.auth.oauth._get_account_by_openid_or_email") diff --git a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py index c77bbd3e44..ca3ae6d0cf 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py @@ -35,9 +35,9 @@ from services.app_dsl_service import ( ImportMode, ImportStatus, PendingData, - _check_version_compatibility, ) from services.app_service import AppService, CreateAppParams +from services.dsl_version import check_version_compatibility from tests.test_containers_integration_tests.helpers import generate_valid_password _DEFAULT_TENANT_ID = "00000000-0000-0000-0000-000000000001" @@ -193,22 +193,25 @@ class TestAppDslService: # ── Version Compatibility ───────────────────────────────────────── def test_check_version_compatibility_invalid_version_returns_failed(self): - assert _check_version_compatibility("not-a-version") == ImportStatus.FAILED + assert check_version_compatibility("not-a-version", app_dsl_service.CURRENT_DSL_VERSION) == ImportStatus.FAILED def test_check_version_compatibility_newer_version_returns_pending(self): - assert _check_version_compatibility("99.0.0") == ImportStatus.PENDING + assert check_version_compatibility("99.0.0", app_dsl_service.CURRENT_DSL_VERSION) == ImportStatus.PENDING def test_check_version_compatibility_major_older_returns_pending(self, monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(app_dsl_service, "CURRENT_DSL_VERSION", "1.0.0") - assert _check_version_compatibility("0.9.9") == ImportStatus.PENDING + assert check_version_compatibility("0.9.9", app_dsl_service.CURRENT_DSL_VERSION) == ImportStatus.PENDING def test_check_version_compatibility_minor_older_returns_completed_with_warnings( self, ): - assert _check_version_compatibility("0.5.0") == ImportStatus.COMPLETED_WITH_WARNINGS + assert ( + check_version_compatibility("0.5.0", app_dsl_service.CURRENT_DSL_VERSION) + == ImportStatus.COMPLETED_WITH_WARNINGS + ) def test_check_version_compatibility_equal_returns_completed(self): - assert _check_version_compatibility(CURRENT_DSL_VERSION) == ImportStatus.COMPLETED + assert check_version_compatibility(CURRENT_DSL_VERSION, CURRENT_DSL_VERSION) == ImportStatus.COMPLETED # ── Import: Validation ──────────────────────────────────────────── diff --git a/api/tests/test_containers_integration_tests/services/test_file_service.py b/api/tests/test_containers_integration_tests/services/test_file_service.py index 42dbdef1c9..4532005836 100644 --- a/api/tests/test_containers_integration_tests/services/test_file_service.py +++ b/api/tests/test_containers_integration_tests/services/test_file_service.py @@ -514,7 +514,7 @@ class TestFileService: db_session_with_containers.commit() - result = FileService(engine).get_file_preview(file_id=upload_file.id) + result = FileService(engine).get_file_preview(file_id=upload_file.id, tenant_id=upload_file.tenant_id) assert result == "extracted text content" mock_external_service_dependencies["extract_processor"].load_from_upload_file.assert_called_once() @@ -529,7 +529,7 @@ class TestFileService: non_existent_id = str(fake.uuid4()) with pytest.raises(NotFound, match="File not found"): - FileService(engine).get_file_preview(file_id=non_existent_id) + FileService(engine).get_file_preview(file_id=non_existent_id, tenant_id=str(fake.uuid4())) def test_get_file_preview_unsupported_file_type( self, db_session_with_containers: Session, engine, mock_external_service_dependencies @@ -549,7 +549,7 @@ class TestFileService: db_session_with_containers.commit() with pytest.raises(UnsupportedFileTypeError): - FileService(engine).get_file_preview(file_id=upload_file.id) + FileService(engine).get_file_preview(file_id=upload_file.id, tenant_id=upload_file.tenant_id) def test_get_file_preview_text_truncation( self, db_session_with_containers: Session, engine, mock_external_service_dependencies @@ -572,7 +572,7 @@ class TestFileService: long_text = "x" * 5000 # Longer than PREVIEW_WORDS_LIMIT mock_external_service_dependencies["extract_processor"].load_from_upload_file.return_value = long_text - result = FileService(engine).get_file_preview(file_id=upload_file.id) + result = FileService(engine).get_file_preview(file_id=upload_file.id, tenant_id=upload_file.tenant_id) assert len(result) == 3000 # PREVIEW_WORDS_LIMIT assert result == "x" * 3000 diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index 583b6128e6..f088cc964d 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -759,7 +759,7 @@ class TestTagService: tag = TagService.save_tags(tag_args) # Update args - update_args = UpdateTagPayload(name="updated_name", type="knowledge") + update_args = UpdateTagPayload(name="updated_name") # Act: Execute the method under test result = TagService.update_tags(update_args, tag.id) @@ -799,7 +799,7 @@ class TestTagService: non_existent_tag_id = str(uuid.uuid4()) - update_args = UpdateTagPayload(name="updated_name", type="knowledge") + update_args = UpdateTagPayload(name="updated_name") # Act & Assert: Verify proper error handling with pytest.raises(NotFound) as exc_info: @@ -830,7 +830,7 @@ class TestTagService: tag2 = TagService.save_tags(tag2_args) # Try to update second tag with first tag's name - update_args = UpdateTagPayload(name="first_tag", type="app") + update_args = UpdateTagPayload(name="first_tag") # Act & Assert: Verify proper error handling with pytest.raises(ValueError) as exc_info: diff --git a/api/tests/unit_tests/commands/test_reset_encrypt_key_pair.py b/api/tests/unit_tests/commands/test_reset_encrypt_key_pair.py new file mode 100644 index 0000000000..31b4d71d0f --- /dev/null +++ b/api/tests/unit_tests/commands/test_reset_encrypt_key_pair.py @@ -0,0 +1,108 @@ +"""Unit tests for the reset-encrypt-key-pair CLI command (#35396). + +The command must purge every table that stores ciphertext encrypted with the +tenant's asymmetric key, otherwise stale rows cause downstream API failures +such as `/console/api/workspaces/current/tool-providers` returning 500. +""" + +from unittest.mock import MagicMock, patch + +import commands +from commands import system as system_commands +from models.provider import Provider, ProviderModel +from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider + + +def _invoke_reset() -> int: + try: + commands.reset_encrypt_key_pair.callback() + except SystemExit as e: + return int(e.code or 0) + return 0 + + +def _delete_targets(session_mock: MagicMock) -> list: + """Extract the model class targeted by each `delete(...)` call on the session.""" + targets = [] + for call in session_mock.execute.call_args_list: + stmt = call.args[0] + # `delete(Foo)` constructs a `Delete` statement whose entity is `Foo`. + try: + targets.append(stmt.table.name) + except AttributeError: + targets.append(repr(stmt)) + return targets + + +def test_reset_aborts_when_not_self_hosted(monkeypatch, capsys): + monkeypatch.setattr(system_commands.dify_config, "EDITION", "CLOUD") + + exit_code = _invoke_reset() + captured = capsys.readouterr() + + assert exit_code == 0 + assert "only for SELF_HOSTED" in captured.out + + +def test_reset_purges_provider_and_tool_tables_for_each_tenant(monkeypatch, capsys): + """The command must purge LLM provider rows AND every tool provider table + that stores ciphertext encrypted under the tenant key (#35396).""" + monkeypatch.setattr(system_commands.dify_config, "EDITION", "SELF_HOSTED") + monkeypatch.setattr(system_commands, "generate_key_pair", lambda tenant_id: f"new-key-{tenant_id}") + + fake_tenant = MagicMock(id="tenant-abc", encrypt_public_key="old-key") + session = MagicMock() + session.scalars.return_value.all.return_value = [fake_tenant] + + fake_sessionmaker = MagicMock() + fake_sessionmaker.begin.return_value.__enter__.return_value = session + fake_sessionmaker.begin.return_value.__exit__.return_value = False + + with ( + patch.object(system_commands, "db", MagicMock()), + patch.object(system_commands, "sessionmaker", return_value=fake_sessionmaker), + ): + exit_code = _invoke_reset() + + captured = capsys.readouterr() + assert exit_code == 0 + assert "tenant-abc" in captured.out + + # New key pair generated and assigned. + assert fake_tenant.encrypt_public_key == "new-key-tenant-abc" + + # Every encrypted-credential table should have been purged for this tenant. + table_names = _delete_targets(session) + expected = { + Provider.__tablename__, + ProviderModel.__tablename__, + BuiltinToolProvider.__tablename__, + ApiToolProvider.__tablename__, + MCPToolProvider.__tablename__, + } + assert expected.issubset(set(table_names)), f"missing purges: expected {expected}, got {table_names}" + + +def test_reset_iterates_all_tenants(monkeypatch, capsys): + """Multi-tenant deployments must purge every tenant, not just the first.""" + monkeypatch.setattr(system_commands.dify_config, "EDITION", "SELF_HOSTED") + monkeypatch.setattr(system_commands, "generate_key_pair", lambda tenant_id: f"new-key-{tenant_id}") + + tenants = [MagicMock(id=f"tenant-{i}", encrypt_public_key="old") for i in range(3)] + session = MagicMock() + session.scalars.return_value.all.return_value = tenants + + fake_sessionmaker = MagicMock() + fake_sessionmaker.begin.return_value.__enter__.return_value = session + fake_sessionmaker.begin.return_value.__exit__.return_value = False + + with ( + patch.object(system_commands, "db", MagicMock()), + patch.object(system_commands, "sessionmaker", return_value=fake_sessionmaker), + ): + _invoke_reset() + + # Five purges per tenant × 3 tenants = 15 execute calls. + assert session.execute.call_count == 15 + for tenant in tenants: + assert tenant.encrypt_public_key == f"new-key-{tenant.id}" diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_comment_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_comment_api.py index 85afcf0e60..baa21999f9 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_comment_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_comment_api.py @@ -17,6 +17,15 @@ from controllers.console.app import wraps as app_wraps from libs import login as login_lib from models.account import Account, AccountStatus, TenantAccountRole +JAN_1_2024_NOON = datetime(2024, 1, 1, 12, 0, 0) +JAN_1_2024_NOON_TS = int(JAN_1_2024_NOON.timestamp()) +JAN_1_2024_1201 = datetime(2024, 1, 1, 12, 1, 0) +JAN_1_2024_1201_TS = int(JAN_1_2024_1201.timestamp()) +JAN_1_2024_1202 = datetime(2024, 1, 1, 12, 2, 0) +JAN_1_2024_1202_TS = int(JAN_1_2024_1202.timestamp()) +JAN_1_2024_1203 = datetime(2024, 1, 1, 12, 3, 0) +JAN_1_2024_1203_TS = int(JAN_1_2024_1203.timestamp()) + def _make_account(role: TenantAccountRole) -> Account: account = Account(name="tester", email="tester@example.com") @@ -78,6 +87,30 @@ class WriteCase: payload: dict[str, object] | None = None +@dataclass(frozen=True) +class MutationResponseCase: + resource_cls: type + method_name: str + path: str + kwargs: dict[str, str] + service_method_name: str + service_return: object + expected_response: dict[str, object] + payload: dict[str, object] | None = None + expected_status: int | None = None + + +def _unwrap_response(result: object) -> tuple[dict[str, object], int | None]: + if isinstance(result, tuple): + response, status = result + assert isinstance(response, dict) + assert isinstance(status, int) + return response, status + + assert isinstance(result, dict) + return result, None + + @pytest.mark.parametrize( "case", [ @@ -151,17 +184,20 @@ def test_create_comment_allows_editor(app: Flask, monkeypatch: pytest.MonkeyPatc create_comment_mock = MagicMock(return_value={"id": "comment-1"}) monkeypatch.setattr(workflow_comment_module.WorkflowCommentService, "create_comment", create_comment_mock) - payload = {"content": "hello", "position_x": 1.0, "position_y": 2.0, "mentioned_user_ids": []} + payload: dict[str, object] = { + "content": "hello", + "position_x": 1.0, + "position_y": 2.0, + "mentioned_user_ids": [], + } with app.test_request_context("/console/api/apps/app-123/workflow/comments", method="POST", json=payload): with _patch_payload(payload): result = workflow_comment_module.WorkflowCommentListApi().post(app_id="app-123") - if isinstance(result, tuple): - response = result[0] - else: - response = result + response, status = _unwrap_response(result) assert response["id"] == "comment-1" + assert status == 201 create_comment_mock.assert_called_once_with( tenant_id="tenant-123", app_id="app-123", @@ -181,14 +217,17 @@ def test_update_comment_omits_mentions_when_payload_does_not_include_them( app_model = _make_app() _patch_console_guards(monkeypatch, account, app_model) - update_comment_mock = MagicMock(return_value={"id": "comment-1", "updated_at": datetime(2024, 1, 1, 12, 0, 0)}) + update_comment_mock = MagicMock(return_value={"id": "comment-1", "updated_at": JAN_1_2024_NOON}) monkeypatch.setattr(workflow_comment_module.WorkflowCommentService, "update_comment", update_comment_mock) - payload = {"content": "hello", "position_x": 10.0, "position_y": 20.0} + payload: dict[str, object] = {"content": "hello", "position_x": 10.0, "position_y": 20.0} with app.test_request_context("/console/api/apps/app-123/workflow/comments/comment-1", method="PUT", json=payload): with _patch_payload(payload): - workflow_comment_module.WorkflowCommentDetailApi().put(app_id="app-123", comment_id="comment-1") + result = workflow_comment_module.WorkflowCommentDetailApi().put(app_id="app-123", comment_id="comment-1") + response, status = _unwrap_response(result) + assert response == {"id": "comment-1", "updated_at": JAN_1_2024_NOON_TS} + assert status is None update_comment_mock.assert_called_once_with( tenant_id="tenant-123", app_id="app-123", @@ -199,3 +238,254 @@ def test_update_comment_omits_mentions_when_payload_does_not_include_them( position_y=20.0, mentioned_user_ids=None, ) + + +def test_list_comments_serializes_response_model(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + app.config.setdefault("RESTX_MASK_HEADER", "X-Fields") + account = _make_account(TenantAccountRole.NORMAL) + app_model = _make_app() + _patch_console_guards(monkeypatch, account, app_model) + + comment_author = SimpleNamespace( + id="account-123", + name="tester", + email="tester@example.com", + avatar="https://example.com/avatar.png", + ) + comment = SimpleNamespace( + id="comment-1", + position_x=1.5, + position_y=2.5, + content="hello", + created_by="account-123", + created_by_account=comment_author, + created_at=1_700_000_000, + updated_at=1_700_000_001, + resolved=False, + resolved_at=None, + resolved_by=None, + resolved_by_account=None, + reply_count=0, + mention_count=0, + participants=[comment_author], + ) + get_comments_mock = MagicMock(return_value=[comment]) + monkeypatch.setattr(workflow_comment_module.WorkflowCommentService, "get_comments", get_comments_mock) + + with app.test_request_context("/console/api/apps/app-123/workflow/comments", method="GET"): + response = workflow_comment_module.WorkflowCommentListApi().get(app_id="app-123") + + assert response == { + "data": [ + { + "id": "comment-1", + "position_x": 1.5, + "position_y": 2.5, + "content": "hello", + "created_by": "account-123", + "created_by_account": { + "id": "account-123", + "name": "tester", + "email": "tester@example.com", + "avatar_url": "https://example.com/avatar.png", + }, + "created_at": 1_700_000_000, + "updated_at": 1_700_000_001, + "resolved": False, + "resolved_at": None, + "resolved_by": None, + "resolved_by_account": None, + "reply_count": 0, + "mention_count": 0, + "participants": [ + { + "id": "account-123", + "name": "tester", + "email": "tester@example.com", + "avatar_url": "https://example.com/avatar.png", + } + ], + } + ] + } + get_comments_mock.assert_called_once_with(tenant_id="tenant-123", app_id="app-123") + + +def test_get_comment_serializes_detail_response_model(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + app.config.setdefault("RESTX_MASK_HEADER", "X-Fields") + account = _make_account(TenantAccountRole.NORMAL) + app_model = _make_app() + _patch_console_guards(monkeypatch, account, app_model) + + comment_author = SimpleNamespace( + id="account-123", + name="tester", + email="tester@example.com", + avatar="https://example.com/avatar.png", + ) + mentioned_user = SimpleNamespace( + id="account-456", + name="mentioned", + email="mentioned@example.com", + avatar=None, + ) + comment = SimpleNamespace( + id="comment-1", + position_x=1.5, + position_y=2.5, + content="hello", + created_by="account-123", + created_by_account=comment_author, + created_at=JAN_1_2024_NOON, + updated_at=JAN_1_2024_1201, + resolved=True, + resolved_at=JAN_1_2024_1202, + resolved_by="account-123", + resolved_by_account=comment_author, + replies=[ + SimpleNamespace( + id="reply-1", + content="reply", + created_by="account-456", + created_by_account=mentioned_user, + created_at=JAN_1_2024_1203, + ) + ], + mentions=[ + SimpleNamespace( + mentioned_user_id="account-456", + mentioned_user_account=mentioned_user, + reply_id="reply-1", + ) + ], + ) + get_comment_mock = MagicMock(return_value=comment) + monkeypatch.setattr(workflow_comment_module.WorkflowCommentService, "get_comment", get_comment_mock) + + with app.test_request_context("/console/api/apps/app-123/workflow/comments/comment-1", method="GET"): + response = workflow_comment_module.WorkflowCommentDetailApi().get(app_id="app-123", comment_id="comment-1") + + assert response == { + "id": "comment-1", + "position_x": 1.5, + "position_y": 2.5, + "content": "hello", + "created_by": "account-123", + "created_by_account": { + "id": "account-123", + "name": "tester", + "email": "tester@example.com", + "avatar_url": "https://example.com/avatar.png", + }, + "created_at": JAN_1_2024_NOON_TS, + "updated_at": JAN_1_2024_1201_TS, + "resolved": True, + "resolved_at": JAN_1_2024_1202_TS, + "resolved_by": "account-123", + "resolved_by_account": { + "id": "account-123", + "name": "tester", + "email": "tester@example.com", + "avatar_url": "https://example.com/avatar.png", + }, + "replies": [ + { + "id": "reply-1", + "content": "reply", + "created_by": "account-456", + "created_by_account": { + "id": "account-456", + "name": "mentioned", + "email": "mentioned@example.com", + "avatar_url": None, + }, + "created_at": JAN_1_2024_1203_TS, + } + ], + "mentions": [ + { + "mentioned_user_id": "account-456", + "mentioned_user_account": { + "id": "account-456", + "name": "mentioned", + "email": "mentioned@example.com", + "avatar_url": None, + }, + "reply_id": "reply-1", + } + ], + } + get_comment_mock.assert_called_once_with(tenant_id="tenant-123", app_id="app-123", comment_id="comment-1") + + +@pytest.mark.parametrize( + "case", + [ + MutationResponseCase( + resource_cls=workflow_comment_module.WorkflowCommentResolveApi, + method_name="post", + path="/console/api/apps/app-123/workflow/comments/comment-1/resolve", + kwargs={"app_id": "app-123", "comment_id": "comment-1"}, + service_method_name="resolve_comment", + service_return={ + "id": "comment-1", + "resolved": True, + "resolved_at": JAN_1_2024_NOON, + "resolved_by": "account-123", + }, + expected_response={ + "id": "comment-1", + "resolved": True, + "resolved_at": JAN_1_2024_NOON_TS, + "resolved_by": "account-123", + }, + ), + MutationResponseCase( + resource_cls=workflow_comment_module.WorkflowCommentReplyApi, + method_name="post", + path="/console/api/apps/app-123/workflow/comments/comment-1/replies", + kwargs={"app_id": "app-123", "comment_id": "comment-1"}, + payload={"content": "reply", "mentioned_user_ids": []}, + service_method_name="create_reply", + service_return={"id": "reply-1", "created_at": JAN_1_2024_NOON}, + expected_response={"id": "reply-1", "created_at": JAN_1_2024_NOON_TS}, + expected_status=201, + ), + MutationResponseCase( + resource_cls=workflow_comment_module.WorkflowCommentReplyDetailApi, + method_name="put", + path="/console/api/apps/app-123/workflow/comments/comment-1/replies/reply-1", + kwargs={"app_id": "app-123", "comment_id": "comment-1", "reply_id": "reply-1"}, + payload={"content": "reply", "mentioned_user_ids": []}, + service_method_name="update_reply", + service_return={"id": "reply-1", "updated_at": JAN_1_2024_NOON}, + expected_response={"id": "reply-1", "updated_at": JAN_1_2024_NOON_TS}, + ), + ], +) +def test_mutation_endpoints_serialize_response_models( + app: Flask, monkeypatch: pytest.MonkeyPatch, case: MutationResponseCase +) -> None: + app.config.setdefault("RESTX_MASK_HEADER", "X-Fields") + account = _make_account(TenantAccountRole.EDITOR) + app_model = _make_app() + _patch_console_guards(monkeypatch, account, app_model) + _patch_write_services(monkeypatch) + monkeypatch.setattr( + workflow_comment_module.WorkflowCommentService, + case.service_method_name, + MagicMock(return_value=case.service_return), + ) + + with app.test_request_context(case.path, method=case.method_name.upper(), json=case.payload): + with _patch_payload(case.payload): + result = getattr(case.resource_cls(), case.method_name)(**case.kwargs) + + response, status = _unwrap_response(result) + assert response == case.expected_response + assert status == case.expected_status + + +def test_workflow_comment_response_schemas_are_registered() -> None: + assert workflow_comment_module.WorkflowCommentBasicList.__name__ in workflow_comment_module.console_ns.models + assert workflow_comment_module.WorkflowCommentDetail.__name__ in workflow_comment_module.console_ns.models diff --git a/api/tests/unit_tests/controllers/console/auth/test_email_register_language.py b/api/tests/unit_tests/controllers/console/auth/test_email_register_language.py new file mode 100644 index 0000000000..df282880af --- /dev/null +++ b/api/tests/unit_tests/controllers/console/auth/test_email_register_language.py @@ -0,0 +1,40 @@ +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from controllers.console.auth.email_register import EmailRegisterResetApi, EmailRegisterResetPayload + + +@patch("controllers.console.auth.email_register.AccountService.create_account_and_tenant") +def test_create_new_account_uses_requested_language(mock_create_account): + account = MagicMock() + mock_create_account.return_value = account + + result = EmailRegisterResetApi()._create_new_account( + "invitee@example.com", + "ValidPass123!", + timezone="Asia/Shanghai", + language="zh-Hans", + ) + + assert result is account + mock_create_account.assert_called_once_with( + email="invitee@example.com", + name="invitee@example.com", + password="ValidPass123!", + interface_language="zh-Hans", + timezone="Asia/Shanghai", + ) + + +def test_reset_payload_rejects_invalid_timezone(): + with pytest.raises(ValidationError): + EmailRegisterResetPayload.model_validate( + { + "token": "token-123", + "new_password": "ValidPass123!", + "password_confirm": "ValidPass123!", + "timezone": "", + } + ) diff --git a/api/tests/unit_tests/controllers/console/auth/test_email_verification.py b/api/tests/unit_tests/controllers/console/auth/test_email_verification.py index 102af9b250..fa23942c65 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_email_verification.py +++ b/api/tests/unit_tests/controllers/console/auth/test_email_verification.py @@ -13,9 +13,10 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask +from pydantic import ValidationError from controllers.console.auth.error import EmailCodeError, InvalidEmailError, InvalidTokenError -from controllers.console.auth.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi +from controllers.console.auth.login import EmailCodeLoginApi, EmailCodeLoginPayload, EmailCodeLoginSendEmailApi from controllers.console.error import ( AccountInFreezeError, AccountNotFound, @@ -31,6 +32,18 @@ def encode_code(code: str) -> str: return base64.b64encode(code.encode("utf-8")).decode() +def test_email_code_login_payload_rejects_invalid_timezone(): + with pytest.raises(ValidationError): + EmailCodeLoginPayload.model_validate( + { + "email": "newuser@example.com", + "code": "123456", + "token": "token-123", + "timezone": "", + } + ) + + class TestEmailCodeLoginSendEmailApi: """Test cases for sending email verification codes.""" @@ -342,6 +355,7 @@ class TestEmailCodeLoginApi: "code": encode_code("123456"), "token": "valid_token", "language": "en-US", + "timezone": "Asia/Shanghai", }, ): api = EmailCodeLoginApi() @@ -349,7 +363,12 @@ class TestEmailCodeLoginApi: # Assert assert response.json["result"] == "success" - mock_create_account.assert_called_once() + mock_create_account.assert_called_once_with( + email="newuser@example.com", + name="newuser@example.com", + interface_language="en-US", + timezone="Asia/Shanghai", + ) @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.AccountService.get_email_code_login_data") diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth_timezone.py b/api/tests/unit_tests/controllers/console/auth/test_oauth_timezone.py new file mode 100644 index 0000000000..36c707dbf9 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/auth/test_oauth_timezone.py @@ -0,0 +1,123 @@ +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.console.auth.oauth import OAuthLogin, _generate_account +from libs.oauth import OAuthUserInfo +from services.errors.account import AccountRegisterError + + +@pytest.fixture +def app() -> Flask: + app = Flask(__name__) + app.config["TESTING"] = True + return app + + +@patch("controllers.console.auth.oauth.redirect") +@patch("controllers.console.auth.oauth.get_oauth_providers") +def test_oauth_login_passes_language_and_timezone_to_authorization_url( + mock_get_oauth_providers, + mock_redirect, + app: Flask, +): + oauth_provider = MagicMock() + oauth_provider.get_authorization_url.return_value = "https://github.com/login/oauth/authorize?state=..." + mock_get_oauth_providers.return_value = {"github": oauth_provider} + + with app.test_request_context("/oauth/login/github?language=zh-Hans&timezone=Asia/Shanghai"): + OAuthLogin().get("github") + + oauth_provider.get_authorization_url.assert_called_once_with( + invite_token=None, + timezone="Asia/Shanghai", + language="zh-Hans", + ) + mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?state=...") + + +@patch("controllers.console.auth.oauth.AccountService.link_account_integrate") +@patch("controllers.console.auth.oauth.RegisterService") +@patch("controllers.console.auth.oauth.FeatureService") +@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None) +def test_generate_account_registers_with_browser_timezone( + mock_get_account, + mock_feature_service, + mock_register_service, + mock_link_account, + app: Flask, +): + account = MagicMock() + mock_register_service.register.return_value = account + mock_feature_service.get_system_features.return_value.is_allow_register = True + user_info = OAuthUserInfo(id="github-123", name="Test User", email="User@Example.com") + + with app.test_request_context(headers={"Accept-Language": "zh-Hans,zh;q=0.9"}): + result, oauth_new_user = _generate_account("github", user_info, timezone="Asia/Shanghai") + + assert result is account + assert oauth_new_user is True + mock_register_service.register.assert_called_once_with( + email="user@example.com", + name="Test User", + password=None, + open_id="github-123", + provider="github", + language="zh-Hans", + timezone="Asia/Shanghai", + ) + mock_link_account.assert_called_once_with("github", "github-123", account) + + +@patch("controllers.console.auth.oauth.AccountService.link_account_integrate") +@patch("controllers.console.auth.oauth.RegisterService") +@patch("controllers.console.auth.oauth.FeatureService") +@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None) +def test_generate_account_prefers_state_language_over_accept_language( + mock_get_account, + mock_feature_service, + mock_register_service, + mock_link_account, + app: Flask, +): + account = MagicMock() + mock_register_service.register.return_value = account + mock_feature_service.get_system_features.return_value.is_allow_register = True + user_info = OAuthUserInfo(id="github-123", name="Test User", email="User@Example.com") + + with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}): + _generate_account("github", user_info, language="zh-Hans") + + mock_register_service.register.assert_called_once_with( + email="user@example.com", + name="Test User", + password=None, + open_id="github-123", + provider="github", + language="zh-Hans", + timezone=None, + ) + mock_link_account.assert_called_once_with("github", "github-123", account) + + +@patch("controllers.console.auth.oauth.dify_config") +@patch("controllers.console.auth.oauth.RegisterService") +@patch("controllers.console.auth.oauth.FeatureService") +@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None) +def test_generate_account_rejects_new_user_when_registration_disabled( + mock_get_account, + mock_feature_service, + mock_register_service, + mock_config, + app: Flask, +): + mock_feature_service.get_system_features.return_value.is_allow_register = False + mock_config.BILLING_ENABLED = False + user_info = OAuthUserInfo(id="github-123", name="Test User", email="user@example.com") + + with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}): + with pytest.raises(AccountRegisterError): + _generate_account("github", user_info) + + mock_register_service.register.assert_not_called() diff --git a/api/tests/unit_tests/controllers/console/tag/test_tags.py b/api/tests/unit_tests/controllers/console/tag/test_tags.py index 8b47da25fb..b44706c566 100644 --- a/api/tests/unit_tests/controllers/console/tag/test_tags.py +++ b/api/tests/unit_tests/controllers/console/tag/test_tags.py @@ -14,6 +14,7 @@ from controllers.console.tag.tags import ( TagUpdateDeleteApi, ) from models.enums import TagType +from services.tag_service import UpdateTagPayload def unwrap(func): @@ -147,7 +148,7 @@ class TestTagUpdateDeleteApi: api = TagUpdateDeleteApi() method = unwrap(api.patch) - payload = {"name": "updated", "type": "knowledge"} + payload = {"name": "updated"} with app.test_request_context("/", json=payload): with ( @@ -159,7 +160,7 @@ class TestTagUpdateDeleteApi: patch( "controllers.console.tag.tags.TagService.update_tags", return_value=tag, - ), + ) as update_tags_mock, patch( "controllers.console.tag.tags.TagService.get_tag_binding_count", return_value=3, @@ -168,6 +169,9 @@ class TestTagUpdateDeleteApi: result, status = method(api, "tag-1") assert status == 200 + update_payload, tag_id = update_tags_mock.call_args.args + assert update_payload == UpdateTagPayload(name="updated") + assert tag_id == "tag-1" assert result["binding_count"] == "3" def test_patch_forbidden(self, app: Flask, readonly_user, payload_patch): diff --git a/api/tests/unit_tests/controllers/console/test_files.py b/api/tests/unit_tests/controllers/console/test_files.py index eebc6f9d60..acae081b98 100644 --- a/api/tests/unit_tests/controllers/console/test_files.py +++ b/api/tests/unit_tests/controllers/console/test_files.py @@ -278,7 +278,7 @@ class TestFileApiPost: class TestFilePreviewApi: - def test_get_preview(self, app, mock_file_service): + def test_get_preview(self, app, mock_account_context, mock_file_service): api = FilePreviewApi() get_method = unwrap(api.get) mock_file_service.get_file_preview.return_value = "preview text" diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py index c26011b758..9726c939e9 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py @@ -234,9 +234,19 @@ class TestAdvancedChatGenerateTaskPipeline: ) pipeline._workflow_response_converter.workflow_start_to_stream_response = lambda **kwargs: "started" + # Track database operations for verification + executed_statements = [] + @contextmanager def _fake_session(): - yield SimpleNamespace() + sess = SimpleNamespace() + + def _execute(stmt): + executed_statements.append(stmt) + return SimpleNamespace() + + sess.execute = _execute + yield sess monkeypatch.setattr(pipeline, "_database_session", _fake_session) monkeypatch.setattr(pipeline, "_get_message", lambda **kwargs: SimpleNamespace()) @@ -246,6 +256,14 @@ class TestAdvancedChatGenerateTaskPipeline: assert pipeline._workflow_run_id == "run-id" assert responses == ["started"] + # Verify database operation was executed + assert len(executed_statements) == 1 + # Verify the UPDATE statement targets the correct message and sets workflow_run_id + update_stmt = executed_statements[0] + stmt_str = str(update_stmt) + assert "UPDATE messages" in stmt_str + assert "WHERE messages.id" in stmt_str + def test_message_end_to_stream_response_strips_annotation_reply(self): pipeline = _make_pipeline() pipeline._task_state.metadata.annotation_reply = AnnotationReply( diff --git a/api/tests/unit_tests/core/rag/datasource/test_retrieval_attachment_access.py b/api/tests/unit_tests/core/rag/datasource/test_retrieval_attachment_access.py new file mode 100644 index 0000000000..426ffc498b --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/test_retrieval_attachment_access.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +from types import SimpleNamespace +from uuid import uuid4 + +from sqlalchemy import select + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.file_access import ( + DatabaseFileAccessController, + FileAccessScope, + bind_file_access_scope, + get_current_file_access_scope, + grant_retriever_segment_access, + grant_upload_file_access, + is_retriever_segment_access_granted, +) +from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.retrieval import dataset_retrieval as dataset_retrieval_module +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval +from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest +from models import UploadFile + + +class _ScalarResult: + def __init__(self, values): + self._values = values + + def all(self): + return self._values + + +class _AttachmentSession: + def __init__(self, upload_file, binding): + self._results = [ + _ScalarResult([upload_file]), + _ScalarResult([binding]), + ] + + def scalars(self, _stmt): + return self._results.pop(0) + + +class _DatasetRetrievalSession: + def __init__(self, datasets, documents): + self._results = [ + _ScalarResult(datasets), + _ScalarResult(documents), + ] + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def scalars(self, _stmt): + return self._results.pop(0) + + +def test_file_access_grants_ignore_empty_inputs_and_missing_scope() -> None: + grant_upload_file_access(["upload-file-id"]) + grant_retriever_segment_access(["segment-id"]) + assert is_retriever_segment_access_granted("segment-id") is True + + scope = FileAccessScope( + tenant_id=str(uuid4()), + user_id=str(uuid4()), + user_from=UserFrom.END_USER, + invoke_from=InvokeFrom.WEB_APP, + ) + + with bind_file_access_scope(scope): + grant_upload_file_access([""]) + grant_retriever_segment_access([""]) + current_scope = get_current_file_access_scope() + + assert current_scope is not None + assert current_scope.granted_upload_file_ids == frozenset() + assert current_scope.granted_retriever_segment_ids == frozenset() + + +def test_segment_attachment_lookup_grants_returned_upload_files_to_current_scope() -> None: + tenant_id = str(uuid4()) + upload_file_id = str(uuid4()) + segment_id = str(uuid4()) + upload_file = SimpleNamespace( + id=upload_file_id, + name="chart.png", + extension="png", + mime_type="image/png", + size=1024, + ) + binding = SimpleNamespace(attachment_id=upload_file_id, segment_id=segment_id) + session = _AttachmentSession(upload_file, binding) + scope = FileAccessScope( + tenant_id=tenant_id, + user_id=str(uuid4()), + user_from=UserFrom.END_USER, + invoke_from=InvokeFrom.WEB_APP, + ) + + with bind_file_access_scope(scope): + result = RetrievalService.get_segment_attachment_infos([upload_file_id], session) # type: ignore[arg-type] + scoped_stmt = DatabaseFileAccessController().apply_upload_file_filters( + select(UploadFile).where(UploadFile.id == upload_file_id) + ) + + assert result[0]["attachment_id"] == upload_file_id + whereclause = str(scoped_stmt.whereclause) + assert "upload_files.created_by_role" in whereclause + assert "upload_files.id IN" in whereclause + + +def test_knowledge_retrieval_grants_returned_segments_to_current_scope(monkeypatch) -> None: + tenant_id = str(uuid4()) + dataset_id = str(uuid4()) + document_id = str(uuid4()) + segment_id = str(uuid4()) + segment = SimpleNamespace( + id=segment_id, + dataset_id=dataset_id, + document_id=document_id, + hit_count=1, + word_count=10, + position=1, + index_node_hash="hash", + answer=None, + get_sign_content=lambda: "segment content", + ) + record = SimpleNamespace(segment=segment, score=0.8, child_chunks=None, files=None, summary=None) + dataset = SimpleNamespace(id=dataset_id, name="Dataset") + document = SimpleNamespace(id=document_id, name="Document", data_source_type="upload_file", doc_metadata={}) + retrieval = DatasetRetrieval() + monkeypatch.setattr(retrieval, "_check_knowledge_rate_limit", lambda tenant_id: None) + monkeypatch.setattr(retrieval, "_get_available_datasets", lambda tenant_id, dataset_ids: [dataset]) + monkeypatch.setattr(retrieval, "multiple_retrieve", lambda **kwargs: [SimpleNamespace(provider="dify")]) + monkeypatch.setattr(RetrievalService, "format_retrieval_documents", lambda documents: [record]) + session = _DatasetRetrievalSession([dataset], [document]) + monkeypatch.setattr(dataset_retrieval_module.session_factory, "create_session", lambda: session) + scope = FileAccessScope( + tenant_id=tenant_id, + user_id=str(uuid4()), + user_from=UserFrom.END_USER, + invoke_from=InvokeFrom.WEB_APP, + ) + + with bind_file_access_scope(scope): + results = retrieval.knowledge_retrieval( + KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=str(uuid4()), + app_id=str(uuid4()), + user_from=UserFrom.END_USER.value, + dataset_ids=[dataset_id], + query="desktop picture", + retrieval_mode="multiple", + ) + ) + current_scope = get_current_file_access_scope() + + assert results[0].metadata.segment_id == segment_id + assert current_scope is not None + assert segment_id in current_scope.granted_retriever_segment_ids diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index 7d23b63049..100b294f52 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -297,7 +297,7 @@ class TableTestRunner: max_workers: int = 4, enable_logging: bool = False, log_level: str = "INFO", - graph_engine_min_workers: int = 1, + graph_engine_min_workers: int = 3, graph_engine_max_workers: int = 1, graph_engine_scale_up_threshold: int = 5, graph_engine_scale_down_idle_time: float = 30.0, @@ -310,7 +310,7 @@ class TableTestRunner: max_workers: Maximum number of parallel workers for test execution enable_logging: Enable detailed logging log_level: Logging level (DEBUG, INFO, WARNING, ERROR) - graph_engine_min_workers: Minimum workers for GraphEngine (default: 1) + graph_engine_min_workers: Minimum workers for GraphEngine (default: 3) graph_engine_max_workers: Maximum workers for GraphEngine (default: 1) graph_engine_scale_up_threshold: Queue depth to trigger scale up graph_engine_scale_down_idle_time: Idle time before scaling down diff --git a/api/tests/unit_tests/core/workflow/test_node_factory.py b/api/tests/unit_tests/core/workflow/test_node_factory.py index ccb63f36d3..95b05a35a1 100644 --- a/api/tests/unit_tests/core/workflow/test_node_factory.py +++ b/api/tests/unit_tests/core/workflow/test_node_factory.py @@ -13,7 +13,8 @@ from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.code.entities import CodeLanguage from graphon.nodes.llm.entities import LLMNodeData from graphon.nodes.llm.node import LLMNode -from graphon.variables.segments import StringSegment +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.variables.segments import ArrayObjectSegment, StringSegment def _assert_constructor_node_data(data, *, node_id: str, node_type: NodeType, version: str = "1") -> None: @@ -430,6 +431,7 @@ class TestDifyNodeFactoryCreateNode: factory._http_request_config = sentinel.http_request_config factory._llm_credentials_provider = sentinel.credentials_provider factory._llm_model_factory = sentinel.model_factory + factory._build_retriever_attachment_loader = MagicMock(return_value=sentinel.retriever_attachment_loader) return factory def test_rejects_unknown_node_type(self, factory): @@ -777,6 +779,128 @@ class TestDifyNodeFactoryCreateNode: for key, value in expected_extra_kwargs.items(): assert constructor_kwargs[key] is value + def test_parameter_extractor_init_does_not_require_retriever_context(self, factory): + node_data = ParameterExtractorNodeData.model_validate( + { + "type": BuiltinNodeTypes.PARAMETER_EXTRACTOR, + "title": "Parameter Extractor", + "model": {"provider": "provider", "name": "model", "mode": "chat", "completion_params": {}}, + "query": ["sys", "query"], + "parameters": [ + { + "name": "topic", + "type": "string", + "description": "Topic", + "required": True, + } + ], + "reasoning_mode": "prompt", + } + ) + factory._build_model_instance_for_llm_node = MagicMock(return_value=sentinel.model_instance) + factory._build_memory_for_llm_node = MagicMock(return_value=sentinel.memory) + factory._build_retriever_attachment_loader = MagicMock(side_effect=AssertionError("unexpected loader build")) + + kwargs = factory._build_llm_compatible_node_init_kwargs( + node_class=sentinel.node_class, + node_data=node_data, + wrap_model_instance=True, + include_http_client=False, + include_llm_file_saver=False, + include_prompt_message_serializer=True, + include_retriever_attachment_loader=False, + include_jinja2_template_renderer=False, + ) + + assert "retriever_attachment_loader" not in kwargs + assert kwargs["prompt_message_serializer"] is sentinel.prompt_message_serializer + factory._build_retriever_attachment_loader.assert_not_called() + + +class TestDifyNodeFactoryRetrieverAttachmentAccess: + @pytest.fixture + def factory(self): + factory = object.__new__(node_factory.DifyNodeFactory) + factory.graph_runtime_state = SimpleNamespace(variable_pool=MagicMock()) + return factory + + def test_retriever_attachment_loader_is_typed_for_llm_node_data_only(self): + annotations = node_factory.DifyNodeFactory._build_retriever_attachment_loader.__annotations__ + + assert annotations["node_data"] is LLMNodeData + + def test_build_retriever_attachment_loader_uses_llm_context_selector(self, factory): + factory._file_reference_factory = sentinel.file_reference_factory + factory.graph_runtime_state.variable_pool.get.return_value = ArrayObjectSegment( + value=[ + { + "metadata": { + "_source": "knowledge", + "segment_id": "allowed-segment", + } + } + ] + ) + node_data = LLMNodeData.model_validate( + { + "type": BuiltinNodeTypes.LLM, + "title": "LLM", + "model": {"provider": "provider", "name": "model", "mode": "chat", "completion_params": {}}, + "prompt_template": [{"role": "system", "text": "x"}], + "context": {"enabled": True, "variable_selector": ["knowledge-node", "result"]}, + "vision": {"enabled": False}, + } + ) + + loader = factory._build_retriever_attachment_loader(node_data) + + assert loader._segment_access_checker is not None + assert loader._segment_access_checker("allowed-segment") is True + factory.graph_runtime_state.variable_pool.get.assert_called_once_with(["knowledge-node", "result"]) + + def test_checker_rejects_missing_context_selector_without_reading_variable_pool(self, factory): + checker = factory._build_retriever_segment_access_checker(None) + + assert checker("segment-id") is False + factory.graph_runtime_state.variable_pool.get.assert_not_called() + + def test_checker_rejects_non_knowledge_context_items(self, factory): + factory.graph_runtime_state.variable_pool.get.return_value = ArrayObjectSegment.model_construct( + value=[ + "plain-text", + {"metadata": "not-a-mapping"}, + ] + ) + + checker = factory._build_retriever_segment_access_checker(["knowledge-node", "result"]) + + assert checker("segment-id") is False + + def test_checker_rejects_non_array_context_value(self, factory): + factory.graph_runtime_state.variable_pool.get.return_value = StringSegment(value="not knowledge context") + + checker = factory._build_retriever_segment_access_checker(["knowledge-node", "result"]) + + assert checker("segment-id") is False + + def test_checker_allows_only_segments_from_selected_knowledge_context(self, factory): + factory.graph_runtime_state.variable_pool.get.return_value = ArrayObjectSegment( + value=[ + { + "metadata": { + "_source": "knowledge", + "segment_id": "allowed-segment", + } + } + ] + ) + + checker = factory._build_retriever_segment_access_checker(["knowledge-node", "result"]) + + assert checker("allowed-segment") is True + assert checker("other-segment") is False + factory.graph_runtime_state.variable_pool.get.assert_any_call(["knowledge-node", "result"]) + class TestDifyNodeFactoryModelInstance: @pytest.fixture diff --git a/api/tests/unit_tests/core/workflow/test_node_runtime.py b/api/tests/unit_tests/core/workflow/test_node_runtime.py index d2925fd1a8..5e83863dc2 100644 --- a/api/tests/unit_tests/core/workflow/test_node_runtime.py +++ b/api/tests/unit_tests/core/workflow/test_node_runtime.py @@ -1,9 +1,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, sentinel +from uuid import uuid4 import pytest from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext, InvokeFrom, UserFrom +from core.app.file_access import FileAccessScope, bind_file_access_scope, grant_retriever_segment_access from core.llm_generator.output_parser.errors import OutputParserError from core.workflow import node_runtime from core.workflow.file_reference import parse_file_reference @@ -268,6 +270,114 @@ def test_dify_retriever_attachment_loader_builds_graph_files(monkeypatch: pytest assert parse_file_reference(mapping["reference"]).storage_key is None +def test_dify_retriever_attachment_loader_grants_upload_files_for_allowed_segment( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from factories.file_factory import builders as file_builders + + upload_file_id = str(uuid4()) + segment_id = str(uuid4()) + upload_file = SimpleNamespace( + id=upload_file_id, + tenant_id="tenant-id", + name="diagram.png", + extension="png", + mime_type="image/png", + source_url="https://example.com/diagram.png", + key="storage-key", + size=128, + ) + attachment_session = MagicMock() + attachment_session.execute.return_value.all.return_value = [(None, upload_file)] + + class _AttachmentSessionContext: + def __enter__(self): + return attachment_session + + def __exit__(self, exc_type, exc, tb): + return False + + upload_session = MagicMock() + upload_session.__enter__.return_value = upload_session + upload_session.__exit__.return_value = False + upload_session.scalar.return_value = upload_file + + monkeypatch.setattr(node_runtime, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr(node_runtime, "Session", MagicMock(return_value=_AttachmentSessionContext())) + monkeypatch.setattr(file_builders, "session_factory", SimpleNamespace(create_session=lambda: upload_session)) + + loader = DifyRetrieverAttachmentLoader(file_reference_factory=DifyFileReferenceFactory(_build_run_context())) + scope = FileAccessScope( + tenant_id="tenant-id", + user_id="end-user-id", + user_from=UserFrom.END_USER, + invoke_from=InvokeFrom.WEB_APP, + ) + + with bind_file_access_scope(scope): + grant_retriever_segment_access([segment_id]) + files = loader.load(segment_id=segment_id) + + assert files[0].related_id == upload_file_id + stmt = upload_session.scalar.call_args.args[0] + whereclause = str(stmt.whereclause) + assert "upload_files.tenant_id" in whereclause + assert "upload_files.id IN" in whereclause + + +def test_dify_retriever_attachment_loader_skips_ungranted_segment_for_end_user( + monkeypatch: pytest.MonkeyPatch, +) -> None: + build_from_mapping = MagicMock() + session_factory = MagicMock() + monkeypatch.setattr(node_runtime, "Session", session_factory) + loader = DifyRetrieverAttachmentLoader( + file_reference_factory=SimpleNamespace(build_from_mapping=build_from_mapping) + ) + scope = FileAccessScope( + tenant_id="tenant-id", + user_id="end-user-id", + user_from=UserFrom.END_USER, + invoke_from=InvokeFrom.WEB_APP, + ) + + with bind_file_access_scope(scope): + files = loader.load(segment_id=str(uuid4())) + + assert files == [] + session_factory.assert_not_called() + build_from_mapping.assert_not_called() + + +def test_dify_retriever_attachment_loader_skips_segment_rejected_by_checker( + monkeypatch: pytest.MonkeyPatch, +) -> None: + segment_id = str(uuid4()) + build_from_mapping = MagicMock() + session_factory = MagicMock() + segment_access_checker = MagicMock(return_value=False) + monkeypatch.setattr(node_runtime, "Session", session_factory) + loader = DifyRetrieverAttachmentLoader( + file_reference_factory=SimpleNamespace(build_from_mapping=build_from_mapping), + segment_access_checker=segment_access_checker, + ) + scope = FileAccessScope( + tenant_id="tenant-id", + user_id="end-user-id", + user_from=UserFrom.END_USER, + invoke_from=InvokeFrom.WEB_APP, + ) + + with bind_file_access_scope(scope): + grant_retriever_segment_access([segment_id]) + files = loader.load(segment_id=segment_id) + + assert files == [] + segment_access_checker.assert_called_once_with(segment_id) + session_factory.assert_not_called() + build_from_mapping.assert_not_called() + + def test_dify_tool_file_manager_resolves_conversation_id_for_tool_files(monkeypatch: pytest.MonkeyPatch) -> None: create_file_by_raw = MagicMock(return_value=SimpleNamespace(id="tool-file-id")) manager_instance = SimpleNamespace(create_file_by_raw=create_file_by_raw) diff --git a/api/tests/unit_tests/libs/test_oauth_base.py b/api/tests/unit_tests/libs/test_oauth_base.py index 7b7f086dac..1c0066ed9a 100644 --- a/api/tests/unit_tests/libs/test_oauth_base.py +++ b/api/tests/unit_tests/libs/test_oauth_base.py @@ -1,6 +1,6 @@ import pytest -from libs.oauth import OAuth +from libs.oauth import OAuth, decode_oauth_state, encode_oauth_state def test_oauth_base_methods_raise_not_implemented(): @@ -17,3 +17,17 @@ def test_oauth_base_methods_raise_not_implemented(): with pytest.raises(NotImplementedError): oauth._transform_user_info({}) + + +def test_oauth_state_round_trips_invite_token_timezone_and_language(): + state = encode_oauth_state(invite_token="invite-123", timezone="Asia/Shanghai", language="zh-Hans") + + assert decode_oauth_state(state) == { + "invite_token": "invite-123", + "timezone": "Asia/Shanghai", + "language": "zh-Hans", + } + + +def test_oauth_state_returns_empty_payload_for_invalid_state(): + assert decode_oauth_state("invalid-state") == {} diff --git a/api/tests/unit_tests/libs/test_oauth_clients.py b/api/tests/unit_tests/libs/test_oauth_clients.py index 830284e697..b3ecc5a06d 100644 --- a/api/tests/unit_tests/libs/test_oauth_clients.py +++ b/api/tests/unit_tests/libs/test_oauth_clients.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock, patch import httpx import pytest -from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo +from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo, decode_oauth_state class BaseOAuthTest: @@ -37,15 +37,25 @@ class TestGitHubOAuth(BaseOAuthTest): return GitHubOAuth(oauth_config["client_id"], oauth_config["client_secret"], oauth_config["redirect_uri"]) @pytest.mark.parametrize( - ("invite_token", "expected_state"), + ("invite_token", "timezone", "language", "expected_state"), [ - (None, None), - ("test_invite_token", "test_invite_token"), - ("", None), + (None, None, None, None), + ("test_invite_token", None, None, {"invite_token": "test_invite_token"}), + ("", None, None, None), + (None, "Asia/Shanghai", None, {"timezone": "Asia/Shanghai"}), + (None, None, "zh-Hans", {"language": "zh-Hans"}), + ( + "test_invite_token", + "Asia/Shanghai", + "zh-Hans", + {"invite_token": "test_invite_token", "timezone": "Asia/Shanghai", "language": "zh-Hans"}, + ), ], ) - def test_should_generate_authorization_url_correctly(self, oauth, oauth_config, invite_token, expected_state): - url = oauth.get_authorization_url(invite_token) + def test_should_generate_authorization_url_correctly( + self, oauth, oauth_config, invite_token, timezone, language, expected_state + ): + url = oauth.get_authorization_url(invite_token, timezone=timezone, language=language) parsed, params = self.parse_auth_url(url) assert parsed.scheme == "https" @@ -56,7 +66,7 @@ class TestGitHubOAuth(BaseOAuthTest): assert params["scope"][0] == "user:email" if expected_state: - assert params["state"][0] == expected_state + assert decode_oauth_state(params["state"][0]) == expected_state else: assert "state" not in params @@ -208,15 +218,25 @@ class TestGoogleOAuth(BaseOAuthTest): return GoogleOAuth(oauth_config["client_id"], oauth_config["client_secret"], oauth_config["redirect_uri"]) @pytest.mark.parametrize( - ("invite_token", "expected_state"), + ("invite_token", "timezone", "language", "expected_state"), [ - (None, None), - ("test_invite_token", "test_invite_token"), - ("", None), + (None, None, None, None), + ("test_invite_token", None, None, {"invite_token": "test_invite_token"}), + ("", None, None, None), + (None, "Asia/Shanghai", None, {"timezone": "Asia/Shanghai"}), + (None, None, "zh-Hans", {"language": "zh-Hans"}), + ( + "test_invite_token", + "Asia/Shanghai", + "zh-Hans", + {"invite_token": "test_invite_token", "timezone": "Asia/Shanghai", "language": "zh-Hans"}, + ), ], ) - def test_should_generate_authorization_url_correctly(self, oauth, oauth_config, invite_token, expected_state): - url = oauth.get_authorization_url(invite_token) + def test_should_generate_authorization_url_correctly( + self, oauth, oauth_config, invite_token, timezone, language, expected_state + ): + url = oauth.get_authorization_url(invite_token, timezone=timezone, language=language) parsed, params = self.parse_auth_url(url) assert parsed.scheme == "https" @@ -228,7 +248,7 @@ class TestGoogleOAuth(BaseOAuthTest): assert params["scope"][0] == "openid email" if expected_state: - assert params["state"][0] == expected_state + assert decode_oauth_state(params["state"][0]) == expected_state else: assert "state" not in params diff --git a/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_dsl_service.py b/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_dsl_service.py index 337659b15f..e72ebb4907 100644 --- a/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_dsl_service.py +++ b/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_dsl_service.py @@ -8,11 +8,12 @@ from sqlalchemy.orm import Session from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from graphon.enums import BuiltinNodeTypes +from services.dsl_version import check_version_compatibility from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity +from services.rag_pipeline import rag_pipeline_dsl_service from services.rag_pipeline.rag_pipeline_dsl_service import ( ImportStatus, RagPipelineDslService, - _check_version_compatibility, ) @@ -26,7 +27,9 @@ from services.rag_pipeline.rag_pipeline_dsl_service import ( ], ) def test_check_version_compatibility(imported_version: str, expected_status: ImportStatus) -> None: - assert _check_version_compatibility(imported_version) == expected_status + assert ( + check_version_compatibility(imported_version, rag_pipeline_dsl_service.CURRENT_DSL_VERSION) == expected_status + ) def test_encrypt_decrypt_dataset_id_roundtrip() -> None: @@ -259,6 +262,60 @@ workflow: if result.status == ImportStatus.FAILED: print(f"DEBUG: {result.error}") assert result.status == ImportStatus.COMPLETED + session.commit.assert_not_called() + session.flush.assert_called() + + +def test_import_rag_pipeline_flushes_new_collection_binding_without_commit(mocker) -> None: + yaml_content = """ +version: 0.1.0 +kind: rag_pipeline +rag_pipeline: + name: Test Pipeline +workflow: + graph: + nodes: + - data: + type: knowledge-index +""" + pipeline = Mock(id="p1", description="desc", is_published=False) + pipeline.name = "Test Pipeline" + mocker.patch.object(RagPipelineDslService, "_create_or_update_pipeline", return_value=pipeline) + + config_mock = Mock() + config_mock.indexing_technique = "high_quality" + config_mock.embedding_model = "m" + config_mock.embedding_model_provider = "p" + config_mock.chunk_structure = "text_model" + config_mock.retrieval_model.model_dump.return_value = {} + config_mock.summary_index_setting = None + mocker.patch( + "services.rag_pipeline.rag_pipeline_dsl_service.KnowledgeConfiguration.model_validate", + return_value=config_mock, + ) + + dataset_mock = Mock(id="d1") + binding_mock = Mock(id="b1") + mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Dataset", return_value=dataset_mock) + binding_cls = mocker.patch( + "services.rag_pipeline.rag_pipeline_dsl_service.DatasetCollectionBinding", + return_value=binding_mock, + ) + mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock()) + + session = cast(MagicMock, Mock()) + session.scalar.return_value = None + session.scalars.return_value.all.return_value = [] + service = RagPipelineDslService(session=cast(Session, session)) + account = Mock(current_tenant_id="t1") + + result = service.import_rag_pipeline(account=account, import_mode="yaml-content", yaml_content=yaml_content) + + assert result.status == ImportStatus.COMPLETED + binding_cls.assert_called_once() + assert dataset_mock.collection_binding_id == "b1" + session.commit.assert_not_called() + assert session.flush.call_count >= 2 def test_import_rag_pipeline_pending_version(mocker) -> None: @@ -338,6 +395,67 @@ workflow: assert result.dataset_id == "d1" +def test_confirm_import_flushes_new_collection_binding_without_commit(mocker) -> None: + from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelinePendingData + + yaml_content = """ +version: 0.1.0 +kind: rag_pipeline +rag_pipeline: + name: Test Pipeline +workflow: + graph: + nodes: + - data: + type: knowledge-index +""" + pending = RagPipelinePendingData(import_mode="yaml-content", yaml_content=yaml_content, pipeline_id="p1") + mocker.patch( + "services.rag_pipeline.rag_pipeline_dsl_service.redis_client.get", + return_value=pending.model_dump_json(), + ) + mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.redis_client.delete") + + pipeline = Mock(id="p1", description="desc") + pipeline.name = "Test Pipeline" + pipeline.retrieve_dataset.return_value = None + mocker.patch.object(RagPipelineDslService, "_create_or_update_pipeline", return_value=pipeline) + + config_mock = Mock() + config_mock.indexing_technique = "high_quality" + config_mock.embedding_model = "m" + config_mock.embedding_model_provider = "p" + config_mock.chunk_structure = "text_model" + config_mock.retrieval_model.model_dump.return_value = {} + config_mock.summary_index_setting = None + mocker.patch( + "services.rag_pipeline.rag_pipeline_dsl_service.KnowledgeConfiguration.model_validate", + return_value=config_mock, + ) + + dataset_mock = Mock(id="d1") + binding_mock = Mock(id="b1") + mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Dataset", return_value=dataset_mock) + binding_cls = mocker.patch( + "services.rag_pipeline.rag_pipeline_dsl_service.DatasetCollectionBinding", + return_value=binding_mock, + ) + mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock()) + + session = cast(MagicMock, Mock()) + session.scalar.side_effect = [pipeline, None] + service = RagPipelineDslService(session=cast(Session, session)) + account = Mock(id="u1", current_tenant_id="t1") + + result = service.confirm_import(account=account, import_id="imp-1") + + assert result.status == ImportStatus.COMPLETED + binding_cls.assert_called_once() + assert dataset_mock.collection_binding_id == "b1" + session.commit.assert_not_called() + assert session.flush.call_count >= 2 + + # --- _extract_dependencies_from_workflow_graph all types --- @@ -421,6 +539,8 @@ def test_create_or_update_pipeline_create_new(mocker) -> None: assert result == pipeline_instance session.add.assert_called() + session.commit.assert_not_called() + session.flush.assert_called() # --- export_rag_pipeline_dsl comprehensive --- @@ -984,7 +1104,7 @@ def test_extract_dependencies_from_model_config_includes_dataset_reranking_and_t def test_check_version_compatibility_hits_major_older_branch(mocker) -> None: mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.CURRENT_DSL_VERSION", "1.0.0") - status = _check_version_compatibility("0.9.0") + status = check_version_compatibility("0.9.0", rag_pipeline_dsl_service.CURRENT_DSL_VERSION) assert status == ImportStatus.PENDING diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index 02013392fc..8c554e012d 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -260,7 +260,7 @@ class TestAccountService: assert result.interface_theme == "light" assert result.password is not None assert result.password_salt is not None - assert result.timezone is not None + assert result.timezone == "America/New_York" # Verify database operations mock_db_dependencies["db"].session.add.assert_called_once() @@ -271,7 +271,28 @@ class TestAccountService: assert added_account.interface_theme == "light" assert added_account.password is not None assert added_account.password_salt is not None - assert added_account.timezone is not None + assert added_account.timezone == "America/New_York" + self._assert_database_operations_called(mock_db_dependencies["db"]) + + def test_create_account_uses_explicit_timezone( + self, mock_db_dependencies, mock_password_dependencies, mock_external_service_dependencies + ): + """Test account creation prefers explicit browser timezone.""" + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + mock_password_dependencies["hash_password"].return_value = b"hashed_password" + + result = AccountService.create_account( + email="test@example.com", + name="Test User", + interface_language="en-US", + password="password123", + timezone="Asia/Shanghai", + ) + + assert result.timezone == "Asia/Shanghai" + added_account = mock_db_dependencies["db"].session.add.call_args[0][0] + assert added_account.timezone == "Asia/Shanghai" self._assert_database_operations_called(mock_db_dependencies["db"]) def test_create_account_registration_disabled(self, mock_external_service_dependencies): @@ -1221,6 +1242,7 @@ class TestRegisterService: interface_language="en-US", password="password123", is_setup=False, + timezone=None, ) mock_create_tenant.assert_called_once_with("Test User's Workspace") mock_create_member.assert_called_once_with(mock_tenant, mock_account, role="owner") diff --git a/api/tests/unit_tests/services/test_file_service.py b/api/tests/unit_tests/services/test_file_service.py index 8e1b22886b..69bd194a68 100644 --- a/api/tests/unit_tests/services/test_file_service.py +++ b/api/tests/unit_tests/services/test_file_service.py @@ -221,7 +221,7 @@ class TestFileService: mock_extract.return_value = "Extracted text content" # Execute - result = file_service.get_file_preview("file_id") + result = file_service.get_file_preview("file_id", "tenant_id") # Assert assert result == "Extracted text content" @@ -229,7 +229,7 @@ class TestFileService: def test_get_file_preview_not_found(self, file_service, mock_db_session): mock_db_session.scalar.return_value = None with pytest.raises(NotFound, match="File not found"): - file_service.get_file_preview("non_existent") + file_service.get_file_preview("non_existent", "tenant_id") def test_get_file_preview_unsupported_type(self, file_service, mock_db_session): upload_file = MagicMock(spec=UploadFile) @@ -237,7 +237,7 @@ class TestFileService: upload_file.extension = "exe" mock_db_session.scalar.return_value = upload_file with pytest.raises(UnsupportedFileTypeError): - file_service.get_file_preview("file_id") + file_service.get_file_preview("file_id", "tenant_id") def test_get_image_preview_success(self, file_service, mock_db_session): # Setup diff --git a/api/tests/unit_tests/tasks/test_document_indexing_update_task.py b/api/tests/unit_tests/tasks/test_document_indexing_update_task.py new file mode 100644 index 0000000000..b73275b97d --- /dev/null +++ b/api/tests/unit_tests/tasks/test_document_indexing_update_task.py @@ -0,0 +1,524 @@ +""" +Unit tests for document_indexing_update_task summary generation. + +After updating a document via the API, the summary index should be +regenerated under the same conditions as during initial creation: +- indexing_technique is HIGH_QUALITY +- summary_index_setting has enable=True +- document.indexing_status is COMPLETED +- document.doc_form is not QA_INDEX +- document.need_summary is True +""" + +from contextlib import nullcontext +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.indexing_runner import DocumentIsPausedError +from tasks.document_indexing_update_task import document_indexing_update_task + + +class _SessionContext: + """Minimal context manager that yields a mock session.""" + + def __init__(self, session: MagicMock) -> None: + self._session = session + + def __enter__(self) -> MagicMock: + return self._session + + def __exit__(self, exc_type, exc, tb) -> None: # type: ignore[override] + return None + + +def _make_dataset_and_documents( + *, + dataset_id: str = "ds-1", + document_id: str = "doc-1", + indexing_technique: str = "high_quality", + summary_index_setting: dict | None = None, + doc_form: str = "text_model", + need_summary: bool = True, +): + """Create mock dataset and document objects. + + Returns (dataset, doc_for_session1, doc_for_session3). + + session1 doc: before IndexingRunner runs (status irrelevant for summary). + session3 doc: re-queried after IndexingRunner completes — normally COMPLETED. + """ + dataset = SimpleNamespace( + id=dataset_id, + indexing_technique=indexing_technique, + summary_index_setting=summary_index_setting, + ) + doc_s1 = SimpleNamespace( + id=document_id, + dataset_id=dataset_id, + indexing_status="waiting", + doc_form=doc_form, + need_summary=need_summary, + ) + # After IndexingRunner.run the document status is COMPLETED in the DB + doc_s3 = SimpleNamespace( + id=document_id, + dataset_id=dataset_id, + indexing_status="completed", + doc_form=doc_form, + need_summary=need_summary, + ) + return dataset, doc_s1, doc_s3 + + +def _patch_all(monkeypatch: pytest.MonkeyPatch, *, sessions, runner, processor): + """Wire up all mocks for document_indexing_update_task.""" + monkeypatch.setattr( + "tasks.document_indexing_update_task.session_factory.create_session", + MagicMock(side_effect=sessions), + ) + monkeypatch.setattr( + "tasks.document_indexing_update_task.IndexProcessorFactory", + MagicMock(return_value=MagicMock(init_index_processor=MagicMock(return_value=processor))), + ) + monkeypatch.setattr( + "tasks.document_indexing_update_task.IndexingRunner", + MagicMock(return_value=runner), + ) + + +def _session_with_begin(): + """Create a mock session with a begin() context manager.""" + s = MagicMock() + s.begin.return_value = nullcontext() + return s + + +class TestUpdateTaskSummaryGeneration: + """Tests for summary index generation in the document update task. + + The update task creates sessions in this order: + 1. session1: fetch document + dataset + segments (uses begin()) + 2. session2: delete segments — only if segments exist (uses begin()) + 3. session3: summary check — only if indexing succeeded (no begin()) + + With empty segments (default), only sessions 1 and 3 are created. + """ + + def test_should_queue_summary_when_conditions_met(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Summary task is queued when all conditions are met.""" + dataset, doc_s1, doc_s3 = _make_dataset_and_documents( + summary_index_setting={"enable": True}, + ) + + session1 = _session_with_begin() + session1.scalar.side_effect = [doc_s1, dataset] + session1.scalars.return_value = MagicMock(all=MagicMock(return_value=[])) + + session3 = MagicMock() + session3.scalar.side_effect = [dataset, doc_s3] + + runner = MagicMock() + processor = MagicMock() + + _patch_all( + monkeypatch, + sessions=[_SessionContext(session1), _SessionContext(session3)], + runner=runner, + processor=processor, + ) + + delay_mock = MagicMock() + monkeypatch.setattr( + "tasks.document_indexing_update_task.generate_summary_index_task.delay", + delay_mock, + ) + + document_indexing_update_task("ds-1", "doc-1") + + delay_mock.assert_called_once_with("ds-1", "doc-1", None) + + def test_should_not_queue_when_not_high_quality(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Summary is skipped when indexing_technique is not high_quality.""" + dataset, doc_s1, _ = _make_dataset_and_documents( + indexing_technique="economy", + summary_index_setting={"enable": True}, + ) + + session1 = _session_with_begin() + session1.scalar.side_effect = [doc_s1, dataset] + session1.scalars.return_value = MagicMock(all=MagicMock(return_value=[])) + + session3 = MagicMock() + session3.scalar.return_value = dataset # dataset.indexing_technique == "economy" + + runner = MagicMock() + processor = MagicMock() + + _patch_all( + monkeypatch, + sessions=[_SessionContext(session1), _SessionContext(session3)], + runner=runner, + processor=processor, + ) + + delay_mock = MagicMock() + monkeypatch.setattr( + "tasks.document_indexing_update_task.generate_summary_index_task.delay", + delay_mock, + ) + + document_indexing_update_task("ds-1", "doc-1") + + delay_mock.assert_not_called() + + def test_should_not_queue_when_summary_setting_disabled(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Summary is skipped when summary_index_setting has enable=False.""" + dataset, doc_s1, _ = _make_dataset_and_documents( + summary_index_setting={"enable": False}, + ) + + session1 = _session_with_begin() + session1.scalar.side_effect = [doc_s1, dataset] + session1.scalars.return_value = MagicMock(all=MagicMock(return_value=[])) + + session3 = MagicMock() + session3.scalar.return_value = dataset + + runner = MagicMock() + processor = MagicMock() + + _patch_all( + monkeypatch, + sessions=[_SessionContext(session1), _SessionContext(session3)], + runner=runner, + processor=processor, + ) + + delay_mock = MagicMock() + monkeypatch.setattr( + "tasks.document_indexing_update_task.generate_summary_index_task.delay", + delay_mock, + ) + + document_indexing_update_task("ds-1", "doc-1") + + delay_mock.assert_not_called() + + def test_should_not_queue_when_summary_setting_none(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Summary is skipped when summary_index_setting is None.""" + dataset, doc_s1, _ = _make_dataset_and_documents( + summary_index_setting=None, + ) + + session1 = _session_with_begin() + session1.scalar.side_effect = [doc_s1, dataset] + session1.scalars.return_value = MagicMock(all=MagicMock(return_value=[])) + + session3 = MagicMock() + session3.scalar.return_value = dataset + + runner = MagicMock() + processor = MagicMock() + + _patch_all( + monkeypatch, + sessions=[_SessionContext(session1), _SessionContext(session3)], + runner=runner, + processor=processor, + ) + + delay_mock = MagicMock() + monkeypatch.setattr( + "tasks.document_indexing_update_task.generate_summary_index_task.delay", + delay_mock, + ) + + document_indexing_update_task("ds-1", "doc-1") + + delay_mock.assert_not_called() + + def test_should_not_queue_when_need_summary_false(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Summary is skipped when document.need_summary is False.""" + dataset, doc_s1, doc_s3 = _make_dataset_and_documents( + summary_index_setting={"enable": True}, + need_summary=False, + ) + + session1 = _session_with_begin() + session1.scalar.side_effect = [doc_s1, dataset] + session1.scalars.return_value = MagicMock(all=MagicMock(return_value=[])) + + session3 = MagicMock() + session3.scalar.side_effect = [dataset, doc_s3] + + runner = MagicMock() + processor = MagicMock() + + _patch_all( + monkeypatch, + sessions=[_SessionContext(session1), _SessionContext(session3)], + runner=runner, + processor=processor, + ) + + delay_mock = MagicMock() + monkeypatch.setattr( + "tasks.document_indexing_update_task.generate_summary_index_task.delay", + delay_mock, + ) + + document_indexing_update_task("ds-1", "doc-1") + + delay_mock.assert_not_called() + + def test_should_not_queue_when_qa_index_form(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Summary is skipped when doc_form is QA_INDEX.""" + dataset, doc_s1, doc_s3 = _make_dataset_and_documents( + summary_index_setting={"enable": True}, + doc_form="qa_model", + ) + + session1 = _session_with_begin() + session1.scalar.side_effect = [doc_s1, dataset] + session1.scalars.return_value = MagicMock(all=MagicMock(return_value=[])) + + session3 = MagicMock() + session3.scalar.side_effect = [dataset, doc_s3] + + runner = MagicMock() + processor = MagicMock() + + _patch_all( + monkeypatch, + sessions=[_SessionContext(session1), _SessionContext(session3)], + runner=runner, + processor=processor, + ) + + delay_mock = MagicMock() + monkeypatch.setattr( + "tasks.document_indexing_update_task.generate_summary_index_task.delay", + delay_mock, + ) + + document_indexing_update_task("ds-1", "doc-1") + + delay_mock.assert_not_called() + + def test_should_not_queue_when_indexing_fails(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Summary is skipped when IndexingRunner.run raises.""" + dataset, doc_s1, _ = _make_dataset_and_documents( + summary_index_setting={"enable": True}, + ) + + session1 = _session_with_begin() + session1.scalar.side_effect = [doc_s1, dataset] + session1.scalars.return_value = MagicMock(all=MagicMock(return_value=[])) + + runner = MagicMock() + runner.run.side_effect = Exception("indexing failed") + processor = MagicMock() + + # Only session1 needed — task returns early after indexing failure + _patch_all( + monkeypatch, + sessions=[_SessionContext(session1)], + runner=runner, + processor=processor, + ) + + delay_mock = MagicMock() + monkeypatch.setattr( + "tasks.document_indexing_update_task.generate_summary_index_task.delay", + delay_mock, + ) + + document_indexing_update_task("ds-1", "doc-1") + + delay_mock.assert_not_called() + + def test_should_not_queue_when_document_is_paused(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Summary is skipped when IndexingRunner raises DocumentIsPausedError.""" + + dataset, doc_s1, _ = _make_dataset_and_documents( + summary_index_setting={"enable": True}, + ) + + session1 = _session_with_begin() + session1.scalar.side_effect = [doc_s1, dataset] + session1.scalars.return_value = MagicMock(all=MagicMock(return_value=[])) + + runner = MagicMock() + runner.run.side_effect = DocumentIsPausedError("doc-1 is paused") + processor = MagicMock() + + # Only session1 needed — task returns early after paused error + _patch_all( + monkeypatch, + sessions=[_SessionContext(session1)], + runner=runner, + processor=processor, + ) + + delay_mock = MagicMock() + monkeypatch.setattr( + "tasks.document_indexing_update_task.generate_summary_index_task.delay", + delay_mock, + ) + + document_indexing_update_task("ds-1", "doc-1") + + delay_mock.assert_not_called() + + def test_should_not_queue_when_dataset_not_found_after_indexing(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Summary is skipped when the dataset disappears after indexing.""" + dataset, doc_s1, _ = _make_dataset_and_documents( + summary_index_setting={"enable": True}, + ) + + session1 = _session_with_begin() + session1.scalar.side_effect = [doc_s1, dataset] + session1.scalars.return_value = MagicMock(all=MagicMock(return_value=[])) + + # Session 3: dataset is None + session3 = MagicMock() + session3.scalar.return_value = None + + runner = MagicMock() + processor = MagicMock() + + _patch_all( + monkeypatch, + sessions=[_SessionContext(session1), _SessionContext(session3)], + runner=runner, + processor=processor, + ) + + delay_mock = MagicMock() + monkeypatch.setattr( + "tasks.document_indexing_update_task.generate_summary_index_task.delay", + delay_mock, + ) + + document_indexing_update_task("ds-1", "doc-1") + + delay_mock.assert_not_called() + + def test_should_not_queue_when_document_not_completed_after_indexing(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Summary is skipped when document indexing_status is not COMPLETED after indexing.""" + dataset, doc_s1, _ = _make_dataset_and_documents( + summary_index_setting={"enable": True}, + ) + + session1 = _session_with_begin() + session1.scalar.side_effect = [doc_s1, dataset] + session1.scalars.return_value = MagicMock(all=MagicMock(return_value=[])) + + # Document still in error status after indexing + doc_s3_error = SimpleNamespace( + id="doc-1", + dataset_id="ds-1", + indexing_status="error", + doc_form="text_model", + need_summary=True, + ) + session3 = MagicMock() + session3.scalar.side_effect = [dataset, doc_s3_error] + + runner = MagicMock() + processor = MagicMock() + + _patch_all( + monkeypatch, + sessions=[_SessionContext(session1), _SessionContext(session3)], + runner=runner, + processor=processor, + ) + + delay_mock = MagicMock() + monkeypatch.setattr( + "tasks.document_indexing_update_task.generate_summary_index_task.delay", + delay_mock, + ) + + document_indexing_update_task("ds-1", "doc-1") + + delay_mock.assert_not_called() + + def test_should_swallow_summary_queue_error(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Task should not raise when generate_summary_index_task.delay raises.""" + dataset, doc_s1, doc_s3 = _make_dataset_and_documents( + summary_index_setting={"enable": True}, + ) + + session1 = _session_with_begin() + session1.scalar.side_effect = [doc_s1, dataset] + session1.scalars.return_value = MagicMock(all=MagicMock(return_value=[])) + + session3 = MagicMock() + session3.scalar.side_effect = [dataset, doc_s3] + + runner = MagicMock() + processor = MagicMock() + + _patch_all( + monkeypatch, + sessions=[_SessionContext(session1), _SessionContext(session3)], + runner=runner, + processor=processor, + ) + + delay_mock = MagicMock(side_effect=Exception("queue full")) + monkeypatch.setattr( + "tasks.document_indexing_update_task.generate_summary_index_task.delay", + delay_mock, + ) + + # Should not raise + document_indexing_update_task("ds-1", "doc-1") + + delay_mock.assert_called_once_with("ds-1", "doc-1", None) + + def test_should_queue_summary_with_segments_and_session2(self, monkeypatch: pytest.MonkeyPatch) -> None: + """When segments exist, session2 is also created for deletion. + Verify summary generation still works correctly.""" + dataset, doc_s1, doc_s3 = _make_dataset_and_documents( + summary_index_setting={"enable": True}, + ) + + session1 = _session_with_begin() + session1.scalar.side_effect = [doc_s1, dataset] + seg = SimpleNamespace(index_node_id="node-1") + session1.scalars.return_value = MagicMock(all=MagicMock(return_value=[seg])) + + # Session 2: segment deletion + session2 = _session_with_begin() + + session3 = MagicMock() + session3.scalar.side_effect = [dataset, doc_s3] + + runner = MagicMock() + processor = MagicMock() + + _patch_all( + monkeypatch, + sessions=[ + _SessionContext(session1), + _SessionContext(session2), + _SessionContext(session3), + ], + runner=runner, + processor=processor, + ) + + delay_mock = MagicMock() + monkeypatch.setattr( + "tasks.document_indexing_update_task.generate_summary_index_task.delay", + delay_mock, + ) + + document_indexing_update_task("ds-1", "doc-1") + + delay_mock.assert_called_once_with("ds-1", "doc-1", None) diff --git a/api/uv.lock b/api/uv.lock index 0e4f3170b5..bb150e8cf3 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1332,7 +1332,6 @@ dependencies = [ { name = "boto3" }, { name = "celery" }, { name = "croniter" }, - { name = "dify-agent" }, { name = "fastopenapi", extra = ["flask"] }, { name = "flask" }, { name = "flask-compress" }, @@ -1373,6 +1372,7 @@ dev = [ { name = "boto3-stubs" }, { name = "celery-types" }, { name = "coverage" }, + { name = "dify-agent" }, { name = "dotenv-linter" }, { name = "faker" }, { name = "hypothesis" }, @@ -1615,7 +1615,6 @@ requires-dist = [ { name = "boto3", specifier = ">=1.43.6" }, { name = "celery", specifier = ">=5.6.3" }, { name = "croniter", specifier = ">=6.2.2" }, - { name = "dify-agent", directory = "../dify-agent" }, { name = "fastopenapi", extras = ["flask"], specifier = "~=0.7.0" }, { name = "flask", specifier = ">=3.1.3,<4.0.0" }, { name = "flask-compress", specifier = ">=1.24,<2.0.0" }, @@ -1656,6 +1655,7 @@ dev = [ { name = "boto3-stubs", specifier = ">=1.43.2" }, { name = "celery-types", specifier = ">=0.23.0" }, { name = "coverage", specifier = ">=7.13.4" }, + { name = "dify-agent", directory = "../dify-agent" }, { name = "dotenv-linter", specifier = ">=0.7.0" }, { name = "faker", specifier = ">=40.15.0" }, { name = "hypothesis", specifier = ">=6.152.4" }, diff --git a/docker/envs/core-services/shared.env.example b/docker/envs/core-services/shared.env.example index 80cfe42c38..fca0b57d0c 100644 --- a/docker/envs/core-services/shared.env.example +++ b/docker/envs/core-services/shared.env.example @@ -177,7 +177,7 @@ WORKFLOW_MAX_EXECUTION_TIME=1200 WORKFLOW_CALL_MAX_DEPTH=5 MAX_VARIABLE_SIZE=204800 WORKFLOW_FILE_UPLOAD_LIMIT=10 -GRAPH_ENGINE_MIN_WORKERS=1 +GRAPH_ENGINE_MIN_WORKERS=3 GRAPH_ENGINE_MAX_WORKERS=10 GRAPH_ENGINE_SCALE_UP_THRESHOLD=3 GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME=5.0 diff --git a/eslint-suppressions.json b/eslint-suppressions.json index 1f7e82aee0..e169248299 100644 --- a/eslint-suppressions.json +++ b/eslint-suppressions.json @@ -1422,14 +1422,6 @@ "count": 1 } }, - "web/app/components/base/message-log-modal/index.tsx": { - "react/set-state-in-effect": { - "count": 1 - }, - "ts/no-explicit-any": { - "count": 1 - } - }, "web/app/components/base/new-audio-button/index.tsx": { "ts/no-explicit-any": { "count": 1 @@ -2281,11 +2273,6 @@ "count": 1 } }, - "web/app/components/header/account-setting/members-page/invite-modal/index.tsx": { - "react/set-state-in-effect": { - "count": 3 - } - }, "web/app/components/header/account-setting/members-page/transfer-ownership-modal/index.tsx": { "erasable-syntax-only/enums": { "count": 1 diff --git a/web/app/(commonLayout)/layout.tsx b/web/app/(commonLayout)/layout.tsx index 699d2a4348..2467f35b7b 100644 --- a/web/app/(commonLayout)/layout.tsx +++ b/web/app/(commonLayout)/layout.tsx @@ -2,6 +2,7 @@ import type { ReactNode } from 'react' import * as React from 'react' import { AppInitializer } from '@/app/components/app-initializer' import InSiteMessageNotification from '@/app/components/app/in-site-message/notification' +import AmplitudeProvider from '@/app/components/base/amplitude' import GA, { GaType } from '@/app/components/base/ga' import Zendesk from '@/app/components/base/zendesk' import { GotoAnything } from '@/app/components/goto-anything' @@ -19,6 +20,7 @@ const Layout = ({ children }: { children: ReactNode }) => { return ( <> + diff --git a/web/app/account/(commonLayout)/layout.tsx b/web/app/account/(commonLayout)/layout.tsx index f116cd00f9..8fdbd8a238 100644 --- a/web/app/account/(commonLayout)/layout.tsx +++ b/web/app/account/(commonLayout)/layout.tsx @@ -1,6 +1,7 @@ import type { ReactNode } from 'react' import * as React from 'react' import { AppInitializer } from '@/app/components/app-initializer' +import AmplitudeProvider from '@/app/components/base/amplitude' import GA, { GaType } from '@/app/components/base/ga' import HeaderWrapper from '@/app/components/header/header-wrapper' import { AppContextProvider } from '@/context/app-context-provider' @@ -13,6 +14,7 @@ const Layout = ({ children }: { children: ReactNode }) => { return ( <> + diff --git a/web/app/components/app/overview/customize/index.tsx b/web/app/components/app/overview/customize/index.tsx index 89b621f32b..e27e1e7653 100644 --- a/web/app/components/app/overview/customize/index.tsx +++ b/web/app/components/app/overview/customize/index.tsx @@ -46,78 +46,80 @@ const CustomizeModal: FC = ({ return ( !open && onClose()}> - - + + {t(`${prefixCustomize}.title`, { ns: 'appOverview' })} - + {t(`${prefixCustomize}.explanation`, { ns: 'appOverview' })} -
- - {t(`${prefixCustomize}.way`, { ns: 'appOverview' })} - {' '} - 1 - -

{t(`${prefixCustomize}.way1.name`, { ns: 'appOverview' })}

-
- 1 -
-
{t(`${prefixCustomize}.way1.step1`, { ns: 'appOverview' })}
-
{t(`${prefixCustomize}.way1.step1Tip`, { ns: 'appOverview' })}
- +
+
+ + {t(`${prefixCustomize}.way`, { ns: 'appOverview' })} + {' '} + 1 + +

{t(`${prefixCustomize}.way1.name`, { ns: 'appOverview' })}

+
+ 1 +
+
{t(`${prefixCustomize}.way1.step1`, { ns: 'appOverview' })}
+
{t(`${prefixCustomize}.way1.step1Tip`, { ns: 'appOverview' })}
+ +
-
-
- 2 -
-
{t(`${prefixCustomize}.way1.step2`, { ns: 'appOverview' })}
-
{t(`${prefixCustomize}.way1.step2Tip`, { ns: 'appOverview' })}
- +
+ 2 +
+
{t(`${prefixCustomize}.way1.step2`, { ns: 'appOverview' })}
+
{t(`${prefixCustomize}.way1.step2Tip`, { ns: 'appOverview' })}
+ +
-
-
- 3 -
-
{t(`${prefixCustomize}.way1.step3`, { ns: 'appOverview' })}
-
{t(`${prefixCustomize}.way1.step3Tip`, { ns: 'appOverview' })}
-
-                NEXT_PUBLIC_APP_ID=
-                {`'${appId}'`}
-                {' '}
-                
- NEXT_PUBLIC_APP_KEY= - {'\'\''} - {' '} -
- NEXT_PUBLIC_API_URL= - {`'${api_base_url}'`} -
+
+ 3 +
+
{t(`${prefixCustomize}.way1.step3`, { ns: 'appOverview' })}
+
{t(`${prefixCustomize}.way1.step3Tip`, { ns: 'appOverview' })}
+
+                  NEXT_PUBLIC_APP_ID=
+                  {`'${appId}'`}
+                  {' '}
+                  
+ NEXT_PUBLIC_APP_KEY= + {'\'\''} + {' '} +
+ NEXT_PUBLIC_API_URL= + {`'${api_base_url}'`} +
+
-
-
-
- - {t(`${prefixCustomize}.way`, { ns: 'appOverview' })} - {' '} - 2 - -

{t(`${prefixCustomize}.way2.name`, { ns: 'appOverview' })}

- +
+
+ + {t(`${prefixCustomize}.way`, { ns: 'appOverview' })} + {' '} + 2 + +

{t(`${prefixCustomize}.way2.name`, { ns: 'appOverview' })}

+ +
diff --git a/web/app/components/app/overview/embedded/index.tsx b/web/app/components/app/overview/embedded/index.tsx index 112848760b..a0dcc1b535 100644 --- a/web/app/components/app/overview/embedded/index.tsx +++ b/web/app/components/app/overview/embedded/index.tsx @@ -319,12 +319,12 @@ const Embedded = ({ siteInfo, isShow, onClose, appBaseUrl, accessToken, hiddenIn onClose() }} > - - + + {t(`${prefixEmbedded}.title`, { ns: 'appOverview' })} -
+
{isShow && ( = ({ return ( <> !open && onHide()}> - + {/* header */} -
+
{t(`${prefixSettings}.title`, { ns: 'appOverview' })} @@ -263,7 +263,7 @@ const SettingsModal: FC = ({
{/* form body */} -
+
{/* name & icon */}
@@ -474,7 +474,7 @@ const SettingsModal: FC = ({ )}
{/* footer */} -
+
diff --git a/web/app/components/base/app-icon/__tests__/index.spec.tsx b/web/app/components/base/app-icon/__tests__/index.spec.tsx index 4bbea1e0ae..54c1b8c361 100644 --- a/web/app/components/base/app-icon/__tests__/index.spec.tsx +++ b/web/app/components/base/app-icon/__tests__/index.spec.tsx @@ -26,9 +26,8 @@ describe('AppIcon', () => { super() } - // Mock basic functionality connectedCallback() { - this.innerHTML = '🤖' + this.innerHTML = this.getAttribute('id') || '🤖' } }) } @@ -51,6 +50,15 @@ describe('AppIcon', () => { expect(emojiElement?.getAttribute('id')).toBe('smile') }) + it('updates the rendered emoji when icon changes', () => { + const { rerender } = render() + expect(document.querySelector('em-emoji')).toHaveTextContent('smile') + + rerender() + + expect(document.querySelector('em-emoji')).toHaveTextContent('robot') + }) + it('renders image when iconType is image and imageUrl is provided', () => { render() const imgElement = screen.getByAltText('app icon') diff --git a/web/app/components/base/app-icon/index.tsx b/web/app/components/base/app-icon/index.tsx index 5c15179446..b08ac5e981 100644 --- a/web/app/components/base/app-icon/index.tsx +++ b/web/app/components/base/app-icon/index.tsx @@ -104,7 +104,8 @@ const AppIcon: FC = ({ showEditIcon = false, }) => { const isValidImageIcon = iconType === 'image' && imageUrl - const Icon = (icon && icon !== '') ? : + const emojiIcon = (icon && icon !== '') ? icon : '🤖' + const Icon = const wrapperRef = useRef(null) const isHovering = useHover(wrapperRef) diff --git a/web/app/components/base/message-log-modal/__tests__/index.spec.tsx b/web/app/components/base/message-log-modal/__tests__/index.spec.tsx index 49f2970654..f7e0158896 100644 --- a/web/app/components/base/message-log-modal/__tests__/index.spec.tsx +++ b/web/app/components/base/message-log-modal/__tests__/index.spec.tsx @@ -4,11 +4,9 @@ import { useStore } from '@/app/components/app/store' import MessageLogModal from '../index' let clickAwayHandler: (() => void) | null = null -let clickAwayHandlers: (() => void)[] = [] vi.mock('ahooks', () => ({ useClickAway: (fn: () => void) => { clickAwayHandler = fn - clickAwayHandlers.push(fn) }, })) @@ -40,7 +38,6 @@ describe('MessageLogModal', () => { beforeEach(() => { vi.clearAllMocks() clickAwayHandler = null - clickAwayHandlers = [] // eslint-disable-next-line ts/no-explicit-any vi.mocked(useStore).mockImplementation((selector: any) => selector({ appDetail: { id: 'app-1' }, @@ -76,15 +73,17 @@ describe('MessageLogModal', () => { it('sets fixed style when fixedWidth is false (floating)', () => { const { container } = render() - const modal = container.firstChild as HTMLElement - expect(modal.style.position).toBe('fixed') - expect(modal.style.width).toBe('480px') + const modal = screen.getByRole('dialog') + expect(container).not.toContainElement(modal) + expect(document.body).toContainElement(modal) + expect(modal).toHaveClass('fixed', 'z-50', 'w-[480px]!', 'left-[max(8px,calc(100vw-1136px))]!') }) it('sets fixed width when fixedWidth is true', () => { const { container } = render() - const modal = container.firstChild as HTMLElement - expect(modal.style.width).toBe('1000px') + const panel = container.firstElementChild as HTMLElement + expect(panel).toHaveClass('relative', 'z-10') + expect(panel.style.width).toBe('1000px') }) }) @@ -98,16 +97,16 @@ describe('MessageLogModal', () => { }) it('calls onCancel when clicked away', () => { - render() + render() expect(clickAwayHandler).toBeTruthy() clickAwayHandler!() expect(onCancel).toHaveBeenCalledTimes(1) }) - it('does not call onCancel when clicked away if not mounted', () => { + it('does not use click away to close the floating dialog', () => { render() - expect(clickAwayHandlers.length).toBeGreaterThan(0) - clickAwayHandlers[0]!() // This is the closure from the initial render, where mounted is false + expect(clickAwayHandler).toBeTruthy() + clickAwayHandler!() expect(onCancel).not.toHaveBeenCalled() }) }) diff --git a/web/app/components/base/message-log-modal/index.tsx b/web/app/components/base/message-log-modal/index.tsx index 9a58a0213d..6b63ccdf3f 100644 --- a/web/app/components/base/message-log-modal/index.tsx +++ b/web/app/components/base/message-log-modal/index.tsx @@ -1,13 +1,18 @@ import type { FC } from 'react' import type { IChatItem } from '@/app/components/base/chat/chat/type' import { cn } from '@langgenius/dify-ui/cn' -import { RiCloseLine } from '@remixicon/react' +import { Dialog, DialogContent, DialogTitle } from '@langgenius/dify-ui/dialog' import { useClickAway } from 'ahooks' -import { useEffect, useRef, useState } from 'react' +import { useRef } from 'react' import { useTranslation } from 'react-i18next' import { useStore } from '@/app/components/app/store' import Run from '@/app/components/workflow/run' +type RunActiveTab = 'RESULT' | 'DETAIL' | 'TRACING' + +const isRunActiveTab = (tab: string): tab is RunActiveTab => + tab === 'RESULT' || tab === 'DETAIL' || tab === 'TRACING' + type MessageLogModalProps = { currentLogItem?: IChatItem defaultTab?: string @@ -24,36 +29,65 @@ const MessageLogModal: FC = ({ }) => { const { t } = useTranslation() const ref = useRef(null) - const [mounted, setMounted] = useState(false) const appDetail = useStore(state => state.appDetail) useClickAway(() => { - if (mounted) + if (fixedWidth) onCancel() }, ref) - useEffect(() => { - setMounted(true) - }, []) - if (!currentLogItem || !currentLogItem.workflow_run_id) return null + const activeTab = isRunActiveTab(defaultTab) ? defaultTab : 'DETAIL' + const modalContent = ( + <> + {t('runDetail.title', { ns: 'appLog' })} + + + + ) + + if (!fixedWidth) { + return ( + { + if (!open) + onCancel() + }} + > + + {modalContent} + + + ) + } + return (
@@ -64,11 +98,11 @@ const MessageLogModal: FC = ({ className="absolute top-4 right-3 z-20 cursor-pointer border-none bg-transparent p-1 focus-visible:ring-1 focus-visible:ring-components-input-border-active focus-visible:outline-hidden" onClick={onCancel} > -
) } diff --git a/web/app/components/datasets/external-api/external-api-modal/index.tsx b/web/app/components/datasets/external-api/external-api-modal/index.tsx index 0179a31b89..406ab815e6 100644 --- a/web/app/components/datasets/external-api/external-api-modal/index.tsx +++ b/web/app/components/datasets/external-api/external-api-modal/index.tsx @@ -121,9 +121,9 @@ const AddExternalAPIModal: FC = ({ data, onSave, onCan onCancel() }} > - -
-
+ +
+
{isEditMode ? t('editExternalAPIFormTitle', { ns: 'dataset' }) : t('createExternalAPI', { ns: 'dataset' })} @@ -173,8 +173,8 @@ const AddExternalAPIModal: FC = ({ data, onSave, onCan -
-
+ +
@@ -194,7 +194,7 @@ const AddExternalAPIModal: FC = ({ data, onSave, onCan {t('externalAPIForm.save', { ns: 'dataset' })}
-
diff --git a/web/app/components/header/account-setting/members-page/edit-workspace-modal/index.tsx b/web/app/components/header/account-setting/members-page/edit-workspace-modal/index.tsx index 60fcf8c094..f9d84d414b 100644 --- a/web/app/components/header/account-setting/members-page/edit-workspace-modal/index.tsx +++ b/web/app/components/header/account-setting/members-page/edit-workspace-modal/index.tsx @@ -60,7 +60,7 @@ const EditWorkspaceModal = ({ onCancel }: IEditWorkspaceModalProps) => { onCancel() }} > - + s.licenseLimit) const refreshLicenseLimit = useProviderContextSelector(s => s.refreshLicenseLimit) const [emails, setEmails] = useState([]) - const [isLimited, setIsLimited] = useState(false) - const [isLimitExceeded, setIsLimitExceeded] = useState(false) - const [usedSize, setUsedSize] = useState(licenseLimit.workspace_members.size ?? 0) - useEffect(() => { - const limited = licenseLimit.workspace_members.limit > 0 - const used = emails.length + licenseLimit.workspace_members.size - setIsLimited(limited) - setUsedSize(used) - setIsLimitExceeded(limited && (used > licenseLimit.workspace_members.limit)) - }, [licenseLimit, emails]) + const isLimited = licenseLimit.workspace_members.limit > 0 + const usedSize = emails.length + licenseLimit.workspace_members.size + const isLimitExceeded = isLimited && (usedSize > licenseLimit.workspace_members.limit) const locale = useLocale() const [role, setRole] = useState('normal') @@ -85,7 +78,7 @@ const InviteModal = ({ >
diff --git a/web/app/components/plugins/install-plugin/install-bundle/__tests__/index.spec.tsx b/web/app/components/plugins/install-plugin/install-bundle/__tests__/index.spec.tsx index dd3d63aa32..8bd187129c 100644 --- a/web/app/components/plugins/install-plugin/install-bundle/__tests__/index.spec.tsx +++ b/web/app/components/plugins/install-plugin/install-bundle/__tests__/index.spec.tsx @@ -261,6 +261,12 @@ describe('InstallBundle', () => { expect(screen.getByText('plugin.installModal.installPlugin')).toBeInTheDocument() }) + it('should constrain modal height to the viewport', () => { + render() + + expect(screen.getByText('plugin.installModal.installPlugin').parentElement?.parentElement).toHaveClass('max-h-[calc(100dvh-48px)]') + }) + it('should render ReadyToInstall component', () => { render() diff --git a/web/app/components/plugins/install-plugin/install-bundle/index.tsx b/web/app/components/plugins/install-plugin/install-bundle/index.tsx index 0a1021591a..7f07ee2150 100644 --- a/web/app/components/plugins/install-plugin/install-bundle/index.tsx +++ b/web/app/components/plugins/install-plugin/install-bundle/index.tsx @@ -57,7 +57,7 @@ const InstallBundle: FC = ({ foldAnimInto() }} > - +
diff --git a/web/app/components/plugins/install-plugin/install-bundle/steps/__tests__/install.spec.tsx b/web/app/components/plugins/install-plugin/install-bundle/steps/__tests__/install.spec.tsx index 3e848b35f4..f9d55571a7 100644 --- a/web/app/components/plugins/install-plugin/install-bundle/steps/__tests__/install.spec.tsx +++ b/web/app/components/plugins/install-plugin/install-bundle/steps/__tests__/install.spec.tsx @@ -265,6 +265,18 @@ describe('Install Component', () => { expect(screen.getByTestId('all-plugins-count')).toHaveTextContent('2') }) + it('should make the plugin list scrollable inside the modal body', () => { + render() + + expect(screen.getByTestId('install-multi').parentElement).toHaveClass('overflow-y-auto') + }) + + it('should constrain the install step so the plugin list can scroll with many items', () => { + const { container } = render() + + expect(container.firstElementChild).toHaveClass('min-h-0', 'flex-1', 'overflow-hidden') + }) + it('should show singular text when one plugin is selected', async () => { render() diff --git a/web/app/components/plugins/install-plugin/install-bundle/steps/install.tsx b/web/app/components/plugins/install-plugin/install-bundle/steps/install.tsx index a94cd8588d..aede6cbb83 100644 --- a/web/app/components/plugins/install-plugin/install-bundle/steps/install.tsx +++ b/web/app/components/plugins/install-plugin/install-bundle/steps/install.tsx @@ -170,12 +170,12 @@ const Install: FC = ({ const { canInstallPluginFromMarketplace } = useCanInstallPluginFromMarketplace() return ( - <> -
+
+

{t(`${i18nPrefix}.${selectedPluginsNum > 1 ? 'readyToInstallPackages' : 'readyToInstallPackage'}`, { ns: 'plugin', num: selectedPluginsNum })}

-
+
= ({
)} - +
) } export default React.memo(Install) diff --git a/web/app/components/plugins/install-plugin/install-from-local-package/__tests__/index.spec.tsx b/web/app/components/plugins/install-plugin/install-from-local-package/__tests__/index.spec.tsx index cac6250550..65b74bcac0 100644 --- a/web/app/components/plugins/install-plugin/install-from-local-package/__tests__/index.spec.tsx +++ b/web/app/components/plugins/install-plugin/install-from-local-package/__tests__/index.spec.tsx @@ -292,6 +292,12 @@ describe('InstallFromLocalPackage', () => { expect(screen.getByTestId('is-bundle')).toHaveTextContent('true') }) + it('should constrain dialog height so bundle dependency lists can scroll', () => { + render() + + expect(screen.getByRole('dialog')).toHaveClass('max-h-[calc(100dvh-48px)]') + }) + it('should identify package file correctly', () => { render() diff --git a/web/app/components/plugins/install-plugin/install-from-local-package/index.tsx b/web/app/components/plugins/install-plugin/install-from-local-package/index.tsx index b2a780dbdb..032c2197f1 100644 --- a/web/app/components/plugins/install-plugin/install-from-local-package/index.tsx +++ b/web/app/components/plugins/install-plugin/install-from-local-package/index.tsx @@ -93,7 +93,7 @@ const InstallFromLocalPackage: React.FC = ({ foldAnimInto() }} > - +
diff --git a/web/app/components/plugins/install-plugin/install-from-marketplace/__tests__/index.spec.tsx b/web/app/components/plugins/install-plugin/install-from-marketplace/__tests__/index.spec.tsx index 18fa634202..05d5640c96 100644 --- a/web/app/components/plugins/install-plugin/install-from-marketplace/__tests__/index.spec.tsx +++ b/web/app/components/plugins/install-plugin/install-from-marketplace/__tests__/index.spec.tsx @@ -212,6 +212,19 @@ describe('InstallFromMarketplace', () => { expect(screen.getByTestId('bundle-plugins-count')).toHaveTextContent('2') }) + it('should constrain bundle dialog height so dependency lists can scroll', () => { + const dependencies = createMockDependencies() + render( + , + ) + + expect(screen.getByRole('dialog')).toHaveClass('max-h-[calc(100dvh-48px)]') + }) + it('should pass isFromMarketPlace as true to bundle component', () => { const dependencies = createMockDependencies() render( diff --git a/web/app/components/plugins/install-plugin/install-from-marketplace/index.tsx b/web/app/components/plugins/install-plugin/install-from-marketplace/index.tsx index 523009abe1..a53b014e6e 100644 --- a/web/app/components/plugins/install-plugin/install-from-marketplace/index.tsx +++ b/web/app/components/plugins/install-plugin/install-from-marketplace/index.tsx @@ -77,7 +77,7 @@ const InstallFromMarketplace: React.FC = ({ foldAnimInto() }} > - +
diff --git a/web/app/components/share/text-generation/__tests__/info-modal.spec.tsx b/web/app/components/share/text-generation/__tests__/info-modal.spec.tsx index 972c22dfce..63a15604c2 100644 --- a/web/app/components/share/text-generation/__tests__/info-modal.spec.tsx +++ b/web/app/components/share/text-generation/__tests__/info-modal.spec.tsx @@ -72,10 +72,10 @@ describe('InfoModal', () => { expect(screen.getByText('Test App')).toBeInTheDocument() }) - it('should render copyright when provided', async () => { + it('should render copyright in the full rights reserved format when provided', async () => { const siteInfoWithCopyright: SiteInfo = { ...baseSiteInfo, - copyright: 'Dify Inc.', + copyright: 'Dify AI', } await renderModal( @@ -86,7 +86,8 @@ describe('InfoModal', () => { />, ) - expect(screen.getByText(/Dify Inc./)).toBeInTheDocument() + const currentYear = new Date().getFullYear().toString() + expect(screen.getByText(`Copyright © ${currentYear} Dify AI. All Rights Reserved.`)).toBeInTheDocument() }) it('should render current year in copyright', async () => { diff --git a/web/app/components/share/text-generation/info-modal.tsx b/web/app/components/share/text-generation/info-modal.tsx index b851610661..894ca8781e 100644 --- a/web/app/components/share/text-generation/info-modal.tsx +++ b/web/app/components/share/text-generation/info-modal.tsx @@ -16,6 +16,8 @@ const InfoModal = ({ onClose, data, }: Props) => { + const [currentYear] = React.useState(() => new Date().getFullYear()) + return ( - +
@@ -35,15 +37,20 @@ const InfoModal = ({ background={data?.icon_background || appDefaultIconBackground} imageUrl={data?.icon_url} /> -
{data?.title}
+
+
{data?.title}
+
{data?.description}
+
{/* copyright */} {data?.copyright && (
- © - {(new Date()).getFullYear()} + Copyright © + {' '} + {currentYear} {' '} {data?.copyright} + . All Rights Reserved.
)} {data?.custom_disclaimer && ( diff --git a/web/app/layout.tsx b/web/app/layout.tsx index 8bb2069aaf..4eb392fb6d 100644 --- a/web/app/layout.tsx +++ b/web/app/layout.tsx @@ -4,7 +4,6 @@ import { TooltipProvider } from '@langgenius/dify-ui/tooltip' import { Provider as JotaiProvider } from 'jotai/react' import { ThemeProvider } from 'next-themes' import { NuqsAdapter } from 'nuqs/adapters/next/app' -import AmplitudeProvider from '@/app/components/base/amplitude' import { IS_PROD } from '@/config' import { TanstackQueryInitializer } from '@/context/query-client' import { getDatasetMap } from '@/env' @@ -60,7 +59,6 @@ const LocaleLayout = async ({ {...datasetMap} >
- ({ + useSearchParams: vi.fn(), +})) + +vi.mock('@/context/i18n', () => ({ + useLocale: vi.fn(), +})) + +vi.mock('@/utils/timezone', () => ({ + getBrowserTimezone: vi.fn(), +})) + +const mockUseSearchParams = vi.mocked(useSearchParams) +const mockUseLocale = vi.mocked(useLocale) +const mockGetBrowserTimezone = vi.mocked(getBrowserTimezone) + +describe('SocialAuth', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseSearchParams.mockReturnValue(new URLSearchParams() as unknown as ReturnType) + mockUseLocale.mockReturnValue('zh-Hans') + mockGetBrowserTimezone.mockReturnValue('Asia/Shanghai') + }) + + describe('Rendering', () => { + it('should render oauth provider links', () => { + render() + + expect(screen.getByRole('link', { name: 'login.withGitHub' })).toBeInTheDocument() + expect(screen.getByRole('link', { name: 'login.withGoogle' })).toBeInTheDocument() + }) + }) + + describe('OAuth params', () => { + it('should include browser timezone and locale in oauth links', () => { + render() + + expect(screen.getByRole('link', { name: 'login.withGitHub' })).toHaveAttribute( + 'href', + expect.stringContaining('timezone=Asia%2FShanghai'), + ) + expect(screen.getByRole('link', { name: 'login.withGitHub' })).toHaveAttribute( + 'href', + expect.stringContaining('language=zh-Hans'), + ) + expect(screen.getByRole('link', { name: 'login.withGoogle' })).toHaveAttribute( + 'href', + expect.stringContaining('timezone=Asia%2FShanghai'), + ) + expect(screen.getByRole('link', { name: 'login.withGoogle' })).toHaveAttribute( + 'href', + expect.stringContaining('language=zh-Hans'), + ) + }) + + it('should preserve invite token when adding timezone', () => { + mockUseSearchParams.mockReturnValue( + new URLSearchParams('invite_token=invite-123') as unknown as ReturnType, + ) + + render() + + const githubLink = screen.getByRole('link', { name: 'login.withGitHub' }) + expect(githubLink).toHaveAttribute('href', expect.stringContaining('invite_token=invite-123')) + expect(githubLink).toHaveAttribute('href', expect.stringContaining('timezone=Asia%2FShanghai')) + expect(githubLink).toHaveAttribute('href', expect.stringContaining('language=zh-Hans')) + }) + }) + + describe('Edge Cases', () => { + it('should omit timezone when browser timezone is unavailable', () => { + mockGetBrowserTimezone.mockReturnValue(undefined) + + render() + + expect(screen.getByRole('link', { name: 'login.withGitHub' }).getAttribute('href')).not.toContain('timezone=') + }) + }) +}) diff --git a/web/app/signin/components/social-auth.tsx b/web/app/signin/components/social-auth.tsx index 09fa528d1f..17455cebf4 100644 --- a/web/app/signin/components/social-auth.tsx +++ b/web/app/signin/components/social-auth.tsx @@ -2,8 +2,10 @@ import { Button } from '@langgenius/dify-ui/button' import { cn } from '@langgenius/dify-ui/cn' import { useTranslation } from 'react-i18next' import { API_PREFIX } from '@/config' +import { useLocale } from '@/context/i18n' import { useSearchParams } from '@/next/navigation' import { getPurifyHref } from '@/utils' +import { getBrowserTimezone } from '@/utils/timezone' import style from '../page.module.css' type SocialAuthProps = { @@ -13,11 +15,19 @@ type SocialAuthProps = { export default function SocialAuth(props: SocialAuthProps) { const { t } = useTranslation() const searchParams = useSearchParams() + const locale = useLocale() const getOAuthLink = (href: string) => { const url = getPurifyHref(`${API_PREFIX}${href}`) - if (searchParams.has('invite_token')) - return `${url}?${searchParams.toString()}` + const params = new URLSearchParams(searchParams.toString()) + const timezone = getBrowserTimezone() + if (timezone) + params.set('timezone', timezone) + params.set('language', locale) + + const query = params.toString() + if (query) + return `${url}?${query}` return url } diff --git a/web/app/signin/invite-settings/__tests__/page.spec.tsx b/web/app/signin/invite-settings/__tests__/page.spec.tsx new file mode 100644 index 0000000000..12e6f4dfb7 --- /dev/null +++ b/web/app/signin/invite-settings/__tests__/page.spec.tsx @@ -0,0 +1,139 @@ +import type { MockedFunction } from 'vitest' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { useLocale } from '@/context/i18n' +import { useRouter, useSearchParams } from '@/next/navigation' +import { activateMember } from '@/service/common' +import { useInvitationCheck } from '@/service/use-common' +import { getBrowserTimezone } from '@/utils/timezone' +import InviteSettingsPage from '../page' + +vi.mock('@tanstack/react-query', async () => { + const actual = await vi.importActual('@tanstack/react-query') + return { + ...actual, + useSuspenseQuery: vi.fn(() => ({ + data: { + branding: { + enabled: true, + }, + }, + })), + } +}) + +vi.mock('@/context/i18n', () => ({ + useLocale: vi.fn(), +})) + +vi.mock('@/i18n-config', () => ({ + i18n: { + defaultLocale: 'en-US', + }, + setLocaleOnClient: vi.fn(() => Promise.resolve()), +})) + +vi.mock('@/next/navigation', () => ({ + useRouter: vi.fn(), + useSearchParams: vi.fn(), +})) + +vi.mock('@/service/common', () => ({ + activateMember: vi.fn(), +})) + +vi.mock('@/service/use-common', () => ({ + useInvitationCheck: vi.fn(), +})) + +vi.mock('@/utils/timezone', () => ({ + getBrowserTimezone: vi.fn(), + timezones: [ + { value: 'Asia/Shanghai', name: 'Asia/Shanghai' }, + { value: 'America/Los_Angeles', name: 'America/Los_Angeles' }, + ], +})) + +vi.mock('../utils/post-login-redirect', () => ({ + resolvePostLoginRedirect: vi.fn(() => null), +})) + +const mockReplace = vi.fn() +const mockRefetch = vi.fn() + +const mockUseLocale = useLocale as unknown as MockedFunction +const mockUseRouter = useRouter as unknown as MockedFunction +const mockUseSearchParams = useSearchParams as unknown as MockedFunction +const mockActivateMember = activateMember as unknown as MockedFunction +const mockUseInvitationCheck = useInvitationCheck as unknown as MockedFunction +const mockGetBrowserTimezone = getBrowserTimezone as unknown as MockedFunction + +describe('InviteSettingsPage', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseLocale.mockReturnValue('zh-Hans') + mockUseRouter.mockReturnValue({ replace: mockReplace } as unknown as ReturnType) + mockUseSearchParams.mockReturnValue( + new URLSearchParams('invite_token=invite-token') as unknown as ReturnType, + ) + mockUseInvitationCheck.mockReturnValue({ + data: { + is_valid: true, + data: { + workspace_name: 'Acme', + workspace_id: 'workspace-id', + email: 'invitee@example.com', + }, + }, + refetch: mockRefetch, + } as unknown as ReturnType) + mockGetBrowserTimezone.mockReturnValue('Asia/Shanghai') + mockActivateMember.mockResolvedValue({ result: 'success' }) + }) + + describe('Activation payload', () => { + it('should default language to the current UI locale', async () => { + render() + + fireEvent.change(screen.getByLabelText('login.name'), { + target: { value: 'Invitee' }, + }) + fireEvent.click(screen.getByRole('button', { name: 'login.join Acme' })) + + await waitFor(() => { + expect(mockActivateMember).toHaveBeenCalledWith({ + url: '/activate', + body: { + token: 'invite-token', + name: 'Invitee', + interface_language: 'zh-Hans', + timezone: 'Asia/Shanghai', + }, + }) + }) + }) + + it('should fall back to configured default locale when current locale is unsupported', async () => { + mockUseLocale.mockReturnValue('unsupported-locale' as ReturnType) + + render() + + fireEvent.change(screen.getByLabelText('login.name'), { + target: { value: 'Invitee' }, + }) + fireEvent.click(screen.getByRole('button', { name: 'login.join Acme' })) + + await waitFor(() => { + expect(mockActivateMember).toHaveBeenCalledWith({ + url: '/activate', + body: { + token: 'invite-token', + name: 'Invitee', + interface_language: 'en-US', + timezone: 'Asia/Shanghai', + }, + }) + }) + }) + }) +}) diff --git a/web/app/signin/invite-settings/page.tsx b/web/app/signin/invite-settings/page.tsx index af0ff5e07a..ddde11b940 100644 --- a/web/app/signin/invite-settings/page.tsx +++ b/web/app/signin/invite-settings/page.tsx @@ -11,14 +11,15 @@ import { useTranslation } from 'react-i18next' import Input from '@/app/components/base/input' import Loading from '@/app/components/base/loading' import { LICENSE_LINK } from '@/constants/link' -import { setLocaleOnClient } from '@/i18n-config' -import { languages, LanguagesSupported } from '@/i18n-config/language' +import { useLocale } from '@/context/i18n' +import { i18n, setLocaleOnClient } from '@/i18n-config' +import { languages } from '@/i18n-config/language' import Link from '@/next/link' import { useRouter, useSearchParams } from '@/next/navigation' import { activateMember } from '@/service/common' import { systemFeaturesQueryOptions } from '@/service/system-features' import { useInvitationCheck } from '@/service/use-common' -import { timezones } from '@/utils/timezone' +import { getBrowserTimezone, timezones } from '@/utils/timezone' import { resolvePostLoginRedirect } from '../utils/post-login-redirect' type LanguageSelectOption = { @@ -43,15 +44,23 @@ const TIMEZONE_OPTIONS: TimezoneSelectOption[] = timezones.map(item => ({ name: item.name, })) +const getInitialLanguage = (locale: Locale): Locale => { + if (LANGUAGE_OPTIONS.some(item => item.value === locale)) + return locale + + return i18n.defaultLocale +} + export default function InviteSettingsPage() { const { t } = useTranslation() const { data: systemFeatures } = useSuspenseQuery(systemFeaturesQueryOptions()) const router = useRouter() const searchParams = useSearchParams() const token = decodeURIComponent(searchParams.get('invite_token') as string) + const locale = useLocale() const [name, setName] = useState('') - const [language, setLanguage] = useState(LanguagesSupported[0]) - const [timezone, setTimezone] = useState(() => Intl.DateTimeFormat().resolvedOptions().timeZone || 'America/Los_Angeles') + const [language, setLanguage] = useState(() => getInitialLanguage(locale)) + const [timezone, setTimezone] = useState(() => getBrowserTimezone() || 'America/Los_Angeles') const selectedLanguage = LANGUAGE_OPTIONS.find(item => item.value === language) const selectedTimezone = TIMEZONE_OPTIONS.find(item => item.value === timezone) diff --git a/web/app/signup/set-password/__tests__/page.spec.tsx b/web/app/signup/set-password/__tests__/page.spec.tsx new file mode 100644 index 0000000000..e68694db15 --- /dev/null +++ b/web/app/signup/set-password/__tests__/page.spec.tsx @@ -0,0 +1,85 @@ +import type { MockedFunction } from 'vitest' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { useLocale } from '@/context/i18n' +import { useRouter, useSearchParams } from '@/next/navigation' +import { useMailRegister } from '@/service/use-common' +import { getBrowserTimezone } from '@/utils/timezone' +import ChangePasswordForm from '../page' + +vi.mock('@/context/i18n', () => ({ + useLocale: vi.fn(), +})) + +vi.mock('@/next/navigation', () => ({ + useRouter: vi.fn(), + useSearchParams: vi.fn(), +})) + +vi.mock('@/service/use-common', () => ({ + useMailRegister: vi.fn(), +})) + +vi.mock('@/utils/timezone', () => ({ + getBrowserTimezone: vi.fn(), +})) + +vi.mock('@/utils/gtag', () => ({ + sendGAEvent: vi.fn(), +})) + +vi.mock('@/app/components/base/amplitude', () => ({ + trackEvent: vi.fn(), +})) + +vi.mock('@/utils/create-app-tracking', () => ({ + rememberCreateAppExternalAttribution: vi.fn(), +})) + +const mockRegister = vi.fn() +const mockReplace = vi.fn() + +const mockUseLocale = useLocale as unknown as MockedFunction +const mockUseSearchParams = useSearchParams as unknown as MockedFunction +const mockUseRouter = useRouter as unknown as MockedFunction +const mockUseMailRegister = useMailRegister as unknown as MockedFunction +const mockGetBrowserTimezone = getBrowserTimezone as unknown as MockedFunction + +describe('Signup Set Password Page', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseLocale.mockReturnValue('zh-Hans') + mockUseSearchParams.mockReturnValue(new URLSearchParams('token=register-token') as unknown as ReturnType) + mockUseRouter.mockReturnValue({ replace: mockReplace } as unknown as ReturnType) + mockUseMailRegister.mockReturnValue({ + mutateAsync: mockRegister, + isPending: false, + } as unknown as ReturnType) + mockGetBrowserTimezone.mockReturnValue('Asia/Shanghai') + mockRegister.mockResolvedValue({ result: 'fail', data: {} }) + }) + + describe('Registration payload', () => { + it('should submit locale and browser timezone when setting password', async () => { + render() + + fireEvent.change(screen.getByLabelText('common.account.newPassword'), { + target: { value: 'ValidPass123!' }, + }) + fireEvent.change(screen.getByLabelText('common.account.confirmPassword'), { + target: { value: 'ValidPass123!' }, + }) + fireEvent.click(screen.getByRole('button', { name: 'login.changePasswordBtn' })) + + await waitFor(() => { + expect(mockRegister).toHaveBeenCalledWith({ + token: 'register-token', + new_password: 'ValidPass123!', + password_confirm: 'ValidPass123!', + language: 'zh-Hans', + timezone: 'Asia/Shanghai', + }) + }) + }) + }) +}) diff --git a/web/app/signup/set-password/page.tsx b/web/app/signup/set-password/page.tsx index a8eb883078..534b6b55d4 100644 --- a/web/app/signup/set-password/page.tsx +++ b/web/app/signup/set-password/page.tsx @@ -9,10 +9,12 @@ import { useTranslation } from 'react-i18next' import { trackEvent } from '@/app/components/base/amplitude' import Input from '@/app/components/base/input' import { validPassword } from '@/config' +import { useLocale } from '@/context/i18n' import { useRouter, useSearchParams } from '@/next/navigation' import { useMailRegister } from '@/service/use-common' import { rememberCreateAppExternalAttribution } from '@/utils/create-app-tracking' import { sendGAEvent } from '@/utils/gtag' +import { getBrowserTimezone } from '@/utils/timezone' const parseUtmInfo = () => { const utmInfoStr = Cookies.get('utm_info') @@ -32,6 +34,7 @@ const ChangePasswordForm = () => { const router = useRouter() const searchParams = useSearchParams() const token = decodeURIComponent(searchParams.get('token') || '') + const locale = useLocale() const [password, setPassword] = useState('') const [confirmPassword, setConfirmPassword] = useState('') @@ -65,6 +68,8 @@ const ChangePasswordForm = () => { token, new_password: password, password_confirm: confirmPassword, + language: locale, + timezone: getBrowserTimezone(), }) const { result } = res as MailRegisterResponse if (result === 'success') { @@ -88,7 +93,7 @@ const ChangePasswordForm = () => { catch (error) { console.error(error) } - }, [password, token, valid, confirmPassword, register]) + }, [password, token, valid, confirmPassword, register, locale]) return (
()) - .output(type()) + .output(type()) export const tagDeleteContract = base .route({ diff --git a/web/features/tag-management/components/tag-item-editor.tsx b/web/features/tag-management/components/tag-item-editor.tsx index 0e47eafa60..09f8dbf45e 100644 --- a/web/features/tag-management/components/tag-item-editor.tsx +++ b/web/features/tag-management/components/tag-item-editor.tsx @@ -30,7 +30,6 @@ export const TagItemEditor = ({ tag, onTagsChange }: TagItemEditorProps) => { const updateTagMutation = useMutation(consoleQuery.tags.update.mutationOptions()) const deleteTagMutation = useMutation(consoleQuery.tags.delete.mutationOptions()) const [isEditing, setIsEditing] = useState(false) - const [name, setName] = useState(tag.name) const editTag = (tagId: string, name: string) => { if (name === tag.name) { setIsEditing(false) @@ -38,7 +37,6 @@ export const TagItemEditor = ({ tag, onTagsChange }: TagItemEditorProps) => { } if (!name) { toast.error('tag name is empty') - setName(tag.name) setIsEditing(false) return } @@ -53,13 +51,11 @@ export const TagItemEditor = ({ tag, onTagsChange }: TagItemEditorProps) => { }, { onSuccess: () => { toast.success(t('actionMsg.modifiedSuccessfully', { ns: 'common' })) - setName(name) setIsEditing(false) onTagsChange?.() }, onError: () => { toast.error(t('actionMsg.modifiedUnsuccessfully', { ns: 'common' })) - setName(tag.name) setIsEditing(false) }, }) @@ -123,7 +119,22 @@ export const TagItemEditor = ({ tag, onTagsChange }: TagItemEditorProps) => { )} - {isEditing && ( setName(e.target.value)} onKeyDown={e => e.key === 'Enter' && editTag(tag.id, name)} onBlur={() => editTag(tag.id, name)} />)} + {isEditing && ( + { + if (e.key !== 'Enter' || e.nativeEvent.isComposing) + return + + e.preventDefault() + e.currentTarget.blur() + }} + onBlur={e => editTag(tag.id, e.currentTarget.value)} + /> + )}
!open && setShowRemoveModal(false)}> diff --git a/web/i18n/ar-TN/education.json b/web/i18n/ar-TN/education.json index 250e6b7d26..bf8ae1d795 100644 --- a/web/i18n/ar-TN/education.json +++ b/web/i18n/ar-TN/education.json @@ -16,10 +16,10 @@ "currentSigned": "تم تسجيل الدخول حاليًا باسم", "educationPricingConfirm.billingPeriod.monthly": "شهري", "educationPricingConfirm.billingPeriod.yearly": "سنوي", - "educationPricingConfirm.cancel": "إلغاء", - "educationPricingConfirm.continue": "المتابعة بدون خصم", - "educationPricingConfirm.description": "خطتك {{planName}} {{billingPeriod}} لا تدعم الخصم التعليمي. فقط خطة Professional السنوية مؤهلة.", - "educationPricingConfirm.title": "الخصم التعليمي غير متاح", + "educationPricingConfirm.cancel": "الاحتفاظ بالخطة الحالية", + "educationPricingConfirm.continue": "التبديل إلى Professional السنوية", + "educationPricingConfirm.description": "ينطبق الخصم التعليمي على خطة Professional السنوية فقط. الاحتفاظ بخطتك الحالية لن يتضمن الخصم.", + "educationPricingConfirm.title": "الخطة التي اخترتها لا تدعم الخصم التعليمي", "emailLabel": "بريدك الإلكتروني الحالي", "form.schoolName.placeholder": "أدخل الاسم الرسمي الكامل لمدرستك", "form.schoolName.title": "اسم مدرستك", diff --git a/web/i18n/de-DE/education.json b/web/i18n/de-DE/education.json index 32dd76bd46..c5153c7f75 100644 --- a/web/i18n/de-DE/education.json +++ b/web/i18n/de-DE/education.json @@ -16,10 +16,10 @@ "currentSigned": "DERZEIT ANGEMELDET ALS", "educationPricingConfirm.billingPeriod.monthly": "monatlich", "educationPricingConfirm.billingPeriod.yearly": "jährlich", - "educationPricingConfirm.cancel": "Abbrechen", - "educationPricingConfirm.continue": "Ohne Rabatt fortfahren", - "educationPricingConfirm.description": "Ihr {{planName}} {{billingPeriod}} Plan unterstützt den Bildungsrabatt nicht. Nur der Professional-Jahresplan ist berechtigt.", - "educationPricingConfirm.title": "Bildungsrabatt nicht verfügbar", + "educationPricingConfirm.cancel": "Aktuellen Plan behalten", + "educationPricingConfirm.continue": "Zu Professional jährlich wechseln", + "educationPricingConfirm.description": "Der Bildungsrabatt gilt nur für den jährlichen Professional-Plan. Wenn Sie Ihren aktuellen Plan behalten, ist der Rabatt nicht enthalten.", + "educationPricingConfirm.title": "Ihr ausgewählter Plan unterstützt den Bildungsrabatt nicht", "emailLabel": "Ihre aktuelle E-Mail", "form.schoolName.placeholder": "Geben Sie den offiziellen, unabgekürzten Namen Ihrer Schule ein.", "form.schoolName.title": "Ihr Schulname", diff --git a/web/i18n/en-US/education.json b/web/i18n/en-US/education.json index e26b1cc24d..479ea5d28f 100644 --- a/web/i18n/en-US/education.json +++ b/web/i18n/en-US/education.json @@ -16,10 +16,10 @@ "currentSigned": "CURRENTLY SIGNED IN AS", "educationPricingConfirm.billingPeriod.monthly": "monthly", "educationPricingConfirm.billingPeriod.yearly": "annual", - "educationPricingConfirm.cancel": "Cancel", - "educationPricingConfirm.continue": "Continue without discount", - "educationPricingConfirm.description": "Your {{planName}} {{billingPeriod}} plan doesn't support the education discount. Only the Professional annual plan is eligible.", - "educationPricingConfirm.title": "Education discount not available", + "educationPricingConfirm.cancel": "Keep current plan", + "educationPricingConfirm.continue": "Switch to Professional Annual", + "educationPricingConfirm.description": "The education discount applies to the Professional annual plan only. Keeping your current plan won't include the discount.", + "educationPricingConfirm.title": "Your selected plan doesn't support the education discount", "emailLabel": "Your current email", "form.schoolName.placeholder": "Enter the official, unabbreviated name of your school", "form.schoolName.title": "Your School Name", diff --git a/web/i18n/es-ES/education.json b/web/i18n/es-ES/education.json index 0b2ac91b00..5784797977 100644 --- a/web/i18n/es-ES/education.json +++ b/web/i18n/es-ES/education.json @@ -16,10 +16,10 @@ "currentSigned": "ACTUALMENTE CONECTADO COMO", "educationPricingConfirm.billingPeriod.monthly": "mensual", "educationPricingConfirm.billingPeriod.yearly": "anual", - "educationPricingConfirm.cancel": "Cancelar", - "educationPricingConfirm.continue": "Continuar sin descuento", - "educationPricingConfirm.description": "Tu plan {{planName}} {{billingPeriod}} no admite el descuento educativo. Solo el plan Professional anual es elegible.", - "educationPricingConfirm.title": "Descuento educativo no disponible", + "educationPricingConfirm.cancel": "Mantener el plan actual", + "educationPricingConfirm.continue": "Cambiar a Professional anual", + "educationPricingConfirm.description": "El descuento educativo solo se aplica al plan Professional anual. Si mantienes tu plan actual, no se incluirá el descuento.", + "educationPricingConfirm.title": "El plan seleccionado no admite el descuento educativo", "emailLabel": "Tu correo electrónico actual", "form.schoolName.placeholder": "Ingrese el nombre oficial y completo de su escuela", "form.schoolName.title": "El nombre de tu escuela", diff --git a/web/i18n/fa-IR/education.json b/web/i18n/fa-IR/education.json index 63150df78b..ee1e282719 100644 --- a/web/i18n/fa-IR/education.json +++ b/web/i18n/fa-IR/education.json @@ -16,10 +16,10 @@ "currentSigned": "اکنون به عنوان", "educationPricingConfirm.billingPeriod.monthly": "ماهانه", "educationPricingConfirm.billingPeriod.yearly": "سالانه", - "educationPricingConfirm.cancel": "لغو", - "educationPricingConfirm.continue": "ادامه بدون تخفیف", - "educationPricingConfirm.description": "طرح {{planName}} {{billingPeriod}} شما از تخفیف آموزشی پشتیبانی نمی‌کند. فقط طرح سالانه Professional واجد شرایط است.", - "educationPricingConfirm.title": "تخفیف آموزشی در دسترس نیست", + "educationPricingConfirm.cancel": "حفظ طرح فعلی", + "educationPricingConfirm.continue": "تغییر به Professional سالانه", + "educationPricingConfirm.description": "تخفیف آموزشی فقط برای طرح سالانه Professional اعمال می‌شود. با حفظ طرح فعلی، این تخفیف شامل نمی‌شود.", + "educationPricingConfirm.title": "طرح انتخاب‌شده شما از تخفیف آموزشی پشتیبانی نمی‌کند", "emailLabel": "ایمیل فعلی شما", "form.schoolName.placeholder": "نام رسمی و کامل مدرسه خود را وارد کنید", "form.schoolName.title": "نام مدرسه شما", diff --git a/web/i18n/fr-FR/education.json b/web/i18n/fr-FR/education.json index d201b6f031..1757a84911 100644 --- a/web/i18n/fr-FR/education.json +++ b/web/i18n/fr-FR/education.json @@ -16,10 +16,10 @@ "currentSigned": "ACTUELLEMENT CONNECTÉ EN TANT QUE", "educationPricingConfirm.billingPeriod.monthly": "mensuel", "educationPricingConfirm.billingPeriod.yearly": "annuel", - "educationPricingConfirm.cancel": "Annuler", - "educationPricingConfirm.continue": "Continuer sans remise", - "educationPricingConfirm.description": "Votre plan {{planName}} {{billingPeriod}} ne prend pas en charge la remise éducative. Seul le plan Professional annuel est éligible.", - "educationPricingConfirm.title": "Remise éducative non disponible", + "educationPricingConfirm.cancel": "Conserver le plan actuel", + "educationPricingConfirm.continue": "Passer à Professional annuel", + "educationPricingConfirm.description": "La remise éducation s'applique uniquement au plan Professional annuel. En conservant votre plan actuel, la remise ne sera pas incluse.", + "educationPricingConfirm.title": "Le plan sélectionné ne prend pas en charge la remise éducation", "emailLabel": "Votre email actuel", "form.schoolName.placeholder": "Entrez le nom officiel et complet de votre école", "form.schoolName.title": "Le nom de votre école", diff --git a/web/i18n/hi-IN/education.json b/web/i18n/hi-IN/education.json index a580491cb7..3d4df9db51 100644 --- a/web/i18n/hi-IN/education.json +++ b/web/i18n/hi-IN/education.json @@ -16,10 +16,10 @@ "currentSigned": "वर्तमान में साइन इन किया गया है के रूप में", "educationPricingConfirm.billingPeriod.monthly": "मासिक", "educationPricingConfirm.billingPeriod.yearly": "वार्षिक", - "educationPricingConfirm.cancel": "रद्द करें", - "educationPricingConfirm.continue": "छूट के बिना जारी रखें", - "educationPricingConfirm.description": "आपका {{planName}} {{billingPeriod}} प्लान शिक्षा छूट का समर्थन नहीं करता। केवल Professional वार्षिक प्लान पात्र है।", - "educationPricingConfirm.title": "शिक्षा छूट उपलब्ध नहीं", + "educationPricingConfirm.cancel": "वर्तमान प्लान रखें", + "educationPricingConfirm.continue": "Professional वार्षिक पर स्विच करें", + "educationPricingConfirm.description": "शिक्षा छूट केवल Professional वार्षिक प्लान पर लागू होती है। अपना वर्तमान प्लान रखने पर छूट शामिल नहीं होगी।", + "educationPricingConfirm.title": "आपका चुना हुआ प्लान शिक्षा छूट का समर्थन नहीं करता", "emailLabel": "आपका वर्तमान ईमेल", "form.schoolName.placeholder": "अपनी स्कूल का आधिकारिक, बिना संक्षिप्त नाम दर्ज करें", "form.schoolName.title": "आपके स्कूल का नाम", diff --git a/web/i18n/id-ID/education.json b/web/i18n/id-ID/education.json index 3fa6d70a60..6d37be99dd 100644 --- a/web/i18n/id-ID/education.json +++ b/web/i18n/id-ID/education.json @@ -16,10 +16,10 @@ "currentSigned": "SAAT INI MASUK SEBAGAI", "educationPricingConfirm.billingPeriod.monthly": "bulanan", "educationPricingConfirm.billingPeriod.yearly": "tahunan", - "educationPricingConfirm.cancel": "Batal", - "educationPricingConfirm.continue": "Lanjutkan tanpa diskon", - "educationPricingConfirm.description": "Paket {{planName}} {{billingPeriod}} Anda tidak mendukung diskon pendidikan. Hanya paket Professional tahunan yang memenuhi syarat.", - "educationPricingConfirm.title": "Diskon pendidikan tidak tersedia", + "educationPricingConfirm.cancel": "Tetap gunakan paket saat ini", + "educationPricingConfirm.continue": "Beralih ke Professional Tahunan", + "educationPricingConfirm.description": "Diskon pendidikan hanya berlaku untuk paket Professional tahunan. Jika tetap menggunakan paket saat ini, diskon tidak akan disertakan.", + "educationPricingConfirm.title": "Paket yang Anda pilih tidak mendukung diskon pendidikan", "emailLabel": "Email Anda saat ini", "form.schoolName.placeholder": "Masukkan nama resmi sekolah Anda yang tidak disingkat", "form.schoolName.title": "Nama Sekolah Anda", diff --git a/web/i18n/it-IT/education.json b/web/i18n/it-IT/education.json index b1ccc69308..313c9b404b 100644 --- a/web/i18n/it-IT/education.json +++ b/web/i18n/it-IT/education.json @@ -16,10 +16,10 @@ "currentSigned": "ATTUALMENTE ACCEDUTO COME", "educationPricingConfirm.billingPeriod.monthly": "mensile", "educationPricingConfirm.billingPeriod.yearly": "annuale", - "educationPricingConfirm.cancel": "Annulla", - "educationPricingConfirm.continue": "Continua senza sconto", - "educationPricingConfirm.description": "Il tuo piano {{planName}} {{billingPeriod}} non supporta lo sconto educativo. Solo il piano Professional annuale è idoneo.", - "educationPricingConfirm.title": "Sconto educativo non disponibile", + "educationPricingConfirm.cancel": "Mantieni il piano attuale", + "educationPricingConfirm.continue": "Passa a Professional annuale", + "educationPricingConfirm.description": "Lo sconto Education si applica solo al piano Professional annuale. Mantenendo il piano attuale, lo sconto non verrà incluso.", + "educationPricingConfirm.title": "Il piano selezionato non supporta lo sconto Education", "emailLabel": "La tua email attuale", "form.schoolName.placeholder": "Inserisci il nome ufficiale e completo della tua scuola", "form.schoolName.title": "Il Nome della tua Scuola", diff --git a/web/i18n/ja-JP/education.json b/web/i18n/ja-JP/education.json index 978b561ff0..9473f203cd 100644 --- a/web/i18n/ja-JP/education.json +++ b/web/i18n/ja-JP/education.json @@ -16,10 +16,10 @@ "currentSigned": "現在ログイン中のアカウントは", "educationPricingConfirm.billingPeriod.monthly": "月次", "educationPricingConfirm.billingPeriod.yearly": "年次", - "educationPricingConfirm.cancel": "キャンセル", - "educationPricingConfirm.continue": "割引なしで続行", - "educationPricingConfirm.description": "{{planName}} {{billingPeriod}} プランは教育割引に対応していません。Professional 年次プランのみが対象です。", - "educationPricingConfirm.title": "教育割引は利用できません", + "educationPricingConfirm.cancel": "現在のプランを維持", + "educationPricingConfirm.continue": "Professional 年間プランに切り替える", + "educationPricingConfirm.description": "教育割引は Professional 年間プランにのみ適用されます。現在のプランを維持すると、割引は適用されません。", + "educationPricingConfirm.title": "選択したプランは教育割引に対応していません", "emailLabel": "現在のメールアドレス", "form.schoolName.placeholder": "学校の正式名称(省略不可)を入力してください。", "form.schoolName.title": "学校名", diff --git a/web/i18n/ko-KR/education.json b/web/i18n/ko-KR/education.json index c7db9a99b7..1370265ae4 100644 --- a/web/i18n/ko-KR/education.json +++ b/web/i18n/ko-KR/education.json @@ -16,10 +16,10 @@ "currentSigned": "현재 로그인 중입니다", "educationPricingConfirm.billingPeriod.monthly": "월간", "educationPricingConfirm.billingPeriod.yearly": "연간", - "educationPricingConfirm.cancel": "취소", - "educationPricingConfirm.continue": "할인 없이 계속", - "educationPricingConfirm.description": "{{planName}} {{billingPeriod}} 플랜은 교육 할인을 지원하지 않습니다. Professional 연간 플랜만 자격이 있습니다.", - "educationPricingConfirm.title": "교육 할인 불가", + "educationPricingConfirm.cancel": "현재 플랜 유지", + "educationPricingConfirm.continue": "Professional 연간으로 전환", + "educationPricingConfirm.description": "교육 할인은 Professional 연간 플랜에만 적용됩니다. 현재 플랜을 유지하면 할인이 포함되지 않습니다.", + "educationPricingConfirm.title": "선택한 플랜은 교육 할인을 지원하지 않습니다", "emailLabel": "현재 이메일", "form.schoolName.placeholder": "귀하의 학교의 공식 약어가 아닌 전체 이름을 입력하세요.", "form.schoolName.title": "당신의 학교 이름", diff --git a/web/i18n/nl-NL/education.json b/web/i18n/nl-NL/education.json index 6bf16ef619..4a6d14bf0e 100644 --- a/web/i18n/nl-NL/education.json +++ b/web/i18n/nl-NL/education.json @@ -16,10 +16,10 @@ "currentSigned": "CURRENTLY SIGNED IN AS", "educationPricingConfirm.billingPeriod.monthly": "maandelijks", "educationPricingConfirm.billingPeriod.yearly": "jaarlijks", - "educationPricingConfirm.cancel": "Annuleren", - "educationPricingConfirm.continue": "Doorgaan zonder korting", - "educationPricingConfirm.description": "Uw {{planName}} {{billingPeriod}} abonnement ondersteunt de onderwijskorting niet. Alleen het jaarlijkse Professional abonnement komt in aanmerking.", - "educationPricingConfirm.title": "Onderwijskorting niet beschikbaar", + "educationPricingConfirm.cancel": "Huidig abonnement behouden", + "educationPricingConfirm.continue": "Overschakelen naar Professional jaarlijks", + "educationPricingConfirm.description": "De onderwijskorting is alleen van toepassing op het jaarlijkse Professional-abonnement. Als u uw huidige abonnement behoudt, is de korting niet inbegrepen.", + "educationPricingConfirm.title": "Uw geselecteerde abonnement ondersteunt de onderwijskorting niet", "emailLabel": "Your current email", "form.schoolName.placeholder": "Enter the official, unabbreviated name of your school", "form.schoolName.title": "Your School Name", diff --git a/web/i18n/pl-PL/education.json b/web/i18n/pl-PL/education.json index cb71de4572..139d7912dd 100644 --- a/web/i18n/pl-PL/education.json +++ b/web/i18n/pl-PL/education.json @@ -16,10 +16,10 @@ "currentSigned": "AKTUALNIE ZALOGOWANY JAKO", "educationPricingConfirm.billingPeriod.monthly": "miesięcznie", "educationPricingConfirm.billingPeriod.yearly": "rocznie", - "educationPricingConfirm.cancel": "Anuluj", - "educationPricingConfirm.continue": "Kontynuuj bez rabatu", - "educationPricingConfirm.description": "Twój plan {{planName}} {{billingPeriod}} nie obsługuje rabatu edukacyjnego. Tylko roczny plan Professional jest uprawniony.", - "educationPricingConfirm.title": "Rabat edukacyjny niedostępny", + "educationPricingConfirm.cancel": "Zachowaj obecny plan", + "educationPricingConfirm.continue": "Przełącz na Professional roczny", + "educationPricingConfirm.description": "Zniżka edukacyjna dotyczy tylko rocznego planu Professional. Pozostanie przy obecnym planie nie obejmie zniżki.", + "educationPricingConfirm.title": "Wybrany plan nie obsługuje zniżki edukacyjnej", "emailLabel": "Twój aktualny email", "form.schoolName.placeholder": "Wpisz oficjalną, pełną nazwę swojej szkoły", "form.schoolName.title": "Nazwa Twojej Szkoły", diff --git a/web/i18n/pt-BR/education.json b/web/i18n/pt-BR/education.json index c6929f5840..9441542015 100644 --- a/web/i18n/pt-BR/education.json +++ b/web/i18n/pt-BR/education.json @@ -16,10 +16,10 @@ "currentSigned": "ATUALMENTE CONECTADO COMO", "educationPricingConfirm.billingPeriod.monthly": "mensal", "educationPricingConfirm.billingPeriod.yearly": "anual", - "educationPricingConfirm.cancel": "Cancelar", - "educationPricingConfirm.continue": "Continuar sem desconto", - "educationPricingConfirm.description": "Seu plano {{planName}} {{billingPeriod}} não suporta o desconto educacional. Apenas o plano Professional anual é elegível.", - "educationPricingConfirm.title": "Desconto educacional não disponível", + "educationPricingConfirm.cancel": "Manter plano atual", + "educationPricingConfirm.continue": "Mudar para Professional anual", + "educationPricingConfirm.description": "O desconto educacional se aplica apenas ao plano Professional anual. Manter seu plano atual não incluirá o desconto.", + "educationPricingConfirm.title": "O plano selecionado não aceita o desconto educacional", "emailLabel": "Seu e-mail atual", "form.schoolName.placeholder": "Digite o nome oficial e não abreviado da sua escola", "form.schoolName.title": "O nome da sua escola", diff --git a/web/i18n/ro-RO/education.json b/web/i18n/ro-RO/education.json index 61d257f08b..a361ec2bbe 100644 --- a/web/i18n/ro-RO/education.json +++ b/web/i18n/ro-RO/education.json @@ -16,10 +16,10 @@ "currentSigned": "CONEXIUNE ÎN PREZENT CA", "educationPricingConfirm.billingPeriod.monthly": "lunar", "educationPricingConfirm.billingPeriod.yearly": "anual", - "educationPricingConfirm.cancel": "Anulează", - "educationPricingConfirm.continue": "Continuă fără reducere", - "educationPricingConfirm.description": "Planul tău {{planName}} {{billingPeriod}} nu suportă reducerea educațională. Doar planul Professional anual este eligibil.", - "educationPricingConfirm.title": "Reducerea educațională nu este disponibilă", + "educationPricingConfirm.cancel": "Păstrează planul curent", + "educationPricingConfirm.continue": "Treci la Professional anual", + "educationPricingConfirm.description": "Reducerea educațională se aplică doar planului Professional anual. Dacă păstrezi planul curent, reducerea nu va fi inclusă.", + "educationPricingConfirm.title": "Planul selectat nu acceptă reducerea educațională", "emailLabel": "Emailul tău curent", "form.schoolName.placeholder": "Introduceți numele oficial, neabbreviat al școlii dumneavoastră", "form.schoolName.title": "Numele Școlii Tale", diff --git a/web/i18n/ru-RU/education.json b/web/i18n/ru-RU/education.json index ce9300745f..58534dd57e 100644 --- a/web/i18n/ru-RU/education.json +++ b/web/i18n/ru-RU/education.json @@ -16,10 +16,10 @@ "currentSigned": "В ДАННЫЙ МОМЕНТ ВХОД В ПРОФИЛЬ КАК", "educationPricingConfirm.billingPeriod.monthly": "ежемесячно", "educationPricingConfirm.billingPeriod.yearly": "ежегодно", - "educationPricingConfirm.cancel": "Отмена", - "educationPricingConfirm.continue": "Продолжить без скидки", - "educationPricingConfirm.description": "Ваш план {{planName}} {{billingPeriod}} не поддерживает образовательную скидку. Только годовой план Professional имеет право на скидку.", - "educationPricingConfirm.title": "Образовательная скидка недоступна", + "educationPricingConfirm.cancel": "Оставить текущий план", + "educationPricingConfirm.continue": "Перейти на Professional годовой", + "educationPricingConfirm.description": "Образовательная скидка применяется только к годовому плану Professional. Если оставить текущий план, скидка не будет включена.", + "educationPricingConfirm.title": "Выбранный план не поддерживает образовательную скидку", "emailLabel": "Ваш текущий адрес электронной почты", "form.schoolName.placeholder": "Введите официальное, полное название вашей школы", "form.schoolName.title": "Название вашей школы", diff --git a/web/i18n/sl-SI/education.json b/web/i18n/sl-SI/education.json index 94abe1f58d..7855bd3e74 100644 --- a/web/i18n/sl-SI/education.json +++ b/web/i18n/sl-SI/education.json @@ -16,10 +16,10 @@ "currentSigned": "Trenutno prijavljen kot", "educationPricingConfirm.billingPeriod.monthly": "mesečno", "educationPricingConfirm.billingPeriod.yearly": "letno", - "educationPricingConfirm.cancel": "Prekliči", - "educationPricingConfirm.continue": "Nadaljuj brez popusta", - "educationPricingConfirm.description": "Vaš načrt {{planName}} {{billingPeriod}} ne podpira izobraževalnega popusta. Do popusta je upravičen samo letni načrt Professional.", - "educationPricingConfirm.title": "Izobraževalni popust ni na voljo", + "educationPricingConfirm.cancel": "Obdrži trenutni paket", + "educationPricingConfirm.continue": "Preklopi na letni Professional", + "educationPricingConfirm.description": "Izobraževalni popust velja samo za letni paket Professional. Če obdržite trenutni paket, popust ne bo vključen.", + "educationPricingConfirm.title": "Izbrani paket ne podpira izobraževalnega popusta", "emailLabel": "Vaš trenutni elektronski naslov", "form.schoolName.placeholder": "Vpišite uradno, neokrnjeno ime vaše šole", "form.schoolName.title": "Ime vaše šole", diff --git a/web/i18n/th-TH/education.json b/web/i18n/th-TH/education.json index b6b50a9181..830440802a 100644 --- a/web/i18n/th-TH/education.json +++ b/web/i18n/th-TH/education.json @@ -16,10 +16,10 @@ "currentSigned": "ลงชื่อเข้าใช้ในฐานะ", "educationPricingConfirm.billingPeriod.monthly": "รายเดือน", "educationPricingConfirm.billingPeriod.yearly": "รายปี", - "educationPricingConfirm.cancel": "ยกเลิก", - "educationPricingConfirm.continue": "ดำเนินการต่อโดยไม่มีส่วนลด", - "educationPricingConfirm.description": "แผน {{planName}} {{billingPeriod}} ของคุณไม่รองรับส่วนลดการศึกษา เฉพาะแผน Professional รายปีเท่านั้นที่มีสิทธิ์", - "educationPricingConfirm.title": "ส่วนลดการศึกษาไม่พร้อมใช้งาน", + "educationPricingConfirm.cancel": "ใช้แผนปัจจุบันต่อ", + "educationPricingConfirm.continue": "เปลี่ยนเป็น Professional รายปี", + "educationPricingConfirm.description": "ส่วนลดการศึกษาใช้ได้เฉพาะกับแผน Professional รายปีเท่านั้น หากใช้แผนปัจจุบันต่อ จะไม่มีส่วนลดนี้รวมอยู่ด้วย", + "educationPricingConfirm.title": "แผนที่คุณเลือกไม่รองรับส่วนลดการศึกษา", "emailLabel": "อีเมลปัจจุบันของคุณ", "form.schoolName.placeholder": "กรุณาใส่ชื่อของโรงเรียนอย่างเป็นทางการที่ไม่มีการย่อ", "form.schoolName.title": "ชื่อโรงเรียนของคุณ", diff --git a/web/i18n/tr-TR/education.json b/web/i18n/tr-TR/education.json index 61e03379b4..d3b77aad33 100644 --- a/web/i18n/tr-TR/education.json +++ b/web/i18n/tr-TR/education.json @@ -16,10 +16,10 @@ "currentSigned": "ŞU ANDA GİRİŞ YAPILDIĞI KİŞİ", "educationPricingConfirm.billingPeriod.monthly": "aylık", "educationPricingConfirm.billingPeriod.yearly": "yıllık", - "educationPricingConfirm.cancel": "İptal", - "educationPricingConfirm.continue": "İndirim olmadan devam et", - "educationPricingConfirm.description": "{{planName}} {{billingPeriod}} planınız eğitim indirimini desteklemiyor. Yalnızca yıllık Professional planı uygundur.", - "educationPricingConfirm.title": "Eğitim indirimi mevcut değil", + "educationPricingConfirm.cancel": "Mevcut planı koru", + "educationPricingConfirm.continue": "Professional yıllık plana geç", + "educationPricingConfirm.description": "Eğitim indirimi yalnızca yıllık Professional planı için geçerlidir. Mevcut planınızı korursanız indirim dahil edilmez.", + "educationPricingConfirm.title": "Seçtiğiniz plan eğitim indirimini desteklemiyor", "emailLabel": "Şu anki e-posta adresin", "form.schoolName.placeholder": "Okulunuzun resmi, kısaltılmamış adını girin", "form.schoolName.title": "Okulunuzun Adı", diff --git a/web/i18n/uk-UA/education.json b/web/i18n/uk-UA/education.json index d0cf4a77de..6cfa324666 100644 --- a/web/i18n/uk-UA/education.json +++ b/web/i18n/uk-UA/education.json @@ -16,10 +16,10 @@ "currentSigned": "В даний момент ви підписані як", "educationPricingConfirm.billingPeriod.monthly": "щомісячно", "educationPricingConfirm.billingPeriod.yearly": "щорічно", - "educationPricingConfirm.cancel": "Скасувати", - "educationPricingConfirm.continue": "Продовжити без знижки", - "educationPricingConfirm.description": "Ваш план {{planName}} {{billingPeriod}} не підтримує освітню знижку. Лише річний план Professional має право на знижку.", - "educationPricingConfirm.title": "Освітня знижка недоступна", + "educationPricingConfirm.cancel": "Залишити поточний план", + "educationPricingConfirm.continue": "Перейти на Professional річний", + "educationPricingConfirm.description": "Освітня знижка застосовується лише до річного плану Professional. Якщо залишити поточний план, знижку не буде включено.", + "educationPricingConfirm.title": "Вибраний план не підтримує освітню знижку", "emailLabel": "Ваш поточний електронний лист", "form.schoolName.placeholder": "Введіть офіційну, повну назву вашої школи", "form.schoolName.title": "Ваша назва школи", diff --git a/web/i18n/vi-VN/education.json b/web/i18n/vi-VN/education.json index 2edc6965a1..65429bcd8a 100644 --- a/web/i18n/vi-VN/education.json +++ b/web/i18n/vi-VN/education.json @@ -16,10 +16,10 @@ "currentSigned": "HIỆN ĐANG ĐĂNG NHẬP VÀO", "educationPricingConfirm.billingPeriod.monthly": "hàng tháng", "educationPricingConfirm.billingPeriod.yearly": "hàng năm", - "educationPricingConfirm.cancel": "Hủy", - "educationPricingConfirm.continue": "Tiếp tục không có giảm giá", - "educationPricingConfirm.description": "Gói {{planName}} {{billingPeriod}} của bạn không hỗ trợ giảm giá giáo dục. Chỉ gói Professional hàng năm mới được áp dụng.", - "educationPricingConfirm.title": "Giảm giá giáo dục không khả dụng", + "educationPricingConfirm.cancel": "Giữ gói hiện tại", + "educationPricingConfirm.continue": "Chuyển sang Professional hằng năm", + "educationPricingConfirm.description": "Giảm giá giáo dục chỉ áp dụng cho gói Professional hằng năm. Nếu giữ gói hiện tại, giảm giá sẽ không được áp dụng.", + "educationPricingConfirm.title": "Gói bạn chọn không hỗ trợ giảm giá giáo dục", "emailLabel": "Email hiện tại của bạn", "form.schoolName.placeholder": "Nhập tên chính thức, không viết tắt của trường bạn", "form.schoolName.title": "Tên Trường Của Bạn", diff --git a/web/i18n/zh-Hans/education.json b/web/i18n/zh-Hans/education.json index 657d265424..f82c9b3405 100644 --- a/web/i18n/zh-Hans/education.json +++ b/web/i18n/zh-Hans/education.json @@ -16,10 +16,10 @@ "currentSigned": "您当前登录的账户是", "educationPricingConfirm.billingPeriod.monthly": "月付", "educationPricingConfirm.billingPeriod.yearly": "年付", - "educationPricingConfirm.cancel": "取消", - "educationPricingConfirm.continue": "不使用优惠继续", - "educationPricingConfirm.description": "你的 {{planName}} 计划{{billingPeriod}}不支持教育优惠。只有 Professional 的年付计划符合条件。", - "educationPricingConfirm.title": "教育优惠不适用于该计划", + "educationPricingConfirm.cancel": "保留当前计划", + "educationPricingConfirm.continue": "切换到 Professional 年付", + "educationPricingConfirm.description": "教育优惠仅适用于 Professional 年付计划。保留当前计划将不包含该优惠。", + "educationPricingConfirm.title": "你选择的计划不支持教育优惠", "emailLabel": "您当前的邮箱", "form.schoolName.placeholder": "请输入您的学校的官方全称(不得缩写)", "form.schoolName.title": "您的学校名称", diff --git a/web/i18n/zh-Hant/education.json b/web/i18n/zh-Hant/education.json index 7447470a4c..76ce672e14 100644 --- a/web/i18n/zh-Hant/education.json +++ b/web/i18n/zh-Hant/education.json @@ -16,10 +16,10 @@ "currentSigned": "當前以以下身份登入", "educationPricingConfirm.billingPeriod.monthly": "月付", "educationPricingConfirm.billingPeriod.yearly": "年付", - "educationPricingConfirm.cancel": "取消", - "educationPricingConfirm.continue": "不使用優惠繼續", - "educationPricingConfirm.description": "你的 {{planName}} 方案{{billingPeriod}}不支援教育優惠。只有 Professional 的年付方案符合資格。", - "educationPricingConfirm.title": "教育優惠不適用於此方案", + "educationPricingConfirm.cancel": "保留目前方案", + "educationPricingConfirm.continue": "切換到 Professional 年付", + "educationPricingConfirm.description": "教育優惠僅適用於 Professional 年付方案。保留目前方案將不包含此優惠。", + "educationPricingConfirm.title": "你選擇的方案不支援教育優惠", "emailLabel": "您當前的電子郵件", "form.schoolName.placeholder": "請輸入您學校的正式全名", "form.schoolName.title": "你的學校名稱", diff --git a/web/service/client.spec.ts b/web/service/client.spec.ts index 57ca8765e7..641cbda7c6 100644 --- a/web/service/client.spec.ts +++ b/web/service/client.spec.ts @@ -187,15 +187,20 @@ describe('consoleQuery tag mutation defaults', () => { queryClient.setQueryData(appListKey, [targetTag, otherTag]) queryClient.setQueryData(knowledgeListKey, [knowledgeTag]) + const updatedTag = createTag({ + ...targetTag, + name: 'After', + binding_count: 5, + }) const mutationOptions = consoleQuery.tags.update.mutationOptions() await mutationOptions.onSuccess?.( - undefined, + updatedTag, { params: { tagId: targetTag.id, }, body: { - name: 'After', + name: 'Ignored Client Name', }, }, undefined, @@ -203,10 +208,7 @@ describe('consoleQuery tag mutation defaults', () => { ) expect(queryClient.getQueryData(appListKey)).toEqual([ - { - ...targetTag, - name: 'After', - }, + updatedTag, otherTag, ]) expect(queryClient.getQueryData(knowledgeListKey)).toEqual([knowledgeTag]) diff --git a/web/service/client.ts b/web/service/client.ts index 1c33423295..984779e2b2 100644 --- a/web/service/client.ts +++ b/web/service/client.ts @@ -108,16 +108,13 @@ export const consoleQuery = createTanstackQueryUtils(consoleClient, { }, update: { mutationOptions: { - onSuccess: (_data, variables, _onMutateResult, context) => { + onSuccess: (updatedTag, variables, _onMutateResult, context) => { context.client.setQueriesData( { queryKey: consoleQuery.tags.list.key({ type: 'query' }), }, (oldTags: Tag[] | undefined) => oldTags?.map(tag => tag.id === variables.params.tagId - ? { - ...tag, - name: variables.body.name, - } + ? updatedTag : tag), ) }, diff --git a/web/service/common.ts b/web/service/common.ts index d2fcf8c823..2b72370086 100644 --- a/web/service/common.ts +++ b/web/service/common.ts @@ -339,7 +339,13 @@ export const uploadRemoteFileInfo = (url: string, isPublic?: boolean, silent?: b export const sendEMailLoginCode = (email: string, language = 'en-US'): Promise => post('/email-code-login', { body: { email, language } }) -export const emailLoginWithCode = (data: { email: string, code: string, token: string, language: string }): Promise => +export const emailLoginWithCode = (data: { + email: string + code: string + token: string + language: string + timezone?: string +}): Promise => post('/email-code-login/validity', { body: data }) export const sendResetPasswordCode = (email: string, language = 'en-US'): Promise => diff --git a/web/service/use-common.ts b/web/service/use-common.ts index 0154be09ff..9d17585bfa 100644 --- a/web/service/use-common.ts +++ b/web/service/use-common.ts @@ -178,7 +178,13 @@ export type MailRegisterResponse = { result: string, data: {} } export const useMailRegister = () => { return useMutation({ mutationKey: [NAME_SPACE, 'mail-register'], - mutationFn: (body: { token: string, new_password: string, password_confirm: string }) => { + mutationFn: (body: { + token: string + new_password: string + password_confirm: string + language?: string + timezone?: string + }) => { return post('/email-register', { body }) }, }) diff --git a/web/utils/timezone.ts b/web/utils/timezone.ts index e854ae7d5a..38e2e04d50 100644 --- a/web/utils/timezone.ts +++ b/web/utils/timezone.ts @@ -5,3 +5,10 @@ type Item = { name: string } export const timezones: Item[] = tz + +export const getBrowserTimezone = () => { + if (typeof Intl === 'undefined') + return undefined + + return Intl.DateTimeFormat().resolvedOptions().timeZone || undefined +}