diff --git a/web/app/components/header/account-setting/model-provider-page/hooks.ts b/web/app/components/header/account-setting/model-provider-page/hooks.ts index 8db964cc27..7e80183e0c 100644 --- a/web/app/components/header/account-setting/model-provider-page/hooks.ts +++ b/web/app/components/header/account-setting/model-provider-page/hooks.ts @@ -222,6 +222,14 @@ export const useUpdateModelList = () => { return updateModelList } +export const useInvalidateDefaultModel = () => { + const queryClient = useQueryClient() + + return useCallback((type: ModelTypeEnum) => { + queryClient.invalidateQueries({ queryKey: commonQueryKeys.defaultModel(type) }) + }, [queryClient]) +} + export const useAnthropicBuyQuota = () => { const [loading, setLoading] = useState(false) diff --git a/web/app/components/header/account-setting/model-provider-page/index.spec.tsx b/web/app/components/header/account-setting/model-provider-page/index.spec.tsx index 50e86b6fd8..7f1ab81290 100644 --- a/web/app/components/header/account-setting/model-provider-page/index.spec.tsx +++ b/web/app/components/header/account-setting/model-provider-page/index.spec.tsx @@ -161,7 +161,7 @@ describe('ModelProviderPage', () => { }) describe('system model config status', () => { - it('should show no-provider warning when no configured providers exist', () => { + it('should not show top warning when no configured providers exist (empty state card handles it)', () => { mockProviders.splice(0, mockProviders.length, { provider: 'anthropic', label: { en_US: 'Anthropic' }, @@ -174,7 +174,8 @@ describe('ModelProviderPage', () => { }) render() - expect(screen.getByText('common.modelProvider.noProviderInstalled')).toBeInTheDocument() + expect(screen.queryByText('common.modelProvider.noProviderInstalled')).not.toBeInTheDocument() + expect(screen.getByText('common.modelProvider.emptyProviderTitle')).toBeInTheDocument() }) it('should show none-configured warning when providers exist but no default models set', () => { diff --git a/web/app/components/header/account-setting/model-provider-page/index.tsx b/web/app/components/header/account-setting/model-provider-page/index.tsx index ec95dde24f..5f5e05b51a 100644 --- a/web/app/components/header/account-setting/model-provider-page/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/index.tsx @@ -23,12 +23,6 @@ import SystemModelSelector from './system-model-selector' type SystemModelConfigStatus = 'no-provider' | 'none-configured' | 'partially-configured' | 'fully-configured' -const WARNING_TEXT_KEYS = { - 'no-provider': 'modelProvider.noProviderInstalled', - 'none-configured': 'modelProvider.noneConfigured', - 'partially-configured': 'modelProvider.notConfigured', -} as const - type Props = { searchText: string } @@ -94,10 +88,13 @@ const ModelProviderPage = ({ searchText }: Props) => { return 'partially-configured' return 'fully-configured' }, [configuredProviders, textGenerationDefaultModel, embeddingsDefaultModel, rerankDefaultModel, speech2textDefaultModel, ttsDefaultModel]) - const showWarning = !isDefaultModelLoading && systemModelConfigStatus !== 'fully-configured' - const warningTextKey = systemModelConfigStatus !== 'fully-configured' - ? WARNING_TEXT_KEYS[systemModelConfigStatus] - : undefined + const warningTextKey + = systemModelConfigStatus === 'none-configured' + ? 'modelProvider.noneConfigured' + : systemModelConfigStatus === 'partially-configured' + ? 'modelProvider.notConfigured' + : null + const showWarning = !isDefaultModelLoading && !!warningTextKey const [filteredConfiguredProviders, filteredNotConfiguredProviders] = useMemo(() => { const filteredConfiguredProviders = configuredProviders.filter( diff --git a/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.spec.tsx b/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.spec.tsx index 819bb71164..ca28698b64 100644 --- a/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.spec.tsx +++ b/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.spec.tsx @@ -26,6 +26,7 @@ vi.mock('react-i18next', async () => { const mockNotify = vi.hoisted(() => vi.fn()) const mockUpdateModelList = vi.hoisted(() => vi.fn()) +const mockInvalidateDefaultModel = vi.hoisted(() => vi.fn()) const mockUpdateDefaultModel = vi.hoisted(() => vi.fn(() => Promise.resolve({ result: 'success' }))) let mockIsCurrentWorkspaceManager = true @@ -57,6 +58,7 @@ vi.mock('../hooks', () => ({ vi.fn(), ], useUpdateModelList: () => mockUpdateModelList, + useInvalidateDefaultModel: () => mockInvalidateDefaultModel, })) vi.mock('@/service/common', () => ({ @@ -144,6 +146,7 @@ describe('SystemModel', () => { type: 'success', message: 'Modified successfully', }) + expect(mockInvalidateDefaultModel).toHaveBeenCalledTimes(5) expect(mockUpdateModelList).toHaveBeenCalledTimes(5) }) }) diff --git a/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx b/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx index 29c71e04fc..1476a2f03d 100644 --- a/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx @@ -19,6 +19,7 @@ import { useProviderContext } from '@/context/provider-context' import { updateDefaultModel } from '@/service/common' import { ModelTypeEnum } from '../declarations' import { + useInvalidateDefaultModel, useModelList, useSystemDefaultModelAndModelList, useUpdateModelList, @@ -48,6 +49,7 @@ const SystemModel: FC = ({ const { isCurrentWorkspaceManager } = useAppContext() const { textGenerationModelList } = useProviderContext() const updateModelList = useUpdateModelList() + const invalidateDefaultModel = useInvalidateDefaultModel() const { data: embeddingModelList } = useModelList(ModelTypeEnum.textEmbedding) const { data: rerankModelList } = useModelList(ModelTypeEnum.rerank) const { data: speech2textModelList } = useModelList(ModelTypeEnum.speech2text) @@ -106,18 +108,9 @@ const SystemModel: FC = ({ notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) setOpen(false) - changedModelTypes.forEach((modelType) => { - if (modelType === ModelTypeEnum.textGeneration) - updateModelList(modelType) - else if (modelType === ModelTypeEnum.textEmbedding) - updateModelList(modelType) - else if (modelType === ModelTypeEnum.rerank) - updateModelList(modelType) - else if (modelType === ModelTypeEnum.speech2text) - updateModelList(modelType) - else if (modelType === ModelTypeEnum.tts) - updateModelList(modelType) - }) + const allModelTypes = [ModelTypeEnum.textGeneration, ModelTypeEnum.textEmbedding, ModelTypeEnum.rerank, ModelTypeEnum.speech2text, ModelTypeEnum.tts] + allModelTypes.forEach(type => invalidateDefaultModel(type)) + changedModelTypes.forEach(type => updateModelList(type)) } } diff --git a/web/i18n/en-US/common.json b/web/i18n/en-US/common.json index ad3b3c7e97..20a3ea7cc0 100644 --- a/web/i18n/en-US/common.json +++ b/web/i18n/en-US/common.json @@ -390,7 +390,6 @@ "modelProvider.models": "Models", "modelProvider.modelsNum": "{{num}} Models", "modelProvider.noModelFound": "No model found for {{model}}", - "modelProvider.noProviderInstalled": "No model provider installed. Install one to configure system models.", "modelProvider.noneConfigured": "Configure a default system model to run applications", "modelProvider.notConfigured": "The system model has not yet been fully configured", "modelProvider.parameters": "PARAMETERS", diff --git a/web/i18n/zh-Hans/common.json b/web/i18n/zh-Hans/common.json index 5e59133595..150cb2ef91 100644 --- a/web/i18n/zh-Hans/common.json +++ b/web/i18n/zh-Hans/common.json @@ -390,7 +390,6 @@ "modelProvider.models": "模型列表", "modelProvider.modelsNum": "{{num}} 个模型", "modelProvider.noModelFound": "找不到模型 {{model}}", - "modelProvider.noProviderInstalled": "尚未安装模型供应商,请先安装以配置系统模型。", "modelProvider.noneConfigured": "配置默认系统模型以运行应用", "modelProvider.notConfigured": "系统模型尚未完全配置", "modelProvider.parameters": "参数",