mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 00:48:04 +08:00
feat: enhance model plugin workflow checks and model provider management UX (#33289)
Signed-off-by: yyh <yuanyouhuilyz@gmail.com> Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: CodingOnStar <hanxujiang@dify.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Coding On Star <447357187@qq.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: -LAN- <laipz8200@outlook.com> Co-authored-by: statxc <tyleradams93226@gmail.com>
This commit is contained in:
@ -0,0 +1,77 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { ChunkStructureEnum } from '../../types'
|
||||
import ChunkStructure from './index'
|
||||
|
||||
const mockUseChunkStructure = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('@/app/components/workflow/nodes/_base/components/layout', () => ({
|
||||
Field: ({ children, fieldTitleProps }: { children: ReactNode, fieldTitleProps: { title: string, warningDot?: boolean, operation?: ReactNode } }) => (
|
||||
<div data-testid="field" data-warning-dot={String(!!fieldTitleProps.warningDot)}>
|
||||
<div>{fieldTitleProps.title}</div>
|
||||
{fieldTitleProps.operation}
|
||||
{children}
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('./hooks', () => ({
|
||||
useChunkStructure: mockUseChunkStructure,
|
||||
}))
|
||||
|
||||
vi.mock('../option-card', () => ({
|
||||
default: ({ title }: { title: string }) => <div data-testid="option-card">{title}</div>,
|
||||
}))
|
||||
|
||||
vi.mock('./selector', () => ({
|
||||
default: ({ trigger, value }: { trigger?: ReactNode, value?: string }) => (
|
||||
<div data-testid="selector">
|
||||
{value ?? 'no-value'}
|
||||
{trigger}
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('./instruction', () => ({
|
||||
default: ({ className }: { className?: string }) => <div data-testid="instruction" className={className}>Instruction</div>,
|
||||
}))
|
||||
|
||||
describe('ChunkStructure', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseChunkStructure.mockReturnValue({
|
||||
options: [{ value: ChunkStructureEnum.general, label: 'General' }],
|
||||
optionMap: {
|
||||
[ChunkStructureEnum.general]: {
|
||||
title: 'General Chunk Structure',
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
it('should render the selected option and warning dot metadata when a chunk structure is chosen', () => {
|
||||
render(
|
||||
<ChunkStructure
|
||||
chunkStructure={ChunkStructureEnum.general}
|
||||
warningDot
|
||||
onChunkStructureChange={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('field')).toHaveAttribute('data-warning-dot', 'true')
|
||||
expect(screen.getByTestId('selector')).toHaveTextContent(ChunkStructureEnum.general)
|
||||
expect(screen.getByTestId('option-card')).toHaveTextContent('General Chunk Structure')
|
||||
expect(screen.queryByTestId('instruction')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render the add trigger and instruction when no chunk structure is selected', () => {
|
||||
render(
|
||||
<ChunkStructure
|
||||
onChunkStructureChange={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByRole('button', { name: /chooseChunkStructure/i })).toBeInTheDocument()
|
||||
expect(screen.getByTestId('instruction')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@ -1,5 +1,4 @@
|
||||
import type { ChunkStructureEnum } from '../../types'
|
||||
import { RiAddLine } from '@remixicon/react'
|
||||
import { memo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Button from '@/app/components/base/button'
|
||||
@ -12,11 +11,13 @@ import Selector from './selector'
|
||||
type ChunkStructureProps = {
|
||||
chunkStructure?: ChunkStructureEnum
|
||||
onChunkStructureChange: (value: ChunkStructureEnum) => void
|
||||
warningDot?: boolean
|
||||
readonly?: boolean
|
||||
}
|
||||
const ChunkStructure = ({
|
||||
chunkStructure,
|
||||
onChunkStructureChange,
|
||||
warningDot = false,
|
||||
readonly = false,
|
||||
}: ChunkStructureProps) => {
|
||||
const { t } = useTranslation()
|
||||
@ -30,6 +31,7 @@ const ChunkStructure = ({
|
||||
fieldTitleProps={{
|
||||
title: t('nodes.knowledgeBase.chunkStructure', { ns: 'workflow' }),
|
||||
tooltip: t('nodes.knowledgeBase.chunkStructureTip.message', { ns: 'workflow' }),
|
||||
warningDot,
|
||||
operation: chunkStructure && (
|
||||
<Selector
|
||||
options={options}
|
||||
@ -62,7 +64,7 @@ const ChunkStructure = ({
|
||||
className="w-full"
|
||||
variant="secondary-accent"
|
||||
>
|
||||
<RiAddLine className="mr-1 h-4 w-4" />
|
||||
<span className="i-ri-add-line mr-1 h-4 w-4" />
|
||||
{t('nodes.knowledgeBase.chooseChunkStructure', { ns: 'workflow' })}
|
||||
</Button>
|
||||
)}
|
||||
|
||||
@ -0,0 +1,62 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import { render } from '@testing-library/react'
|
||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import EmbeddingModel from './embedding-model'
|
||||
|
||||
const mockUseModelList = vi.hoisted(() => vi.fn())
|
||||
const mockModelSelector = vi.hoisted(() => vi.fn(() => <div data-testid="model-selector">selector</div>))
|
||||
|
||||
vi.mock('@/app/components/workflow/nodes/_base/components/layout', () => ({
|
||||
Field: ({ children, fieldTitleProps }: { children: ReactNode, fieldTitleProps: { warningDot?: boolean } }) => (
|
||||
<div data-testid="field" data-warning-dot={String(!!fieldTitleProps.warningDot)}>
|
||||
{children}
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({
|
||||
useModelList: mockUseModelList,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/header/account-setting/model-provider-page/model-selector', () => ({
|
||||
default: mockModelSelector,
|
||||
}))
|
||||
|
||||
describe('EmbeddingModel', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseModelList.mockReturnValue({ data: [{ provider: 'openai', model: 'text-embedding-3-large' }] })
|
||||
})
|
||||
|
||||
it('should pass the selected model configuration and warning state to the selector field', () => {
|
||||
const onEmbeddingModelChange = vi.fn()
|
||||
|
||||
render(
|
||||
<EmbeddingModel
|
||||
embeddingModel="text-embedding-3-large"
|
||||
embeddingModelProvider="openai"
|
||||
warningDot
|
||||
onEmbeddingModelChange={onEmbeddingModelChange}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(mockUseModelList).toHaveBeenCalledWith(ModelTypeEnum.textEmbedding)
|
||||
expect(mockModelSelector).toHaveBeenCalledWith(expect.objectContaining({
|
||||
defaultModel: {
|
||||
provider: 'openai',
|
||||
model: 'text-embedding-3-large',
|
||||
},
|
||||
modelList: [{ provider: 'openai', model: 'text-embedding-3-large' }],
|
||||
readonly: false,
|
||||
showDeprecatedWarnIcon: true,
|
||||
}), undefined)
|
||||
})
|
||||
|
||||
it('should pass an undefined default model when the embedding model is incomplete', () => {
|
||||
render(<EmbeddingModel embeddingModel="text-embedding-3-large" />)
|
||||
|
||||
expect(mockModelSelector).toHaveBeenCalledWith(expect.objectContaining({
|
||||
defaultModel: undefined,
|
||||
}), undefined)
|
||||
})
|
||||
})
|
||||
@ -17,12 +17,14 @@ type EmbeddingModelProps = {
|
||||
embeddingModel: string
|
||||
embeddingModelProvider: string
|
||||
}) => void
|
||||
warningDot?: boolean
|
||||
readonly?: boolean
|
||||
}
|
||||
const EmbeddingModel = ({
|
||||
embeddingModel,
|
||||
embeddingModelProvider,
|
||||
onEmbeddingModelChange,
|
||||
warningDot = false,
|
||||
readonly = false,
|
||||
}: EmbeddingModelProps) => {
|
||||
const { t } = useTranslation()
|
||||
@ -50,6 +52,7 @@ const EmbeddingModel = ({
|
||||
<Field
|
||||
fieldTitleProps={{
|
||||
title: t('form.embeddingModel', { ns: 'datasetSettings' }),
|
||||
warningDot,
|
||||
}}
|
||||
>
|
||||
<ModelSelector
|
||||
|
||||
@ -0,0 +1,121 @@
|
||||
import type {
|
||||
DefaultModel,
|
||||
Model,
|
||||
ModelItem,
|
||||
} from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import {
|
||||
ConfigurationMethodEnum,
|
||||
ModelStatusEnum,
|
||||
ModelTypeEnum,
|
||||
} from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import RerankingModelSelector from './reranking-model-selector'
|
||||
|
||||
type MockModelSelectorProps = {
|
||||
defaultModel?: DefaultModel
|
||||
modelList: Model[]
|
||||
onSelect?: (model: DefaultModel) => void
|
||||
}
|
||||
|
||||
const mockUseModelListAndDefaultModel = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({
|
||||
useModelListAndDefaultModel: mockUseModelListAndDefaultModel,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/header/account-setting/model-provider-page/model-selector', () => ({
|
||||
default: ({ defaultModel, modelList, onSelect }: MockModelSelectorProps) => (
|
||||
<div>
|
||||
<div data-testid="default-model">
|
||||
{defaultModel ? `${defaultModel.provider}/${defaultModel.model}` : 'no-default-model'}
|
||||
</div>
|
||||
<div data-testid="model-list-count">{modelList.length}</div>
|
||||
<button type="button" onClick={() => onSelect?.({ provider: 'cohere', model: 'rerank-v3' })}>
|
||||
select-model
|
||||
</button>
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
const createModelItem = (overrides: Partial<ModelItem> = {}): ModelItem => ({
|
||||
model: 'rerank-v3',
|
||||
label: { en_US: 'Rerank V3', zh_Hans: 'Rerank V3' },
|
||||
model_type: ModelTypeEnum.rerank,
|
||||
fetch_from: ConfigurationMethodEnum.predefinedModel,
|
||||
status: ModelStatusEnum.active,
|
||||
model_properties: {},
|
||||
load_balancing_enabled: false,
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const createModel = (overrides: Partial<Model> = {}): Model => ({
|
||||
provider: 'cohere',
|
||||
icon_small: {
|
||||
en_US: 'https://example.com/cohere.png',
|
||||
zh_Hans: 'https://example.com/cohere.png',
|
||||
},
|
||||
icon_small_dark: {
|
||||
en_US: 'https://example.com/cohere-dark.png',
|
||||
zh_Hans: 'https://example.com/cohere-dark.png',
|
||||
},
|
||||
label: { en_US: 'Cohere', zh_Hans: 'Cohere' },
|
||||
models: [createModelItem()],
|
||||
status: ModelStatusEnum.active,
|
||||
...overrides,
|
||||
})
|
||||
|
||||
describe('RerankingModelSelector', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseModelListAndDefaultModel.mockReturnValue({
|
||||
modelList: [createModel()],
|
||||
defaultModel: undefined,
|
||||
})
|
||||
})
|
||||
|
||||
// Rendering behavior for mapped rerank model state.
|
||||
describe('Rendering', () => {
|
||||
it('should not pass a default model when reranking model fields are empty strings', () => {
|
||||
render(
|
||||
<RerankingModelSelector
|
||||
rerankingModel={{
|
||||
reranking_provider_name: '',
|
||||
reranking_model_name: '',
|
||||
}}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('default-model')).toHaveTextContent('no-default-model')
|
||||
expect(screen.getByTestId('model-list-count')).toHaveTextContent('1')
|
||||
})
|
||||
|
||||
it('should map reranking model to default model when both fields exist', () => {
|
||||
render(
|
||||
<RerankingModelSelector
|
||||
rerankingModel={{
|
||||
reranking_provider_name: 'cohere',
|
||||
reranking_model_name: 'rerank-v3',
|
||||
}}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('default-model')).toHaveTextContent('cohere/rerank-v3')
|
||||
})
|
||||
})
|
||||
|
||||
// Selection behavior should convert back to workflow reranking model shape.
|
||||
describe('Interactions', () => {
|
||||
it('should map selected model back to reranking model fields', () => {
|
||||
const onRerankingModelChange = vi.fn()
|
||||
|
||||
render(<RerankingModelSelector onRerankingModelChange={onRerankingModelChange} />)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'select-model' }))
|
||||
|
||||
expect(onRerankingModelChange).toHaveBeenCalledWith({
|
||||
reranking_provider_name: 'cohere',
|
||||
reranking_model_name: 'rerank-v3',
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -22,12 +22,12 @@ const RerankingModelSelector = ({
|
||||
modelList: rerankModelList,
|
||||
} = useModelListAndDefaultModel(ModelTypeEnum.rerank)
|
||||
const rerankModel = useMemo(() => {
|
||||
if (!rerankingModel)
|
||||
if (!rerankingModel?.reranking_provider_name || !rerankingModel?.reranking_model_name)
|
||||
return undefined
|
||||
|
||||
return {
|
||||
providerName: rerankingModel.reranking_provider_name,
|
||||
modelName: rerankingModel.reranking_model_name,
|
||||
provider: rerankingModel.reranking_provider_name,
|
||||
model: rerankingModel.reranking_model_name,
|
||||
}
|
||||
}, [rerankingModel])
|
||||
|
||||
@ -40,7 +40,7 @@ const RerankingModelSelector = ({
|
||||
|
||||
return (
|
||||
<ModelSelector
|
||||
defaultModel={rerankModel && { provider: rerankModel.providerName, model: rerankModel.modelName }}
|
||||
defaultModel={rerankModel}
|
||||
modelList={rerankModelList}
|
||||
onSelect={handleRerankingModelChange}
|
||||
readonly={readonly}
|
||||
|
||||
@ -0,0 +1,74 @@
|
||||
import type { KnowledgeBaseNodeType } from './types'
|
||||
import type { Model, ModelItem } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import {
|
||||
ConfigurationMethodEnum,
|
||||
ModelStatusEnum,
|
||||
ModelTypeEnum,
|
||||
} from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import nodeDefault from './default'
|
||||
import { ChunkStructureEnum, IndexMethodEnum, RetrievalSearchMethodEnum } from './types'
|
||||
|
||||
const t = (key: string) => key
|
||||
|
||||
const makeEmbeddingModelList = (status: ModelStatusEnum): Model[] => [{
|
||||
provider: 'openai',
|
||||
icon_small: { en_US: '', zh_Hans: '' },
|
||||
label: { en_US: 'OpenAI', zh_Hans: 'OpenAI' },
|
||||
models: [{
|
||||
model: 'text-embedding-3-large',
|
||||
label: { en_US: 'Text Embedding 3 Large', zh_Hans: 'Text Embedding 3 Large' },
|
||||
model_type: ModelTypeEnum.textEmbedding,
|
||||
fetch_from: ConfigurationMethodEnum.predefinedModel,
|
||||
status,
|
||||
model_properties: {},
|
||||
load_balancing_enabled: false,
|
||||
}],
|
||||
status,
|
||||
}]
|
||||
|
||||
const makeEmbeddingProviderModelList = (status: ModelStatusEnum): ModelItem[] => [{
|
||||
model: 'text-embedding-3-large',
|
||||
label: { en_US: 'Text Embedding 3 Large', zh_Hans: 'Text Embedding 3 Large' },
|
||||
model_type: ModelTypeEnum.textEmbedding,
|
||||
fetch_from: ConfigurationMethodEnum.predefinedModel,
|
||||
status,
|
||||
model_properties: {},
|
||||
load_balancing_enabled: false,
|
||||
}]
|
||||
|
||||
const createPayload = (overrides: Partial<KnowledgeBaseNodeType> = {}): KnowledgeBaseNodeType => ({
|
||||
...nodeDefault.defaultValue,
|
||||
index_chunk_variable_selector: ['chunks', 'results'],
|
||||
chunk_structure: ChunkStructureEnum.general,
|
||||
indexing_technique: IndexMethodEnum.QUALIFIED,
|
||||
embedding_model: 'text-embedding-3-large',
|
||||
embedding_model_provider: 'openai',
|
||||
retrieval_model: {
|
||||
...nodeDefault.defaultValue.retrieval_model,
|
||||
search_method: RetrievalSearchMethodEnum.semantic,
|
||||
},
|
||||
_embeddingModelList: makeEmbeddingModelList(ModelStatusEnum.active),
|
||||
_embeddingProviderModelList: makeEmbeddingProviderModelList(ModelStatusEnum.active),
|
||||
_rerankModelList: [],
|
||||
...overrides,
|
||||
}) as KnowledgeBaseNodeType
|
||||
|
||||
describe('knowledge-base default node validation', () => {
|
||||
it('should return an invalid result when the payload has a validation issue', () => {
|
||||
const result = nodeDefault.checkValid(createPayload({ chunk_structure: undefined }), t)
|
||||
|
||||
expect(result).toEqual({
|
||||
isValid: false,
|
||||
errorMessage: 'nodes.knowledgeBase.chunkIsRequired',
|
||||
})
|
||||
})
|
||||
|
||||
it('should return a valid result when the payload is complete', () => {
|
||||
const result = nodeDefault.checkValid(createPayload(), t)
|
||||
|
||||
expect(result).toEqual({
|
||||
isValid: true,
|
||||
errorMessage: '',
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -1,8 +1,11 @@
|
||||
import type { NodeDefault } from '../../types'
|
||||
import type { KnowledgeBaseNodeType } from './types'
|
||||
import { IndexingType } from '@/app/components/datasets/create/step-two'
|
||||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
import { genNodeMetaData } from '@/app/components/workflow/utils'
|
||||
import {
|
||||
getKnowledgeBaseValidationIssue,
|
||||
getKnowledgeBaseValidationMessage,
|
||||
} from './utils'
|
||||
|
||||
const metaData = genNodeMetaData({
|
||||
sort: 3.1,
|
||||
@ -24,86 +27,9 @@ const nodeDefault: NodeDefault<KnowledgeBaseNodeType> = {
|
||||
},
|
||||
},
|
||||
checkValid(payload, t) {
|
||||
const {
|
||||
chunk_structure,
|
||||
indexing_technique,
|
||||
retrieval_model,
|
||||
embedding_model,
|
||||
embedding_model_provider,
|
||||
index_chunk_variable_selector,
|
||||
_embeddingModelList,
|
||||
_rerankModelList,
|
||||
} = payload
|
||||
|
||||
const {
|
||||
search_method,
|
||||
reranking_enable,
|
||||
reranking_model,
|
||||
} = retrieval_model || {}
|
||||
|
||||
const currentEmbeddingModelProvider = _embeddingModelList?.find(provider => provider.provider === embedding_model_provider)
|
||||
const currentEmbeddingModel = currentEmbeddingModelProvider?.models.find(model => model.model === embedding_model)
|
||||
|
||||
const currentRerankingModelProvider = _rerankModelList?.find(provider => provider.provider === reranking_model?.reranking_provider_name)
|
||||
const currentRerankingModel = currentRerankingModelProvider?.models.find(model => model.model === reranking_model?.reranking_model_name)
|
||||
|
||||
if (!chunk_structure) {
|
||||
return {
|
||||
isValid: false,
|
||||
errorMessage: t('nodes.knowledgeBase.chunkIsRequired', { ns: 'workflow' }),
|
||||
}
|
||||
}
|
||||
|
||||
if (index_chunk_variable_selector.length === 0) {
|
||||
return {
|
||||
isValid: false,
|
||||
errorMessage: t('nodes.knowledgeBase.chunksVariableIsRequired', { ns: 'workflow' }),
|
||||
}
|
||||
}
|
||||
|
||||
if (!indexing_technique) {
|
||||
return {
|
||||
isValid: false,
|
||||
errorMessage: t('nodes.knowledgeBase.indexMethodIsRequired', { ns: 'workflow' }),
|
||||
}
|
||||
}
|
||||
|
||||
if (indexing_technique === IndexingType.QUALIFIED) {
|
||||
if (!embedding_model || !embedding_model_provider) {
|
||||
return {
|
||||
isValid: false,
|
||||
errorMessage: t('nodes.knowledgeBase.embeddingModelIsRequired', { ns: 'workflow' }),
|
||||
}
|
||||
}
|
||||
else if (!currentEmbeddingModel) {
|
||||
return {
|
||||
isValid: false,
|
||||
errorMessage: t('nodes.knowledgeBase.embeddingModelIsInvalid', { ns: 'workflow' }),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!retrieval_model || !search_method) {
|
||||
return {
|
||||
isValid: false,
|
||||
errorMessage: t('nodes.knowledgeBase.retrievalSettingIsRequired', { ns: 'workflow' }),
|
||||
}
|
||||
}
|
||||
|
||||
if (reranking_enable) {
|
||||
if (!reranking_model || !reranking_model.reranking_provider_name || !reranking_model.reranking_model_name) {
|
||||
return {
|
||||
isValid: false,
|
||||
errorMessage: t('nodes.knowledgeBase.rerankingModelIsRequired', { ns: 'workflow' }),
|
||||
}
|
||||
}
|
||||
else if (!currentRerankingModel) {
|
||||
return {
|
||||
isValid: false,
|
||||
errorMessage: t('nodes.knowledgeBase.rerankingModelIsInvalid', { ns: 'workflow' }),
|
||||
}
|
||||
}
|
||||
}
|
||||
const issue = getKnowledgeBaseValidationIssue(payload)
|
||||
if (issue)
|
||||
return { isValid: false, errorMessage: getKnowledgeBaseValidationMessage(issue, t) }
|
||||
|
||||
return {
|
||||
isValid: true,
|
||||
|
||||
@ -0,0 +1,61 @@
|
||||
import type {
|
||||
Model,
|
||||
ModelItem,
|
||||
ModelProvider,
|
||||
} from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { useMemo } from 'react'
|
||||
import { deriveModelStatus } from '@/app/components/header/account-setting/model-provider-page/derive-model-status'
|
||||
import { useCredentialPanelState } from '@/app/components/header/account-setting/model-provider-page/provider-added-card/use-credential-panel-state'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
|
||||
type UseEmbeddingModelStatusProps = {
|
||||
embeddingModel?: string
|
||||
embeddingModelProvider?: string
|
||||
embeddingModelList: Model[]
|
||||
}
|
||||
|
||||
type UseEmbeddingModelStatusResult = {
|
||||
providerMeta: ModelProvider | undefined
|
||||
modelProvider: Model | undefined
|
||||
currentModel: ModelItem | undefined
|
||||
status: ReturnType<typeof deriveModelStatus>
|
||||
}
|
||||
|
||||
export const useEmbeddingModelStatus = ({
|
||||
embeddingModel,
|
||||
embeddingModelProvider,
|
||||
embeddingModelList,
|
||||
}: UseEmbeddingModelStatusProps): UseEmbeddingModelStatusResult => {
|
||||
const { modelProviders } = useProviderContext()
|
||||
|
||||
const providerMeta = useMemo(() => {
|
||||
return modelProviders.find(provider => provider.provider === embeddingModelProvider)
|
||||
}, [embeddingModelProvider, modelProviders])
|
||||
|
||||
const modelProvider = useMemo(() => {
|
||||
return embeddingModelList.find(provider => provider.provider === embeddingModelProvider)
|
||||
}, [embeddingModelList, embeddingModelProvider])
|
||||
|
||||
const currentModel = useMemo(() => {
|
||||
return modelProvider?.models.find(model => model.model === embeddingModel)
|
||||
}, [embeddingModel, modelProvider])
|
||||
|
||||
const credentialState = useCredentialPanelState(providerMeta)
|
||||
|
||||
const status = useMemo(() => {
|
||||
return deriveModelStatus(
|
||||
embeddingModel,
|
||||
embeddingModelProvider,
|
||||
providerMeta,
|
||||
currentModel,
|
||||
credentialState,
|
||||
)
|
||||
}, [credentialState, currentModel, embeddingModel, embeddingModelProvider, providerMeta])
|
||||
|
||||
return {
|
||||
providerMeta,
|
||||
modelProvider,
|
||||
currentModel,
|
||||
status,
|
||||
}
|
||||
}
|
||||
233
web/app/components/workflow/nodes/knowledge-base/node.spec.tsx
Normal file
233
web/app/components/workflow/nodes/knowledge-base/node.spec.tsx
Normal file
@ -0,0 +1,233 @@
|
||||
import type { KnowledgeBaseNodeType } from './types'
|
||||
import type { ModelItem } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import type { CommonNodeType } from '@/app/components/workflow/types'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import {
|
||||
ConfigurationMethodEnum,
|
||||
ModelStatusEnum,
|
||||
ModelTypeEnum,
|
||||
} from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
import Node from './node'
|
||||
import {
|
||||
ChunkStructureEnum,
|
||||
IndexMethodEnum,
|
||||
RetrievalSearchMethodEnum,
|
||||
} from './types'
|
||||
|
||||
const mockUseModelList = vi.hoisted(() => vi.fn())
|
||||
const mockUseSettingsDisplay = vi.hoisted(() => vi.fn())
|
||||
const mockUseEmbeddingModelStatus = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('@tanstack/react-query', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('@tanstack/react-query')>()
|
||||
return {
|
||||
...actual,
|
||||
useQuery: () => ({ data: undefined }),
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('@/app/components/header/account-setting/model-provider-page/hooks')>()
|
||||
return {
|
||||
...actual,
|
||||
useLanguage: () => 'en_US',
|
||||
useModelList: mockUseModelList,
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('./hooks/use-settings-display', () => ({
|
||||
useSettingsDisplay: mockUseSettingsDisplay,
|
||||
}))
|
||||
|
||||
vi.mock('./hooks/use-embedding-model-status', () => ({
|
||||
useEmbeddingModelStatus: mockUseEmbeddingModelStatus,
|
||||
}))
|
||||
|
||||
const createModelItem = (overrides: Partial<ModelItem> = {}): ModelItem => ({
|
||||
model: 'text-embedding-3-large',
|
||||
label: { en_US: 'Text Embedding 3 Large', zh_Hans: 'Text Embedding 3 Large' },
|
||||
model_type: ModelTypeEnum.textEmbedding,
|
||||
fetch_from: ConfigurationMethodEnum.predefinedModel,
|
||||
status: ModelStatusEnum.active,
|
||||
model_properties: {},
|
||||
load_balancing_enabled: false,
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const createNodeData = (overrides: Partial<CommonNodeType<KnowledgeBaseNodeType>> = {}): CommonNodeType<KnowledgeBaseNodeType> => ({
|
||||
title: 'Knowledge Base',
|
||||
desc: '',
|
||||
type: BlockEnum.KnowledgeBase,
|
||||
index_chunk_variable_selector: ['result'],
|
||||
chunk_structure: ChunkStructureEnum.general,
|
||||
indexing_technique: IndexMethodEnum.QUALIFIED,
|
||||
embedding_model: 'text-embedding-3-large',
|
||||
embedding_model_provider: 'openai',
|
||||
keyword_number: 10,
|
||||
retrieval_model: {
|
||||
top_k: 3,
|
||||
score_threshold_enabled: false,
|
||||
score_threshold: 0.5,
|
||||
search_method: RetrievalSearchMethodEnum.semantic,
|
||||
},
|
||||
...overrides,
|
||||
})
|
||||
|
||||
describe('KnowledgeBaseNode', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseModelList.mockReturnValue({ data: [] })
|
||||
mockUseSettingsDisplay.mockReturnValue({
|
||||
[IndexMethodEnum.QUALIFIED]: 'High Quality',
|
||||
[RetrievalSearchMethodEnum.semantic]: 'Vector Search',
|
||||
})
|
||||
mockUseEmbeddingModelStatus.mockReturnValue({
|
||||
providerMeta: undefined,
|
||||
modelProvider: undefined,
|
||||
currentModel: createModelItem(),
|
||||
status: 'active',
|
||||
})
|
||||
})
|
||||
|
||||
// Embedding model row should mirror the selector status labels.
|
||||
describe('Embedding Model Status', () => {
|
||||
it('should render active embedding model label when the model is available', () => {
|
||||
render(<Node id="knowledge-base-1" data={createNodeData()} />)
|
||||
|
||||
expect(screen.getByText('Text Embedding 3 Large')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render configure required when embedding model status requires configuration', () => {
|
||||
mockUseEmbeddingModelStatus.mockReturnValue({
|
||||
providerMeta: undefined,
|
||||
modelProvider: undefined,
|
||||
currentModel: createModelItem({ status: ModelStatusEnum.noConfigure }),
|
||||
status: 'configure-required',
|
||||
})
|
||||
|
||||
render(<Node id="knowledge-base-1" data={createNodeData()} />)
|
||||
|
||||
expect(screen.getByText('common.modelProvider.selector.configureRequired')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render disabled when embedding model status is disabled', () => {
|
||||
mockUseEmbeddingModelStatus.mockReturnValue({
|
||||
providerMeta: undefined,
|
||||
modelProvider: undefined,
|
||||
currentModel: createModelItem({ status: ModelStatusEnum.disabled }),
|
||||
status: 'disabled',
|
||||
})
|
||||
|
||||
render(<Node id="knowledge-base-1" data={createNodeData()} />)
|
||||
|
||||
expect(screen.getByText('common.modelProvider.selector.disabled')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render incompatible when embedding model status is incompatible', () => {
|
||||
mockUseEmbeddingModelStatus.mockReturnValue({
|
||||
providerMeta: undefined,
|
||||
modelProvider: undefined,
|
||||
currentModel: undefined,
|
||||
status: 'incompatible',
|
||||
})
|
||||
|
||||
render(<Node id="knowledge-base-1" data={createNodeData()} />)
|
||||
|
||||
expect(screen.getByText('common.modelProvider.selector.incompatible')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render configure model prompt when no embedding model is selected', () => {
|
||||
mockUseEmbeddingModelStatus.mockReturnValue({
|
||||
providerMeta: undefined,
|
||||
modelProvider: undefined,
|
||||
currentModel: undefined,
|
||||
status: 'empty',
|
||||
})
|
||||
|
||||
render(
|
||||
<Node
|
||||
id="knowledge-base-1"
|
||||
data={createNodeData({
|
||||
embedding_model: undefined,
|
||||
embedding_model_provider: undefined,
|
||||
})}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('plugin.detailPanel.configureModel')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Validation warnings', () => {
|
||||
it('should render a warning banner when chunk structure is missing', () => {
|
||||
render(
|
||||
<Node
|
||||
id="knowledge-base-1"
|
||||
data={createNodeData({
|
||||
chunk_structure: undefined,
|
||||
})}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText(/chunkIsRequired/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render a warning value for the chunks input row when no chunk variable is selected', () => {
|
||||
render(
|
||||
<Node
|
||||
id="knowledge-base-1"
|
||||
data={createNodeData({
|
||||
index_chunk_variable_selector: [],
|
||||
})}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText(/chunksVariableIsRequired/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render a warning value for retrieval settings when reranking is incomplete', () => {
|
||||
mockUseModelList.mockImplementation((modelType: ModelTypeEnum) => {
|
||||
if (modelType === ModelTypeEnum.textEmbedding) {
|
||||
return {
|
||||
data: [{
|
||||
provider: 'openai',
|
||||
models: [createModelItem()],
|
||||
}],
|
||||
}
|
||||
}
|
||||
return { data: [] }
|
||||
})
|
||||
|
||||
render(
|
||||
<Node
|
||||
id="knowledge-base-1"
|
||||
data={createNodeData({
|
||||
retrieval_model: {
|
||||
top_k: 3,
|
||||
score_threshold_enabled: false,
|
||||
score_threshold: 0.5,
|
||||
search_method: RetrievalSearchMethodEnum.semantic,
|
||||
reranking_enable: true,
|
||||
},
|
||||
})}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText(/rerankingModelIsRequired/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should hide the embedding model row when the index method is not qualified', () => {
|
||||
render(
|
||||
<Node
|
||||
id="knowledge-base-1"
|
||||
data={createNodeData({
|
||||
indexing_technique: IndexMethodEnum.ECONOMICAL,
|
||||
})}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByText('Text Embedding 3 Large')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -1,38 +1,210 @@
|
||||
import type { FC } from 'react'
|
||||
import type { KnowledgeBaseNodeType } from './types'
|
||||
import type { NodeProps } from '@/app/components/workflow/types'
|
||||
import { memo } from 'react'
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import {
|
||||
memo,
|
||||
useMemo,
|
||||
} from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import {
|
||||
ModelTypeEnum,
|
||||
} from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { DERIVED_MODEL_STATUS_BADGE_I18N } from '@/app/components/header/account-setting/model-provider-page/derive-model-status'
|
||||
import {
|
||||
useLanguage,
|
||||
useModelList,
|
||||
} from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { useEmbeddingModelStatus } from './hooks/use-embedding-model-status'
|
||||
import { useSettingsDisplay } from './hooks/use-settings-display'
|
||||
import {
|
||||
IndexMethodEnum,
|
||||
} from './types'
|
||||
import {
|
||||
getKnowledgeBaseValidationIssue,
|
||||
getKnowledgeBaseValidationMessage,
|
||||
KnowledgeBaseValidationIssueCode,
|
||||
} from './utils'
|
||||
|
||||
type SettingRowProps = {
|
||||
label: string
|
||||
value: string
|
||||
warning?: boolean
|
||||
}
|
||||
|
||||
const SettingRow = memo(({
|
||||
label,
|
||||
value,
|
||||
warning = false,
|
||||
}: SettingRowProps) => {
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
'flex h-6 items-center rounded-md px-1.5',
|
||||
warning
|
||||
? 'border-[0.5px] border-state-warning-active bg-state-warning-hover'
|
||||
: 'bg-workflow-block-parma-bg',
|
||||
)}
|
||||
>
|
||||
<div className="mr-2 shrink-0 text-text-tertiary system-xs-medium-uppercase">
|
||||
{label}
|
||||
</div>
|
||||
<div
|
||||
className={cn('grow truncate text-right system-xs-medium', warning ? 'text-text-warning' : 'text-text-secondary')}
|
||||
title={value}
|
||||
>
|
||||
{value}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
})
|
||||
|
||||
const RETRIEVAL_WARNING_CODES = new Set<KnowledgeBaseValidationIssueCode>([
|
||||
KnowledgeBaseValidationIssueCode.retrievalSettingRequired,
|
||||
KnowledgeBaseValidationIssueCode.rerankingModelRequired,
|
||||
KnowledgeBaseValidationIssueCode.rerankingModelInvalid,
|
||||
])
|
||||
|
||||
const Node: FC<NodeProps<KnowledgeBaseNodeType>> = ({ data }) => {
|
||||
const { t } = useTranslation()
|
||||
const language = useLanguage()
|
||||
const settingsDisplay = useSettingsDisplay()
|
||||
const { data: embeddingModelList } = useModelList(ModelTypeEnum.textEmbedding)
|
||||
const { data: rerankModelList } = useModelList(ModelTypeEnum.rerank)
|
||||
const chunkStructure = data.chunk_structure
|
||||
const indexChunkVariableSelector = data.index_chunk_variable_selector
|
||||
const indexingTechnique = data.indexing_technique
|
||||
const embeddingModel = data.embedding_model
|
||||
const retrievalModel = data.retrieval_model
|
||||
const retrievalSearchMethod = retrievalModel?.search_method
|
||||
const retrievalRerankingEnable = retrievalModel?.reranking_enable
|
||||
const embeddingModelProvider = data.embedding_model_provider
|
||||
const { data: embeddingProviderModelList } = useQuery(
|
||||
consoleQuery.modelProviders.models.queryOptions({
|
||||
input: { params: { provider: embeddingModelProvider || '' } },
|
||||
enabled: indexingTechnique === IndexMethodEnum.QUALIFIED && !!embeddingModelProvider,
|
||||
refetchOnWindowFocus: false,
|
||||
select: response => response.data,
|
||||
}),
|
||||
)
|
||||
|
||||
const validationPayload = useMemo(() => {
|
||||
return {
|
||||
chunk_structure: chunkStructure,
|
||||
index_chunk_variable_selector: indexChunkVariableSelector,
|
||||
indexing_technique: indexingTechnique,
|
||||
embedding_model: embeddingModel,
|
||||
embedding_model_provider: embeddingModelProvider,
|
||||
retrieval_model: {
|
||||
search_method: retrievalSearchMethod,
|
||||
reranking_enable: retrievalRerankingEnable,
|
||||
reranking_model: retrievalModel?.reranking_model,
|
||||
},
|
||||
_embeddingModelList: embeddingModelList,
|
||||
_embeddingProviderModelList: embeddingProviderModelList,
|
||||
_rerankModelList: rerankModelList,
|
||||
}
|
||||
}, [
|
||||
chunkStructure,
|
||||
indexChunkVariableSelector,
|
||||
indexingTechnique,
|
||||
embeddingModel,
|
||||
embeddingModelProvider,
|
||||
retrievalSearchMethod,
|
||||
retrievalRerankingEnable,
|
||||
retrievalModel?.reranking_model,
|
||||
embeddingModelList,
|
||||
embeddingProviderModelList,
|
||||
rerankModelList,
|
||||
])
|
||||
|
||||
const validationIssue = useMemo(() => {
|
||||
return getKnowledgeBaseValidationIssue({
|
||||
...validationPayload,
|
||||
})
|
||||
}, [validationPayload])
|
||||
|
||||
const validationIssueMessage = useMemo(() => {
|
||||
return getKnowledgeBaseValidationMessage(validationIssue, t)
|
||||
}, [validationIssue, t])
|
||||
const { currentModel: currentEmbeddingModel, status: embeddingModelStatus } = useEmbeddingModelStatus({
|
||||
embeddingModel: data.embedding_model,
|
||||
embeddingModelProvider: data.embedding_model_provider,
|
||||
embeddingModelList,
|
||||
})
|
||||
|
||||
const chunksDisplayValue = useMemo(() => {
|
||||
if (!data.index_chunk_variable_selector?.length)
|
||||
return '-'
|
||||
|
||||
const chunkVar = data.index_chunk_variable_selector.at(-1)
|
||||
return chunkVar || '-'
|
||||
}, [data.index_chunk_variable_selector])
|
||||
|
||||
const embeddingModelDisplay = useMemo(() => {
|
||||
if (data.indexing_technique !== IndexMethodEnum.QUALIFIED)
|
||||
return '-'
|
||||
|
||||
if (embeddingModelStatus === 'empty')
|
||||
return t('detailPanel.configureModel', { ns: 'plugin' })
|
||||
|
||||
if (embeddingModelStatus !== 'active') {
|
||||
const statusI18nKey = DERIVED_MODEL_STATUS_BADGE_I18N[embeddingModelStatus as keyof typeof DERIVED_MODEL_STATUS_BADGE_I18N]
|
||||
if (statusI18nKey)
|
||||
return t(statusI18nKey as 'modelProvider.selector.incompatible', { ns: 'common' })
|
||||
}
|
||||
|
||||
return currentEmbeddingModel?.label[language] || currentEmbeddingModel?.label.en_US || data.embedding_model || '-'
|
||||
}, [currentEmbeddingModel, data.embedding_model, data.indexing_technique, embeddingModelStatus, language, t])
|
||||
|
||||
const indexMethodDisplay = settingsDisplay[data.indexing_technique as keyof typeof settingsDisplay] || '-'
|
||||
const retrievalMethodDisplay = settingsDisplay[data.retrieval_model?.search_method as keyof typeof settingsDisplay] || '-'
|
||||
|
||||
const chunksWarning = validationIssue?.code === KnowledgeBaseValidationIssueCode.chunksVariableRequired
|
||||
const indexMethodWarning = validationIssue?.code === KnowledgeBaseValidationIssueCode.indexMethodRequired
|
||||
const embeddingWarning = data.indexing_technique === IndexMethodEnum.QUALIFIED && embeddingModelStatus !== 'active'
|
||||
const showEmbeddingModelRow = data.indexing_technique === IndexMethodEnum.QUALIFIED
|
||||
const retrievalWarning = !!(validationIssue && RETRIEVAL_WARNING_CODES.has(validationIssue.code))
|
||||
|
||||
if (!data.chunk_structure) {
|
||||
return (
|
||||
<div className="mb-1 space-y-0.5 px-3 py-1">
|
||||
<div className="flex h-6 items-center rounded-md border-[0.5px] border-state-warning-active bg-state-warning-hover px-1.5">
|
||||
<span className="mr-1 size-[4px] shrink-0 rounded-[2px] bg-text-warning-secondary" />
|
||||
<div className="grow truncate text-text-warning system-xs-medium" title={validationIssueMessage}>
|
||||
{validationIssueMessage}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="mb-1 space-y-0.5 px-3 py-1">
|
||||
<div className="flex h-6 items-center rounded-md bg-workflow-block-parma-bg px-1.5">
|
||||
<div className="system-xs-medium-uppercase mr-2 shrink-0 text-text-tertiary">
|
||||
{t('stepTwo.indexMode', { ns: 'datasetCreation' })}
|
||||
</div>
|
||||
<div
|
||||
className="system-xs-medium grow truncate text-right text-text-secondary"
|
||||
title={data.indexing_technique}
|
||||
>
|
||||
{settingsDisplay[data.indexing_technique as keyof typeof settingsDisplay]}
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex h-6 items-center rounded-md bg-workflow-block-parma-bg px-1.5">
|
||||
<div className="system-xs-medium-uppercase mr-2 shrink-0 text-text-tertiary">
|
||||
{t('form.retrievalSetting.title', { ns: 'datasetSettings' })}
|
||||
</div>
|
||||
<div
|
||||
className="system-xs-medium grow truncate text-right text-text-secondary"
|
||||
title={data.retrieval_model?.search_method}
|
||||
>
|
||||
{settingsDisplay[data.retrieval_model?.search_method as keyof typeof settingsDisplay]}
|
||||
</div>
|
||||
</div>
|
||||
<SettingRow
|
||||
label={t('nodes.knowledgeBase.chunksInput', { ns: 'workflow' })}
|
||||
value={chunksWarning ? validationIssueMessage : chunksDisplayValue}
|
||||
warning={chunksWarning}
|
||||
/>
|
||||
<SettingRow
|
||||
label={t('stepTwo.indexMode', { ns: 'datasetCreation' })}
|
||||
value={indexMethodWarning ? validationIssueMessage : indexMethodDisplay}
|
||||
warning={indexMethodWarning}
|
||||
/>
|
||||
{showEmbeddingModelRow && (
|
||||
<SettingRow
|
||||
label={t('form.embeddingModel', { ns: 'datasetSettings' })}
|
||||
value={embeddingModelDisplay}
|
||||
warning={embeddingWarning}
|
||||
/>
|
||||
)}
|
||||
<SettingRow
|
||||
label={t('form.retrievalSetting.method', { ns: 'datasetSettings' })}
|
||||
value={retrievalWarning ? validationIssueMessage : retrievalMethodDisplay}
|
||||
warning={retrievalWarning}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
198
web/app/components/workflow/nodes/knowledge-base/panel.spec.tsx
Normal file
198
web/app/components/workflow/nodes/knowledge-base/panel.spec.tsx
Normal file
@ -0,0 +1,198 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import type { PanelProps } from '@/types/workflow'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import Panel from './panel'
|
||||
import { ChunkStructureEnum, IndexMethodEnum, RetrievalSearchMethodEnum } from './types'
|
||||
|
||||
const mockUseModelList = vi.hoisted(() => vi.fn())
|
||||
const mockUseQuery = vi.hoisted(() => vi.fn())
|
||||
const mockUseEmbeddingModelStatus = vi.hoisted(() => vi.fn())
|
||||
const mockChunkStructure = vi.hoisted(() => vi.fn(() => <div data-testid="chunk-structure" />))
|
||||
const mockEmbeddingModel = vi.hoisted(() => vi.fn(() => <div data-testid="embedding-model" />))
|
||||
const mockSummaryIndexSetting = vi.hoisted(() => vi.fn(() => <div data-testid="summary-index-setting" />))
|
||||
const mockQueryOptions = vi.hoisted(() => vi.fn((options: unknown) => options))
|
||||
|
||||
vi.mock('@tanstack/react-query', () => ({
|
||||
useQuery: mockUseQuery,
|
||||
}))
|
||||
|
||||
vi.mock('@/service/client', () => ({
|
||||
consoleQuery: {
|
||||
modelProviders: {
|
||||
models: {
|
||||
queryOptions: mockQueryOptions,
|
||||
},
|
||||
},
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({
|
||||
useModelList: mockUseModelList,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/hooks', () => ({
|
||||
useNodesReadOnly: () => ({ nodesReadOnly: false }),
|
||||
}))
|
||||
|
||||
vi.mock('./hooks/use-config', () => ({
|
||||
useConfig: () => ({
|
||||
handleChunkStructureChange: vi.fn(),
|
||||
handleIndexMethodChange: vi.fn(),
|
||||
handleKeywordNumberChange: vi.fn(),
|
||||
handleEmbeddingModelChange: vi.fn(),
|
||||
handleRetrievalSearchMethodChange: vi.fn(),
|
||||
handleHybridSearchModeChange: vi.fn(),
|
||||
handleRerankingModelEnabledChange: vi.fn(),
|
||||
handleWeighedScoreChange: vi.fn(),
|
||||
handleRerankingModelChange: vi.fn(),
|
||||
handleTopKChange: vi.fn(),
|
||||
handleScoreThresholdChange: vi.fn(),
|
||||
handleScoreThresholdEnabledChange: vi.fn(),
|
||||
handleInputVariableChange: vi.fn(),
|
||||
handleSummaryIndexSettingChange: vi.fn(),
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('./hooks/use-embedding-model-status', () => ({
|
||||
useEmbeddingModelStatus: mockUseEmbeddingModelStatus,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/datasets/settings/utils', () => ({
|
||||
checkShowMultiModalTip: () => false,
|
||||
}))
|
||||
|
||||
vi.mock('@/config', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('@/config')>()
|
||||
return {
|
||||
...actual,
|
||||
IS_CE_EDITION: true,
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/app/components/workflow/nodes/_base/components/layout', () => ({
|
||||
Group: ({ children }: { children: ReactNode }) => <div data-testid="group">{children}</div>,
|
||||
BoxGroup: ({ children }: { children: ReactNode }) => <div data-testid="box-group">{children}</div>,
|
||||
BoxGroupField: ({ children, fieldProps }: { children: ReactNode, fieldProps: { fieldTitleProps: { warningDot?: boolean } } }) => (
|
||||
<div data-testid="box-group-field" data-warning-dot={String(!!fieldProps.fieldTitleProps.warningDot)}>
|
||||
{children}
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/nodes/_base/components/variable/var-reference-picker', () => ({
|
||||
default: () => <div data-testid="var-reference-picker" />,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/nodes/_base/components/split', () => ({
|
||||
default: () => <div data-testid="split" />,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/datasets/settings/summary-index-setting', () => ({
|
||||
default: mockSummaryIndexSetting,
|
||||
}))
|
||||
|
||||
vi.mock('./components/chunk-structure', () => ({
|
||||
default: mockChunkStructure,
|
||||
}))
|
||||
|
||||
vi.mock('./components/index-method', () => ({
|
||||
default: () => <div data-testid="index-method" />,
|
||||
}))
|
||||
|
||||
vi.mock('./components/embedding-model', () => ({
|
||||
default: mockEmbeddingModel,
|
||||
}))
|
||||
|
||||
vi.mock('./components/retrieval-setting', () => ({
|
||||
default: () => <div data-testid="retrieval-setting" />,
|
||||
}))
|
||||
|
||||
const createData = (overrides: Record<string, unknown> = {}) => ({
|
||||
index_chunk_variable_selector: ['chunks', 'results'],
|
||||
chunk_structure: ChunkStructureEnum.general,
|
||||
indexing_technique: IndexMethodEnum.QUALIFIED,
|
||||
embedding_model: 'text-embedding-3-large',
|
||||
embedding_model_provider: 'openai',
|
||||
keyword_number: 10,
|
||||
retrieval_model: {
|
||||
search_method: RetrievalSearchMethodEnum.semantic,
|
||||
reranking_enable: false,
|
||||
top_k: 3,
|
||||
score_threshold_enabled: false,
|
||||
score_threshold: 0.5,
|
||||
},
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const panelProps: PanelProps = {
|
||||
getInputVars: () => [],
|
||||
toVarInputs: () => [],
|
||||
runInputData: {},
|
||||
runInputDataRef: { current: {} },
|
||||
setRunInputData: vi.fn(),
|
||||
runResult: undefined,
|
||||
}
|
||||
|
||||
describe('KnowledgeBasePanel', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseQuery.mockReturnValue({ data: undefined })
|
||||
mockUseModelList.mockImplementation((modelType: ModelTypeEnum) => {
|
||||
if (modelType === ModelTypeEnum.textEmbedding) {
|
||||
return {
|
||||
data: [{
|
||||
provider: 'openai',
|
||||
models: [{ model: 'text-embedding-3-large' }],
|
||||
}],
|
||||
}
|
||||
}
|
||||
return { data: [] }
|
||||
})
|
||||
mockUseEmbeddingModelStatus.mockReturnValue({ status: 'active' })
|
||||
})
|
||||
|
||||
it('should show a warning dot on chunk structure and skip nested sections when chunk structure is missing', () => {
|
||||
render(<Panel id="knowledge-base-1" data={createData({ chunk_structure: undefined }) as never} panelProps={panelProps} />)
|
||||
|
||||
expect(mockChunkStructure).toHaveBeenCalledWith(expect.objectContaining({
|
||||
warningDot: true,
|
||||
}), undefined)
|
||||
expect(screen.queryByTestId('box-group-field')).not.toBeInTheDocument()
|
||||
expect(mockQueryOptions).toHaveBeenCalledWith(expect.objectContaining({
|
||||
enabled: true,
|
||||
}))
|
||||
})
|
||||
|
||||
it('should pass warning dots and render summary settings when the qualified configuration needs attention', () => {
|
||||
mockUseEmbeddingModelStatus.mockReturnValue({ status: 'disabled' })
|
||||
|
||||
render(<Panel id="knowledge-base-1" data={createData({ index_chunk_variable_selector: [] }) as never} panelProps={panelProps} />)
|
||||
|
||||
expect(screen.getByTestId('box-group-field')).toHaveAttribute('data-warning-dot', 'true')
|
||||
expect(mockEmbeddingModel).toHaveBeenCalledWith(expect.objectContaining({
|
||||
warningDot: true,
|
||||
}), undefined)
|
||||
expect(mockQueryOptions).toHaveBeenCalledWith(expect.objectContaining({
|
||||
input: { params: { provider: 'openai' } },
|
||||
enabled: true,
|
||||
}))
|
||||
expect(screen.getByTestId('summary-index-setting')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should hide embedding and summary settings for non-qualified index methods', () => {
|
||||
render(
|
||||
<Panel
|
||||
id="knowledge-base-1"
|
||||
data={createData({ indexing_technique: IndexMethodEnum.ECONOMICAL }) as never}
|
||||
panelProps={panelProps}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByTestId('embedding-model')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTestId('summary-index-setting')).not.toBeInTheDocument()
|
||||
expect(mockQueryOptions).toHaveBeenCalledWith(expect.objectContaining({
|
||||
enabled: false,
|
||||
}))
|
||||
})
|
||||
})
|
||||
@ -1,6 +1,7 @@
|
||||
import type { FC } from 'react'
|
||||
import type { KnowledgeBaseNodeType } from './types'
|
||||
import type { NodePanelProps, Var } from '@/app/components/workflow/types'
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import {
|
||||
memo,
|
||||
useCallback,
|
||||
@ -19,16 +20,22 @@ import {
|
||||
} from '@/app/components/workflow/nodes/_base/components/layout'
|
||||
import VarReferencePicker from '@/app/components/workflow/nodes/_base/components/variable/var-reference-picker'
|
||||
import { IS_CE_EDITION } from '@/config'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
import Split from '../_base/components/split'
|
||||
import ChunkStructure from './components/chunk-structure'
|
||||
import EmbeddingModel from './components/embedding-model'
|
||||
import IndexMethod from './components/index-method'
|
||||
import RetrievalSetting from './components/retrieval-setting'
|
||||
import { useConfig } from './hooks/use-config'
|
||||
import { useEmbeddingModelStatus } from './hooks/use-embedding-model-status'
|
||||
import {
|
||||
ChunkStructureEnum,
|
||||
IndexMethodEnum,
|
||||
} from './types'
|
||||
import {
|
||||
getKnowledgeBaseValidationIssue,
|
||||
KnowledgeBaseValidationIssueCode,
|
||||
} from './utils'
|
||||
|
||||
const Panel: FC<NodePanelProps<KnowledgeBaseNodeType>> = ({
|
||||
id,
|
||||
@ -38,6 +45,22 @@ const Panel: FC<NodePanelProps<KnowledgeBaseNodeType>> = ({
|
||||
const { nodesReadOnly } = useNodesReadOnly()
|
||||
const { data: embeddingModelList } = useModelList(ModelTypeEnum.textEmbedding)
|
||||
const { data: rerankModelList } = useModelList(ModelTypeEnum.rerank)
|
||||
const chunkStructure = data.chunk_structure
|
||||
const indexChunkVariableSelector = data.index_chunk_variable_selector
|
||||
const indexingTechnique = data.indexing_technique
|
||||
const embeddingModel = data.embedding_model
|
||||
const retrievalModel = data.retrieval_model
|
||||
const retrievalSearchMethod = retrievalModel?.search_method
|
||||
const retrievalRerankingEnable = retrievalModel?.reranking_enable
|
||||
const embeddingModelProvider = data.embedding_model_provider
|
||||
const { data: embeddingProviderModelList } = useQuery(
|
||||
consoleQuery.modelProviders.models.queryOptions({
|
||||
input: { params: { provider: embeddingModelProvider || '' } },
|
||||
enabled: indexingTechnique === IndexMethodEnum.QUALIFIED && !!embeddingModelProvider,
|
||||
refetchOnWindowFocus: false,
|
||||
select: response => response.data,
|
||||
}),
|
||||
)
|
||||
|
||||
const {
|
||||
handleChunkStructureChange,
|
||||
@ -108,6 +131,49 @@ const Panel: FC<NodePanelProps<KnowledgeBaseNodeType>> = ({
|
||||
})
|
||||
}, [data.embedding_model_provider, data.embedding_model, data.retrieval_model?.reranking_enable, data.retrieval_model?.reranking_model, data.indexing_technique, embeddingModelList, rerankModelList])
|
||||
|
||||
const validationPayload = useMemo(() => {
|
||||
return {
|
||||
chunk_structure: chunkStructure,
|
||||
index_chunk_variable_selector: indexChunkVariableSelector,
|
||||
indexing_technique: indexingTechnique,
|
||||
embedding_model: embeddingModel,
|
||||
embedding_model_provider: embeddingModelProvider,
|
||||
retrieval_model: {
|
||||
search_method: retrievalSearchMethod,
|
||||
reranking_enable: retrievalRerankingEnable,
|
||||
reranking_model: retrievalModel?.reranking_model,
|
||||
},
|
||||
_embeddingModelList: embeddingModelList,
|
||||
_embeddingProviderModelList: embeddingProviderModelList,
|
||||
_rerankModelList: rerankModelList,
|
||||
}
|
||||
}, [
|
||||
chunkStructure,
|
||||
indexChunkVariableSelector,
|
||||
indexingTechnique,
|
||||
embeddingModel,
|
||||
embeddingModelProvider,
|
||||
retrievalSearchMethod,
|
||||
retrievalRerankingEnable,
|
||||
retrievalModel?.reranking_model,
|
||||
embeddingModelList,
|
||||
embeddingProviderModelList,
|
||||
rerankModelList,
|
||||
])
|
||||
|
||||
const validationIssue = useMemo(() => {
|
||||
return getKnowledgeBaseValidationIssue(validationPayload)
|
||||
}, [validationPayload])
|
||||
const { status: embeddingModelStatus } = useEmbeddingModelStatus({
|
||||
embeddingModel,
|
||||
embeddingModelProvider,
|
||||
embeddingModelList,
|
||||
})
|
||||
|
||||
const chunkStructureWarning = validationIssue?.code === KnowledgeBaseValidationIssueCode.chunkStructureRequired
|
||||
const chunksInputWarning = validationIssue?.code === KnowledgeBaseValidationIssueCode.chunksVariableRequired
|
||||
const embeddingModelWarning = indexingTechnique === IndexMethodEnum.QUALIFIED && embeddingModelStatus !== 'active'
|
||||
|
||||
return (
|
||||
<div>
|
||||
<Group
|
||||
@ -117,6 +183,7 @@ const Panel: FC<NodePanelProps<KnowledgeBaseNodeType>> = ({
|
||||
<ChunkStructure
|
||||
chunkStructure={data.chunk_structure}
|
||||
onChunkStructureChange={handleChunkStructureChange}
|
||||
warningDot={chunkStructureWarning}
|
||||
readonly={nodesReadOnly}
|
||||
/>
|
||||
</Group>
|
||||
@ -131,6 +198,7 @@ const Panel: FC<NodePanelProps<KnowledgeBaseNodeType>> = ({
|
||||
fieldTitleProps: {
|
||||
title: t('nodes.knowledgeBase.chunksInput', { ns: 'workflow' }),
|
||||
tooltip: t('nodes.knowledgeBase.chunksInputTip', { ns: 'workflow' }),
|
||||
warningDot: chunksInputWarning,
|
||||
},
|
||||
}}
|
||||
>
|
||||
@ -163,6 +231,7 @@ const Panel: FC<NodePanelProps<KnowledgeBaseNodeType>> = ({
|
||||
embeddingModel={data.embedding_model}
|
||||
embeddingModelProvider={data.embedding_model_provider}
|
||||
onEmbeddingModelChange={handleEmbeddingModelChange}
|
||||
warningDot={embeddingModelWarning}
|
||||
readonly={nodesReadOnly}
|
||||
/>
|
||||
)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import type { IndexingType } from '@/app/components/datasets/create/step-two'
|
||||
import type { Model } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import type { Model, ModelItem } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import type { CommonNodeType } from '@/app/components/workflow/types'
|
||||
import type { RerankingModeEnum, WeightedScoreEnum } from '@/models/datasets'
|
||||
import type { RETRIEVE_METHOD } from '@/types/app'
|
||||
@ -57,6 +57,7 @@ export type KnowledgeBaseNodeType = CommonNodeType & {
|
||||
keyword_number: number
|
||||
retrieval_model: RetrievalSetting
|
||||
_embeddingModelList?: Model[]
|
||||
_embeddingProviderModelList?: ModelItem[]
|
||||
_rerankModelList?: Model[]
|
||||
summary_index_setting?: SummaryIndexSetting
|
||||
}
|
||||
|
||||
226
web/app/components/workflow/nodes/knowledge-base/utils.spec.ts
Normal file
226
web/app/components/workflow/nodes/knowledge-base/utils.spec.ts
Normal file
@ -0,0 +1,226 @@
|
||||
import type { KnowledgeBaseNodeType } from './types'
|
||||
import type { Model, ModelItem } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import {
|
||||
ConfigurationMethodEnum,
|
||||
ModelStatusEnum,
|
||||
ModelTypeEnum,
|
||||
} from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import {
|
||||
ChunkStructureEnum,
|
||||
IndexMethodEnum,
|
||||
RetrievalSearchMethodEnum,
|
||||
} from './types'
|
||||
import {
|
||||
getKnowledgeBaseValidationIssue,
|
||||
getKnowledgeBaseValidationMessage,
|
||||
isHighQualitySearchMethod,
|
||||
isKnowledgeBaseEmbeddingIssue,
|
||||
KnowledgeBaseValidationIssueCode,
|
||||
} from './utils'
|
||||
|
||||
const makeEmbeddingModelList = (status: ModelStatusEnum): Model[] => {
|
||||
return [
|
||||
{
|
||||
provider: 'openai',
|
||||
icon_small: { en_US: '', zh_Hans: '' },
|
||||
label: { en_US: 'OpenAI', zh_Hans: 'OpenAI' },
|
||||
models: [{
|
||||
model: 'gpt-4o',
|
||||
label: { en_US: 'GPT-4o', zh_Hans: 'GPT-4o' },
|
||||
model_type: ModelTypeEnum.textEmbedding,
|
||||
fetch_from: ConfigurationMethodEnum.predefinedModel,
|
||||
status,
|
||||
model_properties: {},
|
||||
load_balancing_enabled: false,
|
||||
}],
|
||||
status,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
const makeEmbeddingProviderModelList = (status: ModelStatusEnum): ModelItem[] => {
|
||||
return [{
|
||||
model: 'gpt-4o',
|
||||
label: { en_US: 'GPT-4o', zh_Hans: 'GPT-4o' },
|
||||
model_type: ModelTypeEnum.textEmbedding,
|
||||
fetch_from: ConfigurationMethodEnum.predefinedModel,
|
||||
status,
|
||||
model_properties: {},
|
||||
load_balancing_enabled: false,
|
||||
}]
|
||||
}
|
||||
|
||||
const makePayload = (overrides: Partial<KnowledgeBaseNodeType> = {}): KnowledgeBaseNodeType => {
|
||||
return {
|
||||
index_chunk_variable_selector: ['general_chunks', 'results'],
|
||||
chunk_structure: ChunkStructureEnum.general,
|
||||
indexing_technique: IndexMethodEnum.QUALIFIED,
|
||||
embedding_model: 'gpt-4o',
|
||||
embedding_model_provider: 'openai',
|
||||
keyword_number: 10,
|
||||
retrieval_model: {
|
||||
top_k: 3,
|
||||
score_threshold_enabled: false,
|
||||
score_threshold: 0.5,
|
||||
search_method: RetrievalSearchMethodEnum.semantic,
|
||||
},
|
||||
_embeddingModelList: makeEmbeddingModelList(ModelStatusEnum.active),
|
||||
_embeddingProviderModelList: makeEmbeddingProviderModelList(ModelStatusEnum.active),
|
||||
_rerankModelList: [],
|
||||
...overrides,
|
||||
} as KnowledgeBaseNodeType
|
||||
}
|
||||
|
||||
describe('knowledge-base validation issue', () => {
|
||||
it('identifies high quality retrieval methods', () => {
|
||||
expect(isHighQualitySearchMethod(RetrievalSearchMethodEnum.semantic)).toBe(true)
|
||||
expect(isHighQualitySearchMethod(RetrievalSearchMethodEnum.hybrid)).toBe(true)
|
||||
expect(isHighQualitySearchMethod(RetrievalSearchMethodEnum.fullText)).toBe(true)
|
||||
expect(isHighQualitySearchMethod('unknown-method' as RetrievalSearchMethodEnum)).toBe(false)
|
||||
})
|
||||
|
||||
it('returns chunk structure issue when chunk structure is missing', () => {
|
||||
const issue = getKnowledgeBaseValidationIssue(makePayload({ chunk_structure: undefined }))
|
||||
expect(issue?.code).toBe(KnowledgeBaseValidationIssueCode.chunkStructureRequired)
|
||||
})
|
||||
|
||||
it('returns chunks variable issue when chunks selector is empty', () => {
|
||||
const issue = getKnowledgeBaseValidationIssue(makePayload({ index_chunk_variable_selector: [] }))
|
||||
expect(issue?.code).toBe(KnowledgeBaseValidationIssueCode.chunksVariableRequired)
|
||||
})
|
||||
|
||||
it('maps no-configure to configure required', () => {
|
||||
const issue = getKnowledgeBaseValidationIssue(
|
||||
makePayload({ _embeddingProviderModelList: makeEmbeddingProviderModelList(ModelStatusEnum.noConfigure) }),
|
||||
)
|
||||
expect(issue?.code).toBe(KnowledgeBaseValidationIssueCode.embeddingModelConfigureRequired)
|
||||
})
|
||||
|
||||
it('maps credential-removed to API key unavailable', () => {
|
||||
const issue = getKnowledgeBaseValidationIssue(
|
||||
makePayload({ _embeddingProviderModelList: makeEmbeddingProviderModelList(ModelStatusEnum.credentialRemoved) }),
|
||||
)
|
||||
expect(issue?.code).toBe(KnowledgeBaseValidationIssueCode.embeddingModelApiKeyUnavailable)
|
||||
})
|
||||
|
||||
it('maps quota-exceeded to credits exhausted', () => {
|
||||
const issue = getKnowledgeBaseValidationIssue(
|
||||
makePayload({ _embeddingProviderModelList: makeEmbeddingProviderModelList(ModelStatusEnum.quotaExceeded) }),
|
||||
)
|
||||
expect(issue?.code).toBe(KnowledgeBaseValidationIssueCode.embeddingModelCreditsExhausted)
|
||||
})
|
||||
|
||||
it('maps disabled to disabled', () => {
|
||||
const issue = getKnowledgeBaseValidationIssue(
|
||||
makePayload({ _embeddingProviderModelList: makeEmbeddingProviderModelList(ModelStatusEnum.disabled) }),
|
||||
)
|
||||
expect(issue?.code).toBe(KnowledgeBaseValidationIssueCode.embeddingModelDisabled)
|
||||
})
|
||||
|
||||
it('maps missing provider plugin to incompatible when embedding model is already configured', () => {
|
||||
const issue = getKnowledgeBaseValidationIssue(
|
||||
makePayload({
|
||||
embedding_model_provider: 'missing-provider',
|
||||
_embeddingProviderModelList: undefined,
|
||||
}),
|
||||
)
|
||||
expect(issue?.code).toBe(KnowledgeBaseValidationIssueCode.embeddingModelIncompatible)
|
||||
})
|
||||
|
||||
it('falls back to provider model list when provider scoped model list is empty', () => {
|
||||
const issue = getKnowledgeBaseValidationIssue(
|
||||
makePayload({ _embeddingProviderModelList: [] }),
|
||||
)
|
||||
expect(issue).toBeNull()
|
||||
})
|
||||
|
||||
it('returns embedding-model-not-configured when the qualified index is missing provider details', () => {
|
||||
const issue = getKnowledgeBaseValidationIssue(
|
||||
makePayload({ embedding_model: undefined }),
|
||||
)
|
||||
|
||||
expect(issue?.code).toBe(KnowledgeBaseValidationIssueCode.embeddingModelNotConfigured)
|
||||
})
|
||||
|
||||
it('maps no-permission embedding models to incompatible', () => {
|
||||
const issue = getKnowledgeBaseValidationIssue(
|
||||
makePayload({ _embeddingProviderModelList: makeEmbeddingProviderModelList(ModelStatusEnum.noPermission) }),
|
||||
)
|
||||
|
||||
expect(issue?.code).toBe(KnowledgeBaseValidationIssueCode.embeddingModelIncompatible)
|
||||
})
|
||||
|
||||
it('returns retrieval-setting-required when retrieval search method is missing', () => {
|
||||
const issue = getKnowledgeBaseValidationIssue(
|
||||
makePayload({ retrieval_model: undefined as never }),
|
||||
)
|
||||
|
||||
expect(issue?.code).toBe(KnowledgeBaseValidationIssueCode.retrievalSettingRequired)
|
||||
})
|
||||
|
||||
it('returns reranking-model-required when reranking is enabled without a model', () => {
|
||||
const issue = getKnowledgeBaseValidationIssue(
|
||||
makePayload({
|
||||
retrieval_model: {
|
||||
...makePayload().retrieval_model,
|
||||
reranking_enable: true,
|
||||
},
|
||||
}),
|
||||
)
|
||||
|
||||
expect(issue?.code).toBe(KnowledgeBaseValidationIssueCode.rerankingModelRequired)
|
||||
})
|
||||
|
||||
it('returns reranking-model-invalid when the configured reranking model is unavailable', () => {
|
||||
const issue = getKnowledgeBaseValidationIssue(
|
||||
makePayload({
|
||||
retrieval_model: {
|
||||
...makePayload().retrieval_model,
|
||||
reranking_enable: true,
|
||||
reranking_model: {
|
||||
reranking_provider_name: 'missing-provider',
|
||||
reranking_model_name: 'missing-model',
|
||||
},
|
||||
},
|
||||
}),
|
||||
)
|
||||
|
||||
expect(issue?.code).toBe(KnowledgeBaseValidationIssueCode.rerankingModelInvalid)
|
||||
})
|
||||
})
|
||||
|
||||
describe('knowledge-base validation messaging', () => {
|
||||
const t = (key: string) => key
|
||||
|
||||
it.each([
|
||||
[KnowledgeBaseValidationIssueCode.chunkStructureRequired, 'nodes.knowledgeBase.chunkIsRequired'],
|
||||
[KnowledgeBaseValidationIssueCode.chunksVariableRequired, 'nodes.knowledgeBase.chunksVariableIsRequired'],
|
||||
[KnowledgeBaseValidationIssueCode.indexMethodRequired, 'nodes.knowledgeBase.indexMethodIsRequired'],
|
||||
[KnowledgeBaseValidationIssueCode.embeddingModelNotConfigured, 'nodes.knowledgeBase.embeddingModelNotConfigured'],
|
||||
[KnowledgeBaseValidationIssueCode.embeddingModelConfigureRequired, 'modelProvider.selector.configureRequired'],
|
||||
[KnowledgeBaseValidationIssueCode.embeddingModelApiKeyUnavailable, 'modelProvider.selector.apiKeyUnavailable'],
|
||||
[KnowledgeBaseValidationIssueCode.embeddingModelCreditsExhausted, 'modelProvider.selector.creditsExhausted'],
|
||||
[KnowledgeBaseValidationIssueCode.embeddingModelDisabled, 'modelProvider.selector.disabled'],
|
||||
[KnowledgeBaseValidationIssueCode.embeddingModelIncompatible, 'modelProvider.selector.incompatible'],
|
||||
[KnowledgeBaseValidationIssueCode.retrievalSettingRequired, 'nodes.knowledgeBase.retrievalSettingIsRequired'],
|
||||
[KnowledgeBaseValidationIssueCode.rerankingModelRequired, 'nodes.knowledgeBase.rerankingModelIsRequired'],
|
||||
[KnowledgeBaseValidationIssueCode.rerankingModelInvalid, 'nodes.knowledgeBase.rerankingModelIsInvalid'],
|
||||
] as const)('maps %s to the expected translation key', (code, expectedKey) => {
|
||||
expect(getKnowledgeBaseValidationMessage({ code }, t as never)).toBe(expectedKey)
|
||||
})
|
||||
|
||||
it('returns an empty string when there is no issue', () => {
|
||||
expect(getKnowledgeBaseValidationMessage(undefined, t as never)).toBe('')
|
||||
})
|
||||
})
|
||||
|
||||
describe('isKnowledgeBaseEmbeddingIssue', () => {
|
||||
it('returns true for embedding-related issues', () => {
|
||||
expect(isKnowledgeBaseEmbeddingIssue({ code: KnowledgeBaseValidationIssueCode.embeddingModelDisabled })).toBe(true)
|
||||
})
|
||||
|
||||
it('returns false for non-embedding issues and missing values', () => {
|
||||
expect(isKnowledgeBaseEmbeddingIssue({ code: KnowledgeBaseValidationIssueCode.rerankingModelInvalid })).toBe(false)
|
||||
expect(isKnowledgeBaseEmbeddingIssue(undefined)).toBe(false)
|
||||
})
|
||||
})
|
||||
@ -1,3 +1,11 @@
|
||||
import type { TFunction } from 'i18next'
|
||||
import type { KnowledgeBaseNodeType } from './types'
|
||||
import {
|
||||
IndexingType,
|
||||
} from '@/app/components/datasets/create/step-two'
|
||||
import {
|
||||
ModelStatusEnum,
|
||||
} from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import {
|
||||
RetrievalSearchMethodEnum,
|
||||
} from './types'
|
||||
@ -7,3 +15,174 @@ export const isHighQualitySearchMethod = (searchMethod: RetrievalSearchMethodEnu
|
||||
|| searchMethod === RetrievalSearchMethodEnum.hybrid
|
||||
|| searchMethod === RetrievalSearchMethodEnum.fullText
|
||||
}
|
||||
|
||||
export enum KnowledgeBaseValidationIssueCode {
|
||||
chunkStructureRequired = 'chunk-structure-required',
|
||||
chunksVariableRequired = 'chunks-variable-required',
|
||||
indexMethodRequired = 'index-method-required',
|
||||
embeddingModelNotConfigured = 'embedding-model-not-configured',
|
||||
embeddingModelConfigureRequired = 'embedding-model-configure-required',
|
||||
embeddingModelApiKeyUnavailable = 'embedding-model-api-key-unavailable',
|
||||
embeddingModelCreditsExhausted = 'embedding-model-credits-exhausted',
|
||||
embeddingModelDisabled = 'embedding-model-disabled',
|
||||
embeddingModelIncompatible = 'embedding-model-incompatible',
|
||||
retrievalSettingRequired = 'retrieval-setting-required',
|
||||
rerankingModelRequired = 'reranking-model-required',
|
||||
rerankingModelInvalid = 'reranking-model-invalid',
|
||||
}
|
||||
|
||||
type KnowledgeBaseValidationIssue = {
|
||||
code: KnowledgeBaseValidationIssueCode
|
||||
}
|
||||
|
||||
type KnowledgeBaseValidationPayload = Pick<KnowledgeBaseNodeType, 'chunk_structure' | 'index_chunk_variable_selector' | 'indexing_technique' | 'embedding_model' | 'embedding_model_provider' | '_embeddingModelList' | '_embeddingProviderModelList' | '_rerankModelList'> & {
|
||||
retrieval_model?: Pick<KnowledgeBaseNodeType['retrieval_model'], 'search_method' | 'reranking_enable' | 'reranking_model'>
|
||||
}
|
||||
|
||||
const EMBEDDING_ISSUE_CODES = new Set<KnowledgeBaseValidationIssueCode>([
|
||||
KnowledgeBaseValidationIssueCode.embeddingModelNotConfigured,
|
||||
KnowledgeBaseValidationIssueCode.embeddingModelConfigureRequired,
|
||||
KnowledgeBaseValidationIssueCode.embeddingModelApiKeyUnavailable,
|
||||
KnowledgeBaseValidationIssueCode.embeddingModelCreditsExhausted,
|
||||
KnowledgeBaseValidationIssueCode.embeddingModelDisabled,
|
||||
KnowledgeBaseValidationIssueCode.embeddingModelIncompatible,
|
||||
])
|
||||
|
||||
const resolveIssue = (code: KnowledgeBaseValidationIssueCode): KnowledgeBaseValidationIssue => ({
|
||||
code,
|
||||
})
|
||||
|
||||
const resolveEmbeddingIssue = (payload: KnowledgeBaseValidationPayload): KnowledgeBaseValidationIssue | null => {
|
||||
const {
|
||||
embedding_model,
|
||||
embedding_model_provider,
|
||||
_embeddingModelList,
|
||||
_embeddingProviderModelList,
|
||||
} = payload
|
||||
|
||||
if (!embedding_model || !embedding_model_provider)
|
||||
return resolveIssue(KnowledgeBaseValidationIssueCode.embeddingModelNotConfigured)
|
||||
|
||||
const currentEmbeddingModelProvider = _embeddingModelList?.find(provider => provider.provider === embedding_model_provider)
|
||||
const hasProviderScopedModelList = !!_embeddingProviderModelList && _embeddingProviderModelList.length > 0
|
||||
const embeddingModelCandidates = hasProviderScopedModelList
|
||||
? _embeddingProviderModelList
|
||||
: currentEmbeddingModelProvider?.models
|
||||
const currentEmbeddingModel = embeddingModelCandidates?.find(model => model.model === embedding_model)
|
||||
|
||||
if (!currentEmbeddingModel) {
|
||||
if (!currentEmbeddingModelProvider)
|
||||
return resolveIssue(KnowledgeBaseValidationIssueCode.embeddingModelIncompatible)
|
||||
|
||||
const providerExists = hasProviderScopedModelList || currentEmbeddingModelProvider
|
||||
return resolveIssue(providerExists
|
||||
? KnowledgeBaseValidationIssueCode.embeddingModelIncompatible
|
||||
: KnowledgeBaseValidationIssueCode.embeddingModelNotConfigured)
|
||||
}
|
||||
|
||||
switch (currentEmbeddingModel.status) {
|
||||
case ModelStatusEnum.active:
|
||||
return null
|
||||
case ModelStatusEnum.noConfigure:
|
||||
return resolveIssue(KnowledgeBaseValidationIssueCode.embeddingModelConfigureRequired)
|
||||
case ModelStatusEnum.credentialRemoved:
|
||||
return resolveIssue(KnowledgeBaseValidationIssueCode.embeddingModelApiKeyUnavailable)
|
||||
case ModelStatusEnum.quotaExceeded:
|
||||
return resolveIssue(KnowledgeBaseValidationIssueCode.embeddingModelCreditsExhausted)
|
||||
case ModelStatusEnum.disabled:
|
||||
return resolveIssue(KnowledgeBaseValidationIssueCode.embeddingModelDisabled)
|
||||
case ModelStatusEnum.noPermission:
|
||||
default:
|
||||
return resolveIssue(KnowledgeBaseValidationIssueCode.embeddingModelIncompatible)
|
||||
}
|
||||
}
|
||||
|
||||
export const getKnowledgeBaseValidationIssue = (payload: KnowledgeBaseValidationPayload): KnowledgeBaseValidationIssue | null => {
|
||||
const {
|
||||
chunk_structure,
|
||||
indexing_technique,
|
||||
retrieval_model,
|
||||
index_chunk_variable_selector,
|
||||
_rerankModelList,
|
||||
} = payload
|
||||
|
||||
const {
|
||||
search_method,
|
||||
reranking_enable,
|
||||
reranking_model,
|
||||
} = retrieval_model || {}
|
||||
|
||||
if (!chunk_structure)
|
||||
return resolveIssue(KnowledgeBaseValidationIssueCode.chunkStructureRequired)
|
||||
|
||||
if (index_chunk_variable_selector.length === 0)
|
||||
return resolveIssue(KnowledgeBaseValidationIssueCode.chunksVariableRequired)
|
||||
|
||||
if (!indexing_technique)
|
||||
return resolveIssue(KnowledgeBaseValidationIssueCode.indexMethodRequired)
|
||||
|
||||
if (indexing_technique === IndexingType.QUALIFIED) {
|
||||
const embeddingIssue = resolveEmbeddingIssue(payload)
|
||||
if (embeddingIssue)
|
||||
return embeddingIssue
|
||||
}
|
||||
|
||||
if (!retrieval_model || !search_method)
|
||||
return resolveIssue(KnowledgeBaseValidationIssueCode.retrievalSettingRequired)
|
||||
|
||||
if (reranking_enable) {
|
||||
if (!reranking_model || !reranking_model.reranking_provider_name || !reranking_model.reranking_model_name)
|
||||
return resolveIssue(KnowledgeBaseValidationIssueCode.rerankingModelRequired)
|
||||
|
||||
const currentRerankingModelProvider = _rerankModelList?.find(provider => provider.provider === reranking_model.reranking_provider_name)
|
||||
const currentRerankingModel = currentRerankingModelProvider?.models.find(model => model.model === reranking_model.reranking_model_name)
|
||||
if (!currentRerankingModel)
|
||||
return resolveIssue(KnowledgeBaseValidationIssueCode.rerankingModelInvalid)
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
export const getKnowledgeBaseValidationMessage = (
|
||||
issue: KnowledgeBaseValidationIssue | null | undefined,
|
||||
t: TFunction,
|
||||
) => {
|
||||
if (!issue)
|
||||
return ''
|
||||
|
||||
switch (issue.code) {
|
||||
case KnowledgeBaseValidationIssueCode.chunkStructureRequired:
|
||||
return t('nodes.knowledgeBase.chunkIsRequired', { ns: 'workflow' })
|
||||
case KnowledgeBaseValidationIssueCode.chunksVariableRequired:
|
||||
return t('nodes.knowledgeBase.chunksVariableIsRequired', { ns: 'workflow' })
|
||||
case KnowledgeBaseValidationIssueCode.indexMethodRequired:
|
||||
return t('nodes.knowledgeBase.indexMethodIsRequired', { ns: 'workflow' })
|
||||
case KnowledgeBaseValidationIssueCode.embeddingModelNotConfigured:
|
||||
return t('nodes.knowledgeBase.embeddingModelNotConfigured', { ns: 'workflow' })
|
||||
case KnowledgeBaseValidationIssueCode.embeddingModelConfigureRequired:
|
||||
return t('modelProvider.selector.configureRequired', { ns: 'common' })
|
||||
case KnowledgeBaseValidationIssueCode.embeddingModelApiKeyUnavailable:
|
||||
return t('modelProvider.selector.apiKeyUnavailable', { ns: 'common' })
|
||||
case KnowledgeBaseValidationIssueCode.embeddingModelCreditsExhausted:
|
||||
return t('modelProvider.selector.creditsExhausted', { ns: 'common' })
|
||||
case KnowledgeBaseValidationIssueCode.embeddingModelDisabled:
|
||||
return t('modelProvider.selector.disabled', { ns: 'common' })
|
||||
case KnowledgeBaseValidationIssueCode.embeddingModelIncompatible:
|
||||
return t('modelProvider.selector.incompatible', { ns: 'common' })
|
||||
case KnowledgeBaseValidationIssueCode.retrievalSettingRequired:
|
||||
return t('nodes.knowledgeBase.retrievalSettingIsRequired', { ns: 'workflow' })
|
||||
case KnowledgeBaseValidationIssueCode.rerankingModelRequired:
|
||||
return t('nodes.knowledgeBase.rerankingModelIsRequired', { ns: 'workflow' })
|
||||
case KnowledgeBaseValidationIssueCode.rerankingModelInvalid:
|
||||
return t('nodes.knowledgeBase.rerankingModelIsInvalid', { ns: 'workflow' })
|
||||
default:
|
||||
return ''
|
||||
}
|
||||
}
|
||||
|
||||
export const isKnowledgeBaseEmbeddingIssue = (issue: KnowledgeBaseValidationIssue | null | undefined) => {
|
||||
if (!issue)
|
||||
return false
|
||||
|
||||
return EMBEDDING_ISSUE_CODES.has(issue.code)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user