mirror of
https://github.com/langgenius/dify.git
synced 2026-05-23 10:29:07 +08:00
Merge remote-tracking branch 'upstream/main' into feat/cli
This commit is contained in:
@ -3,12 +3,12 @@ from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from controllers.common.schema import register_response_schema_models
|
||||
from libs.login import current_account_with_tenant, current_user, login_required
|
||||
from services.feature_service import FeatureModel, FeatureService, SystemFeatureModel
|
||||
from services.feature_service import FeatureModel, FeatureService, LimitationModel, SystemFeatureModel
|
||||
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, cloud_utm_record, setup_required
|
||||
|
||||
register_response_schema_models(console_ns, FeatureModel, SystemFeatureModel)
|
||||
register_response_schema_models(console_ns, FeatureModel, LimitationModel, SystemFeatureModel)
|
||||
|
||||
|
||||
@console_ns.route("/features")
|
||||
@ -28,7 +28,32 @@ class FeatureApi(Resource):
|
||||
"""Get feature configuration for current tenant"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
return FeatureService.get_features(current_tenant_id).model_dump()
|
||||
payload = FeatureService.get_features(
|
||||
current_tenant_id,
|
||||
exclude_vector_space=True,
|
||||
).model_dump()
|
||||
payload.pop("vector_space", None)
|
||||
return payload
|
||||
|
||||
|
||||
@console_ns.route("/features/vector-space")
|
||||
class FeatureVectorSpaceApi(Resource):
|
||||
@console_ns.doc("get_tenant_feature_vector_space")
|
||||
@console_ns.doc(description="Get vector-space usage and limit for current tenant")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.models[LimitationModel.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_utm_record
|
||||
def get(self):
|
||||
"""Get vector-space usage and limit for current tenant"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
return FeatureService.get_vector_space(current_tenant_id).model_dump()
|
||||
|
||||
|
||||
@console_ns.route("/system-features")
|
||||
|
||||
@ -5693,6 +5693,23 @@ Get feature configuration for current tenant
|
||||
| ---- | ----------- | ------ |
|
||||
| 200 | Success | [FeatureModel](#featuremodel) |
|
||||
|
||||
### /features/vector-space
|
||||
|
||||
#### GET
|
||||
##### Summary
|
||||
|
||||
Get vector-space usage and limit for current tenant
|
||||
|
||||
##### Description
|
||||
|
||||
Get vector-space usage and limit for current tenant
|
||||
|
||||
##### Responses
|
||||
|
||||
| Code | Description | Schema |
|
||||
| ---- | ----------- | ------ |
|
||||
| 200 | Success | [LimitationModel](#limitationmodel) |
|
||||
|
||||
### /files/support-type
|
||||
|
||||
#### GET
|
||||
|
||||
@ -116,7 +116,7 @@ class BillingInfo(TypedDict):
|
||||
subscription: _BillingSubscription
|
||||
members: _BillingQuota
|
||||
apps: _BillingQuota
|
||||
vector_space: _VectorSpaceQuota
|
||||
vector_space: NotRequired[_VectorSpaceQuota]
|
||||
knowledge_rate_limit: _KnowledgeRateLimit
|
||||
documents_upload_quota: _BillingQuota
|
||||
annotation_quota_limit: _BillingQuota
|
||||
@ -128,6 +128,7 @@ class BillingInfo(TypedDict):
|
||||
|
||||
|
||||
_billing_info_adapter = TypeAdapter(BillingInfo)
|
||||
_vector_space_quota_adapter = TypeAdapter(_VectorSpaceQuota)
|
||||
|
||||
|
||||
class KnowledgeRateLimitDict(TypedDict):
|
||||
@ -185,12 +186,21 @@ class BillingService:
|
||||
_PLAN_CACHE_TTL = 600
|
||||
|
||||
@classmethod
|
||||
def get_info(cls, tenant_id: str) -> BillingInfo:
|
||||
def get_info(cls, tenant_id: str, exclude_vector_space: bool = False) -> BillingInfo:
|
||||
params = {"tenant_id": tenant_id}
|
||||
if exclude_vector_space:
|
||||
params["exclude_vector_space"] = "true"
|
||||
|
||||
billing_info = cls._send_request("GET", "/subscription/info", params=params)
|
||||
return _billing_info_adapter.validate_python(billing_info)
|
||||
|
||||
@classmethod
|
||||
def get_vector_space(cls, tenant_id: str) -> _VectorSpaceQuota:
|
||||
params = {"tenant_id": tenant_id}
|
||||
return _vector_space_quota_adapter.validate_python(
|
||||
cls._send_request("GET", "/subscription/vector-space", params=params)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_tenant_feature_plan_usage_info(cls, tenant_id: str):
|
||||
"""Deprecated: Use get_quota_info instead."""
|
||||
|
||||
@ -6,7 +6,7 @@ from configs import dify_config
|
||||
from constants.dsl_version import CURRENT_APP_DSL_VERSION
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from enums.hosted_provider import HostedTrialProvider
|
||||
from services.billing_service import BillingService
|
||||
from services.billing_service import BillingInfo, BillingService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
|
||||
|
||||
@ -186,13 +186,17 @@ class SystemFeatureModel(FeatureResponseModel):
|
||||
|
||||
class FeatureService:
|
||||
@classmethod
|
||||
def get_features(cls, tenant_id: str) -> FeatureModel:
|
||||
def get_features(cls, tenant_id: str, exclude_vector_space: bool = False) -> FeatureModel:
|
||||
features = FeatureModel()
|
||||
|
||||
cls._fulfill_params_from_env(features)
|
||||
|
||||
if dify_config.BILLING_ENABLED and tenant_id:
|
||||
cls._fulfill_params_from_billing_api(features, tenant_id)
|
||||
cls._fulfill_params_from_billing_api(
|
||||
features,
|
||||
tenant_id,
|
||||
exclude_vector_space=exclude_vector_space,
|
||||
)
|
||||
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
features.webapp_copyright_enabled = True
|
||||
@ -206,6 +210,18 @@ class FeatureService:
|
||||
|
||||
return features
|
||||
|
||||
@classmethod
|
||||
def get_vector_space(cls, tenant_id: str) -> LimitationModel:
|
||||
vector_space = LimitationModel(size=0, limit=5)
|
||||
if dify_config.BILLING_ENABLED and tenant_id:
|
||||
billing_vector_space = BillingService.get_vector_space(tenant_id)
|
||||
# NOTE: billing API returns vector_space.size as float (e.g. 0.0),
|
||||
# but feature API keeps LimitationModel.size as int for compatibility.
|
||||
vector_space.size = int(billing_vector_space["size"])
|
||||
vector_space.limit = billing_vector_space["limit"]
|
||||
|
||||
return vector_space
|
||||
|
||||
@classmethod
|
||||
def get_knowledge_rate_limit(cls, tenant_id: str):
|
||||
knowledge_rate_limit = KnowledgeRateLimitModel()
|
||||
@ -289,8 +305,16 @@ class FeatureService:
|
||||
features.workspace_members.enabled = workspace_info["WorkspaceMembers"]["enabled"]
|
||||
|
||||
@classmethod
|
||||
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
|
||||
billing_info = BillingService.get_info(tenant_id)
|
||||
def _fulfill_params_from_billing_api(
|
||||
cls,
|
||||
features: FeatureModel,
|
||||
tenant_id: str,
|
||||
exclude_vector_space: bool = False,
|
||||
):
|
||||
if exclude_vector_space:
|
||||
billing_info = BillingService.get_info(tenant_id, exclude_vector_space=True)
|
||||
else:
|
||||
billing_info = BillingService.get_info(tenant_id)
|
||||
|
||||
features_usage_info = BillingService.get_quota_info(tenant_id)
|
||||
|
||||
@ -322,12 +346,8 @@ class FeatureService:
|
||||
features.apps.size = billing_info["apps"]["size"]
|
||||
features.apps.limit = billing_info["apps"]["limit"]
|
||||
|
||||
if "vector_space" in billing_info:
|
||||
# NOTE (hj24): billing API returns vector_space.size as float (e.g. 0.0)
|
||||
# but LimitationModel.size is int; truncate here for compatibility
|
||||
features.vector_space.size = int(billing_info["vector_space"]["size"])
|
||||
# NOTE END
|
||||
features.vector_space.limit = billing_info["vector_space"]["limit"]
|
||||
if not exclude_vector_space:
|
||||
cls._fulfill_vector_space_from_billing_info(features.vector_space, billing_info)
|
||||
|
||||
if "documents_upload_quota" in billing_info:
|
||||
features.documents_upload_quota.size = billing_info["documents_upload_quota"]["size"]
|
||||
@ -359,6 +379,16 @@ class FeatureService:
|
||||
if "next_credit_reset_date" in billing_info:
|
||||
features.next_credit_reset_date = billing_info["next_credit_reset_date"]
|
||||
|
||||
@classmethod
|
||||
def _fulfill_vector_space_from_billing_info(cls, vector_space: LimitationModel, billing_info: BillingInfo):
|
||||
if "vector_space" not in billing_info:
|
||||
return
|
||||
|
||||
# NOTE: billing API returns vector_space.size as float (e.g. 0.0),
|
||||
# but feature API keeps LimitationModel.size as int for compatibility.
|
||||
vector_space.size = int(billing_info["vector_space"]["size"])
|
||||
vector_space.limit = billing_info["vector_space"]["limit"]
|
||||
|
||||
@classmethod
|
||||
def _fulfill_params_from_enterprise(cls, features: SystemFeatureModel, is_authenticated: bool = False):
|
||||
enterprise_info = EnterpriseService.get_info()
|
||||
|
||||
@ -20,8 +20,10 @@ class TestFeatureApi:
|
||||
return_value=("account_id", "tenant_123"),
|
||||
)
|
||||
|
||||
mocker.patch("controllers.console.feature.FeatureService.get_features").return_value.model_dump.return_value = {
|
||||
"features": {"feature_a": True}
|
||||
get_features = mocker.patch("controllers.console.feature.FeatureService.get_features")
|
||||
get_features.return_value.model_dump.return_value = {
|
||||
"features": {"feature_a": True},
|
||||
"vector_space": {"size": 1, "limit": 2},
|
||||
}
|
||||
|
||||
api = FeatureApi()
|
||||
@ -30,6 +32,28 @@ class TestFeatureApi:
|
||||
result = raw_get(api)
|
||||
|
||||
assert result == {"features": {"feature_a": True}}
|
||||
get_features.assert_called_once_with("tenant_123", exclude_vector_space=True)
|
||||
|
||||
|
||||
class TestFeatureVectorSpaceApi:
|
||||
def test_get_vector_space_success(self, mocker: MockerFixture):
|
||||
from controllers.console.feature import FeatureVectorSpaceApi
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.feature.current_account_with_tenant",
|
||||
return_value=("account_id", "tenant_123"),
|
||||
)
|
||||
|
||||
get_vector_space = mocker.patch("controllers.console.feature.FeatureService.get_vector_space")
|
||||
get_vector_space.return_value.model_dump.return_value = {"size": 5120, "limit": 20480}
|
||||
|
||||
api = FeatureVectorSpaceApi()
|
||||
|
||||
raw_get = unwrap(FeatureVectorSpaceApi.get)
|
||||
result = raw_get(api)
|
||||
|
||||
assert result == {"size": 5120, "limit": 20480}
|
||||
get_vector_space.assert_called_once_with("tenant_123")
|
||||
|
||||
|
||||
class TestSystemFeatureApi:
|
||||
|
||||
@ -313,6 +313,54 @@ class TestBillingServiceSubscriptionInfo:
|
||||
assert result == expected_response
|
||||
mock_send_request.assert_called_once_with("GET", "/subscription/info", params={"tenant_id": tenant_id})
|
||||
|
||||
def test_get_info_exclude_vector_space(self, mock_send_request):
|
||||
"""When requested, get_info asks billing to skip vector_space."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
expected_response = {
|
||||
"enabled": True,
|
||||
"subscription": {"plan": "professional", "interval": "month", "education": False},
|
||||
"members": {"size": 1, "limit": 50},
|
||||
"apps": {"size": 1, "limit": 200},
|
||||
"knowledge_rate_limit": {"limit": 1000},
|
||||
"documents_upload_quota": {"size": 0, "limit": 1000},
|
||||
"annotation_quota_limit": {"size": 0, "limit": 5000},
|
||||
"docs_processing": "top-priority",
|
||||
"can_replace_logo": True,
|
||||
"model_load_balancing_enabled": True,
|
||||
"knowledge_pipeline_publish_enabled": True,
|
||||
}
|
||||
mock_send_request.return_value = expected_response
|
||||
|
||||
# Act
|
||||
result = BillingService.get_info(tenant_id, exclude_vector_space=True)
|
||||
|
||||
# Assert
|
||||
assert "vector_space" not in result
|
||||
mock_send_request.assert_called_once_with(
|
||||
"GET",
|
||||
"/subscription/info",
|
||||
params={"tenant_id": tenant_id, "exclude_vector_space": "true"},
|
||||
)
|
||||
|
||||
def test_get_vector_space_success(self, mock_send_request):
|
||||
"""Test successful retrieval of vector-space usage and limit."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
expected_response = {"size": 5120.75, "limit": 20480}
|
||||
mock_send_request.return_value = expected_response
|
||||
|
||||
# Act
|
||||
result = BillingService.get_vector_space(tenant_id)
|
||||
|
||||
# Assert
|
||||
assert result == expected_response
|
||||
mock_send_request.assert_called_once_with(
|
||||
"GET",
|
||||
"/subscription/vector-space",
|
||||
params={"tenant_id": tenant_id},
|
||||
)
|
||||
|
||||
def test_get_knowledge_rate_limit_with_defaults(self, mock_send_request):
|
||||
"""Test knowledge rate limit retrieval with default values."""
|
||||
# Arrange
|
||||
@ -1744,8 +1792,9 @@ class TestBillingServiceSubscriptionInfoDataType:
|
||||
assert isinstance(result["apps"]["size"], int)
|
||||
assert isinstance(result["apps"]["limit"], int)
|
||||
|
||||
assert isinstance(result["vector_space"]["size"], float)
|
||||
assert isinstance(result["vector_space"]["limit"], int)
|
||||
if "vector_space" in result:
|
||||
assert isinstance(result["vector_space"]["size"], float)
|
||||
assert isinstance(result["vector_space"]["limit"], int)
|
||||
|
||||
assert isinstance(result["knowledge_rate_limit"]["limit"], int)
|
||||
|
||||
@ -1783,11 +1832,13 @@ class TestBillingServiceSubscriptionInfoDataType:
|
||||
def test_get_info_without_optional_fields(self, mock_send_request, string_billing_response):
|
||||
"""NotRequired fields can be absent without raising."""
|
||||
del string_billing_response["next_credit_reset_date"]
|
||||
del string_billing_response["vector_space"]
|
||||
mock_send_request.return_value = string_billing_response
|
||||
|
||||
result = BillingService.get_info("tenant-type-test")
|
||||
|
||||
assert "next_credit_reset_date" not in result
|
||||
assert "vector_space" not in result
|
||||
self._assert_billing_info_types(result)
|
||||
|
||||
def test_get_info_with_extra_fields(self, mock_send_request, string_billing_response):
|
||||
|
||||
@ -102,3 +102,17 @@ def test_resolve_human_input_email_delivery_enabled_matrix(
|
||||
)
|
||||
|
||||
assert result is case.expected
|
||||
|
||||
|
||||
def test_get_vector_space_converts_billing_float_size(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(feature_service_module.dify_config, "BILLING_ENABLED", True)
|
||||
monkeypatch.setattr(
|
||||
feature_service_module.BillingService,
|
||||
"get_vector_space",
|
||||
lambda tenant_id: {"size": 5120.75, "limit": 20480},
|
||||
)
|
||||
|
||||
result = FeatureService.get_vector_space("tenant-1")
|
||||
|
||||
assert result.size == 5120
|
||||
assert result.limit == 20480
|
||||
|
||||
@ -2,14 +2,35 @@
|
||||
|
||||
import { oc } from '@orpc/contract'
|
||||
|
||||
import { zGetFeaturesResponse } from './zod.gen'
|
||||
import { zGetFeaturesResponse, zGetFeaturesVectorSpaceResponse } from './zod.gen'
|
||||
|
||||
/**
|
||||
* Get vector-space usage and limit for current tenant
|
||||
*
|
||||
* Get vector-space usage and limit for current tenant
|
||||
*/
|
||||
export const get = oc
|
||||
.route({
|
||||
description: 'Get vector-space usage and limit for current tenant',
|
||||
inputStructure: 'detailed',
|
||||
method: 'GET',
|
||||
operationId: 'getFeaturesVectorSpace',
|
||||
path: '/features/vector-space',
|
||||
summary: 'Get vector-space usage and limit for current tenant',
|
||||
tags: ['console'],
|
||||
})
|
||||
.output(zGetFeaturesVectorSpaceResponse)
|
||||
|
||||
export const vectorSpace = {
|
||||
get,
|
||||
}
|
||||
|
||||
/**
|
||||
* Get feature configuration for current tenant
|
||||
*
|
||||
* Get feature configuration for current tenant
|
||||
*/
|
||||
export const get = oc
|
||||
export const get2 = oc
|
||||
.route({
|
||||
description: 'Get feature configuration for current tenant',
|
||||
inputStructure: 'detailed',
|
||||
@ -22,7 +43,8 @@ export const get = oc
|
||||
.output(zGetFeaturesResponse)
|
||||
|
||||
export const features = {
|
||||
get,
|
||||
get: get2,
|
||||
vectorSpace,
|
||||
}
|
||||
|
||||
export const contract = {
|
||||
|
||||
@ -75,3 +75,17 @@ export type GetFeaturesResponses = {
|
||||
}
|
||||
|
||||
export type GetFeaturesResponse = GetFeaturesResponses[keyof GetFeaturesResponses]
|
||||
|
||||
export type GetFeaturesVectorSpaceData = {
|
||||
body?: never
|
||||
path?: never
|
||||
query?: never
|
||||
url: '/features/vector-space'
|
||||
}
|
||||
|
||||
export type GetFeaturesVectorSpaceResponses = {
|
||||
200: LimitationModel
|
||||
}
|
||||
|
||||
export type GetFeaturesVectorSpaceResponse
|
||||
= GetFeaturesVectorSpaceResponses[keyof GetFeaturesVectorSpaceResponses]
|
||||
|
||||
@ -93,3 +93,8 @@ export const zFeatureModel = z.object({
|
||||
* Success
|
||||
*/
|
||||
export const zGetFeaturesResponse = zFeatureModel
|
||||
|
||||
/**
|
||||
* Success
|
||||
*/
|
||||
export const zGetFeaturesVectorSpaceResponse = zLimitationModel
|
||||
|
||||
@ -53,6 +53,9 @@ vi.mock('@/service/use-billing', () => ({
|
||||
refetch: mockRefetch,
|
||||
}),
|
||||
useBindPartnerStackInfo: () => ({ mutateAsync: vi.fn() }),
|
||||
useCurrentPlanVectorSpace: () => ({
|
||||
data: undefined,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-education', () => ({
|
||||
|
||||
@ -60,6 +60,9 @@ vi.mock('@/service/use-billing', () => ({
|
||||
isFetching: false,
|
||||
refetch: vi.fn(),
|
||||
}),
|
||||
useCurrentPlanVectorSpace: () => ({
|
||||
data: undefined,
|
||||
}),
|
||||
}))
|
||||
|
||||
// ─── Navigation mocks ───────────────────────────────────────────────────────
|
||||
|
||||
@ -39,6 +39,12 @@ vi.mock('@/service/billing', () => ({
|
||||
fetchSubscriptionUrls: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-billing', () => ({
|
||||
useCurrentPlanVectorSpace: () => ({
|
||||
data: undefined,
|
||||
}),
|
||||
}))
|
||||
|
||||
const fetchSubscriptionUrlsMock = vi.mocked(fetchSubscriptionUrls)
|
||||
|
||||
const mutateAsyncMock = vi.fn()
|
||||
|
||||
@ -71,10 +71,6 @@ export type CurrentPlanInfoBackend = {
|
||||
size: number
|
||||
limit: number // total. 0 means unlimited
|
||||
}
|
||||
vector_space: {
|
||||
size: number
|
||||
limit: number // total. 0 means unlimited
|
||||
}
|
||||
annotation_quota_limit: {
|
||||
size: number
|
||||
limit: number // total. 0 means unlimited
|
||||
|
||||
@ -10,6 +10,7 @@ const queryPlaceholder = () =>
|
||||
let mockPlanType = Plan.sandbox
|
||||
let mockVectorSpaceUsage = 30
|
||||
let mockVectorSpaceTotal = 5120
|
||||
let mockVectorSpaceApiData: { size: number, limit: number } | undefined
|
||||
|
||||
vi.mock('@/context/provider-context', () => ({
|
||||
useProviderContext: () => ({
|
||||
@ -28,6 +29,12 @@ vi.mock('@/context/provider-context', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-billing', () => ({
|
||||
useCurrentPlanVectorSpace: () => ({
|
||||
data: mockVectorSpaceApiData,
|
||||
}),
|
||||
}))
|
||||
|
||||
describe('VectorSpaceInfo', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
@ -35,6 +42,7 @@ describe('VectorSpaceInfo', () => {
|
||||
mockPlanType = Plan.sandbox
|
||||
mockVectorSpaceUsage = 30
|
||||
mockVectorSpaceTotal = 5120
|
||||
mockVectorSpaceApiData = undefined
|
||||
})
|
||||
|
||||
describe('Rendering', () => {
|
||||
@ -252,5 +260,18 @@ describe('VectorSpaceInfo', () => {
|
||||
expect(screen.getByText('100')).toBeInTheDocument()
|
||||
expect(screen.getByText('102400MB')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should use vector space API limit directly', () => {
|
||||
mockVectorSpaceApiData = {
|
||||
size: 100,
|
||||
limit: 0,
|
||||
}
|
||||
|
||||
render(<VectorSpaceInfo />)
|
||||
|
||||
expect(screen.getByText('100')).toBeInTheDocument()
|
||||
expect(screen.getByText('0MB')).toBeInTheDocument()
|
||||
expect(screen.queryByText('billing.plansCommon.unlimited')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -7,6 +7,7 @@ import {
|
||||
import * as React from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { useCurrentPlanVectorSpace } from '@/service/use-billing'
|
||||
import { Plan } from '../type'
|
||||
import UsageInfo from '../usage-info'
|
||||
import { getPlanVectorSpaceLimitMB } from '../utils'
|
||||
@ -23,11 +24,25 @@ const VectorSpaceInfo: FC<Props> = ({
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { plan } = useProviderContext()
|
||||
const { data: vectorSpace } = useCurrentPlanVectorSpace()
|
||||
const displayPlan = vectorSpace
|
||||
? {
|
||||
...plan,
|
||||
usage: {
|
||||
...plan.usage,
|
||||
vectorSpace: vectorSpace.size,
|
||||
},
|
||||
total: {
|
||||
...plan.total,
|
||||
vectorSpace: vectorSpace.limit,
|
||||
},
|
||||
}
|
||||
: plan
|
||||
const {
|
||||
type,
|
||||
usage,
|
||||
total,
|
||||
} = plan
|
||||
} = displayPlan
|
||||
|
||||
// Determine total based on plan type (in MB), derived from ALL_PLANS config
|
||||
const getTotalInMB = () => {
|
||||
|
||||
@ -65,10 +65,6 @@ describe('billing utils', () => {
|
||||
size: 2,
|
||||
limit: 5,
|
||||
},
|
||||
vector_space: {
|
||||
size: 10,
|
||||
limit: 50,
|
||||
},
|
||||
annotation_quota_limit: {
|
||||
size: 5,
|
||||
limit: 10,
|
||||
@ -108,7 +104,7 @@ describe('billing utils', () => {
|
||||
const data = createMockPlanData()
|
||||
const result = parseCurrentPlan(data)
|
||||
|
||||
expect(result.usage.vectorSpace).toBe(10)
|
||||
expect(result.usage.vectorSpace).toBe(0)
|
||||
expect(result.usage.buildApps).toBe(2)
|
||||
expect(result.usage.teamMembers).toBe(1)
|
||||
expect(result.usage.annotatedResponse).toBe(5)
|
||||
@ -125,6 +121,29 @@ describe('billing utils', () => {
|
||||
expect(result.total.annotatedResponse).toBe(10)
|
||||
})
|
||||
|
||||
it('should not read vector space usage from current plan info', () => {
|
||||
const data = createMockPlanData()
|
||||
const result = parseCurrentPlan(data)
|
||||
|
||||
expect(result.usage.vectorSpace).toBe(0)
|
||||
expect(result.total.vectorSpace).toBe(50)
|
||||
})
|
||||
|
||||
it('should derive vector space total from plan config', () => {
|
||||
const data = createMockPlanData({
|
||||
billing: {
|
||||
enabled: true,
|
||||
subscription: {
|
||||
plan: Plan.professional,
|
||||
},
|
||||
},
|
||||
})
|
||||
const result = parseCurrentPlan(data)
|
||||
|
||||
expect(result.usage.vectorSpace).toBe(0)
|
||||
expect(result.total.vectorSpace).toBe(5 * 1024)
|
||||
})
|
||||
|
||||
it('should convert 0 limits to NUM_INFINITE (-1)', () => {
|
||||
const data = createMockPlanData({
|
||||
documents_upload_quota: {
|
||||
|
||||
@ -79,6 +79,7 @@ const getResetInDaysFromDate = (resetDate?: number | null) => {
|
||||
export const parseCurrentPlan = (data: CurrentPlanInfoBackend) => {
|
||||
const planType = data.billing.subscription.plan
|
||||
const planPreset = ALL_PLANS[planType]
|
||||
const vectorSpaceLimit = getPlanVectorSpaceLimitMB(planType)
|
||||
const resolveRateLimit = (limit?: number, fallback?: number) => {
|
||||
const value = limit ?? fallback ?? 0
|
||||
return parseRateLimit(value)
|
||||
@ -93,7 +94,7 @@ export const parseCurrentPlan = (data: CurrentPlanInfoBackend) => {
|
||||
return {
|
||||
type: planType,
|
||||
usage: {
|
||||
vectorSpace: data.vector_space.size,
|
||||
vectorSpace: 0,
|
||||
buildApps: data.apps?.size || 0,
|
||||
teamMembers: data.members.size,
|
||||
annotatedResponse: data.annotation_quota_limit.size,
|
||||
@ -102,7 +103,7 @@ export const parseCurrentPlan = (data: CurrentPlanInfoBackend) => {
|
||||
triggerEvents: getQuotaUsage(data.trigger_event),
|
||||
},
|
||||
total: {
|
||||
vectorSpace: parseLimit(data.vector_space.limit),
|
||||
vectorSpace: vectorSpaceLimit,
|
||||
buildApps: parseLimit(data.apps?.limit) || 0,
|
||||
teamMembers: parseLimit(data.members.limit),
|
||||
annotatedResponse: parseLimit(data.annotation_quota_limit.limit),
|
||||
|
||||
@ -21,6 +21,12 @@ vi.mock('../../upgrade-btn', () => ({
|
||||
default: () => <button data-testid="vector-upgrade-btn" type="button">Upgrade</button>,
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-billing', () => ({
|
||||
useCurrentPlanVectorSpace: () => ({
|
||||
data: undefined,
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock utils to control threshold and plan limits
|
||||
vi.mock('../../utils', () => ({
|
||||
getPlanVectorSpaceLimitMB: (planType: string) => {
|
||||
|
||||
@ -148,6 +148,30 @@ describe('DocumentPicker', () => {
|
||||
expect(trigger).not.toHaveFocus()
|
||||
})
|
||||
|
||||
it('should keep focus in the search input when deleting from an empty result', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockUseDocumentList.mockImplementation(({ query }) => ({
|
||||
data: query.keyword === 'missing'
|
||||
? { data: [] }
|
||||
: mockDocumentListData,
|
||||
}))
|
||||
|
||||
renderDocumentPicker()
|
||||
|
||||
const trigger = screen.getByRole('combobox', { name: 'Document 1' })
|
||||
await user.click(trigger)
|
||||
|
||||
const searchInput = screen.getByPlaceholderText('common.operation.search')
|
||||
await user.type(searchInput, 'missing')
|
||||
expect(await screen.findByText('common.noData')).toBeInTheDocument()
|
||||
|
||||
await user.keyboard('{Backspace}{Backspace}{Backspace}{Backspace}{Backspace}{Backspace}{Backspace}')
|
||||
|
||||
expect(trigger).toHaveAttribute('aria-expanded', 'true')
|
||||
expect(searchInput).toHaveFocus()
|
||||
expect(trigger).not.toHaveFocus()
|
||||
})
|
||||
|
||||
it('should keep focus in the search input while typing quickly', async () => {
|
||||
const user = userEvent.setup()
|
||||
renderDocumentPicker()
|
||||
|
||||
@ -13,7 +13,8 @@ import {
|
||||
ComboboxValue,
|
||||
} from '@langgenius/dify-ui/combobox'
|
||||
import { RiArrowDownSLine } from '@remixicon/react'
|
||||
import { useDeferredValue, useState } from 'react'
|
||||
import { useDebounce } from 'ahooks'
|
||||
import { useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { GeneralChunk, ParentChildChunk } from '@/app/components/base/icons/src/vender/knowledge'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
@ -106,12 +107,12 @@ export function DocumentPicker({
|
||||
}: Props) {
|
||||
const { t } = useTranslation()
|
||||
const [searchValue, setSearchValue] = useState('')
|
||||
const deferredSearchValue = useDeferredValue(searchValue)
|
||||
const debouncedSearchValue = useDebounce(searchValue, { wait: 500 })
|
||||
|
||||
const { data } = useDocumentList({
|
||||
datasetId,
|
||||
query: {
|
||||
keyword: deferredSearchValue,
|
||||
keyword: debouncedSearchValue,
|
||||
page: 1,
|
||||
limit: 20,
|
||||
},
|
||||
@ -175,19 +176,16 @@ export function DocumentPicker({
|
||||
className="block h-4.5 grow px-1 py-0 text-[13px] text-text-primary"
|
||||
/>
|
||||
</ComboboxInputGroup>
|
||||
<DocumentList
|
||||
className="mt-2 data-empty:mt-0"
|
||||
/>
|
||||
{data
|
||||
? (
|
||||
documentsList.length > 0
|
||||
? (
|
||||
<DocumentList
|
||||
className="mt-2"
|
||||
/>
|
||||
)
|
||||
: (
|
||||
<ComboboxEmpty className="mt-2 flex h-[100px] w-full items-center justify-center">
|
||||
{t('noData', { ns: 'common' })}
|
||||
</ComboboxEmpty>
|
||||
)
|
||||
<ComboboxEmpty className="p-0">
|
||||
<div className="mt-2 flex h-[100px] w-full items-center justify-center px-3 py-2 system-sm-regular text-text-tertiary">
|
||||
{t('noData', { ns: 'common' })}
|
||||
</div>
|
||||
</ComboboxEmpty>
|
||||
)
|
||||
: (
|
||||
<ComboboxStatus className="mt-2 flex h-[100px] w-full items-center justify-center">
|
||||
|
||||
@ -36,6 +36,16 @@ vi.mock('@/context/provider-context', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-billing', () => ({
|
||||
useCurrentPlanVectorSpace: () => ({
|
||||
data: {
|
||||
size: mockPlan.usage.vectorSpace,
|
||||
limit: mockPlan.total.vectorSpace,
|
||||
},
|
||||
isFetching: false,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('../../file-uploader', () => ({
|
||||
default: ({ onPreview, fileList }: { onPreview: (file: File) => void, fileList: FileItem[] }) => (
|
||||
<div data-testid="file-uploader">
|
||||
|
||||
@ -15,6 +15,7 @@ import VectorSpaceFull from '@/app/components/billing/vector-space-full'
|
||||
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { DataSourceType } from '@/models/datasets'
|
||||
import { useCurrentPlanVectorSpace } from '@/service/use-billing'
|
||||
import EmptyDatasetCreationModal from '../empty-dataset-creation-modal'
|
||||
import FileUploader from '../file-uploader'
|
||||
import Website from '../website'
|
||||
@ -119,7 +120,15 @@ const StepOne = ({
|
||||
|
||||
const allFileLoaded = files.length > 0 && files.every(file => file.file.id)
|
||||
const hasNotion = notionPages.length > 0
|
||||
const isVectorSpaceFull = plan.usage.vectorSpace >= plan.total.vectorSpace
|
||||
const shouldCheckVectorSpace = enableBilling && (allFileLoaded || hasNotion)
|
||||
const {
|
||||
data: vectorSpace,
|
||||
isFetching: isFetchingVectorSpacePlan,
|
||||
} = useCurrentPlanVectorSpace(shouldCheckVectorSpace)
|
||||
const isCheckingVectorSpace = shouldCheckVectorSpace && !vectorSpace && isFetchingVectorSpacePlan
|
||||
const isVectorSpaceFull = !!vectorSpace
|
||||
&& vectorSpace.limit > 0
|
||||
&& vectorSpace.size >= vectorSpace.limit
|
||||
const isShowVectorSpaceFull = (allFileLoaded || hasNotion) && isVectorSpaceFull && enableBilling
|
||||
const supportBatchUpload = !enableBilling || plan.type !== Plan.sandbox
|
||||
|
||||
@ -131,8 +140,10 @@ const StepOne = ({
|
||||
return true
|
||||
if (files.some(file => !file.file.id))
|
||||
return true
|
||||
if (isCheckingVectorSpace)
|
||||
return true
|
||||
return isShowVectorSpaceFull
|
||||
}, [files, isShowVectorSpaceFull])
|
||||
}, [files, isCheckingVectorSpace, isShowVectorSpaceFull])
|
||||
|
||||
// Clear previews when switching data source type
|
||||
const handleClearPreviews = useCallback((newType: DataSourceType) => {
|
||||
|
||||
@ -42,6 +42,16 @@ vi.mock('@/context/provider-context', () => ({
|
||||
selector({ plan: mockPlan, enableBilling: true }),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-billing', () => ({
|
||||
useCurrentPlanVectorSpace: () => ({
|
||||
data: {
|
||||
size: mockPlan.usage.vectorSpace,
|
||||
limit: mockPlan.total.vectorSpace,
|
||||
},
|
||||
isFetching: false,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/dataset-detail', () => ({
|
||||
useDatasetDetailContextWithSelector: (selector: (state: { dataset: { pipeline_id: string } }) => unknown) =>
|
||||
selector({ dataset: { pipeline_id: 'test-pipeline-id' } }),
|
||||
|
||||
@ -13,6 +13,7 @@ type DatasourceUIStateParams = {
|
||||
selectedFileIdsLength: number
|
||||
onlineDriveFileList: OnlineDriveFile[]
|
||||
isVectorSpaceFull: boolean
|
||||
isCheckingVectorSpace?: boolean
|
||||
enableBilling: boolean
|
||||
currentWorkspacePagesLength: number
|
||||
fileUploadConfig: { file_size_limit: number, batch_count_limit: number }
|
||||
@ -30,6 +31,7 @@ export const useDatasourceUIState = ({
|
||||
selectedFileIdsLength,
|
||||
onlineDriveFileList,
|
||||
isVectorSpaceFull,
|
||||
isCheckingVectorSpace = false,
|
||||
enableBilling,
|
||||
currentWorkspacePagesLength,
|
||||
fileUploadConfig,
|
||||
@ -59,14 +61,14 @@ export const useDatasourceUIState = ({
|
||||
return true
|
||||
|
||||
const disabledConditions: Record<string, boolean> = {
|
||||
[DatasourceType.localFile]: isShowVectorSpaceFull || localFileListLength === 0 || !allFileLoaded,
|
||||
[DatasourceType.onlineDocument]: isShowVectorSpaceFull || onlineDocumentsLength === 0,
|
||||
[DatasourceType.websiteCrawl]: isShowVectorSpaceFull || websitePagesLength === 0,
|
||||
[DatasourceType.onlineDrive]: isShowVectorSpaceFull || selectedFileIdsLength === 0,
|
||||
[DatasourceType.localFile]: isCheckingVectorSpace || isShowVectorSpaceFull || localFileListLength === 0 || !allFileLoaded,
|
||||
[DatasourceType.onlineDocument]: isCheckingVectorSpace || isShowVectorSpaceFull || onlineDocumentsLength === 0,
|
||||
[DatasourceType.websiteCrawl]: isCheckingVectorSpace || isShowVectorSpaceFull || websitePagesLength === 0,
|
||||
[DatasourceType.onlineDrive]: isCheckingVectorSpace || isShowVectorSpaceFull || selectedFileIdsLength === 0,
|
||||
}
|
||||
|
||||
return disabledConditions[datasourceType] ?? true
|
||||
}, [datasource, datasourceType, isShowVectorSpaceFull, localFileListLength, allFileLoaded, onlineDocumentsLength, websitePagesLength, selectedFileIdsLength])
|
||||
}, [datasource, datasourceType, isCheckingVectorSpace, isShowVectorSpaceFull, localFileListLength, allFileLoaded, onlineDocumentsLength, websitePagesLength, selectedFileIdsLength])
|
||||
|
||||
// Check if select all should be shown
|
||||
const showSelect = useMemo(() => {
|
||||
|
||||
@ -12,6 +12,7 @@ import { PlanUpgradeModal } from '@/app/components/billing/plan-upgrade-modal'
|
||||
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
|
||||
import { useProviderContextSelector } from '@/context/provider-context'
|
||||
import { DatasourceType } from '@/models/pipeline'
|
||||
import { useCurrentPlanVectorSpace } from '@/service/use-billing'
|
||||
import { useFileUploadConfig } from '@/service/use-common'
|
||||
import { usePublishedPipelineInfo } from '@/service/use-pipeline'
|
||||
import { useDataSourceStore } from './data-source/store'
|
||||
@ -91,7 +92,20 @@ const CreateFormPipeline = () => {
|
||||
} = useOnlineDrive()
|
||||
|
||||
// Computed values
|
||||
const isVectorSpaceFull = plan.usage.vectorSpace >= plan.total.vectorSpace
|
||||
const shouldCheckVectorSpace = enableBilling && (
|
||||
allFileLoaded
|
||||
|| onlineDocuments.length > 0
|
||||
|| websitePages.length > 0
|
||||
|| selectedFileIds.length > 0
|
||||
)
|
||||
const {
|
||||
data: vectorSpace,
|
||||
isFetching: isFetchingVectorSpacePlan,
|
||||
} = useCurrentPlanVectorSpace(shouldCheckVectorSpace)
|
||||
const isCheckingVectorSpace = shouldCheckVectorSpace && !vectorSpace && isFetchingVectorSpacePlan
|
||||
const isVectorSpaceFull = !!vectorSpace
|
||||
&& vectorSpace.limit > 0
|
||||
&& vectorSpace.size >= vectorSpace.limit
|
||||
const supportBatchUpload = !enableBilling || plan.type !== 'sandbox'
|
||||
|
||||
// UI state
|
||||
@ -112,6 +126,7 @@ const CreateFormPipeline = () => {
|
||||
selectedFileIdsLength: selectedFileIds.length,
|
||||
onlineDriveFileList,
|
||||
isVectorSpaceFull,
|
||||
isCheckingVectorSpace,
|
||||
enableBilling,
|
||||
currentWorkspacePagesLength: currentWorkspace?.pages.length ?? 0,
|
||||
fileUploadConfig,
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import type { ButtonHTMLAttributes } from 'react'
|
||||
import type { NodeDefault } from '../../types'
|
||||
import { Button } from '@langgenius/dify-ui/button'
|
||||
import { screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { renderWorkflowComponent } from '../../__tests__/workflow-test-env'
|
||||
@ -191,4 +192,28 @@ describe('NodeSelector', () => {
|
||||
expect(trigger.closest('[aria-haspopup="dialog"]')).toBe(trigger)
|
||||
expect(screen.getByPlaceholderText('workflow.tabs.searchBlock')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('can render the shared Button trigger as the popover root', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
renderWorkflowComponent(
|
||||
<NodeSelector
|
||||
onSelect={vi.fn()}
|
||||
blocks={[createBlock(BlockEnum.LLM, 'LLM')]}
|
||||
availableBlocksTypes={[BlockEnum.LLM]}
|
||||
renderTriggerAsButtonRoot
|
||||
trigger={() => (
|
||||
<Button variant="primary">
|
||||
open-shared-button-trigger
|
||||
</Button>
|
||||
)}
|
||||
/>,
|
||||
)
|
||||
|
||||
const trigger = screen.getByRole('button', { name: 'open-shared-button-trigger' })
|
||||
await user.click(trigger)
|
||||
|
||||
expect(trigger.closest('[aria-haspopup="dialog"]')).toBe(trigger)
|
||||
expect(screen.getByPlaceholderText('workflow.tabs.searchBlock')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@ -56,6 +56,7 @@ const DataSourceEmptyNode = ({ id, data }: NodeProps) => {
|
||||
<BlockSelector
|
||||
onSelect={handleReplaceNode}
|
||||
trigger={renderTrigger}
|
||||
renderTriggerAsButtonRoot
|
||||
noBlocks
|
||||
noTools
|
||||
popupClassName="w-[320px]"
|
||||
|
||||
@ -1,4 +1,7 @@
|
||||
import type { ComponentType, SVGProps } from 'react'
|
||||
import { FieldRoot } from '@langgenius/dify-ui/field'
|
||||
import { FieldsetLegend, FieldsetRoot } from '@langgenius/dify-ui/fieldset'
|
||||
import { RadioGroup } from '@langgenius/dify-ui/radio-group'
|
||||
import {
|
||||
fireEvent,
|
||||
render,
|
||||
@ -9,7 +12,7 @@ import {
|
||||
RetrievalSearchMethodEnum,
|
||||
WeightedScoreEnum,
|
||||
} from '../../../types'
|
||||
import SearchMethodOption from '../search-method-option'
|
||||
import { SearchMethodOption } from '../search-method-option'
|
||||
|
||||
const mockUseModelListAndDefaultModel = vi.hoisted(() => vi.fn())
|
||||
const mockUseProviderContext = vi.hoisted(() => vi.fn())
|
||||
@ -68,29 +71,62 @@ const createProps = () => ({
|
||||
description: 'Semantic description',
|
||||
effectColor: 'purple',
|
||||
},
|
||||
hybridSearchModeOptions,
|
||||
searchMethod: RetrievalSearchMethodEnum.semantic,
|
||||
onRetrievalSearchMethodChange: vi.fn(),
|
||||
hybridSearchMode: HybridSearchModeEnum.WeightedScore,
|
||||
onHybridSearchModeChange: vi.fn(),
|
||||
weightedScore,
|
||||
onWeightedScoreChange: vi.fn(),
|
||||
rerankingModelEnabled: false,
|
||||
onRerankingModelEnabledChange: vi.fn(),
|
||||
rerankingModel: {
|
||||
reranking_provider_name: '',
|
||||
reranking_model_name: '',
|
||||
hybridSearch: {
|
||||
mode: HybridSearchModeEnum.WeightedScore,
|
||||
options: hybridSearchModeOptions,
|
||||
onModeChange: vi.fn(),
|
||||
weightedScore,
|
||||
onWeightedScoreChange: vi.fn(),
|
||||
},
|
||||
reranking: {
|
||||
enabled: false,
|
||||
onEnabledChange: vi.fn(),
|
||||
rerankingModel: {
|
||||
reranking_provider_name: '',
|
||||
reranking_model_name: '',
|
||||
},
|
||||
onRerankingModelChange: vi.fn(),
|
||||
showMultiModalTip: false,
|
||||
},
|
||||
retrievalParameters: {
|
||||
topK: {
|
||||
value: 3,
|
||||
onChange: vi.fn(),
|
||||
},
|
||||
scoreThreshold: {
|
||||
value: 0.5,
|
||||
onChange: vi.fn(),
|
||||
enabled: true,
|
||||
onEnabledChange: vi.fn(),
|
||||
},
|
||||
},
|
||||
onRerankingModelChange: vi.fn(),
|
||||
topK: 3,
|
||||
onTopKChange: vi.fn(),
|
||||
scoreThreshold: 0.5,
|
||||
onScoreThresholdChange: vi.fn(),
|
||||
isScoreThresholdEnabled: true,
|
||||
onScoreThresholdEnabledChange: vi.fn(),
|
||||
showMultiModalTip: false,
|
||||
})
|
||||
|
||||
function renderSearchMethodOption(props: ReturnType<typeof createProps>) {
|
||||
const {
|
||||
onRetrievalSearchMethodChange,
|
||||
...optionProps
|
||||
} = props
|
||||
|
||||
render(
|
||||
<FieldRoot name="retrieval_search_method">
|
||||
<FieldsetRoot
|
||||
render={(
|
||||
<RadioGroup
|
||||
value={props.searchMethod}
|
||||
onValueChange={value => onRetrievalSearchMethodChange(value)}
|
||||
/>
|
||||
)}
|
||||
>
|
||||
<FieldsetLegend>Retrieval search method</FieldsetLegend>
|
||||
<SearchMethodOption {...optionProps} />
|
||||
</FieldsetRoot>
|
||||
</FieldRoot>,
|
||||
)
|
||||
}
|
||||
|
||||
describe('SearchMethodOption', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
@ -116,37 +152,32 @@ describe('SearchMethodOption', () => {
|
||||
it('should render semantic search controls and notify retrieval and reranking changes', () => {
|
||||
const props = createProps()
|
||||
|
||||
render(<SearchMethodOption {...props} />)
|
||||
renderSearchMethodOption(props)
|
||||
|
||||
expect(screen.getByText('Semantic title'))!.toBeInTheDocument()
|
||||
expect(screen.getByText('common.modelProvider.rerankModel.key'))!.toBeInTheDocument()
|
||||
expect(screen.getByText('plugin.detailPanel.configureModel'))!.toBeInTheDocument()
|
||||
expect(screen.getAllByRole('switch')).toHaveLength(2)
|
||||
|
||||
fireEvent.click(screen.getByText('Semantic title'))
|
||||
fireEvent.click(screen.getAllByRole('switch')[0]!)
|
||||
|
||||
expect(props.onRetrievalSearchMethodChange).toHaveBeenCalledWith(RetrievalSearchMethodEnum.semantic)
|
||||
expect(props.onRerankingModelEnabledChange).toHaveBeenCalledWith(true)
|
||||
expect(props.reranking.onEnabledChange).toHaveBeenCalledWith(true)
|
||||
})
|
||||
|
||||
it('should render the reranking switch for full-text search as well', () => {
|
||||
it('should notify retrieval changes when an inactive option is selected', () => {
|
||||
const props = createProps()
|
||||
const fullTextProps = {
|
||||
...props,
|
||||
option: {
|
||||
...props.option,
|
||||
id: RetrievalSearchMethodEnum.fullText,
|
||||
title: 'Full-text title',
|
||||
},
|
||||
}
|
||||
|
||||
render(
|
||||
<SearchMethodOption
|
||||
{...props}
|
||||
option={{
|
||||
...props.option,
|
||||
id: RetrievalSearchMethodEnum.fullText,
|
||||
title: 'Full-text title',
|
||||
}}
|
||||
searchMethod={RetrievalSearchMethodEnum.fullText}
|
||||
/>,
|
||||
)
|
||||
renderSearchMethodOption(fullTextProps)
|
||||
|
||||
expect(screen.getByText('Full-text title'))!.toBeInTheDocument()
|
||||
expect(screen.getByText('common.modelProvider.rerankModel.key'))!.toBeInTheDocument()
|
||||
|
||||
fireEvent.click(screen.getByText('Full-text title'))
|
||||
|
||||
@ -155,20 +186,25 @@ describe('SearchMethodOption', () => {
|
||||
|
||||
it('should render hybrid weighted-score controls without reranking model selector', () => {
|
||||
const props = createProps()
|
||||
const hybridProps = {
|
||||
...props,
|
||||
option: {
|
||||
...props.option,
|
||||
id: RetrievalSearchMethodEnum.hybrid,
|
||||
title: 'Hybrid title',
|
||||
},
|
||||
searchMethod: RetrievalSearchMethodEnum.hybrid,
|
||||
hybridSearch: {
|
||||
...props.hybridSearch,
|
||||
mode: HybridSearchModeEnum.WeightedScore,
|
||||
},
|
||||
reranking: {
|
||||
...props.reranking,
|
||||
showMultiModalTip: true,
|
||||
},
|
||||
}
|
||||
|
||||
render(
|
||||
<SearchMethodOption
|
||||
{...props}
|
||||
option={{
|
||||
...props.option,
|
||||
id: RetrievalSearchMethodEnum.hybrid,
|
||||
title: 'Hybrid title',
|
||||
}}
|
||||
searchMethod={RetrievalSearchMethodEnum.hybrid}
|
||||
hybridSearchMode={HybridSearchModeEnum.WeightedScore}
|
||||
showMultiModalTip
|
||||
/>,
|
||||
)
|
||||
renderSearchMethodOption(hybridProps)
|
||||
|
||||
expect(screen.getByText('Weighted mode'))!.toBeInTheDocument()
|
||||
expect(screen.getByText('Rerank mode'))!.toBeInTheDocument()
|
||||
@ -179,25 +215,30 @@ describe('SearchMethodOption', () => {
|
||||
|
||||
fireEvent.click(screen.getByText('Rerank mode'))
|
||||
|
||||
expect(props.onHybridSearchModeChange).toHaveBeenCalledWith(HybridSearchModeEnum.RerankingModel)
|
||||
expect(props.hybridSearch.onModeChange).toHaveBeenCalledWith(HybridSearchModeEnum.RerankingModel)
|
||||
})
|
||||
|
||||
it('should render the hybrid reranking selector when reranking mode is selected', () => {
|
||||
const props = createProps()
|
||||
const hybridProps = {
|
||||
...props,
|
||||
option: {
|
||||
...props.option,
|
||||
id: RetrievalSearchMethodEnum.hybrid,
|
||||
title: 'Hybrid title',
|
||||
},
|
||||
searchMethod: RetrievalSearchMethodEnum.hybrid,
|
||||
hybridSearch: {
|
||||
...props.hybridSearch,
|
||||
mode: HybridSearchModeEnum.RerankingModel,
|
||||
},
|
||||
reranking: {
|
||||
...props.reranking,
|
||||
showMultiModalTip: true,
|
||||
},
|
||||
}
|
||||
|
||||
render(
|
||||
<SearchMethodOption
|
||||
{...props}
|
||||
option={{
|
||||
...props.option,
|
||||
id: RetrievalSearchMethodEnum.hybrid,
|
||||
title: 'Hybrid title',
|
||||
}}
|
||||
searchMethod={RetrievalSearchMethodEnum.hybrid}
|
||||
hybridSearchMode={HybridSearchModeEnum.RerankingModel}
|
||||
showMultiModalTip
|
||||
/>,
|
||||
)
|
||||
renderSearchMethodOption(hybridProps)
|
||||
|
||||
expect(screen.getByText('plugin.detailPanel.configureModel'))!.toBeInTheDocument()
|
||||
expect(screen.queryByText('common.modelProvider.rerankModel.key')).not.toBeInTheDocument()
|
||||
@ -207,23 +248,22 @@ describe('SearchMethodOption', () => {
|
||||
|
||||
it('should hide the score-threshold control for keyword search', () => {
|
||||
const props = createProps()
|
||||
const keywordProps = {
|
||||
...props,
|
||||
option: {
|
||||
...props.option,
|
||||
id: RetrievalSearchMethodEnum.keywordSearch,
|
||||
title: 'Keyword title',
|
||||
},
|
||||
searchMethod: RetrievalSearchMethodEnum.keywordSearch,
|
||||
}
|
||||
|
||||
render(
|
||||
<SearchMethodOption
|
||||
{...props}
|
||||
option={{
|
||||
...props.option,
|
||||
id: RetrievalSearchMethodEnum.keywordSearch,
|
||||
title: 'Keyword title',
|
||||
}}
|
||||
searchMethod={RetrievalSearchMethodEnum.keywordSearch}
|
||||
/>,
|
||||
)
|
||||
renderSearchMethodOption(keywordProps)
|
||||
|
||||
fireEvent.change(screen.getByRole('textbox'), { target: { value: '9' } })
|
||||
|
||||
expect(screen.getAllByRole('textbox')).toHaveLength(1)
|
||||
expect(screen.queryAllByRole('switch')).toHaveLength(0)
|
||||
expect(props.onTopKChange).toHaveBeenCalledWith(9)
|
||||
expect(props.retrievalParameters.topK.onChange).toHaveBeenCalledWith(9)
|
||||
})
|
||||
})
|
||||
|
||||
@ -1,40 +1,44 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import TopKAndScoreThreshold from '../top-k-and-score-threshold'
|
||||
import { TopKAndScoreThreshold } from '../top-k-and-score-threshold'
|
||||
|
||||
describe('TopKAndScoreThreshold', () => {
|
||||
const topKLabel = /datasetConfig\.top_k/
|
||||
const scoreThresholdLabel = /datasetConfig\.score_threshold/
|
||||
const defaultProps = {
|
||||
topK: 3,
|
||||
onTopKChange: vi.fn(),
|
||||
scoreThreshold: 0.4,
|
||||
onScoreThresholdChange: vi.fn(),
|
||||
isScoreThresholdEnabled: true,
|
||||
onScoreThresholdEnabledChange: vi.fn(),
|
||||
topK: {
|
||||
value: 3,
|
||||
onChange: vi.fn(),
|
||||
},
|
||||
scoreThreshold: {
|
||||
value: 0.4,
|
||||
onChange: vi.fn(),
|
||||
enabled: true,
|
||||
onEnabledChange: vi.fn(),
|
||||
},
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should round top-k input values before notifying the parent', () => {
|
||||
it('should notify top-k input values without additional rounding', () => {
|
||||
render(<TopKAndScoreThreshold {...defaultProps} />)
|
||||
|
||||
const [topKInput] = screen.getAllByRole('textbox')
|
||||
fireEvent.change(topKInput!, { target: { value: '3.7' } })
|
||||
fireEvent.change(screen.getByRole('textbox', { name: topKLabel }), { target: { value: '3.7' } })
|
||||
|
||||
expect(defaultProps.onTopKChange).toHaveBeenCalledWith(4)
|
||||
expect(defaultProps.topK.onChange).toHaveBeenCalledWith(3.7)
|
||||
})
|
||||
|
||||
it('should round score-threshold input values to two decimals', () => {
|
||||
it('should notify score-threshold input values without additional rounding', () => {
|
||||
render(<TopKAndScoreThreshold {...defaultProps} />)
|
||||
|
||||
const [, scoreThresholdInput] = screen.getAllByRole('textbox')
|
||||
fireEvent.change(scoreThresholdInput!, { target: { value: '0.456' } })
|
||||
fireEvent.change(screen.getByRole('textbox', { name: scoreThresholdLabel }), { target: { value: '0.456' } })
|
||||
|
||||
expect(defaultProps.onScoreThresholdChange).toHaveBeenCalledWith(0.46)
|
||||
expect(defaultProps.scoreThreshold.onChange).toHaveBeenCalledWith(0.456)
|
||||
})
|
||||
|
||||
it('should hide the score-threshold column when requested', () => {
|
||||
render(<TopKAndScoreThreshold {...defaultProps} hiddenScoreThreshold />)
|
||||
render(<TopKAndScoreThreshold {...defaultProps} scoreThreshold={{ hidden: true }} />)
|
||||
|
||||
expect(screen.getAllByRole('textbox')).toHaveLength(1)
|
||||
expect(screen.queryByRole('switch')).not.toBeInTheDocument()
|
||||
@ -44,15 +48,18 @@ describe('TopKAndScoreThreshold', () => {
|
||||
render(
|
||||
<TopKAndScoreThreshold
|
||||
{...defaultProps}
|
||||
scoreThreshold={undefined}
|
||||
isScoreThresholdEnabled
|
||||
scoreThreshold={{
|
||||
...defaultProps.scoreThreshold,
|
||||
value: undefined,
|
||||
enabled: true,
|
||||
}}
|
||||
/>,
|
||||
)
|
||||
|
||||
const [topKInput, scoreThresholdInput] = screen.getAllByRole('textbox')
|
||||
fireEvent.change(topKInput!, { target: { value: '' } })
|
||||
|
||||
expect(defaultProps.onTopKChange).toHaveBeenCalledWith(0)
|
||||
expect(defaultProps.topK.onChange).toHaveBeenCalledWith(0)
|
||||
expect(scoreThresholdInput)!.toHaveValue('')
|
||||
})
|
||||
|
||||
@ -60,10 +67,13 @@ describe('TopKAndScoreThreshold', () => {
|
||||
render(
|
||||
<TopKAndScoreThreshold
|
||||
{...defaultProps}
|
||||
isScoreThresholdEnabled={undefined}
|
||||
scoreThreshold={{
|
||||
...defaultProps.scoreThreshold,
|
||||
enabled: undefined,
|
||||
}}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByRole('switch'))!.toHaveAttribute('aria-checked', 'false')
|
||||
expect(screen.getByRole('switch', { name: scoreThresholdLabel }))!.toHaveAttribute('aria-checked', 'false')
|
||||
})
|
||||
})
|
||||
|
||||
@ -5,7 +5,13 @@ import type {
|
||||
WeightedScore,
|
||||
} from '../../types'
|
||||
import type { RerankingModelSelectorProps } from './reranking-model-selector'
|
||||
import type { TopKAndScoreThresholdProps } from './top-k-and-score-threshold'
|
||||
import type {
|
||||
TopKFieldProps,
|
||||
VisibleScoreThresholdFieldProps,
|
||||
} from './top-k-and-score-threshold'
|
||||
import { FieldRoot } from '@langgenius/dify-ui/field'
|
||||
import { FieldsetLegend, FieldsetRoot } from '@langgenius/dify-ui/fieldset'
|
||||
import { RadioGroup } from '@langgenius/dify-ui/radio-group'
|
||||
import {
|
||||
memo,
|
||||
} from 'react'
|
||||
@ -13,7 +19,7 @@ import { useTranslation } from 'react-i18next'
|
||||
import { Field } from '@/app/components/workflow/nodes/_base/components/layout'
|
||||
import { useDocLink } from '@/context/i18n'
|
||||
import { useRetrievalSetting } from './hooks'
|
||||
import SearchMethodOption from './search-method-option'
|
||||
import { SearchMethodOption } from './search-method-option'
|
||||
|
||||
type RetrievalSettingProps = {
|
||||
indexMethod?: IndexMethodEnum
|
||||
@ -23,11 +29,18 @@ type RetrievalSettingProps = {
|
||||
hybridSearchMode?: HybridSearchModeEnum
|
||||
onHybridSearchModeChange: (value: HybridSearchModeEnum) => void
|
||||
rerankingModelEnabled?: boolean
|
||||
onRerankingModelEnabledChange?: (value: boolean) => void
|
||||
onRerankingModelEnabledChange: (value: boolean) => void
|
||||
weightedScore?: WeightedScore
|
||||
onWeightedScoreChange: (value: { value: number[] }) => void
|
||||
showMultiModalTip?: boolean
|
||||
} & RerankingModelSelectorProps & TopKAndScoreThresholdProps
|
||||
} & RerankingModelSelectorProps & {
|
||||
topK: TopKFieldProps['value']
|
||||
onTopKChange: TopKFieldProps['onChange']
|
||||
scoreThreshold: VisibleScoreThresholdFieldProps['value']
|
||||
onScoreThresholdChange: VisibleScoreThresholdFieldProps['onChange']
|
||||
isScoreThresholdEnabled?: VisibleScoreThresholdFieldProps['enabled']
|
||||
onScoreThresholdEnabledChange: VisibleScoreThresholdFieldProps['onEnabledChange']
|
||||
}
|
||||
|
||||
const RetrievalSetting = ({
|
||||
indexMethod,
|
||||
@ -70,35 +83,56 @@ const RetrievalSetting = ({
|
||||
),
|
||||
}}
|
||||
>
|
||||
<div className="space-y-1">
|
||||
{
|
||||
options.map(option => (
|
||||
<FieldRoot name="retrieval_search_method" className="gap-0">
|
||||
<FieldsetRoot
|
||||
render={(
|
||||
<RadioGroup<RetrievalSearchMethodEnum>
|
||||
value={searchMethod}
|
||||
onValueChange={value => onRetrievalSearchMethodChange(value)}
|
||||
disabled={readonly}
|
||||
className="flex-col items-stretch gap-1"
|
||||
/>
|
||||
)}
|
||||
>
|
||||
<FieldsetLegend className="sr-only">
|
||||
{t('form.retrievalSetting.title', { ns: 'datasetSettings' })}
|
||||
</FieldsetLegend>
|
||||
{options.map(option => (
|
||||
<SearchMethodOption
|
||||
key={option.id}
|
||||
option={option}
|
||||
hybridSearchModeOptions={hybridSearchModeOptions}
|
||||
searchMethod={searchMethod}
|
||||
onRetrievalSearchMethodChange={onRetrievalSearchMethodChange}
|
||||
hybridSearchMode={hybridSearchMode}
|
||||
onHybridSearchModeChange={onHybridSearchModeChange}
|
||||
weightedScore={weightedScore}
|
||||
onWeightedScoreChange={onWeightedScoreChange}
|
||||
topK={topK}
|
||||
onTopKChange={onTopKChange}
|
||||
scoreThreshold={scoreThreshold}
|
||||
onScoreThresholdChange={onScoreThresholdChange}
|
||||
isScoreThresholdEnabled={isScoreThresholdEnabled}
|
||||
onScoreThresholdEnabledChange={onScoreThresholdEnabledChange}
|
||||
rerankingModelEnabled={rerankingModelEnabled}
|
||||
onRerankingModelEnabledChange={onRerankingModelEnabledChange}
|
||||
rerankingModel={rerankingModel}
|
||||
onRerankingModelChange={onRerankingModelChange}
|
||||
hybridSearch={{
|
||||
mode: hybridSearchMode,
|
||||
options: hybridSearchModeOptions,
|
||||
onModeChange: onHybridSearchModeChange,
|
||||
weightedScore,
|
||||
onWeightedScoreChange,
|
||||
}}
|
||||
retrievalParameters={{
|
||||
topK: {
|
||||
value: topK,
|
||||
onChange: onTopKChange,
|
||||
},
|
||||
scoreThreshold: {
|
||||
value: scoreThreshold,
|
||||
onChange: onScoreThresholdChange,
|
||||
enabled: isScoreThresholdEnabled,
|
||||
onEnabledChange: onScoreThresholdEnabledChange,
|
||||
},
|
||||
}}
|
||||
reranking={{
|
||||
enabled: rerankingModelEnabled,
|
||||
onEnabledChange: onRerankingModelEnabledChange,
|
||||
rerankingModel,
|
||||
onRerankingModelChange,
|
||||
showMultiModalTip,
|
||||
}}
|
||||
readonly={readonly}
|
||||
showMultiModalTip={showMultiModalTip}
|
||||
/>
|
||||
))
|
||||
}
|
||||
</div>
|
||||
))}
|
||||
</FieldsetRoot>
|
||||
</FieldRoot>
|
||||
</Field>
|
||||
)
|
||||
}
|
||||
|
||||
@ -1,221 +1,358 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import type {
|
||||
WeightedScore,
|
||||
} from '../../types'
|
||||
import type { RerankingModelSelectorProps } from './reranking-model-selector'
|
||||
import type { TopKAndScoreThresholdProps } from './top-k-and-score-threshold'
|
||||
import type {
|
||||
TopKFieldProps,
|
||||
VisibleScoreThresholdFieldProps,
|
||||
} from './top-k-and-score-threshold'
|
||||
import type {
|
||||
HybridSearchModeOption,
|
||||
Option,
|
||||
} from './type'
|
||||
import { cn } from '@langgenius/dify-ui/cn'
|
||||
import { FieldItem, FieldLabel, FieldRoot } from '@langgenius/dify-ui/field'
|
||||
import { FieldsetLegend, FieldsetRoot } from '@langgenius/dify-ui/fieldset'
|
||||
import { RadioControl, RadioRoot } from '@langgenius/dify-ui/radio'
|
||||
import { RadioGroup } from '@langgenius/dify-ui/radio-group'
|
||||
import { Switch } from '@langgenius/dify-ui/switch'
|
||||
import {
|
||||
memo,
|
||||
useCallback,
|
||||
useMemo,
|
||||
} from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import WeightedScoreComponent from '@/app/components/app/configuration/dataset-config/params-config/weighted-score'
|
||||
import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback'
|
||||
import Badge from '@/app/components/base/badge'
|
||||
import {
|
||||
OptionCardEffectBlue,
|
||||
OptionCardEffectBlueLight,
|
||||
OptionCardEffectOrange,
|
||||
OptionCardEffectPurple,
|
||||
OptionCardEffectTeal,
|
||||
} from '@/app/components/base/icons/src/public/knowledge'
|
||||
import { Infotip } from '@/app/components/base/infotip'
|
||||
import { DEFAULT_WEIGHTED_SCORE } from '@/models/datasets'
|
||||
import {
|
||||
HybridSearchModeEnum,
|
||||
RetrievalSearchMethodEnum,
|
||||
} from '../../types'
|
||||
import OptionCard from '../option-card'
|
||||
import RerankingModelSelector from './reranking-model-selector'
|
||||
import TopKAndScoreThreshold from './top-k-and-score-threshold'
|
||||
import { TopKAndScoreThreshold } from './top-k-and-score-threshold'
|
||||
|
||||
type SearchMethodOptionProps = {
|
||||
readonly?: boolean
|
||||
option: Option
|
||||
hybridSearchModeOptions: HybridSearchModeOption[]
|
||||
searchMethod?: RetrievalSearchMethodEnum
|
||||
onRetrievalSearchMethodChange: (value: RetrievalSearchMethodEnum) => void
|
||||
hybridSearchMode?: HybridSearchModeEnum
|
||||
onHybridSearchModeChange: (value: HybridSearchModeEnum) => void
|
||||
type HybridSearchConfig = {
|
||||
mode?: HybridSearchModeEnum
|
||||
options: HybridSearchModeOption[]
|
||||
onModeChange: (value: HybridSearchModeEnum) => void
|
||||
weightedScore?: WeightedScore
|
||||
onWeightedScoreChange: (value: { value: number[] }) => void
|
||||
rerankingModelEnabled?: boolean
|
||||
onRerankingModelEnabledChange?: (value: boolean) => void
|
||||
}
|
||||
|
||||
type RerankingConfig = RerankingModelSelectorProps & {
|
||||
enabled?: boolean
|
||||
onEnabledChange: (value: boolean) => void
|
||||
showMultiModalTip?: boolean
|
||||
} & RerankingModelSelectorProps & TopKAndScoreThresholdProps
|
||||
const SearchMethodOption = ({
|
||||
readonly,
|
||||
option,
|
||||
hybridSearchModeOptions,
|
||||
searchMethod,
|
||||
onRetrievalSearchMethodChange,
|
||||
hybridSearchMode,
|
||||
onHybridSearchModeChange,
|
||||
weightedScore,
|
||||
onWeightedScoreChange,
|
||||
rerankingModelEnabled,
|
||||
onRerankingModelEnabledChange,
|
||||
rerankingModel,
|
||||
onRerankingModelChange,
|
||||
topK,
|
||||
onTopKChange,
|
||||
scoreThreshold,
|
||||
onScoreThresholdChange,
|
||||
isScoreThresholdEnabled,
|
||||
onScoreThresholdEnabledChange,
|
||||
showMultiModalTip = false,
|
||||
}: SearchMethodOptionProps) => {
|
||||
const { t } = useTranslation()
|
||||
const Icon = option.icon
|
||||
const isHybridSearch = option.id === RetrievalSearchMethodEnum.hybrid
|
||||
const isHybridSearchWeightedScoreMode = hybridSearchMode === HybridSearchModeEnum.WeightedScore
|
||||
}
|
||||
|
||||
const weightedScoreValue = useMemo(() => {
|
||||
const sematicWeightedScore = weightedScore?.vector_setting.vector_weight ?? DEFAULT_WEIGHTED_SCORE.other.semantic
|
||||
const keywordWeightedScore = weightedScore?.keyword_setting.keyword_weight ?? DEFAULT_WEIGHTED_SCORE.other.keyword
|
||||
const mergedValue = [sematicWeightedScore, keywordWeightedScore]
|
||||
type RetrievalParametersConfig = {
|
||||
topK: TopKFieldProps
|
||||
scoreThreshold: VisibleScoreThresholdFieldProps
|
||||
}
|
||||
|
||||
return {
|
||||
value: mergedValue,
|
||||
}
|
||||
}, [weightedScore])
|
||||
type SearchMethodRadioCardProps = {
|
||||
option: Option
|
||||
searchMethod?: RetrievalSearchMethodEnum
|
||||
readonly?: boolean
|
||||
isRecommended?: boolean
|
||||
children?: ReactNode
|
||||
}
|
||||
|
||||
const icon = useCallback((isActive: boolean) => {
|
||||
return (
|
||||
<Icon
|
||||
className={cn(
|
||||
'h-[15px] w-[15px] text-text-tertiary group-hover:text-util-colors-purple-purple-600',
|
||||
isActive && 'text-util-colors-purple-purple-600',
|
||||
)}
|
||||
/>
|
||||
)
|
||||
}, [Icon])
|
||||
export type SearchMethodOptionProps = {
|
||||
readonly?: boolean
|
||||
option: Option
|
||||
searchMethod?: RetrievalSearchMethodEnum
|
||||
hybridSearch: HybridSearchConfig
|
||||
reranking: RerankingConfig
|
||||
retrievalParameters: RetrievalParametersConfig
|
||||
}
|
||||
|
||||
const hybridSearchModeWrapperClassName = useCallback((isActive: boolean) => {
|
||||
return isActive ? 'border-[1.5px] bg-components-option-card-option-selected-bg' : ''
|
||||
}, [])
|
||||
const HEADER_EFFECT_MAP: Record<string, ReactNode> = {
|
||||
'blue': <OptionCardEffectBlue />,
|
||||
'blue-light': <OptionCardEffectBlueLight />,
|
||||
'orange': <OptionCardEffectOrange />,
|
||||
'purple': <OptionCardEffectPurple />,
|
||||
'teal': <OptionCardEffectTeal />,
|
||||
}
|
||||
|
||||
const showRerankModelSelectorSwitch = useMemo(() => {
|
||||
if (searchMethod === RetrievalSearchMethodEnum.semantic)
|
||||
return true
|
||||
function getWeightedScoreValue(weightedScore?: WeightedScore) {
|
||||
const semanticWeightedScore = weightedScore?.vector_setting.vector_weight ?? DEFAULT_WEIGHTED_SCORE.other.semantic
|
||||
const keywordWeightedScore = weightedScore?.keyword_setting.keyword_weight ?? DEFAULT_WEIGHTED_SCORE.other.keyword
|
||||
|
||||
if (searchMethod === RetrievalSearchMethodEnum.fullText)
|
||||
return true
|
||||
return {
|
||||
value: [semanticWeightedScore, keywordWeightedScore],
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}, [searchMethod])
|
||||
const showRerankModelSelector = useMemo(() => {
|
||||
if (searchMethod === RetrievalSearchMethodEnum.semantic)
|
||||
return true
|
||||
function shouldShowRerankModelSelectorSwitch(searchMethod?: RetrievalSearchMethodEnum) {
|
||||
return searchMethod === RetrievalSearchMethodEnum.semantic || searchMethod === RetrievalSearchMethodEnum.fullText
|
||||
}
|
||||
|
||||
if (searchMethod === RetrievalSearchMethodEnum.fullText)
|
||||
return true
|
||||
function shouldShowRerankModelSelector(searchMethod: RetrievalSearchMethodEnum | undefined, hybridSearchMode: HybridSearchModeEnum | undefined) {
|
||||
if (shouldShowRerankModelSelectorSwitch(searchMethod))
|
||||
return true
|
||||
|
||||
if (searchMethod === RetrievalSearchMethodEnum.hybrid && hybridSearchMode !== HybridSearchModeEnum.WeightedScore)
|
||||
return true
|
||||
return searchMethod === RetrievalSearchMethodEnum.hybrid && hybridSearchMode !== HybridSearchModeEnum.WeightedScore
|
||||
}
|
||||
|
||||
return false
|
||||
}, [hybridSearchMode, searchMethod])
|
||||
function getSearchMethodEffect(effectColor: string | undefined, isActive: boolean) {
|
||||
const effect = effectColor ? HEADER_EFFECT_MAP[effectColor] : undefined
|
||||
|
||||
if (!effect)
|
||||
return null
|
||||
|
||||
return (
|
||||
<OptionCard
|
||||
key={option.id}
|
||||
id={option.id}
|
||||
selectedId={searchMethod}
|
||||
icon={icon}
|
||||
title={option.title}
|
||||
description={option.description}
|
||||
effectColor={option.effectColor}
|
||||
isRecommended={option.id === RetrievalSearchMethodEnum.hybrid}
|
||||
onClick={onRetrievalSearchMethodChange}
|
||||
readonly={readonly}
|
||||
<div
|
||||
className={cn(
|
||||
'absolute -top-0.5 -left-0.5 hidden h-14 w-14 rounded-full',
|
||||
'group-hover/search-method-radio:block',
|
||||
isActive && 'block',
|
||||
)}
|
||||
>
|
||||
<div className="space-y-3">
|
||||
{
|
||||
isHybridSearch && (
|
||||
<div className="space-y-1">
|
||||
{
|
||||
hybridSearchModeOptions.map(hybridOption => (
|
||||
<OptionCard
|
||||
key={hybridOption.id}
|
||||
id={hybridOption.id}
|
||||
selectedId={hybridSearchMode}
|
||||
enableHighlightBorder={false}
|
||||
enableRadio
|
||||
wrapperClassName={hybridSearchModeWrapperClassName}
|
||||
className="p-3"
|
||||
title={hybridOption.title}
|
||||
description={hybridOption.description}
|
||||
onClick={onHybridSearchModeChange}
|
||||
readonly={readonly}
|
||||
/>
|
||||
))
|
||||
}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{
|
||||
isHybridSearch && isHybridSearchWeightedScoreMode && (
|
||||
<WeightedScoreComponent
|
||||
value={weightedScoreValue}
|
||||
onChange={onWeightedScoreChange}
|
||||
readonly={readonly}
|
||||
/>
|
||||
)
|
||||
}
|
||||
{
|
||||
showRerankModelSelector && (
|
||||
<div>
|
||||
{
|
||||
showRerankModelSelectorSwitch && (
|
||||
<div className="mb-1 flex items-center system-sm-semibold text-text-secondary">
|
||||
<Switch
|
||||
className="mr-1"
|
||||
checked={rerankingModelEnabled ?? false}
|
||||
onCheckedChange={onRerankingModelEnabledChange}
|
||||
disabled={readonly}
|
||||
/>
|
||||
{t('modelProvider.rerankModel.key', { ns: 'common' })}
|
||||
<Infotip
|
||||
aria-label={t('modelProvider.rerankModel.tip', { ns: 'common' })}
|
||||
className="ml-0.5 size-3.5 shrink-0"
|
||||
iconClassName="h-3.5 w-3.5"
|
||||
>
|
||||
{t('modelProvider.rerankModel.tip', { ns: 'common' })}
|
||||
</Infotip>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
<RerankingModelSelector
|
||||
rerankingModel={rerankingModel}
|
||||
onRerankingModelChange={onRerankingModelChange}
|
||||
readonly={readonly}
|
||||
/>
|
||||
{showMultiModalTip && (
|
||||
<div className="mt-2 flex h-10 items-center gap-x-0.5 overflow-hidden rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur p-2 shadow-xs backdrop-blur-[5px]">
|
||||
<div className="absolute inset-0 bg-dataset-warning-message-bg opacity-40" />
|
||||
<div className="p-1">
|
||||
<AlertTriangle className="size-4 text-text-warning-secondary" />
|
||||
</div>
|
||||
<span className="system-xs-medium text-text-primary">
|
||||
{t('form.retrievalSetting.multiModalTip', { ns: 'datasetSettings' })}
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
<TopKAndScoreThreshold
|
||||
topK={topK}
|
||||
onTopKChange={onTopKChange}
|
||||
scoreThreshold={scoreThreshold}
|
||||
onScoreThresholdChange={onScoreThresholdChange}
|
||||
isScoreThresholdEnabled={isScoreThresholdEnabled}
|
||||
onScoreThresholdEnabledChange={onScoreThresholdEnabledChange}
|
||||
readonly={readonly}
|
||||
hiddenScoreThreshold={searchMethod === RetrievalSearchMethodEnum.keywordSearch}
|
||||
/>
|
||||
</div>
|
||||
</OptionCard>
|
||||
{effect}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(SearchMethodOption)
|
||||
function renderSearchMethodIcon(Icon: Option['icon'], isActive: boolean) {
|
||||
return (
|
||||
<Icon
|
||||
className={cn(
|
||||
'h-3.75 w-3.75 text-text-tertiary group-hover:text-util-colors-purple-purple-600',
|
||||
isActive && 'text-util-colors-purple-purple-600',
|
||||
)}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
function SearchMethodRadioCard({
|
||||
option,
|
||||
searchMethod,
|
||||
readonly,
|
||||
isRecommended,
|
||||
children,
|
||||
}: SearchMethodRadioCardProps) {
|
||||
const { t } = useTranslation()
|
||||
const isActive = option.id === searchMethod
|
||||
const Icon = option.icon
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
'group/search-method-radio overflow-hidden rounded-xl border border-components-option-card-option-border bg-components-option-card-option-bg',
|
||||
'has-data-checked:border-[1.5px] has-data-checked:border-components-option-card-option-selected-border',
|
||||
!readonly && 'cursor-pointer hover:shadow-xs',
|
||||
readonly && 'cursor-not-allowed',
|
||||
)}
|
||||
>
|
||||
<RadioRoot
|
||||
value={option.id}
|
||||
variant="unstyled"
|
||||
nativeButton
|
||||
render={<button type="button" />}
|
||||
disabled={readonly}
|
||||
className={cn(
|
||||
'relative flex w-full rounded-t-xl p-2 text-left outline-hidden focus-visible:ring-1 focus-visible:ring-components-input-border-active',
|
||||
readonly ? 'cursor-not-allowed' : 'cursor-pointer',
|
||||
)}
|
||||
>
|
||||
{getSearchMethodEffect(option.effectColor, isActive)}
|
||||
<div className="mr-1 flex h-4.5 w-4.5 shrink-0 items-center justify-center">
|
||||
{renderSearchMethodIcon(Icon, isActive)}
|
||||
</div>
|
||||
<div className="grow py-1 pt-px">
|
||||
<div className="flex items-center">
|
||||
<div className="flex grow items-center system-sm-medium text-text-secondary">
|
||||
{option.title}
|
||||
{isRecommended
|
||||
? (
|
||||
<Badge className="ml-1 h-4 border-text-accent-secondary text-text-accent-secondary">
|
||||
{t('stepTwo.recommend', { ns: 'datasetCreation' })}
|
||||
</Badge>
|
||||
)
|
||||
: null}
|
||||
</div>
|
||||
</div>
|
||||
{option.description
|
||||
? (
|
||||
<div className="mt-1 system-xs-regular text-text-tertiary">
|
||||
{option.description}
|
||||
</div>
|
||||
)
|
||||
: null}
|
||||
</div>
|
||||
</RadioRoot>
|
||||
{!!(children && isActive) && (
|
||||
<div className="relative rounded-b-xl bg-components-panel-bg p-3">
|
||||
<div className="absolute -top-2.75 left-3.5 i-custom-vender-knowledge-arrow-shape h-4 w-4 text-components-panel-bg" />
|
||||
{children}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
function HybridSearchModeRadioCard({
|
||||
option,
|
||||
readonly,
|
||||
}: {
|
||||
option: HybridSearchModeOption
|
||||
readonly?: boolean
|
||||
}) {
|
||||
return (
|
||||
<FieldItem>
|
||||
<RadioRoot
|
||||
value={option.id}
|
||||
variant="unstyled"
|
||||
nativeButton
|
||||
render={<button type="button" />}
|
||||
disabled={readonly}
|
||||
className={cn(
|
||||
'w-full rounded-xl border border-components-option-card-option-border bg-components-option-card-option-bg p-3 text-left outline-hidden transition-colors',
|
||||
'data-checked:border-[1.5px] data-checked:bg-components-option-card-option-selected-bg',
|
||||
'focus-visible:ring-1 focus-visible:ring-components-input-border-active',
|
||||
readonly ? 'cursor-not-allowed' : 'cursor-pointer hover:shadow-xs',
|
||||
)}
|
||||
>
|
||||
<div className="flex items-start gap-2">
|
||||
<div className="min-w-0 grow">
|
||||
<div className="system-sm-medium text-text-secondary">
|
||||
{option.title}
|
||||
</div>
|
||||
<div className="mt-1 system-xs-regular text-text-tertiary">
|
||||
{option.description}
|
||||
</div>
|
||||
</div>
|
||||
<RadioControl className="mt-0.5" aria-hidden="true" />
|
||||
</div>
|
||||
</RadioRoot>
|
||||
</FieldItem>
|
||||
)
|
||||
}
|
||||
|
||||
export function SearchMethodOption({
|
||||
readonly,
|
||||
option,
|
||||
searchMethod,
|
||||
hybridSearch,
|
||||
reranking,
|
||||
retrievalParameters,
|
||||
}: SearchMethodOptionProps) {
|
||||
const { t } = useTranslation()
|
||||
const isHybridSearch = option.id === RetrievalSearchMethodEnum.hybrid
|
||||
const isHybridSearchWeightedScoreMode = hybridSearch.mode === HybridSearchModeEnum.WeightedScore
|
||||
const showRerankModelSelectorSwitch = shouldShowRerankModelSelectorSwitch(option.id)
|
||||
const showRerankModelSelector = shouldShowRerankModelSelector(option.id, hybridSearch.mode)
|
||||
const rerankModelLabel = t('modelProvider.rerankModel.key', { ns: 'common' })
|
||||
const rerankModelTip = t('modelProvider.rerankModel.tip', { ns: 'common' })
|
||||
const scoreThresholdHidden = option.id === RetrievalSearchMethodEnum.keywordSearch
|
||||
const config = (
|
||||
<div className="space-y-3">
|
||||
{isHybridSearch
|
||||
? (
|
||||
<FieldRoot name="hybrid_search_mode" className="gap-0">
|
||||
<FieldsetRoot
|
||||
render={(
|
||||
<RadioGroup<HybridSearchModeEnum>
|
||||
value={hybridSearch.mode}
|
||||
onValueChange={value => hybridSearch.onModeChange(value)}
|
||||
disabled={readonly}
|
||||
className="flex-col items-stretch gap-1"
|
||||
/>
|
||||
)}
|
||||
>
|
||||
<FieldsetLegend className="sr-only">Hybrid search mode</FieldsetLegend>
|
||||
{hybridSearch.options.map(hybridOption => (
|
||||
<HybridSearchModeRadioCard
|
||||
key={hybridOption.id}
|
||||
option={hybridOption}
|
||||
readonly={readonly}
|
||||
/>
|
||||
))}
|
||||
</FieldsetRoot>
|
||||
</FieldRoot>
|
||||
)
|
||||
: null}
|
||||
{isHybridSearch && isHybridSearchWeightedScoreMode
|
||||
? (
|
||||
<WeightedScoreComponent
|
||||
value={getWeightedScoreValue(hybridSearch.weightedScore)}
|
||||
onChange={hybridSearch.onWeightedScoreChange}
|
||||
readonly={readonly}
|
||||
/>
|
||||
)
|
||||
: null}
|
||||
{showRerankModelSelector
|
||||
? (
|
||||
<div>
|
||||
{showRerankModelSelectorSwitch
|
||||
? (
|
||||
<FieldRoot name="reranking_model_enabled" className="mb-1 gap-0">
|
||||
<div className="flex items-center">
|
||||
<FieldLabel className="flex min-w-0 items-center py-0 system-sm-semibold text-text-secondary">
|
||||
<Switch
|
||||
className="mr-1"
|
||||
checked={reranking.enabled ?? false}
|
||||
onCheckedChange={reranking.onEnabledChange}
|
||||
disabled={readonly}
|
||||
/>
|
||||
<span className="truncate">{rerankModelLabel}</span>
|
||||
</FieldLabel>
|
||||
<Infotip
|
||||
aria-label={rerankModelTip}
|
||||
className="ml-0.5 size-3.5 shrink-0"
|
||||
iconClassName="h-3.5 w-3.5"
|
||||
>
|
||||
{rerankModelTip}
|
||||
</Infotip>
|
||||
</div>
|
||||
</FieldRoot>
|
||||
)
|
||||
: null}
|
||||
<RerankingModelSelector
|
||||
rerankingModel={reranking.rerankingModel}
|
||||
onRerankingModelChange={reranking.onRerankingModelChange}
|
||||
readonly={readonly}
|
||||
/>
|
||||
{reranking.showMultiModalTip
|
||||
? (
|
||||
<div className="mt-2 flex h-10 items-center gap-x-0.5 overflow-hidden rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur p-2 shadow-xs backdrop-blur-[5px]">
|
||||
<div className="absolute inset-0 bg-dataset-warning-message-bg opacity-40" />
|
||||
<div className="p-1">
|
||||
<div className="i-custom-vender-solid-alertsAndFeedback-alert-triangle size-4 text-text-warning-secondary" />
|
||||
</div>
|
||||
<span className="system-xs-medium text-text-primary">
|
||||
{t('form.retrievalSetting.multiModalTip', { ns: 'datasetSettings' })}
|
||||
</span>
|
||||
</div>
|
||||
)
|
||||
: null}
|
||||
</div>
|
||||
)
|
||||
: null}
|
||||
<TopKAndScoreThreshold
|
||||
topK={retrievalParameters.topK}
|
||||
scoreThreshold={scoreThresholdHidden ? { hidden: true } : retrievalParameters.scoreThreshold}
|
||||
readonly={readonly}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
|
||||
return (
|
||||
<FieldItem>
|
||||
<SearchMethodRadioCard
|
||||
option={option}
|
||||
searchMethod={searchMethod}
|
||||
isRecommended={option.id === RetrievalSearchMethodEnum.hybrid}
|
||||
readonly={readonly}
|
||||
>
|
||||
{config}
|
||||
</SearchMethodRadioCard>
|
||||
</FieldItem>
|
||||
)
|
||||
}
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import { FieldLabel, FieldRoot } from '@langgenius/dify-ui/field'
|
||||
import { FieldsetLegend, FieldsetRoot } from '@langgenius/dify-ui/fieldset'
|
||||
import {
|
||||
NumberField,
|
||||
NumberFieldControls,
|
||||
@ -7,27 +9,39 @@ import {
|
||||
NumberFieldInput,
|
||||
} from '@langgenius/dify-ui/number-field'
|
||||
import { Switch } from '@langgenius/dify-ui/switch'
|
||||
import { memo, useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Infotip } from '@/app/components/base/infotip'
|
||||
import { env } from '@/env'
|
||||
|
||||
export type TopKAndScoreThresholdProps = {
|
||||
topK: number
|
||||
onTopKChange: (value: number) => void
|
||||
scoreThreshold?: number
|
||||
onScoreThresholdChange?: (value: number) => void
|
||||
isScoreThresholdEnabled?: boolean
|
||||
onScoreThresholdEnabledChange?: (value: boolean) => void
|
||||
readonly?: boolean
|
||||
hiddenScoreThreshold?: boolean
|
||||
export type TopKFieldProps = {
|
||||
value: number
|
||||
onChange: (value: number) => void
|
||||
}
|
||||
|
||||
export type VisibleScoreThresholdFieldProps = {
|
||||
hidden?: false
|
||||
value?: number
|
||||
onChange: (value: number) => void
|
||||
enabled?: boolean
|
||||
onEnabledChange: (value: boolean) => void
|
||||
}
|
||||
|
||||
type ScoreThresholdFieldProps
|
||||
= | VisibleScoreThresholdFieldProps
|
||||
| {
|
||||
hidden: true
|
||||
}
|
||||
|
||||
export type TopKAndScoreThresholdProps = {
|
||||
topK: TopKFieldProps
|
||||
scoreThreshold: ScoreThresholdFieldProps
|
||||
readonly?: boolean
|
||||
}
|
||||
|
||||
const maxTopK = env.NEXT_PUBLIC_TOP_K_MAX_VALUE
|
||||
const TOP_K_VALUE_LIMIT = {
|
||||
amount: 1,
|
||||
step: 1,
|
||||
min: 1,
|
||||
max: maxTopK,
|
||||
max: env.NEXT_PUBLIC_TOP_K_MAX_VALUE,
|
||||
}
|
||||
const SCORE_THRESHOLD_VALUE_LIMIT = {
|
||||
step: 0.01,
|
||||
@ -35,99 +49,99 @@ const SCORE_THRESHOLD_VALUE_LIMIT = {
|
||||
max: 1,
|
||||
}
|
||||
|
||||
const TopKAndScoreThreshold = ({
|
||||
export function TopKAndScoreThreshold({
|
||||
topK,
|
||||
onTopKChange,
|
||||
scoreThreshold,
|
||||
onScoreThresholdChange,
|
||||
isScoreThresholdEnabled,
|
||||
onScoreThresholdEnabledChange,
|
||||
readonly,
|
||||
hiddenScoreThreshold,
|
||||
}: TopKAndScoreThresholdProps) => {
|
||||
}: TopKAndScoreThresholdProps) {
|
||||
const { t } = useTranslation()
|
||||
const topKLabel = t('datasetConfig.top_k', { ns: 'appDebug' })
|
||||
const scoreThresholdLabel = t('datasetConfig.score_threshold', { ns: 'appDebug' })
|
||||
const handleTopKChange = useCallback((value: number) => {
|
||||
onTopKChange?.(Number.parseInt(value.toFixed(0)))
|
||||
}, [onTopKChange])
|
||||
|
||||
const handleScoreThresholdChange = (value: number) => {
|
||||
onScoreThresholdChange?.(Number.parseFloat(value.toFixed(2)))
|
||||
}
|
||||
const topKTip = t('datasetConfig.top_kTip', { ns: 'appDebug' })
|
||||
const scoreThresholdTip = t('datasetConfig.score_thresholdTip', { ns: 'appDebug' })
|
||||
const scoreThresholdHidden = scoreThreshold.hidden === true
|
||||
const scoreThresholdEnabled = scoreThresholdHidden ? false : (scoreThreshold.enabled ?? false)
|
||||
|
||||
return (
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<div>
|
||||
<div className="mb-0.5 flex h-6 items-center system-xs-medium text-text-secondary">
|
||||
{topKLabel}
|
||||
<FieldRoot name="top_k" className="gap-0">
|
||||
<div className="mb-0.5 flex h-6 items-center">
|
||||
<FieldLabel className="py-0 system-xs-medium text-text-secondary">
|
||||
{topKLabel}
|
||||
</FieldLabel>
|
||||
<Infotip
|
||||
aria-label={t('datasetConfig.top_kTip', { ns: 'appDebug' })}
|
||||
aria-label={topKTip}
|
||||
className="ml-0.5 size-3.5"
|
||||
iconClassName="h-3.5 w-3.5"
|
||||
>
|
||||
{t('datasetConfig.top_kTip', { ns: 'appDebug' })}
|
||||
{topKTip}
|
||||
</Infotip>
|
||||
</div>
|
||||
<NumberField
|
||||
disabled={readonly}
|
||||
step={TOP_K_VALUE_LIMIT.amount}
|
||||
step={TOP_K_VALUE_LIMIT.step}
|
||||
min={TOP_K_VALUE_LIMIT.min}
|
||||
max={TOP_K_VALUE_LIMIT.max}
|
||||
value={topK}
|
||||
onValueChange={value => handleTopKChange(value ?? 0)}
|
||||
value={topK.value}
|
||||
onValueChange={value => topK.onChange(value ?? 0)}
|
||||
>
|
||||
<NumberFieldGroup>
|
||||
<NumberFieldInput aria-label={topKLabel} />
|
||||
<NumberFieldInput />
|
||||
<NumberFieldControls>
|
||||
<NumberFieldIncrement />
|
||||
<NumberFieldDecrement />
|
||||
</NumberFieldControls>
|
||||
</NumberFieldGroup>
|
||||
</NumberField>
|
||||
</div>
|
||||
{
|
||||
!hiddenScoreThreshold && (
|
||||
<div>
|
||||
<div className="mb-0.5 flex h-6 items-center">
|
||||
<Switch
|
||||
className="mr-2"
|
||||
checked={isScoreThresholdEnabled ?? false}
|
||||
onCheckedChange={onScoreThresholdEnabledChange}
|
||||
disabled={readonly}
|
||||
/>
|
||||
<div className="grow truncate system-sm-medium text-text-secondary">
|
||||
{scoreThresholdLabel}
|
||||
</div>
|
||||
<Infotip
|
||||
aria-label={t('datasetConfig.score_thresholdTip', { ns: 'appDebug' })}
|
||||
className="ml-0.5 size-3.5"
|
||||
iconClassName="h-3.5 w-3.5"
|
||||
>
|
||||
{t('datasetConfig.score_thresholdTip', { ns: 'appDebug' })}
|
||||
</Infotip>
|
||||
</div>
|
||||
<NumberField
|
||||
disabled={readonly || !isScoreThresholdEnabled}
|
||||
step={SCORE_THRESHOLD_VALUE_LIMIT.step}
|
||||
min={SCORE_THRESHOLD_VALUE_LIMIT.min}
|
||||
max={SCORE_THRESHOLD_VALUE_LIMIT.max}
|
||||
value={scoreThreshold ?? null}
|
||||
onValueChange={value => handleScoreThresholdChange(value ?? 0)}
|
||||
>
|
||||
<NumberFieldGroup>
|
||||
<NumberFieldInput aria-label={scoreThresholdLabel} />
|
||||
<NumberFieldControls>
|
||||
<NumberFieldIncrement />
|
||||
<NumberFieldDecrement />
|
||||
</NumberFieldControls>
|
||||
</NumberFieldGroup>
|
||||
</NumberField>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
</FieldRoot>
|
||||
{scoreThresholdHidden
|
||||
? null
|
||||
: (
|
||||
<FieldsetRoot className="min-w-0">
|
||||
<FieldsetLegend className="sr-only">{scoreThresholdLabel}</FieldsetLegend>
|
||||
<FieldRoot name="score_threshold_enabled" className="mb-0.5 gap-0">
|
||||
<div className="flex h-6 items-center">
|
||||
<FieldLabel className="flex w-full min-w-0 grow items-center py-0 system-sm-medium text-text-secondary">
|
||||
<Switch
|
||||
className="mr-2"
|
||||
checked={scoreThresholdEnabled}
|
||||
onCheckedChange={scoreThreshold.onEnabledChange}
|
||||
disabled={readonly}
|
||||
/>
|
||||
<span className="grow truncate">
|
||||
{scoreThresholdLabel}
|
||||
</span>
|
||||
</FieldLabel>
|
||||
<Infotip
|
||||
aria-label={scoreThresholdTip}
|
||||
className="ml-0.5 size-3.5"
|
||||
iconClassName="h-3.5 w-3.5"
|
||||
>
|
||||
{scoreThresholdTip}
|
||||
</Infotip>
|
||||
</div>
|
||||
</FieldRoot>
|
||||
<FieldRoot name="score_threshold" className="gap-0">
|
||||
<FieldLabel className="sr-only">{scoreThresholdLabel}</FieldLabel>
|
||||
<NumberField
|
||||
disabled={readonly || !scoreThresholdEnabled}
|
||||
step={SCORE_THRESHOLD_VALUE_LIMIT.step}
|
||||
min={SCORE_THRESHOLD_VALUE_LIMIT.min}
|
||||
max={SCORE_THRESHOLD_VALUE_LIMIT.max}
|
||||
value={scoreThreshold.value ?? null}
|
||||
onValueChange={value => scoreThreshold.onChange(value ?? 0)}
|
||||
>
|
||||
<NumberFieldGroup>
|
||||
<NumberFieldInput />
|
||||
<NumberFieldControls>
|
||||
<NumberFieldIncrement />
|
||||
<NumberFieldDecrement />
|
||||
</NumberFieldControls>
|
||||
</NumberFieldGroup>
|
||||
</NumberField>
|
||||
</FieldRoot>
|
||||
</FieldsetRoot>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(TopKAndScoreThreshold)
|
||||
|
||||
@ -24,7 +24,7 @@ import { isPrivateOrLocalAddress } from '@/utils/urlValidation'
|
||||
import HeaderTable from './components/header-table'
|
||||
import ParagraphInput from './components/paragraph-input'
|
||||
import ParameterTable from './components/parameter-table'
|
||||
import { DEFAULT_STATUS_CODE, MAX_STATUS_CODE, normalizeStatusCode, useConfig } from './use-config'
|
||||
import { DEFAULT_STATUS_CODE, MAX_STATUS_CODE, useConfig } from './use-config'
|
||||
import { OutputVariablesContent } from './utils/render-output-vars'
|
||||
|
||||
const i18nPrefix = 'nodes.triggerWebhook'
|
||||
@ -262,8 +262,8 @@ const Panel: FC<NodePanelProps<WebhookTriggerNodeType>> = ({
|
||||
disabled={readOnly}
|
||||
onValueChange={value => value !== null && handleStatusCodeChange(value)}
|
||||
onValueCommitted={(value, eventDetails) => {
|
||||
if (eventDetails.reason === 'input-blur' || eventDetails.reason === 'input-clear')
|
||||
handleStatusCodeChange(normalizeStatusCode(value ?? DEFAULT_STATUS_CODE))
|
||||
if (eventDetails.reason === 'input-clear')
|
||||
handleStatusCodeChange(value ?? DEFAULT_STATUS_CODE)
|
||||
}}
|
||||
>
|
||||
<NumberFieldGroup>
|
||||
|
||||
@ -29,7 +29,7 @@ const i18n = {
|
||||
placeholder: 'common.tag.placeholder',
|
||||
selectorPlaceholder: 'common.tag.selectorPlaceholder',
|
||||
operationClear: 'common.operation.clear',
|
||||
noTag: 'common.tag.noTag',
|
||||
noTag: /common\.tag\.noTag/,
|
||||
manageTags: 'common.tag.manageTags',
|
||||
}
|
||||
|
||||
@ -230,6 +230,20 @@ describe('TagFilter', () => {
|
||||
expect(screen.getByText(i18n.noTag)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should keep search input focused when search has no results', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
render(<TagFilter {...defaultProps} />)
|
||||
|
||||
await user.click(screen.getByText(i18n.placeholder))
|
||||
|
||||
const searchInput = screen.getByRole('combobox', { name: i18n.selectorPlaceholder })
|
||||
await user.type(searchInput, 'NonExistentTag')
|
||||
|
||||
expect(screen.getByText(i18n.noTag)).toBeInTheDocument()
|
||||
expect(searchInput).toHaveFocus()
|
||||
})
|
||||
|
||||
it('should clear search and show all tags when clear icon is clicked', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ import { render, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { useMemo, useState } from 'react'
|
||||
import { isCreateTagOption } from '../components/tag-combobox-item'
|
||||
import { TagPanel } from '../components/tag-panel'
|
||||
import { TagSearchContent } from '../components/tag-search-content'
|
||||
|
||||
const { onValueChangeSpy } = vi.hoisted(() => ({
|
||||
onValueChangeSpy: vi.fn(),
|
||||
@ -15,7 +15,7 @@ const i18n = {
|
||||
selectorPlaceholder: 'common.tag.selectorPlaceholder',
|
||||
operationClear: 'common.operation.clear',
|
||||
create: 'common.tag.create',
|
||||
noTag: 'common.tag.noTag',
|
||||
noTag: /common\.tag\.noTag/,
|
||||
manageTags: 'common.tag.manageTags',
|
||||
}
|
||||
|
||||
@ -78,7 +78,7 @@ const PanelHarness = ({
|
||||
itemToStringLabel={tagToString}
|
||||
isItemEqualToValue={isSameTag}
|
||||
>
|
||||
<TagPanel
|
||||
<TagSearchContent
|
||||
type={type}
|
||||
inputValue={inputValue}
|
||||
onInputValueChange={setInputValue}
|
||||
@ -88,7 +88,7 @@ const PanelHarness = ({
|
||||
)
|
||||
}
|
||||
|
||||
describe('TagPanel', () => {
|
||||
describe('TagSearchContent', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
@ -8,7 +8,7 @@ import { useTranslation } from 'react-i18next'
|
||||
import Tag01Icon from '@/app/components/base/icons/src/vender/line/financeAndECommerce/Tag01'
|
||||
import XCircleIcon from '@/app/components/base/icons/src/vender/solid/general/XCircle'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
import { TagPanel } from './tag-panel'
|
||||
import { TagSearchContent } from './tag-search-content'
|
||||
|
||||
const tagFilterComboboxFilter: NonNullable<ComboboxRootProps<Tag, true>['filter']> = (tag, query) => tag.name.includes(query)
|
||||
const tagToString = (tag: Tag) => tag.name
|
||||
@ -114,7 +114,7 @@ export const TagFilter = ({
|
||||
sideOffset={4}
|
||||
popupClassName="w-[240px] rounded-lg border-[0.5px] border-components-panel-border bg-components-panel-bg-blur p-0 shadow-lg backdrop-blur-[5px]"
|
||||
>
|
||||
<TagPanel
|
||||
<TagSearchContent
|
||||
type={type}
|
||||
inputValue={inputValue}
|
||||
onInputValueChange={setInputValue}
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
import type { TagComboboxItem } from './tag-combobox-item'
|
||||
import type { TagType } from '@/contract/console/tags'
|
||||
import { ComboboxInput, ComboboxInputGroup, ComboboxItem, ComboboxItemIndicator, ComboboxItemText, ComboboxList, ComboboxSeparator, useComboboxFilteredItems } from '@langgenius/dify-ui/combobox'
|
||||
import { ComboboxEmpty, ComboboxInput, ComboboxInputGroup, ComboboxItem, ComboboxItemIndicator, ComboboxItemText, ComboboxList, ComboboxSeparator, useComboboxFilteredItems } from '@langgenius/dify-ui/combobox'
|
||||
import { Fragment } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { isCreateTagOption } from './tag-combobox-item'
|
||||
|
||||
type TagPanelProps = {
|
||||
type TagSearchContentProps = {
|
||||
type: TagType
|
||||
inputValue: string
|
||||
onInputValueChange: (value: string) => void
|
||||
@ -13,17 +13,16 @@ type TagPanelProps = {
|
||||
onClose?: () => void
|
||||
}
|
||||
|
||||
export const TagPanel = ({
|
||||
export const TagSearchContent = ({
|
||||
type,
|
||||
inputValue,
|
||||
onInputValueChange,
|
||||
onOpenTagManagement,
|
||||
onClose,
|
||||
}: TagPanelProps) => {
|
||||
}: TagSearchContentProps) => {
|
||||
const { t } = useTranslation()
|
||||
const filteredItems = useComboboxFilteredItems<TagComboboxItem>()
|
||||
const realItemCount = filteredItems.filter(tag => !isCreateTagOption(tag)).length
|
||||
const hasCreateOption = filteredItems.some(isCreateTagOption)
|
||||
const placeholder = t('tag.selectorPlaceholder', { ns: 'common' }) || ''
|
||||
|
||||
return (
|
||||
@ -50,45 +49,41 @@ export const TagPanel = ({
|
||||
)}
|
||||
</ComboboxInputGroup>
|
||||
</div>
|
||||
{filteredItems.length > 0 && (
|
||||
<ComboboxList className="max-h-58">
|
||||
{(tag: TagComboboxItem) => {
|
||||
if (isCreateTagOption(tag)) {
|
||||
return (
|
||||
<Fragment key={tag.id}>
|
||||
<ComboboxItem
|
||||
value={tag}
|
||||
>
|
||||
<ComboboxItemText className="flex items-center gap-x-1 px-0">
|
||||
<span aria-hidden="true" className="i-ri-add-line size-4 shrink-0 text-text-tertiary" />
|
||||
<span className="min-w-0 grow truncate px-1 system-md-regular text-text-secondary">
|
||||
{`${t('tag.create', { ns: 'common' })} `}
|
||||
<span className="system-md-medium">{`'${tag.name}'`}</span>
|
||||
</span>
|
||||
</ComboboxItemText>
|
||||
</ComboboxItem>
|
||||
{realItemCount > 0 && <ComboboxSeparator />}
|
||||
</Fragment>
|
||||
)
|
||||
}
|
||||
|
||||
<ComboboxList className="max-h-58">
|
||||
{(tag: TagComboboxItem) => {
|
||||
if (isCreateTagOption(tag)) {
|
||||
return (
|
||||
<ComboboxItem key={tag.id} value={tag}>
|
||||
<ComboboxItemText title={tag.name}>{tag.name}</ComboboxItemText>
|
||||
<ComboboxItemIndicator />
|
||||
</ComboboxItem>
|
||||
<Fragment key={tag.id}>
|
||||
<ComboboxItem
|
||||
value={tag}
|
||||
>
|
||||
<ComboboxItemText className="flex items-center gap-x-1 px-0">
|
||||
<span aria-hidden="true" className="i-ri-add-line size-4 shrink-0 text-text-tertiary" />
|
||||
<span className="min-w-0 grow truncate px-1 system-md-regular text-text-secondary">
|
||||
{`${t('tag.create', { ns: 'common' })} `}
|
||||
<span className="system-md-medium">{`'${tag.name}'`}</span>
|
||||
</span>
|
||||
</ComboboxItemText>
|
||||
</ComboboxItem>
|
||||
{realItemCount > 0 && <ComboboxSeparator />}
|
||||
</Fragment>
|
||||
)
|
||||
}}
|
||||
</ComboboxList>
|
||||
)}
|
||||
{!hasCreateOption && realItemCount === 0 && (
|
||||
<div className="p-1">
|
||||
<div className="flex flex-col items-center gap-y-1 p-3">
|
||||
<span aria-hidden="true" className="i-ri-price-tag-3-line size-6 text-text-quaternary" />
|
||||
<div className="system-xs-regular text-text-tertiary">{t('tag.noTag', { ns: 'common' })}</div>
|
||||
</div>
|
||||
}
|
||||
|
||||
return (
|
||||
<ComboboxItem key={tag.id} value={tag}>
|
||||
<ComboboxItemText title={tag.name}>{tag.name}</ComboboxItemText>
|
||||
<ComboboxItemIndicator />
|
||||
</ComboboxItem>
|
||||
)
|
||||
}}
|
||||
</ComboboxList>
|
||||
<ComboboxEmpty className="p-1">
|
||||
<div className="flex flex-col items-center gap-y-1 p-3">
|
||||
<span aria-hidden="true" className="i-ri-price-tag-3-line size-6 text-text-quaternary" />
|
||||
<div className="system-xs-regular text-text-tertiary">{t('tag.noTag', { ns: 'common' })}</div>
|
||||
</div>
|
||||
)}
|
||||
</ComboboxEmpty>
|
||||
<ComboboxSeparator />
|
||||
<div className="p-1">
|
||||
<button
|
||||
@ -11,7 +11,7 @@ import { useTranslation } from 'react-i18next'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
import { useApplyTagBindingsMutation } from '../hooks/use-tag-mutations'
|
||||
import { isCreateTagOption } from './tag-combobox-item'
|
||||
import { TagPanel } from './tag-panel'
|
||||
import { TagSearchContent } from './tag-search-content'
|
||||
import { TagTrigger } from './tag-trigger'
|
||||
|
||||
const TAG_COMBOBOX_FILTER: NonNullable<ComboboxRootProps<TagComboboxItem, true>['filter']> = (tag, query) => tag.name.includes(query)
|
||||
@ -229,7 +229,7 @@ export const TagSelector = ({
|
||||
popupProps={popupProps}
|
||||
popupClassName={cn('w-(--anchor-width) min-w-60 rounded-lg border-[0.5px] border-components-panel-border bg-components-panel-bg-blur p-0 shadow-lg backdrop-blur-[5px]', popupClassName)}
|
||||
>
|
||||
<TagPanel
|
||||
<TagSearchContent
|
||||
type={type}
|
||||
inputValue={inputValue}
|
||||
onInputValueChange={setInputValue}
|
||||
|
||||
@ -1,10 +1,19 @@
|
||||
import type { CurrentPlanInfoBackend, SubscriptionUrlsBackend } from '@/app/components/billing/type'
|
||||
import { get } from './base'
|
||||
|
||||
export type CurrentPlanVectorSpaceBackend = {
|
||||
size: number
|
||||
limit: number
|
||||
}
|
||||
|
||||
export const fetchCurrentPlanInfo = () => {
|
||||
return get<CurrentPlanInfoBackend>('/features')
|
||||
}
|
||||
|
||||
export const fetchCurrentPlanVectorSpace = () => {
|
||||
return get<CurrentPlanVectorSpaceBackend>('/features/vector-space')
|
||||
}
|
||||
|
||||
export const fetchSubscriptionUrls = (plan: string, interval: string) => {
|
||||
return get<SubscriptionUrlsBackend>(`/billing/subscription?plan=${plan}&interval=${interval}`)
|
||||
}
|
||||
|
||||
@ -1,5 +1,8 @@
|
||||
import { useMutation, useQuery } from '@tanstack/react-query'
|
||||
import { consoleClient, consoleQuery } from '@/service/client'
|
||||
import { fetchCurrentPlanVectorSpace } from './billing'
|
||||
|
||||
const currentPlanVectorSpaceQueryKey = ['billing', 'current-plan-vector-space'] as const
|
||||
|
||||
export const useBindPartnerStackInfo = () => {
|
||||
return useMutation({
|
||||
@ -21,3 +24,11 @@ export const useBillingUrl = (enabled: boolean) => {
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
export const useCurrentPlanVectorSpace = (enabled = true) => {
|
||||
return useQuery({
|
||||
queryKey: currentPlanVectorSpaceQueryKey,
|
||||
queryFn: () => fetchCurrentPlanVectorSpace(),
|
||||
enabled,
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user