From bdd8d5b470b9df7683317a9cb12a9add87cb214d Mon Sep 17 00:00:00 2001 From: Coding On Star <447357187@qq.com> Date: Thu, 15 Jan 2026 10:56:02 +0800 Subject: [PATCH 1/8] test: add unit tests for PluginPage and related components (#30908) Co-authored-by: CodingOnStar --- .../components/plugins/card/index.spec.tsx | 52 + .../plugins/plugin-page/context.spec.tsx | 123 ++ .../plugins/plugin-page/index.spec.tsx | 1041 +++++++++++++++++ .../components/plugins/plugin-page/index.tsx | 1 + .../components/plugin-task-list.tsx | 219 ++++ .../components/task-status-indicator.tsx | 96 ++ .../plugin-page/plugin-tasks/index.spec.tsx | 856 ++++++++++++++ .../plugin-page/plugin-tasks/index.tsx | 300 +---- .../plugin-page/use-reference-setting.spec.ts | 388 ++++++ .../plugins/plugin-page/use-uploader.spec.ts | 487 ++++++++ .../components/rag-pipeline/index.spec.tsx | 550 +++++++++ 11 files changed, 3870 insertions(+), 243 deletions(-) create mode 100644 web/app/components/plugins/plugin-page/context.spec.tsx create mode 100644 web/app/components/plugins/plugin-page/index.spec.tsx create mode 100644 web/app/components/plugins/plugin-page/plugin-tasks/components/plugin-task-list.tsx create mode 100644 web/app/components/plugins/plugin-page/plugin-tasks/components/task-status-indicator.tsx create mode 100644 web/app/components/plugins/plugin-page/plugin-tasks/index.spec.tsx create mode 100644 web/app/components/plugins/plugin-page/use-reference-setting.spec.ts create mode 100644 web/app/components/plugins/plugin-page/use-uploader.spec.ts create mode 100644 web/app/components/rag-pipeline/index.spec.tsx diff --git a/web/app/components/plugins/card/index.spec.tsx b/web/app/components/plugins/card/index.spec.tsx index fd97534ec4..8dd7e67d69 100644 --- a/web/app/components/plugins/card/index.spec.tsx +++ b/web/app/components/plugins/card/index.spec.tsx @@ -897,6 +897,58 @@ describe('Icon', () => { const iconDiv = container.firstChild as HTMLElement expect(iconDiv).toHaveStyle({ backgroundImage: 'url(/icon?name=test&size=large)' }) }) + + it('should not render status indicators when src is object with installed=true', () => { + render() + + // Status indicators should not render for object src + expect(screen.queryByTestId('ri-check-line')).not.toBeInTheDocument() + }) + + it('should not render status indicators when src is object with installFailed=true', () => { + render() + + // Status indicators should not render for object src + expect(screen.queryByTestId('ri-close-line')).not.toBeInTheDocument() + }) + + it('should render object src with all size variants', () => { + const sizes: Array<'xs' | 'tiny' | 'small' | 'medium' | 'large'> = ['xs', 'tiny', 'small', 'medium', 'large'] + + sizes.forEach((size) => { + const { unmount } = render() + expect(screen.getByTestId('app-icon')).toHaveAttribute('data-size', size) + unmount() + }) + }) + + it('should render object src with custom className', () => { + const { container } = render( + , + ) + + expect(container.querySelector('.custom-object-icon')).toBeInTheDocument() + }) + + it('should pass correct props to AppIcon for object src', () => { + render() + + const appIcon = screen.getByTestId('app-icon') + expect(appIcon).toHaveAttribute('data-icon', '๐Ÿ˜€') + expect(appIcon).toHaveAttribute('data-background', '#123456') + expect(appIcon).toHaveAttribute('data-icon-type', 'emoji') + }) + + it('should render inner icon only when shouldUseMcpIcon returns true', () => { + // Test with MCP icon content + const { unmount } = render() + expect(screen.getByTestId('inner-icon')).toBeInTheDocument() + unmount() + + // Test without MCP icon content + render() + expect(screen.queryByTestId('inner-icon')).not.toBeInTheDocument() + }) }) }) diff --git a/web/app/components/plugins/plugin-page/context.spec.tsx b/web/app/components/plugins/plugin-page/context.spec.tsx new file mode 100644 index 0000000000..ea52ae1dbd --- /dev/null +++ b/web/app/components/plugins/plugin-page/context.spec.tsx @@ -0,0 +1,123 @@ +import { render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +// Import mocks +import { useGlobalPublicStore } from '@/context/global-public-context' + +import { PluginPageContext, PluginPageContextProvider, usePluginPageContext } from './context' + +// Mock dependencies +vi.mock('nuqs', () => ({ + useQueryState: vi.fn(() => ['plugins', vi.fn()]), +})) + +vi.mock('@/context/global-public-context', () => ({ + useGlobalPublicStore: vi.fn(), +})) + +vi.mock('../hooks', () => ({ + PLUGIN_PAGE_TABS_MAP: { + plugins: 'plugins', + marketplace: 'discover', + }, + usePluginPageTabs: () => [ + { value: 'plugins', text: 'Plugins' }, + { value: 'discover', text: 'Explore Marketplace' }, + ], +})) + +// Helper function to mock useGlobalPublicStore with marketplace setting +const mockGlobalPublicStore = (enableMarketplace: boolean) => { + vi.mocked(useGlobalPublicStore).mockImplementation((selector) => { + const state = { systemFeatures: { enable_marketplace: enableMarketplace } } + return selector(state as Parameters[0]) + }) +} + +// Test component that uses the context +const TestConsumer = () => { + const containerRef = usePluginPageContext(v => v.containerRef) + const options = usePluginPageContext(v => v.options) + const activeTab = usePluginPageContext(v => v.activeTab) + + return ( +
+ {containerRef ? 'true' : 'false'} + {options.length} + {activeTab} + {options.map((opt: { value: string, text: string }) => ( + {opt.text} + ))} +
+ ) +} + +describe('PluginPageContext', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('PluginPageContextProvider', () => { + it('should provide context values to children', () => { + mockGlobalPublicStore(true) + + render( + + + , + ) + + expect(screen.getByTestId('has-container-ref')).toHaveTextContent('true') + expect(screen.getByTestId('options-count')).toHaveTextContent('2') + }) + + it('should include marketplace tab when enable_marketplace is true', () => { + mockGlobalPublicStore(true) + + render( + + + , + ) + + expect(screen.getByTestId('option-plugins')).toBeInTheDocument() + expect(screen.getByTestId('option-discover')).toBeInTheDocument() + }) + + it('should filter out marketplace tab when enable_marketplace is false', () => { + mockGlobalPublicStore(false) + + render( + + + , + ) + + expect(screen.getByTestId('option-plugins')).toBeInTheDocument() + expect(screen.queryByTestId('option-discover')).not.toBeInTheDocument() + expect(screen.getByTestId('options-count')).toHaveTextContent('1') + }) + }) + + describe('usePluginPageContext', () => { + it('should select specific context values', () => { + mockGlobalPublicStore(true) + + render( + + + , + ) + + // activeTab should be 'plugins' from the mock + expect(screen.getByTestId('active-tab')).toHaveTextContent('plugins') + }) + }) + + describe('Default Context Values', () => { + it('should have empty options by default from context', () => { + // Test that the context has proper default values by checking the exported constant + // The PluginPageContext is created with default values including empty options array + expect(PluginPageContext).toBeDefined() + }) + }) +}) diff --git a/web/app/components/plugins/plugin-page/index.spec.tsx b/web/app/components/plugins/plugin-page/index.spec.tsx new file mode 100644 index 0000000000..a3ea7f7125 --- /dev/null +++ b/web/app/components/plugins/plugin-page/index.spec.tsx @@ -0,0 +1,1041 @@ +import type { PluginPageProps } from './index' +import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' +import { useQueryState } from 'nuqs' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +import { usePluginInstallation } from '@/hooks/use-query-params' +// Import mocked modules for assertions +import { fetchBundleInfoFromMarketPlace, fetchManifestFromMarketPlace } from '@/service/plugins' +import PluginPageWithContext from './index' + +// Mock external dependencies +vi.mock('@/service/plugins', () => ({ + fetchManifestFromMarketPlace: vi.fn(), + fetchBundleInfoFromMarketPlace: vi.fn(), +})) + +vi.mock('@/hooks/use-query-params', () => ({ + usePluginInstallation: vi.fn(() => [{ packageId: null, bundleInfo: null }, vi.fn()]), +})) + +vi.mock('@/hooks/use-document-title', () => ({ + default: vi.fn(), +})) + +vi.mock('@/context/i18n', () => ({ + useLocale: () => 'en-US', +})) + +vi.mock('@/context/global-public-context', () => ({ + useGlobalPublicStore: vi.fn((selector) => { + const state = { + systemFeatures: { + enable_marketplace: true, + }, + } + return selector(state) + }), +})) + +vi.mock('@/context/app-context', () => ({ + useAppContext: () => ({ + isCurrentWorkspaceManager: true, + isCurrentWorkspaceOwner: false, + }), +})) + +vi.mock('@/service/use-plugins', () => ({ + useReferenceSettings: () => ({ + data: { + permission: { + install_permission: 'everyone', + debug_permission: 'admins', + }, + }, + }), + useMutationReferenceSettings: () => ({ + mutate: vi.fn(), + isPending: false, + }), + useInvalidateReferenceSettings: () => vi.fn(), + usePluginTaskList: () => ({ + pluginTasks: [], + handleRefetch: vi.fn(), + }), + useMutationClearTaskPlugin: () => ({ + mutateAsync: vi.fn(), + }), + useInstalledPluginList: () => ({ + data: [], + isLoading: false, + isFetching: false, + isLastPage: true, + loadNextPage: vi.fn(), + }), + useInstalledLatestVersion: () => ({ + data: {}, + }), + useInvalidateInstalledPluginList: () => vi.fn(), +})) + +vi.mock('nuqs', () => ({ + useQueryState: vi.fn(() => ['plugins', vi.fn()]), +})) + +vi.mock('./plugin-tasks', () => ({ + default: () =>
PluginTasks
, +})) + +vi.mock('./debug-info', () => ({ + default: () =>
DebugInfo
, +})) + +vi.mock('./install-plugin-dropdown', () => ({ + default: ({ onSwitchToMarketplaceTab }: { onSwitchToMarketplaceTab: () => void }) => ( + + ), +})) + +vi.mock('../install-plugin/install-from-local-package', () => ({ + default: ({ onClose }: { onClose: () => void }) => ( +
+ +
+ ), +})) + +vi.mock('../install-plugin/install-from-marketplace', () => ({ + default: ({ onClose }: { onClose: () => void }) => ( +
+ +
+ ), +})) + +vi.mock('@/app/components/plugins/reference-setting-modal', () => ({ + default: ({ onHide }: { onHide: () => void }) => ( +
+ +
+ ), +})) + +// Helper to create default props +const createDefaultProps = (): PluginPageProps => ({ + plugins:
Plugins Content
, + marketplace:
Marketplace Content
, +}) + +// ============================================================================ +// PluginPage Component Tests +// ============================================================================ +describe('PluginPage Component', () => { + beforeEach(() => { + vi.clearAllMocks() + // Reset to default mock values + vi.mocked(usePluginInstallation).mockReturnValue([ + { packageId: null, bundleInfo: null }, + vi.fn(), + ]) + vi.mocked(useQueryState).mockReturnValue(['plugins', vi.fn()]) + }) + + // ============================================================================ + // Rendering Tests + // ============================================================================ + describe('Rendering', () => { + it('should render without crashing', () => { + render() + expect(document.getElementById('marketplace-container')).toBeInTheDocument() + }) + + it('should render with correct container id', () => { + render() + const container = document.getElementById('marketplace-container') + expect(container).toBeInTheDocument() + }) + + it('should render PluginTasks component', () => { + render() + expect(screen.getByTestId('plugin-tasks')).toBeInTheDocument() + }) + + it('should render plugins content when on plugins tab', () => { + vi.mocked(useQueryState).mockReturnValue(['plugins', vi.fn()]) + + render() + expect(screen.getByTestId('plugins-content')).toBeInTheDocument() + }) + + it('should render marketplace content when on marketplace tab', () => { + vi.mocked(useQueryState).mockReturnValue(['discover', vi.fn()]) + + render() + // The marketplace content should be visible when enable_marketplace is true and on discover tab + const container = document.getElementById('marketplace-container') + expect(container).toBeInTheDocument() + // Check that marketplace-specific links are shown + expect(screen.getByText(/requestAPlugin/i)).toBeInTheDocument() + }) + + it('should render TabSlider', () => { + render() + // TabSlider renders tab options + expect(document.querySelector('.flex-1')).toBeInTheDocument() + }) + + it('should render drag and drop hint when on plugins tab', () => { + vi.mocked(useQueryState).mockReturnValue(['plugins', vi.fn()]) + + render() + expect(screen.getByText(/dropPluginToInstall/i)).toBeInTheDocument() + }) + + it('should render file input for plugin upload', () => { + vi.mocked(useQueryState).mockReturnValue(['plugins', vi.fn()]) + + render() + const fileInput = document.getElementById('fileUploader') + expect(fileInput).toBeInTheDocument() + expect(fileInput).toHaveAttribute('type', 'file') + }) + }) + + // ============================================================================ + // Tab Navigation Tests + // ============================================================================ + describe('Tab Navigation', () => { + it('should display plugins tab as active by default', () => { + vi.mocked(useQueryState).mockReturnValue(['plugins', vi.fn()]) + + render() + expect(screen.getByTestId('plugins-content')).toBeInTheDocument() + }) + + it('should show marketplace links when on marketplace tab', () => { + vi.mocked(useQueryState).mockReturnValue(['discover', vi.fn()]) + + render() + // Check for marketplace-specific buttons + expect(screen.getByText(/requestAPlugin/i)).toBeInTheDocument() + expect(screen.getByText(/publishPlugins/i)).toBeInTheDocument() + }) + + it('should not show marketplace links when on plugins tab', () => { + vi.mocked(useQueryState).mockReturnValue(['plugins', vi.fn()]) + + render() + expect(screen.queryByText(/requestAPlugin/i)).not.toBeInTheDocument() + }) + }) + + // ============================================================================ + // Permission-based Rendering Tests + // ============================================================================ + describe('Permission-based Rendering', () => { + it('should render InstallPluginDropdown when canManagement is true', () => { + render() + expect(screen.getByTestId('install-dropdown')).toBeInTheDocument() + }) + + it('should render DebugInfo when canDebugger is true', () => { + render() + expect(screen.getByTestId('debug-info')).toBeInTheDocument() + }) + + it('should render settings button when canSetPermissions is true', () => { + render() + // Settings button with RiEqualizer2Line icon + const settingsButtons = document.querySelectorAll('button') + expect(settingsButtons.length).toBeGreaterThan(0) + }) + + it('should call setActiveTab when onSwitchToMarketplaceTab is called', async () => { + const mockSetActiveTab = vi.fn() + vi.mocked(useQueryState).mockReturnValue(['plugins', mockSetActiveTab]) + + render() + + // Click the install dropdown button which triggers onSwitchToMarketplaceTab + fireEvent.click(screen.getByTestId('install-dropdown')) + + // The mock onSwitchToMarketplaceTab calls setActiveTab('discover') + // Since our mock InstallPluginDropdown calls onSwitchToMarketplaceTab on click + // we verify that setActiveTab was called with 'discover'. + expect(mockSetActiveTab).toHaveBeenCalledWith('discover') + }) + + it('should use noop for file handlers when canManagement is false', () => { + // Override mock to disable management permission + vi.doMock('@/service/use-plugins', () => ({ + useReferenceSettings: () => ({ + data: { + permission: { + install_permission: 'noone', + debug_permission: 'noone', + }, + }, + }), + useMutationReferenceSettings: () => ({ + mutate: vi.fn(), + isPending: false, + }), + useInvalidateReferenceSettings: () => vi.fn(), + usePluginTaskList: () => ({ + pluginTasks: [], + handleRefetch: vi.fn(), + }), + useMutationClearTaskPlugin: () => ({ + mutateAsync: vi.fn(), + }), + useInstalledPluginList: () => ({ + data: [], + isLoading: false, + isFetching: false, + isLastPage: true, + loadNextPage: vi.fn(), + }), + useInstalledLatestVersion: () => ({ + data: {}, + }), + useInvalidateInstalledPluginList: () => vi.fn(), + })) + + vi.mocked(useQueryState).mockReturnValue(['plugins', vi.fn()]) + + render() + + // File input should still be in the document (even if handlers are noop) + const fileInput = document.getElementById('fileUploader') + expect(fileInput).toBeInTheDocument() + }) + }) + + // ============================================================================ + // File Upload Tests + // ============================================================================ + describe('File Upload', () => { + it('should have hidden file input', () => { + vi.mocked(useQueryState).mockReturnValue(['plugins', vi.fn()]) + + render() + const fileInput = document.getElementById('fileUploader') as HTMLInputElement + expect(fileInput).toHaveClass('hidden') + }) + + it('should accept .difypkg files', () => { + vi.mocked(useQueryState).mockReturnValue(['plugins', vi.fn()]) + + render() + const fileInput = document.getElementById('fileUploader') as HTMLInputElement + expect(fileInput.accept).toContain('.difypkg') + }) + + it('should show InstallFromLocalPackage modal when valid file is selected', async () => { + vi.mocked(useQueryState).mockReturnValue(['plugins', vi.fn()]) + + render() + const fileInput = document.getElementById('fileUploader') as HTMLInputElement + + const file = new File(['content'], 'plugin.difypkg', { type: 'application/octet-stream' }) + Object.defineProperty(fileInput, 'files', { + value: [file], + }) + + fireEvent.change(fileInput) + + await waitFor(() => { + expect(screen.getByTestId('install-local-modal')).toBeInTheDocument() + }) + }) + + it('should not show modal for non-.difypkg files', async () => { + vi.mocked(useQueryState).mockReturnValue(['plugins', vi.fn()]) + + render() + const fileInput = document.getElementById('fileUploader') as HTMLInputElement + + const file = new File(['content'], 'plugin.txt', { type: 'text/plain' }) + Object.defineProperty(fileInput, 'files', { + value: [file], + }) + + fireEvent.change(fileInput) + + await waitFor(() => { + expect(screen.queryByTestId('install-local-modal')).not.toBeInTheDocument() + }) + }) + }) + + // ============================================================================ + // Marketplace Installation Tests + // ============================================================================ + describe('Marketplace Installation', () => { + it('should fetch manifest when packageId is provided', async () => { + const mockSetInstallState = vi.fn() + vi.mocked(usePluginInstallation).mockReturnValue([ + { packageId: 'test-package-id', bundleInfo: null }, + mockSetInstallState, + ]) + + vi.mocked(fetchManifestFromMarketPlace).mockResolvedValue({ + data: { + plugin: { org: 'test-org', name: 'test-plugin', category: 'tool' }, + version: { version: '1.0.0' }, + }, + } as Awaited>) + + render() + + await waitFor(() => { + expect(fetchManifestFromMarketPlace).toHaveBeenCalledWith('test-package-id') + }) + }) + + it('should fetch bundle info when bundleInfo is provided', async () => { + const mockSetInstallState = vi.fn() + vi.mocked(usePluginInstallation).mockReturnValue([ + { packageId: null, bundleInfo: 'test-bundle-info' as unknown }, + mockSetInstallState, + ] as ReturnType) + + vi.mocked(fetchBundleInfoFromMarketPlace).mockResolvedValue({ + data: { version: { dependencies: [] } }, + } as unknown as Awaited>) + + render() + + await waitFor(() => { + expect(fetchBundleInfoFromMarketPlace).toHaveBeenCalledWith('test-bundle-info') + }) + }) + + it('should show InstallFromMarketplace modal after fetching manifest', async () => { + const mockSetInstallState = vi.fn() + vi.mocked(usePluginInstallation).mockReturnValue([ + { packageId: 'test-package-id', bundleInfo: null }, + mockSetInstallState, + ]) + + vi.mocked(fetchManifestFromMarketPlace).mockResolvedValue({ + data: { + plugin: { org: 'test-org', name: 'test-plugin', category: 'tool' }, + version: { version: '1.0.0' }, + }, + } as Awaited>) + + render() + + await waitFor(() => { + expect(screen.getByTestId('install-marketplace-modal')).toBeInTheDocument() + }, { timeout: 3000 }) + }) + + it('should handle fetch error gracefully', async () => { + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + + vi.mocked(usePluginInstallation).mockReturnValue([ + { packageId: null, bundleInfo: 'invalid-bundle' as unknown }, + vi.fn(), + ] as ReturnType) + + vi.mocked(fetchBundleInfoFromMarketPlace).mockRejectedValue(new Error('Network error')) + + render() + + await waitFor(() => { + expect(consoleSpy).toHaveBeenCalledWith('Failed to load bundle info:', expect.any(Error)) + }) + + consoleSpy.mockRestore() + }) + }) + + // ============================================================================ + // Settings Modal Tests + // ============================================================================ + describe('Settings Modal', () => { + it('should open settings modal when settings button is clicked', async () => { + render() + + fireEvent.click(screen.getByTestId('plugin-settings-button')) + + await waitFor(() => { + expect(screen.getByTestId('reference-setting-modal')).toBeInTheDocument() + }) + }) + + it('should close settings modal when onHide is called', async () => { + render() + + // Open modal + fireEvent.click(screen.getByTestId('plugin-settings-button')) + + await waitFor(() => { + expect(screen.getByTestId('reference-setting-modal')).toBeInTheDocument() + }) + + // Close modal + fireEvent.click(screen.getByText('Close Settings')) + + await waitFor(() => { + expect(screen.queryByTestId('reference-setting-modal')).not.toBeInTheDocument() + }) + }) + }) + + // ============================================================================ + // Drag and Drop Tests + // ============================================================================ + describe('Drag and Drop', () => { + it('should show dragging overlay when dragging files over container', () => { + vi.mocked(useQueryState).mockReturnValue(['plugins', vi.fn()]) + + render() + const container = document.getElementById('marketplace-container')! + + // Simulate drag enter + const dragEnterEvent = new Event('dragenter', { bubbles: true }) + Object.defineProperty(dragEnterEvent, 'dataTransfer', { + value: { types: ['Files'] }, + }) + container.dispatchEvent(dragEnterEvent) + + // Check for dragging overlay styles + expect(container).toBeInTheDocument() + }) + + it('should highlight drop zone text when dragging', () => { + vi.mocked(useQueryState).mockReturnValue(['plugins', vi.fn()]) + + render() + + // The drag hint should be visible + const dragHint = screen.getByText(/dropPluginToInstall/i) + expect(dragHint).toBeInTheDocument() + }) + }) + + // ============================================================================ + // Memoization Tests + // ============================================================================ + describe('Memoization', () => { + it('should memoize isPluginsTab correctly', () => { + vi.mocked(useQueryState).mockReturnValue(['plugins', vi.fn()]) + + const { rerender } = render() + + // Should show plugins content + expect(screen.getByTestId('plugins-content')).toBeInTheDocument() + + // Rerender with same props - memoized value should be same + rerender() + expect(screen.getByTestId('plugins-content')).toBeInTheDocument() + }) + + it('should memoize isExploringMarketplace correctly', () => { + vi.mocked(useQueryState).mockReturnValue(['discover', vi.fn()]) + + const { rerender } = render() + + // Should show marketplace links when on discover tab + expect(screen.getByText(/requestAPlugin/i)).toBeInTheDocument() + + // Rerender with same props + rerender() + expect(screen.getByText(/requestAPlugin/i)).toBeInTheDocument() + }) + + it('should recognize plugin type tabs as marketplace', () => { + // Test with a plugin type tab like 'tool' + vi.mocked(useQueryState).mockReturnValue(['tool', vi.fn()]) + + render() + + // Should show marketplace links when on a plugin type tab + expect(screen.getByText(/requestAPlugin/i)).toBeInTheDocument() + expect(screen.getByText(/publishPlugins/i)).toBeInTheDocument() + }) + + it('should render marketplace content when isExploringMarketplace and enable_marketplace are true', () => { + vi.mocked(useQueryState).mockReturnValue(['discover', vi.fn()]) + + render() + + // The marketplace prop content should be rendered + // Since we mock the marketplace as a div, check it's not hidden + const container = document.getElementById('marketplace-container') + expect(container).toBeInTheDocument() + expect(container).toHaveClass('bg-background-body') + }) + }) + + // ============================================================================ + // Context Provider Tests + // ============================================================================ + describe('Context Provider', () => { + it('should wrap component with PluginPageContextProvider', () => { + render() + + // The component should render, indicating context is working + expect(document.getElementById('marketplace-container')).toBeInTheDocument() + }) + + it('should filter out marketplace tab when enable_marketplace is false', () => { + // This tests line 69 in context.tsx - the false branch of enable_marketplace + // The marketplace tab should be filtered out from options + render() + // Component should still work without marketplace + expect(document.getElementById('marketplace-container')).toBeInTheDocument() + }) + }) + + // ============================================================================ + // Edge Cases and Error Handling + // ============================================================================ + describe('Edge Cases', () => { + it('should handle null plugins prop', () => { + vi.mocked(useQueryState).mockReturnValue(['plugins', vi.fn()]) + + render() + expect(document.getElementById('marketplace-container')).toBeInTheDocument() + }) + + it('should handle empty marketplace prop', () => { + vi.mocked(useQueryState).mockReturnValue(['discover', vi.fn()]) + + render() + expect(document.getElementById('marketplace-container')).toBeInTheDocument() + }) + + it('should handle rapid tab switches', async () => { + const mockSetActiveTab = vi.fn() + vi.mocked(useQueryState).mockReturnValue(['plugins', mockSetActiveTab]) + + render() + + // Simulate rapid switches by updating state + act(() => { + vi.mocked(useQueryState).mockReturnValue(['discover', mockSetActiveTab]) + }) + + expect(document.getElementById('marketplace-container')).toBeInTheDocument() + }) + + it('should handle marketplace disabled', () => { + // Mock marketplace disabled + vi.mock('@/context/global-public-context', async () => ({ + useGlobalPublicStore: vi.fn((selector) => { + const state = { + systemFeatures: { + enable_marketplace: false, + }, + } + return selector(state) + }), + })) + + vi.mocked(useQueryState).mockReturnValue(['discover', vi.fn()]) + + render() + + // Component should still render but without marketplace content when disabled + expect(document.getElementById('marketplace-container')).toBeInTheDocument() + }) + + it('should handle file with empty name', async () => { + vi.mocked(useQueryState).mockReturnValue(['plugins', vi.fn()]) + + render() + const fileInput = document.getElementById('fileUploader') as HTMLInputElement + + const file = new File(['content'], '', { type: 'application/octet-stream' }) + Object.defineProperty(fileInput, 'files', { + value: [file], + }) + + fireEvent.change(fileInput) + + // Should not show modal for file without proper extension + await waitFor(() => { + expect(screen.queryByTestId('install-local-modal')).not.toBeInTheDocument() + }) + }) + + it('should handle no files selected', async () => { + vi.mocked(useQueryState).mockReturnValue(['plugins', vi.fn()]) + + render() + const fileInput = document.getElementById('fileUploader') as HTMLInputElement + + Object.defineProperty(fileInput, 'files', { + value: [], + }) + + fireEvent.change(fileInput) + + // Should not show modal + expect(screen.queryByTestId('install-local-modal')).not.toBeInTheDocument() + }) + }) + + // ============================================================================ + // Cleanup Tests + // ============================================================================ + describe('Cleanup', () => { + it('should reset install state when hiding marketplace modal', async () => { + const mockSetInstallState = vi.fn() + vi.mocked(usePluginInstallation).mockReturnValue([ + { packageId: 'test-package', bundleInfo: null }, + mockSetInstallState, + ]) + + vi.mocked(fetchManifestFromMarketPlace).mockResolvedValue({ + data: { + plugin: { org: 'test-org', name: 'test-plugin', category: 'tool' }, + version: { version: '1.0.0' }, + }, + } as Awaited>) + + render() + + // Wait for modal to appear + await waitFor(() => { + expect(screen.getByTestId('install-marketplace-modal')).toBeInTheDocument() + }, { timeout: 3000 }) + + // Close modal + fireEvent.click(screen.getByText('Close')) + + await waitFor(() => { + expect(mockSetInstallState).toHaveBeenCalledWith(null) + }) + }) + }) + + // ============================================================================ + // Styling Tests + // ============================================================================ + describe('Styling', () => { + it('should apply correct background for plugins tab', () => { + vi.mocked(useQueryState).mockReturnValue(['plugins', vi.fn()]) + + render() + const container = document.getElementById('marketplace-container') + + expect(container).toHaveClass('bg-components-panel-bg') + }) + + it('should apply correct background for marketplace tab', () => { + vi.mocked(useQueryState).mockReturnValue(['discover', vi.fn()]) + + render() + const container = document.getElementById('marketplace-container') + + expect(container).toHaveClass('bg-background-body') + }) + + it('should have scrollbar-gutter stable style', () => { + render() + const container = document.getElementById('marketplace-container') + + expect(container).toHaveStyle({ scrollbarGutter: 'stable' }) + }) + }) +}) + +// ============================================================================ +// Uploader Hook Integration Tests +// ============================================================================ +describe('Uploader Hook Integration', () => { + beforeEach(() => { + vi.clearAllMocks() + vi.mocked(useQueryState).mockReturnValue(['plugins', vi.fn()]) + }) + + describe('Drag Events', () => { + it('should handle dragover event', async () => { + render() + const container = document.getElementById('marketplace-container')! + + const dragOverEvent = new Event('dragover', { bubbles: true, cancelable: true }) + Object.defineProperty(dragOverEvent, 'dataTransfer', { + value: { types: ['Files'] }, + }) + + act(() => { + container.dispatchEvent(dragOverEvent) + }) + + expect(container).toBeInTheDocument() + }) + + it('should handle dragleave event when leaving container', async () => { + render() + const container = document.getElementById('marketplace-container')! + + const dragEnterEvent = new Event('dragenter', { bubbles: true }) + Object.defineProperty(dragEnterEvent, 'dataTransfer', { + value: { types: ['Files'] }, + }) + act(() => { + container.dispatchEvent(dragEnterEvent) + }) + + const dragLeaveEvent = new Event('dragleave', { bubbles: true }) + Object.defineProperty(dragLeaveEvent, 'relatedTarget', { + value: null, + }) + act(() => { + container.dispatchEvent(dragLeaveEvent) + }) + + expect(container).toBeInTheDocument() + }) + + it('should handle dragleave event when moving to element outside container', async () => { + render() + const container = document.getElementById('marketplace-container')! + + const dragEnterEvent = new Event('dragenter', { bubbles: true }) + Object.defineProperty(dragEnterEvent, 'dataTransfer', { + value: { types: ['Files'] }, + }) + act(() => { + container.dispatchEvent(dragEnterEvent) + }) + + const outsideElement = document.createElement('div') + document.body.appendChild(outsideElement) + + const dragLeaveEvent = new Event('dragleave', { bubbles: true }) + Object.defineProperty(dragLeaveEvent, 'relatedTarget', { + value: outsideElement, + }) + act(() => { + container.dispatchEvent(dragLeaveEvent) + }) + + expect(container).toBeInTheDocument() + document.body.removeChild(outsideElement) + }) + + it('should handle drop event with files', async () => { + render() + const container = document.getElementById('marketplace-container')! + + const dragEnterEvent = new Event('dragenter', { bubbles: true }) + Object.defineProperty(dragEnterEvent, 'dataTransfer', { + value: { types: ['Files'] }, + }) + act(() => { + container.dispatchEvent(dragEnterEvent) + }) + + const file = new File(['content'], 'test-plugin.difypkg', { type: 'application/octet-stream' }) + const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) + Object.defineProperty(dropEvent, 'dataTransfer', { + value: { files: [file] }, + }) + + act(() => { + container.dispatchEvent(dropEvent) + }) + + await waitFor(() => { + expect(screen.getByTestId('install-local-modal')).toBeInTheDocument() + }) + }) + + it('should handle drop event without dataTransfer', async () => { + render() + const container = document.getElementById('marketplace-container')! + + const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) + + act(() => { + container.dispatchEvent(dropEvent) + }) + + expect(screen.queryByTestId('install-local-modal')).not.toBeInTheDocument() + }) + + it('should handle drop event with empty files array', async () => { + render() + const container = document.getElementById('marketplace-container')! + + const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) + Object.defineProperty(dropEvent, 'dataTransfer', { + value: { files: [] }, + }) + + act(() => { + container.dispatchEvent(dropEvent) + }) + + expect(screen.queryByTestId('install-local-modal')).not.toBeInTheDocument() + }) + }) + + describe('File Change Handler', () => { + it('should handle file change with null file', async () => { + render() + const fileInput = document.getElementById('fileUploader') as HTMLInputElement + + Object.defineProperty(fileInput, 'files', { value: null }) + + fireEvent.change(fileInput) + + expect(screen.queryByTestId('install-local-modal')).not.toBeInTheDocument() + }) + }) + + describe('Remove File', () => { + it('should clear file input when removeFile is called', async () => { + render() + const fileInput = document.getElementById('fileUploader') as HTMLInputElement + + const file = new File(['content'], 'plugin.difypkg', { type: 'application/octet-stream' }) + Object.defineProperty(fileInput, 'files', { value: [file] }) + fireEvent.change(fileInput) + + await waitFor(() => { + expect(screen.getByTestId('install-local-modal')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByText('Close')) + + await waitFor(() => { + expect(screen.queryByTestId('install-local-modal')).not.toBeInTheDocument() + }) + }) + }) +}) + +// ============================================================================ +// Reference Setting Hook Integration Tests +// ============================================================================ +describe('Reference Setting Hook Integration', () => { + describe('Permission Handling', () => { + it('should render InstallPluginDropdown when permission is everyone', () => { + render() + expect(screen.getByTestId('install-dropdown')).toBeInTheDocument() + }) + + it('should render DebugInfo when permission is admins and user is manager', () => { + render() + expect(screen.getByTestId('debug-info')).toBeInTheDocument() + }) + }) +}) + +// ============================================================================ +// Marketplace Installation Permission Tests +// ============================================================================ +describe('Marketplace Installation Permission', () => { + it('should show InstallPluginDropdown when marketplace is enabled and has permission', () => { + render() + expect(screen.getByTestId('install-dropdown')).toBeInTheDocument() + }) +}) + +// ============================================================================ +// Integration Tests +// ============================================================================ +describe('PluginPage Integration', () => { + beforeEach(() => { + vi.clearAllMocks() + vi.mocked(usePluginInstallation).mockReturnValue([ + { packageId: null, bundleInfo: null }, + vi.fn(), + ]) + vi.mocked(useQueryState).mockReturnValue(['plugins', vi.fn()]) + }) + + it('should render complete plugin page with all features', () => { + render() + + // Check all major elements are present + expect(document.getElementById('marketplace-container')).toBeInTheDocument() + expect(screen.getByTestId('plugin-tasks')).toBeInTheDocument() + expect(screen.getByTestId('install-dropdown')).toBeInTheDocument() + expect(screen.getByTestId('debug-info')).toBeInTheDocument() + expect(screen.getByTestId('plugins-content')).toBeInTheDocument() + }) + + it('should handle full install from marketplace flow', async () => { + const mockSetInstallState = vi.fn() + vi.mocked(usePluginInstallation).mockReturnValue([ + { packageId: 'test-package', bundleInfo: null }, + mockSetInstallState, + ]) + + vi.mocked(fetchManifestFromMarketPlace).mockResolvedValue({ + data: { + plugin: { org: 'langgenius', name: 'test-plugin', category: 'tool' }, + version: { version: '1.0.0' }, + }, + } as Awaited>) + + render() + + // Wait for API call + await waitFor(() => { + expect(fetchManifestFromMarketPlace).toHaveBeenCalled() + }) + + // Wait for modal + await waitFor(() => { + expect(screen.getByTestId('install-marketplace-modal')).toBeInTheDocument() + }, { timeout: 3000 }) + + // Close modal + fireEvent.click(screen.getByText('Close')) + + // Verify state reset + await waitFor(() => { + expect(mockSetInstallState).toHaveBeenCalledWith(null) + }) + }) + + it('should handle full local plugin install flow', async () => { + vi.mocked(useQueryState).mockReturnValue(['plugins', vi.fn()]) + + render() + + const fileInput = document.getElementById('fileUploader') as HTMLInputElement + const file = new File(['plugin content'], 'my-plugin.difypkg', { + type: 'application/octet-stream', + }) + + Object.defineProperty(fileInput, 'files', { value: [file] }) + fireEvent.change(fileInput) + + await waitFor(() => { + expect(screen.getByTestId('install-local-modal')).toBeInTheDocument() + }) + + // Close modal (triggers removeFile via onClose) + fireEvent.click(screen.getByText('Close')) + + await waitFor(() => { + expect(screen.queryByTestId('install-local-modal')).not.toBeInTheDocument() + }) + }) + + it('should render marketplace content only when enable_marketplace is true', () => { + vi.mocked(useQueryState).mockReturnValue(['discover', vi.fn()]) + + const { rerender } = render() + + // With enable_marketplace: true (default mock), marketplace links should show + expect(screen.getByText(/requestAPlugin/i)).toBeInTheDocument() + + // Rerender to verify consistent behavior + rerender() + expect(screen.getByText(/publishPlugins/i)).toBeInTheDocument() + }) +}) diff --git a/web/app/components/plugins/plugin-page/index.tsx b/web/app/components/plugins/plugin-page/index.tsx index 1f88f691ef..d852e4d0b8 100644 --- a/web/app/components/plugins/plugin-page/index.tsx +++ b/web/app/components/plugins/plugin-page/index.tsx @@ -207,6 +207,7 @@ const PluginPage = ({ popupContent={t('privilege.title', { ns: 'plugin' })} > + )} + /> + )} + + {/* Error Plugins Section */} + {errorPlugins.length > 0 && ( + + } + defaultStatusText={t('task.installError', { ns: 'plugin', errorLength: errorPlugins.length })} + statusClassName="text-text-destructive break-all" + headerAction={( + + )} + renderItemAction={plugin => ( + + )} + /> + )} + + ) +} + +export default PluginTaskList diff --git a/web/app/components/plugins/plugin-page/plugin-tasks/components/task-status-indicator.tsx b/web/app/components/plugins/plugin-page/plugin-tasks/components/task-status-indicator.tsx new file mode 100644 index 0000000000..084c8f90f9 --- /dev/null +++ b/web/app/components/plugins/plugin-page/plugin-tasks/components/task-status-indicator.tsx @@ -0,0 +1,96 @@ +import type { FC } from 'react' +import { + RiCheckboxCircleFill, + RiErrorWarningFill, + RiInstallLine, +} from '@remixicon/react' +import ProgressCircle from '@/app/components/base/progress-bar/progress-circle' +import Tooltip from '@/app/components/base/tooltip' +import DownloadingIcon from '@/app/components/header/plugins-nav/downloading-icon' +import { cn } from '@/utils/classnames' + +export type TaskStatusIndicatorProps = { + tip: string + isInstalling: boolean + isInstallingWithSuccess: boolean + isInstallingWithError: boolean + isSuccess: boolean + isFailed: boolean + successPluginsLength: number + runningPluginsLength: number + totalPluginsLength: number + onClick: () => void +} + +const TaskStatusIndicator: FC = ({ + tip, + isInstalling, + isInstallingWithSuccess, + isInstallingWithError, + isSuccess, + isFailed, + successPluginsLength, + runningPluginsLength, + totalPluginsLength, + onClick, +}) => { + const showDownloadingIcon = isInstalling || isInstallingWithError + const showErrorStyle = isInstallingWithError || isFailed + const showSuccessIcon = isSuccess || (successPluginsLength > 0 && runningPluginsLength === 0) + + return ( + +
+ {/* Main Icon */} + {showDownloadingIcon + ? + : ( + + )} + + {/* Status Indicator Badge */} +
+ {(isInstalling || isInstallingWithSuccess) && ( + 0 ? successPluginsLength / totalPluginsLength : 0) * 100} + circleFillColor="fill-components-progress-brand-bg" + /> + )} + {isInstallingWithError && ( + 0 ? runningPluginsLength / totalPluginsLength : 0) * 100} + circleFillColor="fill-components-progress-brand-bg" + sectorFillColor="fill-components-progress-error-border" + circleStrokeColor="stroke-components-progress-error-border" + /> + )} + {showSuccessIcon && !isInstalling && !isInstallingWithSuccess && !isInstallingWithError && ( + + )} + {isFailed && ( + + )} +
+
+
+ ) +} + +export default TaskStatusIndicator diff --git a/web/app/components/plugins/plugin-page/plugin-tasks/index.spec.tsx b/web/app/components/plugins/plugin-page/plugin-tasks/index.spec.tsx new file mode 100644 index 0000000000..32892cbe28 --- /dev/null +++ b/web/app/components/plugins/plugin-page/plugin-tasks/index.spec.tsx @@ -0,0 +1,856 @@ +import type { PluginStatus } from '@/app/components/plugins/types' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { TaskStatus } from '@/app/components/plugins/types' +// Import mocked modules +import { useMutationClearTaskPlugin, usePluginTaskList } from '@/service/use-plugins' +import PluginTaskList from './components/plugin-task-list' +import TaskStatusIndicator from './components/task-status-indicator' +import { usePluginTaskStatus } from './hooks' + +import PluginTasks from './index' + +// Mock external dependencies +vi.mock('@/service/use-plugins', () => ({ + usePluginTaskList: vi.fn(), + useMutationClearTaskPlugin: vi.fn(), +})) + +vi.mock('@/app/components/plugins/install-plugin/base/use-get-icon', () => ({ + default: () => ({ + getIconUrl: (icon: string) => `https://example.com/${icon}`, + }), +})) + +vi.mock('@/context/i18n', () => ({ + useGetLanguage: () => 'en_US', +})) + +// Helper to create mock plugin +const createMockPlugin = (overrides: Partial = {}): PluginStatus => ({ + plugin_unique_identifier: `plugin-${Math.random().toString(36).substr(2, 9)}`, + plugin_id: 'test-plugin', + status: TaskStatus.running, + message: '', + icon: 'test-icon.png', + labels: { + en_US: 'Test Plugin', + zh_Hans: 'ๆต‹่ฏ•ๆ’ไปถ', + } as Record, + taskId: 'task-1', + ...overrides, +}) + +// Helper to setup mock hook returns +const setupMocks = (plugins: PluginStatus[] = []) => { + const mockMutateAsync = vi.fn().mockResolvedValue({}) + const mockHandleRefetch = vi.fn() + + vi.mocked(usePluginTaskList).mockReturnValue({ + pluginTasks: plugins.length > 0 + ? [{ id: 'task-1', plugins, created_at: '', updated_at: '', status: 'running', total_plugins: plugins.length, completed_plugins: 0 }] + : [], + handleRefetch: mockHandleRefetch, + } as any) + + vi.mocked(useMutationClearTaskPlugin).mockReturnValue({ + mutateAsync: mockMutateAsync, + } as any) + + return { mockMutateAsync, mockHandleRefetch } +} + +// ============================================================================ +// usePluginTaskStatus Hook Tests +// ============================================================================ +describe('usePluginTaskStatus Hook', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('Plugin categorization', () => { + it('should categorize running plugins correctly', () => { + const runningPlugin = createMockPlugin({ status: TaskStatus.running }) + setupMocks([runningPlugin]) + + const TestComponent = () => { + const { runningPlugins, runningPluginsLength } = usePluginTaskStatus() + return ( +
+ {runningPluginsLength} + {runningPlugins[0]?.plugin_unique_identifier} +
+ ) + } + + render() + + expect(screen.getByTestId('running-count')).toHaveTextContent('1') + expect(screen.getByTestId('running-id')).toHaveTextContent(runningPlugin.plugin_unique_identifier) + }) + + it('should categorize success plugins correctly', () => { + const successPlugin = createMockPlugin({ status: TaskStatus.success }) + setupMocks([successPlugin]) + + const TestComponent = () => { + const { successPlugins, successPluginsLength } = usePluginTaskStatus() + return ( +
+ {successPluginsLength} + {successPlugins[0]?.plugin_unique_identifier} +
+ ) + } + + render() + + expect(screen.getByTestId('success-count')).toHaveTextContent('1') + expect(screen.getByTestId('success-id')).toHaveTextContent(successPlugin.plugin_unique_identifier) + }) + + it('should categorize error plugins correctly', () => { + const errorPlugin = createMockPlugin({ status: TaskStatus.failed, message: 'Install failed' }) + setupMocks([errorPlugin]) + + const TestComponent = () => { + const { errorPlugins, errorPluginsLength } = usePluginTaskStatus() + return ( +
+ {errorPluginsLength} + {errorPlugins[0]?.plugin_unique_identifier} +
+ ) + } + + render() + + expect(screen.getByTestId('error-count')).toHaveTextContent('1') + expect(screen.getByTestId('error-id')).toHaveTextContent(errorPlugin.plugin_unique_identifier) + }) + + it('should categorize mixed plugins correctly', () => { + const plugins = [ + createMockPlugin({ status: TaskStatus.running, plugin_unique_identifier: 'running-1' }), + createMockPlugin({ status: TaskStatus.success, plugin_unique_identifier: 'success-1' }), + createMockPlugin({ status: TaskStatus.failed, plugin_unique_identifier: 'error-1' }), + ] + setupMocks(plugins) + + const TestComponent = () => { + const { runningPluginsLength, successPluginsLength, errorPluginsLength, totalPluginsLength } = usePluginTaskStatus() + return ( +
+ {runningPluginsLength} + {successPluginsLength} + {errorPluginsLength} + {totalPluginsLength} +
+ ) + } + + render() + + expect(screen.getByTestId('running')).toHaveTextContent('1') + expect(screen.getByTestId('success')).toHaveTextContent('1') + expect(screen.getByTestId('error')).toHaveTextContent('1') + expect(screen.getByTestId('total')).toHaveTextContent('3') + }) + }) + + describe('Status flags', () => { + it('should set isInstalling when only running plugins exist', () => { + setupMocks([createMockPlugin({ status: TaskStatus.running })]) + + const TestComponent = () => { + const { isInstalling, isInstallingWithSuccess, isInstallingWithError, isSuccess, isFailed } = usePluginTaskStatus() + return ( +
+ {String(isInstalling)} + {String(isInstallingWithSuccess)} + {String(isInstallingWithError)} + {String(isSuccess)} + {String(isFailed)} +
+ ) + } + + render() + + expect(screen.getByTestId('isInstalling')).toHaveTextContent('true') + expect(screen.getByTestId('isInstallingWithSuccess')).toHaveTextContent('false') + expect(screen.getByTestId('isInstallingWithError')).toHaveTextContent('false') + expect(screen.getByTestId('isSuccess')).toHaveTextContent('false') + expect(screen.getByTestId('isFailed')).toHaveTextContent('false') + }) + + it('should set isInstallingWithSuccess when running and success plugins exist', () => { + setupMocks([ + createMockPlugin({ status: TaskStatus.running }), + createMockPlugin({ status: TaskStatus.success }), + ]) + + const TestComponent = () => { + const { isInstallingWithSuccess } = usePluginTaskStatus() + return {String(isInstallingWithSuccess)} + } + + render() + expect(screen.getByTestId('flag')).toHaveTextContent('true') + }) + + it('should set isInstallingWithError when running and error plugins exist', () => { + setupMocks([ + createMockPlugin({ status: TaskStatus.running }), + createMockPlugin({ status: TaskStatus.failed }), + ]) + + const TestComponent = () => { + const { isInstallingWithError } = usePluginTaskStatus() + return {String(isInstallingWithError)} + } + + render() + expect(screen.getByTestId('flag')).toHaveTextContent('true') + }) + + it('should set isSuccess when all plugins succeeded', () => { + setupMocks([ + createMockPlugin({ status: TaskStatus.success }), + createMockPlugin({ status: TaskStatus.success }), + ]) + + const TestComponent = () => { + const { isSuccess } = usePluginTaskStatus() + return {String(isSuccess)} + } + + render() + expect(screen.getByTestId('flag')).toHaveTextContent('true') + }) + + it('should set isFailed when no running plugins and some failed', () => { + setupMocks([ + createMockPlugin({ status: TaskStatus.success }), + createMockPlugin({ status: TaskStatus.failed }), + ]) + + const TestComponent = () => { + const { isFailed } = usePluginTaskStatus() + return {String(isFailed)} + } + + render() + expect(screen.getByTestId('flag')).toHaveTextContent('true') + }) + }) + + describe('handleClearErrorPlugin', () => { + it('should call mutateAsync and handleRefetch', async () => { + const { mockMutateAsync, mockHandleRefetch } = setupMocks([ + createMockPlugin({ status: TaskStatus.failed }), + ]) + + const TestComponent = () => { + const { handleClearErrorPlugin } = usePluginTaskStatus() + return ( + + ) + } + + render() + fireEvent.click(screen.getByRole('button')) + + await waitFor(() => { + expect(mockMutateAsync).toHaveBeenCalledWith({ + taskId: 'task-1', + pluginId: 'plugin-1', + }) + expect(mockHandleRefetch).toHaveBeenCalled() + }) + }) + }) +}) + +// ============================================================================ +// TaskStatusIndicator Component Tests +// ============================================================================ +describe('TaskStatusIndicator Component', () => { + const defaultProps = { + tip: 'Test tooltip', + isInstalling: false, + isInstallingWithSuccess: false, + isInstallingWithError: false, + isSuccess: false, + isFailed: false, + successPluginsLength: 0, + runningPluginsLength: 0, + totalPluginsLength: 1, + onClick: vi.fn(), + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('Rendering', () => { + it('should render without crashing', () => { + render() + expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + }) + + it('should render with correct id', () => { + render() + expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + }) + }) + + describe('Icon display', () => { + it('should show downloading icon when installing', () => { + render() + // DownloadingIcon is rendered when isInstalling is true + expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + }) + + it('should show downloading icon when installing with error', () => { + render() + expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + }) + + it('should show install icon when not installing', () => { + render() + expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + }) + }) + + describe('Status badge', () => { + it('should show progress circle when installing', () => { + render( + , + ) + expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + }) + + it('should show progress circle when installing with success', () => { + render( + , + ) + expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + }) + + it('should show error progress circle when installing with error', () => { + render( + , + ) + expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + }) + + it('should show success icon when all completed successfully', () => { + render( + , + ) + expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + }) + + it('should show error icon when failed', () => { + render() + expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + }) + }) + + describe('Styling', () => { + it('should apply error styles when installing with error', () => { + render() + const trigger = document.getElementById('plugin-task-trigger') + expect(trigger).toHaveClass('bg-state-destructive-hover') + }) + + it('should apply error styles when failed', () => { + render() + const trigger = document.getElementById('plugin-task-trigger') + expect(trigger).toHaveClass('bg-state-destructive-hover') + }) + + it('should apply cursor-pointer when clickable', () => { + render() + const trigger = document.getElementById('plugin-task-trigger') + expect(trigger).toHaveClass('cursor-pointer') + }) + }) + + describe('User interactions', () => { + it('should call onClick when clicked', () => { + const handleClick = vi.fn() + render() + + fireEvent.click(document.getElementById('plugin-task-trigger')!) + + expect(handleClick).toHaveBeenCalledTimes(1) + }) + }) +}) + +// ============================================================================ +// PluginTaskList Component Tests +// ============================================================================ +describe('PluginTaskList Component', () => { + const defaultProps = { + runningPlugins: [] as PluginStatus[], + successPlugins: [] as PluginStatus[], + errorPlugins: [] as PluginStatus[], + getIconUrl: (icon: string) => `https://example.com/${icon}`, + onClearAll: vi.fn(), + onClearErrors: vi.fn(), + onClearSingle: vi.fn(), + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('Rendering', () => { + it('should render without crashing with empty lists', () => { + render() + expect(document.querySelector('.w-\\[360px\\]')).toBeInTheDocument() + }) + + it('should render running plugins section when plugins exist', () => { + const runningPlugins = [createMockPlugin({ status: TaskStatus.running })] + render() + + // Translation key is returned as text in tests, multiple matches expected (title + status) + expect(screen.getAllByText(/task\.installing/i).length).toBeGreaterThan(0) + // Verify section container is rendered + expect(document.querySelector('.max-h-\\[200px\\]')).toBeInTheDocument() + }) + + it('should render success plugins section when plugins exist', () => { + const successPlugins = [createMockPlugin({ status: TaskStatus.success })] + render() + + // Translation key is returned as text in tests, multiple matches expected + expect(screen.getAllByText(/task\.installed/i).length).toBeGreaterThan(0) + }) + + it('should render error plugins section when plugins exist', () => { + const errorPlugins = [createMockPlugin({ status: TaskStatus.failed, message: 'Error occurred' })] + render() + + expect(screen.getByText('Error occurred')).toBeInTheDocument() + }) + + it('should render all sections when all types exist', () => { + render( + , + ) + + // All sections should be present + expect(document.querySelectorAll('.max-h-\\[200px\\]').length).toBe(3) + }) + }) + + describe('User interactions', () => { + it('should call onClearAll when clear all button is clicked in success section', () => { + const handleClearAll = vi.fn() + const successPlugins = [createMockPlugin({ status: TaskStatus.success })] + + render( + , + ) + + fireEvent.click(screen.getByRole('button', { name: /task\.clearAll/i })) + + expect(handleClearAll).toHaveBeenCalledTimes(1) + }) + + it('should call onClearErrors when clear all button is clicked in error section', () => { + const handleClearErrors = vi.fn() + const errorPlugins = [createMockPlugin({ status: TaskStatus.failed })] + + render( + , + ) + + const clearButtons = screen.getAllByRole('button') + fireEvent.click(clearButtons.find(btn => btn.textContent?.includes('task.clearAll'))!) + + expect(handleClearErrors).toHaveBeenCalledTimes(1) + }) + + it('should call onClearSingle with correct args when individual clear is clicked', () => { + const handleClearSingle = vi.fn() + const errorPlugin = createMockPlugin({ + status: TaskStatus.failed, + plugin_unique_identifier: 'error-plugin-1', + taskId: 'task-123', + }) + + render( + , + ) + + // The individual clear button has the text 'operation.clear' + fireEvent.click(screen.getByRole('button', { name: /operation\.clear/i })) + + expect(handleClearSingle).toHaveBeenCalledWith('task-123', 'error-plugin-1') + }) + }) + + describe('Plugin display', () => { + it('should display plugin name from labels', () => { + const plugin = createMockPlugin({ + status: TaskStatus.running, + labels: { en_US: 'My Test Plugin' } as Record, + }) + + render() + + expect(screen.getByText('My Test Plugin')).toBeInTheDocument() + }) + + it('should display plugin message when available', () => { + const plugin = createMockPlugin({ + status: TaskStatus.success, + message: 'Successfully installed!', + }) + + render() + + expect(screen.getByText('Successfully installed!')).toBeInTheDocument() + }) + + it('should display multiple plugins in each section', () => { + const runningPlugins = [ + createMockPlugin({ status: TaskStatus.running, labels: { en_US: 'Plugin A' } as Record }), + createMockPlugin({ status: TaskStatus.running, labels: { en_US: 'Plugin B' } as Record }), + ] + + render() + + expect(screen.getByText('Plugin A')).toBeInTheDocument() + expect(screen.getByText('Plugin B')).toBeInTheDocument() + // Count is rendered, verify multiple items are in list + expect(document.querySelectorAll('.hover\\:bg-state-base-hover').length).toBe(2) + }) + }) +}) + +// ============================================================================ +// PluginTasks Main Component Tests +// ============================================================================ +describe('PluginTasks Component', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('Rendering', () => { + it('should return null when no plugins exist', () => { + setupMocks([]) + + const { container } = render() + + expect(container.firstChild).toBeNull() + }) + + it('should render when plugins exist', () => { + setupMocks([createMockPlugin({ status: TaskStatus.running })]) + + render() + + expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + }) + }) + + describe('Tooltip text (tip memoization)', () => { + it('should show installing tip when isInstalling', () => { + setupMocks([createMockPlugin({ status: TaskStatus.running })]) + + render() + + // The component renders with a tooltip, we verify it exists + expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + }) + + it('should show success tip when all succeeded', () => { + setupMocks([createMockPlugin({ status: TaskStatus.success })]) + + render() + + expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + }) + + it('should show error tip when some failed', () => { + setupMocks([ + createMockPlugin({ status: TaskStatus.success }), + createMockPlugin({ status: TaskStatus.failed }), + ]) + + render() + + expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + }) + }) + + describe('Popover interaction', () => { + it('should toggle popover when trigger is clicked and status allows', () => { + setupMocks([createMockPlugin({ status: TaskStatus.running })]) + + render() + + // Click to open + fireEvent.click(document.getElementById('plugin-task-trigger')!) + + // The popover content should be visible (PluginTaskList) + expect(document.querySelector('.w-\\[360px\\]')).toBeInTheDocument() + }) + + it('should not toggle when status does not allow', () => { + // Setup with no actionable status (edge case - should not happen in practice) + setupMocks([createMockPlugin({ status: TaskStatus.running })]) + + render() + + // Component should still render + expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + }) + }) + + describe('Clear handlers', () => { + it('should clear all completed plugins when onClearAll is called', async () => { + const { mockMutateAsync } = setupMocks([ + createMockPlugin({ status: TaskStatus.success, plugin_unique_identifier: 'success-1' }), + createMockPlugin({ status: TaskStatus.failed, plugin_unique_identifier: 'error-1' }), + ]) + + render() + + // Open popover + fireEvent.click(document.getElementById('plugin-task-trigger')!) + + // Wait for popover content to render + await waitFor(() => { + expect(document.querySelector('.w-\\[360px\\]')).toBeInTheDocument() + }) + + // Find and click clear all button + const clearButtons = screen.getAllByRole('button') + const clearAllButton = clearButtons.find(btn => btn.textContent?.includes('clearAll')) + if (clearAllButton) + fireEvent.click(clearAllButton) + + // Verify mutateAsync was called for each completed plugin + await waitFor(() => { + expect(mockMutateAsync).toHaveBeenCalled() + }) + }) + + it('should clear only error plugins when onClearErrors is called', async () => { + const { mockMutateAsync } = setupMocks([ + createMockPlugin({ status: TaskStatus.failed, plugin_unique_identifier: 'error-1' }), + ]) + + render() + + // Open popover + fireEvent.click(document.getElementById('plugin-task-trigger')!) + + await waitFor(() => { + expect(document.querySelector('.w-\\[360px\\]')).toBeInTheDocument() + }) + + // Find and click the clear all button in error section + const clearButtons = screen.getAllByRole('button') + if (clearButtons.length > 0) + fireEvent.click(clearButtons[0]) + + await waitFor(() => { + expect(mockMutateAsync).toHaveBeenCalled() + }) + }) + + it('should clear single plugin when onClearSingle is called', async () => { + const { mockMutateAsync } = setupMocks([ + createMockPlugin({ + status: TaskStatus.failed, + plugin_unique_identifier: 'error-plugin', + taskId: 'task-1', + }), + ]) + + render() + + // Open popover + fireEvent.click(document.getElementById('plugin-task-trigger')!) + + await waitFor(() => { + expect(document.querySelector('.w-\\[360px\\]')).toBeInTheDocument() + }) + + // Find and click individual clear button (usually the last one) + const clearButtons = screen.getAllByRole('button') + const individualClearButton = clearButtons[clearButtons.length - 1] + fireEvent.click(individualClearButton) + + await waitFor(() => { + expect(mockMutateAsync).toHaveBeenCalledWith({ + taskId: 'task-1', + pluginId: 'error-plugin', + }) + }) + }) + }) + + describe('Edge cases', () => { + it('should handle empty plugin tasks array', () => { + setupMocks([]) + + const { container } = render() + + expect(container.firstChild).toBeNull() + }) + + it('should handle single running plugin', () => { + setupMocks([createMockPlugin({ status: TaskStatus.running })]) + + render() + + expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + }) + + it('should handle many plugins', () => { + const manyPlugins = Array.from({ length: 10 }, (_, i) => + createMockPlugin({ + status: i % 3 === 0 ? TaskStatus.running : i % 3 === 1 ? TaskStatus.success : TaskStatus.failed, + plugin_unique_identifier: `plugin-${i}`, + })) + setupMocks(manyPlugins) + + render() + + expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + }) + + it('should handle plugins with empty labels', () => { + const plugin = createMockPlugin({ + status: TaskStatus.running, + labels: {} as Record, + }) + setupMocks([plugin]) + + render() + + expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + }) + + it('should handle plugins with long messages', () => { + const plugin = createMockPlugin({ + status: TaskStatus.failed, + message: 'A'.repeat(500), + }) + setupMocks([plugin]) + + render() + + // Open popover + fireEvent.click(document.getElementById('plugin-task-trigger')!) + + expect(document.querySelector('.w-\\[360px\\]')).toBeInTheDocument() + }) + }) +}) + +// ============================================================================ +// Integration Tests +// ============================================================================ +describe('PluginTasks Integration', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should show correct UI flow from installing to success', async () => { + // Start with installing state + setupMocks([createMockPlugin({ status: TaskStatus.running })]) + + const { rerender } = render() + + expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + + // Simulate completion by re-rendering with success + setupMocks([createMockPlugin({ status: TaskStatus.success })]) + rerender() + + expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + }) + + it('should show correct UI flow from installing to failure', async () => { + // Start with installing state + setupMocks([createMockPlugin({ status: TaskStatus.running })]) + + const { rerender } = render() + + expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + + // Simulate failure by re-rendering with failed + setupMocks([createMockPlugin({ status: TaskStatus.failed, message: 'Network error' })]) + rerender() + + expect(document.getElementById('plugin-task-trigger')).toBeInTheDocument() + }) + + it('should handle mixed status during installation', () => { + setupMocks([ + createMockPlugin({ status: TaskStatus.running, plugin_unique_identifier: 'p1' }), + createMockPlugin({ status: TaskStatus.success, plugin_unique_identifier: 'p2' }), + createMockPlugin({ status: TaskStatus.failed, plugin_unique_identifier: 'p3' }), + ]) + + render() + + // Open popover + fireEvent.click(document.getElementById('plugin-task-trigger')!) + + // All sections should be visible + const sections = document.querySelectorAll('.max-h-\\[200px\\]') + expect(sections.length).toBe(3) + }) +}) diff --git a/web/app/components/plugins/plugin-page/plugin-tasks/index.tsx b/web/app/components/plugins/plugin-page/plugin-tasks/index.tsx index 40dd4fedb1..45f1dce86b 100644 --- a/web/app/components/plugins/plugin-page/plugin-tasks/index.tsx +++ b/web/app/components/plugins/plugin-page/plugin-tasks/index.tsx @@ -1,33 +1,21 @@ -import { - RiCheckboxCircleFill, - RiErrorWarningFill, - RiInstallLine, - RiLoaderLine, -} from '@remixicon/react' import { useCallback, useMemo, useState, } from 'react' import { useTranslation } from 'react-i18next' -import Button from '@/app/components/base/button' import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger, } from '@/app/components/base/portal-to-follow-elem' -import ProgressCircle from '@/app/components/base/progress-bar/progress-circle' -import Tooltip from '@/app/components/base/tooltip' -import DownloadingIcon from '@/app/components/header/plugins-nav/downloading-icon' -import CardIcon from '@/app/components/plugins/card/base/card-icon' import useGetIcon from '@/app/components/plugins/install-plugin/base/use-get-icon' -import { useGetLanguage } from '@/context/i18n' -import { cn } from '@/utils/classnames' +import PluginTaskList from './components/plugin-task-list' +import TaskStatusIndicator from './components/task-status-indicator' import { usePluginTaskStatus } from './hooks' const PluginTasks = () => { const { t } = useTranslation() - const language = useGetLanguage() const [open, setOpen] = useState(false) const { errorPlugins, @@ -46,35 +34,7 @@ const PluginTasks = () => { } = usePluginTaskStatus() const { getIconUrl } = useGetIcon() - const handleClearAllWithModal = useCallback(async () => { - // Clear all completed plugins (success and error) but keep running ones - const completedPlugins = [...successPlugins, ...errorPlugins] - - // Clear all completed plugins individually - for (const plugin of completedPlugins) - await handleClearErrorPlugin(plugin.taskId, plugin.plugin_unique_identifier) - - // Only close modal if no plugins are still installing - if (runningPluginsLength === 0) - setOpen(false) - }, [successPlugins, errorPlugins, handleClearErrorPlugin, runningPluginsLength]) - - const handleClearErrorsWithModal = useCallback(async () => { - // Clear only error plugins, not all plugins - for (const plugin of errorPlugins) - await handleClearErrorPlugin(plugin.taskId, plugin.plugin_unique_identifier) - // Only close modal if no plugins are still installing - if (runningPluginsLength === 0) - setOpen(false) - }, [errorPlugins, handleClearErrorPlugin, runningPluginsLength]) - - const handleClearSingleWithModal = useCallback(async (taskId: string, pluginId: string) => { - await handleClearErrorPlugin(taskId, pluginId) - // Only close modal if no plugins are still installing - if (runningPluginsLength === 0) - setOpen(false) - }, [handleClearErrorPlugin, runningPluginsLength]) - + // Generate tooltip text based on status const tip = useMemo(() => { if (isInstallingWithError) return t('task.installingWithError', { ns: 'plugin', installingLength: runningPluginsLength, successLength: successPluginsLength, errorLength: errorPluginsLength }) @@ -99,8 +59,38 @@ const PluginTasks = () => { t, ]) - // Show icon if there are any plugin tasks (completed, running, or failed) - // Only hide when there are absolutely no plugin tasks + // Generic clear function that handles clearing and modal closing + const clearPluginsAndClose = useCallback(async ( + plugins: Array<{ taskId: string, plugin_unique_identifier: string }>, + ) => { + for (const plugin of plugins) + await handleClearErrorPlugin(plugin.taskId, plugin.plugin_unique_identifier) + if (runningPluginsLength === 0) + setOpen(false) + }, [handleClearErrorPlugin, runningPluginsLength]) + + // Clear handlers using the generic function + const handleClearAll = useCallback( + () => clearPluginsAndClose([...successPlugins, ...errorPlugins]), + [clearPluginsAndClose, successPlugins, errorPlugins], + ) + + const handleClearErrors = useCallback( + () => clearPluginsAndClose(errorPlugins), + [clearPluginsAndClose, errorPlugins], + ) + + const handleClearSingle = useCallback( + (taskId: string, pluginId: string) => clearPluginsAndClose([{ taskId, plugin_unique_identifier: pluginId }]), + [clearPluginsAndClose], + ) + + const handleTriggerClick = useCallback(() => { + if (isFailed || isInstalling || isInstallingWithSuccess || isInstallingWithError || isSuccess) + setOpen(v => !v) + }, [isFailed, isInstalling, isInstallingWithSuccess, isInstallingWithError, isSuccess]) + + // Hide when no plugin tasks if (totalPluginsLength === 0) return null @@ -115,206 +105,30 @@ const PluginTasks = () => { crossAxis: 79, }} > - { - if (isFailed || isInstalling || isInstallingWithSuccess || isInstallingWithError || isSuccess) - setOpen(v => !v) - }} - > - -
- { - (isInstalling || isInstallingWithError) && ( - - ) - } - { - !(isInstalling || isInstallingWithError) && ( - - ) - } -
- { - (isInstalling || isInstallingWithSuccess) && ( - - ) - } - { - isInstallingWithError && ( - - ) - } - { - (isSuccess || (successPluginsLength > 0 && runningPluginsLength === 0 && errorPluginsLength === 0)) && ( - - ) - } - { - isFailed && ( - - ) - } -
-
-
+ + {}} + /> -
- {/* Running Plugins */} - {runningPlugins.length > 0 && ( - <> -
- {t('task.installing', { ns: 'plugin' })} - {' '} - ( - {runningPlugins.length} - ) -
-
- {runningPlugins.map(runningPlugin => ( -
-
- - -
-
-
- {runningPlugin.labels[language]} -
-
- {t('task.installing', { ns: 'plugin' })} -
-
-
- ))} -
- - )} - - {/* Success Plugins */} - {successPlugins.length > 0 && ( - <> -
- {t('task.installed', { ns: 'plugin' })} - {' '} - ( - {successPlugins.length} - ) - -
-
- {successPlugins.map(successPlugin => ( -
-
- - -
-
-
- {successPlugin.labels[language]} -
-
- {successPlugin.message || t('task.installed', { ns: 'plugin' })} -
-
-
- ))} -
- - )} - - {/* Error Plugins */} - {errorPlugins.length > 0 && ( - <> -
- {t('task.installError', { ns: 'plugin', errorLength: errorPlugins.length })} - -
-
- {errorPlugins.map(errorPlugin => ( -
-
- - -
-
-
- {errorPlugin.labels[language]} -
-
- {errorPlugin.message} -
-
- -
- ))} -
- - )} -
+
diff --git a/web/app/components/plugins/plugin-page/use-reference-setting.spec.ts b/web/app/components/plugins/plugin-page/use-reference-setting.spec.ts new file mode 100644 index 0000000000..9f64d3fac5 --- /dev/null +++ b/web/app/components/plugins/plugin-page/use-reference-setting.spec.ts @@ -0,0 +1,388 @@ +import { renderHook, waitFor } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +// Import mocks for assertions +import { useAppContext } from '@/context/app-context' +import { useGlobalPublicStore } from '@/context/global-public-context' + +import { useInvalidateReferenceSettings, useMutationReferenceSettings, useReferenceSettings } from '@/service/use-plugins' +import Toast from '../../base/toast' +import { PermissionType } from '../types' +import useReferenceSetting, { useCanInstallPluginFromMarketplace } from './use-reference-setting' + +// Mock dependencies +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +vi.mock('@/context/app-context', () => ({ + useAppContext: vi.fn(), +})) + +vi.mock('@/context/global-public-context', () => ({ + useGlobalPublicStore: vi.fn(), +})) + +vi.mock('@/service/use-plugins', () => ({ + useReferenceSettings: vi.fn(), + useMutationReferenceSettings: vi.fn(), + useInvalidateReferenceSettings: vi.fn(), +})) + +vi.mock('../../base/toast', () => ({ + default: { + notify: vi.fn(), + }, +})) + +describe('useReferenceSetting Hook', () => { + beforeEach(() => { + vi.clearAllMocks() + + // Default mocks + vi.mocked(useAppContext).mockReturnValue({ + isCurrentWorkspaceManager: false, + isCurrentWorkspaceOwner: false, + } as ReturnType) + + vi.mocked(useReferenceSettings).mockReturnValue({ + data: { + permission: { + install_permission: PermissionType.everyone, + debug_permission: PermissionType.everyone, + }, + }, + } as ReturnType) + + vi.mocked(useMutationReferenceSettings).mockReturnValue({ + mutate: vi.fn(), + isPending: false, + } as unknown as ReturnType) + + vi.mocked(useInvalidateReferenceSettings).mockReturnValue(vi.fn()) + }) + + describe('hasPermission logic', () => { + it('should return false when permission is undefined', () => { + vi.mocked(useReferenceSettings).mockReturnValue({ + data: { + permission: { + install_permission: undefined, + debug_permission: undefined, + }, + }, + } as unknown as ReturnType) + + const { result } = renderHook(() => useReferenceSetting()) + + expect(result.current.canManagement).toBe(false) + expect(result.current.canDebugger).toBe(false) + }) + + it('should return false when permission is noOne', () => { + vi.mocked(useReferenceSettings).mockReturnValue({ + data: { + permission: { + install_permission: PermissionType.noOne, + debug_permission: PermissionType.noOne, + }, + }, + } as ReturnType) + + const { result } = renderHook(() => useReferenceSetting()) + + expect(result.current.canManagement).toBe(false) + expect(result.current.canDebugger).toBe(false) + }) + + it('should return true when permission is everyone', () => { + vi.mocked(useReferenceSettings).mockReturnValue({ + data: { + permission: { + install_permission: PermissionType.everyone, + debug_permission: PermissionType.everyone, + }, + }, + } as ReturnType) + + const { result } = renderHook(() => useReferenceSetting()) + + expect(result.current.canManagement).toBe(true) + expect(result.current.canDebugger).toBe(true) + }) + + it('should return isAdmin when permission is admin and user is manager', () => { + vi.mocked(useAppContext).mockReturnValue({ + isCurrentWorkspaceManager: true, + isCurrentWorkspaceOwner: false, + } as ReturnType) + + vi.mocked(useReferenceSettings).mockReturnValue({ + data: { + permission: { + install_permission: PermissionType.admin, + debug_permission: PermissionType.admin, + }, + }, + } as ReturnType) + + const { result } = renderHook(() => useReferenceSetting()) + + expect(result.current.canManagement).toBe(true) + expect(result.current.canDebugger).toBe(true) + }) + + it('should return isAdmin when permission is admin and user is owner', () => { + vi.mocked(useAppContext).mockReturnValue({ + isCurrentWorkspaceManager: false, + isCurrentWorkspaceOwner: true, + } as ReturnType) + + vi.mocked(useReferenceSettings).mockReturnValue({ + data: { + permission: { + install_permission: PermissionType.admin, + debug_permission: PermissionType.admin, + }, + }, + } as ReturnType) + + const { result } = renderHook(() => useReferenceSetting()) + + expect(result.current.canManagement).toBe(true) + expect(result.current.canDebugger).toBe(true) + }) + + it('should return false when permission is admin and user is not admin', () => { + vi.mocked(useAppContext).mockReturnValue({ + isCurrentWorkspaceManager: false, + isCurrentWorkspaceOwner: false, + } as ReturnType) + + vi.mocked(useReferenceSettings).mockReturnValue({ + data: { + permission: { + install_permission: PermissionType.admin, + debug_permission: PermissionType.admin, + }, + }, + } as ReturnType) + + const { result } = renderHook(() => useReferenceSetting()) + + expect(result.current.canManagement).toBe(false) + expect(result.current.canDebugger).toBe(false) + }) + }) + + describe('canSetPermissions', () => { + it('should be true when user is workspace manager', () => { + vi.mocked(useAppContext).mockReturnValue({ + isCurrentWorkspaceManager: true, + isCurrentWorkspaceOwner: false, + } as ReturnType) + + const { result } = renderHook(() => useReferenceSetting()) + + expect(result.current.canSetPermissions).toBe(true) + }) + + it('should be true when user is workspace owner', () => { + vi.mocked(useAppContext).mockReturnValue({ + isCurrentWorkspaceManager: false, + isCurrentWorkspaceOwner: true, + } as ReturnType) + + const { result } = renderHook(() => useReferenceSetting()) + + expect(result.current.canSetPermissions).toBe(true) + }) + + it('should be false when user is neither manager nor owner', () => { + vi.mocked(useAppContext).mockReturnValue({ + isCurrentWorkspaceManager: false, + isCurrentWorkspaceOwner: false, + } as ReturnType) + + const { result } = renderHook(() => useReferenceSetting()) + + expect(result.current.canSetPermissions).toBe(false) + }) + }) + + describe('setReferenceSettings callback', () => { + it('should call invalidateReferenceSettings and show toast on success', async () => { + const mockInvalidate = vi.fn() + vi.mocked(useInvalidateReferenceSettings).mockReturnValue(mockInvalidate) + + let onSuccessCallback: (() => void) | undefined + vi.mocked(useMutationReferenceSettings).mockImplementation((options) => { + onSuccessCallback = options?.onSuccess as () => void + return { + mutate: vi.fn(), + isPending: false, + } as unknown as ReturnType + }) + + renderHook(() => useReferenceSetting()) + + // Trigger the onSuccess callback + if (onSuccessCallback) + onSuccessCallback() + + await waitFor(() => { + expect(mockInvalidate).toHaveBeenCalled() + expect(Toast.notify).toHaveBeenCalledWith({ + type: 'success', + message: 'api.actionSuccess', + }) + }) + }) + }) + + describe('returned values', () => { + it('should return referenceSetting data', () => { + const mockData = { + permission: { + install_permission: PermissionType.everyone, + debug_permission: PermissionType.everyone, + }, + } + vi.mocked(useReferenceSettings).mockReturnValue({ + data: mockData, + } as ReturnType) + + const { result } = renderHook(() => useReferenceSetting()) + + expect(result.current.referenceSetting).toEqual(mockData) + }) + + it('should return isUpdatePending from mutation', () => { + vi.mocked(useMutationReferenceSettings).mockReturnValue({ + mutate: vi.fn(), + isPending: true, + } as unknown as ReturnType) + + const { result } = renderHook(() => useReferenceSetting()) + + expect(result.current.isUpdatePending).toBe(true) + }) + + it('should handle null data', () => { + vi.mocked(useReferenceSettings).mockReturnValue({ + data: null, + } as unknown as ReturnType) + + const { result } = renderHook(() => useReferenceSetting()) + + expect(result.current.canManagement).toBe(false) + expect(result.current.canDebugger).toBe(false) + }) + }) +}) + +describe('useCanInstallPluginFromMarketplace Hook', () => { + beforeEach(() => { + vi.clearAllMocks() + + vi.mocked(useAppContext).mockReturnValue({ + isCurrentWorkspaceManager: true, + isCurrentWorkspaceOwner: false, + } as ReturnType) + + vi.mocked(useReferenceSettings).mockReturnValue({ + data: { + permission: { + install_permission: PermissionType.everyone, + debug_permission: PermissionType.everyone, + }, + }, + } as ReturnType) + + vi.mocked(useMutationReferenceSettings).mockReturnValue({ + mutate: vi.fn(), + isPending: false, + } as unknown as ReturnType) + + vi.mocked(useInvalidateReferenceSettings).mockReturnValue(vi.fn()) + }) + + it('should return true when marketplace is enabled and canManagement is true', () => { + vi.mocked(useGlobalPublicStore).mockImplementation((selector) => { + const state = { + systemFeatures: { + enable_marketplace: true, + }, + } + return selector(state as Parameters[0]) + }) + + const { result } = renderHook(() => useCanInstallPluginFromMarketplace()) + + expect(result.current.canInstallPluginFromMarketplace).toBe(true) + }) + + it('should return false when marketplace is disabled', () => { + vi.mocked(useGlobalPublicStore).mockImplementation((selector) => { + const state = { + systemFeatures: { + enable_marketplace: false, + }, + } + return selector(state as Parameters[0]) + }) + + const { result } = renderHook(() => useCanInstallPluginFromMarketplace()) + + expect(result.current.canInstallPluginFromMarketplace).toBe(false) + }) + + it('should return false when canManagement is false', () => { + vi.mocked(useGlobalPublicStore).mockImplementation((selector) => { + const state = { + systemFeatures: { + enable_marketplace: true, + }, + } + return selector(state as Parameters[0]) + }) + + vi.mocked(useReferenceSettings).mockReturnValue({ + data: { + permission: { + install_permission: PermissionType.noOne, + debug_permission: PermissionType.noOne, + }, + }, + } as ReturnType) + + const { result } = renderHook(() => useCanInstallPluginFromMarketplace()) + + expect(result.current.canInstallPluginFromMarketplace).toBe(false) + }) + + it('should return false when both marketplace is disabled and canManagement is false', () => { + vi.mocked(useGlobalPublicStore).mockImplementation((selector) => { + const state = { + systemFeatures: { + enable_marketplace: false, + }, + } + return selector(state as Parameters[0]) + }) + + vi.mocked(useReferenceSettings).mockReturnValue({ + data: { + permission: { + install_permission: PermissionType.noOne, + debug_permission: PermissionType.noOne, + }, + }, + } as ReturnType) + + const { result } = renderHook(() => useCanInstallPluginFromMarketplace()) + + expect(result.current.canInstallPluginFromMarketplace).toBe(false) + }) +}) diff --git a/web/app/components/plugins/plugin-page/use-uploader.spec.ts b/web/app/components/plugins/plugin-page/use-uploader.spec.ts new file mode 100644 index 0000000000..fa9463b7c0 --- /dev/null +++ b/web/app/components/plugins/plugin-page/use-uploader.spec.ts @@ -0,0 +1,487 @@ +import type { RefObject } from 'react' +import { act, renderHook } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { useUploader } from './use-uploader' + +describe('useUploader Hook', () => { + let mockContainerRef: RefObject + let mockOnFileChange: (file: File | null) => void + let mockContainer: HTMLDivElement + + beforeEach(() => { + vi.clearAllMocks() + + mockContainer = document.createElement('div') + document.body.appendChild(mockContainer) + + mockContainerRef = { current: mockContainer } + mockOnFileChange = vi.fn() + }) + + afterEach(() => { + if (mockContainer.parentNode) + document.body.removeChild(mockContainer) + }) + + describe('Initial State', () => { + it('should return initial state with dragging false', () => { + const { result } = renderHook(() => + useUploader({ + onFileChange: mockOnFileChange, + containerRef: mockContainerRef, + }), + ) + + expect(result.current.dragging).toBe(false) + expect(result.current.fileUploader.current).toBeNull() + expect(result.current.fileChangeHandle).not.toBeNull() + expect(result.current.removeFile).not.toBeNull() + }) + + it('should return null handlers when disabled', () => { + const { result } = renderHook(() => + useUploader({ + onFileChange: mockOnFileChange, + containerRef: mockContainerRef, + enabled: false, + }), + ) + + expect(result.current.dragging).toBe(false) + expect(result.current.fileChangeHandle).toBeNull() + expect(result.current.removeFile).toBeNull() + }) + }) + + describe('Drag Events', () => { + it('should handle dragenter and set dragging to true', () => { + const { result } = renderHook(() => + useUploader({ + onFileChange: mockOnFileChange, + containerRef: mockContainerRef, + }), + ) + + const dragEnterEvent = new Event('dragenter', { bubbles: true, cancelable: true }) + Object.defineProperty(dragEnterEvent, 'dataTransfer', { + value: { types: ['Files'] }, + }) + + act(() => { + mockContainer.dispatchEvent(dragEnterEvent) + }) + + expect(result.current.dragging).toBe(true) + }) + + it('should not set dragging when dragenter without Files type', () => { + const { result } = renderHook(() => + useUploader({ + onFileChange: mockOnFileChange, + containerRef: mockContainerRef, + }), + ) + + const dragEnterEvent = new Event('dragenter', { bubbles: true, cancelable: true }) + Object.defineProperty(dragEnterEvent, 'dataTransfer', { + value: { types: ['text/plain'] }, + }) + + act(() => { + mockContainer.dispatchEvent(dragEnterEvent) + }) + + expect(result.current.dragging).toBe(false) + }) + + it('should handle dragover event', () => { + renderHook(() => + useUploader({ + onFileChange: mockOnFileChange, + containerRef: mockContainerRef, + }), + ) + + const dragOverEvent = new Event('dragover', { bubbles: true, cancelable: true }) + + act(() => { + mockContainer.dispatchEvent(dragOverEvent) + }) + + // dragover should prevent default and stop propagation + expect(mockContainer).toBeInTheDocument() + }) + + it('should handle dragleave when relatedTarget is null', () => { + const { result } = renderHook(() => + useUploader({ + onFileChange: mockOnFileChange, + containerRef: mockContainerRef, + }), + ) + + // First set dragging to true + const dragEnterEvent = new Event('dragenter', { bubbles: true, cancelable: true }) + Object.defineProperty(dragEnterEvent, 'dataTransfer', { + value: { types: ['Files'] }, + }) + act(() => { + mockContainer.dispatchEvent(dragEnterEvent) + }) + expect(result.current.dragging).toBe(true) + + // Then trigger dragleave with null relatedTarget + const dragLeaveEvent = new Event('dragleave', { bubbles: true, cancelable: true }) + Object.defineProperty(dragLeaveEvent, 'relatedTarget', { + value: null, + }) + + act(() => { + mockContainer.dispatchEvent(dragLeaveEvent) + }) + + expect(result.current.dragging).toBe(false) + }) + + it('should handle dragleave when relatedTarget is outside container', () => { + const { result } = renderHook(() => + useUploader({ + onFileChange: mockOnFileChange, + containerRef: mockContainerRef, + }), + ) + + // First set dragging to true + const dragEnterEvent = new Event('dragenter', { bubbles: true, cancelable: true }) + Object.defineProperty(dragEnterEvent, 'dataTransfer', { + value: { types: ['Files'] }, + }) + act(() => { + mockContainer.dispatchEvent(dragEnterEvent) + }) + expect(result.current.dragging).toBe(true) + + // Create element outside container + const outsideElement = document.createElement('div') + document.body.appendChild(outsideElement) + + // Trigger dragleave with relatedTarget outside container + const dragLeaveEvent = new Event('dragleave', { bubbles: true, cancelable: true }) + Object.defineProperty(dragLeaveEvent, 'relatedTarget', { + value: outsideElement, + }) + + act(() => { + mockContainer.dispatchEvent(dragLeaveEvent) + }) + + expect(result.current.dragging).toBe(false) + document.body.removeChild(outsideElement) + }) + + it('should not set dragging to false when relatedTarget is inside container', () => { + const { result } = renderHook(() => + useUploader({ + onFileChange: mockOnFileChange, + containerRef: mockContainerRef, + }), + ) + + // First set dragging to true + const dragEnterEvent = new Event('dragenter', { bubbles: true, cancelable: true }) + Object.defineProperty(dragEnterEvent, 'dataTransfer', { + value: { types: ['Files'] }, + }) + act(() => { + mockContainer.dispatchEvent(dragEnterEvent) + }) + expect(result.current.dragging).toBe(true) + + // Create element inside container + const insideElement = document.createElement('div') + mockContainer.appendChild(insideElement) + + // Trigger dragleave with relatedTarget inside container + const dragLeaveEvent = new Event('dragleave', { bubbles: true, cancelable: true }) + Object.defineProperty(dragLeaveEvent, 'relatedTarget', { + value: insideElement, + }) + + act(() => { + mockContainer.dispatchEvent(dragLeaveEvent) + }) + + // Should still be dragging since relatedTarget is inside container + expect(result.current.dragging).toBe(true) + }) + }) + + describe('Drop Events', () => { + it('should handle drop event with files', () => { + const { result } = renderHook(() => + useUploader({ + onFileChange: mockOnFileChange, + containerRef: mockContainerRef, + }), + ) + + // First set dragging to true + const dragEnterEvent = new Event('dragenter', { bubbles: true, cancelable: true }) + Object.defineProperty(dragEnterEvent, 'dataTransfer', { + value: { types: ['Files'] }, + }) + act(() => { + mockContainer.dispatchEvent(dragEnterEvent) + }) + + // Create mock file + const file = new File(['content'], 'test.difypkg', { type: 'application/octet-stream' }) + + // Trigger drop event + const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) + Object.defineProperty(dropEvent, 'dataTransfer', { + value: { files: [file] }, + }) + + act(() => { + mockContainer.dispatchEvent(dropEvent) + }) + + expect(result.current.dragging).toBe(false) + expect(mockOnFileChange).toHaveBeenCalledWith(file) + }) + + it('should not call onFileChange when drop has no dataTransfer', () => { + const { result } = renderHook(() => + useUploader({ + onFileChange: mockOnFileChange, + containerRef: mockContainerRef, + }), + ) + + // Set dragging first + const dragEnterEvent = new Event('dragenter', { bubbles: true, cancelable: true }) + Object.defineProperty(dragEnterEvent, 'dataTransfer', { + value: { types: ['Files'] }, + }) + act(() => { + mockContainer.dispatchEvent(dragEnterEvent) + }) + + // Drop without dataTransfer + const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) + // No dataTransfer property + + act(() => { + mockContainer.dispatchEvent(dropEvent) + }) + + expect(result.current.dragging).toBe(false) + expect(mockOnFileChange).not.toHaveBeenCalled() + }) + + it('should not call onFileChange when drop has empty files array', () => { + renderHook(() => + useUploader({ + onFileChange: mockOnFileChange, + containerRef: mockContainerRef, + }), + ) + + const dropEvent = new Event('drop', { bubbles: true, cancelable: true }) + Object.defineProperty(dropEvent, 'dataTransfer', { + value: { files: [] }, + }) + + act(() => { + mockContainer.dispatchEvent(dropEvent) + }) + + expect(mockOnFileChange).not.toHaveBeenCalled() + }) + }) + + describe('File Change Handler', () => { + it('should call onFileChange with file from input', () => { + const { result } = renderHook(() => + useUploader({ + onFileChange: mockOnFileChange, + containerRef: mockContainerRef, + }), + ) + + const file = new File(['content'], 'test.difypkg', { type: 'application/octet-stream' }) + const mockEvent = { + target: { + files: [file], + }, + } as unknown as React.ChangeEvent + + act(() => { + result.current.fileChangeHandle?.(mockEvent) + }) + + expect(mockOnFileChange).toHaveBeenCalledWith(file) + }) + + it('should call onFileChange with null when no files', () => { + const { result } = renderHook(() => + useUploader({ + onFileChange: mockOnFileChange, + containerRef: mockContainerRef, + }), + ) + + const mockEvent = { + target: { + files: null, + }, + } as unknown as React.ChangeEvent + + act(() => { + result.current.fileChangeHandle?.(mockEvent) + }) + + expect(mockOnFileChange).toHaveBeenCalledWith(null) + }) + }) + + describe('Remove File', () => { + it('should call onFileChange with null', () => { + const { result } = renderHook(() => + useUploader({ + onFileChange: mockOnFileChange, + containerRef: mockContainerRef, + }), + ) + + act(() => { + result.current.removeFile?.() + }) + + expect(mockOnFileChange).toHaveBeenCalledWith(null) + }) + + it('should handle removeFile when fileUploader has a value', () => { + const { result } = renderHook(() => + useUploader({ + onFileChange: mockOnFileChange, + containerRef: mockContainerRef, + }), + ) + + // Create a mock input element with value property + const mockInput = { + value: 'test.difypkg', + } + + // Override the fileUploader ref + Object.defineProperty(result.current.fileUploader, 'current', { + value: mockInput, + writable: true, + }) + + act(() => { + result.current.removeFile?.() + }) + + expect(mockOnFileChange).toHaveBeenCalledWith(null) + expect(mockInput.value).toBe('') + }) + + it('should handle removeFile when fileUploader is null', () => { + const { result } = renderHook(() => + useUploader({ + onFileChange: mockOnFileChange, + containerRef: mockContainerRef, + }), + ) + + // fileUploader.current is null by default + act(() => { + result.current.removeFile?.() + }) + + expect(mockOnFileChange).toHaveBeenCalledWith(null) + }) + }) + + describe('Enabled/Disabled State', () => { + it('should not add event listeners when disabled', () => { + const addEventListenerSpy = vi.spyOn(mockContainer, 'addEventListener') + + renderHook(() => + useUploader({ + onFileChange: mockOnFileChange, + containerRef: mockContainerRef, + enabled: false, + }), + ) + + expect(addEventListenerSpy).not.toHaveBeenCalled() + }) + + it('should add event listeners when enabled', () => { + const addEventListenerSpy = vi.spyOn(mockContainer, 'addEventListener') + + renderHook(() => + useUploader({ + onFileChange: mockOnFileChange, + containerRef: mockContainerRef, + enabled: true, + }), + ) + + expect(addEventListenerSpy).toHaveBeenCalledWith('dragenter', expect.any(Function)) + expect(addEventListenerSpy).toHaveBeenCalledWith('dragover', expect.any(Function)) + expect(addEventListenerSpy).toHaveBeenCalledWith('dragleave', expect.any(Function)) + expect(addEventListenerSpy).toHaveBeenCalledWith('drop', expect.any(Function)) + }) + + it('should remove event listeners on cleanup', () => { + const removeEventListenerSpy = vi.spyOn(mockContainer, 'removeEventListener') + + const { unmount } = renderHook(() => + useUploader({ + onFileChange: mockOnFileChange, + containerRef: mockContainerRef, + enabled: true, + }), + ) + + unmount() + + expect(removeEventListenerSpy).toHaveBeenCalledWith('dragenter', expect.any(Function)) + expect(removeEventListenerSpy).toHaveBeenCalledWith('dragover', expect.any(Function)) + expect(removeEventListenerSpy).toHaveBeenCalledWith('dragleave', expect.any(Function)) + expect(removeEventListenerSpy).toHaveBeenCalledWith('drop', expect.any(Function)) + }) + + it('should return false for dragging when disabled', () => { + const { result } = renderHook(() => + useUploader({ + onFileChange: mockOnFileChange, + containerRef: mockContainerRef, + enabled: false, + }), + ) + + expect(result.current.dragging).toBe(false) + }) + }) + + describe('Container Ref Edge Cases', () => { + it('should handle null containerRef.current', () => { + const nullRef: RefObject = { current: null } + + const { result } = renderHook(() => + useUploader({ + onFileChange: mockOnFileChange, + containerRef: nullRef, + }), + ) + + expect(result.current.dragging).toBe(false) + }) + }) +}) diff --git a/web/app/components/rag-pipeline/index.spec.tsx b/web/app/components/rag-pipeline/index.spec.tsx new file mode 100644 index 0000000000..5adfc828cf --- /dev/null +++ b/web/app/components/rag-pipeline/index.spec.tsx @@ -0,0 +1,550 @@ +import type { FetchWorkflowDraftResponse } from '@/types/workflow' +import { cleanup, render, screen } from '@testing-library/react' +import * as React from 'react' +import { BlockEnum } from '@/app/components/workflow/types' + +// Import real utility functions (pure functions, no side effects) + +// Import mocked modules for manipulation +import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' +import { usePipelineInit } from './hooks' +import RagPipelineWrapper from './index' +import { processNodesWithoutDataSource } from './utils' + +// Mock: Context - need to control return values +vi.mock('@/context/dataset-detail', () => ({ + useDatasetDetailContextWithSelector: vi.fn(), +})) + +// Mock: Hook with API calls +vi.mock('./hooks', () => ({ + usePipelineInit: vi.fn(), +})) + +// Mock: Store creator +vi.mock('./store', () => ({ + createRagPipelineSliceSlice: vi.fn(() => ({})), +})) + +// Mock: Utility with complex workflow dependencies (generateNewNode, etc.) +vi.mock('./utils', () => ({ + processNodesWithoutDataSource: vi.fn((nodes, viewport) => ({ + nodes, + viewport, + })), +})) + +// Mock: Complex component with useParams, Toast, API calls +vi.mock('./components/conversion', () => ({ + default: () =>
Conversion Component
, +})) + +// Mock: Complex component with many hooks and workflow dependencies +vi.mock('./components/rag-pipeline-main', () => ({ + default: ({ nodes, edges, viewport }: any) => ( +
+ {nodes?.length ?? 0} + {edges?.length ?? 0} + {viewport?.zoom ?? 'none'} +
+ ), +})) + +// Mock: Complex component with ReactFlow and many providers +vi.mock('@/app/components/workflow', () => ({ + default: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), +})) + +// Mock: Context provider +vi.mock('@/app/components/workflow/context', () => ({ + WorkflowContextProvider: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), +})) + +// Type assertions for mocked functions +const mockUseDatasetDetailContextWithSelector = vi.mocked(useDatasetDetailContextWithSelector) +const mockUsePipelineInit = vi.mocked(usePipelineInit) +const mockProcessNodesWithoutDataSource = vi.mocked(processNodesWithoutDataSource) + +// Helper to mock selector with actual execution (increases function coverage) +// This executes the real selector function: s => s.dataset?.pipeline_id +const mockSelectorWithDataset = (pipelineId: string | null | undefined) => { + mockUseDatasetDetailContextWithSelector.mockImplementation((selector: (state: any) => any) => { + const mockState = { dataset: pipelineId ? { pipeline_id: pipelineId } : null } + return selector(mockState) + }) +} + +// Test data factory +const createMockWorkflowData = (overrides?: Partial): FetchWorkflowDraftResponse => ({ + graph: { + nodes: [ + { id: 'node-1', type: 'custom', data: { type: BlockEnum.Start, title: 'Start' }, position: { x: 100, y: 100 } }, + { id: 'node-2', type: 'custom', data: { type: BlockEnum.End, title: 'End' }, position: { x: 300, y: 100 } }, + ], + edges: [ + { id: 'edge-1', source: 'node-1', target: 'node-2', type: 'custom' }, + ], + viewport: { x: 0, y: 0, zoom: 1 }, + }, + hash: 'test-hash-123', + updated_at: 1234567890, + tool_published: false, + environment_variables: [], + ...overrides, +} as FetchWorkflowDraftResponse) + +afterEach(() => { + cleanup() + vi.clearAllMocks() +}) + +describe('RagPipelineWrapper', () => { + describe('Rendering', () => { + it('should render Conversion component when pipelineId is null', () => { + mockSelectorWithDataset(null) + mockUsePipelineInit.mockReturnValue({ data: undefined, isLoading: false }) + + render() + + expect(screen.getByTestId('conversion-component')).toBeInTheDocument() + expect(screen.queryByTestId('workflow-context-provider')).not.toBeInTheDocument() + }) + + it('should render Conversion component when pipelineId is undefined', () => { + mockSelectorWithDataset(undefined) + mockUsePipelineInit.mockReturnValue({ data: undefined, isLoading: false }) + + render() + + expect(screen.getByTestId('conversion-component')).toBeInTheDocument() + }) + + it('should render Conversion component when pipelineId is empty string', () => { + mockSelectorWithDataset('') + mockUsePipelineInit.mockReturnValue({ data: undefined, isLoading: false }) + + render() + + expect(screen.getByTestId('conversion-component')).toBeInTheDocument() + }) + + it('should render WorkflowContextProvider when pipelineId exists', () => { + mockSelectorWithDataset('pipeline-123') + mockUsePipelineInit.mockReturnValue({ data: undefined, isLoading: true }) + + render() + + expect(screen.getByTestId('workflow-context-provider')).toBeInTheDocument() + expect(screen.queryByTestId('conversion-component')).not.toBeInTheDocument() + }) + }) + + describe('Props Variations', () => { + it('should pass injectWorkflowStoreSliceFn to WorkflowContextProvider', () => { + mockSelectorWithDataset('pipeline-456') + mockUsePipelineInit.mockReturnValue({ data: undefined, isLoading: true }) + + render() + + expect(screen.getByTestId('workflow-context-provider')).toBeInTheDocument() + }) + }) +}) + +describe('RagPipeline', () => { + beforeEach(() => { + // Default setup for RagPipeline tests - execute real selector function + mockSelectorWithDataset('pipeline-123') + }) + + describe('Loading State', () => { + it('should render Loading component when isLoading is true', () => { + mockUsePipelineInit.mockReturnValue({ data: undefined, isLoading: true }) + + render() + + // Real Loading component has role="status" + expect(screen.getByRole('status')).toBeInTheDocument() + }) + + it('should render Loading component when data is undefined', () => { + mockUsePipelineInit.mockReturnValue({ data: undefined, isLoading: false }) + + render() + + expect(screen.getByRole('status')).toBeInTheDocument() + }) + + it('should render Loading component when both data is undefined and isLoading is true', () => { + mockUsePipelineInit.mockReturnValue({ data: undefined, isLoading: true }) + + render() + + expect(screen.getByRole('status')).toBeInTheDocument() + }) + }) + + describe('Data Loaded State', () => { + it('should render RagPipelineMain when data is loaded', () => { + const mockData = createMockWorkflowData() + mockUsePipelineInit.mockReturnValue({ data: mockData, isLoading: false }) + + render() + + expect(screen.getByTestId('rag-pipeline-main')).toBeInTheDocument() + expect(screen.queryByTestId('loading-component')).not.toBeInTheDocument() + }) + + it('should pass processed nodes to RagPipelineMain', () => { + const mockData = createMockWorkflowData() + mockUsePipelineInit.mockReturnValue({ data: mockData, isLoading: false }) + + render() + + expect(screen.getByTestId('nodes-count').textContent).toBe('2') + }) + + it('should pass edges to RagPipelineMain', () => { + const mockData = createMockWorkflowData() + mockUsePipelineInit.mockReturnValue({ data: mockData, isLoading: false }) + + render() + + expect(screen.getByTestId('edges-count').textContent).toBe('1') + }) + + it('should pass viewport to RagPipelineMain', () => { + const mockData = createMockWorkflowData({ + graph: { + nodes: [], + edges: [], + viewport: { x: 100, y: 200, zoom: 1.5 }, + }, + }) + mockUsePipelineInit.mockReturnValue({ data: mockData, isLoading: false }) + + render() + + expect(screen.getByTestId('viewport-zoom').textContent).toBe('1.5') + }) + }) + + describe('Memoization Logic', () => { + it('should process nodes through initialNodes when data is loaded', () => { + const mockData = createMockWorkflowData() + mockUsePipelineInit.mockReturnValue({ data: mockData, isLoading: false }) + + render() + + // initialNodes is a real function - verify nodes are rendered + // The real initialNodes processes nodes and adds position data + expect(screen.getByTestId('rag-pipeline-main')).toBeInTheDocument() + }) + + it('should process edges through initialEdges when data is loaded', () => { + const mockData = createMockWorkflowData() + mockUsePipelineInit.mockReturnValue({ data: mockData, isLoading: false }) + + render() + + // initialEdges is a real function - verify component renders with edges + expect(screen.getByTestId('edges-count').textContent).toBe('1') + }) + + it('should call processNodesWithoutDataSource with nodesData and viewport', () => { + const mockData = createMockWorkflowData() + mockUsePipelineInit.mockReturnValue({ data: mockData, isLoading: false }) + + render() + + expect(mockProcessNodesWithoutDataSource).toHaveBeenCalled() + }) + + it('should not process nodes when data is undefined', () => { + mockUsePipelineInit.mockReturnValue({ data: undefined, isLoading: false }) + + render() + + // When data is undefined, Loading is shown, processNodesWithoutDataSource is not called + expect(mockProcessNodesWithoutDataSource).not.toHaveBeenCalled() + }) + + it('should use memoized values when data reference is same', () => { + const mockData = createMockWorkflowData() + mockUsePipelineInit.mockReturnValue({ data: mockData, isLoading: false }) + + const { rerender } = render() + + // Clear mock call count after initial render + mockProcessNodesWithoutDataSource.mockClear() + + // Rerender with same data reference (no change to mockUsePipelineInit) + rerender() + + // processNodesWithoutDataSource should not be called again due to useMemo + // Note: React strict mode may cause double render, so we check it's not excessive + expect(mockProcessNodesWithoutDataSource.mock.calls.length).toBeLessThanOrEqual(1) + }) + }) + + describe('Edge Cases', () => { + it('should handle empty nodes array', () => { + const mockData = createMockWorkflowData({ + graph: { + nodes: [], + edges: [], + viewport: { x: 0, y: 0, zoom: 1 }, + }, + }) + mockUsePipelineInit.mockReturnValue({ data: mockData, isLoading: false }) + + render() + + expect(screen.getByTestId('nodes-count').textContent).toBe('0') + }) + + it('should handle empty edges array', () => { + const mockData = createMockWorkflowData({ + graph: { + nodes: [{ id: 'node-1', type: 'custom', data: { type: BlockEnum.Start, title: 'Start', desc: '' }, position: { x: 0, y: 0 } }], + edges: [], + viewport: { x: 0, y: 0, zoom: 1 }, + }, + }) + mockUsePipelineInit.mockReturnValue({ data: mockData, isLoading: false }) + + render() + + expect(screen.getByTestId('edges-count').textContent).toBe('0') + }) + + it('should handle undefined viewport', () => { + const mockData = createMockWorkflowData({ + graph: { + nodes: [], + edges: [], + viewport: undefined as any, + }, + }) + mockUsePipelineInit.mockReturnValue({ data: mockData, isLoading: false }) + + render() + + expect(screen.getByTestId('rag-pipeline-main')).toBeInTheDocument() + }) + + it('should handle null viewport', () => { + const mockData = createMockWorkflowData({ + graph: { + nodes: [], + edges: [], + viewport: null as any, + }, + }) + mockUsePipelineInit.mockReturnValue({ data: mockData, isLoading: false }) + + render() + + expect(screen.getByTestId('rag-pipeline-main')).toBeInTheDocument() + }) + + it('should handle large number of nodes', () => { + const largeNodesArray = Array.from({ length: 100 }, (_, i) => ({ + id: `node-${i}`, + type: 'custom', + data: { type: BlockEnum.Start, title: `Node ${i}`, desc: '' }, + position: { x: i * 100, y: 0 }, + })) + + const mockData = createMockWorkflowData({ + graph: { + nodes: largeNodesArray, + edges: [], + viewport: { x: 0, y: 0, zoom: 1 }, + }, + }) + mockUsePipelineInit.mockReturnValue({ data: mockData, isLoading: false }) + + render() + + expect(screen.getByTestId('nodes-count').textContent).toBe('100') + }) + + it('should handle viewport with edge case zoom values', () => { + const mockData = createMockWorkflowData({ + graph: { + nodes: [], + edges: [], + viewport: { x: -1000, y: -1000, zoom: 0.25 }, + }, + }) + mockUsePipelineInit.mockReturnValue({ data: mockData, isLoading: false }) + + render() + + expect(screen.getByTestId('viewport-zoom').textContent).toBe('0.25') + }) + + it('should handle viewport with maximum zoom', () => { + const mockData = createMockWorkflowData({ + graph: { + nodes: [], + edges: [], + viewport: { x: 0, y: 0, zoom: 4 }, + }, + }) + mockUsePipelineInit.mockReturnValue({ data: mockData, isLoading: false }) + + render() + + expect(screen.getByTestId('viewport-zoom').textContent).toBe('4') + }) + }) + + describe('Component Integration', () => { + it('should render WorkflowWithDefaultContext as wrapper', () => { + const mockData = createMockWorkflowData() + mockUsePipelineInit.mockReturnValue({ data: mockData, isLoading: false }) + + render() + + expect(screen.getByTestId('workflow-default-context')).toBeInTheDocument() + }) + + it('should nest RagPipelineMain inside WorkflowWithDefaultContext', () => { + const mockData = createMockWorkflowData() + mockUsePipelineInit.mockReturnValue({ data: mockData, isLoading: false }) + + render() + + const workflowContext = screen.getByTestId('workflow-default-context') + const ragPipelineMain = screen.getByTestId('rag-pipeline-main') + + expect(workflowContext).toContainElement(ragPipelineMain) + }) + }) +}) + +describe('processNodesWithoutDataSource utility integration', () => { + beforeEach(() => { + mockSelectorWithDataset('pipeline-123') + }) + + it('should process nodes through processNodesWithoutDataSource', () => { + const mockData = createMockWorkflowData() + mockUsePipelineInit.mockReturnValue({ data: mockData, isLoading: false }) + mockProcessNodesWithoutDataSource.mockReturnValue({ + nodes: [{ id: 'processed-node', type: 'custom', data: { type: BlockEnum.Start, title: 'Processed', desc: '' }, position: { x: 0, y: 0 } }] as any, + viewport: { x: 0, y: 0, zoom: 2 }, + }) + + render() + + expect(mockProcessNodesWithoutDataSource).toHaveBeenCalled() + expect(screen.getByTestId('nodes-count').textContent).toBe('1') + expect(screen.getByTestId('viewport-zoom').textContent).toBe('2') + }) + + it('should handle processNodesWithoutDataSource returning modified viewport', () => { + const mockData = createMockWorkflowData() + mockUsePipelineInit.mockReturnValue({ data: mockData, isLoading: false }) + mockProcessNodesWithoutDataSource.mockReturnValue({ + nodes: [], + viewport: { x: 500, y: 500, zoom: 0.5 }, + }) + + render() + + expect(screen.getByTestId('viewport-zoom').textContent).toBe('0.5') + }) +}) + +describe('Conditional Rendering Flow', () => { + it('should transition from loading to loaded state', () => { + mockSelectorWithDataset('pipeline-123') + + // Start with loading state + mockUsePipelineInit.mockReturnValue({ data: undefined, isLoading: true }) + const { rerender } = render() + + // Real Loading component has role="status" + expect(screen.getByRole('status')).toBeInTheDocument() + + // Transition to loaded state + const mockData = createMockWorkflowData() + mockUsePipelineInit.mockReturnValue({ data: mockData, isLoading: false }) + rerender() + + expect(screen.getByTestId('rag-pipeline-main')).toBeInTheDocument() + }) + + it('should switch from Conversion to Pipeline when pipelineId becomes available', () => { + // Start without pipelineId + mockSelectorWithDataset(null) + mockUsePipelineInit.mockReturnValue({ data: undefined, isLoading: false }) + + const { rerender } = render() + + expect(screen.getByTestId('conversion-component')).toBeInTheDocument() + + // PipelineId becomes available + mockSelectorWithDataset('new-pipeline-id') + mockUsePipelineInit.mockReturnValue({ data: undefined, isLoading: true }) + rerender() + + expect(screen.queryByTestId('conversion-component')).not.toBeInTheDocument() + // Real Loading component has role="status" + expect(screen.getByRole('status')).toBeInTheDocument() + }) +}) + +describe('Error Handling', () => { + beforeEach(() => { + mockSelectorWithDataset('pipeline-123') + }) + + it('should throw when graph nodes is null', () => { + const mockData = { + graph: { + nodes: null as any, + edges: null as any, + viewport: { x: 0, y: 0, zoom: 1 }, + }, + hash: 'test', + updated_at: 123, + } as FetchWorkflowDraftResponse + + mockUsePipelineInit.mockReturnValue({ data: mockData, isLoading: false }) + + // Suppress console.error for expected error + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + + // Real initialNodes will throw when nodes is null + // This documents the component's current behavior - it requires valid nodes array + expect(() => render()).toThrow() + + consoleSpy.mockRestore() + }) + + it('should throw when graph property is missing', () => { + const mockData = { + hash: 'test', + updated_at: 123, + } as unknown as FetchWorkflowDraftResponse + + mockUsePipelineInit.mockReturnValue({ data: mockData, isLoading: false }) + + // Suppress console.error for expected error + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + + // When graph is undefined, component throws because data.graph.nodes is accessed + // This documents the component's current behavior - it requires graph to be present + expect(() => render()).toThrow() + + consoleSpy.mockRestore() + }) +}) From 2bfc54314eeede0352b2abe017a1138cd488e56a Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Thu, 15 Jan 2026 11:10:55 +0800 Subject: [PATCH 2/8] feat: single run add opentelemetry (#31020) --- api/core/workflow/workflow_entry.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index fd3fc02f62..ee37314721 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -189,8 +189,7 @@ class WorkflowEntry: ) try: - # run node - generator = node.run() + generator = cls._traced_node_run(node) except Exception as e: logger.exception( "error while running node, workflow_id=%s, node_id=%s, node_type=%s, node_version=%s", @@ -323,8 +322,7 @@ class WorkflowEntry: tenant_id=tenant_id, ) - # run node - generator = node.run() + generator = cls._traced_node_run(node) return node, generator except Exception as e: @@ -430,3 +428,26 @@ class WorkflowEntry: input_value = current_variable.value | input_value variable_pool.add([variable_node_id] + variable_key_list, input_value) + + @staticmethod + def _traced_node_run(node: Node) -> Generator[GraphNodeEventBase, None, None]: + """ + Wraps a node's run method with OpenTelemetry tracing and returns a generator. + """ + # Wrap node.run() with ObservabilityLayer hooks to produce node-level spans + layer = ObservabilityLayer() + layer.on_graph_start() + node.ensure_execution_id() + + def _gen(): + error: Exception | None = None + layer.on_node_run_start(node) + try: + yield from node.run() + except Exception as exc: + error = exc + raise + finally: + layer.on_node_run_end(node, error) + + return _gen() From 0ef8b5a0ca664ab2be14d4aa74c2b4caa13011c8 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 15 Jan 2026 11:36:15 +0800 Subject: [PATCH 3/8] chore: bump version to 1.11.4 (#30961) --- api/pyproject.toml | 2 +- api/uv.lock | 2 +- docker/docker-compose-template.yaml | 8 ++++---- docker/docker-compose.yaml | 8 ++++---- web/package.json | 2 +- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/api/pyproject.toml b/api/pyproject.toml index 28bd591d17..d025a92846 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dify-api" -version = "1.11.3" +version = "1.11.4" requires-python = ">=3.11,<3.13" dependencies = [ diff --git a/api/uv.lock b/api/uv.lock index 792340599d..83aa89072c 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1368,7 +1368,7 @@ wheels = [ [[package]] name = "dify-api" -version = "1.11.3" +version = "1.11.4" source = { virtual = "." } dependencies = [ { name = "aliyun-log-python-sdk" }, diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index aada39569e..9659990383 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -21,7 +21,7 @@ services: # API service api: - image: langgenius/dify-api:1.11.3 + image: langgenius/dify-api:1.11.4 restart: always environment: # Use the shared environment variables. @@ -63,7 +63,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.11.3 + image: langgenius/dify-api:1.11.4 restart: always environment: # Use the shared environment variables. @@ -102,7 +102,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.11.3 + image: langgenius/dify-api:1.11.4 restart: always environment: # Use the shared environment variables. @@ -132,7 +132,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.11.3 + image: langgenius/dify-web:1.11.4 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 6439cccf47..429667e75f 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -705,7 +705,7 @@ services: # API service api: - image: langgenius/dify-api:1.11.3 + image: langgenius/dify-api:1.11.4 restart: always environment: # Use the shared environment variables. @@ -747,7 +747,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.11.3 + image: langgenius/dify-api:1.11.4 restart: always environment: # Use the shared environment variables. @@ -786,7 +786,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.11.3 + image: langgenius/dify-api:1.11.4 restart: always environment: # Use the shared environment variables. @@ -816,7 +816,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.11.3 + image: langgenius/dify-web:1.11.4 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} diff --git a/web/package.json b/web/package.json index bdbac2af83..000862204b 100644 --- a/web/package.json +++ b/web/package.json @@ -1,7 +1,7 @@ { "name": "dify-web", "type": "module", - "version": "1.11.3", + "version": "1.11.4", "private": true, "packageManager": "pnpm@10.27.0+sha512.72d699da16b1179c14ba9e64dc71c9a40988cbdc65c264cb0e489db7de917f20dcf4d64d8723625f2969ba52d4b7e2a1170682d9ac2a5dcaeaab732b7e16f04a", "imports": { From 52af829f1fadf2afeb3c62e57137c160e8179050 Mon Sep 17 00:00:00 2001 From: hj24 Date: Thu, 15 Jan 2026 14:03:17 +0800 Subject: [PATCH 4/8] refactor: enhance clean messages task (#29638) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: ้žๆณ•ๆ“ไฝœ Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- api/commands.py | 79 ++ api/extensions/ext_commands.py | 2 + ...eat_add_created_at_id_index_to_messages.py | 33 + api/models/model.py | 1 + api/schedule/clean_messages.py | 126 +- .../conversation/messages_clean_policy.py | 216 ++++ .../conversation/messages_clean_service.py | 334 +++++ .../services/test_messages_clean_service.py | 1070 +++++++++++++++++ .../services/test_messages_clean_service.py | 627 ++++++++++ 9 files changed, 2411 insertions(+), 77 deletions(-) create mode 100644 api/migrations/versions/2026_01_12_1729-3334862ee907_feat_add_created_at_id_index_to_messages.py create mode 100644 api/services/retention/conversation/messages_clean_policy.py create mode 100644 api/services/retention/conversation/messages_clean_service.py create mode 100644 api/tests/test_containers_integration_tests/services/test_messages_clean_service.py create mode 100644 api/tests/unit_tests/services/test_messages_clean_service.py diff --git a/api/commands.py b/api/commands.py index 20ce22a6c7..e223df74d4 100644 --- a/api/commands.py +++ b/api/commands.py @@ -3,6 +3,7 @@ import datetime import json import logging import secrets +import time from typing import Any import click @@ -46,6 +47,8 @@ from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpi from services.plugin.data_migration import PluginDataMigration from services.plugin.plugin_migration import PluginMigration from services.plugin.plugin_service import PluginService +from services.retention.conversation.messages_clean_policy import create_message_clean_policy +from services.retention.conversation.messages_clean_service import MessagesCleanService from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup from tasks.remove_app_and_related_data_task import delete_draft_variables_batch @@ -2172,3 +2175,79 @@ def migrate_oss( except Exception as e: db.session.rollback() click.echo(click.style(f"Failed to update DB storage_type: {str(e)}", fg="red")) + + +@click.command("clean-expired-messages", help="Clean expired messages.") +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + required=True, + help="Lower bound (inclusive) for created_at.", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + required=True, + help="Upper bound (exclusive) for created_at.", +) +@click.option("--batch-size", default=1000, show_default=True, help="Batch size for selecting messages.") +@click.option( + "--graceful-period", + default=21, + show_default=True, + help="Graceful period in days after subscription expiration, will be ignored when billing is disabled.", +) +@click.option("--dry-run", is_flag=True, default=False, help="Show messages logs would be cleaned without deleting") +def clean_expired_messages( + batch_size: int, + graceful_period: int, + start_from: datetime.datetime, + end_before: datetime.datetime, + dry_run: bool, +): + """ + Clean expired messages and related data for tenants based on clean policy. + """ + click.echo(click.style("clean_messages: start clean messages.", fg="green")) + + start_at = time.perf_counter() + + try: + # Create policy based on billing configuration + # NOTE: graceful_period will be ignored when billing is disabled. + policy = create_message_clean_policy(graceful_period_days=graceful_period) + + # Create and run the cleanup service + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=start_from, + end_before=end_before, + batch_size=batch_size, + dry_run=dry_run, + ) + stats = service.run() + + end_at = time.perf_counter() + click.echo( + click.style( + f"clean_messages: completed successfully\n" + f" - Latency: {end_at - start_at:.2f}s\n" + f" - Batches processed: {stats['batches']}\n" + f" - Total messages scanned: {stats['total_messages']}\n" + f" - Messages filtered: {stats['filtered_messages']}\n" + f" - Messages deleted: {stats['total_deleted']}", + fg="green", + ) + ) + except Exception as e: + end_at = time.perf_counter() + logger.exception("clean_messages failed") + click.echo( + click.style( + f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}", + fg="red", + ) + ) + raise + + click.echo(click.style("messages cleanup completed.", fg="green")) diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py index c32130d377..51e2c6cdd5 100644 --- a/api/extensions/ext_commands.py +++ b/api/extensions/ext_commands.py @@ -4,6 +4,7 @@ from dify_app import DifyApp def init_app(app: DifyApp): from commands import ( add_qdrant_index, + clean_expired_messages, clean_workflow_runs, cleanup_orphaned_draft_variables, clear_free_plan_tenant_expired_logs, @@ -58,6 +59,7 @@ def init_app(app: DifyApp): transform_datasource_credentials, install_rag_pipeline_plugins, clean_workflow_runs, + clean_expired_messages, ] for cmd in cmds_to_register: app.cli.add_command(cmd) diff --git a/api/migrations/versions/2026_01_12_1729-3334862ee907_feat_add_created_at_id_index_to_messages.py b/api/migrations/versions/2026_01_12_1729-3334862ee907_feat_add_created_at_id_index_to_messages.py new file mode 100644 index 0000000000..758369ba99 --- /dev/null +++ b/api/migrations/versions/2026_01_12_1729-3334862ee907_feat_add_created_at_id_index_to_messages.py @@ -0,0 +1,33 @@ +"""feat: add created_at id index to messages + +Revision ID: 3334862ee907 +Revises: 905527cc8fd3 +Create Date: 2026-01-12 17:29:44.846544 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '3334862ee907' +down_revision = '905527cc8fd3' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.create_index('message_created_at_id_idx', ['created_at', 'id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.drop_index('message_created_at_id_idx') + + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index 68903e86eb..d6a0aa3bb3 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -968,6 +968,7 @@ class Message(Base): Index("message_workflow_run_id_idx", "conversation_id", "workflow_run_id"), Index("message_created_at_idx", "created_at"), Index("message_app_mode_idx", "app_mode"), + Index("message_created_at_id_idx", "created_at", "id"), ) id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) diff --git a/api/schedule/clean_messages.py b/api/schedule/clean_messages.py index 352a84b592..e85bba8823 100644 --- a/api/schedule/clean_messages.py +++ b/api/schedule/clean_messages.py @@ -1,90 +1,62 @@ -import datetime import logging import time import click -from sqlalchemy.exc import SQLAlchemyError import app from configs import dify_config -from enums.cloud_plan import CloudPlan -from extensions.ext_database import db -from extensions.ext_redis import redis_client -from models.model import ( - App, - Message, - MessageAgentThought, - MessageAnnotation, - MessageChain, - MessageFeedback, - MessageFile, -) -from models.web import SavedMessage -from services.feature_service import FeatureService +from services.retention.conversation.messages_clean_policy import create_message_clean_policy +from services.retention.conversation.messages_clean_service import MessagesCleanService logger = logging.getLogger(__name__) -@app.celery.task(queue="dataset") +@app.celery.task(queue="retention") def clean_messages(): - click.echo(click.style("Start clean messages.", fg="green")) - start_at = time.perf_counter() - plan_sandbox_clean_message_day = datetime.datetime.now() - datetime.timedelta( - days=dify_config.PLAN_SANDBOX_CLEAN_MESSAGE_DAY_SETTING - ) - while True: - try: - # Main query with join and filter - messages = ( - db.session.query(Message) - .where(Message.created_at < plan_sandbox_clean_message_day) - .order_by(Message.created_at.desc()) - .limit(100) - .all() - ) + """ + Clean expired messages based on clean policy. - except SQLAlchemyError: - raise - if not messages: - break - for message in messages: - app = db.session.query(App).filter_by(id=message.app_id).first() - if not app: - logger.warning( - "Expected App record to exist, but none was found, app_id=%s, message_id=%s", - message.app_id, - message.id, - ) - continue - features_cache_key = f"features:{app.tenant_id}" - plan_cache = redis_client.get(features_cache_key) - if plan_cache is None: - features = FeatureService.get_features(app.tenant_id) - redis_client.setex(features_cache_key, 600, features.billing.subscription.plan) - plan = features.billing.subscription.plan - else: - plan = plan_cache.decode() - if plan == CloudPlan.SANDBOX: - # clean related message - db.session.query(MessageFeedback).where(MessageFeedback.message_id == message.id).delete( - synchronize_session=False - ) - db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message.id).delete( - synchronize_session=False - ) - db.session.query(MessageChain).where(MessageChain.message_id == message.id).delete( - synchronize_session=False - ) - db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message.id).delete( - synchronize_session=False - ) - db.session.query(MessageFile).where(MessageFile.message_id == message.id).delete( - synchronize_session=False - ) - db.session.query(SavedMessage).where(SavedMessage.message_id == message.id).delete( - synchronize_session=False - ) - db.session.query(Message).where(Message.id == message.id).delete() - db.session.commit() - end_at = time.perf_counter() - click.echo(click.style(f"Cleaned messages from db success latency: {end_at - start_at}", fg="green")) + This task uses MessagesCleanService to efficiently clean messages in batches. + The behavior depends on BILLING_ENABLED configuration: + - BILLING_ENABLED=True: only delete messages from sandbox tenants (with whitelist/grace period) + - BILLING_ENABLED=False: delete all messages within the time range + """ + click.echo(click.style("clean_messages: start clean messages.", fg="green")) + start_at = time.perf_counter() + + try: + # Create policy based on billing configuration + policy = create_message_clean_policy( + graceful_period_days=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD, + ) + + # Create and run the cleanup service + service = MessagesCleanService.from_days( + policy=policy, + days=dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS, + batch_size=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE, + ) + stats = service.run() + + end_at = time.perf_counter() + click.echo( + click.style( + f"clean_messages: completed successfully\n" + f" - Latency: {end_at - start_at:.2f}s\n" + f" - Batches processed: {stats['batches']}\n" + f" - Total messages scanned: {stats['total_messages']}\n" + f" - Messages filtered: {stats['filtered_messages']}\n" + f" - Messages deleted: {stats['total_deleted']}", + fg="green", + ) + ) + except Exception as e: + end_at = time.perf_counter() + logger.exception("clean_messages failed") + click.echo( + click.style( + f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}", + fg="red", + ) + ) + raise diff --git a/api/services/retention/conversation/messages_clean_policy.py b/api/services/retention/conversation/messages_clean_policy.py new file mode 100644 index 0000000000..6e647b983b --- /dev/null +++ b/api/services/retention/conversation/messages_clean_policy.py @@ -0,0 +1,216 @@ +import datetime +import logging +from abc import ABC, abstractmethod +from collections.abc import Callable, Sequence +from dataclasses import dataclass + +from configs import dify_config +from enums.cloud_plan import CloudPlan +from services.billing_service import BillingService, SubscriptionPlan + +logger = logging.getLogger(__name__) + + +@dataclass +class SimpleMessage: + id: str + app_id: str + created_at: datetime.datetime + + +class MessagesCleanPolicy(ABC): + """ + Abstract base class for message cleanup policies. + + A policy determines which messages from a batch should be deleted. + """ + + @abstractmethod + def filter_message_ids( + self, + messages: Sequence[SimpleMessage], + app_to_tenant: dict[str, str], + ) -> Sequence[str]: + """ + Filter messages and return IDs of messages that should be deleted. + + Args: + messages: Batch of messages to evaluate + app_to_tenant: Mapping from app_id to tenant_id + + Returns: + List of message IDs that should be deleted + """ + ... + + +class BillingDisabledPolicy(MessagesCleanPolicy): + """ + Policy for community or enterpriseedition (billing disabled). + + No special filter logic, just return all message ids. + """ + + def filter_message_ids( + self, + messages: Sequence[SimpleMessage], + app_to_tenant: dict[str, str], + ) -> Sequence[str]: + return [msg.id for msg in messages] + + +class BillingSandboxPolicy(MessagesCleanPolicy): + """ + Policy for sandbox plan tenants in cloud edition (billing enabled). + + Filters messages based on sandbox plan expiration rules: + - Skip tenants in the whitelist + - Only delete messages from sandbox plan tenants + - Respect grace period after subscription expiration + - Safe default: if tenant mapping or plan is missing, do NOT delete + """ + + def __init__( + self, + plan_provider: Callable[[Sequence[str]], dict[str, SubscriptionPlan]], + graceful_period_days: int = 21, + tenant_whitelist: Sequence[str] | None = None, + current_timestamp: int | None = None, + ) -> None: + self._graceful_period_days = graceful_period_days + self._tenant_whitelist: Sequence[str] = tenant_whitelist or [] + self._plan_provider = plan_provider + self._current_timestamp = current_timestamp + + def filter_message_ids( + self, + messages: Sequence[SimpleMessage], + app_to_tenant: dict[str, str], + ) -> Sequence[str]: + """ + Filter messages based on sandbox plan expiration rules. + + Args: + messages: Batch of messages to evaluate + app_to_tenant: Mapping from app_id to tenant_id + + Returns: + List of message IDs that should be deleted + """ + if not messages or not app_to_tenant: + return [] + + # Get unique tenant_ids and fetch subscription plans + tenant_ids = list(set(app_to_tenant.values())) + tenant_plans = self._plan_provider(tenant_ids) + + if not tenant_plans: + return [] + + # Apply sandbox deletion rules + return self._filter_expired_sandbox_messages( + messages=messages, + app_to_tenant=app_to_tenant, + tenant_plans=tenant_plans, + ) + + def _filter_expired_sandbox_messages( + self, + messages: Sequence[SimpleMessage], + app_to_tenant: dict[str, str], + tenant_plans: dict[str, SubscriptionPlan], + ) -> list[str]: + """ + Filter messages that should be deleted based on sandbox plan expiration. + + A message should be deleted if: + 1. It belongs to a sandbox tenant AND + 2. Either: + a) The tenant has no previous subscription (expiration_date == -1), OR + b) The subscription expired more than graceful_period_days ago + + Args: + messages: List of message objects with id and app_id attributes + app_to_tenant: Mapping from app_id to tenant_id + tenant_plans: Mapping from tenant_id to subscription plan info + + Returns: + List of message IDs that should be deleted + """ + current_timestamp = self._current_timestamp + if current_timestamp is None: + current_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) + + sandbox_message_ids: list[str] = [] + graceful_period_seconds = self._graceful_period_days * 24 * 60 * 60 + + for msg in messages: + # Get tenant_id for this message's app + tenant_id = app_to_tenant.get(msg.app_id) + if not tenant_id: + continue + + # Skip tenant messages in whitelist + if tenant_id in self._tenant_whitelist: + continue + + # Get subscription plan for this tenant + tenant_plan = tenant_plans.get(tenant_id) + if not tenant_plan: + continue + + plan = str(tenant_plan["plan"]) + expiration_date = int(tenant_plan["expiration_date"]) + + # Only process sandbox plans + if plan != CloudPlan.SANDBOX: + continue + + # Case 1: No previous subscription (-1 means never had a paid subscription) + if expiration_date == -1: + sandbox_message_ids.append(msg.id) + continue + + # Case 2: Subscription expired beyond grace period + if current_timestamp - expiration_date > graceful_period_seconds: + sandbox_message_ids.append(msg.id) + + return sandbox_message_ids + + +def create_message_clean_policy( + graceful_period_days: int = 21, + current_timestamp: int | None = None, +) -> MessagesCleanPolicy: + """ + Factory function to create the appropriate message clean policy. + + Determines which policy to use based on BILLING_ENABLED configuration: + - If BILLING_ENABLED is True: returns BillingSandboxPolicy + - If BILLING_ENABLED is False: returns BillingDisabledPolicy + + Args: + graceful_period_days: Grace period in days after subscription expiration (default: 21) + current_timestamp: Current Unix timestamp for testing (default: None, uses current time) + """ + if not dify_config.BILLING_ENABLED: + logger.info("create_message_clean_policy: billing disabled, using BillingDisabledPolicy") + return BillingDisabledPolicy() + + # Billing enabled - fetch whitelist from BillingService + tenant_whitelist = BillingService.get_expired_subscription_cleanup_whitelist() + plan_provider = BillingService.get_plan_bulk_with_cache + + logger.info( + "create_message_clean_policy: billing enabled, using BillingSandboxPolicy " + "(graceful_period_days=%s, whitelist=%s)", + graceful_period_days, + tenant_whitelist, + ) + + return BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=graceful_period_days, + tenant_whitelist=tenant_whitelist, + current_timestamp=current_timestamp, + ) diff --git a/api/services/retention/conversation/messages_clean_service.py b/api/services/retention/conversation/messages_clean_service.py new file mode 100644 index 0000000000..3ca5d82860 --- /dev/null +++ b/api/services/retention/conversation/messages_clean_service.py @@ -0,0 +1,334 @@ +import datetime +import logging +import random +from collections.abc import Sequence +from typing import cast + +from sqlalchemy import delete, select +from sqlalchemy.engine import CursorResult +from sqlalchemy.orm import Session + +from extensions.ext_database import db +from models.model import ( + App, + AppAnnotationHitHistory, + DatasetRetrieverResource, + Message, + MessageAgentThought, + MessageAnnotation, + MessageChain, + MessageFeedback, + MessageFile, +) +from models.web import SavedMessage +from services.retention.conversation.messages_clean_policy import ( + MessagesCleanPolicy, + SimpleMessage, +) + +logger = logging.getLogger(__name__) + + +class MessagesCleanService: + """ + Service for cleaning expired messages based on retention policies. + + Compatible with non cloud edition (billing disabled): all messages in the time range will be deleted. + If billing is enabled: only sandbox plan tenant messages are deleted (with whitelist and grace period support). + """ + + def __init__( + self, + policy: MessagesCleanPolicy, + end_before: datetime.datetime, + start_from: datetime.datetime | None = None, + batch_size: int = 1000, + dry_run: bool = False, + ) -> None: + """ + Initialize the service with cleanup parameters. + + Args: + policy: The policy that determines which messages to delete + end_before: End time (exclusive) of the range + start_from: Optional start time (inclusive) of the range + batch_size: Number of messages to process per batch + dry_run: Whether to perform a dry run (no actual deletion) + """ + self._policy = policy + self._end_before = end_before + self._start_from = start_from + self._batch_size = batch_size + self._dry_run = dry_run + + @classmethod + def from_time_range( + cls, + policy: MessagesCleanPolicy, + start_from: datetime.datetime, + end_before: datetime.datetime, + batch_size: int = 1000, + dry_run: bool = False, + ) -> "MessagesCleanService": + """ + Create a service instance for cleaning messages within a specific time range. + + Time range is [start_from, end_before). + + Args: + policy: The policy that determines which messages to delete + start_from: Start time (inclusive) of the range + end_before: End time (exclusive) of the range + batch_size: Number of messages to process per batch + dry_run: Whether to perform a dry run (no actual deletion) + + Returns: + MessagesCleanService instance + + Raises: + ValueError: If start_from >= end_before or invalid parameters + """ + if start_from >= end_before: + raise ValueError(f"start_from ({start_from}) must be less than end_before ({end_before})") + + if batch_size <= 0: + raise ValueError(f"batch_size ({batch_size}) must be greater than 0") + + logger.info( + "clean_messages: start_from=%s, end_before=%s, batch_size=%s, policy=%s", + start_from, + end_before, + batch_size, + policy.__class__.__name__, + ) + + return cls( + policy=policy, + end_before=end_before, + start_from=start_from, + batch_size=batch_size, + dry_run=dry_run, + ) + + @classmethod + def from_days( + cls, + policy: MessagesCleanPolicy, + days: int = 30, + batch_size: int = 1000, + dry_run: bool = False, + ) -> "MessagesCleanService": + """ + Create a service instance for cleaning messages older than specified days. + + Args: + policy: The policy that determines which messages to delete + days: Number of days to look back from now + batch_size: Number of messages to process per batch + dry_run: Whether to perform a dry run (no actual deletion) + + Returns: + MessagesCleanService instance + + Raises: + ValueError: If invalid parameters + """ + if days < 0: + raise ValueError(f"days ({days}) must be greater than or equal to 0") + + if batch_size <= 0: + raise ValueError(f"batch_size ({batch_size}) must be greater than 0") + + end_before = datetime.datetime.now() - datetime.timedelta(days=days) + + logger.info( + "clean_messages: days=%s, end_before=%s, batch_size=%s, policy=%s", + days, + end_before, + batch_size, + policy.__class__.__name__, + ) + + return cls(policy=policy, end_before=end_before, start_from=None, batch_size=batch_size, dry_run=dry_run) + + def run(self) -> dict[str, int]: + """ + Execute the message cleanup operation. + + Returns: + Dict with statistics: batches, filtered_messages, total_deleted + """ + return self._clean_messages_by_time_range() + + def _clean_messages_by_time_range(self) -> dict[str, int]: + """ + Clean messages within a time range using cursor-based pagination. + + Time range is [start_from, end_before) + + Steps: + 1. Iterate messages using cursor pagination (by created_at, id) + 2. Query app_id -> tenant_id mapping + 3. Delegate to policy to determine which messages to delete + 4. Batch delete messages and their relations + + Returns: + Dict with statistics: batches, filtered_messages, total_deleted + """ + stats = { + "batches": 0, + "total_messages": 0, + "filtered_messages": 0, + "total_deleted": 0, + } + + # Cursor-based pagination using (created_at, id) to avoid infinite loops + # and ensure proper ordering with time-based filtering + _cursor: tuple[datetime.datetime, str] | None = None + + logger.info( + "clean_messages: start cleaning messages (dry_run=%s), start_from=%s, end_before=%s", + self._dry_run, + self._start_from, + self._end_before, + ) + + while True: + stats["batches"] += 1 + + # Step 1: Fetch a batch of messages using cursor + with Session(db.engine, expire_on_commit=False) as session: + msg_stmt = ( + select(Message.id, Message.app_id, Message.created_at) + .where(Message.created_at < self._end_before) + .order_by(Message.created_at, Message.id) + .limit(self._batch_size) + ) + + if self._start_from: + msg_stmt = msg_stmt.where(Message.created_at >= self._start_from) + + # Apply cursor condition: (created_at, id) > (last_created_at, last_message_id) + # This translates to: + # created_at > last_created_at OR (created_at = last_created_at AND id > last_message_id) + if _cursor: + # Continuing from previous batch + msg_stmt = msg_stmt.where( + (Message.created_at > _cursor[0]) + | ((Message.created_at == _cursor[0]) & (Message.id > _cursor[1])) + ) + + raw_messages = list(session.execute(msg_stmt).all()) + messages = [ + SimpleMessage(id=msg_id, app_id=app_id, created_at=msg_created_at) + for msg_id, app_id, msg_created_at in raw_messages + ] + + # Track total messages fetched across all batches + stats["total_messages"] += len(messages) + + if not messages: + logger.info("clean_messages (batch %s): no more messages to process", stats["batches"]) + break + + # Update cursor to the last message's (created_at, id) + _cursor = (messages[-1].created_at, messages[-1].id) + + # Step 2: Extract app_ids and query tenant_ids + app_ids = list({msg.app_id for msg in messages}) + + if not app_ids: + logger.info("clean_messages (batch %s): no app_ids found, skip", stats["batches"]) + continue + + app_stmt = select(App.id, App.tenant_id).where(App.id.in_(app_ids)) + apps = list(session.execute(app_stmt).all()) + + if not apps: + logger.info("clean_messages (batch %s): no apps found, skip", stats["batches"]) + continue + + # Build app_id -> tenant_id mapping + app_to_tenant: dict[str, str] = {app.id: app.tenant_id for app in apps} + + # Step 3: Delegate to policy to determine which messages to delete + message_ids_to_delete = self._policy.filter_message_ids(messages, app_to_tenant) + + if not message_ids_to_delete: + logger.info("clean_messages (batch %s): no messages to delete, skip", stats["batches"]) + continue + + stats["filtered_messages"] += len(message_ids_to_delete) + + # Step 4: Batch delete messages and their relations + if not self._dry_run: + with Session(db.engine, expire_on_commit=False) as session: + # Delete related records first + self._batch_delete_message_relations(session, message_ids_to_delete) + + # Delete messages + delete_stmt = delete(Message).where(Message.id.in_(message_ids_to_delete)) + delete_result = cast(CursorResult, session.execute(delete_stmt)) + messages_deleted = delete_result.rowcount + session.commit() + + stats["total_deleted"] += messages_deleted + + logger.info( + "clean_messages (batch %s): processed %s messages, deleted %s messages", + stats["batches"], + len(messages), + messages_deleted, + ) + else: + # Log random sample of message IDs that would be deleted (up to 10) + sample_size = min(10, len(message_ids_to_delete)) + sampled_ids = random.sample(list(message_ids_to_delete), sample_size) + + logger.info( + "clean_messages (batch %s, dry_run): would delete %s messages, sampling %s ids:", + stats["batches"], + len(message_ids_to_delete), + sample_size, + ) + for msg_id in sampled_ids: + logger.info("clean_messages (batch %s, dry_run) sample: message_id=%s", stats["batches"], msg_id) + + logger.info( + "clean_messages completed: total batches: %s, total messages: %s, filtered messages: %s, total deleted: %s", + stats["batches"], + stats["total_messages"], + stats["filtered_messages"], + stats["total_deleted"], + ) + + return stats + + @staticmethod + def _batch_delete_message_relations(session: Session, message_ids: Sequence[str]) -> None: + """ + Batch delete all related records for given message IDs. + + Args: + session: Database session + message_ids: List of message IDs to delete relations for + """ + if not message_ids: + return + + # Delete all related records in batch + session.execute(delete(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids))) + + session.execute(delete(MessageAnnotation).where(MessageAnnotation.message_id.in_(message_ids))) + + session.execute(delete(MessageChain).where(MessageChain.message_id.in_(message_ids))) + + session.execute(delete(MessageAgentThought).where(MessageAgentThought.message_id.in_(message_ids))) + + session.execute(delete(MessageFile).where(MessageFile.message_id.in_(message_ids))) + + session.execute(delete(SavedMessage).where(SavedMessage.message_id.in_(message_ids))) + + session.execute(delete(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id.in_(message_ids))) + + session.execute(delete(DatasetRetrieverResource).where(DatasetRetrieverResource.message_id.in_(message_ids))) diff --git a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py new file mode 100644 index 0000000000..29baa4d94f --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py @@ -0,0 +1,1070 @@ +import datetime +import json +import uuid +from decimal import Decimal +from unittest.mock import patch + +import pytest +from faker import Faker + +from enums.cloud_plan import CloudPlan +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.model import ( + App, + AppAnnotationHitHistory, + Conversation, + DatasetRetrieverResource, + Message, + MessageAgentThought, + MessageAnnotation, + MessageChain, + MessageFeedback, + MessageFile, +) +from models.web import SavedMessage +from services.billing_service import BillingService +from services.retention.conversation.messages_clean_policy import ( + BillingDisabledPolicy, + BillingSandboxPolicy, + create_message_clean_policy, +) +from services.retention.conversation.messages_clean_service import MessagesCleanService + + +class TestMessagesCleanServiceIntegration: + """Integration tests for MessagesCleanService.run() and _clean_messages_by_time_range().""" + + # Redis cache key prefix from BillingService + PLAN_CACHE_KEY_PREFIX = BillingService._PLAN_CACHE_KEY_PREFIX # "tenant_plan:" + + @pytest.fixture(autouse=True) + def cleanup_database(self, db_session_with_containers): + """Clean up database before and after each test to ensure isolation.""" + yield + # Clear all test data in correct order (respecting foreign key constraints) + db.session.query(DatasetRetrieverResource).delete() + db.session.query(AppAnnotationHitHistory).delete() + db.session.query(SavedMessage).delete() + db.session.query(MessageFile).delete() + db.session.query(MessageAgentThought).delete() + db.session.query(MessageChain).delete() + db.session.query(MessageAnnotation).delete() + db.session.query(MessageFeedback).delete() + db.session.query(Message).delete() + db.session.query(Conversation).delete() + db.session.query(App).delete() + db.session.query(TenantAccountJoin).delete() + db.session.query(Tenant).delete() + db.session.query(Account).delete() + db.session.commit() + + @pytest.fixture(autouse=True) + def cleanup_redis(self): + """Clean up Redis cache before each test.""" + # Clear tenant plan cache using BillingService key prefix + try: + keys = redis_client.keys(f"{self.PLAN_CACHE_KEY_PREFIX}*") + if keys: + redis_client.delete(*keys) + except Exception: + pass # Redis might not be available in some test environments + yield + # Clean up after test + try: + keys = redis_client.keys(f"{self.PLAN_CACHE_KEY_PREFIX}*") + if keys: + redis_client.delete(*keys) + except Exception: + pass + + @pytest.fixture + def mock_whitelist(self): + """Mock whitelist to return empty list by default.""" + with patch( + "services.retention.conversation.messages_clean_policy.BillingService.get_expired_subscription_cleanup_whitelist" + ) as mock: + mock.return_value = [] + yield mock + + @pytest.fixture + def mock_billing_enabled(self): + """Mock BILLING_ENABLED to be True.""" + with patch("services.retention.conversation.messages_clean_policy.dify_config.BILLING_ENABLED", True): + yield + + @pytest.fixture + def mock_billing_disabled(self): + """Mock BILLING_ENABLED to be False.""" + with patch("services.retention.conversation.messages_clean_policy.dify_config.BILLING_ENABLED", False): + yield + + def _create_account_and_tenant(self, plan: str = CloudPlan.SANDBOX): + """Helper to create account and tenant.""" + fake = Faker() + + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + db.session.add(account) + db.session.flush() + + tenant = Tenant( + name=fake.company(), + plan=str(plan), + status="normal", + ) + db.session.add(tenant) + db.session.flush() + + tenant_account_join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + ) + db.session.add(tenant_account_join) + db.session.commit() + + return account, tenant + + def _create_app(self, tenant, account): + """Helper to create an app.""" + fake = Faker() + + app = App( + tenant_id=tenant.id, + name=fake.company(), + description="Test app", + mode="chat", + enable_site=True, + enable_api=True, + api_rpm=60, + api_rph=3600, + is_demo=False, + is_public=False, + created_by=account.id, + updated_by=account.id, + ) + db.session.add(app) + db.session.commit() + + return app + + def _create_conversation(self, app): + """Helper to create a conversation.""" + conversation = Conversation( + app_id=app.id, + app_model_config_id=str(uuid.uuid4()), + model_provider="openai", + model_id="gpt-3.5-turbo", + mode="chat", + name="Test conversation", + inputs={}, + status="normal", + from_source="api", + from_end_user_id=str(uuid.uuid4()), + ) + db.session.add(conversation) + db.session.commit() + + return conversation + + def _create_message(self, app, conversation, created_at=None, with_relations=True): + """Helper to create a message with optional related records.""" + if created_at is None: + created_at = datetime.datetime.now() + + message = Message( + app_id=app.id, + conversation_id=conversation.id, + model_provider="openai", + model_id="gpt-3.5-turbo", + inputs={}, + query="Test query", + answer="Test answer", + message=[{"role": "user", "text": "Test message"}], + message_tokens=10, + message_unit_price=Decimal("0.001"), + answer_tokens=20, + answer_unit_price=Decimal("0.002"), + total_price=Decimal("0.003"), + currency="USD", + from_source="api", + from_account_id=conversation.from_end_user_id, + created_at=created_at, + ) + db.session.add(message) + db.session.flush() + + if with_relations: + self._create_message_relations(message) + + db.session.commit() + return message + + def _create_message_relations(self, message): + """Helper to create all message-related records.""" + # MessageFeedback + feedback = MessageFeedback( + app_id=message.app_id, + conversation_id=message.conversation_id, + message_id=message.id, + rating="like", + from_source="api", + from_end_user_id=str(uuid.uuid4()), + ) + db.session.add(feedback) + + # MessageAnnotation + annotation = MessageAnnotation( + app_id=message.app_id, + conversation_id=message.conversation_id, + message_id=message.id, + question="Test question", + content="Test annotation", + account_id=message.from_account_id, + ) + db.session.add(annotation) + + # MessageChain + chain = MessageChain( + message_id=message.id, + type="system", + input=json.dumps({"test": "input"}), + output=json.dumps({"test": "output"}), + ) + db.session.add(chain) + db.session.flush() + + # MessageFile + file = MessageFile( + message_id=message.id, + type="image", + transfer_method="local_file", + url="http://example.com/test.jpg", + belongs_to="user", + created_by_role="end_user", + created_by=str(uuid.uuid4()), + ) + db.session.add(file) + + # SavedMessage + saved = SavedMessage( + app_id=message.app_id, + message_id=message.id, + created_by_role="end_user", + created_by=str(uuid.uuid4()), + ) + db.session.add(saved) + + db.session.flush() + + # AppAnnotationHitHistory + hit = AppAnnotationHitHistory( + app_id=message.app_id, + annotation_id=annotation.id, + message_id=message.id, + source="annotation", + question="Test question", + account_id=message.from_account_id, + annotation_question="Test annotation question", + annotation_content="Test annotation content", + ) + db.session.add(hit) + + # DatasetRetrieverResource + resource = DatasetRetrieverResource( + message_id=message.id, + position=1, + dataset_id=str(uuid.uuid4()), + dataset_name="Test dataset", + document_id=str(uuid.uuid4()), + document_name="Test document", + data_source_type="upload_file", + segment_id=str(uuid.uuid4()), + score=0.9, + content="Test content", + hit_count=1, + word_count=10, + segment_position=1, + index_node_hash="test_hash", + retriever_from="dataset", + created_by=message.from_account_id, + ) + db.session.add(resource) + + def test_billing_disabled_deletes_all_messages_in_time_range( + self, db_session_with_containers, mock_billing_disabled + ): + """Test that BillingDisabledPolicy deletes all messages within time range regardless of tenant plan.""" + # Arrange - Create tenant with messages (plan doesn't matter for billing disabled) + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create messages: in-range (should be deleted) and out-of-range (should be kept) + in_range_date = datetime.datetime(2024, 1, 15, 12, 0, 0) + out_of_range_date = datetime.datetime(2024, 1, 25, 12, 0, 0) + + in_range_msg = self._create_message(app, conv, created_at=in_range_date, with_relations=True) + in_range_msg_id = in_range_msg.id + + out_of_range_msg = self._create_message(app, conv, created_at=out_of_range_date, with_relations=True) + out_of_range_msg_id = out_of_range_msg.id + + # Act - create_message_clean_policy should return BillingDisabledPolicy + policy = create_message_clean_policy() + + assert isinstance(policy, BillingDisabledPolicy) + + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime(2024, 1, 10, 0, 0, 0), + end_before=datetime.datetime(2024, 1, 20, 0, 0, 0), + batch_size=100, + ) + stats = service.run() + + # Assert + assert stats["total_messages"] == 1 # Only in-range message fetched + assert stats["filtered_messages"] == 1 + assert stats["total_deleted"] == 1 + + # In-range message deleted + assert db.session.query(Message).where(Message.id == in_range_msg_id).count() == 0 + # Out-of-range message kept + assert db.session.query(Message).where(Message.id == out_of_range_msg_id).count() == 1 + + # Related records of in-range message deleted + assert db.session.query(MessageFeedback).where(MessageFeedback.message_id == in_range_msg_id).count() == 0 + assert db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == in_range_msg_id).count() == 0 + # Related records of out-of-range message kept + assert db.session.query(MessageFeedback).where(MessageFeedback.message_id == out_of_range_msg_id).count() == 1 + + def test_no_messages_returns_empty_stats(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + """Test cleaning when there are no messages to delete (B1).""" + # Arrange + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + start_from = datetime.datetime.now() - datetime.timedelta(days=60) + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = {} + + # Act + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=start_from, + end_before=end_before, + batch_size=100, + ) + stats = service.run() + + # Assert - loop runs once to check, finds nothing + assert stats["batches"] == 1 + assert stats["total_messages"] == 0 + assert stats["filtered_messages"] == 0 + assert stats["total_deleted"] == 0 + + def test_mixed_sandbox_and_paid_tenants(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + """Test cleaning with mixed sandbox and paid tenants (B2).""" + # Arrange - Create sandbox tenants with expired messages + sandbox_tenants = [] + sandbox_message_ids = [] + for i in range(2): + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + sandbox_tenants.append(tenant) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create 3 expired messages per sandbox tenant + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + for j in range(3): + msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=j)) + sandbox_message_ids.append(msg.id) + + # Create paid tenants with expired messages (should NOT be deleted) + paid_tenants = [] + paid_message_ids = [] + for i in range(2): + account, tenant = self._create_account_and_tenant(plan=CloudPlan.PROFESSIONAL) + paid_tenants.append(tenant) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create 2 expired messages per paid tenant + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + for j in range(2): + msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=j)) + paid_message_ids.append(msg.id) + + # Mock billing service - return plan and expiration_date + now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) + expired_15_days_ago = now_timestamp - (15 * 24 * 60 * 60) # Beyond 7-day grace period + + plan_map = {} + for tenant in sandbox_tenants: + plan_map[tenant.id] = { + "plan": CloudPlan.SANDBOX, + "expiration_date": expired_15_days_ago, + } + for tenant in paid_tenants: + plan_map[tenant.id] = { + "plan": CloudPlan.PROFESSIONAL, + "expiration_date": now_timestamp + (365 * 24 * 60 * 60), # Active for 1 year + } + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = plan_map + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy(graceful_period_days=7) + + assert isinstance(policy, BillingSandboxPolicy) + + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=100, + ) + stats = service.run() + + # Assert + assert stats["total_messages"] == 10 # 2 sandbox * 3 + 2 paid * 2 + assert stats["filtered_messages"] == 6 # 2 sandbox tenants * 3 messages + assert stats["total_deleted"] == 6 + + # Only sandbox messages should be deleted + assert db.session.query(Message).where(Message.id.in_(sandbox_message_ids)).count() == 0 + # Paid messages should remain + assert db.session.query(Message).where(Message.id.in_(paid_message_ids)).count() == 4 + + # Related records of sandbox messages should be deleted + assert db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(sandbox_message_ids)).count() == 0 + assert ( + db.session.query(MessageAnnotation).where(MessageAnnotation.message_id.in_(sandbox_message_ids)).count() + == 0 + ) + + def test_cursor_pagination_multiple_batches(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + """Test cursor pagination works correctly across multiple batches (B3).""" + # Arrange - Create sandbox tenant with messages that will span multiple batches + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create 10 expired messages with different timestamps + base_date = datetime.datetime.now() - datetime.timedelta(days=35) + message_ids = [] + for i in range(10): + msg = self._create_message( + app, + conv, + created_at=base_date + datetime.timedelta(hours=i), + with_relations=False, # Skip relations for speed + ) + message_ids.append(msg.id) + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = { + tenant.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, # No previous subscription + } + } + + # Act - Use small batch size to trigger multiple batches + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=3, # Small batch size to test pagination + ) + stats = service.run() + + # 5 batches for 10 messages with batch_size=3, the last batch is empty + assert stats["batches"] == 5 + assert stats["total_messages"] == 10 + assert stats["filtered_messages"] == 10 + assert stats["total_deleted"] == 10 + + # All messages should be deleted + assert db.session.query(Message).where(Message.id.in_(message_ids)).count() == 0 + + def test_dry_run_does_not_delete(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + """Test dry_run mode does not delete messages (B4).""" + # Arrange + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create expired messages + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + message_ids = [] + for i in range(3): + msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=i)) + message_ids.append(msg.id) + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = { + tenant.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, # No previous subscription + } + } + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=100, + dry_run=True, # Dry run mode + ) + stats = service.run() + + # Assert + assert stats["total_messages"] == 3 + assert stats["filtered_messages"] == 3 # Messages identified + assert stats["total_deleted"] == 0 # But NOT deleted + + # All messages should still exist + assert db.session.query(Message).where(Message.id.in_(message_ids)).count() == 3 + # Related records should also still exist + assert db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)).count() == 3 + + def test_partial_plan_data_safe_default(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + """Test when billing returns partial data, unknown tenants are preserved (B5).""" + # Arrange - Create 3 tenants + tenants_data = [] + for i in range(3): + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + msg = self._create_message(app, conv, created_at=expired_date) + + tenants_data.append( + { + "tenant": tenant, + "message_id": msg.id, + } + ) + + # Mock billing service to return partial data + now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) + + # Only tenant[0] is confirmed as sandbox, tenant[1] is professional, tenant[2] is missing + partial_plan_map = { + tenants_data[0]["tenant"].id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, # No previous subscription + }, + tenants_data[1]["tenant"].id: { + "plan": CloudPlan.PROFESSIONAL, + "expiration_date": now_timestamp + (365 * 24 * 60 * 60), # Active for 1 year + }, + # tenants_data[2] is missing from response + } + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = partial_plan_map + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=100, + ) + stats = service.run() + + # Assert - Only tenant[0]'s message should be deleted + assert stats["total_messages"] == 3 # 3 tenants * 1 message + assert stats["filtered_messages"] == 1 + assert stats["total_deleted"] == 1 + + # Check which messages were deleted + assert ( + db.session.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 0 + ) # Sandbox tenant's message deleted + + assert ( + db.session.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1 + ) # Professional tenant's message preserved + + assert ( + db.session.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 1 + ) # Unknown tenant's message preserved (safe default) + + def test_empty_plan_data_skips_deletion(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + """Test when billing returns empty data, skip deletion entirely (B6).""" + # Arrange + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + msg = self._create_message(app, conv, created_at=expired_date) + msg_id = msg.id + db.session.commit() + + # Mock billing service to return empty data (simulating failure/no data scenario) + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = {} # Empty response, tenant plan unknown + + # Act - Should not raise exception, just skip deletion + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=100, + ) + stats = service.run() + + # Assert - No messages should be deleted when plan is unknown + assert stats["total_messages"] == 1 + assert stats["filtered_messages"] == 0 # Cannot determine sandbox messages + assert stats["total_deleted"] == 0 + + # Message should still exist (safe default - don't delete if plan is unknown) + assert db.session.query(Message).where(Message.id == msg_id).count() == 1 + + def test_time_range_boundary_behavior(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + """Test that messages are correctly filtered by [start_from, end_before) time range (B7).""" + # Arrange + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create messages: before range, in range, after range + msg_before = self._create_message( + app, + conv, + created_at=datetime.datetime(2024, 1, 1, 12, 0, 0), # Before start_from + with_relations=False, + ) + msg_before_id = msg_before.id + + msg_at_start = self._create_message( + app, + conv, + created_at=datetime.datetime(2024, 1, 10, 12, 0, 0), # At start_from (inclusive) + with_relations=False, + ) + msg_at_start_id = msg_at_start.id + + msg_in_range = self._create_message( + app, + conv, + created_at=datetime.datetime(2024, 1, 15, 12, 0, 0), # In range + with_relations=False, + ) + msg_in_range_id = msg_in_range.id + + msg_at_end = self._create_message( + app, + conv, + created_at=datetime.datetime(2024, 1, 20, 12, 0, 0), # At end_before (exclusive) + with_relations=False, + ) + msg_at_end_id = msg_at_end.id + + msg_after = self._create_message( + app, + conv, + created_at=datetime.datetime(2024, 1, 25, 12, 0, 0), # After end_before + with_relations=False, + ) + msg_after_id = msg_after.id + + db.session.commit() + + # Mock billing service + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = { + tenant.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, # No previous subscription + } + } + + # Act - Clean with specific time range [2024-01-10, 2024-01-20) + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime(2024, 1, 10, 12, 0, 0), + end_before=datetime.datetime(2024, 1, 20, 12, 0, 0), + batch_size=100, + ) + stats = service.run() + + # Assert - Only messages in [start_from, end_before) should be deleted + assert stats["total_messages"] == 2 # Only in-range messages fetched + assert stats["filtered_messages"] == 2 # msg_at_start and msg_in_range + assert stats["total_deleted"] == 2 + + # Verify specific messages using stored IDs + # Before range, kept + assert db.session.query(Message).where(Message.id == msg_before_id).count() == 1 + # At start (inclusive), deleted + assert db.session.query(Message).where(Message.id == msg_at_start_id).count() == 0 + # In range, deleted + assert db.session.query(Message).where(Message.id == msg_in_range_id).count() == 0 + # At end (exclusive), kept + assert db.session.query(Message).where(Message.id == msg_at_end_id).count() == 1 + # After range, kept + assert db.session.query(Message).where(Message.id == msg_after_id).count() == 1 + + def test_grace_period_scenarios(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + """Test cleaning with different graceful period scenarios (B8).""" + # Arrange - Create 5 different tenants with different plan and expiration scenarios + now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) + graceful_period = 8 # Use 8 days for this test + + # Scenario 1: Sandbox plan with expiration within graceful period (5 days ago) + # Should NOT be deleted + account1, tenant1 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app1 = self._create_app(tenant1, account1) + conv1 = self._create_conversation(app1) + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + msg1 = self._create_message(app1, conv1, created_at=expired_date, with_relations=False) + msg1_id = msg1.id + expired_5_days_ago = now_timestamp - (5 * 24 * 60 * 60) # Within grace period + + # Scenario 2: Sandbox plan with expiration beyond graceful period (10 days ago) + # Should be deleted + account2, tenant2 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app2 = self._create_app(tenant2, account2) + conv2 = self._create_conversation(app2) + msg2 = self._create_message(app2, conv2, created_at=expired_date, with_relations=False) + msg2_id = msg2.id + expired_10_days_ago = now_timestamp - (10 * 24 * 60 * 60) # Beyond grace period + + # Scenario 3: Sandbox plan with expiration_date = -1 (no previous subscription) + # Should be deleted + account3, tenant3 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app3 = self._create_app(tenant3, account3) + conv3 = self._create_conversation(app3) + msg3 = self._create_message(app3, conv3, created_at=expired_date, with_relations=False) + msg3_id = msg3.id + + # Scenario 4: Non-sandbox plan (professional) with no expiration (future date) + # Should NOT be deleted + account4, tenant4 = self._create_account_and_tenant(plan=CloudPlan.PROFESSIONAL) + app4 = self._create_app(tenant4, account4) + conv4 = self._create_conversation(app4) + msg4 = self._create_message(app4, conv4, created_at=expired_date, with_relations=False) + msg4_id = msg4.id + future_expiration = now_timestamp + (365 * 24 * 60 * 60) # Active for 1 year + + # Scenario 5: Sandbox plan with expiration exactly at grace period boundary (8 days ago) + # Should NOT be deleted (boundary is exclusive: > graceful_period) + account5, tenant5 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app5 = self._create_app(tenant5, account5) + conv5 = self._create_conversation(app5) + msg5 = self._create_message(app5, conv5, created_at=expired_date, with_relations=False) + msg5_id = msg5.id + expired_exactly_8_days_ago = now_timestamp - (8 * 24 * 60 * 60) # Exactly at boundary + + db.session.commit() + + # Mock billing service with all scenarios + plan_map = { + tenant1.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": expired_5_days_ago, + }, + tenant2.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": expired_10_days_ago, + }, + tenant3.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, + }, + tenant4.id: { + "plan": CloudPlan.PROFESSIONAL, + "expiration_date": future_expiration, + }, + tenant5.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": expired_exactly_8_days_ago, + }, + } + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = plan_map + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy( + graceful_period_days=graceful_period, + current_timestamp=now_timestamp, # Use fixed timestamp for deterministic behavior + ) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=100, + ) + stats = service.run() + + # Assert - Only messages from scenario 2 and 3 should be deleted + assert stats["total_messages"] == 5 # 5 tenants * 1 message + assert stats["filtered_messages"] == 2 + assert stats["total_deleted"] == 2 + + # Verify each scenario using saved IDs + assert db.session.query(Message).where(Message.id == msg1_id).count() == 1 # Within grace, kept + assert db.session.query(Message).where(Message.id == msg2_id).count() == 0 # Beyond grace, deleted + assert db.session.query(Message).where(Message.id == msg3_id).count() == 0 # No subscription, deleted + assert db.session.query(Message).where(Message.id == msg4_id).count() == 1 # Professional plan, kept + assert db.session.query(Message).where(Message.id == msg5_id).count() == 1 # At boundary, kept + + def test_tenant_whitelist(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + """Test that whitelisted tenants' messages are not deleted (B9).""" + # Arrange - Create 3 sandbox tenants with expired messages + tenants_data = [] + for i in range(3): + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + msg = self._create_message(app, conv, created_at=expired_date, with_relations=False) + + tenants_data.append( + { + "tenant": tenant, + "message_id": msg.id, + } + ) + + # Mock billing service - all tenants are sandbox with no subscription + plan_map = { + tenants_data[0]["tenant"].id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, + }, + tenants_data[1]["tenant"].id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, + }, + tenants_data[2]["tenant"].id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, + }, + } + + # Setup whitelist - tenant0 and tenant1 are whitelisted, tenant2 is not + whitelist = [tenants_data[0]["tenant"].id, tenants_data[1]["tenant"].id] + mock_whitelist.return_value = whitelist + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = plan_map + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=100, + ) + stats = service.run() + + # Assert - Only tenant2's message should be deleted (not whitelisted) + assert stats["total_messages"] == 3 # 3 tenants * 1 message + assert stats["filtered_messages"] == 1 + assert stats["total_deleted"] == 1 + + # Verify tenant0's message still exists (whitelisted) + assert db.session.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 1 + + # Verify tenant1's message still exists (whitelisted) + assert db.session.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1 + + # Verify tenant2's message was deleted (not whitelisted) + assert db.session.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 0 + + def test_from_days_cleans_old_messages(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + """Test from_days correctly cleans messages older than N days (B11).""" + # Arrange + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create old messages (should be deleted - older than 30 days) + old_date = datetime.datetime.now() - datetime.timedelta(days=45) + old_msg_ids = [] + for i in range(3): + msg = self._create_message( + app, conv, created_at=old_date - datetime.timedelta(hours=i), with_relations=False + ) + old_msg_ids.append(msg.id) + + # Create recent messages (should be kept - newer than 30 days) + recent_date = datetime.datetime.now() - datetime.timedelta(days=15) + recent_msg_ids = [] + for i in range(2): + msg = self._create_message( + app, conv, created_at=recent_date - datetime.timedelta(hours=i), with_relations=False + ) + recent_msg_ids.append(msg.id) + + db.session.commit() + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = { + tenant.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, + } + } + + # Act - Use from_days to clean messages older than 30 days + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_days( + policy=policy, + days=30, + batch_size=100, + ) + stats = service.run() + + # Assert + assert stats["total_messages"] == 3 # Only old messages in range + assert stats["filtered_messages"] == 3 # Only old messages + assert stats["total_deleted"] == 3 + + # Old messages deleted + assert db.session.query(Message).where(Message.id.in_(old_msg_ids)).count() == 0 + # Recent messages kept + assert db.session.query(Message).where(Message.id.in_(recent_msg_ids)).count() == 2 + + def test_whitelist_precedence_over_grace_period( + self, db_session_with_containers, mock_billing_enabled, mock_whitelist + ): + """Test that whitelist takes precedence over grace period logic.""" + # Arrange - Create 2 sandbox tenants + now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) + + # Tenant1: whitelisted, expired beyond grace period + account1, tenant1 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app1 = self._create_app(tenant1, account1) + conv1 = self._create_conversation(app1) + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + msg1 = self._create_message(app1, conv1, created_at=expired_date, with_relations=False) + expired_30_days_ago = now_timestamp - (30 * 24 * 60 * 60) # Well beyond 21-day grace + + # Tenant2: not whitelisted, within grace period + account2, tenant2 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app2 = self._create_app(tenant2, account2) + conv2 = self._create_conversation(app2) + msg2 = self._create_message(app2, conv2, created_at=expired_date, with_relations=False) + expired_10_days_ago = now_timestamp - (10 * 24 * 60 * 60) # Within 21-day grace + + # Mock billing service + plan_map = { + tenant1.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": expired_30_days_ago, # Beyond grace period + }, + tenant2.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": expired_10_days_ago, # Within grace period + }, + } + + # Setup whitelist - only tenant1 is whitelisted + whitelist = [tenant1.id] + mock_whitelist.return_value = whitelist + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = plan_map + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=100, + ) + stats = service.run() + + # Assert - No messages should be deleted + # tenant1: whitelisted (protected even though beyond grace period) + # tenant2: within grace period (not eligible for deletion) + assert stats["total_messages"] == 2 # 2 tenants * 1 message + assert stats["filtered_messages"] == 0 + assert stats["total_deleted"] == 0 + + # Verify both messages still exist + assert db.session.query(Message).where(Message.id == msg1.id).count() == 1 # Whitelisted + assert db.session.query(Message).where(Message.id == msg2.id).count() == 1 # Within grace period + + def test_empty_whitelist_deletes_eligible_messages( + self, db_session_with_containers, mock_billing_enabled, mock_whitelist + ): + """Test that empty whitelist behaves as no whitelist (all eligible messages deleted).""" + # Arrange - Create sandbox tenant with expired messages + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + msg_ids = [] + for i in range(3): + msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=i)) + msg_ids.append(msg.id) + + # Mock billing service + plan_map = { + tenant.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, + } + } + + # Setup empty whitelist (default behavior from fixture) + mock_whitelist.return_value = [] + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = plan_map + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=100, + ) + stats = service.run() + + # Assert - All messages should be deleted (no whitelist protection) + assert stats["total_messages"] == 3 + assert stats["filtered_messages"] == 3 + assert stats["total_deleted"] == 3 + + # Verify all messages were deleted + assert db.session.query(Message).where(Message.id.in_(msg_ids)).count() == 0 diff --git a/api/tests/unit_tests/services/test_messages_clean_service.py b/api/tests/unit_tests/services/test_messages_clean_service.py new file mode 100644 index 0000000000..3b619195c7 --- /dev/null +++ b/api/tests/unit_tests/services/test_messages_clean_service.py @@ -0,0 +1,627 @@ +import datetime +from unittest.mock import MagicMock, patch + +import pytest + +from enums.cloud_plan import CloudPlan +from services.retention.conversation.messages_clean_policy import ( + BillingDisabledPolicy, + BillingSandboxPolicy, + SimpleMessage, + create_message_clean_policy, +) +from services.retention.conversation.messages_clean_service import MessagesCleanService + + +def make_simple_message(msg_id: str, app_id: str) -> SimpleMessage: + """Helper to create a SimpleMessage with a fixed created_at timestamp.""" + return SimpleMessage(id=msg_id, app_id=app_id, created_at=datetime.datetime(2024, 1, 1)) + + +def make_plan_provider(tenant_plans: dict) -> MagicMock: + """Helper to create a mock plan_provider that returns the given tenant_plans.""" + provider = MagicMock() + provider.return_value = tenant_plans + return provider + + +class TestBillingSandboxPolicyFilterMessageIds: + """Unit tests for BillingSandboxPolicy.filter_message_ids method.""" + + # Fixed timestamp for deterministic tests + CURRENT_TIMESTAMP = 1000000 + GRACEFUL_PERIOD_DAYS = 8 + GRACEFUL_PERIOD_SECONDS = GRACEFUL_PERIOD_DAYS * 24 * 60 * 60 + + def test_missing_tenant_mapping_excluded(self): + """Test that messages with missing app-to-tenant mapping are excluded.""" + # Arrange + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + ] + app_to_tenant = {} # No mapping + tenant_plans = {"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}} + plan_provider = make_plan_provider(tenant_plans) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + current_timestamp=self.CURRENT_TIMESTAMP, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert + assert list(result) == [] + + def test_missing_tenant_plan_excluded(self): + """Test that messages with missing tenant plan are excluded (safe default).""" + # Arrange + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + ] + app_to_tenant = {"app1": "tenant1", "app2": "tenant2"} + tenant_plans = {} # No plans + plan_provider = make_plan_provider(tenant_plans) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + current_timestamp=self.CURRENT_TIMESTAMP, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert + assert list(result) == [] + + def test_non_sandbox_plan_excluded(self): + """Test that messages from non-sandbox plans (PROFESSIONAL/TEAM) are excluded.""" + # Arrange + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + make_simple_message("msg3", "app3"), + ] + app_to_tenant = {"app1": "tenant1", "app2": "tenant2", "app3": "tenant3"} + tenant_plans = { + "tenant1": {"plan": CloudPlan.PROFESSIONAL, "expiration_date": -1}, + "tenant2": {"plan": CloudPlan.TEAM, "expiration_date": -1}, + "tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, # Only this one + } + plan_provider = make_plan_provider(tenant_plans) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + current_timestamp=self.CURRENT_TIMESTAMP, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - only msg3 (sandbox tenant) should be included + assert set(result) == {"msg3"} + + def test_whitelist_skip(self): + """Test that whitelisted tenants are excluded even if sandbox + expired.""" + # Arrange + messages = [ + make_simple_message("msg1", "app1"), # Whitelisted - excluded + make_simple_message("msg2", "app2"), # Not whitelisted - included + make_simple_message("msg3", "app3"), # Whitelisted - excluded + ] + app_to_tenant = {"app1": "tenant1", "app2": "tenant2", "app3": "tenant3"} + tenant_plans = { + "tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + "tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + "tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + } + plan_provider = make_plan_provider(tenant_plans) + tenant_whitelist = ["tenant1", "tenant3"] + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + tenant_whitelist=tenant_whitelist, + current_timestamp=self.CURRENT_TIMESTAMP, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - only msg2 should be included + assert set(result) == {"msg2"} + + def test_no_previous_subscription_included(self): + """Test that messages with expiration_date=-1 (no previous subscription) are included.""" + # Arrange + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + ] + app_to_tenant = {"app1": "tenant1", "app2": "tenant2"} + tenant_plans = { + "tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + "tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + } + plan_provider = make_plan_provider(tenant_plans) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + current_timestamp=self.CURRENT_TIMESTAMP, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - all messages should be included + assert set(result) == {"msg1", "msg2"} + + def test_within_grace_period_excluded(self): + """Test that messages within grace period are excluded.""" + # Arrange + now = self.CURRENT_TIMESTAMP + expired_1_day_ago = now - (1 * 24 * 60 * 60) + expired_5_days_ago = now - (5 * 24 * 60 * 60) + expired_7_days_ago = now - (7 * 24 * 60 * 60) + + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + make_simple_message("msg3", "app3"), + ] + app_to_tenant = {"app1": "tenant1", "app2": "tenant2", "app3": "tenant3"} + tenant_plans = { + "tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_1_day_ago}, + "tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_5_days_ago}, + "tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_7_days_ago}, + } + plan_provider = make_plan_provider(tenant_plans) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, # 8 days + current_timestamp=now, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - all within 8-day grace period, none should be included + assert list(result) == [] + + def test_exactly_at_boundary_excluded(self): + """Test that messages exactly at grace period boundary are excluded (code uses >).""" + # Arrange + now = self.CURRENT_TIMESTAMP + expired_exactly_8_days_ago = now - self.GRACEFUL_PERIOD_SECONDS # Exactly at boundary + + messages = [make_simple_message("msg1", "app1")] + app_to_tenant = {"app1": "tenant1"} + tenant_plans = { + "tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_exactly_8_days_ago}, + } + plan_provider = make_plan_provider(tenant_plans) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + current_timestamp=now, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - exactly at boundary (==) should be excluded (code uses >) + assert list(result) == [] + + def test_beyond_grace_period_included(self): + """Test that messages beyond grace period are included.""" + # Arrange + now = self.CURRENT_TIMESTAMP + expired_9_days_ago = now - (9 * 24 * 60 * 60) # Just beyond 8-day grace + expired_30_days_ago = now - (30 * 24 * 60 * 60) # Well beyond + + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + ] + app_to_tenant = {"app1": "tenant1", "app2": "tenant2"} + tenant_plans = { + "tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_9_days_ago}, + "tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_30_days_ago}, + } + plan_provider = make_plan_provider(tenant_plans) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + current_timestamp=now, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - both beyond grace period, should be included + assert set(result) == {"msg1", "msg2"} + + def test_empty_messages_returns_empty(self): + """Test that empty messages returns empty list.""" + # Arrange + messages: list[SimpleMessage] = [] + app_to_tenant = {"app1": "tenant1"} + plan_provider = make_plan_provider({"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}}) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + current_timestamp=self.CURRENT_TIMESTAMP, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert + assert list(result) == [] + + def test_plan_provider_called_with_correct_tenant_ids(self): + """Test that plan_provider is called with correct tenant_ids.""" + # Arrange + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + make_simple_message("msg3", "app3"), + ] + app_to_tenant = {"app1": "tenant1", "app2": "tenant2", "app3": "tenant1"} # tenant1 appears twice + plan_provider = make_plan_provider({}) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + current_timestamp=self.CURRENT_TIMESTAMP, + ) + + # Act + policy.filter_message_ids(messages, app_to_tenant) + + # Assert - plan_provider should be called once with unique tenant_ids + plan_provider.assert_called_once() + called_tenant_ids = set(plan_provider.call_args[0][0]) + assert called_tenant_ids == {"tenant1", "tenant2"} + + def test_complex_mixed_scenario(self): + """Test complex scenario with mixed plans, expirations, whitelist, and missing mappings.""" + # Arrange + now = self.CURRENT_TIMESTAMP + sandbox_expired_old = now - (15 * 24 * 60 * 60) # Beyond grace + sandbox_expired_recent = now - (3 * 24 * 60 * 60) # Within grace + future_expiration = now + (30 * 24 * 60 * 60) + + messages = [ + make_simple_message("msg1", "app1"), # Sandbox, no subscription - included + make_simple_message("msg2", "app2"), # Sandbox, expired old - included + make_simple_message("msg3", "app3"), # Sandbox, within grace - excluded + make_simple_message("msg4", "app4"), # Team plan, active - excluded + make_simple_message("msg5", "app5"), # No tenant mapping - excluded + make_simple_message("msg6", "app6"), # No plan info - excluded + make_simple_message("msg7", "app7"), # Sandbox, expired old, whitelisted - excluded + ] + app_to_tenant = { + "app1": "tenant1", + "app2": "tenant2", + "app3": "tenant3", + "app4": "tenant4", + "app6": "tenant6", # Has mapping but no plan + "app7": "tenant7", + # app5 has no mapping + } + tenant_plans = { + "tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + "tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_old}, + "tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_recent}, + "tenant4": {"plan": CloudPlan.TEAM, "expiration_date": future_expiration}, + "tenant7": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_old}, + # tenant6 has no plan + } + plan_provider = make_plan_provider(tenant_plans) + tenant_whitelist = ["tenant7"] + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + tenant_whitelist=tenant_whitelist, + current_timestamp=now, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - only msg1 and msg2 should be included + assert set(result) == {"msg1", "msg2"} + + +class TestBillingDisabledPolicyFilterMessageIds: + """Unit tests for BillingDisabledPolicy.filter_message_ids method.""" + + def test_returns_all_message_ids(self): + """Test that all message IDs are returned (order-preserving).""" + # Arrange + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + make_simple_message("msg3", "app3"), + ] + app_to_tenant = {"app1": "tenant1", "app2": "tenant2"} + + policy = BillingDisabledPolicy() + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - all message IDs returned in order + assert list(result) == ["msg1", "msg2", "msg3"] + + def test_ignores_app_to_tenant(self): + """Test that app_to_tenant mapping is ignored.""" + # Arrange + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + ] + app_to_tenant: dict[str, str] = {} # Empty - should be ignored + + policy = BillingDisabledPolicy() + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - all message IDs still returned + assert list(result) == ["msg1", "msg2"] + + def test_empty_messages_returns_empty(self): + """Test that empty messages returns empty list.""" + # Arrange + messages: list[SimpleMessage] = [] + app_to_tenant = {"app1": "tenant1"} + + policy = BillingDisabledPolicy() + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert + assert list(result) == [] + + +class TestCreateMessageCleanPolicy: + """Unit tests for create_message_clean_policy factory function.""" + + @patch("services.retention.conversation.messages_clean_policy.dify_config") + def test_billing_disabled_returns_billing_disabled_policy(self, mock_config): + """Test that BILLING_ENABLED=False returns BillingDisabledPolicy.""" + # Arrange + mock_config.BILLING_ENABLED = False + + # Act + policy = create_message_clean_policy(graceful_period_days=21) + + # Assert + assert isinstance(policy, BillingDisabledPolicy) + + @patch("services.retention.conversation.messages_clean_policy.BillingService") + @patch("services.retention.conversation.messages_clean_policy.dify_config") + def test_billing_enabled_policy_has_correct_internals(self, mock_config, mock_billing_service): + """Test that BillingSandboxPolicy is created with correct internal values.""" + # Arrange + mock_config.BILLING_ENABLED = True + whitelist = ["tenant1", "tenant2"] + mock_billing_service.get_expired_subscription_cleanup_whitelist.return_value = whitelist + mock_plan_provider = MagicMock() + mock_billing_service.get_plan_bulk_with_cache = mock_plan_provider + + # Act + policy = create_message_clean_policy(graceful_period_days=14, current_timestamp=1234567) + + # Assert + mock_billing_service.get_expired_subscription_cleanup_whitelist.assert_called_once() + assert isinstance(policy, BillingSandboxPolicy) + assert policy._graceful_period_days == 14 + assert list(policy._tenant_whitelist) == whitelist + assert policy._plan_provider == mock_plan_provider + assert policy._current_timestamp == 1234567 + + +class TestMessagesCleanServiceFromTimeRange: + """Unit tests for MessagesCleanService.from_time_range factory method.""" + + def test_start_from_end_before_raises_value_error(self): + """Test that start_from == end_before raises ValueError.""" + policy = BillingDisabledPolicy() + + # Arrange + same_time = datetime.datetime(2024, 1, 1, 12, 0, 0) + + # Act & Assert + with pytest.raises(ValueError, match="start_from .* must be less than end_before"): + MessagesCleanService.from_time_range( + policy=policy, + start_from=same_time, + end_before=same_time, + ) + + # Arrange + start_from = datetime.datetime(2024, 12, 31) + end_before = datetime.datetime(2024, 1, 1) + + # Act & Assert + with pytest.raises(ValueError, match="start_from .* must be less than end_before"): + MessagesCleanService.from_time_range( + policy=policy, + start_from=start_from, + end_before=end_before, + ) + + def test_batch_size_raises_value_error(self): + """Test that batch_size=0 raises ValueError.""" + # Arrange + start_from = datetime.datetime(2024, 1, 1) + end_before = datetime.datetime(2024, 2, 1) + policy = BillingDisabledPolicy() + + # Act & Assert + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + MessagesCleanService.from_time_range( + policy=policy, + start_from=start_from, + end_before=end_before, + batch_size=0, + ) + + start_from = datetime.datetime(2024, 1, 1) + end_before = datetime.datetime(2024, 2, 1) + policy = BillingDisabledPolicy() + + # Act & Assert + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + MessagesCleanService.from_time_range( + policy=policy, + start_from=start_from, + end_before=end_before, + batch_size=-100, + ) + + def test_valid_params_creates_instance(self): + """Test that valid parameters create a correctly configured instance.""" + # Arrange + start_from = datetime.datetime(2024, 1, 1, 0, 0, 0) + end_before = datetime.datetime(2024, 12, 31, 23, 59, 59) + policy = BillingDisabledPolicy() + batch_size = 500 + dry_run = True + + # Act + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=start_from, + end_before=end_before, + batch_size=batch_size, + dry_run=dry_run, + ) + + # Assert + assert isinstance(service, MessagesCleanService) + assert service._policy is policy + assert service._start_from == start_from + assert service._end_before == end_before + assert service._batch_size == batch_size + assert service._dry_run == dry_run + + def test_default_params(self): + """Test that default parameters are applied correctly.""" + # Arrange + start_from = datetime.datetime(2024, 1, 1) + end_before = datetime.datetime(2024, 2, 1) + policy = BillingDisabledPolicy() + + # Act + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=start_from, + end_before=end_before, + ) + + # Assert + assert service._batch_size == 1000 # default + assert service._dry_run is False # default + + +class TestMessagesCleanServiceFromDays: + """Unit tests for MessagesCleanService.from_days factory method.""" + + def test_days_raises_value_error(self): + """Test that days < 0 raises ValueError.""" + # Arrange + policy = BillingDisabledPolicy() + + # Act & Assert + with pytest.raises(ValueError, match="days .* must be greater than or equal to 0"): + MessagesCleanService.from_days(policy=policy, days=-1) + + # Act + with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime: + fixed_now = datetime.datetime(2024, 6, 15, 14, 0, 0) + mock_datetime.datetime.now.return_value = fixed_now + mock_datetime.timedelta = datetime.timedelta + + service = MessagesCleanService.from_days(policy=policy, days=0) + + # Assert + assert service._end_before == fixed_now + + def test_batch_size_raises_value_error(self): + """Test that batch_size=0 raises ValueError.""" + # Arrange + policy = BillingDisabledPolicy() + + # Act & Assert + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + MessagesCleanService.from_days(policy=policy, days=30, batch_size=0) + + # Act & Assert + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + MessagesCleanService.from_days(policy=policy, days=30, batch_size=-500) + + def test_valid_params_creates_instance(self): + """Test that valid parameters create a correctly configured instance.""" + # Arrange + policy = BillingDisabledPolicy() + days = 90 + batch_size = 500 + dry_run = True + + # Act + with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime: + fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0) + mock_datetime.datetime.now.return_value = fixed_now + mock_datetime.timedelta = datetime.timedelta + + service = MessagesCleanService.from_days( + policy=policy, + days=days, + batch_size=batch_size, + dry_run=dry_run, + ) + + # Assert + expected_end_before = fixed_now - datetime.timedelta(days=days) + assert isinstance(service, MessagesCleanService) + assert service._policy is policy + assert service._start_from is None + assert service._end_before == expected_end_before + assert service._batch_size == batch_size + assert service._dry_run == dry_run + + def test_default_params(self): + """Test that default parameters are applied correctly.""" + # Arrange + policy = BillingDisabledPolicy() + + # Act + with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime: + fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0) + mock_datetime.datetime.now.return_value = fixed_now + mock_datetime.timedelta = datetime.timedelta + + service = MessagesCleanService.from_days(policy=policy) + + # Assert + expected_end_before = fixed_now - datetime.timedelta(days=30) # default days=30 + assert service._end_before == expected_end_before + assert service._batch_size == 1000 # default + assert service._dry_run is False # default From 33e99f069bec8dfe6338597e47045bac7975d511 Mon Sep 17 00:00:00 2001 From: hj24 Date: Thu, 15 Jan 2026 15:13:25 +0800 Subject: [PATCH 5/8] fix: message clean service ut (#31038) --- .../services/test_messages_clean_service.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py index 29baa4d94f..5b6db64c09 100644 --- a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py +++ b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py @@ -271,6 +271,7 @@ class TestMessagesCleanServiceIntegration: source="annotation", question="Test question", account_id=message.from_account_id, + score=0.9, annotation_question="Test annotation question", annotation_content="Test annotation content", ) From ab1c5a202737b997878f7b6451543e9d768aa52f Mon Sep 17 00:00:00 2001 From: Stephen Zhou <38493346+hyoban@users.noreply.github.com> Date: Thu, 15 Jan 2026 15:25:43 +0800 Subject: [PATCH 6/8] refactor: remove manual set query logic (#31039) --- web/app/components/apps/list.tsx | 30 ------------------------- web/app/components/header/nav/index.tsx | 14 +++--------- 2 files changed, 3 insertions(+), 41 deletions(-) diff --git a/web/app/components/apps/list.tsx b/web/app/components/apps/list.tsx index 095ed3f696..84150ad480 100644 --- a/web/app/components/apps/list.tsx +++ b/web/app/components/apps/list.tsx @@ -12,7 +12,6 @@ import { useDebounceFn } from 'ahooks' import dynamic from 'next/dynamic' import { useRouter, - useSearchParams, } from 'next/navigation' import { parseAsString, useQueryState } from 'nuqs' import { useCallback, useEffect, useRef, useState } from 'react' @@ -29,7 +28,6 @@ import { CheckModal } from '@/hooks/use-pay' import { useInfiniteAppList } from '@/service/use-apps' import { AppModeEnum } from '@/types/app' import { cn } from '@/utils/classnames' -import { isServer } from '@/utils/client' import AppCard from './app-card' import { AppCardSkeleton } from './app-card-skeleton' import Empty from './empty' @@ -59,7 +57,6 @@ const List = () => { const { t } = useTranslation() const { systemFeatures } = useGlobalPublicStore() const router = useRouter() - const searchParams = useSearchParams() const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator, isLoadingCurrentWorkspace } = useAppContext() const showTagManagementModal = useTagStore(s => s.showTagManagementModal) const [activeTab, setActiveTab] = useQueryState( @@ -67,33 +64,6 @@ const List = () => { parseAsString.withDefault('all').withOptions({ history: 'push' }), ) - // valid tabs for apps list; anything else should fallback to 'all' - - // 1) Normalize legacy/incorrect query params like ?mode=discover -> ?category=all - useEffect(() => { - // avoid running on server - if (isServer) - return - const mode = searchParams.get('mode') - if (!mode) - return - const url = new URL(window.location.href) - url.searchParams.delete('mode') - if (validTabs.has(mode)) { - // migrate to category key - url.searchParams.set('category', mode) - } - else { - url.searchParams.set('category', 'all') - } - router.replace(url.pathname + url.search) - }, [router, searchParams]) - - // 2) If category has an invalid value (e.g., 'discover'), reset to 'all' - useEffect(() => { - if (!validTabs.has(activeTab)) - setActiveTab('all') - }, [activeTab, setActiveTab]) const { query: { tagIDs = [], keywords = '', isCreatedByMe: queryIsCreatedByMe = false }, setQuery } = useAppsQueryState() const [isCreatedByMe, setIsCreatedByMe] = useState(queryIsCreatedByMe) const [tagFilterValue, setTagFilterValue] = useState(tagIDs) diff --git a/web/app/components/header/nav/index.tsx b/web/app/components/header/nav/index.tsx index 83e75b8513..2edc64486e 100644 --- a/web/app/components/header/nav/index.tsx +++ b/web/app/components/header/nav/index.tsx @@ -2,9 +2,9 @@ import type { INavSelectorProps } from './nav-selector' import Link from 'next/link' -import { usePathname, useSearchParams, useSelectedLayoutSegment } from 'next/navigation' +import { useSelectedLayoutSegment } from 'next/navigation' import * as React from 'react' -import { useEffect, useState } from 'react' +import { useState } from 'react' import { useStore as useAppStore } from '@/app/components/app/store' import { ArrowNarrowLeft } from '@/app/components/base/icons/src/vender/line/arrows' import { cn } from '@/utils/classnames' @@ -36,14 +36,6 @@ const Nav = ({ const [hovered, setHovered] = useState(false) const segment = useSelectedLayoutSegment() const isActivated = Array.isArray(activeSegment) ? activeSegment.includes(segment!) : segment === activeSegment - const pathname = usePathname() - const searchParams = useSearchParams() - const [linkLastSearchParams, setLinkLastSearchParams] = useState('') - - useEffect(() => { - if (pathname === link) - setLinkLastSearchParams(searchParams.toString()) - }, [pathname, searchParams]) return (
- +
{ // Don't clear state if opening in new tab/window From 772ff636ec92b5b35635769a17ab0e49c3576df9 Mon Sep 17 00:00:00 2001 From: Xiyuan Chen <52963600+GareArc@users.noreply.github.com> Date: Wed, 14 Jan 2026 23:33:24 -0800 Subject: [PATCH 7/8] feat: credential sync fix for enterprise edition (#30626) --- api/events/event_handlers/__init__.py | 2 + ...eue_credential_sync_when_tenant_created.py | 19 ++++++ api/services/enterprise/workspace_sync.py | 58 +++++++++++++++++++ 3 files changed, 79 insertions(+) create mode 100644 api/events/event_handlers/queue_credential_sync_when_tenant_created.py create mode 100644 api/services/enterprise/workspace_sync.py diff --git a/api/events/event_handlers/__init__.py b/api/events/event_handlers/__init__.py index c79764983b..d37217e168 100644 --- a/api/events/event_handlers/__init__.py +++ b/api/events/event_handlers/__init__.py @@ -6,6 +6,7 @@ from .create_site_record_when_app_created import handle as handle_create_site_re from .delete_tool_parameters_cache_when_sync_draft_workflow import ( handle as handle_delete_tool_parameters_cache_when_sync_draft_workflow, ) +from .queue_credential_sync_when_tenant_created import handle as handle_queue_credential_sync_when_tenant_created from .sync_plugin_trigger_when_app_created import handle as handle_sync_plugin_trigger_when_app_created from .sync_webhook_when_app_created import handle as handle_sync_webhook_when_app_created from .sync_workflow_schedule_when_app_published import handle as handle_sync_workflow_schedule_when_app_published @@ -30,6 +31,7 @@ __all__ = [ "handle_create_installed_app_when_app_created", "handle_create_site_record_when_app_created", "handle_delete_tool_parameters_cache_when_sync_draft_workflow", + "handle_queue_credential_sync_when_tenant_created", "handle_sync_plugin_trigger_when_app_created", "handle_sync_webhook_when_app_created", "handle_sync_workflow_schedule_when_app_published", diff --git a/api/events/event_handlers/queue_credential_sync_when_tenant_created.py b/api/events/event_handlers/queue_credential_sync_when_tenant_created.py new file mode 100644 index 0000000000..6566c214b0 --- /dev/null +++ b/api/events/event_handlers/queue_credential_sync_when_tenant_created.py @@ -0,0 +1,19 @@ +from configs import dify_config +from events.tenant_event import tenant_was_created +from services.enterprise.workspace_sync import WorkspaceSyncService + + +@tenant_was_created.connect +def handle(sender, **kwargs): + """Queue credential sync when a tenant/workspace is created.""" + # Only queue sync tasks if plugin manager (enterprise feature) is enabled + if not dify_config.ENTERPRISE_ENABLED: + return + + tenant = sender + + # Determine source from kwargs if available, otherwise use generic + source = kwargs.get("source", "tenant_created") + + # Queue credential sync task to Redis for enterprise backend to process + WorkspaceSyncService.queue_credential_sync(tenant.id, source=source) diff --git a/api/services/enterprise/workspace_sync.py b/api/services/enterprise/workspace_sync.py new file mode 100644 index 0000000000..acfe325397 --- /dev/null +++ b/api/services/enterprise/workspace_sync.py @@ -0,0 +1,58 @@ +import json +import logging +import uuid +from datetime import UTC, datetime + +from redis import RedisError + +from extensions.ext_redis import redis_client + +logger = logging.getLogger(__name__) + +WORKSPACE_SYNC_QUEUE = "enterprise:workspace:sync:queue" +WORKSPACE_SYNC_PROCESSING = "enterprise:workspace:sync:processing" + + +class WorkspaceSyncService: + """Service to publish workspace sync tasks to Redis queue for enterprise backend consumption""" + + @staticmethod + def queue_credential_sync(workspace_id: str, *, source: str) -> bool: + """ + Queue a credential sync task for a newly created workspace. + + This publishes a task to Redis that will be consumed by the enterprise backend + worker to sync credentials with the plugin-manager. + + Args: + workspace_id: The workspace/tenant ID to sync credentials for + source: Source of the sync request (for debugging/tracking) + + Returns: + bool: True if task was queued successfully, False otherwise + """ + try: + task = { + "task_id": str(uuid.uuid4()), + "workspace_id": workspace_id, + "retry_count": 0, + "created_at": datetime.now(UTC).isoformat(), + "source": source, + } + + # Push to Redis list (queue) - LPUSH adds to the head, worker consumes from tail with RPOP + redis_client.lpush(WORKSPACE_SYNC_QUEUE, json.dumps(task)) + + logger.info( + "Queued credential sync task for workspace %s, task_id: %s, source: %s", + workspace_id, + task["task_id"], + source, + ) + return True + + except (RedisError, TypeError) as e: + logger.error("Failed to queue credential sync for workspace %s: %s", workspace_id, str(e), exc_info=True) + # Don't raise - we don't want to fail workspace creation if queueing fails + # The scheduled task will catch it later + return False From 4a197b94585a25c0f79bce4227b7ef3f89680ad0 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Thu, 15 Jan 2026 15:42:46 +0800 Subject: [PATCH 8/8] fix: fix log updated_at is refreshed (#31045)