Compare commits

..

1 Commits

Author SHA1 Message Date
521545d52e fix: migrate model type enum construction 2026-05-23 18:56:48 +08:00
12 changed files with 66 additions and 104 deletions

View File

@ -1,6 +1,6 @@
from flask_restx import Resource, marshal
from pydantic import BaseModel
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Forbidden
import services
@ -54,13 +54,12 @@ class CreateRagPipelineDatasetApi(Resource):
yaml_content=payload.yaml_content,
)
try:
with Session(db.engine, expire_on_commit=False) as session:
with sessionmaker(db.engine).begin() as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
tenant_id=current_tenant_id,
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
)
session.commit()
if rag_pipeline_dataset_create_entity.permission == "partial_members":
DatasetPermissionService.update_partial_member_list(
current_tenant_id,

View File

@ -1,7 +1,7 @@
from flask import request
from flask_restx import Resource, fields, marshal_with # type: ignore
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
@ -67,12 +67,10 @@ class RagPipelineImportApi(Resource):
current_user, _ = current_account_with_tenant()
payload = RagPipelineImportPayload.model_validate(console_ns.payload or {})
# Use a plain Session so that caught exceptions inside the service
# (which return FAILED status instead of re-raising) do not leave the
# transaction in a closed state that a .begin() context manager cannot
# handle. See app_import.py for the canonical pattern.
with Session(db.engine, expire_on_commit=False) as session:
# Create service with session
with sessionmaker(db.engine).begin() as session:
import_service = RagPipelineDslService(session)
# Import app
account = current_user
result = import_service.import_rag_pipeline(
account=account,
@ -82,10 +80,6 @@ class RagPipelineImportApi(Resource):
pipeline_id=payload.pipeline_id,
dataset_name=payload.name,
)
if result.status == ImportStatus.FAILED:
session.rollback()
else:
session.commit()
# Return appropriate status code based on result
status = result.status
@ -108,14 +102,12 @@ class RagPipelineImportConfirmApi(Resource):
def post(self, import_id):
current_user, _ = current_account_with_tenant()
with Session(db.engine, expire_on_commit=False) as session:
# Create service with session
with sessionmaker(db.engine).begin() as session:
import_service = RagPipelineDslService(session)
# Confirm import
account = current_user
result = import_service.confirm_import(import_id=import_id, account=account)
if result.status == ImportStatus.FAILED:
session.rollback()
else:
session.commit()
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED:
@ -132,7 +124,7 @@ class RagPipelineImportCheckDependenciesApi(Resource):
@edit_permission_required
@marshal_with(pipeline_import_check_dependencies_model)
def get(self, pipeline: Pipeline):
with Session(db.engine, expire_on_commit=False) as session:
with sessionmaker(db.engine).begin() as session:
import_service = RagPipelineDslService(session)
result = import_service.check_dependencies(pipeline=pipeline)
@ -150,7 +142,7 @@ class RagPipelineExportApi(Resource):
# Add include_secret params
query = IncludeSecretQuery.model_validate(request.args.to_dict())
with Session(db.engine, expire_on_commit=False) as session:
with sessionmaker(db.engine).begin() as session:
export_service = RagPipelineDslService(session)
result = export_service.export_rag_pipeline_dsl(
pipeline=pipeline, include_secret=query.include_secret == "true"

View File

@ -795,7 +795,7 @@ class ProviderManager:
return [
{
"model": model_key[0],
"model_type": ModelType.value_of(model_key[1]),
"model_type": ModelType(model_key[1]),
"available_model_credentials": [
CredentialConfiguration(credential_id=cred.id, credential_name=cred.credential_name)
for cred in creds

View File

@ -1,6 +1,6 @@
from typing import Any, Union
from typing import Union
from pydantic import BaseModel, field_validator
from pydantic import BaseModel
from core.rag.entities import RerankingModelConfig, WeightedScoreConfig
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
@ -101,14 +101,3 @@ class KnowledgeIndexNodeData(BaseNodeData):
index_chunk_variable_selector: list[str]
indexing_technique: str | None = None
summary_index_setting: SummaryIndexSettingDict | None = None
@field_validator("summary_index_setting", mode="before")
@classmethod
def normalize_summary_index_setting(cls, v: Any) -> Any:
"""Treat dicts with enable=None (or missing enable) as None (#36233)."""
if v is None:
return None
if isinstance(v, dict):
if v.get("enable") is None:
return None
return v

View File

@ -66,7 +66,7 @@ class ModelLoadBalancingService:
raise ValueError(f"Provider {provider} does not exist.")
# Enable model load balancing
provider_configuration.enable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
provider_configuration.enable_model_load_balancing(model=model, model_type=ModelType(model_type))
def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str):
"""
@ -87,7 +87,7 @@ class ModelLoadBalancingService:
raise ValueError(f"Provider {provider} does not exist.")
# disable model load balancing
provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType(model_type))
def get_load_balancing_configs(
self, tenant_id: str, provider: str, model: str, model_type: str, config_from: str = ""
@ -109,7 +109,7 @@ class ModelLoadBalancingService:
raise ValueError(f"Provider {provider} does not exist.")
# Convert model type to ModelType
model_type_enum = ModelType.value_of(model_type)
model_type_enum = ModelType(model_type)
# Get provider model setting
provider_model_setting = provider_configuration.get_provider_model_setting(
@ -250,7 +250,7 @@ class ModelLoadBalancingService:
raise ValueError(f"Provider {provider} does not exist.")
# Convert model type to ModelType
model_type_enum = ModelType.value_of(model_type)
model_type_enum = ModelType(model_type)
# Get load balancing configurations
load_balancing_model_config = db.session.scalar(
@ -338,7 +338,7 @@ class ModelLoadBalancingService:
raise ValueError(f"Provider {provider} does not exist.")
# Convert model type to ModelType
model_type_enum = ModelType.value_of(model_type)
model_type_enum = ModelType(model_type)
if not isinstance(configs, list):
raise ValueError("Invalid load balancing configs")
@ -524,7 +524,7 @@ class ModelLoadBalancingService:
raise ValueError(f"Provider {provider} does not exist.")
# Convert model type to ModelType
model_type_enum = ModelType.value_of(model_type)
model_type_enum = ModelType(model_type)
load_balancing_model_config = None
if config_id:

View File

@ -67,7 +67,7 @@ class ModelProviderService:
provider_responses = []
for provider_configuration in provider_configurations.values():
if model_type:
model_type_entity = ModelType.value_of(model_type)
model_type_entity = ModelType(model_type)
if model_type_entity not in provider_configuration.provider.supported_model_types:
continue
@ -269,7 +269,7 @@ class ModelProviderService:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
return provider_configuration.get_custom_model_credential(
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
model_type=ModelType(model_type), model=model, credential_id=credential_id
)
def validate_model_credentials(
@ -287,7 +287,7 @@ class ModelProviderService:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.validate_custom_model_credentials(
model_type=ModelType.value_of(model_type), model=model, credentials=credentials
model_type=ModelType(model_type), model=model, credentials=credentials
)
def create_model_credential(
@ -312,7 +312,7 @@ class ModelProviderService:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.create_custom_model_credential(
model_type=ModelType.value_of(model_type),
model_type=ModelType(model_type),
model=model,
credentials=credentials,
credential_name=credential_name,
@ -342,7 +342,7 @@ class ModelProviderService:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.update_custom_model_credential(
model_type=ModelType.value_of(model_type),
model_type=ModelType(model_type),
model=model,
credentials=credentials,
credential_id=credential_id,
@ -362,7 +362,7 @@ class ModelProviderService:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.delete_custom_model_credential(
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
model_type=ModelType(model_type), model=model, credential_id=credential_id
)
def switch_active_custom_model_credential(
@ -380,7 +380,7 @@ class ModelProviderService:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.switch_custom_model_credential(
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
model_type=ModelType(model_type), model=model, credential_id=credential_id
)
def add_model_credential_to_model_list(
@ -398,7 +398,7 @@ class ModelProviderService:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.add_model_credential_to_model(
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
model_type=ModelType(model_type), model=model, credential_id=credential_id
)
def remove_model(self, tenant_id: str, provider: str, model_type: str, model: str):
@ -412,7 +412,7 @@ class ModelProviderService:
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.delete_custom_model(model_type=ModelType.value_of(model_type), model=model)
provider_configuration.delete_custom_model(model_type=ModelType(model_type), model=model)
def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[ProviderWithModelsResponse]:
"""
@ -426,7 +426,7 @@ class ModelProviderService:
provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id)
# Get provider available models
models = provider_configurations.get_models(model_type=ModelType.value_of(model_type), only_active=True)
models = provider_configurations.get_models(model_type=ModelType(model_type), only_active=True)
# Group models by provider
provider_models: dict[str, list[ModelWithProviderEntity]] = {}
@ -505,7 +505,7 @@ class ModelProviderService:
:param model_type: model type
:return:
"""
model_type_enum = ModelType.value_of(model_type)
model_type_enum = ModelType(model_type)
try:
result = self._get_provider_manager(tenant_id).get_default_model(
@ -540,7 +540,7 @@ class ModelProviderService:
:param model: model name
:return:
"""
model_type_enum = ModelType.value_of(model_type)
model_type_enum = ModelType(model_type)
self._get_provider_manager(tenant_id).update_default_model_record(
tenant_id=tenant_id, model_type=model_type_enum, provider=provider, model=model
)
@ -590,7 +590,7 @@ class ModelProviderService:
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type))
provider_configuration.enable_model(model=model, model_type=ModelType(model_type))
def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str):
"""
@ -603,4 +603,4 @@ class ModelProviderService:
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.disable_model(model=model, model_type=ModelType.value_of(model_type))
provider_configuration.disable_model(model=model, model_type=ModelType(model_type))

View File

@ -78,9 +78,9 @@ class CheckDependenciesPendingData(BaseModel):
class RagPipelineDslService:
"""Import, export, and inspect RAG pipeline DSL using the caller-owned session.
Callers pass a plain ``Session`` (not wrapped in ``.begin()``) and are responsible for calling
``session.commit()`` on success or ``session.rollback()`` on failure. Methods here only flush
when generated IDs are needed mid-operation; they never commit or rollback.
Controllers wrap this service in a SQLAlchemy transaction context, so methods must only flush interim changes when
generated IDs are needed. Committing inside the service would close the caller's transaction and break later work in
the same context manager.
"""
def __init__(self, session: Session):

View File

@ -5,7 +5,6 @@ import { useSuspenseQuery } from '@tanstack/react-query'
import { useCallback } from 'react'
import { useTranslation } from 'react-i18next'
import DifyLogo from '@/app/components/base/logo/dify-logo'
import Link from '@/next/link'
import { useRouter } from '@/next/navigation'
import { systemFeaturesQueryOptions } from '@/service/system-features'
import Avatar from './avatar'
@ -18,26 +17,21 @@ const Header = () => {
const goToStudio = useCallback(() => {
router.push('/apps')
}, [router])
const logoLabel = systemFeatures.branding.enabled && systemFeatures.branding.application_title ? systemFeatures.branding.application_title : 'Dify'
return (
<div className="flex flex-1 items-center justify-between px-4">
<div className="flex items-center gap-3">
<Link
href="/apps"
className="flex items-center rounded-sm hover:opacity-80 focus-visible:ring-1 focus-visible:ring-components-input-border-active focus-visible:outline-hidden"
aria-label={logoLabel}
>
<div className="flex cursor-pointer items-center" onClick={goToStudio}>
{systemFeatures.branding.enabled && systemFeatures.branding.login_page_logo
? (
<img
src={systemFeatures.branding.login_page_logo}
className="block h-[22px] w-auto object-contain"
alt=""
alt="Dify logo"
/>
)
: <DifyLogo alt="" />}
</Link>
: <DifyLogo />}
</div>
<div className="h-4 w-px origin-center rotate-[11.31deg] bg-divider-regular" />
<p className="relative mt-[-2px] title-3xl-semi-bold text-text-primary">{t('account.account', { ns: 'common' })}</p>
</div>

View File

@ -58,12 +58,6 @@ describe('DifyLogo', () => {
const img = screen.getByRole('img', { name: /dify logo/i })
expect(img).toHaveClass('custom-test-class')
})
it('applies custom alt text', () => {
const { container } = render(<DifyLogo alt="" />)
const img = container.querySelector('img')
expect(img).toHaveAttribute('alt', '')
})
})
describe('Theme behavior', () => {

View File

@ -23,14 +23,12 @@ type DifyLogoProps = {
style?: LogoStyle
size?: LogoSize
className?: string
alt?: string
}
const DifyLogo: FC<DifyLogoProps> = ({
style = 'default',
size = 'medium',
className,
alt = 'Dify logo',
}) => {
const { theme } = useTheme()
const themedStyle = (theme === 'dark' && style === 'default') ? 'monochromeWhite' : style
@ -39,7 +37,7 @@ const DifyLogo: FC<DifyLogoProps> = ({
<img
src={`${basePath}${logoPathMap[themedStyle]}`}
className={cn('block object-contain', logoSizeMap[size], className)}
alt={alt}
alt="Dify logo"
/>
)
}

View File

@ -1,4 +1,4 @@
import type { AnchorHTMLAttributes, ReactElement } from 'react'
import type { ReactElement } from 'react'
import { fireEvent, screen } from '@testing-library/react'
import { vi } from 'vitest'
import { renderWithSystemFeatures } from '@/__tests__/utils/mock-system-features'
@ -55,7 +55,7 @@ vi.mock('@/context/workspace-context-provider', () => ({
}))
vi.mock('@/next/link', () => ({
default: ({ children, href, ...props }: AnchorHTMLAttributes<HTMLAnchorElement> & { href?: string }) => <a href={href} {...props}>{children}</a>,
default: ({ children, href }: { children?: React.ReactNode, href?: string }) => <a href={href}>{children}</a>,
}))
let mockIsWorkspaceEditor = false
@ -122,9 +122,7 @@ describe('Header', () => {
it('should render header with main nav components', () => {
renderHeader()
expect(screen.getByRole('link', { name: 'Dify' })).toHaveAttribute('href', '/apps')
expect(screen.queryByRole('heading', { level: 1 })).not.toBeInTheDocument()
expect(screen.queryByRole('img', { name: /dify logo/i })).not.toBeInTheDocument()
expect(screen.getByRole('img', { name: /dify logo/i })).toBeInTheDocument()
expect(screen.getByTestId('workplace-selector')).toBeInTheDocument()
expect(screen.getByTestId('app-nav')).toBeInTheDocument()
expect(screen.getByTestId('account-dropdown')).toBeInTheDocument()
@ -168,7 +166,7 @@ describe('Header', () => {
mockMedia = 'mobile'
renderHeader()
expect(screen.getByRole('link', { name: 'Dify' })).toHaveAttribute('href', '/apps')
expect(screen.getByRole('img', { name: /dify logo/i })).toBeInTheDocument()
expect(screen.queryByTestId('env-nav')).not.toBeInTheDocument()
})
@ -179,8 +177,8 @@ describe('Header', () => {
renderHeader()
expect(screen.getByRole('link', { name: 'Acme Workspace' })).toHaveAttribute('href', '/apps')
expect(screen.queryByRole('img', { name: /logo/i })).not.toBeInTheDocument()
expect(screen.getByText('Acme Workspace')).toBeInTheDocument()
expect(screen.getByRole('img', { name: /logo/i })).toBeInTheDocument()
expect(screen.queryByRole('img', { name: /dify logo/i })).not.toBeInTheDocument()
})
@ -191,18 +189,18 @@ describe('Header', () => {
renderHeader()
expect(screen.getByRole('link', { name: 'Custom Title' })).toHaveAttribute('href', '/apps')
expect(screen.queryByRole('img', { name: /dify logo/i })).not.toBeInTheDocument()
expect(screen.getByText('Custom Title')).toBeInTheDocument()
expect(screen.getByRole('img', { name: /dify logo/i })).toBeInTheDocument()
})
it('should use default Dify link label when branding enabled but no application_title', () => {
it('should show default Dify text when branding enabled but no application_title', () => {
mockBrandingEnabled = true
mockBrandingTitle = null
mockBrandingLogo = null
renderHeader()
expect(screen.getByRole('link', { name: 'Dify' })).toHaveAttribute('href', '/apps')
expect(screen.getByText('Dify')).toBeInTheDocument()
})
it('should show dataset nav for editor who is not dataset operator', () => {

View File

@ -44,23 +44,21 @@ const Header = () => {
setShowAccountSettingModal({ payload: ACCOUNT_SETTING_TAB.BILLING })
}, [isFreePlan, setShowAccountSettingModal, setShowPricingModal])
const logoLabel = isBrandingEnabled && systemFeatures.branding.application_title ? systemFeatures.branding.application_title : 'Dify'
const renderLogo = () => (
<Link
href="/apps"
className="flex h-8 shrink-0 items-center justify-center overflow-hidden rounded-sm px-0.5 hover:opacity-80 focus-visible:ring-1 focus-visible:ring-components-input-border-active focus-visible:outline-hidden"
aria-label={logoLabel}
>
{systemFeatures.branding.enabled && systemFeatures.branding.workspace_logo
? (
<img
src={systemFeatures.branding.workspace_logo}
className="block h-[22px] w-auto object-contain"
alt=""
/>
)
: <DifyLogo alt="" />}
</Link>
<h1>
<Link href="/apps" className="flex h-8 shrink-0 items-center justify-center overflow-hidden px-0.5 indent-[-9999px] whitespace-nowrap">
{isBrandingEnabled && systemFeatures.branding.application_title ? systemFeatures.branding.application_title : 'Dify'}
{systemFeatures.branding.enabled && systemFeatures.branding.workspace_logo
? (
<img
src={systemFeatures.branding.workspace_logo}
className="block h-[22px] w-auto object-contain"
alt="logo"
/>
)
: <DifyLogo />}
</Link>
</h1>
)
if (isMobile) {