Compare commits

...

16 Commits

Author SHA1 Message Date
79bad45f55 refactor(migration): drop comment from LB collision key 2026-05-25 22:31:21 -07:00
90d6b3db77 test(migration): cover LB dedup collision case in credential merge 2026-05-25 22:14:28 -07:00
c39cfc14fb refactor(migration): clarify LB collision key includes model_name intentionally 2026-05-25 22:06:58 -07:00
2300497a8c fix(migration): dedup load_balancing_model_configs on credential merge
When provider_model_credentials are merged (loser deleted, winner kept),
rows in load_balancing_model_configs pointing to the loser credential are
rewritten to point to the winner. If the winner already has an LB row for
the same (provider_name, model_name, model_type), the rewrite would create
a duplicate — instead, detect the collision and delete the loser LB row.

Adds _LoadBalancingCredentialDeletePlan, extends
_ProviderModelCredentialGroupPlan with load_balancing_deletions, adds
_emit_load_balancing_reference_deletions, and wires it into the credential
group plan emit path.
2026-05-25 22:00:55 -07:00
b0b96d5e01 Merge branch 'main' into feat/model-type-migration-script 2026-05-25 21:21:05 -07:00
435ca8b9f1 test(api): fix broken tests 2026-05-26 11:21:34 +08:00
e21679d980 chore(api): fix type error 2026-05-26 11:00:51 +08:00
91a0a6d27a chore(api): fix lint issues 2026-05-26 10:49:29 +08:00
5ac44589d6 Merge remote-tracking branch 'upstream/main' into feat/model-type-migration-script-2 2026-05-26 10:12:48 +08:00
2fda1318be feat(api): support concurrency in mode type migration script 2026-05-26 10:00:55 +08:00
8cd3cf7c75 feat(api): allow specify JSON output 2026-05-26 09:45:04 +08:00
7aabb67441 chore(api): rollback changes to ext_logging 2026-05-26 09:35:03 +08:00
c9327ee666 chore(api): remove dev group from default groups 2026-05-26 09:28:24 +08:00
1a25acc140 chore(api): log after updating / deletion 2026-05-26 09:24:08 +08:00
79beadb8cf chore(api): minor adjustment 2026-05-26 04:39:30 +08:00
1b70cdd51d feat(api): introduce model type migration script
Assisted-By: Codex:GPT-5.4
2026-05-26 03:27:19 +08:00
11 changed files with 4537 additions and 5 deletions

View File

@ -223,10 +223,11 @@ def initialize_extensions(app: DifyApp):
def create_migrations_app() -> DifyApp:
app = create_flask_app_with_configs()
from extensions import ext_database, ext_migrate
from extensions import ext_commands, ext_database, ext_migrate
# Initialize only required extensions
ext_database.init_app(app)
ext_migrate.init_app(app)
ext_commands.init_app(app)
return app

View File

@ -3,6 +3,7 @@ CLI command modules extracted from `commands.py`.
"""
from .account import create_tenant, reset_email, reset_password
from .data_migrate import data_migrate, legacy_model_types
from .plugin import (
extract_plugins,
extract_unique_plugins,
@ -44,6 +45,7 @@ __all__ = [
"clear_orphaned_file_records",
"convert_to_agent_apps",
"create_tenant",
"data_migrate",
"delete_archived_workflow_runs",
"export_app_messages",
"extract_plugins",
@ -52,6 +54,7 @@ __all__ = [
"fix_app_site_missing",
"install_plugins",
"install_rag_pipeline_plugins",
"legacy_model_types",
"migrate_annotation_vector_database",
"migrate_data_for_plugin",
"migrate_knowledge_vector_database",

View File

@ -0,0 +1,169 @@
import io
import os
import sys
from contextlib import AbstractContextManager, nullcontext
from pathlib import Path
from typing import cast
import click
from extensions.ext_database import db
from graphon.model_runtime.entities.model_entities import ModelType
from services.legacy_model_type_migration import (
VALID_TABLE_NAMES,
LegacyModelTypeMigrationService,
load_tenant_ids_from_file,
)
_SUPPORTED_MODEL_TYPE_CHOICES = (
ModelType.LLM.value,
ModelType.TEXT_EMBEDDING.value,
ModelType.RERANK.value,
)
_DEFAULT_CONCURRENCY = os.cpu_count() or 1
def _normalize_multi_value_option(
values: tuple[str, ...],
*,
valid_values: tuple[str, ...],
option_name: str,
) -> tuple[str, ...]:
normalized_values: list[str] = []
seen_values: set[str] = set()
for value in values:
for item in value.split(","):
normalized_item = item.strip()
if not normalized_item:
continue
if normalized_item not in valid_values:
raise click.BadParameter(
f"invalid value '{normalized_item}'. valid values: {', '.join(valid_values)}",
param_hint=option_name,
)
if normalized_item in seen_values:
continue
seen_values.add(normalized_item)
normalized_values.append(normalized_item)
return tuple(normalized_values)
@click.group(
"data-migrate",
help="Online data migration commands.",
)
def data_migrate() -> None:
"""Namespace for production data migration commands."""
@click.command(
"legacy-model-types",
help=(
"Migrate legacy provider model_type values to canonical values. "
"Default is dry-run and emits JSON lines only. "
"If --tables includes provider_model_credentials, the command may also update "
"provider_models and load_balancing_model_configs references so merged credentials stay reachable."
),
)
@click.option(
"--apply",
is_flag=True,
default=False,
help="Apply the migration. Default is dry-run.",
)
@click.option(
"--tables",
"tables",
multiple=True,
type=str,
help=(
"Limit model_type migration to specific tables. Accepts comma-separated values or repeated flags. "
"When provider_model_credentials is selected, provider_models and "
"load_balancing_model_configs may also be updated for credential reference rewrites."
"Default to: "
),
)
@click.option(
"--model-types",
"model_types",
multiple=True,
type=str,
help=(
"Canonical model types to migrate. Accepts comma-separated values or repeated flags. "
"Defaults to: `llm,text-embedding,rerank`"
),
)
@click.option(
"--tenant-id-file",
type=click.Path(exists=True, dir_okay=False, readable=True, resolve_path=True),
help="Optional file containing tenant ids, one per line.",
)
@click.option(
"--output",
type=click.Path(dir_okay=False, resolve_path=True, path_type=Path),
help="Optional file path for JSON lines event logs. Defaults to stdout.",
)
@click.option(
"--concurrency",
type=click.IntRange(min=1),
default=_DEFAULT_CONCURRENCY,
show_default=True,
help="Number of tenant-level worker threads to run in parallel.",
)
def legacy_model_types(
apply: bool,
tables: tuple[str, ...],
model_types: tuple[str, ...],
tenant_id_file: str | None,
output: Path | None,
concurrency: int = _DEFAULT_CONCURRENCY,
) -> None:
"""
Migrate legacy provider-related model_type values and emit JSON lines events.
"""
normalized_tables = _normalize_multi_value_option(
tables,
valid_values=VALID_TABLE_NAMES,
option_name="--tables",
)
normalized_model_types = _normalize_multi_value_option(
model_types,
valid_values=_SUPPORTED_MODEL_TYPE_CHOICES,
option_name="--model-types",
)
selected_model_types = (
tuple(ModelType.value_of(model_type) for model_type in normalized_model_types)
if normalized_model_types
else (
ModelType.LLM,
ModelType.TEXT_EMBEDDING,
ModelType.RERANK,
)
)
tenant_ids = load_tenant_ids_from_file(tenant_id_file) if tenant_id_file else None
output_context: AbstractContextManager[io.TextIOBase]
if output is None:
output_context = nullcontext(cast(io.TextIOBase, sys.stdout))
else:
try:
output_context = output.open("w", encoding="utf-8")
except OSError as exc:
raise click.ClickException(f"failed to open output file '{output}': {exc.strerror or exc}") from exc
with output_context as output_stream:
LegacyModelTypeMigrationService(
engine=db.engine,
apply=apply,
concurrency=concurrency,
output=cast(io.TextIOBase, output_stream),
tables=normalized_tables or None,
model_types=selected_model_types,
tenant_ids=tenant_ids,
).migrate()
data_migrate.add_command(legacy_model_types)

View File

@ -12,6 +12,7 @@ def init_app(app: DifyApp):
clear_orphaned_file_records,
convert_to_agent_apps,
create_tenant,
data_migrate,
delete_archived_workflow_runs,
export_app_messages,
extract_plugins,
@ -44,6 +45,7 @@ def init_app(app: DifyApp):
convert_to_agent_apps,
add_qdrant_index,
create_tenant,
data_migrate,
upgrade_db,
fix_app_site_missing,
migrate_data_for_plugin,

View File

@ -102,10 +102,7 @@ dify-trace-weave = { workspace = true }
[tool.uv]
default-groups = ["storage", "tools", "vdb-all", "trace-all"]
package = false
override-dependencies = [
"litellm>=1.83.10,<2.0.0",
"pyarrow>=23.0.1,<24.0.0",
]
override-dependencies = ["litellm>=1.83.10,<2.0.0", "pyarrow>=23.0.1,<24.0.0"]
[dependency-groups]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
"""Shared test helpers for backend migration tests."""

View File

@ -0,0 +1,379 @@
from __future__ import annotations
import json
import uuid
from dataclasses import dataclass
from datetime import datetime, timedelta
from uuid import uuid4
import sqlalchemy as sa
from sqlalchemy.engine import Engine
from models.account import Tenant
from models.enums import CredentialSourceType
from models.provider import (
LoadBalancingModelConfig,
ProviderModel,
ProviderModelCredential,
ProviderModelSetting,
TenantDefaultModel,
)
LEGACY_TO_CANONICAL: dict[str, str] = {
"text-generation": "llm",
"embeddings": "text-embedding",
"reranking": "rerank",
}
UNCHANGED_MODEL_TYPES: tuple[str, ...] = ("speech2text", "moderation", "tts")
ALL_TABLE_NAMES: tuple[str, ...] = (
ProviderModel.__tablename__,
TenantDefaultModel.__tablename__,
ProviderModelSetting.__tablename__,
LoadBalancingModelConfig.__tablename__,
ProviderModelCredential.__tablename__,
)
DEFAULT_PRIMARY_TENANT_ID = "00000000-0000-0000-0000-000000000101"
DEFAULT_SECONDARY_TENANT_ID = "00000000-0000-0000-0000-000000000202"
@dataclass(frozen=True, slots=True)
class DirtyTenantFixture:
tenant_id: str
winner_credential_id: str
loser_credential_id: str
distinct_credential_id: str
provider_model_id: str
load_balancing_config_id: str
winner_load_balancing_config_id: str
provider_model_setting_id: str
tenant_default_model_id: str
embedding_provider_model_id: str
embedding_setting_id: str
loser_credential_name: str
distinct_credential_name: str
loser_encrypted_config: str
winner_encrypted_config: str
@dataclass(frozen=True, slots=True)
class DirtyDataFixture:
primary: DirtyTenantFixture
secondary: DirtyTenantFixture
def create_minimal_legacy_model_type_schema(engine: Engine) -> None:
metadata = Tenant.__table__.metadata
metadata.create_all(
engine,
tables=[
Tenant.__table__,
ProviderModel.__table__,
TenantDefaultModel.__table__,
ProviderModelSetting.__table__,
LoadBalancingModelConfig.__table__,
ProviderModelCredential.__table__,
],
checkfirst=True,
)
def drop_minimal_legacy_model_type_schema(engine: Engine) -> None:
metadata = Tenant.__table__.metadata
metadata.drop_all(
engine,
tables=[
LoadBalancingModelConfig.__table__,
ProviderModelSetting.__table__,
TenantDefaultModel.__table__,
ProviderModel.__table__,
ProviderModelCredential.__table__,
Tenant.__table__,
],
checkfirst=True,
)
def seed_legacy_model_type_dirty_data(
engine: Engine,
*,
primary_tenant_id: str = DEFAULT_PRIMARY_TENANT_ID,
secondary_tenant_id: str = DEFAULT_SECONDARY_TENANT_ID,
) -> DirtyDataFixture:
create_minimal_legacy_model_type_schema(engine)
primary = _seed_tenant(engine, tenant_id=primary_tenant_id, provider_name="openai")
secondary = _seed_tenant(engine, tenant_id=secondary_tenant_id, provider_name="openai")
return DirtyDataFixture(primary=primary, secondary=secondary)
def snapshot_legacy_model_type_state(engine: Engine) -> dict[str, list[dict[str, object]]]:
snapshots: dict[str, list[dict[str, object]]] = {}
for table_name in ALL_TABLE_NAMES:
snapshots[table_name] = fetch_table_rows(engine, table_name)
return snapshots
def fetch_table_rows(
engine: Engine,
table_name: str,
*,
tenant_id: str | None = None,
) -> list[dict[str, object]]:
sql = f"SELECT * FROM {table_name}"
params: dict[str, object] = {}
if tenant_id is not None:
sql += " WHERE tenant_id = :tenant_id"
params["tenant_id"] = tenant_id
sql += " ORDER BY id ASC"
with engine.begin() as conn:
rows = conn.execute(sa.text(sql), params).mappings().all()
result: list[dict[str, object]] = []
for row in rows:
normalized = dict(row)
for key, value in normalized.items():
if isinstance(value, datetime):
normalized[key] = value.isoformat()
elif isinstance(value, uuid.UUID):
normalized[key] = str(value)
result.append(normalized)
return result
def fetch_model_types_for_tenant(engine: Engine, table_name: str, tenant_id: str) -> list[str]:
rows = fetch_table_rows(engine, table_name, tenant_id=tenant_id)
return [str(row["model_type"]) for row in rows]
def assert_tenant_rows_use_only_canonical_model_types(engine: Engine, tenant_id: str) -> None:
for table_name in ALL_TABLE_NAMES:
model_types = fetch_model_types_for_tenant(engine, table_name, tenant_id)
assert set(model_types) <= set(LEGACY_TO_CANONICAL.values()) | set(UNCHANGED_MODEL_TYPES), (
table_name,
model_types,
)
def count_rows(engine: Engine, table_name: str, *, tenant_id: str) -> int:
with engine.begin() as conn:
stmt = sa.text(f"SELECT COUNT(*) FROM {table_name} WHERE tenant_id = :tenant_id")
return int(conn.execute(stmt, {"tenant_id": tenant_id}).scalar_one())
def _seed_tenant(engine: Engine, *, tenant_id: str, provider_name: str) -> DirtyTenantFixture:
now = datetime(2025, 1, 1, 12, 0, 0)
winner_credential_id = str(uuid4())
loser_credential_id = str(uuid4())
distinct_credential_id = str(uuid4())
provider_model_id = str(uuid4())
load_balancing_config_id = str(uuid4())
provider_model_setting_id = str(uuid4())
tenant_default_model_id = str(uuid4())
embedding_provider_model_id = str(uuid4())
embedding_setting_id = str(uuid4())
loser_credential_name = f"{tenant_id}-shared"
distinct_credential_name = f"{tenant_id}-distinct"
winner_encrypted_config = json.dumps({"api_key": f"{tenant_id}-winner"})
loser_encrypted_config = json.dumps({"api_key": f"{tenant_id}-loser"})
distinct_encrypted_config = json.dumps({"api_key": f"{tenant_id}-distinct"})
with engine.begin() as conn:
conn.execute(
Tenant.__table__.insert().values(
id=tenant_id,
name=f"Tenant {tenant_id}",
plan="basic",
status="normal",
)
)
conn.execute(
sa.text(
"""
INSERT INTO provider_model_credentials
(
id, tenant_id, provider_name, model_name,
model_type, credential_name, encrypted_config,
created_at, updated_at
)
VALUES
(
:winner_id, :tenant_id, :provider_name, 'gpt-4o-mini',
'llm', :shared_name, :winner_config,
:created_at, :winner_updated_at
),
(
:loser_id, :tenant_id, :provider_name, 'gpt-4o-mini',
'text-generation', :shared_name, :loser_config,
:created_at, :loser_updated_at
),
(
:distinct_id, :tenant_id, :provider_name, 'gpt-4o-mini',
'text-generation', :distinct_name, :distinct_config,
:created_at, :distinct_updated_at
)
"""
),
{
"winner_id": winner_credential_id,
"loser_id": loser_credential_id,
"distinct_id": distinct_credential_id,
"tenant_id": tenant_id,
"provider_name": provider_name,
"shared_name": loser_credential_name,
"distinct_name": distinct_credential_name,
"winner_config": winner_encrypted_config,
"loser_config": loser_encrypted_config,
"distinct_config": distinct_encrypted_config,
"created_at": now - timedelta(days=2),
"winner_updated_at": now,
"loser_updated_at": now - timedelta(days=1),
"distinct_updated_at": now - timedelta(hours=12),
},
)
conn.execute(
sa.text(
"""
INSERT INTO provider_models
(
id, tenant_id, provider_name, model_name,
model_type, credential_id, is_valid,
created_at, updated_at
)
VALUES
(
:provider_model_id, :tenant_id, :provider_name, 'gpt-4o-mini',
'text-generation', :loser_id, :is_valid,
:created_at, :updated_at
),
(
:embedding_provider_model_id, :tenant_id, :provider_name, 'text-embedding-3-large',
'embeddings', NULL, :is_valid,
:created_at, :updated_at
)
"""
),
{
"provider_model_id": provider_model_id,
"embedding_provider_model_id": embedding_provider_model_id,
"tenant_id": tenant_id,
"provider_name": provider_name,
"loser_id": loser_credential_id,
"is_valid": True,
"created_at": now - timedelta(days=2),
"updated_at": now - timedelta(hours=6),
},
)
conn.execute(
sa.text(
"""
INSERT INTO tenant_default_models
(id, tenant_id, provider_name, model_name, model_type, created_at, updated_at)
VALUES
(
:tenant_default_model_id, :tenant_id, :provider_name, 'gpt-4o-mini',
'text-generation', :created_at, :updated_at
)
"""
),
{
"tenant_default_model_id": tenant_default_model_id,
"tenant_id": tenant_id,
"provider_name": provider_name,
"created_at": now - timedelta(days=2),
"updated_at": now - timedelta(hours=4),
},
)
conn.execute(
sa.text(
"""
INSERT INTO provider_model_settings
(
id, tenant_id, provider_name, model_name,
model_type, enabled, load_balancing_enabled,
created_at, updated_at
)
VALUES
(
:provider_model_setting_id, :tenant_id, :provider_name, 'gpt-4o-mini',
'text-generation', :enabled, :load_balancing_enabled,
:created_at, :updated_at
),
(
:embedding_setting_id, :tenant_id, :provider_name, 'text-embedding-3-large',
'embeddings', :enabled, :embedding_load_balancing_enabled,
:created_at, :updated_at
)
"""
),
{
"provider_model_setting_id": provider_model_setting_id,
"embedding_setting_id": embedding_setting_id,
"tenant_id": tenant_id,
"provider_name": provider_name,
"enabled": True,
"load_balancing_enabled": True,
"embedding_load_balancing_enabled": False,
"created_at": now - timedelta(days=2),
"updated_at": now - timedelta(hours=3),
},
)
winner_load_balancing_config_id = str(uuid4())
conn.execute(
sa.text(
"""
INSERT INTO load_balancing_model_configs
(
id, tenant_id, provider_name, model_name, model_type,
name, encrypted_config, credential_id, credential_source_type,
enabled, created_at, updated_at
)
VALUES
(
:load_balancing_config_id, :tenant_id, :provider_name, 'gpt-4o-mini', 'text-generation',
:lb_name, :loser_config, :loser_id, :credential_source_type,
:enabled, :created_at, :updated_at
),
(
:lb_winner_id, :tenant_id, :provider_name, 'gpt-4o-mini', 'text-generation',
:winner_name, :winner_config, :winner_cred_id, :credential_source_type,
:enabled, :created_at, :winner_updated_at
)
"""
),
{
"load_balancing_config_id": load_balancing_config_id,
"lb_winner_id": winner_load_balancing_config_id,
"tenant_id": tenant_id,
"provider_name": provider_name,
"lb_name": loser_credential_name,
"loser_config": loser_encrypted_config,
"loser_id": loser_credential_id,
"winner_name": f"{tenant_id}-winner-lb",
"winner_config": winner_encrypted_config,
"winner_cred_id": winner_credential_id,
"credential_source_type": CredentialSourceType.CUSTOM_MODEL.value,
"enabled": True,
"created_at": now - timedelta(days=2),
"updated_at": now - timedelta(hours=2),
"winner_updated_at": now - timedelta(hours=1),
},
)
return DirtyTenantFixture(
tenant_id=tenant_id,
winner_credential_id=winner_credential_id,
loser_credential_id=loser_credential_id,
distinct_credential_id=distinct_credential_id,
provider_model_id=provider_model_id,
load_balancing_config_id=load_balancing_config_id,
winner_load_balancing_config_id=winner_load_balancing_config_id,
provider_model_setting_id=provider_model_setting_id,
tenant_default_model_id=tenant_default_model_id,
embedding_provider_model_id=embedding_provider_model_id,
embedding_setting_id=embedding_setting_id,
loser_credential_name=loser_credential_name,
distinct_credential_name=distinct_credential_name,
loser_encrypted_config=loser_encrypted_config,
winner_encrypted_config=winner_encrypted_config,
)

View File

@ -0,0 +1,82 @@
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
API_PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(API_PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(API_PROJECT_ROOT))
import sqlalchemy as sa
from tests.helpers.legacy_model_type_migration import (
DEFAULT_PRIMARY_TENANT_ID,
DEFAULT_SECONDARY_TENANT_ID,
create_minimal_legacy_model_type_schema,
seed_legacy_model_type_dirty_data,
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=(
"Seed dirty legacy model_type rows for manual migration experiments. "
"Example: uv run --project api python api/tests/seed_legacy_model_type_dirty_data.py "
"--db-url postgresql://postgres:postgres@127.0.0.1:5432/dify"
)
)
parser.add_argument("--db-url", required=True, help="SQLAlchemy database URL for the target database.")
parser.add_argument(
"--primary-tenant-id",
default=DEFAULT_PRIMARY_TENANT_ID,
help="Tenant that will contain the main conflict scenario.",
)
parser.add_argument(
"--secondary-tenant-id",
default=DEFAULT_SECONDARY_TENANT_ID,
help="Tenant used to verify tenant filtering behavior.",
)
parser.add_argument(
"--create-minimal-schema",
action="store_true",
help="Create the minimal tables needed for the seed when running against an empty scratch database.",
)
return parser.parse_args()
def main() -> int:
args = parse_args()
engine = sa.create_engine(args.db_url)
try:
if args.create_minimal_schema:
create_minimal_legacy_model_type_schema(engine)
fixture = seed_legacy_model_type_dirty_data(
engine,
primary_tenant_id=args.primary_tenant_id,
secondary_tenant_id=args.secondary_tenant_id,
)
finally:
engine.dispose()
print(
json.dumps(
{
"primary_tenant_id": fixture.primary.tenant_id,
"secondary_tenant_id": fixture.secondary.tenant_id,
"winner_credential_id": fixture.primary.winner_credential_id,
"loser_credential_id": fixture.primary.loser_credential_id,
"provider_model_id": fixture.primary.provider_model_id,
"load_balancing_config_id": fixture.primary.load_balancing_config_id,
},
indent=2,
sort_keys=True,
)
)
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@ -0,0 +1,127 @@
from __future__ import annotations
import importlib
import io
from collections.abc import Generator
import pytest
import sqlalchemy as sa
from tests.helpers.legacy_model_type_migration import (
assert_tenant_rows_use_only_canonical_model_types,
count_rows,
fetch_table_rows,
seed_legacy_model_type_dirty_data,
)
@pytest.fixture(scope="session")
def migration_module():
try:
return importlib.import_module("services.legacy_model_type_migration")
except ModuleNotFoundError as exc: # pragma: no cover - explicit TDD failure path
pytest.fail(
"services.legacy_model_type_migration is missing. "
"Implement LegacyModelTypeMigrationService before running these tests."
)
@pytest.fixture(params=("postgresql", "mysql"), scope="session")
def container_engine(request: pytest.FixtureRequest) -> Generator[tuple[str, sa.Engine], None, None]:
backend_name = request.param
if backend_name == "postgresql":
testcontainers_postgres = pytest.importorskip("testcontainers.postgres")
container = testcontainers_postgres.PostgresContainer("postgres:15-alpine")
else:
testcontainers_mysql = pytest.importorskip("testcontainers.mysql")
container = testcontainers_mysql.MySqlContainer("mysql:8.0")
container.start()
raw_url = container.get_connection_url()
engine_url = raw_url.replace("mysql://", "mysql+pymysql://", 1)
engine = sa.create_engine(engine_url)
try:
yield backend_name, engine
finally:
engine.dispose()
container.stop()
def test_legacy_model_type_migration_end_to_end_across_supported_backends(
migration_module,
container_engine: tuple[str, sa.Engine],
monkeypatch: pytest.MonkeyPatch,
) -> None:
backend_name, engine = container_engine
helper_module = importlib.import_module("tests.helpers.legacy_model_type_migration")
helper_module.drop_minimal_legacy_model_type_schema(engine)
fixture = seed_legacy_model_type_dirty_data(engine)
deleted_cache_keys: list[str] = []
def _record_delete(self) -> None:
deleted_cache_keys.append(self.cache_key)
monkeypatch.setattr(migration_module.ProviderCredentialsCache, "delete", _record_delete)
dry_run_output = io.StringIO()
migration_module.LegacyModelTypeMigrationService(
engine=engine,
apply=False,
output=dry_run_output,
tenant_ids=(fixture.primary.tenant_id,),
).migrate()
assert count_rows(engine, "provider_model_credentials", tenant_id=fixture.primary.tenant_id) == 3
assert deleted_cache_keys == []
apply_output = io.StringIO()
migration_module.LegacyModelTypeMigrationService(
engine=engine,
apply=True,
output=apply_output,
tenant_ids=(fixture.primary.tenant_id,),
).migrate()
first_apply_state = {
table_name: fetch_table_rows(engine, table_name, tenant_id=fixture.primary.tenant_id)
for table_name in (
"provider_models",
"tenant_default_models",
"provider_model_settings",
"load_balancing_model_configs",
"provider_model_credentials",
)
}
assert_tenant_rows_use_only_canonical_model_types(engine, fixture.primary.tenant_id)
assert count_rows(engine, "provider_model_credentials", tenant_id=fixture.primary.tenant_id) == 2
provider_model_row = next(
row for row in first_apply_state["provider_models"] if row["id"] == fixture.primary.provider_model_id
)
assert provider_model_row["credential_id"] == fixture.primary.winner_credential_id
credential_ids = {str(row["id"]) for row in first_apply_state["provider_model_credentials"]}
assert credential_ids == {
fixture.primary.winner_credential_id,
fixture.primary.distinct_credential_id,
}
lb_row = next(
row
for row in first_apply_state["load_balancing_model_configs"]
if row["id"] == fixture.primary.load_balancing_config_id
)
assert lb_row["credential_id"] == fixture.primary.winner_credential_id
assert lb_row["encrypted_config"] == fixture.primary.winner_encrypted_config
assert deleted_cache_keys, f"{backend_name} apply run should clear cache keys"
migration_module.LegacyModelTypeMigrationService(
engine=engine,
apply=True,
output=io.StringIO(),
tenant_ids=(fixture.primary.tenant_id,),
).migrate()
second_apply_state = {
table_name: fetch_table_rows(engine, table_name, tenant_id=fixture.primary.tenant_id)
for table_name in first_apply_state
}
assert second_apply_state == first_apply_state

File diff suppressed because it is too large Load Diff