support model params change

This commit is contained in:
JzoNg
2024-12-24 14:15:18 +08:00
parent c8fc1deca6
commit e2e2090e0c
5 changed files with 131 additions and 71 deletions

View File

@ -21,6 +21,7 @@ import {
PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem'
import LLMParamsPanel from './llm-params-panel'
import TTSParamsPanel from './tts-params-panel'
import { useProviderContext } from '@/context/provider-context'
import cn from '@/utils/classnames'
@ -28,12 +29,8 @@ export type ModelParameterModalProps = {
popupClassName?: string
portalToFollowElemContentClassName?: string
isAdvancedMode: boolean
mode: string
modelId: string
provider: string
setModel: (model: { modelId: string; provider: string; mode?: string; features?: string[] }) => void
completionParams: FormValue
onCompletionParamsChange: (newParams: FormValue) => void
value: any
setModel: (model: any) => void
renderTrigger?: (v: TriggerProps) => ReactNode
readonly?: boolean
isInWorkflow?: boolean
@ -44,15 +41,12 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
popupClassName,
portalToFollowElemContentClassName,
isAdvancedMode,
modelId,
provider,
value,
setModel,
completionParams,
onCompletionParamsChange,
renderTrigger,
readonly,
isInWorkflow,
scope = 'text-generation',
scope = ModelTypeEnum.textGeneration,
}) => {
const { t } = useTranslation()
const { isAPIKeySet } = useProviderContext()
@ -79,29 +73,29 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
...moderationList,
]
}
if (scopeArray.includes('text-generation'))
if (scopeArray.includes(ModelTypeEnum.textGeneration))
return textGenerationList
if (scopeArray.includes('embedding'))
if (scopeArray.includes(ModelTypeEnum.textEmbedding))
return textEmbeddingList
if (scopeArray.includes('rerank'))
if (scopeArray.includes(ModelTypeEnum.rerank))
return rerankList
if (scopeArray.includes('moderation'))
if (scopeArray.includes(ModelTypeEnum.moderation))
return moderationList
if (scopeArray.includes('stt'))
if (scopeArray.includes(ModelTypeEnum.speech2text))
return sttList
if (scopeArray.includes('tts'))
if (scopeArray.includes(ModelTypeEnum.tts))
return ttsList
return resultList
}, [scopeArray, textGenerationList, textEmbeddingList, rerankList, sttList, ttsList, moderationList])
const { currentProvider, currentModel } = useMemo(() => {
const currentProvider = scopedModelList.find(item => item.provider === provider)
const currentModel = currentProvider?.models.find((model: { model: string }) => model.model === modelId)
const currentProvider = scopedModelList.find(item => item.provider === value?.provider)
const currentModel = currentProvider?.models.find((model: { model: string }) => model.model === value?.model)
return {
currentProvider,
currentModel,
}
}, [provider, modelId, scopedModelList])
}, [scopedModelList, value?.provider, value?.model])
const hasDeprecated = useMemo(() => {
return !currentProvider || !currentModel
@ -116,11 +110,33 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
const handleChangeModel = ({ provider, model }: DefaultModel) => {
const targetProvider = scopedModelList.find(modelItem => modelItem.provider === provider)
const targetModelItem = targetProvider?.models.find((modelItem: { model: string }) => modelItem.model === model)
const model_type = targetModelItem?.model_type as string
setModel({
modelId: model,
provider,
mode: targetModelItem?.model_properties.mode as string,
features: targetModelItem?.features || [],
model,
model_type,
...(model_type === ModelTypeEnum.textGeneration ? {
mode: targetModelItem?.model_properties.mode as string,
} : {}),
})
}
const handleLLMParamsChange = (newParams: FormValue) => {
const newValue = {
...(value?.completionParams || {}),
completion_params: newParams,
}
setModel({
...value,
...newValue,
})
}
const handleTTSParamsChange = (language: string, voice: string) => {
setModel({
...value,
language,
voice,
})
}
@ -149,8 +165,8 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
hasDeprecated,
currentProvider,
currentModel,
providerName: provider,
modelId,
providerName: value?.provider,
modelId: value?.model,
})
: (
<Trigger
@ -160,8 +176,8 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
hasDeprecated={hasDeprecated}
currentProvider={currentProvider}
currentModel={currentModel}
providerName={provider}
modelId={modelId}
providerName={value?.provider}
modelId={value?.model}
/>
)
}
@ -174,7 +190,7 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
{t('common.modelProvider.model').toLocaleUpperCase()}
</div>
<ModelSelector
defaultModel={(provider || modelId) ? { provider, model: modelId } : undefined}
defaultModel={(value?.provider || value?.model) ? { provider: value?.provider, model: value?.model } : undefined}
modelList={scopedModelList}
scopeFeatures={scopeFeatures}
onSelect={handleChangeModel}
@ -185,13 +201,21 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
)}
{currentModel?.model_type === ModelTypeEnum.textGeneration && (
<LLMParamsPanel
provider={provider}
modelId={modelId}
completionParams={completionParams}
onCompletionParamsChange={onCompletionParamsChange}
provider={value?.provider}
modelId={value?.model}
completionParams={value?.completion_params || {}}
onCompletionParamsChange={handleLLMParamsChange}
isAdvancedMode={isAdvancedMode}
/>
)}
{currentModel?.model_type === ModelTypeEnum.tts && (
<TTSParamsPanel
currentModel={currentModel}
language={value?.language}
voice={value?.voice}
onChange={handleTTSParamsChange}
/>
)}
</div>
</div>
</PortalToFollowElemContent>

View File

@ -0,0 +1,67 @@
import React, { useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import { languages } from '@/i18n/language'
import { PortalSelect } from '@/app/components/base/select'
import cn from '@/utils/classnames'
type Props = {
currentModel: any
language: string
voice: string
onChange: (language: string, voice: string) => void
}
const TTSParamsPanel = ({
currentModel,
language,
voice,
onChange,
}: Props) => {
const { t } = useTranslation()
const voiceList = useMemo(() => {
if (!currentModel)
return []
return currentModel.model_properties.voices.map((item: { mode: any }) => ({
...item,
value: item.mode,
}))
}, [currentModel])
const setLanguage = (language: string) => {
onChange(language, voice)
}
const setVoice = (voice: string) => {
onChange(language, voice)
}
return (
<>
<div className='mb-3'>
<div className='mb-1 py-1 flex items-center text-text-secondary system-sm-semibold'>
{t('appDebug.voice.voiceSettings.language')}
</div>
<PortalSelect
triggerClassName='h-8'
popupClassName={cn('z-[1000]')}
popupInnerClassName={cn('w-[354px]')}
value={language}
items={languages.filter(item => item.supported)}
onSelect={item => setLanguage(item.value as string)}
/>
</div>
<div className='mb-3'>
<div className='mb-1 py-1 flex items-center text-text-secondary system-sm-semibold'>
{t('appDebug.voice.voiceSettings.voice')}
</div>
<PortalSelect
triggerClassName='h-8'
popupClassName={cn('z-[1000]')}
popupInnerClassName={cn('w-[354px]')}
value={voice}
items={voiceList}
onSelect={item => setVoice(item.value as string)}
/>
</div>
</>
)
}
export default TTSParamsPanel