diff --git a/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/reranking-model-selector.spec.tsx b/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/reranking-model-selector.spec.tsx new file mode 100644 index 0000000000..300de76c2e --- /dev/null +++ b/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/reranking-model-selector.spec.tsx @@ -0,0 +1,121 @@ +import type { + DefaultModel, + Model, + ModelItem, +} from '@/app/components/header/account-setting/model-provider-page/declarations' +import { fireEvent, render, screen } from '@testing-library/react' +import { + ConfigurationMethodEnum, + ModelStatusEnum, + ModelTypeEnum, +} from '@/app/components/header/account-setting/model-provider-page/declarations' +import RerankingModelSelector from './reranking-model-selector' + +type MockModelSelectorProps = { + defaultModel?: DefaultModel + modelList: Model[] + onSelect?: (model: DefaultModel) => void +} + +const mockUseModelListAndDefaultModel = vi.hoisted(() => vi.fn()) + +vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({ + useModelListAndDefaultModel: mockUseModelListAndDefaultModel, +})) + +vi.mock('@/app/components/header/account-setting/model-provider-page/model-selector', () => ({ + default: ({ defaultModel, modelList, onSelect }: MockModelSelectorProps) => ( +
+
+ {defaultModel ? `${defaultModel.provider}/${defaultModel.model}` : 'no-default-model'} +
+
{modelList.length}
+ +
+ ), +})) + +const createModelItem = (overrides: Partial = {}): ModelItem => ({ + model: 'rerank-v3', + label: { en_US: 'Rerank V3', zh_Hans: 'Rerank V3' }, + model_type: ModelTypeEnum.rerank, + fetch_from: ConfigurationMethodEnum.predefinedModel, + status: ModelStatusEnum.active, + model_properties: {}, + load_balancing_enabled: false, + ...overrides, +}) + +const createModel = (overrides: Partial = {}): Model => ({ + provider: 'cohere', + icon_small: { + en_US: 'https://example.com/cohere.png', + zh_Hans: 'https://example.com/cohere.png', + }, + icon_small_dark: { + en_US: 'https://example.com/cohere-dark.png', + zh_Hans: 'https://example.com/cohere-dark.png', + }, + label: { en_US: 'Cohere', zh_Hans: 'Cohere' }, + models: [createModelItem()], + status: ModelStatusEnum.active, + ...overrides, +}) + +describe('RerankingModelSelector', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseModelListAndDefaultModel.mockReturnValue({ + modelList: [createModel()], + defaultModel: undefined, + }) + }) + + // Rendering behavior for mapped rerank model state. + describe('Rendering', () => { + it('should not pass a default model when reranking model fields are empty strings', () => { + render( + , + ) + + expect(screen.getByTestId('default-model')).toHaveTextContent('no-default-model') + expect(screen.getByTestId('model-list-count')).toHaveTextContent('1') + }) + + it('should map reranking model to default model when both fields exist', () => { + render( + , + ) + + expect(screen.getByTestId('default-model')).toHaveTextContent('cohere/rerank-v3') + }) + }) + + // Selection behavior should convert back to workflow reranking model shape. + describe('Interactions', () => { + it('should map selected model back to reranking model fields', () => { + const onRerankingModelChange = vi.fn() + + render() + + fireEvent.click(screen.getByRole('button', { name: 'select-model' })) + + expect(onRerankingModelChange).toHaveBeenCalledWith({ + reranking_provider_name: 'cohere', + reranking_model_name: 'rerank-v3', + }) + }) + }) +}) diff --git a/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/reranking-model-selector.tsx b/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/reranking-model-selector.tsx index 3e0bea2b28..7ee18a82e5 100644 --- a/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/reranking-model-selector.tsx +++ b/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/reranking-model-selector.tsx @@ -22,12 +22,12 @@ const RerankingModelSelector = ({ modelList: rerankModelList, } = useModelListAndDefaultModel(ModelTypeEnum.rerank) const rerankModel = useMemo(() => { - if (!rerankingModel) + if (!rerankingModel?.reranking_provider_name || !rerankingModel?.reranking_model_name) return undefined return { - providerName: rerankingModel.reranking_provider_name, - modelName: rerankingModel.reranking_model_name, + provider: rerankingModel.reranking_provider_name, + model: rerankingModel.reranking_model_name, } }, [rerankingModel]) @@ -40,7 +40,7 @@ const RerankingModelSelector = ({ return (