mirror of
https://github.com/langgenius/dify.git
synced 2026-05-27 12:26:15 +08:00
feat(api): only load tenant required for each table
This commit is contained in:
@ -11,6 +11,14 @@ rewriting credential references in provider models and load-balancing configs be
|
||||
removing loser credential rows. `load_balancing_model_configs` stays mostly row-level,
|
||||
but it first deduplicates `name="__inherit__"` rows by business key before it
|
||||
canonicalizes the remaining legacy rows independently with row-level cache cleanup.
|
||||
|
||||
Tenant scheduling has two modes. When callers provide an explicit tenant list, the
|
||||
service preserves the original tenant-scoped execution model and runs all selected tables
|
||||
for each tenant. When callers omit `tenant_ids`, the service discovers tenant
|
||||
ids per table and then runs only that table for the discovered tenants. Most
|
||||
tables keep the active `model_types` filter in the discovery query, while
|
||||
`load_balancing_model_configs` deliberately uses a whole-table tenant scan so
|
||||
that query stays easy to understand.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -21,7 +29,7 @@ import sys
|
||||
import threading
|
||||
import traceback
|
||||
import uuid
|
||||
from collections.abc import Iterable, Iterator, Sequence
|
||||
from collections.abc import Iterable, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
@ -36,7 +44,7 @@ from sqlalchemy.sql import select
|
||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import LoadBalancingModelConfig, ProviderModel, ProviderModelSetting, Tenant, TenantDefaultModel
|
||||
from models import LoadBalancingModelConfig, ProviderModel, ProviderModelSetting, TenantDefaultModel
|
||||
from models.base import TypeBase
|
||||
from models.provider import ProviderModelCredential
|
||||
|
||||
@ -291,6 +299,21 @@ _LOCK_TIMEOUT_FALLBACK_MESSAGES: tuple[str, ...] = (
|
||||
_RAW_MODEL_TYPE_COLUMN = "_raw_model_type"
|
||||
|
||||
|
||||
def _selected_legacy_values(model_types: Sequence[ModelType]) -> list[str]:
|
||||
legacy_values: list[str] = []
|
||||
for model_type in model_types:
|
||||
legacy_values.extend(_CANONICAL_TO_LEGACY[model_type])
|
||||
return legacy_values
|
||||
|
||||
|
||||
def _selected_model_type_values(model_types: Sequence[ModelType]) -> list[str]:
|
||||
model_type_values: list[str] = []
|
||||
for model_type in model_types:
|
||||
model_type_values.append(model_type.value)
|
||||
model_type_values.extend(_CANONICAL_TO_LEGACY[model_type])
|
||||
return list(dict.fromkeys(model_type_values))
|
||||
|
||||
|
||||
def _session_factory(engine: sa.Engine) -> Session:
|
||||
return Session(bind=engine, expire_on_commit=False)
|
||||
|
||||
@ -363,6 +386,12 @@ class LegacyModelTypeMigrationService:
|
||||
`provider_model_credentials` is selected, that migration also rewrites references in
|
||||
`provider_models` and `load_balancing_model_configs`. Tenant migrations can run in a
|
||||
thread pool; JSONL output remains line-safe through a shared synchronized writer.
|
||||
|
||||
If `tenant_ids` is omitted, tenant discovery becomes table-scoped: each selected ORM
|
||||
model loads its own tenant ids, then only that table is dispatched for those tenants.
|
||||
Most tables keep the active model-type filter in discovery, while
|
||||
`load_balancing_model_configs` intentionally uses the whole table so the tenant query
|
||||
stays simple. This still avoids merging tenant ids across unrelated tables.
|
||||
"""
|
||||
|
||||
_engine: sa.Engine
|
||||
@ -426,22 +455,51 @@ class LegacyModelTypeMigrationService:
|
||||
return tuple(ordered_models)
|
||||
|
||||
def migrate(self) -> None:
|
||||
tenant_ids = tuple(self._iter_tenant_ids())
|
||||
output = _ThreadSafeLineWriter(self._output)
|
||||
if self._tenant_ids is not None:
|
||||
self._migrate_explicit_tenants(output)
|
||||
return
|
||||
|
||||
self._migrate_tables_with_discovered_tenants(output)
|
||||
|
||||
def _migrate_explicit_tenants(self, output: io.TextIOBase) -> None:
|
||||
tenant_ids = self._tenant_ids
|
||||
if not tenant_ids:
|
||||
return
|
||||
|
||||
output = _ThreadSafeLineWriter(self._output)
|
||||
self._run_migrations_for_tenants(tenant_ids, self._orm_models, output)
|
||||
|
||||
def _migrate_tables_with_discovered_tenants(self, output: io.TextIOBase) -> None:
|
||||
for orm_model in self._orm_models:
|
||||
tenant_ids = self._load_tenant_ids_for_model(orm_model)
|
||||
if not tenant_ids:
|
||||
continue
|
||||
self._run_migrations_for_tenants(tenant_ids, (orm_model,), output)
|
||||
|
||||
def _run_migrations_for_tenants(
|
||||
self,
|
||||
tenant_ids: Sequence[str],
|
||||
orm_models: Sequence[ORMModel],
|
||||
output: io.TextIOBase,
|
||||
) -> None:
|
||||
if self._concurrency == 1 or len(tenant_ids) == 1:
|
||||
for tenant_id in tenant_ids:
|
||||
self._run_tenant_migration(tenant_id, output)
|
||||
self._run_tenant_migration(tenant_id, orm_models, output)
|
||||
return
|
||||
|
||||
with ThreadPoolExecutor(max_workers=min(self._concurrency, len(tenant_ids))) as executor:
|
||||
futures = [executor.submit(self._run_tenant_migration, tenant_id, output) for tenant_id in tenant_ids]
|
||||
futures = [
|
||||
executor.submit(self._run_tenant_migration, tenant_id, orm_models, output) for tenant_id in tenant_ids
|
||||
]
|
||||
for future in as_completed(futures):
|
||||
future.result()
|
||||
|
||||
def _run_tenant_migration(self, tenant_id: str, output: io.TextIOBase) -> None:
|
||||
def _run_tenant_migration(
|
||||
self,
|
||||
tenant_id: str,
|
||||
orm_models: Sequence[ORMModel],
|
||||
output: io.TextIOBase,
|
||||
) -> None:
|
||||
"""
|
||||
Execute one tenant migration with the shared, line-synchronized output stream.
|
||||
"""
|
||||
@ -452,18 +510,88 @@ class LegacyModelTypeMigrationService:
|
||||
apply=self._apply,
|
||||
output=output,
|
||||
model_types=self._model_types,
|
||||
orm_models=self._orm_models,
|
||||
orm_models=orm_models,
|
||||
).run()
|
||||
|
||||
def _iter_tenant_ids(self) -> Iterator[str]:
|
||||
if self._tenant_ids is not None:
|
||||
yield from self._tenant_ids
|
||||
return
|
||||
def _load_tenant_ids_for_model(self, orm_model: ORMModel) -> tuple[str, ...]:
|
||||
"""
|
||||
Discover only the tenants that have candidate rows for the current table.
|
||||
|
||||
In automatic tenant mode we keep discovery table-scoped so large shared tenant
|
||||
populations do not force empty work for unrelated tables. Most table queries
|
||||
still apply the active `model_types` filter before scheduling migrations, while
|
||||
`load_balancing_model_configs` intentionally trades a wider tenant set for a
|
||||
simpler discovery query.
|
||||
"""
|
||||
|
||||
legacy_model_type_values = _selected_legacy_values(self._model_types)
|
||||
with _session_factory(self._engine) as session:
|
||||
tenant_ids = session.execute(select(Tenant.id).order_by(Tenant.id.asc())).scalars().all()
|
||||
if orm_model is ProviderModel:
|
||||
tenant_ids = (
|
||||
session.execute(
|
||||
select(ProviderModel.tenant_id)
|
||||
.where(sa.type_coerce(ProviderModel.model_type, sa.String()).in_(legacy_model_type_values))
|
||||
.distinct()
|
||||
.order_by(ProviderModel.tenant_id.asc())
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
elif orm_model is TenantDefaultModel:
|
||||
tenant_ids = (
|
||||
session.execute(
|
||||
select(TenantDefaultModel.tenant_id)
|
||||
.where(sa.type_coerce(TenantDefaultModel.model_type, sa.String()).in_(legacy_model_type_values))
|
||||
.distinct()
|
||||
.order_by(TenantDefaultModel.tenant_id.asc())
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
elif orm_model is ProviderModelSetting:
|
||||
tenant_ids = (
|
||||
session.execute(
|
||||
select(ProviderModelSetting.tenant_id)
|
||||
.where(
|
||||
sa.type_coerce(ProviderModelSetting.model_type, sa.String()).in_(legacy_model_type_values)
|
||||
)
|
||||
.distinct()
|
||||
.order_by(ProviderModelSetting.tenant_id.asc())
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
elif orm_model is LoadBalancingModelConfig:
|
||||
# Deliberately discover tenants from the whole table so the query stays
|
||||
# easier to understand than the legacy/canonical mixed-row filter.
|
||||
tenant_ids = (
|
||||
session.execute(
|
||||
select(LoadBalancingModelConfig.tenant_id)
|
||||
.distinct()
|
||||
.order_by(LoadBalancingModelConfig.tenant_id.asc())
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
elif orm_model is ProviderModelCredential:
|
||||
tenant_ids = (
|
||||
session.execute(
|
||||
select(ProviderModelCredential.tenant_id)
|
||||
.where(
|
||||
sa.type_coerce(ProviderModelCredential.model_type, sa.String()).in_(
|
||||
legacy_model_type_values
|
||||
)
|
||||
)
|
||||
.distinct()
|
||||
.order_by(ProviderModelCredential.tenant_id.asc())
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unsupported orm model: {orm_model}")
|
||||
|
||||
yield from tenant_ids
|
||||
return tuple(tenant_ids)
|
||||
|
||||
|
||||
class Migration:
|
||||
@ -532,17 +660,10 @@ class Migration:
|
||||
)
|
||||
|
||||
def _selected_legacy_values(self) -> list[str]:
|
||||
legacy_values: list[str] = []
|
||||
for model_type in self._model_types:
|
||||
legacy_values.extend(_CANONICAL_TO_LEGACY[model_type])
|
||||
return legacy_values
|
||||
return _selected_legacy_values(self._model_types)
|
||||
|
||||
def _selected_model_type_values(self) -> list[str]:
|
||||
model_type_values: list[str] = []
|
||||
for model_type in self._model_types:
|
||||
model_type_values.append(model_type.value)
|
||||
model_type_values.extend(_CANONICAL_TO_LEGACY[model_type])
|
||||
return list(dict.fromkeys(model_type_values))
|
||||
return _selected_model_type_values(self._model_types)
|
||||
|
||||
def _allowed_values_for_canonical_model_type(self, canonical_model_type: ModelType) -> tuple[str, ...]:
|
||||
return (*_CANONICAL_TO_LEGACY[canonical_model_type], canonical_model_type.value)
|
||||
|
||||
@ -17,6 +17,7 @@ from click.testing import CliRunner
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from models.account import Tenant
|
||||
from models.enums import CredentialSourceType
|
||||
from models.provider import ProviderModel
|
||||
from tests.helpers.legacy_model_type_migration import (
|
||||
@ -24,6 +25,7 @@ from tests.helpers.legacy_model_type_migration import (
|
||||
LEGACY_TO_CANONICAL,
|
||||
assert_tenant_rows_use_only_canonical_model_types,
|
||||
count_rows,
|
||||
create_minimal_legacy_model_type_schema,
|
||||
fetch_table_rows,
|
||||
seed_legacy_model_type_dirty_data,
|
||||
snapshot_legacy_model_type_state,
|
||||
@ -196,6 +198,18 @@ def _insert_provider_model(
|
||||
)
|
||||
|
||||
|
||||
def _insert_tenant(engine: sa.Engine, *, tenant_id: str) -> None:
|
||||
with engine.begin() as conn:
|
||||
conn.execute(
|
||||
Tenant.__table__.insert().values(
|
||||
id=tenant_id,
|
||||
name=f"Tenant {tenant_id}",
|
||||
plan="basic",
|
||||
status="normal",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _insert_tenant_default_model(
|
||||
engine: sa.Engine,
|
||||
*,
|
||||
@ -509,7 +523,7 @@ def test_service_migrate_batches_by_tenant_respects_selected_tables_without_reve
|
||||
migration_module,
|
||||
sqlite_engine: sa.Engine,
|
||||
) -> None:
|
||||
seen_runs: list[dict[str, object]] = []
|
||||
seen_runs: list[tuple[str, tuple[str, ...], tuple[ModelType, ...]]] = []
|
||||
|
||||
class FakeMigration:
|
||||
def __init__(
|
||||
@ -522,18 +536,12 @@ def test_service_migrate_batches_by_tenant_respects_selected_tables_without_reve
|
||||
model_types: tuple[ModelType, ...],
|
||||
orm_models: tuple[type[object], ...],
|
||||
) -> None:
|
||||
seen_runs.append(
|
||||
{
|
||||
"tenant_id": tenant_id,
|
||||
"engine": engine,
|
||||
"apply": apply,
|
||||
"model_types": model_types,
|
||||
"table_names": tuple(model.__table__.name for model in orm_models),
|
||||
}
|
||||
)
|
||||
assert engine is sqlite_engine
|
||||
assert apply is False
|
||||
seen_runs.append((tenant_id, tuple(model.__table__.name for model in orm_models), model_types))
|
||||
|
||||
def run(self) -> None:
|
||||
seen_runs.append({"run": True})
|
||||
return None
|
||||
|
||||
monkeypatch = pytest.MonkeyPatch()
|
||||
try:
|
||||
@ -542,7 +550,7 @@ def test_service_migrate_batches_by_tenant_respects_selected_tables_without_reve
|
||||
engine=sqlite_engine,
|
||||
apply=False,
|
||||
concurrency=1,
|
||||
tables=("provider_models",),
|
||||
tables=("provider_models", "tenant_default_models"),
|
||||
model_types=(ModelType.LLM,),
|
||||
tenant_ids=("tenant-alpha", "tenant-beta"),
|
||||
)
|
||||
@ -551,11 +559,267 @@ def test_service_migrate_batches_by_tenant_respects_selected_tables_without_reve
|
||||
finally:
|
||||
monkeypatch.undo()
|
||||
|
||||
init_calls = [call for call in seen_runs if "tenant_id" in call]
|
||||
assert [call["tenant_id"] for call in init_calls] == ["tenant-alpha", "tenant-beta"]
|
||||
for call in init_calls:
|
||||
assert tuple(cast(tuple[str, ...], call["table_names"])) == ("provider_models",)
|
||||
assert call["model_types"] == (ModelType.LLM,)
|
||||
assert seen_runs == [
|
||||
("tenant-alpha", ("provider_models", "tenant_default_models"), (ModelType.LLM,)),
|
||||
("tenant-beta", ("provider_models", "tenant_default_models"), (ModelType.LLM,)),
|
||||
]
|
||||
|
||||
|
||||
def test_service_migrate_without_tenant_ids_discovers_tenants_per_selected_table_without_querying_tenants(
|
||||
migration_module,
|
||||
sqlite_engine: sa.Engine,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
create_minimal_legacy_model_type_schema(sqlite_engine)
|
||||
provider_tenant_id = "00000000-0000-0000-0000-000000000111"
|
||||
default_tenant_id = "00000000-0000-0000-0000-000000000222"
|
||||
empty_tenant_id = "00000000-0000-0000-0000-000000000333"
|
||||
for tenant_id in (provider_tenant_id, default_tenant_id, empty_tenant_id):
|
||||
_insert_tenant(sqlite_engine, tenant_id=tenant_id)
|
||||
|
||||
created_at = datetime(2025, 1, 1, 12, 0, 0)
|
||||
updated_at = created_at + timedelta(minutes=1)
|
||||
_insert_provider_model(
|
||||
sqlite_engine,
|
||||
row_id="10000000-0000-0000-0000-000000000111",
|
||||
tenant_id=provider_tenant_id,
|
||||
provider_name="openai",
|
||||
model_name="gpt-4o-mini",
|
||||
model_type="text-generation",
|
||||
credential_id=None,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
)
|
||||
_insert_tenant_default_model(
|
||||
sqlite_engine,
|
||||
row_id="20000000-0000-0000-0000-000000000222",
|
||||
tenant_id=default_tenant_id,
|
||||
provider_name="openai",
|
||||
model_name="gpt-4o-mini",
|
||||
model_type="text-generation",
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
)
|
||||
|
||||
seen_runs: list[tuple[str, tuple[str, ...], tuple[ModelType, ...]]] = []
|
||||
executed_sql: list[str] = []
|
||||
|
||||
class FakeMigration:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
engine: sa.Engine,
|
||||
apply: bool,
|
||||
output: io.TextIOBase,
|
||||
model_types: tuple[ModelType, ...],
|
||||
orm_models: tuple[type[object], ...],
|
||||
) -> None:
|
||||
assert engine is sqlite_engine
|
||||
assert apply is False
|
||||
seen_runs.append((tenant_id, tuple(model.__table__.name for model in orm_models), model_types))
|
||||
|
||||
def run(self) -> None:
|
||||
return None
|
||||
|
||||
def _record_sql(
|
||||
conn: sa.engine.Connection,
|
||||
cursor: object,
|
||||
statement: str,
|
||||
parameters: object,
|
||||
context: object,
|
||||
executemany: bool,
|
||||
) -> None:
|
||||
del conn, cursor, parameters, context, executemany
|
||||
executed_sql.append(statement)
|
||||
|
||||
sa.event.listen(sqlite_engine, "before_cursor_execute", _record_sql)
|
||||
try:
|
||||
monkeypatch.setattr(migration_module, "Migration", FakeMigration)
|
||||
service = migration_module.LegacyModelTypeMigrationService(
|
||||
engine=sqlite_engine,
|
||||
apply=False,
|
||||
tables=("provider_models", "tenant_default_models"),
|
||||
model_types=(ModelType.LLM,),
|
||||
)
|
||||
|
||||
service.migrate()
|
||||
finally:
|
||||
sa.event.remove(sqlite_engine, "before_cursor_execute", _record_sql)
|
||||
|
||||
assert seen_runs == [
|
||||
(provider_tenant_id, ("provider_models",), (ModelType.LLM,)),
|
||||
(default_tenant_id, ("tenant_default_models",), (ModelType.LLM,)),
|
||||
]
|
||||
normalized_statements = [" ".join(statement.lower().split()) for statement in executed_sql]
|
||||
discovery_statements = [statement for statement in normalized_statements if statement.startswith("select")]
|
||||
table_names = ("provider_models", "tenant_default_models")
|
||||
table_discovery_statements = [
|
||||
statement
|
||||
for statement in discovery_statements
|
||||
if any(f" from {table_name} " in f" {statement} " for table_name in table_names)
|
||||
]
|
||||
|
||||
assert [statement for statement in discovery_statements if " from tenants " in f" {statement} "] == []
|
||||
assert [statement for statement in discovery_statements if " union " in f" {statement} "] == []
|
||||
assert [
|
||||
next(table_name for table_name in table_names if f" from {table_name} " in f" {statement} ")
|
||||
for statement in table_discovery_statements
|
||||
] == list(table_names)
|
||||
|
||||
|
||||
def test_service_migrate_without_tenant_ids_filters_provider_model_tenants_by_selected_model_types(
|
||||
migration_module,
|
||||
sqlite_engine: sa.Engine,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
create_minimal_legacy_model_type_schema(sqlite_engine)
|
||||
llm_tenant_id = "00000000-0000-0000-0000-000000000411"
|
||||
embedding_tenant_id = "00000000-0000-0000-0000-000000000422"
|
||||
empty_tenant_id = "00000000-0000-0000-0000-000000000433"
|
||||
for tenant_id in (llm_tenant_id, embedding_tenant_id, empty_tenant_id):
|
||||
_insert_tenant(sqlite_engine, tenant_id=tenant_id)
|
||||
|
||||
created_at = datetime(2025, 1, 2, 12, 0, 0)
|
||||
updated_at = created_at + timedelta(minutes=1)
|
||||
_insert_provider_model(
|
||||
sqlite_engine,
|
||||
row_id="30000000-0000-0000-0000-000000000411",
|
||||
tenant_id=llm_tenant_id,
|
||||
provider_name="openai",
|
||||
model_name="gpt-4o-mini",
|
||||
model_type="text-generation",
|
||||
credential_id=None,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
)
|
||||
_insert_provider_model(
|
||||
sqlite_engine,
|
||||
row_id="30000000-0000-0000-0000-000000000422",
|
||||
tenant_id=embedding_tenant_id,
|
||||
provider_name="openai",
|
||||
model_name="text-embedding-3-large",
|
||||
model_type="embeddings",
|
||||
credential_id=None,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
)
|
||||
|
||||
seen_runs: list[tuple[str, tuple[str, ...], tuple[ModelType, ...]]] = []
|
||||
|
||||
class FakeMigration:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
engine: sa.Engine,
|
||||
apply: bool,
|
||||
output: io.TextIOBase,
|
||||
model_types: tuple[ModelType, ...],
|
||||
orm_models: tuple[type[object], ...],
|
||||
) -> None:
|
||||
assert engine is sqlite_engine
|
||||
assert apply is False
|
||||
seen_runs.append((tenant_id, tuple(model.__table__.name for model in orm_models), model_types))
|
||||
|
||||
def run(self) -> None:
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(migration_module, "Migration", FakeMigration)
|
||||
service = migration_module.LegacyModelTypeMigrationService(
|
||||
engine=sqlite_engine,
|
||||
apply=False,
|
||||
tables=("provider_models",),
|
||||
model_types=(ModelType.LLM,),
|
||||
)
|
||||
|
||||
service.migrate()
|
||||
|
||||
assert seen_runs == [
|
||||
(llm_tenant_id, ("provider_models",), (ModelType.LLM,)),
|
||||
]
|
||||
|
||||
|
||||
def test_service_migrate_without_tenant_ids_discovers_all_load_balancing_tenants_for_simpler_table_scoped_query(
|
||||
migration_module,
|
||||
sqlite_engine: sa.Engine,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
create_minimal_legacy_model_type_schema(sqlite_engine)
|
||||
inherit_llm_tenant_id = "00000000-0000-0000-0000-000000000511"
|
||||
inherit_embedding_tenant_id = "00000000-0000-0000-0000-000000000522"
|
||||
empty_tenant_id = "00000000-0000-0000-0000-000000000533"
|
||||
for tenant_id in (inherit_llm_tenant_id, inherit_embedding_tenant_id, empty_tenant_id):
|
||||
_insert_tenant(sqlite_engine, tenant_id=tenant_id)
|
||||
|
||||
created_at = datetime(2025, 1, 3, 12, 0, 0)
|
||||
updated_at = created_at + timedelta(minutes=1)
|
||||
_insert_load_balancing_model_config(
|
||||
sqlite_engine,
|
||||
row_id="40000000-0000-0000-0000-000000000511",
|
||||
tenant_id=inherit_llm_tenant_id,
|
||||
provider_name="openai",
|
||||
model_name="gpt-4o-mini",
|
||||
model_type=ModelType.LLM.value,
|
||||
name="__inherit__",
|
||||
encrypted_config=json.dumps({"api_key": "inherit-llm"}),
|
||||
credential_id="50000000-0000-0000-0000-000000000511",
|
||||
enabled=True,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
)
|
||||
_insert_load_balancing_model_config(
|
||||
sqlite_engine,
|
||||
row_id="40000000-0000-0000-0000-000000000522",
|
||||
tenant_id=inherit_embedding_tenant_id,
|
||||
provider_name="openai",
|
||||
model_name="text-embedding-3-large",
|
||||
model_type=ModelType.TEXT_EMBEDDING.value,
|
||||
name="__inherit__",
|
||||
encrypted_config=json.dumps({"api_key": "inherit-embedding"}),
|
||||
credential_id="50000000-0000-0000-0000-000000000522",
|
||||
enabled=True,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
)
|
||||
|
||||
seen_runs: list[tuple[str, tuple[str, ...], tuple[ModelType, ...]]] = []
|
||||
|
||||
class FakeMigration:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
engine: sa.Engine,
|
||||
apply: bool,
|
||||
output: io.TextIOBase,
|
||||
model_types: tuple[ModelType, ...],
|
||||
orm_models: tuple[type[object], ...],
|
||||
) -> None:
|
||||
assert engine is sqlite_engine
|
||||
assert apply is False
|
||||
seen_runs.append((tenant_id, tuple(model.__table__.name for model in orm_models), model_types))
|
||||
|
||||
def run(self) -> None:
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(migration_module, "Migration", FakeMigration)
|
||||
# Load-balancing tenant discovery is a deliberate exception: it scans the
|
||||
# whole table so the discovery query stays easy to understand, even when
|
||||
# the scheduled tenant set is wider than the selected model types.
|
||||
service = migration_module.LegacyModelTypeMigrationService(
|
||||
engine=sqlite_engine,
|
||||
apply=False,
|
||||
tables=("load_balancing_model_configs",),
|
||||
model_types=(ModelType.LLM,),
|
||||
)
|
||||
|
||||
service.migrate()
|
||||
|
||||
assert seen_runs == [
|
||||
(inherit_llm_tenant_id, ("load_balancing_model_configs",), (ModelType.LLM,)),
|
||||
(inherit_embedding_tenant_id, ("load_balancing_model_configs",), (ModelType.LLM,)),
|
||||
]
|
||||
|
||||
|
||||
def test_service_migrate_with_concurrency_greater_than_one_runs_tenants_in_parallel_without_changing_migration_scope(
|
||||
|
||||
Reference in New Issue
Block a user