mirror of
https://github.com/langgenius/dify.git
synced 2026-05-02 16:38:04 +08:00
support model params change
This commit is contained in:
@ -21,6 +21,7 @@ import {
|
||||
PortalToFollowElemTrigger,
|
||||
} from '@/app/components/base/portal-to-follow-elem'
|
||||
import LLMParamsPanel from './llm-params-panel'
|
||||
import TTSParamsPanel from './tts-params-panel'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import cn from '@/utils/classnames'
|
||||
|
||||
@ -28,12 +29,8 @@ export type ModelParameterModalProps = {
|
||||
popupClassName?: string
|
||||
portalToFollowElemContentClassName?: string
|
||||
isAdvancedMode: boolean
|
||||
mode: string
|
||||
modelId: string
|
||||
provider: string
|
||||
setModel: (model: { modelId: string; provider: string; mode?: string; features?: string[] }) => void
|
||||
completionParams: FormValue
|
||||
onCompletionParamsChange: (newParams: FormValue) => void
|
||||
value: any
|
||||
setModel: (model: any) => void
|
||||
renderTrigger?: (v: TriggerProps) => ReactNode
|
||||
readonly?: boolean
|
||||
isInWorkflow?: boolean
|
||||
@ -44,15 +41,12 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
||||
popupClassName,
|
||||
portalToFollowElemContentClassName,
|
||||
isAdvancedMode,
|
||||
modelId,
|
||||
provider,
|
||||
value,
|
||||
setModel,
|
||||
completionParams,
|
||||
onCompletionParamsChange,
|
||||
renderTrigger,
|
||||
readonly,
|
||||
isInWorkflow,
|
||||
scope = 'text-generation',
|
||||
scope = ModelTypeEnum.textGeneration,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { isAPIKeySet } = useProviderContext()
|
||||
@ -79,29 +73,29 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
||||
...moderationList,
|
||||
]
|
||||
}
|
||||
if (scopeArray.includes('text-generation'))
|
||||
if (scopeArray.includes(ModelTypeEnum.textGeneration))
|
||||
return textGenerationList
|
||||
if (scopeArray.includes('embedding'))
|
||||
if (scopeArray.includes(ModelTypeEnum.textEmbedding))
|
||||
return textEmbeddingList
|
||||
if (scopeArray.includes('rerank'))
|
||||
if (scopeArray.includes(ModelTypeEnum.rerank))
|
||||
return rerankList
|
||||
if (scopeArray.includes('moderation'))
|
||||
if (scopeArray.includes(ModelTypeEnum.moderation))
|
||||
return moderationList
|
||||
if (scopeArray.includes('stt'))
|
||||
if (scopeArray.includes(ModelTypeEnum.speech2text))
|
||||
return sttList
|
||||
if (scopeArray.includes('tts'))
|
||||
if (scopeArray.includes(ModelTypeEnum.tts))
|
||||
return ttsList
|
||||
return resultList
|
||||
}, [scopeArray, textGenerationList, textEmbeddingList, rerankList, sttList, ttsList, moderationList])
|
||||
|
||||
const { currentProvider, currentModel } = useMemo(() => {
|
||||
const currentProvider = scopedModelList.find(item => item.provider === provider)
|
||||
const currentModel = currentProvider?.models.find((model: { model: string }) => model.model === modelId)
|
||||
const currentProvider = scopedModelList.find(item => item.provider === value?.provider)
|
||||
const currentModel = currentProvider?.models.find((model: { model: string }) => model.model === value?.model)
|
||||
return {
|
||||
currentProvider,
|
||||
currentModel,
|
||||
}
|
||||
}, [provider, modelId, scopedModelList])
|
||||
}, [scopedModelList, value?.provider, value?.model])
|
||||
|
||||
const hasDeprecated = useMemo(() => {
|
||||
return !currentProvider || !currentModel
|
||||
@ -116,11 +110,33 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
||||
const handleChangeModel = ({ provider, model }: DefaultModel) => {
|
||||
const targetProvider = scopedModelList.find(modelItem => modelItem.provider === provider)
|
||||
const targetModelItem = targetProvider?.models.find((modelItem: { model: string }) => modelItem.model === model)
|
||||
const model_type = targetModelItem?.model_type as string
|
||||
setModel({
|
||||
modelId: model,
|
||||
provider,
|
||||
mode: targetModelItem?.model_properties.mode as string,
|
||||
features: targetModelItem?.features || [],
|
||||
model,
|
||||
model_type,
|
||||
...(model_type === ModelTypeEnum.textGeneration ? {
|
||||
mode: targetModelItem?.model_properties.mode as string,
|
||||
} : {}),
|
||||
})
|
||||
}
|
||||
|
||||
const handleLLMParamsChange = (newParams: FormValue) => {
|
||||
const newValue = {
|
||||
...(value?.completionParams || {}),
|
||||
completion_params: newParams,
|
||||
}
|
||||
setModel({
|
||||
...value,
|
||||
...newValue,
|
||||
})
|
||||
}
|
||||
|
||||
const handleTTSParamsChange = (language: string, voice: string) => {
|
||||
setModel({
|
||||
...value,
|
||||
language,
|
||||
voice,
|
||||
})
|
||||
}
|
||||
|
||||
@ -149,8 +165,8 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
||||
hasDeprecated,
|
||||
currentProvider,
|
||||
currentModel,
|
||||
providerName: provider,
|
||||
modelId,
|
||||
providerName: value?.provider,
|
||||
modelId: value?.model,
|
||||
})
|
||||
: (
|
||||
<Trigger
|
||||
@ -160,8 +176,8 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
||||
hasDeprecated={hasDeprecated}
|
||||
currentProvider={currentProvider}
|
||||
currentModel={currentModel}
|
||||
providerName={provider}
|
||||
modelId={modelId}
|
||||
providerName={value?.provider}
|
||||
modelId={value?.model}
|
||||
/>
|
||||
)
|
||||
}
|
||||
@ -174,7 +190,7 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
||||
{t('common.modelProvider.model').toLocaleUpperCase()}
|
||||
</div>
|
||||
<ModelSelector
|
||||
defaultModel={(provider || modelId) ? { provider, model: modelId } : undefined}
|
||||
defaultModel={(value?.provider || value?.model) ? { provider: value?.provider, model: value?.model } : undefined}
|
||||
modelList={scopedModelList}
|
||||
scopeFeatures={scopeFeatures}
|
||||
onSelect={handleChangeModel}
|
||||
@ -185,13 +201,21 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
||||
)}
|
||||
{currentModel?.model_type === ModelTypeEnum.textGeneration && (
|
||||
<LLMParamsPanel
|
||||
provider={provider}
|
||||
modelId={modelId}
|
||||
completionParams={completionParams}
|
||||
onCompletionParamsChange={onCompletionParamsChange}
|
||||
provider={value?.provider}
|
||||
modelId={value?.model}
|
||||
completionParams={value?.completion_params || {}}
|
||||
onCompletionParamsChange={handleLLMParamsChange}
|
||||
isAdvancedMode={isAdvancedMode}
|
||||
/>
|
||||
)}
|
||||
{currentModel?.model_type === ModelTypeEnum.tts && (
|
||||
<TTSParamsPanel
|
||||
currentModel={currentModel}
|
||||
language={value?.language}
|
||||
voice={value?.voice}
|
||||
onChange={handleTTSParamsChange}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</PortalToFollowElemContent>
|
||||
|
||||
@ -0,0 +1,67 @@
|
||||
import React, { useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { languages } from '@/i18n/language'
|
||||
import { PortalSelect } from '@/app/components/base/select'
|
||||
import cn from '@/utils/classnames'
|
||||
|
||||
type Props = {
|
||||
currentModel: any
|
||||
language: string
|
||||
voice: string
|
||||
onChange: (language: string, voice: string) => void
|
||||
}
|
||||
|
||||
const TTSParamsPanel = ({
|
||||
currentModel,
|
||||
language,
|
||||
voice,
|
||||
onChange,
|
||||
}: Props) => {
|
||||
const { t } = useTranslation()
|
||||
const voiceList = useMemo(() => {
|
||||
if (!currentModel)
|
||||
return []
|
||||
return currentModel.model_properties.voices.map((item: { mode: any }) => ({
|
||||
...item,
|
||||
value: item.mode,
|
||||
}))
|
||||
}, [currentModel])
|
||||
const setLanguage = (language: string) => {
|
||||
onChange(language, voice)
|
||||
}
|
||||
const setVoice = (voice: string) => {
|
||||
onChange(language, voice)
|
||||
}
|
||||
return (
|
||||
<>
|
||||
<div className='mb-3'>
|
||||
<div className='mb-1 py-1 flex items-center text-text-secondary system-sm-semibold'>
|
||||
{t('appDebug.voice.voiceSettings.language')}
|
||||
</div>
|
||||
<PortalSelect
|
||||
triggerClassName='h-8'
|
||||
popupClassName={cn('z-[1000]')}
|
||||
popupInnerClassName={cn('w-[354px]')}
|
||||
value={language}
|
||||
items={languages.filter(item => item.supported)}
|
||||
onSelect={item => setLanguage(item.value as string)}
|
||||
/>
|
||||
</div>
|
||||
<div className='mb-3'>
|
||||
<div className='mb-1 py-1 flex items-center text-text-secondary system-sm-semibold'>
|
||||
{t('appDebug.voice.voiceSettings.voice')}
|
||||
</div>
|
||||
<PortalSelect
|
||||
triggerClassName='h-8'
|
||||
popupClassName={cn('z-[1000]')}
|
||||
popupInnerClassName={cn('w-[354px]')}
|
||||
value={voice}
|
||||
items={voiceList}
|
||||
onSelect={item => setVoice(item.value as string)}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
export default TTSParamsPanel
|
||||
Reference in New Issue
Block a user