mirror of
https://github.com/langgenius/dify.git
synced 2026-03-05 15:47:06 +08:00
refactor(web): migrate provider-added-card model list to oRPC query-driven state
This commit is contained in:
@ -1,17 +1,30 @@
|
||||
import type { ModelItem, ModelProvider } from '../declarations'
|
||||
import type { ReactNode } from 'react'
|
||||
import type { ModelProvider } from '../declarations'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { fetchModelProviderModelList } from '@/service/common'
|
||||
import { ConfigurationMethodEnum } from '../declarations'
|
||||
import ProviderAddedCard from './index'
|
||||
|
||||
let mockIsCurrentWorkspaceManager = true
|
||||
const mockFetchModelProviderModels = vi.fn()
|
||||
const mockQueryOptions = vi.fn(({ input, ...options }: { input: { params: { provider: string } }, enabled?: boolean }) => ({
|
||||
queryKey: ['console', 'modelProviders', 'models', input.params.provider],
|
||||
queryFn: () => mockFetchModelProviderModels(input.params.provider),
|
||||
...options,
|
||||
}))
|
||||
const mockEventEmitter = {
|
||||
useSubscription: vi.fn(),
|
||||
emit: vi.fn(),
|
||||
}
|
||||
|
||||
vi.mock('@/service/common', () => ({
|
||||
fetchModelProviderModelList: vi.fn(),
|
||||
vi.mock('@/service/client', () => ({
|
||||
consoleQuery: {
|
||||
modelProviders: {
|
||||
models: {
|
||||
queryOptions: (options: { input: { params: { provider: string } }, enabled?: boolean }) => mockQueryOptions(options),
|
||||
},
|
||||
},
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
@ -53,6 +66,21 @@ vi.mock('@/app/components/header/account-setting/model-provider-page/model-auth'
|
||||
ManageCustomModelCredentials: () => <div data-testid="manage-custom-model" />,
|
||||
}))
|
||||
|
||||
const createTestQueryClient = () => new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: { retry: false, gcTime: 0 },
|
||||
},
|
||||
})
|
||||
|
||||
const renderWithQueryClient = (node: ReactNode) => {
|
||||
const queryClient = createTestQueryClient()
|
||||
return render(
|
||||
<QueryClientProvider client={queryClient}>
|
||||
{node}
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
}
|
||||
|
||||
describe('ProviderAddedCard', () => {
|
||||
const mockProvider = {
|
||||
provider: 'langgenius/openai/openai',
|
||||
@ -67,19 +95,21 @@ describe('ProviderAddedCard', () => {
|
||||
})
|
||||
|
||||
it('should render provider added card component', () => {
|
||||
render(<ProviderAddedCard provider={mockProvider} />)
|
||||
renderWithQueryClient(<ProviderAddedCard provider={mockProvider} />)
|
||||
expect(screen.getByTestId('provider-added-card')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('provider-icon')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should open, refresh and collapse model list', async () => {
|
||||
vi.mocked(fetchModelProviderModelList).mockResolvedValue({ data: [{ model: 'gpt-4' }] } as unknown as { data: ModelItem[] })
|
||||
render(<ProviderAddedCard provider={mockProvider} />)
|
||||
mockFetchModelProviderModels.mockResolvedValue({ data: [{ model: 'gpt-4' }] })
|
||||
renderWithQueryClient(<ProviderAddedCard provider={mockProvider} />)
|
||||
|
||||
const showModelsBtn = screen.getByTestId('show-models-button')
|
||||
fireEvent.click(showModelsBtn)
|
||||
|
||||
expect(fetchModelProviderModelList).toHaveBeenCalledWith(`/workspaces/current/model-providers/${mockProvider.provider}/models`)
|
||||
await waitFor(() => {
|
||||
expect(mockFetchModelProviderModels).toHaveBeenCalledWith(mockProvider.provider)
|
||||
})
|
||||
expect(await screen.findByTestId('model-list')).toBeInTheDocument()
|
||||
|
||||
// Test line 71-72: Opening when already fetched
|
||||
@ -90,13 +120,13 @@ describe('ProviderAddedCard', () => {
|
||||
// Explicitly re-find and click to re-open
|
||||
fireEvent.click(screen.getByTestId('show-models-button'))
|
||||
expect(await screen.findByTestId('model-list')).toBeInTheDocument()
|
||||
expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1) // Should not fetch again
|
||||
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1) // Should not fetch again
|
||||
|
||||
// Refresh list from ModelList
|
||||
const refreshBtn = screen.getByRole('button', { name: 'refresh list' })
|
||||
fireEvent.click(refreshBtn)
|
||||
await waitFor(() => {
|
||||
expect(fetchModelProviderModelList).toHaveBeenCalledTimes(2)
|
||||
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
})
|
||||
|
||||
@ -105,18 +135,20 @@ describe('ProviderAddedCard', () => {
|
||||
const promise = new Promise((resolve) => {
|
||||
resolveOuter = resolve
|
||||
})
|
||||
vi.mocked(fetchModelProviderModelList).mockReturnValue(promise as unknown as ReturnType<typeof fetchModelProviderModelList>)
|
||||
mockFetchModelProviderModels.mockReturnValue(promise)
|
||||
|
||||
render(<ProviderAddedCard provider={mockProvider} />)
|
||||
renderWithQueryClient(<ProviderAddedCard provider={mockProvider} />)
|
||||
const showModelsBtn = screen.getByTestId('show-models-button')
|
||||
|
||||
// First call sets loading to true
|
||||
fireEvent.click(showModelsBtn)
|
||||
expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1)
|
||||
await waitFor(() => {
|
||||
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
// Second call should return early because loading is true
|
||||
fireEvent.click(showModelsBtn)
|
||||
expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1)
|
||||
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1)
|
||||
|
||||
await act(async () => {
|
||||
resolveOuter({ data: [] })
|
||||
@ -130,7 +162,7 @@ describe('ProviderAddedCard', () => {
|
||||
...mockProvider,
|
||||
provider: 'custom/provider',
|
||||
} as unknown as ModelProvider
|
||||
render(<ProviderAddedCard provider={providerWithoutQuota} notConfigured />)
|
||||
renderWithQueryClient(<ProviderAddedCard provider={providerWithoutQuota} notConfigured />)
|
||||
expect(screen.getByText('common.modelProvider.configureTip')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
@ -139,9 +171,9 @@ describe('ProviderAddedCard', () => {
|
||||
mockEventEmitter.useSubscription.mockImplementation((handler: (v: unknown) => void) => {
|
||||
capturedHandler = handler as (v: { type: string, payload: string } | null) => void
|
||||
})
|
||||
vi.mocked(fetchModelProviderModelList).mockResolvedValue({ data: [] } as unknown as { data: ModelItem[] })
|
||||
mockFetchModelProviderModels.mockResolvedValue({ data: [] })
|
||||
|
||||
render(<ProviderAddedCard provider={mockProvider} />)
|
||||
renderWithQueryClient(<ProviderAddedCard provider={mockProvider} />)
|
||||
|
||||
expect(capturedHandler).toBeDefined()
|
||||
act(() => {
|
||||
@ -152,7 +184,7 @@ describe('ProviderAddedCard', () => {
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1)
|
||||
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
// Should ignore non-matching events
|
||||
@ -160,7 +192,7 @@ describe('ProviderAddedCard', () => {
|
||||
capturedHandler({ type: 'OTHER', payload: '' })
|
||||
capturedHandler(null)
|
||||
})
|
||||
expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1)
|
||||
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should render custom model actions for workspace managers', () => {
|
||||
@ -168,13 +200,22 @@ describe('ProviderAddedCard', () => {
|
||||
...mockProvider,
|
||||
configurate_methods: [ConfigurationMethodEnum.customizableModel],
|
||||
} as unknown as ModelProvider
|
||||
const { rerender } = render(<ProviderAddedCard provider={customConfigProvider} />)
|
||||
const queryClient = createTestQueryClient()
|
||||
const { rerender } = render(
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<ProviderAddedCard provider={customConfigProvider} />
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('manage-custom-model')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('add-custom-model')).toBeInTheDocument()
|
||||
|
||||
mockIsCurrentWorkspaceManager = false
|
||||
rerender(<ProviderAddedCard provider={customConfigProvider} />)
|
||||
rerender(
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<ProviderAddedCard provider={customConfigProvider} />
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
expect(screen.queryByTestId('manage-custom-model')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
import type { FC } from 'react'
|
||||
import type {
|
||||
ModelItem,
|
||||
ModelProvider,
|
||||
} from '../declarations'
|
||||
import type { ModelProviderQuotaGetPaid } from '../utils'
|
||||
import type { PluginDetail } from '@/app/components/plugins/types'
|
||||
import type { EventEmitterValue } from '@/context/event-emitter'
|
||||
|
||||
import { useState } from 'react'
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { useCallback, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import {
|
||||
AddCustomModel,
|
||||
@ -16,7 +17,7 @@ import { IS_CE_EDITION } from '@/config'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useEventEmitterContextContext } from '@/context/event-emitter'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { fetchModelProviderModelList } from '@/service/common'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { ConfigurationMethodEnum } from '../declarations'
|
||||
import ModelBadge from '../model-badge'
|
||||
@ -30,6 +31,21 @@ import ModelList from './model-list'
|
||||
import ProviderCardActions from './provider-card-actions'
|
||||
|
||||
export const UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST = 'UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST'
|
||||
|
||||
const isModelProviderCustomModelListUpdateEvent = (
|
||||
value: EventEmitterValue,
|
||||
providerName: string,
|
||||
): value is {
|
||||
type: typeof UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST
|
||||
payload: string
|
||||
} => {
|
||||
return typeof value === 'object'
|
||||
&& value !== null
|
||||
&& value.type === UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST
|
||||
&& typeof value.payload === 'string'
|
||||
&& value.payload === providerName
|
||||
}
|
||||
|
||||
type ProviderAddedCardProps = {
|
||||
notConfigured?: boolean
|
||||
provider: ModelProvider
|
||||
@ -43,52 +59,69 @@ const ProviderAddedCard: FC<ProviderAddedCardProps> = ({
|
||||
const { t } = useTranslation()
|
||||
const { eventEmitter } = useEventEmitterContextContext()
|
||||
const { refreshModelProviders } = useProviderContext()
|
||||
const [fetched, setFetched] = useState(false)
|
||||
const [loading, setLoading] = useState(false)
|
||||
const [collapsed, setCollapsed] = useState(true)
|
||||
const [modelList, setModelList] = useState<ModelItem[]>([])
|
||||
const configurationMethods = provider.configurate_methods.filter(method => method !== ConfigurationMethodEnum.fetchFromRemote)
|
||||
const currentProviderName = provider.provider
|
||||
const supportsPredefinedModel = provider.configurate_methods.includes(ConfigurationMethodEnum.predefinedModel)
|
||||
const supportsCustomizableModel = provider.configurate_methods.includes(ConfigurationMethodEnum.customizableModel)
|
||||
const systemConfig = provider.system_configuration
|
||||
const hasModelList = fetched && !!modelList.length
|
||||
const {
|
||||
data: modelList = [],
|
||||
isFetching: loading,
|
||||
isSuccess: hasFetchedModelList,
|
||||
refetch: refetchModelList,
|
||||
} = useQuery(consoleQuery.modelProviders.models.queryOptions({
|
||||
input: { params: { provider: currentProviderName } },
|
||||
enabled: !collapsed,
|
||||
staleTime: Infinity,
|
||||
refetchOnWindowFocus: false,
|
||||
refetchOnReconnect: false,
|
||||
select: response => response.data,
|
||||
}))
|
||||
const hasModelList = hasFetchedModelList && !!modelList.length
|
||||
const showCollapsedSection = collapsed || !hasFetchedModelList
|
||||
const { isCurrentWorkspaceManager } = useAppContext()
|
||||
const showModelProvider = systemConfig.enabled && MODEL_PROVIDER_QUOTA_GET_PAID.includes(provider.provider as ModelProviderQuotaGetPaid) && !IS_CE_EDITION
|
||||
const showCredential = configurationMethods.includes(ConfigurationMethodEnum.predefinedModel) && isCurrentWorkspaceManager
|
||||
const showModelProvider = systemConfig.enabled && MODEL_PROVIDER_QUOTA_GET_PAID.includes(currentProviderName as ModelProviderQuotaGetPaid) && !IS_CE_EDITION
|
||||
const showCredential = supportsPredefinedModel && isCurrentWorkspaceManager
|
||||
const showCustomModelActions = supportsCustomizableModel && isCurrentWorkspaceManager
|
||||
|
||||
const getModelList = async (providerName: string) => {
|
||||
const refreshModelList = useCallback((targetProviderName: string) => {
|
||||
if (targetProviderName !== currentProviderName || loading)
|
||||
return
|
||||
|
||||
if (collapsed)
|
||||
setCollapsed(false)
|
||||
|
||||
refetchModelList().catch(() => {})
|
||||
}, [collapsed, currentProviderName, loading, refetchModelList])
|
||||
|
||||
const handleOpenModelList = useCallback(() => {
|
||||
if (loading)
|
||||
return
|
||||
try {
|
||||
setLoading(true)
|
||||
const modelsData = await fetchModelProviderModelList(`/workspaces/current/model-providers/${providerName}/models`)
|
||||
setModelList(modelsData.data)
|
||||
setCollapsed(false)
|
||||
setFetched(true)
|
||||
}
|
||||
finally {
|
||||
setLoading(false)
|
||||
}
|
||||
}
|
||||
const handleOpenModelList = () => {
|
||||
if (fetched) {
|
||||
|
||||
if (collapsed) {
|
||||
setCollapsed(false)
|
||||
return
|
||||
}
|
||||
|
||||
getModelList(provider.provider)
|
||||
}
|
||||
refetchModelList().catch(() => {})
|
||||
}, [collapsed, loading, refetchModelList])
|
||||
|
||||
eventEmitter?.useSubscription((v: any) => {
|
||||
if (v?.type === UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST && v.payload === provider.provider)
|
||||
getModelList(v.payload)
|
||||
})
|
||||
const handleModelProviderCustomModelListUpdate = useCallback((value: EventEmitterValue) => {
|
||||
if (!isModelProviderCustomModelListUpdateEvent(value, currentProviderName))
|
||||
return
|
||||
|
||||
refreshModelList(currentProviderName)
|
||||
}, [currentProviderName, refreshModelList])
|
||||
|
||||
eventEmitter?.useSubscription(handleModelProviderCustomModelListUpdate)
|
||||
|
||||
return (
|
||||
<div
|
||||
data-testid="provider-added-card"
|
||||
className={cn(
|
||||
'mb-2 rounded-xl border-[0.5px] border-divider-regular bg-third-party-model-bg-default shadow-xs',
|
||||
provider.provider === 'langgenius/openai/openai' && 'bg-third-party-model-bg-openai',
|
||||
provider.provider === 'langgenius/anthropic/anthropic' && 'bg-third-party-model-bg-anthropic',
|
||||
currentProviderName === 'langgenius/openai/openai' && 'bg-third-party-model-bg-openai',
|
||||
currentProviderName === 'langgenius/anthropic/anthropic' && 'bg-third-party-model-bg-anthropic',
|
||||
)}
|
||||
>
|
||||
<div className="flex rounded-t-xl py-2 pl-3 pr-2">
|
||||
@ -117,7 +150,7 @@ const ProviderAddedCard: FC<ProviderAddedCardProps> = ({
|
||||
)}
|
||||
</div>
|
||||
{
|
||||
collapsed && (
|
||||
showCollapsedSection && (
|
||||
<div className="group flex items-center justify-between border-t border-t-divider-subtle py-1.5 pl-2 pr-[11px] text-text-tertiary system-xs-medium">
|
||||
{(showModelProvider || !notConfigured) && (
|
||||
<>
|
||||
@ -155,7 +188,7 @@ const ProviderAddedCard: FC<ProviderAddedCardProps> = ({
|
||||
</div>
|
||||
)}
|
||||
{
|
||||
configurationMethods.includes(ConfigurationMethodEnum.customizableModel) && isCurrentWorkspaceManager && (
|
||||
showCustomModelActions && (
|
||||
<div className="flex grow justify-end">
|
||||
<ManageCustomModelCredentials
|
||||
provider={provider}
|
||||
@ -173,12 +206,12 @@ const ProviderAddedCard: FC<ProviderAddedCardProps> = ({
|
||||
)
|
||||
}
|
||||
{
|
||||
!collapsed && (
|
||||
!showCollapsedSection && (
|
||||
<ModelList
|
||||
provider={provider}
|
||||
models={modelList}
|
||||
onCollapse={() => setCollapsed(true)}
|
||||
onChange={(provider: string) => getModelList(provider)}
|
||||
onChange={refreshModelList}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
17
web/contract/console/model-providers.ts
Normal file
17
web/contract/console/model-providers.ts
Normal file
@ -0,0 +1,17 @@
|
||||
import type { ModelItem } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { type } from '@orpc/contract'
|
||||
import { base } from '../base'
|
||||
|
||||
export const modelProvidersModelsContract = base
|
||||
.route({
|
||||
path: '/workspaces/current/model-providers/{provider}/models',
|
||||
method: 'GET',
|
||||
})
|
||||
.input(type<{
|
||||
params: {
|
||||
provider: string
|
||||
}
|
||||
}>())
|
||||
.output(type<{
|
||||
data: ModelItem[]
|
||||
}>())
|
||||
@ -12,6 +12,7 @@ import {
|
||||
exploreInstalledAppsContract,
|
||||
exploreInstalledAppUninstallContract,
|
||||
} from './console/explore'
|
||||
import { modelProvidersModelsContract } from './console/model-providers'
|
||||
import { systemFeaturesContract } from './console/system'
|
||||
import {
|
||||
triggerOAuthConfigContract,
|
||||
@ -63,6 +64,9 @@ export const consoleRouterContract = {
|
||||
parameters: trialAppParametersContract,
|
||||
workflows: trialAppWorkflowsContract,
|
||||
},
|
||||
modelProviders: {
|
||||
models: modelProvidersModelsContract,
|
||||
},
|
||||
billing: {
|
||||
invoices: invoicesContract,
|
||||
bindPartnerStack: bindPartnerStackContract,
|
||||
|
||||
Reference in New Issue
Block a user