mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 17:38:04 +08:00
feat(knowledge-base): add fine-grained embedding model validation with inline warnings
Extract validation logic from default.ts into shared utils.ts, enabling node card, panel, and checklist to share the same validation rules. Introduce provider-scoped model list queries to detect non-active model states (noConfigure, quotaExceeded, credentialRemoved, incompatible). Expand node card from 2 rows to 4 rows with per-row warning indicators, and add warningDot support to panel field titles.
This commit is contained in:
@ -12,9 +12,11 @@ import type {
|
||||
Node,
|
||||
ValueSelector,
|
||||
} from '../types'
|
||||
import type { ModelItem } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import type { Emoji } from '@/app/components/tools/types'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
import type { I18nKeysWithPrefix } from '@/types/i18n'
|
||||
import { useQueries, useQueryClient } from '@tanstack/react-query'
|
||||
import {
|
||||
useCallback,
|
||||
useMemo,
|
||||
@ -30,6 +32,7 @@ import useNodes from '@/app/components/workflow/store/workflow/use-nodes'
|
||||
import { MAX_TREE_DEPTH } from '@/config'
|
||||
import { useGetLanguage } from '@/context/i18n'
|
||||
import { useProviderContextSelector } from '@/context/provider-context'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
import { fetchDatasets } from '@/service/datasets'
|
||||
import { useStrategyProviders } from '@/service/use-strategy'
|
||||
import {
|
||||
@ -49,6 +52,7 @@ import {
|
||||
useNodesMetaData,
|
||||
} from '../hooks'
|
||||
import { getNodeUsedVars, isSpecialVar } from '../nodes/_base/components/variable/utils'
|
||||
import { IndexMethodEnum } from '../nodes/knowledge-base/types'
|
||||
import {
|
||||
useStore,
|
||||
useWorkflowStore,
|
||||
@ -105,6 +109,43 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
|
||||
const map = useNodesAvailableVarList(nodes)
|
||||
const { data: embeddingModelList } = useModelList(ModelTypeEnum.textEmbedding)
|
||||
const { data: rerankModelList } = useModelList(ModelTypeEnum.rerank)
|
||||
const knowledgeBaseEmbeddingProviders = useMemo(() => {
|
||||
const providers = new Set<string>()
|
||||
|
||||
nodes.forEach((node) => {
|
||||
if (node.type !== CUSTOM_NODE || node.data.type !== BlockEnum.KnowledgeBase)
|
||||
return
|
||||
|
||||
const knowledgeBaseData = node.data as CommonNodeType<KnowledgeBaseNodeType>
|
||||
if (knowledgeBaseData.indexing_technique !== IndexMethodEnum.QUALIFIED)
|
||||
return
|
||||
|
||||
const provider = knowledgeBaseData.embedding_model_provider
|
||||
if (provider)
|
||||
providers.add(provider)
|
||||
})
|
||||
|
||||
return [...providers]
|
||||
}, [nodes])
|
||||
const knowledgeBaseProviderModelMap = useQueries({
|
||||
queries: knowledgeBaseEmbeddingProviders.map(provider =>
|
||||
consoleQuery.modelProviders.models.queryOptions({
|
||||
input: { params: { provider } },
|
||||
enabled: !!provider,
|
||||
refetchOnWindowFocus: false,
|
||||
select: response => response.data,
|
||||
}),
|
||||
),
|
||||
combine: (results) => {
|
||||
const modelMap: Partial<Record<string, ModelItem[]>> = {}
|
||||
knowledgeBaseEmbeddingProviders.forEach((provider, index) => {
|
||||
const models = results[index]?.data
|
||||
if (models)
|
||||
modelMap[provider] = models
|
||||
})
|
||||
return modelMap
|
||||
},
|
||||
})
|
||||
|
||||
const getCheckData = useCallback((data: CommonNodeType<{}>) => {
|
||||
let checkData = data
|
||||
@ -121,14 +162,16 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
|
||||
} as CommonNodeType<KnowledgeRetrievalNodeType>
|
||||
}
|
||||
else if (data.type === BlockEnum.KnowledgeBase) {
|
||||
const modelProviderName = (data as CommonNodeType<KnowledgeBaseNodeType>).embedding_model_provider
|
||||
checkData = {
|
||||
...data,
|
||||
_embeddingModelList: embeddingModelList,
|
||||
_embeddingProviderModelList: modelProviderName ? knowledgeBaseProviderModelMap[modelProviderName] : undefined,
|
||||
_rerankModelList: rerankModelList,
|
||||
} as CommonNodeType<KnowledgeBaseNodeType>
|
||||
}
|
||||
return checkData
|
||||
}, [datasetsDetail, embeddingModelList, rerankModelList])
|
||||
}, [datasetsDetail, embeddingModelList, knowledgeBaseProviderModelMap, rerankModelList])
|
||||
|
||||
const needWarningNodes = useMemo<ChecklistItem[]>(() => {
|
||||
const list: ChecklistItem[] = []
|
||||
@ -198,7 +241,7 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
|
||||
if (isSpecialVar(variable[0]))
|
||||
continue
|
||||
const usedNode = availableVars.find(v => v.nodeId === variable?.[0])
|
||||
if (!usedNode || !usedNode.vars.find(v => v.variable === variable?.[1]))
|
||||
if (!usedNode || !usedNode.vars.some(v => v.variable === variable?.[1]))
|
||||
hasInvalidVar = true
|
||||
}
|
||||
if (hasInvalidVar)
|
||||
@ -208,7 +251,7 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
|
||||
const isStartNodeMeta = nodesExtraData?.[node.data.type as BlockEnum]?.metaData.isStart ?? false
|
||||
const canSkipConnectionCheck = shouldCheckStartNode ? isStartNodeMeta : true
|
||||
|
||||
const isUnconnected = !validNodes.find(n => n.id === node.id)
|
||||
const isUnconnected = !validNodes.some(n => n.id === node.id)
|
||||
const shouldShowError = errorMessages.length > 0 || (isUnconnected && !canSkipConnectionCheck)
|
||||
|
||||
if (shouldShowError) {
|
||||
@ -247,7 +290,7 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
|
||||
const isRequiredNodesType = Object.keys(nodesExtraData!).filter((key: any) => (nodesExtraData as any)[key].metaData.isRequired)
|
||||
|
||||
isRequiredNodesType.forEach((type: string) => {
|
||||
if (!filteredNodes.find(node => node.data.type === type)) {
|
||||
if (!filteredNodes.some(node => node.data.type === type)) {
|
||||
list.push({
|
||||
id: `${type}-need-added`,
|
||||
type,
|
||||
@ -269,12 +312,13 @@ export const useChecklistBeforePublish = () => {
|
||||
const { t } = useTranslation()
|
||||
const language = useGetLanguage()
|
||||
const { notify } = useToastContext()
|
||||
const queryClient = useQueryClient()
|
||||
const store = useStoreApi()
|
||||
const { nodesMap: nodesExtraData } = useNodesMetaData()
|
||||
const { data: strategyProviders } = useStrategyProviders()
|
||||
const modelProviders = useProviderContextSelector(s => s.modelProviders)
|
||||
const updateDatasetsDetail = useDatasetsDetailStore(s => s.updateDatasetsDetail)
|
||||
const updateTime = useRef(0)
|
||||
const updateTimeRef = useRef(0)
|
||||
const workflowStore = useWorkflowStore()
|
||||
const { getNodesAvailableVarList } = useGetNodesAvailableVarList()
|
||||
const { data: embeddingModelList } = useModelList(ModelTypeEnum.textEmbedding)
|
||||
@ -285,7 +329,11 @@ export const useChecklistBeforePublish = () => {
|
||||
const appMode = useAppStore.getState().appDetail?.mode
|
||||
const shouldCheckStartNode = appMode === AppModeEnum.WORKFLOW || appMode === AppModeEnum.ADVANCED_CHAT
|
||||
|
||||
const getCheckData = useCallback((data: CommonNodeType<{}>, datasets: DataSet[]) => {
|
||||
const getCheckData = useCallback((
|
||||
data: CommonNodeType<object>,
|
||||
datasets: DataSet[],
|
||||
embeddingProviderModelMap?: Partial<Record<string, ModelItem[]>>,
|
||||
) => {
|
||||
let checkData = data
|
||||
if (data.type === BlockEnum.KnowledgeRetrieval) {
|
||||
const datasetIds = (data as CommonNodeType<KnowledgeRetrievalNodeType>).dataset_ids
|
||||
@ -304,9 +352,11 @@ export const useChecklistBeforePublish = () => {
|
||||
} as CommonNodeType<KnowledgeRetrievalNodeType>
|
||||
}
|
||||
else if (data.type === BlockEnum.KnowledgeBase) {
|
||||
const modelProviderName = (data as CommonNodeType<KnowledgeBaseNodeType>).embedding_model_provider
|
||||
checkData = {
|
||||
...data,
|
||||
_embeddingModelList: embeddingModelList,
|
||||
_embeddingProviderModelList: modelProviderName ? embeddingProviderModelMap?.[modelProviderName] : undefined,
|
||||
_rerankModelList: rerankModelList,
|
||||
} as CommonNodeType<KnowledgeBaseNodeType>
|
||||
}
|
||||
@ -329,24 +379,66 @@ export const useChecklistBeforePublish = () => {
|
||||
notify({ type: 'error', message: t('common.maxTreeDepth', { ns: 'workflow', depth: MAX_TREE_DEPTH }) })
|
||||
return false
|
||||
}
|
||||
// Before publish, we need to fetch datasets detail, in case of the settings of datasets have been changed
|
||||
const knowledgeRetrievalNodes = filteredNodes.filter(node => node.data.type === BlockEnum.KnowledgeRetrieval)
|
||||
const allDatasetIds = knowledgeRetrievalNodes.reduce<string[]>((acc, node) => {
|
||||
return Array.from(new Set([...acc, ...(node.data as CommonNodeType<KnowledgeRetrievalNodeType>).dataset_ids]))
|
||||
}, [])
|
||||
let datasets: DataSet[] = []
|
||||
if (allDatasetIds.length > 0) {
|
||||
updateTime.current = updateTime.current + 1
|
||||
const currUpdateTime = updateTime.current
|
||||
const { data: datasetsDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: allDatasetIds } })
|
||||
if (datasetsDetail && datasetsDetail.length > 0) {
|
||||
// avoid old data to overwrite the new data
|
||||
if (currUpdateTime < updateTime.current)
|
||||
return false
|
||||
datasets = datasetsDetail
|
||||
updateDatasetsDetail(datasetsDetail)
|
||||
}
|
||||
|
||||
const knowledgeBaseEmbeddingProviders = [...new Set(
|
||||
filteredNodes
|
||||
.filter(node => node.data.type === BlockEnum.KnowledgeBase)
|
||||
.map(node => node.data as CommonNodeType<KnowledgeBaseNodeType>)
|
||||
.filter(node => node.indexing_technique === IndexMethodEnum.QUALIFIED)
|
||||
.map(node => node.embedding_model_provider)
|
||||
.filter((provider): provider is string => !!provider),
|
||||
)]
|
||||
|
||||
const fetchKnowledgeBaseProviderModelMap = async () => {
|
||||
const modelMap: Partial<Record<string, ModelItem[]>> = {}
|
||||
await Promise.all(knowledgeBaseEmbeddingProviders.map(async (provider) => {
|
||||
try {
|
||||
const modelList = await queryClient.fetchQuery(
|
||||
consoleQuery.modelProviders.models.queryOptions({
|
||||
input: { params: { provider } },
|
||||
}),
|
||||
)
|
||||
|
||||
if (modelList.data)
|
||||
modelMap[provider] = modelList.data
|
||||
}
|
||||
catch {
|
||||
}
|
||||
}))
|
||||
return modelMap
|
||||
}
|
||||
|
||||
const fetchLatestDatasets = async (): Promise<DataSet[] | null> => {
|
||||
const allDatasetIds = new Set<string>()
|
||||
filteredNodes.forEach((node) => {
|
||||
if (node.data.type !== BlockEnum.KnowledgeRetrieval)
|
||||
return
|
||||
|
||||
const datasetIds = (node.data as CommonNodeType<KnowledgeRetrievalNodeType>).dataset_ids
|
||||
datasetIds.forEach(id => allDatasetIds.add(id))
|
||||
})
|
||||
|
||||
if (allDatasetIds.size === 0)
|
||||
return []
|
||||
|
||||
updateTimeRef.current = updateTimeRef.current + 1
|
||||
const currUpdateTime = updateTimeRef.current
|
||||
const { data: datasetsDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: [...allDatasetIds] } })
|
||||
if (currUpdateTime < updateTimeRef.current)
|
||||
return null
|
||||
if (datasetsDetail?.length)
|
||||
updateDatasetsDetail(datasetsDetail)
|
||||
return datasetsDetail || []
|
||||
}
|
||||
|
||||
const [embeddingProviderModelMap, datasets] = await Promise.all([
|
||||
fetchKnowledgeBaseProviderModelMap(),
|
||||
fetchLatestDatasets(),
|
||||
])
|
||||
|
||||
if (datasets === null)
|
||||
return false
|
||||
|
||||
const installedPluginIds = new Set(modelProviders.map(p => extractPluginId(p.provider)))
|
||||
const map = getNodesAvailableVarList(nodes)
|
||||
for (let i = 0; i < filteredNodes.length; i++) {
|
||||
@ -383,7 +475,7 @@ export const useChecklistBeforePublish = () => {
|
||||
}
|
||||
}
|
||||
|
||||
const checkData = getCheckData(node.data, datasets)
|
||||
const checkData = getCheckData(node.data, datasets, embeddingProviderModelMap)
|
||||
const { errorMessage } = nodesExtraData![node.data.type as BlockEnum].checkValid(checkData, t, moreDataForCheckValid)
|
||||
|
||||
if (errorMessage) {
|
||||
@ -413,7 +505,7 @@ export const useChecklistBeforePublish = () => {
|
||||
|
||||
const isStartNodeMeta = nodesExtraData?.[node.data.type as BlockEnum]?.metaData.isStart ?? false
|
||||
const canSkipConnectionCheck = shouldCheckStartNode ? isStartNodeMeta : true
|
||||
const isUnconnected = !validNodes.find(n => n.id === node.id)
|
||||
const isUnconnected = !validNodes.some(n => n.id === node.id)
|
||||
|
||||
if (isUnconnected && !canSkipConnectionCheck) {
|
||||
notify({ type: 'error', message: `[${node.data.title}] ${t('common.needConnectTip', { ns: 'workflow' })}` })
|
||||
@ -434,14 +526,14 @@ export const useChecklistBeforePublish = () => {
|
||||
for (let i = 0; i < isRequiredNodesType.length; i++) {
|
||||
const type = isRequiredNodesType[i]
|
||||
|
||||
if (!filteredNodes.find(node => node.data.type === type)) {
|
||||
if (!filteredNodes.some(node => node.data.type === type)) {
|
||||
notify({ type: 'error', message: t('common.needAdd', { ns: 'workflow', node: t(`blocks.${type}` as I18nKeysWithPrefix<'workflow', 'blocks.'>, { ns: 'workflow' }) }) })
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}, [store, workflowStore, getNodesAvailableVarList, shouldCheckStartNode, nodesExtraData, notify, t, updateDatasetsDetail, buildInTools, customTools, workflowTools, language, getCheckData, strategyProviders, modelProviders])
|
||||
}, [store, workflowStore, getNodesAvailableVarList, shouldCheckStartNode, nodesExtraData, notify, t, updateDatasetsDetail, buildInTools, customTools, workflowTools, language, getCheckData, queryClient, strategyProviders, modelProviders])
|
||||
|
||||
return {
|
||||
handleCheckBeforePublish,
|
||||
|
||||
Reference in New Issue
Block a user