diff --git a/web/app/components/header/account-setting/model-provider-page/hooks.ts b/web/app/components/header/account-setting/model-provider-page/hooks.ts index 8db964cc27..7e80183e0c 100644 --- a/web/app/components/header/account-setting/model-provider-page/hooks.ts +++ b/web/app/components/header/account-setting/model-provider-page/hooks.ts @@ -222,6 +222,14 @@ export const useUpdateModelList = () => { return updateModelList } +export const useInvalidateDefaultModel = () => { + const queryClient = useQueryClient() + + return useCallback((type: ModelTypeEnum) => { + queryClient.invalidateQueries({ queryKey: commonQueryKeys.defaultModel(type) }) + }, [queryClient]) +} + export const useAnthropicBuyQuota = () => { const [loading, setLoading] = useState(false) diff --git a/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.spec.tsx b/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.spec.tsx index 819bb71164..ca28698b64 100644 --- a/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.spec.tsx +++ b/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.spec.tsx @@ -26,6 +26,7 @@ vi.mock('react-i18next', async () => { const mockNotify = vi.hoisted(() => vi.fn()) const mockUpdateModelList = vi.hoisted(() => vi.fn()) +const mockInvalidateDefaultModel = vi.hoisted(() => vi.fn()) const mockUpdateDefaultModel = vi.hoisted(() => vi.fn(() => Promise.resolve({ result: 'success' }))) let mockIsCurrentWorkspaceManager = true @@ -57,6 +58,7 @@ vi.mock('../hooks', () => ({ vi.fn(), ], useUpdateModelList: () => mockUpdateModelList, + useInvalidateDefaultModel: () => mockInvalidateDefaultModel, })) vi.mock('@/service/common', () => ({ @@ -144,6 +146,7 @@ describe('SystemModel', () => { type: 'success', message: 'Modified successfully', }) + expect(mockInvalidateDefaultModel).toHaveBeenCalledTimes(5) expect(mockUpdateModelList).toHaveBeenCalledTimes(5) }) }) diff --git a/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx b/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx index 29c71e04fc..1476a2f03d 100644 --- a/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx @@ -19,6 +19,7 @@ import { useProviderContext } from '@/context/provider-context' import { updateDefaultModel } from '@/service/common' import { ModelTypeEnum } from '../declarations' import { + useInvalidateDefaultModel, useModelList, useSystemDefaultModelAndModelList, useUpdateModelList, @@ -48,6 +49,7 @@ const SystemModel: FC = ({ const { isCurrentWorkspaceManager } = useAppContext() const { textGenerationModelList } = useProviderContext() const updateModelList = useUpdateModelList() + const invalidateDefaultModel = useInvalidateDefaultModel() const { data: embeddingModelList } = useModelList(ModelTypeEnum.textEmbedding) const { data: rerankModelList } = useModelList(ModelTypeEnum.rerank) const { data: speech2textModelList } = useModelList(ModelTypeEnum.speech2text) @@ -106,18 +108,9 @@ const SystemModel: FC = ({ notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) }) setOpen(false) - changedModelTypes.forEach((modelType) => { - if (modelType === ModelTypeEnum.textGeneration) - updateModelList(modelType) - else if (modelType === ModelTypeEnum.textEmbedding) - updateModelList(modelType) - else if (modelType === ModelTypeEnum.rerank) - updateModelList(modelType) - else if (modelType === ModelTypeEnum.speech2text) - updateModelList(modelType) - else if (modelType === ModelTypeEnum.tts) - updateModelList(modelType) - }) + const allModelTypes = [ModelTypeEnum.textGeneration, ModelTypeEnum.textEmbedding, ModelTypeEnum.rerank, ModelTypeEnum.speech2text, ModelTypeEnum.tts] + allModelTypes.forEach(type => invalidateDefaultModel(type)) + changedModelTypes.forEach(type => updateModelList(type)) } }