Merge remote-tracking branch 'origin/main' into feat/ui-onboarding-rewrite

This commit is contained in:
yyh
2026-06-03 14:27:46 +08:00
15 changed files with 99 additions and 80 deletions

View File

@ -193,7 +193,7 @@ class ModelProviderModelApi(Resource):
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, tenant_id: str, provider):
def get(self, tenant_id: str, provider: str):
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)

View File

@ -278,7 +278,7 @@ class ToolBuiltinProviderListToolsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
def get(self, provider: str):
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(
@ -294,7 +294,7 @@ class ToolBuiltinProviderInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
def get(self, provider: str):
_, tenant_id = current_account_with_tenant()
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
@ -307,7 +307,7 @@ class ToolBuiltinProviderDeleteApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider):
def post(self, provider: str):
_, tenant_id = current_account_with_tenant()
payload = BuiltinToolCredentialDeletePayload.model_validate(console_ns.payload or {})
@ -325,7 +325,7 @@ class ToolBuiltinProviderAddApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
def post(self, provider: str):
user, tenant_id = current_account_with_tenant()
user_id = user.id
@ -350,7 +350,7 @@ class ToolBuiltinProviderUpdateApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider):
def post(self, provider: str):
user, tenant_id = current_account_with_tenant()
user_id = user.id
@ -372,7 +372,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
def get(self, provider: str):
user, tenant_id = current_account_with_tenant()
# Optional list of credential IDs to include even if visibility would hide them
# (used when a workflow/agent node still references another member's only_me credential).
@ -393,7 +393,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/icon")
class ToolBuiltinProviderIconApi(Resource):
@setup_required
def get(self, provider):
def get(self, provider: str):
icon_bytes, mimetype = BuiltinToolManageService.get_builtin_tool_provider_icon(provider)
icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
@ -793,7 +793,7 @@ class ToolPluginOAuthApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def get(self, provider):
def get(self, provider: str):
tool_provider = ToolProviderID(provider)
plugin_id = tool_provider.plugin_id
provider_name = tool_provider.provider_name
@ -831,7 +831,7 @@ class ToolPluginOAuthApi(Resource):
@console_ns.route("/oauth/plugin/<path:provider>/tool/callback")
class ToolOAuthCallback(Resource):
@setup_required
def get(self, provider):
def get(self, provider: str):
context_id = request.cookies.get("context_id")
if not context_id:
raise Forbidden("context_id not found")
@ -888,7 +888,7 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider):
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
payload = BuiltinProviderDefaultCredentialPayload.model_validate(console_ns.payload or {})
return BuiltinToolManageService.set_default_provider(
@ -920,7 +920,7 @@ class ToolOAuthCustomClient(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
def get(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
@ -929,7 +929,7 @@ class ToolOAuthCustomClient(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self, provider):
def delete(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.delete_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
@ -941,7 +941,7 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
def get(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema(
@ -955,7 +955,7 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
def get(self, provider: str):
user, tenant_id = current_account_with_tenant()
include_credential_ids = request.args.getlist("include_credential_ids") or [
s for s in (request.args.get("include_credential_ids") or "").split(",") if s
@ -1166,7 +1166,7 @@ class ToolMCPDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_id):
def get(self, provider_id: str):
_, tenant_id = current_account_with_tenant()
with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session)
@ -1195,7 +1195,7 @@ class ToolMCPUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_id):
def get(self, provider_id: str):
_, tenant_id = current_account_with_tenant()
with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session)

View File

@ -77,7 +77,7 @@ class TriggerProviderIconApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
def get(self, provider: str):
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
@ -103,7 +103,7 @@ class TriggerProviderInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
def get(self, provider: str):
"""Get info for a trigger provider"""
user = current_user
assert isinstance(user, Account)
@ -119,7 +119,7 @@ class TriggerSubscriptionListApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def get(self, provider):
def get(self, provider: str):
"""List all trigger subscriptions for the current tenant's provider"""
user = current_user
assert isinstance(user, Account)
@ -149,7 +149,7 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider):
def post(self, provider: str):
"""Add a new subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
@ -178,7 +178,7 @@ class TriggerSubscriptionBuilderGetApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def get(self, provider, subscription_builder_id):
def get(self, provider: str, subscription_builder_id: str):
"""Get a subscription instance for a trigger provider"""
return jsonable_encoder(
TriggerSubscriptionBuilderService.get_subscription_builder_by_id(subscription_builder_id)
@ -194,7 +194,7 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
def post(self, provider: str, subscription_builder_id: str):
"""Verify and update a subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
@ -226,7 +226,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
def post(self, provider: str, subscription_builder_id: str):
"""Update a subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
@ -260,7 +260,7 @@ class TriggerSubscriptionBuilderLogsApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def get(self, provider, subscription_builder_id):
def get(self, provider: str, subscription_builder_id: str):
"""Get the request logs for a subscription instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
@ -283,7 +283,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider, subscription_builder_id):
def post(self, provider: str, subscription_builder_id: str):
"""Build a subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
@ -407,7 +407,7 @@ class TriggerOAuthAuthorizeApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
def get(self, provider: str):
"""Initiate OAuth authorization flow for a trigger provider"""
user = current_user
assert isinstance(user, Account)
@ -489,7 +489,7 @@ class TriggerOAuthAuthorizeApi(Resource):
@console_ns.route("/oauth/plugin/<path:provider>/trigger/callback")
class TriggerOAuthCallbackApi(Resource):
@setup_required
def get(self, provider):
def get(self, provider: str):
"""Handle OAuth callback for trigger provider"""
context_id = request.cookies.get("context_id")
if not context_id:
@ -557,7 +557,7 @@ class TriggerOAuthClientManageApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def get(self, provider):
def get(self, provider: str):
"""Get OAuth client configuration for a provider"""
user = current_user
assert user.current_tenant_id is not None
@ -603,7 +603,7 @@ class TriggerOAuthClientManageApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider):
def post(self, provider: str):
"""Configure custom OAuth client for a provider"""
user = current_user
assert user.current_tenant_id is not None
@ -629,7 +629,7 @@ class TriggerOAuthClientManageApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def delete(self, provider):
def delete(self, provider: str):
"""Remove custom OAuth client configuration"""
user = current_user
assert user.current_tenant_id is not None
@ -657,7 +657,7 @@ class TriggerSubscriptionVerifyApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
def post(self, provider, subscription_id):
def post(self, provider: str, subscription_id: str):
"""Verify credentials for an existing subscription (edit mode only)"""
user = current_user
assert user.current_tenant_id is not None

View File

@ -4,6 +4,7 @@ import logging
import os
import sys
from logging.handlers import RotatingFileHandler
from typing import override
from configs import dify_config
from dify_app import DifyApp
@ -92,6 +93,7 @@ def _apply_timezone(handlers: list[logging.Handler]):
class _TextFormatter(logging.Formatter):
"""Text formatter that ensures trace_id and req_id are always present."""
@override
def format(self, record: logging.LogRecord) -> str:
if not hasattr(record, "req_id"):
record.req_id = ""
@ -116,6 +118,7 @@ def get_request_id() -> str:
class RequestIdFilter(logging.Filter):
"""Deprecated: Use TraceContextFilter from core.logging.filters instead."""
@override
def filter(self, record: logging.LogRecord) -> bool:
from core.logging.context import get_request_id as _get_request_id
from core.logging.context import get_trace_id as _get_trace_id
@ -128,6 +131,7 @@ class RequestIdFilter(logging.Filter):
class RequestIdFormatter(logging.Formatter):
"""Deprecated: Use _TextFormatter instead."""
@override
def format(self, record: logging.LogRecord) -> str:
if not hasattr(record, "req_id"):
record.req_id = ""

View File

@ -1,5 +1,5 @@
import json
from typing import cast
from typing import cast, override
import flask_login
from flask import Request, Response, request
@ -28,6 +28,7 @@ class DifyLoginManager(flask_login.LoginManager):
Flask-Login's broader callback contract.
"""
@override
def unauthorized(self) -> Response:
"""Return the registered unauthorized handler result as a Flask `Response`."""
return cast(Response, super().unauthorized())

View File

@ -9,7 +9,7 @@ import logging
import time
from collections.abc import Sequence
from datetime import datetime
from typing import Any
from typing import Any, override
from sqlalchemy.orm import sessionmaker
@ -128,6 +128,7 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
logger.debug("LogstoreAPIWorkflowNodeExecutionRepository.__init__: initializing")
self.logstore_client = AliyunLogStore()
@override
def get_node_last_execution(
self,
tenant_id: str,
@ -160,12 +161,12 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
# Use PG protocol with SQL query (get latest version of each record)
sql_query = f"""
SELECT * FROM (
SELECT *,
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
WHERE tenant_id = '{escaped_tenant_id}'
AND app_id = '{escaped_app_id}'
AND workflow_id = '{escaped_workflow_id}'
WHERE tenant_id = '{escaped_tenant_id}'
AND app_id = '{escaped_app_id}'
AND workflow_id = '{escaped_workflow_id}'
AND node_id = '{escaped_node_id}'
AND __time__ > 0
) AS subquery WHERE rn = 1
@ -236,6 +237,7 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
logger.exception("Failed to get node last execution from LogStore")
raise
@override
def get_executions_by_workflow_run(
self,
tenant_id: str,
@ -265,11 +267,11 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
# Use PG protocol with SQL query (get latest version of each record)
sql_query = f"""
SELECT * FROM (
SELECT *,
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
WHERE tenant_id = '{escaped_tenant_id}'
AND app_id = '{escaped_app_id}'
WHERE tenant_id = '{escaped_tenant_id}'
AND app_id = '{escaped_app_id}'
AND workflow_run_id = '{escaped_workflow_run_id}'
AND __time__ > 0
) AS subquery WHERE rn = 1
@ -340,6 +342,7 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
logger.exception("Failed to get executions by workflow run from LogStore")
raise
@override
def get_execution_by_id(
self,
execution_id: str,
@ -365,7 +368,7 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
sql_query = f"""
SELECT * FROM (
SELECT *,
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_node_execution_logstore}"
WHERE id = '{escaped_execution_id}' {tenant_filter} AND __time__ > 0

View File

@ -18,7 +18,7 @@ import os
import time
from collections.abc import Sequence
from datetime import datetime
from typing import Any, cast
from typing import Any, cast, override
from sqlalchemy.orm import sessionmaker
@ -162,6 +162,7 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
# Set to False for new deployments without legacy data in PostgreSQL
self._enable_dual_read = os.environ.get("LOGSTORE_DUAL_READ_ENABLED", "true").lower() == "true"
@override
def get_paginated_workflow_runs(
self,
tenant_id: str,
@ -257,6 +258,7 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
logger.exception("Failed to get paginated workflow runs from LogStore")
raise
@override
def get_workflow_run_by_id(
self,
tenant_id: str,
@ -282,12 +284,12 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
# Use PG protocol with SQL query (get latest version of record)
sql_query = f"""
SELECT * FROM (
SELECT *,
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_execution_logstore}"
WHERE id = '{escaped_run_id}'
AND tenant_id = '{escaped_tenant_id}'
AND app_id = '{escaped_app_id}'
WHERE id = '{escaped_run_id}'
AND tenant_id = '{escaped_tenant_id}'
AND app_id = '{escaped_app_id}'
AND __time__ > 0
) AS subquery WHERE rn = 1
LIMIT 100
@ -364,6 +366,7 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
)
return session.scalar(stmt)
@override
def get_workflow_run_by_id_without_tenant(
self,
run_id: str,
@ -384,7 +387,7 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
# Use PG protocol with SQL query (get latest version of record)
sql_query = f"""
SELECT * FROM (
SELECT *,
SELECT *,
ROW_NUMBER() OVER (PARTITION BY id ORDER BY log_version DESC) as rn
FROM "{AliyunLogStore.workflow_execution_logstore}"
WHERE id = '{escaped_run_id}' AND __time__ > 0
@ -447,6 +450,7 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
stmt = select(WorkflowRun).where(WorkflowRun.id == run_id)
return session.scalar(stmt)
@override
def get_workflow_runs_count(
self,
tenant_id: str,
@ -594,6 +598,7 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
logger.exception("Failed to get workflow runs count")
raise
@override
def get_daily_runs_statistics(
self,
tenant_id: str,
@ -652,6 +657,7 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
logger.exception("Failed to get daily runs statistics")
raise
@override
def get_daily_terminals_statistics(
self,
tenant_id: str,
@ -712,6 +718,7 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
logger.exception("Failed to get daily terminals statistics")
raise
@override
def get_daily_token_cost_statistics(
self,
tenant_id: str,
@ -772,6 +779,7 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository):
logger.exception("Failed to get daily token cost statistics")
raise
@override
def get_average_app_interaction_statistics(
self,
tenant_id: str,

View File

@ -2,6 +2,7 @@ import json
import logging
import os
import time
from typing import override
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
@ -152,6 +153,7 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
return logstore_model
@override
def save(self, execution: WorkflowExecution) -> None:
"""
Save or update a WorkflowExecution domain entity to the logstore.

View File

@ -11,7 +11,7 @@ import os
import time
from collections.abc import Sequence
from datetime import datetime
from typing import Any
from typing import Any, override
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
@ -222,6 +222,7 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
return logstore_model
@override
def save(self, execution: WorkflowNodeExecution) -> None:
"""
Save or update a NodeExecution domain entity to LogStore.
@ -271,6 +272,7 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
logger.exception("Failed to dual-write node execution to SQL database: id=%s", execution.id)
# Don't raise - LogStore write succeeded, SQL is just a backup
@override
def save_execution_data(self, execution: WorkflowNodeExecution) -> None:
"""
Save or update the inputs, process_data, or outputs associated with a specific
@ -305,6 +307,7 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
logger.exception("Failed to dual-write node execution data to SQL database: id=%s", execution.id)
# Don't raise - LogStore write succeeded, SQL is just a backup
@override
def get_by_workflow_execution(
self,
workflow_execution_id: str,

View File

@ -1,5 +1,6 @@
import logging
from collections.abc import Callable
from typing import override
from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer
from opentelemetry.util.types import AttributeValue
@ -14,6 +15,7 @@ logger = logging.getLogger(__name__)
class AppGenerateHandler(SpanHandler):
"""Span handler for ``AppGenerateService.generate``."""
@override
def wrapper[**P, R](
self,
tracer: Tracer,

View File

@ -1,5 +1,6 @@
import logging
from collections.abc import Callable
from typing import override
from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer
from opentelemetry.util.types import AttributeValue
@ -13,6 +14,7 @@ logger = logging.getLogger(__name__)
class WorkflowAppRunnerHandler(SpanHandler):
"""Span handler for ``WorkflowAppRunner.run``."""
@override
def wrapper[**P, R](
self,
tracer: Tracer,

View File

@ -1,7 +1,7 @@
import contextlib
import logging
from collections.abc import Callable
from typing import Protocol, cast
from typing import Protocol, cast, override
import flask
from opentelemetry.instrumentation.celery import CeleryInstrumentor
@ -63,7 +63,8 @@ class ExceptionLoggingHandler(logging.Handler):
to maintain trace context consistency throughout the request lifecycle.
"""
def emit(self, record: logging.LogRecord):
@override
def emit(self, record: logging.LogRecord) -> None:
with contextlib.suppress(Exception):
if not record.exc_info:
return

View File

@ -2362,11 +2362,6 @@
"count": 2
}
},
"web/app/components/explore/create-app-modal/index.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/explore/try-app/tab.tsx": {
"erasable-syntax-only/enums": {
"count": 1

View File

@ -1,6 +1,6 @@
import type { CreateAppModalProps } from '../index'
import type { UsagePlanInfo } from '@/app/components/billing/type'
import { act, fireEvent, render, screen } from '@testing-library/react'
import { act, fireEvent, render, screen, waitFor, within } from '@testing-library/react'
import * as React from 'react'
import { createMockPlan, createMockPlanTotal, createMockPlanUsage } from '@/__mocks__/provider-context'
import { Plan } from '@/app/components/billing/type'
@ -107,13 +107,19 @@ const setup = async (overrides: Partial<CreateAppModalProps> = {}) => {
const getAppIconTrigger = (): HTMLElement => {
const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder')
const iconRow = nameInput.parentElement?.parentElement
const iconRow = nameInput.parentElement
const iconTrigger = iconRow?.firstElementChild
if (!(iconTrigger instanceof HTMLElement))
throw new Error('Failed to locate app icon trigger')
return iconTrigger
}
const openAppIconPicker = () => {
fireEvent.click(getAppIconTrigger())
return screen.getByRole('dialog', { name: 'app.iconPicker.emoji' })
}
describe('CreateAppModal', () => {
beforeEach(() => {
vi.clearAllMocks()
@ -322,13 +328,15 @@ describe('CreateAppModal', () => {
appIconUrl: 'https://example.com/icon.png',
})
fireEvent.click(getAppIconTrigger())
const pickerDialog = openAppIconPicker()
expect(screen.getByRole('button', { name: 'app.iconPicker.cancel' }))!.toBeInTheDocument()
expect(within(pickerDialog).getByRole('button', { name: 'app.iconPicker.cancel' }))!.toBeInTheDocument()
fireEvent.click(screen.getByRole('button', { name: 'app.iconPicker.cancel' }))
fireEvent.click(within(pickerDialog).getByRole('button', { name: 'app.iconPicker.cancel' }))
expect(screen.queryByRole('button', { name: 'app.iconPicker.cancel' })).not.toBeInTheDocument()
await waitFor(() => {
expect(screen.queryByRole('button', { name: 'app.iconPicker.cancel' })).not.toBeInTheDocument()
})
})
it('should update icon payload when selecting emoji and confirming', async () => {
@ -340,16 +348,11 @@ describe('CreateAppModal', () => {
appIconUrl: 'https://example.com/icon.png',
})
fireEvent.click(getAppIconTrigger())
const pickerDialog = openAppIconPicker()
const categoryLabel = screen.getByText('people')
const emojiGrid = categoryLabel.nextElementSibling
const clickableEmojiWrapper = emojiGrid?.firstElementChild
if (!(clickableEmojiWrapper instanceof HTMLElement))
throw new Error('Failed to locate emoji wrapper')
fireEvent.click(clickableEmojiWrapper)
fireEvent.click(within(pickerDialog).getByRole('button', { name: '😀' }))
fireEvent.click(screen.getByRole('button', { name: 'app.iconPicker.ok' }))
fireEvent.click(within(pickerDialog).getByRole('button', { name: 'app.iconPicker.ok' }))
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ }))
await act(async () => {
@ -378,15 +381,10 @@ describe('CreateAppModal', () => {
appIconBackground: '#FFEAD5',
})
fireEvent.click(getAppIconTrigger())
const pickerDialog = openAppIconPicker()
const colorOption = Array.from(document.querySelectorAll('[style^="background:"]'))
.find(element => element.getAttribute('style')?.includes('#E4FBCC'))
if (!(colorOption instanceof HTMLElement) || !(colorOption.parentElement instanceof HTMLElement))
throw new Error('Failed to locate background color option')
fireEvent.click(colorOption.parentElement)
fireEvent.click(screen.getByRole('button', { name: 'app.iconPicker.ok' }))
fireEvent.click(within(pickerDialog).getByRole('button', { name: '#E4FBCC' }))
fireEvent.click(within(pickerDialog).getByRole('button', { name: 'app.iconPicker.ok' }))
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ }))
await act(async () => {

View File

@ -2,6 +2,7 @@
import type { AppIconType } from '@/types/app'
import { Button } from '@langgenius/dify-ui/button'
import { Dialog, DialogCloseButton, DialogContent, DialogTitle } from '@langgenius/dify-ui/dialog'
import { Input } from '@langgenius/dify-ui/input'
import { Kbd, KbdGroup } from '@langgenius/dify-ui/kbd'
import { Switch } from '@langgenius/dify-ui/switch'
import { Textarea } from '@langgenius/dify-ui/textarea'
@ -12,7 +13,6 @@ import * as React from 'react'
import { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next'
import AppIcon from '@/app/components/base/app-icon'
import Input from '@/app/components/base/input'
import AppsFull from '@/app/components/billing/apps-full-in-dialog'
import { useProviderContext } from '@/context/provider-context'
import { AppModeEnum } from '@/types/app'
@ -114,7 +114,7 @@ const CreateAppModal = ({
return (
<>
<Dialog open={show} onOpenChange={open => !open && onHide()} disablePointerDismissal>
<DialogContent className="px-8">
<DialogContent backdropProps={{ forceRender: true }} className="px-8">
<DialogCloseButton />
{isEditModal && (
<DialogTitle className="mb-9 text-xl leading-[30px] font-semibold text-text-primary">{t('editAppTitle', { ns: 'app' })}</DialogTitle>