refactor(web): migrate provider-added-card model list to oRPC query-driven state

This commit is contained in:
yyh
2026-03-04 21:55:34 +08:00
parent 784bda9c86
commit 04e0ab7eda
4 changed files with 152 additions and 57 deletions

View File

@ -1,17 +1,30 @@
import type { ModelItem, ModelProvider } from '../declarations'
import type { ReactNode } from 'react'
import type { ModelProvider } from '../declarations'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
import { fetchModelProviderModelList } from '@/service/common'
import { ConfigurationMethodEnum } from '../declarations'
import ProviderAddedCard from './index'
let mockIsCurrentWorkspaceManager = true
const mockFetchModelProviderModels = vi.fn()
const mockQueryOptions = vi.fn(({ input, ...options }: { input: { params: { provider: string } }, enabled?: boolean }) => ({
queryKey: ['console', 'modelProviders', 'models', input.params.provider],
queryFn: () => mockFetchModelProviderModels(input.params.provider),
...options,
}))
const mockEventEmitter = {
useSubscription: vi.fn(),
emit: vi.fn(),
}
vi.mock('@/service/common', () => ({
fetchModelProviderModelList: vi.fn(),
vi.mock('@/service/client', () => ({
consoleQuery: {
modelProviders: {
models: {
queryOptions: (options: { input: { params: { provider: string } }, enabled?: boolean }) => mockQueryOptions(options),
},
},
},
}))
vi.mock('@/context/app-context', () => ({
@ -53,6 +66,21 @@ vi.mock('@/app/components/header/account-setting/model-provider-page/model-auth'
ManageCustomModelCredentials: () => <div data-testid="manage-custom-model" />,
}))
const createTestQueryClient = () => new QueryClient({
defaultOptions: {
queries: { retry: false, gcTime: 0 },
},
})
const renderWithQueryClient = (node: ReactNode) => {
const queryClient = createTestQueryClient()
return render(
<QueryClientProvider client={queryClient}>
{node}
</QueryClientProvider>,
)
}
describe('ProviderAddedCard', () => {
const mockProvider = {
provider: 'langgenius/openai/openai',
@ -67,19 +95,21 @@ describe('ProviderAddedCard', () => {
})
it('should render provider added card component', () => {
render(<ProviderAddedCard provider={mockProvider} />)
renderWithQueryClient(<ProviderAddedCard provider={mockProvider} />)
expect(screen.getByTestId('provider-added-card')).toBeInTheDocument()
expect(screen.getByTestId('provider-icon')).toBeInTheDocument()
})
it('should open, refresh and collapse model list', async () => {
vi.mocked(fetchModelProviderModelList).mockResolvedValue({ data: [{ model: 'gpt-4' }] } as unknown as { data: ModelItem[] })
render(<ProviderAddedCard provider={mockProvider} />)
mockFetchModelProviderModels.mockResolvedValue({ data: [{ model: 'gpt-4' }] })
renderWithQueryClient(<ProviderAddedCard provider={mockProvider} />)
const showModelsBtn = screen.getByTestId('show-models-button')
fireEvent.click(showModelsBtn)
expect(fetchModelProviderModelList).toHaveBeenCalledWith(`/workspaces/current/model-providers/${mockProvider.provider}/models`)
await waitFor(() => {
expect(mockFetchModelProviderModels).toHaveBeenCalledWith(mockProvider.provider)
})
expect(await screen.findByTestId('model-list')).toBeInTheDocument()
// Test line 71-72: Opening when already fetched
@ -90,13 +120,13 @@ describe('ProviderAddedCard', () => {
// Explicitly re-find and click to re-open
fireEvent.click(screen.getByTestId('show-models-button'))
expect(await screen.findByTestId('model-list')).toBeInTheDocument()
expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1) // Should not fetch again
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1) // Should not fetch again
// Refresh list from ModelList
const refreshBtn = screen.getByRole('button', { name: 'refresh list' })
fireEvent.click(refreshBtn)
await waitFor(() => {
expect(fetchModelProviderModelList).toHaveBeenCalledTimes(2)
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(2)
})
})
@ -105,18 +135,20 @@ describe('ProviderAddedCard', () => {
const promise = new Promise((resolve) => {
resolveOuter = resolve
})
vi.mocked(fetchModelProviderModelList).mockReturnValue(promise as unknown as ReturnType<typeof fetchModelProviderModelList>)
mockFetchModelProviderModels.mockReturnValue(promise)
render(<ProviderAddedCard provider={mockProvider} />)
renderWithQueryClient(<ProviderAddedCard provider={mockProvider} />)
const showModelsBtn = screen.getByTestId('show-models-button')
// First call sets loading to true
fireEvent.click(showModelsBtn)
expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1)
await waitFor(() => {
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1)
})
// Second call should return early because loading is true
fireEvent.click(showModelsBtn)
expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1)
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1)
await act(async () => {
resolveOuter({ data: [] })
@ -130,7 +162,7 @@ describe('ProviderAddedCard', () => {
...mockProvider,
provider: 'custom/provider',
} as unknown as ModelProvider
render(<ProviderAddedCard provider={providerWithoutQuota} notConfigured />)
renderWithQueryClient(<ProviderAddedCard provider={providerWithoutQuota} notConfigured />)
expect(screen.getByText('common.modelProvider.configureTip')).toBeInTheDocument()
})
@ -139,9 +171,9 @@ describe('ProviderAddedCard', () => {
mockEventEmitter.useSubscription.mockImplementation((handler: (v: unknown) => void) => {
capturedHandler = handler as (v: { type: string, payload: string } | null) => void
})
vi.mocked(fetchModelProviderModelList).mockResolvedValue({ data: [] } as unknown as { data: ModelItem[] })
mockFetchModelProviderModels.mockResolvedValue({ data: [] })
render(<ProviderAddedCard provider={mockProvider} />)
renderWithQueryClient(<ProviderAddedCard provider={mockProvider} />)
expect(capturedHandler).toBeDefined()
act(() => {
@ -152,7 +184,7 @@ describe('ProviderAddedCard', () => {
})
await waitFor(() => {
expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1)
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1)
})
// Should ignore non-matching events
@ -160,7 +192,7 @@ describe('ProviderAddedCard', () => {
capturedHandler({ type: 'OTHER', payload: '' })
capturedHandler(null)
})
expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1)
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1)
})
it('should render custom model actions for workspace managers', () => {
@ -168,13 +200,22 @@ describe('ProviderAddedCard', () => {
...mockProvider,
configurate_methods: [ConfigurationMethodEnum.customizableModel],
} as unknown as ModelProvider
const { rerender } = render(<ProviderAddedCard provider={customConfigProvider} />)
const queryClient = createTestQueryClient()
const { rerender } = render(
<QueryClientProvider client={queryClient}>
<ProviderAddedCard provider={customConfigProvider} />
</QueryClientProvider>,
)
expect(screen.getByTestId('manage-custom-model')).toBeInTheDocument()
expect(screen.getByTestId('add-custom-model')).toBeInTheDocument()
mockIsCurrentWorkspaceManager = false
rerender(<ProviderAddedCard provider={customConfigProvider} />)
rerender(
<QueryClientProvider client={queryClient}>
<ProviderAddedCard provider={customConfigProvider} />
</QueryClientProvider>,
)
expect(screen.queryByTestId('manage-custom-model')).not.toBeInTheDocument()
})
})

View File

@ -1,12 +1,13 @@
import type { FC } from 'react'
import type {
ModelItem,
ModelProvider,
} from '../declarations'
import type { ModelProviderQuotaGetPaid } from '../utils'
import type { PluginDetail } from '@/app/components/plugins/types'
import type { EventEmitterValue } from '@/context/event-emitter'
import { useState } from 'react'
import { useQuery } from '@tanstack/react-query'
import { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next'
import {
AddCustomModel,
@ -16,7 +17,7 @@ import { IS_CE_EDITION } from '@/config'
import { useAppContext } from '@/context/app-context'
import { useEventEmitterContextContext } from '@/context/event-emitter'
import { useProviderContext } from '@/context/provider-context'
import { fetchModelProviderModelList } from '@/service/common'
import { consoleQuery } from '@/service/client'
import { cn } from '@/utils/classnames'
import { ConfigurationMethodEnum } from '../declarations'
import ModelBadge from '../model-badge'
@ -30,6 +31,21 @@ import ModelList from './model-list'
import ProviderCardActions from './provider-card-actions'
export const UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST = 'UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST'
const isModelProviderCustomModelListUpdateEvent = (
value: EventEmitterValue,
providerName: string,
): value is {
type: typeof UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST
payload: string
} => {
return typeof value === 'object'
&& value !== null
&& value.type === UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST
&& typeof value.payload === 'string'
&& value.payload === providerName
}
type ProviderAddedCardProps = {
notConfigured?: boolean
provider: ModelProvider
@ -43,52 +59,69 @@ const ProviderAddedCard: FC<ProviderAddedCardProps> = ({
const { t } = useTranslation()
const { eventEmitter } = useEventEmitterContextContext()
const { refreshModelProviders } = useProviderContext()
const [fetched, setFetched] = useState(false)
const [loading, setLoading] = useState(false)
const [collapsed, setCollapsed] = useState(true)
const [modelList, setModelList] = useState<ModelItem[]>([])
const configurationMethods = provider.configurate_methods.filter(method => method !== ConfigurationMethodEnum.fetchFromRemote)
const currentProviderName = provider.provider
const supportsPredefinedModel = provider.configurate_methods.includes(ConfigurationMethodEnum.predefinedModel)
const supportsCustomizableModel = provider.configurate_methods.includes(ConfigurationMethodEnum.customizableModel)
const systemConfig = provider.system_configuration
const hasModelList = fetched && !!modelList.length
const {
data: modelList = [],
isFetching: loading,
isSuccess: hasFetchedModelList,
refetch: refetchModelList,
} = useQuery(consoleQuery.modelProviders.models.queryOptions({
input: { params: { provider: currentProviderName } },
enabled: !collapsed,
staleTime: Infinity,
refetchOnWindowFocus: false,
refetchOnReconnect: false,
select: response => response.data,
}))
const hasModelList = hasFetchedModelList && !!modelList.length
const showCollapsedSection = collapsed || !hasFetchedModelList
const { isCurrentWorkspaceManager } = useAppContext()
const showModelProvider = systemConfig.enabled && MODEL_PROVIDER_QUOTA_GET_PAID.includes(provider.provider as ModelProviderQuotaGetPaid) && !IS_CE_EDITION
const showCredential = configurationMethods.includes(ConfigurationMethodEnum.predefinedModel) && isCurrentWorkspaceManager
const showModelProvider = systemConfig.enabled && MODEL_PROVIDER_QUOTA_GET_PAID.includes(currentProviderName as ModelProviderQuotaGetPaid) && !IS_CE_EDITION
const showCredential = supportsPredefinedModel && isCurrentWorkspaceManager
const showCustomModelActions = supportsCustomizableModel && isCurrentWorkspaceManager
const getModelList = async (providerName: string) => {
const refreshModelList = useCallback((targetProviderName: string) => {
if (targetProviderName !== currentProviderName || loading)
return
if (collapsed)
setCollapsed(false)
refetchModelList().catch(() => {})
}, [collapsed, currentProviderName, loading, refetchModelList])
const handleOpenModelList = useCallback(() => {
if (loading)
return
try {
setLoading(true)
const modelsData = await fetchModelProviderModelList(`/workspaces/current/model-providers/${providerName}/models`)
setModelList(modelsData.data)
setCollapsed(false)
setFetched(true)
}
finally {
setLoading(false)
}
}
const handleOpenModelList = () => {
if (fetched) {
if (collapsed) {
setCollapsed(false)
return
}
getModelList(provider.provider)
}
refetchModelList().catch(() => {})
}, [collapsed, loading, refetchModelList])
eventEmitter?.useSubscription((v: any) => {
if (v?.type === UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST && v.payload === provider.provider)
getModelList(v.payload)
})
const handleModelProviderCustomModelListUpdate = useCallback((value: EventEmitterValue) => {
if (!isModelProviderCustomModelListUpdateEvent(value, currentProviderName))
return
refreshModelList(currentProviderName)
}, [currentProviderName, refreshModelList])
eventEmitter?.useSubscription(handleModelProviderCustomModelListUpdate)
return (
<div
data-testid="provider-added-card"
className={cn(
'mb-2 rounded-xl border-[0.5px] border-divider-regular bg-third-party-model-bg-default shadow-xs',
provider.provider === 'langgenius/openai/openai' && 'bg-third-party-model-bg-openai',
provider.provider === 'langgenius/anthropic/anthropic' && 'bg-third-party-model-bg-anthropic',
currentProviderName === 'langgenius/openai/openai' && 'bg-third-party-model-bg-openai',
currentProviderName === 'langgenius/anthropic/anthropic' && 'bg-third-party-model-bg-anthropic',
)}
>
<div className="flex rounded-t-xl py-2 pl-3 pr-2">
@ -117,7 +150,7 @@ const ProviderAddedCard: FC<ProviderAddedCardProps> = ({
)}
</div>
{
collapsed && (
showCollapsedSection && (
<div className="group flex items-center justify-between border-t border-t-divider-subtle py-1.5 pl-2 pr-[11px] text-text-tertiary system-xs-medium">
{(showModelProvider || !notConfigured) && (
<>
@ -155,7 +188,7 @@ const ProviderAddedCard: FC<ProviderAddedCardProps> = ({
</div>
)}
{
configurationMethods.includes(ConfigurationMethodEnum.customizableModel) && isCurrentWorkspaceManager && (
showCustomModelActions && (
<div className="flex grow justify-end">
<ManageCustomModelCredentials
provider={provider}
@ -173,12 +206,12 @@ const ProviderAddedCard: FC<ProviderAddedCardProps> = ({
)
}
{
!collapsed && (
!showCollapsedSection && (
<ModelList
provider={provider}
models={modelList}
onCollapse={() => setCollapsed(true)}
onChange={(provider: string) => getModelList(provider)}
onChange={refreshModelList}
/>
)
}

View File

@ -0,0 +1,17 @@
import type { ModelItem } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { type } from '@orpc/contract'
import { base } from '../base'
export const modelProvidersModelsContract = base
.route({
path: '/workspaces/current/model-providers/{provider}/models',
method: 'GET',
})
.input(type<{
params: {
provider: string
}
}>())
.output(type<{
data: ModelItem[]
}>())

View File

@ -12,6 +12,7 @@ import {
exploreInstalledAppsContract,
exploreInstalledAppUninstallContract,
} from './console/explore'
import { modelProvidersModelsContract } from './console/model-providers'
import { systemFeaturesContract } from './console/system'
import {
triggerOAuthConfigContract,
@ -63,6 +64,9 @@ export const consoleRouterContract = {
parameters: trialAppParametersContract,
workflows: trialAppWorkflowsContract,
},
modelProviders: {
models: modelProvidersModelsContract,
},
billing: {
invoices: invoicesContract,
bindPartnerStack: bindPartnerStackContract,