mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 09:28:04 +08:00
feat(knowledge-base): add fine-grained embedding model validation with inline warnings
Extract validation logic from default.ts into shared utils.ts, enabling node card, panel, and checklist to share the same validation rules. Introduce provider-scoped model list queries to detect non-active model states (noConfigure, quotaExceeded, credentialRemoved, incompatible). Expand node card from 2 rows to 4 rows with per-row warning indicators, and add warningDot support to panel field titles.
This commit is contained in:
@ -65,6 +65,7 @@ import type { Shape as HooksStoreShape } from '../hooks-store/store'
|
||||
import type { Shape } from '../store/workflow'
|
||||
import type { Edge, Node, WorkflowRunningData } from '../types'
|
||||
import type { WorkflowHistoryStoreApi } from '../workflow-history-store'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import { render, renderHook } from '@testing-library/react'
|
||||
import isDeepEqual from 'fast-deep-equal'
|
||||
import * as React from 'react'
|
||||
@ -154,6 +155,13 @@ function createWorkflowWrapper(
|
||||
const historyCtxValue = historyConfig
|
||||
? createTestHistoryStoreContext(historyConfig)
|
||||
: undefined
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
return ({ children }: { children: React.ReactNode }) => {
|
||||
let inner: React.ReactNode = children
|
||||
@ -175,9 +183,13 @@ function createWorkflowWrapper(
|
||||
}
|
||||
|
||||
return React.createElement(
|
||||
WorkflowContext.Provider,
|
||||
{ value: stores.store },
|
||||
inner,
|
||||
QueryClientProvider,
|
||||
{ client: queryClient },
|
||||
React.createElement(
|
||||
WorkflowContext.Provider,
|
||||
{ value: stores.store },
|
||||
inner,
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -12,9 +12,11 @@ import type {
|
||||
Node,
|
||||
ValueSelector,
|
||||
} from '../types'
|
||||
import type { ModelItem } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import type { Emoji } from '@/app/components/tools/types'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
import type { I18nKeysWithPrefix } from '@/types/i18n'
|
||||
import { useQueries, useQueryClient } from '@tanstack/react-query'
|
||||
import {
|
||||
useCallback,
|
||||
useMemo,
|
||||
@ -30,6 +32,7 @@ import useNodes from '@/app/components/workflow/store/workflow/use-nodes'
|
||||
import { MAX_TREE_DEPTH } from '@/config'
|
||||
import { useGetLanguage } from '@/context/i18n'
|
||||
import { useProviderContextSelector } from '@/context/provider-context'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
import { fetchDatasets } from '@/service/datasets'
|
||||
import { useStrategyProviders } from '@/service/use-strategy'
|
||||
import {
|
||||
@ -49,6 +52,7 @@ import {
|
||||
useNodesMetaData,
|
||||
} from '../hooks'
|
||||
import { getNodeUsedVars, isSpecialVar } from '../nodes/_base/components/variable/utils'
|
||||
import { IndexMethodEnum } from '../nodes/knowledge-base/types'
|
||||
import {
|
||||
useStore,
|
||||
useWorkflowStore,
|
||||
@ -105,6 +109,43 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
|
||||
const map = useNodesAvailableVarList(nodes)
|
||||
const { data: embeddingModelList } = useModelList(ModelTypeEnum.textEmbedding)
|
||||
const { data: rerankModelList } = useModelList(ModelTypeEnum.rerank)
|
||||
const knowledgeBaseEmbeddingProviders = useMemo(() => {
|
||||
const providers = new Set<string>()
|
||||
|
||||
nodes.forEach((node) => {
|
||||
if (node.type !== CUSTOM_NODE || node.data.type !== BlockEnum.KnowledgeBase)
|
||||
return
|
||||
|
||||
const knowledgeBaseData = node.data as CommonNodeType<KnowledgeBaseNodeType>
|
||||
if (knowledgeBaseData.indexing_technique !== IndexMethodEnum.QUALIFIED)
|
||||
return
|
||||
|
||||
const provider = knowledgeBaseData.embedding_model_provider
|
||||
if (provider)
|
||||
providers.add(provider)
|
||||
})
|
||||
|
||||
return [...providers]
|
||||
}, [nodes])
|
||||
const knowledgeBaseProviderModelMap = useQueries({
|
||||
queries: knowledgeBaseEmbeddingProviders.map(provider =>
|
||||
consoleQuery.modelProviders.models.queryOptions({
|
||||
input: { params: { provider } },
|
||||
enabled: !!provider,
|
||||
refetchOnWindowFocus: false,
|
||||
select: response => response.data,
|
||||
}),
|
||||
),
|
||||
combine: (results) => {
|
||||
const modelMap: Partial<Record<string, ModelItem[]>> = {}
|
||||
knowledgeBaseEmbeddingProviders.forEach((provider, index) => {
|
||||
const models = results[index]?.data
|
||||
if (models)
|
||||
modelMap[provider] = models
|
||||
})
|
||||
return modelMap
|
||||
},
|
||||
})
|
||||
|
||||
const getCheckData = useCallback((data: CommonNodeType<{}>) => {
|
||||
let checkData = data
|
||||
@ -121,14 +162,16 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
|
||||
} as CommonNodeType<KnowledgeRetrievalNodeType>
|
||||
}
|
||||
else if (data.type === BlockEnum.KnowledgeBase) {
|
||||
const modelProviderName = (data as CommonNodeType<KnowledgeBaseNodeType>).embedding_model_provider
|
||||
checkData = {
|
||||
...data,
|
||||
_embeddingModelList: embeddingModelList,
|
||||
_embeddingProviderModelList: modelProviderName ? knowledgeBaseProviderModelMap[modelProviderName] : undefined,
|
||||
_rerankModelList: rerankModelList,
|
||||
} as CommonNodeType<KnowledgeBaseNodeType>
|
||||
}
|
||||
return checkData
|
||||
}, [datasetsDetail, embeddingModelList, rerankModelList])
|
||||
}, [datasetsDetail, embeddingModelList, knowledgeBaseProviderModelMap, rerankModelList])
|
||||
|
||||
const needWarningNodes = useMemo<ChecklistItem[]>(() => {
|
||||
const list: ChecklistItem[] = []
|
||||
@ -198,7 +241,7 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
|
||||
if (isSpecialVar(variable[0]))
|
||||
continue
|
||||
const usedNode = availableVars.find(v => v.nodeId === variable?.[0])
|
||||
if (!usedNode || !usedNode.vars.find(v => v.variable === variable?.[1]))
|
||||
if (!usedNode || !usedNode.vars.some(v => v.variable === variable?.[1]))
|
||||
hasInvalidVar = true
|
||||
}
|
||||
if (hasInvalidVar)
|
||||
@ -208,7 +251,7 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
|
||||
const isStartNodeMeta = nodesExtraData?.[node.data.type as BlockEnum]?.metaData.isStart ?? false
|
||||
const canSkipConnectionCheck = shouldCheckStartNode ? isStartNodeMeta : true
|
||||
|
||||
const isUnconnected = !validNodes.find(n => n.id === node.id)
|
||||
const isUnconnected = !validNodes.some(n => n.id === node.id)
|
||||
const shouldShowError = errorMessages.length > 0 || (isUnconnected && !canSkipConnectionCheck)
|
||||
|
||||
if (shouldShowError) {
|
||||
@ -247,7 +290,7 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
|
||||
const isRequiredNodesType = Object.keys(nodesExtraData!).filter((key: any) => (nodesExtraData as any)[key].metaData.isRequired)
|
||||
|
||||
isRequiredNodesType.forEach((type: string) => {
|
||||
if (!filteredNodes.find(node => node.data.type === type)) {
|
||||
if (!filteredNodes.some(node => node.data.type === type)) {
|
||||
list.push({
|
||||
id: `${type}-need-added`,
|
||||
type,
|
||||
@ -269,12 +312,13 @@ export const useChecklistBeforePublish = () => {
|
||||
const { t } = useTranslation()
|
||||
const language = useGetLanguage()
|
||||
const { notify } = useToastContext()
|
||||
const queryClient = useQueryClient()
|
||||
const store = useStoreApi()
|
||||
const { nodesMap: nodesExtraData } = useNodesMetaData()
|
||||
const { data: strategyProviders } = useStrategyProviders()
|
||||
const modelProviders = useProviderContextSelector(s => s.modelProviders)
|
||||
const updateDatasetsDetail = useDatasetsDetailStore(s => s.updateDatasetsDetail)
|
||||
const updateTime = useRef(0)
|
||||
const updateTimeRef = useRef(0)
|
||||
const workflowStore = useWorkflowStore()
|
||||
const { getNodesAvailableVarList } = useGetNodesAvailableVarList()
|
||||
const { data: embeddingModelList } = useModelList(ModelTypeEnum.textEmbedding)
|
||||
@ -285,7 +329,11 @@ export const useChecklistBeforePublish = () => {
|
||||
const appMode = useAppStore.getState().appDetail?.mode
|
||||
const shouldCheckStartNode = appMode === AppModeEnum.WORKFLOW || appMode === AppModeEnum.ADVANCED_CHAT
|
||||
|
||||
const getCheckData = useCallback((data: CommonNodeType<{}>, datasets: DataSet[]) => {
|
||||
const getCheckData = useCallback((
|
||||
data: CommonNodeType<object>,
|
||||
datasets: DataSet[],
|
||||
embeddingProviderModelMap?: Partial<Record<string, ModelItem[]>>,
|
||||
) => {
|
||||
let checkData = data
|
||||
if (data.type === BlockEnum.KnowledgeRetrieval) {
|
||||
const datasetIds = (data as CommonNodeType<KnowledgeRetrievalNodeType>).dataset_ids
|
||||
@ -304,9 +352,11 @@ export const useChecklistBeforePublish = () => {
|
||||
} as CommonNodeType<KnowledgeRetrievalNodeType>
|
||||
}
|
||||
else if (data.type === BlockEnum.KnowledgeBase) {
|
||||
const modelProviderName = (data as CommonNodeType<KnowledgeBaseNodeType>).embedding_model_provider
|
||||
checkData = {
|
||||
...data,
|
||||
_embeddingModelList: embeddingModelList,
|
||||
_embeddingProviderModelList: modelProviderName ? embeddingProviderModelMap?.[modelProviderName] : undefined,
|
||||
_rerankModelList: rerankModelList,
|
||||
} as CommonNodeType<KnowledgeBaseNodeType>
|
||||
}
|
||||
@ -329,24 +379,66 @@ export const useChecklistBeforePublish = () => {
|
||||
notify({ type: 'error', message: t('common.maxTreeDepth', { ns: 'workflow', depth: MAX_TREE_DEPTH }) })
|
||||
return false
|
||||
}
|
||||
// Before publish, we need to fetch datasets detail, in case of the settings of datasets have been changed
|
||||
const knowledgeRetrievalNodes = filteredNodes.filter(node => node.data.type === BlockEnum.KnowledgeRetrieval)
|
||||
const allDatasetIds = knowledgeRetrievalNodes.reduce<string[]>((acc, node) => {
|
||||
return Array.from(new Set([...acc, ...(node.data as CommonNodeType<KnowledgeRetrievalNodeType>).dataset_ids]))
|
||||
}, [])
|
||||
let datasets: DataSet[] = []
|
||||
if (allDatasetIds.length > 0) {
|
||||
updateTime.current = updateTime.current + 1
|
||||
const currUpdateTime = updateTime.current
|
||||
const { data: datasetsDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: allDatasetIds } })
|
||||
if (datasetsDetail && datasetsDetail.length > 0) {
|
||||
// avoid old data to overwrite the new data
|
||||
if (currUpdateTime < updateTime.current)
|
||||
return false
|
||||
datasets = datasetsDetail
|
||||
updateDatasetsDetail(datasetsDetail)
|
||||
}
|
||||
|
||||
const knowledgeBaseEmbeddingProviders = [...new Set(
|
||||
filteredNodes
|
||||
.filter(node => node.data.type === BlockEnum.KnowledgeBase)
|
||||
.map(node => node.data as CommonNodeType<KnowledgeBaseNodeType>)
|
||||
.filter(node => node.indexing_technique === IndexMethodEnum.QUALIFIED)
|
||||
.map(node => node.embedding_model_provider)
|
||||
.filter((provider): provider is string => !!provider),
|
||||
)]
|
||||
|
||||
const fetchKnowledgeBaseProviderModelMap = async () => {
|
||||
const modelMap: Partial<Record<string, ModelItem[]>> = {}
|
||||
await Promise.all(knowledgeBaseEmbeddingProviders.map(async (provider) => {
|
||||
try {
|
||||
const modelList = await queryClient.fetchQuery(
|
||||
consoleQuery.modelProviders.models.queryOptions({
|
||||
input: { params: { provider } },
|
||||
}),
|
||||
)
|
||||
|
||||
if (modelList.data)
|
||||
modelMap[provider] = modelList.data
|
||||
}
|
||||
catch {
|
||||
}
|
||||
}))
|
||||
return modelMap
|
||||
}
|
||||
|
||||
const fetchLatestDatasets = async (): Promise<DataSet[] | null> => {
|
||||
const allDatasetIds = new Set<string>()
|
||||
filteredNodes.forEach((node) => {
|
||||
if (node.data.type !== BlockEnum.KnowledgeRetrieval)
|
||||
return
|
||||
|
||||
const datasetIds = (node.data as CommonNodeType<KnowledgeRetrievalNodeType>).dataset_ids
|
||||
datasetIds.forEach(id => allDatasetIds.add(id))
|
||||
})
|
||||
|
||||
if (allDatasetIds.size === 0)
|
||||
return []
|
||||
|
||||
updateTimeRef.current = updateTimeRef.current + 1
|
||||
const currUpdateTime = updateTimeRef.current
|
||||
const { data: datasetsDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: [...allDatasetIds] } })
|
||||
if (currUpdateTime < updateTimeRef.current)
|
||||
return null
|
||||
if (datasetsDetail?.length)
|
||||
updateDatasetsDetail(datasetsDetail)
|
||||
return datasetsDetail || []
|
||||
}
|
||||
|
||||
const [embeddingProviderModelMap, datasets] = await Promise.all([
|
||||
fetchKnowledgeBaseProviderModelMap(),
|
||||
fetchLatestDatasets(),
|
||||
])
|
||||
|
||||
if (datasets === null)
|
||||
return false
|
||||
|
||||
const installedPluginIds = new Set(modelProviders.map(p => extractPluginId(p.provider)))
|
||||
const map = getNodesAvailableVarList(nodes)
|
||||
for (let i = 0; i < filteredNodes.length; i++) {
|
||||
@ -383,7 +475,7 @@ export const useChecklistBeforePublish = () => {
|
||||
}
|
||||
}
|
||||
|
||||
const checkData = getCheckData(node.data, datasets)
|
||||
const checkData = getCheckData(node.data, datasets, embeddingProviderModelMap)
|
||||
const { errorMessage } = nodesExtraData![node.data.type as BlockEnum].checkValid(checkData, t, moreDataForCheckValid)
|
||||
|
||||
if (errorMessage) {
|
||||
@ -413,7 +505,7 @@ export const useChecklistBeforePublish = () => {
|
||||
|
||||
const isStartNodeMeta = nodesExtraData?.[node.data.type as BlockEnum]?.metaData.isStart ?? false
|
||||
const canSkipConnectionCheck = shouldCheckStartNode ? isStartNodeMeta : true
|
||||
const isUnconnected = !validNodes.find(n => n.id === node.id)
|
||||
const isUnconnected = !validNodes.some(n => n.id === node.id)
|
||||
|
||||
if (isUnconnected && !canSkipConnectionCheck) {
|
||||
notify({ type: 'error', message: `[${node.data.title}] ${t('common.needConnectTip', { ns: 'workflow' })}` })
|
||||
@ -434,14 +526,14 @@ export const useChecklistBeforePublish = () => {
|
||||
for (let i = 0; i < isRequiredNodesType.length; i++) {
|
||||
const type = isRequiredNodesType[i]
|
||||
|
||||
if (!filteredNodes.find(node => node.data.type === type)) {
|
||||
if (!filteredNodes.some(node => node.data.type === type)) {
|
||||
notify({ type: 'error', message: t('common.needAdd', { ns: 'workflow', node: t(`blocks.${type}` as I18nKeysWithPrefix<'workflow', 'blocks.'>, { ns: 'workflow' }) }) })
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}, [store, workflowStore, getNodesAvailableVarList, shouldCheckStartNode, nodesExtraData, notify, t, updateDatasetsDetail, buildInTools, customTools, workflowTools, language, getCheckData, strategyProviders, modelProviders])
|
||||
}, [store, workflowStore, getNodesAvailableVarList, shouldCheckStartNode, nodesExtraData, notify, t, updateDatasetsDetail, buildInTools, customTools, workflowTools, language, getCheckData, queryClient, strategyProviders, modelProviders])
|
||||
|
||||
return {
|
||||
handleCheckBeforePublish,
|
||||
|
||||
@ -3,8 +3,7 @@ import {
|
||||
memo,
|
||||
useState,
|
||||
} from 'react'
|
||||
import { ArrowDownRoundFill } from '@/app/components/base/icons/src/vender/solid/general'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
export type FieldTitleProps = {
|
||||
@ -12,6 +11,7 @@ export type FieldTitleProps = {
|
||||
operation?: ReactNode
|
||||
subTitle?: string | ReactNode
|
||||
tooltip?: string
|
||||
warningDot?: boolean
|
||||
showArrow?: boolean
|
||||
disabled?: boolean
|
||||
collapsed?: boolean
|
||||
@ -22,6 +22,7 @@ export const FieldTitle = memo(({
|
||||
operation,
|
||||
subTitle,
|
||||
tooltip,
|
||||
warningDot,
|
||||
showArrow,
|
||||
disabled,
|
||||
collapsed,
|
||||
@ -41,13 +42,19 @@ export const FieldTitle = memo(({
|
||||
}
|
||||
}}
|
||||
>
|
||||
<div className="system-sm-semibold-uppercase flex items-center text-text-secondary">
|
||||
{title}
|
||||
<div className="flex items-center text-text-secondary system-sm-semibold-uppercase">
|
||||
<span className="relative">
|
||||
{warningDot && (
|
||||
<span className="absolute -left-[9px] top-1/2 size-[5px] -translate-y-1/2 rounded-full bg-text-warning-secondary" />
|
||||
)}
|
||||
{title}
|
||||
</span>
|
||||
{
|
||||
showArrow && (
|
||||
<ArrowDownRoundFill
|
||||
<span
|
||||
aria-hidden
|
||||
className={cn(
|
||||
'h-4 w-4 cursor-pointer text-text-quaternary group-hover/collapse:text-text-secondary',
|
||||
'i-custom-vender-solid-general-arrow-down-round-fill h-4 w-4 cursor-pointer text-text-quaternary group-hover/collapse:text-text-secondary',
|
||||
collapsedMerged && 'rotate-[270deg]',
|
||||
)}
|
||||
/>
|
||||
@ -55,10 +62,19 @@ export const FieldTitle = memo(({
|
||||
}
|
||||
{
|
||||
tooltip && (
|
||||
<Tooltip
|
||||
popupContent={tooltip}
|
||||
triggerClassName="w-4 h-4 ml-1"
|
||||
/>
|
||||
<Tooltip>
|
||||
<TooltipTrigger
|
||||
delay={0}
|
||||
render={(
|
||||
<span className="ml-1 flex h-4 w-4 shrink-0 items-center justify-center">
|
||||
<span aria-hidden className="i-ri-question-line h-3.5 w-3.5 text-text-quaternary hover:text-text-tertiary" />
|
||||
</span>
|
||||
)}
|
||||
/>
|
||||
<TooltipContent>
|
||||
{tooltip}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import type { ChunkStructureEnum } from '../../types'
|
||||
import { RiAddLine } from '@remixicon/react'
|
||||
import { memo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Button from '@/app/components/base/button'
|
||||
@ -12,11 +11,13 @@ import Selector from './selector'
|
||||
type ChunkStructureProps = {
|
||||
chunkStructure?: ChunkStructureEnum
|
||||
onChunkStructureChange: (value: ChunkStructureEnum) => void
|
||||
warningDot?: boolean
|
||||
readonly?: boolean
|
||||
}
|
||||
const ChunkStructure = ({
|
||||
chunkStructure,
|
||||
onChunkStructureChange,
|
||||
warningDot = false,
|
||||
readonly = false,
|
||||
}: ChunkStructureProps) => {
|
||||
const { t } = useTranslation()
|
||||
@ -30,6 +31,7 @@ const ChunkStructure = ({
|
||||
fieldTitleProps={{
|
||||
title: t('nodes.knowledgeBase.chunkStructure', { ns: 'workflow' }),
|
||||
tooltip: t('nodes.knowledgeBase.chunkStructureTip.message', { ns: 'workflow' }),
|
||||
warningDot,
|
||||
operation: chunkStructure && (
|
||||
<Selector
|
||||
options={options}
|
||||
@ -62,7 +64,7 @@ const ChunkStructure = ({
|
||||
className="w-full"
|
||||
variant="secondary-accent"
|
||||
>
|
||||
<RiAddLine className="mr-1 h-4 w-4" />
|
||||
<span className="i-ri-add-line mr-1 h-4 w-4" />
|
||||
{t('nodes.knowledgeBase.chooseChunkStructure', { ns: 'workflow' })}
|
||||
</Button>
|
||||
)}
|
||||
|
||||
@ -17,12 +17,14 @@ type EmbeddingModelProps = {
|
||||
embeddingModel: string
|
||||
embeddingModelProvider: string
|
||||
}) => void
|
||||
warningDot?: boolean
|
||||
readonly?: boolean
|
||||
}
|
||||
const EmbeddingModel = ({
|
||||
embeddingModel,
|
||||
embeddingModelProvider,
|
||||
onEmbeddingModelChange,
|
||||
warningDot = false,
|
||||
readonly = false,
|
||||
}: EmbeddingModelProps) => {
|
||||
const { t } = useTranslation()
|
||||
@ -50,6 +52,7 @@ const EmbeddingModel = ({
|
||||
<Field
|
||||
fieldTitleProps={{
|
||||
title: t('form.embeddingModel', { ns: 'datasetSettings' }),
|
||||
warningDot,
|
||||
}}
|
||||
>
|
||||
<ModelSelector
|
||||
|
||||
@ -1,8 +1,11 @@
|
||||
import type { NodeDefault } from '../../types'
|
||||
import type { KnowledgeBaseNodeType } from './types'
|
||||
import { IndexingType } from '@/app/components/datasets/create/step-two'
|
||||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
import { genNodeMetaData } from '@/app/components/workflow/utils'
|
||||
import {
|
||||
getKnowledgeBaseValidationIssue,
|
||||
getKnowledgeBaseValidationMessage,
|
||||
} from './utils'
|
||||
|
||||
const metaData = genNodeMetaData({
|
||||
sort: 3.1,
|
||||
@ -24,86 +27,9 @@ const nodeDefault: NodeDefault<KnowledgeBaseNodeType> = {
|
||||
},
|
||||
},
|
||||
checkValid(payload, t) {
|
||||
const {
|
||||
chunk_structure,
|
||||
indexing_technique,
|
||||
retrieval_model,
|
||||
embedding_model,
|
||||
embedding_model_provider,
|
||||
index_chunk_variable_selector,
|
||||
_embeddingModelList,
|
||||
_rerankModelList,
|
||||
} = payload
|
||||
|
||||
const {
|
||||
search_method,
|
||||
reranking_enable,
|
||||
reranking_model,
|
||||
} = retrieval_model || {}
|
||||
|
||||
const currentEmbeddingModelProvider = _embeddingModelList?.find(provider => provider.provider === embedding_model_provider)
|
||||
const currentEmbeddingModel = currentEmbeddingModelProvider?.models.find(model => model.model === embedding_model)
|
||||
|
||||
const currentRerankingModelProvider = _rerankModelList?.find(provider => provider.provider === reranking_model?.reranking_provider_name)
|
||||
const currentRerankingModel = currentRerankingModelProvider?.models.find(model => model.model === reranking_model?.reranking_model_name)
|
||||
|
||||
if (!chunk_structure) {
|
||||
return {
|
||||
isValid: false,
|
||||
errorMessage: t('nodes.knowledgeBase.chunkIsRequired', { ns: 'workflow' }),
|
||||
}
|
||||
}
|
||||
|
||||
if (index_chunk_variable_selector.length === 0) {
|
||||
return {
|
||||
isValid: false,
|
||||
errorMessage: t('nodes.knowledgeBase.chunksVariableIsRequired', { ns: 'workflow' }),
|
||||
}
|
||||
}
|
||||
|
||||
if (!indexing_technique) {
|
||||
return {
|
||||
isValid: false,
|
||||
errorMessage: t('nodes.knowledgeBase.indexMethodIsRequired', { ns: 'workflow' }),
|
||||
}
|
||||
}
|
||||
|
||||
if (indexing_technique === IndexingType.QUALIFIED) {
|
||||
if (!embedding_model || !embedding_model_provider) {
|
||||
return {
|
||||
isValid: false,
|
||||
errorMessage: t('nodes.knowledgeBase.embeddingModelIsRequired', { ns: 'workflow' }),
|
||||
}
|
||||
}
|
||||
else if (!currentEmbeddingModel) {
|
||||
return {
|
||||
isValid: false,
|
||||
errorMessage: t('nodes.knowledgeBase.embeddingModelIsInvalid', { ns: 'workflow' }),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!retrieval_model || !search_method) {
|
||||
return {
|
||||
isValid: false,
|
||||
errorMessage: t('nodes.knowledgeBase.retrievalSettingIsRequired', { ns: 'workflow' }),
|
||||
}
|
||||
}
|
||||
|
||||
if (reranking_enable) {
|
||||
if (!reranking_model || !reranking_model.reranking_provider_name || !reranking_model.reranking_model_name) {
|
||||
return {
|
||||
isValid: false,
|
||||
errorMessage: t('nodes.knowledgeBase.rerankingModelIsRequired', { ns: 'workflow' }),
|
||||
}
|
||||
}
|
||||
else if (!currentRerankingModel) {
|
||||
return {
|
||||
isValid: false,
|
||||
errorMessage: t('nodes.knowledgeBase.rerankingModelIsInvalid', { ns: 'workflow' }),
|
||||
}
|
||||
}
|
||||
}
|
||||
const issue = getKnowledgeBaseValidationIssue(payload)
|
||||
if (issue)
|
||||
return { isValid: false, errorMessage: getKnowledgeBaseValidationMessage(issue, t) }
|
||||
|
||||
return {
|
||||
isValid: true,
|
||||
|
||||
@ -1,38 +1,200 @@
|
||||
import type { FC } from 'react'
|
||||
import type { KnowledgeBaseNodeType } from './types'
|
||||
import type { NodeProps } from '@/app/components/workflow/types'
|
||||
import { memo } from 'react'
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import {
|
||||
memo,
|
||||
useMemo,
|
||||
} from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import {
|
||||
ModelTypeEnum,
|
||||
} from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import {
|
||||
useLanguage,
|
||||
useModelList,
|
||||
} from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { useSettingsDisplay } from './hooks/use-settings-display'
|
||||
import {
|
||||
IndexMethodEnum,
|
||||
} from './types'
|
||||
import {
|
||||
getKnowledgeBaseValidationIssue,
|
||||
getKnowledgeBaseValidationMessage,
|
||||
isKnowledgeBaseEmbeddingIssue,
|
||||
KnowledgeBaseValidationIssueCode,
|
||||
} from './utils'
|
||||
|
||||
type SettingRowProps = {
|
||||
label: string
|
||||
value: string
|
||||
warning?: boolean
|
||||
}
|
||||
|
||||
const SettingRow = memo(({
|
||||
label,
|
||||
value,
|
||||
warning = false,
|
||||
}: SettingRowProps) => {
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
'flex h-6 items-center rounded-md px-1.5',
|
||||
warning
|
||||
? 'border-[0.5px] border-state-warning-active bg-state-warning-hover'
|
||||
: 'bg-workflow-block-parma-bg',
|
||||
)}
|
||||
>
|
||||
<div className={cn('mr-2 shrink-0 system-xs-medium-uppercase', warning ? 'text-text-warning' : 'text-text-tertiary')}>
|
||||
{label}
|
||||
</div>
|
||||
<div
|
||||
className={cn('grow truncate text-right system-xs-medium', warning ? 'text-text-warning' : 'text-text-secondary')}
|
||||
title={value}
|
||||
>
|
||||
{value}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
})
|
||||
|
||||
const RETRIEVAL_WARNING_CODES = new Set<KnowledgeBaseValidationIssueCode>([
|
||||
KnowledgeBaseValidationIssueCode.retrievalSettingRequired,
|
||||
KnowledgeBaseValidationIssueCode.rerankingModelRequired,
|
||||
KnowledgeBaseValidationIssueCode.rerankingModelInvalid,
|
||||
])
|
||||
|
||||
const Node: FC<NodeProps<KnowledgeBaseNodeType>> = ({ data }) => {
|
||||
const { t } = useTranslation()
|
||||
const language = useLanguage()
|
||||
const settingsDisplay = useSettingsDisplay()
|
||||
const { data: embeddingModelList } = useModelList(ModelTypeEnum.textEmbedding)
|
||||
const { data: rerankModelList } = useModelList(ModelTypeEnum.rerank)
|
||||
const chunkStructure = data.chunk_structure
|
||||
const indexChunkVariableSelector = data.index_chunk_variable_selector
|
||||
const indexingTechnique = data.indexing_technique
|
||||
const embeddingModel = data.embedding_model
|
||||
const retrievalModel = data.retrieval_model
|
||||
const retrievalSearchMethod = retrievalModel?.search_method
|
||||
const retrievalRerankingEnable = retrievalModel?.reranking_enable
|
||||
const embeddingModelProvider = data.embedding_model_provider
|
||||
const { data: embeddingProviderModelList } = useQuery(
|
||||
consoleQuery.modelProviders.models.queryOptions({
|
||||
input: { params: { provider: embeddingModelProvider || '' } },
|
||||
enabled: indexingTechnique === IndexMethodEnum.QUALIFIED && !!embeddingModelProvider,
|
||||
refetchOnWindowFocus: false,
|
||||
select: response => response.data,
|
||||
}),
|
||||
)
|
||||
|
||||
const validationPayload = useMemo(() => {
|
||||
return {
|
||||
chunk_structure: chunkStructure,
|
||||
index_chunk_variable_selector: indexChunkVariableSelector,
|
||||
indexing_technique: indexingTechnique,
|
||||
embedding_model: embeddingModel,
|
||||
embedding_model_provider: embeddingModelProvider,
|
||||
retrieval_model: {
|
||||
search_method: retrievalSearchMethod,
|
||||
reranking_enable: retrievalRerankingEnable,
|
||||
reranking_model: retrievalModel?.reranking_model,
|
||||
},
|
||||
_embeddingModelList: embeddingModelList,
|
||||
_embeddingProviderModelList: embeddingProviderModelList,
|
||||
_rerankModelList: rerankModelList,
|
||||
}
|
||||
}, [
|
||||
chunkStructure,
|
||||
indexChunkVariableSelector,
|
||||
indexingTechnique,
|
||||
embeddingModel,
|
||||
embeddingModelProvider,
|
||||
retrievalSearchMethod,
|
||||
retrievalRerankingEnable,
|
||||
retrievalModel?.reranking_model,
|
||||
embeddingModelList,
|
||||
embeddingProviderModelList,
|
||||
rerankModelList,
|
||||
])
|
||||
|
||||
const validationIssue = useMemo(() => {
|
||||
return getKnowledgeBaseValidationIssue({
|
||||
...validationPayload,
|
||||
})
|
||||
}, [validationPayload])
|
||||
|
||||
const validationIssueMessage = useMemo(() => {
|
||||
return getKnowledgeBaseValidationMessage(validationIssue, t)
|
||||
}, [validationIssue, t])
|
||||
|
||||
const chunksDisplayValue = useMemo(() => {
|
||||
if (!data.index_chunk_variable_selector?.length)
|
||||
return '-'
|
||||
|
||||
const chunkVar = data.index_chunk_variable_selector.at(-1)
|
||||
return chunkVar || '-'
|
||||
}, [data.index_chunk_variable_selector])
|
||||
|
||||
const embeddingModelDisplay = useMemo(() => {
|
||||
if (data.indexing_technique !== IndexMethodEnum.QUALIFIED)
|
||||
return '-'
|
||||
|
||||
if (isKnowledgeBaseEmbeddingIssue(validationIssue))
|
||||
return validationIssueMessage
|
||||
|
||||
const currentEmbeddingModelProvider = embeddingModelList.find(provider => provider.provider === data.embedding_model_provider)
|
||||
const currentEmbeddingModel = currentEmbeddingModelProvider?.models.find(model => model.model === data.embedding_model)
|
||||
return currentEmbeddingModel?.label[language] || currentEmbeddingModel?.label.en_US || data.embedding_model || '-'
|
||||
}, [data.embedding_model, data.embedding_model_provider, data.indexing_technique, embeddingModelList, language, validationIssue, validationIssueMessage])
|
||||
|
||||
const indexMethodDisplay = settingsDisplay[data.indexing_technique as keyof typeof settingsDisplay] || '-'
|
||||
const retrievalMethodDisplay = settingsDisplay[data.retrieval_model?.search_method as keyof typeof settingsDisplay] || '-'
|
||||
|
||||
const chunksWarning = validationIssue?.code === KnowledgeBaseValidationIssueCode.chunksVariableRequired
|
||||
const indexMethodWarning = validationIssue?.code === KnowledgeBaseValidationIssueCode.indexMethodRequired
|
||||
const embeddingWarning = isKnowledgeBaseEmbeddingIssue(validationIssue)
|
||||
const showEmbeddingModelRow = data.indexing_technique === IndexMethodEnum.QUALIFIED
|
||||
const retrievalWarning = !!(validationIssue && RETRIEVAL_WARNING_CODES.has(validationIssue.code))
|
||||
|
||||
if (!data.chunk_structure) {
|
||||
return (
|
||||
<div className="mb-1 space-y-0.5 px-3 py-1">
|
||||
<div className="flex h-6 items-center rounded-md border-[0.5px] border-state-warning-active bg-state-warning-hover px-1.5">
|
||||
<span className="mr-1 size-[4px] shrink-0 rounded-[2px] bg-text-warning-secondary" />
|
||||
<div className="grow truncate text-text-warning system-xs-medium" title={validationIssueMessage}>
|
||||
{validationIssueMessage}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="mb-1 space-y-0.5 px-3 py-1">
|
||||
<div className="flex h-6 items-center rounded-md bg-workflow-block-parma-bg px-1.5">
|
||||
<div className="system-xs-medium-uppercase mr-2 shrink-0 text-text-tertiary">
|
||||
{t('stepTwo.indexMode', { ns: 'datasetCreation' })}
|
||||
</div>
|
||||
<div
|
||||
className="system-xs-medium grow truncate text-right text-text-secondary"
|
||||
title={data.indexing_technique}
|
||||
>
|
||||
{settingsDisplay[data.indexing_technique as keyof typeof settingsDisplay]}
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex h-6 items-center rounded-md bg-workflow-block-parma-bg px-1.5">
|
||||
<div className="system-xs-medium-uppercase mr-2 shrink-0 text-text-tertiary">
|
||||
{t('form.retrievalSetting.title', { ns: 'datasetSettings' })}
|
||||
</div>
|
||||
<div
|
||||
className="system-xs-medium grow truncate text-right text-text-secondary"
|
||||
title={data.retrieval_model?.search_method}
|
||||
>
|
||||
{settingsDisplay[data.retrieval_model?.search_method as keyof typeof settingsDisplay]}
|
||||
</div>
|
||||
</div>
|
||||
<SettingRow
|
||||
label={t('nodes.knowledgeBase.chunksInput', { ns: 'workflow' })}
|
||||
value={chunksWarning ? validationIssueMessage : chunksDisplayValue}
|
||||
warning={chunksWarning}
|
||||
/>
|
||||
<SettingRow
|
||||
label={t('stepTwo.indexMode', { ns: 'datasetCreation' })}
|
||||
value={indexMethodWarning ? validationIssueMessage : indexMethodDisplay}
|
||||
warning={indexMethodWarning}
|
||||
/>
|
||||
{showEmbeddingModelRow && (
|
||||
<SettingRow
|
||||
label={t('form.embeddingModel', { ns: 'datasetSettings' })}
|
||||
value={embeddingModelDisplay}
|
||||
warning={embeddingWarning}
|
||||
/>
|
||||
)}
|
||||
<SettingRow
|
||||
label={t('form.retrievalSetting.method', { ns: 'datasetSettings' })}
|
||||
value={retrievalWarning ? validationIssueMessage : retrievalMethodDisplay}
|
||||
warning={retrievalWarning}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import type { FC } from 'react'
|
||||
import type { KnowledgeBaseNodeType } from './types'
|
||||
import type { NodePanelProps, Var } from '@/app/components/workflow/types'
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import {
|
||||
memo,
|
||||
useCallback,
|
||||
@ -19,6 +20,7 @@ import {
|
||||
} from '@/app/components/workflow/nodes/_base/components/layout'
|
||||
import VarReferencePicker from '@/app/components/workflow/nodes/_base/components/variable/var-reference-picker'
|
||||
import { IS_CE_EDITION } from '@/config'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
import Split from '../_base/components/split'
|
||||
import ChunkStructure from './components/chunk-structure'
|
||||
import EmbeddingModel from './components/embedding-model'
|
||||
@ -29,6 +31,11 @@ import {
|
||||
ChunkStructureEnum,
|
||||
IndexMethodEnum,
|
||||
} from './types'
|
||||
import {
|
||||
getKnowledgeBaseValidationIssue,
|
||||
isKnowledgeBaseEmbeddingIssue,
|
||||
KnowledgeBaseValidationIssueCode,
|
||||
} from './utils'
|
||||
|
||||
const Panel: FC<NodePanelProps<KnowledgeBaseNodeType>> = ({
|
||||
id,
|
||||
@ -38,6 +45,22 @@ const Panel: FC<NodePanelProps<KnowledgeBaseNodeType>> = ({
|
||||
const { nodesReadOnly } = useNodesReadOnly()
|
||||
const { data: embeddingModelList } = useModelList(ModelTypeEnum.textEmbedding)
|
||||
const { data: rerankModelList } = useModelList(ModelTypeEnum.rerank)
|
||||
const chunkStructure = data.chunk_structure
|
||||
const indexChunkVariableSelector = data.index_chunk_variable_selector
|
||||
const indexingTechnique = data.indexing_technique
|
||||
const embeddingModel = data.embedding_model
|
||||
const retrievalModel = data.retrieval_model
|
||||
const retrievalSearchMethod = retrievalModel?.search_method
|
||||
const retrievalRerankingEnable = retrievalModel?.reranking_enable
|
||||
const embeddingModelProvider = data.embedding_model_provider
|
||||
const { data: embeddingProviderModelList } = useQuery(
|
||||
consoleQuery.modelProviders.models.queryOptions({
|
||||
input: { params: { provider: embeddingModelProvider || '' } },
|
||||
enabled: indexingTechnique === IndexMethodEnum.QUALIFIED && !!embeddingModelProvider,
|
||||
refetchOnWindowFocus: false,
|
||||
select: response => response.data,
|
||||
}),
|
||||
)
|
||||
|
||||
const {
|
||||
handleChunkStructureChange,
|
||||
@ -108,6 +131,44 @@ const Panel: FC<NodePanelProps<KnowledgeBaseNodeType>> = ({
|
||||
})
|
||||
}, [data.embedding_model_provider, data.embedding_model, data.retrieval_model?.reranking_enable, data.retrieval_model?.reranking_model, data.indexing_technique, embeddingModelList, rerankModelList])
|
||||
|
||||
const validationPayload = useMemo(() => {
|
||||
return {
|
||||
chunk_structure: chunkStructure,
|
||||
index_chunk_variable_selector: indexChunkVariableSelector,
|
||||
indexing_technique: indexingTechnique,
|
||||
embedding_model: embeddingModel,
|
||||
embedding_model_provider: embeddingModelProvider,
|
||||
retrieval_model: {
|
||||
search_method: retrievalSearchMethod,
|
||||
reranking_enable: retrievalRerankingEnable,
|
||||
reranking_model: retrievalModel?.reranking_model,
|
||||
},
|
||||
_embeddingModelList: embeddingModelList,
|
||||
_embeddingProviderModelList: embeddingProviderModelList,
|
||||
_rerankModelList: rerankModelList,
|
||||
}
|
||||
}, [
|
||||
chunkStructure,
|
||||
indexChunkVariableSelector,
|
||||
indexingTechnique,
|
||||
embeddingModel,
|
||||
embeddingModelProvider,
|
||||
retrievalSearchMethod,
|
||||
retrievalRerankingEnable,
|
||||
retrievalModel?.reranking_model,
|
||||
embeddingModelList,
|
||||
embeddingProviderModelList,
|
||||
rerankModelList,
|
||||
])
|
||||
|
||||
const validationIssue = useMemo(() => {
|
||||
return getKnowledgeBaseValidationIssue(validationPayload)
|
||||
}, [validationPayload])
|
||||
|
||||
const chunkStructureWarning = validationIssue?.code === KnowledgeBaseValidationIssueCode.chunkStructureRequired
|
||||
const chunksInputWarning = validationIssue?.code === KnowledgeBaseValidationIssueCode.chunksVariableRequired
|
||||
const embeddingModelWarning = isKnowledgeBaseEmbeddingIssue(validationIssue)
|
||||
|
||||
return (
|
||||
<div>
|
||||
<Group
|
||||
@ -117,6 +178,7 @@ const Panel: FC<NodePanelProps<KnowledgeBaseNodeType>> = ({
|
||||
<ChunkStructure
|
||||
chunkStructure={data.chunk_structure}
|
||||
onChunkStructureChange={handleChunkStructureChange}
|
||||
warningDot={chunkStructureWarning}
|
||||
readonly={nodesReadOnly}
|
||||
/>
|
||||
</Group>
|
||||
@ -131,6 +193,7 @@ const Panel: FC<NodePanelProps<KnowledgeBaseNodeType>> = ({
|
||||
fieldTitleProps: {
|
||||
title: t('nodes.knowledgeBase.chunksInput', { ns: 'workflow' }),
|
||||
tooltip: t('nodes.knowledgeBase.chunksInputTip', { ns: 'workflow' }),
|
||||
warningDot: chunksInputWarning,
|
||||
},
|
||||
}}
|
||||
>
|
||||
@ -163,6 +226,7 @@ const Panel: FC<NodePanelProps<KnowledgeBaseNodeType>> = ({
|
||||
embeddingModel={data.embedding_model}
|
||||
embeddingModelProvider={data.embedding_model_provider}
|
||||
onEmbeddingModelChange={handleEmbeddingModelChange}
|
||||
warningDot={embeddingModelWarning}
|
||||
readonly={nodesReadOnly}
|
||||
/>
|
||||
)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import type { IndexingType } from '@/app/components/datasets/create/step-two'
|
||||
import type { Model } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import type { Model, ModelItem } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import type { CommonNodeType } from '@/app/components/workflow/types'
|
||||
import type { RerankingModeEnum, WeightedScoreEnum } from '@/models/datasets'
|
||||
import type { RETRIEVE_METHOD } from '@/types/app'
|
||||
@ -57,6 +57,7 @@ export type KnowledgeBaseNodeType = CommonNodeType & {
|
||||
keyword_number: number
|
||||
retrieval_model: RetrievalSetting
|
||||
_embeddingModelList?: Model[]
|
||||
_embeddingProviderModelList?: ModelItem[]
|
||||
_rerankModelList?: Model[]
|
||||
summary_index_setting?: SummaryIndexSetting
|
||||
}
|
||||
|
||||
116
web/app/components/workflow/nodes/knowledge-base/utils.spec.ts
Normal file
116
web/app/components/workflow/nodes/knowledge-base/utils.spec.ts
Normal file
@ -0,0 +1,116 @@
|
||||
import type { KnowledgeBaseNodeType } from './types'
|
||||
import type { Model, ModelItem } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import {
|
||||
ConfigurationMethodEnum,
|
||||
ModelStatusEnum,
|
||||
ModelTypeEnum,
|
||||
} from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import {
|
||||
ChunkStructureEnum,
|
||||
IndexMethodEnum,
|
||||
RetrievalSearchMethodEnum,
|
||||
} from './types'
|
||||
import {
|
||||
getKnowledgeBaseValidationIssue,
|
||||
KnowledgeBaseValidationIssueCode,
|
||||
} from './utils'
|
||||
|
||||
const makeEmbeddingModelList = (status: ModelStatusEnum): Model[] => {
|
||||
return [
|
||||
{
|
||||
provider: 'openai',
|
||||
icon_small: { en_US: '', zh_Hans: '' },
|
||||
label: { en_US: 'OpenAI', zh_Hans: 'OpenAI' },
|
||||
models: [{
|
||||
model: 'gpt-4o',
|
||||
label: { en_US: 'GPT-4o', zh_Hans: 'GPT-4o' },
|
||||
model_type: ModelTypeEnum.textEmbedding,
|
||||
fetch_from: ConfigurationMethodEnum.predefinedModel,
|
||||
status,
|
||||
model_properties: {},
|
||||
load_balancing_enabled: false,
|
||||
}],
|
||||
status,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
const makeEmbeddingProviderModelList = (status: ModelStatusEnum): ModelItem[] => {
|
||||
return [{
|
||||
model: 'gpt-4o',
|
||||
label: { en_US: 'GPT-4o', zh_Hans: 'GPT-4o' },
|
||||
model_type: ModelTypeEnum.textEmbedding,
|
||||
fetch_from: ConfigurationMethodEnum.predefinedModel,
|
||||
status,
|
||||
model_properties: {},
|
||||
load_balancing_enabled: false,
|
||||
}]
|
||||
}
|
||||
|
||||
const makePayload = (overrides: Partial<KnowledgeBaseNodeType> = {}): KnowledgeBaseNodeType => {
|
||||
return {
|
||||
index_chunk_variable_selector: ['general_chunks', 'results'],
|
||||
chunk_structure: ChunkStructureEnum.general,
|
||||
indexing_technique: IndexMethodEnum.QUALIFIED,
|
||||
embedding_model: 'gpt-4o',
|
||||
embedding_model_provider: 'openai',
|
||||
keyword_number: 10,
|
||||
retrieval_model: {
|
||||
top_k: 3,
|
||||
score_threshold_enabled: false,
|
||||
score_threshold: 0.5,
|
||||
search_method: RetrievalSearchMethodEnum.semantic,
|
||||
},
|
||||
_embeddingModelList: makeEmbeddingModelList(ModelStatusEnum.active),
|
||||
_embeddingProviderModelList: makeEmbeddingProviderModelList(ModelStatusEnum.active),
|
||||
_rerankModelList: [],
|
||||
...overrides,
|
||||
} as KnowledgeBaseNodeType
|
||||
}
|
||||
|
||||
describe('knowledge-base validation issue', () => {
|
||||
it('returns chunk structure issue when chunk structure is missing', () => {
|
||||
const issue = getKnowledgeBaseValidationIssue(makePayload({ chunk_structure: undefined }))
|
||||
expect(issue?.code).toBe(KnowledgeBaseValidationIssueCode.chunkStructureRequired)
|
||||
})
|
||||
|
||||
it('returns chunks variable issue when chunks selector is empty', () => {
|
||||
const issue = getKnowledgeBaseValidationIssue(makePayload({ index_chunk_variable_selector: [] }))
|
||||
expect(issue?.code).toBe(KnowledgeBaseValidationIssueCode.chunksVariableRequired)
|
||||
})
|
||||
|
||||
it('maps no-configure to not configured', () => {
|
||||
const issue = getKnowledgeBaseValidationIssue(
|
||||
makePayload({ _embeddingProviderModelList: makeEmbeddingProviderModelList(ModelStatusEnum.noConfigure) }),
|
||||
)
|
||||
expect(issue?.code).toBe(KnowledgeBaseValidationIssueCode.embeddingModelNotConfigured)
|
||||
})
|
||||
|
||||
it('maps credential-removed to API key unavailable', () => {
|
||||
const issue = getKnowledgeBaseValidationIssue(
|
||||
makePayload({ _embeddingProviderModelList: makeEmbeddingProviderModelList(ModelStatusEnum.credentialRemoved) }),
|
||||
)
|
||||
expect(issue?.code).toBe(KnowledgeBaseValidationIssueCode.embeddingModelApiKeyUnavailable)
|
||||
})
|
||||
|
||||
it('maps quota-exceeded to credits exhausted', () => {
|
||||
const issue = getKnowledgeBaseValidationIssue(
|
||||
makePayload({ _embeddingProviderModelList: makeEmbeddingProviderModelList(ModelStatusEnum.quotaExceeded) }),
|
||||
)
|
||||
expect(issue?.code).toBe(KnowledgeBaseValidationIssueCode.embeddingModelCreditsExhausted)
|
||||
})
|
||||
|
||||
it('maps disabled to incompatible', () => {
|
||||
const issue = getKnowledgeBaseValidationIssue(
|
||||
makePayload({ _embeddingProviderModelList: makeEmbeddingProviderModelList(ModelStatusEnum.disabled) }),
|
||||
)
|
||||
expect(issue?.code).toBe(KnowledgeBaseValidationIssueCode.embeddingModelIncompatible)
|
||||
})
|
||||
|
||||
it('falls back to provider model list when provider scoped model list is empty', () => {
|
||||
const issue = getKnowledgeBaseValidationIssue(
|
||||
makePayload({ _embeddingProviderModelList: [] }),
|
||||
)
|
||||
expect(issue).toBeNull()
|
||||
})
|
||||
})
|
||||
@ -1,3 +1,12 @@
|
||||
import type { TFunction } from 'i18next'
|
||||
import type { KnowledgeBaseNodeType } from './types'
|
||||
import type { I18nKeysWithPrefix } from '@/types/i18n'
|
||||
import {
|
||||
IndexingType,
|
||||
} from '@/app/components/datasets/create/step-two'
|
||||
import {
|
||||
ModelStatusEnum,
|
||||
} from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import {
|
||||
RetrievalSearchMethodEnum,
|
||||
} from './types'
|
||||
@ -7,3 +16,154 @@ export const isHighQualitySearchMethod = (searchMethod: RetrievalSearchMethodEnu
|
||||
|| searchMethod === RetrievalSearchMethodEnum.hybrid
|
||||
|| searchMethod === RetrievalSearchMethodEnum.fullText
|
||||
}
|
||||
|
||||
export enum KnowledgeBaseValidationIssueCode {
|
||||
chunkStructureRequired = 'chunk-structure-required',
|
||||
chunksVariableRequired = 'chunks-variable-required',
|
||||
indexMethodRequired = 'index-method-required',
|
||||
embeddingModelNotConfigured = 'embedding-model-not-configured',
|
||||
embeddingModelApiKeyUnavailable = 'embedding-model-api-key-unavailable',
|
||||
embeddingModelCreditsExhausted = 'embedding-model-credits-exhausted',
|
||||
embeddingModelIncompatible = 'embedding-model-incompatible',
|
||||
retrievalSettingRequired = 'retrieval-setting-required',
|
||||
rerankingModelRequired = 'reranking-model-required',
|
||||
rerankingModelInvalid = 'reranking-model-invalid',
|
||||
}
|
||||
|
||||
type KnowledgeBaseValidationIssue = {
|
||||
code: KnowledgeBaseValidationIssueCode
|
||||
i18nKey: I18nKeysWithPrefix<'workflow', 'nodes.knowledgeBase.'>
|
||||
}
|
||||
|
||||
type KnowledgeBaseValidationPayload = Pick<KnowledgeBaseNodeType, 'chunk_structure' | 'index_chunk_variable_selector' | 'indexing_technique' | 'embedding_model' | 'embedding_model_provider' | '_embeddingModelList' | '_embeddingProviderModelList' | '_rerankModelList'> & {
|
||||
retrieval_model?: Pick<KnowledgeBaseNodeType['retrieval_model'], 'search_method' | 'reranking_enable' | 'reranking_model'>
|
||||
}
|
||||
|
||||
const ISSUE_I18N_KEY_MAP = {
|
||||
[KnowledgeBaseValidationIssueCode.chunkStructureRequired]: 'nodes.knowledgeBase.chunkIsRequired',
|
||||
[KnowledgeBaseValidationIssueCode.chunksVariableRequired]: 'nodes.knowledgeBase.chunksVariableIsRequired',
|
||||
[KnowledgeBaseValidationIssueCode.indexMethodRequired]: 'nodes.knowledgeBase.indexMethodIsRequired',
|
||||
[KnowledgeBaseValidationIssueCode.embeddingModelNotConfigured]: 'nodes.knowledgeBase.embeddingModelNotConfigured',
|
||||
[KnowledgeBaseValidationIssueCode.embeddingModelApiKeyUnavailable]: 'nodes.knowledgeBase.embeddingModelApiKeyUnavailable',
|
||||
[KnowledgeBaseValidationIssueCode.embeddingModelCreditsExhausted]: 'nodes.knowledgeBase.embeddingModelCreditsExhausted',
|
||||
[KnowledgeBaseValidationIssueCode.embeddingModelIncompatible]: 'nodes.knowledgeBase.embeddingModelIncompatible',
|
||||
[KnowledgeBaseValidationIssueCode.retrievalSettingRequired]: 'nodes.knowledgeBase.retrievalSettingIsRequired',
|
||||
[KnowledgeBaseValidationIssueCode.rerankingModelRequired]: 'nodes.knowledgeBase.rerankingModelIsRequired',
|
||||
[KnowledgeBaseValidationIssueCode.rerankingModelInvalid]: 'nodes.knowledgeBase.rerankingModelIsInvalid',
|
||||
} as const satisfies Record<KnowledgeBaseValidationIssueCode, I18nKeysWithPrefix<'workflow', 'nodes.knowledgeBase.'>>
|
||||
|
||||
const EMBEDDING_ISSUE_CODES = new Set<KnowledgeBaseValidationIssueCode>([
|
||||
KnowledgeBaseValidationIssueCode.embeddingModelNotConfigured,
|
||||
KnowledgeBaseValidationIssueCode.embeddingModelApiKeyUnavailable,
|
||||
KnowledgeBaseValidationIssueCode.embeddingModelCreditsExhausted,
|
||||
KnowledgeBaseValidationIssueCode.embeddingModelIncompatible,
|
||||
])
|
||||
|
||||
const resolveIssue = (code: KnowledgeBaseValidationIssueCode): KnowledgeBaseValidationIssue => ({
|
||||
code,
|
||||
i18nKey: ISSUE_I18N_KEY_MAP[code],
|
||||
})
|
||||
|
||||
const resolveEmbeddingIssue = (payload: KnowledgeBaseValidationPayload): KnowledgeBaseValidationIssue | null => {
|
||||
const {
|
||||
embedding_model,
|
||||
embedding_model_provider,
|
||||
_embeddingModelList,
|
||||
_embeddingProviderModelList,
|
||||
} = payload
|
||||
|
||||
if (!embedding_model || !embedding_model_provider)
|
||||
return resolveIssue(KnowledgeBaseValidationIssueCode.embeddingModelNotConfigured)
|
||||
|
||||
const currentEmbeddingModelProvider = _embeddingModelList?.find(provider => provider.provider === embedding_model_provider)
|
||||
const hasProviderScopedModelList = !!_embeddingProviderModelList && _embeddingProviderModelList.length > 0
|
||||
const embeddingModelCandidates = hasProviderScopedModelList
|
||||
? _embeddingProviderModelList
|
||||
: currentEmbeddingModelProvider?.models
|
||||
const currentEmbeddingModel = embeddingModelCandidates?.find(model => model.model === embedding_model)
|
||||
|
||||
if (!currentEmbeddingModel) {
|
||||
const providerExists = hasProviderScopedModelList || currentEmbeddingModelProvider
|
||||
return resolveIssue(providerExists
|
||||
? KnowledgeBaseValidationIssueCode.embeddingModelIncompatible
|
||||
: KnowledgeBaseValidationIssueCode.embeddingModelNotConfigured)
|
||||
}
|
||||
|
||||
switch (currentEmbeddingModel.status) {
|
||||
case ModelStatusEnum.active:
|
||||
return null
|
||||
case ModelStatusEnum.noConfigure:
|
||||
return resolveIssue(KnowledgeBaseValidationIssueCode.embeddingModelNotConfigured)
|
||||
case ModelStatusEnum.credentialRemoved:
|
||||
return resolveIssue(KnowledgeBaseValidationIssueCode.embeddingModelApiKeyUnavailable)
|
||||
case ModelStatusEnum.quotaExceeded:
|
||||
return resolveIssue(KnowledgeBaseValidationIssueCode.embeddingModelCreditsExhausted)
|
||||
case ModelStatusEnum.noPermission:
|
||||
case ModelStatusEnum.disabled:
|
||||
default:
|
||||
return resolveIssue(KnowledgeBaseValidationIssueCode.embeddingModelIncompatible)
|
||||
}
|
||||
}
|
||||
|
||||
export const getKnowledgeBaseValidationIssue = (payload: KnowledgeBaseValidationPayload): KnowledgeBaseValidationIssue | null => {
|
||||
const {
|
||||
chunk_structure,
|
||||
indexing_technique,
|
||||
retrieval_model,
|
||||
index_chunk_variable_selector,
|
||||
_rerankModelList,
|
||||
} = payload
|
||||
|
||||
const {
|
||||
search_method,
|
||||
reranking_enable,
|
||||
reranking_model,
|
||||
} = retrieval_model || {}
|
||||
|
||||
if (!chunk_structure)
|
||||
return resolveIssue(KnowledgeBaseValidationIssueCode.chunkStructureRequired)
|
||||
|
||||
if (index_chunk_variable_selector.length === 0)
|
||||
return resolveIssue(KnowledgeBaseValidationIssueCode.chunksVariableRequired)
|
||||
|
||||
if (!indexing_technique)
|
||||
return resolveIssue(KnowledgeBaseValidationIssueCode.indexMethodRequired)
|
||||
|
||||
if (indexing_technique === IndexingType.QUALIFIED) {
|
||||
const embeddingIssue = resolveEmbeddingIssue(payload)
|
||||
if (embeddingIssue)
|
||||
return embeddingIssue
|
||||
}
|
||||
|
||||
if (!retrieval_model || !search_method)
|
||||
return resolveIssue(KnowledgeBaseValidationIssueCode.retrievalSettingRequired)
|
||||
|
||||
if (reranking_enable) {
|
||||
if (!reranking_model || !reranking_model.reranking_provider_name || !reranking_model.reranking_model_name)
|
||||
return resolveIssue(KnowledgeBaseValidationIssueCode.rerankingModelRequired)
|
||||
|
||||
const currentRerankingModelProvider = _rerankModelList?.find(provider => provider.provider === reranking_model.reranking_provider_name)
|
||||
const currentRerankingModel = currentRerankingModelProvider?.models.find(model => model.model === reranking_model.reranking_model_name)
|
||||
if (!currentRerankingModel)
|
||||
return resolveIssue(KnowledgeBaseValidationIssueCode.rerankingModelInvalid)
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
export const getKnowledgeBaseValidationMessage = (
|
||||
issue: KnowledgeBaseValidationIssue | null | undefined,
|
||||
t: TFunction,
|
||||
) => {
|
||||
if (!issue)
|
||||
return ''
|
||||
|
||||
return t(issue.i18nKey, { ns: 'workflow' })
|
||||
}
|
||||
|
||||
export const isKnowledgeBaseEmbeddingIssue = (issue: KnowledgeBaseValidationIssue | null | undefined) => {
|
||||
if (!issue)
|
||||
return false
|
||||
|
||||
return EMBEDDING_ISSUE_CODES.has(issue.code)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user