diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/index.spec.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/index.spec.tsx
index 51c0ebce39..28b95c891c 100644
--- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/index.spec.tsx
+++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/index.spec.tsx
@@ -1,17 +1,30 @@
-import type { ModelItem, ModelProvider } from '../declarations'
+import type { ReactNode } from 'react'
+import type { ModelProvider } from '../declarations'
+import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
-import { fetchModelProviderModelList } from '@/service/common'
import { ConfigurationMethodEnum } from '../declarations'
import ProviderAddedCard from './index'
let mockIsCurrentWorkspaceManager = true
+const mockFetchModelProviderModels = vi.fn()
+const mockQueryOptions = vi.fn(({ input, ...options }: { input: { params: { provider: string } }, enabled?: boolean }) => ({
+ queryKey: ['console', 'modelProviders', 'models', input.params.provider],
+ queryFn: () => mockFetchModelProviderModels(input.params.provider),
+ ...options,
+}))
const mockEventEmitter = {
useSubscription: vi.fn(),
emit: vi.fn(),
}
-vi.mock('@/service/common', () => ({
- fetchModelProviderModelList: vi.fn(),
+vi.mock('@/service/client', () => ({
+ consoleQuery: {
+ modelProviders: {
+ models: {
+ queryOptions: (options: { input: { params: { provider: string } }, enabled?: boolean }) => mockQueryOptions(options),
+ },
+ },
+ },
}))
vi.mock('@/context/app-context', () => ({
@@ -53,6 +66,21 @@ vi.mock('@/app/components/header/account-setting/model-provider-page/model-auth'
ManageCustomModelCredentials: () =>
,
}))
+const createTestQueryClient = () => new QueryClient({
+ defaultOptions: {
+ queries: { retry: false, gcTime: 0 },
+ },
+})
+
+const renderWithQueryClient = (node: ReactNode) => {
+ const queryClient = createTestQueryClient()
+ return render(
+
+ {node}
+ ,
+ )
+}
+
describe('ProviderAddedCard', () => {
const mockProvider = {
provider: 'langgenius/openai/openai',
@@ -67,19 +95,21 @@ describe('ProviderAddedCard', () => {
})
it('should render provider added card component', () => {
- render()
+ renderWithQueryClient()
expect(screen.getByTestId('provider-added-card')).toBeInTheDocument()
expect(screen.getByTestId('provider-icon')).toBeInTheDocument()
})
it('should open, refresh and collapse model list', async () => {
- vi.mocked(fetchModelProviderModelList).mockResolvedValue({ data: [{ model: 'gpt-4' }] } as unknown as { data: ModelItem[] })
- render()
+ mockFetchModelProviderModels.mockResolvedValue({ data: [{ model: 'gpt-4' }] })
+ renderWithQueryClient()
const showModelsBtn = screen.getByTestId('show-models-button')
fireEvent.click(showModelsBtn)
- expect(fetchModelProviderModelList).toHaveBeenCalledWith(`/workspaces/current/model-providers/${mockProvider.provider}/models`)
+ await waitFor(() => {
+ expect(mockFetchModelProviderModels).toHaveBeenCalledWith(mockProvider.provider)
+ })
expect(await screen.findByTestId('model-list')).toBeInTheDocument()
// Test line 71-72: Opening when already fetched
@@ -90,13 +120,13 @@ describe('ProviderAddedCard', () => {
// Explicitly re-find and click to re-open
fireEvent.click(screen.getByTestId('show-models-button'))
expect(await screen.findByTestId('model-list')).toBeInTheDocument()
- expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1) // Should not fetch again
+ expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1) // Should not fetch again
// Refresh list from ModelList
const refreshBtn = screen.getByRole('button', { name: 'refresh list' })
fireEvent.click(refreshBtn)
await waitFor(() => {
- expect(fetchModelProviderModelList).toHaveBeenCalledTimes(2)
+ expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(2)
})
})
@@ -105,18 +135,20 @@ describe('ProviderAddedCard', () => {
const promise = new Promise((resolve) => {
resolveOuter = resolve
})
- vi.mocked(fetchModelProviderModelList).mockReturnValue(promise as unknown as ReturnType)
+ mockFetchModelProviderModels.mockReturnValue(promise)
- render()
+ renderWithQueryClient()
const showModelsBtn = screen.getByTestId('show-models-button')
// First call sets loading to true
fireEvent.click(showModelsBtn)
- expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1)
+ await waitFor(() => {
+ expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1)
+ })
// Second call should return early because loading is true
fireEvent.click(showModelsBtn)
- expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1)
+ expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1)
await act(async () => {
resolveOuter({ data: [] })
@@ -130,7 +162,7 @@ describe('ProviderAddedCard', () => {
...mockProvider,
provider: 'custom/provider',
} as unknown as ModelProvider
- render()
+ renderWithQueryClient()
expect(screen.getByText('common.modelProvider.configureTip')).toBeInTheDocument()
})
@@ -139,9 +171,9 @@ describe('ProviderAddedCard', () => {
mockEventEmitter.useSubscription.mockImplementation((handler: (v: unknown) => void) => {
capturedHandler = handler as (v: { type: string, payload: string } | null) => void
})
- vi.mocked(fetchModelProviderModelList).mockResolvedValue({ data: [] } as unknown as { data: ModelItem[] })
+ mockFetchModelProviderModels.mockResolvedValue({ data: [] })
- render()
+ renderWithQueryClient()
expect(capturedHandler).toBeDefined()
act(() => {
@@ -152,7 +184,7 @@ describe('ProviderAddedCard', () => {
})
await waitFor(() => {
- expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1)
+ expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1)
})
// Should ignore non-matching events
@@ -160,7 +192,7 @@ describe('ProviderAddedCard', () => {
capturedHandler({ type: 'OTHER', payload: '' })
capturedHandler(null)
})
- expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1)
+ expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1)
})
it('should render custom model actions for workspace managers', () => {
@@ -168,13 +200,22 @@ describe('ProviderAddedCard', () => {
...mockProvider,
configurate_methods: [ConfigurationMethodEnum.customizableModel],
} as unknown as ModelProvider
- const { rerender } = render()
+ const queryClient = createTestQueryClient()
+ const { rerender } = render(
+
+
+ ,
+ )
expect(screen.getByTestId('manage-custom-model')).toBeInTheDocument()
expect(screen.getByTestId('add-custom-model')).toBeInTheDocument()
mockIsCurrentWorkspaceManager = false
- rerender()
+ rerender(
+
+
+ ,
+ )
expect(screen.queryByTestId('manage-custom-model')).not.toBeInTheDocument()
})
})
diff --git a/web/app/components/header/account-setting/model-provider-page/provider-added-card/index.tsx b/web/app/components/header/account-setting/model-provider-page/provider-added-card/index.tsx
index 8361f6068d..647e47a381 100644
--- a/web/app/components/header/account-setting/model-provider-page/provider-added-card/index.tsx
+++ b/web/app/components/header/account-setting/model-provider-page/provider-added-card/index.tsx
@@ -1,12 +1,13 @@
import type { FC } from 'react'
import type {
- ModelItem,
ModelProvider,
} from '../declarations'
import type { ModelProviderQuotaGetPaid } from '../utils'
import type { PluginDetail } from '@/app/components/plugins/types'
+import type { EventEmitterValue } from '@/context/event-emitter'
-import { useState } from 'react'
+import { useQuery } from '@tanstack/react-query'
+import { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next'
import {
AddCustomModel,
@@ -16,7 +17,7 @@ import { IS_CE_EDITION } from '@/config'
import { useAppContext } from '@/context/app-context'
import { useEventEmitterContextContext } from '@/context/event-emitter'
import { useProviderContext } from '@/context/provider-context'
-import { fetchModelProviderModelList } from '@/service/common'
+import { consoleQuery } from '@/service/client'
import { cn } from '@/utils/classnames'
import { ConfigurationMethodEnum } from '../declarations'
import ModelBadge from '../model-badge'
@@ -30,6 +31,21 @@ import ModelList from './model-list'
import ProviderCardActions from './provider-card-actions'
export const UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST = 'UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST'
+
+const isModelProviderCustomModelListUpdateEvent = (
+ value: EventEmitterValue,
+ providerName: string,
+): value is {
+ type: typeof UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST
+ payload: string
+} => {
+ return typeof value === 'object'
+ && value !== null
+ && value.type === UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST
+ && typeof value.payload === 'string'
+ && value.payload === providerName
+}
+
type ProviderAddedCardProps = {
notConfigured?: boolean
provider: ModelProvider
@@ -43,52 +59,69 @@ const ProviderAddedCard: FC = ({
const { t } = useTranslation()
const { eventEmitter } = useEventEmitterContextContext()
const { refreshModelProviders } = useProviderContext()
- const [fetched, setFetched] = useState(false)
- const [loading, setLoading] = useState(false)
const [collapsed, setCollapsed] = useState(true)
- const [modelList, setModelList] = useState([])
- const configurationMethods = provider.configurate_methods.filter(method => method !== ConfigurationMethodEnum.fetchFromRemote)
+ const currentProviderName = provider.provider
+ const supportsPredefinedModel = provider.configurate_methods.includes(ConfigurationMethodEnum.predefinedModel)
+ const supportsCustomizableModel = provider.configurate_methods.includes(ConfigurationMethodEnum.customizableModel)
const systemConfig = provider.system_configuration
- const hasModelList = fetched && !!modelList.length
+ const {
+ data: modelList = [],
+ isFetching: loading,
+ isSuccess: hasFetchedModelList,
+ refetch: refetchModelList,
+ } = useQuery(consoleQuery.modelProviders.models.queryOptions({
+ input: { params: { provider: currentProviderName } },
+ enabled: !collapsed,
+ staleTime: Infinity,
+ refetchOnWindowFocus: false,
+ refetchOnReconnect: false,
+ select: response => response.data,
+ }))
+ const hasModelList = hasFetchedModelList && !!modelList.length
+ const showCollapsedSection = collapsed || !hasFetchedModelList
const { isCurrentWorkspaceManager } = useAppContext()
- const showModelProvider = systemConfig.enabled && MODEL_PROVIDER_QUOTA_GET_PAID.includes(provider.provider as ModelProviderQuotaGetPaid) && !IS_CE_EDITION
- const showCredential = configurationMethods.includes(ConfigurationMethodEnum.predefinedModel) && isCurrentWorkspaceManager
+ const showModelProvider = systemConfig.enabled && MODEL_PROVIDER_QUOTA_GET_PAID.includes(currentProviderName as ModelProviderQuotaGetPaid) && !IS_CE_EDITION
+ const showCredential = supportsPredefinedModel && isCurrentWorkspaceManager
+ const showCustomModelActions = supportsCustomizableModel && isCurrentWorkspaceManager
- const getModelList = async (providerName: string) => {
+ const refreshModelList = useCallback((targetProviderName: string) => {
+ if (targetProviderName !== currentProviderName || loading)
+ return
+
+ if (collapsed)
+ setCollapsed(false)
+
+ refetchModelList().catch(() => {})
+ }, [collapsed, currentProviderName, loading, refetchModelList])
+
+ const handleOpenModelList = useCallback(() => {
if (loading)
return
- try {
- setLoading(true)
- const modelsData = await fetchModelProviderModelList(`/workspaces/current/model-providers/${providerName}/models`)
- setModelList(modelsData.data)
- setCollapsed(false)
- setFetched(true)
- }
- finally {
- setLoading(false)
- }
- }
- const handleOpenModelList = () => {
- if (fetched) {
+
+ if (collapsed) {
setCollapsed(false)
return
}
- getModelList(provider.provider)
- }
+ refetchModelList().catch(() => {})
+ }, [collapsed, loading, refetchModelList])
- eventEmitter?.useSubscription((v: any) => {
- if (v?.type === UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST && v.payload === provider.provider)
- getModelList(v.payload)
- })
+ const handleModelProviderCustomModelListUpdate = useCallback((value: EventEmitterValue) => {
+ if (!isModelProviderCustomModelListUpdateEvent(value, currentProviderName))
+ return
+
+ refreshModelList(currentProviderName)
+ }, [currentProviderName, refreshModelList])
+
+ eventEmitter?.useSubscription(handleModelProviderCustomModelListUpdate)
return (
@@ -117,7 +150,7 @@ const ProviderAddedCard: FC
= ({
)}
{
- collapsed && (
+ showCollapsedSection && (
{(showModelProvider || !notConfigured) && (
<>
@@ -155,7 +188,7 @@ const ProviderAddedCard: FC
= ({
)}
{
- configurationMethods.includes(ConfigurationMethodEnum.customizableModel) && isCurrentWorkspaceManager && (
+ showCustomModelActions && (
= ({
)
}
{
- !collapsed && (
+ !showCollapsedSection && (
setCollapsed(true)}
- onChange={(provider: string) => getModelList(provider)}
+ onChange={refreshModelList}
/>
)
}
diff --git a/web/contract/console/model-providers.ts b/web/contract/console/model-providers.ts
new file mode 100644
index 0000000000..39cbb64914
--- /dev/null
+++ b/web/contract/console/model-providers.ts
@@ -0,0 +1,17 @@
+import type { ModelItem } from '@/app/components/header/account-setting/model-provider-page/declarations'
+import { type } from '@orpc/contract'
+import { base } from '../base'
+
+export const modelProvidersModelsContract = base
+ .route({
+ path: '/workspaces/current/model-providers/{provider}/models',
+ method: 'GET',
+ })
+ .input(type<{
+ params: {
+ provider: string
+ }
+ }>())
+ .output(type<{
+ data: ModelItem[]
+ }>())
diff --git a/web/contract/router.ts b/web/contract/router.ts
index 79a95be55a..560284bc3f 100644
--- a/web/contract/router.ts
+++ b/web/contract/router.ts
@@ -12,6 +12,7 @@ import {
exploreInstalledAppsContract,
exploreInstalledAppUninstallContract,
} from './console/explore'
+import { modelProvidersModelsContract } from './console/model-providers'
import { systemFeaturesContract } from './console/system'
import {
triggerOAuthConfigContract,
@@ -63,6 +64,9 @@ export const consoleRouterContract = {
parameters: trialAppParametersContract,
workflows: trialAppWorkflowsContract,
},
+ modelProviders: {
+ models: modelProvidersModelsContract,
+ },
billing: {
invoices: invoicesContract,
bindPartnerStack: bindPartnerStackContract,