feat(knowledge-base): add fine-grained embedding model validation with inline warnings

Extract validation logic from default.ts into shared utils.ts, enabling
node card, panel, and checklist to share the same validation rules.
Introduce provider-scoped model list queries to detect non-active model
states (noConfigure, quotaExceeded, credentialRemoved, incompatible).
Expand node card from 2 rows to 4 rows with per-row warning indicators,
and add warningDot support to panel field titles.
This commit is contained in:
yyh
2026-03-10 17:25:27 +08:00
parent 369e4eb7b0
commit 0b2ded3227
14 changed files with 710 additions and 167 deletions

View File

@ -12,9 +12,11 @@ import type {
Node,
ValueSelector,
} from '../types'
import type { ModelItem } from '@/app/components/header/account-setting/model-provider-page/declarations'
import type { Emoji } from '@/app/components/tools/types'
import type { DataSet } from '@/models/datasets'
import type { I18nKeysWithPrefix } from '@/types/i18n'
import { useQueries, useQueryClient } from '@tanstack/react-query'
import {
useCallback,
useMemo,
@ -30,6 +32,7 @@ import useNodes from '@/app/components/workflow/store/workflow/use-nodes'
import { MAX_TREE_DEPTH } from '@/config'
import { useGetLanguage } from '@/context/i18n'
import { useProviderContextSelector } from '@/context/provider-context'
import { consoleQuery } from '@/service/client'
import { fetchDatasets } from '@/service/datasets'
import { useStrategyProviders } from '@/service/use-strategy'
import {
@ -49,6 +52,7 @@ import {
useNodesMetaData,
} from '../hooks'
import { getNodeUsedVars, isSpecialVar } from '../nodes/_base/components/variable/utils'
import { IndexMethodEnum } from '../nodes/knowledge-base/types'
import {
useStore,
useWorkflowStore,
@ -105,6 +109,43 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
const map = useNodesAvailableVarList(nodes)
const { data: embeddingModelList } = useModelList(ModelTypeEnum.textEmbedding)
const { data: rerankModelList } = useModelList(ModelTypeEnum.rerank)
const knowledgeBaseEmbeddingProviders = useMemo(() => {
const providers = new Set<string>()
nodes.forEach((node) => {
if (node.type !== CUSTOM_NODE || node.data.type !== BlockEnum.KnowledgeBase)
return
const knowledgeBaseData = node.data as CommonNodeType<KnowledgeBaseNodeType>
if (knowledgeBaseData.indexing_technique !== IndexMethodEnum.QUALIFIED)
return
const provider = knowledgeBaseData.embedding_model_provider
if (provider)
providers.add(provider)
})
return [...providers]
}, [nodes])
const knowledgeBaseProviderModelMap = useQueries({
queries: knowledgeBaseEmbeddingProviders.map(provider =>
consoleQuery.modelProviders.models.queryOptions({
input: { params: { provider } },
enabled: !!provider,
refetchOnWindowFocus: false,
select: response => response.data,
}),
),
combine: (results) => {
const modelMap: Partial<Record<string, ModelItem[]>> = {}
knowledgeBaseEmbeddingProviders.forEach((provider, index) => {
const models = results[index]?.data
if (models)
modelMap[provider] = models
})
return modelMap
},
})
const getCheckData = useCallback((data: CommonNodeType<{}>) => {
let checkData = data
@ -121,14 +162,16 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
} as CommonNodeType<KnowledgeRetrievalNodeType>
}
else if (data.type === BlockEnum.KnowledgeBase) {
const modelProviderName = (data as CommonNodeType<KnowledgeBaseNodeType>).embedding_model_provider
checkData = {
...data,
_embeddingModelList: embeddingModelList,
_embeddingProviderModelList: modelProviderName ? knowledgeBaseProviderModelMap[modelProviderName] : undefined,
_rerankModelList: rerankModelList,
} as CommonNodeType<KnowledgeBaseNodeType>
}
return checkData
}, [datasetsDetail, embeddingModelList, rerankModelList])
}, [datasetsDetail, embeddingModelList, knowledgeBaseProviderModelMap, rerankModelList])
const needWarningNodes = useMemo<ChecklistItem[]>(() => {
const list: ChecklistItem[] = []
@ -198,7 +241,7 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
if (isSpecialVar(variable[0]))
continue
const usedNode = availableVars.find(v => v.nodeId === variable?.[0])
if (!usedNode || !usedNode.vars.find(v => v.variable === variable?.[1]))
if (!usedNode || !usedNode.vars.some(v => v.variable === variable?.[1]))
hasInvalidVar = true
}
if (hasInvalidVar)
@ -208,7 +251,7 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
const isStartNodeMeta = nodesExtraData?.[node.data.type as BlockEnum]?.metaData.isStart ?? false
const canSkipConnectionCheck = shouldCheckStartNode ? isStartNodeMeta : true
const isUnconnected = !validNodes.find(n => n.id === node.id)
const isUnconnected = !validNodes.some(n => n.id === node.id)
const shouldShowError = errorMessages.length > 0 || (isUnconnected && !canSkipConnectionCheck)
if (shouldShowError) {
@ -247,7 +290,7 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
const isRequiredNodesType = Object.keys(nodesExtraData!).filter((key: any) => (nodesExtraData as any)[key].metaData.isRequired)
isRequiredNodesType.forEach((type: string) => {
if (!filteredNodes.find(node => node.data.type === type)) {
if (!filteredNodes.some(node => node.data.type === type)) {
list.push({
id: `${type}-need-added`,
type,
@ -269,12 +312,13 @@ export const useChecklistBeforePublish = () => {
const { t } = useTranslation()
const language = useGetLanguage()
const { notify } = useToastContext()
const queryClient = useQueryClient()
const store = useStoreApi()
const { nodesMap: nodesExtraData } = useNodesMetaData()
const { data: strategyProviders } = useStrategyProviders()
const modelProviders = useProviderContextSelector(s => s.modelProviders)
const updateDatasetsDetail = useDatasetsDetailStore(s => s.updateDatasetsDetail)
const updateTime = useRef(0)
const updateTimeRef = useRef(0)
const workflowStore = useWorkflowStore()
const { getNodesAvailableVarList } = useGetNodesAvailableVarList()
const { data: embeddingModelList } = useModelList(ModelTypeEnum.textEmbedding)
@ -285,7 +329,11 @@ export const useChecklistBeforePublish = () => {
const appMode = useAppStore.getState().appDetail?.mode
const shouldCheckStartNode = appMode === AppModeEnum.WORKFLOW || appMode === AppModeEnum.ADVANCED_CHAT
const getCheckData = useCallback((data: CommonNodeType<{}>, datasets: DataSet[]) => {
const getCheckData = useCallback((
data: CommonNodeType<object>,
datasets: DataSet[],
embeddingProviderModelMap?: Partial<Record<string, ModelItem[]>>,
) => {
let checkData = data
if (data.type === BlockEnum.KnowledgeRetrieval) {
const datasetIds = (data as CommonNodeType<KnowledgeRetrievalNodeType>).dataset_ids
@ -304,9 +352,11 @@ export const useChecklistBeforePublish = () => {
} as CommonNodeType<KnowledgeRetrievalNodeType>
}
else if (data.type === BlockEnum.KnowledgeBase) {
const modelProviderName = (data as CommonNodeType<KnowledgeBaseNodeType>).embedding_model_provider
checkData = {
...data,
_embeddingModelList: embeddingModelList,
_embeddingProviderModelList: modelProviderName ? embeddingProviderModelMap?.[modelProviderName] : undefined,
_rerankModelList: rerankModelList,
} as CommonNodeType<KnowledgeBaseNodeType>
}
@ -329,24 +379,66 @@ export const useChecklistBeforePublish = () => {
notify({ type: 'error', message: t('common.maxTreeDepth', { ns: 'workflow', depth: MAX_TREE_DEPTH }) })
return false
}
// Before publish, we need to fetch datasets detail, in case of the settings of datasets have been changed
const knowledgeRetrievalNodes = filteredNodes.filter(node => node.data.type === BlockEnum.KnowledgeRetrieval)
const allDatasetIds = knowledgeRetrievalNodes.reduce<string[]>((acc, node) => {
return Array.from(new Set([...acc, ...(node.data as CommonNodeType<KnowledgeRetrievalNodeType>).dataset_ids]))
}, [])
let datasets: DataSet[] = []
if (allDatasetIds.length > 0) {
updateTime.current = updateTime.current + 1
const currUpdateTime = updateTime.current
const { data: datasetsDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: allDatasetIds } })
if (datasetsDetail && datasetsDetail.length > 0) {
// avoid old data to overwrite the new data
if (currUpdateTime < updateTime.current)
return false
datasets = datasetsDetail
updateDatasetsDetail(datasetsDetail)
}
const knowledgeBaseEmbeddingProviders = [...new Set(
filteredNodes
.filter(node => node.data.type === BlockEnum.KnowledgeBase)
.map(node => node.data as CommonNodeType<KnowledgeBaseNodeType>)
.filter(node => node.indexing_technique === IndexMethodEnum.QUALIFIED)
.map(node => node.embedding_model_provider)
.filter((provider): provider is string => !!provider),
)]
const fetchKnowledgeBaseProviderModelMap = async () => {
const modelMap: Partial<Record<string, ModelItem[]>> = {}
await Promise.all(knowledgeBaseEmbeddingProviders.map(async (provider) => {
try {
const modelList = await queryClient.fetchQuery(
consoleQuery.modelProviders.models.queryOptions({
input: { params: { provider } },
}),
)
if (modelList.data)
modelMap[provider] = modelList.data
}
catch {
}
}))
return modelMap
}
const fetchLatestDatasets = async (): Promise<DataSet[] | null> => {
const allDatasetIds = new Set<string>()
filteredNodes.forEach((node) => {
if (node.data.type !== BlockEnum.KnowledgeRetrieval)
return
const datasetIds = (node.data as CommonNodeType<KnowledgeRetrievalNodeType>).dataset_ids
datasetIds.forEach(id => allDatasetIds.add(id))
})
if (allDatasetIds.size === 0)
return []
updateTimeRef.current = updateTimeRef.current + 1
const currUpdateTime = updateTimeRef.current
const { data: datasetsDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: [...allDatasetIds] } })
if (currUpdateTime < updateTimeRef.current)
return null
if (datasetsDetail?.length)
updateDatasetsDetail(datasetsDetail)
return datasetsDetail || []
}
const [embeddingProviderModelMap, datasets] = await Promise.all([
fetchKnowledgeBaseProviderModelMap(),
fetchLatestDatasets(),
])
if (datasets === null)
return false
const installedPluginIds = new Set(modelProviders.map(p => extractPluginId(p.provider)))
const map = getNodesAvailableVarList(nodes)
for (let i = 0; i < filteredNodes.length; i++) {
@ -383,7 +475,7 @@ export const useChecklistBeforePublish = () => {
}
}
const checkData = getCheckData(node.data, datasets)
const checkData = getCheckData(node.data, datasets, embeddingProviderModelMap)
const { errorMessage } = nodesExtraData![node.data.type as BlockEnum].checkValid(checkData, t, moreDataForCheckValid)
if (errorMessage) {
@ -413,7 +505,7 @@ export const useChecklistBeforePublish = () => {
const isStartNodeMeta = nodesExtraData?.[node.data.type as BlockEnum]?.metaData.isStart ?? false
const canSkipConnectionCheck = shouldCheckStartNode ? isStartNodeMeta : true
const isUnconnected = !validNodes.find(n => n.id === node.id)
const isUnconnected = !validNodes.some(n => n.id === node.id)
if (isUnconnected && !canSkipConnectionCheck) {
notify({ type: 'error', message: `[${node.data.title}] ${t('common.needConnectTip', { ns: 'workflow' })}` })
@ -434,14 +526,14 @@ export const useChecklistBeforePublish = () => {
for (let i = 0; i < isRequiredNodesType.length; i++) {
const type = isRequiredNodesType[i]
if (!filteredNodes.find(node => node.data.type === type)) {
if (!filteredNodes.some(node => node.data.type === type)) {
notify({ type: 'error', message: t('common.needAdd', { ns: 'workflow', node: t(`blocks.${type}` as I18nKeysWithPrefix<'workflow', 'blocks.'>, { ns: 'workflow' }) }) })
return false
}
}
return true
}, [store, workflowStore, getNodesAvailableVarList, shouldCheckStartNode, nodesExtraData, notify, t, updateDatasetsDetail, buildInTools, customTools, workflowTools, language, getCheckData, strategyProviders, modelProviders])
}, [store, workflowStore, getNodesAvailableVarList, shouldCheckStartNode, nodesExtraData, notify, t, updateDatasetsDetail, buildInTools, customTools, workflowTools, language, getCheckData, queryClient, strategyProviders, modelProviders])
return {
handleCheckBeforePublish,