diff --git a/web/app/components/workflow/hooks/use-checklist.ts b/web/app/components/workflow/hooks/use-checklist.ts index 838b325dbb..529f0ee6d8 100644 --- a/web/app/components/workflow/hooks/use-checklist.ts +++ b/web/app/components/workflow/hooks/use-checklist.ts @@ -53,6 +53,7 @@ import { } from '../hooks' import { getNodeUsedVars, isSpecialVar } from '../nodes/_base/components/variable/utils' import { IndexMethodEnum } from '../nodes/knowledge-base/types' +import { getLLMModelIssue, isLLMModelProviderInstalled, LLMModelIssueCode } from '../nodes/llm/utils' import { useStore, useWorkflowStore, @@ -223,7 +224,11 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => { else { if (node.data.type === BlockEnum.LLM) { const modelProvider = (node.data as CommonNodeType<{ model?: ModelConfig }>).model?.provider - if (modelProvider && !installedPluginIds.has(extractPluginId(modelProvider))) + const modelIssue = getLLMModelIssue({ + modelProvider, + isModelProviderInstalled: isLLMModelProviderInstalled(modelProvider, installedPluginIds), + }) + if (modelIssue === LLMModelIssueCode.providerPluginUnavailable) errorMessages.push(t('errorMsg.configureModel', { ns: 'workflow' })) } @@ -469,7 +474,11 @@ export const useChecklistBeforePublish = () => { if (node.data.type === BlockEnum.LLM) { const modelProvider = (node.data as CommonNodeType<{ model?: ModelConfig }>).model?.provider - if (modelProvider && !installedPluginIds.has(extractPluginId(modelProvider))) { + const modelIssue = getLLMModelIssue({ + modelProvider, + isModelProviderInstalled: isLLMModelProviderInstalled(modelProvider, installedPluginIds), + }) + if (modelIssue === LLMModelIssueCode.providerPluginUnavailable) { notify({ type: 'error', message: `[${node.data.title}] ${t('errorMsg.configureModel', { ns: 'workflow' })}` }) return false } diff --git a/web/app/components/workflow/nodes/llm/default.ts b/web/app/components/workflow/nodes/llm/default.ts index cdd5bfbe6a..bd6027ec21 100644 --- a/web/app/components/workflow/nodes/llm/default.ts +++ b/web/app/components/workflow/nodes/llm/default.ts @@ -4,6 +4,7 @@ import { genNodeMetaData } from '@/app/components/workflow/utils' // import { RETRIEVAL_OUTPUT_STRUCT } from '../../constants' import { AppModeEnum } from '@/types/app' import { BlockEnum, EditionType, PromptRole } from '../../types' +import { getLLMModelIssue, LLMModelIssueCode } from './utils' const RETRIEVAL_OUTPUT_STRUCT = `{ "content": "", @@ -60,7 +61,8 @@ const nodeDefault: NodeDefault = { }, checkValid(payload: LLMNodeType, t: any) { let errorMessages = '' - if (!errorMessages && !payload.model.provider) + const modelIssue = getLLMModelIssue({ modelProvider: payload.model.provider }) + if (!errorMessages && modelIssue === LLMModelIssueCode.providerRequired) errorMessages = t(`${i18nPrefix}.fieldRequired`, { ns: 'workflow', field: t(`${i18nPrefix}.fields.model`, { ns: 'workflow' }) }) if (!errorMessages && !payload.memory) { diff --git a/web/app/components/workflow/nodes/llm/panel.tsx b/web/app/components/workflow/nodes/llm/panel.tsx index 9ed6e81e3f..b712ceca57 100644 --- a/web/app/components/workflow/nodes/llm/panel.tsx +++ b/web/app/components/workflow/nodes/llm/panel.tsx @@ -25,6 +25,7 @@ import ConfigPrompt from './components/config-prompt' import ReasoningFormatConfig from './components/reasoning-format-config' import StructureOutput from './components/structure-output' import useConfig from './use-config' +import { getLLMModelIssue, LLMModelIssueCode } from './utils' const i18nPrefix = 'nodes.llm' @@ -33,7 +34,6 @@ const Panel: FC> = ({ data, }) => { const { t } = useTranslation() - const modelProviders = useProviderContextSelector(s => s.modelProviders) const { readOnly, inputs, @@ -70,10 +70,18 @@ const Panel: FC> = ({ } = useConfig(id, data) const model = inputs.model - const installedPluginIds = new Set(modelProviders.map(provider => extractPluginId(provider.provider))) - const hasModelWarning = !model?.provider - || !model?.name - || (Boolean(model.provider) && !installedPluginIds.has(extractPluginId(model.provider))) + const isModelProviderInstalled = useProviderContextSelector((state) => { + const modelIssue = getLLMModelIssue({ modelProvider: model?.provider }) + if (modelIssue === LLMModelIssueCode.providerRequired) + return true + + const modelProviderPluginId = extractPluginId(model.provider) + return state.modelProviders.some(provider => extractPluginId(provider.provider) === modelProviderPluginId) + }) + const hasModelWarning = getLLMModelIssue({ + modelProvider: model?.provider, + isModelProviderInstalled, + }) !== null const handleModelChange = useCallback((model: { provider: string diff --git a/web/app/components/workflow/nodes/llm/utils.ts b/web/app/components/workflow/nodes/llm/utils.ts index 31b942ee64..662a1cae26 100644 --- a/web/app/components/workflow/nodes/llm/utils.ts +++ b/web/app/components/workflow/nodes/llm/utils.ts @@ -2,12 +2,41 @@ import type { ValidationError } from 'jsonschema' import type { ArrayItems, Field, LLMNodeType } from './types' import * as z from 'zod' import { draft07Validator, forbidBooleanProperties } from '@/utils/validators' +import { extractPluginId } from '../../utils/plugin' import { ArrayType, Type } from './types' export const checkNodeValid = (_payload: LLMNodeType) => { return true } +export enum LLMModelIssueCode { + providerRequired = 'provider-required', + providerPluginUnavailable = 'provider-plugin-unavailable', +} + +export const getLLMModelIssue = ({ + modelProvider, + isModelProviderInstalled = true, +}: { + modelProvider?: string + isModelProviderInstalled?: boolean +}) => { + if (!modelProvider) + return LLMModelIssueCode.providerRequired + + if (!isModelProviderInstalled) + return LLMModelIssueCode.providerPluginUnavailable + + return null +} + +export const isLLMModelProviderInstalled = (modelProvider: string | undefined, installedPluginIds: ReadonlySet) => { + if (!modelProvider) + return true + + return installedPluginIds.has(extractPluginId(modelProvider)) +} + export const getFieldType = (field: Field) => { const { type, items, enum: enums } = field if (field.schemaType === 'file')