From 04e0ab7eda87eab82e53d3bf1bc026d43067fdcb Mon Sep 17 00:00:00 2001 From: yyh Date: Wed, 4 Mar 2026 21:55:34 +0800 Subject: [PATCH] refactor(web): migrate provider-added-card model list to oRPC query-driven state --- .../provider-added-card/index.spec.tsx | 83 ++++++++++---- .../provider-added-card/index.tsx | 105 ++++++++++++------ web/contract/console/model-providers.ts | 17 +++ web/contract/router.ts | 4 + 4 files changed, 152 insertions(+), 57 deletions(-) create mode 100644 web/contract/console/model-providers.ts diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/index.spec.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/index.spec.tsx index 51c0ebce39..28b95c891c 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/index.spec.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/index.spec.tsx @@ -1,17 +1,30 @@ -import type { ModelItem, ModelProvider } from '../declarations' +import type { ReactNode } from 'react' +import type { ModelProvider } from '../declarations' +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' -import { fetchModelProviderModelList } from '@/service/common' import { ConfigurationMethodEnum } from '../declarations' import ProviderAddedCard from './index' let mockIsCurrentWorkspaceManager = true +const mockFetchModelProviderModels = vi.fn() +const mockQueryOptions = vi.fn(({ input, ...options }: { input: { params: { provider: string } }, enabled?: boolean }) => ({ + queryKey: ['console', 'modelProviders', 'models', input.params.provider], + queryFn: () => mockFetchModelProviderModels(input.params.provider), + ...options, +})) const mockEventEmitter = { useSubscription: vi.fn(), emit: vi.fn(), } -vi.mock('@/service/common', () => ({ - fetchModelProviderModelList: vi.fn(), +vi.mock('@/service/client', () => ({ + consoleQuery: { + modelProviders: { + models: { + queryOptions: (options: { input: { params: { provider: string } }, enabled?: boolean }) => mockQueryOptions(options), + }, + }, + }, })) vi.mock('@/context/app-context', () => ({ @@ -53,6 +66,21 @@ vi.mock('@/app/components/header/account-setting/model-provider-page/model-auth' ManageCustomModelCredentials: () =>
, })) +const createTestQueryClient = () => new QueryClient({ + defaultOptions: { + queries: { retry: false, gcTime: 0 }, + }, +}) + +const renderWithQueryClient = (node: ReactNode) => { + const queryClient = createTestQueryClient() + return render( + + {node} + , + ) +} + describe('ProviderAddedCard', () => { const mockProvider = { provider: 'langgenius/openai/openai', @@ -67,19 +95,21 @@ describe('ProviderAddedCard', () => { }) it('should render provider added card component', () => { - render() + renderWithQueryClient() expect(screen.getByTestId('provider-added-card')).toBeInTheDocument() expect(screen.getByTestId('provider-icon')).toBeInTheDocument() }) it('should open, refresh and collapse model list', async () => { - vi.mocked(fetchModelProviderModelList).mockResolvedValue({ data: [{ model: 'gpt-4' }] } as unknown as { data: ModelItem[] }) - render() + mockFetchModelProviderModels.mockResolvedValue({ data: [{ model: 'gpt-4' }] }) + renderWithQueryClient() const showModelsBtn = screen.getByTestId('show-models-button') fireEvent.click(showModelsBtn) - expect(fetchModelProviderModelList).toHaveBeenCalledWith(`/workspaces/current/model-providers/${mockProvider.provider}/models`) + await waitFor(() => { + expect(mockFetchModelProviderModels).toHaveBeenCalledWith(mockProvider.provider) + }) expect(await screen.findByTestId('model-list')).toBeInTheDocument() // Test line 71-72: Opening when already fetched @@ -90,13 +120,13 @@ describe('ProviderAddedCard', () => { // Explicitly re-find and click to re-open fireEvent.click(screen.getByTestId('show-models-button')) expect(await screen.findByTestId('model-list')).toBeInTheDocument() - expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1) // Should not fetch again + expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1) // Should not fetch again // Refresh list from ModelList const refreshBtn = screen.getByRole('button', { name: 'refresh list' }) fireEvent.click(refreshBtn) await waitFor(() => { - expect(fetchModelProviderModelList).toHaveBeenCalledTimes(2) + expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(2) }) }) @@ -105,18 +135,20 @@ describe('ProviderAddedCard', () => { const promise = new Promise((resolve) => { resolveOuter = resolve }) - vi.mocked(fetchModelProviderModelList).mockReturnValue(promise as unknown as ReturnType) + mockFetchModelProviderModels.mockReturnValue(promise) - render() + renderWithQueryClient() const showModelsBtn = screen.getByTestId('show-models-button') // First call sets loading to true fireEvent.click(showModelsBtn) - expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1) + await waitFor(() => { + expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1) + }) // Second call should return early because loading is true fireEvent.click(showModelsBtn) - expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1) + expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1) await act(async () => { resolveOuter({ data: [] }) @@ -130,7 +162,7 @@ describe('ProviderAddedCard', () => { ...mockProvider, provider: 'custom/provider', } as unknown as ModelProvider - render() + renderWithQueryClient() expect(screen.getByText('common.modelProvider.configureTip')).toBeInTheDocument() }) @@ -139,9 +171,9 @@ describe('ProviderAddedCard', () => { mockEventEmitter.useSubscription.mockImplementation((handler: (v: unknown) => void) => { capturedHandler = handler as (v: { type: string, payload: string } | null) => void }) - vi.mocked(fetchModelProviderModelList).mockResolvedValue({ data: [] } as unknown as { data: ModelItem[] }) + mockFetchModelProviderModels.mockResolvedValue({ data: [] }) - render() + renderWithQueryClient() expect(capturedHandler).toBeDefined() act(() => { @@ -152,7 +184,7 @@ describe('ProviderAddedCard', () => { }) await waitFor(() => { - expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1) + expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1) }) // Should ignore non-matching events @@ -160,7 +192,7 @@ describe('ProviderAddedCard', () => { capturedHandler({ type: 'OTHER', payload: '' }) capturedHandler(null) }) - expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1) + expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1) }) it('should render custom model actions for workspace managers', () => { @@ -168,13 +200,22 @@ describe('ProviderAddedCard', () => { ...mockProvider, configurate_methods: [ConfigurationMethodEnum.customizableModel], } as unknown as ModelProvider - const { rerender } = render() + const queryClient = createTestQueryClient() + const { rerender } = render( + + + , + ) expect(screen.getByTestId('manage-custom-model')).toBeInTheDocument() expect(screen.getByTestId('add-custom-model')).toBeInTheDocument() mockIsCurrentWorkspaceManager = false - rerender() + rerender( + + + , + ) expect(screen.queryByTestId('manage-custom-model')).not.toBeInTheDocument() }) }) diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/index.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/index.tsx index 8361f6068d..647e47a381 100644 --- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/index.tsx @@ -1,12 +1,13 @@ import type { FC } from 'react' import type { - ModelItem, ModelProvider, } from '../declarations' import type { ModelProviderQuotaGetPaid } from '../utils' import type { PluginDetail } from '@/app/components/plugins/types' +import type { EventEmitterValue } from '@/context/event-emitter' -import { useState } from 'react' +import { useQuery } from '@tanstack/react-query' +import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import { AddCustomModel, @@ -16,7 +17,7 @@ import { IS_CE_EDITION } from '@/config' import { useAppContext } from '@/context/app-context' import { useEventEmitterContextContext } from '@/context/event-emitter' import { useProviderContext } from '@/context/provider-context' -import { fetchModelProviderModelList } from '@/service/common' +import { consoleQuery } from '@/service/client' import { cn } from '@/utils/classnames' import { ConfigurationMethodEnum } from '../declarations' import ModelBadge from '../model-badge' @@ -30,6 +31,21 @@ import ModelList from './model-list' import ProviderCardActions from './provider-card-actions' export const UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST = 'UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST' + +const isModelProviderCustomModelListUpdateEvent = ( + value: EventEmitterValue, + providerName: string, +): value is { + type: typeof UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST + payload: string +} => { + return typeof value === 'object' + && value !== null + && value.type === UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST + && typeof value.payload === 'string' + && value.payload === providerName +} + type ProviderAddedCardProps = { notConfigured?: boolean provider: ModelProvider @@ -43,52 +59,69 @@ const ProviderAddedCard: FC = ({ const { t } = useTranslation() const { eventEmitter } = useEventEmitterContextContext() const { refreshModelProviders } = useProviderContext() - const [fetched, setFetched] = useState(false) - const [loading, setLoading] = useState(false) const [collapsed, setCollapsed] = useState(true) - const [modelList, setModelList] = useState([]) - const configurationMethods = provider.configurate_methods.filter(method => method !== ConfigurationMethodEnum.fetchFromRemote) + const currentProviderName = provider.provider + const supportsPredefinedModel = provider.configurate_methods.includes(ConfigurationMethodEnum.predefinedModel) + const supportsCustomizableModel = provider.configurate_methods.includes(ConfigurationMethodEnum.customizableModel) const systemConfig = provider.system_configuration - const hasModelList = fetched && !!modelList.length + const { + data: modelList = [], + isFetching: loading, + isSuccess: hasFetchedModelList, + refetch: refetchModelList, + } = useQuery(consoleQuery.modelProviders.models.queryOptions({ + input: { params: { provider: currentProviderName } }, + enabled: !collapsed, + staleTime: Infinity, + refetchOnWindowFocus: false, + refetchOnReconnect: false, + select: response => response.data, + })) + const hasModelList = hasFetchedModelList && !!modelList.length + const showCollapsedSection = collapsed || !hasFetchedModelList const { isCurrentWorkspaceManager } = useAppContext() - const showModelProvider = systemConfig.enabled && MODEL_PROVIDER_QUOTA_GET_PAID.includes(provider.provider as ModelProviderQuotaGetPaid) && !IS_CE_EDITION - const showCredential = configurationMethods.includes(ConfigurationMethodEnum.predefinedModel) && isCurrentWorkspaceManager + const showModelProvider = systemConfig.enabled && MODEL_PROVIDER_QUOTA_GET_PAID.includes(currentProviderName as ModelProviderQuotaGetPaid) && !IS_CE_EDITION + const showCredential = supportsPredefinedModel && isCurrentWorkspaceManager + const showCustomModelActions = supportsCustomizableModel && isCurrentWorkspaceManager - const getModelList = async (providerName: string) => { + const refreshModelList = useCallback((targetProviderName: string) => { + if (targetProviderName !== currentProviderName || loading) + return + + if (collapsed) + setCollapsed(false) + + refetchModelList().catch(() => {}) + }, [collapsed, currentProviderName, loading, refetchModelList]) + + const handleOpenModelList = useCallback(() => { if (loading) return - try { - setLoading(true) - const modelsData = await fetchModelProviderModelList(`/workspaces/current/model-providers/${providerName}/models`) - setModelList(modelsData.data) - setCollapsed(false) - setFetched(true) - } - finally { - setLoading(false) - } - } - const handleOpenModelList = () => { - if (fetched) { + + if (collapsed) { setCollapsed(false) return } - getModelList(provider.provider) - } + refetchModelList().catch(() => {}) + }, [collapsed, loading, refetchModelList]) - eventEmitter?.useSubscription((v: any) => { - if (v?.type === UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST && v.payload === provider.provider) - getModelList(v.payload) - }) + const handleModelProviderCustomModelListUpdate = useCallback((value: EventEmitterValue) => { + if (!isModelProviderCustomModelListUpdateEvent(value, currentProviderName)) + return + + refreshModelList(currentProviderName) + }, [currentProviderName, refreshModelList]) + + eventEmitter?.useSubscription(handleModelProviderCustomModelListUpdate) return (
@@ -117,7 +150,7 @@ const ProviderAddedCard: FC = ({ )}
{ - collapsed && ( + showCollapsedSection && (
{(showModelProvider || !notConfigured) && ( <> @@ -155,7 +188,7 @@ const ProviderAddedCard: FC = ({
)} { - configurationMethods.includes(ConfigurationMethodEnum.customizableModel) && isCurrentWorkspaceManager && ( + showCustomModelActions && (
= ({ ) } { - !collapsed && ( + !showCollapsedSection && ( setCollapsed(true)} - onChange={(provider: string) => getModelList(provider)} + onChange={refreshModelList} /> ) } diff --git a/web/contract/console/model-providers.ts b/web/contract/console/model-providers.ts new file mode 100644 index 0000000000..39cbb64914 --- /dev/null +++ b/web/contract/console/model-providers.ts @@ -0,0 +1,17 @@ +import type { ModelItem } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { type } from '@orpc/contract' +import { base } from '../base' + +export const modelProvidersModelsContract = base + .route({ + path: '/workspaces/current/model-providers/{provider}/models', + method: 'GET', + }) + .input(type<{ + params: { + provider: string + } + }>()) + .output(type<{ + data: ModelItem[] + }>()) diff --git a/web/contract/router.ts b/web/contract/router.ts index 79a95be55a..560284bc3f 100644 --- a/web/contract/router.ts +++ b/web/contract/router.ts @@ -12,6 +12,7 @@ import { exploreInstalledAppsContract, exploreInstalledAppUninstallContract, } from './console/explore' +import { modelProvidersModelsContract } from './console/model-providers' import { systemFeaturesContract } from './console/system' import { triggerOAuthConfigContract, @@ -63,6 +64,9 @@ export const consoleRouterContract = { parameters: trialAppParametersContract, workflows: trialAppWorkflowsContract, }, + modelProviders: { + models: modelProvidersModelsContract, + }, billing: { invoices: invoicesContract, bindPartnerStack: bindPartnerStackContract,