refactor(web): migrate provider-added-card model list to oRPC query-driven state

This commit is contained in:
yyh
2026-03-04 21:55:34 +08:00
parent 784bda9c86
commit 04e0ab7eda
4 changed files with 152 additions and 57 deletions

View File

@ -1,17 +1,30 @@
import type { ModelItem, ModelProvider } from '../declarations' import type { ReactNode } from 'react'
import type { ModelProvider } from '../declarations'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
import { fetchModelProviderModelList } from '@/service/common'
import { ConfigurationMethodEnum } from '../declarations' import { ConfigurationMethodEnum } from '../declarations'
import ProviderAddedCard from './index' import ProviderAddedCard from './index'
let mockIsCurrentWorkspaceManager = true let mockIsCurrentWorkspaceManager = true
const mockFetchModelProviderModels = vi.fn()
const mockQueryOptions = vi.fn(({ input, ...options }: { input: { params: { provider: string } }, enabled?: boolean }) => ({
queryKey: ['console', 'modelProviders', 'models', input.params.provider],
queryFn: () => mockFetchModelProviderModels(input.params.provider),
...options,
}))
const mockEventEmitter = { const mockEventEmitter = {
useSubscription: vi.fn(), useSubscription: vi.fn(),
emit: vi.fn(), emit: vi.fn(),
} }
vi.mock('@/service/common', () => ({ vi.mock('@/service/client', () => ({
fetchModelProviderModelList: vi.fn(), consoleQuery: {
modelProviders: {
models: {
queryOptions: (options: { input: { params: { provider: string } }, enabled?: boolean }) => mockQueryOptions(options),
},
},
},
})) }))
vi.mock('@/context/app-context', () => ({ vi.mock('@/context/app-context', () => ({
@ -53,6 +66,21 @@ vi.mock('@/app/components/header/account-setting/model-provider-page/model-auth'
ManageCustomModelCredentials: () => <div data-testid="manage-custom-model" />, ManageCustomModelCredentials: () => <div data-testid="manage-custom-model" />,
})) }))
const createTestQueryClient = () => new QueryClient({
defaultOptions: {
queries: { retry: false, gcTime: 0 },
},
})
const renderWithQueryClient = (node: ReactNode) => {
const queryClient = createTestQueryClient()
return render(
<QueryClientProvider client={queryClient}>
{node}
</QueryClientProvider>,
)
}
describe('ProviderAddedCard', () => { describe('ProviderAddedCard', () => {
const mockProvider = { const mockProvider = {
provider: 'langgenius/openai/openai', provider: 'langgenius/openai/openai',
@ -67,19 +95,21 @@ describe('ProviderAddedCard', () => {
}) })
it('should render provider added card component', () => { it('should render provider added card component', () => {
render(<ProviderAddedCard provider={mockProvider} />) renderWithQueryClient(<ProviderAddedCard provider={mockProvider} />)
expect(screen.getByTestId('provider-added-card')).toBeInTheDocument() expect(screen.getByTestId('provider-added-card')).toBeInTheDocument()
expect(screen.getByTestId('provider-icon')).toBeInTheDocument() expect(screen.getByTestId('provider-icon')).toBeInTheDocument()
}) })
it('should open, refresh and collapse model list', async () => { it('should open, refresh and collapse model list', async () => {
vi.mocked(fetchModelProviderModelList).mockResolvedValue({ data: [{ model: 'gpt-4' }] } as unknown as { data: ModelItem[] }) mockFetchModelProviderModels.mockResolvedValue({ data: [{ model: 'gpt-4' }] })
render(<ProviderAddedCard provider={mockProvider} />) renderWithQueryClient(<ProviderAddedCard provider={mockProvider} />)
const showModelsBtn = screen.getByTestId('show-models-button') const showModelsBtn = screen.getByTestId('show-models-button')
fireEvent.click(showModelsBtn) fireEvent.click(showModelsBtn)
expect(fetchModelProviderModelList).toHaveBeenCalledWith(`/workspaces/current/model-providers/${mockProvider.provider}/models`) await waitFor(() => {
expect(mockFetchModelProviderModels).toHaveBeenCalledWith(mockProvider.provider)
})
expect(await screen.findByTestId('model-list')).toBeInTheDocument() expect(await screen.findByTestId('model-list')).toBeInTheDocument()
// Test line 71-72: Opening when already fetched // Test line 71-72: Opening when already fetched
@ -90,13 +120,13 @@ describe('ProviderAddedCard', () => {
// Explicitly re-find and click to re-open // Explicitly re-find and click to re-open
fireEvent.click(screen.getByTestId('show-models-button')) fireEvent.click(screen.getByTestId('show-models-button'))
expect(await screen.findByTestId('model-list')).toBeInTheDocument() expect(await screen.findByTestId('model-list')).toBeInTheDocument()
expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1) // Should not fetch again expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1) // Should not fetch again
// Refresh list from ModelList // Refresh list from ModelList
const refreshBtn = screen.getByRole('button', { name: 'refresh list' }) const refreshBtn = screen.getByRole('button', { name: 'refresh list' })
fireEvent.click(refreshBtn) fireEvent.click(refreshBtn)
await waitFor(() => { await waitFor(() => {
expect(fetchModelProviderModelList).toHaveBeenCalledTimes(2) expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(2)
}) })
}) })
@ -105,18 +135,20 @@ describe('ProviderAddedCard', () => {
const promise = new Promise((resolve) => { const promise = new Promise((resolve) => {
resolveOuter = resolve resolveOuter = resolve
}) })
vi.mocked(fetchModelProviderModelList).mockReturnValue(promise as unknown as ReturnType<typeof fetchModelProviderModelList>) mockFetchModelProviderModels.mockReturnValue(promise)
render(<ProviderAddedCard provider={mockProvider} />) renderWithQueryClient(<ProviderAddedCard provider={mockProvider} />)
const showModelsBtn = screen.getByTestId('show-models-button') const showModelsBtn = screen.getByTestId('show-models-button')
// First call sets loading to true // First call sets loading to true
fireEvent.click(showModelsBtn) fireEvent.click(showModelsBtn)
expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1) await waitFor(() => {
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1)
})
// Second call should return early because loading is true // Second call should return early because loading is true
fireEvent.click(showModelsBtn) fireEvent.click(showModelsBtn)
expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1) expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1)
await act(async () => { await act(async () => {
resolveOuter({ data: [] }) resolveOuter({ data: [] })
@ -130,7 +162,7 @@ describe('ProviderAddedCard', () => {
...mockProvider, ...mockProvider,
provider: 'custom/provider', provider: 'custom/provider',
} as unknown as ModelProvider } as unknown as ModelProvider
render(<ProviderAddedCard provider={providerWithoutQuota} notConfigured />) renderWithQueryClient(<ProviderAddedCard provider={providerWithoutQuota} notConfigured />)
expect(screen.getByText('common.modelProvider.configureTip')).toBeInTheDocument() expect(screen.getByText('common.modelProvider.configureTip')).toBeInTheDocument()
}) })
@ -139,9 +171,9 @@ describe('ProviderAddedCard', () => {
mockEventEmitter.useSubscription.mockImplementation((handler: (v: unknown) => void) => { mockEventEmitter.useSubscription.mockImplementation((handler: (v: unknown) => void) => {
capturedHandler = handler as (v: { type: string, payload: string } | null) => void capturedHandler = handler as (v: { type: string, payload: string } | null) => void
}) })
vi.mocked(fetchModelProviderModelList).mockResolvedValue({ data: [] } as unknown as { data: ModelItem[] }) mockFetchModelProviderModels.mockResolvedValue({ data: [] })
render(<ProviderAddedCard provider={mockProvider} />) renderWithQueryClient(<ProviderAddedCard provider={mockProvider} />)
expect(capturedHandler).toBeDefined() expect(capturedHandler).toBeDefined()
act(() => { act(() => {
@ -152,7 +184,7 @@ describe('ProviderAddedCard', () => {
}) })
await waitFor(() => { await waitFor(() => {
expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1) expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1)
}) })
// Should ignore non-matching events // Should ignore non-matching events
@ -160,7 +192,7 @@ describe('ProviderAddedCard', () => {
capturedHandler({ type: 'OTHER', payload: '' }) capturedHandler({ type: 'OTHER', payload: '' })
capturedHandler(null) capturedHandler(null)
}) })
expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1) expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1)
}) })
it('should render custom model actions for workspace managers', () => { it('should render custom model actions for workspace managers', () => {
@ -168,13 +200,22 @@ describe('ProviderAddedCard', () => {
...mockProvider, ...mockProvider,
configurate_methods: [ConfigurationMethodEnum.customizableModel], configurate_methods: [ConfigurationMethodEnum.customizableModel],
} as unknown as ModelProvider } as unknown as ModelProvider
const { rerender } = render(<ProviderAddedCard provider={customConfigProvider} />) const queryClient = createTestQueryClient()
const { rerender } = render(
<QueryClientProvider client={queryClient}>
<ProviderAddedCard provider={customConfigProvider} />
</QueryClientProvider>,
)
expect(screen.getByTestId('manage-custom-model')).toBeInTheDocument() expect(screen.getByTestId('manage-custom-model')).toBeInTheDocument()
expect(screen.getByTestId('add-custom-model')).toBeInTheDocument() expect(screen.getByTestId('add-custom-model')).toBeInTheDocument()
mockIsCurrentWorkspaceManager = false mockIsCurrentWorkspaceManager = false
rerender(<ProviderAddedCard provider={customConfigProvider} />) rerender(
<QueryClientProvider client={queryClient}>
<ProviderAddedCard provider={customConfigProvider} />
</QueryClientProvider>,
)
expect(screen.queryByTestId('manage-custom-model')).not.toBeInTheDocument() expect(screen.queryByTestId('manage-custom-model')).not.toBeInTheDocument()
}) })
}) })

View File

@ -1,12 +1,13 @@
import type { FC } from 'react' import type { FC } from 'react'
import type { import type {
ModelItem,
ModelProvider, ModelProvider,
} from '../declarations' } from '../declarations'
import type { ModelProviderQuotaGetPaid } from '../utils' import type { ModelProviderQuotaGetPaid } from '../utils'
import type { PluginDetail } from '@/app/components/plugins/types' import type { PluginDetail } from '@/app/components/plugins/types'
import type { EventEmitterValue } from '@/context/event-emitter'
import { useState } from 'react' import { useQuery } from '@tanstack/react-query'
import { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { import {
AddCustomModel, AddCustomModel,
@ -16,7 +17,7 @@ import { IS_CE_EDITION } from '@/config'
import { useAppContext } from '@/context/app-context' import { useAppContext } from '@/context/app-context'
import { useEventEmitterContextContext } from '@/context/event-emitter' import { useEventEmitterContextContext } from '@/context/event-emitter'
import { useProviderContext } from '@/context/provider-context' import { useProviderContext } from '@/context/provider-context'
import { fetchModelProviderModelList } from '@/service/common' import { consoleQuery } from '@/service/client'
import { cn } from '@/utils/classnames' import { cn } from '@/utils/classnames'
import { ConfigurationMethodEnum } from '../declarations' import { ConfigurationMethodEnum } from '../declarations'
import ModelBadge from '../model-badge' import ModelBadge from '../model-badge'
@ -30,6 +31,21 @@ import ModelList from './model-list'
import ProviderCardActions from './provider-card-actions' import ProviderCardActions from './provider-card-actions'
export const UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST = 'UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST' export const UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST = 'UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST'
const isModelProviderCustomModelListUpdateEvent = (
value: EventEmitterValue,
providerName: string,
): value is {
type: typeof UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST
payload: string
} => {
return typeof value === 'object'
&& value !== null
&& value.type === UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST
&& typeof value.payload === 'string'
&& value.payload === providerName
}
type ProviderAddedCardProps = { type ProviderAddedCardProps = {
notConfigured?: boolean notConfigured?: boolean
provider: ModelProvider provider: ModelProvider
@ -43,52 +59,69 @@ const ProviderAddedCard: FC<ProviderAddedCardProps> = ({
const { t } = useTranslation() const { t } = useTranslation()
const { eventEmitter } = useEventEmitterContextContext() const { eventEmitter } = useEventEmitterContextContext()
const { refreshModelProviders } = useProviderContext() const { refreshModelProviders } = useProviderContext()
const [fetched, setFetched] = useState(false)
const [loading, setLoading] = useState(false)
const [collapsed, setCollapsed] = useState(true) const [collapsed, setCollapsed] = useState(true)
const [modelList, setModelList] = useState<ModelItem[]>([]) const currentProviderName = provider.provider
const configurationMethods = provider.configurate_methods.filter(method => method !== ConfigurationMethodEnum.fetchFromRemote) const supportsPredefinedModel = provider.configurate_methods.includes(ConfigurationMethodEnum.predefinedModel)
const supportsCustomizableModel = provider.configurate_methods.includes(ConfigurationMethodEnum.customizableModel)
const systemConfig = provider.system_configuration const systemConfig = provider.system_configuration
const hasModelList = fetched && !!modelList.length const {
data: modelList = [],
isFetching: loading,
isSuccess: hasFetchedModelList,
refetch: refetchModelList,
} = useQuery(consoleQuery.modelProviders.models.queryOptions({
input: { params: { provider: currentProviderName } },
enabled: !collapsed,
staleTime: Infinity,
refetchOnWindowFocus: false,
refetchOnReconnect: false,
select: response => response.data,
}))
const hasModelList = hasFetchedModelList && !!modelList.length
const showCollapsedSection = collapsed || !hasFetchedModelList
const { isCurrentWorkspaceManager } = useAppContext() const { isCurrentWorkspaceManager } = useAppContext()
const showModelProvider = systemConfig.enabled && MODEL_PROVIDER_QUOTA_GET_PAID.includes(provider.provider as ModelProviderQuotaGetPaid) && !IS_CE_EDITION const showModelProvider = systemConfig.enabled && MODEL_PROVIDER_QUOTA_GET_PAID.includes(currentProviderName as ModelProviderQuotaGetPaid) && !IS_CE_EDITION
const showCredential = configurationMethods.includes(ConfigurationMethodEnum.predefinedModel) && isCurrentWorkspaceManager const showCredential = supportsPredefinedModel && isCurrentWorkspaceManager
const showCustomModelActions = supportsCustomizableModel && isCurrentWorkspaceManager
const getModelList = async (providerName: string) => { const refreshModelList = useCallback((targetProviderName: string) => {
if (targetProviderName !== currentProviderName || loading)
return
if (collapsed)
setCollapsed(false)
refetchModelList().catch(() => {})
}, [collapsed, currentProviderName, loading, refetchModelList])
const handleOpenModelList = useCallback(() => {
if (loading) if (loading)
return return
try {
setLoading(true) if (collapsed) {
const modelsData = await fetchModelProviderModelList(`/workspaces/current/model-providers/${providerName}/models`)
setModelList(modelsData.data)
setCollapsed(false)
setFetched(true)
}
finally {
setLoading(false)
}
}
const handleOpenModelList = () => {
if (fetched) {
setCollapsed(false) setCollapsed(false)
return return
} }
getModelList(provider.provider) refetchModelList().catch(() => {})
} }, [collapsed, loading, refetchModelList])
eventEmitter?.useSubscription((v: any) => { const handleModelProviderCustomModelListUpdate = useCallback((value: EventEmitterValue) => {
if (v?.type === UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST && v.payload === provider.provider) if (!isModelProviderCustomModelListUpdateEvent(value, currentProviderName))
getModelList(v.payload) return
})
refreshModelList(currentProviderName)
}, [currentProviderName, refreshModelList])
eventEmitter?.useSubscription(handleModelProviderCustomModelListUpdate)
return ( return (
<div <div
data-testid="provider-added-card" data-testid="provider-added-card"
className={cn( className={cn(
'mb-2 rounded-xl border-[0.5px] border-divider-regular bg-third-party-model-bg-default shadow-xs', 'mb-2 rounded-xl border-[0.5px] border-divider-regular bg-third-party-model-bg-default shadow-xs',
provider.provider === 'langgenius/openai/openai' && 'bg-third-party-model-bg-openai', currentProviderName === 'langgenius/openai/openai' && 'bg-third-party-model-bg-openai',
provider.provider === 'langgenius/anthropic/anthropic' && 'bg-third-party-model-bg-anthropic', currentProviderName === 'langgenius/anthropic/anthropic' && 'bg-third-party-model-bg-anthropic',
)} )}
> >
<div className="flex rounded-t-xl py-2 pl-3 pr-2"> <div className="flex rounded-t-xl py-2 pl-3 pr-2">
@ -117,7 +150,7 @@ const ProviderAddedCard: FC<ProviderAddedCardProps> = ({
)} )}
</div> </div>
{ {
collapsed && ( showCollapsedSection && (
<div className="group flex items-center justify-between border-t border-t-divider-subtle py-1.5 pl-2 pr-[11px] text-text-tertiary system-xs-medium"> <div className="group flex items-center justify-between border-t border-t-divider-subtle py-1.5 pl-2 pr-[11px] text-text-tertiary system-xs-medium">
{(showModelProvider || !notConfigured) && ( {(showModelProvider || !notConfigured) && (
<> <>
@ -155,7 +188,7 @@ const ProviderAddedCard: FC<ProviderAddedCardProps> = ({
</div> </div>
)} )}
{ {
configurationMethods.includes(ConfigurationMethodEnum.customizableModel) && isCurrentWorkspaceManager && ( showCustomModelActions && (
<div className="flex grow justify-end"> <div className="flex grow justify-end">
<ManageCustomModelCredentials <ManageCustomModelCredentials
provider={provider} provider={provider}
@ -173,12 +206,12 @@ const ProviderAddedCard: FC<ProviderAddedCardProps> = ({
) )
} }
{ {
!collapsed && ( !showCollapsedSection && (
<ModelList <ModelList
provider={provider} provider={provider}
models={modelList} models={modelList}
onCollapse={() => setCollapsed(true)} onCollapse={() => setCollapsed(true)}
onChange={(provider: string) => getModelList(provider)} onChange={refreshModelList}
/> />
) )
} }

View File

@ -0,0 +1,17 @@
import type { ModelItem } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { type } from '@orpc/contract'
import { base } from '../base'
export const modelProvidersModelsContract = base
.route({
path: '/workspaces/current/model-providers/{provider}/models',
method: 'GET',
})
.input(type<{
params: {
provider: string
}
}>())
.output(type<{
data: ModelItem[]
}>())

View File

@ -12,6 +12,7 @@ import {
exploreInstalledAppsContract, exploreInstalledAppsContract,
exploreInstalledAppUninstallContract, exploreInstalledAppUninstallContract,
} from './console/explore' } from './console/explore'
import { modelProvidersModelsContract } from './console/model-providers'
import { systemFeaturesContract } from './console/system' import { systemFeaturesContract } from './console/system'
import { import {
triggerOAuthConfigContract, triggerOAuthConfigContract,
@ -63,6 +64,9 @@ export const consoleRouterContract = {
parameters: trialAppParametersContract, parameters: trialAppParametersContract,
workflows: trialAppWorkflowsContract, workflows: trialAppWorkflowsContract,
}, },
modelProviders: {
models: modelProvidersModelsContract,
},
billing: { billing: {
invoices: invoicesContract, invoices: invoicesContract,
bindPartnerStack: bindPartnerStackContract, bindPartnerStack: bindPartnerStackContract,