diff --git a/api/services/legacy_model_type_migration.py b/api/services/legacy_model_type_migration.py index 8a0b98a270..2de5e7f7f3 100644 --- a/api/services/legacy_model_type_migration.py +++ b/api/services/legacy_model_type_migration.py @@ -11,6 +11,14 @@ rewriting credential references in provider models and load-balancing configs be removing loser credential rows. `load_balancing_model_configs` stays mostly row-level, but it first deduplicates `name="__inherit__"` rows by business key before it canonicalizes the remaining legacy rows independently with row-level cache cleanup. + +Tenant scheduling has two modes. When callers provide an explicit tenant list, the +service preserves the original tenant-scoped execution model and runs all selected tables +for each tenant. When callers omit `tenant_ids`, the service discovers tenant +ids per table and then runs only that table for the discovered tenants. Most +tables keep the active `model_types` filter in the discovery query, while +`load_balancing_model_configs` deliberately uses a whole-table tenant scan so +that query stays easy to understand. """ from __future__ import annotations @@ -21,7 +29,7 @@ import sys import threading import traceback import uuid -from collections.abc import Iterable, Iterator, Sequence +from collections.abc import Iterable, Sequence from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import asdict, dataclass from datetime import datetime @@ -36,7 +44,7 @@ from sqlalchemy.sql import select from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from graphon.model_runtime.entities.model_entities import ModelType from libs.datetime_utils import naive_utc_now -from models import LoadBalancingModelConfig, ProviderModel, ProviderModelSetting, Tenant, TenantDefaultModel +from models import LoadBalancingModelConfig, ProviderModel, ProviderModelSetting, TenantDefaultModel from models.base import TypeBase from models.provider import ProviderModelCredential @@ -291,6 +299,21 @@ _LOCK_TIMEOUT_FALLBACK_MESSAGES: tuple[str, ...] = ( _RAW_MODEL_TYPE_COLUMN = "_raw_model_type" +def _selected_legacy_values(model_types: Sequence[ModelType]) -> list[str]: + legacy_values: list[str] = [] + for model_type in model_types: + legacy_values.extend(_CANONICAL_TO_LEGACY[model_type]) + return legacy_values + + +def _selected_model_type_values(model_types: Sequence[ModelType]) -> list[str]: + model_type_values: list[str] = [] + for model_type in model_types: + model_type_values.append(model_type.value) + model_type_values.extend(_CANONICAL_TO_LEGACY[model_type]) + return list(dict.fromkeys(model_type_values)) + + def _session_factory(engine: sa.Engine) -> Session: return Session(bind=engine, expire_on_commit=False) @@ -363,6 +386,12 @@ class LegacyModelTypeMigrationService: `provider_model_credentials` is selected, that migration also rewrites references in `provider_models` and `load_balancing_model_configs`. Tenant migrations can run in a thread pool; JSONL output remains line-safe through a shared synchronized writer. + + If `tenant_ids` is omitted, tenant discovery becomes table-scoped: each selected ORM + model loads its own tenant ids, then only that table is dispatched for those tenants. + Most tables keep the active model-type filter in discovery, while + `load_balancing_model_configs` intentionally uses the whole table so the tenant query + stays simple. This still avoids merging tenant ids across unrelated tables. """ _engine: sa.Engine @@ -426,22 +455,51 @@ class LegacyModelTypeMigrationService: return tuple(ordered_models) def migrate(self) -> None: - tenant_ids = tuple(self._iter_tenant_ids()) + output = _ThreadSafeLineWriter(self._output) + if self._tenant_ids is not None: + self._migrate_explicit_tenants(output) + return + + self._migrate_tables_with_discovered_tenants(output) + + def _migrate_explicit_tenants(self, output: io.TextIOBase) -> None: + tenant_ids = self._tenant_ids if not tenant_ids: return - output = _ThreadSafeLineWriter(self._output) + self._run_migrations_for_tenants(tenant_ids, self._orm_models, output) + + def _migrate_tables_with_discovered_tenants(self, output: io.TextIOBase) -> None: + for orm_model in self._orm_models: + tenant_ids = self._load_tenant_ids_for_model(orm_model) + if not tenant_ids: + continue + self._run_migrations_for_tenants(tenant_ids, (orm_model,), output) + + def _run_migrations_for_tenants( + self, + tenant_ids: Sequence[str], + orm_models: Sequence[ORMModel], + output: io.TextIOBase, + ) -> None: if self._concurrency == 1 or len(tenant_ids) == 1: for tenant_id in tenant_ids: - self._run_tenant_migration(tenant_id, output) + self._run_tenant_migration(tenant_id, orm_models, output) return with ThreadPoolExecutor(max_workers=min(self._concurrency, len(tenant_ids))) as executor: - futures = [executor.submit(self._run_tenant_migration, tenant_id, output) for tenant_id in tenant_ids] + futures = [ + executor.submit(self._run_tenant_migration, tenant_id, orm_models, output) for tenant_id in tenant_ids + ] for future in as_completed(futures): future.result() - def _run_tenant_migration(self, tenant_id: str, output: io.TextIOBase) -> None: + def _run_tenant_migration( + self, + tenant_id: str, + orm_models: Sequence[ORMModel], + output: io.TextIOBase, + ) -> None: """ Execute one tenant migration with the shared, line-synchronized output stream. """ @@ -452,18 +510,88 @@ class LegacyModelTypeMigrationService: apply=self._apply, output=output, model_types=self._model_types, - orm_models=self._orm_models, + orm_models=orm_models, ).run() - def _iter_tenant_ids(self) -> Iterator[str]: - if self._tenant_ids is not None: - yield from self._tenant_ids - return + def _load_tenant_ids_for_model(self, orm_model: ORMModel) -> tuple[str, ...]: + """ + Discover only the tenants that have candidate rows for the current table. + In automatic tenant mode we keep discovery table-scoped so large shared tenant + populations do not force empty work for unrelated tables. Most table queries + still apply the active `model_types` filter before scheduling migrations, while + `load_balancing_model_configs` intentionally trades a wider tenant set for a + simpler discovery query. + """ + + legacy_model_type_values = _selected_legacy_values(self._model_types) with _session_factory(self._engine) as session: - tenant_ids = session.execute(select(Tenant.id).order_by(Tenant.id.asc())).scalars().all() + if orm_model is ProviderModel: + tenant_ids = ( + session.execute( + select(ProviderModel.tenant_id) + .where(sa.type_coerce(ProviderModel.model_type, sa.String()).in_(legacy_model_type_values)) + .distinct() + .order_by(ProviderModel.tenant_id.asc()) + ) + .scalars() + .all() + ) + elif orm_model is TenantDefaultModel: + tenant_ids = ( + session.execute( + select(TenantDefaultModel.tenant_id) + .where(sa.type_coerce(TenantDefaultModel.model_type, sa.String()).in_(legacy_model_type_values)) + .distinct() + .order_by(TenantDefaultModel.tenant_id.asc()) + ) + .scalars() + .all() + ) + elif orm_model is ProviderModelSetting: + tenant_ids = ( + session.execute( + select(ProviderModelSetting.tenant_id) + .where( + sa.type_coerce(ProviderModelSetting.model_type, sa.String()).in_(legacy_model_type_values) + ) + .distinct() + .order_by(ProviderModelSetting.tenant_id.asc()) + ) + .scalars() + .all() + ) + elif orm_model is LoadBalancingModelConfig: + # Deliberately discover tenants from the whole table so the query stays + # easier to understand than the legacy/canonical mixed-row filter. + tenant_ids = ( + session.execute( + select(LoadBalancingModelConfig.tenant_id) + .distinct() + .order_by(LoadBalancingModelConfig.tenant_id.asc()) + ) + .scalars() + .all() + ) + elif orm_model is ProviderModelCredential: + tenant_ids = ( + session.execute( + select(ProviderModelCredential.tenant_id) + .where( + sa.type_coerce(ProviderModelCredential.model_type, sa.String()).in_( + legacy_model_type_values + ) + ) + .distinct() + .order_by(ProviderModelCredential.tenant_id.asc()) + ) + .scalars() + .all() + ) + else: + raise ValueError(f"unsupported orm model: {orm_model}") - yield from tenant_ids + return tuple(tenant_ids) class Migration: @@ -532,17 +660,10 @@ class Migration: ) def _selected_legacy_values(self) -> list[str]: - legacy_values: list[str] = [] - for model_type in self._model_types: - legacy_values.extend(_CANONICAL_TO_LEGACY[model_type]) - return legacy_values + return _selected_legacy_values(self._model_types) def _selected_model_type_values(self) -> list[str]: - model_type_values: list[str] = [] - for model_type in self._model_types: - model_type_values.append(model_type.value) - model_type_values.extend(_CANONICAL_TO_LEGACY[model_type]) - return list(dict.fromkeys(model_type_values)) + return _selected_model_type_values(self._model_types) def _allowed_values_for_canonical_model_type(self, canonical_model_type: ModelType) -> tuple[str, ...]: return (*_CANONICAL_TO_LEGACY[canonical_model_type], canonical_model_type.value) diff --git a/api/tests/unit_tests/commands/test_legacy_model_type_migration.py b/api/tests/unit_tests/commands/test_legacy_model_type_migration.py index 80b3e96b87..7eead948c1 100644 --- a/api/tests/unit_tests/commands/test_legacy_model_type_migration.py +++ b/api/tests/unit_tests/commands/test_legacy_model_type_migration.py @@ -17,6 +17,7 @@ from click.testing import CliRunner from sqlalchemy.exc import OperationalError from graphon.model_runtime.entities.model_entities import ModelType +from models.account import Tenant from models.enums import CredentialSourceType from models.provider import ProviderModel from tests.helpers.legacy_model_type_migration import ( @@ -24,6 +25,7 @@ from tests.helpers.legacy_model_type_migration import ( LEGACY_TO_CANONICAL, assert_tenant_rows_use_only_canonical_model_types, count_rows, + create_minimal_legacy_model_type_schema, fetch_table_rows, seed_legacy_model_type_dirty_data, snapshot_legacy_model_type_state, @@ -196,6 +198,18 @@ def _insert_provider_model( ) +def _insert_tenant(engine: sa.Engine, *, tenant_id: str) -> None: + with engine.begin() as conn: + conn.execute( + Tenant.__table__.insert().values( + id=tenant_id, + name=f"Tenant {tenant_id}", + plan="basic", + status="normal", + ) + ) + + def _insert_tenant_default_model( engine: sa.Engine, *, @@ -509,7 +523,7 @@ def test_service_migrate_batches_by_tenant_respects_selected_tables_without_reve migration_module, sqlite_engine: sa.Engine, ) -> None: - seen_runs: list[dict[str, object]] = [] + seen_runs: list[tuple[str, tuple[str, ...], tuple[ModelType, ...]]] = [] class FakeMigration: def __init__( @@ -522,18 +536,12 @@ def test_service_migrate_batches_by_tenant_respects_selected_tables_without_reve model_types: tuple[ModelType, ...], orm_models: tuple[type[object], ...], ) -> None: - seen_runs.append( - { - "tenant_id": tenant_id, - "engine": engine, - "apply": apply, - "model_types": model_types, - "table_names": tuple(model.__table__.name for model in orm_models), - } - ) + assert engine is sqlite_engine + assert apply is False + seen_runs.append((tenant_id, tuple(model.__table__.name for model in orm_models), model_types)) def run(self) -> None: - seen_runs.append({"run": True}) + return None monkeypatch = pytest.MonkeyPatch() try: @@ -542,7 +550,7 @@ def test_service_migrate_batches_by_tenant_respects_selected_tables_without_reve engine=sqlite_engine, apply=False, concurrency=1, - tables=("provider_models",), + tables=("provider_models", "tenant_default_models"), model_types=(ModelType.LLM,), tenant_ids=("tenant-alpha", "tenant-beta"), ) @@ -551,11 +559,267 @@ def test_service_migrate_batches_by_tenant_respects_selected_tables_without_reve finally: monkeypatch.undo() - init_calls = [call for call in seen_runs if "tenant_id" in call] - assert [call["tenant_id"] for call in init_calls] == ["tenant-alpha", "tenant-beta"] - for call in init_calls: - assert tuple(cast(tuple[str, ...], call["table_names"])) == ("provider_models",) - assert call["model_types"] == (ModelType.LLM,) + assert seen_runs == [ + ("tenant-alpha", ("provider_models", "tenant_default_models"), (ModelType.LLM,)), + ("tenant-beta", ("provider_models", "tenant_default_models"), (ModelType.LLM,)), + ] + + +def test_service_migrate_without_tenant_ids_discovers_tenants_per_selected_table_without_querying_tenants( + migration_module, + sqlite_engine: sa.Engine, + monkeypatch: pytest.MonkeyPatch, +) -> None: + create_minimal_legacy_model_type_schema(sqlite_engine) + provider_tenant_id = "00000000-0000-0000-0000-000000000111" + default_tenant_id = "00000000-0000-0000-0000-000000000222" + empty_tenant_id = "00000000-0000-0000-0000-000000000333" + for tenant_id in (provider_tenant_id, default_tenant_id, empty_tenant_id): + _insert_tenant(sqlite_engine, tenant_id=tenant_id) + + created_at = datetime(2025, 1, 1, 12, 0, 0) + updated_at = created_at + timedelta(minutes=1) + _insert_provider_model( + sqlite_engine, + row_id="10000000-0000-0000-0000-000000000111", + tenant_id=provider_tenant_id, + provider_name="openai", + model_name="gpt-4o-mini", + model_type="text-generation", + credential_id=None, + created_at=created_at, + updated_at=updated_at, + ) + _insert_tenant_default_model( + sqlite_engine, + row_id="20000000-0000-0000-0000-000000000222", + tenant_id=default_tenant_id, + provider_name="openai", + model_name="gpt-4o-mini", + model_type="text-generation", + created_at=created_at, + updated_at=updated_at, + ) + + seen_runs: list[tuple[str, tuple[str, ...], tuple[ModelType, ...]]] = [] + executed_sql: list[str] = [] + + class FakeMigration: + def __init__( + self, + *, + tenant_id: str, + engine: sa.Engine, + apply: bool, + output: io.TextIOBase, + model_types: tuple[ModelType, ...], + orm_models: tuple[type[object], ...], + ) -> None: + assert engine is sqlite_engine + assert apply is False + seen_runs.append((tenant_id, tuple(model.__table__.name for model in orm_models), model_types)) + + def run(self) -> None: + return None + + def _record_sql( + conn: sa.engine.Connection, + cursor: object, + statement: str, + parameters: object, + context: object, + executemany: bool, + ) -> None: + del conn, cursor, parameters, context, executemany + executed_sql.append(statement) + + sa.event.listen(sqlite_engine, "before_cursor_execute", _record_sql) + try: + monkeypatch.setattr(migration_module, "Migration", FakeMigration) + service = migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=False, + tables=("provider_models", "tenant_default_models"), + model_types=(ModelType.LLM,), + ) + + service.migrate() + finally: + sa.event.remove(sqlite_engine, "before_cursor_execute", _record_sql) + + assert seen_runs == [ + (provider_tenant_id, ("provider_models",), (ModelType.LLM,)), + (default_tenant_id, ("tenant_default_models",), (ModelType.LLM,)), + ] + normalized_statements = [" ".join(statement.lower().split()) for statement in executed_sql] + discovery_statements = [statement for statement in normalized_statements if statement.startswith("select")] + table_names = ("provider_models", "tenant_default_models") + table_discovery_statements = [ + statement + for statement in discovery_statements + if any(f" from {table_name} " in f" {statement} " for table_name in table_names) + ] + + assert [statement for statement in discovery_statements if " from tenants " in f" {statement} "] == [] + assert [statement for statement in discovery_statements if " union " in f" {statement} "] == [] + assert [ + next(table_name for table_name in table_names if f" from {table_name} " in f" {statement} ") + for statement in table_discovery_statements + ] == list(table_names) + + +def test_service_migrate_without_tenant_ids_filters_provider_model_tenants_by_selected_model_types( + migration_module, + sqlite_engine: sa.Engine, + monkeypatch: pytest.MonkeyPatch, +) -> None: + create_minimal_legacy_model_type_schema(sqlite_engine) + llm_tenant_id = "00000000-0000-0000-0000-000000000411" + embedding_tenant_id = "00000000-0000-0000-0000-000000000422" + empty_tenant_id = "00000000-0000-0000-0000-000000000433" + for tenant_id in (llm_tenant_id, embedding_tenant_id, empty_tenant_id): + _insert_tenant(sqlite_engine, tenant_id=tenant_id) + + created_at = datetime(2025, 1, 2, 12, 0, 0) + updated_at = created_at + timedelta(minutes=1) + _insert_provider_model( + sqlite_engine, + row_id="30000000-0000-0000-0000-000000000411", + tenant_id=llm_tenant_id, + provider_name="openai", + model_name="gpt-4o-mini", + model_type="text-generation", + credential_id=None, + created_at=created_at, + updated_at=updated_at, + ) + _insert_provider_model( + sqlite_engine, + row_id="30000000-0000-0000-0000-000000000422", + tenant_id=embedding_tenant_id, + provider_name="openai", + model_name="text-embedding-3-large", + model_type="embeddings", + credential_id=None, + created_at=created_at, + updated_at=updated_at, + ) + + seen_runs: list[tuple[str, tuple[str, ...], tuple[ModelType, ...]]] = [] + + class FakeMigration: + def __init__( + self, + *, + tenant_id: str, + engine: sa.Engine, + apply: bool, + output: io.TextIOBase, + model_types: tuple[ModelType, ...], + orm_models: tuple[type[object], ...], + ) -> None: + assert engine is sqlite_engine + assert apply is False + seen_runs.append((tenant_id, tuple(model.__table__.name for model in orm_models), model_types)) + + def run(self) -> None: + return None + + monkeypatch.setattr(migration_module, "Migration", FakeMigration) + service = migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=False, + tables=("provider_models",), + model_types=(ModelType.LLM,), + ) + + service.migrate() + + assert seen_runs == [ + (llm_tenant_id, ("provider_models",), (ModelType.LLM,)), + ] + + +def test_service_migrate_without_tenant_ids_discovers_all_load_balancing_tenants_for_simpler_table_scoped_query( + migration_module, + sqlite_engine: sa.Engine, + monkeypatch: pytest.MonkeyPatch, +) -> None: + create_minimal_legacy_model_type_schema(sqlite_engine) + inherit_llm_tenant_id = "00000000-0000-0000-0000-000000000511" + inherit_embedding_tenant_id = "00000000-0000-0000-0000-000000000522" + empty_tenant_id = "00000000-0000-0000-0000-000000000533" + for tenant_id in (inherit_llm_tenant_id, inherit_embedding_tenant_id, empty_tenant_id): + _insert_tenant(sqlite_engine, tenant_id=tenant_id) + + created_at = datetime(2025, 1, 3, 12, 0, 0) + updated_at = created_at + timedelta(minutes=1) + _insert_load_balancing_model_config( + sqlite_engine, + row_id="40000000-0000-0000-0000-000000000511", + tenant_id=inherit_llm_tenant_id, + provider_name="openai", + model_name="gpt-4o-mini", + model_type=ModelType.LLM.value, + name="__inherit__", + encrypted_config=json.dumps({"api_key": "inherit-llm"}), + credential_id="50000000-0000-0000-0000-000000000511", + enabled=True, + created_at=created_at, + updated_at=updated_at, + ) + _insert_load_balancing_model_config( + sqlite_engine, + row_id="40000000-0000-0000-0000-000000000522", + tenant_id=inherit_embedding_tenant_id, + provider_name="openai", + model_name="text-embedding-3-large", + model_type=ModelType.TEXT_EMBEDDING.value, + name="__inherit__", + encrypted_config=json.dumps({"api_key": "inherit-embedding"}), + credential_id="50000000-0000-0000-0000-000000000522", + enabled=True, + created_at=created_at, + updated_at=updated_at, + ) + + seen_runs: list[tuple[str, tuple[str, ...], tuple[ModelType, ...]]] = [] + + class FakeMigration: + def __init__( + self, + *, + tenant_id: str, + engine: sa.Engine, + apply: bool, + output: io.TextIOBase, + model_types: tuple[ModelType, ...], + orm_models: tuple[type[object], ...], + ) -> None: + assert engine is sqlite_engine + assert apply is False + seen_runs.append((tenant_id, tuple(model.__table__.name for model in orm_models), model_types)) + + def run(self) -> None: + return None + + monkeypatch.setattr(migration_module, "Migration", FakeMigration) + # Load-balancing tenant discovery is a deliberate exception: it scans the + # whole table so the discovery query stays easy to understand, even when + # the scheduled tenant set is wider than the selected model types. + service = migration_module.LegacyModelTypeMigrationService( + engine=sqlite_engine, + apply=False, + tables=("load_balancing_model_configs",), + model_types=(ModelType.LLM,), + ) + + service.migrate() + + assert seen_runs == [ + (inherit_llm_tenant_id, ("load_balancing_model_configs",), (ModelType.LLM,)), + (inherit_embedding_tenant_id, ("load_balancing_model_configs",), (ModelType.LLM,)), + ] def test_service_migrate_with_concurrency_greater_than_one_runs_tenants_in_parallel_without_changing_migration_scope(