diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 0e9a66c579..654991900d 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -3,12 +3,12 @@ from werkzeug.exceptions import Unauthorized from controllers.common.schema import register_response_schema_models from libs.login import current_account_with_tenant, current_user, login_required -from services.feature_service import FeatureModel, FeatureService, SystemFeatureModel +from services.feature_service import FeatureModel, FeatureService, LimitationModel, SystemFeatureModel from . import console_ns from .wraps import account_initialization_required, cloud_utm_record, setup_required -register_response_schema_models(console_ns, FeatureModel, SystemFeatureModel) +register_response_schema_models(console_ns, FeatureModel, LimitationModel, SystemFeatureModel) @console_ns.route("/features") @@ -28,7 +28,32 @@ class FeatureApi(Resource): """Get feature configuration for current tenant""" _, current_tenant_id = current_account_with_tenant() - return FeatureService.get_features(current_tenant_id).model_dump() + payload = FeatureService.get_features( + current_tenant_id, + exclude_vector_space=True, + ).model_dump() + payload.pop("vector_space", None) + return payload + + +@console_ns.route("/features/vector-space") +class FeatureVectorSpaceApi(Resource): + @console_ns.doc("get_tenant_feature_vector_space") + @console_ns.doc(description="Get vector-space usage and limit for current tenant") + @console_ns.response( + 200, + "Success", + console_ns.models[LimitationModel.__name__], + ) + @setup_required + @login_required + @account_initialization_required + @cloud_utm_record + def get(self): + """Get vector-space usage and limit for current tenant""" + _, current_tenant_id = current_account_with_tenant() + + return FeatureService.get_vector_space(current_tenant_id).model_dump() @console_ns.route("/system-features") diff --git a/api/openapi/markdown/console-swagger.md b/api/openapi/markdown/console-swagger.md index b270d970be..8536cc93ae 100644 --- a/api/openapi/markdown/console-swagger.md +++ b/api/openapi/markdown/console-swagger.md @@ -5693,6 +5693,23 @@ Get feature configuration for current tenant | ---- | ----------- | ------ | | 200 | Success | [FeatureModel](#featuremodel) | +### /features/vector-space + +#### GET +##### Summary + +Get vector-space usage and limit for current tenant + +##### Description + +Get vector-space usage and limit for current tenant + +##### Responses + +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Success | [LimitationModel](#limitationmodel) | + ### /files/support-type #### GET diff --git a/api/services/billing_service.py b/api/services/billing_service.py index c0e23cdc6f..6021d46c72 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -116,7 +116,7 @@ class BillingInfo(TypedDict): subscription: _BillingSubscription members: _BillingQuota apps: _BillingQuota - vector_space: _VectorSpaceQuota + vector_space: NotRequired[_VectorSpaceQuota] knowledge_rate_limit: _KnowledgeRateLimit documents_upload_quota: _BillingQuota annotation_quota_limit: _BillingQuota @@ -128,6 +128,7 @@ class BillingInfo(TypedDict): _billing_info_adapter = TypeAdapter(BillingInfo) +_vector_space_quota_adapter = TypeAdapter(_VectorSpaceQuota) class KnowledgeRateLimitDict(TypedDict): @@ -185,12 +186,21 @@ class BillingService: _PLAN_CACHE_TTL = 600 @classmethod - def get_info(cls, tenant_id: str) -> BillingInfo: + def get_info(cls, tenant_id: str, exclude_vector_space: bool = False) -> BillingInfo: params = {"tenant_id": tenant_id} + if exclude_vector_space: + params["exclude_vector_space"] = "true" billing_info = cls._send_request("GET", "/subscription/info", params=params) return _billing_info_adapter.validate_python(billing_info) + @classmethod + def get_vector_space(cls, tenant_id: str) -> _VectorSpaceQuota: + params = {"tenant_id": tenant_id} + return _vector_space_quota_adapter.validate_python( + cls._send_request("GET", "/subscription/vector-space", params=params) + ) + @classmethod def get_tenant_feature_plan_usage_info(cls, tenant_id: str): """Deprecated: Use get_quota_info instead.""" diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 461cc8b30c..c4f723fde8 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -6,7 +6,7 @@ from configs import dify_config from constants.dsl_version import CURRENT_APP_DSL_VERSION from enums.cloud_plan import CloudPlan from enums.hosted_provider import HostedTrialProvider -from services.billing_service import BillingService +from services.billing_service import BillingInfo, BillingService from services.enterprise.enterprise_service import EnterpriseService @@ -186,13 +186,17 @@ class SystemFeatureModel(FeatureResponseModel): class FeatureService: @classmethod - def get_features(cls, tenant_id: str) -> FeatureModel: + def get_features(cls, tenant_id: str, exclude_vector_space: bool = False) -> FeatureModel: features = FeatureModel() cls._fulfill_params_from_env(features) if dify_config.BILLING_ENABLED and tenant_id: - cls._fulfill_params_from_billing_api(features, tenant_id) + cls._fulfill_params_from_billing_api( + features, + tenant_id, + exclude_vector_space=exclude_vector_space, + ) if dify_config.ENTERPRISE_ENABLED: features.webapp_copyright_enabled = True @@ -206,6 +210,18 @@ class FeatureService: return features + @classmethod + def get_vector_space(cls, tenant_id: str) -> LimitationModel: + vector_space = LimitationModel(size=0, limit=5) + if dify_config.BILLING_ENABLED and tenant_id: + billing_vector_space = BillingService.get_vector_space(tenant_id) + # NOTE: billing API returns vector_space.size as float (e.g. 0.0), + # but feature API keeps LimitationModel.size as int for compatibility. + vector_space.size = int(billing_vector_space["size"]) + vector_space.limit = billing_vector_space["limit"] + + return vector_space + @classmethod def get_knowledge_rate_limit(cls, tenant_id: str): knowledge_rate_limit = KnowledgeRateLimitModel() @@ -289,8 +305,16 @@ class FeatureService: features.workspace_members.enabled = workspace_info["WorkspaceMembers"]["enabled"] @classmethod - def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str): - billing_info = BillingService.get_info(tenant_id) + def _fulfill_params_from_billing_api( + cls, + features: FeatureModel, + tenant_id: str, + exclude_vector_space: bool = False, + ): + if exclude_vector_space: + billing_info = BillingService.get_info(tenant_id, exclude_vector_space=True) + else: + billing_info = BillingService.get_info(tenant_id) features_usage_info = BillingService.get_quota_info(tenant_id) @@ -322,12 +346,8 @@ class FeatureService: features.apps.size = billing_info["apps"]["size"] features.apps.limit = billing_info["apps"]["limit"] - if "vector_space" in billing_info: - # NOTE (hj24): billing API returns vector_space.size as float (e.g. 0.0) - # but LimitationModel.size is int; truncate here for compatibility - features.vector_space.size = int(billing_info["vector_space"]["size"]) - # NOTE END - features.vector_space.limit = billing_info["vector_space"]["limit"] + if not exclude_vector_space: + cls._fulfill_vector_space_from_billing_info(features.vector_space, billing_info) if "documents_upload_quota" in billing_info: features.documents_upload_quota.size = billing_info["documents_upload_quota"]["size"] @@ -359,6 +379,16 @@ class FeatureService: if "next_credit_reset_date" in billing_info: features.next_credit_reset_date = billing_info["next_credit_reset_date"] + @classmethod + def _fulfill_vector_space_from_billing_info(cls, vector_space: LimitationModel, billing_info: BillingInfo): + if "vector_space" not in billing_info: + return + + # NOTE: billing API returns vector_space.size as float (e.g. 0.0), + # but feature API keeps LimitationModel.size as int for compatibility. + vector_space.size = int(billing_info["vector_space"]["size"]) + vector_space.limit = billing_info["vector_space"]["limit"] + @classmethod def _fulfill_params_from_enterprise(cls, features: SystemFeatureModel, is_authenticated: bool = False): enterprise_info = EnterpriseService.get_info() diff --git a/api/tests/unit_tests/controllers/console/test_feature.py b/api/tests/unit_tests/controllers/console/test_feature.py index 1711aede61..0339c50777 100644 --- a/api/tests/unit_tests/controllers/console/test_feature.py +++ b/api/tests/unit_tests/controllers/console/test_feature.py @@ -20,8 +20,10 @@ class TestFeatureApi: return_value=("account_id", "tenant_123"), ) - mocker.patch("controllers.console.feature.FeatureService.get_features").return_value.model_dump.return_value = { - "features": {"feature_a": True} + get_features = mocker.patch("controllers.console.feature.FeatureService.get_features") + get_features.return_value.model_dump.return_value = { + "features": {"feature_a": True}, + "vector_space": {"size": 1, "limit": 2}, } api = FeatureApi() @@ -30,6 +32,28 @@ class TestFeatureApi: result = raw_get(api) assert result == {"features": {"feature_a": True}} + get_features.assert_called_once_with("tenant_123", exclude_vector_space=True) + + +class TestFeatureVectorSpaceApi: + def test_get_vector_space_success(self, mocker: MockerFixture): + from controllers.console.feature import FeatureVectorSpaceApi + + mocker.patch( + "controllers.console.feature.current_account_with_tenant", + return_value=("account_id", "tenant_123"), + ) + + get_vector_space = mocker.patch("controllers.console.feature.FeatureService.get_vector_space") + get_vector_space.return_value.model_dump.return_value = {"size": 5120, "limit": 20480} + + api = FeatureVectorSpaceApi() + + raw_get = unwrap(FeatureVectorSpaceApi.get) + result = raw_get(api) + + assert result == {"size": 5120, "limit": 20480} + get_vector_space.assert_called_once_with("tenant_123") class TestSystemFeatureApi: diff --git a/api/tests/unit_tests/services/test_billing_service.py b/api/tests/unit_tests/services/test_billing_service.py index 36592196c6..e7a195a472 100644 --- a/api/tests/unit_tests/services/test_billing_service.py +++ b/api/tests/unit_tests/services/test_billing_service.py @@ -313,6 +313,54 @@ class TestBillingServiceSubscriptionInfo: assert result == expected_response mock_send_request.assert_called_once_with("GET", "/subscription/info", params={"tenant_id": tenant_id}) + def test_get_info_exclude_vector_space(self, mock_send_request): + """When requested, get_info asks billing to skip vector_space.""" + # Arrange + tenant_id = "tenant-123" + expected_response = { + "enabled": True, + "subscription": {"plan": "professional", "interval": "month", "education": False}, + "members": {"size": 1, "limit": 50}, + "apps": {"size": 1, "limit": 200}, + "knowledge_rate_limit": {"limit": 1000}, + "documents_upload_quota": {"size": 0, "limit": 1000}, + "annotation_quota_limit": {"size": 0, "limit": 5000}, + "docs_processing": "top-priority", + "can_replace_logo": True, + "model_load_balancing_enabled": True, + "knowledge_pipeline_publish_enabled": True, + } + mock_send_request.return_value = expected_response + + # Act + result = BillingService.get_info(tenant_id, exclude_vector_space=True) + + # Assert + assert "vector_space" not in result + mock_send_request.assert_called_once_with( + "GET", + "/subscription/info", + params={"tenant_id": tenant_id, "exclude_vector_space": "true"}, + ) + + def test_get_vector_space_success(self, mock_send_request): + """Test successful retrieval of vector-space usage and limit.""" + # Arrange + tenant_id = "tenant-123" + expected_response = {"size": 5120.75, "limit": 20480} + mock_send_request.return_value = expected_response + + # Act + result = BillingService.get_vector_space(tenant_id) + + # Assert + assert result == expected_response + mock_send_request.assert_called_once_with( + "GET", + "/subscription/vector-space", + params={"tenant_id": tenant_id}, + ) + def test_get_knowledge_rate_limit_with_defaults(self, mock_send_request): """Test knowledge rate limit retrieval with default values.""" # Arrange @@ -1744,8 +1792,9 @@ class TestBillingServiceSubscriptionInfoDataType: assert isinstance(result["apps"]["size"], int) assert isinstance(result["apps"]["limit"], int) - assert isinstance(result["vector_space"]["size"], float) - assert isinstance(result["vector_space"]["limit"], int) + if "vector_space" in result: + assert isinstance(result["vector_space"]["size"], float) + assert isinstance(result["vector_space"]["limit"], int) assert isinstance(result["knowledge_rate_limit"]["limit"], int) @@ -1783,11 +1832,13 @@ class TestBillingServiceSubscriptionInfoDataType: def test_get_info_without_optional_fields(self, mock_send_request, string_billing_response): """NotRequired fields can be absent without raising.""" del string_billing_response["next_credit_reset_date"] + del string_billing_response["vector_space"] mock_send_request.return_value = string_billing_response result = BillingService.get_info("tenant-type-test") assert "next_credit_reset_date" not in result + assert "vector_space" not in result self._assert_billing_info_types(result) def test_get_info_with_extra_fields(self, mock_send_request, string_billing_response): diff --git a/api/tests/unit_tests/services/test_feature_service_human_input_email_delivery.py b/api/tests/unit_tests/services/test_feature_service_human_input_email_delivery.py index ab141a7b2d..8614d351f1 100644 --- a/api/tests/unit_tests/services/test_feature_service_human_input_email_delivery.py +++ b/api/tests/unit_tests/services/test_feature_service_human_input_email_delivery.py @@ -102,3 +102,17 @@ def test_resolve_human_input_email_delivery_enabled_matrix( ) assert result is case.expected + + +def test_get_vector_space_converts_billing_float_size(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(feature_service_module.dify_config, "BILLING_ENABLED", True) + monkeypatch.setattr( + feature_service_module.BillingService, + "get_vector_space", + lambda tenant_id: {"size": 5120.75, "limit": 20480}, + ) + + result = FeatureService.get_vector_space("tenant-1") + + assert result.size == 5120 + assert result.limit == 20480 diff --git a/packages/contracts/generated/api/console/features/orpc.gen.ts b/packages/contracts/generated/api/console/features/orpc.gen.ts index e24ec3d964..3463ccb015 100644 --- a/packages/contracts/generated/api/console/features/orpc.gen.ts +++ b/packages/contracts/generated/api/console/features/orpc.gen.ts @@ -2,14 +2,35 @@ import { oc } from '@orpc/contract' -import { zGetFeaturesResponse } from './zod.gen' +import { zGetFeaturesResponse, zGetFeaturesVectorSpaceResponse } from './zod.gen' + +/** + * Get vector-space usage and limit for current tenant + * + * Get vector-space usage and limit for current tenant + */ +export const get = oc + .route({ + description: 'Get vector-space usage and limit for current tenant', + inputStructure: 'detailed', + method: 'GET', + operationId: 'getFeaturesVectorSpace', + path: '/features/vector-space', + summary: 'Get vector-space usage and limit for current tenant', + tags: ['console'], + }) + .output(zGetFeaturesVectorSpaceResponse) + +export const vectorSpace = { + get, +} /** * Get feature configuration for current tenant * * Get feature configuration for current tenant */ -export const get = oc +export const get2 = oc .route({ description: 'Get feature configuration for current tenant', inputStructure: 'detailed', @@ -22,7 +43,8 @@ export const get = oc .output(zGetFeaturesResponse) export const features = { - get, + get: get2, + vectorSpace, } export const contract = { diff --git a/packages/contracts/generated/api/console/features/types.gen.ts b/packages/contracts/generated/api/console/features/types.gen.ts index 411e062afb..68b2dc0d9e 100644 --- a/packages/contracts/generated/api/console/features/types.gen.ts +++ b/packages/contracts/generated/api/console/features/types.gen.ts @@ -75,3 +75,17 @@ export type GetFeaturesResponses = { } export type GetFeaturesResponse = GetFeaturesResponses[keyof GetFeaturesResponses] + +export type GetFeaturesVectorSpaceData = { + body?: never + path?: never + query?: never + url: '/features/vector-space' +} + +export type GetFeaturesVectorSpaceResponses = { + 200: LimitationModel +} + +export type GetFeaturesVectorSpaceResponse + = GetFeaturesVectorSpaceResponses[keyof GetFeaturesVectorSpaceResponses] diff --git a/packages/contracts/generated/api/console/features/zod.gen.ts b/packages/contracts/generated/api/console/features/zod.gen.ts index 9ace83a433..0e26f296b6 100644 --- a/packages/contracts/generated/api/console/features/zod.gen.ts +++ b/packages/contracts/generated/api/console/features/zod.gen.ts @@ -93,3 +93,8 @@ export const zFeatureModel = z.object({ * Success */ export const zGetFeaturesResponse = zFeatureModel + +/** + * Success + */ +export const zGetFeaturesVectorSpaceResponse = zLimitationModel diff --git a/web/__tests__/billing/billing-integration.test.tsx b/web/__tests__/billing/billing-integration.test.tsx index 3113e36751..f1d96ad2c5 100644 --- a/web/__tests__/billing/billing-integration.test.tsx +++ b/web/__tests__/billing/billing-integration.test.tsx @@ -53,6 +53,9 @@ vi.mock('@/service/use-billing', () => ({ refetch: mockRefetch, }), useBindPartnerStackInfo: () => ({ mutateAsync: vi.fn() }), + useCurrentPlanVectorSpace: () => ({ + data: undefined, + }), })) vi.mock('@/service/use-education', () => ({ diff --git a/web/__tests__/billing/education-verification-flow.test.tsx b/web/__tests__/billing/education-verification-flow.test.tsx index 707f1d690a..58b531661e 100644 --- a/web/__tests__/billing/education-verification-flow.test.tsx +++ b/web/__tests__/billing/education-verification-flow.test.tsx @@ -60,6 +60,9 @@ vi.mock('@/service/use-billing', () => ({ isFetching: false, refetch: vi.fn(), }), + useCurrentPlanVectorSpace: () => ({ + data: undefined, + }), })) // ─── Navigation mocks ─────────────────────────────────────────────────────── diff --git a/web/app/components/billing/plan/__tests__/index.spec.tsx b/web/app/components/billing/plan/__tests__/index.spec.tsx index e9e0fd7012..18c370e833 100644 --- a/web/app/components/billing/plan/__tests__/index.spec.tsx +++ b/web/app/components/billing/plan/__tests__/index.spec.tsx @@ -39,6 +39,12 @@ vi.mock('@/service/billing', () => ({ fetchSubscriptionUrls: vi.fn(), })) +vi.mock('@/service/use-billing', () => ({ + useCurrentPlanVectorSpace: () => ({ + data: undefined, + }), +})) + const fetchSubscriptionUrlsMock = vi.mocked(fetchSubscriptionUrls) const mutateAsyncMock = vi.fn() diff --git a/web/app/components/billing/type.ts b/web/app/components/billing/type.ts index 15eda0bbf6..e40c89f1a7 100644 --- a/web/app/components/billing/type.ts +++ b/web/app/components/billing/type.ts @@ -71,10 +71,6 @@ export type CurrentPlanInfoBackend = { size: number limit: number // total. 0 means unlimited } - vector_space: { - size: number - limit: number // total. 0 means unlimited - } annotation_quota_limit: { size: number limit: number // total. 0 means unlimited diff --git a/web/app/components/billing/usage-info/__tests__/vector-space-info.spec.tsx b/web/app/components/billing/usage-info/__tests__/vector-space-info.spec.tsx index 3422d09c2f..e379ef4a51 100644 --- a/web/app/components/billing/usage-info/__tests__/vector-space-info.spec.tsx +++ b/web/app/components/billing/usage-info/__tests__/vector-space-info.spec.tsx @@ -10,6 +10,7 @@ const queryPlaceholder = () => let mockPlanType = Plan.sandbox let mockVectorSpaceUsage = 30 let mockVectorSpaceTotal = 5120 +let mockVectorSpaceApiData: { size: number, limit: number } | undefined vi.mock('@/context/provider-context', () => ({ useProviderContext: () => ({ @@ -28,6 +29,12 @@ vi.mock('@/context/provider-context', () => ({ }), })) +vi.mock('@/service/use-billing', () => ({ + useCurrentPlanVectorSpace: () => ({ + data: mockVectorSpaceApiData, + }), +})) + describe('VectorSpaceInfo', () => { beforeEach(() => { vi.clearAllMocks() @@ -35,6 +42,7 @@ describe('VectorSpaceInfo', () => { mockPlanType = Plan.sandbox mockVectorSpaceUsage = 30 mockVectorSpaceTotal = 5120 + mockVectorSpaceApiData = undefined }) describe('Rendering', () => { @@ -252,5 +260,18 @@ describe('VectorSpaceInfo', () => { expect(screen.getByText('100')).toBeInTheDocument() expect(screen.getByText('102400MB')).toBeInTheDocument() }) + + it('should use vector space API limit directly', () => { + mockVectorSpaceApiData = { + size: 100, + limit: 0, + } + + render() + + expect(screen.getByText('100')).toBeInTheDocument() + expect(screen.getByText('0MB')).toBeInTheDocument() + expect(screen.queryByText('billing.plansCommon.unlimited')).not.toBeInTheDocument() + }) }) }) diff --git a/web/app/components/billing/usage-info/vector-space-info.tsx b/web/app/components/billing/usage-info/vector-space-info.tsx index e384ef4d9a..c5f23bf422 100644 --- a/web/app/components/billing/usage-info/vector-space-info.tsx +++ b/web/app/components/billing/usage-info/vector-space-info.tsx @@ -7,6 +7,7 @@ import { import * as React from 'react' import { useTranslation } from 'react-i18next' import { useProviderContext } from '@/context/provider-context' +import { useCurrentPlanVectorSpace } from '@/service/use-billing' import { Plan } from '../type' import UsageInfo from '../usage-info' import { getPlanVectorSpaceLimitMB } from '../utils' @@ -23,11 +24,25 @@ const VectorSpaceInfo: FC = ({ }) => { const { t } = useTranslation() const { plan } = useProviderContext() + const { data: vectorSpace } = useCurrentPlanVectorSpace() + const displayPlan = vectorSpace + ? { + ...plan, + usage: { + ...plan.usage, + vectorSpace: vectorSpace.size, + }, + total: { + ...plan.total, + vectorSpace: vectorSpace.limit, + }, + } + : plan const { type, usage, total, - } = plan + } = displayPlan // Determine total based on plan type (in MB), derived from ALL_PLANS config const getTotalInMB = () => { diff --git a/web/app/components/billing/utils/__tests__/index.spec.ts b/web/app/components/billing/utils/__tests__/index.spec.ts index 115da91db7..84818d3175 100644 --- a/web/app/components/billing/utils/__tests__/index.spec.ts +++ b/web/app/components/billing/utils/__tests__/index.spec.ts @@ -65,10 +65,6 @@ describe('billing utils', () => { size: 2, limit: 5, }, - vector_space: { - size: 10, - limit: 50, - }, annotation_quota_limit: { size: 5, limit: 10, @@ -108,7 +104,7 @@ describe('billing utils', () => { const data = createMockPlanData() const result = parseCurrentPlan(data) - expect(result.usage.vectorSpace).toBe(10) + expect(result.usage.vectorSpace).toBe(0) expect(result.usage.buildApps).toBe(2) expect(result.usage.teamMembers).toBe(1) expect(result.usage.annotatedResponse).toBe(5) @@ -125,6 +121,29 @@ describe('billing utils', () => { expect(result.total.annotatedResponse).toBe(10) }) + it('should not read vector space usage from current plan info', () => { + const data = createMockPlanData() + const result = parseCurrentPlan(data) + + expect(result.usage.vectorSpace).toBe(0) + expect(result.total.vectorSpace).toBe(50) + }) + + it('should derive vector space total from plan config', () => { + const data = createMockPlanData({ + billing: { + enabled: true, + subscription: { + plan: Plan.professional, + }, + }, + }) + const result = parseCurrentPlan(data) + + expect(result.usage.vectorSpace).toBe(0) + expect(result.total.vectorSpace).toBe(5 * 1024) + }) + it('should convert 0 limits to NUM_INFINITE (-1)', () => { const data = createMockPlanData({ documents_upload_quota: { diff --git a/web/app/components/billing/utils/index.ts b/web/app/components/billing/utils/index.ts index 2d37eecbd5..c83c0a6c52 100644 --- a/web/app/components/billing/utils/index.ts +++ b/web/app/components/billing/utils/index.ts @@ -79,6 +79,7 @@ const getResetInDaysFromDate = (resetDate?: number | null) => { export const parseCurrentPlan = (data: CurrentPlanInfoBackend) => { const planType = data.billing.subscription.plan const planPreset = ALL_PLANS[planType] + const vectorSpaceLimit = getPlanVectorSpaceLimitMB(planType) const resolveRateLimit = (limit?: number, fallback?: number) => { const value = limit ?? fallback ?? 0 return parseRateLimit(value) @@ -93,7 +94,7 @@ export const parseCurrentPlan = (data: CurrentPlanInfoBackend) => { return { type: planType, usage: { - vectorSpace: data.vector_space.size, + vectorSpace: 0, buildApps: data.apps?.size || 0, teamMembers: data.members.size, annotatedResponse: data.annotation_quota_limit.size, @@ -102,7 +103,7 @@ export const parseCurrentPlan = (data: CurrentPlanInfoBackend) => { triggerEvents: getQuotaUsage(data.trigger_event), }, total: { - vectorSpace: parseLimit(data.vector_space.limit), + vectorSpace: vectorSpaceLimit, buildApps: parseLimit(data.apps?.limit) || 0, teamMembers: parseLimit(data.members.limit), annotatedResponse: parseLimit(data.annotation_quota_limit.limit), diff --git a/web/app/components/billing/vector-space-full/__tests__/index.spec.tsx b/web/app/components/billing/vector-space-full/__tests__/index.spec.tsx index b1ef0104a0..42054df649 100644 --- a/web/app/components/billing/vector-space-full/__tests__/index.spec.tsx +++ b/web/app/components/billing/vector-space-full/__tests__/index.spec.tsx @@ -21,6 +21,12 @@ vi.mock('../../upgrade-btn', () => ({ default: () => , })) +vi.mock('@/service/use-billing', () => ({ + useCurrentPlanVectorSpace: () => ({ + data: undefined, + }), +})) + // Mock utils to control threshold and plan limits vi.mock('../../utils', () => ({ getPlanVectorSpaceLimitMB: (planType: string) => { diff --git a/web/app/components/datasets/common/document-picker/__tests__/index.spec.tsx b/web/app/components/datasets/common/document-picker/__tests__/index.spec.tsx index a6c2078836..f934ef545e 100644 --- a/web/app/components/datasets/common/document-picker/__tests__/index.spec.tsx +++ b/web/app/components/datasets/common/document-picker/__tests__/index.spec.tsx @@ -148,6 +148,30 @@ describe('DocumentPicker', () => { expect(trigger).not.toHaveFocus() }) + it('should keep focus in the search input when deleting from an empty result', async () => { + const user = userEvent.setup() + mockUseDocumentList.mockImplementation(({ query }) => ({ + data: query.keyword === 'missing' + ? { data: [] } + : mockDocumentListData, + })) + + renderDocumentPicker() + + const trigger = screen.getByRole('combobox', { name: 'Document 1' }) + await user.click(trigger) + + const searchInput = screen.getByPlaceholderText('common.operation.search') + await user.type(searchInput, 'missing') + expect(await screen.findByText('common.noData')).toBeInTheDocument() + + await user.keyboard('{Backspace}{Backspace}{Backspace}{Backspace}{Backspace}{Backspace}{Backspace}') + + expect(trigger).toHaveAttribute('aria-expanded', 'true') + expect(searchInput).toHaveFocus() + expect(trigger).not.toHaveFocus() + }) + it('should keep focus in the search input while typing quickly', async () => { const user = userEvent.setup() renderDocumentPicker() diff --git a/web/app/components/datasets/common/document-picker/index.tsx b/web/app/components/datasets/common/document-picker/index.tsx index 5a15cc4a3b..9f06727ed9 100644 --- a/web/app/components/datasets/common/document-picker/index.tsx +++ b/web/app/components/datasets/common/document-picker/index.tsx @@ -13,7 +13,8 @@ import { ComboboxValue, } from '@langgenius/dify-ui/combobox' import { RiArrowDownSLine } from '@remixicon/react' -import { useDeferredValue, useState } from 'react' +import { useDebounce } from 'ahooks' +import { useState } from 'react' import { useTranslation } from 'react-i18next' import { GeneralChunk, ParentChildChunk } from '@/app/components/base/icons/src/vender/knowledge' import Loading from '@/app/components/base/loading' @@ -106,12 +107,12 @@ export function DocumentPicker({ }: Props) { const { t } = useTranslation() const [searchValue, setSearchValue] = useState('') - const deferredSearchValue = useDeferredValue(searchValue) + const debouncedSearchValue = useDebounce(searchValue, { wait: 500 }) const { data } = useDocumentList({ datasetId, query: { - keyword: deferredSearchValue, + keyword: debouncedSearchValue, page: 1, limit: 20, }, @@ -175,19 +176,16 @@ export function DocumentPicker({ className="block h-4.5 grow px-1 py-0 text-[13px] text-text-primary" /> + {data ? ( - documentsList.length > 0 - ? ( - - ) - : ( - - {t('noData', { ns: 'common' })} - - ) + +
+ {t('noData', { ns: 'common' })} +
+
) : ( diff --git a/web/app/components/datasets/create/step-one/__tests__/index.spec.tsx b/web/app/components/datasets/create/step-one/__tests__/index.spec.tsx index 6c6c60d808..bf44c0f37d 100644 --- a/web/app/components/datasets/create/step-one/__tests__/index.spec.tsx +++ b/web/app/components/datasets/create/step-one/__tests__/index.spec.tsx @@ -36,6 +36,16 @@ vi.mock('@/context/provider-context', () => ({ }), })) +vi.mock('@/service/use-billing', () => ({ + useCurrentPlanVectorSpace: () => ({ + data: { + size: mockPlan.usage.vectorSpace, + limit: mockPlan.total.vectorSpace, + }, + isFetching: false, + }), +})) + vi.mock('../../file-uploader', () => ({ default: ({ onPreview, fileList }: { onPreview: (file: File) => void, fileList: FileItem[] }) => (
diff --git a/web/app/components/datasets/create/step-one/index.tsx b/web/app/components/datasets/create/step-one/index.tsx index e94a44ebb2..3b57390cce 100644 --- a/web/app/components/datasets/create/step-one/index.tsx +++ b/web/app/components/datasets/create/step-one/index.tsx @@ -15,6 +15,7 @@ import VectorSpaceFull from '@/app/components/billing/vector-space-full' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import { useProviderContext } from '@/context/provider-context' import { DataSourceType } from '@/models/datasets' +import { useCurrentPlanVectorSpace } from '@/service/use-billing' import EmptyDatasetCreationModal from '../empty-dataset-creation-modal' import FileUploader from '../file-uploader' import Website from '../website' @@ -119,7 +120,15 @@ const StepOne = ({ const allFileLoaded = files.length > 0 && files.every(file => file.file.id) const hasNotion = notionPages.length > 0 - const isVectorSpaceFull = plan.usage.vectorSpace >= plan.total.vectorSpace + const shouldCheckVectorSpace = enableBilling && (allFileLoaded || hasNotion) + const { + data: vectorSpace, + isFetching: isFetchingVectorSpacePlan, + } = useCurrentPlanVectorSpace(shouldCheckVectorSpace) + const isCheckingVectorSpace = shouldCheckVectorSpace && !vectorSpace && isFetchingVectorSpacePlan + const isVectorSpaceFull = !!vectorSpace + && vectorSpace.limit > 0 + && vectorSpace.size >= vectorSpace.limit const isShowVectorSpaceFull = (allFileLoaded || hasNotion) && isVectorSpaceFull && enableBilling const supportBatchUpload = !enableBilling || plan.type !== Plan.sandbox @@ -131,8 +140,10 @@ const StepOne = ({ return true if (files.some(file => !file.file.id)) return true + if (isCheckingVectorSpace) + return true return isShowVectorSpaceFull - }, [files, isShowVectorSpaceFull]) + }, [files, isCheckingVectorSpace, isShowVectorSpaceFull]) // Clear previews when switching data source type const handleClearPreviews = useCallback((newType: DataSourceType) => { diff --git a/web/app/components/datasets/documents/create-from-pipeline/__tests__/index.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/__tests__/index.spec.tsx index 1029714661..ae1eb817b4 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/__tests__/index.spec.tsx @@ -42,6 +42,16 @@ vi.mock('@/context/provider-context', () => ({ selector({ plan: mockPlan, enableBilling: true }), })) +vi.mock('@/service/use-billing', () => ({ + useCurrentPlanVectorSpace: () => ({ + data: { + size: mockPlan.usage.vectorSpace, + limit: mockPlan.total.vectorSpace, + }, + isFetching: false, + }), +})) + vi.mock('@/context/dataset-detail', () => ({ useDatasetDetailContextWithSelector: (selector: (state: { dataset: { pipeline_id: string } }) => unknown) => selector({ dataset: { pipeline_id: 'test-pipeline-id' } }), diff --git a/web/app/components/datasets/documents/create-from-pipeline/hooks/use-datasource-ui-state.ts b/web/app/components/datasets/documents/create-from-pipeline/hooks/use-datasource-ui-state.ts index f4c222f652..5cfd557b3d 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/hooks/use-datasource-ui-state.ts +++ b/web/app/components/datasets/documents/create-from-pipeline/hooks/use-datasource-ui-state.ts @@ -13,6 +13,7 @@ type DatasourceUIStateParams = { selectedFileIdsLength: number onlineDriveFileList: OnlineDriveFile[] isVectorSpaceFull: boolean + isCheckingVectorSpace?: boolean enableBilling: boolean currentWorkspacePagesLength: number fileUploadConfig: { file_size_limit: number, batch_count_limit: number } @@ -30,6 +31,7 @@ export const useDatasourceUIState = ({ selectedFileIdsLength, onlineDriveFileList, isVectorSpaceFull, + isCheckingVectorSpace = false, enableBilling, currentWorkspacePagesLength, fileUploadConfig, @@ -59,14 +61,14 @@ export const useDatasourceUIState = ({ return true const disabledConditions: Record = { - [DatasourceType.localFile]: isShowVectorSpaceFull || localFileListLength === 0 || !allFileLoaded, - [DatasourceType.onlineDocument]: isShowVectorSpaceFull || onlineDocumentsLength === 0, - [DatasourceType.websiteCrawl]: isShowVectorSpaceFull || websitePagesLength === 0, - [DatasourceType.onlineDrive]: isShowVectorSpaceFull || selectedFileIdsLength === 0, + [DatasourceType.localFile]: isCheckingVectorSpace || isShowVectorSpaceFull || localFileListLength === 0 || !allFileLoaded, + [DatasourceType.onlineDocument]: isCheckingVectorSpace || isShowVectorSpaceFull || onlineDocumentsLength === 0, + [DatasourceType.websiteCrawl]: isCheckingVectorSpace || isShowVectorSpaceFull || websitePagesLength === 0, + [DatasourceType.onlineDrive]: isCheckingVectorSpace || isShowVectorSpaceFull || selectedFileIdsLength === 0, } return disabledConditions[datasourceType] ?? true - }, [datasource, datasourceType, isShowVectorSpaceFull, localFileListLength, allFileLoaded, onlineDocumentsLength, websitePagesLength, selectedFileIdsLength]) + }, [datasource, datasourceType, isCheckingVectorSpace, isShowVectorSpaceFull, localFileListLength, allFileLoaded, onlineDocumentsLength, websitePagesLength, selectedFileIdsLength]) // Check if select all should be shown const showSelect = useMemo(() => { diff --git a/web/app/components/datasets/documents/create-from-pipeline/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/index.tsx index 799f24fa2a..07843217ef 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/index.tsx @@ -12,6 +12,7 @@ import { PlanUpgradeModal } from '@/app/components/billing/plan-upgrade-modal' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import { useProviderContextSelector } from '@/context/provider-context' import { DatasourceType } from '@/models/pipeline' +import { useCurrentPlanVectorSpace } from '@/service/use-billing' import { useFileUploadConfig } from '@/service/use-common' import { usePublishedPipelineInfo } from '@/service/use-pipeline' import { useDataSourceStore } from './data-source/store' @@ -91,7 +92,20 @@ const CreateFormPipeline = () => { } = useOnlineDrive() // Computed values - const isVectorSpaceFull = plan.usage.vectorSpace >= plan.total.vectorSpace + const shouldCheckVectorSpace = enableBilling && ( + allFileLoaded + || onlineDocuments.length > 0 + || websitePages.length > 0 + || selectedFileIds.length > 0 + ) + const { + data: vectorSpace, + isFetching: isFetchingVectorSpacePlan, + } = useCurrentPlanVectorSpace(shouldCheckVectorSpace) + const isCheckingVectorSpace = shouldCheckVectorSpace && !vectorSpace && isFetchingVectorSpacePlan + const isVectorSpaceFull = !!vectorSpace + && vectorSpace.limit > 0 + && vectorSpace.size >= vectorSpace.limit const supportBatchUpload = !enableBilling || plan.type !== 'sandbox' // UI state @@ -112,6 +126,7 @@ const CreateFormPipeline = () => { selectedFileIdsLength: selectedFileIds.length, onlineDriveFileList, isVectorSpaceFull, + isCheckingVectorSpace, enableBilling, currentWorkspacePagesLength: currentWorkspace?.pages.length ?? 0, fileUploadConfig, diff --git a/web/app/components/workflow/block-selector/__tests__/main.spec.tsx b/web/app/components/workflow/block-selector/__tests__/main.spec.tsx index dedf64c56e..7c7f131f70 100644 --- a/web/app/components/workflow/block-selector/__tests__/main.spec.tsx +++ b/web/app/components/workflow/block-selector/__tests__/main.spec.tsx @@ -1,5 +1,6 @@ import type { ButtonHTMLAttributes } from 'react' import type { NodeDefault } from '../../types' +import { Button } from '@langgenius/dify-ui/button' import { screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import { renderWorkflowComponent } from '../../__tests__/workflow-test-env' @@ -191,4 +192,28 @@ describe('NodeSelector', () => { expect(trigger.closest('[aria-haspopup="dialog"]')).toBe(trigger) expect(screen.getByPlaceholderText('workflow.tabs.searchBlock')).toBeInTheDocument() }) + + it('can render the shared Button trigger as the popover root', async () => { + const user = userEvent.setup() + + renderWorkflowComponent( + ( + + )} + />, + ) + + const trigger = screen.getByRole('button', { name: 'open-shared-button-trigger' }) + await user.click(trigger) + + expect(trigger.closest('[aria-haspopup="dialog"]')).toBe(trigger) + expect(screen.getByPlaceholderText('workflow.tabs.searchBlock')).toBeInTheDocument() + }) }) diff --git a/web/app/components/workflow/nodes/data-source-empty/index.tsx b/web/app/components/workflow/nodes/data-source-empty/index.tsx index bd03621b87..dfb3a55788 100644 --- a/web/app/components/workflow/nodes/data-source-empty/index.tsx +++ b/web/app/components/workflow/nodes/data-source-empty/index.tsx @@ -56,6 +56,7 @@ const DataSourceEmptyNode = ({ id, data }: NodeProps) => { vi.fn()) const mockUseProviderContext = vi.hoisted(() => vi.fn()) @@ -68,29 +71,62 @@ const createProps = () => ({ description: 'Semantic description', effectColor: 'purple', }, - hybridSearchModeOptions, searchMethod: RetrievalSearchMethodEnum.semantic, onRetrievalSearchMethodChange: vi.fn(), - hybridSearchMode: HybridSearchModeEnum.WeightedScore, - onHybridSearchModeChange: vi.fn(), - weightedScore, - onWeightedScoreChange: vi.fn(), - rerankingModelEnabled: false, - onRerankingModelEnabledChange: vi.fn(), - rerankingModel: { - reranking_provider_name: '', - reranking_model_name: '', + hybridSearch: { + mode: HybridSearchModeEnum.WeightedScore, + options: hybridSearchModeOptions, + onModeChange: vi.fn(), + weightedScore, + onWeightedScoreChange: vi.fn(), + }, + reranking: { + enabled: false, + onEnabledChange: vi.fn(), + rerankingModel: { + reranking_provider_name: '', + reranking_model_name: '', + }, + onRerankingModelChange: vi.fn(), + showMultiModalTip: false, + }, + retrievalParameters: { + topK: { + value: 3, + onChange: vi.fn(), + }, + scoreThreshold: { + value: 0.5, + onChange: vi.fn(), + enabled: true, + onEnabledChange: vi.fn(), + }, }, - onRerankingModelChange: vi.fn(), - topK: 3, - onTopKChange: vi.fn(), - scoreThreshold: 0.5, - onScoreThresholdChange: vi.fn(), - isScoreThresholdEnabled: true, - onScoreThresholdEnabledChange: vi.fn(), - showMultiModalTip: false, }) +function renderSearchMethodOption(props: ReturnType) { + const { + onRetrievalSearchMethodChange, + ...optionProps + } = props + + render( + + onRetrievalSearchMethodChange(value)} + /> + )} + > + Retrieval search method + + + , + ) +} + describe('SearchMethodOption', () => { beforeEach(() => { vi.clearAllMocks() @@ -116,37 +152,32 @@ describe('SearchMethodOption', () => { it('should render semantic search controls and notify retrieval and reranking changes', () => { const props = createProps() - render() + renderSearchMethodOption(props) expect(screen.getByText('Semantic title'))!.toBeInTheDocument() expect(screen.getByText('common.modelProvider.rerankModel.key'))!.toBeInTheDocument() expect(screen.getByText('plugin.detailPanel.configureModel'))!.toBeInTheDocument() expect(screen.getAllByRole('switch')).toHaveLength(2) - fireEvent.click(screen.getByText('Semantic title')) fireEvent.click(screen.getAllByRole('switch')[0]!) - expect(props.onRetrievalSearchMethodChange).toHaveBeenCalledWith(RetrievalSearchMethodEnum.semantic) - expect(props.onRerankingModelEnabledChange).toHaveBeenCalledWith(true) + expect(props.reranking.onEnabledChange).toHaveBeenCalledWith(true) }) - it('should render the reranking switch for full-text search as well', () => { + it('should notify retrieval changes when an inactive option is selected', () => { const props = createProps() + const fullTextProps = { + ...props, + option: { + ...props.option, + id: RetrievalSearchMethodEnum.fullText, + title: 'Full-text title', + }, + } - render( - , - ) + renderSearchMethodOption(fullTextProps) expect(screen.getByText('Full-text title'))!.toBeInTheDocument() - expect(screen.getByText('common.modelProvider.rerankModel.key'))!.toBeInTheDocument() fireEvent.click(screen.getByText('Full-text title')) @@ -155,20 +186,25 @@ describe('SearchMethodOption', () => { it('should render hybrid weighted-score controls without reranking model selector', () => { const props = createProps() + const hybridProps = { + ...props, + option: { + ...props.option, + id: RetrievalSearchMethodEnum.hybrid, + title: 'Hybrid title', + }, + searchMethod: RetrievalSearchMethodEnum.hybrid, + hybridSearch: { + ...props.hybridSearch, + mode: HybridSearchModeEnum.WeightedScore, + }, + reranking: { + ...props.reranking, + showMultiModalTip: true, + }, + } - render( - , - ) + renderSearchMethodOption(hybridProps) expect(screen.getByText('Weighted mode'))!.toBeInTheDocument() expect(screen.getByText('Rerank mode'))!.toBeInTheDocument() @@ -179,25 +215,30 @@ describe('SearchMethodOption', () => { fireEvent.click(screen.getByText('Rerank mode')) - expect(props.onHybridSearchModeChange).toHaveBeenCalledWith(HybridSearchModeEnum.RerankingModel) + expect(props.hybridSearch.onModeChange).toHaveBeenCalledWith(HybridSearchModeEnum.RerankingModel) }) it('should render the hybrid reranking selector when reranking mode is selected', () => { const props = createProps() + const hybridProps = { + ...props, + option: { + ...props.option, + id: RetrievalSearchMethodEnum.hybrid, + title: 'Hybrid title', + }, + searchMethod: RetrievalSearchMethodEnum.hybrid, + hybridSearch: { + ...props.hybridSearch, + mode: HybridSearchModeEnum.RerankingModel, + }, + reranking: { + ...props.reranking, + showMultiModalTip: true, + }, + } - render( - , - ) + renderSearchMethodOption(hybridProps) expect(screen.getByText('plugin.detailPanel.configureModel'))!.toBeInTheDocument() expect(screen.queryByText('common.modelProvider.rerankModel.key')).not.toBeInTheDocument() @@ -207,23 +248,22 @@ describe('SearchMethodOption', () => { it('should hide the score-threshold control for keyword search', () => { const props = createProps() + const keywordProps = { + ...props, + option: { + ...props.option, + id: RetrievalSearchMethodEnum.keywordSearch, + title: 'Keyword title', + }, + searchMethod: RetrievalSearchMethodEnum.keywordSearch, + } - render( - , - ) + renderSearchMethodOption(keywordProps) fireEvent.change(screen.getByRole('textbox'), { target: { value: '9' } }) expect(screen.getAllByRole('textbox')).toHaveLength(1) expect(screen.queryAllByRole('switch')).toHaveLength(0) - expect(props.onTopKChange).toHaveBeenCalledWith(9) + expect(props.retrievalParameters.topK.onChange).toHaveBeenCalledWith(9) }) }) diff --git a/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/__tests__/top-k-and-score-threshold.spec.tsx b/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/__tests__/top-k-and-score-threshold.spec.tsx index 47755ec101..840b704577 100644 --- a/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/__tests__/top-k-and-score-threshold.spec.tsx +++ b/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/__tests__/top-k-and-score-threshold.spec.tsx @@ -1,40 +1,44 @@ import { fireEvent, render, screen } from '@testing-library/react' -import TopKAndScoreThreshold from '../top-k-and-score-threshold' +import { TopKAndScoreThreshold } from '../top-k-and-score-threshold' describe('TopKAndScoreThreshold', () => { + const topKLabel = /datasetConfig\.top_k/ + const scoreThresholdLabel = /datasetConfig\.score_threshold/ const defaultProps = { - topK: 3, - onTopKChange: vi.fn(), - scoreThreshold: 0.4, - onScoreThresholdChange: vi.fn(), - isScoreThresholdEnabled: true, - onScoreThresholdEnabledChange: vi.fn(), + topK: { + value: 3, + onChange: vi.fn(), + }, + scoreThreshold: { + value: 0.4, + onChange: vi.fn(), + enabled: true, + onEnabledChange: vi.fn(), + }, } beforeEach(() => { vi.clearAllMocks() }) - it('should round top-k input values before notifying the parent', () => { + it('should notify top-k input values without additional rounding', () => { render() - const [topKInput] = screen.getAllByRole('textbox') - fireEvent.change(topKInput!, { target: { value: '3.7' } }) + fireEvent.change(screen.getByRole('textbox', { name: topKLabel }), { target: { value: '3.7' } }) - expect(defaultProps.onTopKChange).toHaveBeenCalledWith(4) + expect(defaultProps.topK.onChange).toHaveBeenCalledWith(3.7) }) - it('should round score-threshold input values to two decimals', () => { + it('should notify score-threshold input values without additional rounding', () => { render() - const [, scoreThresholdInput] = screen.getAllByRole('textbox') - fireEvent.change(scoreThresholdInput!, { target: { value: '0.456' } }) + fireEvent.change(screen.getByRole('textbox', { name: scoreThresholdLabel }), { target: { value: '0.456' } }) - expect(defaultProps.onScoreThresholdChange).toHaveBeenCalledWith(0.46) + expect(defaultProps.scoreThreshold.onChange).toHaveBeenCalledWith(0.456) }) it('should hide the score-threshold column when requested', () => { - render() + render() expect(screen.getAllByRole('textbox')).toHaveLength(1) expect(screen.queryByRole('switch')).not.toBeInTheDocument() @@ -44,15 +48,18 @@ describe('TopKAndScoreThreshold', () => { render( , ) const [topKInput, scoreThresholdInput] = screen.getAllByRole('textbox') fireEvent.change(topKInput!, { target: { value: '' } }) - expect(defaultProps.onTopKChange).toHaveBeenCalledWith(0) + expect(defaultProps.topK.onChange).toHaveBeenCalledWith(0) expect(scoreThresholdInput)!.toHaveValue('') }) @@ -60,10 +67,13 @@ describe('TopKAndScoreThreshold', () => { render( , ) - expect(screen.getByRole('switch'))!.toHaveAttribute('aria-checked', 'false') + expect(screen.getByRole('switch', { name: scoreThresholdLabel }))!.toHaveAttribute('aria-checked', 'false') }) }) diff --git a/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/index.tsx b/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/index.tsx index e316f941ea..91a24c4e8b 100644 --- a/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/index.tsx +++ b/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/index.tsx @@ -5,7 +5,13 @@ import type { WeightedScore, } from '../../types' import type { RerankingModelSelectorProps } from './reranking-model-selector' -import type { TopKAndScoreThresholdProps } from './top-k-and-score-threshold' +import type { + TopKFieldProps, + VisibleScoreThresholdFieldProps, +} from './top-k-and-score-threshold' +import { FieldRoot } from '@langgenius/dify-ui/field' +import { FieldsetLegend, FieldsetRoot } from '@langgenius/dify-ui/fieldset' +import { RadioGroup } from '@langgenius/dify-ui/radio-group' import { memo, } from 'react' @@ -13,7 +19,7 @@ import { useTranslation } from 'react-i18next' import { Field } from '@/app/components/workflow/nodes/_base/components/layout' import { useDocLink } from '@/context/i18n' import { useRetrievalSetting } from './hooks' -import SearchMethodOption from './search-method-option' +import { SearchMethodOption } from './search-method-option' type RetrievalSettingProps = { indexMethod?: IndexMethodEnum @@ -23,11 +29,18 @@ type RetrievalSettingProps = { hybridSearchMode?: HybridSearchModeEnum onHybridSearchModeChange: (value: HybridSearchModeEnum) => void rerankingModelEnabled?: boolean - onRerankingModelEnabledChange?: (value: boolean) => void + onRerankingModelEnabledChange: (value: boolean) => void weightedScore?: WeightedScore onWeightedScoreChange: (value: { value: number[] }) => void showMultiModalTip?: boolean -} & RerankingModelSelectorProps & TopKAndScoreThresholdProps +} & RerankingModelSelectorProps & { + topK: TopKFieldProps['value'] + onTopKChange: TopKFieldProps['onChange'] + scoreThreshold: VisibleScoreThresholdFieldProps['value'] + onScoreThresholdChange: VisibleScoreThresholdFieldProps['onChange'] + isScoreThresholdEnabled?: VisibleScoreThresholdFieldProps['enabled'] + onScoreThresholdEnabledChange: VisibleScoreThresholdFieldProps['onEnabledChange'] +} const RetrievalSetting = ({ indexMethod, @@ -70,35 +83,56 @@ const RetrievalSetting = ({ ), }} > -
- { - options.map(option => ( + + + value={searchMethod} + onValueChange={value => onRetrievalSearchMethodChange(value)} + disabled={readonly} + className="flex-col items-stretch gap-1" + /> + )} + > + + {t('form.retrievalSetting.title', { ns: 'datasetSettings' })} + + {options.map(option => ( - )) - } -
+ ))} + + ) } diff --git a/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/search-method-option.tsx b/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/search-method-option.tsx index be31b49f95..7f80b26eae 100644 --- a/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/search-method-option.tsx +++ b/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/search-method-option.tsx @@ -1,221 +1,358 @@ +import type { ReactNode } from 'react' import type { WeightedScore, } from '../../types' import type { RerankingModelSelectorProps } from './reranking-model-selector' -import type { TopKAndScoreThresholdProps } from './top-k-and-score-threshold' +import type { + TopKFieldProps, + VisibleScoreThresholdFieldProps, +} from './top-k-and-score-threshold' import type { HybridSearchModeOption, Option, } from './type' import { cn } from '@langgenius/dify-ui/cn' +import { FieldItem, FieldLabel, FieldRoot } from '@langgenius/dify-ui/field' +import { FieldsetLegend, FieldsetRoot } from '@langgenius/dify-ui/fieldset' +import { RadioControl, RadioRoot } from '@langgenius/dify-ui/radio' +import { RadioGroup } from '@langgenius/dify-ui/radio-group' import { Switch } from '@langgenius/dify-ui/switch' -import { - memo, - useCallback, - useMemo, -} from 'react' import { useTranslation } from 'react-i18next' import WeightedScoreComponent from '@/app/components/app/configuration/dataset-config/params-config/weighted-score' -import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' +import Badge from '@/app/components/base/badge' +import { + OptionCardEffectBlue, + OptionCardEffectBlueLight, + OptionCardEffectOrange, + OptionCardEffectPurple, + OptionCardEffectTeal, +} from '@/app/components/base/icons/src/public/knowledge' import { Infotip } from '@/app/components/base/infotip' import { DEFAULT_WEIGHTED_SCORE } from '@/models/datasets' import { HybridSearchModeEnum, RetrievalSearchMethodEnum, } from '../../types' -import OptionCard from '../option-card' import RerankingModelSelector from './reranking-model-selector' -import TopKAndScoreThreshold from './top-k-and-score-threshold' +import { TopKAndScoreThreshold } from './top-k-and-score-threshold' -type SearchMethodOptionProps = { - readonly?: boolean - option: Option - hybridSearchModeOptions: HybridSearchModeOption[] - searchMethod?: RetrievalSearchMethodEnum - onRetrievalSearchMethodChange: (value: RetrievalSearchMethodEnum) => void - hybridSearchMode?: HybridSearchModeEnum - onHybridSearchModeChange: (value: HybridSearchModeEnum) => void +type HybridSearchConfig = { + mode?: HybridSearchModeEnum + options: HybridSearchModeOption[] + onModeChange: (value: HybridSearchModeEnum) => void weightedScore?: WeightedScore onWeightedScoreChange: (value: { value: number[] }) => void - rerankingModelEnabled?: boolean - onRerankingModelEnabledChange?: (value: boolean) => void +} + +type RerankingConfig = RerankingModelSelectorProps & { + enabled?: boolean + onEnabledChange: (value: boolean) => void showMultiModalTip?: boolean -} & RerankingModelSelectorProps & TopKAndScoreThresholdProps -const SearchMethodOption = ({ - readonly, - option, - hybridSearchModeOptions, - searchMethod, - onRetrievalSearchMethodChange, - hybridSearchMode, - onHybridSearchModeChange, - weightedScore, - onWeightedScoreChange, - rerankingModelEnabled, - onRerankingModelEnabledChange, - rerankingModel, - onRerankingModelChange, - topK, - onTopKChange, - scoreThreshold, - onScoreThresholdChange, - isScoreThresholdEnabled, - onScoreThresholdEnabledChange, - showMultiModalTip = false, -}: SearchMethodOptionProps) => { - const { t } = useTranslation() - const Icon = option.icon - const isHybridSearch = option.id === RetrievalSearchMethodEnum.hybrid - const isHybridSearchWeightedScoreMode = hybridSearchMode === HybridSearchModeEnum.WeightedScore +} - const weightedScoreValue = useMemo(() => { - const sematicWeightedScore = weightedScore?.vector_setting.vector_weight ?? DEFAULT_WEIGHTED_SCORE.other.semantic - const keywordWeightedScore = weightedScore?.keyword_setting.keyword_weight ?? DEFAULT_WEIGHTED_SCORE.other.keyword - const mergedValue = [sematicWeightedScore, keywordWeightedScore] +type RetrievalParametersConfig = { + topK: TopKFieldProps + scoreThreshold: VisibleScoreThresholdFieldProps +} - return { - value: mergedValue, - } - }, [weightedScore]) +type SearchMethodRadioCardProps = { + option: Option + searchMethod?: RetrievalSearchMethodEnum + readonly?: boolean + isRecommended?: boolean + children?: ReactNode +} - const icon = useCallback((isActive: boolean) => { - return ( - - ) - }, [Icon]) +export type SearchMethodOptionProps = { + readonly?: boolean + option: Option + searchMethod?: RetrievalSearchMethodEnum + hybridSearch: HybridSearchConfig + reranking: RerankingConfig + retrievalParameters: RetrievalParametersConfig +} - const hybridSearchModeWrapperClassName = useCallback((isActive: boolean) => { - return isActive ? 'border-[1.5px] bg-components-option-card-option-selected-bg' : '' - }, []) +const HEADER_EFFECT_MAP: Record = { + 'blue': , + 'blue-light': , + 'orange': , + 'purple': , + 'teal': , +} - const showRerankModelSelectorSwitch = useMemo(() => { - if (searchMethod === RetrievalSearchMethodEnum.semantic) - return true +function getWeightedScoreValue(weightedScore?: WeightedScore) { + const semanticWeightedScore = weightedScore?.vector_setting.vector_weight ?? DEFAULT_WEIGHTED_SCORE.other.semantic + const keywordWeightedScore = weightedScore?.keyword_setting.keyword_weight ?? DEFAULT_WEIGHTED_SCORE.other.keyword - if (searchMethod === RetrievalSearchMethodEnum.fullText) - return true + return { + value: [semanticWeightedScore, keywordWeightedScore], + } +} - return false - }, [searchMethod]) - const showRerankModelSelector = useMemo(() => { - if (searchMethod === RetrievalSearchMethodEnum.semantic) - return true +function shouldShowRerankModelSelectorSwitch(searchMethod?: RetrievalSearchMethodEnum) { + return searchMethod === RetrievalSearchMethodEnum.semantic || searchMethod === RetrievalSearchMethodEnum.fullText +} - if (searchMethod === RetrievalSearchMethodEnum.fullText) - return true +function shouldShowRerankModelSelector(searchMethod: RetrievalSearchMethodEnum | undefined, hybridSearchMode: HybridSearchModeEnum | undefined) { + if (shouldShowRerankModelSelectorSwitch(searchMethod)) + return true - if (searchMethod === RetrievalSearchMethodEnum.hybrid && hybridSearchMode !== HybridSearchModeEnum.WeightedScore) - return true + return searchMethod === RetrievalSearchMethodEnum.hybrid && hybridSearchMode !== HybridSearchModeEnum.WeightedScore +} - return false - }, [hybridSearchMode, searchMethod]) +function getSearchMethodEffect(effectColor: string | undefined, isActive: boolean) { + const effect = effectColor ? HEADER_EFFECT_MAP[effectColor] : undefined + + if (!effect) + return null return ( -