feat(api): only load tenant required for each table

This commit is contained in:
QuantumGhost
2026-05-26 22:57:27 +08:00
parent 5da4160cb8
commit 1a9fdef2cf
2 changed files with 425 additions and 40 deletions

View File

@ -11,6 +11,14 @@ rewriting credential references in provider models and load-balancing configs be
removing loser credential rows. `load_balancing_model_configs` stays mostly row-level,
but it first deduplicates `name="__inherit__"` rows by business key before it
canonicalizes the remaining legacy rows independently with row-level cache cleanup.
Tenant scheduling has two modes. When callers provide an explicit tenant list, the
service preserves the original tenant-scoped execution model and runs all selected tables
for each tenant. When callers omit `tenant_ids`, the service discovers tenant
ids per table and then runs only that table for the discovered tenants. Most
tables keep the active `model_types` filter in the discovery query, while
`load_balancing_model_configs` deliberately uses a whole-table tenant scan so
that query stays easy to understand.
"""
from __future__ import annotations
@ -21,7 +29,7 @@ import sys
import threading
import traceback
import uuid
from collections.abc import Iterable, Iterator, Sequence
from collections.abc import Iterable, Sequence
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import asdict, dataclass
from datetime import datetime
@ -36,7 +44,7 @@ from sqlalchemy.sql import select
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from graphon.model_runtime.entities.model_entities import ModelType
from libs.datetime_utils import naive_utc_now
from models import LoadBalancingModelConfig, ProviderModel, ProviderModelSetting, Tenant, TenantDefaultModel
from models import LoadBalancingModelConfig, ProviderModel, ProviderModelSetting, TenantDefaultModel
from models.base import TypeBase
from models.provider import ProviderModelCredential
@ -291,6 +299,21 @@ _LOCK_TIMEOUT_FALLBACK_MESSAGES: tuple[str, ...] = (
_RAW_MODEL_TYPE_COLUMN = "_raw_model_type"
def _selected_legacy_values(model_types: Sequence[ModelType]) -> list[str]:
legacy_values: list[str] = []
for model_type in model_types:
legacy_values.extend(_CANONICAL_TO_LEGACY[model_type])
return legacy_values
def _selected_model_type_values(model_types: Sequence[ModelType]) -> list[str]:
model_type_values: list[str] = []
for model_type in model_types:
model_type_values.append(model_type.value)
model_type_values.extend(_CANONICAL_TO_LEGACY[model_type])
return list(dict.fromkeys(model_type_values))
def _session_factory(engine: sa.Engine) -> Session:
return Session(bind=engine, expire_on_commit=False)
@ -363,6 +386,12 @@ class LegacyModelTypeMigrationService:
`provider_model_credentials` is selected, that migration also rewrites references in
`provider_models` and `load_balancing_model_configs`. Tenant migrations can run in a
thread pool; JSONL output remains line-safe through a shared synchronized writer.
If `tenant_ids` is omitted, tenant discovery becomes table-scoped: each selected ORM
model loads its own tenant ids, then only that table is dispatched for those tenants.
Most tables keep the active model-type filter in discovery, while
`load_balancing_model_configs` intentionally uses the whole table so the tenant query
stays simple. This still avoids merging tenant ids across unrelated tables.
"""
_engine: sa.Engine
@ -426,22 +455,51 @@ class LegacyModelTypeMigrationService:
return tuple(ordered_models)
def migrate(self) -> None:
tenant_ids = tuple(self._iter_tenant_ids())
output = _ThreadSafeLineWriter(self._output)
if self._tenant_ids is not None:
self._migrate_explicit_tenants(output)
return
self._migrate_tables_with_discovered_tenants(output)
def _migrate_explicit_tenants(self, output: io.TextIOBase) -> None:
tenant_ids = self._tenant_ids
if not tenant_ids:
return
output = _ThreadSafeLineWriter(self._output)
self._run_migrations_for_tenants(tenant_ids, self._orm_models, output)
def _migrate_tables_with_discovered_tenants(self, output: io.TextIOBase) -> None:
for orm_model in self._orm_models:
tenant_ids = self._load_tenant_ids_for_model(orm_model)
if not tenant_ids:
continue
self._run_migrations_for_tenants(tenant_ids, (orm_model,), output)
def _run_migrations_for_tenants(
self,
tenant_ids: Sequence[str],
orm_models: Sequence[ORMModel],
output: io.TextIOBase,
) -> None:
if self._concurrency == 1 or len(tenant_ids) == 1:
for tenant_id in tenant_ids:
self._run_tenant_migration(tenant_id, output)
self._run_tenant_migration(tenant_id, orm_models, output)
return
with ThreadPoolExecutor(max_workers=min(self._concurrency, len(tenant_ids))) as executor:
futures = [executor.submit(self._run_tenant_migration, tenant_id, output) for tenant_id in tenant_ids]
futures = [
executor.submit(self._run_tenant_migration, tenant_id, orm_models, output) for tenant_id in tenant_ids
]
for future in as_completed(futures):
future.result()
def _run_tenant_migration(self, tenant_id: str, output: io.TextIOBase) -> None:
def _run_tenant_migration(
self,
tenant_id: str,
orm_models: Sequence[ORMModel],
output: io.TextIOBase,
) -> None:
"""
Execute one tenant migration with the shared, line-synchronized output stream.
"""
@ -452,18 +510,88 @@ class LegacyModelTypeMigrationService:
apply=self._apply,
output=output,
model_types=self._model_types,
orm_models=self._orm_models,
orm_models=orm_models,
).run()
def _iter_tenant_ids(self) -> Iterator[str]:
if self._tenant_ids is not None:
yield from self._tenant_ids
return
def _load_tenant_ids_for_model(self, orm_model: ORMModel) -> tuple[str, ...]:
"""
Discover only the tenants that have candidate rows for the current table.
In automatic tenant mode we keep discovery table-scoped so large shared tenant
populations do not force empty work for unrelated tables. Most table queries
still apply the active `model_types` filter before scheduling migrations, while
`load_balancing_model_configs` intentionally trades a wider tenant set for a
simpler discovery query.
"""
legacy_model_type_values = _selected_legacy_values(self._model_types)
with _session_factory(self._engine) as session:
tenant_ids = session.execute(select(Tenant.id).order_by(Tenant.id.asc())).scalars().all()
if orm_model is ProviderModel:
tenant_ids = (
session.execute(
select(ProviderModel.tenant_id)
.where(sa.type_coerce(ProviderModel.model_type, sa.String()).in_(legacy_model_type_values))
.distinct()
.order_by(ProviderModel.tenant_id.asc())
)
.scalars()
.all()
)
elif orm_model is TenantDefaultModel:
tenant_ids = (
session.execute(
select(TenantDefaultModel.tenant_id)
.where(sa.type_coerce(TenantDefaultModel.model_type, sa.String()).in_(legacy_model_type_values))
.distinct()
.order_by(TenantDefaultModel.tenant_id.asc())
)
.scalars()
.all()
)
elif orm_model is ProviderModelSetting:
tenant_ids = (
session.execute(
select(ProviderModelSetting.tenant_id)
.where(
sa.type_coerce(ProviderModelSetting.model_type, sa.String()).in_(legacy_model_type_values)
)
.distinct()
.order_by(ProviderModelSetting.tenant_id.asc())
)
.scalars()
.all()
)
elif orm_model is LoadBalancingModelConfig:
# Deliberately discover tenants from the whole table so the query stays
# easier to understand than the legacy/canonical mixed-row filter.
tenant_ids = (
session.execute(
select(LoadBalancingModelConfig.tenant_id)
.distinct()
.order_by(LoadBalancingModelConfig.tenant_id.asc())
)
.scalars()
.all()
)
elif orm_model is ProviderModelCredential:
tenant_ids = (
session.execute(
select(ProviderModelCredential.tenant_id)
.where(
sa.type_coerce(ProviderModelCredential.model_type, sa.String()).in_(
legacy_model_type_values
)
)
.distinct()
.order_by(ProviderModelCredential.tenant_id.asc())
)
.scalars()
.all()
)
else:
raise ValueError(f"unsupported orm model: {orm_model}")
yield from tenant_ids
return tuple(tenant_ids)
class Migration:
@ -532,17 +660,10 @@ class Migration:
)
def _selected_legacy_values(self) -> list[str]:
legacy_values: list[str] = []
for model_type in self._model_types:
legacy_values.extend(_CANONICAL_TO_LEGACY[model_type])
return legacy_values
return _selected_legacy_values(self._model_types)
def _selected_model_type_values(self) -> list[str]:
model_type_values: list[str] = []
for model_type in self._model_types:
model_type_values.append(model_type.value)
model_type_values.extend(_CANONICAL_TO_LEGACY[model_type])
return list(dict.fromkeys(model_type_values))
return _selected_model_type_values(self._model_types)
def _allowed_values_for_canonical_model_type(self, canonical_model_type: ModelType) -> tuple[str, ...]:
return (*_CANONICAL_TO_LEGACY[canonical_model_type], canonical_model_type.value)

View File

@ -17,6 +17,7 @@ from click.testing import CliRunner
from sqlalchemy.exc import OperationalError
from graphon.model_runtime.entities.model_entities import ModelType
from models.account import Tenant
from models.enums import CredentialSourceType
from models.provider import ProviderModel
from tests.helpers.legacy_model_type_migration import (
@ -24,6 +25,7 @@ from tests.helpers.legacy_model_type_migration import (
LEGACY_TO_CANONICAL,
assert_tenant_rows_use_only_canonical_model_types,
count_rows,
create_minimal_legacy_model_type_schema,
fetch_table_rows,
seed_legacy_model_type_dirty_data,
snapshot_legacy_model_type_state,
@ -196,6 +198,18 @@ def _insert_provider_model(
)
def _insert_tenant(engine: sa.Engine, *, tenant_id: str) -> None:
with engine.begin() as conn:
conn.execute(
Tenant.__table__.insert().values(
id=tenant_id,
name=f"Tenant {tenant_id}",
plan="basic",
status="normal",
)
)
def _insert_tenant_default_model(
engine: sa.Engine,
*,
@ -509,7 +523,7 @@ def test_service_migrate_batches_by_tenant_respects_selected_tables_without_reve
migration_module,
sqlite_engine: sa.Engine,
) -> None:
seen_runs: list[dict[str, object]] = []
seen_runs: list[tuple[str, tuple[str, ...], tuple[ModelType, ...]]] = []
class FakeMigration:
def __init__(
@ -522,18 +536,12 @@ def test_service_migrate_batches_by_tenant_respects_selected_tables_without_reve
model_types: tuple[ModelType, ...],
orm_models: tuple[type[object], ...],
) -> None:
seen_runs.append(
{
"tenant_id": tenant_id,
"engine": engine,
"apply": apply,
"model_types": model_types,
"table_names": tuple(model.__table__.name for model in orm_models),
}
)
assert engine is sqlite_engine
assert apply is False
seen_runs.append((tenant_id, tuple(model.__table__.name for model in orm_models), model_types))
def run(self) -> None:
seen_runs.append({"run": True})
return None
monkeypatch = pytest.MonkeyPatch()
try:
@ -542,7 +550,7 @@ def test_service_migrate_batches_by_tenant_respects_selected_tables_without_reve
engine=sqlite_engine,
apply=False,
concurrency=1,
tables=("provider_models",),
tables=("provider_models", "tenant_default_models"),
model_types=(ModelType.LLM,),
tenant_ids=("tenant-alpha", "tenant-beta"),
)
@ -551,11 +559,267 @@ def test_service_migrate_batches_by_tenant_respects_selected_tables_without_reve
finally:
monkeypatch.undo()
init_calls = [call for call in seen_runs if "tenant_id" in call]
assert [call["tenant_id"] for call in init_calls] == ["tenant-alpha", "tenant-beta"]
for call in init_calls:
assert tuple(cast(tuple[str, ...], call["table_names"])) == ("provider_models",)
assert call["model_types"] == (ModelType.LLM,)
assert seen_runs == [
("tenant-alpha", ("provider_models", "tenant_default_models"), (ModelType.LLM,)),
("tenant-beta", ("provider_models", "tenant_default_models"), (ModelType.LLM,)),
]
def test_service_migrate_without_tenant_ids_discovers_tenants_per_selected_table_without_querying_tenants(
migration_module,
sqlite_engine: sa.Engine,
monkeypatch: pytest.MonkeyPatch,
) -> None:
create_minimal_legacy_model_type_schema(sqlite_engine)
provider_tenant_id = "00000000-0000-0000-0000-000000000111"
default_tenant_id = "00000000-0000-0000-0000-000000000222"
empty_tenant_id = "00000000-0000-0000-0000-000000000333"
for tenant_id in (provider_tenant_id, default_tenant_id, empty_tenant_id):
_insert_tenant(sqlite_engine, tenant_id=tenant_id)
created_at = datetime(2025, 1, 1, 12, 0, 0)
updated_at = created_at + timedelta(minutes=1)
_insert_provider_model(
sqlite_engine,
row_id="10000000-0000-0000-0000-000000000111",
tenant_id=provider_tenant_id,
provider_name="openai",
model_name="gpt-4o-mini",
model_type="text-generation",
credential_id=None,
created_at=created_at,
updated_at=updated_at,
)
_insert_tenant_default_model(
sqlite_engine,
row_id="20000000-0000-0000-0000-000000000222",
tenant_id=default_tenant_id,
provider_name="openai",
model_name="gpt-4o-mini",
model_type="text-generation",
created_at=created_at,
updated_at=updated_at,
)
seen_runs: list[tuple[str, tuple[str, ...], tuple[ModelType, ...]]] = []
executed_sql: list[str] = []
class FakeMigration:
def __init__(
self,
*,
tenant_id: str,
engine: sa.Engine,
apply: bool,
output: io.TextIOBase,
model_types: tuple[ModelType, ...],
orm_models: tuple[type[object], ...],
) -> None:
assert engine is sqlite_engine
assert apply is False
seen_runs.append((tenant_id, tuple(model.__table__.name for model in orm_models), model_types))
def run(self) -> None:
return None
def _record_sql(
conn: sa.engine.Connection,
cursor: object,
statement: str,
parameters: object,
context: object,
executemany: bool,
) -> None:
del conn, cursor, parameters, context, executemany
executed_sql.append(statement)
sa.event.listen(sqlite_engine, "before_cursor_execute", _record_sql)
try:
monkeypatch.setattr(migration_module, "Migration", FakeMigration)
service = migration_module.LegacyModelTypeMigrationService(
engine=sqlite_engine,
apply=False,
tables=("provider_models", "tenant_default_models"),
model_types=(ModelType.LLM,),
)
service.migrate()
finally:
sa.event.remove(sqlite_engine, "before_cursor_execute", _record_sql)
assert seen_runs == [
(provider_tenant_id, ("provider_models",), (ModelType.LLM,)),
(default_tenant_id, ("tenant_default_models",), (ModelType.LLM,)),
]
normalized_statements = [" ".join(statement.lower().split()) for statement in executed_sql]
discovery_statements = [statement for statement in normalized_statements if statement.startswith("select")]
table_names = ("provider_models", "tenant_default_models")
table_discovery_statements = [
statement
for statement in discovery_statements
if any(f" from {table_name} " in f" {statement} " for table_name in table_names)
]
assert [statement for statement in discovery_statements if " from tenants " in f" {statement} "] == []
assert [statement for statement in discovery_statements if " union " in f" {statement} "] == []
assert [
next(table_name for table_name in table_names if f" from {table_name} " in f" {statement} ")
for statement in table_discovery_statements
] == list(table_names)
def test_service_migrate_without_tenant_ids_filters_provider_model_tenants_by_selected_model_types(
migration_module,
sqlite_engine: sa.Engine,
monkeypatch: pytest.MonkeyPatch,
) -> None:
create_minimal_legacy_model_type_schema(sqlite_engine)
llm_tenant_id = "00000000-0000-0000-0000-000000000411"
embedding_tenant_id = "00000000-0000-0000-0000-000000000422"
empty_tenant_id = "00000000-0000-0000-0000-000000000433"
for tenant_id in (llm_tenant_id, embedding_tenant_id, empty_tenant_id):
_insert_tenant(sqlite_engine, tenant_id=tenant_id)
created_at = datetime(2025, 1, 2, 12, 0, 0)
updated_at = created_at + timedelta(minutes=1)
_insert_provider_model(
sqlite_engine,
row_id="30000000-0000-0000-0000-000000000411",
tenant_id=llm_tenant_id,
provider_name="openai",
model_name="gpt-4o-mini",
model_type="text-generation",
credential_id=None,
created_at=created_at,
updated_at=updated_at,
)
_insert_provider_model(
sqlite_engine,
row_id="30000000-0000-0000-0000-000000000422",
tenant_id=embedding_tenant_id,
provider_name="openai",
model_name="text-embedding-3-large",
model_type="embeddings",
credential_id=None,
created_at=created_at,
updated_at=updated_at,
)
seen_runs: list[tuple[str, tuple[str, ...], tuple[ModelType, ...]]] = []
class FakeMigration:
def __init__(
self,
*,
tenant_id: str,
engine: sa.Engine,
apply: bool,
output: io.TextIOBase,
model_types: tuple[ModelType, ...],
orm_models: tuple[type[object], ...],
) -> None:
assert engine is sqlite_engine
assert apply is False
seen_runs.append((tenant_id, tuple(model.__table__.name for model in orm_models), model_types))
def run(self) -> None:
return None
monkeypatch.setattr(migration_module, "Migration", FakeMigration)
service = migration_module.LegacyModelTypeMigrationService(
engine=sqlite_engine,
apply=False,
tables=("provider_models",),
model_types=(ModelType.LLM,),
)
service.migrate()
assert seen_runs == [
(llm_tenant_id, ("provider_models",), (ModelType.LLM,)),
]
def test_service_migrate_without_tenant_ids_discovers_all_load_balancing_tenants_for_simpler_table_scoped_query(
migration_module,
sqlite_engine: sa.Engine,
monkeypatch: pytest.MonkeyPatch,
) -> None:
create_minimal_legacy_model_type_schema(sqlite_engine)
inherit_llm_tenant_id = "00000000-0000-0000-0000-000000000511"
inherit_embedding_tenant_id = "00000000-0000-0000-0000-000000000522"
empty_tenant_id = "00000000-0000-0000-0000-000000000533"
for tenant_id in (inherit_llm_tenant_id, inherit_embedding_tenant_id, empty_tenant_id):
_insert_tenant(sqlite_engine, tenant_id=tenant_id)
created_at = datetime(2025, 1, 3, 12, 0, 0)
updated_at = created_at + timedelta(minutes=1)
_insert_load_balancing_model_config(
sqlite_engine,
row_id="40000000-0000-0000-0000-000000000511",
tenant_id=inherit_llm_tenant_id,
provider_name="openai",
model_name="gpt-4o-mini",
model_type=ModelType.LLM.value,
name="__inherit__",
encrypted_config=json.dumps({"api_key": "inherit-llm"}),
credential_id="50000000-0000-0000-0000-000000000511",
enabled=True,
created_at=created_at,
updated_at=updated_at,
)
_insert_load_balancing_model_config(
sqlite_engine,
row_id="40000000-0000-0000-0000-000000000522",
tenant_id=inherit_embedding_tenant_id,
provider_name="openai",
model_name="text-embedding-3-large",
model_type=ModelType.TEXT_EMBEDDING.value,
name="__inherit__",
encrypted_config=json.dumps({"api_key": "inherit-embedding"}),
credential_id="50000000-0000-0000-0000-000000000522",
enabled=True,
created_at=created_at,
updated_at=updated_at,
)
seen_runs: list[tuple[str, tuple[str, ...], tuple[ModelType, ...]]] = []
class FakeMigration:
def __init__(
self,
*,
tenant_id: str,
engine: sa.Engine,
apply: bool,
output: io.TextIOBase,
model_types: tuple[ModelType, ...],
orm_models: tuple[type[object], ...],
) -> None:
assert engine is sqlite_engine
assert apply is False
seen_runs.append((tenant_id, tuple(model.__table__.name for model in orm_models), model_types))
def run(self) -> None:
return None
monkeypatch.setattr(migration_module, "Migration", FakeMigration)
# Load-balancing tenant discovery is a deliberate exception: it scans the
# whole table so the discovery query stays easy to understand, even when
# the scheduled tenant set is wider than the selected model types.
service = migration_module.LegacyModelTypeMigrationService(
engine=sqlite_engine,
apply=False,
tables=("load_balancing_model_configs",),
model_types=(ModelType.LLM,),
)
service.migrate()
assert seen_runs == [
(inherit_llm_tenant_id, ("load_balancing_model_configs",), (ModelType.LLM,)),
(inherit_embedding_tenant_id, ("load_balancing_model_configs",), (ModelType.LLM,)),
]
def test_service_migrate_with_concurrency_greater_than_one_runs_tenants_in_parallel_without_changing_migration_scope(